├── README.md ├── components ├── font.ttc ├── sample_images │ ├── 000.jpg │ ├── 001.jpg │ ├── 002.jpg │ ├── 003.jpg │ ├── 004.jpg │ ├── 005.jpg │ ├── 006.jpg │ ├── 007.jpg │ ├── 008.jpg │ ├── 009.jpg │ ├── 010.jpg │ ├── 011.jpg │ ├── 012.jpg │ ├── 013.jpg │ ├── 014.jpg │ ├── 015.jpg │ ├── 016.jpg │ ├── 017.jpg │ ├── 018.jpg │ ├── 019.jpg │ ├── 020.jpg │ ├── 021.jpg │ ├── 022.jpg │ ├── 023.jpg │ ├── 024.jpg │ ├── 025.jpg │ ├── 026.jpg │ ├── 027.jpg │ ├── 028.jpg │ ├── 029.jpg │ ├── 030.jpg │ ├── 031.jpg │ ├── 032.jpg │ ├── 033.jpg │ ├── 034.jpg │ ├── 035.jpg │ ├── 036.jpg │ ├── 037.jpg │ ├── 038.jpg │ ├── 039.jpg │ ├── 040.jpg │ ├── 041.jpg │ ├── 042.jpg │ ├── 043.jpg │ ├── 044.jpg │ ├── 045.jpg │ ├── 046.jpg │ ├── 047.jpg │ ├── 048.jpg │ ├── 049.jpg │ ├── 050.jpg │ ├── 051.jpg │ ├── 052.jpg │ ├── 053.jpg │ ├── 054.jpg │ ├── 055.jpg │ ├── 056.jpg │ ├── 057.jpg │ ├── 058.jpg │ ├── 059.jpg │ ├── 060.jpg │ ├── 061.jpg │ ├── 062.jpg │ ├── 063.jpg │ ├── 064.jpg │ ├── 065.jpg │ ├── 066.jpg │ ├── 067.jpg │ ├── 068.jpg │ ├── 069.jpg │ ├── 070.jpg │ ├── 071.jpg │ ├── 072.jpg │ ├── 073.jpg │ ├── 074.jpg │ ├── 075.jpg │ ├── 076.jpg │ ├── 077.jpg │ ├── 078.jpg │ ├── 079.jpg │ ├── 080.jpg │ ├── 081.jpg │ ├── 082.jpg │ ├── 083.jpg │ ├── 084.jpg │ ├── 085.jpg │ ├── 086.jpg │ ├── 087.jpg │ ├── 088.jpg │ ├── 089.jpg │ ├── 090.jpg │ ├── 091.jpg │ ├── 092.jpg │ ├── 093.jpg │ ├── 094.jpg │ ├── 095.jpg │ ├── 096.jpg │ ├── 097.jpg │ ├── 098.jpg │ ├── 099.jpg │ ├── 100.jpg │ ├── 101.jpg │ ├── 102.jpg │ ├── 103.jpg │ ├── 104.jpg │ ├── 105.jpg │ ├── 106.jpg │ ├── 107.jpg │ ├── 108.jpg │ ├── 109.jpg │ ├── 110.jpg │ ├── 111.jpg │ ├── 112.jpg │ ├── 113.jpg │ ├── 114.jpg │ ├── 115.jpg │ ├── 116.jpg │ ├── 117.jpg │ ├── 118.jpg │ ├── 119.jpg │ ├── 120.jpg │ ├── 121.jpg │ ├── 122.jpg │ ├── 123.jpg │ ├── 124.jpg │ ├── 125.jpg │ ├── 126.jpg │ ├── 127.jpg │ ├── 128.jpg │ ├── 129.jpg │ ├── 130.jpg │ ├── 131.jpg │ ├── 132.jpg │ ├── 133.jpg │ ├── 134.jpg │ ├── 135.jpg │ ├── 136.jpg │ ├── 137.jpg │ ├── 138.jpg │ ├── 139.jpg │ ├── 140.jpg │ ├── 141.jpg │ ├── 142.jpg │ ├── 143.jpg │ ├── 144.jpg │ ├── 145.jpg │ ├── 146.jpg │ ├── 147.jpg │ ├── 148.jpg │ ├── 149.jpg │ ├── 150.jpg │ ├── 151.jpg │ ├── 152.jpg │ ├── 153.jpg │ ├── 154.jpg │ ├── 155.jpg │ ├── 156.jpg │ ├── 157.jpg │ ├── 158.jpg │ ├── 159.jpg │ ├── 160.jpg │ ├── 161.jpg │ ├── 162.jpg │ ├── 163.jpg │ ├── 164.jpg │ ├── 165.jpg │ ├── 166.jpg │ ├── 167.jpg │ ├── 168.jpg │ ├── 169.jpg │ ├── 170.jpg │ ├── 171.jpg │ ├── 172.jpg │ ├── 173.jpg │ ├── 174.jpg │ ├── 175.jpg │ ├── 176.jpg │ ├── 177.jpg │ ├── 178.jpg │ ├── 179.jpg │ ├── 180.jpg │ ├── 181.jpg │ ├── 182.jpg │ ├── 183.jpg │ ├── 184.jpg │ ├── 185.jpg │ ├── 186.jpg │ ├── 187.jpg │ ├── 188.jpg │ ├── 189.jpg │ ├── 190.jpg │ ├── 191.jpg │ ├── 192.jpg │ ├── 193.jpg │ ├── 194.jpg │ ├── 195.jpg │ ├── 196.jpg │ ├── 197.jpg │ ├── 198.jpg │ ├── 199.jpg │ └── sample.txt └── vocab.json ├── loader ├── __init__.py ├── caffe_model.py ├── caption_helper.py ├── caption_model.py ├── data_loader.py ├── feature_loader.py ├── model_loader.py ├── vqa_data_loader.py ├── vqa_model.py └── vqa_resnet.py ├── result └── pytorch_resnet18_places365 │ ├── decompose.npy │ └── snapshot │ └── 14.pth ├── script ├── dlbroden.sh └── dlzoo.sh ├── settings.py ├── test.py ├── train.py ├── util ├── __init__.py ├── clean.py ├── experiments.py ├── feature_decoder.py ├── feature_operation.py ├── image_operation.py ├── imagenet_categories.py ├── places365_categories.py ├── upsample.py └── vecquantile.py └── visualize ├── __init__.py ├── html.py └── plot.py /README.md: -------------------------------------------------------------------------------- 1 | # IBD: Interpretable Basis Decomposition for Visual Explanation 2 | 3 | ## Introduction 4 | This repository contains the demo code for the ECCV'18 paper "Interpretable Basis Decomposition for Visual Explanation". 5 | 6 | ## Download 7 | * Clone the code of Network Dissection Lite from github 8 | ``` 9 | git clone https://github.com/CSAILVision/IBD 10 | cd IBD 11 | ``` 12 | * Download the Broden dataset (~1GB space) and the example pretrained model. If you already download this, you can create a symbolic link to your original dataset. 13 | ``` 14 | ./script/dlbroden.sh 15 | ./script/dlzoo.sh 16 | ``` 17 | 18 | Note that AlexNet models work with 227x227 image input, while VGG, ResNet, GoogLeNet works with 224x224 image input. 19 | 20 | ## Requirements 21 | 22 | * Python Environments 23 | 24 | ``` 25 | pip3 install numpy sklearn scipy scikit-image matplotlib easydict torch torchvision 26 | ``` 27 | 28 | Note: The repo was written by pytorch-0.3.1. ([PyTorch](http://pytorch.org/), [Torchvision](https://github.com/pytorch/vision)) 29 | 30 | ## Run IBD in PyTorch 31 | 32 | * You can configure `settings.py` to load your own model, or change the default parameters. 33 | 34 | * Run IBD 35 | 36 | ``` 37 | python3 test.py 38 | ``` 39 | 40 | ## IBD Result 41 | 42 | * At the end of the dissection script, a HTML-formatted report will be generated inside `result` folder that summarizes the interpretable units of the tested network. 43 | 44 | 45 | ## Train Concept Basis 46 | 47 | * If you want to train the concept basis, delete the pretrained files first. 48 | ``` 49 | rm result/pytorch_resnet18_places365/snapshot/14.pth 50 | rm result/pytorch_resnet18_places365/decompose.npy 51 | 52 | ``` 53 | 54 | * Run the train script. 55 | 56 | ``` 57 | python3 train.py 58 | ``` 59 | * Then run IBD. 60 | 61 | ``` 62 | python3 test.py 63 | ``` 64 | 65 | ## Reference 66 | If you find the codes useful, please cite this paper 67 | ``` 68 | @inproceedings{IBD2018, 69 | title={Interpretable Basis Decomposition for Visual Explanation}, 70 | author={Zhou, Bolei* and Sun, Yiyou* and Bau, David* and Torralba, Antonio}, 71 | booktitle={European Conference on Computer Vision}, 72 | year={2018} 73 | } 74 | ``` 75 | 76 | -------------------------------------------------------------------------------- /components/font.ttc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/font.ttc -------------------------------------------------------------------------------- /components/sample_images/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/000.jpg -------------------------------------------------------------------------------- /components/sample_images/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/001.jpg -------------------------------------------------------------------------------- /components/sample_images/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/002.jpg -------------------------------------------------------------------------------- /components/sample_images/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/003.jpg -------------------------------------------------------------------------------- /components/sample_images/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/004.jpg -------------------------------------------------------------------------------- /components/sample_images/005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/005.jpg -------------------------------------------------------------------------------- /components/sample_images/006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/006.jpg -------------------------------------------------------------------------------- /components/sample_images/007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/007.jpg -------------------------------------------------------------------------------- /components/sample_images/008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/008.jpg -------------------------------------------------------------------------------- /components/sample_images/009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/009.jpg -------------------------------------------------------------------------------- /components/sample_images/010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/010.jpg -------------------------------------------------------------------------------- /components/sample_images/011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/011.jpg -------------------------------------------------------------------------------- /components/sample_images/012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/012.jpg -------------------------------------------------------------------------------- /components/sample_images/013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/013.jpg -------------------------------------------------------------------------------- /components/sample_images/014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/014.jpg -------------------------------------------------------------------------------- /components/sample_images/015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/015.jpg -------------------------------------------------------------------------------- /components/sample_images/016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/016.jpg -------------------------------------------------------------------------------- /components/sample_images/017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/017.jpg -------------------------------------------------------------------------------- /components/sample_images/018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/018.jpg -------------------------------------------------------------------------------- /components/sample_images/019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/019.jpg -------------------------------------------------------------------------------- /components/sample_images/020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/020.jpg -------------------------------------------------------------------------------- /components/sample_images/021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/021.jpg -------------------------------------------------------------------------------- /components/sample_images/022.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/022.jpg -------------------------------------------------------------------------------- /components/sample_images/023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/023.jpg -------------------------------------------------------------------------------- /components/sample_images/024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/024.jpg -------------------------------------------------------------------------------- /components/sample_images/025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/025.jpg -------------------------------------------------------------------------------- /components/sample_images/026.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/026.jpg -------------------------------------------------------------------------------- /components/sample_images/027.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/027.jpg -------------------------------------------------------------------------------- /components/sample_images/028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/028.jpg -------------------------------------------------------------------------------- /components/sample_images/029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/029.jpg -------------------------------------------------------------------------------- /components/sample_images/030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/030.jpg -------------------------------------------------------------------------------- /components/sample_images/031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/031.jpg -------------------------------------------------------------------------------- /components/sample_images/032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/032.jpg -------------------------------------------------------------------------------- /components/sample_images/033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/033.jpg -------------------------------------------------------------------------------- /components/sample_images/034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/034.jpg -------------------------------------------------------------------------------- /components/sample_images/035.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/035.jpg -------------------------------------------------------------------------------- /components/sample_images/036.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/036.jpg -------------------------------------------------------------------------------- /components/sample_images/037.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/037.jpg -------------------------------------------------------------------------------- /components/sample_images/038.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/038.jpg -------------------------------------------------------------------------------- /components/sample_images/039.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/039.jpg -------------------------------------------------------------------------------- /components/sample_images/040.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/040.jpg -------------------------------------------------------------------------------- /components/sample_images/041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/041.jpg -------------------------------------------------------------------------------- /components/sample_images/042.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/042.jpg -------------------------------------------------------------------------------- /components/sample_images/043.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/043.jpg -------------------------------------------------------------------------------- /components/sample_images/044.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/044.jpg -------------------------------------------------------------------------------- /components/sample_images/045.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/045.jpg -------------------------------------------------------------------------------- /components/sample_images/046.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/046.jpg -------------------------------------------------------------------------------- /components/sample_images/047.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/047.jpg -------------------------------------------------------------------------------- /components/sample_images/048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/048.jpg -------------------------------------------------------------------------------- /components/sample_images/049.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/049.jpg -------------------------------------------------------------------------------- /components/sample_images/050.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/050.jpg -------------------------------------------------------------------------------- /components/sample_images/051.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/051.jpg -------------------------------------------------------------------------------- /components/sample_images/052.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/052.jpg -------------------------------------------------------------------------------- /components/sample_images/053.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/053.jpg -------------------------------------------------------------------------------- /components/sample_images/054.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/054.jpg -------------------------------------------------------------------------------- /components/sample_images/055.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/055.jpg -------------------------------------------------------------------------------- /components/sample_images/056.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/056.jpg -------------------------------------------------------------------------------- /components/sample_images/057.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/057.jpg -------------------------------------------------------------------------------- /components/sample_images/058.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/058.jpg -------------------------------------------------------------------------------- /components/sample_images/059.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/059.jpg -------------------------------------------------------------------------------- /components/sample_images/060.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/060.jpg -------------------------------------------------------------------------------- /components/sample_images/061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/061.jpg -------------------------------------------------------------------------------- /components/sample_images/062.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/062.jpg -------------------------------------------------------------------------------- /components/sample_images/063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/063.jpg -------------------------------------------------------------------------------- /components/sample_images/064.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/064.jpg -------------------------------------------------------------------------------- /components/sample_images/065.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/065.jpg -------------------------------------------------------------------------------- /components/sample_images/066.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/066.jpg -------------------------------------------------------------------------------- /components/sample_images/067.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/067.jpg -------------------------------------------------------------------------------- /components/sample_images/068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/068.jpg -------------------------------------------------------------------------------- /components/sample_images/069.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/069.jpg -------------------------------------------------------------------------------- /components/sample_images/070.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/070.jpg -------------------------------------------------------------------------------- /components/sample_images/071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/071.jpg -------------------------------------------------------------------------------- /components/sample_images/072.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/072.jpg -------------------------------------------------------------------------------- /components/sample_images/073.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/073.jpg -------------------------------------------------------------------------------- /components/sample_images/074.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/074.jpg -------------------------------------------------------------------------------- /components/sample_images/075.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/075.jpg -------------------------------------------------------------------------------- /components/sample_images/076.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/076.jpg -------------------------------------------------------------------------------- /components/sample_images/077.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/077.jpg -------------------------------------------------------------------------------- /components/sample_images/078.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/078.jpg -------------------------------------------------------------------------------- /components/sample_images/079.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/079.jpg -------------------------------------------------------------------------------- /components/sample_images/080.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/080.jpg -------------------------------------------------------------------------------- /components/sample_images/081.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/081.jpg -------------------------------------------------------------------------------- /components/sample_images/082.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/082.jpg -------------------------------------------------------------------------------- /components/sample_images/083.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/083.jpg -------------------------------------------------------------------------------- /components/sample_images/084.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/084.jpg -------------------------------------------------------------------------------- /components/sample_images/085.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/085.jpg -------------------------------------------------------------------------------- /components/sample_images/086.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/086.jpg -------------------------------------------------------------------------------- /components/sample_images/087.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/087.jpg -------------------------------------------------------------------------------- /components/sample_images/088.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/088.jpg -------------------------------------------------------------------------------- /components/sample_images/089.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/089.jpg -------------------------------------------------------------------------------- /components/sample_images/090.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/090.jpg -------------------------------------------------------------------------------- /components/sample_images/091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/091.jpg -------------------------------------------------------------------------------- /components/sample_images/092.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/092.jpg -------------------------------------------------------------------------------- /components/sample_images/093.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/093.jpg -------------------------------------------------------------------------------- /components/sample_images/094.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/094.jpg -------------------------------------------------------------------------------- /components/sample_images/095.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/095.jpg -------------------------------------------------------------------------------- /components/sample_images/096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/096.jpg -------------------------------------------------------------------------------- /components/sample_images/097.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/097.jpg -------------------------------------------------------------------------------- /components/sample_images/098.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/098.jpg -------------------------------------------------------------------------------- /components/sample_images/099.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/099.jpg -------------------------------------------------------------------------------- /components/sample_images/100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/100.jpg -------------------------------------------------------------------------------- /components/sample_images/101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/101.jpg -------------------------------------------------------------------------------- /components/sample_images/102.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/102.jpg -------------------------------------------------------------------------------- /components/sample_images/103.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/103.jpg -------------------------------------------------------------------------------- /components/sample_images/104.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/104.jpg -------------------------------------------------------------------------------- /components/sample_images/105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/105.jpg -------------------------------------------------------------------------------- /components/sample_images/106.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/106.jpg -------------------------------------------------------------------------------- /components/sample_images/107.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/107.jpg -------------------------------------------------------------------------------- /components/sample_images/108.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/108.jpg -------------------------------------------------------------------------------- /components/sample_images/109.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/109.jpg -------------------------------------------------------------------------------- /components/sample_images/110.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/110.jpg -------------------------------------------------------------------------------- /components/sample_images/111.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/111.jpg -------------------------------------------------------------------------------- /components/sample_images/112.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/112.jpg -------------------------------------------------------------------------------- /components/sample_images/113.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/113.jpg -------------------------------------------------------------------------------- /components/sample_images/114.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/114.jpg -------------------------------------------------------------------------------- /components/sample_images/115.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/115.jpg -------------------------------------------------------------------------------- /components/sample_images/116.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/116.jpg -------------------------------------------------------------------------------- /components/sample_images/117.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/117.jpg -------------------------------------------------------------------------------- /components/sample_images/118.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/118.jpg -------------------------------------------------------------------------------- /components/sample_images/119.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/119.jpg -------------------------------------------------------------------------------- /components/sample_images/120.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/120.jpg -------------------------------------------------------------------------------- /components/sample_images/121.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/121.jpg -------------------------------------------------------------------------------- /components/sample_images/122.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/122.jpg -------------------------------------------------------------------------------- /components/sample_images/123.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/123.jpg -------------------------------------------------------------------------------- /components/sample_images/124.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/124.jpg -------------------------------------------------------------------------------- /components/sample_images/125.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/125.jpg -------------------------------------------------------------------------------- /components/sample_images/126.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/126.jpg -------------------------------------------------------------------------------- /components/sample_images/127.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/127.jpg -------------------------------------------------------------------------------- /components/sample_images/128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/128.jpg -------------------------------------------------------------------------------- /components/sample_images/129.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/129.jpg -------------------------------------------------------------------------------- /components/sample_images/130.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/130.jpg -------------------------------------------------------------------------------- /components/sample_images/131.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/131.jpg -------------------------------------------------------------------------------- /components/sample_images/132.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/132.jpg -------------------------------------------------------------------------------- /components/sample_images/133.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/133.jpg -------------------------------------------------------------------------------- /components/sample_images/134.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/134.jpg -------------------------------------------------------------------------------- /components/sample_images/135.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/135.jpg -------------------------------------------------------------------------------- /components/sample_images/136.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/136.jpg -------------------------------------------------------------------------------- /components/sample_images/137.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/137.jpg -------------------------------------------------------------------------------- /components/sample_images/138.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/138.jpg -------------------------------------------------------------------------------- /components/sample_images/139.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/139.jpg -------------------------------------------------------------------------------- /components/sample_images/140.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/140.jpg -------------------------------------------------------------------------------- /components/sample_images/141.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/141.jpg -------------------------------------------------------------------------------- /components/sample_images/142.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/142.jpg -------------------------------------------------------------------------------- /components/sample_images/143.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/143.jpg -------------------------------------------------------------------------------- /components/sample_images/144.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/144.jpg -------------------------------------------------------------------------------- /components/sample_images/145.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/145.jpg -------------------------------------------------------------------------------- /components/sample_images/146.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/146.jpg -------------------------------------------------------------------------------- /components/sample_images/147.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/147.jpg -------------------------------------------------------------------------------- /components/sample_images/148.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/148.jpg -------------------------------------------------------------------------------- /components/sample_images/149.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/149.jpg -------------------------------------------------------------------------------- /components/sample_images/150.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/150.jpg -------------------------------------------------------------------------------- /components/sample_images/151.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/151.jpg -------------------------------------------------------------------------------- /components/sample_images/152.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/152.jpg -------------------------------------------------------------------------------- /components/sample_images/153.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/153.jpg -------------------------------------------------------------------------------- /components/sample_images/154.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/154.jpg -------------------------------------------------------------------------------- /components/sample_images/155.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/155.jpg -------------------------------------------------------------------------------- /components/sample_images/156.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/156.jpg -------------------------------------------------------------------------------- /components/sample_images/157.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/157.jpg -------------------------------------------------------------------------------- /components/sample_images/158.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/158.jpg -------------------------------------------------------------------------------- /components/sample_images/159.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/159.jpg -------------------------------------------------------------------------------- /components/sample_images/160.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/160.jpg -------------------------------------------------------------------------------- /components/sample_images/161.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/161.jpg -------------------------------------------------------------------------------- /components/sample_images/162.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/162.jpg -------------------------------------------------------------------------------- /components/sample_images/163.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/163.jpg -------------------------------------------------------------------------------- /components/sample_images/164.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/164.jpg -------------------------------------------------------------------------------- /components/sample_images/165.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/165.jpg -------------------------------------------------------------------------------- /components/sample_images/166.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/166.jpg -------------------------------------------------------------------------------- /components/sample_images/167.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/167.jpg -------------------------------------------------------------------------------- /components/sample_images/168.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/168.jpg -------------------------------------------------------------------------------- /components/sample_images/169.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/169.jpg -------------------------------------------------------------------------------- /components/sample_images/170.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/170.jpg -------------------------------------------------------------------------------- /components/sample_images/171.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/171.jpg -------------------------------------------------------------------------------- /components/sample_images/172.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/172.jpg -------------------------------------------------------------------------------- /components/sample_images/173.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/173.jpg -------------------------------------------------------------------------------- /components/sample_images/174.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/174.jpg -------------------------------------------------------------------------------- /components/sample_images/175.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/175.jpg -------------------------------------------------------------------------------- /components/sample_images/176.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/176.jpg -------------------------------------------------------------------------------- /components/sample_images/177.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/177.jpg -------------------------------------------------------------------------------- /components/sample_images/178.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/178.jpg -------------------------------------------------------------------------------- /components/sample_images/179.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/179.jpg -------------------------------------------------------------------------------- /components/sample_images/180.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/180.jpg -------------------------------------------------------------------------------- /components/sample_images/181.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/181.jpg -------------------------------------------------------------------------------- /components/sample_images/182.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/182.jpg -------------------------------------------------------------------------------- /components/sample_images/183.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/183.jpg -------------------------------------------------------------------------------- /components/sample_images/184.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/184.jpg -------------------------------------------------------------------------------- /components/sample_images/185.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/185.jpg -------------------------------------------------------------------------------- /components/sample_images/186.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/186.jpg -------------------------------------------------------------------------------- /components/sample_images/187.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/187.jpg -------------------------------------------------------------------------------- /components/sample_images/188.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/188.jpg -------------------------------------------------------------------------------- /components/sample_images/189.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/189.jpg -------------------------------------------------------------------------------- /components/sample_images/190.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/190.jpg -------------------------------------------------------------------------------- /components/sample_images/191.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/191.jpg -------------------------------------------------------------------------------- /components/sample_images/192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/192.jpg -------------------------------------------------------------------------------- /components/sample_images/193.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/193.jpg -------------------------------------------------------------------------------- /components/sample_images/194.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/194.jpg -------------------------------------------------------------------------------- /components/sample_images/195.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/195.jpg -------------------------------------------------------------------------------- /components/sample_images/196.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/196.jpg -------------------------------------------------------------------------------- /components/sample_images/197.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/197.jpg -------------------------------------------------------------------------------- /components/sample_images/198.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/198.jpg -------------------------------------------------------------------------------- /components/sample_images/199.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/components/sample_images/199.jpg -------------------------------------------------------------------------------- /components/sample_images/sample.txt: -------------------------------------------------------------------------------- 1 | 000.jpg 2 | 001.jpg 3 | 002.jpg 4 | 003.jpg 5 | 004.jpg 6 | 005.jpg 7 | 006.jpg 8 | 007.jpg 9 | 008.jpg 10 | 009.jpg 11 | 010.jpg 12 | 011.jpg 13 | 012.jpg 14 | 013.jpg 15 | 014.jpg 16 | 015.jpg 17 | 016.jpg 18 | 017.jpg 19 | 018.jpg 20 | 019.jpg 21 | 020.jpg -------------------------------------------------------------------------------- /loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/loader/__init__.py -------------------------------------------------------------------------------- /loader/caffe_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | import settings 7 | 8 | class FCView(nn.Module): 9 | def __init__(self): 10 | super(FCView, self).__init__() 11 | 12 | def forward(self, x): 13 | nB = x.data.size(0) 14 | x = x.view(nB,-1) 15 | return x 16 | def __repr__(self): 17 | return 'view(nB, -1)' 18 | 19 | class LRN(nn.Module): 20 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): 21 | super(LRN, self).__init__() 22 | self.ACROSS_CHANNELS = ACROSS_CHANNELS 23 | if ACROSS_CHANNELS: 24 | self.average=nn.AvgPool3d(kernel_size=(int(local_size), 1, 1), 25 | stride=1,padding=(int((local_size-1.0)/2),0,0)) 26 | else: 27 | self.average=nn.AvgPool2d(kernel_size=int(local_size), 28 | stride=1, 29 | padding=int((local_size-1.0)/2)) 30 | self.alpha = alpha 31 | self.beta = beta 32 | 33 | 34 | def forward(self, x): 35 | if self.ACROSS_CHANNELS: 36 | div = x.pow(2).unsqueeze(1) 37 | div = self.average(div).squeeze(1) 38 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 39 | else: 40 | div = x.pow(2) 41 | div = self.average(div) 42 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 43 | x = x.div(div) 44 | return x 45 | 46 | class Eltwise(nn.Module): 47 | def __init__(self, operation='+'): 48 | super(Eltwise, self).__init__() 49 | self.operation = operation 50 | 51 | def forward(self, x1, x2): 52 | if self.operation == '+' or self.operation == 'SUM': 53 | x = x1 + x2 54 | if self.operation == '*' or self.operation == 'MUL': 55 | x = x1 * x2 56 | if self.operation == '/' or self.operation == 'DIV': 57 | x = x1 / x2 58 | return x 59 | 60 | class CaffeNetCAM(nn.Module): 61 | def __init__(self): 62 | super(CaffeNetCAM, self).__init__() 63 | self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4) 64 | self.relu1 = nn.ReLU(inplace=True) 65 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 66 | self.norm1 = LRN(local_size=5, alpha=0.0001, beta=0.75) 67 | 68 | self.conv2 = nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2) 69 | self.relu2 = nn.ReLU(inplace=True) 70 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 71 | self.norm2 = LRN(local_size=5, alpha=0.0001, beta=0.75) 72 | 73 | self.conv3 = nn.Conv2d(256, 384, kernel_size=3, padding=1) 74 | self.relu3 = nn.ReLU(inplace=True) 75 | self.conv4 = nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2) 76 | self.relu4 = nn.ReLU(inplace=True) 77 | self.conv5 = nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2) 78 | self.relu5 = nn.ReLU(inplace=True) 79 | 80 | self.CAM_conv = nn.Conv2d(256, 1024, kernel_size=3, padding=1, groups=2) 81 | self.CAM_relu = nn.ReLU(inplace=True) 82 | self.CAM_pool = nn.AvgPool2d(kernel_size=13, stride=13, padding=0, ceil_mode=False, count_include_pad=True) 83 | self.CAM_fc = nn.Sequential( 84 | FCView(), 85 | nn.Linear(1024, settings.NUM_CLASSES), 86 | ) 87 | self.prob = nn.Softmax() 88 | 89 | def forward(self, x): 90 | x = self.norm1(self.pool1(self.relu1(self.conv1(x)))) 91 | x = self.norm2(self.pool2(self.relu2(self.conv2(x)))) 92 | x = self.relu5(self.conv5(self.relu4(self.conv4(self.relu3(self.conv3(x)))))) 93 | x = self.prob(self.CAM_fc(self.CAM_pool(self.CAM_relu(self.CAM_conv(x))))) 94 | return x 95 | 96 | 97 | class VGG16CAM(nn.Module): 98 | def __init__(self): 99 | super(VGG16CAM, self).__init__() 100 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 101 | self.relu1_1 = nn.ReLU() 102 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 103 | self.relu1_2 = nn.ReLU() 104 | self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 105 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 106 | self.relu2_1 = nn.ReLU() 107 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 108 | self.relu2_2 = nn.ReLU() 109 | self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 110 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 111 | self.relu3_1 = nn.ReLU() 112 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 113 | self.relu3_2 = nn.ReLU() 114 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 115 | self.relu3_3 = nn.ReLU() 116 | self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 117 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 118 | self.relu4_1 = nn.ReLU() 119 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 120 | self.relu4_2 = nn.ReLU() 121 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 122 | self.relu4_3 = nn.ReLU() 123 | self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 124 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 125 | self.relu5_1 = nn.ReLU() 126 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 127 | self.relu5_2 = nn.ReLU() 128 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 129 | self.relu5_3 = nn.ReLU() 130 | self.CAM_conv = nn.Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2) 131 | self.CAM_relu = nn.ReLU() 132 | self.CAM_pool = nn.AvgPool2d(kernel_size=14, stride=14, padding=0, ceil_mode=False, count_include_pad=True) 133 | self.CAM_fc = nn.Sequential( 134 | FCView(), 135 | nn.Linear(in_features=1024, out_features=365) 136 | ) 137 | self.prob = nn.Softmax() 138 | 139 | def forward(self, x): 140 | x = self.pool1(self.relu1_2(self.conv1_2(self.relu1_1(self.conv1_1(x))))) 141 | x = self.pool2(self.relu2_2(self.conv2_2(self.relu2_1(self.conv2_1(x))))) 142 | x = self.pool3(self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(self.relu3_1(self.conv3_1(x))))))) 143 | x = self.pool4(self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(self.relu4_1(self.conv4_1(x))))))) 144 | x =self.relu5_3(self.conv5_3(self.relu5_2(self.conv5_2(self.relu5_1(self.conv5_1(x)))))) 145 | x = self.prob(self.CAM_fc(self.CAM_pool(self.CAM_relu(self.CAM_conv(x))))) 146 | return x 147 | 148 | 149 | class VGG16(nn.Module): 150 | def __init__(self): 151 | super(VGG16, self).__init__() 152 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 153 | self.relu1_1 = nn.ReLU() 154 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 155 | self.relu1_2 = nn.ReLU() 156 | self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 157 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 158 | self.relu2_1 = nn.ReLU() 159 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 160 | self.relu2_2 = nn.ReLU() 161 | self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 162 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 163 | self.relu3_1 = nn.ReLU() 164 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 165 | self.relu3_2 = nn.ReLU() 166 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 167 | self.relu3_3 = nn.ReLU() 168 | self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 169 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 170 | self.relu4_1 = nn.ReLU() 171 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 172 | self.relu4_2 = nn.ReLU() 173 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 174 | self.relu4_3 = nn.ReLU() 175 | self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 176 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 177 | self.relu5_1 = nn.ReLU() 178 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 179 | self.relu5_2 = nn.ReLU() 180 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 181 | self.relu5_3 = nn.ReLU() 182 | self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 183 | 184 | self.fc6 = nn.Sequential( 185 | FCView(), 186 | nn.Linear(25088, 4096), 187 | ) 188 | self.relu6 = nn.ReLU(inplace=True) 189 | self.fc7 = nn.Linear(in_features=4096, out_features=4096) 190 | self.relu7 = nn.ReLU(inplace=True) 191 | self.fc8a = nn.Linear(in_features=4096, out_features=365) 192 | self.prob = nn.Softmax() 193 | 194 | def forward(self, x): 195 | x = self.pool1(self.relu1_2(self.conv1_2(self.relu1_1(self.conv1_1(x))))) 196 | x = self.pool2(self.relu2_2(self.conv2_2(self.relu2_1(self.conv2_1(x))))) 197 | x = self.pool3(self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(self.relu3_1(self.conv3_1(x))))))) 198 | x = self.pool4(self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(self.relu4_1(self.conv4_1(x))))))) 199 | x = self.relu5_3(self.conv5_3(self.relu5_2(self.conv5_2(self.relu5_1(self.conv5_1(x)))))) 200 | x = self.fc8a(self.relu7(self.fc7(self.relu6(self.fc6(self.pool5(x)))))) 201 | 202 | return x 203 | 204 | 205 | class CaffeNet_David(nn.Module): 206 | def __init__(self, dropout=False, bn=False): 207 | super(CaffeNet_David, self).__init__() 208 | self.dropout = dropout 209 | self.bn = bn 210 | if self.bn: 211 | self.bn1 = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 212 | self.bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) 213 | self.bn3 = nn.BatchNorm2d(384, eps=1e-05, momentum=0.9, affine=True) 214 | self.bn4 = nn.BatchNorm2d(384, eps=1e-05, momentum=0.9, affine=True) 215 | self.bn5 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) 216 | self.bn6 = nn.BatchNorm2d(4096, eps=1e-05, momentum=0.9, affine=True) 217 | self.bn7 = nn.BatchNorm2d(4096, eps=1e-05, momentum=0.9, affine=True) 218 | self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4) 219 | self.relu1 = nn.ReLU(inplace=True) 220 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 221 | 222 | self.conv2 = nn.Conv2d(96, 256, kernel_size=5, padding=2) 223 | self.relu2 = nn.ReLU(inplace=True) 224 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 225 | 226 | self.conv3 = nn.Conv2d(256, 384, kernel_size=3, padding=1) 227 | self.relu3 = nn.ReLU(inplace=True) 228 | self.conv4 = nn.Conv2d(384, 384, kernel_size=3, padding=1) 229 | self.relu4 = nn.ReLU(inplace=True) 230 | self.conv5 = nn.Conv2d(384, 256, kernel_size=3, padding=1) 231 | self.relu5 = nn.ReLU(inplace=True) 232 | self.pool5 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1)) 233 | 234 | if dropout: 235 | self.drop1 = nn.Dropout() 236 | self.drop2 = nn.Dropout() 237 | self.fc6 = nn.Sequential( 238 | FCView(), 239 | nn.Linear(9216, 4096), 240 | ) 241 | self.relu6 = nn.ReLU(inplace=True) 242 | self.fc7 = nn.Linear(in_features=4096, out_features=4096) 243 | self.relu7 = nn.ReLU(inplace=True) 244 | self.fc8 = nn.Linear(in_features=4096, out_features=365) 245 | 246 | def forward(self, x): 247 | if self.bn: 248 | x = self.pool1(self.relu1(self.bn1(self.conv1(x)))) 249 | x = self.pool2(self.relu2(self.bn2(self.conv2(x)))) 250 | x = self.relu5(self.bn5(self.conv5(self.relu4(self.bn4(self.conv4(self.relu3(self.bn3(self.conv3(x))))))))) 251 | if self.dropout: 252 | x = self.fc8(self.relu7(self.bn7(self.fc7(self.drop2(self.relu6(self.bn6(self.fc6(self.drop1(self.pool5(x)))))))))) 253 | else: 254 | x = self.fc8(self.relu7(self.bn7(self.fc7(self.relu6(self.bn6(self.fc6(self.pool5(x)))))))) 255 | else: 256 | x = self.pool1(self.relu1(self.conv1(x))) 257 | x = self.pool2(self.relu2(self.conv2(x))) 258 | x = self.relu5(self.conv5(self.relu4(self.conv4(self.relu3(self.conv3(x)))))) 259 | if self.dropout: 260 | x = self.fc8(self.relu7(self.fc7(self.drop2(self.relu6(self.fc6(self.drop1(self.pool5(x)))))))) 261 | else: 262 | x = self.fc8(self.relu7(self.fc7(self.relu6(self.fc6(self.pool5(x)))))) 263 | return x 264 | 265 | def load_state_dict(self, state_dict, strict=True): 266 | """Copies parameters and buffers from :attr:`state_dict` into 267 | this module and its descendants. If :attr:`strict` is ``True`` then 268 | the keys of :attr:`state_dict` must exactly match the keys returned 269 | by this module's :func:`state_dict()` function. 270 | 271 | Arguments: 272 | state_dict (dict): A dict containing parameters and 273 | persistent buffers. 274 | strict (bool): Strictly enforce that the keys in :attr:`state_dict` 275 | match the keys returned by this module's `:func:`state_dict()` 276 | function. 277 | """ 278 | own_state = self.state_dict() 279 | for name, param in state_dict.items(): 280 | if name in own_state: 281 | if isinstance(param, nn.Parameter): 282 | # backwards compatibility for serialized parameters 283 | param = param.data 284 | try: 285 | own_state[name].copy_(param) 286 | except Exception: 287 | raise RuntimeError('While copying the parameter named {}, ' 288 | 'whose dimensions in the model are {} and ' 289 | 'whose dimensions in the checkpoint are {}.' 290 | .format(name, own_state[name].size(), param.size())) 291 | elif strict: 292 | raise KeyError('unexpected key "{}" in state_dict' 293 | .format(name)) 294 | if strict: 295 | missing = set(own_state.keys()) - set(state_dict.keys()) 296 | if len(missing) > 0: 297 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 298 | -------------------------------------------------------------------------------- /loader/caption_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Class for generating captions from an image-to-text model. 16 | Adapted from https://github.com/tensorflow/models/blob/master/im2txt/im2txt/inference_utils/caption_generator.py""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import torch 23 | from torch.autograd import Variable 24 | from torch.nn.functional import log_softmax 25 | import heapq 26 | import math 27 | 28 | 29 | class Caption(object): 30 | """Represents a complete or partial caption.""" 31 | 32 | def __init__(self, sentence, state, logprob, score, metadata=None): 33 | """Initializes the Caption. 34 | 35 | Args: 36 | sentence: List of word ids in the caption. 37 | state: Model state after generating the previous word. 38 | logprob: Log-probability of the caption. 39 | score: Score of the caption. 40 | metadata: Optional metadata associated with the partial sentence. If not 41 | None, a list of strings with the same length as 'sentence'. 42 | """ 43 | self.sentence = sentence 44 | self.state = state 45 | self.logprob = logprob 46 | self.score = score 47 | self.metadata = metadata 48 | 49 | def __cmp__(self, other): 50 | """Compares Captions by score.""" 51 | assert isinstance(other, Caption) 52 | if self.score == other.score: 53 | return 0 54 | elif self.score < other.score: 55 | return -1 56 | else: 57 | return 1 58 | 59 | # For Python 3 compatibility (__cmp__ is deprecated). 60 | def __lt__(self, other): 61 | assert isinstance(other, Caption) 62 | return self.score < other.score 63 | 64 | # Also for Python 3 compatibility. 65 | def __eq__(self, other): 66 | assert isinstance(other, Caption) 67 | return self.score == other.score 68 | 69 | 70 | class CaptionGenerator(object): 71 | 72 | """Class to generate captions from an image-to-text model.""" 73 | 74 | def __init__(self, 75 | embedder, 76 | rnn, 77 | classifier, 78 | eos_id, 79 | beam_size=1, 80 | max_caption_length=20, 81 | length_normalization_factor=0.0): 82 | """Initializes the generator. 83 | 84 | Args: 85 | model: recurrent model, with inputs: (input, state) and outputs len(vocab) values 86 | beam_size: Beam size to use when generating captions. 87 | max_caption_length: The maximum caption length before stopping the search. 88 | length_normalization_factor: If != 0, a number x such that captions are 89 | scored by logprob/length^x, rather than logprob. This changes the 90 | relative scores of captions depending on their lengths. For example, if 91 | x > 0 then longer captions will be favored. 92 | """ 93 | self.embedder = embedder 94 | self.rnn = rnn 95 | self.classifier = classifier 96 | self.eos_id = eos_id 97 | self.beam_size = beam_size 98 | self.max_caption_length = max_caption_length 99 | self.length_normalization_factor = length_normalization_factor 100 | 101 | def beam_search(self, rnn_input, initial_state=None): 102 | """Runs beam search caption generation on a single image. 103 | 104 | Args: 105 | initial_state: An initial state for the recurrent model 106 | 107 | Returns: 108 | A list of Caption sorted by descending score. 109 | """ 110 | batch_size = rnn_input.data.shape[1] 111 | word_vars = [] 112 | def get_topk_words(embeddings, state): 113 | output, new_states = self.rnn(embeddings, state) 114 | output = self.classifier(output.squeeze(0)) 115 | logprobs = log_softmax(output) 116 | logprobs, words = logprobs.topk(self.beam_size, 1) 117 | word_vars.append(words) 118 | return words.data, logprobs.data, new_states 119 | 120 | partial_captions = TopN(self.beam_size) 121 | complete_captions = TopN(self.beam_size) 122 | 123 | words, logprobs, new_state = get_topk_words(rnn_input, initial_state) 124 | 125 | for k in range(self.beam_size): 126 | cap = Caption( 127 | sentence=[words[:, k]], 128 | state=new_state, 129 | logprob=logprobs[:, k], 130 | score=logprobs[:, k]) 131 | partial_captions.push(cap) 132 | 133 | # Run beam search. 134 | for _ in range(self.max_caption_length - 1): 135 | partial_captions_list = partial_captions.extract() 136 | partial_captions.reset() 137 | # input_feed = torch.LongTensor([c.sentence[-1] for c in partial_captions_list]) 138 | # if rnn_input.is_cuda: 139 | # input_feed = input_feed.cuda() 140 | # input_feed = Variable(input_feed, volatile=True) 141 | input_feed = Variable(partial_captions_list[0].sentence[-1], volatile=True) 142 | state_feed = [c.state for c in partial_captions_list] 143 | if isinstance(state_feed[0], tuple): 144 | state_feed_h, state_feed_c = zip(*state_feed) 145 | state_feed = (torch.cat(state_feed_h, 1), 146 | torch.cat(state_feed_c, 1)) 147 | else: 148 | state_feed = torch.cat(state_feed, 1) 149 | 150 | embeddings = self.embedder(input_feed).view(1, len(input_feed), -1) 151 | words, logprobs, new_states = get_topk_words( 152 | embeddings, state_feed) 153 | for i, partial_caption in enumerate(partial_captions_list): 154 | if isinstance(new_states, tuple): 155 | state = (new_states[0].narrow(1, i, batch_size), 156 | new_states[1].narrow(1, i, batch_size)) 157 | else: 158 | state = new_states[i] 159 | for k in range(self.beam_size): 160 | w = words[:, k] 161 | sentence = partial_caption.sentence + [w] 162 | logprob = partial_caption.logprob + logprobs[:, k] 163 | score = logprob 164 | beam = Caption(sentence, state, logprob, score) 165 | partial_captions.push(beam) 166 | if partial_captions.size() == 0: 167 | # We have run out of partial candidates; happens when beam_size 168 | # = 1. 169 | break 170 | 171 | # If we have no complete captions then fall back to the partial captions. 172 | # But never output a mixture of complete and partial captions because a 173 | # partial caption could have a higher score than all the complete 174 | # captions. 175 | if not complete_captions.size(): 176 | complete_captions = partial_captions 177 | 178 | caps = complete_captions.extract(sort=True) 179 | 180 | return [[ list(sent) for sent in c.sentence ]for c in caps], [c.score for c in caps], word_vars 181 | 182 | class TopN(object): 183 | """Maintains the top n elements of an incrementally provided set.""" 184 | 185 | def __init__(self, n): 186 | self._n = n 187 | self._data = [] 188 | 189 | def size(self): 190 | assert self._data is not None 191 | return len(self._data) 192 | 193 | def push(self, x): 194 | """Pushes a new element.""" 195 | assert self._data is not None 196 | if len(self._data) < self._n: 197 | heapq.heappush(self._data, x) 198 | else: 199 | heapq.heappushpop(self._data, x) 200 | 201 | def extract(self, sort=False): 202 | """Extracts all elements from the TopN. This is a destructive operation. 203 | 204 | The only method that can be called immediately after extract() is reset(). 205 | 206 | Args: 207 | sort: Whether to return the elements in descending sorted order. 208 | 209 | Returns: 210 | A list of data; the top n elements provided to the set. 211 | """ 212 | assert self._data is not None 213 | data = self._data 214 | self._data = None 215 | if sort: 216 | data.sort(reverse=True) 217 | return data 218 | 219 | def reset(self): 220 | """Returns the TopN to an empty state.""" 221 | self._data = [] 222 | -------------------------------------------------------------------------------- /loader/caption_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | import settings 5 | import string 6 | from torch.nn.utils.rnn import pack_padded_sequence 7 | from torch.autograd import Variable 8 | from loader.caption_helper import CaptionGenerator 9 | import numpy as np 10 | 11 | __UNK_TOKEN = 'UNK' 12 | __PAD_TOKEN = 'PAD' 13 | __EOS_TOKEN = 'EOS' 14 | 15 | def simple_tokenize(captions): 16 | processed = [] 17 | for j, s in enumerate(captions): 18 | txt = str(s).lower().translate(string.punctuation).strip().split() 19 | processed.append(txt) 20 | return processed 21 | 22 | def create_target(vocab): 23 | word2idx = {word: idx for idx, word in enumerate(vocab)} 24 | unk = word2idx[__UNK_TOKEN] 25 | 26 | def get_caption(captions): 27 | captions = simple_tokenize(captions) 28 | caption = captions[0] 29 | targets = [] 30 | for w in caption: 31 | targets.append(word2idx.get(w, unk)) 32 | return torch.Tensor(targets) 33 | return get_caption 34 | 35 | 36 | def create_batches(vocab, max_length=settings.BATCH_SIZE): 37 | padding = vocab.index(__PAD_TOKEN) 38 | eos = vocab.index(__EOS_TOKEN) 39 | 40 | def collate(img_cap): 41 | imgs, caps = img_cap 42 | imgs = torch.cat([img.unsqueeze(0) for img in imgs], 0) 43 | lengths = [min(len(c) + 1, max_length) for c in caps] 44 | batch_length = max(lengths) 45 | cap_tensor = torch.LongTensor(batch_length, len(caps)).fill_(padding) 46 | for i, c in enumerate(caps): 47 | end_cap = lengths[i] 48 | if end_cap < batch_length: 49 | cap_tensor[end_cap, i] = eos 50 | 51 | cap_tensor[:end_cap, i].copy_(c[:end_cap]) 52 | 53 | return (imgs, (cap_tensor, lengths)) 54 | return collate 55 | 56 | 57 | 58 | class CaptionModel(nn.Module): 59 | 60 | def __init__(self, cnn=None, vocab=None, voc_size=10003, embedding_size=256, rnn_size=256, num_layers=2, 61 | share_embedding_weights=False): 62 | super(CaptionModel, self).__init__() 63 | self.vocab = vocab 64 | if cnn: 65 | self.cnn = cnn 66 | else: 67 | self.cnn = models.__dict__[settings.CNN_MODEL](pretrained=True) 68 | self.cnn.fc = nn.Linear(self.cnn.fc.in_features, embedding_size) 69 | self.rnn = nn.LSTM(embedding_size, rnn_size, num_layers=num_layers) 70 | self.classifier = nn.Linear(rnn_size, voc_size) 71 | self.embedder = nn.Embedding(voc_size, embedding_size) 72 | if share_embedding_weights: 73 | self.embedder.weight = self.classifier.weight 74 | 75 | 76 | def forward(self, imgs, captions ): 77 | captions = torch.LongTensor(captions).transpose(1,0)[:,:-1] 78 | if settings.GPU: 79 | captions = captions.cuda() 80 | embeddings = self.embedder(Variable(captions)) 81 | 82 | img_feats = self.cnn(imgs).unsqueeze(1) 83 | embeddings = torch.cat([img_feats, embeddings], 1) 84 | feats, state = self.rnn(embeddings.transpose(1,0)) 85 | pred = self.classifier(feats.transpose(1,0)).max(2)[0] 86 | 87 | return pred, state 88 | 89 | def generate(self, img,eos_token='EOS',max_caption_length=20): 90 | cap_gen = CaptionGenerator(embedder=self.embedder, 91 | rnn=self.rnn, 92 | classifier=self.classifier, 93 | eos_id=self.vocab.index(eos_token), 94 | max_caption_length=max_caption_length) 95 | img_feats = self.cnn(img).unsqueeze(0) 96 | sentences, score, vars = cap_gen.beam_search(img_feats) 97 | words = [[self.vocab[sentences[0][wid][b]] for wid in range(max_caption_length)] for b in 98 | range(img.data.shape[0])] 99 | 100 | return sentences, words 101 | 102 | def save_checkpoint(self, filename): 103 | torch.save({'embedder_dict': self.embedder.state_dict(), 104 | 'rnn_dict': self.rnn.state_dict(), 105 | 'cnn_dict': self.cnn.state_dict(), 106 | 'classifier_dict': self.classifier.state_dict(), 107 | 'vocab': self.vocab}, 108 | filename) 109 | 110 | def load_checkpoint(self, filename): 111 | if not settings.GPU: 112 | cpnt = torch.load(filename, map_location=lambda storage, loc: storage ) 113 | else: 114 | cpnt = torch.load(filename) 115 | 116 | if 'cnn_dict' in cpnt: 117 | self.cnn.load_state_dict(cpnt['cnn_dict']) 118 | self.embedder.load_state_dict(cpnt['embedder_dict']) 119 | self.rnn.load_state_dict(cpnt['rnn_dict']) 120 | self.vocab=cpnt['vocab'] 121 | self.classifier.load_state_dict(cpnt['classifier_dict']) 122 | 123 | def finetune_cnn(self, allow=True): 124 | for p in self.cnn.parameters(): 125 | p.requires_grad = allow 126 | for p in self.cnn.fc.parameters(): 127 | p.requires_grad = True 128 | -------------------------------------------------------------------------------- /loader/model_loader.py: -------------------------------------------------------------------------------- 1 | import settings 2 | import torch 3 | import torchvision 4 | import torch.nn as nn 5 | from util.feature_operation import hook_feature, hook_grad 6 | 7 | def loadmodel(): 8 | if settings.APP == "vqa": 9 | from loader.vqa_resnet import resnet152 10 | from loader.vqa_model import VQANet, SimpleVQANet 11 | net_status = torch.load(settings.MODEL_FILE) 12 | vqa_net = nn.DataParallel(SimpleVQANet(net_status['metadata']['num_tokens'])).cuda() 13 | vqa_net.load_state_dict(net_status['weights']) 14 | vqa_net.cnn = resnet152(pretrained=True) 15 | for name in settings.FEATURE_NAMES: 16 | vqa_net.cnn._modules.get(name).register_forward_hook(hook_feature) 17 | vqa_net.cnn._modules.get(name).register_backward_hook(hook_grad) 18 | # if settings.GPU: 19 | # vqa_net.cnn.cuda() 20 | model = vqa_net 21 | elif settings.APP == "imagecap": 22 | from loader.caption_model import CaptionModel 23 | model = CaptionModel() 24 | model.load_checkpoint(settings.MODEL_FILE) 25 | for name in settings.FEATURE_NAMES: 26 | model.cnn._modules.get(name).register_forward_hook(hook_feature) 27 | model.cnn._modules.get(name).register_backward_hook(hook_grad) 28 | elif settings.APP == "classification": 29 | if settings.CAFFE_MODEL: 30 | from loader.caffe_model import CaffeNet_David, VGG16, CaffeNetCAM, VGG16CAM 31 | if settings.CNN_MODEL == "alexnet": 32 | model = CaffeNet_David() 33 | elif settings.CNN_MODEL == "vgg16": 34 | model = VGG16() 35 | elif settings.CNN_MODEL == "caffenetCAM": 36 | model = CaffeNetCAM() 37 | elif settings.CNN_MODEL == "vgg16CAM": 38 | model = VGG16CAM() 39 | model.load_state_dict(torch.load(settings.MODEL_FILE)) 40 | else: 41 | if settings.MODEL_FILE is None: 42 | model = torchvision.models.__dict__[settings.CNN_MODEL](pretrained=True) 43 | else: 44 | checkpoint = torch.load(settings.MODEL_FILE) 45 | if type(checkpoint).__name__ == 'OrderedDict' or type(checkpoint).__name__ == 'dict': 46 | model = torchvision.models.__dict__[settings.CNN_MODEL](num_classes=settings.NUM_CLASSES) 47 | if settings.MODEL_PARALLEL: 48 | state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint[ 49 | 'state_dict'].items()} # the data parallel layer will add 'module' before each layer name 50 | else: 51 | state_dict = checkpoint 52 | model.load_state_dict(state_dict) 53 | else: 54 | model = checkpoint 55 | for name in settings.FEATURE_NAMES: 56 | model._modules.get(name).register_forward_hook(hook_feature) 57 | model._modules.get(name).register_backward_hook(hook_grad) 58 | if settings.GPU: 59 | model.cuda() 60 | model.eval() 61 | return model 62 | -------------------------------------------------------------------------------- /loader/vqa_data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path 4 | import re 5 | 6 | from PIL import Image 7 | import h5py 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | import settings 12 | import os 13 | import numpy as np 14 | def load_cache(): 15 | pass 16 | 17 | def collate_fn(batch): 18 | # put question lengths in descending order so that we can use packed sequences later 19 | batch.sort(key=lambda x: x[-1], reverse=True) 20 | return data.dataloader.default_collate(batch) 21 | 22 | 23 | class VQA(data.Dataset): 24 | """ VQA dataset, open-ended """ 25 | def __init__(self, questions_path, answers_path, image_path, answerable_only=False, cache_path=None): 26 | super(VQA, self).__init__() 27 | # vocab 28 | self.token_to_index = None 29 | self.answer_to_index = None 30 | self.questions = None 31 | self.raw_questions = None 32 | self.answers = None 33 | self.raw_answers = None 34 | self.image_features_path = None 35 | self.coco_id_to_index = None 36 | self.coco_ids = None 37 | self.answerable_only = None 38 | self.answerable = None 39 | 40 | if cache_path is not None: 41 | if os.path.exists(cache_path): 42 | self.load_cache(cache_path) 43 | else: 44 | self.init(questions_path, answers_path, image_path, answerable_only=False) 45 | self.save_cache(cache_path) 46 | else: 47 | self.init(questions_path, answers_path, image_path, answerable_only=False) 48 | 49 | def init(self, questions_path, answers_path, image_path, answerable_only): 50 | with open(questions_path, 'r') as fd: 51 | questions_json = json.load(fd) 52 | with open(answers_path, 'r') as fd: 53 | answers_json = json.load(fd) 54 | with open(settings.VOCAB_FILE, 'r') as fd: 55 | vocab_json = json.load(fd) 56 | self._check_integrity(questions_json, answers_json) 57 | 58 | # vocab 59 | self.vocab = vocab_json 60 | self.token_to_index = self.vocab['question'] 61 | self.questions_tokens = list(self.vocab['question'].keys()) 62 | self.answer_to_index = self.vocab['answer'] 63 | self.answer_tokens = list(self.vocab['answer']) 64 | 65 | 66 | # q and a 67 | self.raw_questions = list(prepare_questions(questions_json)) 68 | self.raw_answers = list(prepare_answers(answers_json)) 69 | self.questions = [self._encode_question(q) for q in self.raw_questions] 70 | self.answers = [self._encode_answers(a) for a in self.raw_answers] 71 | 72 | # v 73 | # self.image_path = image_path 74 | # self.coco_id_to_index = np.load(settings.VQA_IMAGE_INDEX_FILE) 75 | self.coco_ids = [q['image_id'] for q in questions_json['questions']] 76 | 77 | def save_cache(self, cache_path): 78 | dict = { 79 | "questions" : self.questions, 80 | "answers" : self.answers, 81 | "token_to_index" : self.token_to_index, 82 | "answer_to_index" : self.answer_to_index, 83 | "image_features_path" : self.image_features_path, 84 | "coco_id_to_index" : self.coco_id_to_index, 85 | "coco_ids" : self.coco_ids, 86 | "answerable_only" : self.answerable_only, 87 | "answerable" : self.answerable, 88 | } 89 | np.save(cache_path, dict) 90 | 91 | def load_cache(self, cache_path): 92 | dict = np.load(cache_path) 93 | 94 | @property 95 | def max_question_length(self): 96 | if not hasattr(self, '_max_length'): 97 | self._max_length = max(map(len, self.raw_questions)) 98 | return self._max_length 99 | 100 | @property 101 | def num_tokens(self): 102 | return len(self.token_to_index) + 1 # add 1 for token at index 0 103 | 104 | def _create_coco_id_to_index(self): 105 | """ Create a mapping from a COCO image id into the corresponding index into the h5 file """ 106 | with h5py.File(self.image_features_path, 'r') as features_file: 107 | coco_ids = features_file['ids'][()] 108 | coco_id_to_index = {id: i for i, id in enumerate(coco_ids)} 109 | return coco_id_to_index 110 | 111 | def _check_integrity(self, questions, answers): 112 | """ Verify that we are using the correct data """ 113 | qa_pairs = list(zip(questions['questions'], answers['annotations'])) 114 | assert all(q['question_id'] == a['question_id'] for q, a in qa_pairs), 'Questions not aligned with answers' 115 | assert all(q['image_id'] == a['image_id'] for q, a in qa_pairs), 'Image id of question and answer don\'t match' 116 | # assert questions['data_type'] == answers['data_type'], 'Mismatched data types' 117 | # assert questions['data_subtype'] == answers['data_subtype'], 'Mismatched data subtypes' 118 | 119 | def _encode_question(self, question): 120 | """ Turn a question into a vector of indices and a question length """ 121 | vec = torch.zeros(self.max_question_length).long() 122 | for i, token in enumerate(question): 123 | index = self.token_to_index.get(token, 0) 124 | vec[i] = index 125 | return vec, len(question) 126 | 127 | def _encode_answers(self, answers): 128 | """ Turn an answer into a vector """ 129 | # answer vec will be a vector of answer counts to determine which answers will contribute to the loss. 130 | # this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up 131 | # to get the loss that is weighted by how many humans gave that answer 132 | answer_vec = torch.zeros(len(self.answer_to_index)) 133 | for answer in answers: 134 | index = self.answer_to_index.get(answer) 135 | if index is not None: 136 | answer_vec[index] += 1 137 | return answer_vec 138 | 139 | 140 | def __getitem__(self, item): 141 | q_org = self.raw_questions[item] 142 | a_org = self.raw_answers[item] 143 | q, q_length = self.questions[item] 144 | a = self.answers[item] 145 | image_id = self.coco_ids[item] 146 | v = os.path.join(settings.VQA_IMG_PATH, '%06d.jpg' % image_id) 147 | return v, q, a, item, q_length, q_org, a_org 148 | 149 | def __len__(self): 150 | return len(self.questions) 151 | 152 | 153 | # this is used for normalizing questions 154 | _special_chars = re.compile('[^a-z0-9 ]*') 155 | 156 | # these try to emulate the original normalization scheme for answers 157 | _period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)') 158 | _comma_strip = re.compile(r'(\d)(,)(\d)') 159 | _punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!') 160 | _punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars))) 161 | _punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars)) 162 | 163 | 164 | def prepare_questions(questions_json): 165 | """ Tokenize and normalize questions from a given question json in the usual VQA format. """ 166 | questions = [q['question'] for q in questions_json['questions']] 167 | for question in questions: 168 | question = question.lower()[:-1] 169 | yield question.split(' ') 170 | 171 | 172 | def prepare_answers(answers_json): 173 | """ Normalize answers from a given answer json in the usual VQA format. """ 174 | answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in answers_json['annotations']] 175 | # The only normalization that is applied to both machine generated answers as well as 176 | # ground truth answers is replacing most punctuation with space (see [0] and [1]). 177 | # Since potential machine generated answers are just taken from most common answers, applying the other 178 | # normalizations is not needed, assuming that the human answers are already normalized. 179 | # [0]: http://visualqa.org/evaluation.html 180 | # [1]: https://github.com/VT-vision-lab/VQA/blob/3849b1eae04a0ffd83f56ad6f70ebd0767e09e0f/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L96 181 | 182 | def process_punctuation(s): 183 | # the original is somewhat broken, so things that look odd here might just be to mimic that behaviour 184 | # this version should be faster since we use re instead of repeated operations on str's 185 | if _punctuation.search(s) is None: 186 | return s 187 | s = _punctuation_with_a_space.sub('', s) 188 | if re.search(_comma_strip, s) is not None: 189 | s = s.replace(',', '') 190 | s = _punctuation.sub(' ', s) 191 | s = _period_strip.sub('', s) 192 | return s.strip() 193 | 194 | for answer_list in answers: 195 | yield list(map(process_punctuation, answer_list)) 196 | 197 | 198 | class CocoImages(data.Dataset): 199 | """ Dataset for MSCOCO images located in a folder on the filesystem """ 200 | def __init__(self, path, transform=None): 201 | super(CocoImages, self).__init__() 202 | self.path = path 203 | self.id_to_filename = self._find_images() 204 | self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order 205 | print('found {} images in {}'.format(len(self), self.path)) 206 | self.transform = transform 207 | 208 | def _find_images(self): 209 | id_to_filename = {} 210 | for filename in os.listdir(self.path): 211 | if not filename.endswith('.jpg'): 212 | continue 213 | id_and_extension = filename.split('_')[-1] 214 | id = int(id_and_extension.split('.')[0]) 215 | id_to_filename[id] = filename 216 | return id_to_filename 217 | 218 | def __getitem__(self, item): 219 | id = self.sorted_ids[item] 220 | path = os.path.join(self.path, self.id_to_filename[id]) 221 | img = Image.open(path).convert('RGB') 222 | 223 | if self.transform is not None: 224 | img = self.transform(img) 225 | return id, img 226 | 227 | def __len__(self): 228 | return len(self.sorted_ids) 229 | 230 | 231 | class Composite(data.Dataset): 232 | """ Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """ 233 | def __init__(self, *datasets): 234 | self.datasets = datasets 235 | 236 | def __getitem__(self, item): 237 | current = self.datasets[0] 238 | for d in self.datasets: 239 | if item < len(d): 240 | return d[item] 241 | item -= len(d) 242 | else: 243 | raise IndexError('Index too large for composite dataset') 244 | 245 | def __len__(self): 246 | return sum(map(len, self.datasets)) 247 | -------------------------------------------------------------------------------- /loader/vqa_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | import settings 7 | 8 | 9 | class SimpleVQANet(nn.Module): 10 | def __init__(self, embedding_tokens): 11 | super(SimpleVQANet, self).__init__() 12 | question_features = 1024 13 | concat_feature = 1024 14 | vision_features = settings.OUTPUT_FEATURE_SIZE 15 | 16 | self.fc_q = nn.Linear(question_features, concat_feature) 17 | self.pool_v = nn.AvgPool2d(7) 18 | self.fc_v = nn.Linear(vision_features, concat_feature) 19 | 20 | self.text = TextProcessor( 21 | embedding_tokens=embedding_tokens, 22 | embedding_features=300, 23 | lstm_features=question_features, 24 | drop=0.5, 25 | ) 26 | 27 | self.classifier = nn.Linear(concat_feature, settings.MAX_ANSWERS) 28 | 29 | def forward(self, v, q, q_len): 30 | q = self.text(q, list(q_len.data)) 31 | q = self.fc_q(q) 32 | 33 | v = self.pool_v(v).squeeze() 34 | if len(v.shape) == 1: 35 | v = v.unsqueeze(0) 36 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 37 | v = self.fc_v(v) 38 | 39 | answer = self.classifier(v * q) 40 | return answer 41 | 42 | class VQANet(nn.Module): 43 | """ Re-implementation of ``Show, Ask, Attend, and Answer: A Strong Baseline For Visual Question Answering'' [0] 44 | 45 | [0]: https://arxiv.org/abs/1704.03162 46 | """ 47 | 48 | def __init__(self, embedding_tokens): 49 | super(VQANet, self).__init__() 50 | question_features = 1024 51 | vision_features = settings.OUTPUT_FEATURE_SIZE 52 | glimpses = 2 53 | self.cnn = None 54 | 55 | self.text = TextProcessor( 56 | embedding_tokens=embedding_tokens, 57 | embedding_features=300, 58 | lstm_features=question_features, 59 | drop=0.5, 60 | ) 61 | self.attention = Attention( 62 | v_features=vision_features, 63 | q_features=question_features, 64 | mid_features=512, 65 | glimpses=2, 66 | drop=0.5, 67 | ) 68 | self.classifier = Classifier( 69 | in_features=glimpses * vision_features + question_features, 70 | mid_features=1024, 71 | out_features=settings.MAX_ANSWERS, 72 | drop=0.5, 73 | ) 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 77 | init.xavier_uniform(m.weight) 78 | if m.bias is not None: 79 | m.bias.data.zero_() 80 | 81 | 82 | def forward(self, v, q, q_len): 83 | q = self.text(q, list(q_len.data)) 84 | 85 | v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8) 86 | a = self.attention(v, q) 87 | v = apply_attention(v, a) 88 | 89 | combined = torch.cat([v, q], dim=1) 90 | answer = self.classifier(combined) 91 | return answer 92 | 93 | 94 | class Classifier(nn.Sequential): 95 | def __init__(self, in_features, mid_features, out_features, drop=0.0): 96 | super(Classifier, self).__init__() 97 | self.add_module('drop1', nn.Dropout(drop)) 98 | self.add_module('lin1', nn.Linear(in_features, mid_features)) 99 | self.add_module('relu', nn.ReLU()) 100 | self.add_module('drop2', nn.Dropout(drop)) 101 | self.add_module('lin2', nn.Linear(mid_features, out_features)) 102 | 103 | 104 | class TextProcessor(nn.Module): 105 | def __init__(self, embedding_tokens, embedding_features, lstm_features, drop=0.0): 106 | super(TextProcessor, self).__init__() 107 | self.embedding = nn.Embedding(embedding_tokens, embedding_features, padding_idx=0) 108 | self.drop = nn.Dropout(drop) 109 | self.tanh = nn.Tanh() 110 | self.lstm = nn.LSTM(input_size=embedding_features, 111 | hidden_size=lstm_features, 112 | num_layers=1) 113 | self.features = lstm_features 114 | 115 | self._init_lstm(self.lstm.weight_ih_l0) 116 | self._init_lstm(self.lstm.weight_hh_l0) 117 | self.lstm.bias_ih_l0.data.zero_() 118 | self.lstm.bias_hh_l0.data.zero_() 119 | 120 | init.xavier_uniform(self.embedding.weight) 121 | 122 | def _init_lstm(self, weight): 123 | for w in weight.chunk(4, 0): 124 | init.xavier_uniform(w) 125 | 126 | def forward(self, q, q_len): 127 | embedded = self.embedding(q) 128 | tanhed = self.tanh(self.drop(embedded)) 129 | packed = pack_padded_sequence(tanhed, q_len, batch_first=True) 130 | _, (_, c) = self.lstm(packed) 131 | return c.squeeze(0) 132 | 133 | 134 | class Attention(nn.Module): 135 | def __init__(self, v_features, q_features, mid_features, glimpses, drop=0.0): 136 | super(Attention, self).__init__() 137 | self.v_conv = nn.Conv2d(v_features, mid_features, 1, bias=False) # let self.lin take care of bias 138 | self.q_lin = nn.Linear(q_features, mid_features) 139 | self.x_conv = nn.Conv2d(mid_features, glimpses, 1) 140 | 141 | self.drop = nn.Dropout(drop) 142 | self.relu = nn.ReLU(inplace=True) 143 | 144 | def forward(self, v, q): 145 | v = self.v_conv(self.drop(v)) 146 | q = self.q_lin(self.drop(q)) 147 | q = tile_2d_over_nd(q, v) 148 | x = self.relu(v + q) 149 | x = self.x_conv(self.drop(x)) 150 | return x 151 | 152 | 153 | def apply_attention(input, attention): 154 | """ Apply any number of attention maps over the input. 155 | The attention map has to have the same size in all dimensions except dim=1. 156 | """ 157 | n, c = input.size()[:2] 158 | glimpses = attention.size(1) 159 | 160 | # flatten the spatial dims into the third dim, since we don't need to care about how they are arranged 161 | input = input.view(n, c, -1) 162 | attention = attention.view(n, glimpses, -1) 163 | s = input.size(2) 164 | 165 | # apply a softmax to each attention map separately 166 | # since softmax only takes 2d inputs, we have to collapse the first two dimensions together 167 | # so that each glimpse is normalized separately 168 | attention = attention.view(n * glimpses, -1) 169 | attention = F.softmax(attention) 170 | 171 | # apply the weighting by creating a new dim to tile both tensors over 172 | target_size = [n, glimpses, c, s] 173 | input = input.view(n, 1, c, s).expand(*target_size) 174 | attention = attention.view(n, glimpses, 1, s).expand(*target_size) 175 | weighted = input * attention 176 | # sum over only the spatial dimension 177 | weighted_mean = weighted.sum(dim=3) 178 | # the shape at this point is (n, glimpses, c, 1) 179 | return weighted_mean.view(n, -1) 180 | 181 | 182 | def tile_2d_over_nd(feature_vector, feature_map): 183 | """ Repeat the same feature vector over all spatial positions of a given feature map. 184 | The feature vector should have the same batch size and number of features as the feature map. 185 | """ 186 | n, c = feature_vector.size() 187 | spatial_size = feature_map.dim() - 2 188 | tiled = feature_vector.view(n, c, *([1] * spatial_size)).expand_as(feature_map) 189 | return tiled 190 | -------------------------------------------------------------------------------- /loader/vqa_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://github.com/Cyanogenoid/pytorch-resnet/releases/download/hosting/resnet152-95e0e999.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | def __init__(self, block, layers, num_classes=1000): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AvgPool2d(7) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = F.adaptive_avg_pool2d(x, 1) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | 156 | def resnet18(pretrained=False): 157 | """Constructs a ResNet-18 model. 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 165 | return model 166 | 167 | 168 | def resnet34(pretrained=False): 169 | """Constructs a ResNet-34 model. 170 | 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | """ 174 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 175 | if pretrained: 176 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 177 | return model 178 | 179 | 180 | def resnet50(pretrained=False): 181 | """Constructs a ResNet-50 model. 182 | 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 189 | return model 190 | 191 | 192 | def resnet101(pretrained=False): 193 | """Constructs a ResNet-101 model. 194 | 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 199 | if pretrained: 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 201 | return model 202 | 203 | 204 | def resnet152(pretrained=False): 205 | """Constructs a ResNet-152 model. 206 | 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 211 | if pretrained: 212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 213 | return model 214 | -------------------------------------------------------------------------------- /result/pytorch_resnet18_places365/decompose.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/result/pytorch_resnet18_places365/decompose.npy -------------------------------------------------------------------------------- /result/pytorch_resnet18_places365/snapshot/14.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/result/pytorch_resnet18_places365/snapshot/14.pth -------------------------------------------------------------------------------- /script/dlbroden.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # Start from parent directory of script 5 | cd "$(dirname "$(dirname "$(readlink -f "$0")")")" 6 | 7 | # Download broden1_224 8 | if [ ! -f dataset/broden1_224/index.csv ] 9 | then 10 | 11 | echo "Downloading broden1_224" 12 | mkdir -p dataset 13 | pushd dataset 14 | wget --progress=bar \ 15 | http://netdissect.csail.mit.edu/data/broden1_224.zip \ 16 | -O broden1_224.zip 17 | unzip broden1_224.zip 18 | rm broden1_224.zip 19 | popd 20 | 21 | fi 22 | 23 | # Download broden1_227 24 | if [ ! -f dataset/broden1_227/index.csv ] 25 | then 26 | 27 | echo "Downloading broden1_227" 28 | mkdir -p dataset 29 | pushd dataset 30 | wget --progress=bar \ 31 | http://netdissect.csail.mit.edu/data/broden1_227.zip \ 32 | -O broden1_227.zip 33 | unzip broden1_227.zip 34 | rm broden1_227.zip 35 | popd 36 | 37 | fi 38 | -------------------------------------------------------------------------------- /script/dlzoo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # Start from parent directory of script 5 | 6 | echo "Downloading" 7 | mkdir -p zoo 8 | pushd zoo 9 | wget http://places2.csail.mit.edu/models_places365/resnet18_places365.pth.tar 10 | popd 11 | 12 | echo "done" 13 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | ######### global settings ######### 2 | GPU = True # running on GPU is highly suggested 3 | CLEAN = False # set to "True" if you want to clean the temporary large files after generating result 4 | APP = "classification" # Do not change! mode choide: "classification", "imagecap", "vqa". Currently "imagecap" and "vqa" are not supported. 5 | CATAGORIES = ["object", "part"] # Do not change! concept categories that are chosen to detect: "object", "part", "scene", "material", "texture", "color" 6 | 7 | CAM_THRESHOLD = 0.5 # the threshold used for CAM visualization 8 | FONT_PATH = "components/font.ttc" # font file path 9 | FONT_SIZE = 26 # font size 10 | SEG_RESOLUTION = 7 # the resolution of cam map 11 | BASIS_NUM = 7 # In decomposition, this is to decide how many concepts are used to interpret the weight vector of a class. 12 | 13 | EPOCHS = 15 # max epochs to train the concept classifier 14 | SNAPSHOT_FREQ = 5 # the frequence of making snapshot 15 | SINGLE_LABEL = False # Do not change. 16 | COMPRESSED_INDEX = True # Do not change. 17 | WORKERS = 4 # how much thread is used to extract images 18 | BATCH_SIZE = 128 # batch size when extracting image feature 19 | FEAT_BATCH_SIZE = 16 # feature's batch in training feature classifier 20 | TALLY_BATCH_SIZE = 4 # batch size when tallying concept 21 | TALLY_AHEAD = 4 # size of prefetching batch when tallying concept 22 | INDEX_FILE = 'index.csv' # image index of concept dataset 23 | 24 | CAFFE_MODEL = False # whether the model is transferred from "*.caffemodel". 25 | CNN_MODEL = 'resnet18' # model arch: resnet18, alexnet, resnet50, densenet161, etc... 26 | DATASET = 'places365' # model trained on: places365 or imagenet 27 | OUTPUT_FOLDER = "result/pytorch_"+CNN_MODEL+"_"+DATASET # where output file exists 28 | 29 | DATASET_PATH = 'components/sample_images' # where sample image folder exists 30 | DATASET_INDEX_FILE = 'components/sample_images/sample.txt' # a file list of sample images 31 | if (not CNN_MODEL.endswith('CAM')) and (CNN_MODEL == "alexnet" or CNN_MODEL.startswith('vgg')): 32 | GRAD_CAM = True # to decide if we have to decompose the grad-CAM for the chosen model 33 | else: 34 | GRAD_CAM = False 35 | if DATASET == 'places365': 36 | NUM_CLASSES = 365 # class amount of dataset 37 | if CNN_MODEL == 'resnet18': 38 | MODEL_FILE = 'zoo/resnet18_places365.pth.tar' # model filee's path 39 | MODEL_PARALLEL = True # if the model is trained by multi-GPUs 40 | 41 | 42 | if CNN_MODEL != 'alexnet' and CNN_MODEL != 'caffenetCAM': 43 | DATA_DIRECTORY = 'dataset/broden1_224' # concept dataset's path 44 | IMG_SIZE = 224 # image's size in the concept dataset 45 | else: 46 | DATA_DIRECTORY = 'dataset/broden1_227' 47 | IMG_SIZE = 227 48 | 49 | if CAFFE_MODEL: 50 | FEATURE_NAMES = ['pool5'] # the layer to be decomposed 51 | else: 52 | if 'resnet' in CNN_MODEL: 53 | FEATURE_NAMES = ['layer4'] 54 | elif CNN_MODEL == 'densenet161' or CNN_MODEL == 'alexnet' or CNN_MODEL.startswith('vgg'): 55 | FEATURE_NAMES = ['features'] 56 | 57 | 58 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import settings 4 | from loader.model_loader import loadmodel 5 | from util.feature_operation import FeatureOperator 6 | from util.clean import clean 7 | from util.feature_decoder import SingleSigmoidFeatureClassifier 8 | from util.image_operation import * 9 | from PIL import Image 10 | import numpy as np 11 | from scipy.misc import imresize, imread 12 | from visualize.plot import random_color 13 | from torch.autograd import Variable as V 14 | import torch 15 | 16 | 17 | model = loadmodel() 18 | fo = FeatureOperator() 19 | 20 | features, _ = fo.feature_extraction(model=model) 21 | 22 | for layer_id, layer in enumerate(settings.FEATURE_NAMES): 23 | feat_clf = SingleSigmoidFeatureClassifier(feature=features[layer_id], layer=layer, fo=fo) 24 | feat_clf.load_snapshot(14, unbiased=True) 25 | 26 | if not settings.GRAD_CAM: 27 | fo.weight_decompose(model, feat_clf, feat_labels=[l['name'] for l in fo.data.label]) 28 | 29 | with open(settings.DATASET_INDEX_FILE) as f: 30 | image_list = f.readlines() 31 | predictions = [] 32 | outpath = os.path.join(settings.OUTPUT_FOLDER, 'html', 'image') 33 | if not os.path.exists(outpath): 34 | os.makedirs(outpath) 35 | for image_ind, file in enumerate(image_list): 36 | print("generating figure on %03d" % image_ind) 37 | image_file = os.path.join(settings.DATASET_PATH, file.strip()) 38 | 39 | # feature extraction 40 | org_img = imread(image_file) 41 | org_img = imresize(org_img, (settings.IMG_SIZE, settings.IMG_SIZE)) 42 | if org_img.shape.__len__() == 2: 43 | org_img = org_img[:, :, None].repeat(3, axis=2) 44 | img_feat, img_grad, prediction_ind, prediction = fo.single_feature_extraction(model, org_img) 45 | if settings.COMPRESSED_INDEX: 46 | try: 47 | labels = [fo.data.label[concept] for concept in feat_clf.valid_concepts] 48 | except Exception: 49 | labels = [fo.data.label[concept] for concept in np.load('cache/valid_concept.npy')] 50 | 51 | else: 52 | labels = fo.data.label 53 | h, w, u = img_feat.shape 54 | 55 | # feature classification 56 | seg_resolution = settings.SEG_RESOLUTION 57 | img_feat_resized = np.zeros((seg_resolution, seg_resolution, u)) 58 | for i in range(u): 59 | img_feat_resized[:, :, i] = imresize(img_feat[:, :, i], (seg_resolution, seg_resolution), mode="F") 60 | img_feat_resized.shape = (seg_resolution * seg_resolution, u) 61 | 62 | concept_predicted = feat_clf.fc(V(torch.FloatTensor(img_feat_resized))) 63 | concept_predicted = concept_predicted.data.numpy().reshape(seg_resolution, seg_resolution, -1) 64 | # concept_predicted_reg = (concept_predicted - np.min(concept_predicted, 2, keepdims=True)) / np.max( 65 | # concept_predicted, 2, keepdims=True) 66 | 67 | concept_inds = concept_predicted.argmax(2) 68 | concept_colors = np.array(random_color(concept_predicted.shape[2])) * 256 69 | 70 | # feature visualization 71 | vis_size = settings.IMG_SIZE 72 | margin = int(vis_size / 30) 73 | img_cam = fo.cam_mat(img_feat * img_grad.mean((0, 1))[None, None, :], above_zero=False) 74 | vis_cam = vis_cam_mask(img_cam, org_img, vis_size) 75 | CONCEPT_CAM_TOPN = settings.BASIS_NUM 76 | CONCEPT_CAM_BOTTOMN = 0 77 | 78 | if settings.GRAD_CAM: 79 | weight_clf = feat_clf.fc.weight.data.numpy() 80 | weight_concept = weight_clf # np.maximum(weight_clf, 0) 81 | weight_concept = weight_concept / np.linalg.norm(weight_concept, axis=1)[:, None] 82 | target_weight = img_grad.mean((0, 1)) 83 | target_weight = target_weight / np.linalg.norm(target_weight) 84 | rankings, scores, coefficients, residuals = fo.decompose_Gram_Schmidt(weight_concept, 85 | target_weight[None, :], 86 | MAX=settings.BASIS_NUM) 87 | ranking = rankings[0] 88 | residual = residuals[0] 89 | d_e = np.linalg.norm(residuals[0]) ** 2 90 | 91 | component_weights = np.vstack( 92 | [coefficients[0][:settings.BASIS_NUM, None] * weight_concept[ranking], residual[None, :]]) 93 | a = img_feat.mean((0, 1)) 94 | a /= np.linalg.norm(a) 95 | qcas = np.dot(component_weights, a) 96 | combination_score = sum(abs(qcas)) 97 | inds = qcas[:-1].argsort()[:-CONCEPT_CAM_TOPN - 1:-1] 98 | concept_masks_ind = ranking[inds] 99 | scores_topn = coefficients[0][inds] 100 | contribution = qcas[inds] 101 | else: 102 | weight_label, weight_concept = fo.weight_extraction(model, feat_clf) 103 | 104 | rankings, errvar, coefficients, residuals_T = np.load( 105 | os.path.join(settings.OUTPUT_FOLDER, "decompose.npy")) 106 | ranking = rankings[prediction_ind].astype(int) 107 | residual = residuals_T.T[prediction_ind] 108 | d_e = np.linalg.norm(residual) ** 2 109 | component_weights = np.vstack( 110 | [coefficients[prediction_ind][:settings.BASIS_NUM, None] * weight_concept[ranking], 111 | residual[None, :]]) 112 | a = img_feat.mean((0, 1)) 113 | a /= np.linalg.norm(a) 114 | qcas = np.dot(component_weights, a) 115 | combination_score = sum(qcas) 116 | inds = qcas[:-1].argsort()[:-CONCEPT_CAM_TOPN - 1:-1] 117 | concept_masks_ind = ranking[inds] 118 | scores_topn = coefficients[prediction_ind][inds] 119 | contribution = qcas[inds] 120 | 121 | 122 | concept_masks = concept_predicted[:, :, concept_masks_ind] 123 | concept_masks = concept_masks * ((scores_topn > 0) * 1)[None, None, :] 124 | concept_masks = (np.maximum(concept_masks, 0)) / np.max(concept_masks) 125 | 126 | vis_concept_cam = [] 127 | for i in range(CONCEPT_CAM_TOPN + CONCEPT_CAM_BOTTOMN): 128 | vis_concept_cam.append(vis_cam_mask(concept_masks[:, :, i], org_img, vis_size, font_text=None)) 129 | 130 | vis_img = Image.fromarray(org_img).resize((vis_size, vis_size), resample=Image.BILINEAR) 131 | vis_bm = big_margin(vis_size) 132 | vis = imconcat([vis_img, vis_cam, vis_bm] + vis_concept_cam[:3], vis_size, vis_size, margin=margin) 133 | captions = [ 134 | "%s(%4.2f%%)" % (labels[concept_masks_ind[i]]['name'], contribution[i] * 100 / combination_score) 135 | for i in range(3)] 136 | captions = ["%s(%.2f) " % (prediction, combination_score)] + captions 137 | vis_headline = headline2(captions, vis_size, vis.height // 5, vis.width, margin=margin) 138 | vis = imstack([vis_headline, vis]) 139 | 140 | predictions.append(prediction) 141 | vis.save(os.path.join(outpath, "%03d.jpg" % image_ind)) 142 | 143 | f = open(os.path.join(settings.OUTPUT_FOLDER, 'html', 'result.html'), 'w') 144 | f.write("\n\n\n\n\n\n\n\n") 145 | for ind in range(len(predictions)): 146 | headline = "

