├── .github ├── FUNDING.yml └── workflows │ └── ci.yml ├── .gitignore ├── .gitmodules ├── .readme └── fcn8s_iter28000.jpg ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples └── voc │ ├── .gitignore │ ├── README.md │ ├── download_dataset.sh │ ├── evaluate.py │ ├── learning_curve.py │ ├── model_caffe_to_pytorch.py │ ├── speedtest.py │ ├── summarize_logs.py │ ├── train_fcn16s.py │ ├── train_fcn32s.py │ ├── train_fcn8s.py │ ├── train_fcn8s_atonce.py │ └── view_log ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests └── models_tests │ └── test_fcn32s.py └── torchfcn ├── __init__.py ├── datasets ├── __init__.py └── voc.py ├── ext └── fcn.berkeleyvision.org │ ├── README.md │ ├── data │ ├── nyud │ │ ├── README.md │ │ ├── classes.txt │ │ ├── test.txt │ │ ├── train.txt │ │ └── val.txt │ ├── pascal-context │ │ ├── README.md │ │ ├── classes-400.txt │ │ └── classes-59.txt │ ├── pascal │ │ ├── README.md │ │ ├── classes.txt │ │ └── seg11valid.txt │ └── sift-flow │ │ ├── README.md │ │ ├── classes.txt │ │ ├── test.txt │ │ └── trainval.txt │ ├── ilsvrc-nets │ └── README.md │ ├── infer.py │ ├── nyud-fcn32s-color-d │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── test.prototxt │ └── trainval.prototxt │ ├── nyud-fcn32s-color-hha │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── test.prototxt │ └── trainval.prototxt │ ├── nyud-fcn32s-color │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── test.prototxt │ └── trainval.prototxt │ ├── nyud-fcn32s-hha │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── test.prototxt │ └── trainval.prototxt │ ├── nyud_layers.py │ ├── pascalcontext-fcn16s │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── train.prototxt │ └── val.prototxt │ ├── pascalcontext-fcn32s │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── train.prototxt │ └── val.prototxt │ ├── pascalcontext-fcn8s │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── train.prototxt │ └── val.prototxt │ ├── pascalcontext_layers.py │ ├── score.py │ ├── siftflow-fcn16s │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── test.prototxt │ └── trainval.prototxt │ ├── siftflow-fcn32s │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── test.prototxt │ └── trainval.prototxt │ ├── siftflow-fcn8s │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── test.prototxt │ └── trainval.prototxt │ ├── siftflow_layers.py │ ├── surgery.py │ ├── voc-fcn-alexnet │ ├── caffemodel-url │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── train.prototxt │ └── val.prototxt │ ├── voc-fcn16s │ ├── caffemodel-url │ ├── deploy.prototxt │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── train.prototxt │ └── val.prototxt │ ├── voc-fcn32s │ ├── caffemodel-url │ ├── deploy.prototxt │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── train.prototxt │ └── val.prototxt │ ├── voc-fcn8s-atonce │ ├── caffemodel-url │ ├── deploy.prototxt │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── train.prototxt │ └── val.prototxt │ ├── voc-fcn8s │ ├── caffemodel-url │ ├── deploy.prototxt │ ├── net.py │ ├── solve.py │ ├── solver.prototxt │ ├── train.prototxt │ └── val.prototxt │ ├── voc_helper.py │ └── voc_layers.py ├── models ├── __init__.py ├── fcn16s.py ├── fcn32s.py ├── fcn8s.py └── vgg.py ├── trainer.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [wkentaro] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [3.8] 16 | 17 | steps: 18 | - uses: actions/checkout@v1 19 | 20 | - name: Update submodules 21 | run: | 22 | git submodule update --init --recursive 23 | 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v1 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | if [ "${{ matrix.python-version }}" = "2.7" ]; then 33 | pip install numpy==1.16.5 34 | fi 35 | pip install -r requirements.txt 36 | 37 | - name: Install main 38 | run: | 39 | pip install . 40 | 41 | - name: Lint with flake8 42 | run: | 43 | pip install flake8 44 | flake8 . 45 | 46 | - name: Test with pytest 47 | run: | 48 | pip install pytest 49 | pytest tests 50 | 51 | - name: Install from dist 52 | run: | 53 | rm -f dist/*.tar.gz 54 | python setup.py sdist 55 | pip install dist/*.tar.gz 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | *.py[cdo] 3 | 4 | build/ 5 | dist/ 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "github2pypi"] 2 | path = github2pypi 3 | url = https://github.com/wkentaro/github2pypi.git 4 | -------------------------------------------------------------------------------- /.readme/fcn8s_iter28000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wkentaro/pytorch-fcn/4b988509bfd1a613d5e8482595a1c6654047d69b/.readme/fcn8s_iter28000.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 - 2019 Kentaro Wada. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt 3 | recursive-include torchfcn/ext * 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-fcn 2 | 3 | [![PyPI Version](https://img.shields.io/pypi/v/torchfcn.svg)](https://pypi.python.org/pypi/torchfcn) 4 | [![Python Versions](https://img.shields.io/pypi/pyversions/torchfcn.svg)](https://pypi.org/project/torchfcn) 5 | [![GitHub Actions](https://github.com/wkentaro/pytorch-fcn/workflows/CI/badge.svg)](https://github.com/wkentaro/pytorch-fcn/actions) 6 | 7 | PyTorch implementation of [Fully Convolutional Networks](https://github.com/shelhamer/fcn.berkeleyvision.org). 8 | 9 | 10 | ## Requirements 11 | 12 | - [pytorch](https://github.com/pytorch/pytorch) >= 0.2.0 13 | - [torchvision](https://github.com/pytorch/vision) >= 0.1.8 14 | - [fcn](https://github.com/wkentaro/fcn) >= 6.1.5 15 | - [Pillow](https://github.com/python-pillow/Pillow) 16 | - [scipy](https://github.com/scipy/scipy) 17 | - [tqdm](https://github.com/tqdm/tqdm) 18 | 19 | 20 | ## Installation 21 | 22 | ```bash 23 | git clone https://github.com/wkentaro/pytorch-fcn.git 24 | cd pytorch-fcn 25 | pip install . 26 | 27 | # or 28 | 29 | pip install torchfcn 30 | ``` 31 | 32 | 33 | ## Training 34 | 35 | See [VOC example](examples/voc). 36 | 37 | 38 | ## Accuracy 39 | 40 | At `10fdec9`. 41 | 42 | | Model | Implementation | epoch | iteration | Mean IU | Pretrained Model | 43 | |:-----:|:--------------:|:-------:|:-----------:|:-------:|:----------------:| 44 | |FCN32s | [Original](https://github.com/shelhamer/fcn.berkeleyvision.org/tree/main/voc-fcn32s) | - | - | **63.63** | [Download](https://github.com/wkentaro/pytorch-fcn/blob/45c6b2d3f553cbe6369822d17a7a51dfe9328662/torchfcn/models/fcn32s.py#L34) | 45 | |FCN32s | Ours |11 | 96000 | 62.84 | | 46 | |FCN16s | [Original](https://github.com/shelhamer/fcn.berkeleyvision.org/tree/main/voc-fcn16s) | - | - | **65.01** | [Download](https://github.com/wkentaro/pytorch-fcn/blob/45c6b2d3f553cbe6369822d17a7a51dfe9328662/torchfcn/models/fcn16s.py#L17) | 47 | |FCN16s | Ours |11 | 96000 | 64.91 | | 48 | |FCN8s | [Original](https://github.com/shelhamer/fcn.berkeleyvision.org/tree/main/voc-fcn8s) | - | - | **65.51** | [Download](https://github.com/wkentaro/pytorch-fcn/blob/45c6b2d3f553cbe6369822d17a7a51dfe9328662/torchfcn/models/fcn8s.py#L17) | 49 | |FCN8s | Ours | 7 | 60000 | 65.49 | | 50 | |FCN8sAtOnce | [Original](https://github.com/shelhamer/fcn.berkeleyvision.org/tree/main/voc-fcn8s-atonce) | - | - | **65.40** | | 51 | |FCN8sAtOnce | Ours |11 | 96000 | 64.74 | | 52 | 53 | 54 | Visualization of validation result of FCN8s. 55 | 56 | 57 | ## Cite This Project 58 | 59 | If you use this project in your research or wish to refer to the baseline results published in the README, please use the following BibTeX entry. 60 | 61 | ```bash 62 | @misc{pytorch-fcn2017, 63 | author = {Ketaro Wada}, 64 | title = {{pytorch-fcn: PyTorch Implementation of Fully Convolutional Networks}}, 65 | howpublished = {\url{https://github.com/wkentaro/pytorch-fcn}}, 66 | year = {2017} 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /examples/voc/.gitignore: -------------------------------------------------------------------------------- 1 | /logs 2 | -------------------------------------------------------------------------------- /examples/voc/README.md: -------------------------------------------------------------------------------- 1 | # VOC Example 2 | 3 | 4 | ## Training 5 | 6 | 7 | ```bash 8 | ./download_dataset.sh 9 | 10 | ./train_fcn32s.py -g 0 11 | ./train_fcn16s.py -g 0 12 | ./train_fcn8s.py -g 0 13 | ./train_fcn8s_atonce.py -g 0 14 | 15 | ./view_log logs/XXX/log.csv 16 | ``` 17 | 18 | 19 | ## Speed 20 | 21 | PyTorch implementation is faster for static inputs and slower for dynamic ones than [Chainer one](https://github.com/wkentaro/fcn) at test time. 22 | (In the previous performance, Chainer one was much slower, but it was fixed via [wkentaro/fcn#90](https://github.com/wkentaro/fcn/pull/90).) 23 | 24 | ```bash 25 | # Titan X (Pascal) 26 | # chainer==2.0.2 27 | # pytorch==0.2.0.post2 28 | # pytorch-fcn==1.7.0 29 | 30 | % cd examples/voc 31 | 32 | % ./speedtest.py --gpu 2 33 | ==> Benchmark: gpu=2, times=1000, dynamic_input=False 34 | ==> Testing FCN32s with Chainer 35 | Elapsed time: 45.95 [s / 1000 evals] 36 | Hz: 21.76 [hz] 37 | ==> Testing FCN32s with PyTorch 38 | Elapsed time: 42.63 [s / 1000 evals] 39 | Hz: 23.46 [hz] 40 | 41 | % ./speedtest.py --gpu 3 --dynamic-input 42 | ==> Benchmark: gpu=3, times=1000, dynamic_input=True 43 | ==> Testing FCN32s with Chainer 44 | Elapsed time: 47.68 [s / 1000 evals] 45 | Hz: 20.97 [hz] 46 | ==> Testing FCN32s with PyTorch 47 | Elapsed time: 54.49 [s / 1000 evals] 48 | Hz: 18.35 [hz] 49 | ``` 50 | 51 | 52 | ## Caffe to PyTorch model 53 | 54 | ``` 55 | git clone https://github.com/BVLC/caffe.git 56 | cd caffe 57 | cp Makefile.config.example Makefile.config 58 | vim Makefile.config # edit as you like 59 | make -j 60 | make pycaffe 61 | export PYTHONPATH=$(pwd)/python:$PYTHONPATH 62 | cd .. 63 | 64 | cd pytorch-fcn 65 | cd examples/voc 66 | ./model_caffe_to_pytorch.py 67 | ``` 68 | -------------------------------------------------------------------------------- /examples/voc/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIR=~/data/datasets/VOC 4 | 5 | mkdir -p $DIR 6 | cd $DIR 7 | 8 | if [ ! -e benchmark_RELEASE ]; then 9 | wget http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz -O benchmark.tar 10 | tar -xvf benchmark.tar 11 | fi 12 | 13 | if [ ! -e VOCdevkit/VOC2012 ]; then 14 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 15 | tar -xvf VOCtrainval_11-May-2012.tar 16 | fi 17 | -------------------------------------------------------------------------------- /examples/voc/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import os.path as osp 6 | 7 | import fcn 8 | import numpy as np 9 | import skimage.io 10 | import torch 11 | from torch.autograd import Variable 12 | import torchfcn 13 | import tqdm 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('model_file', help='Model path') 19 | parser.add_argument('-g', '--gpu', type=int, default=0) 20 | args = parser.parse_args() 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 23 | model_file = args.model_file 24 | 25 | root = osp.expanduser('~/data/datasets') 26 | val_loader = torch.utils.data.DataLoader( 27 | torchfcn.datasets.VOC2011ClassSeg( 28 | root, split='seg11valid', transform=True), 29 | batch_size=1, shuffle=False, 30 | num_workers=4, pin_memory=True) 31 | 32 | n_class = len(val_loader.dataset.class_names) 33 | 34 | if osp.basename(model_file).startswith('fcn32s'): 35 | model = torchfcn.models.FCN32s(n_class=21) 36 | elif osp.basename(model_file).startswith('fcn16s'): 37 | model = torchfcn.models.FCN16s(n_class=21) 38 | elif osp.basename(model_file).startswith('fcn8s'): 39 | if osp.basename(model_file).startswith('fcn8s-atonce'): 40 | model = torchfcn.models.FCN8sAtOnce(n_class=21) 41 | else: 42 | model = torchfcn.models.FCN8s(n_class=21) 43 | else: 44 | raise ValueError 45 | if torch.cuda.is_available(): 46 | model = model.cuda() 47 | print('==> Loading %s model file: %s' % 48 | (model.__class__.__name__, model_file)) 49 | model_data = torch.load(model_file) 50 | try: 51 | model.load_state_dict(model_data) 52 | except Exception: 53 | model.load_state_dict(model_data['model_state_dict']) 54 | model.eval() 55 | 56 | print('==> Evaluating with VOC2011ClassSeg seg11valid') 57 | visualizations = [] 58 | label_trues, label_preds = [], [] 59 | for batch_idx, (data, target) in tqdm.tqdm(enumerate(val_loader), 60 | total=len(val_loader), 61 | ncols=80, leave=False): 62 | if torch.cuda.is_available(): 63 | data, target = data.cuda(), target.cuda() 64 | data, target = Variable(data, volatile=True), Variable(target) 65 | score = model(data) 66 | 67 | imgs = data.data.cpu() 68 | lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] 69 | lbl_true = target.data.cpu() 70 | for img, lt, lp in zip(imgs, lbl_true, lbl_pred): 71 | img, lt = val_loader.dataset.untransform(img, lt) 72 | label_trues.append(lt) 73 | label_preds.append(lp) 74 | if len(visualizations) < 9: 75 | viz = fcn.utils.visualize_segmentation( 76 | lbl_pred=lp, lbl_true=lt, img=img, n_class=n_class, 77 | label_names=val_loader.dataset.class_names) 78 | visualizations.append(viz) 79 | metrics = torchfcn.utils.label_accuracy_score( 80 | label_trues, label_preds, n_class=n_class) 81 | metrics = np.array(metrics) 82 | metrics *= 100 83 | print('''\ 84 | Accuracy: {0} 85 | Accuracy Class: {1} 86 | Mean IU: {2} 87 | FWAV Accuracy: {3}'''.format(*metrics)) 88 | 89 | viz = fcn.utils.get_tile_image(visualizations) 90 | skimage.io.imsave('viz_evaluate.png', viz) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /examples/voc/learning_curve.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import division 4 | 5 | import argparse 6 | import os.path as osp 7 | 8 | import matplotlib 9 | 10 | matplotlib.use('Agg') 11 | 12 | import matplotlib.pyplot as plt # NOQA 13 | import pandas # NOQA 14 | import seaborn # NOQA 15 | 16 | 17 | def learning_curve(log_file): 18 | print('==> Plotting log file: %s' % log_file) 19 | 20 | df = pandas.read_csv(log_file) 21 | 22 | colors = ['red', 'green', 'blue', 'purple', 'orange'] 23 | colors = seaborn.xkcd_palette(colors) 24 | 25 | plt.figure(figsize=(20, 6), dpi=300) 26 | 27 | row_min = df.min() 28 | row_max = df.max() 29 | 30 | # initialize DataFrame for train 31 | columns = [ 32 | 'epoch', 33 | 'iteration', 34 | 'train/loss', 35 | 'train/acc', 36 | 'train/acc_cls', 37 | 'train/mean_iu', 38 | 'train/fwavacc', 39 | ] 40 | df_train = df[columns] 41 | if hasattr(df_train, 'rolling'): 42 | df_train = df_train.rolling(window=10).mean() 43 | else: 44 | df_train = pandas.rolling_mean(df_train, window=10) 45 | df_train = df_train.dropna() 46 | iter_per_epoch = df_train[df_train['epoch'] == 1]['iteration'].values[0] 47 | df_train['epoch_detail'] = df_train['iteration'] / iter_per_epoch 48 | 49 | # initialize DataFrame for val 50 | columns = [ 51 | 'epoch', 52 | 'iteration', 53 | 'valid/loss', 54 | 'valid/acc', 55 | 'valid/acc_cls', 56 | 'valid/mean_iu', 57 | 'valid/fwavacc', 58 | ] 59 | df_valid = df[columns] 60 | df_valid = df_valid.dropna() 61 | df_valid['epoch_detail'] = df_valid['iteration'] / iter_per_epoch 62 | 63 | data_frames = {'train': df_train, 'valid': df_valid} 64 | 65 | n_row = 2 66 | n_col = 3 67 | for i, split in enumerate(['train', 'valid']): 68 | df_split = data_frames[split] 69 | 70 | # loss 71 | plt.subplot(n_row, n_col, i * n_col + 1) 72 | plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) 73 | plt.plot(df_split['epoch_detail'], df_split['%s/loss' % split], '-', 74 | markersize=1, color=colors[0], alpha=.5, 75 | label='%s loss' % split) 76 | plt.xlim((0, row_max['epoch'])) 77 | plt.ylim((min(row_min['train/loss'], row_min['valid/loss']), 78 | max(row_max['train/loss'], row_max['valid/loss']))) 79 | plt.xlabel('epoch') 80 | plt.ylabel('%s loss' % split) 81 | 82 | # loss (log) 83 | plt.subplot(n_row, n_col, i * n_col + 2) 84 | plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) 85 | plt.semilogy(df_split['epoch_detail'], df_split['%s/loss' % split], 86 | '-', markersize=1, color=colors[0], alpha=.5, 87 | label='%s loss' % split) 88 | plt.xlim((0, row_max['epoch'])) 89 | plt.ylim((min(row_min['train/loss'], row_min['valid/loss']), 90 | max(row_max['train/loss'], row_max['valid/loss']))) 91 | plt.xlabel('epoch') 92 | plt.ylabel('%s loss (log)' % split) 93 | 94 | # lbl accuracy 95 | plt.subplot(n_row, n_col, i * n_col + 3) 96 | plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) 97 | plt.plot(df_split['epoch_detail'], df_split['%s/acc' % split], 98 | '-', markersize=1, color=colors[1], alpha=.5, 99 | label='%s accuracy' % split) 100 | plt.plot(df_split['epoch_detail'], df_split['%s/acc_cls' % split], 101 | '-', markersize=1, color=colors[2], alpha=.5, 102 | label='%s accuracy class' % split) 103 | plt.plot(df_split['epoch_detail'], df_split['%s/mean_iu' % split], 104 | '-', markersize=1, color=colors[3], alpha=.5, 105 | label='%s mean IU' % split) 106 | plt.plot(df_split['epoch_detail'], df_split['%s/fwavacc' % split], 107 | '-', markersize=1, color=colors[4], alpha=.5, 108 | label='%s fwav accuracy' % split) 109 | plt.legend() 110 | plt.xlim((0, row_max['epoch'])) 111 | plt.ylim((0, 1)) 112 | plt.xlabel('epoch') 113 | plt.ylabel('%s label accuracy' % split) 114 | 115 | out_file = osp.splitext(log_file)[0] + '.png' 116 | plt.savefig(out_file) 117 | print('==> Wrote figure to: %s' % out_file) 118 | 119 | 120 | def main(): 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('log_file') 123 | args = parser.parse_args() 124 | 125 | log_file = args.log_file 126 | 127 | learning_curve(log_file) 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /examples/voc/model_caffe_to_pytorch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os.path as osp 4 | import pkg_resources 5 | import sys 6 | 7 | import torch 8 | 9 | # FIXME: must be after import torch 10 | import caffe 11 | 12 | import torchfcn 13 | 14 | 15 | models = [ 16 | ('fcn32s', 'FCN32s', []), 17 | ('fcn16s', 'FCN16s', []), 18 | ('fcn8s', 'FCN8s', []), 19 | ('fcn8s-atonce', 'FCN8sAtOnce', ['scale_pool4', 'scale_pool3']), 20 | ] 21 | 22 | 23 | for name_lower, name_upper, blacklists in models: 24 | print('==> Loading caffe model of %s' % name_upper) 25 | pkg_root = pkg_resources.get_distribution('torchfcn').location 26 | sys.path.insert( 27 | 0, osp.join(pkg_root, 'torchfcn/ext/fcn.berkeleyvision.org')) 28 | caffe_prototxt = osp.join( 29 | pkg_root, 30 | 'torchfcn/ext/fcn.berkeleyvision.org/voc-%s/deploy.prototxt' % 31 | name_lower) 32 | caffe_model_path = osp.expanduser( 33 | '~/data/models/caffe/%s-heavy-pascal.caffemodel' % name_lower) 34 | caffe_model = caffe.Net(caffe_prototxt, caffe_model_path, caffe.TEST) 35 | 36 | torch_model = getattr(torchfcn.models, name_upper)() 37 | 38 | torch_model_params = torch_model.parameters() 39 | for name, p1 in caffe_model.params.iteritems(): 40 | if name in blacklists: 41 | continue 42 | l2 = getattr(torch_model, name) 43 | p2 = l2.weight 44 | assert p1[0].data.shape == tuple(p2.data.size()) 45 | print('%s: %s -> %s' % (name, p1[0].data.shape, p2.data.size())) 46 | p2.data = torch.from_numpy(p1[0].data) 47 | if len(p1) == 2: 48 | p2 = l2.bias 49 | assert p1[1].data.shape == tuple(p2.data.size()) 50 | print('%s: %s -> %s' % (name, p1[1].data.shape, p2.data.size())) 51 | p2.data = torch.from_numpy(p1[1].data) 52 | 53 | torch_model_path = osp.expanduser( 54 | '~/data/models/pytorch/%s-heavy-pascal.pth' % name_lower) 55 | torch.save(torch_model.state_dict(), torch_model_path) 56 | print('==> Saved pytorch model: %s' % torch_model_path) 57 | -------------------------------------------------------------------------------- /examples/voc/speedtest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import time 5 | 6 | import numpy as np 7 | import six 8 | 9 | 10 | def bench_chainer(gpu, times, dynamic_input=False): 11 | import chainer 12 | import fcn 13 | print('==> Testing FCN32s with Chainer') 14 | chainer.cuda.get_device(gpu).use() 15 | 16 | chainer.config.train = False 17 | chainer.config.enable_backprop = False 18 | 19 | if dynamic_input: 20 | x_data = np.random.random((1, 3, 480, 640)).astype(np.float32) 21 | x_data = chainer.cuda.to_gpu(x_data) 22 | x1 = chainer.Variable(x_data) 23 | x_data = np.random.random((1, 3, 640, 480)).astype(np.float32) 24 | x_data = chainer.cuda.to_gpu(x_data) 25 | x2 = chainer.Variable(x_data) 26 | else: 27 | x_data = np.random.random((1, 3, 480, 640)).astype(np.float32) 28 | x_data = chainer.cuda.to_gpu(x_data) 29 | x1 = chainer.Variable(x_data) 30 | 31 | model = fcn.models.FCN32s() 32 | model.train = False 33 | model.to_gpu() 34 | 35 | for i in six.moves.range(5): 36 | model(x1) 37 | chainer.cuda.Stream().synchronize() 38 | t_start = time.time() 39 | for i in six.moves.range(times): 40 | if dynamic_input: 41 | if i % 2 == 1: 42 | model(x1) 43 | else: 44 | model(x2) 45 | else: 46 | model(x1) 47 | chainer.cuda.Stream().synchronize() 48 | elapsed_time = time.time() - t_start 49 | 50 | print('Elapsed time: %.2f [s / %d evals]' % (elapsed_time, times)) 51 | print('Hz: %.2f [hz]' % (times / elapsed_time)) 52 | 53 | 54 | def bench_pytorch(gpu, times, dynamic_input=False): 55 | import torch 56 | import torchfcn.models 57 | print('==> Testing FCN32s with PyTorch') 58 | torch.cuda.set_device(gpu) 59 | torch.backends.cudnn.benchmark = not dynamic_input 60 | 61 | model = torchfcn.models.FCN32s() 62 | model.eval() 63 | model = model.cuda() 64 | 65 | if dynamic_input: 66 | x_data = np.random.random((1, 3, 480, 640)) 67 | x1 = torch.autograd.Variable(torch.from_numpy(x_data).float(), 68 | volatile=True).cuda() 69 | x_data = np.random.random((1, 3, 640, 480)) 70 | x2 = torch.autograd.Variable(torch.from_numpy(x_data).float(), 71 | volatile=True).cuda() 72 | else: 73 | x_data = np.random.random((1, 3, 480, 640)) 74 | x1 = torch.autograd.Variable(torch.from_numpy(x_data).float(), 75 | volatile=True).cuda() 76 | 77 | for i in six.moves.range(5): 78 | model(x1) 79 | torch.cuda.synchronize() 80 | t_start = time.time() 81 | for i in six.moves.range(times): 82 | if dynamic_input: 83 | if i % 2 == 1: 84 | model(x1) 85 | else: 86 | model(x2) 87 | else: 88 | model(x1) 89 | torch.cuda.synchronize() 90 | elapsed_time = time.time() - t_start 91 | 92 | print('Elapsed time: %.2f [s / %d evals]' % (elapsed_time, times)) 93 | print('Hz: %.2f [hz]' % (times / elapsed_time)) 94 | 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('--gpu', type=int, default=0) 99 | parser.add_argument('--times', type=int, default=1000) 100 | parser.add_argument('--dynamic-input', action='store_true') 101 | args = parser.parse_args() 102 | 103 | print('==> Benchmark: gpu=%d, times=%d, dynamic_input=%s' % 104 | (args.gpu, args.times, args.dynamic_input)) 105 | bench_chainer(args.gpu, args.times, args.dynamic_input) 106 | bench_pytorch(args.gpu, args.times, args.dynamic_input) 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /examples/voc/summarize_logs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import os.path as osp 5 | 6 | import pandas as pd 7 | import tabulate 8 | import yaml 9 | 10 | 11 | def main(): 12 | logs_dir = 'logs' 13 | 14 | headers = [ 15 | 'name', 16 | 'model', 17 | 'git_hash', 18 | 'pretrained_model', 19 | 'epoch', 20 | 'iteration', 21 | 'valid/mean_iu', 22 | ] 23 | rows = [] 24 | for log in os.listdir(logs_dir): 25 | log_dir = osp.join(logs_dir, log) 26 | if not osp.isdir(log_dir): 27 | continue 28 | try: 29 | log_file = osp.join(log_dir, 'log.csv') 30 | df = pd.read_csv(log_file) 31 | columns = [c for c in df.columns if not c.startswith('train')] 32 | df = df[columns] 33 | df = df.set_index(['epoch', 'iteration']) 34 | index_best = df['valid/mean_iu'].idxmax() 35 | row_best = df.loc[index_best].dropna() 36 | 37 | with open(osp.join(log_dir, 'config.yaml')) as f: 38 | config = yaml.load(f) 39 | except Exception: 40 | continue 41 | rows.append([ 42 | osp.join(logs_dir, log), 43 | config['model'], 44 | config['git_hash'], 45 | config.get('pretrained_model', None), 46 | row_best.index[0][0], 47 | row_best.index[0][1], 48 | 100 * row_best['valid/mean_iu'].values[0], 49 | ]) 50 | rows.sort(key=lambda x: x[-1], reverse=True) 51 | print(tabulate.tabulate(rows, headers=headers)) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /examples/voc/train_fcn16s.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import datetime 5 | import os 6 | import os.path as osp 7 | 8 | import torch 9 | import yaml 10 | 11 | import torchfcn 12 | 13 | from train_fcn32s import get_parameters 14 | from train_fcn32s import git_hash 15 | 16 | 17 | here = osp.dirname(osp.abspath(__file__)) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser( 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 23 | ) 24 | parser.add_argument('-g', '--gpu', type=int, required=True, help='gpu id') 25 | parser.add_argument('--resume', help='checkpoint path') 26 | # configurations (same configuration as original work) 27 | # https://github.com/shelhamer/fcn.berkeleyvision.org 28 | parser.add_argument( 29 | '--max-iteration', type=int, default=100000, help='max iteration' 30 | ) 31 | parser.add_argument( 32 | '--lr', type=float, default=1.0e-12, help='learning rate', 33 | ) 34 | parser.add_argument( 35 | '--weight-decay', type=float, default=0.0005, help='weight decay', 36 | ) 37 | parser.add_argument( 38 | '--momentum', type=float, default=0.99, help='momentum', 39 | ) 40 | parser.add_argument( 41 | '--pretrained-model', 42 | default=torchfcn.models.FCN32s.download(), 43 | help='pretrained model of FCN32s', 44 | ) 45 | args = parser.parse_args() 46 | 47 | args.model = 'FCN16s' 48 | args.git_hash = git_hash() 49 | 50 | now = datetime.datetime.now() 51 | args.out = osp.join(here, 'logs', now.strftime('%Y%m%d_%H%M%S.%f')) 52 | 53 | os.makedirs(args.out) 54 | with open(osp.join(args.out, 'config.yaml'), 'w') as f: 55 | yaml.safe_dump(args.__dict__, f, default_flow_style=False) 56 | 57 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 58 | cuda = torch.cuda.is_available() 59 | 60 | torch.manual_seed(1337) 61 | if cuda: 62 | torch.cuda.manual_seed(1337) 63 | 64 | # 1. dataset 65 | 66 | root = osp.expanduser('~/data/datasets') 67 | kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {} 68 | train_loader = torch.utils.data.DataLoader( 69 | torchfcn.datasets.SBDClassSeg(root, split='train', transform=True), 70 | batch_size=1, shuffle=True, **kwargs) 71 | val_loader = torch.utils.data.DataLoader( 72 | torchfcn.datasets.VOC2011ClassSeg( 73 | root, split='seg11valid', transform=True), 74 | batch_size=1, shuffle=False, **kwargs) 75 | 76 | # 2. model 77 | 78 | model = torchfcn.models.FCN16s(n_class=21) 79 | start_epoch = 0 80 | start_iteration = 0 81 | if args.resume: 82 | checkpoint = torch.load(args.resume) 83 | model.load_state_dict(checkpoint['model_state_dict']) 84 | start_epoch = checkpoint['epoch'] 85 | start_iteration = checkpoint['iteration'] 86 | else: 87 | fcn32s = torchfcn.models.FCN32s() 88 | state_dict = torch.load(args.pretrained_model) 89 | try: 90 | fcn32s.load_state_dict(state_dict) 91 | except RuntimeError: 92 | fcn32s.load_state_dict(state_dict['model_state_dict']) 93 | model.copy_params_from_fcn32s(fcn32s) 94 | if cuda: 95 | model = model.cuda() 96 | 97 | # 3. optimizer 98 | 99 | optim = torch.optim.SGD( 100 | [ 101 | {'params': get_parameters(model, bias=False)}, 102 | {'params': get_parameters(model, bias=True), 103 | 'lr': args.lr * 2, 'weight_decay': 0}, 104 | ], 105 | lr=args.lr, 106 | momentum=args.momentum, 107 | weight_decay=args.weight_decay) 108 | if args.resume: 109 | optim.load_state_dict(checkpoint['optim_state_dict']) 110 | 111 | trainer = torchfcn.Trainer( 112 | cuda=cuda, 113 | model=model, 114 | optimizer=optim, 115 | train_loader=train_loader, 116 | val_loader=val_loader, 117 | out=args.out, 118 | max_iter=args.max_iteration, 119 | interval_validate=4000, 120 | ) 121 | trainer.epoch = start_epoch 122 | trainer.iteration = start_iteration 123 | trainer.train() 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /examples/voc/train_fcn32s.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import datetime 5 | import os 6 | import os.path as osp 7 | import shlex 8 | import subprocess 9 | 10 | import torch 11 | import yaml 12 | 13 | import torchfcn 14 | 15 | 16 | def git_hash(): 17 | cmd = 'git log -n 1 --pretty="%h"' 18 | ret = subprocess.check_output(shlex.split(cmd)).strip() 19 | if isinstance(ret, bytes): 20 | ret = ret.decode() 21 | return ret 22 | 23 | 24 | def get_parameters(model, bias=False): 25 | import torch.nn as nn 26 | modules_skipped = ( 27 | nn.ReLU, 28 | nn.MaxPool2d, 29 | nn.Dropout2d, 30 | nn.Sequential, 31 | torchfcn.models.FCN32s, 32 | torchfcn.models.FCN16s, 33 | torchfcn.models.FCN8s, 34 | ) 35 | for m in model.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | if bias: 38 | yield m.bias 39 | else: 40 | yield m.weight 41 | elif isinstance(m, nn.ConvTranspose2d): 42 | # weight is frozen because it is just a bilinear upsampling 43 | if bias: 44 | assert m.bias is None 45 | elif isinstance(m, modules_skipped): 46 | continue 47 | else: 48 | raise ValueError('Unexpected module: %s' % str(m)) 49 | 50 | 51 | here = osp.dirname(osp.abspath(__file__)) 52 | 53 | 54 | def main(): 55 | parser = argparse.ArgumentParser( 56 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 57 | ) 58 | parser.add_argument('-g', '--gpu', type=int, required=True, help='gpu id') 59 | parser.add_argument('--resume', help='checkpoint path') 60 | # configurations (same configuration as original work) 61 | # https://github.com/shelhamer/fcn.berkeleyvision.org 62 | parser.add_argument( 63 | '--max-iteration', type=int, default=100000, help='max iteration' 64 | ) 65 | parser.add_argument( 66 | '--lr', type=float, default=1.0e-10, help='learning rate', 67 | ) 68 | parser.add_argument( 69 | '--weight-decay', type=float, default=0.0005, help='weight decay', 70 | ) 71 | parser.add_argument( 72 | '--momentum', type=float, default=0.99, help='momentum', 73 | ) 74 | args = parser.parse_args() 75 | 76 | args.model = 'FCN32s' 77 | args.git_hash = git_hash() 78 | 79 | now = datetime.datetime.now() 80 | args.out = osp.join(here, 'logs', now.strftime('%Y%m%d_%H%M%S.%f')) 81 | 82 | os.makedirs(args.out) 83 | with open(osp.join(args.out, 'config.yaml'), 'w') as f: 84 | yaml.safe_dump(args.__dict__, f, default_flow_style=False) 85 | 86 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 87 | cuda = torch.cuda.is_available() 88 | 89 | torch.manual_seed(1337) 90 | if cuda: 91 | torch.cuda.manual_seed(1337) 92 | 93 | # 1. dataset 94 | 95 | root = osp.expanduser('~/data/datasets') 96 | kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {} 97 | train_loader = torch.utils.data.DataLoader( 98 | torchfcn.datasets.SBDClassSeg(root, split='train', transform=True), 99 | batch_size=1, shuffle=True, **kwargs) 100 | val_loader = torch.utils.data.DataLoader( 101 | torchfcn.datasets.VOC2011ClassSeg( 102 | root, split='seg11valid', transform=True), 103 | batch_size=1, shuffle=False, **kwargs) 104 | 105 | # 2. model 106 | 107 | model = torchfcn.models.FCN32s(n_class=21) 108 | start_epoch = 0 109 | start_iteration = 0 110 | if args.resume: 111 | checkpoint = torch.load(args.resume) 112 | model.load_state_dict(checkpoint['model_state_dict']) 113 | start_epoch = checkpoint['epoch'] 114 | start_iteration = checkpoint['iteration'] 115 | else: 116 | vgg16 = torchfcn.models.VGG16(pretrained=True) 117 | model.copy_params_from_vgg16(vgg16) 118 | if cuda: 119 | model = model.cuda() 120 | 121 | # 3. optimizer 122 | 123 | optim = torch.optim.SGD( 124 | [ 125 | {'params': get_parameters(model, bias=False)}, 126 | {'params': get_parameters(model, bias=True), 127 | 'lr': args.lr * 2, 'weight_decay': 0}, 128 | ], 129 | lr=args.lr, 130 | momentum=args.momentum, 131 | weight_decay=args.weight_decay) 132 | if args.resume: 133 | optim.load_state_dict(checkpoint['optim_state_dict']) 134 | 135 | trainer = torchfcn.Trainer( 136 | cuda=cuda, 137 | model=model, 138 | optimizer=optim, 139 | train_loader=train_loader, 140 | val_loader=val_loader, 141 | out=args.out, 142 | max_iter=args.max_iteration, 143 | interval_validate=4000, 144 | ) 145 | trainer.epoch = start_epoch 146 | trainer.iteration = start_iteration 147 | trainer.train() 148 | 149 | 150 | if __name__ == '__main__': 151 | main() 152 | -------------------------------------------------------------------------------- /examples/voc/train_fcn8s.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import datetime 5 | import os 6 | import os.path as osp 7 | 8 | import torch 9 | import yaml 10 | 11 | import torchfcn 12 | 13 | from train_fcn32s import get_parameters 14 | from train_fcn32s import git_hash 15 | 16 | 17 | here = osp.dirname(osp.abspath(__file__)) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser( 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 23 | ) 24 | parser.add_argument('-g', '--gpu', type=int, required=True, help='gpu id') 25 | parser.add_argument('--resume', help='checkpoint path') 26 | # configurations (same configuration as original work) 27 | # https://github.com/shelhamer/fcn.berkeleyvision.org 28 | parser.add_argument( 29 | '--max-iteration', type=int, default=100000, help='max iteration' 30 | ) 31 | parser.add_argument( 32 | '--lr', type=float, default=1.0e-14, help='learning rate', 33 | ) 34 | parser.add_argument( 35 | '--weight-decay', type=float, default=0.0005, help='weight decay', 36 | ) 37 | parser.add_argument( 38 | '--momentum', type=float, default=0.99, help='momentum', 39 | ) 40 | parser.add_argument( 41 | '--pretrained-model', 42 | default=torchfcn.models.FCN16s.download(), 43 | help='pretrained model of FCN16s', 44 | ) 45 | args = parser.parse_args() 46 | 47 | args.model = 'FCN8s' 48 | args.git_hash = git_hash() 49 | 50 | now = datetime.datetime.now() 51 | args.out = osp.join(here, 'logs', now.strftime('%Y%m%d_%H%M%S.%f')) 52 | 53 | os.makedirs(args.out) 54 | with open(osp.join(args.out, 'config.yaml'), 'w') as f: 55 | yaml.safe_dump(args.__dict__, f, default_flow_style=False) 56 | 57 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 58 | cuda = torch.cuda.is_available() 59 | 60 | torch.manual_seed(1337) 61 | if cuda: 62 | torch.cuda.manual_seed(1337) 63 | 64 | # 1. dataset 65 | 66 | root = osp.expanduser('~/data/datasets') 67 | kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {} 68 | train_loader = torch.utils.data.DataLoader( 69 | torchfcn.datasets.SBDClassSeg(root, split='train', transform=True), 70 | batch_size=1, shuffle=True, **kwargs) 71 | val_loader = torch.utils.data.DataLoader( 72 | torchfcn.datasets.VOC2011ClassSeg( 73 | root, split='seg11valid', transform=True), 74 | batch_size=1, shuffle=False, **kwargs) 75 | 76 | # 2. model 77 | 78 | model = torchfcn.models.FCN8s(n_class=21) 79 | start_epoch = 0 80 | start_iteration = 0 81 | if args.resume: 82 | checkpoint = torch.load(args.resume) 83 | model.load_state_dict(checkpoint['model_state_dict']) 84 | start_epoch = checkpoint['epoch'] 85 | start_iteration = checkpoint['iteration'] 86 | else: 87 | fcn16s = torchfcn.models.FCN16s() 88 | state_dict = torch.load(args.pretrained_model) 89 | try: 90 | fcn16s.load_state_dict(state_dict) 91 | except RuntimeError: 92 | fcn16s.load_state_dict(state_dict['model_state_dict']) 93 | model.copy_params_from_fcn16s(fcn16s) 94 | if cuda: 95 | model = model.cuda() 96 | 97 | # 3. optimizer 98 | 99 | optim = torch.optim.SGD( 100 | [ 101 | {'params': get_parameters(model, bias=False)}, 102 | {'params': get_parameters(model, bias=True), 103 | 'lr': args.lr * 2, 'weight_decay': 0}, 104 | ], 105 | lr=args.lr, 106 | momentum=args.momentum, 107 | weight_decay=args.weight_decay) 108 | if args.resume: 109 | optim.load_state_dict(checkpoint['optim_state_dict']) 110 | 111 | trainer = torchfcn.Trainer( 112 | cuda=cuda, 113 | model=model, 114 | optimizer=optim, 115 | train_loader=train_loader, 116 | val_loader=val_loader, 117 | out=args.out, 118 | max_iter=args.max_iteration, 119 | interval_validate=4000, 120 | ) 121 | trainer.epoch = start_epoch 122 | trainer.iteration = start_iteration 123 | trainer.train() 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /examples/voc/train_fcn8s_atonce.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import datetime 5 | import os 6 | import os.path as osp 7 | 8 | import torch 9 | import yaml 10 | 11 | import torchfcn 12 | 13 | from train_fcn32s import get_parameters 14 | from train_fcn32s import git_hash 15 | 16 | 17 | here = osp.dirname(osp.abspath(__file__)) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser( 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 23 | ) 24 | parser.add_argument('-g', '--gpu', type=int, required=True, help='gpu id') 25 | parser.add_argument('--resume', help='checkpoint path') 26 | # configurations (same configuration as original work) 27 | # https://github.com/shelhamer/fcn.berkeleyvision.org 28 | parser.add_argument( 29 | '--max-iteration', type=int, default=100000, help='max iteration' 30 | ) 31 | parser.add_argument( 32 | '--lr', type=float, default=1.0e-10, help='learning rate', 33 | ) 34 | parser.add_argument( 35 | '--weight-decay', type=float, default=0.0005, help='weight decay', 36 | ) 37 | parser.add_argument( 38 | '--momentum', type=float, default=0.99, help='momentum', 39 | ) 40 | args = parser.parse_args() 41 | 42 | args.model = 'FCN8sAtOnce' 43 | args.git_hash = git_hash() 44 | 45 | now = datetime.datetime.now() 46 | args.out = osp.join(here, 'logs', now.strftime('%Y%m%d_%H%M%S.%f')) 47 | 48 | os.makedirs(args.out) 49 | with open(osp.join(args.out, 'config.yaml'), 'w') as f: 50 | yaml.safe_dump(args.__dict__, f, default_flow_style=False) 51 | 52 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 53 | cuda = torch.cuda.is_available() 54 | 55 | torch.manual_seed(1337) 56 | if cuda: 57 | torch.cuda.manual_seed(1337) 58 | 59 | # 1. dataset 60 | 61 | root = osp.expanduser('~/data/datasets') 62 | kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {} 63 | train_loader = torch.utils.data.DataLoader( 64 | torchfcn.datasets.SBDClassSeg(root, split='train', transform=True), 65 | batch_size=1, shuffle=True, **kwargs) 66 | val_loader = torch.utils.data.DataLoader( 67 | torchfcn.datasets.VOC2011ClassSeg( 68 | root, split='seg11valid', transform=True), 69 | batch_size=1, shuffle=False, **kwargs) 70 | 71 | # 2. model 72 | 73 | model = torchfcn.models.FCN8sAtOnce(n_class=21) 74 | start_epoch = 0 75 | start_iteration = 0 76 | if args.resume: 77 | checkpoint = torch.load(args.resume) 78 | model.load_state_dict(checkpoint['model_state_dict']) 79 | start_epoch = checkpoint['epoch'] 80 | start_iteration = checkpoint['iteration'] 81 | else: 82 | vgg16 = torchfcn.models.VGG16(pretrained=True) 83 | model.copy_params_from_vgg16(vgg16) 84 | if cuda: 85 | model = model.cuda() 86 | 87 | # 3. optimizer 88 | 89 | optim = torch.optim.SGD( 90 | [ 91 | {'params': get_parameters(model, bias=False)}, 92 | {'params': get_parameters(model, bias=True), 93 | 'lr': args.lr * 2, 'weight_decay': 0}, 94 | ], 95 | lr=args.lr, 96 | momentum=args.momentum, 97 | weight_decay=args.weight_decay) 98 | if args.resume: 99 | optim.load_state_dict(checkpoint['optim_state_dict']) 100 | 101 | trainer = torchfcn.Trainer( 102 | cuda=cuda, 103 | model=model, 104 | optimizer=optim, 105 | train_loader=train_loader, 106 | val_loader=val_loader, 107 | out=args.out, 108 | max_iter=args.max_iteration, 109 | interval_validate=4000, 110 | ) 111 | trainer.epoch = start_epoch 112 | trainer.iteration = start_iteration 113 | trainer.train() 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /examples/voc/view_log: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os.path as osp 7 | import sys 8 | import time 9 | 10 | import pandas as pd 11 | 12 | 13 | def print_bar(title='', width=80): 14 | if title: 15 | title = ' ' + title + ' ' 16 | length = len(title) 17 | if length % 2 == 0: 18 | length_left = length_right = length // 2 19 | else: 20 | length_left = length // 2 21 | length_right = length - length_left 22 | print('=' * (width // 2 - 1 - length_left) + 23 | title + '=' * (width // 2 - length_right)) 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('log_file') 29 | parser.add_argument('--all', action='store_true') 30 | parser.add_argument('-1', '--once', action='store_true') 31 | args = parser.parse_args() 32 | 33 | log_file = args.log_file 34 | 35 | pd.set_option('display.width', 200) 36 | pd.set_option('display.float_format', lambda x: '%.3f' % x) 37 | 38 | while True: 39 | try: 40 | ext = osp.splitext(log_file)[-1] 41 | if ext == '.json': 42 | df = pd.read_json(log_file) 43 | elif ext == '.csv': 44 | df = pd.read_csv(log_file) 45 | else: 46 | print('Unsupported file extension: {}'.format(log_file)) 47 | sys.exit(1) 48 | df = df.set_index(['epoch', 'iteration']) 49 | 50 | train_cols, valid_cols, else_cols = [], [], [] 51 | for col in df.columns: 52 | if col.startswith('validation/') or col.startswith('valid/'): 53 | valid_cols.append(col) 54 | elif col.startswith('train/'): 55 | train_cols.append(col) 56 | else: 57 | else_cols.append(col) 58 | 59 | if args.all: 60 | print(df.to_string()) 61 | break 62 | 63 | width = len(' '.join(['epoch', 'iteration']) + ' '.join(train_cols + else_cols) + ' ') 64 | 65 | print(chr(27) + "[2J") 66 | 67 | log_dir = osp.dirname(log_file) 68 | param_files = [osp.join(log_dir, f) for f in ['params.yaml', 'config.yaml']] 69 | exists = [osp.exists(f) for f in param_files] 70 | if any(exists): 71 | import yaml 72 | param_file = param_files[exists.index(True)] 73 | data = yaml.load(open(param_file)) 74 | print_bar('params', width=width) 75 | print(yaml.safe_dump(data, default_flow_style=False)) 76 | print_bar('', width=width) 77 | 78 | print('log_file: %s' % log_file) 79 | 80 | if df.empty: 81 | time.sleep(1) 82 | continue 83 | 84 | try: 85 | df_train = df[train_cols + else_cols].dropna(thresh=len(train_cols)) 86 | except: 87 | df_train = df[train_cols + else_cols].dropna() 88 | if not df_train.empty: 89 | print_bar('train', width=width) 90 | print(df_train.tail(n=5)) 91 | print() 92 | 93 | try: 94 | df_valid = df[valid_cols + else_cols].dropna(thresh=len(valid_cols)) 95 | except: 96 | df_valid = df[valid_cols + else_cols].dropna() 97 | if not df_valid.empty: 98 | print_bar('valid', width=width) 99 | print(df_valid.tail(n=3)) 100 | print() 101 | 102 | for col in valid_cols: 103 | if 'loss' in col: 104 | print_bar('min:%s' % col, width=width) 105 | idx = df[col].idxmin() 106 | else: 107 | print_bar('max:%s' % col, width=width) 108 | idx = df[col].idxmax() 109 | try: 110 | print(df.ix[idx][valid_cols + else_cols].dropna(thresh=len(valid_cols))) 111 | except: 112 | print(df.ix[idx][valid_cols + else_cols].dropna()) 113 | 114 | print_bar(width=width) 115 | 116 | if args.once: 117 | break 118 | time.sleep(1) 119 | except KeyboardInterrupt: 120 | break 121 | except IOError as e: 122 | print(chr(27) + "[2J") 123 | print(e) 124 | time.sleep(1) 125 | continue 126 | except Exception as e: 127 | print(e) 128 | break 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fcn>=6.1.5 2 | numpy 3 | Pillow 4 | pytz 5 | scipy 6 | torch>=0.2.0 7 | torchvision>=0.1.8 8 | tqdm 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .eggs,*.egg,build,torchfcn/ext/*/* 3 | ignore = H304,W504 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import distutils.spawn 6 | import shlex 7 | import subprocess 8 | import sys 9 | 10 | from setuptools import find_packages 11 | from setuptools import setup 12 | 13 | 14 | version = '1.9.7' 15 | 16 | 17 | if sys.argv[1] == 'release': 18 | if not distutils.spawn.find_executable('twine'): 19 | print( 20 | 'Please install twine:\n\n\tpip install twine\n', 21 | file=sys.stderr, 22 | ) 23 | sys.exit(1) 24 | 25 | commands = [ 26 | 'git pull origin main', 27 | 'git tag v{:s}'.format(version), 28 | 'git push origin main --tags', 29 | 'python setup.py sdist', 30 | 'twine upload dist/torchfcn-{:s}.tar.gz'.format(version), 31 | ] 32 | for cmd in commands: 33 | print('+ {}'.format(cmd)) 34 | subprocess.check_call(shlex.split(cmd)) 35 | sys.exit(0) 36 | 37 | 38 | def get_long_description(): 39 | with open('README.md') as f: 40 | long_description = f.read() 41 | 42 | try: 43 | import github2pypi 44 | 45 | return github2pypi.replace_url( 46 | slug='wkentaro/pytorch-fcn', content=long_description 47 | ) 48 | except Exception: 49 | return long_description 50 | 51 | 52 | def get_install_requires(): 53 | with open('requirements.txt') as f: 54 | return [req.strip() for req in f] 55 | 56 | 57 | setup( 58 | name='torchfcn', 59 | version=version, 60 | packages=find_packages(exclude=['github2pypi']), 61 | install_requires=get_install_requires(), 62 | description='PyTorch Implementation of Fully Convolutional Networks.', 63 | long_description=get_long_description(), 64 | long_description_content_type='text/markdown', 65 | package_data={'torchfcn': ['ext/*']}, 66 | include_package_data=True, 67 | author='Kentaro Wada', 68 | author_email='www.kentaro.wada@gmail.com', 69 | license='MIT', 70 | url='https://github.com/wkentaro/pytorch-fcn', 71 | classifiers=[ 72 | 'Development Status :: 5 - Production/Stable', 73 | 'Intended Audience :: Developers', 74 | 'Natural Language :: English', 75 | 'License :: OSI Approved :: MIT License', 76 | 'Programming Language :: Python', 77 | 'Programming Language :: Python :: 2.7', 78 | 'Programming Language :: Python :: 3.5', 79 | 'Programming Language :: Python :: 3.6', 80 | 'Programming Language :: Python :: 3.7', 81 | 'Programming Language :: Python :: Implementation :: CPython', 82 | ], 83 | ) 84 | -------------------------------------------------------------------------------- /tests/models_tests/test_fcn32s.py: -------------------------------------------------------------------------------- 1 | # FIXME: Import order causes error: 2 | # ImportError: dlopen: cannot load any more object with static TL 3 | # https://github.com/pytorch/pytorch/issues/2083 4 | import torch 5 | 6 | import numpy as np 7 | import skimage.data 8 | 9 | from torchfcn.models.fcn32s import get_upsampling_weight 10 | 11 | 12 | def test_get_upsampling_weight(): 13 | src = skimage.data.coffee() 14 | x = src.transpose(2, 0, 1) 15 | x = x[np.newaxis, :, :, :] 16 | x = torch.from_numpy(x).float() 17 | x = torch.autograd.Variable(x) 18 | 19 | in_channels = 3 20 | out_channels = 3 21 | kernel_size = 4 22 | 23 | m = torch.nn.ConvTranspose2d( 24 | in_channels, out_channels, kernel_size, stride=2, bias=False) 25 | m.weight.data = get_upsampling_weight( 26 | in_channels, out_channels, kernel_size) 27 | 28 | y = m(x) 29 | 30 | y = y.data.numpy() 31 | y = y[0] 32 | y = y.transpose(1, 2, 0) 33 | dst = y.astype(np.uint8) 34 | 35 | assert abs(src.shape[0] * 2 - dst.shape[0]) <= 2 36 | assert abs(src.shape[1] * 2 - dst.shape[1]) <= 2 37 | 38 | return src, dst 39 | 40 | 41 | if __name__ == '__main__': 42 | import matplotlib.pyplot as plt 43 | 44 | src, dst = test_get_upsampling_weight() 45 | plt.subplot(121) 46 | plt.imshow(src) 47 | plt.title('x1: {}'.format(src.shape)) 48 | plt.subplot(122) 49 | plt.imshow(dst) 50 | plt.title('x2: {}'.format(dst.shape)) 51 | plt.show() 52 | -------------------------------------------------------------------------------- /torchfcn/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from . import datasets 3 | from . import models 4 | from . import utils 5 | from .trainer import Trainer 6 | -------------------------------------------------------------------------------- /torchfcn/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .voc import SBDClassSeg # NOQA 2 | from .voc import VOC2011ClassSeg # NOQA 3 | from .voc import VOC2012ClassSeg # NOQA 4 | -------------------------------------------------------------------------------- /torchfcn/datasets/voc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import os.path as osp 5 | 6 | import numpy as np 7 | import PIL.Image 8 | import scipy.io 9 | import torch 10 | from torch.utils import data 11 | 12 | 13 | class VOCClassSegBase(data.Dataset): 14 | 15 | class_names = np.array([ 16 | 'background', 17 | 'aeroplane', 18 | 'bicycle', 19 | 'bird', 20 | 'boat', 21 | 'bottle', 22 | 'bus', 23 | 'car', 24 | 'cat', 25 | 'chair', 26 | 'cow', 27 | 'diningtable', 28 | 'dog', 29 | 'horse', 30 | 'motorbike', 31 | 'person', 32 | 'potted plant', 33 | 'sheep', 34 | 'sofa', 35 | 'train', 36 | 'tv/monitor', 37 | ]) 38 | mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 39 | 40 | def __init__(self, root, split='train', transform=False): 41 | self.root = root 42 | self.split = split 43 | self._transform = transform 44 | 45 | # VOC2011 and others are subset of VOC2012 46 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 47 | self.files = collections.defaultdict(list) 48 | for split in ['train', 'val']: 49 | imgsets_file = osp.join( 50 | dataset_dir, 'ImageSets/Segmentation/%s.txt' % split) 51 | for did in open(imgsets_file): 52 | did = did.strip() 53 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 54 | lbl_file = osp.join( 55 | dataset_dir, 'SegmentationClass/%s.png' % did) 56 | self.files[split].append({ 57 | 'img': img_file, 58 | 'lbl': lbl_file, 59 | }) 60 | 61 | def __len__(self): 62 | return len(self.files[self.split]) 63 | 64 | def __getitem__(self, index): 65 | data_file = self.files[self.split][index] 66 | # load image 67 | img_file = data_file['img'] 68 | img = PIL.Image.open(img_file) 69 | img = np.array(img, dtype=np.uint8) 70 | # load label 71 | lbl_file = data_file['lbl'] 72 | lbl = PIL.Image.open(lbl_file) 73 | lbl = np.array(lbl, dtype=np.int32) 74 | lbl[lbl == 255] = -1 75 | if self._transform: 76 | return self.transform(img, lbl) 77 | else: 78 | return img, lbl 79 | 80 | def transform(self, img, lbl): 81 | img = img[:, :, ::-1] # RGB -> BGR 82 | img = img.astype(np.float64) 83 | img -= self.mean_bgr 84 | img = img.transpose(2, 0, 1) 85 | img = torch.from_numpy(img).float() 86 | lbl = torch.from_numpy(lbl).long() 87 | return img, lbl 88 | 89 | def untransform(self, img, lbl): 90 | img = img.numpy() 91 | img = img.transpose(1, 2, 0) 92 | img += self.mean_bgr 93 | img = img.astype(np.uint8) 94 | img = img[:, :, ::-1] 95 | lbl = lbl.numpy() 96 | return img, lbl 97 | 98 | 99 | class VOC2011ClassSeg(VOCClassSegBase): 100 | 101 | def __init__(self, root, split='train', transform=False): 102 | super(VOC2011ClassSeg, self).__init__( 103 | root, split=split, transform=transform) 104 | pkg_root = osp.join(osp.dirname(osp.realpath(__file__)), '..') 105 | imgsets_file = osp.join( 106 | pkg_root, 'ext/fcn.berkeleyvision.org', 107 | 'data/pascal/seg11valid.txt') 108 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 109 | for did in open(imgsets_file): 110 | did = did.strip() 111 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 112 | lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % did) 113 | self.files['seg11valid'].append({'img': img_file, 'lbl': lbl_file}) 114 | 115 | 116 | class VOC2012ClassSeg(VOCClassSegBase): 117 | 118 | url = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' # NOQA 119 | 120 | def __init__(self, root, split='train', transform=False): 121 | super(VOC2012ClassSeg, self).__init__( 122 | root, split=split, transform=transform) 123 | 124 | 125 | class SBDClassSeg(VOCClassSegBase): 126 | 127 | # XXX: It must be renamed to benchmark.tar to be extracted. 128 | url = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz' # NOQA 129 | 130 | def __init__(self, root, split='train', transform=False): 131 | self.root = root 132 | self.split = split 133 | self._transform = transform 134 | 135 | dataset_dir = osp.join(self.root, 'VOC/benchmark_RELEASE/dataset') 136 | self.files = collections.defaultdict(list) 137 | for split in ['train', 'val']: 138 | imgsets_file = osp.join(dataset_dir, '%s.txt' % split) 139 | for did in open(imgsets_file): 140 | did = did.strip() 141 | img_file = osp.join(dataset_dir, 'img/%s.jpg' % did) 142 | lbl_file = osp.join(dataset_dir, 'cls/%s.mat' % did) 143 | self.files[split].append({ 144 | 'img': img_file, 145 | 'lbl': lbl_file, 146 | }) 147 | 148 | def __getitem__(self, index): 149 | data_file = self.files[self.split][index] 150 | # load image 151 | img_file = data_file['img'] 152 | img = PIL.Image.open(img_file) 153 | img = np.array(img, dtype=np.uint8) 154 | # load label 155 | lbl_file = data_file['lbl'] 156 | mat = scipy.io.loadmat(lbl_file) 157 | lbl = mat['GTcls'][0]['Segmentation'][0].astype(np.int32) 158 | lbl[lbl == 255] = -1 159 | if self._transform: 160 | return self.transform(img, lbl) 161 | else: 162 | return img, lbl 163 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/README.md: -------------------------------------------------------------------------------- 1 | # Fully Convolutional Networks for Semantic Segmentation 2 | 3 | This is the reference implementation of the models and code for the fully convolutional networks (FCNs) in the [PAMI FCN](https://arxiv.org/abs/1605.06211) and [CVPR FCN](http://www.cv-foundation.org/openaccess/content_cvpr_2015/html/Long_Fully_Convolutional_Networks_2015_CVPR_paper.html) papers: 4 | 5 | Fully Convolutional Models for Semantic Segmentation 6 | Evan Shelhamer*, Jonathan Long*, Trevor Darrell 7 | PAMI 2016 8 | arXiv:1605.06211 9 | 10 | Fully Convolutional Models for Semantic Segmentation 11 | Jonathan Long*, Evan Shelhamer*, Trevor Darrell 12 | CVPR 2015 13 | arXiv:1411.4038 14 | 15 | **Note that this is a work in progress and the final, reference version is coming soon.** 16 | Please ask Caffe and FCN usage questions on the [caffe-users mailing list](https://groups.google.com/forum/#!forum/caffe-users). 17 | 18 | Refer to [these slides](https://docs.google.com/presentation/d/10XodYojlW-1iurpUsMoAZknQMS36p7lVIfFZ-Z7V_aY/edit?usp=sharing) for a summary of the approach. 19 | 20 | These models are compatible with `BVLC/caffe:master`. 21 | Compatibility has held since `master@8c66fa5` with the merge of PRs #3613 and #3570. 22 | The code and models here are available under the same license as Caffe (BSD-2) and the Caffe-bundled models (that is, unrestricted use; see the [BVLC model license](http://caffe.berkeleyvision.org/model_zoo.html#bvlc-model-license)). 23 | 24 | **PASCAL VOC models**: trained online with high momentum for a ~5 point boost in mean intersection-over-union over the original models. 25 | These models are trained using extra data from [Hariharan et al.](http://www.cs.berkeley.edu/~bharath2/codes/SBD/download.html), but excluding SBD val. 26 | FCN-32s is fine-tuned from the [ILSVRC-trained VGG-16 model](https://github.com/BVLC/caffe/wiki/Model-Zoo#models-used-by-the-vgg-team-in-ilsvrc-2014), and the finer strides are then fine-tuned in turn. 27 | The "at-once" FCN-8s is fine-tuned from VGG-16 all-at-once by scaling the skip connections to better condition optimization. 28 | 29 | * [FCN-32s PASCAL](voc-fcn32s): single stream, 32 pixel prediction stride net, scoring 63.6 mIU on seg11valid 30 | * [FCN-16s PASCAL](voc-fcn16s): two stream, 16 pixel prediction stride net, scoring 65.0 mIU on seg11valid 31 | * [FCN-8s PASCAL](voc-fcn8s): three stream, 8 pixel prediction stride net, scoring 65.5 mIU on seg11valid and 67.2 mIU on seg12test 32 | * [FCN-8s PASCAL at-once](voc-fcn8s-atonce): all-at-once, three stream, 8 pixel prediction stride net, scoring 65.4 mIU on seg11valid 33 | 34 | [FCN-AlexNet PASCAL](voc-fcn-alexnet): AlexNet (CaffeNet) architecture, single stream, 32 pixel prediction stride net, scoring 48.0 mIU on seg11valid. 35 | Unlike the FCN-32/16/8s models, this network is trained with gradient accumulation, normalized loss, and standard momentum. 36 | (Note: when both FCN-32s/FCN-VGG16 and FCN-AlexNet are trained in this same way FCN-VGG16 is far better; see Table 1 of the paper.) 37 | 38 | To reproduce the validation scores, use the [seg11valid](https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/data/pascal/seg11valid.txt) split defined by the paper in footnote 7. Since SBD train and PASCAL VOC 2011 segval intersect, we only evaluate on the non-intersecting set for validation purposes. 39 | 40 | **NYUDv2 models**: trained online with high momentum on color, depth, and HHA features (from Gupta et al. https://github.com/s-gupta/rcnn-depth). 41 | These models demonstrate FCNs for multi-modal input. 42 | 43 | * [FCN-32s NYUDv2 Color](nyud-fcn32s-color): single stream, 32 pixel prediction stride net on color/BGR input 44 | * [FCN-32s NYUDv2 HHA](nyud-fcn32s-hha): single stream, 32 pixel prediction stride net on HHA input 45 | * [FCN-32s NYUDv2 Early Color-Depth](nyud-fcn32s-color-d): single stream, 32 pixel prediction stride net on early fusion of color and (log) depth for 4-channel input 46 | * [FCN-32s NYUDv2 Late Color-HHA](nyud-fcn32s-color-hha): single stream, 32 pixel prediction stride net by late fusion of FCN-32s NYUDv2 Color and FCN-32s NYUDv2 HHA 47 | 48 | **SIFT Flow models**: trained online with high momentum for joint semantic class and geometric class segmentation. 49 | These models demonstrate FCNs for multi-task output. 50 | 51 | * [FCN-32s SIFT Flow](siftflow-fcn32s): single stream stream, 32 pixel prediction stride net 52 | * [FCN-16s SIFT Flow](siftflow-fcn16s): two stream, 16 pixel prediction stride net 53 | * [FCN-8s SIFT Flow](siftflow-fcn8s): three stream, 8 pixel prediction stride net 54 | 55 | *Note*: in this release, the evaluation of the semantic classes is not quite right at the moment due to an issue with missing classes. 56 | This will be corrected soon. 57 | The evaluation of the geometric classes is fine. 58 | 59 | **PASCAL-Context models**: trained online with high momentum on an object and scene labeling of PASCAL VOC. 60 | 61 | * [FCN-32s PASCAL-Context](pascalcontext-fcn32s): single stream, 32 pixel prediction stride net 62 | * [FCN-16s PASCAL-Context](pascalcontext-fcn16s): two stream, 16 pixel prediction stride net 63 | * [FCN-8s PASCAL-Context](pascalcontext-fcn8s): three stream, 8 pixel prediction stride net 64 | 65 | ## Frequently Asked Questions 66 | 67 | **Is learning the interpolation necessary?** In our original experiments the interpolation layers were initialized to bilinear kernels and then learned. 68 | In follow-up experiments, and this reference implementation, the bilinear kernels are fixed. 69 | There is no significant difference in accuracy in our experiments, and fixing these parameters gives a slight speed-up. 70 | Note that in our networks there is only one interpolation kernel per output class, and results may differ for higher-dimensional and non-linear interpolation, for which learning may help further. 71 | 72 | **Why pad the input?**: The 100 pixel input padding guarantees that the network output can be aligned to the input for any input size in the given datasets, for instance PASCAL VOC. 73 | The alignment is handled automatically by net specification and the crop layer. 74 | It is possible, though less convenient, to calculate the exact offsets necessary and do away with this amount of padding. 75 | 76 | **Why are all the outputs/gradients/parameters zero?**: This is almost universally due to not initializing the weights as needed. 77 | To reproduce our FCN training, or train your own FCNs, it is crucial to transplant the weights from the corresponding ILSVRC net such as VGG16. 78 | The included `surgery.transplant()` method can help with this. 79 | 80 | **What about FCN-GoogLeNet?**: a reference FCN-GoogLeNet for PASCAL VOC is coming soon. 81 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/nyud/README.md: -------------------------------------------------------------------------------- 1 | # NYUDv2: NYU Depth Dataset V2 2 | 3 | NYUDv2 has a curated semantic segmentation challenge with RGB-D inputs and full scene labels of objects and surfaces. 4 | While there are many labels, we follow the 40 class task defined by 5 | 6 | > Perceptual Organization and Recognition of Indoor Scenes from RGB-D Images. 7 | Saurabh Gupta, Pablo Arbelaez, and Jitendra Malik. 8 | CVPR 2013 9 | 10 | at http://www.cs.berkeley.edu/~sgupta/pdf/GuptaArbelaezMalikCVPR13.pdf . 11 | To reproduce the results of our paper, you must make use of the data from Gupta et al. at http://people.eecs.berkeley.edu/~sgupta/cvpr13/data.tgz . 12 | 13 | Refer to `classes.txt` for the listing of classes in model output order. 14 | Refer to `../nyud_layers.py` for the Python data layer for this dataset. 15 | 16 | See the dataset site: http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html. 17 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/nyud/classes.txt: -------------------------------------------------------------------------------- 1 | wall 2 | floor 3 | cabinet 4 | bed 5 | chair 6 | sofa 7 | table 8 | door 9 | window 10 | bookshelf 11 | picture 12 | counter 13 | blinds 14 | desk 15 | shelves 16 | curtain 17 | dresser 18 | pillow 19 | mirror 20 | floor mat 21 | clothes 22 | ceiling 23 | books 24 | refridgerator 25 | television 26 | paper 27 | towel 28 | shower curtain 29 | box 30 | whiteboard 31 | person 32 | night stand 33 | toilet 34 | sink 35 | lamp 36 | bathtub 37 | bag 38 | otherstructure 39 | otherfurniture 40 | otherprop 41 | 42 | and 0 is void (and converted to 255 by the NYUDSegDataLayer) 43 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/nyud/test.txt: -------------------------------------------------------------------------------- 1 | 5133 2 | 6002 3 | 6314 4 | 6298 5 | 5193 6 | 5434 7 | 5312 8 | 6022 9 | 5669 10 | 6082 11 | 5711 12 | 5780 13 | 6422 14 | 5840 15 | 5760 16 | 5532 17 | 5009 18 | 5430 19 | 6217 20 | 6155 21 | 6364 22 | 6151 23 | 5850 24 | 6032 25 | 6023 26 | 5945 27 | 5396 28 | 6407 29 | 6084 30 | 5513 31 | 5001 32 | 5283 33 | 5153 34 | 5127 35 | 5046 36 | 5355 37 | 5469 38 | 5679 39 | 6081 40 | 6083 41 | 6118 42 | 5813 43 | 5677 44 | 5209 45 | 5091 46 | 6338 47 | 5465 48 | 5171 49 | 5801 50 | 6329 51 | 5970 52 | 6255 53 | 5870 54 | 5531 55 | 5523 56 | 5580 57 | 6096 58 | 6413 59 | 5413 60 | 5056 61 | 5781 62 | 6038 63 | 5187 64 | 5059 65 | 6164 66 | 5018 67 | 6039 68 | 6303 69 | 6144 70 | 6088 71 | 5972 72 | 5733 73 | 5210 74 | 6340 75 | 5861 76 | 6414 77 | 5550 78 | 5992 79 | 5842 80 | 5035 81 | 5694 82 | 5183 83 | 5699 84 | 6080 85 | 5621 86 | 5810 87 | 5770 88 | 5636 89 | 5212 90 | 5014 91 | 5975 92 | 6145 93 | 6202 94 | 5198 95 | 5168 96 | 6277 97 | 6103 98 | 5697 99 | 6332 100 | 5189 101 | 5718 102 | 6384 103 | 5194 104 | 5583 105 | 5079 106 | 5776 107 | 5593 108 | 5634 109 | 5871 110 | 5566 111 | 5280 112 | 6107 113 | 6432 114 | 5388 115 | 5569 116 | 5119 117 | 6257 118 | 5087 119 | 5351 120 | 5837 121 | 5663 122 | 6262 123 | 5539 124 | 6291 125 | 5335 126 | 5180 127 | 5633 128 | 5579 129 | 5858 130 | 6443 131 | 5764 132 | 5713 133 | 6147 134 | 5015 135 | 5384 136 | 5326 137 | 5352 138 | 6117 139 | 5357 140 | 6347 141 | 5511 142 | 5016 143 | 6010 144 | 5002 145 | 5132 146 | 5533 147 | 6153 148 | 5220 149 | 5078 150 | 5030 151 | 6349 152 | 6400 153 | 5174 154 | 5334 155 | 6335 156 | 5476 157 | 6365 158 | 6410 159 | 5551 160 | 5768 161 | 5717 162 | 5785 163 | 5076 164 | 5330 165 | 5126 166 | 6098 167 | 5386 168 | 6058 169 | 5086 170 | 6330 171 | 5570 172 | 6109 173 | 5477 174 | 5620 175 | 5389 176 | 5917 177 | 6146 178 | 5725 179 | 5038 180 | 5555 181 | 5057 182 | 5521 183 | 5516 184 | 5862 185 | 6162 186 | 5775 187 | 5771 188 | 5766 189 | 5299 190 | 6287 191 | 6248 192 | 5846 193 | 6135 194 | 6228 195 | 6356 196 | 5571 197 | 5537 198 | 5190 199 | 6207 200 | 5606 201 | 5567 202 | 5411 203 | 5965 204 | 5926 205 | 6424 206 | 5686 207 | 5039 208 | 5710 209 | 5296 210 | 5727 211 | 6247 212 | 6180 213 | 6127 214 | 5843 215 | 5967 216 | 6285 217 | 6152 218 | 6369 219 | 5927 220 | 6388 221 | 5772 222 | 5786 223 | 6396 224 | 5517 225 | 6104 226 | 6124 227 | 6399 228 | 5031 229 | 6256 230 | 6353 231 | 5645 232 | 5033 233 | 5211 234 | 5515 235 | 5617 236 | 5690 237 | 5524 238 | 5538 239 | 5763 240 | 5271 241 | 5298 242 | 5782 243 | 5814 244 | 6119 245 | 5591 246 | 6048 247 | 5959 248 | 6254 249 | 5134 250 | 6130 251 | 6304 252 | 5708 253 | 5118 254 | 5839 255 | 5562 256 | 6279 257 | 6091 258 | 5062 259 | 6206 260 | 5359 261 | 5994 262 | 6249 263 | 5202 264 | 5196 265 | 5549 266 | 5397 267 | 5559 268 | 5803 269 | 5463 270 | 6089 271 | 6102 272 | 5358 273 | 5557 274 | 5316 275 | 5769 276 | 5470 277 | 5995 278 | 6079 279 | 5395 280 | 6170 281 | 5800 282 | 5432 283 | 6165 284 | 5637 285 | 6108 286 | 6171 287 | 6175 288 | 5037 289 | 5568 290 | 6278 291 | 6401 292 | 6205 293 | 5446 294 | 5508 295 | 5976 296 | 5851 297 | 6148 298 | 5184 299 | 6411 300 | 6409 301 | 5519 302 | 6090 303 | 6446 304 | 6021 305 | 5869 306 | 6294 307 | 6423 308 | 6354 309 | 5777 310 | 6211 311 | 6193 312 | 5907 313 | 5656 314 | 6204 315 | 5328 316 | 6305 317 | 5061 318 | 6078 319 | 6302 320 | 6004 321 | 6129 322 | 6408 323 | 6128 324 | 5761 325 | 6433 326 | 6156 327 | 6126 328 | 6093 329 | 6136 330 | 5155 331 | 5362 332 | 5520 333 | 5195 334 | 6220 335 | 6355 336 | 6034 337 | 5471 338 | 5250 339 | 5784 340 | 6258 341 | 6131 342 | 6261 343 | 6337 344 | 5063 345 | 5186 346 | 5823 347 | 5207 348 | 5650 349 | 5188 350 | 5822 351 | 5041 352 | 6260 353 | 5208 354 | 5734 355 | 6412 356 | 6297 357 | 6441 358 | 5612 359 | 5778 360 | 5838 361 | 5835 362 | 6216 363 | 5435 364 | 5961 365 | 5518 366 | 5509 367 | 5689 368 | 6442 369 | 6049 370 | 5021 371 | 6449 372 | 5688 373 | 5329 374 | 6052 375 | 5857 376 | 5934 377 | 5592 378 | 5767 379 | 5638 380 | 5431 381 | 5327 382 | 5526 383 | 6053 384 | 5973 385 | 5706 386 | 5363 387 | 5445 388 | 5933 389 | 5036 390 | 6250 391 | 5284 392 | 5137 393 | 6275 394 | 6295 395 | 5671 396 | 6391 397 | 6444 398 | 5332 399 | 6234 400 | 5556 401 | 6210 402 | 5433 403 | 6208 404 | 5657 405 | 5833 406 | 5028 407 | 5759 408 | 5560 409 | 5385 410 | 5090 411 | 6395 412 | 5297 413 | 5726 414 | 6286 415 | 5361 416 | 5285 417 | 6331 418 | 5301 419 | 5960 420 | 5029 421 | 6308 422 | 6336 423 | 5946 424 | 6288 425 | 6149 426 | 5201 427 | 5928 428 | 5664 429 | 5279 430 | 5693 431 | 5302 432 | 5154 433 | 5117 434 | 6386 435 | 6150 436 | 6265 437 | 5783 438 | 5773 439 | 6166 440 | 6398 441 | 6209 442 | 6167 443 | 5282 444 | 5522 445 | 6447 446 | 5977 447 | 5473 448 | 5672 449 | 6307 450 | 6339 451 | 6092 452 | 5129 453 | 5821 454 | 5670 455 | 5651 456 | 6235 457 | 5658 458 | 5475 459 | 5364 460 | 6306 461 | 5441 462 | 5317 463 | 5448 464 | 5191 465 | 5510 466 | 5273 467 | 5558 468 | 6290 469 | 6075 470 | 6181 471 | 5845 472 | 5860 473 | 5728 474 | 5185 475 | 5182 476 | 5356 477 | 6100 478 | 5932 479 | 6003 480 | 5676 481 | 5802 482 | 5762 483 | 6368 484 | 6101 485 | 5993 486 | 6445 487 | 6229 488 | 5712 489 | 5464 490 | 5199 491 | 6212 492 | 6233 493 | 6385 494 | 5744 495 | 5687 496 | 6421 497 | 5947 498 | 5962 499 | 5088 500 | 5315 501 | 5594 502 | 6094 503 | 6203 504 | 5387 505 | 5681 506 | 6389 507 | 6431 508 | 6194 509 | 5698 510 | 5060 511 | 6430 512 | 6280 513 | 6263 514 | 5765 515 | 5561 516 | 5971 517 | 5032 518 | 5779 519 | 5603 520 | 6095 521 | 5724 522 | 5604 523 | 5192 524 | 5906 525 | 6057 526 | 5042 527 | 5197 528 | 5311 529 | 6011 530 | 5181 531 | 5774 532 | 5221 533 | 5525 534 | 5732 535 | 6012 536 | 5804 537 | 5613 538 | 6174 539 | 5300 540 | 5462 541 | 5167 542 | 5414 543 | 6157 544 | 5680 545 | 6390 546 | 6123 547 | 6077 548 | 5991 549 | 6195 550 | 6001 551 | 5619 552 | 5707 553 | 6201 554 | 5472 555 | 5077 556 | 6219 557 | 6218 558 | 5089 559 | 5084 560 | 5325 561 | 5047 562 | 5281 563 | 5605 564 | 6387 565 | 5668 566 | 5333 567 | 5731 568 | 5564 569 | 5644 570 | 5709 571 | 6179 572 | 5310 573 | 5678 574 | 6397 575 | 6448 576 | 5442 577 | 5264 578 | 5859 579 | 5331 580 | 6276 581 | 6394 582 | 6192 583 | 5974 584 | 5173 585 | 6176 586 | 5125 587 | 6154 588 | 5172 589 | 6230 590 | 6125 591 | 5175 592 | 5918 593 | 6299 594 | 5512 595 | 5743 596 | 6315 597 | 6184 598 | 5200 599 | 6106 600 | 6259 601 | 5390 602 | 5565 603 | 5169 604 | 5443 605 | 5222 606 | 5017 607 | 5176 608 | 5043 609 | 6264 610 | 6163 611 | 6033 612 | 5834 613 | 6292 614 | 5673 615 | 6227 616 | 5447 617 | 5787 618 | 5444 619 | 5607 620 | 5360 621 | 6158 622 | 5412 623 | 5581 624 | 5582 625 | 5635 626 | 5466 627 | 6226 628 | 5844 629 | 5272 630 | 5034 631 | 5811 632 | 6293 633 | 5836 634 | 5852 635 | 6348 636 | 5908 637 | 6076 638 | 5040 639 | 6196 640 | 5919 641 | 6099 642 | 5128 643 | 6183 644 | 5474 645 | 5085 646 | 6289 647 | 5563 648 | 5841 649 | 6182 650 | 5812 651 | 5131 652 | 5935 653 | 5966 654 | 5618 -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/nyud/train.txt: -------------------------------------------------------------------------------- 1 | 5449 2 | 6140 3 | 5902 4 | 5543 5 | 6392 6 | 5425 7 | 5121 8 | 5506 9 | 5696 10 | 6239 11 | 6143 12 | 5485 13 | 5990 14 | 5322 15 | 6138 16 | 5986 17 | 5756 18 | 5323 19 | 5158 20 | 5921 21 | 5855 22 | 5478 23 | 5898 24 | 5415 25 | 6054 26 | 5161 27 | 5318 28 | 5218 29 | 5460 30 | 6056 31 | 6313 32 | 5595 33 | 5256 34 | 5353 35 | 5044 36 | 5177 37 | 6029 38 | 5980 39 | 5493 40 | 5528 41 | 5904 42 | 5895 43 | 5881 44 | 5275 45 | 5829 46 | 5426 47 | 6334 48 | 5548 49 | 5988 50 | 5714 51 | 5254 52 | 5309 53 | 5253 54 | 5255 55 | 5983 56 | 5752 57 | 5005 58 | 6240 59 | 5546 60 | 5695 61 | 5684 62 | 5751 63 | 6274 64 | 5882 65 | 5730 66 | 5495 67 | 5489 68 | 5749 69 | 6244 70 | 5599 71 | 5503 72 | 5319 73 | 5418 74 | 5454 75 | 5937 76 | 5416 77 | 5989 78 | 5505 79 | 6352 80 | 6237 81 | 6139 82 | 5901 83 | 5421 84 | 5498 85 | 5602 86 | 5083 87 | 5944 88 | 5456 89 | 6122 90 | 6333 91 | 5417 92 | 5981 93 | 5165 94 | 6417 95 | 5758 96 | 5527 97 | 5082 98 | 5805 99 | 5308 100 | 5828 101 | 5120 102 | 5214 103 | 5530 104 | 6026 105 | 5452 106 | 5008 107 | 5251 108 | 6047 109 | 6238 110 | 6008 111 | 5925 112 | 5873 113 | 6366 114 | 5156 115 | 5875 116 | 6311 117 | 6224 118 | 6169 119 | 5922 120 | 5877 121 | 5615 122 | 5896 123 | 5715 124 | 5890 125 | 6141 126 | 5179 127 | 5215 128 | 5685 129 | 6246 130 | 5641 131 | 5058 132 | 5807 133 | 5122 134 | 5423 135 | 5716 136 | 5652 137 | 5262 138 | 5978 139 | 5429 140 | 5542 141 | 5598 142 | 5984 143 | 5354 144 | 5261 145 | 6044 146 | 5003 147 | 5888 148 | 5422 149 | 5124 150 | 5219 151 | 6009 152 | 6087 153 | 5892 154 | 6168 155 | 5616 156 | 5754 157 | 5547 158 | 5393 159 | 5889 160 | 5750 161 | 5963 162 | 5500 163 | 5004 164 | 5303 165 | 6269 166 | 6243 167 | 5885 168 | 5019 169 | 5757 170 | 6267 171 | 5809 172 | 5321 173 | 5529 174 | 5643 175 | 5748 176 | 5501 177 | 6137 178 | 5213 179 | 5259 180 | 5596 181 | 5745 182 | 5653 183 | 6418 184 | 5507 185 | 5136 186 | 5453 187 | 6367 188 | 5544 189 | 6046 190 | 6271 191 | 5252 192 | 5488 193 | 5480 194 | 5080 195 | 5504 196 | 5274 197 | 5578 198 | 5920 199 | 5654 200 | 5924 201 | 5260 202 | 5394 203 | 6041 204 | 5263 205 | 6223 206 | 5642 207 | 6121 208 | 5497 209 | 5939 210 | 5491 211 | 5825 212 | 5753 213 | 5320 214 | 5487 215 | 6042 216 | 6270 217 | 5940 218 | 5157 219 | 5479 220 | 5496 221 | 5639 222 | 5392 223 | 6177 224 | 5614 225 | 5451 226 | 6312 227 | 6199 228 | 5667 229 | 5666 230 | 6198 231 | 5006 232 | 5427 233 | 5887 234 | 5755 235 | 6200 236 | 5461 237 | 6120 238 | 5982 239 | 6416 240 | 5277 241 | 5884 242 | 6142 243 | 6268 244 | 5880 245 | 6266 246 | 5166 247 | 5258 248 | 5420 249 | 5490 250 | 5135 251 | 5655 252 | 5391 253 | 5682 254 | 5853 255 | 5905 256 | 6045 257 | 5576 258 | 5827 259 | 5492 260 | 5943 261 | 5574 262 | 5307 263 | 5428 264 | 5874 265 | 6006 266 | 5458 267 | 5883 268 | 6030 269 | 5808 270 | 5964 271 | 5305 272 | 5159 273 | 5540 274 | 6178 275 | 6024 276 | 5484 277 | 5832 278 | 6031 279 | 5459 280 | 6028 281 | 5729 282 | 5601 283 | 6415 284 | 5483 285 | 5324 286 | 5894 287 | 5830 288 | 6025 289 | 5854 290 | 5164 291 | 6350 292 | 5903 293 | 6296 294 | 5600 295 | 5486 296 | 5007 297 | 6055 298 | 5747 299 | 5872 300 | 5856 301 | 5482 302 | 5424 303 | 5987 304 | 6222 305 | 5597 306 | 5876 307 | 5824 308 | 5178 309 | 6085 310 | 5979 311 | 6197 312 | 5985 313 | 5572 314 | 5899 315 | 5020 316 | 6241 317 | 5276 318 | 5938 319 | 5806 320 | 6272 321 | 6043 322 | 5502 323 | 5893 324 | 6105 325 | 5160 326 | 5886 327 | 6007 328 | 5923 329 | 5942 330 | 5665 331 | 6225 332 | 5577 333 | 5257 334 | 6273 335 | 5481 336 | 5162 337 | 5217 338 | 5457 339 | 6245 340 | 5879 341 | 6005 342 | 6309 343 | 5575 344 | 5494 345 | 5900 346 | 5216 347 | 5304 348 | 5499 349 | 5746 350 | 5545 351 | 5045 352 | 6236 353 | 5278 354 | 6242 355 | 5123 356 | 5450 357 | 5306 358 | 5419 359 | 5897 360 | 5831 361 | 6086 362 | 5891 363 | 5455 364 | 6351 365 | 5878 366 | 5826 367 | 5081 368 | 6420 369 | 6393 370 | 6040 371 | 5573 372 | 6310 373 | 5640 374 | 5936 375 | 5541 376 | 6221 377 | 5163 378 | 6027 379 | 5941 380 | 5683 381 | 6419 -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/nyud/val.txt: -------------------------------------------------------------------------------- 1 | 5010 2 | 5011 3 | 5012 4 | 5013 5 | 5022 6 | 5023 7 | 5024 8 | 5025 9 | 5026 10 | 5027 11 | 5048 12 | 5049 13 | 5050 14 | 5051 15 | 5052 16 | 5053 17 | 5054 18 | 5055 19 | 5064 20 | 5065 21 | 5066 22 | 5067 23 | 5068 24 | 5069 25 | 5070 26 | 5071 27 | 5072 28 | 5073 29 | 5074 30 | 5075 31 | 5092 32 | 5093 33 | 5094 34 | 5095 35 | 5096 36 | 5097 37 | 5098 38 | 5099 39 | 5100 40 | 5101 41 | 5102 42 | 5103 43 | 5104 44 | 5105 45 | 5106 46 | 5107 47 | 5108 48 | 5109 49 | 5110 50 | 5111 51 | 5112 52 | 5113 53 | 5114 54 | 5115 55 | 5116 56 | 5130 57 | 5138 58 | 5139 59 | 5140 60 | 5141 61 | 5142 62 | 5143 63 | 5144 64 | 5145 65 | 5146 66 | 5147 67 | 5148 68 | 5149 69 | 5150 70 | 5151 71 | 5152 72 | 5170 73 | 5203 74 | 5204 75 | 5205 76 | 5206 77 | 5223 78 | 5224 79 | 5225 80 | 5226 81 | 5227 82 | 5228 83 | 5229 84 | 5230 85 | 5231 86 | 5232 87 | 5233 88 | 5234 89 | 5235 90 | 5236 91 | 5237 92 | 5238 93 | 5239 94 | 5240 95 | 5241 96 | 5242 97 | 5243 98 | 5244 99 | 5245 100 | 5246 101 | 5247 102 | 5248 103 | 5249 104 | 5265 105 | 5266 106 | 5267 107 | 5268 108 | 5269 109 | 5270 110 | 5286 111 | 5287 112 | 5288 113 | 5289 114 | 5290 115 | 5291 116 | 5292 117 | 5293 118 | 5294 119 | 5295 120 | 5313 121 | 5314 122 | 5336 123 | 5337 124 | 5338 125 | 5339 126 | 5340 127 | 5341 128 | 5342 129 | 5343 130 | 5344 131 | 5345 132 | 5346 133 | 5347 134 | 5348 135 | 5349 136 | 5350 137 | 5365 138 | 5366 139 | 5367 140 | 5368 141 | 5369 142 | 5370 143 | 5371 144 | 5372 145 | 5373 146 | 5374 147 | 5375 148 | 5376 149 | 5377 150 | 5378 151 | 5379 152 | 5380 153 | 5381 154 | 5382 155 | 5383 156 | 5398 157 | 5399 158 | 5400 159 | 5401 160 | 5402 161 | 5403 162 | 5404 163 | 5405 164 | 5406 165 | 5407 166 | 5408 167 | 5409 168 | 5410 169 | 5436 170 | 5437 171 | 5438 172 | 5439 173 | 5440 174 | 5467 175 | 5468 176 | 5514 177 | 5534 178 | 5535 179 | 5536 180 | 5552 181 | 5553 182 | 5554 183 | 5584 184 | 5585 185 | 5586 186 | 5587 187 | 5588 188 | 5589 189 | 5590 190 | 5608 191 | 5609 192 | 5610 193 | 5611 194 | 5622 195 | 5623 196 | 5624 197 | 5625 198 | 5626 199 | 5627 200 | 5628 201 | 5629 202 | 5630 203 | 5631 204 | 5632 205 | 5646 206 | 5647 207 | 5648 208 | 5649 209 | 5659 210 | 5660 211 | 5661 212 | 5662 213 | 5674 214 | 5675 215 | 5691 216 | 5692 217 | 5700 218 | 5701 219 | 5702 220 | 5703 221 | 5704 222 | 5705 223 | 5719 224 | 5720 225 | 5721 226 | 5722 227 | 5723 228 | 5735 229 | 5736 230 | 5737 231 | 5738 232 | 5739 233 | 5740 234 | 5741 235 | 5742 236 | 5788 237 | 5789 238 | 5790 239 | 5791 240 | 5792 241 | 5793 242 | 5794 243 | 5795 244 | 5796 245 | 5797 246 | 5798 247 | 5799 248 | 5815 249 | 5816 250 | 5817 251 | 5818 252 | 5819 253 | 5820 254 | 5847 255 | 5848 256 | 5849 257 | 5863 258 | 5864 259 | 5865 260 | 5866 261 | 5867 262 | 5868 263 | 5909 264 | 5910 265 | 5911 266 | 5912 267 | 5913 268 | 5914 269 | 5915 270 | 5916 271 | 5929 272 | 5930 273 | 5931 274 | 5948 275 | 5949 276 | 5950 277 | 5951 278 | 5952 279 | 5953 280 | 5954 281 | 5955 282 | 5956 283 | 5957 284 | 5958 285 | 5968 286 | 5969 287 | 5996 288 | 5997 289 | 5998 290 | 5999 291 | 6000 292 | 6013 293 | 6014 294 | 6015 295 | 6016 296 | 6017 297 | 6018 298 | 6019 299 | 6020 300 | 6035 301 | 6036 302 | 6037 303 | 6050 304 | 6051 305 | 6059 306 | 6060 307 | 6061 308 | 6062 309 | 6063 310 | 6064 311 | 6065 312 | 6066 313 | 6067 314 | 6068 315 | 6069 316 | 6070 317 | 6071 318 | 6072 319 | 6073 320 | 6074 321 | 6097 322 | 6110 323 | 6111 324 | 6112 325 | 6113 326 | 6114 327 | 6115 328 | 6116 329 | 6132 330 | 6133 331 | 6134 332 | 6159 333 | 6160 334 | 6161 335 | 6172 336 | 6173 337 | 6185 338 | 6186 339 | 6187 340 | 6188 341 | 6189 342 | 6190 343 | 6191 344 | 6213 345 | 6214 346 | 6215 347 | 6231 348 | 6232 349 | 6251 350 | 6252 351 | 6253 352 | 6281 353 | 6282 354 | 6283 355 | 6284 356 | 6300 357 | 6301 358 | 6316 359 | 6317 360 | 6318 361 | 6319 362 | 6320 363 | 6321 364 | 6322 365 | 6323 366 | 6324 367 | 6325 368 | 6326 369 | 6327 370 | 6328 371 | 6341 372 | 6342 373 | 6343 374 | 6344 375 | 6345 376 | 6346 377 | 6357 378 | 6358 379 | 6359 380 | 6360 381 | 6361 382 | 6362 383 | 6363 384 | 6370 385 | 6371 386 | 6372 387 | 6373 388 | 6374 389 | 6375 390 | 6376 391 | 6377 392 | 6378 393 | 6379 394 | 6380 395 | 6381 396 | 6382 397 | 6383 398 | 6402 399 | 6403 400 | 6404 401 | 6405 402 | 6406 403 | 6425 404 | 6426 405 | 6427 406 | 6428 407 | 6429 408 | 6434 409 | 6435 410 | 6436 411 | 6437 412 | 6438 413 | 6439 414 | 6440 415 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/pascal-context/README.md: -------------------------------------------------------------------------------- 1 | # PASCAL-Context 2 | 3 | PASCAL-Context is a full object and scene labeling of PASCAL VOC 2010. 4 | It includes both object (cat, dog, ...) and surface (sky, grass, ...) classes. 5 | 6 | We follow the 59 class task defined by 7 | 8 | > The Role of Context for Object Detection and Semantic Segmentation in the Wild. 9 | Roozbeh Mottaghi, Xianjie Chen, Xiaobai Liu, Nam-Gyu Cho, Seong-Whan Lee, Sanja Fidler, Raquel Urtasun, and Alan Yuille. 10 | CVPR 2014 11 | 12 | which selects the 59 most common classes for learning and evaluation. 13 | 14 | Refer to `classes-59.txt` for the listing of classes in model output order. 15 | Refer to `../pascalcontext_layers.py` for the Python data layer for this dataset. 16 | 17 | Note that care must be taken to map the raw class annotations into the 59 class task, as handled by our data layer. 18 | 19 | See the dataset site: http://www.cs.stanford.edu/~roozbeh/pascal-context/ 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/pascal-context/classes-59.txt: -------------------------------------------------------------------------------- 1 | 0: background 2 | 1: aeroplane 3 | 2: bicycle 4 | 3: bird 5 | 4: boat 6 | 5: bottle 7 | 6: bus 8 | 7: car 9 | 8: cat 10 | 9: chair 11 | 10: cow 12 | 11: diningtable 13 | 12: dog 14 | 13: horse 15 | 14: motorbike 16 | 15: person 17 | 16: pottedplant 18 | 17: sheep 19 | 18: sofa 20 | 19: train 21 | 20: tvmonitor 22 | 21: bag 23 | 22: bed 24 | 23: bench 25 | 24: book 26 | 25: building 27 | 26: cabinet 28 | 27: ceiling 29 | 28: clothes 30 | 29: computer 31 | 30: cup 32 | 31: door 33 | 32: fence 34 | 33: floor 35 | 34: flower 36 | 35: food 37 | 36: grass 38 | 37: ground 39 | 38: keyboard 40 | 39: light 41 | 40: mountain 42 | 41: mouse 43 | 42: curtain 44 | 43: platform 45 | 44: sign 46 | 45: plate 47 | 46: road 48 | 47: rock 49 | 48: shelves 50 | 49: sidewalk 51 | 50: sky 52 | 51: snow 53 | 52: bedcloth 54 | 53: track 55 | 54: tree 56 | 55: truck 57 | 56: wall 58 | 57: water 59 | 58: window 60 | 59: wood 61 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/pascal/README.md: -------------------------------------------------------------------------------- 1 | # PASCAL VOC and SBD 2 | 3 | PASCAL VOC is a standard recognition dataset and benchmark with detection and semantic segmentation challenges. 4 | The semantic segmentation challenge annotates 20 object classes and background. 5 | The Semantic Boundary Dataset (SBD) is a further annotation of the PASCAL VOC data that provides more semantic segmentation and instance segmentation masks. 6 | 7 | PASCAL VOC has a private test set and [leaderboard for semantic segmentation](http://host.robots.ox.ac.uk:8080/leaderboard/displaylb.php?challengeid=11&compid=6). 8 | 9 | The train/val/test splits of PASCAL VOC segmentation challenge and SBD diverge. 10 | Most notably VOC 2011 segval intersects with SBD train. 11 | Care must be taken for proper evaluation by excluding images from the train or val splits. 12 | 13 | We train on the 8,498 images of SBD train. 14 | We validate on the non-intersecting set defined in the included `seg11valid.txt`. 15 | 16 | Refer to `classes.txt` for the listing of classes in model output order. 17 | Refer to `../voc_layers.py` for the Python data layer for this dataset. 18 | 19 | See the dataset sites for download: 20 | 21 | - PASCAL VOC 2012: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/ 22 | - SBD: see [homepage](http://home.bharathh.info/home/sbd) or [direct download](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz) 23 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/pascal/classes.txt: -------------------------------------------------------------------------------- 1 | background 2 | aeroplane 3 | bicycle 4 | bird 5 | boat 6 | bottle 7 | bus 8 | car 9 | cat 10 | chair 11 | cow 12 | diningtable 13 | dog 14 | horse 15 | motorbike 16 | person 17 | pottedplant 18 | sheep 19 | sofa 20 | train 21 | tvmonitor 22 | 23 | and 255 is the ignore label that marks pixels excluded from learning and 24 | evaluation by the PASCAL VOC ground truth. 25 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/sift-flow/README.md: -------------------------------------------------------------------------------- 1 | # SIFT Flow 2 | 3 | SIFT Flow is a semantic segmentation dataset with two labelings: 4 | 5 | - semantic classes, such as "cat" or "dog" 6 | - geometric classes, consisting of "horizontal, vertical, and sky" 7 | 8 | Refer to `classes.txt` for the listing of classes in model output order. 9 | Refer to `../siftflow_layers.py` for the Python data layer for this dataset. 10 | 11 | Note that the dataset has a number of issues, including unannotated images and missing classes from the test set. 12 | The provided splits exclude the unannotated images. 13 | As noted in the paper, care must be taken for proper evalution by excluding the missing classes. 14 | 15 | Download the dataset: 16 | http://www.cs.unc.edu/~jtighe/Papers/ECCV10/siftflow/SiftFlowDataset.zip 17 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/sift-flow/classes.txt: -------------------------------------------------------------------------------- 1 | Semantic and geometric segmentation classes for scenes. 2 | 3 | Semantic: 0 is void and 1–33 are classes. 4 | 5 | 01 awning 6 | 02 balcony 7 | 03 bird 8 | 04 boat 9 | 05 bridge 10 | 06 building 11 | 07 bus 12 | 08 car 13 | 09 cow 14 | 10 crosswalk 15 | 11 desert 16 | 12 door 17 | 13 fence 18 | 14 field 19 | 15 grass 20 | 16 moon 21 | 17 mountain 22 | 18 person 23 | 19 plant 24 | 20 pole 25 | 21 river 26 | 22 road 27 | 23 rock 28 | 24 sand 29 | 25 sea 30 | 26 sidewalk 31 | 27 sign 32 | 28 sky 33 | 29 staircase 34 | 30 streetlight 35 | 31 sun 36 | 32 tree 37 | 33 window 38 | 39 | Geometric: -1 is void and 1–3 are classes. 40 | 41 | 01 sky 42 | 02 horizontal 43 | 03 vertical 44 | 45 | N.B. Three classes (cow, desert, and moon) are absent from the test set, so 46 | they are excluded from evaluation. The highway_bost181 and street_urb506 images 47 | are missing annotations so these are likewise excluded from evaluation. 48 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/data/sift-flow/test.txt: -------------------------------------------------------------------------------- 1 | coast_natu975 2 | insidecity_art947 3 | insidecity_urb781 4 | highway_bost374 5 | coast_n203085 6 | insidecity_a223049 7 | mountain_nat116 8 | street_art861 9 | mountain_land188 10 | street_par177 11 | opencountry_natu524 12 | forest_natu29 13 | highway_gre37 14 | street_bost77 15 | insidecity_art1125 16 | street_urb521 17 | highway_bost178 18 | street_art760 19 | street_urb885 20 | insidecity_art829 21 | coast_natu804 22 | mountain_sharp44 23 | coast_natu649 24 | opencountry_land691 25 | insidecity_hous35 26 | tallbuilding_art1719 27 | mountain_n736026 28 | mountain_moun41 29 | insidecity_urban992 30 | opencountry_land295 31 | tallbuilding_art527 32 | highway_art238 33 | forest_for114 34 | coast_land296 35 | tallbuilding_sky7 36 | mountain_n44009 37 | tallbuilding_art1316 38 | forest_nat717 39 | highway_bost164 40 | street_par29 41 | forest_natc52 42 | tallbuilding_art1004 43 | coast_sun14 44 | opencountry_land206 45 | opencountry_land364 46 | mountain_n219015 47 | highway_a836030 48 | forest_nat324 49 | opencountry_land493 50 | insidecity_art1598 51 | street_street27 52 | insidecity_a48009 53 | coast_cdmc889 54 | street_gre295 55 | tallbuilding_a538076 56 | street_boston378 57 | highway_urb759 58 | street_par151 59 | tallbuilding_urban1003 60 | tallbuilding_urban16 61 | highway_bost151 62 | opencountry_nat965 63 | highway_gre661 64 | forest_for42 65 | opencountry_n18002 66 | insidecity_art646 67 | highway_gre55 68 | coast_n295051 69 | forest_bost103 70 | highway_n480036 71 | mountain_land4 72 | forest_nat130 73 | coast_nat643 74 | insidecity_urb250 75 | street_gre11 76 | street_boston271 77 | opencountry_n490003 78 | mountain_nat762 79 | street_par86 80 | coast_arnat59 81 | mountain_land787 82 | highway_gre472 83 | opencountry_tell67 84 | mountain_sharp66 85 | opencountry_land534 86 | insidecity_gre290 87 | highway_bost307 88 | opencountry_n213059 89 | forest_nat220 90 | forest_cdmc348 91 | tallbuilding_art900 92 | insidecity_art569 93 | street_urb200 94 | coast_natu468 95 | coast_n672069 96 | insidecity_hous109 97 | forest_land862 98 | opencountry_natu65 99 | tallbuilding_a805096 100 | opencountry_n291058 101 | forest_natu439 102 | coast_nat799 103 | tallbuilding_urban991 104 | tallbuilding_sky17 105 | opencountry_land638 106 | opencountry_natu563 107 | tallbuilding_urb733 108 | forest_cdmc451 109 | mountain_n371066 110 | mountain_n213081 111 | mountain_nat57 112 | tallbuilding_a463068 113 | forest_natu848 114 | tallbuilding_art306 115 | insidecity_boston92 116 | insidecity_urb584 117 | tallbuilding_urban1126 118 | coast_n286045 119 | street_gre179 120 | coast_nat1091 121 | opencountry_nat615 122 | coast_nat901 123 | forest_cdmc291 124 | mountain_natu568 125 | mountain_n18070 126 | street_bost136 127 | tallbuilding_art425 128 | coast_bea3 129 | tallbuilding_art1616 130 | insidecity_art690 131 | highway_gre492 132 | highway_bost320 133 | forest_nat400 134 | highway_par23 135 | tallbuilding_a212033 136 | forest_natu994 137 | tallbuilding_archi296 138 | highway_gre413 139 | tallbuilding_a279033 140 | insidecity_art1277 141 | coast_cdmc948 142 | forest_for15 143 | street_par68 144 | mountain_natu786 145 | opencountry_open61 146 | opencountry_nat423 147 | mountain_land143 148 | tallbuilding_a487066 149 | tallbuilding_art1751 150 | insidecity_hous79 151 | street_par118 152 | highway_bost293 153 | mountain_n213021 154 | opencountry_nat802 155 | coast_n384099 156 | opencountry_natu998 157 | mountain_n344042 158 | coast_nat1265 159 | forest_text44 160 | forest_for84 161 | insidecity_a807066 162 | opencountry_nat1117 163 | coast_sun42 164 | insidecity_par180 165 | opencountry_land923 166 | highway_art580 167 | street_art1328 168 | coast_cdmc838 169 | opencountry_land660 170 | opencountry_cdmc354 171 | coast_natu825 172 | opencountry_natu38 173 | mountain_nat30 174 | coast_n199066 175 | forest_text124 176 | forest_land222 177 | tallbuilding_city56 178 | tallbuilding_city22 179 | opencountry_fie36 180 | mountain_ski24 181 | coast_cdmc997 182 | insidecity_boston232 183 | opencountry_land575 184 | opencountry_land797 185 | insidecity_urb362 186 | forest_nat1033 187 | mountain_nat891 188 | street_hexp3 189 | tallbuilding_art1474 190 | tallbuilding_urban73 191 | opencountry_natu852 192 | mountain_nat1008 193 | coast_nat294 194 | mountain_sharp20 195 | opencountry_fie14 196 | mountain_land275 197 | forest_land760 198 | coast_land374 199 | mountain_nat426 200 | highway_gre141 -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/ilsvrc-nets/README.md: -------------------------------------------------------------------------------- 1 | # ILSVRC Networks 2 | 3 | These classification networks are trained on ILSVRC for object recognition. 4 | We cast these nets into fully convolutional form to make use of their parameters as pre-training. 5 | 6 | To reproduce our FCNs, or train your own on your own data, you need to first collect the corresponding base network. 7 | 8 | - [VGG16](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md) 9 | - [CaffeNet](https://github.com/BVLC/caffe/tree/master/models/bvlc_reference_caffenet) 10 | - [BVLC GoogLeNet](https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet) 11 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/infer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import caffe 5 | 6 | # load image, switch to BGR, subtract mean, and make dims C x H x W for Caffe 7 | im = Image.open('pascal/VOC2010/JPEGImages/2007_000129.jpg') 8 | in_ = np.array(im, dtype=np.float32) 9 | in_ = in_[:,:,::-1] 10 | in_ -= np.array((104.00698793,116.66876762,122.67891434)) 11 | in_ = in_.transpose((2,0,1)) 12 | 13 | # load net 14 | net = caffe.Net('voc-fcn8s/deploy.prototxt', 'voc-fcn8s/fcn8s-heavy-pascal.caffemodel', caffe.TEST) 15 | # shape for input (data blob is N x C x H x W), set data 16 | net.blobs['data'].reshape(1, *in_.shape) 17 | net.blobs['data'].data[...] = in_ 18 | # run net and take argmax for prediction 19 | net.forward() 20 | out = net.blobs['score'].data[0].argmax(axis=0) 21 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color-d/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split, tops): 15 | n = caffe.NetSpec() 16 | n.color, n.depth, n.label = L.Python(module='nyud_layers', 17 | layer='NYUDSegDataLayer', ntop=3, 18 | param_str=str(dict(nyud_dir='../data/nyud', split=split, 19 | tops=tops, seed=1337))) 20 | n.data = L.Concat(n.color, n.depth) 21 | 22 | # the base net 23 | n.conv1_1_bgrd, n.relu1_1 = conv_relu(n.data, 64, pad=100) 24 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 25 | n.pool1 = max_pool(n.relu1_2) 26 | 27 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 28 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 29 | n.pool2 = max_pool(n.relu2_2) 30 | 31 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 32 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 33 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 34 | n.pool3 = max_pool(n.relu3_3) 35 | 36 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 37 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 38 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 39 | n.pool4 = max_pool(n.relu4_3) 40 | 41 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 42 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 43 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 44 | n.pool5 = max_pool(n.relu5_3) 45 | 46 | # fully conv 47 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 48 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 49 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 50 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 51 | 52 | n.score_fr = L.Convolution(n.drop7, num_output=40, kernel_size=1, pad=0, 53 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 54 | n.upscore = L.Deconvolution(n.score_fr, 55 | convolution_param=dict(num_output=40, kernel_size=64, stride=32, 56 | bias_term=False), 57 | param=[dict(lr_mult=0)]) 58 | n.score = crop(n.upscore, n.data) 59 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 60 | loss_param=dict(normalize=False, ignore_label=255)) 61 | 62 | return n.to_proto() 63 | 64 | def make_net(): 65 | tops = ['color', 'depth', 'label'] 66 | with open('trainval.prototxt', 'w') as f: 67 | f.write(str(fcn('trainval', tops))) 68 | 69 | with open('test.prototxt', 'w') as f: 70 | f.write(str(fcn('test', tops))) 71 | 72 | if __name__ == '__main__': 73 | make_net() 74 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color-d/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../ilsvrc-nets/vgg16-fcn.caffemodel' 15 | base_net = caffe.Net('../ilsvrc-nets/vgg16fcn.prototxt', '../vgg16fc.caffemodel', 16 | caffe.TEST) 17 | 18 | # init 19 | caffe.set_device(int(sys.argv[1])) 20 | caffe.set_mode_gpu() 21 | 22 | solver = caffe.SGDSolver('solver.prototxt') 23 | surgery.transplant(solver.net, base_net) 24 | 25 | # surgeries 26 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 27 | surgery.interp(solver.net, interp_layers) 28 | 29 | solver.net.params['conv1_1_bgrd'][0].data[:, :3] = base_net.params['conv1_1'][0].data 30 | solver.net.params['conv1_1_bgrd'][0].data[:, 3] = np.mean(base_net.params['conv1_1'][0].data, axis=1) 31 | solver.net.params['conv1_1_bgrd'][1].data[...] = base_net.params['conv1_1'][1].data 32 | 33 | del base_net 34 | 35 | # scoring 36 | test = np.loadtxt('../data/nyud/test.txt', dtype=str) 37 | 38 | for _ in range(50): 39 | solver.step(2000) 40 | score.seg_tests(solver, False, val, layer='score') 41 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color-d/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "trainval.prototxt" 2 | test_net: "test.prototxt" 3 | test_iter: 654 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-10 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | snapshot: 2000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color-hha/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/nyud-fcn32s-color-hha-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color-hha/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def modality_fcn(net_spec, data, modality): 15 | n = net_spec 16 | # the base net 17 | n['conv1_1' + modality], n['relu1_1' + modality] = conv_relu(n[data], 64, 18 | pad=100) 19 | n['conv1_2' + modality], n['relu1_2' + modality] = conv_relu(n['relu1_1' + 20 | modality], 64) 21 | n['pool1' + modality] = max_pool(n['relu1_2' + modality]) 22 | 23 | n['conv2_1' + modality], n['relu2_1' + modality] = conv_relu(n['pool1' + 24 | modality], 128) 25 | n['conv2_2' + modality], n['relu2_2' + modality] = conv_relu(n['relu2_1' + 26 | modality], 128) 27 | n['pool2' + modality] = max_pool(n['relu2_2' + modality]) 28 | 29 | n['conv3_1' + modality], n['relu3_1' + modality] = conv_relu(n['pool2' + 30 | modality], 256) 31 | n['conv3_2' + modality], n['relu3_2' + modality] = conv_relu(n['relu3_1' + 32 | modality], 256) 33 | n['conv3_3' + modality], n['relu3_3' + modality] = conv_relu(n['relu3_2' + 34 | modality], 256) 35 | n['pool3' + modality] = max_pool(n['relu3_3' + modality]) 36 | 37 | n['conv4_1' + modality], n['relu4_1' + modality] = conv_relu(n['pool3' + 38 | modality], 512) 39 | n['conv4_2' + modality], n['relu4_2' + modality] = conv_relu(n['relu4_1' + 40 | modality], 512) 41 | n['conv4_3' + modality], n['relu4_3' + modality] = conv_relu(n['relu4_2' + 42 | modality], 512) 43 | n['pool4' + modality] = max_pool(n['relu4_3' + modality]) 44 | 45 | n['conv5_1' + modality], n['relu5_1' + modality] = conv_relu(n['pool4' + 46 | modality], 512) 47 | n['conv5_2' + modality], n['relu5_2' + modality] = conv_relu(n['relu5_1' + 48 | modality], 512) 49 | n['conv5_3' + modality], n['relu5_3' + modality] = conv_relu(n['relu5_2' + 50 | modality], 512) 51 | n['pool5' + modality] = max_pool(n['relu5_3' + modality]) 52 | 53 | # fully conv 54 | n['fc6' + modality], n['relu6' + modality] = conv_relu( 55 | n['pool5' + modality], 4096, ks=7, pad=0) 56 | n['drop6' + modality] = L.Dropout( 57 | n['relu6' + modality], dropout_ratio=0.5, in_place=True) 58 | n['fc7' + modality], n['relu7' + modality] = conv_relu( 59 | n['drop6' + modality], 4096, ks=1, pad=0) 60 | n['drop7' + modality] = L.Dropout( 61 | n['relu7' + modality], dropout_ratio=0.5, in_place=True) 62 | n['score_fr' + modality] = L.Convolution( 63 | n['drop7' + modality], num_output=40, kernel_size=1, pad=0, 64 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 65 | return n 66 | 67 | def fcn(split, tops): 68 | n = caffe.NetSpec() 69 | n.color, n.hha, n.label = L.Python(module='nyud_layers', 70 | layer='NYUDSegDataLayer', ntop=3, 71 | param_str=str(dict(nyud_dir='../data/nyud', split=split, 72 | tops=tops, seed=1337))) 73 | n = modality_fcn(n, 'color', 'color') 74 | n = modality_fcn(n, 'hha', 'hha') 75 | n.score_fused = L.Eltwise(n.score_frcolor, n.score_frhha, 76 | operation=P.Eltwise.SUM, coeff=[0.5, 0.5]) 77 | n.upscore = L.Deconvolution(n.score_fused, 78 | convolution_param=dict(num_output=40, kernel_size=64, stride=32, 79 | bias_term=False), 80 | param=[dict(lr_mult=0)]) 81 | n.score = crop(n.upscore, n.color) 82 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 83 | loss_param=dict(normalize=False, ignore_label=255)) 84 | return n.to_proto() 85 | 86 | def make_net(): 87 | tops = ['color', 'hha', 'label'] 88 | with open('trainval.prototxt', 'w') as f: 89 | f.write(str(fcn('trainval', tops))) 90 | 91 | with open('test.prototxt', 'w') as f: 92 | f.write(str(fcn('test', tops))) 93 | 94 | if __name__ == '__main__': 95 | make_net() 96 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color-hha/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | color_proto = '../nyud-rgb-32s/trainval.prototxt' 15 | color_weights = '../nyud-rgb-32s/nyud-rgb-32s-28k.caffemodel' 16 | hha_proto = '../nyud-hha-32s/trainval.prototxt' 17 | hha_weights = '../nyud-hha-32s/nyud-hha-32s-60k.caffemodel' 18 | 19 | # init 20 | caffe.set_device(int(sys.argv[1])) 21 | caffe.set_mode_gpu() 22 | 23 | solver = caffe.SGDSolver('solver.prototxt') 24 | 25 | # surgeries 26 | color_net = caffe.Net(color_proto, color_weights, caffe.TEST) 27 | surgery.transplant(solver.net, color_net, suffix='color') 28 | del color_net 29 | 30 | hha_net = caffe.Net(hha_proto, hha_weights, caffe.TEST) 31 | surgery.transplant(solver.net, hha_net, suffix='hha') 32 | del hha_net 33 | 34 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 35 | surgery.interp(solver.net, interp_layers) 36 | 37 | # scoring 38 | test = np.loadtxt('../data/nyud/test.txt', dtype=str) 39 | 40 | for _ in range(50): 41 | solver.step(2000) 42 | score.seg_tests(solver, False, val, layer='score') 43 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color-hha/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "trainval.prototxt" 2 | test_net: "test.prototxt" 3 | test_iter: 654 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-12 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | snapshot: 2000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/nyud-fcn32s-color-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split, tops): 15 | n = caffe.NetSpec() 16 | n.data, n.label = L.Python(module='nyud_layers', 17 | layer='NYUDSegDataLayer', ntop=2, 18 | param_str=str(dict(nyud_dir='../data/nyud', split=split, 19 | tops=tops, seed=1337))) 20 | 21 | # the base net 22 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 23 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 24 | n.pool1 = max_pool(n.relu1_2) 25 | 26 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 27 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 28 | n.pool2 = max_pool(n.relu2_2) 29 | 30 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 31 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 32 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 33 | n.pool3 = max_pool(n.relu3_3) 34 | 35 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 36 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 37 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 38 | n.pool4 = max_pool(n.relu4_3) 39 | 40 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 41 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 42 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 43 | n.pool5 = max_pool(n.relu5_3) 44 | 45 | # fully conv 46 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 47 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 48 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 49 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 50 | 51 | n.score_fr = L.Convolution(n.drop7, num_output=40, kernel_size=1, pad=0, 52 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 53 | n.upscore = L.Deconvolution(n.score_fr, 54 | convolution_param=dict(num_output=40, kernel_size=64, stride=32, 55 | bias_term=False), 56 | param=[dict(lr_mult=0)]) 57 | n.score = crop(n.upscore, n.data) 58 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 59 | loss_param=dict(normalize=False, ignore_label=255)) 60 | 61 | return n.to_proto() 62 | 63 | def make_net(): 64 | tops = ['color', 'label'] 65 | with open('trainval.prototxt', 'w') as f: 66 | f.write(str(fcn('trainval', tops))) 67 | 68 | with open('test.prototxt', 'w') as f: 69 | f.write(str(fcn('test', tops))) 70 | 71 | if __name__ == '__main__': 72 | make_net() 73 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../ilsvrc-nets/vgg16-fcn.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | test = np.loadtxt('../data/nyud/test.txt', dtype=str) 29 | 30 | for _ in range(50): 31 | solver.step(2000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-color/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "trainval.prototxt" 2 | test_net: "test.prototxt" 3 | test_iter: 654 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-10 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | snapshot: 2000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-hha/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/nyud-fcn32s-hha-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-hha/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split, tops): 15 | n = caffe.NetSpec() 16 | n.data, n.label = L.Python(module='nyud_layers', 17 | layer='NYUDSegDataLayer', ntop=2, 18 | param_str=str(dict(nyud_dir='../data/nyud', split=split, 19 | tops=tops, seed=1337))) 20 | 21 | # the base net 22 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 23 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 24 | n.pool1 = max_pool(n.relu1_2) 25 | 26 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 27 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 28 | n.pool2 = max_pool(n.relu2_2) 29 | 30 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 31 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 32 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 33 | n.pool3 = max_pool(n.relu3_3) 34 | 35 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 36 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 37 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 38 | n.pool4 = max_pool(n.relu4_3) 39 | 40 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 41 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 42 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 43 | n.pool5 = max_pool(n.relu5_3) 44 | 45 | # fully conv 46 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 47 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 48 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 49 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 50 | 51 | n.score_fr = L.Convolution(n.drop7, num_output=40, kernel_size=1, pad=0, 52 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 53 | n.upscore = L.Deconvolution(n.score_fr, 54 | convolution_param=dict(num_output=40, kernel_size=64, stride=32, 55 | bias_term=False), 56 | param=[dict(lr_mult=0)]) 57 | n.score = crop(n.upscore, n.data) 58 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 59 | loss_param=dict(normalize=False, ignore_label=255)) 60 | 61 | return n.to_proto() 62 | 63 | def make_net(): 64 | tops = ['hha', 'label'] 65 | with open('trainval.prototxt', 'w') as f: 66 | f.write(str(fcn('trainval', tops))) 67 | 68 | with open('test.prototxt', 'w') as f: 69 | f.write(str(fcn('test', tops))) 70 | 71 | if __name__ == '__main__': 72 | make_net() 73 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-hha/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../ilsvrc-nets/vgg16-fcn.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | test = np.loadtxt('../data/nyud/test.txt', dtype=str) 29 | 30 | for _ in range(50): 31 | solver.step(2000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud-fcn32s-hha/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "trainval.prototxt" 2 | test_net: "test.prototxt" 3 | test_iter: 654 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-10 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | snapshot: 2000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/nyud_layers.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import scipy.io 6 | 7 | import random 8 | 9 | class NYUDSegDataLayer(caffe.Layer): 10 | """ 11 | Load (input image, label image) pairs from NYUDv2 12 | one-at-a-time while reshaping the net to preserve dimensions. 13 | 14 | The labels follow the 40 class task defined by 15 | 16 | S. Gupta, R. Girshick, p. Arbelaez, and J. Malik. Learning rich features 17 | from RGB-D images for object detection and segmentation. ECCV 2014. 18 | 19 | with 0 as the void label and 1-40 the classes. 20 | 21 | Use this to feed data to a fully convolutional network. 22 | """ 23 | 24 | def setup(self, bottom, top): 25 | """ 26 | Setup data layer according to parameters: 27 | 28 | - nyud_dir: path to NYUDv2 dir 29 | - split: train / val / test 30 | - tops: list of tops to output from {color, depth, hha, label} 31 | - randomize: load in random order (default: True) 32 | - seed: seed for randomization (default: None / current time) 33 | 34 | for NYUDv2 semantic segmentation. 35 | 36 | example: params = dict(nyud_dir="/path/to/NYUDVOC2011", split="val", 37 | tops=['color', 'hha', 'label']) 38 | """ 39 | # config 40 | params = eval(self.param_str) 41 | self.nyud_dir = params['nyud_dir'] 42 | self.split = params['split'] 43 | self.tops = params['tops'] 44 | self.random = params.get('randomize', True) 45 | self.seed = params.get('seed', None) 46 | 47 | # store top data for reshape + forward 48 | self.data = {} 49 | 50 | # means 51 | self.mean_bgr = np.array((116.190, 97.203, 92.318), dtype=np.float32) 52 | self.mean_hha = np.array((132.431, 94.076, 118.477), dtype=np.float32) 53 | self.mean_logd = np.array((7.844,), dtype=np.float32) 54 | 55 | # tops: check configuration 56 | if len(top) != len(self.tops): 57 | raise Exception("Need to define {} tops for all outputs.") 58 | # data layers have no bottoms 59 | if len(bottom) != 0: 60 | raise Exception("Do not define a bottom.") 61 | 62 | # load indices for images and labels 63 | split_f = '{}/{}.txt'.format(self.nyud_dir, self.split) 64 | self.indices = open(split_f, 'r').read().splitlines() 65 | self.idx = 0 66 | 67 | # make eval deterministic 68 | if 'train' not in self.split: 69 | self.random = False 70 | 71 | # randomization: seed and pick 72 | if self.random: 73 | random.seed(self.seed) 74 | self.idx = random.randint(0, len(self.indices)-1) 75 | 76 | def reshape(self, bottom, top): 77 | # load data for tops and reshape tops to fit (1 is the batch dim) 78 | for i, t in enumerate(self.tops): 79 | self.data[t] = self.load(t, self.indices[self.idx]) 80 | top[i].reshape(1, *self.data[t].shape) 81 | 82 | def forward(self, bottom, top): 83 | # assign output 84 | for i, t in enumerate(self.tops): 85 | top[i].data[...] = self.data[t] 86 | 87 | # pick next input 88 | if self.random: 89 | self.idx = random.randint(0, len(self.indices)-1) 90 | else: 91 | self.idx += 1 92 | if self.idx == len(self.indices): 93 | self.idx = 0 94 | 95 | def backward(self, top, propagate_down, bottom): 96 | pass 97 | 98 | def load(self, top, idx): 99 | if top == 'color': 100 | return self.load_image(idx) 101 | elif top == 'label': 102 | return self.load_label(idx) 103 | elif top == 'depth': 104 | return self.load_depth(idx) 105 | elif top == 'hha': 106 | return self.load_hha(idx) 107 | else: 108 | raise Exception("Unknown output type: {}".format(top)) 109 | 110 | def load_image(self, idx): 111 | """ 112 | Load input image and preprocess for Caffe: 113 | - cast to float 114 | - switch channels RGB -> BGR 115 | - subtract mean 116 | - transpose to channel x height x width order 117 | """ 118 | im = Image.open('{}/data/images/img_{}.png'.format(self.nyud_dir, idx)) 119 | in_ = np.array(im, dtype=np.float32) 120 | in_ = in_[:,:,::-1] 121 | in_ -= self.mean_bgr 122 | in_ = in_.transpose((2,0,1)) 123 | return in_ 124 | 125 | def load_label(self, idx): 126 | """ 127 | Load label image as 1 x height x width integer array of label indices. 128 | Shift labels so that classes are 0-39 and void is 255 (to ignore it). 129 | The leading singleton dimension is required by the loss. 130 | """ 131 | label = scipy.io.loadmat('{}/segmentation/img_{}.mat'.format(self.nyud_dir, idx))['segmentation'].astype(np.uint8) 132 | label -= 1 # rotate labels 133 | label = label[np.newaxis, ...] 134 | return label 135 | 136 | def load_depth(self, idx): 137 | """ 138 | Load pre-processed depth for NYUDv2 segmentation set. 139 | """ 140 | im = Image.open('{}/data/depth/img_{}.png'.format(self.nyud_dir, idx)) 141 | d = np.array(im, dtype=np.float32) 142 | d = np.log(d) 143 | d -= self.mean_logd 144 | d = d[np.newaxis, ...] 145 | return d 146 | 147 | def load_hha(self, idx): 148 | """ 149 | Load HHA features from Gupta et al. ECCV14. 150 | See https://github.com/s-gupta/rcnn-depth/blob/master/rcnn/saveHHA.m 151 | """ 152 | im = Image.open('{}/data/hha/img_{}.png'.format(self.nyud_dir, idx)) 153 | hha = np.array(im, dtype=np.float32) 154 | hha -= self.mean_hha 155 | hha = hha.transpose((2,0,1)) 156 | return hha 157 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn16s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/pascalcontext-fcn16s-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn16s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | n.data, n.label = L.Python(module='pascalcontext_layers', 17 | layer='PASCALContextSegDataLayer', ntop=2, 18 | param_str=str(dict(voc_dir='../../data/pascal', 19 | context_dir='../../data/pascal-context', split=split, 20 | seed=1337))) 21 | 22 | # the base net 23 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 24 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 25 | n.pool1 = max_pool(n.relu1_2) 26 | 27 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 28 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 29 | n.pool2 = max_pool(n.relu2_2) 30 | 31 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 32 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 33 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 34 | n.pool3 = max_pool(n.relu3_3) 35 | 36 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 37 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 38 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 39 | n.pool4 = max_pool(n.relu4_3) 40 | 41 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 42 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 43 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 44 | n.pool5 = max_pool(n.relu5_3) 45 | 46 | # fully conv 47 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 48 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 49 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 50 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 51 | 52 | n.score_fr = L.Convolution(n.drop7, num_output=60, kernel_size=1, pad=0, 53 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 54 | n.upscore2 = L.Deconvolution(n.score_fr, 55 | convolution_param=dict(num_output=60, kernel_size=4, stride=2, 56 | bias_term=False), 57 | param=[dict(lr_mult=0)]) 58 | 59 | n.score_pool4 = L.Convolution(n.pool4, num_output=60, kernel_size=1, pad=0, 60 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 61 | n.score_pool4c = crop(n.score_pool4, n.upscore2) 62 | n.fuse_pool4 = L.Eltwise(n.upscore2, n.score_pool4c, 63 | operation=P.Eltwise.SUM) 64 | n.upscore16 = L.Deconvolution(n.fuse_pool4, 65 | convolution_param=dict(num_output=60, kernel_size=32, stride=16, 66 | bias_term=False), 67 | param=[dict(lr_mult=0)]) 68 | 69 | n.score = crop(n.upscore16, n.data) 70 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 71 | loss_param=dict(normalize=False, ignore_label=255)) 72 | 73 | return n.to_proto() 74 | 75 | def make_net(): 76 | with open('train.prototxt', 'w') as f: 77 | f.write(str(fcn('train'))) 78 | 79 | with open('val.prototxt', 'w') as f: 80 | f.write(str(fcn('val'))) 81 | 82 | if __name__ == '__main__': 83 | make_net() 84 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn16s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../pascalcontext-fcn32s/pascalcontext-fcn32s.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | val = np.loadtxt('../data/pascal/VOC2010/ImageSets/Main/val.txt', dtype=str) 29 | 30 | for _ in range(50): 31 | solver.step(8000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn16s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "train.prototxt" 2 | test_net: "val.prototxt" 3 | test_iter: 5105 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-12 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | snapshot: 4000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn32s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/pascalcontext-fcn32s-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn32s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | n.data, n.label = L.Python(module='pascalcontext_layers', 17 | layer='PASCALContextSegDataLayer', ntop=2, 18 | param_str=str(dict(voc_dir='../../data/pascal', 19 | context_dir='../../data/pascal-context', split=split, 20 | seed=1337))) 21 | 22 | # the base net 23 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 24 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 25 | n.pool1 = max_pool(n.relu1_2) 26 | 27 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 28 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 29 | n.pool2 = max_pool(n.relu2_2) 30 | 31 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 32 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 33 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 34 | n.pool3 = max_pool(n.relu3_3) 35 | 36 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 37 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 38 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 39 | n.pool4 = max_pool(n.relu4_3) 40 | 41 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 42 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 43 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 44 | n.pool5 = max_pool(n.relu5_3) 45 | 46 | # fully conv 47 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 48 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 49 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 50 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 51 | 52 | n.score_fr = L.Convolution(n.drop7, num_output=60, kernel_size=1, pad=0, 53 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 54 | n.upscore = L.Deconvolution(n.score_fr, 55 | convolution_param=dict(num_output=60, kernel_size=64, stride=32, 56 | bias_term=False), 57 | param=[dict(lr_mult=0)]) 58 | n.score = crop(n.upscore, n.data) 59 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 60 | loss_param=dict(normalize=False, ignore_label=255)) 61 | 62 | 63 | return n.to_proto() 64 | 65 | def make_net(): 66 | with open('train.prototxt', 'w') as f: 67 | f.write(str(fcn('train'))) 68 | 69 | with open('val.prototxt', 'w') as f: 70 | f.write(str(fcn('val'))) 71 | 72 | if __name__ == '__main__': 73 | make_net() 74 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn32s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../ilsvrc-nets/vgg16-fcn.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | val = np.loadtxt('../data/pascal/VOC2010/ImageSets/Main/val.txt', dtype=str) 29 | 30 | for _ in range(50): 31 | solver.step(8000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn32s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "train.prototxt" 2 | test_net: "val.prototxt" 3 | test_iter: 5105 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-10 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | snapshot: 4000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn8s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/pascalcontext-fcn8s-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn8s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | n.data, n.label = L.Python(module='pascalcontext_layers', 17 | layer='PASCALContextSegDataLayer', ntop=2, 18 | param_str=str(dict(voc_dir='../../data/pascal', 19 | context_dir='../../data/pascal-context', split=split, 20 | seed=1337))) 21 | 22 | # the base net 23 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 24 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 25 | n.pool1 = max_pool(n.relu1_2) 26 | 27 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 28 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 29 | n.pool2 = max_pool(n.relu2_2) 30 | 31 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 32 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 33 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 34 | n.pool3 = max_pool(n.relu3_3) 35 | 36 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 37 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 38 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 39 | n.pool4 = max_pool(n.relu4_3) 40 | 41 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 42 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 43 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 44 | n.pool5 = max_pool(n.relu5_3) 45 | 46 | # fully conv 47 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 48 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 49 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 50 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 51 | 52 | n.score_fr = L.Convolution(n.drop7, num_output=60, kernel_size=1, pad=0, 53 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 54 | n.upscore2 = L.Deconvolution(n.score_fr, 55 | convolution_param=dict(num_output=60, kernel_size=4, stride=2, 56 | bias_term=False), 57 | param=[dict(lr_mult=0)]) 58 | 59 | n.score_pool4 = L.Convolution(n.pool4, num_output=60, kernel_size=1, pad=0, 60 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 61 | n.score_pool4c = crop(n.score_pool4, n.upscore2) 62 | n.fuse_pool4 = L.Eltwise(n.upscore2, n.score_pool4c, 63 | operation=P.Eltwise.SUM) 64 | n.upscore_pool4 = L.Deconvolution(n.fuse_pool4, 65 | convolution_param=dict(num_output=60, kernel_size=4, stride=2, 66 | bias_term=False), 67 | param=[dict(lr_mult=0)]) 68 | 69 | n.score_pool3 = L.Convolution(n.pool3, num_output=60, kernel_size=1, pad=0, 70 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 71 | n.score_pool3c = crop(n.score_pool3, n.upscore_pool4) 72 | n.fuse_pool3 = L.Eltwise(n.upscore_pool4, n.score_pool3c, 73 | operation=P.Eltwise.SUM) 74 | n.upscore8 = L.Deconvolution(n.fuse_pool3, 75 | convolution_param=dict(num_output=60, kernel_size=16, stride=8, 76 | bias_term=False), 77 | param=[dict(lr_mult=0)]) 78 | 79 | n.score = crop(n.upscore8, n.data) 80 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 81 | loss_param=dict(normalize=False, ignore_label=255)) 82 | 83 | return n.to_proto() 84 | 85 | def make_net(): 86 | with open('train.prototxt', 'w') as f: 87 | f.write(str(fcn('train'))) 88 | 89 | with open('val.prototxt', 'w') as f: 90 | f.write(str(fcn('val'))) 91 | 92 | if __name__ == '__main__': 93 | make_net() 94 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn8s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../pascalcontext-fcn16s/pascalcontext-fcn16s.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | val = np.loadtxt('../data/pascal/VOC2010/ImageSets/Main/val.txt', dtype=str) 29 | 30 | for _ in range(50): 31 | solver.step(8000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext-fcn8s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "train.prototxt" 2 | test_net: "val.prototxt" 3 | test_iter: 5105 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-14 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | snapshot: 4000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/pascalcontext_layers.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import scipy.io 6 | 7 | import random 8 | 9 | class PASCALContextSegDataLayer(caffe.Layer): 10 | """ 11 | Load (input image, label image) pairs from PASCAL-Context 12 | one-at-a-time while reshaping the net to preserve dimensions. 13 | 14 | The labels follow the 59 class task defined by 15 | 16 | R. Mottaghi, X. Chen, X. Liu, N.-G. Cho, S.-W. Lee, S. Fidler, R. 17 | Urtasun, and A. Yuille. The Role of Context for Object Detection and 18 | Semantic Segmentation in the Wild. CVPR 2014. 19 | 20 | Use this to feed data to a fully convolutional network. 21 | """ 22 | 23 | def setup(self, bottom, top): 24 | """ 25 | Setup data layer according to parameters: 26 | 27 | - voc_dir: path to PASCAL VOC dir (must contain 2010) 28 | - context_dir: path to PASCAL-Context annotations 29 | - split: train / val / test 30 | - randomize: load in random order (default: True) 31 | - seed: seed for randomization (default: None / current time) 32 | 33 | for PASCAL-Context semantic segmentation. 34 | 35 | example: params = dict(voc_dir="/path/to/PASCAL", split="val") 36 | """ 37 | # config 38 | params = eval(self.param_str) 39 | self.voc_dir = params['voc_dir'] + '/VOC2010' 40 | self.context_dir = params['context_dir'] 41 | self.split = params['split'] 42 | self.mean = np.array((104.007, 116.669, 122.679), dtype=np.float32) 43 | self.random = params.get('randomize', True) 44 | self.seed = params.get('seed', None) 45 | 46 | # load labels and resolve inconsistencies by mapping to full 400 labels 47 | self.labels_400 = [label.replace(' ','') for idx, label in np.genfromtxt(self.context_dir + '/labels.txt', delimiter=':', dtype=None)] 48 | self.labels_59 = [label.replace(' ','') for idx, label in np.genfromtxt(self.context_dir + '/59_labels.txt', delimiter=':', dtype=None)] 49 | for main_label, task_label in zip(('table', 'bedclothes', 'cloth'), ('diningtable', 'bedcloth', 'clothes')): 50 | self.labels_59[self.labels_59.index(task_label)] = main_label 51 | 52 | # two tops: data and label 53 | if len(top) != 2: 54 | raise Exception("Need to define two tops: data and label.") 55 | # data layers have no bottoms 56 | if len(bottom) != 0: 57 | raise Exception("Do not define a bottom.") 58 | 59 | # load indices for images and labels 60 | split_f = '{}/ImageSets/Main/{}.txt'.format(self.voc_dir, 61 | self.split) 62 | self.indices = open(split_f, 'r').read().splitlines() 63 | self.idx = 0 64 | 65 | # make eval deterministic 66 | if 'train' not in self.split: 67 | self.random = False 68 | 69 | # randomization: seed and pick 70 | if self.random: 71 | random.seed(self.seed) 72 | self.idx = random.randint(0, len(self.indices)-1) 73 | 74 | def reshape(self, bottom, top): 75 | # load image + label image pair 76 | self.data = self.load_image(self.indices[self.idx]) 77 | self.label = self.load_label(self.indices[self.idx]) 78 | # reshape tops to fit (leading 1 is for batch dimension) 79 | top[0].reshape(1, *self.data.shape) 80 | top[1].reshape(1, *self.label.shape) 81 | 82 | def forward(self, bottom, top): 83 | # assign output 84 | top[0].data[...] = self.data 85 | top[1].data[...] = self.label 86 | 87 | # pick next input 88 | if self.random: 89 | self.idx = random.randint(0, len(self.indices)-1) 90 | else: 91 | self.idx += 1 92 | if self.idx == len(self.indices): 93 | self.idx = 0 94 | 95 | def backward(self, top, propagate_down, bottom): 96 | pass 97 | 98 | def load_image(self, idx): 99 | """ 100 | Load input image and preprocess for Caffe: 101 | - cast to float 102 | - switch channels RGB -> BGR 103 | - subtract mean 104 | - transpose to channel x height x width order 105 | """ 106 | im = Image.open('{}/JPEGImages/{}.jpg'.format(self.voc_dir, idx)) 107 | in_ = np.array(im, dtype=np.float32) 108 | in_ = in_[:,:,::-1] 109 | in_ -= self.mean 110 | in_ = in_.transpose((2,0,1)) 111 | return in_ 112 | 113 | def load_label(self, idx): 114 | """ 115 | Load label image as 1 x height x width integer array of label indices. 116 | The leading singleton dimension is required by the loss. 117 | The full 400 labels are translated to the 59 class task labels. 118 | """ 119 | label_400 = scipy.io.loadmat('{}/trainval/{}.mat'.format(self.context_dir, idx))['LabelMap'] 120 | label = np.zeros_like(label_400, dtype=np.uint8) 121 | for idx, l in enumerate(self.labels_59): 122 | idx_400 = self.labels_400.index(l) + 1 123 | label[label_400 == idx_400] = idx + 1 124 | label = label[np.newaxis, ...] 125 | return label 126 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/score.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import caffe 3 | import numpy as np 4 | import os 5 | import sys 6 | from datetime import datetime 7 | from PIL import Image 8 | 9 | def fast_hist(a, b, n): 10 | k = (a >= 0) & (a < n) 11 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n) 12 | 13 | def compute_hist(net, save_dir, dataset, layer='score', gt='label'): 14 | n_cl = net.blobs[layer].channels 15 | if save_dir: 16 | os.mkdir(save_dir) 17 | hist = np.zeros((n_cl, n_cl)) 18 | loss = 0 19 | for idx in dataset: 20 | net.forward() 21 | hist += fast_hist(net.blobs[gt].data[0, 0].flatten(), 22 | net.blobs[layer].data[0].argmax(0).flatten(), 23 | n_cl) 24 | 25 | if save_dir: 26 | im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode='P') 27 | im.save(os.path.join(save_dir, idx + '.png')) 28 | # compute the loss as well 29 | loss += net.blobs['loss'].data.flat[0] 30 | return hist, loss / len(dataset) 31 | 32 | def seg_tests(solver, save_format, dataset, layer='score', gt='label'): 33 | print '>>>', datetime.now(), 'Begin seg tests' 34 | solver.test_nets[0].share_with(solver.net) 35 | do_seg_tests(solver.test_nets[0], solver.iter, save_format, dataset, layer, gt) 36 | 37 | def do_seg_tests(net, iter, save_format, dataset, layer='score', gt='label'): 38 | n_cl = net.blobs[layer].channels 39 | if save_format: 40 | save_format = save_format.format(iter) 41 | hist, loss = compute_hist(net, save_format, dataset, layer, gt) 42 | # mean loss 43 | print '>>>', datetime.now(), 'Iteration', iter, 'loss', loss 44 | # overall accuracy 45 | acc = np.diag(hist).sum() / hist.sum() 46 | print '>>>', datetime.now(), 'Iteration', iter, 'overall accuracy', acc 47 | # per-class accuracy 48 | acc = np.diag(hist) / hist.sum(1) 49 | print '>>>', datetime.now(), 'Iteration', iter, 'mean accuracy', np.nanmean(acc) 50 | # per-class IU 51 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 52 | print '>>>', datetime.now(), 'Iteration', iter, 'mean IU', np.nanmean(iu) 53 | freq = hist.sum(1) / hist.sum() 54 | print '>>>', datetime.now(), 'Iteration', iter, 'fwavacc', \ 55 | (freq[freq > 0] * iu[freq > 0]).sum() 56 | return hist 57 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn16s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/siftflow-fcn16s-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn16s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | n.data, n.sem, n.geo = L.Python(module='siftflow_layers', 17 | layer='SIFTFlowSegDataLayer', ntop=3, 18 | param_str=str(dict(siftflow_dir='../data/sift-flow', 19 | split=split, seed=1337))) 20 | 21 | # the base net 22 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 23 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 24 | n.pool1 = max_pool(n.relu1_2) 25 | 26 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 27 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 28 | n.pool2 = max_pool(n.relu2_2) 29 | 30 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 31 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 32 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 33 | n.pool3 = max_pool(n.relu3_3) 34 | 35 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 36 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 37 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 38 | n.pool4 = max_pool(n.relu4_3) 39 | 40 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 41 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 42 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 43 | n.pool5 = max_pool(n.relu5_3) 44 | 45 | # fully conv 46 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 47 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 48 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 49 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 50 | 51 | n.score_fr_sem = L.Convolution(n.drop7, num_output=33, kernel_size=1, pad=0, 52 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 53 | n.upscore2_sem = L.Deconvolution(n.score_fr_sem, 54 | convolution_param=dict(num_output=33, kernel_size=4, stride=2, 55 | bias_term=False), 56 | param=[dict(lr_mult=0)]) 57 | 58 | n.score_pool4_sem = L.Convolution(n.pool4, num_output=33, kernel_size=1, pad=0, 59 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 60 | n.score_pool4_semc = crop(n.score_pool4_sem, n.upscore2_sem) 61 | n.fuse_pool4_sem = L.Eltwise(n.upscore2_sem, n.score_pool4_semc, 62 | operation=P.Eltwise.SUM) 63 | n.upscore16_sem = L.Deconvolution(n.fuse_pool4_sem, 64 | convolution_param=dict(num_output=33, kernel_size=32, stride=16, 65 | bias_term=False), 66 | param=[dict(lr_mult=0)]) 67 | 68 | n.score_sem = crop(n.upscore16_sem, n.data) 69 | # loss to make score happy (o.w. loss_sem) 70 | n.loss = L.SoftmaxWithLoss(n.score_sem, n.sem, 71 | loss_param=dict(normalize=False, ignore_label=255)) 72 | 73 | n.score_fr_geo = L.Convolution(n.drop7, num_output=3, kernel_size=1, pad=0, 74 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 75 | 76 | n.upscore2_geo = L.Deconvolution(n.score_fr_geo, 77 | convolution_param=dict(num_output=3, kernel_size=4, stride=2, 78 | bias_term=False), 79 | param=[dict(lr_mult=0)]) 80 | 81 | n.score_pool4_geo = L.Convolution(n.pool4, num_output=3, kernel_size=1, pad=0, 82 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 83 | n.score_pool4_geoc = crop(n.score_pool4_geo, n.upscore2_geo) 84 | n.fuse_pool4_geo = L.Eltwise(n.upscore2_geo, n.score_pool4_geoc, 85 | operation=P.Eltwise.SUM) 86 | n.upscore16_geo = L.Deconvolution(n.fuse_pool4_geo, 87 | convolution_param=dict(num_output=3, kernel_size=32, stride=16, 88 | bias_term=False), 89 | param=[dict(lr_mult=0)]) 90 | 91 | n.score_geo = crop(n.upscore16_geo, n.data) 92 | n.loss_geo = L.SoftmaxWithLoss(n.score_geo, n.geo, 93 | loss_param=dict(normalize=False, ignore_label=255)) 94 | 95 | return n.to_proto() 96 | 97 | def make_net(): 98 | with open('trainval.prototxt', 'w') as f: 99 | f.write(str(fcn('trainval'))) 100 | 101 | with open('test.prototxt', 'w') as f: 102 | f.write(str(fcn('test'))) 103 | 104 | if __name__ == '__main__': 105 | make_net() 106 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn16s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../siftflow-fcn32s/siftflow-fcn32s.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | test = np.loadtxt('../data/sift-flow/test.txt', dtype=str) 29 | 30 | for _ in range(50): 31 | solver.step(2000) 32 | # N.B. metrics on the semantic labels are off b.c. of missing classes; 33 | # score manually from the histogram instead for proper evaluation 34 | score.seg_tests(solver, False, test, layer='score_sem', gt='sem') 35 | score.seg_tests(solver, False, test, layer='score_geo', gt='geo') 36 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn16s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "trainval.prototxt" 2 | test_net: "test.prototxt" 3 | test_iter: 200 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-12 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | test_initialization: false 18 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn32s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/siftflow-fcn32s-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn32s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | n.data, n.sem, n.geo = L.Python(module='siftflow_layers', 17 | layer='SIFTFlowSegDataLayer', ntop=3, 18 | param_str=str(dict(siftflow_dir='../data/sift-flow', 19 | split=split, seed=1337))) 20 | 21 | # the base net 22 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 23 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 24 | n.pool1 = max_pool(n.relu1_2) 25 | 26 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 27 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 28 | n.pool2 = max_pool(n.relu2_2) 29 | 30 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 31 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 32 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 33 | n.pool3 = max_pool(n.relu3_3) 34 | 35 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 36 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 37 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 38 | n.pool4 = max_pool(n.relu4_3) 39 | 40 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 41 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 42 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 43 | n.pool5 = max_pool(n.relu5_3) 44 | 45 | # fully conv 46 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 47 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 48 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 49 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 50 | 51 | n.score_fr_sem = L.Convolution(n.drop7, num_output=33, kernel_size=1, pad=0, 52 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 53 | n.upscore_sem = L.Deconvolution(n.score_fr_sem, 54 | convolution_param=dict(num_output=33, kernel_size=64, stride=32, 55 | bias_term=False), 56 | param=[dict(lr_mult=0)]) 57 | n.score_sem = crop(n.upscore_sem, n.data) 58 | # loss to make score happy (o.w. loss_sem) 59 | n.loss = L.SoftmaxWithLoss(n.score_sem, n.sem, 60 | loss_param=dict(normalize=False, ignore_label=255)) 61 | 62 | n.score_fr_geo = L.Convolution(n.drop7, num_output=3, kernel_size=1, pad=0, 63 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 64 | n.upscore_geo = L.Deconvolution(n.score_fr_geo, 65 | convolution_param=dict(num_output=3, kernel_size=64, stride=32, 66 | bias_term=False), 67 | param=[dict(lr_mult=0)]) 68 | n.score_geo = crop(n.upscore_geo, n.data) 69 | n.loss_geo = L.SoftmaxWithLoss(n.score_geo, n.geo, 70 | loss_param=dict(normalize=False, ignore_label=255)) 71 | 72 | return n.to_proto() 73 | 74 | def make_net(): 75 | with open('trainval.prototxt', 'w') as f: 76 | f.write(str(fcn('trainval'))) 77 | 78 | with open('test.prototxt', 'w') as f: 79 | f.write(str(fcn('test'))) 80 | 81 | if __name__ == '__main__': 82 | make_net() 83 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn32s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../ilsvrc-nets/vgg16-fcn.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | test = np.loadtxt('../data/sift-flow/test.txt', dtype=str) 29 | 30 | for _ in range(50): 31 | solver.step(2000) 32 | # N.B. metrics on the semantic labels are off b.c. of missing classes; 33 | # score manually from the histogram instead for proper evaluation 34 | score.seg_tests(solver, False, test, layer='score_sem', gt='sem') 35 | score.seg_tests(solver, False, test, layer='score_geo', gt='geo') 36 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn32s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "trainval.prototxt" 2 | test_net: "test.prototxt" 3 | test_iter: 200 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-10 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | test_initialization: false 18 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn8s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/siftflow-fcn8s-heavy.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn8s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | n.data, n.sem, n.geo = L.Python(module='siftflow_layers', 17 | layer='SIFTFlowSegDataLayer', ntop=3, 18 | param_str=str(dict(siftflow_dir='../data/sift-flow', 19 | split=split, seed=1337))) 20 | 21 | # the base net 22 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 23 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 24 | n.pool1 = max_pool(n.relu1_2) 25 | 26 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 27 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 28 | n.pool2 = max_pool(n.relu2_2) 29 | 30 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 31 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 32 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 33 | n.pool3 = max_pool(n.relu3_3) 34 | 35 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 36 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 37 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 38 | n.pool4 = max_pool(n.relu4_3) 39 | 40 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 41 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 42 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 43 | n.pool5 = max_pool(n.relu5_3) 44 | 45 | # fully conv 46 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 47 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 48 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 49 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 50 | 51 | n.score_fr_sem = L.Convolution(n.drop7, num_output=33, kernel_size=1, pad=0, 52 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 53 | n.upscore2_sem = L.Deconvolution(n.score_fr_sem, 54 | convolution_param=dict(num_output=33, kernel_size=4, stride=2, 55 | bias_term=False), 56 | param=[dict(lr_mult=0)]) 57 | 58 | n.score_pool4_sem = L.Convolution(n.pool4, num_output=33, kernel_size=1, pad=0, 59 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 60 | n.score_pool4_semc = crop(n.score_pool4_sem, n.upscore2_sem) 61 | n.fuse_pool4_sem = L.Eltwise(n.upscore2_sem, n.score_pool4_semc, 62 | operation=P.Eltwise.SUM) 63 | n.upscore_pool4_sem = L.Deconvolution(n.fuse_pool4_sem, 64 | convolution_param=dict(num_output=33, kernel_size=4, stride=2, 65 | bias_term=False), 66 | param=[dict(lr_mult=0)]) 67 | 68 | n.score_pool3_sem = L.Convolution(n.pool3, num_output=33, kernel_size=1, 69 | pad=0, param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, 70 | decay_mult=0)]) 71 | n.score_pool3_semc = crop(n.score_pool3_sem, n.upscore_pool4_sem) 72 | n.fuse_pool3_sem = L.Eltwise(n.upscore_pool4_sem, n.score_pool3_semc, 73 | operation=P.Eltwise.SUM) 74 | n.upscore8_sem = L.Deconvolution(n.fuse_pool3_sem, 75 | convolution_param=dict(num_output=33, kernel_size=16, stride=8, 76 | bias_term=False), 77 | param=[dict(lr_mult=0)]) 78 | 79 | n.score_sem = crop(n.upscore8_sem, n.data) 80 | # loss to make score happy (o.w. loss_sem) 81 | n.loss = L.SoftmaxWithLoss(n.score_sem, n.sem, 82 | loss_param=dict(normalize=False, ignore_label=255)) 83 | 84 | n.score_fr_geo = L.Convolution(n.drop7, num_output=3, kernel_size=1, pad=0, 85 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 86 | 87 | n.upscore2_geo = L.Deconvolution(n.score_fr_geo, 88 | convolution_param=dict(num_output=3, kernel_size=4, stride=2, 89 | bias_term=False), 90 | param=[dict(lr_mult=0)]) 91 | 92 | n.score_pool4_geo = L.Convolution(n.pool4, num_output=3, kernel_size=1, pad=0, 93 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 94 | n.score_pool4_geoc = crop(n.score_pool4_geo, n.upscore2_geo) 95 | n.fuse_pool4_geo = L.Eltwise(n.upscore2_geo, n.score_pool4_geoc, 96 | operation=P.Eltwise.SUM) 97 | n.upscore_pool4_geo = L.Deconvolution(n.fuse_pool4_geo, 98 | convolution_param=dict(num_output=3, kernel_size=4, stride=2, 99 | bias_term=False), 100 | param=[dict(lr_mult=0)]) 101 | 102 | n.score_pool3_geo = L.Convolution(n.pool3, num_output=3, kernel_size=1, 103 | pad=0, param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, 104 | decay_mult=0)]) 105 | n.score_pool3_geoc = crop(n.score_pool3_geo, n.upscore_pool4_geo) 106 | n.fuse_pool3_geo = L.Eltwise(n.upscore_pool4_geo, n.score_pool3_geoc, 107 | operation=P.Eltwise.SUM) 108 | n.upscore8_geo = L.Deconvolution(n.fuse_pool3_geo, 109 | convolution_param=dict(num_output=3, kernel_size=16, stride=8, 110 | bias_term=False), 111 | param=[dict(lr_mult=0)]) 112 | 113 | n.score_geo = crop(n.upscore8_geo, n.data) 114 | n.loss_geo = L.SoftmaxWithLoss(n.score_geo, n.geo, 115 | loss_param=dict(normalize=False, ignore_label=255)) 116 | 117 | return n.to_proto() 118 | 119 | def make_net(): 120 | with open('trainval.prototxt', 'w') as f: 121 | f.write(str(fcn('trainval'))) 122 | 123 | with open('test.prototxt', 'w') as f: 124 | f.write(str(fcn('test'))) 125 | 126 | if __name__ == '__main__': 127 | make_net() 128 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn8s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../siftflow-fcn16s/siftflow-fcn16s.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | test = np.loadtxt('../data/sift-flow/test.txt', dtype=str) 29 | 30 | for _ in range(50): 31 | solver.step(2000) 32 | # N.B. metrics on the semantic labels are off b.c. of missing classes; 33 | # score manually from the histogram instead for proper evaluation 34 | score.seg_tests(solver, False, test, layer='score_sem', gt='sem') 35 | score.seg_tests(solver, False, test, layer='score_geo', gt='geo') 36 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow-fcn8s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "trainval.prototxt" 2 | test_net: "test.prototxt" 3 | test_iter: 200 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-12 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | test_initialization: false 18 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/siftflow_layers.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import scipy.io 6 | 7 | import random 8 | 9 | class SIFTFlowSegDataLayer(caffe.Layer): 10 | """ 11 | Load (input image, label image) pairs from SIFT Flow 12 | one-at-a-time while reshaping the net to preserve dimensions. 13 | 14 | This data layer has three tops: 15 | 16 | 1. the data, pre-processed 17 | 2. the semantic labels 0-32 and void 255 18 | 3. the geometric labels 0-2 and void 255 19 | 20 | Use this to feed data to a fully convolutional network. 21 | """ 22 | 23 | def setup(self, bottom, top): 24 | """ 25 | Setup data layer according to parameters: 26 | 27 | - siftflow_dir: path to SIFT Flow dir 28 | - split: train / val / test 29 | - randomize: load in random order (default: True) 30 | - seed: seed for randomization (default: None / current time) 31 | 32 | for semantic segmentation of object and geometric classes. 33 | 34 | example: params = dict(siftflow_dir="/path/to/siftflow", split="val") 35 | """ 36 | # config 37 | params = eval(self.param_str) 38 | self.siftflow_dir = params['siftflow_dir'] 39 | self.split = params['split'] 40 | self.mean = np.array((114.578, 115.294, 108.353), dtype=np.float32) 41 | self.random = params.get('randomize', True) 42 | self.seed = params.get('seed', None) 43 | 44 | # three tops: data, semantic, geometric 45 | if len(top) != 3: 46 | raise Exception("Need to define three tops: data, semantic label, and geometric label.") 47 | # data layers have no bottoms 48 | if len(bottom) != 0: 49 | raise Exception("Do not define a bottom.") 50 | 51 | # load indices for images and labels 52 | split_f = '{}/{}.txt'.format(self.siftflow_dir, self.split) 53 | self.indices = open(split_f, 'r').read().splitlines() 54 | self.idx = 0 55 | 56 | # make eval deterministic 57 | if 'train' not in self.split: 58 | self.random = False 59 | 60 | # randomization: seed and pick 61 | if self.random: 62 | random.seed(self.seed) 63 | self.idx = random.randint(0, len(self.indices)-1) 64 | 65 | def reshape(self, bottom, top): 66 | # load image + label image pair 67 | self.data = self.load_image(self.indices[self.idx]) 68 | self.label_semantic = self.load_label(self.indices[self.idx], label_type='semantic') 69 | self.label_geometric = self.load_label(self.indices[self.idx], label_type='geometric') 70 | # reshape tops to fit (leading 1 is for batch dimension) 71 | top[0].reshape(1, *self.data.shape) 72 | top[1].reshape(1, *self.label_semantic.shape) 73 | top[2].reshape(1, *self.label_geometric.shape) 74 | 75 | def forward(self, bottom, top): 76 | # assign output 77 | top[0].data[...] = self.data 78 | top[1].data[...] = self.label_semantic 79 | top[2].data[...] = self.label_geometric 80 | 81 | # pick next input 82 | if self.random: 83 | self.idx = random.randint(0, len(self.indices)-1) 84 | else: 85 | self.idx += 1 86 | if self.idx == len(self.indices): 87 | self.idx = 0 88 | 89 | def backward(self, top, propagate_down, bottom): 90 | pass 91 | 92 | def load_image(self, idx): 93 | """ 94 | Load input image and preprocess for Caffe: 95 | - cast to float 96 | - switch channels RGB -> BGR 97 | - subtract mean 98 | - transpose to channel x height x width order 99 | """ 100 | im = Image.open('{}/Images/spatial_envelope_256x256_static_8outdoorcategories/{}.jpg'.format(self.siftflow_dir, idx)) 101 | in_ = np.array(im, dtype=np.float32) 102 | in_ = in_[:,:,::-1] 103 | in_ -= self.mean 104 | in_ = in_.transpose((2,0,1)) 105 | return in_ 106 | 107 | def load_label(self, idx, label_type=None): 108 | """ 109 | Load label image as 1 x height x width integer array of label indices. 110 | The leading singleton dimension is required by the loss. 111 | """ 112 | if label_type == 'semantic': 113 | label = scipy.io.loadmat('{}/SemanticLabels/spatial_envelope_256x256_static_8outdoorcategories/{}.mat'.format(self.siftflow_dir, idx))['S'] 114 | elif label_type == 'geometric': 115 | label = scipy.io.loadmat('{}/GeoLabels/spatial_envelope_256x256_static_8outdoorcategories/{}.mat'.format(self.siftflow_dir, idx))['S'] 116 | label[label == -1] = 0 117 | else: 118 | raise Exception("Unknown label type: {}. Pick semantic or geometric.".format(label_type)) 119 | label = label.astype(np.uint8) 120 | label -= 1 # rotate labels so classes start at 0, void is 255 121 | label = label[np.newaxis, ...] 122 | return label.copy() 123 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/surgery.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import caffe 3 | import numpy as np 4 | 5 | def transplant(new_net, net, suffix=''): 6 | """ 7 | Transfer weights by copying matching parameters, coercing parameters of 8 | incompatible shape, and dropping unmatched parameters. 9 | 10 | The coercion is useful to convert fully connected layers to their 11 | equivalent convolutional layers, since the weights are the same and only 12 | the shapes are different. In particular, equivalent fully connected and 13 | convolution layers have shapes O x I and O x I x H x W respectively for O 14 | outputs channels, I input channels, H kernel height, and W kernel width. 15 | 16 | Both `net` to `new_net` arguments must be instantiated `caffe.Net`s. 17 | """ 18 | for p in net.params: 19 | p_new = p + suffix 20 | if p_new not in new_net.params: 21 | print 'dropping', p 22 | continue 23 | for i in range(len(net.params[p])): 24 | if i > (len(new_net.params[p_new]) - 1): 25 | print 'dropping', p, i 26 | break 27 | if net.params[p][i].data.shape != new_net.params[p_new][i].data.shape: 28 | print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p_new][i].data.shape 29 | else: 30 | print 'copying', p, ' -> ', p_new, i 31 | new_net.params[p_new][i].data.flat = net.params[p][i].data.flat 32 | 33 | def upsample_filt(size): 34 | """ 35 | Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size. 36 | """ 37 | factor = (size + 1) // 2 38 | if size % 2 == 1: 39 | center = factor - 1 40 | else: 41 | center = factor - 0.5 42 | og = np.ogrid[:size, :size] 43 | return (1 - abs(og[0] - center) / factor) * \ 44 | (1 - abs(og[1] - center) / factor) 45 | 46 | def interp(net, layers): 47 | """ 48 | Set weights of each layer in layers to bilinear kernels for interpolation. 49 | """ 50 | for l in layers: 51 | m, k, h, w = net.params[l][0].data.shape 52 | if m != k and k != 1: 53 | print 'input + output channels need to be the same or |output| == 1' 54 | raise 55 | if h != w: 56 | print 'filters need to be square' 57 | raise 58 | filt = upsample_filt(h) 59 | net.params[l][0].data[range(m), range(k), :, :] = filt 60 | 61 | def expand_score(new_net, new_layer, net, layer): 62 | """ 63 | Transplant an old score layer's parameters, with k < k' classes, into a new 64 | score layer with k classes s.t. the first k' are the old classes. 65 | """ 66 | old_cl = net.params[layer][0].num 67 | new_net.params[new_layer][0].data[:old_cl][...] = net.params[layer][0].data 68 | new_net.params[new_layer][1].data[0,0,0,:old_cl][...] = net.params[layer][1].data 69 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn-alexnet/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/fcn-alexnet-pascal.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn-alexnet/net.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../python') 3 | 4 | import caffe 5 | from caffe import layers as L, params as P 6 | from caffe.coord_map import crop 7 | 8 | def conv_relu(bottom, ks, nout, stride=1, pad=0, group=1): 9 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 10 | num_output=nout, pad=pad, group=group) 11 | return conv, L.ReLU(conv, in_place=True) 12 | 13 | def max_pool(bottom, ks, stride=1): 14 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 15 | 16 | def fcn(split): 17 | n = caffe.NetSpec() 18 | pydata_params = dict(split=split, mean=(104.00699, 116.66877, 122.67892), 19 | seed=1337) 20 | if split == 'train': 21 | pydata_params['sbdd_dir'] = '../data/sbdd/dataset' 22 | pylayer = 'SBDDSegDataLayer' 23 | else: 24 | pydata_params['voc_dir'] = '../data/pascal/VOC2011' 25 | pylayer = 'VOCSegDataLayer' 26 | n.data, n.label = L.Python(module='voc_layers', layer=pylayer, 27 | ntop=2, param_str=str(pydata_params)) 28 | 29 | # the base net 30 | n.conv1, n.relu1 = conv_relu(n.data, 11, 96, stride=4, pad=100) 31 | n.pool1 = max_pool(n.relu1, 3, stride=2) 32 | n.norm1 = L.LRN(n.pool1, local_size=5, alpha=1e-4, beta=0.75) 33 | n.conv2, n.relu2 = conv_relu(n.norm1, 5, 256, pad=2, group=2) 34 | n.pool2 = max_pool(n.relu2, 3, stride=2) 35 | n.norm2 = L.LRN(n.pool2, local_size=5, alpha=1e-4, beta=0.75) 36 | n.conv3, n.relu3 = conv_relu(n.norm2, 3, 384, pad=1) 37 | n.conv4, n.relu4 = conv_relu(n.relu3, 3, 384, pad=1, group=2) 38 | n.conv5, n.relu5 = conv_relu(n.relu4, 3, 256, pad=1, group=2) 39 | n.pool5 = max_pool(n.relu5, 3, stride=2) 40 | 41 | # fully conv 42 | n.fc6, n.relu6 = conv_relu(n.pool5, 6, 4096) 43 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 44 | n.fc7, n.relu7 = conv_relu(n.drop6, 1, 4096) 45 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 46 | 47 | n.score_fr = L.Convolution(n.drop7, num_output=21, kernel_size=1, pad=0, 48 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 49 | n.upscore = L.Deconvolution(n.score_fr, 50 | convolution_param=dict(num_output=21, kernel_size=63, stride=32, 51 | bias_term=False), 52 | param=[dict(lr_mult=0)]) 53 | n.score = crop(n.upscore, n.data) 54 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 55 | loss_param=dict(normalize=True, ignore_label=255)) 56 | 57 | return n.to_proto() 58 | 59 | def make_net(): 60 | with open('train.prototxt', 'w') as f: 61 | f.write(str(fcn('train'))) 62 | 63 | with open('val.prototxt', 'w') as f: 64 | f.write(str(fcn('seg11valid'))) 65 | 66 | if __name__ == '__main__': 67 | make_net() 68 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn-alexnet/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../ilsvrc-nets/alexnet-fcn.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | val = np.loadtxt('../data/segvalid11.txt', dtype=str) 29 | 30 | for _ in range(25): 31 | solver.step(4000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn-alexnet/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "train.prototxt" 2 | test_net: "val.prototxt" 3 | test_iter: 736 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for normalized softmax 10 | base_lr: 1e-4 11 | # standard momentum 12 | momentum: 0.9 13 | # gradient accumulation 14 | iter_size: 20 15 | max_iter: 100000 16 | weight_decay: 0.0005 17 | snapshot: 4000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn-alexnet/train.prototxt: -------------------------------------------------------------------------------- 1 | layer { 2 | name: "data" 3 | type: "Python" 4 | top: "data" 5 | top: "label" 6 | python_param { 7 | module: "voc_layers" 8 | layer: "SBDDSegDataLayer" 9 | param_str: "{\'sbdd_dir\': \'../data/sbdd/dataset\', \'seed\': 1337, \'split\': \'train\', \'mean\': (104.00699, 116.66877, 122.67892)}" 10 | } 11 | } 12 | layer { 13 | name: "conv1" 14 | type: "Convolution" 15 | bottom: "data" 16 | top: "conv1" 17 | convolution_param { 18 | num_output: 96 19 | pad: 100 20 | kernel_size: 11 21 | group: 1 22 | stride: 4 23 | } 24 | } 25 | layer { 26 | name: "relu1" 27 | type: "ReLU" 28 | bottom: "conv1" 29 | top: "conv1" 30 | } 31 | layer { 32 | name: "pool1" 33 | type: "Pooling" 34 | bottom: "conv1" 35 | top: "pool1" 36 | pooling_param { 37 | pool: MAX 38 | kernel_size: 3 39 | stride: 2 40 | } 41 | } 42 | layer { 43 | name: "norm1" 44 | type: "LRN" 45 | bottom: "pool1" 46 | top: "norm1" 47 | lrn_param { 48 | local_size: 5 49 | alpha: 0.0001 50 | beta: 0.75 51 | } 52 | } 53 | layer { 54 | name: "conv2" 55 | type: "Convolution" 56 | bottom: "norm1" 57 | top: "conv2" 58 | convolution_param { 59 | num_output: 256 60 | pad: 2 61 | kernel_size: 5 62 | group: 2 63 | stride: 1 64 | } 65 | } 66 | layer { 67 | name: "relu2" 68 | type: "ReLU" 69 | bottom: "conv2" 70 | top: "conv2" 71 | } 72 | layer { 73 | name: "pool2" 74 | type: "Pooling" 75 | bottom: "conv2" 76 | top: "pool2" 77 | pooling_param { 78 | pool: MAX 79 | kernel_size: 3 80 | stride: 2 81 | } 82 | } 83 | layer { 84 | name: "norm2" 85 | type: "LRN" 86 | bottom: "pool2" 87 | top: "norm2" 88 | lrn_param { 89 | local_size: 5 90 | alpha: 0.0001 91 | beta: 0.75 92 | } 93 | } 94 | layer { 95 | name: "conv3" 96 | type: "Convolution" 97 | bottom: "norm2" 98 | top: "conv3" 99 | convolution_param { 100 | num_output: 384 101 | pad: 1 102 | kernel_size: 3 103 | group: 1 104 | stride: 1 105 | } 106 | } 107 | layer { 108 | name: "relu3" 109 | type: "ReLU" 110 | bottom: "conv3" 111 | top: "conv3" 112 | } 113 | layer { 114 | name: "conv4" 115 | type: "Convolution" 116 | bottom: "conv3" 117 | top: "conv4" 118 | convolution_param { 119 | num_output: 384 120 | pad: 1 121 | kernel_size: 3 122 | group: 2 123 | stride: 1 124 | } 125 | } 126 | layer { 127 | name: "relu4" 128 | type: "ReLU" 129 | bottom: "conv4" 130 | top: "conv4" 131 | } 132 | layer { 133 | name: "conv5" 134 | type: "Convolution" 135 | bottom: "conv4" 136 | top: "conv5" 137 | convolution_param { 138 | num_output: 256 139 | pad: 1 140 | kernel_size: 3 141 | group: 2 142 | stride: 1 143 | } 144 | } 145 | layer { 146 | name: "relu5" 147 | type: "ReLU" 148 | bottom: "conv5" 149 | top: "conv5" 150 | } 151 | layer { 152 | name: "pool5" 153 | type: "Pooling" 154 | bottom: "conv5" 155 | top: "pool5" 156 | pooling_param { 157 | pool: MAX 158 | kernel_size: 3 159 | stride: 2 160 | } 161 | } 162 | layer { 163 | name: "fc6" 164 | type: "Convolution" 165 | bottom: "pool5" 166 | top: "fc6" 167 | convolution_param { 168 | num_output: 4096 169 | pad: 0 170 | kernel_size: 6 171 | group: 1 172 | stride: 1 173 | } 174 | } 175 | layer { 176 | name: "relu6" 177 | type: "ReLU" 178 | bottom: "fc6" 179 | top: "fc6" 180 | } 181 | layer { 182 | name: "drop6" 183 | type: "Dropout" 184 | bottom: "fc6" 185 | top: "fc6" 186 | dropout_param { 187 | dropout_ratio: 0.5 188 | } 189 | } 190 | layer { 191 | name: "fc7" 192 | type: "Convolution" 193 | bottom: "fc6" 194 | top: "fc7" 195 | convolution_param { 196 | num_output: 4096 197 | pad: 0 198 | kernel_size: 1 199 | group: 1 200 | stride: 1 201 | } 202 | } 203 | layer { 204 | name: "relu7" 205 | type: "ReLU" 206 | bottom: "fc7" 207 | top: "fc7" 208 | } 209 | layer { 210 | name: "drop7" 211 | type: "Dropout" 212 | bottom: "fc7" 213 | top: "fc7" 214 | dropout_param { 215 | dropout_ratio: 0.5 216 | } 217 | } 218 | layer { 219 | name: "score_fr" 220 | type: "Convolution" 221 | bottom: "fc7" 222 | top: "score_fr" 223 | param { 224 | lr_mult: 1 225 | decay_mult: 1 226 | } 227 | param { 228 | lr_mult: 2 229 | decay_mult: 0 230 | } 231 | convolution_param { 232 | num_output: 21 233 | pad: 0 234 | kernel_size: 1 235 | } 236 | } 237 | layer { 238 | name: "upscore" 239 | type: "Deconvolution" 240 | bottom: "score_fr" 241 | top: "upscore" 242 | param { 243 | lr_mult: 0 244 | } 245 | convolution_param { 246 | num_output: 21 247 | bias_term: false 248 | kernel_size: 63 249 | stride: 32 250 | } 251 | } 252 | layer { 253 | name: "score" 254 | type: "Crop" 255 | bottom: "upscore" 256 | bottom: "data" 257 | top: "score" 258 | crop_param { 259 | axis: 2 260 | offset: 18 261 | } 262 | } 263 | layer { 264 | name: "loss" 265 | type: "SoftmaxWithLoss" 266 | bottom: "score" 267 | bottom: "label" 268 | top: "loss" 269 | loss_param { 270 | ignore_label: 255 271 | normalize: true 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn-alexnet/val.prototxt: -------------------------------------------------------------------------------- 1 | layer { 2 | name: "data" 3 | type: "Python" 4 | top: "data" 5 | top: "label" 6 | python_param { 7 | module: "voc_layers" 8 | layer: "VOCSegDataLayer" 9 | param_str: "{\'voc_dir\': \'../data/pascal/VOC2011\', \'seed\': 1337, \'split\': \'seg11valid\', \'mean\': (104.00699, 116.66877, 122.67892)}" 10 | } 11 | } 12 | layer { 13 | name: "conv1" 14 | type: "Convolution" 15 | bottom: "data" 16 | top: "conv1" 17 | convolution_param { 18 | num_output: 96 19 | pad: 100 20 | kernel_size: 11 21 | group: 1 22 | stride: 4 23 | } 24 | } 25 | layer { 26 | name: "relu1" 27 | type: "ReLU" 28 | bottom: "conv1" 29 | top: "conv1" 30 | } 31 | layer { 32 | name: "pool1" 33 | type: "Pooling" 34 | bottom: "conv1" 35 | top: "pool1" 36 | pooling_param { 37 | pool: MAX 38 | kernel_size: 3 39 | stride: 2 40 | } 41 | } 42 | layer { 43 | name: "norm1" 44 | type: "LRN" 45 | bottom: "pool1" 46 | top: "norm1" 47 | lrn_param { 48 | local_size: 5 49 | alpha: 0.0001 50 | beta: 0.75 51 | } 52 | } 53 | layer { 54 | name: "conv2" 55 | type: "Convolution" 56 | bottom: "norm1" 57 | top: "conv2" 58 | convolution_param { 59 | num_output: 256 60 | pad: 2 61 | kernel_size: 5 62 | group: 2 63 | stride: 1 64 | } 65 | } 66 | layer { 67 | name: "relu2" 68 | type: "ReLU" 69 | bottom: "conv2" 70 | top: "conv2" 71 | } 72 | layer { 73 | name: "pool2" 74 | type: "Pooling" 75 | bottom: "conv2" 76 | top: "pool2" 77 | pooling_param { 78 | pool: MAX 79 | kernel_size: 3 80 | stride: 2 81 | } 82 | } 83 | layer { 84 | name: "norm2" 85 | type: "LRN" 86 | bottom: "pool2" 87 | top: "norm2" 88 | lrn_param { 89 | local_size: 5 90 | alpha: 0.0001 91 | beta: 0.75 92 | } 93 | } 94 | layer { 95 | name: "conv3" 96 | type: "Convolution" 97 | bottom: "norm2" 98 | top: "conv3" 99 | convolution_param { 100 | num_output: 384 101 | pad: 1 102 | kernel_size: 3 103 | group: 1 104 | stride: 1 105 | } 106 | } 107 | layer { 108 | name: "relu3" 109 | type: "ReLU" 110 | bottom: "conv3" 111 | top: "conv3" 112 | } 113 | layer { 114 | name: "conv4" 115 | type: "Convolution" 116 | bottom: "conv3" 117 | top: "conv4" 118 | convolution_param { 119 | num_output: 384 120 | pad: 1 121 | kernel_size: 3 122 | group: 2 123 | stride: 1 124 | } 125 | } 126 | layer { 127 | name: "relu4" 128 | type: "ReLU" 129 | bottom: "conv4" 130 | top: "conv4" 131 | } 132 | layer { 133 | name: "conv5" 134 | type: "Convolution" 135 | bottom: "conv4" 136 | top: "conv5" 137 | convolution_param { 138 | num_output: 256 139 | pad: 1 140 | kernel_size: 3 141 | group: 2 142 | stride: 1 143 | } 144 | } 145 | layer { 146 | name: "relu5" 147 | type: "ReLU" 148 | bottom: "conv5" 149 | top: "conv5" 150 | } 151 | layer { 152 | name: "pool5" 153 | type: "Pooling" 154 | bottom: "conv5" 155 | top: "pool5" 156 | pooling_param { 157 | pool: MAX 158 | kernel_size: 3 159 | stride: 2 160 | } 161 | } 162 | layer { 163 | name: "fc6" 164 | type: "Convolution" 165 | bottom: "pool5" 166 | top: "fc6" 167 | convolution_param { 168 | num_output: 4096 169 | pad: 0 170 | kernel_size: 6 171 | group: 1 172 | stride: 1 173 | } 174 | } 175 | layer { 176 | name: "relu6" 177 | type: "ReLU" 178 | bottom: "fc6" 179 | top: "fc6" 180 | } 181 | layer { 182 | name: "drop6" 183 | type: "Dropout" 184 | bottom: "fc6" 185 | top: "fc6" 186 | dropout_param { 187 | dropout_ratio: 0.5 188 | } 189 | } 190 | layer { 191 | name: "fc7" 192 | type: "Convolution" 193 | bottom: "fc6" 194 | top: "fc7" 195 | convolution_param { 196 | num_output: 4096 197 | pad: 0 198 | kernel_size: 1 199 | group: 1 200 | stride: 1 201 | } 202 | } 203 | layer { 204 | name: "relu7" 205 | type: "ReLU" 206 | bottom: "fc7" 207 | top: "fc7" 208 | } 209 | layer { 210 | name: "drop7" 211 | type: "Dropout" 212 | bottom: "fc7" 213 | top: "fc7" 214 | dropout_param { 215 | dropout_ratio: 0.5 216 | } 217 | } 218 | layer { 219 | name: "score_fr" 220 | type: "Convolution" 221 | bottom: "fc7" 222 | top: "score_fr" 223 | param { 224 | lr_mult: 1 225 | decay_mult: 1 226 | } 227 | param { 228 | lr_mult: 2 229 | decay_mult: 0 230 | } 231 | convolution_param { 232 | num_output: 21 233 | pad: 0 234 | kernel_size: 1 235 | } 236 | } 237 | layer { 238 | name: "upscore" 239 | type: "Deconvolution" 240 | bottom: "score_fr" 241 | top: "upscore" 242 | param { 243 | lr_mult: 0 244 | } 245 | convolution_param { 246 | num_output: 21 247 | bias_term: false 248 | kernel_size: 63 249 | stride: 32 250 | } 251 | } 252 | layer { 253 | name: "score" 254 | type: "Crop" 255 | bottom: "upscore" 256 | bottom: "data" 257 | top: "score" 258 | crop_param { 259 | axis: 2 260 | offset: 18 261 | } 262 | } 263 | layer { 264 | name: "loss" 265 | type: "SoftmaxWithLoss" 266 | bottom: "score" 267 | bottom: "label" 268 | top: "loss" 269 | loss_param { 270 | ignore_label: 255 271 | normalize: true 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn16s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/fcn16s-heavy-pascal.caffemodel -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn16s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | pydata_params = dict(split=split, mean=(104.00699, 116.66877, 122.67892), 17 | seed=1337) 18 | if split == 'train': 19 | pydata_params['sbdd_dir'] = '../../data/sbdd/dataset' 20 | pylayer = 'SBDDSegDataLayer' 21 | else: 22 | pydata_params['voc_dir'] = '../../data/pascal/VOC2011' 23 | pylayer = 'VOCSegDataLayer' 24 | n.data, n.label = L.Python(module='voc_layers', layer=pylayer, 25 | ntop=2, param_str=str(pydata_params)) 26 | 27 | # the base net 28 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 29 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 30 | n.pool1 = max_pool(n.relu1_2) 31 | 32 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 33 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 34 | n.pool2 = max_pool(n.relu2_2) 35 | 36 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 37 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 38 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 39 | n.pool3 = max_pool(n.relu3_3) 40 | 41 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 42 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 43 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 44 | n.pool4 = max_pool(n.relu4_3) 45 | 46 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 47 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 48 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 49 | n.pool5 = max_pool(n.relu5_3) 50 | 51 | # fully conv 52 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 53 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 54 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 55 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 56 | n.score_fr = L.Convolution(n.drop7, num_output=21, kernel_size=1, pad=0, 57 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 58 | n.upscore2 = L.Deconvolution(n.score_fr, 59 | convolution_param=dict(num_output=21, kernel_size=4, stride=2, 60 | bias_term=False), 61 | param=[dict(lr_mult=0)]) 62 | 63 | n.score_pool4 = L.Convolution(n.pool4, num_output=21, kernel_size=1, pad=0, 64 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 65 | n.score_pool4c = crop(n.score_pool4, n.upscore2) 66 | n.fuse_pool4 = L.Eltwise(n.upscore2, n.score_pool4c, 67 | operation=P.Eltwise.SUM) 68 | n.upscore16 = L.Deconvolution(n.fuse_pool4, 69 | convolution_param=dict(num_output=21, kernel_size=32, stride=16, 70 | bias_term=False), 71 | param=[dict(lr_mult=0)]) 72 | 73 | n.score = crop(n.upscore16, n.data) 74 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 75 | loss_param=dict(normalize=False, ignore_label=255)) 76 | 77 | return n.to_proto() 78 | 79 | def make_net(): 80 | with open('train.prototxt', 'w') as f: 81 | f.write(str(fcn('train'))) 82 | 83 | with open('val.prototxt', 'w') as f: 84 | f.write(str(fcn('seg11valid'))) 85 | 86 | if __name__ == '__main__': 87 | make_net() 88 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn16s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../voc-fcn32s/voc-fcn32s.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | val = np.loadtxt('../data/segvalid11.txt', dtype=str) 29 | 30 | for _ in range(25): 31 | solver.step(4000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn16s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "train.prototxt" 2 | test_net: "val.prototxt" 3 | test_iter: 736 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-12 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 100000 16 | weight_decay: 0.0005 17 | snapshot: 4000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn32s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/fcn32s-heavy-pascal.caffemodel -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn32s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | pydata_params = dict(split=split, mean=(104.00699, 116.66877, 122.67892), 17 | seed=1337) 18 | if split == 'train': 19 | pydata_params['sbdd_dir'] = '../data/sbdd/dataset' 20 | pylayer = 'SBDDSegDataLayer' 21 | else: 22 | pydata_params['voc_dir'] = '../data/pascal/VOC2011' 23 | pylayer = 'VOCSegDataLayer' 24 | n.data, n.label = L.Python(module='voc_layers', layer=pylayer, 25 | ntop=2, param_str=str(pydata_params)) 26 | 27 | # the base net 28 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 29 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 30 | n.pool1 = max_pool(n.relu1_2) 31 | 32 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 33 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 34 | n.pool2 = max_pool(n.relu2_2) 35 | 36 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 37 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 38 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 39 | n.pool3 = max_pool(n.relu3_3) 40 | 41 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 42 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 43 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 44 | n.pool4 = max_pool(n.relu4_3) 45 | 46 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 47 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 48 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 49 | n.pool5 = max_pool(n.relu5_3) 50 | 51 | # fully conv 52 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 53 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 54 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 55 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 56 | n.score_fr = L.Convolution(n.drop7, num_output=21, kernel_size=1, pad=0, 57 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 58 | n.upscore = L.Deconvolution(n.score_fr, 59 | convolution_param=dict(num_output=21, kernel_size=64, stride=32, 60 | bias_term=False), 61 | param=[dict(lr_mult=0)]) 62 | n.score = crop(n.upscore, n.data) 63 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 64 | loss_param=dict(normalize=False, ignore_label=255)) 65 | 66 | return n.to_proto() 67 | 68 | def make_net(): 69 | with open('train.prototxt', 'w') as f: 70 | f.write(str(fcn('train'))) 71 | 72 | with open('val.prototxt', 'w') as f: 73 | f.write(str(fcn('seg11valid'))) 74 | 75 | if __name__ == '__main__': 76 | make_net() 77 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn32s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../ilsvrc-nets/vgg16-fcn.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | val = np.loadtxt('../data/segvalid11.txt', dtype=str) 29 | 30 | for _ in range(25): 31 | solver.step(4000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn32s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "train.prototxt" 2 | test_net: "val.prototxt" 3 | test_iter: 736 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-10 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 100000 16 | weight_decay: 0.0005 17 | snapshot: 4000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn8s-atonce/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/fcn8s-atonce-pascal.caffemodel 2 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn8s-atonce/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | pydata_params = dict(split=split, mean=(104.00699, 116.66877, 122.67892), 17 | seed=1337) 18 | if split == 'train': 19 | pydata_params['sbdd_dir'] = '../data/sbdd/dataset' 20 | pylayer = 'SBDDSegDataLayer' 21 | else: 22 | pydata_params['voc_dir'] = '../data/pascal/VOC2011' 23 | pylayer = 'VOCSegDataLayer' 24 | n.data, n.label = L.Python(module='voc_layers', layer=pylayer, 25 | ntop=2, param_str=str(pydata_params)) 26 | 27 | # the base net 28 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 29 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 30 | n.pool1 = max_pool(n.relu1_2) 31 | 32 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 33 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 34 | n.pool2 = max_pool(n.relu2_2) 35 | 36 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 37 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 38 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 39 | n.pool3 = max_pool(n.relu3_3) 40 | 41 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 42 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 43 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 44 | n.pool4 = max_pool(n.relu4_3) 45 | 46 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 47 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 48 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 49 | n.pool5 = max_pool(n.relu5_3) 50 | 51 | # fully conv 52 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 53 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 54 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 55 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 56 | 57 | n.score_fr = L.Convolution(n.drop7, num_output=21, kernel_size=1, pad=0, 58 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 59 | n.upscore2 = L.Deconvolution(n.score_fr, 60 | convolution_param=dict(num_output=21, kernel_size=4, stride=2, 61 | bias_term=False), 62 | param=[dict(lr_mult=0)]) 63 | 64 | # scale pool4 skip for compatibility 65 | n.scale_pool4 = L.Scale(n.pool4, filler=dict(type='constant', 66 | value=0.01), param=[dict(lr_mult=0)]) 67 | n.score_pool4 = L.Convolution(n.scale_pool4, num_output=21, kernel_size=1, pad=0, 68 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 69 | n.score_pool4c = crop(n.score_pool4, n.upscore2) 70 | n.fuse_pool4 = L.Eltwise(n.upscore2, n.score_pool4c, 71 | operation=P.Eltwise.SUM) 72 | n.upscore_pool4 = L.Deconvolution(n.fuse_pool4, 73 | convolution_param=dict(num_output=21, kernel_size=4, stride=2, 74 | bias_term=False), 75 | param=[dict(lr_mult=0)]) 76 | 77 | # scale pool3 skip for compatibility 78 | n.scale_pool3 = L.Scale(n.pool3, filler=dict(type='constant', 79 | value=0.0001), param=[dict(lr_mult=0)]) 80 | n.score_pool3 = L.Convolution(n.scale_pool3, num_output=21, kernel_size=1, pad=0, 81 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 82 | n.score_pool3c = crop(n.score_pool3, n.upscore_pool4) 83 | n.fuse_pool3 = L.Eltwise(n.upscore_pool4, n.score_pool3c, 84 | operation=P.Eltwise.SUM) 85 | n.upscore8 = L.Deconvolution(n.fuse_pool3, 86 | convolution_param=dict(num_output=21, kernel_size=16, stride=8, 87 | bias_term=False), 88 | param=[dict(lr_mult=0)]) 89 | 90 | n.score = crop(n.upscore8, n.data) 91 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 92 | loss_param=dict(normalize=False, ignore_label=255)) 93 | 94 | return n.to_proto() 95 | 96 | def make_net(): 97 | with open('train.prototxt', 'w') as f: 98 | f.write(str(fcn('train'))) 99 | 100 | with open('val.prototxt', 'w') as f: 101 | f.write(str(fcn('seg11valid'))) 102 | 103 | if __name__ == '__main__': 104 | make_net() 105 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn8s-atonce/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../ilsvrc-nets/vgg16-fcn.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | val = np.loadtxt('../data/segvalid11.txt', dtype=str) 29 | 30 | for _ in range(75): 31 | solver.step(4000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn8s-atonce/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "train.prototxt" 2 | test_net: "val.prototxt" 3 | test_iter: 736 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-10 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 300000 16 | weight_decay: 0.0005 17 | snapshot: 4000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn8s/caffemodel-url: -------------------------------------------------------------------------------- 1 | http://dl.caffe.berkeleyvision.org/fcn8s-heavy-pascal.caffemodel -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn8s/net.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | from caffe import layers as L, params as P 3 | from caffe.coord_map import crop 4 | 5 | def conv_relu(bottom, nout, ks=3, stride=1, pad=1): 6 | conv = L.Convolution(bottom, kernel_size=ks, stride=stride, 7 | num_output=nout, pad=pad, 8 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 9 | return conv, L.ReLU(conv, in_place=True) 10 | 11 | def max_pool(bottom, ks=2, stride=2): 12 | return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride) 13 | 14 | def fcn(split): 15 | n = caffe.NetSpec() 16 | pydata_params = dict(split=split, mean=(104.00699, 116.66877, 122.67892), 17 | seed=1337) 18 | if split == 'train': 19 | pydata_params['sbdd_dir'] = '../data/sbdd/dataset' 20 | pylayer = 'SBDDSegDataLayer' 21 | else: 22 | pydata_params['voc_dir'] = '../data/pascal/VOC2011' 23 | pylayer = 'VOCSegDataLayer' 24 | n.data, n.label = L.Python(module='voc_layers', layer=pylayer, 25 | ntop=2, param_str=str(pydata_params)) 26 | 27 | # the base net 28 | n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100) 29 | n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64) 30 | n.pool1 = max_pool(n.relu1_2) 31 | 32 | n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128) 33 | n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128) 34 | n.pool2 = max_pool(n.relu2_2) 35 | 36 | n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256) 37 | n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256) 38 | n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256) 39 | n.pool3 = max_pool(n.relu3_3) 40 | 41 | n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512) 42 | n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512) 43 | n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512) 44 | n.pool4 = max_pool(n.relu4_3) 45 | 46 | n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512) 47 | n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512) 48 | n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512) 49 | n.pool5 = max_pool(n.relu5_3) 50 | 51 | # fully conv 52 | n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0) 53 | n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True) 54 | n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0) 55 | n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True) 56 | n.score_fr = L.Convolution(n.drop7, num_output=21, kernel_size=1, pad=0, 57 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 58 | n.upscore2 = L.Deconvolution(n.score_fr, 59 | convolution_param=dict(num_output=21, kernel_size=4, stride=2, 60 | bias_term=False), 61 | param=[dict(lr_mult=0)]) 62 | 63 | n.score_pool4 = L.Convolution(n.pool4, num_output=21, kernel_size=1, pad=0, 64 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 65 | n.score_pool4c = crop(n.score_pool4, n.upscore2) 66 | n.fuse_pool4 = L.Eltwise(n.upscore2, n.score_pool4c, 67 | operation=P.Eltwise.SUM) 68 | n.upscore_pool4 = L.Deconvolution(n.fuse_pool4, 69 | convolution_param=dict(num_output=21, kernel_size=4, stride=2, 70 | bias_term=False), 71 | param=[dict(lr_mult=0)]) 72 | 73 | n.score_pool3 = L.Convolution(n.pool3, num_output=21, kernel_size=1, pad=0, 74 | param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 75 | n.score_pool3c = crop(n.score_pool3, n.upscore_pool4) 76 | n.fuse_pool3 = L.Eltwise(n.upscore_pool4, n.score_pool3c, 77 | operation=P.Eltwise.SUM) 78 | n.upscore8 = L.Deconvolution(n.fuse_pool3, 79 | convolution_param=dict(num_output=21, kernel_size=16, stride=8, 80 | bias_term=False), 81 | param=[dict(lr_mult=0)]) 82 | 83 | n.score = crop(n.upscore8, n.data) 84 | n.loss = L.SoftmaxWithLoss(n.score, n.label, 85 | loss_param=dict(normalize=False, ignore_label=255)) 86 | 87 | return n.to_proto() 88 | 89 | def make_net(): 90 | with open('train.prototxt', 'w') as f: 91 | f.write(str(fcn('train'))) 92 | 93 | with open('val.prototxt', 'w') as f: 94 | f.write(str(fcn('seg11valid'))) 95 | 96 | if __name__ == '__main__': 97 | make_net() 98 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn8s/solve.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import surgery, score 3 | 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | try: 9 | import setproctitle 10 | setproctitle.setproctitle(os.path.basename(os.getcwd())) 11 | except: 12 | pass 13 | 14 | weights = '../voc-fcn16s/voc-fcn16s.caffemodel' 15 | 16 | # init 17 | caffe.set_device(int(sys.argv[1])) 18 | caffe.set_mode_gpu() 19 | 20 | solver = caffe.SGDSolver('solver.prototxt') 21 | solver.net.copy_from(weights) 22 | 23 | # surgeries 24 | interp_layers = [k for k in solver.net.params.keys() if 'up' in k] 25 | surgery.interp(solver.net, interp_layers) 26 | 27 | # scoring 28 | val = np.loadtxt('../data/segvalid11.txt', dtype=str) 29 | 30 | for _ in range(25): 31 | solver.step(4000) 32 | score.seg_tests(solver, False, val, layer='score') 33 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc-fcn8s/solver.prototxt: -------------------------------------------------------------------------------- 1 | train_net: "train.prototxt" 2 | test_net: "val.prototxt" 3 | test_iter: 736 4 | # make test net, but don't invoke it from the solver itself 5 | test_interval: 999999999 6 | display: 20 7 | average_loss: 20 8 | lr_policy: "fixed" 9 | # lr for unnormalized softmax 10 | base_lr: 1e-14 11 | # high momentum 12 | momentum: 0.99 13 | # no gradient accumulation 14 | iter_size: 1 15 | max_iter: 100000 16 | weight_decay: 0.0005 17 | snapshot: 4000 18 | snapshot_prefix: "snapshot/train" 19 | test_initialization: false 20 | -------------------------------------------------------------------------------- /torchfcn/ext/fcn.berkeleyvision.org/voc_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import glob 4 | import numpy as np 5 | 6 | from PIL import Image 7 | 8 | 9 | class voc: 10 | def __init__(self, data_path): 11 | # data_path is /path/to/PASCAL/VOC2011 12 | self.dir = data_path 13 | self.classes = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 14 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 15 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 16 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 17 | # for paletting 18 | reference_idx = '2008_000666' 19 | palette_im = Image.open('{}/SegmentationClass/{}.png'.format( 20 | self.dir, reference_idx)) 21 | self.palette = palette_im.palette 22 | 23 | def load_image(self, idx): 24 | im = Image.open('{}/JPEGImages/{}.jpg'.format(self.dir, idx)) 25 | return im 26 | 27 | def load_label(self, idx): 28 | """ 29 | Load label image as 1 x height x width integer array of label indices. 30 | The leading singleton dimension is required by the loss. 31 | """ 32 | label = Image.open('{}/SegmentationClass/{}.png'.format(self.dir, idx)) 33 | label = np.array(label, dtype=np.uint8) 34 | label = label[np.newaxis, ...] 35 | return label 36 | 37 | def palette(self, label_im): 38 | ''' 39 | Transfer the VOC color palette to an output mask for visualization. 40 | ''' 41 | if label_im.ndim == 3: 42 | label_im = label_im[0] 43 | label = Image.fromarray(label_im, mode='P') 44 | label.palette = copy.copy(self.palette) 45 | return label 46 | -------------------------------------------------------------------------------- /torchfcn/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .fcn32s import FCN32s 3 | from .fcn16s import FCN16s 4 | from .fcn8s import FCN8s 5 | from .fcn8s import FCN8sAtOnce 6 | from .vgg import VGG16 7 | -------------------------------------------------------------------------------- /torchfcn/models/fcn16s.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import fcn 4 | import torch.nn as nn 5 | 6 | from .fcn32s import get_upsampling_weight 7 | 8 | 9 | class FCN16s(nn.Module): 10 | 11 | pretrained_model = \ 12 | osp.expanduser('~/data/models/pytorch/fcn16s_from_caffe.pth') 13 | 14 | @classmethod 15 | def download(cls): 16 | return fcn.data.cached_download( 17 | url='http://drive.google.com/uc?id=1bctu58B6YH9bu6lBBSBB2rUeGlGhYLoP', # NOQA 18 | path=cls.pretrained_model, 19 | md5='a2d4035f669f09483b39c9a14a0d6670', 20 | ) 21 | 22 | def __init__(self, n_class=21): 23 | super(FCN16s, self).__init__() 24 | # conv1 25 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 26 | self.relu1_1 = nn.ReLU(inplace=True) 27 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 28 | self.relu1_2 = nn.ReLU(inplace=True) 29 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 30 | 31 | # conv2 32 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 33 | self.relu2_1 = nn.ReLU(inplace=True) 34 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 35 | self.relu2_2 = nn.ReLU(inplace=True) 36 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 37 | 38 | # conv3 39 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 40 | self.relu3_1 = nn.ReLU(inplace=True) 41 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 42 | self.relu3_2 = nn.ReLU(inplace=True) 43 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 44 | self.relu3_3 = nn.ReLU(inplace=True) 45 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 46 | 47 | # conv4 48 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 49 | self.relu4_1 = nn.ReLU(inplace=True) 50 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 51 | self.relu4_2 = nn.ReLU(inplace=True) 52 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 53 | self.relu4_3 = nn.ReLU(inplace=True) 54 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 55 | 56 | # conv5 57 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 58 | self.relu5_1 = nn.ReLU(inplace=True) 59 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 60 | self.relu5_2 = nn.ReLU(inplace=True) 61 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 62 | self.relu5_3 = nn.ReLU(inplace=True) 63 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 64 | 65 | # fc6 66 | self.fc6 = nn.Conv2d(512, 4096, 7) 67 | self.relu6 = nn.ReLU(inplace=True) 68 | self.drop6 = nn.Dropout2d() 69 | 70 | # fc7 71 | self.fc7 = nn.Conv2d(4096, 4096, 1) 72 | self.relu7 = nn.ReLU(inplace=True) 73 | self.drop7 = nn.Dropout2d() 74 | 75 | self.score_fr = nn.Conv2d(4096, n_class, 1) 76 | self.score_pool4 = nn.Conv2d(512, n_class, 1) 77 | 78 | self.upscore2 = nn.ConvTranspose2d( 79 | n_class, n_class, 4, stride=2, bias=False) 80 | self.upscore16 = nn.ConvTranspose2d( 81 | n_class, n_class, 32, stride=16, bias=False) 82 | 83 | self._initialize_weights() 84 | 85 | def _initialize_weights(self): 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | m.weight.data.zero_() 89 | if m.bias is not None: 90 | m.bias.data.zero_() 91 | if isinstance(m, nn.ConvTranspose2d): 92 | assert m.kernel_size[0] == m.kernel_size[1] 93 | initial_weight = get_upsampling_weight( 94 | m.in_channels, m.out_channels, m.kernel_size[0]) 95 | m.weight.data.copy_(initial_weight) 96 | 97 | def forward(self, x): 98 | h = x 99 | h = self.relu1_1(self.conv1_1(h)) 100 | h = self.relu1_2(self.conv1_2(h)) 101 | h = self.pool1(h) 102 | 103 | h = self.relu2_1(self.conv2_1(h)) 104 | h = self.relu2_2(self.conv2_2(h)) 105 | h = self.pool2(h) 106 | 107 | h = self.relu3_1(self.conv3_1(h)) 108 | h = self.relu3_2(self.conv3_2(h)) 109 | h = self.relu3_3(self.conv3_3(h)) 110 | h = self.pool3(h) 111 | 112 | h = self.relu4_1(self.conv4_1(h)) 113 | h = self.relu4_2(self.conv4_2(h)) 114 | h = self.relu4_3(self.conv4_3(h)) 115 | h = self.pool4(h) 116 | pool4 = h # 1/16 117 | 118 | h = self.relu5_1(self.conv5_1(h)) 119 | h = self.relu5_2(self.conv5_2(h)) 120 | h = self.relu5_3(self.conv5_3(h)) 121 | h = self.pool5(h) 122 | 123 | h = self.relu6(self.fc6(h)) 124 | h = self.drop6(h) 125 | 126 | h = self.relu7(self.fc7(h)) 127 | h = self.drop7(h) 128 | 129 | h = self.score_fr(h) 130 | h = self.upscore2(h) 131 | upscore2 = h # 1/16 132 | 133 | h = self.score_pool4(pool4) 134 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 135 | score_pool4c = h # 1/16 136 | 137 | h = upscore2 + score_pool4c 138 | 139 | h = self.upscore16(h) 140 | h = h[:, :, 27:27 + x.size()[2], 27:27 + x.size()[3]].contiguous() 141 | 142 | return h 143 | 144 | def copy_params_from_fcn32s(self, fcn32s): 145 | for name, l1 in fcn32s.named_children(): 146 | try: 147 | l2 = getattr(self, name) 148 | l2.weight # skip ReLU / Dropout 149 | except Exception: 150 | continue 151 | assert l1.weight.size() == l2.weight.size() 152 | assert l1.bias.size() == l2.bias.size() 153 | l2.weight.data.copy_(l1.weight.data) 154 | l2.bias.data.copy_(l1.bias.data) 155 | -------------------------------------------------------------------------------- /torchfcn/models/vgg.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import fcn 4 | 5 | import torch 6 | import torchvision 7 | 8 | 9 | def VGG16(pretrained=False): 10 | model = torchvision.models.vgg16(pretrained=False) 11 | if not pretrained: 12 | return model 13 | model_file = _get_vgg16_pretrained_model() 14 | state_dict = torch.load(model_file) 15 | model.load_state_dict(state_dict) 16 | return model 17 | 18 | 19 | def _get_vgg16_pretrained_model(): 20 | return fcn.data.cached_download( 21 | url='http://drive.google.com/uc?id=1adDBTGY3GcEB_47dvcibajyzi872RAs3', 22 | path=osp.expanduser('~/data/models/pytorch/vgg16_from_caffe.pth'), 23 | md5='aa75b158f4181e7f6230029eb96c1b13', 24 | ) 25 | -------------------------------------------------------------------------------- /torchfcn/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _fast_hist(label_true, label_pred, n_class): 5 | mask = (label_true >= 0) & (label_true < n_class) 6 | hist = np.bincount( 7 | n_class * label_true[mask].astype(int) + 8 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 9 | return hist 10 | 11 | 12 | def label_accuracy_score(label_trues, label_preds, n_class): 13 | """Returns accuracy score evaluation result. 14 | 15 | - overall accuracy 16 | - mean accuracy 17 | - mean IU 18 | - fwavacc 19 | """ 20 | hist = np.zeros((n_class, n_class)) 21 | for lt, lp in zip(label_trues, label_preds): 22 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 23 | acc = np.diag(hist).sum() / hist.sum() 24 | with np.errstate(divide='ignore', invalid='ignore'): 25 | acc_cls = np.diag(hist) / hist.sum(axis=1) 26 | acc_cls = np.nanmean(acc_cls) 27 | with np.errstate(divide='ignore', invalid='ignore'): 28 | iu = np.diag(hist) / ( 29 | hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) 30 | ) 31 | mean_iu = np.nanmean(iu) 32 | freq = hist.sum(axis=1) / hist.sum() 33 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 34 | return acc, acc_cls, mean_iu, fwavacc 35 | --------------------------------------------------------------------------------