├── .deepsource.toml ├── .gitignore ├── LICENSE ├── README.md ├── code ├── README.md ├── conda_environment.yml ├── evaluate.py ├── lib │ ├── __init__.py │ ├── archs │ │ ├── README.md │ │ ├── __init__.py │ │ ├── instance_counter.py │ │ ├── modules │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── conv_gru.py │ │ │ ├── coord_conv.py │ │ │ ├── recurrent_hourglass.py │ │ │ ├── renet.py │ │ │ ├── utils.py │ │ │ └── vgg16.py │ │ ├── reseg.py │ │ └── stacked_recurrent_hourglass.py │ ├── dataset.py │ ├── losses │ │ ├── __init__.py │ │ ├── dice.py │ │ └── discriminative.py │ ├── model.py │ ├── prediction.py │ ├── preprocess.py │ └── utils.py ├── pred.py ├── pred_list.py ├── settings │ ├── CVPPP │ │ ├── README.md │ │ ├── __init__.py │ │ ├── data_settings.py │ │ ├── model_settings.py │ │ └── training_settings.py │ ├── README.md │ └── __init__.py └── train.py ├── data ├── README.md ├── metadata │ ├── CVPPP │ │ ├── means_and_stds.txt │ │ ├── training.lst │ │ └── validation.lst │ └── README.md ├── processed │ ├── CVPPP │ │ └── README.md │ └── README.md ├── raw │ ├── CVPPP │ │ └── README.md │ └── README.md └── scripts │ ├── CVPPP │ ├── 1-create_annotations.py │ ├── 1-remove_alpha.sh │ ├── 2-get_image_means-stds.py │ ├── 2-get_image_paths.py │ ├── 2-get_image_shapes.py │ ├── 2-get_number_of_instances.py │ ├── 3-create_dataset.py │ ├── prepare.sh │ └── utils.py │ └── README.md ├── models ├── CVPPP │ └── README.md └── README.md ├── outputs ├── CVPPP │ └── README.md └── README.md └── samples ├── CVPPP ├── README.md ├── plant007_rgb-fg_mask.png ├── plant007_rgb-ins_mask_color.png ├── plant007_rgb.png ├── plant031_rgb-fg_mask.png ├── plant031_rgb-ins_mask_color.png ├── plant031_rgb.png ├── plant033_rgb-fg_mask.png ├── plant033_rgb-ins_mask_color.png └── plant033_rgb.png └── README.md /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | [[analyzers]] 4 | name = "python" 5 | enabled = true 6 | 7 | [analyzers.meta] 8 | runtime_version = "2.x.x" 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/metadata/*/number_of_instances.txt 2 | data/metadata/*/image_shapes.txt 3 | data/metadata/*/*image_paths.txt 4 | data/processed/CVPPP/*/ 5 | data/raw/CVPPP/*/ 6 | !data/processed/*.md 7 | !data/processed/CVPPP/*.md 8 | !data/raw/CVPPP/*.md 9 | data/raw/*/*.zip 10 | models/CVPPP/*/ 11 | !models/CVPPP/*.md 12 | outputs/CVPPP/*/ 13 | !outputs/CVPPP/*.md 14 | 15 | !*README* 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Sphinx documentation 26 | docs/_build/ 27 | 28 | # Jupyter Notebook 29 | .ipynb_checkpoints 30 | 31 | # OS generated files # 32 | ###################### 33 | .DS_Store 34 | ehthumbs.db 35 | Icon 36 | Thumbs.db 37 | .tmtags 38 | .idea 39 | vendor.tags 40 | tmtagsHistory 41 | *.sublime-project 42 | *.sublime-workspace 43 | .bundle 44 | 45 | *~ 46 | *.swp 47 | 48 | .pynative 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Instance Segmentation with a Discriminative Loss Function 2 | 3 | This repository implements [Semantic Instance Segmentation with a Discriminative Loss Function](https://arxiv.org/abs/1708.02551) with some enhancements. 4 | 5 | * Reference paper does not predict semantic segmentation mask, instead it uses ground-truth semantic segmentation mask. This code predicts semantic segmentation mask, similar to [Towards End-to-End Lane Detection: an Instance Segmentation Approach](https://arxiv.org/abs/1802.05591). 6 | * Reference paper predicts the number of instances implicity. It predicts embeddings for instances and predicts the number of instances as a result of clustering. Instead, this code predicts the number of instances as an output of network. 7 | * Reference paper uses a segmentation network based on [ResNet-38](https://arxiv.org/abs/1512.03385). Instead, this code uses either [ReSeg](https://arxiv.org/abs/1511.07053) with skip-connections based on first seven convolutional layers of [VGG16](https://arxiv.org/abs/1409.1556) as segmentation network or an augmented version of [Stacked Recurrent Hourglass](https://arxiv.org/abs/1806.02070). 8 | * This code uses [KMeans Clustering](http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans); however, reference paper uses "a fast variant of the mean-shift algorithm". 9 | 10 | ---------------------------- 11 | 12 | ## Modules 13 | 14 | * [Convolutional GRU](code/lib/archs/modules/README.md#3) 15 | * [Coordinate Convolution](code/lib/archs/modules/README.md#44) 16 | * [AddCoordinates](code/lib/archs/modules/README.md#44) 17 | * [CoordConv](code/lib/archs/modules/README.md#81) 18 | * [CoordConvTranspose](code/lib/archs/modules/README.md#112) 19 | * [CoordConvNet](code/lib/archs/modules/README.md#145) 20 | * [Recurrent Hourglass](code/lib/archs/modules/README.md#187) 21 | * [ReNet](code/lib/archs/modules/README.md#225) 22 | * [VGG16](code/lib/archs/modules/README.md#258) 23 | * [VGG16](code/lib/archs/modules/README.md#258) 24 | * [SkipVGG16](code/lib/archs/modules/README.md#304) 25 | 26 | ## Networks 27 | 28 | * [ReSeg](code/lib/archs/README.md#3) 29 | * [Stacked Recurrent Hourglass](code/lib/archs/README.md#47) 30 | 31 | ---------------------------- 32 | 33 | In prediction phase, network inputs an image and outputs a semantic segmentation mask, the number of instances and embeddings for all pixels in the image. Then, foreground embeddings (which correspond to instances) are selected using semantic segmentation mask and foreground embeddings are clustered into "the number of instances" groups via clustering. 34 | 35 | # Installation 36 | 37 | * Clone this repository : `git clone --recursive https://github.com/Wizaron/instance-segmentation-pytorch.git` 38 | * Install ImageMagick : `sudo apt install imagemagick` 39 | * Download and install [Anaconda](https://www.anaconda.com/download/) or [Miniconda](https://conda.io/miniconda.html) 40 | * Create a conda environment : `conda env create -f instance-segmentation-pytorch/code/conda_environment.yml` 41 | 42 | ## Data 43 | 44 | ### CVPPP 45 | 46 | * Download [CVPPP dataset](https://www.plant-phenotyping.org/datasets-download) and extract downloaded zip file (`CVPPP2017_LSC_training.zip`) to `instance-segmentation-pytorch/data/raw/CVPPP/` 47 | * This work uses *A1* subset of the dataset. 48 | 49 | ## Code Structure 50 | 51 | * **code**: Codes for training and evaluation. 52 | * **lib** 53 | * **lib/archs**: Stores network architectures. 54 | * **lib/archs/modules**: Stores basic modules for architectures. 55 | * **lib/model.py**: Defines model (optimization, criterion, fit, predict, test, etc.). 56 | * **lib/dataset.py**: Data loading, augmentation, minibatching procedures. 57 | * **lib/preprocess.py**, **lib/utils**: Data augmentation methods. 58 | * **lib/prediction.py**: Prediction module. 59 | * **lib/losses/dice.py**: Dice loss for foreground semantic segmentation. 60 | * **lib/losses/discriminative.py**: [Discriminative loss](https://arxiv.org/pdf/1708.02551.pdf) for instance segmentation. 61 | * **settings** 62 | * **settings/CVPPP/data_settings.py**: Defines settings about data. 63 | * **settings/CVPPP/model_settings.py**: Defines settings about model (hyper-parameters). 64 | * **settings/CVPPP/training_settings.py**: Defines settings for training (optimization method, weight decay, augmentation, etc.). 65 | * **train.py**: Training script. 66 | * **pred.py**: Prediction script for single image. 67 | * **pred_list.py**: Prediction scripts for a list of images. 68 | * **evaluate.py**: Evaluation script. Calculates SBD (symmetric best dice), |DiC| (absolute difference in count) and Foreground Dice (Dice score for semantic segmentation) as defined in the [paper](http://eprints.nottingham.ac.uk/34197/1/MVAP-D-15-00134_Revised_manuscript.pdf). 69 | * **data**: Stores data and scripts to prepare dataset for training and evaluation. 70 | * **metadata/CVPPP**: Stores metadata; such as, training, validation and test splits, image shapes etc. 71 | * **processed/CVPPP**: Stores processed form of the data. 72 | * **raw/CVPPP**: Stores raw form of the data. 73 | * **scripts**: Stores scripts to prepare dataset. 74 | * **scripts/CVPPP**: For CVPPP dataset. 75 | * **scripts/CVPPP/1-create_annotations.py**: Saves annotations as a numpy array to `processed/CVPPP/semantic-annotations/` and `processed/CVPPP/instance-annotations`. 76 | * **scripts/CVPPP/1-remove_alpha.sh**: Removes alpha channels from images. (In order to run this script, `imagemagick` should be installed.). 77 | * **scripts/CVPPP/2-get_image_means-stds.py**: Calculates and prints channel-wise means and standard deviations from training subset. 78 | * **scripts/CVPPP/2-get_image_shapes.py**: Saves image shapes to `metadata/CVPPP/image_shapes.txt`. 79 | * **scripts/CVPPP/2-get_number_of_instances.py**: Saves the number of instances in each image to `metadata/CVPPP/number_of_instances.txt`. 80 | * **scripts/CVPPP/2-get_image_paths.py**: Saves image paths to `metadata/CVPPP/training_image_paths.txt`, `metadata/CVPPP/validation_image_paths.txt` 81 | * **scripts/CVPPP/3-create_dataset.py**: Creates an lmdb dataset to `processed/CVPPP/lmdb/`. 82 | * **scripts/CVPPP/prepare.sh**: Runs the scripts above in a sequential manner. 83 | * **models/CVPPP**: Stores checkpoints of the trained models. 84 | * **outputs/CVPPP**: Stores predictions of the trained models. 85 | 86 | ## Data Preparation 87 | 88 | Data should be prepared prior to training and evaluation. 89 | 90 | * Activate previously created conda environment : `source activate ins-seg-pytorch` or `conda activate ins-seg-pytorch` 91 | 92 | ### CVPPP 93 | 94 | * Place the extracted dataset to `instance-segmentation-pytorch/data/raw/CVPPP/`. Hence, raw dataset should be found at `instance-segmentation-pytorch/data/raw/CVPPP/CVPPP2017_LSC_training/`. 95 | * In order to prepare the data go to `instance-segmentation-pytorch/data/scripts/CVPPP/` and run `sh prepare.sh`. 96 | 97 | ## Visdom Server 98 | 99 | Start a [Visdom](https://github.com/facebookresearch/visdom) server in a `screen` or `tmux`. 100 | 101 | * Activate previously created conda environment : `source activate ins-seg-pytorch` or `conda activate ins-seg-pytorch` 102 | 103 | * Start visdom server : `python -m visdom.server` 104 | 105 | * We can access visdom server using `http://localhost:8097` 106 | 107 | ## Training 108 | 109 | * Activate previously created conda environment : `source activate ins-seg-pytorch` or `conda activate ins-seg-pytorch` 110 | 111 | * Go to `instance-segmentation-pytorch/code/` and run `train.py`. 112 | 113 | ``` 114 | usage: train.py [-h] [--model MODEL] [--usegpu] [--nepochs NEPOCHS] 115 | [--batchsize BATCHSIZE] [--debug] [--nworkers NWORKERS] 116 | --dataset DATASET 117 | 118 | optional arguments: 119 | -h, --help show this help message and exit 120 | --model MODEL Filepath of trained model (to continue training) 121 | [Default: ''] 122 | --usegpu Enables cuda to train on gpu [Default: False] 123 | --nepochs NEPOCHS Number of epochs to train for [Default: 600] 124 | --batchsize BATCHSIZE 125 | Batch size [Default: 2] 126 | --debug Activates debug mode [Default: False] 127 | --nworkers NWORKERS Number of workers for data loading (0 to do it using 128 | main process) [Default : 2] 129 | --dataset DATASET Name of the dataset which is "CVPPP" 130 | ``` 131 | 132 | Debug mode plots pixel embeddings to visdom, it reduces size of the embeddings to two-dimensions using [TSNE](http://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html). Hence, it slows training down. 133 | 134 | As training continues, models are saved to `instance-segmentation-pytorch/models/CVPPP`. 135 | 136 | ## Evaluation 137 | 138 | After training is completed, we can make predictions. 139 | 140 | * Activate previously created conda environment : `source activate ins-seg-pytorch` or `conda activate ins-seg-pytorch` 141 | 142 | * Go to `instance-segmentation-pytorch/code/`. 143 | * Run `pred_list.py`. 144 | 145 | ``` 146 | usage: pred_list.py [-h] --lst LST --model MODEL [--usegpu] 147 | [--n_workers N_WORKERS] --dataset DATASET 148 | 149 | optional arguments: 150 | -h, --help show this help message and exit 151 | --lst LST Text file that contains image paths 152 | --model MODEL Path of the model 153 | --usegpu Enables cuda to predict on gpu 154 | --dataset DATASET Name of the dataset which is "CVPPP" 155 | ``` 156 | 157 | For example: `python pred_list.py --lst ../data/metadata/CVPPP/validation_image_paths.txt --model ../models/CVPPP/2018-3-4_16-15_jcmaxwell_29-937494/model_155_0.123682662845.pth --usegpu --n_workers 4 --dataset CVPPP` 158 | 159 | * Predictions are written to `outputs` directory. 160 | * After prediction is completed we can run `evaluate.py`. It prints output metrics to the stdout. 161 | 162 | ``` 163 | usage: evaluate.py [-h] --pred_dir PRED_DIR --dataset DATASET 164 | 165 | optional arguments: 166 | -h, --help show this help message and exit 167 | --pred_dir PRED_DIR Prediction directory 168 | --dataset DATASET Name of the dataset which is "CVPPP" 169 | ``` 170 | 171 | For example: `python evaluate.py --pred_dir ../outputs/CVPPP/2018-3-4_16-15_jcmaxwell_29-937494-model_155_0.123682662845/validation/ --dataset CVPPP` 172 | 173 | ## Prediction 174 | 175 | After training is complete, we can make predictions. We can use `pred.py` to make predictions for a single image. 176 | 177 | * Activate previously created conda environment : `source activate ins-seg-pytorch` or `conda activate ins-seg-pytorch` 178 | 179 | * Go to `instance-segmentation-pytorch/code/`. 180 | * Run `pred.py`. 181 | 182 | ``` 183 | usage: pred.py [-h] --image IMAGE --model MODEL [--usegpu] --output OUTPUT 184 | [--n_workers N_WORKERS] --dataset DATASET 185 | 186 | optional arguments: 187 | -h, --help show this help message and exit 188 | --image IMAGE Path of the image 189 | --model MODEL Path of the model 190 | --usegpu Enables cuda to predict on gpu 191 | --output OUTPUT Path of the output directory 192 | --dataset DATASET Name of the dataset which is "CVPPP" 193 | ``` 194 | 195 | ## Results 196 | 197 | ### CVPPP 198 | 199 | #### Scores on validation subset (28 images) 200 | 201 | | SBD | \|DiC\| | Foreground Dice | 202 | |:-------------:|:-------------:|:----------------:| 203 | | 87.9 | 0.5 | 96.8 | 204 | 205 | #### Sample Predictions 206 | 207 | ![plant007 image](samples/CVPPP/plant007_rgb.png) ![plant007 image](samples/CVPPP/plant007_rgb-ins_mask_color.png) ![plant007 image](samples/CVPPP/plant007_rgb-fg_mask.png) 208 | ![plant031 image](samples/CVPPP/plant031_rgb.png?raw=true "plant031 image") ![plant031 image](samples/CVPPP/plant031_rgb-ins_mask_color.png?raw=true "plant031 instance segmentation") ![plant031 image](samples/CVPPP/plant031_rgb-fg_mask.png?raw=true "plant031 foreground segmentation") 209 | ![plant033 image](samples/CVPPP/plant033_rgb.png?raw=true "plant033 image") ![plant033 image](samples/CVPPP/plant033_rgb-ins_mask_color.png?raw=true "plant033 instance segmentation") ![plant033 image](samples/CVPPP/plant033_rgb-fg_mask.png?raw=true "plant033 foreground segmentation") 210 | 211 | # References 212 | 213 | * [VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION](https://arxiv.org/abs/1409.1556) 214 | * [ReNet: A Recurrent Neural Network Based Alternative to Convolutional Networks](https://arxiv.org/abs/1505.00393) 215 | * [DELVING DEEPER INTO CONVOLUTIONAL NETWORKS FOR LEARNING VIDEO REPRESENTATIONS](https://arxiv.org/abs/1511.06432) 216 | * [ReSeg: A Recurrent Neural Network-based Model for Semantic Segmentation](https://arxiv.org/abs/1511.07053) 217 | * [Semantic Instance Segmentation with a Discriminative Loss Function](https://arxiv.org/abs/1708.02551) 218 | * [Instance Segmentation and Tracking with Cosine Embeddings and Recurrent Hourglass Networks](https://arxiv.org/abs/1806.02070) 219 | * [An intriguing failing of convolutional neural networks and the CoordConv solution](https://arxiv.org/abs/1807.03247) 220 | * [Leaf segmentation in plant phenotyping: A collation study](http://eprints.nottingham.ac.uk/34197/1/MVAP-D-15-00134_Revised_manuscript.pdf) 221 | -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/code/README.md -------------------------------------------------------------------------------- /code/conda_environment.yml: -------------------------------------------------------------------------------- 1 | name: ins-seg-pytorch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - backports=1.0=py27h63c9359_1 7 | - backports.functools_lru_cache=1.4=py27he8db605_1 8 | - backports.shutil_get_terminal_size=1.0.0=py27h5bc021e_2 9 | - backports_abc=0.5=py27h7b3c97b_0 10 | - bzip2=1.0.6=h6d464ef_2 11 | - ca-certificates=2017.08.26=h1d4fec5_0 12 | - cairo=1.14.10=hdf128ce_6 13 | - certifi=2017.11.5=py27h71e7faf_0 14 | - cffi=1.11.2=py27ha7929c6_0 15 | - cudatoolkit=8.0=3 16 | - cycler=0.10.0=py27hc7354d3_0 17 | - dbus=1.12.2=hc3f9b76_1 18 | - decorator=4.1.2=py27h1544723_0 19 | - enum34=1.1.6=py27h99a27e9_1 20 | - expat=2.2.5=he0dffb1_0 21 | - ffmpeg=3.4=h7264315_0 22 | - fontconfig=2.12.4=h88586e7_1 23 | - freetype=2.8=hab7d2ae_1 24 | - functools32=3.2.3.2=py27h4ead58f_1 25 | - glib=2.53.6=h5d9569c_2 26 | - graphite2=1.3.10=hc526e54_0 27 | - gst-plugins-base=1.12.4=h33fb286_0 28 | - gstreamer=1.12.4=hb53b477_0 29 | - harfbuzz=1.5.0=h2545bd6_0 30 | - hdf5=1.10.1=h9caa474_1 31 | - icu=58.2=h9c2bf20_1 32 | - intel-openmp=2018.0.0=hc7b2577_8 33 | - ipython=5.4.1=py27h36c99b6_1 34 | - ipython_genutils=0.2.0=py27h89fb69b_0 35 | - jasper=1.900.1=hd497a04_4 36 | - jpeg=9b=h024ee3a_2 37 | - libedit=3.1=heed3624_0 38 | - libffi=3.2.1=hd88cf55_4 39 | - libgcc-ng=7.2.0=h7cc24e2_2 40 | - libgfortran-ng=7.2.0=h9f7466a_2 41 | - libopus=1.2.1=hb9ed12e_0 42 | - libpng=1.6.32=hbd3595f_4 43 | - libprotobuf=3.4.1=h5b8497f_0 44 | - libstdcxx-ng=7.2.0=h7a57d05_2 45 | - libtiff=4.0.9=h28f6b97_0 46 | - libvpx=1.6.1=h888fd40_0 47 | - libxcb=1.12=hcd93eb1_4 48 | - libxml2=2.9.4=h2e8b1d7_6 49 | - matplotlib=2.1.2=py27h0e671d2_0 50 | - mkl=2018.0.1=h19d6760_4 51 | - ncurses=6.0=h9df7e31_2 52 | - ninja=1.8.2=py27h6bb024c_1 53 | - numpy=1.13.3=py27hbcc08e0_0 54 | - olefile=0.44=py27h4bd3e3c_0 55 | - opencv=3.3.1=py27h6cbbc71_1 56 | - openssl=1.0.2n=hb7f436b_0 57 | - pathlib2=2.3.0=py27h6e9d198_0 58 | - pcre=8.41=hc27e229_1 59 | - pexpect=4.3.0=py27hdeba8d9_0 60 | - pickleshare=0.7.4=py27h09770e1_0 61 | - pillow=4.3.0=py27h353bd0c_1 62 | - pip=9.0.1=py27ha730c48_4 63 | - pixman=0.34.0=hceecf20_3 64 | - prompt_toolkit=1.0.15=py27h1b593e1_0 65 | - ptyprocess=0.5.2=py27h4ccb14c_0 66 | - pycparser=2.18=py27hefa08c5_1 67 | - pygments=2.2.0=py27h4a8b6f5_0 68 | - pyparsing=2.2.0=py27hf1513f8_1 69 | - pyqt=5.6.0=py27h4b1e83c_5 70 | - python=2.7.14=h1571d57_29 71 | - python-dateutil=2.6.1=py27h4ca5741_1 72 | - pytz=2017.3=py27h001bace_0 73 | - qt=5.6.2=h974d657_12 74 | - readline=7.0=ha6073c6_4 75 | - scandir=1.6=py27hf7388dc_0 76 | - scikit-learn=0.19.1=py27h445a80a_0 77 | - scipy=1.0.0=py27hf5f0f52_0 78 | - setuptools=36.5.0=py27h68b189e_0 79 | - simplegeneric=0.8.1=py27h19e43cd_0 80 | - singledispatch=3.4.0.3=py27h9bcb476_0 81 | - sip=4.18.1=py27he9ba0ab_2 82 | - six=1.11.0=py27h5f960f1_1 83 | - sqlite=3.20.1=hb898158_2 84 | - ssl_match_hostname=3.5.0.1=py27h4ec10b9_2 85 | - subprocess32=3.2.7=py27h373dbce_0 86 | - tk=8.6.7=hc745277_3 87 | - tornado=4.5.3=py27_0 88 | - traitlets=4.3.2=py27hd6ce930_0 89 | - wcwidth=0.1.7=py27h9e3e1ab_0 90 | - wheel=0.30.0=py27h2bc6bb2_1 91 | - xz=5.2.3=h55aa19d_2 92 | - zlib=1.2.11=ha838bed_2 93 | - pytorch=0.4.0=py27_cuda8.0.61_cudnn7.1.2_1 94 | - torchvision=0.2.1=py27_1 95 | - pip: 96 | - backports.ssl-match-hostname==3.5.0.1 97 | - chardet==3.0.4 98 | - convertdate==2.1.1 99 | - dateparser==0.6.0 100 | - ephem==3.7.6.0 101 | - idna==2.6 102 | - iso3166==0.8 103 | - lmdb==0.93 104 | - pathlib==1.0.1 105 | - pylibdmtx==0.1.7 106 | - pyyaml==3.12 107 | - pyzmq==16.0.4 108 | - regex==2018.7.11 109 | - requests==2.18.4 110 | - ruamel.ordereddict==0.4.13 111 | - ruamel.yaml==0.15.26 112 | - torch==0.4.0 113 | - torchfile==0.1.0 114 | - tzlocal==1.4 115 | - umalqurra==0.2 116 | - urllib3==1.22 117 | - visdom==0.1.7 118 | - xlrd==1.0.0 119 | prefix: /home/jcmaxwell/Libraries/miniconda2/envs/ins-seg-pytorch 120 | 121 | -------------------------------------------------------------------------------- /code/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--pred_dir', required=True, help='Prediction directory') 8 | parser.add_argument('--dataset', type=str, 9 | help='Name of the dataset which is "CVPPP"', 10 | required=True) 11 | opt = parser.parse_args() 12 | 13 | assert opt.dataset in ['CVPPP', ] 14 | 15 | pred_dir = opt.pred_dir 16 | 17 | 18 | def calc_dic(n_objects_gt, n_objects_pred): 19 | return np.abs(n_objects_gt - n_objects_pred) 20 | 21 | 22 | def calc_dice(gt_seg, pred_seg): 23 | 24 | nom = 2 * np.sum(gt_seg * pred_seg) 25 | denom = np.sum(gt_seg) + np.sum(pred_seg) 26 | 27 | dice = float(nom) / float(denom) 28 | return dice 29 | 30 | 31 | def calc_bd(ins_seg_gt, ins_seg_pred): 32 | 33 | gt_object_idxes = list(set(np.unique(ins_seg_gt)).difference([0])) 34 | pred_object_idxes = list(set(np.unique(ins_seg_pred)).difference([0])) 35 | 36 | best_dices = [] 37 | for gt_idx in gt_object_idxes: 38 | _gt_seg = (ins_seg_gt == gt_idx).astype('bool') 39 | dices = [] 40 | for pred_idx in pred_object_idxes: 41 | _pred_seg = (ins_seg_pred == pred_idx).astype('bool') 42 | 43 | dice = calc_dice(_gt_seg, _pred_seg) 44 | dices.append(dice) 45 | best_dice = np.max(dices) 46 | best_dices.append(best_dice) 47 | 48 | best_dice = np.mean(best_dices) 49 | 50 | return best_dice 51 | 52 | 53 | def calc_sbd(ins_seg_gt, ins_seg_pred): 54 | 55 | _dice1 = calc_bd(ins_seg_gt, ins_seg_pred) 56 | _dice2 = calc_bd(ins_seg_pred, ins_seg_gt) 57 | return min(_dice1, _dice2) 58 | 59 | 60 | if opt.dataset == 'CVPPP': 61 | names = np.loadtxt('../data/metadata/CVPPP/validation_image_paths.txt', 62 | dtype='str', delimiter=',') 63 | names = np.array([os.path.splitext(os.path.basename(n))[0] for n in names]) 64 | n_objects_gts = np.loadtxt( 65 | '../data/metadata/CVPPP/number_of_instances.txt', 66 | dtype='str', 67 | delimiter=',') 68 | img_dir = '../data/raw/CVPPP/CVPPP2017_LSC_training/training/A1' 69 | 70 | dics, sbds, fg_dices = [], [], [] 71 | for name in names: 72 | if not os.path.isfile( 73 | '{}/{}/{}-n_objects.npy'.format(pred_dir, name, name)): 74 | continue 75 | 76 | n_objects_gt = int(n_objects_gts[n_objects_gts[:, 0] == name.replace('_rgb', '')][0][1]) 77 | n_objects_pred = np.load( 78 | '{}/{}/{}-n_objects.npy'.format(pred_dir, name, name)) 79 | 80 | ins_seg_gt = np.array(Image.open( 81 | os.path.join(img_dir, name.replace('_rgb', '') + '_label.png'))) 82 | ins_seg_pred = np.array(Image.open(os.path.join( 83 | pred_dir, name, name + '-ins_mask.png'))) 84 | 85 | fg_seg_gt = np.array( 86 | Image.open( 87 | os.path.join( 88 | img_dir, 89 | name.replace('_rgb', '') + 90 | '_fg.png'))) 91 | fg_seg_pred = np.array(Image.open(os.path.join( 92 | pred_dir, name, name + '-fg_mask.png'))) 93 | 94 | fg_seg_gt = (fg_seg_gt == 1).astype('bool') 95 | fg_seg_pred = (fg_seg_pred == 255).astype('bool') 96 | 97 | sbd = calc_sbd(ins_seg_gt, ins_seg_pred) 98 | sbds.append(sbd) 99 | 100 | dic = calc_dic(n_objects_gt, n_objects_pred) 101 | dics.append(dic) 102 | 103 | fg_dice = calc_dice(fg_seg_gt, fg_seg_pred) 104 | fg_dices.append(fg_dice) 105 | 106 | mean_dic = np.mean(dics) 107 | mean_sbd = np.mean(sbds) 108 | mean_fg_dice = np.mean(fg_dices) 109 | 110 | print 'MEAN SBD : ', mean_sbd 111 | print 'MEAN |DIC| : ', mean_dic 112 | print 'MEAN FG DICE : ', mean_fg_dice 113 | -------------------------------------------------------------------------------- /code/lib/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset import SegDataset, AlignCollate 2 | from model import Model 3 | from prediction import Prediction 4 | -------------------------------------------------------------------------------- /code/lib/archs/README.md: -------------------------------------------------------------------------------- 1 | # Architectures 2 | 3 | ## ReSeg 4 | 5 | * Proposed in [ReSeg: A Recurrent Neural Network-based Model for Semantic Segmentation](https://arxiv.org/pdf/1511.07053.pdf) 6 | * Can be found at `reseg.py` 7 | 8 | ``` 9 | ReSeg Module (with modifications) as defined in 'ReSeg: A Recurrent 10 | Neural Network-based Model for Semantic Segmentation' 11 | (https://arxiv.org/pdf/1511.07053.pdf). 12 | 13 | * VGG16 with skip Connections as base network 14 | * Two ReNet layers 15 | * Two transposed convolutional layers for upsampling 16 | * Three heads for semantic segmentation, instance segmentation and 17 | instance counting. 18 | 19 | Args: 20 | n_classes (int): Number of semantic classes 21 | use_instance_seg (bool, optional): If `False`, does not perform 22 | instance segmentation. Default: `True` 23 | pretrained (bool, optional): If `True`, initializes weights of the 24 | VGG16 using weights trained on ImageNet. Default: `True` 25 | use_coordinates (bool, optional): If `True`, adds coordinate 26 | information to input image and hidden state. Default: `False` 27 | usegpu (bool, optional): If `True`, runs operations on GPU 28 | Default: `True` 29 | 30 | Shape: 31 | - Input: `(N, C_{in}, H_{in}, W_{in})` 32 | - Output: 33 | - Semantic Seg: `(N, N_{class}, H_{in}, W_{in})` 34 | - Instance Seg: `(N, 32, H_{in}, W_{in})` 35 | - Instance Cnt: `(N, 1)` 36 | 37 | Examples: 38 | >>> reseg = ReSeg(3, True, True, True, False) 39 | >>> input = torch.randn(8, 3, 64, 64) 40 | >>> outputs = reseg(input) 41 | 42 | >>> reseg = ReSeg(3, True, True, True, True).cuda() 43 | >>> input = torch.randn(8, 3, 64, 64).cuda() 44 | >>> outputs = reseg(input) 45 | ``` 46 | 47 | ## Stacked Recurrent Hourglass 48 | 49 | * Proposed in [Instance Segmentation and Tracking with Cosine Embeddings and Recurrent Hourglass Networks](https://arxiv.org/pdf/1806.02070.pdf) 50 | * Can be found at `stacked_recurrent_hourglass.py` 51 | 52 | ``` 53 | Stacked Recurrent Hourglass Module for instance segmentation 54 | as defined in 'Instance Segmentation and Tracking with Cosine 55 | Embeddings and Recurrent Hourglass Networks' 56 | (https://arxiv.org/pdf/1806.02070.pdf). 57 | 58 | * First four layers of VGG16 59 | * Two RecurrentHourglass layers 60 | * Two ReNet layers 61 | * Two transposed convolutional layers for upsampling 62 | * Three heads for semantic segmentation, instance segmentation and 63 | instance counting. 64 | 65 | Args: 66 | n_classes (int): Number of semantic classes 67 | use_instance_seg (bool, optional): If `False`, does not perform 68 | instance segmentation. Default: `True` 69 | pretrained (bool, optional): If `True`, initializes weights of the 70 | VGG16 using weights trained on ImageNet. Default: `True` 71 | use_coordinates (bool, optional): If `True`, adds coordinate 72 | information to input image and hidden state. Default: `False` 73 | usegpu (bool, optional): If `True`, runs operations on GPU 74 | Default: `True` 75 | 76 | Shape: 77 | - Input: `(N, C_{in}, H_{in}, W_{in})` 78 | - Output: 79 | - Semantic Seg: `(N, N_{class}, H_{in}, W_{in})` 80 | - Instance Seg: `(N, 32, H_{in}, W_{in})` 81 | - Instance Cnt: `(N, 1)` 82 | 83 | Examples: 84 | >>> srhg = StackedRecurrentHourglass(4, True, True, True, False) 85 | >>> input = torch.randn(8, 3, 64, 64) 86 | >>> outputs = srhg(input) 87 | 88 | >>> srhg = StackedRecurrentHourglass(4, True, True, True, True) 89 | >>> srhg = srhg.cuda() 90 | >>> input = torch.randn(8, 3, 64, 64).cuda() 91 | >>> outputs = srhg(input) 92 | ``` 93 | -------------------------------------------------------------------------------- /code/lib/archs/__init__.py: -------------------------------------------------------------------------------- 1 | from stacked_recurrent_hourglass import StackedRecurrentHourglass 2 | from reseg import ReSeg 3 | -------------------------------------------------------------------------------- /code/lib/archs/instance_counter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from modules.coord_conv import CoordConvNet 3 | 4 | 5 | class InstanceCounter(nn.Module): 6 | 7 | r"""Instance Counter Module. Basically, it is a convolutional network 8 | to count instances for a given feature map. 9 | 10 | Args: 11 | input_n_filters (int): Number of channels in the input image 12 | use_coordinates (bool, optional): If `True`, adds coordinate 13 | information to input image and hidden state. Default: `False` 14 | usegpu (bool, optional): If `True`, runs operations on GPU 15 | Default: `True` 16 | 17 | Shape: 18 | - Input: `(N, C_{in}, H_{in}, W_{in})` 19 | - Output: `(N, 1)` 20 | 21 | Examples: 22 | >>> ins_cnt = InstanceCounter(3, True, False) 23 | >>> input = torch.randn(8, 3, 64, 64) 24 | >>> output = ins_cnt(input) 25 | 26 | >>> ins_cnt = InstanceCounter(3, True, True).cuda() 27 | >>> input = torch.randn(8, 3, 64, 64).cuda() 28 | >>> output = ins_cnt(input) 29 | """ 30 | 31 | def __init__(self, input_n_filters, use_coordinates=False, 32 | usegpu=True): 33 | super(InstanceCounter, self).__init__() 34 | 35 | self.input_n_filters = input_n_filters 36 | self.n_filters = 32 37 | self.use_coordinates = use_coordinates 38 | self.usegpu = usegpu 39 | 40 | self.__generate_cnn() 41 | 42 | self.output = nn.Sequential() 43 | self.output.add_module('linear', nn.Linear(self.n_filters, 44 | 1)) 45 | self.output.add_module('sigmoid', nn.Sigmoid()) 46 | 47 | def __generate_cnn(self): 48 | 49 | self.cnn = nn.Sequential() 50 | self.cnn.add_module('pool1', nn.MaxPool2d(2, stride=2)) 51 | self.cnn.add_module('conv1', nn.Conv2d(self.input_n_filters, 52 | self.n_filters, 53 | kernel_size=(3, 3), 54 | stride=(1, 1), 55 | padding=(1, 1))) 56 | self.cnn.add_module('relu1', nn.ReLU()) 57 | self.cnn.add_module('conv2', nn.Conv2d(self.n_filters, 58 | self.n_filters, 59 | kernel_size=(3, 3), 60 | stride=(1, 1), 61 | padding=(1, 1))) 62 | self.cnn.add_module('relu2', nn.ReLU()) 63 | self.cnn.add_module('pool2', nn.MaxPool2d(2, stride=2)) 64 | self.cnn.add_module('conv3', nn.Conv2d(self.n_filters, 65 | self.n_filters, 66 | kernel_size=(3, 3), 67 | stride=(1, 1), 68 | padding=(1, 1))) 69 | self.cnn.add_module('relu3', nn.ReLU()) 70 | self.cnn.add_module('conv4', nn.Conv2d(self.n_filters, 71 | self.n_filters, 72 | kernel_size=(3, 3), 73 | stride=(1, 1), 74 | padding=(1, 1))) 75 | self.cnn.add_module('relu4', nn.ReLU()) 76 | self.cnn.add_module('pool3', nn.AdaptiveAvgPool2d((1, 1))) 77 | # b, nf, 1, 1 78 | 79 | if self.use_coordinates: 80 | self.cnn = CoordConvNet(self.cnn, with_r=True, 81 | usegpu=self.usegpu) 82 | 83 | def forward(self, x): 84 | 85 | x = self.cnn(x) 86 | if self.use_coordinates: 87 | x = x[-1] 88 | x = x.squeeze(3).squeeze(2) 89 | x = self.output(x) 90 | 91 | return x 92 | -------------------------------------------------------------------------------- /code/lib/archs/modules/README.md: -------------------------------------------------------------------------------- 1 | # Modules 2 | 3 | ## ConvGRUCell 4 | 5 | * Proposed in [Delving Deeper into Convolutional Networks for Learning Video Representations](https://arxiv.org/pdf/1511.06432.pdf) 6 | * Can be found at `conv_gru.py` 7 | 8 | ``` 9 | Convolutional GRU Module as defined in 'Delving Deeper into 10 | Convolutional Networks for Learning Video Representations' 11 | (https://arxiv.org/pdf/1511.06432.pdf). 12 | 13 | Args: 14 | input_size (int): Number of channels in the input image 15 | hidden_size (int): Number of channels produced by the ConvGRU 16 | kernel_size (int or tuple): Size of the convolving kernel 17 | use_coordinates (bool, optional): If `True`, adds coordinate 18 | information to input image and hidden state. Default: `False` 19 | usegpu (bool, optional): If `True`, runs operations on GPU 20 | Default: `True` 21 | 22 | Shape: 23 | - Input: 24 | - `x` : `(N, C_{in}, H_{in}, W_{in})` 25 | - `hidden` : `(N, C_{out}, H_{in}, W_{in})` or `None` 26 | - Output: `next_hidden` : `(N, C_{out}, H_{in}, W_{in})` 27 | 28 | Examples: 29 | >>> n_hidden = 16 30 | >>> conv_gru = ConvGRUCell(3, n_hidden, 3, True, False) 31 | >>> input = torch.randn(8, 3, 64, 64) 32 | >>> hidden = torch.rand(8, n_hidden, 64, 64) 33 | >>> output = conv_gru(input, None) 34 | >>> output = conv_gru(input, hidden) 35 | 36 | >>> n_hidden = 16 37 | >>> conv_gru = ConvGRUCell(3, n_hidden, 3, True, True).cuda() 38 | >>> input = torch.randn(8, 3, 64, 64).cuda() 39 | >>> hidden = torch.rand(8, n_hidden, 64, 64).cuda() 40 | >>> output = conv_gru(input, None) 41 | >>> output = conv_gru(input, hidden) 42 | ``` 43 | 44 | ## AddCoordinates 45 | 46 | * Proposed in [An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution](https://arxiv.org/pdf/1807.03247.pdf) 47 | * Can be found at `coord_conv.py` 48 | 49 | ``` 50 | Coordinate Adder Module as defined in 'An Intriguing Failing of 51 | Convolutional Neural Networks and the CoordConv Solution' 52 | (https://arxiv.org/pdf/1807.03247.pdf). 53 | 54 | This module concatenates coordinate information (`x`, `y`, and `r`) with 55 | given input tensor. 56 | 57 | `x` and `y` coordinates are scaled to `[-1, 1]` range where origin is the 58 | center. `r` is the Euclidean distance from the center and is scaled to 59 | `[0, 1]`. 60 | 61 | Args: 62 | with_r (bool, optional): If `True`, adds radius (`r`) coordinate 63 | information to input image. Default: `False` 64 | usegpu (bool, optional): If `True`, runs operations on GPU 65 | Default: `True` 66 | 67 | Shape: 68 | - Input: `(N, C_{in}, H_{in}, W_{in})` 69 | - Output: `(N, (C_{in} + 2) or (C_{in} + 3), H_{in}, W_{in})` 70 | 71 | Examples: 72 | >>> coord_adder = AddCoordinates(True, False) 73 | >>> input = torch.randn(8, 3, 64, 64) 74 | >>> output = coord_adder(input) 75 | 76 | >>> coord_adder = AddCoordinates(True, True).cuda() 77 | >>> input = torch.randn(8, 3, 64, 64).cuda() 78 | >>> output = coord_adder(input) 79 | ``` 80 | 81 | ## CoordConv 82 | 83 | * Proposed in [An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution](https://arxiv.org/pdf/1807.03247.pdf) 84 | * Can be found at `coord_conv.py` 85 | 86 | ``` 87 | 2D Convolution Module Using Extra Coordinate Information as defined 88 | in 'An Intriguing Failing of Convolutional Neural Networks and the 89 | CoordConv Solution' (https://arxiv.org/pdf/1807.03247.pdf). 90 | 91 | Args: 92 | Same as `torch.nn.Conv2d` with two additional arguments 93 | with_r (bool, optional): If `True`, adds radius (`r`) coordinate 94 | information to input image. Default: `False` 95 | usegpu (bool, optional): If `True`, runs operations on GPU 96 | Default: `True` 97 | 98 | Shape: 99 | - Input: `(N, C_{in}, H_{in}, W_{in})` 100 | - Output: `(N, C_{out}, H_{out}, W_{out})` 101 | 102 | Examples: 103 | >>> coord_conv = CoordConv(3, 16, 3, with_r=True, usegpu=False) 104 | >>> input = torch.randn(8, 3, 64, 64) 105 | >>> output = coord_conv(input) 106 | 107 | >>> coord_conv = CoordConv(3, 16, 3, with_r=True, usegpu=True).cuda() 108 | >>> input = torch.randn(8, 3, 64, 64).cuda() 109 | >>> output = coord_conv(input) 110 | ``` 111 | 112 | ## CoordConvTranspose 113 | 114 | * Proposed in [An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution](https://arxiv.org/pdf/1807.03247.pdf) 115 | * Can be found at `coord_conv.py` 116 | 117 | ``` 118 | 2D Transposed Convolution Module Using Extra Coordinate Information 119 | as defined in 'An Intriguing Failing of Convolutional Neural Networks and 120 | the CoordConv Solution' (https://arxiv.org/pdf/1807.03247.pdf). 121 | 122 | Args: 123 | Same as `torch.nn.ConvTranspose2d` with two additional arguments 124 | with_r (bool, optional): If `True`, adds radius (`r`) coordinate 125 | information to input image. Default: `False` 126 | usegpu (bool, optional): If `True`, runs operations on GPU 127 | Default: `True` 128 | 129 | Shape: 130 | - Input: `(N, C_{in}, H_{in}, W_{in})` 131 | - Output: `(N, C_{out}, H_{out}, W_{out})` 132 | 133 | Examples: 134 | >>> coord_conv_tr = CoordConvTranspose(3, 16, 3, with_r=True, 135 | >>> usegpu=False) 136 | >>> input = torch.randn(8, 3, 64, 64) 137 | >>> output = coord_conv_tr(input) 138 | 139 | >>> coord_conv_tr = CoordConvTranspose(3, 16, 3, with_r=True, 140 | >>> usegpu=True).cuda() 141 | >>> input = torch.randn(8, 3, 64, 64).cuda() 142 | >>> output = coord_conv_tr(input) 143 | ``` 144 | 145 | ## CoordConvNet 146 | 147 | * Proposed in [An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution](https://arxiv.org/pdf/1807.03247.pdf) 148 | * Can be found at `coord_conv.py` 149 | 150 | ``` 151 | Improves 2D Convolutions inside a ConvNet by processing extra 152 | coordinate information as defined in 'An Intriguing Failing of 153 | Convolutional Neural Networks and the CoordConv Solution' 154 | (https://arxiv.org/pdf/1807.03247.pdf). 155 | 156 | This module adds coordinate information to inputs of each 2D convolution 157 | module (`torch.nn.Conv2d`). 158 | 159 | Assumption: ConvNet Model must contain single `Sequential` container 160 | (`torch.nn.modules.container.Sequential`). 161 | 162 | Args: 163 | cnn_model: A ConvNet model that must contain single `Sequential` 164 | container (`torch.nn.modules.container.Sequential`). 165 | with_r (bool, optional): If `True`, adds radius (`r`) coordinate 166 | information to input image. Default: `False` 167 | usegpu (bool, optional): If `True`, runs operations on GPU 168 | Default: `True` 169 | 170 | Shape: 171 | - Input: Same as the input of the model. 172 | - Output: A list that contains all outputs (including 173 | intermediate outputs) of the model. 174 | 175 | Examples: 176 | >>> cnn_model = ... 177 | >>> cnn_model = CoordConvNet(cnn_model, True, False) 178 | >>> input = torch.randn(8, 3, 64, 64) 179 | >>> outputs = cnn_model(input) 180 | 181 | >>> cnn_model = ... 182 | >>> cnn_model = CoordConvNet(cnn_model, True, True).cuda() 183 | >>> input = torch.randn(8, 3, 64, 64).cuda() 184 | >>> outputs = cnn_model(input) 185 | ``` 186 | 187 | ## RecurrentHourglass 188 | 189 | * Proposed in [Instance Segmentation and Tracking with Cosine Embeddings and Recurrent Hourglass Networks](https://arxiv.org/pdf/1806.02070.pdf) 190 | * Can be found at `recurrent_hourglass.py` 191 | 192 | ``` 193 | RecurrentHourglass Module as defined in 194 | 'Instance Segmentation and Tracking with Cosine Embeddings and Recurrent 195 | Hourglass Networks' (https://arxiv.org/pdf/1806.02070.pdf). 196 | 197 | Args: 198 | input_n_filters (int): Number of channels in the input image 199 | hidden_n_filters (int): Number of channels produced by Convolutional 200 | GRU module 201 | kernel_size (int or tuple): Size of the convolving kernels 202 | n_levels (int): Number of timesteps to unroll Convolutional GRU 203 | module 204 | embedding_size (int): Number of channels produced by Recurrent 205 | Hourglass module 206 | use_coordinates (bool, optional): If `True`, adds coordinate 207 | information to input image and hidden state. Default: `False` 208 | usegpu (bool, optional): If `True`, runs operations on GPU 209 | Default: `True` 210 | 211 | Shape: 212 | - Input: `(N, C_{in}, H_{in}, W_{in})` 213 | - Output: `(N, C_{out}, H_{in}, W_{in})` 214 | 215 | Examples: 216 | >>> hg = RecurrentHourglass(3, 16, 3, 5, 32, True, False) 217 | >>> input = torch.randn(8, 3, 64, 64) 218 | >>> output = hg(input) 219 | 220 | >>> hg = RecurrentHourglass(3, 16, 3, 5, 32, True, True).cuda() 221 | >>> input = torch.randn(8, 3, 64, 64).cuda() 222 | >>> output = hg(input) 223 | ``` 224 | 225 | ## ReNet 226 | 227 | * Proposed in [ReNet: A Recurrent Neural Network Based Alternative to Convolutional Networks](https://arxiv.org/pdf/1505.00393.pdf) 228 | * Can be found at `renet.py` 229 | 230 | ``` 231 | ReNet Module as defined in 'ReNet: A Recurrent Neural 232 | Network Based Alternative to Convolutional Networks' 233 | (https://arxiv.org/pdf/1505.00393.pdf). 234 | 235 | Args: 236 | n_input (int): Number of channels in the input image 237 | n_units (int): Number of channels produced by ReNet 238 | patch_size (tuple): Patch size in the input of ReNet 239 | use_coordinates (bool, optional): If `True`, adds coordinate 240 | information to input image and hidden state. Default: `False` 241 | usegpu (bool, optional): If `True`, runs operations on GPU 242 | Default: `True` 243 | 244 | Shape: 245 | - Input: `(N, C_{in}, H_{in}, W_{in})` 246 | - Output: `(N, C_{out}, H_{out}, W_{out})` 247 | 248 | Examples: 249 | >>> renet = ReNet(3, 16, (2, 2), True, False) 250 | >>> input = torch.randn(8, 3, 64, 64) 251 | >>> output = renet(input) 252 | 253 | >>> renet = ReNet(3, 16, (2, 2), True, True).cuda() 254 | >>> input = torch.randn(8, 3, 64, 64).cuda() 255 | >>> output = renet(input) 256 | ``` 257 | 258 | ## VGG16 259 | 260 | * Proposed in [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/pdf/1409.1556.pdf) 261 | * Can be found at `vgg16.py` 262 | 263 | ``` 264 | A module that augments VGG16 as defined in 'Very Deep Convolutional 265 | Networks for Large-Scale Image Recognition' 266 | (https://arxiv.org/pdf/1409.1556.pdf). 267 | 268 | 1. It can return first `n_layers` of the VGG16. 269 | 2. It can add coordinates to feature maps prior to each convolution. 270 | 3. It can return all outputs (including intermediate outputs) of the 271 | VGG16. 272 | 273 | Args: 274 | n_layers (int): Use first `n_layers` layers of the VGG16 275 | pretrained (bool, optional): If `True`, initializes weights of the 276 | VGG16 using weights trained on ImageNet. Default: `True` 277 | use_coordinates (bool, optional): If `True`, adds `x`, `y` and radius 278 | (`r`) coordinates to feature maps prior to each convolution. 279 | Weights to process these coordinates are initialized as zero. 280 | Default: `False` 281 | return_intermediate_outputs (bool, optional): If `True`, return 282 | outputs of the each layer in the VGG16 as a list otherwise 283 | return output of the last layer of first `n_layers` layers 284 | of the VGG16. Default: `False` 285 | usegpu (bool, optional): If `True`, runs operations on GPU 286 | Default: `True` 287 | 288 | Shape: 289 | - Input: `(N, C_{in}, H_{in}, W_{in})` 290 | - Output: Output of the last layer of the selected subpart of VGG16 291 | or the list that contains outputs of the each layer depending on 292 | `return_intermediate_outputs` 293 | 294 | Examples: 295 | >>> vgg16 = VGG16(16, True, True, True, False) 296 | >>> input = torch.randn(8, 3, 64, 64) 297 | >>> output = vgg16(input) 298 | 299 | >>> vgg16 = VGG16(16, True, True, True, True).cuda() 300 | >>> input = torch.randn(8, 3, 64, 64).cuda() 301 | >>> output = vgg16(input) 302 | ``` 303 | 304 | ## SkipVGG16 305 | 306 | * Proposed in [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/pdf/1409.1556.pdf) 307 | * Can be found at `vgg16.py` 308 | 309 | ``` 310 | A module that returns output of 7th convolutional layer of the 311 | VGG16 along with outputs of the 2nd and 4th convolutional layers. 312 | 313 | Args: 314 | pretrained (bool, optional): If `True`, initializes weights of the 315 | VGG16 using weights trained on ImageNet. Default: `True` 316 | use_coordinates (bool, optional): If `True`, adds `x`, `y` and radius 317 | (`r`) coordinates to feature maps prior to each convolution. 318 | Weights to process these coordinates are initialized as zero. 319 | Default: `False` 320 | usegpu (bool, optional): If `True`, runs operations on GPU 321 | Default: `True` 322 | 323 | Shape: 324 | - Input: `(N, C_{in}, H_{in}, W_{in})` 325 | - Output: List of outputs of the 2nd, 4th and 7th convolutional 326 | layers of the VGG16, respectively. 327 | 328 | Examples: 329 | >>> vgg16 = SkipVGG16(True, True, False) 330 | >>> input = torch.randn(8, 3, 64, 64) 331 | >>> output = vgg16(input) 332 | 333 | >>> vgg16 = SkipVGG16(True, True, True).cuda() 334 | >>> input = torch.randn(8, 3, 64, 64).cuda() 335 | >>> output = vgg16(input) 336 | ``` 337 | -------------------------------------------------------------------------------- /code/lib/archs/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/code/lib/archs/modules/__init__.py -------------------------------------------------------------------------------- /code/lib/archs/modules/conv_gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from coord_conv import CoordConv 6 | 7 | # Adapted from: https://github.com/bionick87/ConvGRUCell-pytorch 8 | 9 | 10 | class ConvGRUCell(nn.Module): 11 | 12 | r"""Convolutional GRU Module as defined in 'Delving Deeper into 13 | Convolutional Networks for Learning Video Representations' 14 | (https://arxiv.org/pdf/1511.06432.pdf). 15 | 16 | Args: 17 | input_size (int): Number of channels in the input image 18 | hidden_size (int): Number of channels produced by the ConvGRU 19 | kernel_size (int or tuple): Size of the convolving kernel 20 | use_coordinates (bool, optional): If `True`, adds coordinate 21 | information to input image and hidden state. Default: `False` 22 | usegpu (bool, optional): If `True`, runs operations on GPU 23 | Default: `True` 24 | 25 | Shape: 26 | - Input: 27 | - `x` : `(N, C_{in}, H_{in}, W_{in})` 28 | - `hidden` : `(N, C_{out}, H_{in}, W_{in})` or `None` 29 | - Output: `next_hidden` : `(N, C_{out}, H_{in}, W_{in})` 30 | 31 | Examples: 32 | >>> n_hidden = 16 33 | >>> conv_gru = ConvGRUCell(3, n_hidden, 3, True, False) 34 | >>> input = torch.randn(8, 3, 64, 64) 35 | >>> hidden = torch.rand(8, n_hidden, 64, 64) 36 | >>> output = conv_gru(input, None) 37 | >>> output = conv_gru(input, hidden) 38 | 39 | >>> n_hidden = 16 40 | >>> conv_gru = ConvGRUCell(3, n_hidden, 3, True, True).cuda() 41 | >>> input = torch.randn(8, 3, 64, 64).cuda() 42 | >>> hidden = torch.rand(8, n_hidden, 64, 64).cuda() 43 | >>> output = conv_gru(input, None) 44 | >>> output = conv_gru(input, hidden) 45 | """ 46 | 47 | def __init__(self, input_size, hidden_size, kernel_size, 48 | use_coordinates=False, usegpu=True): 49 | super(ConvGRUCell, self).__init__() 50 | 51 | self.input_size = input_size 52 | self.hidden_size = hidden_size 53 | self.kernel_size = kernel_size 54 | self.use_coordinates = use_coordinates 55 | self.usegpu = usegpu 56 | 57 | _n_inputs = self.input_size + self.hidden_size 58 | if self.use_coordinates: 59 | self.conv_gates = CoordConv(_n_inputs, 60 | 2 * self.hidden_size, 61 | self.kernel_size, 62 | padding=self.kernel_size // 2, 63 | with_r=True, 64 | usegpu=self.usegpu) 65 | 66 | self.conv_ct = CoordConv(_n_inputs, self.hidden_size, 67 | self.kernel_size, 68 | padding=self.kernel_size // 2, 69 | with_r=True, 70 | usegpu=self.usegpu) 71 | else: 72 | self.conv_gates = nn.Conv2d(_n_inputs, 73 | 2 * self.hidden_size, 74 | self.kernel_size, 75 | padding=self.kernel_size // 2) 76 | 77 | self.conv_ct = nn.Conv2d(_n_inputs, self.hidden_size, 78 | self.kernel_size, 79 | padding=self.kernel_size // 2) 80 | 81 | def forward(self, x, hidden): 82 | 83 | batch_size, _, height, width = x.size() 84 | 85 | if hidden is None: 86 | size_h = [batch_size, self.hidden_size, height, width] 87 | hidden = Variable(torch.zeros(size_h)) 88 | 89 | if self.usegpu: 90 | hidden = hidden.cuda() 91 | 92 | c1 = self.conv_gates(torch.cat((x, hidden), dim=1)) 93 | rt, ut = c1.chunk(2, 1) 94 | 95 | reset_gate = F.sigmoid(rt) 96 | update_gate = F.sigmoid(ut) 97 | 98 | gated_hidden = torch.mul(reset_gate, hidden) 99 | 100 | ct = F.tanh(self.conv_ct(torch.cat((x, gated_hidden), dim=1))) 101 | 102 | next_h = torch.mul(update_gate, hidden) + (1 - update_gate) * ct 103 | 104 | return next_h 105 | 106 | 107 | if __name__ == '__main__': 108 | def test(use_coordinates, usegpu): 109 | n_timesteps, batch_size, n_channels = 3, 8, 3 110 | hidden_size, kernel_size, image_size = 64, 3, 32 111 | max_epoch = 10 112 | 113 | model = ConvGRUCell(n_channels, hidden_size, kernel_size, 114 | use_coordinates, usegpu) 115 | 116 | input = Variable(torch.rand(batch_size, n_channels, 117 | image_size, image_size)) 118 | hidden = Variable(torch.rand(batch_size, hidden_size, 119 | image_size, image_size)) 120 | 121 | if usegpu: 122 | model = model.cuda() 123 | input = input.cuda() 124 | hidden = hidden.cuda() 125 | 126 | print '\n* Model :\n\n', model 127 | 128 | out1 = model(input, None) 129 | out2 = model(input, hidden) 130 | 131 | print '\n* Success!' 132 | 133 | print '\n### CPU without coordinates ###' 134 | test(False, False) 135 | print '\n### CPU with coordinates ###' 136 | test(True, False) 137 | print '\n### GPU without coordinates ###' 138 | test(False, True) 139 | print '\n### GPU with coordinates ###' 140 | test(True, True) 141 | -------------------------------------------------------------------------------- /code/lib/archs/modules/coord_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class AddCoordinates(object): 7 | 8 | r"""Coordinate Adder Module as defined in 'An Intriguing Failing of 9 | Convolutional Neural Networks and the CoordConv Solution' 10 | (https://arxiv.org/pdf/1807.03247.pdf). 11 | 12 | This module concatenates coordinate information (`x`, `y`, and `r`) with 13 | given input tensor. 14 | 15 | `x` and `y` coordinates are scaled to `[-1, 1]` range where origin is the 16 | center. `r` is the Euclidean distance from the center and is scaled to 17 | `[0, 1]`. 18 | 19 | Args: 20 | with_r (bool, optional): If `True`, adds radius (`r`) coordinate 21 | information to input image. Default: `False` 22 | usegpu (bool, optional): If `True`, runs operations on GPU 23 | Default: `True` 24 | 25 | Shape: 26 | - Input: `(N, C_{in}, H_{in}, W_{in})` 27 | - Output: `(N, (C_{in} + 2) or (C_{in} + 3), H_{in}, W_{in})` 28 | 29 | Examples: 30 | >>> coord_adder = AddCoordinates(True, False) 31 | >>> input = torch.randn(8, 3, 64, 64) 32 | >>> output = coord_adder(input) 33 | 34 | >>> coord_adder = AddCoordinates(True, True).cuda() 35 | >>> input = torch.randn(8, 3, 64, 64).cuda() 36 | >>> output = coord_adder(input) 37 | """ 38 | 39 | def __init__(self, with_r=False, usegpu=True): 40 | self.with_r = with_r 41 | self.usegpu = usegpu 42 | 43 | def __call__(self, image): 44 | batch_size, _, image_height, image_width = image.size() 45 | 46 | y_coords = 2.0 * torch.arange(image_height).unsqueeze( 47 | 1).expand(image_height, image_width) / (image_height - 1.0) - 1.0 48 | x_coords = 2.0 * torch.arange(image_width).unsqueeze( 49 | 0).expand(image_height, image_width) / (image_width - 1.0) - 1.0 50 | 51 | coords = torch.stack((y_coords, x_coords), dim=0) 52 | 53 | if self.with_r: 54 | rs = ((y_coords ** 2) + (x_coords ** 2)) ** 0.5 55 | rs = rs / torch.max(rs) 56 | rs = torch.unsqueeze(rs, dim=0) 57 | coords = torch.cat((coords, rs), dim=0) 58 | 59 | coords = torch.unsqueeze(coords, dim=0).repeat(batch_size, 1, 1, 1) 60 | coords = Variable(coords) 61 | 62 | if self.usegpu: 63 | coords = coords.cuda() 64 | 65 | image = torch.cat((coords, image), dim=1) 66 | 67 | return image 68 | 69 | 70 | class CoordConv(nn.Module): 71 | 72 | r"""2D Convolution Module Using Extra Coordinate Information as defined 73 | in 'An Intriguing Failing of Convolutional Neural Networks and the 74 | CoordConv Solution' (https://arxiv.org/pdf/1807.03247.pdf). 75 | 76 | Args: 77 | Same as `torch.nn.Conv2d` with two additional arguments 78 | with_r (bool, optional): If `True`, adds radius (`r`) coordinate 79 | information to input image. Default: `False` 80 | usegpu (bool, optional): If `True`, runs operations on GPU 81 | Default: `True` 82 | 83 | Shape: 84 | - Input: `(N, C_{in}, H_{in}, W_{in})` 85 | - Output: `(N, C_{out}, H_{out}, W_{out})` 86 | 87 | Examples: 88 | >>> coord_conv = CoordConv(3, 16, 3, with_r=True, usegpu=False) 89 | >>> input = torch.randn(8, 3, 64, 64) 90 | >>> output = coord_conv(input) 91 | 92 | >>> coord_conv = CoordConv(3, 16, 3, with_r=True, usegpu=True).cuda() 93 | >>> input = torch.randn(8, 3, 64, 64).cuda() 94 | >>> output = coord_conv(input) 95 | """ 96 | 97 | def __init__(self, in_channels, out_channels, kernel_size, 98 | stride=1, padding=0, dilation=1, groups=1, bias=True, 99 | with_r=False, usegpu=True): 100 | super(CoordConv, self).__init__() 101 | 102 | in_channels += 2 103 | if with_r: 104 | in_channels += 1 105 | 106 | self.conv_layer = nn.Conv2d(in_channels, out_channels, 107 | kernel_size, stride=stride, 108 | padding=padding, dilation=dilation, 109 | groups=groups, bias=bias) 110 | 111 | self.coord_adder = AddCoordinates(with_r, usegpu) 112 | 113 | def forward(self, x): 114 | x = self.coord_adder(x) 115 | x = self.conv_layer(x) 116 | 117 | return x 118 | 119 | 120 | class CoordConvTranspose(nn.Module): 121 | 122 | r"""2D Transposed Convolution Module Using Extra Coordinate Information 123 | as defined in 'An Intriguing Failing of Convolutional Neural Networks and 124 | the CoordConv Solution' (https://arxiv.org/pdf/1807.03247.pdf). 125 | 126 | Args: 127 | Same as `torch.nn.ConvTranspose2d` with two additional arguments 128 | with_r (bool, optional): If `True`, adds radius (`r`) coordinate 129 | information to input image. Default: `False` 130 | usegpu (bool, optional): If `True`, runs operations on GPU 131 | Default: `True` 132 | 133 | Shape: 134 | - Input: `(N, C_{in}, H_{in}, W_{in})` 135 | - Output: `(N, C_{out}, H_{out}, W_{out})` 136 | 137 | Examples: 138 | >>> coord_conv_tr = CoordConvTranspose(3, 16, 3, with_r=True, 139 | >>> usegpu=False) 140 | >>> input = torch.randn(8, 3, 64, 64) 141 | >>> output = coord_conv_tr(input) 142 | 143 | >>> coord_conv_tr = CoordConvTranspose(3, 16, 3, with_r=True, 144 | >>> usegpu=True).cuda() 145 | >>> input = torch.randn(8, 3, 64, 64).cuda() 146 | >>> output = coord_conv_tr(input) 147 | """ 148 | 149 | def __init__(self, in_channels, out_channels, kernel_size, 150 | stride=1, padding=0, output_padding=0, groups=1, bias=True, 151 | dilation=1, with_r=False, usegpu=True): 152 | super(CoordConvTranspose, self).__init__() 153 | 154 | in_channels += 2 155 | if with_r: 156 | in_channels += 1 157 | 158 | self.conv_tr_layer = nn.ConvTranspose2d(in_channels, out_channels, 159 | kernel_size, stride=stride, 160 | padding=padding, 161 | output_padding=output_padding, 162 | groups=groups, bias=bias, 163 | dilation=dilation) 164 | 165 | self.coord_adder = AddCoordinates(with_r, usegpu) 166 | 167 | def forward(self, x): 168 | x = self.coord_adder(x) 169 | x = self.conv_tr_layer(x) 170 | 171 | return x 172 | 173 | 174 | class CoordConvNet(nn.Module): 175 | 176 | r"""Improves 2D Convolutions inside a ConvNet by processing extra 177 | coordinate information as defined in 'An Intriguing Failing of 178 | Convolutional Neural Networks and the CoordConv Solution' 179 | (https://arxiv.org/pdf/1807.03247.pdf). 180 | 181 | This module adds coordinate information to inputs of each 2D convolution 182 | module (`torch.nn.Conv2d`). 183 | 184 | Assumption: ConvNet Model must contain single `Sequential` container 185 | (`torch.nn.modules.container.Sequential`). 186 | 187 | Args: 188 | cnn_model: A ConvNet model that must contain single `Sequential` 189 | container (`torch.nn.modules.container.Sequential`). 190 | with_r (bool, optional): If `True`, adds radius (`r`) coordinate 191 | information to input image. Default: `False` 192 | usegpu (bool, optional): If `True`, runs operations on GPU 193 | Default: `True` 194 | 195 | Shape: 196 | - Input: Same as the input of the model. 197 | - Output: A list that contains all outputs (including 198 | intermediate outputs) of the model. 199 | 200 | Examples: 201 | >>> cnn_model = ... 202 | >>> cnn_model = CoordConvNet(cnn_model, True, False) 203 | >>> input = torch.randn(8, 3, 64, 64) 204 | >>> outputs = cnn_model(input) 205 | 206 | >>> cnn_model = ... 207 | >>> cnn_model = CoordConvNet(cnn_model, True, True).cuda() 208 | >>> input = torch.randn(8, 3, 64, 64).cuda() 209 | >>> outputs = cnn_model(input) 210 | """ 211 | 212 | def __init__(self, cnn_model, with_r=False, usegpu=True): 213 | super(CoordConvNet, self).__init__() 214 | 215 | self.with_r = with_r 216 | 217 | self.cnn_model = cnn_model 218 | self.__get_model() 219 | self.__update_weights() 220 | 221 | self.coord_adder = AddCoordinates(self.with_r, usegpu) 222 | 223 | def __get_model(self): 224 | for module in list(self.cnn_model.modules()): 225 | if module.__class__ == torch.nn.modules.container.Sequential: 226 | self.cnn_model = module 227 | break 228 | 229 | def __update_weights(self): 230 | coord_channels = 2 231 | if self.with_r: 232 | coord_channels += 1 233 | 234 | for l in list(self.cnn_model.modules()): 235 | if l.__str__().startswith('Conv2d'): 236 | weights = l.weight.data 237 | 238 | out_channels, in_channels, k_height, k_width = weights.size() 239 | 240 | coord_weights = torch.zeros(out_channels, coord_channels, 241 | k_height, k_width) 242 | 243 | weights = torch.cat((coord_weights, weights), dim=1) 244 | weights = nn.Parameter(weights) 245 | 246 | l.weight = weights 247 | l.in_channels += coord_channels 248 | 249 | def __get_outputs(self, x): 250 | outputs = [] 251 | for layer_name, layer in self.cnn_model._modules.items(): 252 | if layer.__str__().startswith('Conv2d'): 253 | x = self.coord_adder(x) 254 | x = layer(x) 255 | outputs.append(x) 256 | 257 | return outputs 258 | 259 | def forward(self, x): 260 | return self.__get_outputs(x) 261 | -------------------------------------------------------------------------------- /code/lib/archs/modules/recurrent_hourglass.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | from conv_gru import ConvGRUCell 4 | from coord_conv import CoordConv 5 | from utils import ListModule 6 | 7 | 8 | class RecurrentHourglass(nn.Module): 9 | 10 | r"""RecurrentHourglass Module as defined in 11 | 'Instance Segmentation and Tracking with Cosine Embeddings and Recurrent 12 | Hourglass Networks' (https://arxiv.org/pdf/1806.02070.pdf). 13 | 14 | Args: 15 | input_n_filters (int): Number of channels in the input image 16 | hidden_n_filters (int): Number of channels produced by Convolutional 17 | GRU module 18 | kernel_size (int or tuple): Size of the convolving kernels 19 | n_levels (int): Number of timesteps to unroll Convolutional GRU 20 | module 21 | embedding_size (int): Number of channels produced by Recurrent 22 | Hourglass module 23 | use_coordinates (bool, optional): If `True`, adds coordinate 24 | information to input image and hidden state. Default: `False` 25 | usegpu (bool, optional): If `True`, runs operations on GPU 26 | Default: `True` 27 | 28 | Shape: 29 | - Input: `(N, C_{in}, H_{in}, W_{in})` 30 | - Output: `(N, C_{out}, H_{in}, W_{in})` 31 | 32 | Examples: 33 | >>> hg = RecurrentHourglass(3, 16, 3, 5, 32, True, False) 34 | >>> input = torch.randn(8, 3, 64, 64) 35 | >>> output = hg(input) 36 | 37 | >>> hg = RecurrentHourglass(3, 16, 3, 5, 32, True, True).cuda() 38 | >>> input = torch.randn(8, 3, 64, 64).cuda() 39 | >>> output = hg(input) 40 | """ 41 | 42 | def __init__(self, input_n_filters, hidden_n_filters, kernel_size, 43 | n_levels, embedding_size, use_coordinates=False, 44 | usegpu=True): 45 | super(RecurrentHourglass, self).__init__() 46 | 47 | assert n_levels >= 1, 'n_levels should be greater than or equal to 1.' 48 | 49 | self.input_n_filters = input_n_filters 50 | self.hidden_n_filters = hidden_n_filters 51 | self.kernel_size = kernel_size 52 | self.n_levels = n_levels 53 | self.embedding_size = embedding_size 54 | self.use_coordinates = use_coordinates 55 | self.usegpu = usegpu 56 | 57 | self.convgru_cell = ConvGRUCell(self.hidden_n_filters, 58 | self.hidden_n_filters, 59 | self.kernel_size, 60 | self.use_coordinates, 61 | self.usegpu) 62 | 63 | self.__generate_pre_post_convs() 64 | 65 | def __generate_pre_post_convs(self): 66 | 67 | if self.use_coordinates: 68 | def __get_conv(input_n_filters, output_n_filters): 69 | return CoordConv(input_n_filters, output_n_filters, 70 | self.kernel_size, 71 | padding=self.kernel_size // 2, 72 | with_r=True, 73 | usegpu=self.usegpu) 74 | else: 75 | def __get_conv(input_n_filters, output_n_filters): 76 | return nn.Conv2d(input_n_filters, output_n_filters, 77 | self.kernel_size, 78 | padding=self.kernel_size // 2) 79 | 80 | # Pre Conv Layers 81 | self.pre_conv_layers = [__get_conv(self.input_n_filters, 82 | self.hidden_n_filters), ] 83 | for _ in range(self.n_levels - 1): 84 | self.pre_conv_layers.append(__get_conv(self.hidden_n_filters, 85 | self.hidden_n_filters)) 86 | self.pre_conv_layers = ListModule(*self.pre_conv_layers) 87 | 88 | # Post Conv Layers 89 | self.post_conv_layers = [__get_conv(self.hidden_n_filters, 90 | self.embedding_size), ] 91 | for _ in range(self.n_levels - 1): 92 | self.post_conv_layers.append(__get_conv(self.hidden_n_filters, 93 | self.hidden_n_filters)) 94 | self.post_conv_layers = ListModule(*self.post_conv_layers) 95 | 96 | def forward_encoding(self, x): 97 | 98 | convgru_outputs = [] 99 | hidden = None 100 | for i in range(self.n_levels): 101 | x = F.relu(self.pre_conv_layers[i](x)) 102 | hidden = self.convgru_cell(x, hidden) 103 | convgru_outputs.append(hidden) 104 | 105 | return convgru_outputs 106 | 107 | def forward_decoding(self, convgru_outputs): 108 | 109 | _last_conv_layer = self.post_conv_layers[self.n_levels - 1] 110 | _last_output = convgru_outputs[self.n_levels - 1] 111 | 112 | post_feature_map = F.relu(_last_conv_layer(_last_output)) 113 | for i in range(self.n_levels - 1)[::-1]: 114 | post_feature_map += convgru_outputs[i] 115 | post_feature_map = self.post_conv_layers[i](post_feature_map) 116 | post_feature_map = F.relu(post_feature_map) 117 | 118 | return post_feature_map 119 | 120 | def forward(self, x): 121 | 122 | x = self.forward_encoding(x) 123 | x = self.forward_decoding(x) 124 | 125 | return x 126 | 127 | if __name__ == '__main__': 128 | from torch.autograd import Variable 129 | import torch 130 | import time 131 | 132 | def test(use_coordinates, usegpu): 133 | 134 | n_epochs, batch_size, image_size = 10, 4, 36 135 | 136 | input_n_filters, hidden_n_filters = 3, 64 137 | kernel_size, n_levels, embedding_size = 3, 4, 8 138 | 139 | hg = RecurrentHourglass(input_n_filters, hidden_n_filters, 140 | kernel_size, n_levels, 141 | embedding_size, use_coordinates, 142 | usegpu) 143 | 144 | input = Variable(torch.rand(batch_size, input_n_filters, 145 | image_size, image_size)) 146 | 147 | if usegpu: 148 | hg = hg.cuda() 149 | input = input.cuda() 150 | 151 | print hg 152 | 153 | output = hg(input) 154 | 155 | print input.size(), output.size() 156 | 157 | print '\n### CPU without Coords ###' 158 | test(False, False) 159 | print '\n### CPU with Coords ###' 160 | test(True, False) 161 | print '\n### GPU without Coords ###' 162 | test(False, True) 163 | print '\n### GPU with Coords ###' 164 | test(True, True) 165 | -------------------------------------------------------------------------------- /code/lib/archs/modules/renet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | from coord_conv import AddCoordinates 4 | 5 | 6 | class ReNet(nn.Module): 7 | 8 | r"""ReNet Module as defined in 'ReNet: A Recurrent Neural 9 | Network Based Alternative to Convolutional Networks' 10 | (https://arxiv.org/pdf/1505.00393.pdf). 11 | 12 | Args: 13 | n_input (int): Number of channels in the input image 14 | n_units (int): Number of channels produced by ReNet 15 | patch_size (tuple): Patch size in the input of ReNet 16 | use_coordinates (bool, optional): If `True`, adds coordinate 17 | information to input image and hidden state. Default: `False` 18 | usegpu (bool, optional): If `True`, runs operations on GPU 19 | Default: `True` 20 | 21 | Shape: 22 | - Input: `(N, C_{in}, H_{in}, W_{in})` 23 | - Output: `(N, C_{out}, H_{out}, W_{out})` 24 | 25 | Examples: 26 | >>> renet = ReNet(3, 16, (2, 2), True, False) 27 | >>> input = torch.randn(8, 3, 64, 64) 28 | >>> output = renet(input) 29 | 30 | >>> renet = ReNet(3, 16, (2, 2), True, True).cuda() 31 | >>> input = torch.randn(8, 3, 64, 64).cuda() 32 | >>> output = renet(input) 33 | """ 34 | 35 | def __init__(self, n_input, n_units, patch_size=(1, 1), 36 | use_coordinates=False, usegpu=True): 37 | super(ReNet, self).__init__() 38 | 39 | self.use_coordinates = use_coordinates 40 | self.usegpu = usegpu 41 | 42 | # Determine whether to do tiling and patch sizes 43 | self.patch_size_height = int(patch_size[0]) 44 | self.patch_size_width = int(patch_size[1]) 45 | 46 | assert self.patch_size_height >= 1 47 | assert self.patch_size_width >= 1 48 | 49 | self.tiling = False if ((self.patch_size_height == 1) and ( 50 | self.patch_size_width == 1)) else True 51 | 52 | if self.use_coordinates: 53 | self.coord_adder = AddCoordinates(with_r=True, 54 | usegpu=self.usegpu) 55 | 56 | # Determine RNNs 57 | # Horizontal RNN 58 | rnn_hor_n_inputs = n_input * self.patch_size_height * \ 59 | self.patch_size_width 60 | if self.use_coordinates: 61 | rnn_hor_n_inputs += 3 62 | 63 | self.rnn_hor = nn.GRU(rnn_hor_n_inputs, n_units, 64 | num_layers=1, batch_first=True, 65 | bidirectional=True) 66 | 67 | # Vertical RNN 68 | self.rnn_ver = nn.GRU(n_units * 2, n_units, 69 | num_layers=1, batch_first=True, 70 | bidirectional=True) 71 | 72 | def __tile(self, x): 73 | 74 | if (x.size(2) % self.patch_size_height) == 0: 75 | n_height_padding = 0 76 | else: 77 | n_height_padding = self.patch_size_height - \ 78 | x.size(2) % self.patch_size_height 79 | if (x.size(3) % self.patch_size_width) == 0: 80 | n_width_padding = 0 81 | else: 82 | n_width_padding = self.patch_size_width - \ 83 | x.size(3) % self.patch_size_width 84 | 85 | n_top_padding = n_height_padding / 2 86 | n_bottom_padding = n_height_padding - n_top_padding 87 | 88 | n_left_padding = n_width_padding / 2 89 | n_right_padding = n_width_padding - n_left_padding 90 | 91 | x = F.pad(x, (n_left_padding, n_right_padding, 92 | n_top_padding, n_bottom_padding)) 93 | 94 | b, n_filters, n_height, n_width = x.size() 95 | 96 | assert n_height % self.patch_size_height == 0 97 | assert n_width % self.patch_size_width == 0 98 | 99 | new_height = n_height / self.patch_size_height 100 | new_width = n_width / self.patch_size_width 101 | 102 | x = x.view(b, n_filters, new_height, self.patch_size_height, 103 | new_width, self.patch_size_width) 104 | x = x.permute(0, 2, 4, 1, 3, 5) 105 | x = x.contiguous() 106 | x = x.view(b, new_height, new_width, self.patch_size_height * 107 | self.patch_size_width * n_filters) 108 | x = x.permute(0, 3, 1, 2) 109 | x = x.contiguous() 110 | 111 | return x 112 | 113 | def __swap_hw(self, x): 114 | 115 | # x : b, nf, h, w 116 | x = x.permute(0, 1, 3, 2) 117 | x = x.contiguous() 118 | # x : b, nf, w, h 119 | 120 | return x 121 | 122 | def rnn_forward(self, x, hor_or_ver): 123 | 124 | # x : b, nf, h, w 125 | assert hor_or_ver in ['hor', 'ver'] 126 | 127 | if hor_or_ver == 'ver': 128 | x = self.__swap_hw(x) 129 | 130 | x = x.permute(0, 2, 3, 1) 131 | x = x.contiguous() 132 | b, n_height, n_width, n_filters = x.size() 133 | # x : b, h, w, nf 134 | 135 | x = x.view(b * n_height, n_width, n_filters) 136 | # x : b * h, w, nf 137 | if hor_or_ver == 'hor': 138 | x, _ = self.rnn_hor(x) 139 | elif hor_or_ver == 'ver': 140 | x, _ = self.rnn_ver(x) 141 | 142 | x = x.contiguous() 143 | x = x.view(b, n_height, n_width, -1) 144 | # x : b, h, w, nf 145 | 146 | x = x.permute(0, 3, 1, 2) 147 | x = x.contiguous() 148 | # x : b, nf, h, w 149 | 150 | if hor_or_ver == 'ver': 151 | x = self.__swap_hw(x) 152 | 153 | return x 154 | 155 | def forward(self, x): 156 | 157 | # x : b, nf, h, w 158 | if self.tiling: 159 | x = self.__tile(x) 160 | 161 | if self.use_coordinates: 162 | x = self.coord_adder(x) 163 | 164 | x = self.rnn_forward(x, 'hor') 165 | x = self.rnn_forward(x, 'ver') 166 | 167 | return x 168 | -------------------------------------------------------------------------------- /code/lib/archs/modules/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | # https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219 4 | 5 | 6 | class ListModule(nn.Module): 7 | 8 | def __init__(self, *args): 9 | super(ListModule, self).__init__() 10 | 11 | idx = 0 12 | for module in args: 13 | self.add_module(str(idx), module) 14 | idx += 1 15 | 16 | def __getitem__(self, idx): 17 | 18 | if idx < 0 or idx >= len(self._modules): 19 | raise IndexError('index {} is out of range'.format(idx)) 20 | 21 | it = iter(self._modules.values()) 22 | for i in range(idx): 23 | next(it) 24 | 25 | return next(it) 26 | 27 | def __iter__(self): 28 | return iter(self._modules.values()) 29 | 30 | def __len__(self): 31 | return len(self._modules) 32 | -------------------------------------------------------------------------------- /code/lib/archs/modules/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models as models 3 | from coord_conv import CoordConvNet 4 | 5 | 6 | class VGG16(nn.Module): 7 | 8 | r"""A module that augments VGG16 as defined in 'Very Deep Convolutional 9 | Networks for Large-Scale Image Recognition' 10 | (https://arxiv.org/pdf/1409.1556.pdf). 11 | 12 | 1. It can return first `n_layers` of the VGG16. 13 | 2. It can add coordinates to feature maps prior to each convolution. 14 | 3. It can return all outputs (including intermediate outputs) of the 15 | VGG16. 16 | 17 | Args: 18 | n_layers (int): Use first `n_layers` layers of the VGG16 19 | pretrained (bool, optional): If `True`, initializes weights of the 20 | VGG16 using weights trained on ImageNet. Default: `True` 21 | use_coordinates (bool, optional): If `True`, adds `x`, `y` and radius 22 | (`r`) coordinates to feature maps prior to each convolution. 23 | Weights to process these coordinates are initialized as zero. 24 | Default: `False` 25 | return_intermediate_outputs (bool, optional): If `True`, return 26 | outputs of the each layer in the VGG16 as a list otherwise 27 | return output of the last layer of first `n_layers` layers 28 | of the VGG16. Default: `False` 29 | usegpu (bool, optional): If `True`, runs operations on GPU 30 | Default: `True` 31 | 32 | Shape: 33 | - Input: `(N, C_{in}, H_{in}, W_{in})` 34 | - Output: Output of the last layer of the selected subpart of VGG16 35 | or the list that contains outputs of the each layer depending on 36 | `return_intermediate_outputs` 37 | 38 | Examples: 39 | >>> vgg16 = VGG16(16, True, True, True, False) 40 | >>> input = torch.randn(8, 3, 64, 64) 41 | >>> output = vgg16(input) 42 | 43 | >>> vgg16 = VGG16(16, True, True, True, True).cuda() 44 | >>> input = torch.randn(8, 3, 64, 64).cuda() 45 | >>> output = vgg16(input) 46 | """ 47 | 48 | def __init__(self, n_layers, pretrained=True, use_coordinates=False, 49 | return_intermediate_outputs=False, usegpu=True): 50 | super(VGG16, self).__init__() 51 | 52 | self.use_coordinates = use_coordinates 53 | self.return_intermediate_outputs = return_intermediate_outputs 54 | 55 | self.cnn = models.__dict__['vgg16'](pretrained=pretrained) 56 | self.cnn = nn.Sequential(*list(self.cnn.children())[0]) 57 | self.cnn = nn.Sequential(*list(self.cnn.children())[: n_layers]) 58 | 59 | if self.use_coordinates: 60 | self.cnn = CoordConvNet(self.cnn, True, usegpu) 61 | 62 | def __get_outputs(self, x): 63 | if self.use_coordinates: 64 | return self.cnn(x) 65 | 66 | outputs = [] 67 | for i, layer in enumerate(self.cnn.children()): 68 | x = layer(x) 69 | outputs.append(x) 70 | 71 | return outputs 72 | 73 | def forward(self, x): 74 | outputs = self.__get_outputs(x) 75 | 76 | if self.return_intermediate_outputs: 77 | return outputs 78 | 79 | return outputs[-1] 80 | 81 | 82 | class SkipVGG16(nn.Module): 83 | 84 | r"""A module that returns output of 7th convolutional layer of the 85 | VGG16 along with outputs of the 2nd and 4th convolutional layers. 86 | 87 | Args: 88 | pretrained (bool, optional): If `True`, initializes weights of the 89 | VGG16 using weights trained on ImageNet. Default: `True` 90 | use_coordinates (bool, optional): If `True`, adds `x`, `y` and radius 91 | (`r`) coordinates to feature maps prior to each convolution. 92 | Weights to process these coordinates are initialized as zero. 93 | Default: `False` 94 | usegpu (bool, optional): If `True`, runs operations on GPU 95 | Default: `True` 96 | 97 | Shape: 98 | - Input: `(N, C_{in}, H_{in}, W_{in})` 99 | - Output: List of outputs of the 2nd, 4th and 7th convolutional 100 | layers of the VGG16, respectively. 101 | 102 | Examples: 103 | >>> vgg16 = SkipVGG16(True, True, False) 104 | >>> input = torch.randn(8, 3, 64, 64) 105 | >>> output = vgg16(input) 106 | 107 | >>> vgg16 = SkipVGG16(True, True, True).cuda() 108 | >>> input = torch.randn(8, 3, 64, 64).cuda() 109 | >>> output = vgg16(input) 110 | """ 111 | 112 | def __init__(self, pretrained=True, use_coordinates=False, 113 | usegpu=True): 114 | super(SkipVGG16, self).__init__() 115 | 116 | self.use_coordinates = use_coordinates 117 | 118 | self.outputs = [3, 8] 119 | self.n_filters = [64, 128] 120 | 121 | self.model = VGG16(n_layers=16, pretrained=pretrained, 122 | use_coordinates=self.use_coordinates, 123 | return_intermediate_outputs=True, 124 | usegpu=usegpu) 125 | 126 | def forward(self, x): 127 | 128 | if self.use_coordinates: 129 | outs = self.model(x) 130 | out = [o for i, o in enumerate(outs) if i in self.outputs] 131 | out.append(outs[-1]) 132 | else: 133 | out = [] 134 | for i, layer in enumerate(list(self.model.children())[0]): 135 | x = layer(x) 136 | if i in self.outputs: 137 | out.append(x) 138 | out.append(x) 139 | 140 | return out 141 | 142 | if __name__ == '__main__': 143 | from torch.autograd import Variable 144 | import torch 145 | import time 146 | 147 | def test(use_coordinates, usegpu, skip): 148 | 149 | batch_size, image_size = 4, 128 150 | 151 | if skip: 152 | vgg16 = SkipVGG16(False, use_coordinates, usegpu) 153 | else: 154 | vgg16 = VGG16(16, False, use_coordinates, False, usegpu) 155 | 156 | input = Variable(torch.rand(batch_size, 3, image_size, 157 | image_size)) 158 | 159 | if usegpu: 160 | vgg16 = vgg16.cuda() 161 | input = input.cuda() 162 | 163 | print '\nModel :\n\n', vgg16 164 | 165 | output = vgg16(input) 166 | 167 | if isinstance(output, list): 168 | print '\n* N outputs : ', len(output) 169 | for o in output: 170 | print '** Output shape : ', o.size() 171 | else: 172 | print '\n** Output Shape : ', output.size() 173 | 174 | print '\n### COORDS + GPU + SKIP' 175 | test(True, True, True) 176 | print '\n### COORDS + GPU' 177 | test(True, True, False) 178 | print '\n### COORDS + CPU + SKIP' 179 | test(True, False, True) 180 | print '\n### COORDS + CPU' 181 | test(True, False, False) 182 | print '\n### GPU + SKIP' 183 | test(False, True, True) 184 | print '\n### GPU' 185 | test(False, True, False) 186 | print '\n### CPU + SKIP' 187 | test(False, False, True) 188 | print '\n### CPU' 189 | test(False, False, False) 190 | -------------------------------------------------------------------------------- /code/lib/archs/reseg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules.vgg16 import SkipVGG16 4 | from modules.renet import ReNet 5 | from instance_counter import InstanceCounter 6 | 7 | 8 | class ReSeg(nn.Module): 9 | 10 | r"""ReSeg Module (with modifications) as defined in 'ReSeg: A Recurrent 11 | Neural Network-based Model for Semantic Segmentation' 12 | (https://arxiv.org/pdf/1511.07053.pdf). 13 | 14 | * VGG16 with skip Connections as base network 15 | * Two ReNet layers 16 | * Two transposed convolutional layers for upsampling 17 | * Three heads for semantic segmentation, instance segmentation and 18 | instance counting. 19 | 20 | Args: 21 | n_classes (int): Number of semantic classes 22 | use_instance_seg (bool, optional): If `False`, does not perform 23 | instance segmentation. Default: `True` 24 | pretrained (bool, optional): If `True`, initializes weights of the 25 | VGG16 using weights trained on ImageNet. Default: `True` 26 | use_coordinates (bool, optional): If `True`, adds coordinate 27 | information to input image and hidden state. Default: `False` 28 | usegpu (bool, optional): If `True`, runs operations on GPU 29 | Default: `True` 30 | 31 | Shape: 32 | - Input: `(N, C_{in}, H_{in}, W_{in})` 33 | - Output: 34 | - Semantic Seg: `(N, N_{class}, H_{in}, W_{in})` 35 | - Instance Seg: `(N, 32, H_{in}, W_{in})` 36 | - Instance Cnt: `(N, 1)` 37 | 38 | Examples: 39 | >>> reseg = ReSeg(3, True, True, True, False) 40 | >>> input = torch.randn(8, 3, 64, 64) 41 | >>> outputs = reseg(input) 42 | 43 | >>> reseg = ReSeg(3, True, True, True, True).cuda() 44 | >>> input = torch.randn(8, 3, 64, 64).cuda() 45 | >>> outputs = reseg(input) 46 | """ 47 | 48 | def __init__(self, n_classes, use_instance_seg=True, pretrained=True, 49 | use_coordinates=False, usegpu=True): 50 | super(ReSeg, self).__init__() 51 | 52 | self.n_classes = n_classes 53 | self.use_instance_seg = use_instance_seg 54 | 55 | # Encoder 56 | # BaseCNN 57 | self.cnn = SkipVGG16(pretrained=pretrained, 58 | use_coordinates=use_coordinates, 59 | usegpu=usegpu) 60 | 61 | # ReNets 62 | self.renet1 = ReNet(256, 100, use_coordinates=use_coordinates, 63 | usegpu=usegpu) 64 | self.renet2 = ReNet(100 * 2, 100, use_coordinates=use_coordinates, 65 | usegpu=usegpu) 66 | 67 | # Decoder 68 | self.upsampling1 = nn.ConvTranspose2d(100 * 2, 100, 69 | kernel_size=(2, 2), 70 | stride=(2, 2)) 71 | self.relu1 = nn.ReLU() 72 | self.upsampling2 = nn.ConvTranspose2d(100 + self.cnn.n_filters[1], 73 | 100, kernel_size=(2, 2), 74 | stride=(2, 2)) 75 | self.relu2 = nn.ReLU() 76 | 77 | # Semantic Segmentation 78 | self.sem_seg_output = nn.Conv2d(100 + self.cnn.n_filters[0], 79 | self.n_classes, kernel_size=(1, 1), 80 | stride=(1, 1)) 81 | 82 | # Instance Segmentation 83 | if self.use_instance_seg: 84 | self.ins_seg_output = nn.Conv2d(100 + self.cnn.n_filters[0], 85 | 32, kernel_size=(1, 1), 86 | stride=(1, 1)) 87 | 88 | # Instance Counting 89 | self.ins_cls_cnn = InstanceCounter(100 * 2, use_coordinates, 90 | usegpu=usegpu) 91 | 92 | def forward(self, x): 93 | 94 | # Encoder 95 | # BaseCNN 96 | first_skip, second_skip, x_enc = self.cnn(x) 97 | 98 | # ReNets 99 | x_enc = self.renet1(x_enc) 100 | x_enc = self.renet2(x_enc) 101 | 102 | # Decoder 103 | x_dec = self.relu1(self.upsampling1(x_enc)) 104 | x_dec = torch.cat((x_dec, second_skip), dim=1) 105 | x_dec = self.relu2(self.upsampling2(x_dec)) 106 | x_dec = torch.cat((x_dec, first_skip), dim=1) 107 | 108 | # Semantic Segmentation 109 | sem_seg_out = self.sem_seg_output(x_dec) 110 | 111 | # Instance Segmentation 112 | if self.use_instance_seg: 113 | ins_seg_out = self.ins_seg_output(x_dec) 114 | else: 115 | ins_seg_out = None 116 | 117 | # Instance Counting 118 | ins_cls_out = self.ins_cls_cnn(x_enc) 119 | 120 | return sem_seg_out, ins_seg_out, ins_cls_out 121 | -------------------------------------------------------------------------------- /code/lib/archs/stacked_recurrent_hourglass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from modules.vgg16 import VGG16 5 | from modules.recurrent_hourglass import RecurrentHourglass 6 | from modules.renet import ReNet 7 | from instance_counter import InstanceCounter 8 | 9 | 10 | class StackedRecurrentHourglass(nn.Module): 11 | 12 | r"""Stacked Recurrent Hourglass Module for instance segmentation 13 | as defined in 'Instance Segmentation and Tracking with Cosine 14 | Embeddings and Recurrent Hourglass Networks' 15 | (https://arxiv.org/pdf/1806.02070.pdf). 16 | 17 | * First four layers of VGG16 18 | * Two RecurrentHourglass layers 19 | * Two ReNet layers 20 | * Two transposed convolutional layers for upsampling 21 | * Three heads for semantic segmentation, instance segmentation and 22 | instance counting. 23 | 24 | Args: 25 | n_classes (int): Number of semantic classes 26 | use_instance_seg (bool, optional): If `False`, does not perform 27 | instance segmentation. Default: `True` 28 | pretrained (bool, optional): If `True`, initializes weights of the 29 | VGG16 using weights trained on ImageNet. Default: `True` 30 | use_coordinates (bool, optional): If `True`, adds coordinate 31 | information to input image and hidden state. Default: `False` 32 | usegpu (bool, optional): If `True`, runs operations on GPU 33 | Default: `True` 34 | 35 | Shape: 36 | - Input: `(N, C_{in}, H_{in}, W_{in})` 37 | - Output: 38 | - Semantic Seg: `(N, N_{class}, H_{in}, W_{in})` 39 | - Instance Seg: `(N, 32, H_{in}, W_{in})` 40 | - Instance Cnt: `(N, 1)` 41 | 42 | Examples: 43 | >>> srhg = StackedRecurrentHourglass(4, True, True, True, False) 44 | >>> input = torch.randn(8, 3, 64, 64) 45 | >>> outputs = srhg(input) 46 | 47 | >>> srhg = StackedRecurrentHourglass(4, True, True, True, True) 48 | >>> srhg = srhg.cuda() 49 | >>> input = torch.randn(8, 3, 64, 64).cuda() 50 | >>> outputs = srhg(input) 51 | """ 52 | 53 | def __init__(self, n_classes, use_instance_seg=True, pretrained=True, 54 | use_coordinates=False, usegpu=True): 55 | super(StackedRecurrentHourglass, self).__init__() 56 | 57 | self.n_classes = n_classes 58 | self.use_instance_seg = use_instance_seg 59 | self.use_coords = use_coordinates 60 | self.pretrained = pretrained 61 | self.usegpu = usegpu 62 | 63 | # Encoder 64 | # BaseCNN 65 | self.base_cnn = self.__generate_base_cnn() 66 | 67 | # Encoder Stacked Hourglass 68 | self.enc_stacked_hourglass = self.__generate_enc_stacked_hg(64, 3) 69 | 70 | # ReNets 71 | self.stacked_renet = self.__generate_stacked_renet(64, 2) 72 | 73 | # Decoder 74 | self.decoder = self.__generate_decoder(64) 75 | 76 | # Heads 77 | self.semantic_seg, self.instance_seg, self.instance_count = \ 78 | self.__generate_heads(64, 32) 79 | 80 | def __generate_base_cnn(self): 81 | 82 | base_cnn = VGG16(n_layers=4, pretrained=self.pretrained, 83 | use_coordinates=self.use_coords, 84 | return_intermediate_outputs=False, 85 | usegpu=self.usegpu) 86 | 87 | return base_cnn 88 | 89 | def __generate_enc_stacked_hg(self, input_n_filters, n_levels): 90 | 91 | stacked_hourglass = nn.Sequential() 92 | stacked_hourglass.add_module('Hourglass_1', 93 | RecurrentHourglass( 94 | input_n_filters=input_n_filters, 95 | hidden_n_filters=64, 96 | kernel_size=3, 97 | n_levels=n_levels, 98 | embedding_size=64, 99 | use_coordinates=self.use_coords, 100 | usegpu=self.usegpu)) 101 | stacked_hourglass.add_module('pool_1', 102 | nn.MaxPool2d(2, stride=2)) 103 | stacked_hourglass.add_module('Hourglass_2', 104 | RecurrentHourglass( 105 | input_n_filters=64, 106 | hidden_n_filters=64, 107 | kernel_size=3, 108 | n_levels=n_levels, 109 | embedding_size=64, 110 | use_coordinates=self.use_coords, 111 | usegpu=self.usegpu)) 112 | stacked_hourglass.add_module('pool_2', 113 | nn.MaxPool2d(2, stride=2)) 114 | 115 | return stacked_hourglass 116 | 117 | def __generate_stacked_renet(self, input_n_filters, n_renets): 118 | 119 | assert n_renets >= 1, 'n_renets should be 1 at least.' 120 | 121 | renet = nn.Sequential() 122 | renet.add_module('ReNet_1', ReNet(input_n_filters, 32, 123 | patch_size=(1, 1), 124 | use_coordinates=self.use_coords, 125 | usegpu=self.usegpu)) 126 | for i in range(1, n_renets): 127 | renet.add_module('ReNet_{}'.format(i + 1), 128 | ReNet(32 * 2, 32, patch_size=(1, 1), 129 | use_coordinates=self.use_coords, 130 | usegpu=self.usegpu)) 131 | 132 | return renet 133 | 134 | def __generate_decoder(self, input_n_filters): 135 | 136 | decoder = nn.Sequential() 137 | decoder.add_module('ConvTranspose_1', 138 | nn.ConvTranspose2d(input_n_filters, 139 | 64, 140 | kernel_size=(2, 2), 141 | stride=(2, 2))) 142 | decoder.add_module('ReLU_1', nn.ReLU()) 143 | decoder.add_module('ConvTranspose_2', 144 | nn.ConvTranspose2d(64, 64, 145 | kernel_size=(2, 2), 146 | stride=(2, 2))) 147 | decoder.add_module('ReLU_2', nn.ReLU()) 148 | 149 | return decoder 150 | 151 | def __generate_heads(self, input_n_filters, embedding_size): 152 | 153 | semantic_segmentation = nn.Sequential() 154 | semantic_segmentation.add_module('Conv_1', 155 | nn.Conv2d(input_n_filters, 156 | self.n_classes, 157 | kernel_size=(1, 1), 158 | stride=(1, 1))) 159 | 160 | if self.use_instance_seg: 161 | instance_segmentation = nn.Sequential() 162 | instance_segmentation.add_module('Conv_1', 163 | nn.Conv2d(input_n_filters, 164 | embedding_size, 165 | kernel_size=(1, 1), 166 | stride=(1, 1))) 167 | else: 168 | instance_segmentation = None 169 | 170 | instance_counting = InstanceCounter(input_n_filters, 171 | use_coordinates=self.use_coords, 172 | usegpu=self.usegpu) 173 | 174 | return semantic_segmentation, instance_segmentation, instance_counting 175 | 176 | def forward(self, x): 177 | 178 | x = self.base_cnn(x) 179 | x = self.enc_stacked_hourglass(x) 180 | x = self.stacked_renet(x) 181 | x = self.decoder(x) 182 | 183 | sem_seg_out = self.semantic_seg(x) 184 | if self.use_instance_seg: 185 | ins_seg_out = self.instance_seg(x) 186 | else: 187 | ins_seg_out = None 188 | 189 | ins_count_out = self.instance_count(x) 190 | 191 | return sem_seg_out, ins_seg_out, ins_count_out 192 | -------------------------------------------------------------------------------- /code/lib/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import random 4 | 5 | from PIL import Image 6 | import lmdb 7 | import sys 8 | import numpy as np 9 | from StringIO import StringIO 10 | 11 | from utils import ImageUtilities as IU 12 | 13 | 14 | class SegDataset(Dataset): 15 | """Dataset Reader""" 16 | 17 | def __init__(self, lmdb_path): 18 | 19 | self._lmdb_path = lmdb_path 20 | 21 | self.env = lmdb.open(self._lmdb_path, max_readers=1, 22 | readonly=True, lock=False, 23 | readahead=False, meminit=False) 24 | 25 | if not self.env: 26 | print 'Cannot read lmdb from {}'.format(self._lmdb_path) 27 | sys.exit(0) 28 | 29 | with self.env.begin(write=False) as txn: 30 | self.n_samples = int(txn.get('num-samples')) 31 | 32 | def __load_data(self, index): 33 | 34 | with self.env.begin(write=False) as txn: 35 | image_key = 'image-{}'.format(index + 1) 36 | semantic_ann_key = 'semantic-annotation-{}'.format(index + 1) 37 | instance_ann_key = 'instance-annotation-{}'.format(index + 1) 38 | height_key = 'height-{}'.format(index + 1) 39 | width_key = 'width-{}'.format(index + 1) 40 | n_objects_key = 'n_objects-{}'.format(index + 1) 41 | 42 | img = txn.get(image_key) 43 | img = Image.open(StringIO(img)) 44 | 45 | height = int(txn.get(height_key)) 46 | width = int(txn.get(width_key)) 47 | n_objects = int(txn.get(n_objects_key)) 48 | 49 | semantic_annotation = np.fromstring(txn.get(semantic_ann_key), 50 | dtype=np.uint8) 51 | semantic_annotation = semantic_annotation.reshape(height, width) 52 | 53 | instance_annotation = np.fromstring(txn.get(instance_ann_key), 54 | dtype=np.uint8) 55 | instance_annotation = instance_annotation.reshape(height, width, 56 | n_objects) 57 | 58 | return img, semantic_annotation, instance_annotation, n_objects 59 | 60 | def __getitem__(self, index): 61 | 62 | assert index <= len(self), 'index range error' 63 | 64 | image, semantic_annotation, instance_annotation, n_objects \ 65 | = self.__load_data(index) 66 | 67 | return image, semantic_annotation, instance_annotation, \ 68 | n_objects 69 | 70 | def __len__(self): 71 | return self.n_samples 72 | 73 | 74 | class AlignCollate(object): 75 | 76 | def __init__(self, mode, n_classes, max_n_objects, mean, std, image_height, 77 | image_width, random_hor_flipping=True, 78 | random_ver_flipping=True, random_transposing=True, 79 | random_90x_rotation=True, random_rotation=True, 80 | random_color_jittering=True, random_grayscaling=True, 81 | random_channel_swapping=True, random_gamma=True, 82 | random_resolution=True): 83 | 84 | self._mode = mode 85 | self.n_classes = n_classes 86 | self.max_n_objects = max_n_objects 87 | 88 | assert self._mode in ['training', 'test'] 89 | 90 | self.mean = mean 91 | self.std = std 92 | self.image_height = image_height 93 | self.image_width = image_width 94 | 95 | self.random_horizontal_flipping = random_hor_flipping 96 | self.random_vertical_flipping = random_ver_flipping 97 | self.random_transposing = random_transposing 98 | self.random_90x_rotation = random_90x_rotation 99 | self.random_rotation = random_rotation 100 | self.random_color_jittering = random_color_jittering 101 | self.random_grayscaling = random_grayscaling 102 | self.random_channel_swapping = random_channel_swapping 103 | self.random_gamma = random_gamma 104 | self.random_resolution = random_resolution 105 | 106 | if self._mode == 'training': 107 | if self.random_horizontal_flipping: 108 | self.horizontal_flipper = IU.image_random_horizontal_flipper() 109 | if self.random_vertical_flipping: 110 | self.vertical_flipper = IU.image_random_vertical_flipper() 111 | if self.random_transposing: 112 | self.transposer = IU.image_random_transposer() 113 | if self.random_rotation: 114 | self.image_rotator = IU.image_random_rotator(random_bg=True) 115 | self.annotation_rotator = IU.image_random_rotator(Image.NEAREST, 116 | random_bg=False) 117 | if self.random_90x_rotation: 118 | self.image_rotator_90x = IU.image_random_90x_rotator() 119 | self.annotation_rotator_90x = IU.image_random_90x_rotator(Image.NEAREST) 120 | if self.random_color_jittering: 121 | self.color_jitter = IU.image_random_color_jitter( 122 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2) 123 | if self.random_grayscaling: 124 | self.grayscaler = IU.image_random_grayscaler(p=0.3) 125 | if self.random_channel_swapping: 126 | self.channel_swapper = IU.image_random_channel_swapper(p=0.5) 127 | if self.random_gamma: 128 | self.gamma_adjuster = IU.image_random_gamma([0.7, 1.3], gain=1) 129 | if self.random_resolution: 130 | self.resolution_degrader = IU.image_random_resolution([0.7, 1.3]) 131 | 132 | self.img_resizer = IU.image_resizer(self.image_height, 133 | self.image_width) 134 | self.ann_resizer = IU.image_resizer(self.image_height, 135 | self.image_width, 136 | interpolation=Image.NEAREST) 137 | else: 138 | self.img_resizer = IU.image_resizer(self.image_height, 139 | self.image_width) 140 | self.ann_resizer = IU.image_resizer(self.image_height, 141 | self.image_width, 142 | interpolation=Image.NEAREST) 143 | 144 | self.image_normalizer = IU.image_normalizer(self.mean, self.std) 145 | 146 | def __preprocess(self, image, semantic_annotation, instance_annotation): 147 | 148 | # Augmentation 149 | if self._mode == 'training': 150 | instance_annotation = list(instance_annotation.transpose(2, 0, 1)) 151 | n_objects = len(instance_annotation) 152 | 153 | if self.random_resolution: 154 | image = self.resolution_degrader(image) 155 | 156 | if self.random_horizontal_flipping: 157 | is_flip = random.random() < 0.5 158 | image = self.horizontal_flipper(image, is_flip) 159 | 160 | for i in range(n_objects): 161 | _ann = instance_annotation[i].copy() 162 | _ann = self.horizontal_flipper(_ann, is_flip) 163 | instance_annotation[i] = _ann 164 | 165 | semantic_annotation = self.horizontal_flipper( 166 | semantic_annotation, is_flip) 167 | 168 | if self.random_vertical_flipping: 169 | is_flip = random.random() < 0.5 170 | image = self.vertical_flipper(image, is_flip) 171 | 172 | for i in range(n_objects): 173 | _ann = instance_annotation[i].copy() 174 | _ann = self.vertical_flipper(_ann, is_flip) 175 | instance_annotation[i] = _ann 176 | 177 | semantic_annotation = self.vertical_flipper( 178 | semantic_annotation, is_flip) 179 | 180 | if self.random_transposing: 181 | is_trans = random.random() < 0.5 182 | image = self.transposer(image, is_trans) 183 | 184 | for i in range(n_objects): 185 | _ann = instance_annotation[i].copy() 186 | _ann = self.transposer(_ann, is_trans) 187 | instance_annotation[i] = _ann 188 | 189 | semantic_annotation = self.transposer( 190 | semantic_annotation, is_trans) 191 | 192 | if self.random_90x_rotation: 193 | rot_angle = np.random.choice([0, 90, 180, 270]) 194 | rot_expand = True 195 | image = self.image_rotator_90x(image, rot_angle, rot_expand) 196 | 197 | for i in range(n_objects): 198 | _ann = instance_annotation[i].copy() 199 | _ann = self.annotation_rotator_90x(_ann, rot_angle, rot_expand) 200 | instance_annotation[i] = _ann 201 | 202 | semantic_annotation = self.annotation_rotator_90x(semantic_annotation, 203 | rot_angle, rot_expand) 204 | 205 | if self.random_rotation: 206 | rot_angle = int(np.random.rand() * 10) 207 | if np.random.rand() >= 0.5: 208 | rot_angle = -1 * rot_angle 209 | # rot_expand = np.random.rand() < 0.5 210 | rot_expand = True 211 | image = self.image_rotator(image, rot_angle, rot_expand) 212 | 213 | for i in range(n_objects): 214 | _ann = instance_annotation[i].copy() 215 | _ann = self.annotation_rotator(_ann, rot_angle, rot_expand) 216 | instance_annotation[i] = _ann 217 | 218 | semantic_annotation = self.annotation_rotator(semantic_annotation, 219 | rot_angle, rot_expand) 220 | 221 | if self.random_color_jittering: 222 | image = self.color_jitter(image) 223 | 224 | if self.random_gamma: 225 | image = self.gamma_adjuster(image) 226 | 227 | if self.random_channel_swapping: 228 | image = self.channel_swapper(image) 229 | 230 | if self.random_grayscaling: 231 | image = self.grayscaler(image) 232 | 233 | instance_annotation = np.array( 234 | instance_annotation).transpose(1, 2, 0) 235 | 236 | # Resize Images 237 | image = self.img_resizer(image) 238 | 239 | # Resize Instance Annotations 240 | ann_height, ann_width, n_objects = instance_annotation.shape 241 | instance_annotation_resized = [] 242 | 243 | height_ratio = 1.0 * self.image_height / ann_height 244 | width_ratio = 1.0 * self.image_width / ann_width 245 | 246 | for i in range(n_objects): 247 | instance_ann_img = Image.fromarray(instance_annotation[:, :, i]) 248 | instance_ann_img = self.ann_resizer(instance_ann_img) 249 | instance_ann_img = np.array(instance_ann_img) 250 | 251 | instance_annotation_resized.append(instance_ann_img) 252 | 253 | # Fill Instance Annotations with zeros 254 | for i in range(self.max_n_objects - n_objects): 255 | zero = np.zeros((ann_height, ann_width), 256 | dtype=np.uint8) 257 | zero = Image.fromarray(zero) 258 | zero = self.ann_resizer(zero) 259 | zero = np.array(zero) 260 | instance_annotation_resized.append(zero.copy()) 261 | 262 | instance_annotation_resized = np.stack( 263 | instance_annotation_resized, axis=0) 264 | instance_annotation_resized = instance_annotation_resized.transpose( 265 | 1, 2, 0) 266 | 267 | # Resize Semantic Anntations 268 | semantic_annotation = self.ann_resizer( 269 | Image.fromarray(semantic_annotation)) 270 | semantic_annotation = np.array(semantic_annotation) 271 | 272 | # Image Normalization 273 | image = self.image_normalizer(image) 274 | 275 | return (image, semantic_annotation, instance_annotation_resized) 276 | 277 | def __call__(self, batch): 278 | images, semantic_annotations, instance_annotations, \ 279 | n_objects = zip(*batch) 280 | 281 | images = list(images) 282 | semantic_annotations = list(semantic_annotations) 283 | instance_annotations = list(instance_annotations) 284 | 285 | # max_n_objects = np.max(n_objects) 286 | 287 | bs = len(images) 288 | for i in range(bs): 289 | image, semantic_annotation, instance_annotation = \ 290 | self.__preprocess(images[i], 291 | semantic_annotations[i], 292 | instance_annotations[i]) 293 | 294 | images[i] = image 295 | semantic_annotations[i] = semantic_annotation 296 | instance_annotations[i] = instance_annotation 297 | 298 | images = torch.stack(images) 299 | 300 | instance_annotations = np.array( 301 | instance_annotations, 302 | dtype='int') # bs, h, w, n_ins 303 | 304 | semantic_annotations = np.array( 305 | semantic_annotations, dtype='int') # bs, h, w 306 | semantic_annotations_one_hot = np.eye(self.n_classes, dtype='int') 307 | semantic_annotations_one_hot = \ 308 | semantic_annotations_one_hot[semantic_annotations.flatten()].reshape( 309 | semantic_annotations.shape[0], semantic_annotations.shape[1], 310 | semantic_annotations.shape[2], self.n_classes) 311 | 312 | instance_annotations = torch.LongTensor(instance_annotations) 313 | instance_annotations = instance_annotations.permute(0, 3, 1, 2) 314 | 315 | semantic_annotations_one_hot = torch.LongTensor( 316 | semantic_annotations_one_hot) 317 | semantic_annotations_one_hot = semantic_annotations_one_hot.permute( 318 | 0, 3, 1, 2) 319 | 320 | n_objects = torch.IntTensor(n_objects) 321 | 322 | return (images, semantic_annotations_one_hot, instance_annotations, 323 | n_objects) 324 | 325 | 326 | if __name__ == '__main__': 327 | ds = SegDataset('../../data/processed/CVPPP/lmdb/training-lmdb/') 328 | image, semantic_annotation, instance_annotation, n_objects = ds[5] 329 | 330 | print image.size 331 | print semantic_annotation.shape 332 | print instance_annotation.shape 333 | print n_objects 334 | print np.unique(semantic_annotation) 335 | print np.unique(instance_annotation) 336 | 337 | ac = AlignCollate('training', 9, 120, [0.0, 0.0, 0.0], 338 | [1.0, 1.0, 1.0], 256, 512) 339 | 340 | loader = torch.utils.data.DataLoader(ds, batch_size=3, 341 | shuffle=False, 342 | num_workers=0, 343 | pin_memory=False, 344 | collate_fn=ac) 345 | loader = iter(loader) 346 | 347 | images, semantic_annotations, instance_annotations, \ 348 | n_objects = loader.next() 349 | 350 | print images.size() 351 | print semantic_annotations.size() 352 | print instance_annotations.size() 353 | print n_objects.size() 354 | print n_objects 355 | -------------------------------------------------------------------------------- /code/lib/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from dice import DiceLoss, DiceCoefficient 2 | from discriminative import DiscriminativeLoss 3 | -------------------------------------------------------------------------------- /code/lib/losses/dice.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.loss import _assert_no_grad, _Loss, _WeightedLoss 2 | from torch.nn import functional as F 3 | import torch 4 | import numpy as np 5 | 6 | # https://github.com/rogertrullo/pytorch/blob/rogertrullo-dice_loss/torch/nn/functional.py#L708 7 | 8 | 9 | def dice_coefficient(input, target, smooth=1.0): 10 | """input : is a torch variable of size BatchxnclassesxHxW representing 11 | log probabilities for each class 12 | target : is a 1-hot representation of the groundtruth, shoud have same size 13 | as the input""" 14 | 15 | assert input.size() == target.size(), 'Input sizes must be equal.' 16 | assert input.dim() == 4, 'Input must be a 4D Tensor.' 17 | uniques = np.unique(target.data.cpu().numpy()) 18 | assert set(list(uniques)) <= set( 19 | [0, 1]), 'Target must only contain zeros and ones.' 20 | assert smooth > 0, 'Smooth must be greater than 0.' 21 | 22 | probs = F.softmax(input, dim=1) 23 | target_f = target.float() 24 | 25 | num = probs * target_f # b, c, h, w -- p*g 26 | num = torch.sum(num, dim=3) # b, c, h 27 | num = torch.sum(num, dim=2) # b, c 28 | 29 | den1 = probs * probs # b, c, h, w -- p^2 30 | den1 = torch.sum(den1, dim=3) # b, c, h 31 | den1 = torch.sum(den1, dim=2) # b, c 32 | 33 | den2 = target_f * target_f # b, c, h, w -- g^2 34 | den2 = torch.sum(den2, dim=3) # b, c, h 35 | den2 = torch.sum(den2, dim=2) # b, c 36 | 37 | dice = (2 * num + smooth) / (den1 + den2 + smooth) 38 | 39 | return dice 40 | 41 | 42 | def dice_loss(input, target, optimize_bg=False, weight=None, 43 | smooth=1.0, size_average=True, reduce=True): 44 | """input : is a torch variable of size BatchxnclassesxHxW representing 45 | log probabilities for each class 46 | target : is a 1-hot representation of the groundtruth, shoud have same size 47 | as the input 48 | 49 | weight (Variable, optional): a manual rescaling weight given to each 50 | class. If given, has to be a Variable of size "nclasses""" 51 | 52 | dice = dice_coefficient(input, target, smooth=smooth) 53 | 54 | if not optimize_bg: 55 | # we ignore bg dice val, and take the fg 56 | dice = dice[:, 1:] 57 | 58 | if not isinstance(weight, type(None)): 59 | if not optimize_bg: 60 | weight = weight[1:] # ignore bg weight 61 | weight = weight.size(0) * weight / weight.sum() # normalize fg weights 62 | dice = dice * weight # weighting 63 | 64 | # loss is calculated using mean over fg dice vals 65 | dice_loss = 1 - dice.mean(1) 66 | 67 | if not reduce: 68 | return dice_loss 69 | 70 | if size_average: 71 | return dice_loss.mean() 72 | 73 | return dice_loss.sum() 74 | 75 | 76 | class DiceLoss(_WeightedLoss): 77 | 78 | def __init__(self, optimize_bg=False, weight=None, 79 | smooth=1.0, size_average=True, reduce=True): 80 | """input : is a torch variable of size BatchxnclassesxHxW representing 81 | log probabilities for each class 82 | target : is a 1-hot representation of the groundtruth, shoud have same 83 | size as the input 84 | 85 | weight (Variable, optional): a manual rescaling weight given to each 86 | class. If given, has to be a Variable of size "nclasses""" 87 | 88 | super(DiceLoss, self).__init__(weight, size_average) 89 | self.optimize_bg = optimize_bg 90 | self.smooth = smooth 91 | self.reduce = reduce 92 | 93 | def forward(self, input, target): 94 | _assert_no_grad(target) 95 | return dice_loss(input, target, optimize_bg=self.optimize_bg, 96 | weight=self.weight, smooth=self.smooth, 97 | size_average=self.size_average, 98 | reduce=self.reduce) 99 | 100 | 101 | class DiceCoefficient(torch.nn.Module): 102 | 103 | def __init__(self, smooth=1.0): 104 | """input : is a torch variable of size BatchxnclassesxHxW representing 105 | log probabilities for each class 106 | target : is a 1-hot representation of the groundtruth, shoud have same 107 | size as the input""" 108 | super(DiceCoefficient, self).__init__() 109 | 110 | self.smooth = smooth 111 | 112 | def forward(self, input, target): 113 | _assert_no_grad(target) 114 | return dice_coefficient(input, target, smooth=self.smooth) 115 | 116 | 117 | if __name__ == '__main__': 118 | from torch.autograd import Variable 119 | input = torch.FloatTensor([[-3, -1, 100, -20], [-5, -20, 5, 5]]) 120 | input = Variable(input.unsqueeze(2).unsqueeze(3)) 121 | target = torch.IntTensor([[0, 0, 1, 0], [0, 0, 0, 1]]) 122 | target = Variable(target.unsqueeze(2).unsqueeze(3)) 123 | 124 | weight = Variable(torch.FloatTensor(np.array([1.0, 1.0, 1.0, 1.0]))) 125 | 126 | dice_loss_1 = DiceLoss(weight=weight) 127 | # dice_loss_2 = DiceLoss(size_average=False) 128 | # dice_loss_3 = DiceLoss(reduce=False) 129 | 130 | print dice_loss_1(input, target) 131 | # print dice_loss_2(input, target) 132 | # print dice_loss_3(input, target) 133 | print dice_coefficient(input, target, smooth=1.0) 134 | -------------------------------------------------------------------------------- /code/lib/losses/discriminative.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.loss import _assert_no_grad, _Loss 2 | from torch.autograd import Variable 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def calculate_means(pred, gt, n_objects, max_n_objects, usegpu): 8 | """pred: bs, height * width, n_filters 9 | gt: bs, height * width, n_instances""" 10 | 11 | bs, n_loc, n_filters = pred.size() 12 | n_instances = gt.size(2) 13 | 14 | pred_repeated = pred.unsqueeze(2).expand( 15 | bs, n_loc, n_instances, n_filters) # bs, n_loc, n_instances, n_filters 16 | # bs, n_loc, n_instances, 1 17 | gt_expanded = gt.unsqueeze(3) 18 | 19 | pred_masked = pred_repeated * gt_expanded 20 | 21 | means = [] 22 | for i in range(bs): 23 | _n_objects_sample = n_objects[i] 24 | # n_loc, n_objects, n_filters 25 | _pred_masked_sample = pred_masked[i, :, : _n_objects_sample] 26 | # n_loc, n_objects, 1 27 | _gt_expanded_sample = gt_expanded[i, :, : _n_objects_sample] 28 | 29 | _mean_sample = _pred_masked_sample.sum( 30 | 0) / _gt_expanded_sample.sum(0) # n_objects, n_filters 31 | if (max_n_objects - _n_objects_sample) != 0: 32 | n_fill_objects = int(max_n_objects - _n_objects_sample) 33 | _fill_sample = torch.zeros(n_fill_objects, n_filters) 34 | if usegpu: 35 | _fill_sample = _fill_sample.cuda() 36 | _fill_sample = Variable(_fill_sample) 37 | _mean_sample = torch.cat((_mean_sample, _fill_sample), dim=0) 38 | means.append(_mean_sample) 39 | 40 | means = torch.stack(means) 41 | 42 | # means = pred_masked.sum(1) / gt_expanded.sum(1) 43 | # # bs, n_instances, n_filters 44 | 45 | return means 46 | 47 | 48 | def calculate_variance_term(pred, gt, means, n_objects, delta_v, norm=2): 49 | """pred: bs, height * width, n_filters 50 | gt: bs, height * width, n_instances 51 | means: bs, n_instances, n_filters""" 52 | 53 | bs, n_loc, n_filters = pred.size() 54 | n_instances = gt.size(2) 55 | 56 | # bs, n_loc, n_instances, n_filters 57 | means = means.unsqueeze(1).expand(bs, n_loc, n_instances, n_filters) 58 | # bs, n_loc, n_instances, n_filters 59 | pred = pred.unsqueeze(2).expand(bs, n_loc, n_instances, n_filters) 60 | # bs, n_loc, n_instances, n_filters 61 | gt = gt.unsqueeze(3).expand(bs, n_loc, n_instances, n_filters) 62 | 63 | _var = (torch.clamp(torch.norm((pred - means), norm, 3) - 64 | delta_v, min=0.0) ** 2) * gt[:, :, :, 0] 65 | 66 | var_term = 0.0 67 | for i in range(bs): 68 | _var_sample = _var[i, :, :n_objects[i]] # n_loc, n_objects 69 | _gt_sample = gt[i, :, :n_objects[i], 0] # n_loc, n_objects 70 | 71 | var_term += torch.sum(_var_sample) / torch.sum(_gt_sample) 72 | var_term = var_term / bs 73 | 74 | return var_term 75 | 76 | 77 | def calculate_distance_term(means, n_objects, delta_d, norm=2, usegpu=True): 78 | """means: bs, n_instances, n_filters""" 79 | 80 | bs, n_instances, n_filters = means.size() 81 | 82 | dist_term = 0.0 83 | for i in range(bs): 84 | _n_objects_sample = int(n_objects[i]) 85 | 86 | if _n_objects_sample <= 1: 87 | continue 88 | 89 | _mean_sample = means[i, : _n_objects_sample, :] # n_objects, n_filters 90 | means_1 = _mean_sample.unsqueeze(1).expand( 91 | _n_objects_sample, _n_objects_sample, n_filters) 92 | means_2 = means_1.permute(1, 0, 2) 93 | 94 | diff = means_1 - means_2 # n_objects, n_objects, n_filters 95 | 96 | _norm = torch.norm(diff, norm, 2) 97 | 98 | margin = 2 * delta_d * (1.0 - torch.eye(_n_objects_sample)) 99 | if usegpu: 100 | margin = margin.cuda() 101 | margin = Variable(margin) 102 | 103 | _dist_term_sample = torch.sum( 104 | torch.clamp(margin - _norm, min=0.0) ** 2) 105 | _dist_term_sample = _dist_term_sample / \ 106 | (_n_objects_sample * (_n_objects_sample - 1)) 107 | dist_term += _dist_term_sample 108 | 109 | dist_term = dist_term / bs 110 | 111 | return dist_term 112 | 113 | 114 | def calculate_regularization_term(means, n_objects, norm): 115 | """means: bs, n_instances, n_filters""" 116 | 117 | bs, n_instances, n_filters = means.size() 118 | 119 | reg_term = 0.0 120 | for i in range(bs): 121 | _mean_sample = means[i, : n_objects[i], :] # n_objects, n_filters 122 | _norm = torch.norm(_mean_sample, norm, 1) 123 | reg_term += torch.mean(_norm) 124 | reg_term = reg_term / bs 125 | 126 | return reg_term 127 | 128 | 129 | def discriminative_loss(input, target, n_objects, 130 | max_n_objects, delta_v, delta_d, norm, usegpu): 131 | """input: bs, n_filters, fmap, fmap 132 | target: bs, n_instances, fmap, fmap 133 | n_objects: bs""" 134 | 135 | alpha = beta = 1.0 136 | gamma = 0.001 137 | 138 | bs, n_filters, height, width = input.size() 139 | n_instances = target.size(1) 140 | 141 | input = input.permute(0, 2, 3, 1).contiguous().view( 142 | bs, height * width, n_filters) 143 | target = target.permute(0, 2, 3, 1).contiguous().view( 144 | bs, height * width, n_instances) 145 | 146 | cluster_means = calculate_means( 147 | input, target, n_objects, max_n_objects, usegpu) 148 | 149 | var_term = calculate_variance_term( 150 | input, target, cluster_means, n_objects, delta_v, norm) 151 | dist_term = calculate_distance_term( 152 | cluster_means, n_objects, delta_d, norm, usegpu) 153 | reg_term = calculate_regularization_term(cluster_means, n_objects, norm) 154 | 155 | loss = alpha * var_term + beta * dist_term + gamma * reg_term 156 | 157 | return loss 158 | 159 | 160 | class DiscriminativeLoss(_Loss): 161 | 162 | def __init__(self, delta_var, delta_dist, norm, 163 | size_average=True, reduce=True, usegpu=True): 164 | super(DiscriminativeLoss, self).__init__(size_average) 165 | self.reduce = reduce 166 | 167 | assert self.size_average 168 | assert self.reduce 169 | 170 | self.delta_var = float(delta_var) 171 | self.delta_dist = float(delta_dist) 172 | self.norm = int(norm) 173 | self.usegpu = usegpu 174 | 175 | assert self.norm in [1, 2] 176 | 177 | def forward(self, input, target, n_objects, max_n_objects): 178 | _assert_no_grad(target) 179 | return discriminative_loss(input, target, n_objects, max_n_objects, 180 | self.delta_var, self.delta_dist, self.norm, 181 | self.usegpu) 182 | -------------------------------------------------------------------------------- /code/lib/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import visdom 4 | from sklearn.manifold import TSNE 5 | import torch 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau 8 | from torch.autograd import Variable 9 | import torch.backends.cudnn as cudnn 10 | import numpy as np 11 | from itertools import ifilter 12 | 13 | from archs import ReSeg 14 | from archs import StackedRecurrentHourglass as SRecHg 15 | from losses import DiceLoss, DiceCoefficient, DiscriminativeLoss 16 | 17 | 18 | class Model(object): 19 | 20 | def __init__(self, dataset, model_name, n_classes, max_n_objects, 21 | use_instance_segmentation=False, use_coords=False, 22 | load_model_path='', usegpu=True): 23 | 24 | self.dataset = dataset 25 | self.model_name = model_name 26 | self.n_classes = n_classes 27 | self.max_n_objects = max_n_objects 28 | self.use_instance_segmentation = use_instance_segmentation 29 | self.use_coords = use_coords 30 | self.load_model_path = load_model_path 31 | self.usegpu = usegpu 32 | 33 | assert self.dataset in ['CVPPP', ] 34 | assert self.model_name in ['ReSeg', 'StackedRecurrentHourglass'] 35 | 36 | if self.dataset == 'CVPPP': 37 | if self.model_name == 'ReSeg': 38 | self.model = ReSeg(self.n_classes, 39 | self.use_instance_segmentation, 40 | pretrained=True, 41 | use_coordinates=self.use_coords, 42 | usegpu=self.usegpu) 43 | elif self.model_name == 'StackedRecurrentHourglass': 44 | self.model = SRecHg(self.n_classes, 45 | self.use_instance_segmentation, 46 | self.use_coords, 47 | pretrained=True, 48 | usegpu=self.usegpu) 49 | 50 | self.__load_weights() 51 | 52 | if self.usegpu: 53 | cudnn.benchmark = True 54 | self.model.cuda() 55 | # self.model = torch.nn.DataParallel(self.model, 56 | # device_ids=range(self.ngpus)) 57 | 58 | print self.model 59 | 60 | self.vis = visdom.Visdom() 61 | self.training_metric_vis, self.test_metric_vis = None, None 62 | if self.use_instance_segmentation: 63 | self.instance_seg_vis = None 64 | 65 | def __load_weights(self): 66 | 67 | if self.load_model_path != '': 68 | assert os.path.isfile(self.load_model_path), 'Model : {} does not \ 69 | exists!'.format(self.load_model_path) 70 | print 'Loading model from {}'.format(self.load_model_path) 71 | 72 | model_state_dict = self.model.state_dict() 73 | 74 | if self.usegpu: 75 | pretrained_state_dict = torch.load(self.load_model_path) 76 | else: 77 | pretrained_state_dict = torch.load( 78 | self.load_model_path, map_location=lambda storage, 79 | loc: storage) 80 | 81 | model_state_dict.update(pretrained_state_dict) 82 | self.model.load_state_dict(model_state_dict) 83 | 84 | def __define_variable(self, tensor, volatile=False): 85 | if volatile: 86 | with torch.no_grad(): 87 | return Variable(tensor) 88 | 89 | return Variable(tensor) 90 | 91 | def __define_input_variables( 92 | self, features, fg_labels, ins_labels, n_objects, mode): 93 | 94 | volatile = True 95 | if mode == 'training': 96 | volatile = False 97 | 98 | features_var = self.__define_variable(features, volatile=volatile) 99 | fg_labels_var = self.__define_variable(fg_labels, volatile=volatile) 100 | ins_labels_var = self.__define_variable(ins_labels, volatile=volatile) 101 | n_objects_var = self.__define_variable(n_objects, volatile=volatile) 102 | 103 | return features_var, fg_labels_var, ins_labels_var, n_objects_var 104 | 105 | def __define_criterion(self, class_weights, delta_var, 106 | delta_dist, norm=2, optimize_bg=False, 107 | criterion='CE'): 108 | assert criterion in ['CE', 'Dice', 'Multi', None] 109 | 110 | smooth = 1.0 111 | 112 | # Discriminative Loss 113 | if self.use_instance_segmentation: 114 | self.criterion_discriminative = DiscriminativeLoss( 115 | delta_var, delta_dist, norm, self.usegpu) 116 | if self.usegpu: 117 | self.criterion_discriminative = \ 118 | self.criterion_discriminative.cuda() 119 | 120 | # FG Segmentation Loss 121 | if class_weights is not None: 122 | class_weights = self.__define_variable( 123 | torch.FloatTensor(class_weights)) 124 | if criterion in ['CE', 'Multi']: 125 | self.criterion_ce = torch.nn.CrossEntropyLoss(class_weights) 126 | if criterion in ['Dice', 'Multi']: 127 | self.criterion_dice = DiceLoss( 128 | optimize_bg=optimize_bg, weight=class_weights, 129 | smooth=smooth) 130 | else: 131 | if criterion in ['CE', 'Multi']: 132 | self.criterion_ce = torch.nn.CrossEntropyLoss() 133 | if criterion in ['Dice', 'Multi']: 134 | self.criterion_dice = DiceLoss( 135 | optimize_bg=optimize_bg, smooth=smooth) 136 | 137 | # MSE Loss 138 | self.criterion_mse = torch.nn.MSELoss() 139 | 140 | if self.usegpu: 141 | if criterion in ['CE', 'Multi']: 142 | self.criterion_ce = self.criterion_ce.cuda() 143 | if criterion in ['Dice', 'Multi']: 144 | self.criterion_dice = self.criterion_dice.cuda() 145 | 146 | self.criterion_mse = self.criterion_mse.cuda() 147 | 148 | def __define_optimizer(self, learning_rate, weight_decay, 149 | lr_drop_factor, lr_drop_patience, optimizer='Adam'): 150 | assert optimizer in ['RMSprop', 'Adam', 'Adadelta', 'SGD'] 151 | 152 | parameters = ifilter(lambda p: p.requires_grad, 153 | self.model.parameters()) 154 | 155 | if optimizer == 'RMSprop': 156 | self.optimizer = optim.RMSprop( 157 | parameters, lr=learning_rate, weight_decay=weight_decay) 158 | elif optimizer == 'Adadelta': 159 | self.optimizer = optim.Adadelta( 160 | parameters, lr=learning_rate, weight_decay=weight_decay) 161 | elif optimizer == 'Adam': 162 | self.optimizer = optim.Adam( 163 | parameters, lr=learning_rate, weight_decay=weight_decay) 164 | elif optimizer == 'SGD': 165 | self.optimizer = optim.SGD( 166 | parameters, lr=learning_rate, momentum=0.9, 167 | weight_decay=weight_decay) 168 | 169 | self.lr_scheduler = ReduceLROnPlateau( 170 | self.optimizer, mode='min', factor=lr_drop_factor, 171 | patience=lr_drop_patience, verbose=True) 172 | 173 | @staticmethod 174 | def __get_loss_averager(): 175 | return averager() 176 | 177 | def __minibatch(self, train_test_iter, clip_grad_norm, 178 | criterion_type, train_cnn=True, mode='training', 179 | debug=False): 180 | assert mode in ['training', 181 | 'test'], 'Mode must be either "training" or "test"' 182 | 183 | if mode == 'training': 184 | for param in self.model.parameters(): 185 | param.requires_grad = True 186 | if not train_cnn: 187 | for param in self.model.cnn.parameters(): 188 | param.requires_grad = False 189 | self.model.train() 190 | else: 191 | for param in self.model.parameters(): 192 | param.requires_grad = False 193 | self.model.eval() 194 | 195 | cpu_images, cpu_sem_seg_annotations, \ 196 | cpu_ins_seg_annotations, cpu_n_objects = train_test_iter.next() 197 | cpu_images = cpu_images.contiguous() 198 | cpu_sem_seg_annotations = cpu_sem_seg_annotations.contiguous() 199 | cpu_ins_seg_annotations = cpu_ins_seg_annotations.contiguous() 200 | cpu_n_objects = cpu_n_objects.contiguous() 201 | 202 | if self.usegpu: 203 | gpu_images = cpu_images.cuda(async=True) 204 | gpu_sem_seg_annotations = cpu_sem_seg_annotations.cuda(async=True) 205 | gpu_ins_seg_annotations = cpu_ins_seg_annotations.cuda(async=True) 206 | gpu_n_objects = cpu_n_objects.cuda(async=True) 207 | else: 208 | gpu_images = cpu_images 209 | gpu_sem_seg_annotations = cpu_sem_seg_annotations 210 | gpu_ins_seg_annotations = cpu_ins_seg_annotations 211 | gpu_n_objects = cpu_n_objects 212 | 213 | gpu_images, gpu_sem_seg_annotations, \ 214 | gpu_ins_seg_annotations, gpu_n_objects = \ 215 | self.__define_input_variables(gpu_images, 216 | gpu_sem_seg_annotations, 217 | gpu_ins_seg_annotations, 218 | gpu_n_objects, mode) 219 | 220 | gpu_n_objects = gpu_n_objects.unsqueeze(dim=1) 221 | 222 | gpu_n_objects_normalized = gpu_n_objects.float() / self.max_n_objects 223 | 224 | sem_seg_predictions, ins_seg_predictions, \ 225 | n_objects_predictions = self.model(gpu_images) 226 | 227 | if mode == 'test': 228 | if debug: 229 | _vis_prob = np.random.rand() 230 | if _vis_prob > 0.7: 231 | if self.use_instance_segmentation: 232 | sem_seg_preds = np.argmax( 233 | sem_seg_predictions.data.cpu().numpy(), axis=1) 234 | seg_preds = ins_seg_predictions.data.cpu().numpy() 235 | 236 | _bs, _n_feats = seg_preds.shape[:2] 237 | 238 | _sample_idx = np.random.randint(_bs) 239 | _sem_seg_preds_sample = sem_seg_preds[_sample_idx] 240 | _seg_preds_sample = seg_preds[_sample_idx] 241 | 242 | fg_ins_embeddings = np.stack( 243 | [_seg_preds_sample[i][np.where( 244 | _sem_seg_preds_sample == 1)] 245 | for i in range(_n_feats)], axis=1) 246 | _n_fg_samples = fg_ins_embeddings.shape[0] 247 | if _n_fg_samples > 0: 248 | fg_ins_embeddings = \ 249 | fg_ins_embeddings[np.random.choice( 250 | range(_n_fg_samples), size=400)] 251 | 252 | tsne = TSNE(n_components=2, random_state=0) 253 | fg_ins_embeddings_vis = tsne.fit_transform( 254 | fg_ins_embeddings) 255 | 256 | if self.instance_seg_vis: 257 | self.vis.scatter(X=fg_ins_embeddings_vis, 258 | win=self.instance_seg_vis, 259 | opts={'title': 260 | 'Predicted Embeddings \ 261 | for Foreground \ 262 | Predictions', 263 | 'markersize': 2}) 264 | else: 265 | self.instance_seg_vis =\ 266 | self.vis.scatter(X=fg_ins_embeddings_vis, 267 | opts={'title': 268 | 'Predicted \ 269 | Embeddings for \ 270 | Foreground \ 271 | Predictions', 272 | 'markersize': 2}) 273 | 274 | cost = 0.0 275 | out_metrics = dict() 276 | 277 | if self.use_instance_segmentation: 278 | disc_cost = self.criterion_discriminative( 279 | ins_seg_predictions, gpu_ins_seg_annotations.float(), 280 | cpu_n_objects, self.max_n_objects) 281 | cost += disc_cost 282 | out_metrics['Discriminative Cost'] = disc_cost.data 283 | 284 | if criterion_type in ['CE', 'Multi']: 285 | _, gpu_sem_seg_annotations_criterion_ce = \ 286 | gpu_sem_seg_annotations.max(1) 287 | ce_cost = self.criterion_ce( 288 | sem_seg_predictions.permute(0, 2, 3, 1).contiguous().view( 289 | -1, self.n_classes), 290 | gpu_sem_seg_annotations_criterion_ce.view(-1)) 291 | cost += ce_cost 292 | out_metrics['CE Cost'] = ce_cost.data 293 | if criterion_type in ['Dice', 'Multi']: 294 | dice_cost = self.criterion_dice( 295 | sem_seg_predictions, gpu_sem_seg_annotations) 296 | cost += dice_cost 297 | out_metrics['Dice Cost'] = dice_cost.data 298 | 299 | mse_cost = self.criterion_mse( 300 | n_objects_predictions, gpu_n_objects_normalized) 301 | cost += mse_cost 302 | out_metrics['MSE Cost'] = mse_cost.data 303 | 304 | if mode == 'training': 305 | self.model.zero_grad() 306 | cost.backward() 307 | if clip_grad_norm != 0: 308 | torch.nn.utils.clip_grad_norm_( 309 | self.model.parameters(), clip_grad_norm) 310 | self.optimizer.step() 311 | 312 | return out_metrics 313 | 314 | def __test(self, test_loader, criterion_type, epoch, debug): 315 | 316 | n_minibatches = len(test_loader) 317 | 318 | test_iter = iter(test_loader) 319 | 320 | out_metrics = dict() 321 | for minibatch_index in range(n_minibatches): 322 | mb_out_metrics = self.__minibatch( 323 | test_iter, 0.0, criterion_type, train_cnn=False, mode='test', 324 | debug=debug) 325 | for mk, mv in mb_out_metrics.iteritems(): 326 | if mk not in out_metrics: 327 | out_metrics[mk] = [] 328 | out_metrics[mk].append(mv) 329 | 330 | test_metric_vis_data, test_metric_vis_legend = [], [] 331 | metrics_as_str = 'Testing: [METRIC]' 332 | for mk, mv in out_metrics.iteritems(): 333 | out_metrics[mk] = torch.stack(mv, dim=0).mean() 334 | metrics_as_str += ' {} : {} |'.format(mk, out_metrics[mk]) 335 | 336 | test_metric_vis_data.append(out_metrics[mk]) 337 | test_metric_vis_legend.append(mk) 338 | 339 | print metrics_as_str 340 | 341 | test_metric_vis_data = np.expand_dims( 342 | np.array(test_metric_vis_data), 0) 343 | 344 | if self.test_metric_vis: 345 | self.vis.line(X=np.array([epoch]), 346 | Y=test_metric_vis_data, 347 | win=self.test_metric_vis, 348 | update='append') 349 | else: 350 | self.test_metric_vis = self.vis.line(X=np.array([epoch]), 351 | Y=test_metric_vis_data, 352 | opts={'legend': 353 | test_metric_vis_legend, 354 | 'title': 'Test Metrics', 355 | 'showlegend': True, 356 | 'xlabel': 'Epoch', 357 | 'ylabel': 'Metric'}) 358 | 359 | return out_metrics 360 | 361 | def fit(self, criterion_type, delta_var, delta_dist, norm, 362 | learning_rate, weight_decay, clip_grad_norm, 363 | lr_drop_factor, lr_drop_patience, optimize_bg, optimizer, 364 | train_cnn, n_epochs, class_weights, train_loader, test_loader, 365 | model_save_path, debug): 366 | 367 | assert criterion_type in ['CE', 'Dice', 'Multi'] 368 | 369 | training_log_file = open(os.path.join( 370 | model_save_path, 'training.log'), 'w') 371 | validation_log_file = open(os.path.join( 372 | model_save_path, 'validation.log'), 'w') 373 | 374 | training_log_file.write('Epoch,Cost\n') 375 | validation_log_file.write('Epoch,Cost\n') 376 | 377 | self.__define_criterion(class_weights, delta_var, delta_dist, 378 | norm=norm, optimize_bg=optimize_bg, 379 | criterion=criterion_type) 380 | self.__define_optimizer(learning_rate, weight_decay, 381 | lr_drop_factor, lr_drop_patience, 382 | optimizer=optimizer) 383 | 384 | self.__test(test_loader, criterion_type, -1.0, debug) 385 | 386 | best_val_cost = np.Inf 387 | for epoch in range(n_epochs): 388 | epoch_start = time.time() 389 | 390 | train_iter = iter(train_loader) 391 | n_minibatches = len(train_loader) 392 | 393 | train_out_metrics = dict() 394 | 395 | minibatch_index = 0 396 | while minibatch_index < n_minibatches: 397 | mb_out_metrics = self.__minibatch(train_iter, clip_grad_norm, 398 | criterion_type, 399 | train_cnn=train_cnn, 400 | mode='training', debug=debug) 401 | for mk, mv in mb_out_metrics.iteritems(): 402 | if mk not in train_out_metrics: 403 | train_out_metrics[mk] = [] 404 | train_out_metrics[mk].append(mv) 405 | 406 | minibatch_index += 1 407 | 408 | epoch_end = time.time() 409 | epoch_duration = epoch_end - epoch_start 410 | 411 | training_metric_vis_data, training_metric_vis_legend = [], [] 412 | 413 | print 'Epoch : [{}/{}] - [{}]'.format(epoch, 414 | n_epochs, epoch_duration) 415 | metrics_as_str = 'Training: [METRIC]' 416 | for mk, mv in train_out_metrics.iteritems(): 417 | train_out_metrics[mk] = torch.stack(mv, dim=0).mean() 418 | metrics_as_str += ' {} : {} |'.format(mk, 419 | train_out_metrics[mk]) 420 | 421 | training_metric_vis_data.append(train_out_metrics[mk]) 422 | training_metric_vis_legend.append(mk) 423 | 424 | print metrics_as_str 425 | 426 | training_metric_vis_data = np.expand_dims( 427 | np.array(training_metric_vis_data), 0) 428 | 429 | if self.training_metric_vis: 430 | self.vis.line(X=np.array([epoch]), 431 | Y=training_metric_vis_data, 432 | win=self.training_metric_vis, update='append') 433 | else: 434 | self.training_metric_vis = self.vis.line( 435 | X=np.array([epoch]), Y=training_metric_vis_data, 436 | opts={'legend': training_metric_vis_legend, 437 | 'title': 'Training Metrics', 438 | 'showlegend': True, 'xlabel': 'Epoch', 439 | 'ylabel': 'Metric'}) 440 | 441 | val_out_metrics = self.__test( 442 | test_loader, criterion_type, epoch, debug) 443 | if self.use_instance_segmentation: 444 | val_cost = val_out_metrics['Discriminative Cost'] 445 | train_cost = train_out_metrics['Discriminative Cost'] 446 | elif criterion_type in ['Dice', 'Multi']: 447 | val_cost = val_out_metrics['Dice Cost'] 448 | train_cost = train_out_metrics['Dice Cost'] 449 | else: 450 | val_cost = val_out_metrics['CE Cost'] 451 | train_cost = train_out_metrics['CE Cost'] 452 | 453 | self.lr_scheduler.step(val_cost) 454 | 455 | is_best_model = val_cost <= best_val_cost 456 | 457 | if is_best_model: 458 | best_val_cost = val_cost 459 | torch.save(self.model.state_dict(), os.path.join( 460 | model_save_path, 'model_{}_{}.pth'.format(epoch, 461 | val_cost))) 462 | 463 | training_log_file.write('{},{}\n'.format(epoch, train_cost)) 464 | validation_log_file.write('{},{}\n'.format(epoch, val_cost)) 465 | training_log_file.flush() 466 | validation_log_file.flush() 467 | 468 | training_log_file.close() 469 | validation_log_file.close() 470 | 471 | def predict(self, images): 472 | 473 | assert len(images.size()) == 4 # b, c, h, w 474 | 475 | for param in self.model.parameters(): 476 | param.requires_grad = False 477 | self.model.eval() 478 | 479 | images = images.contiguous() 480 | if self.usegpu: 481 | images = images.cuda(async=True) 482 | 483 | images = self.__define_variable(images, volatile=True) 484 | 485 | sem_seg_predictions, ins_seg_predictions, n_objects_predictions = \ 486 | self.model(images) 487 | 488 | sem_seg_predictions = torch.nn.functional.softmax( 489 | sem_seg_predictions, dim=1) 490 | 491 | n_objects_predictions = n_objects_predictions * self.max_n_objects 492 | n_objects_predictions = torch.round(n_objects_predictions).int() 493 | 494 | sem_seg_predictions = sem_seg_predictions.data.cpu() 495 | ins_seg_predictions = ins_seg_predictions.data.cpu() 496 | n_objects_predictions = n_objects_predictions.data.cpu() 497 | 498 | return sem_seg_predictions, ins_seg_predictions, n_objects_predictions 499 | 500 | 501 | class averager(object): 502 | """Compute average for `torch.Variable` and `torch.Tensor`.""" 503 | 504 | def __init__(self): 505 | self.reset() 506 | 507 | def add(self, v): 508 | if isinstance(v, Variable): 509 | count = v.data.numel() 510 | v = v.data.sum() 511 | elif isinstance(v, torch.Tensor): 512 | count = v.numel() 513 | v = v.sum() 514 | 515 | self.n_count += count 516 | self.sum += v 517 | 518 | def reset(self): 519 | self.n_count = 0 520 | self.sum = 0 521 | 522 | def val(self): 523 | res = 0 524 | if self.n_count != 0: 525 | res = self.sum / float(self.n_count) 526 | return res 527 | -------------------------------------------------------------------------------- /code/lib/prediction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from PIL import Image 4 | import cv2 5 | from sklearn.cluster import KMeans 6 | from utils import ImageUtilities 7 | from archs.modules.coord_conv import AddCoordinates 8 | 9 | 10 | class Prediction(object): 11 | 12 | def __init__(self, resize_height, resize_width, mean, 13 | std, use_coordinates, model, n_workers): 14 | 15 | self.normalizer = ImageUtilities.image_normalizer(mean, std) 16 | self.use_coordinates = use_coordinates 17 | 18 | self.resize_height = resize_height 19 | self.resize_width = resize_width 20 | self.model = model 21 | 22 | self.n_workers = n_workers 23 | 24 | self.img_resizer = ImageUtilities.image_resizer( 25 | self.resize_height, self.resize_width) 26 | 27 | if self.use_coordinates: 28 | self.coordinate_adder = AddCoordinates(with_r=True, 29 | usegpu=False) 30 | 31 | def get_image(self, image_path): 32 | 33 | img = ImageUtilities.read_image(image_path) 34 | image_width, image_height = img.size 35 | 36 | img = self.img_resizer(img) 37 | img = self.normalizer(img) 38 | 39 | return img, image_height, image_width 40 | 41 | def get_annotation(self, annotation_path): 42 | 43 | img = ImageUtilities.read_image(annotation_path) 44 | return img 45 | 46 | def upsample_prediction(self, prediction, image_height, image_width): 47 | 48 | return cv2.resize(prediction, (image_width, image_height), 49 | interpolation=cv2.INTER_NEAREST) 50 | 51 | def cluster(self, sem_seg_prediction, ins_seg_prediction, 52 | n_objects_prediction): 53 | 54 | seg_height, seg_width = ins_seg_prediction.shape[1:] 55 | 56 | sem_seg_prediction = sem_seg_prediction.cpu().numpy() 57 | sem_seg_prediction = sem_seg_prediction.argmax(0).astype(np.uint8) 58 | 59 | embeddings = ins_seg_prediction.cpu() 60 | if self.use_coordinates: 61 | embeddings = self.coordinate_adder(embeddings) 62 | embeddings = embeddings.numpy() 63 | embeddings = embeddings.transpose(1, 2, 0) # h, w, c 64 | 65 | n_objects_prediction = n_objects_prediction.cpu().numpy()[0] 66 | 67 | embeddings = np.stack([embeddings[:, :, i][sem_seg_prediction != 0] 68 | for i in range(embeddings.shape[2])], axis=1) 69 | 70 | 71 | labels = KMeans(n_clusters=n_objects_prediction, 72 | n_init=35, max_iter=500, 73 | n_jobs=self.n_workers).fit_predict(embeddings) 74 | 75 | instance_mask = np.zeros((seg_height, seg_width), dtype=np.uint8) 76 | 77 | fg_coords = np.where(sem_seg_prediction != 0) 78 | for si in range(len(fg_coords[0])): 79 | y_coord = fg_coords[0][si] 80 | x_coord = fg_coords[1][si] 81 | _label = labels[si] + 1 82 | instance_mask[y_coord, x_coord] = _label 83 | 84 | return sem_seg_prediction, instance_mask, n_objects_prediction 85 | 86 | def predict(self, image_path): 87 | 88 | image, image_height, image_width = self.get_image(image_path) 89 | image = image.unsqueeze(0) 90 | 91 | sem_seg_prediction, ins_seg_prediction, n_objects_prediction = \ 92 | self.model.predict(image) 93 | 94 | sem_seg_prediction = sem_seg_prediction.squeeze(0) 95 | ins_seg_prediction = ins_seg_prediction.squeeze(0) 96 | n_objects_prediction = n_objects_prediction.squeeze(0) 97 | 98 | sem_seg_prediction, ins_seg_prediction, \ 99 | n_objects_prediction = self.cluster(sem_seg_prediction, 100 | ins_seg_prediction, 101 | n_objects_prediction) 102 | 103 | sem_seg_prediction = self.upsample_prediction( 104 | sem_seg_prediction, image_height, image_width) 105 | ins_seg_prediction = self.upsample_prediction( 106 | ins_seg_prediction, image_height, image_width) 107 | 108 | raw_image_pil = ImageUtilities.read_image(image_path) 109 | raw_image = np.array(raw_image_pil) 110 | 111 | return raw_image, sem_seg_prediction, ins_seg_prediction, \ 112 | n_objects_prediction 113 | -------------------------------------------------------------------------------- /code/lib/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | try: 4 | import accimage 5 | except ImportError: 6 | accimage = None 7 | import random 8 | import math 9 | import numbers 10 | import collections 11 | import numpy as np 12 | 13 | 14 | def _is_pil_image(img): 15 | if accimage is not None: 16 | return isinstance(img, (Image.Image, accimage.Image)) 17 | else: 18 | return isinstance(img, Image.Image) 19 | 20 | def crop(img, i, j, h, w): 21 | """Crop the given PIL Image. 22 | Args: 23 | img (PIL Image): Image to be cropped. 24 | i: Upper pixel coordinate. 25 | j: Left pixel coordinate. 26 | h: Height of the cropped image. 27 | w: Width of the cropped image. 28 | Returns: 29 | PIL Image: Cropped image. 30 | """ 31 | if not _is_pil_image(img): 32 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 33 | 34 | return img.crop((j, i, j + w, i + h)) 35 | 36 | def resize(img, size, interpolation=Image.BILINEAR): 37 | """Resize the input PIL Image to the given size. 38 | Args: 39 | img (PIL Image): Image to be resized. 40 | size (sequence or int): Desired output size. If size is a sequence like 41 | (h, w), the output size will be matched to this. If size is an int, 42 | the smaller edge of the image will be matched to this number maintaing 43 | the aspect ratio. i.e, if height > width, then image will be rescaled to 44 | (size * height / width, size) 45 | interpolation (int, optional): Desired interpolation. Default is 46 | ``PIL.Image.BILINEAR`` 47 | Returns: 48 | PIL Image: Resized image. 49 | """ 50 | if not _is_pil_image(img): 51 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 52 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 53 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 54 | 55 | if isinstance(size, int): 56 | w, h = img.size 57 | if (w <= h and w == size) or (h <= w and h == size): 58 | return img 59 | if w < h: 60 | ow = size 61 | oh = int(size * h / w) 62 | return img.resize((ow, oh), interpolation) 63 | else: 64 | oh = size 65 | ow = int(size * w / h) 66 | return img.resize((ow, oh), interpolation) 67 | else: 68 | return img.resize(size[::-1], interpolation) 69 | 70 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 71 | """Crop the given PIL Image and resize it to desired size. 72 | Notably used in RandomResizedCrop. 73 | Args: 74 | img (PIL Image): Image to be cropped. 75 | i: Upper pixel coordinate. 76 | j: Left pixel coordinate. 77 | h: Height of the cropped image. 78 | w: Width of the cropped image. 79 | size (sequence or int): Desired output size. Same semantics as ``scale``. 80 | interpolation (int, optional): Desired interpolation. Default is 81 | ``PIL.Image.BILINEAR``. 82 | Returns: 83 | PIL Image: Cropped image. 84 | """ 85 | assert _is_pil_image(img), 'img should be PIL Image' 86 | img = crop(img, i, j, h, w) 87 | img = resize(img, size, interpolation) 88 | return img 89 | 90 | class RandomResizedCrop(object): 91 | """Crop the given PIL Image to random size and aspect ratio. 92 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 93 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 94 | is finally resized to given size. 95 | This is popularly used to train the Inception networks. 96 | Args: 97 | size: expected output size of each edge 98 | scale: range of size of the origin size cropped 99 | ratio: range of aspect ratio of the origin aspect ratio cropped 100 | interpolation: Default: PIL.Image.BILINEAR 101 | """ 102 | 103 | def __init__(self, size_height, size_width, interpolation=Image.BILINEAR): 104 | self.size = (size_height, size_width) 105 | self.interpolation = interpolation 106 | 107 | @staticmethod 108 | def get_params(img, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)): 109 | """Get parameters for ``crop`` for a random sized crop. 110 | Args: 111 | img (PIL Image): Image to be cropped. 112 | scale (tuple): range of size of the origin size cropped 113 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 114 | Returns: 115 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 116 | sized crop. 117 | """ 118 | for attempt in range(10): 119 | area = img.size[0] * img.size[1] 120 | target_area = random.uniform(*scale) * area 121 | aspect_ratio = random.uniform(*ratio) 122 | 123 | w = int(round(math.sqrt(target_area * aspect_ratio))) 124 | h = int(round(math.sqrt(target_area / aspect_ratio))) 125 | 126 | if random.random() < 0.5: 127 | w, h = h, w 128 | 129 | if w <= img.size[0] and h <= img.size[1]: 130 | i = random.randint(0, img.size[1] - h) 131 | j = random.randint(0, img.size[0] - w) 132 | return i, j, h, w 133 | 134 | # Fallback 135 | w = min(img.size[0], img.size[1]) 136 | i = (img.size[1] - w) // 2 137 | j = (img.size[0] - w) // 2 138 | return i, j, w, w 139 | 140 | def __call__(self, img, params): 141 | """ 142 | Args: 143 | img (PIL Image): Image to be flipped. 144 | Returns: 145 | PIL Image: Randomly cropped and resize image. 146 | """ 147 | i, j, h, w = params 148 | return resized_crop(img, i, j, h, w, self.size, self.interpolation) 149 | 150 | 151 | ### HORIZONTAL FLIPPING ### 152 | 153 | def hflip(img): 154 | """Horizontally flip the given PIL Image. 155 | Args: 156 | img (PIL Image): Image to be flipped. 157 | Returns: 158 | PIL Image: Horizontally flipped image. 159 | """ 160 | 161 | is_numpy = isinstance(img, np.ndarray) 162 | 163 | if not _is_pil_image(img): 164 | if is_numpy: 165 | img = Image.fromarray(img) 166 | else: 167 | raise TypeError( 168 | 'img should be PIL Image or numpy array. \ 169 | Got {}'.format(type(img))) 170 | 171 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 172 | 173 | if is_numpy: 174 | img = np.array(img) 175 | 176 | return img 177 | 178 | class RandomHorizontalFlip(object): 179 | """Horizontally flip the given PIL Image.""" 180 | 181 | def __call__(self, img, flip): 182 | """ 183 | Args: 184 | img (PIL Image): Image to be flipped. 185 | Returns: 186 | PIL Image: Flipped image. 187 | """ 188 | if flip: 189 | return hflip(img) 190 | return img 191 | 192 | 193 | ### VERTICAL FLIPPING ### 194 | 195 | def vflip(img): 196 | """Vertically flip the given PIL Image. 197 | Args: 198 | img (PIL Image): Image to be flipped. 199 | Returns: 200 | PIL Image: Vertically flipped image. 201 | """ 202 | 203 | is_numpy = isinstance(img, np.ndarray) 204 | 205 | if not _is_pil_image(img): 206 | if is_numpy: 207 | img = Image.fromarray(img) 208 | else: 209 | raise TypeError( 210 | 'img should be PIL Image or numpy array. \ 211 | Got {}'.format(type(img))) 212 | 213 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 214 | 215 | if is_numpy: 216 | img = np.array(img) 217 | 218 | return img 219 | 220 | class RandomVerticalFlip(object): 221 | """Vertically flip the given PIL Image.""" 222 | 223 | def __call__(self, img, flip): 224 | """ 225 | Args: 226 | img (PIL Image): Image to be flipped. 227 | Returns: 228 | PIL Image: Flipped image. 229 | """ 230 | if flip: 231 | return vflip(img) 232 | return img 233 | 234 | 235 | # TRANSPOSE # 236 | 237 | def transpose(img): 238 | """Transpose the given PIL Image. 239 | Args: 240 | img (PIL Image): Image to be transposed. 241 | Returns: 242 | PIL Image: Transposed image. 243 | """ 244 | 245 | is_numpy = isinstance(img, np.ndarray) 246 | 247 | if not _is_pil_image(img): 248 | if is_numpy: 249 | img = Image.fromarray(img) 250 | else: 251 | raise TypeError( 252 | 'img should be PIL Image or numpy array. \ 253 | Got {}'.format(type(img))) 254 | 255 | img = img.transpose(Image.TRANSPOSE) 256 | 257 | if is_numpy: 258 | img = np.array(img) 259 | 260 | return img 261 | 262 | 263 | class RandomTranspose(object): 264 | """Transpose the given PIL Image.""" 265 | 266 | def __call__(self, img, trans): 267 | """ 268 | Args: 269 | img (PIL Image): Image to be transposed. 270 | Returns: 271 | PIL Image: Transposed image. 272 | """ 273 | if trans: 274 | return transpose(img) 275 | return img 276 | 277 | 278 | ### RANDOM ROTATION ### 279 | 280 | def rotate(img, angle, resample=Image.BILINEAR, expand=True): 281 | 282 | is_numpy = isinstance(img, np.ndarray) 283 | 284 | if not _is_pil_image(img): 285 | if is_numpy: 286 | img = Image.fromarray(img) 287 | else: 288 | raise TypeError( 289 | 'img should be PIL Image or numpy array. \ 290 | Got {}'.format(type(img))) 291 | 292 | img = img.rotate(angle, resample=resample, expand=expand) 293 | 294 | if is_numpy: 295 | img = np.array(img) 296 | 297 | return img 298 | 299 | def rotate_with_random_bg(img, angle, resample=Image.BILINEAR, expand=True): 300 | 301 | is_numpy = isinstance(img, np.ndarray) 302 | 303 | if not _is_pil_image(img): 304 | if is_numpy: 305 | img = Image.fromarray(img) 306 | else: 307 | raise TypeError( 308 | 'img should be PIL Image or numpy array. \ 309 | Got {}'.format(type(img))) 310 | 311 | img_np = np.array(img) 312 | 313 | img = img.convert('RGBA') 314 | img = rotate(img, angle, resample=resample, expand=expand) 315 | 316 | key = np.random.choice([0, 1, 2, 3]) 317 | if key == 0: 318 | bg = Image.new('RGBA', img.size, (255, ) * 4) # White image 319 | elif key == 1: 320 | bg = Image.new('RGBA', img.size, (0, 0, 0, 255)) # Black image 321 | elif key == 2: 322 | mean_color = map(int, img_np.mean((0, 1))) 323 | bg = Image.new('RGBA', img.size, (mean_color[0], mean_color[1], mean_color[2], 255)) # Mean 324 | elif key == 3: 325 | median_color = map(int, np.median(img_np, (0, 1))) 326 | bg = Image.new('RGBA', img.size, (median_color[0], median_color[1], median_color[2], 255)) # Median 327 | 328 | img = Image.composite(img, bg, img) 329 | img = img.convert('RGB') 330 | 331 | if is_numpy: 332 | img = np.array(img) 333 | 334 | return img 335 | 336 | class RandomRotate(object): 337 | 338 | def __init__(self, interpolation=Image.BILINEAR, random_bg=True): 339 | self.interpolation = interpolation 340 | self.random_bg = random_bg 341 | 342 | def __call__(self, img, angle, expand): 343 | if self.random_bg: 344 | return rotate_with_random_bg(img, angle, resample=self.interpolation, expand=expand) 345 | else: 346 | return rotate(img, angle, resample=self.interpolation, expand=expand) 347 | 348 | ### RANDOM CHANNEL SWAPING ### 349 | 350 | def swap_channels(img): 351 | 352 | if not _is_pil_image(img): 353 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 354 | 355 | img_np = np.array(img) 356 | 357 | channel_idxes = np.random.choice([0, 1, 2], 3, True) 358 | 359 | return Image.fromarray(img_np[:, :, channel_idxes]) 360 | 361 | class RandomChannelSwap(object): 362 | 363 | def __init__(self, prob): 364 | self.prob = prob 365 | 366 | def __call__(self, img): 367 | if np.random.rand() >= self.prob: 368 | return img 369 | 370 | return swap_channels(img) 371 | 372 | ### GAMMA CORRECTION ### 373 | 374 | def adjust_gamma(img, gamma, gain=1): 375 | """Perform gamma correction on an image. 376 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 377 | based on the following equation: 378 | I_out = 255 * gain * ((I_in / 255) ** gamma) 379 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 380 | Args: 381 | img (PIL Image): PIL Image to be adjusted. 382 | gamma (float): Non negative real number. gamma larger than 1 make the 383 | shadows darker, while gamma smaller than 1 make dark regions 384 | lighter. 385 | gain (float): The constant multiplier. 386 | """ 387 | if not _is_pil_image(img): 388 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 389 | 390 | if gamma < 0: 391 | raise ValueError('Gamma should be a non-negative real number') 392 | 393 | gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 394 | img = img.point(gamma_map) # use PIL's point-function to accelerate this part 395 | 396 | return img 397 | 398 | class RandomGamma(object): 399 | 400 | def __init__(self, gamma_range, gain=1): 401 | self.min_gamma = gamma_range[0] 402 | self.max_gamma = gamma_range[1] 403 | 404 | self.gain = gain 405 | 406 | def __call__(self, img): 407 | gamma = np.random.rand() * (self.max_gamma - self.min_gamma) + self.min_gamma 408 | return adjust_gamma(img, gamma=gamma, gain=self.gain) 409 | 410 | ### RESOLUTION ### 411 | 412 | def random_resolution(img, ratio): 413 | 414 | if not _is_pil_image(img): 415 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 416 | 417 | img_size = np.array(img.size) 418 | new_size = (img_size * ratio).astype('int') 419 | 420 | img = img.resize(new_size, Image.ANTIALIAS) 421 | img = img.resize(img_size, Image.ANTIALIAS) 422 | 423 | return img 424 | 425 | class RandomResolution(object): 426 | 427 | def __init__(self, ratio_range): 428 | self.ratio_range = np.arange(ratio_range[0], ratio_range[1], 0.05) 429 | 430 | def __call__(self, img): 431 | _range = np.random.choice(self.ratio_range) 432 | return random_resolution(img, _range) 433 | -------------------------------------------------------------------------------- /code/lib/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torchvision.transforms as transforms 3 | from StringIO import StringIO 4 | 5 | from preprocess import RandomResizedCrop, RandomHorizontalFlip, \ 6 | RandomVerticalFlip, RandomTranspose, RandomRotate, \ 7 | RandomChannelSwap, RandomGamma, RandomResolution 8 | 9 | 10 | class ImageUtilities(object): 11 | 12 | @staticmethod 13 | def read_image(image_path, is_raw=False): 14 | if is_raw: 15 | img = Image.open(StringIO(image_path)) 16 | else: 17 | img = Image.open(image_path).convert('RGB') 18 | img_copy = img.copy() 19 | img.close() 20 | return img_copy 21 | 22 | @staticmethod 23 | def image_resizer(height, width, interpolation=Image.BILINEAR): 24 | return transforms.Resize((height, width), interpolation=interpolation) 25 | 26 | @staticmethod 27 | def image_random_cropper_and_resizer(height, width, interpolation=Image.BILINEAR): 28 | return RandomResizedCrop(height, width, interpolation=interpolation) 29 | 30 | @staticmethod 31 | def image_random_horizontal_flipper(): 32 | return RandomHorizontalFlip() 33 | 34 | @staticmethod 35 | def image_random_vertical_flipper(): 36 | return RandomVerticalFlip() 37 | 38 | @staticmethod 39 | def image_random_transposer(): 40 | return RandomTranspose() 41 | 42 | @staticmethod 43 | def image_normalizer(mean, std): 44 | return transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 45 | 46 | @staticmethod 47 | def image_random_rotator(interpolation=Image.BILINEAR, random_bg=True): 48 | return RandomRotate(interpolation=interpolation, random_bg=random_bg) 49 | 50 | @staticmethod 51 | def image_random_90x_rotator(interpolation=Image.BILINEAR): 52 | return RandomRotate(interpolation=interpolation, random_bg=False) 53 | 54 | @staticmethod 55 | def image_random_color_jitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2): 56 | return transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 57 | 58 | @staticmethod 59 | def image_random_grayscaler(p=0.5): 60 | return transforms.RandomGrayscale(p=p) 61 | 62 | @staticmethod 63 | def image_random_channel_swapper(p=0.5): 64 | return RandomChannelSwap(prob=p) 65 | 66 | @staticmethod 67 | def image_random_gamma(gamma_range, gain=1): 68 | return RandomGamma(gamma_range, gain=gain) 69 | 70 | @staticmethod 71 | def image_random_resolution(ratio_range): 72 | return RandomResolution(ratio_range) 73 | -------------------------------------------------------------------------------- /code/pred.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import matplotlib.pylab as plt 5 | import os 6 | import sys 7 | import argparse 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--image', required=True, help='Path of the image') 14 | parser.add_argument('--model', required=True, help='Path of the model') 15 | parser.add_argument('--usegpu', action='store_true', 16 | help='Enables cuda to predict on gpu') 17 | parser.add_argument('--output', required=True, 18 | help='Path of the output directory') 19 | parser.add_argument('--dataset', type=str, 20 | help='Name of the dataset which is "CVPPP"', 21 | required=True) 22 | opt = parser.parse_args() 23 | 24 | assert opt.dataset in ['CVPPP', ] 25 | 26 | image_path = opt.image 27 | model_path = opt.model 28 | output_path = opt.output 29 | 30 | try: 31 | os.makedirs(output_path) 32 | except BaseException: 33 | pass 34 | 35 | model_dir = os.path.dirname(model_path) 36 | sys.path.insert(0, model_dir) 37 | 38 | from lib import Model, Prediction 39 | 40 | if opt.dataset == 'CVPPP': 41 | from settings import CVPPPModelSettings 42 | ms = CVPPPModelSettings() 43 | 44 | model = Model(opt.dataset, ms.MODEL_NAME, ms.N_CLASSES, ms.MAX_N_OBJECTS, 45 | use_instance_segmentation=ms.USE_INSTANCE_SEGMENTATION, 46 | use_coords=ms.USE_COORDINATES, load_model_path=opt.model, 47 | usegpu=opt.usegpu) 48 | 49 | prediction = Prediction(ms.IMAGE_HEIGHT, ms.IMAGE_WIDTH, 50 | ms.MEAN, ms.STD, False, model, 51 | 1) 52 | image, fg_seg_pred, ins_seg_pred, n_objects_pred = prediction.predict( 53 | image_path) 54 | 55 | fg_seg_pred = fg_seg_pred * 255 56 | 57 | _n_clusters = len(np.unique(ins_seg_pred.flatten())) - 1 # discard bg 58 | colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, _n_clusters)] 59 | ins_seg_pred_color = np.zeros( 60 | (ins_seg_pred.shape[0], ins_seg_pred.shape[1], 3), dtype=np.uint8) 61 | for i in range(_n_clusters): 62 | ins_seg_pred_color[ins_seg_pred == ( 63 | i + 1)] = (np.array(colors[i][:3]) * 255).astype('int') 64 | 65 | image_name = os.path.splitext(os.path.basename(image_path))[0] 66 | 67 | image_pil = Image.fromarray(image) 68 | fg_seg_pred_pil = Image.fromarray(fg_seg_pred) 69 | ins_seg_pred_pil = Image.fromarray(ins_seg_pred) 70 | ins_seg_pred_color_pil = Image.fromarray(ins_seg_pred_color) 71 | 72 | image_pil.save(os.path.join(output_path, image_name + '.png')) 73 | fg_seg_pred_pil.save(os.path.join(output_path, image_name + '-fg_mask.png')) 74 | ins_seg_pred_pil.save(os.path.join(output_path, image_name + '-ins_mask.png')) 75 | ins_seg_pred_color_pil.save(os.path.join( 76 | output_path, image_name + '-ins_mask_color.png')) 77 | np.save( 78 | os.path.join( 79 | output_path, 80 | image_name + 81 | '-n_objects.npy'), 82 | n_objects_pred) 83 | -------------------------------------------------------------------------------- /code/pred_list.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import matplotlib.pylab as plt 5 | import os 6 | import sys 7 | import argparse 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--lst', required=True, help='Text file that contains image paths') 14 | parser.add_argument('--model', required=True, help='Path of the model') 15 | parser.add_argument('--usegpu', action='store_true', 16 | help='Enables cuda to predict on gpu') 17 | parser.add_argument('--dataset', type=str, 18 | help='Name of the dataset which is "CVPPP"', 19 | required=True) 20 | opt = parser.parse_args() 21 | 22 | assert opt.dataset in ['CVPPP', ] 23 | 24 | images_list = np.loadtxt(opt.lst, dtype='str', delimiter=',') 25 | model_path = opt.model 26 | 27 | _subset = os.path.basename(opt.lst).split('_')[0] 28 | _model_name = os.path.splitext(os.path.basename(model_path))[0] 29 | _model_dir = os.path.basename(os.path.dirname(model_path)) 30 | 31 | output_path = os.path.abspath(os.path.join(os.path.abspath(__file__), os.path.pardir, 32 | os.path.pardir, 'outputs', opt.dataset, 33 | _model_dir + '-' + _model_name, _subset)) 34 | 35 | image_names = [os.path.splitext(os.path.basename(img))[0] \ 36 | for img in images_list] 37 | 38 | try: 39 | os.makedirs(output_path) 40 | except BaseException: 41 | pass 42 | 43 | model_dir = os.path.dirname(model_path) 44 | sys.path.insert(0, model_dir) 45 | 46 | from lib import Model, Prediction 47 | 48 | if opt.dataset == 'CVPPP': 49 | from settings import CVPPPModelSettings 50 | ms = CVPPPModelSettings() 51 | 52 | model = Model(opt.dataset, ms.MODEL_NAME, ms.N_CLASSES, ms.MAX_N_OBJECTS, 53 | use_instance_segmentation=ms.USE_INSTANCE_SEGMENTATION, 54 | use_coords=ms.USE_COORDINATES, load_model_path=opt.model, 55 | usegpu=opt.usegpu) 56 | 57 | prediction = Prediction(ms.IMAGE_HEIGHT, ms.IMAGE_WIDTH, 58 | ms.MEAN, ms.STD, False, model, 59 | 1) 60 | 61 | for image_name, image_path in zip(image_names, images_list): 62 | image, fg_seg_pred, ins_seg_pred, n_objects_pred = \ 63 | prediction.predict(image_path) 64 | 65 | _output_path = os.path.join(output_path, image_name) 66 | 67 | try: 68 | os.makedirs(_output_path) 69 | except BaseException: 70 | pass 71 | 72 | fg_seg_pred = fg_seg_pred * 255 73 | 74 | _n_clusters = len(np.unique(ins_seg_pred.flatten())) - 1 # discard bg 75 | colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, _n_clusters)] 76 | ins_seg_pred_color = np.zeros( 77 | (ins_seg_pred.shape[0], ins_seg_pred.shape[1], 3), dtype=np.uint8) 78 | for i in range(_n_clusters): 79 | ins_seg_pred_color[ins_seg_pred == ( 80 | i + 1)] = (np.array(colors[i][:3]) * 255).astype('int') 81 | 82 | image_pil = Image.fromarray(image) 83 | fg_seg_pred_pil = Image.fromarray(fg_seg_pred) 84 | ins_seg_pred_pil = Image.fromarray(ins_seg_pred) 85 | ins_seg_pred_color_pil = Image.fromarray(ins_seg_pred_color) 86 | 87 | image_pil.save(os.path.join(_output_path, image_name + '.png')) 88 | fg_seg_pred_pil.save(os.path.join(_output_path, image_name + '-fg_mask.png')) 89 | ins_seg_pred_pil.save(os.path.join(_output_path, image_name + '-ins_mask.png')) 90 | ins_seg_pred_color_pil.save(os.path.join( 91 | _output_path, image_name + '-ins_mask_color.png')) 92 | np.save( 93 | os.path.join( 94 | _output_path, 95 | image_name + 96 | '-n_objects.npy'), 97 | n_objects_pred) 98 | -------------------------------------------------------------------------------- /code/settings/CVPPP/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/code/settings/CVPPP/README.md -------------------------------------------------------------------------------- /code/settings/CVPPP/__init__.py: -------------------------------------------------------------------------------- 1 | from data_settings import DataSettings 2 | from model_settings import ModelSettings 3 | from training_settings import TrainingSettings 4 | -------------------------------------------------------------------------------- /code/settings/CVPPP/data_settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | class DataSettings(object): 6 | 7 | def __init__(self): 8 | 9 | self.BASE_PATH = os.path.abspath( 10 | os.path.join( 11 | __file__, 12 | os.path.pardir, 13 | os.path.pardir, 14 | os.path.pardir, 15 | os.path.pardir)) 16 | # self.CLASS_WEIGHTS = np.loadtxt(os.path.join(self.BASE_PATH, 'data', 17 | # 'metadata', 18 | # 'class_weights.txt'), 19 | # dtype='float', delimiter=',')[:, 1] 20 | self.CLASS_WEIGHTS = None 21 | # Assign it to None in order to disable class weighting 22 | 23 | self.MAX_N_OBJECTS = 20 24 | 25 | self.N_CLASSES = 1 + 1 26 | -------------------------------------------------------------------------------- /code/settings/CVPPP/model_settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from data_settings import DataSettings 4 | 5 | 6 | class ModelSettings(DataSettings): 7 | 8 | def __init__(self): 9 | super(ModelSettings, self).__init__() 10 | 11 | # self.MEAN = [0.485, 0.456, 0.406] 12 | # self.STD = [0.229, 0.224, 0.225] 13 | self.MEAN = [0.521697844321, 0.389775426267, 0.206216114391] 14 | self.STD = [0.212398291819, 0.151755427041, 0.113022107204] 15 | 16 | self.MODEL_NAME = 'ReSeg' # 'ReSeg' or 'StackedRecurrentHourglass' 17 | 18 | self.USE_INSTANCE_SEGMENTATION = True 19 | self.USE_COORDINATES = False 20 | 21 | self.IMAGE_HEIGHT = 256 22 | self.IMAGE_WIDTH = 256 23 | 24 | self.DELTA_VAR = 0.5 25 | self.DELTA_DIST = 1.5 26 | self.NORM = 2 27 | -------------------------------------------------------------------------------- /code/settings/CVPPP/training_settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from model_settings import ModelSettings 3 | 4 | 5 | class TrainingSettings(ModelSettings): 6 | 7 | def __init__(self): 8 | super(TrainingSettings, self).__init__() 9 | 10 | self.TRAINING_LMDB = os.path.join( 11 | self.BASE_PATH, 12 | 'data', 13 | 'processed', 14 | 'CVPPP', 15 | 'lmdb', 16 | 'training-lmdb') 17 | self.VALIDATION_LMDB = os.path.join( 18 | self.BASE_PATH, 19 | 'data', 20 | 'processed', 21 | 'CVPPP', 22 | 'lmdb', 23 | 'validation-lmdb') 24 | 25 | self.TRAIN_CNN = True 26 | 27 | self.OPTIMIZER = 'Adadelta' 28 | # optimizer - one of : 'RMSprop', 'Adam', 'Adadelta', 'SGD' 29 | self.LEARNING_RATE = 1.0 30 | self.LR_DROP_FACTOR = 0.1 31 | self.LR_DROP_PATIENCE = 20 32 | self.WEIGHT_DECAY = 0.001 33 | # weight decay - use 0 to disable it 34 | self.CLIP_GRAD_NORM = 10.0 35 | # max l2 norm of gradient of parameters - use 0 to disable it 36 | 37 | self.HORIZONTAL_FLIPPING = False 38 | self.VERTICAL_FLIPPING = False 39 | self.TRANSPOSING = False 40 | self.ROTATION_90X = False 41 | self.ROTATION = False 42 | self.COLOR_JITTERING = False 43 | self.GRAYSCALING = False 44 | self.CHANNEL_SWAPPING = False 45 | self.GAMMA_ADJUSTMENT = False 46 | self.RESOLUTION_DEGRADING = False 47 | 48 | self.CRITERION = 'Multi' 49 | # criterion - One of 'CE', 'Dice', 'Multi' 50 | self.OPTIMIZE_BG = False 51 | 52 | # self.RANDOM_CROPPING = False 53 | # CROP_SCALE and CROP_AR is used iff self.RANDOM_CROPPING is True 54 | # self.CROP_SCALE = (1.0, 1.0) 55 | # Choose it carefully - have a look at 56 | # lib/preprocess.py -> RandomResizedCrop 57 | # self.CROP_AR = (1.0, 1.0) 58 | # Choose it carefully - have a look 59 | # at lib/preprocess.py -> RandomResizedCrop 60 | 61 | self.SEED = 23 62 | -------------------------------------------------------------------------------- /code/settings/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/code/settings/README.md -------------------------------------------------------------------------------- /code/settings/__init__.py: -------------------------------------------------------------------------------- 1 | from CVPPP import DataSettings as CVPPPDataSettings 2 | from CVPPP import ModelSettings as CVPPPModelSettings 3 | from CVPPP import TrainingSettings as CVPPPTrainingSettings 4 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import os 4 | import getpass 5 | import datetime 6 | import shutil 7 | import numpy as np 8 | import torch 9 | from lib import SegDataset, Model, AlignCollate 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--model', default='', 13 | help="Filepath of trained model (to continue training) \ 14 | [Default: '']") 15 | parser.add_argument('--usegpu', action='store_true', 16 | help='Enables cuda to train on gpu [Default: False]') 17 | parser.add_argument('--nepochs', type=int, default=600, 18 | help='Number of epochs to train for [Default: 600]') 19 | parser.add_argument('--batchsize', type=int, 20 | default=2, help='Batch size [Default: 2]') 21 | parser.add_argument('--debug', action='store_true', 22 | help='Activates debug mode [Default: False]') 23 | parser.add_argument('--nworkers', type=int, 24 | help='Number of workers for data loading \ 25 | (0 to do it using main process) [Default : 2]', 26 | default=2) 27 | parser.add_argument('--dataset', type=str, 28 | help='Name of the dataset which is "CVPPP"', 29 | required=True) 30 | opt = parser.parse_args() 31 | 32 | assert opt.dataset in ['CVPPP', ] 33 | 34 | if opt.dataset == 'CVPPP': 35 | from settings import CVPPPTrainingSettings 36 | ts = CVPPPTrainingSettings() 37 | 38 | 39 | def generate_run_id(): 40 | 41 | username = getpass.getuser() 42 | 43 | now = datetime.datetime.now() 44 | date = map(str, [now.year, now.month, now.day]) 45 | coarse_time = map(str, [now.hour, now.minute]) 46 | fine_time = map(str, [now.second, now.microsecond]) 47 | 48 | run_id = '_'.join(['-'.join(date), '-'.join(coarse_time), 49 | username, '-'.join(fine_time)]) 50 | return run_id 51 | 52 | 53 | RUN_ID = generate_run_id() 54 | model_save_path = os.path.abspath(os.path.join(os.path.abspath(__file__), 55 | os.path.pardir, os.path.pardir, 56 | 'models', opt.dataset, RUN_ID)) 57 | os.makedirs(model_save_path) 58 | 59 | CODE_BASE_DIR = os.path.abspath(os.path.join( 60 | os.path.abspath(__file__), os.path.pardir)) 61 | shutil.copytree(os.path.join(CODE_BASE_DIR, 'settings'), 62 | os.path.join(model_save_path, 'settings')) 63 | shutil.copytree(os.path.join(CODE_BASE_DIR, 'lib'), 64 | os.path.join(model_save_path, 'lib')) 65 | 66 | if torch.cuda.is_available() and not opt.usegpu: 67 | print 'WARNING: You have a CUDA device, so you should probably \ 68 | run with --usegpu' 69 | 70 | # Load Seeds 71 | random.seed(ts.SEED) 72 | np.random.seed(ts.SEED) 73 | torch.manual_seed(ts.SEED) 74 | 75 | # Define Data Loaders 76 | pin_memory = False 77 | if opt.usegpu: 78 | pin_memory = True 79 | 80 | train_dataset = SegDataset(ts.TRAINING_LMDB) 81 | assert train_dataset 82 | 83 | train_align_collate = AlignCollate( 84 | 'training', 85 | ts.N_CLASSES, 86 | ts.MAX_N_OBJECTS, 87 | ts.MEAN, 88 | ts.STD, 89 | ts.IMAGE_HEIGHT, 90 | ts.IMAGE_WIDTH, 91 | random_hor_flipping=ts.HORIZONTAL_FLIPPING, 92 | random_ver_flipping=ts.VERTICAL_FLIPPING, 93 | random_transposing=ts.TRANSPOSING, 94 | random_90x_rotation=ts.ROTATION_90X, 95 | random_rotation=ts.ROTATION, 96 | random_color_jittering=ts.COLOR_JITTERING, 97 | random_grayscaling=ts.GRAYSCALING, 98 | random_channel_swapping=ts.CHANNEL_SWAPPING, 99 | random_gamma=ts.GAMMA_ADJUSTMENT, 100 | random_resolution=ts.RESOLUTION_DEGRADING) 101 | 102 | train_loader = torch.utils.data.DataLoader(train_dataset, 103 | batch_size=opt.batchsize, 104 | shuffle=True, 105 | num_workers=opt.nworkers, 106 | pin_memory=pin_memory, 107 | collate_fn=train_align_collate) 108 | 109 | test_dataset = SegDataset(ts.VALIDATION_LMDB) 110 | assert test_dataset 111 | 112 | test_align_collate = AlignCollate( 113 | 'test', 114 | ts.N_CLASSES, 115 | ts.MAX_N_OBJECTS, 116 | ts.MEAN, 117 | ts.STD, 118 | ts.IMAGE_HEIGHT, 119 | ts.IMAGE_WIDTH, 120 | random_hor_flipping=ts.HORIZONTAL_FLIPPING, 121 | random_ver_flipping=ts.VERTICAL_FLIPPING, 122 | random_transposing=ts.TRANSPOSING, 123 | random_90x_rotation=ts.ROTATION_90X, 124 | random_rotation=ts.ROTATION, 125 | random_color_jittering=ts.COLOR_JITTERING, 126 | random_grayscaling=ts.GRAYSCALING, 127 | random_channel_swapping=ts.CHANNEL_SWAPPING, 128 | random_gamma=ts.GAMMA_ADJUSTMENT, 129 | random_resolution=ts.RESOLUTION_DEGRADING) 130 | 131 | test_loader = torch.utils.data.DataLoader(test_dataset, 132 | batch_size=opt.batchsize, 133 | shuffle=False, 134 | num_workers=opt.nworkers, 135 | pin_memory=pin_memory, 136 | collate_fn=test_align_collate) 137 | 138 | # Define Model 139 | model = Model(opt.dataset, ts.MODEL_NAME, ts.N_CLASSES, ts.MAX_N_OBJECTS, 140 | use_instance_segmentation=ts.USE_INSTANCE_SEGMENTATION, 141 | use_coords=ts.USE_COORDINATES, load_model_path=opt.model, 142 | usegpu=opt.usegpu) 143 | 144 | # Train Model 145 | model.fit(ts.CRITERION, ts.DELTA_VAR, ts.DELTA_DIST, ts.NORM, ts.LEARNING_RATE, 146 | ts.WEIGHT_DECAY, ts.CLIP_GRAD_NORM, ts.LR_DROP_FACTOR, 147 | ts.LR_DROP_PATIENCE, ts.OPTIMIZE_BG, ts.OPTIMIZER, ts.TRAIN_CNN, 148 | opt.nepochs, ts.CLASS_WEIGHTS, train_loader, test_loader, 149 | model_save_path, opt.debug) 150 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/data/README.md -------------------------------------------------------------------------------- /data/metadata/CVPPP/means_and_stds.txt: -------------------------------------------------------------------------------- 1 | RGB MEANS : 0.521697844321 0.389775426267 0.206216114391 2 | RGB STDS : 0.212398291819 0.151755427041 0.113022107204 3 | -------------------------------------------------------------------------------- /data/metadata/CVPPP/training.lst: -------------------------------------------------------------------------------- 1 | plant026 2 | plant016 3 | plant129 4 | plant090 5 | plant030 6 | plant012 7 | plant083 8 | plant123 9 | plant119 10 | plant107 11 | plant153 12 | plant138 13 | plant118 14 | plant134 15 | plant115 16 | plant105 17 | plant094 18 | plant029 19 | plant043 20 | plant145 21 | plant006 22 | plant057 23 | plant137 24 | plant040 25 | plant159 26 | plant001 27 | plant017 28 | plant099 29 | plant078 30 | plant149 31 | plant113 32 | plant141 33 | plant021 34 | plant086 35 | plant151 36 | plant015 37 | plant079 38 | plant130 39 | plant127 40 | plant104 41 | plant010 42 | plant067 43 | plant088 44 | plant059 45 | plant070 46 | plant069 47 | plant008 48 | plant128 49 | plant062 50 | plant144 51 | plant148 52 | plant146 53 | plant055 54 | plant120 55 | plant024 56 | plant048 57 | plant053 58 | plant005 59 | plant038 60 | plant156 61 | plant037 62 | plant108 63 | plant064 64 | plant045 65 | plant050 66 | plant143 67 | plant054 68 | plant085 69 | plant036 70 | plant124 71 | plant132 72 | plant027 73 | plant121 74 | plant051 75 | plant133 76 | plant047 77 | plant061 78 | plant091 79 | plant139 80 | plant076 81 | plant068 82 | plant116 83 | plant063 84 | plant046 85 | plant060 86 | plant084 87 | plant109 88 | plant154 89 | plant002 90 | plant114 91 | plant058 92 | plant161 93 | plant071 94 | plant110 95 | plant032 96 | plant020 97 | plant044 98 | plant042 99 | plant082 100 | plant142 101 | -------------------------------------------------------------------------------- /data/metadata/CVPPP/validation.lst: -------------------------------------------------------------------------------- 1 | plant031 2 | plant152 3 | plant072 4 | plant033 5 | plant035 6 | plant018 7 | plant013 8 | plant022 9 | plant073 10 | plant101 11 | plant052 12 | plant049 13 | plant011 14 | plant080 15 | plant039 16 | plant077 17 | plant092 18 | plant126 19 | plant106 20 | plant096 21 | plant098 22 | plant147 23 | plant135 24 | plant100 25 | plant089 26 | plant065 27 | plant102 28 | plant007 29 | -------------------------------------------------------------------------------- /data/metadata/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/data/metadata/README.md -------------------------------------------------------------------------------- /data/processed/CVPPP/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/data/processed/CVPPP/README.md -------------------------------------------------------------------------------- /data/processed/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/data/processed/README.md -------------------------------------------------------------------------------- /data/raw/CVPPP/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/data/raw/CVPPP/README.md -------------------------------------------------------------------------------- /data/raw/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/data/raw/README.md -------------------------------------------------------------------------------- /data/scripts/CVPPP/1-create_annotations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | from PIL import Image 5 | import numpy as np 6 | 7 | DATA_DIR = os.path.abspath(os.path.join(__file__, os.path.pardir, 8 | os.path.pardir, os.path.pardir)) 9 | ANN_DIR = os.path.join(DATA_DIR, 'raw', 'CVPPP', 'CVPPP2017_LSC_training', 10 | 'training', 'A1') 11 | IMG_DIR = os.path.join(DATA_DIR, 'raw', 'CVPPP', 'CVPPP2017_LSC_training', 12 | 'training', 'A1') 13 | SEMANTIC_OUTPUT_DIR = os.path.join(DATA_DIR, 'processed', 'CVPPP', 14 | 'semantic-annotations') 15 | INSTANCE_OUTPUT_DIR = os.path.join(DATA_DIR, 'processed', 'CVPPP', 16 | 'instance-annotations') 17 | 18 | try: 19 | os.makedirs(SEMANTIC_OUTPUT_DIR) 20 | except BaseException: 21 | pass 22 | 23 | try: 24 | os.makedirs(INSTANCE_OUTPUT_DIR) 25 | except BaseException: 26 | pass 27 | 28 | image_paths = glob.glob(os.path.join(IMG_DIR, '*_rgb.png')) 29 | 30 | for image_path in image_paths: 31 | img = Image.open(image_path) 32 | img_width, img_height = img.size 33 | 34 | image_name = os.path.splitext(os.path.basename(image_path))[ 35 | 0].split('_')[0] 36 | annotation_path = os.path.join(ANN_DIR, image_name + '_label.png') 37 | 38 | if not os.path.isfile(annotation_path): 39 | continue 40 | 41 | annotation = np.array(Image.open(annotation_path)) 42 | 43 | assert len(annotation.shape) == 2 44 | assert np.array(img).shape[:2] == annotation.shape[:2] 45 | 46 | instance_values = set(np.unique(annotation)).difference([0]) 47 | n_instances = len(instance_values) 48 | 49 | if n_instances == 0: 50 | continue 51 | 52 | instance_mask = np.zeros( 53 | (img_height, img_width, n_instances), dtype=np.uint8) 54 | 55 | for i, v in enumerate(instance_values): 56 | _mask = np.zeros((img_height, img_width), dtype=np.uint8) 57 | _mask[annotation == v] = 1 58 | instance_mask[:, :, i] = _mask 59 | 60 | semantic_mask = instance_mask.sum(2) 61 | semantic_mask[semantic_mask != 0] = 1 62 | semantic_mask = semantic_mask.astype(np.uint8) 63 | 64 | np.save(os.path.join(INSTANCE_OUTPUT_DIR, image_name + '.npy'), 65 | instance_mask) 66 | np.save(os.path.join(SEMANTIC_OUTPUT_DIR, image_name + '.npy'), 67 | semantic_mask) 68 | -------------------------------------------------------------------------------- /data/scripts/CVPPP/1-remove_alpha.sh: -------------------------------------------------------------------------------- 1 | for f in ../../raw/CVPPP/CVPPP2017_LSC_training/training/A1/*_rgb.png; 2 | do 3 | convert $f -alpha off $f 4 | done 5 | -------------------------------------------------------------------------------- /data/scripts/CVPPP/2-get_image_means-stds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | 6 | DATA_DIR = os.path.abspath(os.path.join(__file__, os.path.pardir, 7 | os.path.pardir, os.path.pardir)) 8 | IMG_DIR = os.path.join(DATA_DIR, 'raw', 'CVPPP', 'CVPPP2017_LSC_training', 9 | 'training', 'A1') 10 | 11 | image_list = np.loadtxt(os.path.join(DATA_DIR, 'metadata', 'CVPPP', 12 | 'training.lst'), 13 | dtype='str', delimiter=',') 14 | 15 | reds, greens, blues = [], [], [] 16 | for image_name in image_list: 17 | img = np.array(Image.open(os.path.join(IMG_DIR, image_name + '_rgb.png'))) 18 | r, g, b = np.split(img, 3, axis=2) 19 | 20 | r = r.flatten() 21 | g = g.flatten() 22 | b = b.flatten() 23 | 24 | reds.extend(r) 25 | greens.extend(g) 26 | blues.extend(b) 27 | 28 | reds = np.array(reds) 29 | greens = np.array(greens) 30 | blues = np.array(blues) 31 | 32 | red_mean = np.mean(reds) / 255. 33 | green_mean = np.mean(greens) / 255. 34 | blue_mean = np.mean(blues) / 255. 35 | 36 | red_std = np.std(reds) / 255. 37 | green_std = np.std(greens) / 255. 38 | blue_std = np.std(blues) / 255. 39 | 40 | with open(os.path.join(DATA_DIR, 'metadata', 'CVPPP', 'means_and_stds.txt'), 'w') as fp: 41 | fp.write('RGB MEANS : {} {} {}\n'.format(red_mean, green_mean, blue_mean)) 42 | fp.write('RGB STDS : {} {} {}\n'.format(red_std, green_std, blue_std)) 43 | -------------------------------------------------------------------------------- /data/scripts/CVPPP/2-get_image_paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | DATA_DIR = os.path.abspath(os.path.join(__file__, os.path.pardir, 5 | os.path.pardir, os.path.pardir)) 6 | IMG_DIR = os.path.join(DATA_DIR, 'raw', 'CVPPP', 'CVPPP2017_LSC_training', 7 | 'training', 'A1') 8 | METADATA_OUTPUT_DIR = os.path.join(DATA_DIR, 'metadata', 'CVPPP') 9 | 10 | SUBSETS = ['train', 'val'] 11 | SUBSET_NAMES= ['training', 'validation'] 12 | 13 | for si, subset in enumerate(SUBSETS): 14 | lst = os.path.join(METADATA_OUTPUT_DIR, SUBSET_NAMES[si] + '.lst') 15 | image_names = np.loadtxt(lst, dtype='str', delimiter=',') 16 | 17 | image_paths = [] 18 | for image_name in image_names: 19 | _dir = image_name.split('_')[0] 20 | image_path = os.path.join(IMG_DIR, image_name + '_rgb.png') 21 | image_paths.append(image_path) 22 | 23 | np.savetxt(os.path.join(METADATA_OUTPUT_DIR, SUBSET_NAMES[si] + '_image_paths.txt'), 24 | image_paths, fmt='%s', delimiter=',') 25 | -------------------------------------------------------------------------------- /data/scripts/CVPPP/2-get_image_shapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | 6 | DATA_DIR = os.path.abspath(os.path.join(__file__, os.path.pardir, 7 | os.path.pardir, os.path.pardir)) 8 | ANN_DIR = os.path.join(DATA_DIR, 'processed', 'CVPPP', 'instance-annotations') 9 | IMG_DIR = os.path.join(DATA_DIR, 'raw', 'CVPPP', 'CVPPP2017_LSC_training', 10 | 'training', 'A1') 11 | OUTPUT_DIR = os.path.join(DATA_DIR, 'metadata', 'CVPPP') 12 | 13 | annotation_files = glob.glob(os.path.join(ANN_DIR, '*.npy')) 14 | 15 | image_shapes = [] 16 | for f in annotation_files: 17 | image_name = os.path.splitext(os.path.basename(f))[0] 18 | ann_size = np.load(f).shape[:2] 19 | image_path = os.path.join(IMG_DIR, image_name + '_rgb.png') 20 | img_size = Image.open(image_path).size 21 | img_size = (img_size[1], img_size[0]) 22 | 23 | assert ann_size == img_size 24 | 25 | image_shapes.append([image_name, ann_size[0], ann_size[1]]) 26 | 27 | np.savetxt(os.path.join(OUTPUT_DIR, 'image_shapes.txt'), image_shapes, 28 | fmt='%s', delimiter=',') 29 | -------------------------------------------------------------------------------- /data/scripts/CVPPP/2-get_number_of_instances.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | 6 | DATA_DIR = os.path.abspath(os.path.join(__file__, os.path.pardir, 7 | os.path.pardir, os.path.pardir)) 8 | ANN_DIR = os.path.join(DATA_DIR, 'processed', 'CVPPP', 'instance-annotations') 9 | OUTPUT_DIR = os.path.join(DATA_DIR, 'metadata', 'CVPPP') 10 | 11 | annotation_files = glob.glob(os.path.join(ANN_DIR, '*.npy')) 12 | 13 | number_of_instances = [] 14 | for f in annotation_files: 15 | image_name = os.path.splitext(os.path.basename(f))[0] 16 | n_instances = np.load(f).shape[-1] 17 | 18 | number_of_instances.append([image_name, n_instances]) 19 | 20 | np.savetxt(os.path.join(OUTPUT_DIR, 'number_of_instances.txt'), 21 | number_of_instances, fmt='%s', delimiter=',') 22 | -------------------------------------------------------------------------------- /data/scripts/CVPPP/3-create_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | from utils import create_dataset 6 | 7 | DATA_DIR = os.path.abspath(os.path.join(__file__, os.path.pardir, 8 | os.path.pardir, os.path.pardir)) 9 | SEMANTIC_ANN_DIR = os.path.join(DATA_DIR, 'processed', 'CVPPP', 10 | 'semantic-annotations') 11 | INSTANCE_ANN_DIR = os.path.join(DATA_DIR, 'processed', 'CVPPP', 12 | 'instance-annotations') 13 | IMG_DIR = os.path.join(DATA_DIR, 'raw', 'CVPPP', 'CVPPP2017_LSC_training', 14 | 'training', 'A1') 15 | OUT_DIR = os.path.join(DATA_DIR, 'processed', 'CVPPP', 'lmdb') 16 | 17 | try: 18 | os.makedirs(OUT_DIR) 19 | except BaseException: 20 | pass 21 | 22 | for subset in ['training', 'validation']: 23 | lst_filepath = os.path.join(DATA_DIR, 'metadata', 'CVPPP', subset + '.lst') 24 | lst = np.loadtxt(lst_filepath, dtype='str', delimiter=' ') 25 | 26 | img_paths = [] 27 | ins_ann_paths = [] 28 | semantic_ann_paths = [] 29 | for image_name in lst: 30 | img_path = os.path.join(IMG_DIR, image_name + '_rgb.png') 31 | ins_ann_path = os.path.join(INSTANCE_ANN_DIR, image_name + '.npy') 32 | sem_ann_path = os.path.join(SEMANTIC_ANN_DIR, image_name + '.npy') 33 | 34 | if os.path.isfile(img_path) and os.path.isfile( 35 | ins_ann_path) and os.path.isfile(sem_ann_path): 36 | img_paths.append(img_path) 37 | ins_ann_paths.append(ins_ann_path) 38 | semantic_ann_paths.append(sem_ann_path) 39 | 40 | out_path = os.path.join(OUT_DIR, '{}-lmdb'.format(subset)) 41 | 42 | create_dataset(out_path, img_paths, semantic_ann_paths, ins_ann_paths) 43 | -------------------------------------------------------------------------------- /data/scripts/CVPPP/prepare.sh: -------------------------------------------------------------------------------- 1 | # 1. Create Semantic and Instance Masks. 2 | echo "1. Creating semantic and instance masks" 3 | python 1-create_annotations.py 4 | 5 | # 2. Remove alpha channels from images. 6 | echo "2. Removing alpha channels from images" 7 | sh 1-remove_alpha.sh 8 | 9 | # 3. Get Image Paths. 10 | echo "3. Saving image paths" 11 | python 2-get_image_paths.py 12 | 13 | # 4. Get Image Shapes. 14 | echo "4. Calculating image shapes" 15 | python 2-get_image_shapes.py 16 | 17 | # 5. Get Image Means and Standard Deviations. 18 | echo "5. Calculating means and standard deviations per channel" 19 | python 2-get_image_means-stds.py 20 | 21 | # 6. Get Number of Instances. 22 | echo "6. Calculating number of instances in images" 23 | python 2-get_number_of_instances.py 24 | 25 | # 7. Create LMDB. 26 | echo "7. Creating LMDB" 27 | python 3-create_dataset.py 28 | -------------------------------------------------------------------------------- /data/scripts/CVPPP/utils.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import numpy as np 3 | 4 | 5 | def write_cache(env, cache): 6 | with env.begin(write=True) as txn: 7 | for k, v in cache.iteritems(): 8 | txn.put(k, v) 9 | 10 | 11 | def create_dataset( 12 | output_path, 13 | image_paths, 14 | semantic_annotation_paths, 15 | instance_annotation_paths): 16 | 17 | n_images = len(image_paths) 18 | 19 | assert(n_images == len(semantic_annotation_paths)) 20 | 21 | print 'Number of Images : {}'.format(n_images) 22 | 23 | env = lmdb.open(output_path, map_size=1099511627776) 24 | cache = {} 25 | n_images_cntr = 1 26 | for i in xrange(n_images): 27 | image_path = image_paths[i] 28 | semantic_annotation_path = semantic_annotation_paths[i] 29 | instance_annotation_path = instance_annotation_paths[i] 30 | 31 | image = open(image_path, 'r').read() 32 | 33 | semantic_annotation = np.load(semantic_annotation_path) 34 | semantic_annotation_height = semantic_annotation.shape[0] 35 | semantic_annotation_width = semantic_annotation.shape[1] 36 | 37 | instance_annotation = np.load(instance_annotation_path) 38 | 39 | n_objects = instance_annotation.shape[2] 40 | 41 | cache['image-{}'.format(n_images_cntr)] = image 42 | cache['semantic-annotation-{}'.format(n_images_cntr) 43 | ] = semantic_annotation.tostring() 44 | cache['instance-annotation-{}'.format(n_images_cntr) 45 | ] = instance_annotation.tostring() 46 | cache['height-{}'.format(n_images_cntr) 47 | ] = str(semantic_annotation_height) 48 | cache['width-{}'.format(n_images_cntr) 49 | ] = str(semantic_annotation_width) 50 | cache['n_objects-{}'.format(n_images_cntr)] = str(n_objects) 51 | 52 | if n_images_cntr % 50 == 0: 53 | write_cache(env, cache) 54 | cache = {} 55 | print 'Processed %d / %d' % (n_images_cntr, n_images) 56 | n_images_cntr += 1 57 | 58 | cache['num-samples'] = str(n_images) 59 | write_cache(env, cache) 60 | print 'Created dataset with {} samples'.format(n_images) 61 | -------------------------------------------------------------------------------- /data/scripts/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/data/scripts/README.md -------------------------------------------------------------------------------- /models/CVPPP/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/models/CVPPP/README.md -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/models/README.md -------------------------------------------------------------------------------- /outputs/CVPPP/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/outputs/CVPPP/README.md -------------------------------------------------------------------------------- /outputs/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/outputs/README.md -------------------------------------------------------------------------------- /samples/CVPPP/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/README.md -------------------------------------------------------------------------------- /samples/CVPPP/plant007_rgb-fg_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant007_rgb-fg_mask.png -------------------------------------------------------------------------------- /samples/CVPPP/plant007_rgb-ins_mask_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant007_rgb-ins_mask_color.png -------------------------------------------------------------------------------- /samples/CVPPP/plant007_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant007_rgb.png -------------------------------------------------------------------------------- /samples/CVPPP/plant031_rgb-fg_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant031_rgb-fg_mask.png -------------------------------------------------------------------------------- /samples/CVPPP/plant031_rgb-ins_mask_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant031_rgb-ins_mask_color.png -------------------------------------------------------------------------------- /samples/CVPPP/plant031_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant031_rgb.png -------------------------------------------------------------------------------- /samples/CVPPP/plant033_rgb-fg_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant033_rgb-fg_mask.png -------------------------------------------------------------------------------- /samples/CVPPP/plant033_rgb-ins_mask_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant033_rgb-ins_mask_color.png -------------------------------------------------------------------------------- /samples/CVPPP/plant033_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/CVPPP/plant033_rgb.png -------------------------------------------------------------------------------- /samples/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wizaron/instance-segmentation-pytorch/f61c903606f0a00d24cbbea73fd3164e1dfa85fa/samples/README.md --------------------------------------------------------------------------------