%d %s


\n" % (ind, predictions[ind]) 147 | imageline = "\n" % (ind, settings.IMG_SIZE, os.path.join('image', '%03d.jpg' % ind)) 148 | f.write(headline) 149 | f.write(imageline) 150 | f.write("\n\n") 151 | f.close() 152 | 153 | 154 | if settings.CLEAN: 155 | clean() 156 | 157 | 158 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import settings 3 | from loader.model_loader import loadmodel 4 | from util.feature_operation import FeatureOperator 5 | from util.feature_decoder import SingleSigmoidFeatureClassifier 6 | from util.image_operation import * 7 | 8 | model = loadmodel() 9 | fo = FeatureOperator() 10 | features, _ = fo.feature_extraction(model=model) 11 | 12 | for layer_id, layer in enumerate(settings.FEATURE_NAMES): 13 | settings.GPU = False 14 | feat_clf = SingleSigmoidFeatureClassifier(feature=features[layer_id], layer=layer, fo=fo) 15 | feat_clf.run() 16 | 17 | 18 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/util/__init__.py -------------------------------------------------------------------------------- /util/clean.py: -------------------------------------------------------------------------------- 1 | import settings 2 | import os 3 | 4 | def clean(): 5 | filelist = [f for f in os.listdir(settings.OUTPUT_FOLDER) if f.endswith('mmap')] 6 | for f in filelist: 7 | os.remove(os.path.join(settings.OUTPUT_FOLDER, f)) 8 | -------------------------------------------------------------------------------- /util/experiments.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim 6 | import torch.utils.data 7 | import torch.utils.data.distributed 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | from easydict import EasyDict as edict 11 | import settings 12 | import numpy as np 13 | 14 | EXP_SETTINGS = edict({ 15 | "print_freq": 10, 16 | "batch_size": 64, 17 | "data_imagenet": "/home/sunyiyou/dataset/imagenet/", 18 | "data_places365": "/home/sunyiyou/dataset/places365_standard/", 19 | "workers": 12, 20 | }) 21 | 22 | class AverageMeter(object): 23 | """Computes and stores the average and current value""" 24 | def __init__(self): 25 | self.reset() 26 | 27 | def reset(self): 28 | self.val = 0 29 | self.avg = 0 30 | self.sum = 0 31 | self.count = 0 32 | 33 | def update(self, val, n=1): 34 | self.val = val 35 | self.sum += val * n 36 | self.count += n 37 | self.avg = self.sum / self.count 38 | 39 | def test_clf_power(feat_clf, model): 40 | def accuracy(output, target, topk=(1,)): 41 | """Computes the precision@k for the specified values of k""" 42 | maxk = max(topk) 43 | batch_size = target.size(0) 44 | 45 | _, pred = output.topk(maxk, 1, True, True) 46 | pred = pred.t() 47 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 48 | 49 | res = [] 50 | for k in topk: 51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 52 | res.append(correct_k.mul_(100.0 / batch_size)) 53 | return res 54 | 55 | def validate(val_loader, model, criterion): 56 | batch_time = AverageMeter() 57 | losses = AverageMeter() 58 | top1 = AverageMeter() 59 | top5 = AverageMeter() 60 | 61 | # switch to evaluate mode 62 | model.eval() 63 | model.cuda() 64 | 65 | end = time.time() 66 | for i, (input, target) in enumerate(val_loader): 67 | target = target.cuda() 68 | input = input.cuda() 69 | input_var = torch.autograd.Variable(input, volatile=True) 70 | target_var = torch.autograd.Variable(target, volatile=True) 71 | 72 | # compute output 73 | output = model(input_var) 74 | loss = criterion(output, target_var) 75 | 76 | # measure accuracy and record loss 77 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 78 | losses.update(loss.data[0], input.size(0)) 79 | top1.update(prec1[0], input.size(0)) 80 | top5.update(prec5[0], input.size(0)) 81 | 82 | # measure elapsed time 83 | batch_time.update(time.time() - end) 84 | end = time.time() 85 | 86 | if i % EXP_SETTINGS.print_freq == 0: 87 | print('Test: [{0}/{1}]\t' 88 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 89 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 90 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 91 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 92 | i, len(val_loader), batch_time=batch_time, loss=losses, 93 | top1=top1, top5=top5)) 94 | 95 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 96 | .format(top1=top1, top5=top5)) 97 | 98 | return top1.avg 99 | 100 | # weight_label = list(model.parameters())[-2].data.numpy() 101 | # norm_weight_label = np.linalg.norm(weight_label, axis=1) 102 | # rankings, errvar, coefficients, residuals_T = np.load(os.path.join(settings.OUTPUT_FOLDER, "decompose_pos.npy")) 103 | # new_weight = (weight_label / norm_weight_label[:, None] - residuals_T.T) * norm_weight_label[:, None] 104 | # new_weight = residuals_T.T * norm_weight_label[:, None] 105 | # model.fc.weight.data.copy_(torch.Tensor(new_weight)) 106 | criterion = nn.CrossEntropyLoss().cuda() 107 | valdir = os.path.join(EXP_SETTINGS.data_places365, 'val') 108 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 109 | std=[0.229, 0.224, 0.225]) 110 | val_loader = torch.utils.data.DataLoader( 111 | datasets.ImageFolder(valdir, transforms.Compose([ 112 | transforms.Resize(256), 113 | transforms.CenterCrop(224), 114 | transforms.ToTensor(), 115 | normalize, 116 | ])), 117 | batch_size=EXP_SETTINGS.batch_size, shuffle=False, 118 | num_workers=EXP_SETTINGS.workers, pin_memory=True) 119 | 120 | prec1 = validate(val_loader, model, criterion) 121 | print(prec1) 122 | -------------------------------------------------------------------------------- /util/feature_decoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn as nn, optim as optim 7 | from torch.autograd import Variable as V 8 | import torch.nn.functional as F 9 | from sklearn.metrics import average_precision_score 10 | import settings 11 | from loader.feature_loader import feature_loader, concept_loader, concept_loader_factory, ConceptDataset 12 | 13 | 14 | class FeatureClassifier(nn.Module): 15 | def __init__(self,): 16 | super(FeatureClassifier, self).__init__() 17 | self.epoch = 0 18 | 19 | def forward(self, input): 20 | raise NotImplementedError 21 | 22 | def load_snapshot(self, epoch, unbiased=True): 23 | self.epoch = epoch 24 | self.load_state_dict(torch.load(os.path.join(settings.OUTPUT_FOLDER, "snapshot", "%d.pth" % epoch))) 25 | if unbiased: 26 | w = self.fc.weight.data 27 | w_ub = w - w.mean(1)[:, None] 28 | self.fc.weight.data.copy_(w_ub) 29 | 30 | def save_snapshot(self, epoch): 31 | torch.save(self.state_dict(), os.path.join(settings.OUTPUT_FOLDER, "snapshot", "%d.pth" % epoch)) 32 | 33 | def val(self): 34 | raise NotImplementedError 35 | 36 | def train(self): 37 | raise NotImplementedError 38 | 39 | class SemanticFeatureClassifier(FeatureClassifier): 40 | def __init__(self, feat_len, concept_size): 41 | super(FeatureClassifier, self).__init__() 42 | self.fc = nn.Linear(feat_len, concept_size) 43 | 44 | def forward(self, input): 45 | return self.fc(input) 46 | 47 | def val(self, feature, layer, fo): 48 | filename = os.path.join(settings.OUTPUT_FOLDER, "%s-val-accuracy.npy" % layer) 49 | if os.path.exists(filename): 50 | label_count, label_top1_accuracy, label_top5_accuracy, epoch_error, top1, top5 = np.load(filename) 51 | else: 52 | loss = nn.CrossEntropyLoss() 53 | if settings.GPU: 54 | loss.cuda() 55 | feat_loader_train, feat_loader_test = feature_loader(feature, layer, fo.data, len(fo.data.label), split=True) 56 | epoch_error = 0 57 | label_count = torch.zeros(len(fo.data.label)) 58 | label_top1_accuracy = torch.zeros(len(fo.data.label)).fill_(-1e-10) 59 | label_top5_accuracy = torch.zeros(len(fo.data.label)).fill_(-1e-10) 60 | start_time = time.time() 61 | last_batch_time = start_time 62 | for i, (feat, label) in enumerate(feat_loader_test): 63 | if type(feat) == int: 64 | continue 65 | batch_time = time.time() 66 | rate = i * settings.FEAT_BATCH_SIZE / (batch_time - start_time + 1e-15) 67 | batch_rate = settings.FEAT_BATCH_SIZE / (batch_time - last_batch_time + 1e-15) 68 | last_batch_time = batch_time 69 | 70 | feat_var = V(feat, requires_grad=True) 71 | if settings.GPU: 72 | feat_var.cuda() 73 | out = self.forward(feat_var) 74 | err = loss(out, V(label)) 75 | epoch_error += 1 / (i + 1) * (err.data[0] - epoch_error) 76 | top5 = label[:, None] == torch.topk(out, 5, 1)[1].data 77 | for b_i in range(len(label)): 78 | label_count[label[b_i]] += 1 79 | label_top1_accuracy[label[b_i]] += 1 / label_count[label[b_i]] * (top5[b_i, 0] - label_top1_accuracy[label[b_i]]) 80 | label_top5_accuracy[label[b_i]] += 1 / label_count[label[b_i]] * (float(top5[b_i].sum()) - label_top5_accuracy[label[b_i]]) 81 | top1 = ((label_top1_accuracy >= 0).float() * label_top1_accuracy).sum() / (label_top1_accuracy >= 0).sum() 82 | top5 = ((label_top5_accuracy >= 0).float() * label_top5_accuracy).sum() / (label_top5_accuracy >= 0).sum() 83 | print("val epoch [%d/%d]: batch error %.4f, overall error %.4f, top <1> %.4f, top <5> %.4f, item per second %.4f, %.4f" % ( 84 | i + 1, len(feat_loader_test), err.data[0], epoch_error, top1, top5, batch_rate, rate)) 85 | if settings.GPU: 86 | label_count, label_top1_accuracy, label_top5_accuracy = label_count.cpu().numpy(), label_top1_accuracy.cpu().numpy(), label_top5_accuracy.cpu().numpy() 87 | else: 88 | label_count, label_top1_accuracy, label_top5_accuracy = label_count.numpy(), label_top1_accuracy.numpy(), label_top5_accuracy.numpy() 89 | np.save(filename, (label_count, label_top1_accuracy, label_top5_accuracy, epoch_error, top1, top5)) 90 | import matplotlib.pyplot as plt 91 | plt.figure(figsize=(24,4)) 92 | plt.plot(range(len(label_top1_accuracy[label_top1_accuracy >= 0])), label_top1_accuracy[label_top1_accuracy >= 0], label='top 1') 93 | plt.plot(range(len(label_top5_accuracy[label_top5_accuracy >= 0])), label_top5_accuracy[label_top5_accuracy >= 0], label='top 5') 94 | plt.legend(loc='upper right') 95 | plt.title("%s \n error:%.4f, top <1> %.4f, top <5> %.4f" % (settings.CNN_MODEL, epoch_error, top1, top5)) 96 | plt.tight_layout() 97 | plt.savefig(os.path.join(settings.OUTPUT_FOLDER, 'html', 'image', "accuracy_distribute.jpg")) 98 | 99 | def train(self, feature, layer, fo): 100 | optimizer = optim.SGD(self.parameters(), lr=0.02) 101 | # loss = nn.MSELoss() 102 | loss = nn.CrossEntropyLoss() 103 | if settings.GPU: 104 | loss.cuda() 105 | feat_loader_train, feat_loader_test = concept_loader(feature, layer, fo.data, len(fo.data.label)) 106 | for epoch in range(self.epoch+1, settings.EPOCHS+1): 107 | 108 | # training 109 | training_epoch_error = 0 110 | start_time = time.time() 111 | last_batch_time = start_time 112 | for i, (feat, label) in enumerate(feat_loader_train): 113 | if type(feat) == int: 114 | continue 115 | # feat = feat.view((settings.FEAT_BATCH_SIZE * settings.SEG_RESOLUTION ** 2, -1)) 116 | # label = label.view((settings.FEAT_BATCH_SIZE * settings.SEG_RESOLUTION ** 2)).long() 117 | batch_time = time.time() 118 | rate = i * settings.FEAT_BATCH_SIZE / (batch_time - start_time + 1e-15) 119 | batch_rate = settings.FEAT_BATCH_SIZE / (batch_time - last_batch_time + 1e-15) 120 | last_batch_time = batch_time 121 | 122 | feat_var = V(feat, requires_grad=True) 123 | if settings.GPU: 124 | feat_var.cuda() 125 | out = self.forward(feat_var) 126 | err = loss(out, V(label)) 127 | training_epoch_error += 1 / (i + 1) * (err.data[0] - training_epoch_error) 128 | print("training epoch [%d/%d][%d/%d]: batch error %.6f, overall error %.6f, item per second %.4f, %.4f" % ( 129 | epoch, settings.EPOCHS, i + 1, len(feat_loader_train), err.data[0], training_epoch_error, batch_rate, rate)) 130 | optimizer.zero_grad() 131 | err.backward() 132 | optimizer.step() 133 | 134 | # validation 135 | if feat_loader_test: 136 | val_epoch_error = 0 137 | for i, (feat, label) in enumerate(feat_loader_test): 138 | feat_var = V(feat, volatile=True) 139 | if settings.GPU: 140 | feat_var.cuda() 141 | out = self.forward(feat_var) 142 | err = loss(out, V(label)) 143 | val_epoch_error += 1 / (i + 1) * (err.data[0] - val_epoch_error) 144 | print("validation epoch [%d/%d][%d/%d]: batch error %.6f, overall error %.6f" % ( 145 | epoch, settings.EPOCHS, i + 1, len(feat_loader_test), err.data[0], val_epoch_error)) 146 | 147 | if epoch % settings.SNAPSHOT_FREQ == 0: 148 | self.save_snapshot(epoch) 149 | 150 | class IndexLinear(nn.Linear): 151 | def __init__(self, in_features, out_features, bias=True): 152 | super(IndexLinear, self).__init__(in_features, out_features, bias) 153 | 154 | def forward(self, input, id=None): 155 | if id is not None: 156 | return F.linear(input, self.weight[id:id+1, :], self.bias[id:id+1]) 157 | else: 158 | return F.linear(input, self.weight, self.bias) 159 | 160 | from torch.nn.modules.loss import _Loss 161 | class NegWLoss(_Loss): 162 | 163 | def __init__(self, size_average=True, alpha=0.01): 164 | super(NegWLoss, self).__init__(size_average) 165 | self.alpha = alpha 166 | 167 | def forward(self, weight): 168 | return self.alpha * weight.mean()#F.relu(-weight).sum() 169 | 170 | 171 | 172 | class SingleSigmoidFeatureClassifier(FeatureClassifier): 173 | def __init__(self, feature=None, layer=None, fo=None): 174 | super(SingleSigmoidFeatureClassifier, self).__init__() 175 | 176 | self.dataset = ConceptDataset(feature, layer, fo.data, len(fo.data.label), ) 177 | # self.feat_loader = feature_loader(feature, layer, fo.data, len(fo.data.label)) 178 | self.loader_factory = concept_loader_factory(feature, layer, fo.data, len(fo.data.label), concept_dataset=self.dataset) 179 | self.valid_concepts = self.dataset.concept_count.nonzero()[0] 180 | self.feat = feature 181 | self.layer_name = layer 182 | self.fo = fo 183 | self.concept_size = len(self.valid_concepts) 184 | self.display_epoch = 100 185 | if feature is None: 186 | self.fc = IndexLinear(1024, 660) 187 | else: 188 | self.fc = IndexLinear(feature.shape[1], self.concept_size) 189 | # self.fc.weight 190 | self.sig = nn.Sigmoid() 191 | 192 | self.loss_mse = nn.MSELoss() 193 | self.loss_weight = NegWLoss() 194 | self.optimizer = optim.SGD(self.parameters(), lr=1e-2) 195 | if settings.GPU: 196 | self.loss_mse.cuda() 197 | self.loss_weight.cuda() 198 | 199 | def forward(self, input, id=None): 200 | return self.sig(self.fc(input, id)) 201 | 202 | def run(self): 203 | history = np.memmap(os.path.join(settings.OUTPUT_FOLDER, "mAP_table.mmap"), dtype=float, mode='w+', shape=(self.concept_size, settings.EPOCHS)) 204 | neg_test_loaders = self.loader_factory.negative_test_concept_loader(sample_ratio=20, verbose=False) 205 | neg_scores = None 206 | for epoch in range(self.epoch, settings.EPOCHS): 207 | concept_train_loaders, concept_val_loaders = self.loader_factory.negative_mining_loader(neg_scores=neg_scores) 208 | neg_scores = [None] * self.concept_size 209 | for c_i in range(self.concept_size): 210 | train_loader = concept_train_loaders[c_i] 211 | val_loader = concept_val_loaders[c_i] 212 | test_loader = neg_test_loaders[c_i] 213 | self.train(train_loader, c_i, epoch) 214 | neg_scores[c_i] = self.test(test_loader, c_i, epoch) 215 | history[c_i, epoch] = self.val(val_loader, c_i, epoch) 216 | self.save_snapshot(epoch) 217 | np.save(os.path.join(settings.OUTPUT_FOLDER, "snapshot", "neg_scores.npy"), neg_scores) 218 | 219 | def run_naive(self): 220 | # history = np.memmap(os.path.join(settings.OUTPUT_FOLDER, "mAP_table.mmap"), dtype=float, mode='w+', shape=(self.concept_size, settings.EPOCHS)) 221 | for epoch in range(self.epoch, settings.EPOCHS): 222 | concept_train_loaders, concept_val_loaders = self.loader_factory.random_concept_loader() 223 | for c_i in range(self.concept_size): 224 | train_loader = concept_train_loaders[c_i] 225 | # val_loader = concept_val_loaders[c_i] 226 | self.train(train_loader, c_i, epoch) 227 | # history[c_i, epoch] = self.val(val_loader, c_i, epoch) 228 | self.save_snapshot(epoch) 229 | 230 | def run_fix_eval(self): 231 | concept_val_loaders = self.loader_factory.fixed_val_loader() 232 | aps = np.zeros(self.concept_size) 233 | for c_i in range(self.concept_size): 234 | val_loader = concept_val_loaders[c_i] 235 | aps[c_i] = self.val(val_loader, c_i, settings.EPOCHS) 236 | np.save(os.path.join(settings.OUTPUT_FOLDER, "mAP_val.npy"), aps) 237 | print("mAP {:4.4f}%".format(aps.mean())) 238 | 239 | def val(self, val_loader, c_i, epoch): 240 | val_epoch_error = 0 241 | val_label = np.zeros(len(val_loader.dataset)) 242 | val_score = np.zeros(len(val_loader.dataset)) 243 | for i, (ind, feat, label) in enumerate(val_loader): 244 | start_ind = i * settings.BATCH_SIZE 245 | end_ind = i * settings.BATCH_SIZE + len(feat) 246 | feat_var = V(feat, requires_grad=True) 247 | if settings.GPU: 248 | feat_var.cuda() 249 | out = self.forward(feat_var, c_i) 250 | err = self.loss_mse(out, V(label)) 251 | val_epoch_error += 1 / (i + 1) * (err.data[0] - val_epoch_error) 252 | val_label[start_ind:end_ind] = label.numpy() 253 | val_score[start_ind:end_ind] = out[:,0].data.numpy() 254 | if i % self.display_epoch == 0: 255 | print("val epoch [%d/%d][%d/%d][%d/%d]: batch error %.6f, overall error %.6f" % ( 256 | c_i + 1, len(self.valid_concepts), 257 | epoch, settings.EPOCHS, i + 1, len(val_loader), err.data[0], val_epoch_error, 258 | )) 259 | AP = average_precision_score(val_label, val_score) 260 | print("Concept {:d}, epoch{:d}, AP: {:4.2f}%".format(c_i, epoch, AP * 100)) 261 | return AP 262 | 263 | # def run_val_resize(self, size): 264 | # 265 | 266 | def train(self, train_loader, c_i, epoch): 267 | training_epoch_error = 0 268 | start_time = time.time() 269 | last_batch_time = start_time 270 | for i, (ind, feat, label) in enumerate(train_loader): 271 | batch_time = time.time() 272 | rate = i * settings.FEAT_BATCH_SIZE / (batch_time - start_time + 1e-15) 273 | batch_rate = settings.FEAT_BATCH_SIZE / (batch_time - last_batch_time + 1e-15) 274 | last_batch_time = batch_time 275 | 276 | feat_var = V(feat, requires_grad=True) 277 | if settings.GPU: 278 | feat_var.cuda() 279 | out = self.forward(feat_var, c_i) 280 | err_mse = self.loss_mse(out, V(label)) 281 | # err_weight = self.loss_weight(self.fc.weight[c_i]) 282 | err_weight = V(torch.FloatTensor([0])) 283 | err = err_mse #+ err_weight 284 | training_epoch_error += 1 / (i + 1) * (err_mse.data[0] - training_epoch_error) 285 | if i % self.display_epoch == 0: 286 | print("training epoch [%d/%d][%d/%d][%d/%d]: mse error %.5f, weight loss %.5f overall mse error %.5f, item per second %.4f, %.4f" % ( 287 | c_i + 1, len(self.valid_concepts), 288 | epoch, settings.EPOCHS, i + 1, len(train_loader), err_mse.data[0], err_weight.data[0], training_epoch_error, 289 | batch_rate, rate)) 290 | self.optimizer.zero_grad() 291 | err.backward() 292 | self.optimizer.step() 293 | # self.fc.weight[c_i].data.copy_(F.relu(self.fc.weight[c_i]).data) 294 | 295 | def test(self, test_loader, c_i, epoch): 296 | training_epoch_error = 0 297 | start_time = time.time() 298 | last_batch_time = start_time 299 | prediction_scores = np.zeros(len(test_loader.dataset)) 300 | for i, (ind, feat, label) in enumerate(test_loader): 301 | batch_time = time.time() 302 | rate = i * settings.FEAT_BATCH_SIZE / (batch_time - start_time + 1e-15) 303 | batch_rate = settings.FEAT_BATCH_SIZE / (batch_time - last_batch_time + 1e-15) 304 | last_batch_time = batch_time 305 | 306 | start_ind = i * settings.BATCH_SIZE 307 | end_ind = i * settings.BATCH_SIZE + len(feat) 308 | 309 | feat_var = V(feat, requires_grad=True) 310 | if settings.GPU: 311 | feat_var.cuda() 312 | out = self.forward(feat_var, c_i) 313 | err = self.loss_mse(out, V(label)) 314 | prediction_scores[start_ind:end_ind] = out.data.squeeze().numpy() 315 | training_epoch_error += 1 / (i + 1) * (err.data[0] - training_epoch_error) 316 | if i % self.display_epoch == 0: 317 | print( 318 | "testing epoch [%d/%d][%d/%d][%d/%d]: batch error %.6f, overall error %.6f, item per second %.4f, %.4f" % ( 319 | c_i + 1, len(self.valid_concepts), 320 | epoch, settings.EPOCHS, i + 1, len(test_loader), err.data[0], training_epoch_error, 321 | batch_rate, rate)) 322 | return prediction_scores 323 | -------------------------------------------------------------------------------- /util/image_operation.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | from PIL import ImageDraw, ImageFont 3 | import skimage.measure 4 | import numpy as np 5 | import settings 6 | import cv2 7 | from scipy.misc import imresize 8 | from PIL import Image 9 | 10 | def imconcat(imgs, w, h, margin=0): 11 | w = sum([img.width for img in imgs]) 12 | ret = PIL.Image.new("RGB", (w + (len(imgs) - 1) * margin, imgs[0].height), color=(255,255,255)) 13 | w_pre = 0 14 | for i, img in enumerate(imgs): 15 | ret.paste(img, (w_pre+margin*int(bool(i)), 0)) 16 | w_pre += img.width+margin*int(bool(i)) 17 | # ret = PIL.Image.new("RGB", (len(imgs) * w + (len(imgs) - 1) * margin,h), color=(255,255,255)) 18 | # for i, img in enumerate(imgs): 19 | # ret.paste(img, ((w+margin)*i,0)) 20 | return ret 21 | 22 | 23 | def imstack(imgs): 24 | h = sum([img.height for img in imgs]) 25 | ret = PIL.Image.new("RGB", (imgs[0].width, h)) 26 | h_pre = 0 27 | for i, img in enumerate(imgs): 28 | ret.paste(img, (0, h_pre)) 29 | h_pre += img.height 30 | return ret 31 | 32 | 33 | 34 | def vis_cam_mask(cam_mat, org_img, vis_size, font_text=None): 35 | cam_mask = 255 * imresize(cam_mat, (settings.IMG_SIZE, settings.IMG_SIZE), mode="F") 36 | cam_mask = cv2.applyColorMap(np.uint8(cam_mask), cv2.COLORMAP_JET)[:, :, ::-1] 37 | vis_cam = cam_mask * 0.5 + org_img * 0.5 38 | vis_cam = Image.fromarray(vis_cam.astype(np.uint8)) 39 | vis_cam = vis_cam.resize((vis_size, vis_size), resample=Image.BILINEAR) 40 | 41 | # if font_text is not None: 42 | # font = ImageFont.truetype(settings.FONT_PATH, settings.FONT_SIZE+4) 43 | # draw = ImageDraw.Draw(vis_cam) 44 | # fw, fh = draw.textsize(font_text) 45 | # coord = np.array(np.unravel_index(cam_mat.argmax(),cam_mat.shape)) * vis_size / cam_mat.shape[0] 46 | # draw.text((coord[1], coord[0]), font_text, font=font, fill=(240, 240, 240, 255)) 47 | 48 | return vis_cam 49 | 50 | def label_seg(img, vis_size, labels, concept_inds, cam=None): 51 | h, w = concept_inds.shape 52 | grid_size = vis_size / settings.SEG_RESOLUTION 53 | draw = ImageDraw.Draw(img) 54 | font = ImageFont.truetype(settings.FONT_PATH, settings.FONT_SIZE) 55 | X, Y = np.meshgrid(np.arange(7) * grid_size, np.arange(7) * grid_size) 56 | 57 | label_img = skimage.measure.label(concept_inds, connectivity=1) 58 | cpt_groups = skimage.measure.regionprops(label_img) 59 | for cpt_group in cpt_groups: 60 | # y_start, x_start, y_end, x_end = cpt_group.bbox 61 | label = labels[concept_inds[tuple(cpt_group.coords[0])]]['name'] 62 | fw, fh = draw.textsize(label) 63 | coord = np.array(cpt_group.centroid)[::-1] * grid_size 64 | draw.text((coord[0] + (grid_size - fw) / 2, coord[1] + (grid_size - fh) / 2 ), label, font=font, fill=(0,0,0, 255)) 65 | 66 | contours = skimage.measure.find_contours(cam, cam.max() * settings.CAM_THRESHOLD) 67 | for contour in contours: 68 | draw.line(list((contour[:,::-1] * vis_size / cam.shape[0]).ravel()), fill=(255, 200, 0)) 69 | 70 | 71 | def headline(captions, vis_size, height, width, margin=3): 72 | vis_headline = Image.fromarray(np.full((height, width, 3), 255, dtype=np.int8), mode="RGB") 73 | draw = ImageDraw.Draw(vis_headline) 74 | font = ImageFont.truetype(settings.FONT_PATH, settings.FONT_SIZE) 75 | for i in range(len(captions)): 76 | label = captions[i] 77 | fw, fh = draw.textsize(label) 78 | coord = ((vis_size+margin) * i, 0) 79 | draw.text((coord[0] + (vis_size - fw) / 2, coord[1] + (height - fh) / 2), label, font=font, 80 | fill=(0, 0, 0, 255)) 81 | return vis_headline 82 | 83 | def headline2(captions, vis_size, height, width, margin=3): 84 | vis_headline = Image.fromarray(np.full((height, width, 3), 255, dtype=np.int8), mode="RGB") 85 | draw = ImageDraw.Draw(vis_headline) 86 | font = ImageFont.truetype(settings.FONT_PATH, settings.FONT_SIZE) 87 | for i in range(len(captions)): 88 | label = captions[i] 89 | fw, fh = draw.textsize(label) 90 | if i == 0: 91 | draw.text(((vis_size * 2 - fw*1.85) / 2, (height - fh*1.85) / 2), label, font=font, fill=(0, 0, 0, 255)) 92 | else: 93 | coord = (vis_size *7 // 3 + (vis_size + margin) * (i - 1), 0) 94 | draw.text((coord[0] + (vis_size - fw*1.85) / 2, coord[1] + (height - fh*1.85) / 2), label, font=font, fill=(0, 0, 0, 255)) 95 | 96 | return vis_headline 97 | 98 | def big_margin(vis_size): 99 | w = vis_size // 3 100 | h = vis_size 101 | canvas = Image.fromarray(np.full((h, w, 3), 255, dtype=np.int8), mode="RGB") 102 | draw = ImageDraw.Draw(canvas) 103 | font = ImageFont.truetype(settings.FONT_PATH, settings.FONT_SIZE) 104 | label = "=" 105 | draw.text(((w-settings.FONT_SIZE*0.5) / 2, (vis_size - settings.FONT_SIZE * 1.8) / 2), label, font=font, fill=(0, 0, 0, 255)) 106 | return canvas -------------------------------------------------------------------------------- /util/places365_categories.py: -------------------------------------------------------------------------------- 1 | places365_categories = { 2 | 0: 'airfield', 3 | 1: 'airplane_cabin', 4 | 2: 'airport_terminal', 5 | 3: 'alcove', 6 | 4: 'alley', 7 | 5: 'amphitheater', 8 | 6: 'amusement_arcade', 9 | 7: 'amusement_park', 10 | 8: 'apartment_building/outdoor', 11 | 9: 'aquarium', 12 | 10: 'aqueduct', 13 | 11: 'arcade', 14 | 12: 'arch', 15 | 13: 'archaelogical_excavation', 16 | 14: 'archive', 17 | 15: 'arena/hockey', 18 | 16: 'arena/performance', 19 | 17: 'arena/rodeo', 20 | 18: 'army_base', 21 | 19: 'art_gallery', 22 | 20: 'art_school', 23 | 21: 'art_studio', 24 | 22: 'artists_loft', 25 | 23: 'assembly_line', 26 | 24: 'athletic_field/outdoor', 27 | 25: 'atrium/public', 28 | 26: 'attic', 29 | 27: 'auditorium', 30 | 28: 'auto_factory', 31 | 29: 'auto_showroom', 32 | 30: 'badlands', 33 | 31: 'bakery/shop', 34 | 32: 'balcony/exterior', 35 | 33: 'balcony/interior', 36 | 34: 'ball_pit', 37 | 35: 'ballroom', 38 | 36: 'bamboo_forest', 39 | 37: 'bank_vault', 40 | 38: 'banquet_hall', 41 | 39: 'bar', 42 | 40: 'barn', 43 | 41: 'barndoor', 44 | 42: 'baseball_field', 45 | 43: 'basement', 46 | 44: 'basketball_court/indoor', 47 | 45: 'bathroom', 48 | 46: 'bazaar/indoor', 49 | 47: 'bazaar/outdoor', 50 | 48: 'beach', 51 | 49: 'beach_house', 52 | 50: 'beauty_salon', 53 | 51: 'bedchamber', 54 | 52: 'bedroom', 55 | 53: 'beer_garden', 56 | 54: 'beer_hall', 57 | 55: 'berth', 58 | 56: 'biology_laboratory', 59 | 57: 'boardwalk', 60 | 58: 'boat_deck', 61 | 59: 'boathouse', 62 | 60: 'bookstore', 63 | 61: 'booth/indoor', 64 | 62: 'botanical_garden', 65 | 63: 'bow_window/indoor', 66 | 64: 'bowling_alley', 67 | 65: 'boxing_ring', 68 | 66: 'bridge', 69 | 67: 'building_facade', 70 | 68: 'bullring', 71 | 69: 'burial_chamber', 72 | 70: 'bus_interior', 73 | 71: 'bus_station/indoor', 74 | 72: 'butchers_shop', 75 | 73: 'butte', 76 | 74: 'cabin/outdoor', 77 | 75: 'cafeteria', 78 | 76: 'campsite', 79 | 77: 'campus', 80 | 78: 'canal/natural', 81 | 79: 'canal/urban', 82 | 80: 'candy_store', 83 | 81: 'canyon', 84 | 82: 'car_interior', 85 | 83: 'carrousel', 86 | 84: 'castle', 87 | 85: 'catacomb', 88 | 86: 'cemetery', 89 | 87: 'chalet', 90 | 88: 'chemistry_lab', 91 | 89: 'childs_room', 92 | 90: 'church/indoor', 93 | 91: 'church/outdoor', 94 | 92: 'classroom', 95 | 93: 'clean_room', 96 | 94: 'cliff', 97 | 95: 'closet', 98 | 96: 'clothing_store', 99 | 97: 'coast', 100 | 98: 'cockpit', 101 | 99: 'coffee_shop', 102 | 100: 'computer_room', 103 | 101: 'conference_center', 104 | 102: 'conference_room', 105 | 103: 'construction_site', 106 | 104: 'corn_field', 107 | 105: 'corral', 108 | 106: 'corridor', 109 | 107: 'cottage', 110 | 108: 'courthouse', 111 | 109: 'courtyard', 112 | 110: 'creek', 113 | 111: 'crevasse', 114 | 112: 'crosswalk', 115 | 113: 'dam', 116 | 114: 'delicatessen', 117 | 115: 'department_store', 118 | 116: 'desert/sand', 119 | 117: 'desert/vegetation', 120 | 118: 'desert_road', 121 | 119: 'diner/outdoor', 122 | 120: 'dining_hall', 123 | 121: 'dining_room', 124 | 122: 'discotheque', 125 | 123: 'doorway/outdoor', 126 | 124: 'dorm_room', 127 | 125: 'downtown', 128 | 126: 'dressing_room', 129 | 127: 'driveway', 130 | 128: 'drugstore', 131 | 129: 'elevator/door', 132 | 130: 'elevator_lobby', 133 | 131: 'elevator_shaft', 134 | 132: 'embassy', 135 | 133: 'engine_room', 136 | 134: 'entrance_hall', 137 | 135: 'escalator/indoor', 138 | 136: 'excavation', 139 | 137: 'fabric_store', 140 | 138: 'farm', 141 | 139: 'fastfood_restaurant', 142 | 140: 'field/cultivated', 143 | 141: 'field/wild', 144 | 142: 'field_road', 145 | 143: 'fire_escape', 146 | 144: 'fire_station', 147 | 145: 'fishpond', 148 | 146: 'flea_market/indoor', 149 | 147: 'florist_shop/indoor', 150 | 148: 'food_court', 151 | 149: 'football_field', 152 | 150: 'forest/broadleaf', 153 | 151: 'forest_path', 154 | 152: 'forest_road', 155 | 153: 'formal_garden', 156 | 154: 'fountain', 157 | 155: 'galley', 158 | 156: 'garage/indoor', 159 | 157: 'garage/outdoor', 160 | 158: 'gas_station', 161 | 159: 'gazebo/exterior', 162 | 160: 'general_store/indoor', 163 | 161: 'general_store/outdoor', 164 | 162: 'gift_shop', 165 | 163: 'glacier', 166 | 164: 'golf_course', 167 | 165: 'greenhouse/indoor', 168 | 166: 'greenhouse/outdoor', 169 | 167: 'grotto', 170 | 168: 'gymnasium/indoor', 171 | 169: 'hangar/indoor', 172 | 170: 'hangar/outdoor', 173 | 171: 'harbor', 174 | 172: 'hardware_store', 175 | 173: 'hayfield', 176 | 174: 'heliport', 177 | 175: 'highway', 178 | 176: 'home_office', 179 | 177: 'home_theater', 180 | 178: 'hospital', 181 | 179: 'hospital_room', 182 | 180: 'hot_spring', 183 | 181: 'hotel/outdoor', 184 | 182: 'hotel_room', 185 | 183: 'house', 186 | 184: 'hunting_lodge/outdoor', 187 | 185: 'ice_cream_parlor', 188 | 186: 'ice_floe', 189 | 187: 'ice_shelf', 190 | 188: 'ice_skating_rink/indoor', 191 | 189: 'ice_skating_rink/outdoor', 192 | 190: 'iceberg', 193 | 191: 'igloo', 194 | 192: 'industrial_area', 195 | 193: 'inn/outdoor', 196 | 194: 'islet', 197 | 195: 'jacuzzi/indoor', 198 | 196: 'jail_cell', 199 | 197: 'japanese_garden', 200 | 198: 'jewelry_shop', 201 | 199: 'junkyard', 202 | 200: 'kasbah', 203 | 201: 'kennel/outdoor', 204 | 202: 'kindergarden_classroom', 205 | 203: 'kitchen', 206 | 204: 'lagoon', 207 | 205: 'lake/natural', 208 | 206: 'landfill', 209 | 207: 'landing_deck', 210 | 208: 'laundromat', 211 | 209: 'lawn', 212 | 210: 'lecture_room', 213 | 211: 'legislative_chamber', 214 | 212: 'library/indoor', 215 | 213: 'library/outdoor', 216 | 214: 'lighthouse', 217 | 215: 'living_room', 218 | 216: 'loading_dock', 219 | 217: 'lobby', 220 | 218: 'lock_chamber', 221 | 219: 'locker_room', 222 | 220: 'mansion', 223 | 221: 'manufactured_home', 224 | 222: 'market/indoor', 225 | 223: 'market/outdoor', 226 | 224: 'marsh', 227 | 225: 'martial_arts_gym', 228 | 226: 'mausoleum', 229 | 227: 'medina', 230 | 228: 'mezzanine', 231 | 229: 'moat/water', 232 | 230: 'mosque/outdoor', 233 | 231: 'motel', 234 | 232: 'mountain', 235 | 233: 'mountain_path', 236 | 234: 'mountain_snowy', 237 | 235: 'movie_theater/indoor', 238 | 236: 'museum/indoor', 239 | 237: 'museum/outdoor', 240 | 238: 'music_studio', 241 | 239: 'natural_history_museum', 242 | 240: 'nursery', 243 | 241: 'nursing_home', 244 | 242: 'oast_house', 245 | 243: 'ocean', 246 | 244: 'office', 247 | 245: 'office_building', 248 | 246: 'office_cubicles', 249 | 247: 'oilrig', 250 | 248: 'operating_room', 251 | 249: 'orchard', 252 | 250: 'orchestra_pit', 253 | 251: 'pagoda', 254 | 252: 'palace', 255 | 253: 'pantry', 256 | 254: 'park', 257 | 255: 'parking_garage/indoor', 258 | 256: 'parking_garage/outdoor', 259 | 257: 'parking_lot', 260 | 258: 'pasture', 261 | 259: 'patio', 262 | 260: 'pavilion', 263 | 261: 'pet_shop', 264 | 262: 'pharmacy', 265 | 263: 'phone_booth', 266 | 264: 'physics_laboratory', 267 | 265: 'picnic_area', 268 | 266: 'pier', 269 | 267: 'pizzeria', 270 | 268: 'playground', 271 | 269: 'playroom', 272 | 270: 'plaza', 273 | 271: 'pond', 274 | 272: 'porch', 275 | 273: 'promenade', 276 | 274: 'pub/indoor', 277 | 275: 'racecourse', 278 | 276: 'raceway', 279 | 277: 'raft', 280 | 278: 'railroad_track', 281 | 279: 'rainforest', 282 | 280: 'reception', 283 | 281: 'recreation_room', 284 | 282: 'repair_shop', 285 | 283: 'residential_neighborhood', 286 | 284: 'restaurant', 287 | 285: 'restaurant_kitchen', 288 | 286: 'restaurant_patio', 289 | 287: 'rice_paddy', 290 | 288: 'river', 291 | 289: 'rock_arch', 292 | 290: 'roof_garden', 293 | 291: 'rope_bridge', 294 | 292: 'ruin', 295 | 293: 'runway', 296 | 294: 'sandbox', 297 | 295: 'sauna', 298 | 296: 'schoolhouse', 299 | 297: 'science_museum', 300 | 298: 'server_room', 301 | 299: 'shed', 302 | 300: 'shoe_shop', 303 | 301: 'shopfront', 304 | 302: 'shopping_mall/indoor', 305 | 303: 'shower', 306 | 304: 'ski_resort', 307 | 305: 'ski_slope', 308 | 306: 'sky', 309 | 307: 'skyscraper', 310 | 308: 'slum', 311 | 309: 'snowfield', 312 | 310: 'soccer_field', 313 | 311: 'stable', 314 | 312: 'stadium/baseball', 315 | 313: 'stadium/football', 316 | 314: 'stadium/soccer', 317 | 315: 'stage/indoor', 318 | 316: 'stage/outdoor', 319 | 317: 'staircase', 320 | 318: 'storage_room', 321 | 319: 'street', 322 | 320: 'subway_station/platform', 323 | 321: 'supermarket', 324 | 322: 'sushi_bar', 325 | 323: 'swamp', 326 | 324: 'swimming_hole', 327 | 325: 'swimming_pool/indoor', 328 | 326: 'swimming_pool/outdoor', 329 | 327: 'synagogue/outdoor', 330 | 328: 'television_room', 331 | 329: 'television_studio', 332 | 330: 'temple/asia', 333 | 331: 'throne_room', 334 | 332: 'ticket_booth', 335 | 333: 'topiary_garden', 336 | 334: 'tower', 337 | 335: 'toyshop', 338 | 336: 'train_interior', 339 | 337: 'train_station/platform', 340 | 338: 'tree_farm', 341 | 339: 'tree_house', 342 | 340: 'trench', 343 | 341: 'tundra', 344 | 342: 'underwater/ocean_deep', 345 | 343: 'utility_room', 346 | 344: 'valley', 347 | 345: 'vegetable_garden', 348 | 346: 'veterinarians_office', 349 | 347: 'viaduct', 350 | 348: 'village', 351 | 349: 'vineyard', 352 | 350: 'volcano', 353 | 351: 'volleyball_court/outdoor', 354 | 352: 'waiting_room', 355 | 353: 'water_park', 356 | 354: 'water_tower', 357 | 355: 'waterfall', 358 | 356: 'watering_hole', 359 | 357: 'wave', 360 | 358: 'wet_bar', 361 | 359: 'wheat_field', 362 | 360: 'wind_farm', 363 | 361: 'windmill', 364 | 362: 'yard', 365 | 363: 'youth_hostel', 366 | 364: 'zen_garden' 367 | } 368 | -------------------------------------------------------------------------------- /util/upsample.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.filters import gaussian_filter 2 | from scipy.interpolate import RectBivariateSpline 3 | from scipy.ndimage.interpolation import zoom 4 | import numpy 5 | 6 | def upsampleL(fieldmap, activation_data, reduction=1, shape=None, 7 | scaleshape=None, out=None): 8 | ''' 9 | Applies a bilinear upsampling. 10 | ''' 11 | offset, size, step = fieldmap 12 | input_count = activation_data.shape[0] 13 | if len(activation_data.shape) == 2: 14 | ay, ax = centered_arange(fieldmap, activation_data.shape, reduction) 15 | if shape is None: 16 | shape = upsampled_shape( 17 | fieldmap, activation_data.shape, reduction) 18 | else: 19 | ay, ax = centered_arange(fieldmap, activation_data.shape[1:], reduction) 20 | if shape is None: 21 | shape = upsampled_shape( 22 | fieldmap, activation_data.shape[1:], reduction) 23 | if scaleshape is not None: 24 | iy, ix = full_arange(scaleshape) 25 | # TODO: consider treaing each point as a center of a pixel 26 | iy *= shape[0] / scaleshape[0] 27 | ix *= shape[1] / scaleshape[1] 28 | else: 29 | iy, ix = full_arange(shape) 30 | if out is None: 31 | out = numpy.empty((input_count, len(iy), len(ix)), 32 | dtype=activation_data.dtype) 33 | if len(activation_data.shape) == 2: 34 | f = RectBivariateSpline(ay, ax, activation_data, kx=1, ky=1) 35 | return f(iy, ix, grid=True) 36 | else: 37 | for z in range(input_count): 38 | f = RectBivariateSpline(ay, ax, activation_data[z], kx=1, ky=1) 39 | out[z] = f(iy, ix, grid=True) 40 | return out 41 | 42 | def upsampleC(fieldmap, activation_data, shape=None, out=None): 43 | ''' 44 | Applies a bicubic upsampling. 45 | ''' 46 | offset, size, step = fieldmap 47 | input_count = activation_data.shape[0] 48 | ay, ax = centered_arange(fieldmap, activation_data.shape[1:]) 49 | if shape is None: 50 | shape = upsampled_shape(fieldmap, activation_data.shape[1:]) 51 | iy, ix = full_arange(shape) 52 | if out is None: 53 | out = numpy.empty((input_count,) + shape, 54 | dtype=activation_data.dtype) 55 | for z in range(input_count): 56 | f = RectBivariateSpline(ay, ax, activation_data[z], kx=3, ky=3) 57 | out[z] = f(iy, ix, grid=True) 58 | return out 59 | 60 | def upsampleG(fieldmap, activation_data, shape=None): 61 | ''' 62 | Upsampling utility functions 63 | ''' 64 | offset, size, step = fieldmap 65 | input_count = activation_data.shape[0] 66 | if shape is None: 67 | shape = upsampled_shape(fieldmap, activation_data.shape[1:]) 68 | activations = numpy.zeros((input_count,) + shape) 69 | activations[(slice(None),) + 70 | centered_slice(fieldmap, activation_data.shape[1:])] = ( 71 | activation_data * numpy.prod(step)) 72 | blurred = gaussian_filter( 73 | activations, 74 | sigma=(0, ) + tuple(t // 1.414 for o, s, t in zip(*fieldmap)), 75 | mode='constant') 76 | return blurred 77 | 78 | def topo_sort(layers): 79 | # First, build a links-from and also a links-to graph 80 | links_from = {} 81 | links_to = {} 82 | for layer in layers: 83 | for bot in layer.bottom: 84 | if bot not in links_from: 85 | links_from[bot] = [] 86 | links_from[bot].append(layer) 87 | for top in layer.top: 88 | if top not in links_to: 89 | links_to[top] = [] 90 | links_to[top].append(layer) 91 | # Now do a DFS to figure out the ordering (using links-from) 92 | visited = set() 93 | ordering = [] 94 | stack = [] 95 | for seed in links_from: 96 | if seed not in visited: 97 | stack.append((seed, True)) 98 | stack.append((seed, False)) 99 | visited.add(seed) 100 | while stack: 101 | (blob, completed) = stack.pop() 102 | if completed: 103 | ordering.append(blob) 104 | elif blob in links_from: 105 | for layer in links_from[blob]: 106 | for t in layer.top: 107 | if t not in visited: 108 | stack.append((t, True)) 109 | stack.append((t, False)) 110 | visited.add(t) 111 | # Return a result in front-to-back order, with incoming links for each 112 | return list((blob, links_to[blob] if blob in links_to else []) 113 | for blob in reversed(ordering)) 114 | 115 | def composed_fieldmap(layers, end): 116 | ts = topo_sort(layers) 117 | fm_record = {} 118 | for blob, layers in ts: 119 | # Compute fm's on all the edges that go to this blob. 120 | all_fms = [ 121 | (compose_fieldmap(fm_record[bot][0], layer_fieldmap(layer)), 122 | fm_record[bot][1] + [(bot, layer)]) 123 | for layer in layers for bot in layer.bottom if bot != blob] 124 | # And take the max fieldmap. 125 | fm_record[blob] = max_fieldmap(all_fms) 126 | if blob == end: 127 | return fm_record[blob] 128 | 129 | def max_fieldmap(maps): 130 | biggest, bp = None, None 131 | for fm, path in maps: 132 | if biggest is None: 133 | biggest, bp = fm, path 134 | elif fm[1][0] > biggest[1][0]: 135 | biggest, bp = fm, path 136 | # When there is no biggest, for example when maps is the empty array, 137 | # use the trivial identity fieldmap with no path. 138 | if biggest is None: 139 | return ((0, 0), (1, 1), (1, 1)), [] 140 | return biggest, bp 141 | 142 | def shortest_layer_path(start, end, layers): 143 | # First, build a blob-to-outgoing-layer graph 144 | links_from = {} 145 | for layer in layers: 146 | for bot in layer.bottom: 147 | if bot not in links_from: 148 | links_from[bot] = [] 149 | links_from[bot].append(layer) 150 | # Then do a BFS on the graph to find the shortest path to 'end' 151 | queue = [(s, []) for s in start] 152 | visited = set(start) 153 | while queue: 154 | (blob, path) = queue.pop(0) 155 | for layer in links_from[blob]: 156 | for t in layer.top: 157 | if t == end: 158 | return path + [layer] 159 | if t not in visited: 160 | queue.append((t, path + [layer])) 161 | visited.add(t) 162 | return None 163 | 164 | def upsampled_shape(fieldmap, shape, reduction=1): 165 | # Given the shape of a layer's activation and a fieldmap describing 166 | # the transformation to original image space, returns the shape of 167 | # the input size 168 | return tuple(((w - 1) * t + s + 2 * o) // reduction 169 | for (o, s, t), w in zip(zip(*fieldmap), shape)) 170 | 171 | def make_mask_set(image_shape, fieldmap, activation_data, 172 | output=None, sigma=0.1, threshold=0.5, percentile=None): 173 | """Creates a set of receptive field masks with uniform thresholds 174 | over a range of inputs. 175 | """ 176 | offset, shape, step = fieldmap 177 | input_count = activation_data.shape[0] 178 | activations = numpy.zeros((input_count,) + image_shape) 179 | activations[(slice(None),) + 180 | centered_slice(fieldmap, activation_data.shape[1:])] = ( 181 | activation_data) 182 | blurred = gaussian_filter( 183 | activations, 184 | sigma=(0, ) + tuple(s * sigma for s in shape), 185 | mode='constant') 186 | if percentile is not None: 187 | limit = blurred.ravel().percentile(percentile) 188 | return blurred > limit 189 | else: 190 | maximum = blurred.ravel().max() 191 | return (blurred > maximum * threshold) 192 | 193 | def safezoom(array, ratio, output=None, order=0): 194 | '''Like numpy.zoom, but does not crash when the first dimension 195 | of the array is of size 1, as happens often with segmentations''' 196 | dtype = array.dtype 197 | if array.dtype == numpy.float16: 198 | array = array.astype(numpy.float32) 199 | if array.shape[0] == 1: 200 | if output is not None: 201 | output = output[0,...] 202 | result = zoom(array[0,...], ratio[1:], 203 | output=output, order=order) 204 | if output is None: 205 | output = result[numpy.newaxis] 206 | else: 207 | result = zoom(array, ratio, output=output, order=order) 208 | if output is None: 209 | output = result 210 | return output.astype(dtype) 211 | 212 | def receptive_field(location, fieldmap): 213 | """Computes the receptive field of a specific location. 214 | 215 | Parameters 216 | ---------- 217 | location: tuple 218 | The x-y position of the unit being queried. 219 | fieldmap: 220 | The (offset, size, step) tuple fieldmap representing the 221 | receptive field map for the layer being queried. 222 | """ 223 | return compose_fieldmap(fieldmap, (location, (1, 1), (1, 1)))[:2] 224 | 225 | 226 | def proto_getattr(p, a, d): 227 | hf = True 228 | # Try using HasField to detect the presence of a field; 229 | # if there is no HasField, then just use getattr. 230 | try: 231 | hf = p.HasField(a) 232 | except: 233 | pass 234 | if hf: 235 | return getattr(p, a, d) 236 | return d 237 | 238 | def wh_attr(layer, attrname, default=0, minval=0): 239 | if not hasattr(default, '__len__'): 240 | default = (default, default) 241 | val = proto_getattr(layer, attrname, None) 242 | if val is None or val == []: 243 | h = max(minval, getattr(layer, attrname + '_h', default[0])) 244 | w = max(minval, getattr(layer, attrname + '_w', default[1])) 245 | elif hasattr(val, '__len__'): 246 | h = val[0] 247 | w = val[1] if len(val) >= 2 else h 248 | else: 249 | h = val 250 | w = val 251 | return (h, w) 252 | 253 | def layer_fieldmap(layer): 254 | # Only convolutional and pooling layers affect geometry. 255 | if layer.type == 'Convolution' or layer.type == 'Pooling': 256 | if layer.type == 'Pooling': 257 | config = layer.pooling_param 258 | if config.global_pooling: 259 | return ((0, 0), (None, None), (1, 1)) 260 | else: 261 | config = layer.convolution_param 262 | size = wh_attr(config, 'kernel_size', wh_attr(config, 'kernel', 1)) 263 | stride = wh_attr(config, 'stride', 1, minval=1) 264 | padding = wh_attr(config, 'pad', 0) 265 | neg_padding = tuple((-x) for x in padding) 266 | return (neg_padding, size, stride) 267 | # All other layers just pass through geometry unchanged. 268 | return ((0, 0), (1, 1), (1, 1)) 269 | 270 | def layerarray_fieldmap(layerarray): 271 | fieldmap = ((0, 0), (1, 1), (1, 1)) 272 | for layer in layerarray: 273 | fieldmap = compose_fieldmap(fieldmap, layer_fieldmap(layer)) 274 | return fieldmap 275 | 276 | # rf1 is the lower layer, rf2 is the higher layer 277 | def compose_fieldmap(rf1, rf2): 278 | """Composes two stacked fieldmap maps. 279 | 280 | Field maps are represented as triples of (offset, size, step), 281 | where each is an (x, y) pair. 282 | 283 | To find the pixel range corresponding to output pixel (x, y), just 284 | do the following: 285 | start_x = x * step[0] + offset[1] 286 | limit_x = start_x + size[0] 287 | start_y = y * step[1] + offset[1] 288 | limit_y = start_y + size[1] 289 | 290 | Parameters 291 | ---------- 292 | rf1: tuple 293 | The lower-layer receptive fieldmap, a tuple of (offset, size, step). 294 | rf2: tuple 295 | The higher-layer receptive fieldmap, a tuple of (offset, size, step). 296 | """ 297 | if rf1 == None: 298 | import pdb; pdb.set_trace() 299 | offset1, size1, step1 = rf1 300 | offset2, size2, step2 = rf2 301 | 302 | size = tuple((size2c - 1) * step1c + size1c 303 | for size1c, step1c, size2c in zip(size1, step1, size2)) 304 | offset = tuple(offset2c * step1c + offset1c 305 | for offset2c, step1c, offset1c in zip(offset2, step1, offset1)) 306 | step = tuple(step2c * step1c 307 | for step1c, step2c in zip(step1, step2)) 308 | return (offset, size, step) 309 | 310 | def _cropped_slices(offset, size, limit): 311 | corner = 0 312 | if offset < 0: 313 | size += offset 314 | offset = 0 315 | if limit - offset < size: 316 | corner = limit - offset 317 | size -= corner 318 | return (slice(corner, corner + size), slice(offset, offset + size)) 319 | 320 | def crop_field(image_data, fieldmap, location): 321 | """Crops image_data to the specified receptive field. 322 | 323 | Together fieldmap and location specify a receptive field on the image, 324 | which may overlap the edge. This returns a crop to that shape, including 325 | any zero padding necessary to fill out the shape beyond the image edge. 326 | """ 327 | offset, size = receptive_field(fieldmap, location) 328 | return crop_rectangle(image_data, offset, size) 329 | 330 | def crop_rectangle(image_data, offset, size): 331 | coloraxis = 0 if image_data.size <= 2 else 1 332 | allcolors = () if not coloraxis else (slice(None),) * coloraxis 333 | colordepth = () if not coloraxis else (image_data.size[0], ) 334 | result = numpy.zeros(colordepth + size) 335 | (xto, xfrom), (yto, yfrom) = (_cropped_slices( 336 | o, s, l) for o, s, l in zip(offset, size, image_data.size[coloraxis:])) 337 | result[allcolors + (xto, yto)] = image_data[allcolors + (xfrom, yfrom)] 338 | return result 339 | 340 | def center_location(fieldmap, location): 341 | if isinstance(location, numpy.ndarray): 342 | offset, size, step = fieldmap 343 | broadcast = (numpy.newaxis, ) * (len(location.shape) - 1) + ( 344 | slice(None),) 345 | step = numpy.array(step)[broadcast] 346 | offset = numpy.array(offset)[broadcast] 347 | size = numpy.array(size)[broadcast] 348 | return location * step + offset + size // 2 349 | else: 350 | offset, shape = receptive_field(location, fieldmap) 351 | return tuple(o + s // 2 for o, s in zip(offset, shape)) 352 | 353 | def centered_slice(fieldmap, activation_shape, reduction=1): 354 | offset, size, step = fieldmap 355 | r = reduction 356 | return tuple(slice((s // 2 + o) // r, (s // 2 + o + a * t) // r, t // r) 357 | for o, s, t, a in zip(offset, size, step, activation_shape)) 358 | 359 | def centered_arange(fieldmap, activation_shape, reduction=1): 360 | offset, size, step = fieldmap 361 | r = reduction 362 | return tuple(numpy.arange( 363 | (s // 2 + o) // r, (s // 2 + o + a * t) // r, t // r)[:a] # Hack to avoid a+1 points 364 | for o, s, t, a in zip(offset, size, step, activation_shape)) 365 | 366 | def full_arange(output_shape): 367 | return tuple(numpy.arange(o) for o in output_shape) 368 | 369 | -------------------------------------------------------------------------------- /util/vecquantile.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | class QuantileVector: 4 | """ 5 | Streaming randomized quantile computation for numpy. 6 | 7 | Add any amount of data repeatedly via add(data). At any time, 8 | quantile estimates (or old-style percentiles) can be read out using 9 | quantiles(q) or percentiles(p). 10 | 11 | Accuracy scales according to resolution: the default is to 12 | set resolution to be accurate to better than 0.1%, 13 | while limiting storage to about 50,000 samples. 14 | 15 | Good for computing quantiles of huge data without using much memory. 16 | Works well on arbitrary data with probability near 1. 17 | 18 | Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty 19 | from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf 20 | """ 21 | 22 | def __init__(self, depth=1, resolution=24 * 1024, buffersize=None, 23 | dtype=None, seed=None): 24 | self.resolution = resolution 25 | self.depth = depth 26 | # Default buffersize: 128 samples (and smaller than resolution). 27 | if buffersize is None: 28 | buffersize = min(128, (resolution + 7) // 8) 29 | self.buffersize = buffersize 30 | self.samplerate = 1.0 31 | self.data = [numpy.zeros(shape=(depth, resolution), dtype=dtype)] 32 | self.firstfree = [0] 33 | self.random = numpy.random.RandomState(seed) 34 | self.extremes = numpy.empty(shape=(depth, 2), dtype=dtype) 35 | self.extremes.fill(numpy.NaN) 36 | self.size = 0 37 | 38 | def add(self, incoming): 39 | assert len(incoming.shape) == 2 40 | assert incoming.shape[1] == self.depth 41 | self.size += incoming.shape[0] 42 | # Convert to a flat numpy array. 43 | if self.samplerate >= 1.0: 44 | self._add_every(incoming) 45 | return 46 | # If we are sampling, then subsample a large chunk at a time. 47 | self._scan_extremes(incoming) 48 | chunksize = numpy.ceil[self.buffersize / self.samplerate] 49 | for index in range(0, len(incoming), chunksize): 50 | batch = incoming[index:index+chunksize] 51 | sample = batch[self.random.binomial(1, self.samplerate, len(batch))] 52 | self._add_every(sample) 53 | 54 | def _add_every(self, incoming): 55 | supplied = len(incoming) 56 | index = 0 57 | while index < supplied: 58 | ff = self.firstfree[0] 59 | available = self.data[0].shape[1] - ff 60 | if available == 0: 61 | if not self._shift(): 62 | # If we shifted by subsampling, then subsample. 63 | incoming = incoming[index:] 64 | if self.samplerate >= 0.5: 65 | print('SAMPLING') 66 | self._scan_extremes(incoming) 67 | incoming = incoming[self.random.binomial(1, 0.5, 68 | len(incoming - index))] 69 | index = 0 70 | supplied = len(incoming) 71 | ff = self.firstfree[0] 72 | available = self.data[0].shape[1] - ff 73 | copycount = min(available, supplied - index) 74 | self.data[0][:,ff:ff + copycount] = numpy.transpose( 75 | incoming[index:index + copycount,:]) 76 | self.firstfree[0] += copycount 77 | index += copycount 78 | 79 | def _shift(self): 80 | index = 0 81 | # If remaining space at the current layer is less than half prev 82 | # buffer size (rounding up), then we need to shift it up to ensure 83 | # enough space for future shifting. 84 | while self.data[index].shape[1] - self.firstfree[index] < ( 85 | -(-self.data[index-1].shape[1] // 2) if index else 1): 86 | if index + 1 >= len(self.data): 87 | return self._expand() 88 | data = self.data[index][:,0:self.firstfree[index]] 89 | data.sort() 90 | if index == 0 and self.samplerate >= 1.0: 91 | self._update_extremes(data[:,0], data[:,-1]) 92 | offset = self.random.binomial(1, 0.5) 93 | position = self.firstfree[index + 1] 94 | subset = data[:,offset::2] 95 | self.data[index + 1][:,position:position + subset.shape[1]] = subset 96 | self.firstfree[index] = 0 97 | self.firstfree[index + 1] += subset.shape[1] 98 | index += 1 99 | return True 100 | 101 | def _scan_extremes(self, incoming): 102 | # When sampling, we need to scan every item still to get extremes 103 | self._update_extremes( 104 | numpy.nanmin(incoming, axis=0), 105 | numpy.nanmax(incoming, axis=0)) 106 | 107 | def _update_extremes(self, minr, maxr): 108 | self.extremes[:,0] = numpy.nanmin( 109 | [self.extremes[:, 0], minr], axis=0) 110 | self.extremes[:,-1] = numpy.nanmax( 111 | [self.extremes[:, -1], maxr], axis=0) 112 | 113 | def minmax(self): 114 | if self.firstfree[0]: 115 | self._scan_extremes(self.data[0][:,:self.firstfree[0]].transpose()) 116 | return self.extremes.copy() 117 | 118 | def _expand(self): 119 | cap = self._next_capacity() 120 | if cap > 0: 121 | # First, make a new layer of the proper capacity. 122 | self.data.insert(0, numpy.empty( 123 | shape=(self.depth, cap), dtype=self.data[-1].dtype)) 124 | self.firstfree.insert(0, 0) 125 | else: 126 | # Unless we're so big we are just subsampling. 127 | assert self.firstfree[0] == 0 128 | self.samplerate *= 0.5 129 | for index in range(1, len(self.data)): 130 | # Scan for existing data that needs to be moved down a level. 131 | amount = self.firstfree[index] 132 | if amount == 0: 133 | continue 134 | position = self.firstfree[index-1] 135 | # Move data down if it would leave enough empty space there 136 | # This is the key invariant: enough empty space to fit half 137 | # of the previous level's buffer size (rounding up) 138 | if self.data[index-1].shape[1] - (amount + position) >= ( 139 | -(-self.data[index-2].shape[1] // 2) if (index-1) else 1): 140 | self.data[index-1][:,position:position + amount] = ( 141 | self.data[index][:,:amount]) 142 | self.firstfree[index-1] += amount 143 | self.firstfree[index] = 0 144 | else: 145 | # Scrunch the data if it would not. 146 | data = self.data[index][:,:amount] 147 | data.sort() 148 | if index == 1: 149 | self._update_extremes(data[:,0], data[:,-1]) 150 | offset = self.random.binomial(1, 0.5) 151 | scrunched = data[:,offset::2] 152 | self.data[index][:,:scrunched.shape[1]] = scrunched 153 | self.firstfree[index] = scrunched.shape[1] 154 | return cap > 0 155 | 156 | def _next_capacity(self): 157 | cap = numpy.ceil(self.resolution * numpy.power(0.67, len(self.data))) 158 | if cap < 2: 159 | return 0 160 | return max(self.buffersize, int(cap)) 161 | 162 | def _weighted_summary(self, sort=True): 163 | if self.firstfree[0]: 164 | self._scan_extremes(self.data[0][:,:self.firstfree[0]].transpose()) 165 | size = sum(self.firstfree) + 2 166 | weights = numpy.empty( 167 | shape=(size), dtype='float32') # floating point 168 | summary = numpy.empty( 169 | shape=(self.depth, size), dtype=self.data[-1].dtype) 170 | weights[0:2] = 0 171 | summary[:,0:2] = self.extremes 172 | index = 2 173 | for level, ff in enumerate(self.firstfree): 174 | if ff == 0: 175 | continue 176 | summary[:,index:index + ff] = self.data[level][:,:ff] 177 | weights[index:index + ff] = numpy.power(2.0, level) 178 | index += ff 179 | assert index == summary.shape[1] 180 | if sort: 181 | order = numpy.argsort(summary) 182 | summary = summary[numpy.arange(self.depth)[:,None], order] 183 | weights = weights[order] 184 | return (summary, weights) 185 | 186 | def quantiles(self, quantiles, old_style=False): 187 | if self.size == 0: 188 | return numpy.full((self.depth, len(quantiles)), numpy.nan) 189 | summary, weights = self._weighted_summary() 190 | cumweights = numpy.cumsum(weights, axis=-1) - weights / 2 191 | if old_style: 192 | # To be convenient with numpy.percentile 193 | cumweights -= cumweights[:,0:1] 194 | cumweights /= cumweights[:,-1:] 195 | else: 196 | cumweights /= numpy.sum(weights, axis=-1, keepdims=True) 197 | result = numpy.empty(shape=(self.depth, len(quantiles))) 198 | for d in range(self.depth): 199 | result[d] = numpy.interp(quantiles, cumweights[d], summary[d]) 200 | return result 201 | 202 | def integrate(self, fun): 203 | result = None 204 | for level, ff in enumerate(self.firstfree): 205 | if ff == 0: 206 | continue 207 | term = numpy.sum( 208 | fun(self.data[level][:,:ff]) * numpy.power(2.0, level), 209 | axis=-1) 210 | if result is None: 211 | result = term 212 | else: 213 | result += term 214 | if result is not None: 215 | result /= self.samplerate 216 | return result 217 | 218 | def percentiles(self, percentiles): 219 | return self.quantiles(percentiles, old_style=True) 220 | 221 | def readout(self, count, old_style=True): 222 | return self.quantiles( 223 | numpy.linspace(0.0, 1.0, count), old_style=old_style) 224 | 225 | 226 | if __name__ == '__main__': 227 | import time 228 | # An adverarial case: we keep finding more numbers in the middle 229 | # as the stream goes on. 230 | amount = 10000000 231 | percentiles = 1000 232 | data = numpy.arange(float(amount)) 233 | data[1::2] = data[-1::-2] + (len(data) - 1) 234 | data /= 2 235 | depth = 50 236 | alldata = data[:,None] + (numpy.arange(depth) * amount)[None, :] 237 | actual_sum = numpy.sum(alldata * alldata, axis=0) 238 | amt = amount // depth 239 | for r in range(depth): 240 | numpy.random.shuffle(alldata[r*amt:r*amt+amt,r]) 241 | # data[::2] = data[-2::-2] 242 | # numpy.random.shuffle(data) 243 | starttime = time.time() 244 | qc = QuantileVector(depth=depth, resolution=8 * 1024) 245 | qc.add(alldata) 246 | ro = qc.readout(1001) 247 | endtime = time.time() 248 | # print 'ro', ro 249 | # print ro - numpy.linspace(0, amount, percentiles+1) 250 | gt = numpy.linspace(0, amount, percentiles+1)[None,:] + ( 251 | numpy.arange(qc.depth) * amount)[:,None] 252 | print("Maximum relative deviation among %d perentiles:" % percentiles, ( 253 | numpy.max(abs(ro - gt) / amount) * percentiles)) 254 | print("Minmax eror %f, %f" % ( 255 | max(abs(qc.minmax()[:,0] - numpy.arange(qc.depth) * amount)), 256 | max(abs(qc.minmax()[:, -1] - (numpy.arange(qc.depth)+1) * amount + 1)))) 257 | print("Integral error:", numpy.max(numpy.abs( 258 | qc.integrate(lambda x: x * x) 259 | - actual_sum) / actual_sum)) 260 | print("Count error: ", (qc.integrate(lambda x: numpy.ones(x.shape[-1]) 261 | ) - qc.size) / (0.0 + qc.size)) 262 | print("Time", (endtime - starttime)) 263 | 264 | -------------------------------------------------------------------------------- /visualize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/IBD/6fda25a4aea5ba0f2cea13f5619f5aaac7c4186d/visualize/__init__.py -------------------------------------------------------------------------------- /visualize/html.py: -------------------------------------------------------------------------------- 1 | import os 2 | import settings 3 | def html_cam_seg(output_file, image_folder, inds, info): 4 | f = open(output_file, 'w') 5 | f.write("\n\n\n\n\n\n\n\n") 6 | # f.write("") 7 | for ind in inds: 8 | if settings.APP == "vqa": 9 | headline = "

%d %s? real: %s, prediction %s


\n" % (ind, info[ind][0], info[ind][1], info[ind][2]) 10 | elif settings.APP == "imagecap": 11 | sents, highlight_id = info[ind] 12 | sents = list(sents) 13 | sents[highlight_id] = "%s" % sents[highlight_id] 14 | headline = "

%d %s


\n" % (ind, ' '.join(sents)) 15 | else: 16 | headline = "

%d %s


\n" % (ind, info[ind]) 17 | imageline = "\n" % (ind, settings.IMG_SIZE, os.path.join(image_folder, '%03d.jpg' % ind)) 18 | f.write(headline) 19 | f.write(imageline) 20 | f.write("\n\n") 21 | f.close() 22 | -------------------------------------------------------------------------------- /visualize/plot.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import colorsys 5 | import settings 6 | import numpy as np 7 | from sklearn.manifold import TSNE, SpectralEmbedding 8 | from loader.feature_loader import concept_loader 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def random_color(labels, type="soft"): 13 | if type == "soft": 14 | HSVcolors = [(np.random.uniform(low=0.1, high=0.9), 15 | np.random.uniform(low=0.2, high=0.6), 16 | np.random.uniform(low=0.6, high=0.95)) for _ in range(labels)] 17 | elif type == "bright": 18 | HSVcolors = [(np.random.uniform(low=0.0, high=1), 19 | np.random.uniform(low=0.2, high=1), 20 | np.random.uniform(low=0.9, high=1)) for _ in range(labels)] 21 | return [colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]) for HSVcolor in HSVcolors] 22 | 23 | 24 | # Generate random colormap 25 | def rand_cmap(nlabels, type='bright', first_color_black=True, last_color_black=False, verbose=True): 26 | import matplotlib.pyplot as plt 27 | from matplotlib.colors import LinearSegmentedColormap 28 | """ 29 | Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks 30 | :param nlabels: Number of labels (size of colormap) 31 | :param type: 'bright' for strong colors, 'soft' for pastel colors 32 | :param first_color_black: Option to use first color as black, True or False 33 | :param last_color_black: Option to use last color as black, True or False 34 | :param verbose: Prints the number of labels and shows the colormap. True or False 35 | :return: colormap for matplotlib 36 | """ 37 | 38 | if type not in ('bright', 'soft'): 39 | print ('Please choose "bright" or "soft" for type') 40 | return 41 | 42 | if verbose: 43 | print('Number of labels: ' + str(nlabels)) 44 | 45 | # Generate color map for bright colors, based on hsv 46 | if type == 'bright': 47 | randHSVcolors = [(np.random.uniform(low=0.0, high=1), 48 | np.random.uniform(low=0.2, high=1), 49 | np.random.uniform(low=0.9, high=1)) for i in range(nlabels)] 50 | 51 | # Convert HSV list to RGB 52 | randRGBcolors = [] 53 | for HSVcolor in randHSVcolors: 54 | randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) 55 | 56 | if first_color_black: 57 | randRGBcolors[0] = [0, 0, 0] 58 | 59 | if last_color_black: 60 | randRGBcolors[-1] = [0, 0, 0] 61 | 62 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 63 | 64 | # Generate soft pastel colors, by limiting the RGB spectrum 65 | if type == 'soft': 66 | low = 0.6 67 | high = 0.95 68 | randRGBcolors = [(np.random.uniform(low=low, high=high), 69 | np.random.uniform(low=low, high=high), 70 | np.random.uniform(low=low, high=high)) for i in range(nlabels)] 71 | 72 | if first_color_black: 73 | randRGBcolors[0] = [0, 0, 0] 74 | 75 | if last_color_black: 76 | randRGBcolors[-1] = [0, 0, 0] 77 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 78 | 79 | # Display colorbar 80 | if verbose: 81 | from matplotlib import colors, colorbar 82 | from matplotlib import pyplot as plt 83 | fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) 84 | 85 | bounds = np.linspace(0, nlabels, nlabels + 1) 86 | norm = colors.BoundaryNorm(bounds, nlabels) 87 | 88 | cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None, 89 | boundaries=bounds, format='%1i', orientation=u'horizontal') 90 | 91 | return random_colormap 92 | 93 | def embedding_map_sample(feat, layer, fo): 94 | sample_size = 1000 95 | feat_loader_train, feat_loader_test = concept_loader(feat, layer, fo.data, len(fo.data.label), split=True, batch_size=sample_size) 96 | from sklearn.manifold import TSNE, SpectralEmbedding 97 | feat, label = feat_loader_test.__iter__().next() 98 | feat_nse = TSNE(n_components=2, verbose=2).fit_transform(feat.numpy()) 99 | # feat_nse = SpectralEmbedding(n_components=2).fit_transform(feat) 100 | HSVcolors = [(np.random.uniform(low=0.0, high=1), 101 | np.random.uniform(low=0.2, high=1), 102 | np.random.uniform(low=0.9, high=1)) for i in range(label.max() + 1)] 103 | RGBcolors = np.array([colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]) for HSVcolor in HSVcolors]) 104 | plt.scatter(feat_nse[:,0], feat_nse[:,1], c=RGBcolors[label.numpy()], alpha=.5) 105 | plt.show() 106 | 107 | 108 | def embedding_map(feat2d, c_map, labelcat): 109 | 110 | marker_cat = ["x", "o", "v", "+", "s", "*"] 111 | HSVcolors = [(np.random.uniform(low=0.0, high=1), 112 | np.random.uniform(low=0.2, high=1), 113 | np.random.uniform(low=0.9, high=1)) for i in range(len(c_map))] 114 | RGBcolors = [colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]) for HSVcolor in HSVcolors] 115 | 116 | for cid,ind_set in enumerate(c_map): 117 | if ind_set.__len__() == 0: 118 | continue 119 | ind_set = np.array(list(ind_set)) 120 | cat = labelcat[cid].nonzero()[0][0] 121 | points = feat2d[ind_set[:, 0], ind_set[:, 1], ind_set[:, 2]] 122 | plt.scatter(points[:,0], points[:,1], c=RGBcolors[cid], marker=marker_cat[cat], alpha=.5) 123 | 124 | plt.savefig(os.path.join(settings.OUTPUT_FOLDER,'image','feat2d.jpg')) 125 | 126 | 127 | #weight_vis(model, imagenet_categories.values(), feat_clf, [l['name'] for l in fo.data.label], save_path='cachecache/nse.npy') 128 | 129 | def weight_vis(model, model_labels, feat_clf, feat_clf_labels, save_path=None): 130 | params = list(model.parameters()) 131 | weight_softmax = params[-2].data.numpy() 132 | if settings.GPU: 133 | weight_clf = feat_clf.fc.weight.data.cpu().numpy() 134 | else: 135 | weight_clf = feat_clf.fc.weight.data.numpy() 136 | 137 | if save_path and os.path.exists(save_path): 138 | nse = np.load(save_path) 139 | else: 140 | nse = TSNE(n_components=2, verbose=2).fit_transform(np.concatenate([weight_softmax, weight_clf])) 141 | if save_path: 142 | np.save(save_path, nse) 143 | plt.figure() 144 | plt.scatter(nse[:len(weight_softmax),0],nse[:len(weight_softmax),1], 10, c='r') 145 | plt.scatter(nse[len(weight_softmax):, 0], nse[len(weight_softmax):, 1], 10, c='b') 146 | for i,label in enumerate(model_labels): 147 | plt.text(nse[i,0], nse[i,1], label, fontdict={'size': 6, 'color': 'r'}) 148 | for i,label in enumerate(feat_clf_labels): 149 | plt.text(nse[i+len(weight_softmax), 0], nse[i+len(weight_softmax), 1], label, fontdict={'size': 6, 'color': 'b'}) 150 | plt.show() 151 | 152 | def image_summary(fo, model, feat_clf): 153 | outpath = os.path.join(settings.OUTPUT_FOLDER, 'html', 'image') 154 | if not os.path.exists(outpath): 155 | os.makedirs(outpath) 156 | if settings.APP == "classification" or settings.APP == "imagecap": 157 | # image_file ='/home/sunyiyou/PycharmProjects/test/example4.jpg' 158 | # vis, prediction = fo.instance_cam_by_file(model, image_file, feat_clf) 159 | # vis.save('/home/sunyiyou/PycharmProjects/test/%03d.jpg' % 1) 160 | with open(settings.DATASET_INDEX_FILE) as f: 161 | image_list = f.readlines() 162 | predictions = [] 163 | for i, file in enumerate(image_list): 164 | print("generating visualization on %03d" % i) 165 | image_file = os.path.join(settings.DATASET_PATH, file.strip()) 166 | vis, prediction = fo.instance_cam_by_file(model, image_file, feat_clf) 167 | predictions.append(prediction) 168 | vis.save(os.path.join(settings.OUTPUT_FOLDER, 'html', 'image', "%03d.jpg" % i)) 169 | 170 | elif settings.APP == "vqa": 171 | from loader.vqa_data_loader import VQA, collate_fn 172 | vqa_test = VQA(settings.VQA_QUESTIONS_FILE,settings.VQA_ANSWERS_FILE,settings.VQA_IMG_PATH) 173 | loader = torch.utils.data.DataLoader( 174 | vqa_test, 175 | batch_size=1, 176 | shuffle=False, 177 | num_workers=1, 178 | collate_fn=collate_fn, 179 | ) 180 | predictions = [] 181 | for i, (v, q, a, idx, q_len, q_org, a_org) in enumerate(loader): 182 | print("generating visualization on %03d" % i) 183 | vis, prediction = fo.instance_cam_by_file(model, v[0], feat_clf, other_params=(q, q_len, a)) 184 | vis.save(os.path.join(settings.OUTPUT_FOLDER, 'html', 'image', "%03d.jpg" % i)) 185 | # question = [] 186 | # for q_i in range(q_len[0]): 187 | # question.append(vqa_test.questions_tokens[q[0][q_i] - 1]) 188 | prediction_answer = vqa_test.answer_tokens[prediction.data[0]] 189 | predictions.append((' '.join([qw[0] for qw in q_org]), a_org[0][0], prediction_answer)) 190 | # elif settings.APP == "imagecap": 191 | 192 | import pickle 193 | with open(os.path.join(settings.OUTPUT_FOLDER, 'prediction.pickle'), 'wb') as f: 194 | pickle.dump(predictions, f) 195 | 196 | return predictions 197 | 198 | 199 | def fig_sample(fo, model, feat_clf): 200 | if settings.APP == "classification" or settings.APP == "imagecap": 201 | # image_file ='/home/sunyiyou/PycharmProjects/test/example4.jpg' 202 | # vis, prediction = fo.instance_cam_by_file(model, image_file, feat_clf, fig_style=1) 203 | # vis.save('test.jpg') 204 | with open(settings.DATASET_INDEX_FILE) as f: 205 | image_list = f.readlines() 206 | predictions = [] 207 | outpath = os.path.join(settings.OUTPUT_FOLDER, 'html', 'image') 208 | if not os.path.exists(outpath): 209 | os.makedirs(outpath) 210 | for i, file in enumerate(image_list): 211 | print("generating figure on %03d" % i) 212 | image_file = os.path.join(settings.DATASET_PATH, file.strip()) 213 | vis, prediction = fo.instance_cam_by_file(model, image_file, feat_clf, fig_style=1) 214 | predictions.append(prediction) 215 | vis.save(os.path.join(outpath, "%03d.jpg" % i)) 216 | import pickle 217 | with open(os.path.join(settings.OUTPUT_FOLDER, 'prediction.pickle'), 'wb') as f: 218 | pickle.dump(predictions, f) 219 | 220 | return predictions 221 | --------------------------------------------------------------------------------