├── libs
├── __init__.py
├── datasets
│ ├── __init__.py
│ └── voc.py
├── models
│ ├── __init__.py
│ ├── resnet.py
│ └── pspnet.py
├── utils
│ ├── __init__.py
│ ├── crf.py
│ └── metric.py
└── caffe.proto
├── data
├── models
│ └── .gitignore
└── datasets
│ ├── cityscapes
│ └── labels.txt
│ ├── voc12
│ └── labels.txt
│ └── ade20k
│ └── labels.txt
├── docs
├── demo.png
├── voc12_ms.json
└── voc12_ss.json
├── config
├── ade20k.yaml
├── voc12_ss.yaml
├── cityscapes.yaml
└── voc12_ms.yaml
├── LICENSE
├── .gitignore
├── README.md
├── draw_model.py
├── demo.py
├── convert.py
└── eval.py
/libs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/models/.gitignore:
--------------------------------------------------------------------------------
1 | *.caffemodel
2 | *.pth
--------------------------------------------------------------------------------
/docs/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kazuto1011/pspnet-pytorch/HEAD/docs/demo.png
--------------------------------------------------------------------------------
/libs/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from .voc import *
3 |
--------------------------------------------------------------------------------
/libs/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from .resnet import *
3 | from .pspnet import *
4 |
--------------------------------------------------------------------------------
/libs/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from .crf import *
3 | from .metric import *
4 |
--------------------------------------------------------------------------------
/data/datasets/cityscapes/labels.txt:
--------------------------------------------------------------------------------
1 | 255 unlabeled
2 | 0 Road
3 | 1 Sidewalk
4 | 2 Building
5 | 3 Wall
6 | 4 Fence
7 | 5 Pole
8 | 6 Traffic light
9 | 7 Traffic sign
10 | 8 Vegetation
11 | 9 Terrain
12 | 10 Sky
13 | 11 Person
14 | 12 Rider
15 | 13 Car
16 | 14 Truck
17 | 15 Bus
18 | 16 Train
19 | 17 Motorcycle
20 | 18 Bicycle
--------------------------------------------------------------------------------
/data/datasets/voc12/labels.txt:
--------------------------------------------------------------------------------
1 | 0 __background__
2 | 1 aeroplane
3 | 2 bicycle
4 | 3 bird
5 | 4 boat
6 | 5 bottle
7 | 6 bus
8 | 7 car
9 | 8 cat
10 | 9 chair
11 | 10 cow
12 | 11 diningtable
13 | 12 dog
14 | 13 horse
15 | 14 motorbike
16 | 15 person
17 | 16 pottedplant
18 | 17 sheep
19 | 18 sofa
20 | 19 train
21 | 20 tvmonitor
--------------------------------------------------------------------------------
/config/ade20k.yaml:
--------------------------------------------------------------------------------
1 | DATASET: 'ade20k'
2 | DATASET_ROOT:
3 | CAFFE_MODEL: 'data/models/pspnet50_ADE20K.caffemodel'
4 | PYTORCH_MODEL: 'data/models/pspnet50_ADE20K.pth'
5 | LABELS: 'data/datasets/ade20k/labels.txt'
6 | N_CLASSES: 150
7 | N_BLOCKS: [3, 4, 6, 3]
8 | PYRAMIDS: [6, 3, 2, 1]
9 | IMAGE:
10 | SIZE:
11 | BASE: 512
12 | TRAIN: 473
13 | TEST: 473
14 | MEAN:
15 | R: 123.68
16 | G: 116.779
17 | B: 103.939
18 | SCALES: [0.5, 0.75, 1, 1.25, 1.5, 1.75]
19 | NUM_WORKERS: 1
--------------------------------------------------------------------------------
/config/voc12_ss.yaml:
--------------------------------------------------------------------------------
1 | DATASET: 'voc12'
2 | DATASET_ROOT: '/media/kazuto1011/Extra/VOCdevkit'
3 | CAFFE_MODEL: 'data/models/pspnet101_VOC2012.caffemodel'
4 | PYTORCH_MODEL: 'data/models/pspnet101_VOC2012.pth'
5 | LABELS: 'data/datasets/voc12/labels.txt'
6 | N_CLASSES: 21
7 | N_BLOCKS: [3, 4, 23, 3]
8 | PYRAMIDS: [6, 3, 2, 1]
9 | IMAGE:
10 | SIZE:
11 | BASE: 512
12 | TRAIN: 473
13 | TEST: 473
14 | MEAN:
15 | R: 123.68
16 | G: 116.779
17 | B: 103.939
18 | SCALES: [1,]
19 | NUM_WORKERS: 1
--------------------------------------------------------------------------------
/config/cityscapes.yaml:
--------------------------------------------------------------------------------
1 | DATASET: 'cityscapes'
2 | DATASET_ROOT:
3 | CAFFE_MODEL: 'data/models/pspnet101_cityscapes.caffemodel'
4 | PYTORCH_MODEL: 'data/models/pspnet101_cityscapes.pth'
5 | LABELS: 'data/datasets/cityscapes/labels.txt'
6 | N_CLASSES: 19
7 | N_BLOCKS: [3, 4, 23, 3]
8 | PYRAMIDS: [6, 3, 2, 1]
9 | IMAGE:
10 | SIZE:
11 | BASE: 2048
12 | TRAIN: 713
13 | TEST: 713
14 | MEAN:
15 | R: 123.68
16 | G: 116.779
17 | B: 103.939
18 | SCALES: [0.5, 0.75, 1, 1.25, 1.5, 1.75]
19 | NUM_WORKERS: 1
--------------------------------------------------------------------------------
/config/voc12_ms.yaml:
--------------------------------------------------------------------------------
1 | DATASET: 'voc12'
2 | DATASET_ROOT: '/media/kazuto1011/Extra/VOCdevkit'
3 | CAFFE_MODEL: 'data/models/pspnet101_VOC2012.caffemodel'
4 | PYTORCH_MODEL: 'data/models/pspnet101_VOC2012.pth'
5 | LABELS: 'data/datasets/voc12/labels.txt'
6 | N_CLASSES: 21
7 | N_BLOCKS: [3, 4, 23, 3]
8 | PYRAMIDS: [6, 3, 2, 1]
9 | IMAGE:
10 | SIZE:
11 | BASE: 512
12 | TRAIN: 473
13 | TEST: 473
14 | MEAN:
15 | R: 123.68
16 | G: 116.779
17 | B: 103.939
18 | SCALES: [0.5, 0.75, 1, 1.25, 1.5, 1.75]
19 | NUM_WORKERS: 1
--------------------------------------------------------------------------------
/libs/utils/crf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Kazuto Nakashima
5 | # URL: http://kazuto1011.github.io
6 | # Created: 2017-11-01
7 |
8 | import numpy as np
9 | import pydensecrf.densecrf as dcrf
10 | import pydensecrf.utils as utils
11 |
12 | MAX_ITER = 10
13 | POS_W = 3
14 | POS_XY_STD = 3
15 | Bi_W = 4
16 | Bi_XY_STD = 49
17 | Bi_RGB_STD = 5
18 |
19 |
20 | def dense_crf(img, probs):
21 | c = probs.shape[0]
22 | h = probs.shape[1]
23 | w = probs.shape[2]
24 |
25 | U = utils.unary_from_softmax(probs)
26 | U = np.ascontiguousarray(U)
27 |
28 | img = np.ascontiguousarray(img)
29 |
30 | d = dcrf.DenseCRF2D(w, h, c)
31 | d.setUnaryEnergy(U)
32 | d.addPairwiseGaussian(sxy=POS_XY_STD, compat=POS_W)
33 | d.addPairwiseBilateral(sxy=Bi_XY_STD, srgb=Bi_RGB_STD, rgbim=img, compat=Bi_W)
34 |
35 | Q = d.inference(MAX_ITER)
36 | Q = np.array(Q).reshape((c, h, w))
37 |
38 | return Q
39 |
--------------------------------------------------------------------------------
/docs/voc12_ms.json:
--------------------------------------------------------------------------------
1 | {
2 | "Class IoU": {
3 | "0": 0.9723835140393139,
4 | "1": 0.932284405801355,
5 | "2": 0.7468137710962681,
6 | "3": 0.9218783356840369,
7 | "4": 0.8484579004810219,
8 | "5": 0.8696297249201347,
9 | "6": 0.9720973207300103,
10 | "7": 0.935915049578212,
11 | "8": 0.9374628071322245,
12 | "9": 0.6875070250549213,
13 | "10": 0.9582069281376151,
14 | "11": 0.8308168870071049,
15 | "12": 0.9282477592335033,
16 | "13": 0.9304011384698115,
17 | "14": 0.9140842413972792,
18 | "15": 0.9055205679928403,
19 | "16": 0.8116604297212983,
20 | "17": 0.9352167645389193,
21 | "18": 0.8234079050857073,
22 | "19": 0.9300925249646486,
23 | "20": 0.8099681859474088
24 | },
25 | "FreqW Acc": 0.9531461900878117,
26 | "Mean Acc": 0.9279972291530902,
27 | "Mean IoU": 0.8858120565244588,
28 | "Overall Acc": 0.9756045013064952
29 | }
--------------------------------------------------------------------------------
/docs/voc12_ss.json:
--------------------------------------------------------------------------------
1 | {
2 | "Class IoU": {
3 | "0": 0.9720507018787903,
4 | "1": 0.9296281094338285,
5 | "2": 0.7234742052854309,
6 | "3": 0.9225243177338182,
7 | "4": 0.8318045051638611,
8 | "5": 0.8782490336084968,
9 | "6": 0.9715514433351778,
10 | "7": 0.9402895154307185,
11 | "8": 0.9372223191974904,
12 | "9": 0.6759331602494226,
13 | "10": 0.9526220286282768,
14 | "11": 0.8278434084687414,
15 | "12": 0.9248270465951416,
16 | "13": 0.9248037370339303,
17 | "14": 0.9131082290202047,
18 | "15": 0.9030004679899833,
19 | "16": 0.7972657605334218,
20 | "17": 0.9282957360023752,
21 | "18": 0.8207954040092147,
22 | "19": 0.934780815109344,
23 | "20": 0.7932057422450404
24 | },
25 | "FreqW Acc": 0.9521466591274718,
26 | "Mean Acc": 0.9302564874727208,
27 | "Mean IoU": 0.8811083660453672,
28 | "Overall Acc": 0.9749297938026947
29 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Kazuto Nakashima
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/libs/utils/metric.py:
--------------------------------------------------------------------------------
1 | # Originally written by wkentaro
2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3 |
4 | import numpy as np
5 |
6 |
7 | def _fast_hist(label_true, label_pred, n_class):
8 | mask = (label_true >= 0) & (label_true < n_class)
9 | hist = np.bincount(
10 | n_class * label_true[mask].astype(int) + label_pred[mask],
11 | minlength=n_class ** 2,
12 | ).reshape(n_class, n_class)
13 | return hist
14 |
15 |
16 | def scores(label_trues, label_preds, n_class):
17 | hist = np.zeros((n_class, n_class))
18 | for lt, lp in zip(label_trues, label_preds):
19 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
20 | acc = np.diag(hist).sum() / hist.sum()
21 | acc_cls = np.diag(hist) / hist.sum(axis=1)
22 | acc_cls = np.nanmean(acc_cls)
23 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
24 | mean_iu = np.nanmean(iu)
25 | freq = hist.sum(axis=1) / hist.sum()
26 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
27 | cls_iu = dict(zip(range(n_class), iu))
28 |
29 | return (
30 | {
31 | "Overall Acc": acc,
32 | "Mean Acc": acc_cls,
33 | "FreqW Acc": fwavacc,
34 | "Mean IoU": mean_iu,
35 | },
36 | cls_iu,
37 | )
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | .vscode/
104 | .style.yapf
105 | results.json
106 | model.pdf
--------------------------------------------------------------------------------
/libs/datasets/voc.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import os.path as osp
5 | import sys
6 |
7 | import cv2
8 | import numpy as np
9 | import torch
10 | import torch.utils.data as data
11 | from PIL import Image
12 |
13 |
14 | class VOCSegmentation(data.Dataset):
15 | def __init__(
16 | self,
17 | root,
18 | image_set,
19 | transform=None,
20 | target_transform=None,
21 | dataset_name="VOC2007",
22 | ):
23 | self.root = root
24 | self.image_set = image_set
25 | self.transform = transform
26 | self.target_transform = target_transform
27 | self.mean_rgb = np.array([123.68, 116.779, 103.939])
28 |
29 | self._annopath = osp.join(
30 | self.root, dataset_name, "SegmentationClass", "%s.png"
31 | )
32 | self._imgpath = osp.join(self.root, dataset_name, "JPEGImages", "%s.jpg")
33 | self._imgsetpath = osp.join(
34 | self.root, dataset_name, "ImageSets", "Segmentation", "%s.txt"
35 | )
36 |
37 | with open(self._imgsetpath % self.image_set) as f:
38 | self.ids = f.readlines()
39 | self.ids = [x.strip("\n") for x in self.ids]
40 |
41 | def __getitem__(self, index):
42 | img_id = self.ids[index]
43 |
44 | target = Image.open(self._annopath % img_id)
45 | img = Image.open(self._imgpath % img_id).convert("RGB")
46 |
47 | if self.transform is not None:
48 | img = self.transform(img)
49 |
50 | if self.target_transform is not None:
51 | target = self.target_transform(target)
52 |
53 | img = (np.array(img) - self.mean_rgb).transpose(2, 0, 1)
54 | img = torch.from_numpy(img.astype(np.float32))
55 | target = np.array(target, dtype=np.int32)
56 | target[target == 255] = -1
57 | target = torch.from_numpy(target.astype(np.int64))
58 |
59 | return img, target
60 |
61 | def __len__(self):
62 | return len(self.ids)
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PSPNet with PyTorch
2 |
3 | Unofficial implementation of "Pyramid Scene Parsing Network" (https://arxiv.org/abs/1612.01105). This repository is just for caffe to pytorch model conversion and evaluation.
4 |
5 | ### Requirements
6 |
7 | * pytorch
8 | * click
9 | * addict
10 | * pydensecrf
11 | * protobuf
12 |
13 | ## Preparation
14 | Instead of building the author's caffe implementation, you can convert off-the-shelf caffemodels to pytorch models via the ```caffe.proto```.
15 |
16 | ### 1. Compile the ```caffe.proto``` for Python API
17 | This step can be skipped. FYI.
18 | Download [the author's ```caffe.proto```](https://github.com/hszhao/PSPNet/blob/master/src/caffe/proto/caffe.proto) into the ```libs```, not the one in the original caffe.
19 | ```sh
20 | # For protoc command
21 | pip install protobuf
22 | # This generates ./caffe_pb2.py
23 | protoc --python_out=. caffe.proto
24 | ```
25 |
26 | ### 2. Model conversion
27 |
28 | 1. Find the caffemodels on [the author's page](https://github.com/hszhao/PSPNet#usage) (e.g. pspnet50_ADE20K.caffemodel) and store them to the ```data/models/``` directory.
29 | 2. Convert the caffemodels to ```.pth``` file.
30 |
31 | ```sh
32 | python convert.py -c
33 | ```
34 |
35 | ## Demo
36 |
37 | ```sh
38 | python demo.py -c -i
39 | ```
40 | * With a ```--no-cuda``` option, this runs on CPU.
41 | * With a ```--crf``` option, you can perform a CRF postprocessing.
42 |
43 | 
44 |
45 | ## Evaluation
46 |
47 | PASCAL VOC2012 only. Please set the dataset path in ```config/voc12.yaml```.
48 |
49 | ```sh
50 | python eval.py -c config/voc12.yaml
51 | ```
52 |
53 | 88.1% mIoU (SS) and 88.6% mIoU (MS) on validation set.
54 | *NOTE: 3 points lower than caffe implementation. WIP*
55 |
56 | * SS: averaged prediction with flipping (2x)
57 | * MS: averaged prediction with multi-scaling (6x) and flipping (2x)
58 | * Both: No CRF post-processing
59 |
60 | ## References
61 |
62 | * Official implementation: https://github.com/hszhao/PSPNet
63 | * Chainer implementation: https://github.com/mitmul/chainer-pspnet
--------------------------------------------------------------------------------
/draw_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Kazuto Nakashima
5 | # URL: http://kazuto1011.github.io
6 | # Created: 2017-11-06
7 |
8 | import torch
9 | from graphviz import Digraph
10 | from torch.autograd import Variable
11 |
12 | from libs.models import *
13 |
14 |
15 | def make_dot(var, params):
16 |
17 | param_map = {id(v): k for k, v in params.items()}
18 |
19 | node_attr = dict(
20 | style="filled",
21 | shape="box",
22 | align="left",
23 | fontsize="12",
24 | ranksep="0.1",
25 | height="0.2",
26 | )
27 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
28 | seen = set()
29 |
30 | def size_to_str(size):
31 | return "(" + (", ").join(["%d" % v for v in size]) + ")"
32 |
33 | def add_nodes(var):
34 | if var not in seen:
35 | if torch.is_tensor(var):
36 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor="orange")
37 | elif hasattr(var, "variable"):
38 | u = var.variable
39 | dot.node(str(id(var)), size_to_str(u.size()), fillcolor="lightblue")
40 | else:
41 | dot.node(str(id(var)), str(type(var).__name__.replace("Backward", "")))
42 | seen.add(var)
43 | if hasattr(var, "next_functions"):
44 | for u in var.next_functions:
45 | if u[0] is not None:
46 | dot.edge(str(id(u[0])), str(id(var)))
47 | add_nodes(u[0])
48 | if hasattr(var, "saved_tensors"):
49 | for t in var.saved_tensors:
50 | dot.edge(str(id(t)), str(id(var)))
51 | add_nodes(t)
52 |
53 | add_nodes(var.grad_fn)
54 |
55 | return dot
56 |
57 |
58 | if __name__ == "__main__":
59 | # Define a model
60 | model = PSPNet(n_classes=6, n_blocks=[3, 4, 6, 3], pyramids=[6, 3, 2, 1])
61 |
62 | # Build a computational graph from x to y
63 | x = torch.randn(2, 3, 512, 512)
64 | y1, y2 = model(Variable(x))
65 | g = make_dot(y1 + y2, model.state_dict())
66 | g.view(filename="model", cleanup=True)
67 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Kazuto Nakashima
5 | # URL: http://kazuto1011.github.io
6 | # Created: 2017-11-15
7 |
8 | import click
9 | import cv2
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import yaml
16 | from addict import Dict
17 | from torch.autograd import Variable
18 |
19 | from libs.models import PSPNet
20 | from libs.utils import dense_crf
21 |
22 |
23 | @click.command()
24 | @click.option("--config", "-c", required=True)
25 | @click.option("--image-path", "-i", required=True)
26 | @click.option("--cuda/--no-cuda", default=True)
27 | @click.option("--crf", is_flag=True)
28 | def main(config, image_path, cuda, crf):
29 | CONFIG = Dict(yaml.load(open(config)))
30 |
31 | cuda = cuda and torch.cuda.is_available()
32 |
33 | # Label list
34 | with open(CONFIG.LABELS) as f:
35 | classes = {}
36 | for label in f:
37 | label = label.rstrip().split("\t")
38 | classes[int(label[0])] = label[1].split(",")[0]
39 |
40 | # Load a model
41 | state_dict = torch.load(CONFIG.PYTORCH_MODEL)
42 |
43 | # Model
44 | model = PSPNet(
45 | n_classes=CONFIG.N_CLASSES, n_blocks=CONFIG.N_BLOCKS, pyramids=CONFIG.PYRAMIDS
46 | )
47 | model.load_state_dict(state_dict)
48 | model.eval()
49 | if cuda:
50 | model.cuda()
51 |
52 | image_size = (CONFIG.IMAGE.SIZE.TEST,) * 2
53 |
54 | # Image preprocessing
55 | image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(float)
56 | image = cv2.resize(image, image_size)
57 | image_original = image.astype(np.uint8)
58 | image = image[..., ::-1] - np.array(
59 | [CONFIG.IMAGE.MEAN.R, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.B]
60 | )
61 | image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
62 | image = image.cuda() if cuda else image
63 |
64 | # Inference
65 | output = model(Variable(image, volatile=True))
66 |
67 | output = F.upsample(output, size=image_size, mode="bilinear")
68 | output = F.softmax(output, dim=1)
69 | output = output[0].cpu().data.numpy()
70 |
71 | if crf:
72 | output = dense_crf(image_original, output)
73 | labelmap = np.argmax(output.transpose(1, 2, 0), axis=2)
74 |
75 | labels = np.unique(labelmap)
76 |
77 | rows = np.floor(np.sqrt(len(labels) + 1))
78 | cols = np.ceil((len(labels) + 1) / rows)
79 |
80 | plt.figure(figsize=(10, 10))
81 | ax = plt.subplot(rows, cols, 1)
82 | ax.set_title("Input image")
83 | ax.imshow(image_original[:, :, ::-1])
84 | ax.set_xticks([])
85 | ax.set_yticks([])
86 |
87 | for i, label in enumerate(labels):
88 | print("{0:3d}: {1}".format(label, classes[label]))
89 | mask = labelmap == label
90 | ax = plt.subplot(rows, cols, i + 2)
91 | ax.set_title(classes[label])
92 | ax.imshow(image_original[:, :, ::-1])
93 | ax.imshow(mask.astype(np.float32), alpha=0.5, cmap="viridis")
94 | ax.set_xticks([])
95 | ax.set_yticks([])
96 |
97 | plt.tight_layout()
98 | plt.show()
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/libs/models/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Kazuto Nakashima
5 | # URL: http://kazuto1011.github.io
6 | # Created: 2017-11-19
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.model_zoo as model_zoo
11 | import torch.nn.functional as F
12 | from collections import OrderedDict
13 |
14 |
15 | class _ConvBatchNormReLU(nn.Sequential):
16 | """Convolution Unit"""
17 |
18 | def __init__(
19 | self,
20 | in_channels,
21 | out_channels,
22 | kernel_size,
23 | stride,
24 | padding,
25 | dilation,
26 | relu=True,
27 | ):
28 | super(_ConvBatchNormReLU, self).__init__()
29 | self.add_module(
30 | "conv",
31 | nn.Conv2d(
32 | in_channels=in_channels,
33 | out_channels=out_channels,
34 | kernel_size=kernel_size,
35 | stride=stride,
36 | padding=padding,
37 | dilation=dilation,
38 | bias=False,
39 | ),
40 | )
41 | self.add_module(
42 | "bn", nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.95, affine=True)
43 | )
44 | if relu:
45 | self.add_module("relu", nn.ReLU())
46 |
47 | def forward(self, x):
48 | return super(_ConvBatchNormReLU, self).forward(x)
49 |
50 |
51 | class _Bottleneck(nn.Module):
52 | """Bottleneck Unit"""
53 |
54 | def __init__(
55 | self, in_channels, mid_channels, out_channels, stride, dilation, downsample
56 | ):
57 | super(_Bottleneck, self).__init__()
58 | self.reduce = _ConvBatchNormReLU(in_channels, mid_channels, 1, 1, 0, 1)
59 | self.conv3x3 = _ConvBatchNormReLU(
60 | mid_channels, mid_channels, 3, stride, dilation, dilation
61 | )
62 | self.increase = _ConvBatchNormReLU(
63 | mid_channels, out_channels, 1, 1, 0, 1, relu=False
64 | )
65 | self.downsample = downsample
66 | if self.downsample:
67 | self.proj = _ConvBatchNormReLU(
68 | in_channels, out_channels, 1, stride, 0, 1, relu=False
69 | )
70 |
71 | def forward(self, x):
72 | h = self.reduce(x)
73 | h = self.conv3x3(h)
74 | h = self.increase(h)
75 | if self.downsample:
76 | h += self.proj(x)
77 | else:
78 | h += x
79 | return F.relu(h)
80 |
81 |
82 | class _ResBlock(nn.Sequential):
83 | """Residual Block"""
84 |
85 | def __init__(
86 | self, n_layers, in_channels, mid_channels, out_channels, stride, dilation
87 | ):
88 | super(_ResBlock, self).__init__()
89 | self.add_module(
90 | "block1",
91 | _Bottleneck(
92 | in_channels, mid_channels, out_channels, stride, dilation, True
93 | ),
94 | )
95 | for i in range(2, n_layers + 1):
96 | self.add_module(
97 | "block" + str(i),
98 | _Bottleneck(
99 | out_channels, mid_channels, out_channels, 1, dilation, False
100 | ),
101 | )
102 |
103 | def __call__(self, x):
104 | return super(_ResBlock, self).forward(x)
105 |
--------------------------------------------------------------------------------
/libs/models/pspnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Kazuto Nakashima
5 | # URL: http://kazuto1011.github.io
6 | # Created: 2017-11-15
7 |
8 | from __future__ import absolute_import
9 |
10 | from collections import OrderedDict
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 | from .resnet import _ConvBatchNormReLU, _ResBlock
17 |
18 |
19 | class _DilatedFCN(nn.Module):
20 | """ResNet-based Dilated FCN"""
21 |
22 | def __init__(self, n_blocks):
23 | super(_DilatedFCN, self).__init__()
24 | self.layer1 = nn.Sequential(
25 | OrderedDict(
26 | [
27 | ("conv1", _ConvBatchNormReLU(3, 64, 3, 2, 1, 1)),
28 | ("conv2", _ConvBatchNormReLU(64, 64, 3, 1, 1, 1)),
29 | ("conv3", _ConvBatchNormReLU(64, 128, 3, 1, 1, 1)),
30 | ("pool", nn.MaxPool2d(3, 2, 1)),
31 | ]
32 | )
33 | )
34 | self.layer2 = _ResBlock(n_blocks[0], 128, 64, 256, 1, 1)
35 | self.layer3 = _ResBlock(n_blocks[1], 256, 128, 512, 2, 1)
36 | self.layer4 = _ResBlock(n_blocks[2], 512, 256, 1024, 1, 2)
37 | self.layer5 = _ResBlock(n_blocks[3], 1024, 512, 2048, 1, 4)
38 |
39 | def forward(self, x):
40 | h = self.layer1(x)
41 | h = self.layer2(h)
42 | h = self.layer3(h)
43 | h1 = self.layer4(h)
44 | h2 = self.layer5(h1)
45 | if self.training:
46 | return h1, h2
47 | else:
48 | return h2
49 |
50 |
51 | class _PyramidPoolModule(nn.Sequential):
52 | """Pyramid Pooling Module"""
53 |
54 | def __init__(self, in_channels, pyramids=[6, 3, 2, 1]):
55 | super(_PyramidPoolModule, self).__init__()
56 | out_channels = in_channels // len(pyramids)
57 | self.stages = nn.Module()
58 | for i, p in enumerate(pyramids):
59 | self.stages.add_module(
60 | "s{}".format(i),
61 | nn.Sequential(
62 | OrderedDict(
63 | [
64 | ("pool", nn.AdaptiveAvgPool2d(output_size=p)),
65 | (
66 | "conv",
67 | _ConvBatchNormReLU(
68 | in_channels, out_channels, 1, 1, 0, 1
69 | ),
70 | ),
71 | ]
72 | )
73 | ),
74 | )
75 |
76 | def forward(self, x):
77 | hs = [x]
78 | height, width = x.size()[2:]
79 | for stage in self.stages.children():
80 | h = stage(x)
81 | h = F.upsample(h, (height, width), mode="bilinear")
82 | hs.append(h)
83 | return torch.cat(hs, dim=1)
84 |
85 |
86 | class PSPNet(nn.Module):
87 | """Pyramid Scene Parsing Network"""
88 |
89 | def __init__(self, n_classes, n_blocks, pyramids):
90 | super(PSPNet, self).__init__()
91 | self.n_classes = n_classes
92 | self.fcn = _DilatedFCN(n_blocks=n_blocks)
93 | self.ppm = _PyramidPoolModule(in_channels=2048, pyramids=pyramids)
94 | # Main branch
95 | self.final = nn.Sequential(
96 | OrderedDict(
97 | [
98 | ("conv5_4", _ConvBatchNormReLU(4096, 512, 3, 1, 1, 1)),
99 | ("drop5_4", nn.Dropout2d(p=0.1)),
100 | ("conv6", nn.Conv2d(512, n_classes, 1, stride=1, padding=0)),
101 | ]
102 | )
103 | )
104 | # Auxiliary branch
105 | self.aux = nn.Sequential(
106 | OrderedDict(
107 | [
108 | ("conv4_aux", _ConvBatchNormReLU(1024, 256, 3, 1, 1, 1)),
109 | ("drop4_aux", nn.Dropout2d(p=0.1)),
110 | ("conv6_1", nn.Conv2d(256, n_classes, 1, stride=1, padding=0)),
111 | ]
112 | )
113 | )
114 |
115 | def forward(self, x):
116 | if self.training:
117 | aux, h = self.fcn(x)
118 | aux = self.aux(aux)
119 | else:
120 | h = self.fcn(x)
121 | h = self.ppm(h)
122 | h = self.final(h)
123 |
124 | if self.training:
125 | return aux, h
126 | else:
127 | return h
128 |
129 |
130 | if __name__ == "__main__":
131 | model = PSPNet(n_classes=150, n_blocks=[3, 4, 6, 3], pyramids=[6, 3, 2, 1])
132 | print(list(model.named_children()))
133 | model.eval()
134 | image = torch.autograd.Variable(torch.randn(1, 3, 473, 473))
135 | print(model(image).size())
136 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Kazuto Nakashima
5 | # URL: http://kazuto1011.github.io
6 | # Created: 2017-11-15
7 |
8 | from __future__ import print_function
9 |
10 | import re
11 | from collections import OrderedDict
12 |
13 | import click
14 | import numpy as np
15 | import torch
16 | import yaml
17 | from addict import Dict
18 |
19 | from libs import caffe_pb2
20 | from libs.models import PSPNet
21 |
22 |
23 | def parse_caffemodel(model_path):
24 | caffemodel = caffe_pb2.NetParameter()
25 | with open(model_path, "rb") as f:
26 | caffemodel.MergeFromString(f.read())
27 |
28 | # Check trainable layers
29 | print(set([(layer.type, len(layer.blobs)) for layer in caffemodel.layer]))
30 |
31 | params = OrderedDict()
32 | for layer in caffemodel.layer:
33 | print("{} ({}): {}".format(layer.name, layer.type, len(layer.blobs)))
34 |
35 | # Convolution or Dilated Convolution
36 | if "Convolution" in layer.type:
37 | params[layer.name] = {}
38 | params[layer.name]["kernel_size"] = layer.convolution_param.kernel_size[0]
39 | params[layer.name]["stride"] = layer.convolution_param.stride[0]
40 | params[layer.name]["weight"] = list(layer.blobs[0].data)
41 | if len(layer.blobs) == 2:
42 | params[layer.name]["bias"] = list(layer.blobs[1].data)
43 | if len(layer.convolution_param.pad) == 1: # or []
44 | params[layer.name]["padding"] = layer.convolution_param.pad[0]
45 | else:
46 | params[layer.name]["padding"] = 0
47 | if isinstance(layer.convolution_param.dilation, int): # or []
48 | params[layer.name]["dilation"] = layer.convolution_param.dilation
49 | else:
50 | params[layer.name]["dilation"] = 1
51 |
52 | # Batch Normalization
53 | elif "BN" in layer.type:
54 | params[layer.name] = {}
55 | params[layer.name]["weight"] = list(layer.blobs[0].data)
56 | params[layer.name]["bias"] = list(layer.blobs[1].data)
57 | params[layer.name]["running_mean"] = list(layer.blobs[2].data)
58 | params[layer.name]["running_var"] = list(layer.blobs[3].data)
59 | params[layer.name]["eps"] = layer.bn_param.eps
60 | params[layer.name]["momentum"] = layer.bn_param.momentum
61 |
62 | return params
63 |
64 |
65 | # Hard coded translater
66 | def translate_layer_name(source):
67 | def conv_or_bn(source):
68 | if "bn" in source:
69 | return ".bn"
70 | else:
71 | return ".conv"
72 |
73 | source = re.split("[_/]", source)
74 | layer = int(source[0][4]) # Remove "conv"
75 | target = ""
76 |
77 | if layer == 1:
78 | target += "fcn.layer{}.conv{}".format(layer, source[1])
79 | target += conv_or_bn(source)
80 | elif layer in range(2, 6):
81 | block = int(source[1])
82 | # Auxirally layer
83 | if layer == 4 and len(source) == 3 and source[2] == "bn":
84 | target += "aux.conv4_aux.bn"
85 | elif layer == 4 and len(source) == 2:
86 | target += "aux.conv4_aux.conv"
87 | # Pyramid pooling modules
88 | elif layer == 5 and block == 3 and "pool" in source[2]:
89 | pyramid = {1: 3, 2: 2, 3: 1, 6: 0}[int(source[2][4])]
90 | target += "ppm.stages.s{}.conv".format(pyramid)
91 | target += conv_or_bn(source)
92 | # Last convolutions
93 | elif layer == 5 and block == 4:
94 | target += "final.conv5_4"
95 | target += conv_or_bn(source)
96 | else:
97 | target += "fcn.layer{}".format(layer)
98 | target += ".block{}".format(block)
99 | if source[2] == "3x3":
100 | target += ".conv3x3"
101 | else:
102 | target += ".{}".format(source[3])
103 | target += conv_or_bn(source)
104 | elif layer == 6:
105 | if len(source) == 1:
106 | target += "final.conv6"
107 | else:
108 | target += "aux.conv6_1"
109 |
110 | return target
111 |
112 |
113 | @click.command()
114 | @click.option("--config", "-c", required=True)
115 | def main(config):
116 | WHITELIST = ["kernel_size", "stride", "padding", "dilation", "eps", "momentum"]
117 | CONFIG = Dict(yaml.load(open(config)))
118 |
119 | params = parse_caffemodel(CONFIG.CAFFE_MODEL)
120 |
121 | model = PSPNet(
122 | n_classes=CONFIG.N_CLASSES, n_blocks=CONFIG.N_BLOCKS, pyramids=CONFIG.PYRAMIDS
123 | )
124 | model.eval()
125 | own_state = model.state_dict()
126 |
127 | report = []
128 | state_dict = OrderedDict()
129 | for layer_name, layer_dict in params.items():
130 | for param_name, values in layer_dict.items():
131 | if param_name in WHITELIST:
132 | attribute = translate_layer_name(layer_name)
133 | attribute = eval("model." + attribute + "." + param_name)
134 | message = " ".join(
135 | [
136 | layer_name.ljust(25),
137 | "->",
138 | param_name,
139 | "pytorch: " + str(attribute),
140 | "caffe: " + str(values),
141 | ]
142 | )
143 | print(message, end="")
144 | if isinstance(attribute, tuple):
145 | if attribute[0] != values:
146 | report.append(message)
147 | else:
148 | if abs(attribute - values) > 1e-4:
149 | report.append(message)
150 | print(": Checked!")
151 | continue
152 | param_name = translate_layer_name(layer_name) + "." + param_name
153 | if param_name in own_state:
154 | print(layer_name.ljust(25), "->", param_name, end="")
155 | values = torch.FloatTensor(values)
156 | values = values.view_as(own_state[param_name])
157 | state_dict[param_name] = values
158 | print(": Copied!")
159 |
160 | print("Inconsistent parameters (*_3x3 dilation and momentum can be ignored):")
161 | print(*report, sep="\n")
162 |
163 | # Check
164 | model.load_state_dict(state_dict)
165 | torch.save(state_dict, CONFIG.PYTORCH_MODEL)
166 |
167 |
168 | if __name__ == "__main__":
169 | main()
170 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Kazuto Nakashima
5 | # URL: http://kazuto1011.github.io
6 | # Created: 2018-01-24
7 |
8 | import json
9 | import pickle
10 | from math import ceil
11 |
12 | import click
13 | import cv2
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import yaml
19 | from addict import Dict
20 | from PIL import Image
21 | from torch.autograd import Variable
22 | from torchvision import transforms
23 | from tqdm import tqdm
24 |
25 | from libs.datasets import VOCSegmentation
26 | from libs.models import PSPNet
27 | from libs.utils import scores
28 |
29 |
30 | def pad_image(image, crop_size):
31 | new_h, new_w = image.shape[2:]
32 | pad_h = max(crop_size - new_h, 0)
33 | pad_w = max(crop_size - new_w, 0)
34 | padded_image = torch.FloatTensor(1, 3, new_h + pad_h, new_w + pad_w).zero_()
35 | for i in range(3): # RGB
36 | padded_image[:, [i], ...] = F.pad(
37 | image[:, [i], ...],
38 | pad=(0, pad_w, 0, pad_h), # Pad right and bottom
39 | mode="constant",
40 | value=0,
41 | ).data
42 | return padded_image
43 |
44 |
45 | def to_cuda(tensors, cuda):
46 | return tensors.cuda() if cuda else tensors
47 |
48 |
49 | def to_var(tensors, cuda):
50 | tensors = to_cuda(tensors, cuda)
51 | variables = Variable(tensors, volatile=True)
52 | return variables
53 |
54 |
55 | def flip(x, dim=3):
56 | xsize = x.size()
57 | dim = x.dim() + dim if dim < 0 else dim
58 | x = x.view(-1, *xsize[dim:])
59 | x = x.view(x.size(0), x.size(1), -1)[
60 | :,
61 | getattr(
62 | torch.arange(x.size(1) - 1, -1, -1), ("cpu", "cuda")[x.is_cuda]
63 | )().long(),
64 | :,
65 | ]
66 | return x.view(xsize)
67 |
68 |
69 | def tile_predict(image, model, crop_size, cuda, n_classes):
70 | # Original MATLAB script
71 | # https://github.com/hszhao/PSPNet/blob/master/evaluation/scale_process.m
72 | pad_h, pad_w = image.shape[2:]
73 | stride_rate = 2 / 3.0
74 | stride = int(ceil(crop_size * stride_rate))
75 | h_grid = int(ceil((pad_h - crop_size) / float(stride)) + 1)
76 | w_grid = int(ceil((pad_w - crop_size) / float(stride)) + 1)
77 | count = to_cuda(torch.FloatTensor(1, 1, pad_h, pad_w).zero_(), cuda)
78 | prediction = to_cuda(torch.FloatTensor(1, n_classes, pad_h, pad_w).zero_(), cuda)
79 | for ih in range(h_grid):
80 | for iw in range(w_grid):
81 | sh, sw = ih * stride, iw * stride
82 | eh, ew = min(sh + crop_size, pad_h), min(sw + crop_size, pad_w)
83 | sh, sw = eh - crop_size, ew - crop_size # Stay within image size
84 | image_sub = image[..., sh:eh, sw:ew]
85 | image_sub = pad_image(image_sub, crop_size)
86 | image_sub = to_var(image_sub, cuda)
87 | output = model(image_sub)
88 | output = F.upsample(output, size=(crop_size,) * 2, mode="bilinear")
89 | count[..., sh:eh, sw:ew] += 1
90 | prediction[..., sh:eh, sw:ew] += output.data
91 | prediction /= count # Normalize overlayed parts
92 | return prediction
93 |
94 |
95 | @click.command()
96 | @click.option("--config", "-c", required=True)
97 | @click.option("--cuda/--no-cuda", default=True)
98 | @click.option("--show", is_flag=True)
99 | def main(config, cuda, show):
100 | CONFIG = Dict(yaml.load(open(config)))
101 |
102 | cuda = cuda and torch.cuda.is_available()
103 |
104 | dataset = VOCSegmentation(
105 | root=CONFIG.DATASET_ROOT, image_set="val", dataset_name="VOC2012"
106 | )
107 |
108 | dataloader = torch.utils.data.DataLoader(
109 | dataset=dataset,
110 | batch_size=1, #! DO NOT CHANGE
111 | num_workers=CONFIG.NUM_WORKERS,
112 | pin_memory=False,
113 | shuffle=False,
114 | )
115 |
116 | # Load a model
117 | state_dict = torch.load(CONFIG.PYTORCH_MODEL)
118 |
119 | # Model
120 | model = PSPNet(
121 | n_classes=CONFIG.N_CLASSES, n_blocks=CONFIG.N_BLOCKS, pyramids=CONFIG.PYRAMIDS
122 | )
123 | model.load_state_dict(state_dict)
124 | model = nn.DataParallel(model)
125 | model.eval()
126 | if cuda:
127 | model.cuda()
128 |
129 | crop_size = CONFIG.IMAGE.SIZE.TEST
130 | targets, outputs = [], []
131 |
132 | for image, target in tqdm(
133 | dataloader, total=len(dataloader), leave=False, dynamic_ncols=True
134 | ):
135 |
136 | h, w = image.size()[2:]
137 | outputs_ = []
138 |
139 | for scale in CONFIG.SCALES:
140 |
141 | # Resize
142 | long_side = int(scale * CONFIG.IMAGE.SIZE.BASE)
143 | new_h = long_side
144 | new_w = long_side
145 | if h > w:
146 | new_w = int(long_side * w / h)
147 | else:
148 | new_h = int(long_side * h / w)
149 | image_ = F.upsample(image, size=(new_h, new_w), mode="bilinear").data
150 |
151 | # Predict (w/ flipping)
152 | if long_side <= crop_size:
153 | # Padding evaluation
154 | image_ = pad_image(image_, crop_size)
155 | image_ = to_var(image_, cuda)
156 | output = torch.cat(
157 | (model(image_), flip(model(flip(image_)))) # C, H, W # C, H, W
158 | )
159 | output = F.upsample(output, size=(crop_size,) * 2, mode="bilinear")
160 | # Revert to original size
161 | output = output[..., 0:new_h, 0:new_w]
162 | output = F.upsample(output, size=(h, w), mode="bilinear")
163 | outputs_ += [o for o in output.data] # 2 x [C, H, W]
164 | else:
165 | # Sliced evaluation
166 | image_ = pad_image(image_, crop_size)
167 | output = torch.cat(
168 | (
169 | tile_predict(image_, model, crop_size, cuda, CONFIG.N_CLASSES),
170 | flip(
171 | tile_predict(
172 | flip(image_), model, crop_size, cuda, CONFIG.N_CLASSES
173 | )
174 | ),
175 | )
176 | )
177 | # Revert to original size
178 | output = output[..., 0:new_h, 0:new_w]
179 | output = F.upsample(output, size=(h, w), mode="bilinear")
180 | outputs_ += [o for o in output.data] # 2 x [C, H, W]
181 |
182 | # Average
183 | output = torch.stack(outputs_, dim=0) # 2x#scales, C, H, W
184 | output = torch.mean(output, dim=0) # C, H, W
185 | output = torch.max(output, dim=0)[1] # H, W
186 | output = output.cpu().numpy()
187 | target = target.squeeze(0).numpy()
188 |
189 | if show:
190 | res_gt = np.concatenate((output, target), 1)
191 | mask = (res_gt >= 0)[..., None]
192 | res_gt[res_gt < 0] = 0
193 | res_gt = np.uint8(res_gt / float(CONFIG.N_CLASSES) * 255)
194 | res_gt = cv2.applyColorMap(res_gt, cv2.COLORMAP_JET)
195 | res_gt = np.uint8(res_gt * mask)
196 | img = np.uint8(image.numpy()[0].transpose(1, 2, 0) + dataset.mean_rgb)[
197 | ..., ::-1
198 | ]
199 | img_res_gt = np.concatenate((img, res_gt), 1)
200 | cv2.imshow("result", img_res_gt)
201 | cv2.waitKey(10)
202 |
203 | outputs.append(output)
204 | targets.append(target)
205 |
206 | score, class_iou = scores(targets, outputs, n_class=CONFIG.N_CLASSES)
207 |
208 | for k, v in score.items():
209 | print(k, v)
210 |
211 | score["Class IoU"] = {}
212 | for i in range(CONFIG.N_CLASSES):
213 | score["Class IoU"][i] = class_iou[i]
214 |
215 | with open("results.json", "w") as f:
216 | json.dump(score, f, indent=4, sort_keys=True)
217 |
218 |
219 | if __name__ == "__main__":
220 | main()
221 |
--------------------------------------------------------------------------------
/data/datasets/ade20k/labels.txt:
--------------------------------------------------------------------------------
1 | 0 wall
2 | 1 building, edifice
3 | 2 sky
4 | 3 floor, flooring
5 | 4 tree
6 | 5 ceiling
7 | 6 road, route
8 | 7 bed
9 | 8 windowpane, window
10 | 9 grass
11 | 10 cabinet
12 | 11 sidewalk, pavement
13 | 12 person, individual, someone, somebody, mortal, soul
14 | 13 earth, ground
15 | 14 door, double door
16 | 15 table
17 | 16 mountain, mount
18 | 17 plant, flora, plant life
19 | 18 curtain, drape, drapery, mantle, pall
20 | 19 chair
21 | 20 car, auto, automobile, machine, motorcar
22 | 21 water
23 | 22 painting, picture
24 | 23 sofa, couch, lounge
25 | 24 shelf
26 | 25 house
27 | 26 sea
28 | 27 mirror
29 | 28 rug, carpet, carpeting
30 | 29 field
31 | 30 armchair
32 | 31 seat
33 | 32 fence, fencing
34 | 33 desk
35 | 34 rock, stone
36 | 35 wardrobe, closet, press
37 | 36 lamp
38 | 37 bathtub, bathing tub, bath, tub
39 | 38 railing, rail
40 | 39 cushion
41 | 40 base, pedestal, stand
42 | 41 box
43 | 42 column, pillar
44 | 43 signboard, sign
45 | 44 chest of drawers, chest, bureau, dresser
46 | 45 counter
47 | 46 sand
48 | 47 sink
49 | 48 skyscraper
50 | 49 fireplace, hearth, open fireplace
51 | 50 refrigerator, icebox
52 | 51 grandstand, covered stand
53 | 52 path
54 | 53 stairs, steps
55 | 54 runway
56 | 55 case, display case, showcase, vitrine
57 | 56 pool table, billiard table, snooker table
58 | 57 pillow
59 | 58 screen door, screen
60 | 59 stairway, staircase
61 | 60 river
62 | 61 bridge, span
63 | 62 bookcase
64 | 63 blind, screen
65 | 64 coffee table, cocktail table
66 | 65 toilet, can, commode, crapper, pot, potty, stool, throne
67 | 66 flower
68 | 67 book
69 | 68 hill
70 | 69 bench
71 | 70 countertop
72 | 71 stove, kitchen stove, range, kitchen range, cooking stove
73 | 72 palm, palm tree
74 | 73 kitchen island
75 | 74 computer, computing machine, computing device, data processor, electronic computer, information processing system
76 | 75 swivel chair
77 | 76 boat
78 | 77 bar
79 | 78 arcade machine
80 | 79 hovel, hut, hutch, shack, shanty
81 | 80 bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle
82 | 81 towel
83 | 82 light, light source
84 | 83 truck, motortruck
85 | 84 tower
86 | 85 chandelier, pendant, pendent
87 | 86 awning, sunshade, sunblind
88 | 87 streetlight, street lamp
89 | 88 booth, cubicle, stall, kiosk
90 | 89 television, television receiver, television set, tv, tv set, idiot box, boob tube, telly, goggle box
91 | 90 airplane, aeroplane, plane
92 | 91 dirt track
93 | 92 apparel, wearing apparel, dress, clothes
94 | 93 pole
95 | 94 land, ground, soil
96 | 95 bannister, banister, balustrade, balusters, handrail
97 | 96 escalator, moving staircase, moving stairway
98 | 97 ottoman, pouf, pouffe, puff, hassock
99 | 98 bottle
100 | 99 buffet, counter, sideboard
101 | 100 poster, posting, placard, notice, bill, card
102 | 101 stage
103 | 102 van
104 | 103 ship
105 | 104 fountain
106 | 105 conveyer belt, conveyor belt, conveyer, conveyor, transporter
107 | 106 canopy
108 | 107 washer, automatic washer, washing machine
109 | 108 plaything, toy
110 | 109 swimming pool, swimming bath, natatorium
111 | 110 stool
112 | 111 barrel, cask
113 | 112 basket, handbasket
114 | 113 waterfall, falls
115 | 114 tent, collapsible shelter
116 | 115 bag
117 | 116 minibike, motorbike
118 | 117 cradle
119 | 118 oven
120 | 119 ball
121 | 120 food, solid food
122 | 121 step, stair
123 | 122 tank, storage tank
124 | 123 trade name, brand name, brand, marque
125 | 124 microwave, microwave oven
126 | 125 pot, flowerpot
127 | 126 animal, animate being, beast, brute, creature, fauna
128 | 127 bicycle, bike, wheel, cycle
129 | 128 lake
130 | 129 dishwasher, dish washer, dishwashing machine
131 | 130 screen, silver screen, projection screen
132 | 131 blanket, cover
133 | 132 sculpture
134 | 133 hood, exhaust hood
135 | 134 sconce
136 | 135 vase
137 | 136 traffic light, traffic signal, stoplight
138 | 137 tray
139 | 138 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
140 | 139 fan
141 | 140 pier, wharf, wharfage, dock
142 | 141 crt screen
143 | 142 plate
144 | 143 monitor, monitoring device
145 | 144 bulletin board, notice board
146 | 145 shower
147 | 146 radiator
148 | 147 glass, drinking glass
149 | 148 clock
150 | 149 flag
151 | 0 wall
152 | 1 building, edifice
153 | 2 sky
154 | 3 floor, flooring
155 | 4 tree
156 | 5 ceiling
157 | 6 road, route
158 | 7 bed
159 | 8 windowpane, window
160 | 9 grass
161 | 10 cabinet
162 | 11 sidewalk, pavement
163 | 12 person, individual, someone, somebody, mortal, soul
164 | 13 earth, ground
165 | 14 door, double door
166 | 15 table
167 | 16 mountain, mount
168 | 17 plant, flora, plant life
169 | 18 curtain, drape, drapery, mantle, pall
170 | 19 chair
171 | 20 car, auto, automobile, machine, motorcar
172 | 21 water
173 | 22 painting, picture
174 | 23 sofa, couch, lounge
175 | 24 shelf
176 | 25 house
177 | 26 sea
178 | 27 mirror
179 | 28 rug, carpet, carpeting
180 | 29 field
181 | 30 armchair
182 | 31 seat
183 | 32 fence, fencing
184 | 33 desk
185 | 34 rock, stone
186 | 35 wardrobe, closet, press
187 | 36 lamp
188 | 37 bathtub, bathing tub, bath, tub
189 | 38 railing, rail
190 | 39 cushion
191 | 40 base, pedestal, stand
192 | 41 box
193 | 42 column, pillar
194 | 43 signboard, sign
195 | 44 chest of drawers, chest, bureau, dresser
196 | 45 counter
197 | 46 sand
198 | 47 sink
199 | 48 skyscraper
200 | 49 fireplace, hearth, open fireplace
201 | 50 refrigerator, icebox
202 | 51 grandstand, covered stand
203 | 52 path
204 | 53 stairs, steps
205 | 54 runway
206 | 55 case, display case, showcase, vitrine
207 | 56 pool table, billiard table, snooker table
208 | 57 pillow
209 | 58 screen door, screen
210 | 59 stairway, staircase
211 | 60 river
212 | 61 bridge, span
213 | 62 bookcase
214 | 63 blind, screen
215 | 64 coffee table, cocktail table
216 | 65 toilet, can, commode, crapper, pot, potty, stool, throne
217 | 66 flower
218 | 67 book
219 | 68 hill
220 | 69 bench
221 | 70 countertop
222 | 71 stove, kitchen stove, range, kitchen range, cooking stove
223 | 72 palm, palm tree
224 | 73 kitchen island
225 | 74 computer, computing machine, computing device, data processor, electronic computer, information processing system
226 | 75 swivel chair
227 | 76 boat
228 | 77 bar
229 | 78 arcade machine
230 | 79 hovel, hut, hutch, shack, shanty
231 | 80 bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle
232 | 81 towel
233 | 82 light, light source
234 | 83 truck, motortruck
235 | 84 tower
236 | 85 chandelier, pendant, pendent
237 | 86 awning, sunshade, sunblind
238 | 87 streetlight, street lamp
239 | 88 booth, cubicle, stall, kiosk
240 | 89 television, television receiver, television set, tv, tv set, idiot box, boob tube, telly, goggle box
241 | 90 airplane, aeroplane, plane
242 | 91 dirt track
243 | 92 apparel, wearing apparel, dress, clothes
244 | 93 pole
245 | 94 land, ground, soil
246 | 95 bannister, banister, balustrade, balusters, handrail
247 | 96 escalator, moving staircase, moving stairway
248 | 97 ottoman, pouf, pouffe, puff, hassock
249 | 98 bottle
250 | 99 buffet, counter, sideboard
251 | 100 poster, posting, placard, notice, bill, card
252 | 101 stage
253 | 102 van
254 | 103 ship
255 | 104 fountain
256 | 105 conveyer belt, conveyor belt, conveyer, conveyor, transporter
257 | 106 canopy
258 | 107 washer, automatic washer, washing machine
259 | 108 plaything, toy
260 | 109 swimming pool, swimming bath, natatorium
261 | 110 stool
262 | 111 barrel, cask
263 | 112 basket, handbasket
264 | 113 waterfall, falls
265 | 114 tent, collapsible shelter
266 | 115 bag
267 | 116 minibike, motorbike
268 | 117 cradle
269 | 118 oven
270 | 119 ball
271 | 120 food, solid food
272 | 121 step, stair
273 | 122 tank, storage tank
274 | 123 trade name, brand name, brand, marque
275 | 124 microwave, microwave oven
276 | 125 pot, flowerpot
277 | 126 animal, animate being, beast, brute, creature, fauna
278 | 127 bicycle, bike, wheel, cycle
279 | 128 lake
280 | 129 dishwasher, dish washer, dishwashing machine
281 | 130 screen, silver screen, projection screen
282 | 131 blanket, cover
283 | 132 sculpture
284 | 133 hood, exhaust hood
285 | 134 sconce
286 | 135 vase
287 | 136 traffic light, traffic signal, stoplight
288 | 137 tray
289 | 138 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
290 | 139 fan
291 | 140 pier, wharf, wharfage, dock
292 | 141 crt screen
293 | 142 plate
294 | 143 monitor, monitoring device
295 | 144 bulletin board, notice board
296 | 145 shower
297 | 146 radiator
298 | 147 glass, drinking glass
299 | 148 clock
300 | 149 flag
--------------------------------------------------------------------------------
/libs/caffe.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package caffe;
4 |
5 | // Specifies the shape (dimensions) of a Blob.
6 | message BlobShape {
7 | repeated int64 dim = 1 [packed = true];
8 | }
9 |
10 | message BlobProto {
11 | optional BlobShape shape = 7;
12 | repeated float data = 5 [packed = true];
13 | repeated float diff = 6 [packed = true];
14 | repeated double double_data = 8 [packed = true];
15 | repeated double double_diff = 9 [packed = true];
16 |
17 | // 4D dimensions -- deprecated. Use "shape" instead.
18 | optional int32 num = 1 [default = 0];
19 | optional int32 channels = 2 [default = 0];
20 | optional int32 height = 3 [default = 0];
21 | optional int32 width = 4 [default = 0];
22 | }
23 |
24 | // The BlobProtoVector is simply a way to pass multiple blobproto instances
25 | // around.
26 | message BlobProtoVector {
27 | repeated BlobProto blobs = 1;
28 | }
29 |
30 | message Datum {
31 | optional int32 channels = 1;
32 | optional int32 height = 2;
33 | optional int32 width = 3;
34 | // the actual image data, in bytes
35 | optional bytes data = 4;
36 | optional int32 label = 5;
37 | // Optionally, the datum could also hold float data.
38 | repeated float float_data = 6;
39 | // If true data contains an encoded image that need to be decoded
40 | optional bool encoded = 7 [default = false];
41 | }
42 |
43 | message FillerParameter {
44 | // The filler type.
45 | optional string type = 1 [default = 'constant'];
46 | optional float value = 2 [default = 0]; // the value in constant filler
47 | optional float min = 3 [default = 0]; // the min value in uniform filler
48 | optional float max = 4 [default = 1]; // the max value in uniform filler
49 | optional float mean = 5 [default = 0]; // the mean value in Gaussian filler
50 | optional float std = 6 [default = 1]; // the std value in Gaussian filler
51 | // The expected number of non-zero output weights for a given input in
52 | // Gaussian filler -- the default -1 means don't perform sparsification.
53 | optional int32 sparse = 7 [default = -1];
54 | // Normalize the filler variance by fan_in, fan_out, or their average.
55 | // Applies to 'xavier' and 'msra' fillers.
56 | enum VarianceNorm {
57 | FAN_IN = 0;
58 | FAN_OUT = 1;
59 | AVERAGE = 2;
60 | }
61 | optional VarianceNorm variance_norm = 8 [default = FAN_IN];
62 | }
63 |
64 | message NetParameter {
65 | optional string name = 1; // consider giving the network a name
66 | // The input blobs to the network.
67 | repeated string input = 3;
68 | // The shape of the input blobs.
69 | repeated BlobShape input_shape = 8;
70 |
71 | // 4D input dimensions -- deprecated. Use "shape" instead.
72 | // If specified, for each input blob there should be four
73 | // values specifying the num, channels, height and width of the input blob.
74 | // Thus, there should be a total of (4 * #input) numbers.
75 | repeated int32 input_dim = 4;
76 |
77 | // Whether the network will force every layer to carry out backward operation.
78 | // If set False, then whether to carry out backward is determined
79 | // automatically according to the net structure and learning rates.
80 | optional bool force_backward = 5 [default = false];
81 | // The current "state" of the network, including the phase, level, and stage.
82 | // Some layers may be included/excluded depending on this state and the states
83 | // specified in the layers' include and exclude fields.
84 | optional NetState state = 6;
85 |
86 | // Print debugging information about results while running Net::Forward,
87 | // Net::Backward, and Net::Update.
88 | optional bool debug_info = 7 [default = false];
89 |
90 | // The layers that make up the net. Each of their configurations, including
91 | // connectivity and behavior, is specified as a LayerParameter.
92 | repeated LayerParameter layer = 100; // ID 100 so layers are printed last.
93 |
94 | // DEPRECATED: use 'layer' instead.
95 | repeated V1LayerParameter layers = 2;
96 | }
97 |
98 | // NOTE
99 | // Update the next available ID when you add a new SolverParameter field.
100 | //
101 | // SolverParameter next available ID: 41 (last added: type)
102 | message SolverParameter {
103 | //////////////////////////////////////////////////////////////////////////////
104 | // Specifying the train and test networks
105 | //
106 | // Exactly one train net must be specified using one of the following fields:
107 | // train_net_param, train_net, net_param, net
108 | // One or more test nets may be specified using any of the following fields:
109 | // test_net_param, test_net, net_param, net
110 | // If more than one test net field is specified (e.g., both net and
111 | // test_net are specified), they will be evaluated in the field order given
112 | // above: (1) test_net_param, (2) test_net, (3) net_param/net.
113 | // A test_iter must be specified for each test_net.
114 | // A test_level and/or a test_stage may also be specified for each test_net.
115 | //////////////////////////////////////////////////////////////////////////////
116 |
117 | // Proto filename for the train net, possibly combined with one or more
118 | // test nets.
119 | optional string net = 24;
120 | // Inline train net param, possibly combined with one or more test nets.
121 | optional NetParameter net_param = 25;
122 |
123 | optional string train_net = 1; // Proto filename for the train net.
124 | repeated string test_net = 2; // Proto filenames for the test nets.
125 | optional NetParameter train_net_param = 21; // Inline train net params.
126 | repeated NetParameter test_net_param = 22; // Inline test net params.
127 |
128 | // The states for the train/test nets. Must be unspecified or
129 | // specified once per net.
130 | //
131 | // By default, all states will have solver = true;
132 | // train_state will have phase = TRAIN,
133 | // and all test_state's will have phase = TEST.
134 | // Other defaults are set according to the NetState defaults.
135 | optional NetState train_state = 26;
136 | repeated NetState test_state = 27;
137 |
138 | // The number of iterations for each test net.
139 | repeated int32 test_iter = 3;
140 |
141 | // The number of iterations between two testing phases.
142 | optional int32 test_interval = 4 [default = 0];
143 | optional bool test_compute_loss = 19 [default = false];
144 | // If true, run an initial test pass before the first iteration,
145 | // ensuring memory availability and printing the starting value of the loss.
146 | optional bool test_initialization = 32 [default = true];
147 | optional float base_lr = 5; // The base learning rate
148 | // the number of iterations between displaying info. If display = 0, no info
149 | // will be displayed.
150 | optional int32 display = 6;
151 | // Display the loss averaged over the last average_loss iterations
152 | optional int32 average_loss = 33 [default = 1];
153 | optional int32 max_iter = 7; // the maximum number of iterations
154 | // accumulate gradients over `iter_size` x `batch_size` instances
155 | optional int32 iter_size = 36 [default = 1];
156 |
157 | // The learning rate decay policy. The currently implemented learning rate
158 | // policies are as follows:
159 | // - fixed: always return base_lr.
160 | // - step: return base_lr * gamma ^ (floor(iter / step))
161 | // - exp: return base_lr * gamma ^ iter
162 | // - inv: return base_lr * (1 + gamma * iter) ^ (- power)
163 | // - multistep: similar to step but it allows non uniform steps defined by
164 | // stepvalue
165 | // - poly: the effective learning rate follows a polynomial decay, to be
166 | // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
167 | // - sigmoid: the effective learning rate follows a sigmod decay
168 | // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
169 | //
170 | // where base_lr, max_iter, gamma, step, stepvalue and power are defined
171 | // in the solver parameter protocol buffer, and iter is the current iteration.
172 | optional string lr_policy = 8;
173 | optional float gamma = 9; // The parameter to compute the learning rate.
174 | optional float power = 10; // The parameter to compute the learning rate.
175 | optional float momentum = 11; // The momentum value.
176 | optional float weight_decay = 12; // The weight decay.
177 | // regularization types supported: L1 and L2
178 | // controlled by weight_decay
179 | optional string regularization_type = 29 [default = "L2"];
180 | // the stepsize for learning rate policy "step"
181 | optional int32 stepsize = 13;
182 | // the stepsize for learning rate policy "multistep"
183 | repeated int32 stepvalue = 34;
184 |
185 | // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
186 | // whenever their actual L2 norm is larger.
187 | optional float clip_gradients = 35 [default = -1];
188 |
189 | optional int32 snapshot = 14 [default = 0]; // The snapshot interval
190 | optional string snapshot_prefix = 15; // The prefix for the snapshot.
191 | // whether to snapshot diff in the results or not. Snapshotting diff will help
192 | // debugging but the final protocol buffer size will be much larger.
193 | optional bool snapshot_diff = 16 [default = false];
194 | enum SnapshotFormat {
195 | HDF5 = 0;
196 | BINARYPROTO = 1;
197 | }
198 | optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
199 | // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
200 | enum SolverMode {
201 | CPU = 0;
202 | GPU = 1;
203 | }
204 | optional SolverMode solver_mode = 17 [default = GPU];
205 | // the device_id will that be used in GPU mode. Use device_id = 0 in default.
206 | optional int32 device_id = 18 [default = 0];
207 | // If non-negative, the seed with which the Solver will initialize the Caffe
208 | // random number generator -- useful for reproducible results. Otherwise,
209 | // (and by default) initialize using a seed derived from the system clock.
210 | optional int64 random_seed = 20 [default = -1];
211 |
212 | // type of the solver
213 | optional string type = 40 [default = "SGD"];
214 |
215 | // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
216 | optional float delta = 31 [default = 1e-8];
217 | // parameters for the Adam solver
218 | optional float momentum2 = 39 [default = 0.999];
219 |
220 | // RMSProp decay value
221 | // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
222 | optional float rms_decay = 38;
223 |
224 | // If true, print information about the state of the net that may help with
225 | // debugging learning problems.
226 | optional bool debug_info = 23 [default = false];
227 |
228 | // If false, don't save a snapshot after training finishes.
229 | optional bool snapshot_after_train = 28 [default = true];
230 |
231 | // DEPRECATED: old solver enum types, use string instead
232 | enum SolverType {
233 | SGD = 0;
234 | NESTEROV = 1;
235 | ADAGRAD = 2;
236 | RMSPROP = 3;
237 | ADADELTA = 4;
238 | ADAM = 5;
239 | }
240 | // DEPRECATED: use type instead of solver_type
241 | optional SolverType solver_type = 30 [default = SGD];
242 | }
243 |
244 | // A message that stores the solver snapshots
245 | message SolverState {
246 | optional int32 iter = 1; // The current iteration
247 | optional string learned_net = 2; // The file that stores the learned net.
248 | repeated BlobProto history = 3; // The history for sgd solvers
249 | optional int32 current_step = 4 [default = 0]; // The current step for learning rate
250 | }
251 |
252 | enum Phase {
253 | TRAIN = 0;
254 | TEST = 1;
255 | }
256 |
257 | message NetState {
258 | optional Phase phase = 1 [default = TEST];
259 | optional int32 level = 2 [default = 0];
260 | repeated string stage = 3;
261 | }
262 |
263 | message NetStateRule {
264 | // Set phase to require the NetState have a particular phase (TRAIN or TEST)
265 | // to meet this rule.
266 | optional Phase phase = 1;
267 |
268 | // Set the minimum and/or maximum levels in which the layer should be used.
269 | // Leave undefined to meet the rule regardless of level.
270 | optional int32 min_level = 2;
271 | optional int32 max_level = 3;
272 |
273 | // Customizable sets of stages to include or exclude.
274 | // The net must have ALL of the specified stages and NONE of the specified
275 | // "not_stage"s to meet the rule.
276 | // (Use multiple NetStateRules to specify conjunctions of stages.)
277 | repeated string stage = 4;
278 | repeated string not_stage = 5;
279 | }
280 |
281 | // Specifies training parameters (multipliers on global learning constants,
282 | // and the name and other settings used for weight sharing).
283 | message ParamSpec {
284 | // The names of the parameter blobs -- useful for sharing parameters among
285 | // layers, but never required otherwise. To share a parameter between two
286 | // layers, give it a (non-empty) name.
287 | optional string name = 1;
288 |
289 | // Whether to require shared weights to have the same shape, or just the same
290 | // count -- defaults to STRICT if unspecified.
291 | optional DimCheckMode share_mode = 2;
292 | enum DimCheckMode {
293 | // STRICT (default) requires that num, channels, height, width each match.
294 | STRICT = 0;
295 | // PERMISSIVE requires only the count (num*channels*height*width) to match.
296 | PERMISSIVE = 1;
297 | }
298 |
299 | // The multiplier on the global learning rate for this parameter.
300 | optional float lr_mult = 3 [default = 1.0];
301 |
302 | // The multiplier on the global weight decay for this parameter.
303 | optional float decay_mult = 4 [default = 1.0];
304 | }
305 |
306 | // NOTE
307 | // Update the next available ID when you add a new LayerParameter field.
308 | //
309 | // LayerParameter next available layer-specific ID: 153 (last added: bn_param)
310 | message LayerParameter {
311 | optional string name = 1; // the layer name
312 | optional string type = 2; // the layer type
313 | repeated string bottom = 3; // the name of each bottom blob
314 | repeated string top = 4; // the name of each top blob
315 |
316 | // The train / test phase for computation.
317 | optional Phase phase = 10;
318 |
319 | // The amount of weight to assign each top blob in the objective.
320 | // Each layer assigns a default value, usually of either 0 or 1,
321 | // to each top blob.
322 | repeated float loss_weight = 5;
323 |
324 | // Specifies training parameters (multipliers on global learning constants,
325 | // and the name and other settings used for weight sharing).
326 | repeated ParamSpec param = 6;
327 |
328 | // The blobs containing the numeric parameters of the layer.
329 | repeated BlobProto blobs = 7;
330 |
331 | // Specifies on which bottoms the backpropagation should be skipped.
332 | // The size must be either 0 or equal to the number of bottoms.
333 | repeated bool propagate_down = 11;
334 |
335 | // Rules controlling whether and when a layer is included in the network,
336 | // based on the current NetState. You may specify a non-zero number of rules
337 | // to include OR exclude, but not both. If no include or exclude rules are
338 | // specified, the layer is always included. If the current NetState meets
339 | // ANY (i.e., one or more) of the specified rules, the layer is
340 | // included/excluded.
341 | repeated NetStateRule include = 8;
342 | repeated NetStateRule exclude = 9;
343 |
344 | // Parameters for data pre-processing.
345 | optional TransformationParameter transform_param = 100;
346 |
347 | // Parameters shared by loss layers.
348 | optional LossParameter loss_param = 101;
349 |
350 | // Layer type-specific parameters.
351 | //
352 | // Note: certain layers may have more than one computational engine
353 | // for their implementation. These layers include an Engine type and
354 | // engine parameter for selecting the implementation.
355 | // The default for the engine is set by the ENGINE switch at compile-time.
356 | optional AccuracyParameter accuracy_param = 102;
357 | optional AdaptiveBiasChannelParameter adaptive_bias_channel_param = 148;
358 | optional ArgMaxParameter argmax_param = 103;
359 | optional BatchNormParameter batch_norm_param = 139;
360 | optional BNParameter bn_param = 152;
361 | optional BiasParameter bias_param = 141;
362 | optional BiasChannelParameter bias_channel_param = 149;
363 | optional ConcatParameter concat_param = 104;
364 | optional ContrastiveLossParameter contrastive_loss_param = 105;
365 | optional ConvolutionParameter convolution_param = 106;
366 | optional DataParameter data_param = 107;
367 | optional DenseCRFParameter dense_crf_param = 146;
368 | optional DomainTransformParameter domain_transform_param = 147;
369 | optional DropoutParameter dropout_param = 108;
370 | optional DummyDataParameter dummy_data_param = 109;
371 | optional EltwiseParameter eltwise_param = 110;
372 | optional ELUParameter elu_param = 140;
373 | optional EmbedParameter embed_param = 137;
374 | optional ExpParameter exp_param = 111;
375 | optional FlattenParameter flatten_param = 135;
376 | optional HDF5DataParameter hdf5_data_param = 112;
377 | optional HDF5OutputParameter hdf5_output_param = 113;
378 | optional HingeLossParameter hinge_loss_param = 114;
379 | optional ImageDataParameter image_data_param = 115;
380 | optional InfogainLossParameter infogain_loss_param = 116;
381 | optional InnerProductParameter inner_product_param = 117;
382 | optional InterpParameter interp_param = 143;
383 | optional LogParameter log_param = 134;
384 | optional LRNParameter lrn_param = 118;
385 | optional MatReadParameter mat_read_param = 151;
386 | optional MatWriteParameter mat_write_param = 145;
387 | optional MemoryDataParameter memory_data_param = 119;
388 | optional MVNParameter mvn_param = 120;
389 | optional PoolingParameter pooling_param = 121;
390 | optional PowerParameter power_param = 122;
391 | optional PReLUParameter prelu_param = 131;
392 | optional PythonParameter python_param = 130;
393 | optional ReductionParameter reduction_param = 136;
394 | optional ReLUParameter relu_param = 123;
395 | optional ReshapeParameter reshape_param = 133;
396 | optional ScaleParameter scale_param = 142;
397 | optional SegAccuracyParameter seg_accuracy_param = 144;
398 | optional SigmoidParameter sigmoid_param = 124;
399 | optional SoftmaxParameter softmax_param = 125;
400 | optional SPPParameter spp_param = 132;
401 | optional SliceParameter slice_param = 126;
402 | optional TanHParameter tanh_param = 127;
403 | optional ThresholdParameter threshold_param = 128;
404 | optional TileParameter tile_param = 138;
405 | optional UniqueLabelParameter unique_label_param = 150;
406 | optional WindowDataParameter window_data_param = 129;
407 | }
408 |
409 | // Message that stores parameters used to apply transformation
410 | // to the data layer's data
411 | message TransformationParameter {
412 | // For data pre-processing, we can do simple scaling and subtracting the
413 | // data mean, if provided. Note that the mean subtraction is always carried
414 | // out before scaling.
415 | optional float scale = 1 [default = 1];
416 | // Specify if we want to randomly mirror data.
417 | optional bool mirror = 2 [default = false];
418 | // Specify if we would like to randomly crop an image.
419 | optional uint32 crop_size = 3 [default = 0];
420 | // mean_file and mean_value cannot be specified at the same time
421 | optional string mean_file = 4;
422 | // if specified can be repeated once (would substract it from all the channels)
423 | // or can be repeated the same number of times as channels
424 | // (would subtract them from the corresponding channel)
425 | repeated float mean_value = 5;
426 | // Force the decoded image to have 3 color channels.
427 | optional bool force_color = 6 [default = false];
428 | // Force the decoded image to have 1 color channels.
429 | optional bool force_gray = 7 [default = false];
430 | // If we want to do data augmentation, Scaling factor for randomly scaling input images
431 | repeated float scale_factors = 8;
432 | // the width for cropped region
433 | optional uint32 crop_width = 9 [default = 0];
434 | // the height for cropped region
435 | optional uint32 crop_height = 10 [default = 0];
436 |
437 | }
438 |
439 | // Message that stores parameters shared by loss layers
440 | message LossParameter {
441 | // If specified, ignore instances with the given label.
442 | optional int32 ignore_label = 1;
443 | // How to normalize the loss for loss layers that aggregate across batches,
444 | // spatial dimensions, or other dimensions. Currently only implemented in
445 | // SoftmaxWithLoss layer.
446 | enum NormalizationMode {
447 | // Divide by the number of examples in the batch times spatial dimensions.
448 | // Outputs that receive the ignore label will NOT be ignored in computing
449 | // the normalization factor.
450 | FULL = 0;
451 | // Divide by the total number of output locations that do not take the
452 | // ignore_label. If ignore_label is not set, this behaves like FULL.
453 | VALID = 1;
454 | // Divide by the batch size.
455 | BATCH_SIZE = 2;
456 | // Do not normalize the loss.
457 | NONE = 3;
458 | }
459 | optional NormalizationMode normalization = 3 [default = VALID];
460 | // Deprecated. Ignored if normalization is specified. If normalization
461 | // is not specified, then setting this to false will be equivalent to
462 | // normalization = BATCH_SIZE to be consistent with previous behavior.
463 | optional bool normalize = 2;
464 | }
465 |
466 | // Messages that store parameters used by individual layer types follow, in
467 | // alphabetical order.
468 |
469 | message AccuracyParameter {
470 | // When computing accuracy, count as correct by comparing the true label to
471 | // the top k scoring classes. By default, only compare to the top scoring
472 | // class (i.e. argmax).
473 | optional uint32 top_k = 1 [default = 1];
474 |
475 | // The "label" axis of the prediction blob, whose argmax corresponds to the
476 | // predicted label -- may be negative to index from the end (e.g., -1 for the
477 | // last axis). For example, if axis == 1 and the predictions are
478 | // (N x C x H x W), the label blob is expected to contain N*H*W ground truth
479 | // labels with integer values in {0, 1, ..., C-1}.
480 | optional int32 axis = 2 [default = 1];
481 |
482 | // If specified, ignore instances with the given label.
483 | optional int32 ignore_label = 3;
484 | }
485 |
486 | message AdaptiveBiasChannelParameter {
487 | optional int32 num_iter = 1 [default = 1];
488 | optional float bg_portion = 2 [default = 0.2];
489 | optional float fg_portion = 3 [default = 0.2];
490 | optional bool suppress_others = 4 [default = true];
491 | optional float margin_others = 5 [default = 1e-5];
492 | }
493 |
494 | message ArgMaxParameter {
495 | // If true produce pairs (argmax, maxval)
496 | optional bool out_max_val = 1 [default = false];
497 | optional uint32 top_k = 2 [default = 1];
498 | // The axis along which to maximise -- may be negative to index from the
499 | // end (e.g., -1 for the last axis).
500 | // By default ArgMaxLayer maximizes over the flattened trailing dimensions
501 | // for each index of the first / num dimension.
502 | optional int32 axis = 3;
503 | }
504 |
505 | message BiasChannelParameter {
506 | // Score biases. Separate values for BG / FG
507 | optional float bg_bias = 1 [default = 1.];
508 | optional float fg_bias = 2 [default = 2.];
509 | // will ignore labels with this value when adding bias
510 | repeated int32 ignore_label = 3;
511 | enum LabelType {
512 | IMAGE = 1;
513 | PIXEL = 2;
514 | }
515 | optional LabelType label_type = 4 [default = IMAGE];
516 | // If the dataset defines generic background label or not.
517 | // The default value is defined for PASCAL VOC segmentation
518 | optional int32 background_label = 6 [default = 0];
519 | }
520 |
521 | message ConcatParameter {
522 | // The axis along which to concatenate -- may be negative to index from the
523 | // end (e.g., -1 for the last axis). Other axes must have the
524 | // same dimension for all the bottom blobs.
525 | // By default, ConcatLayer concatenates blobs along the "channels" axis (1).
526 | optional int32 axis = 2 [default = 1];
527 |
528 | // DEPRECATED: alias for "axis" -- does not support negative indexing.
529 | optional uint32 concat_dim = 1 [default = 1];
530 | }
531 |
532 | message BatchNormParameter {
533 | // If false, accumulate global mean/variance values via a moving average. If
534 | // true, use those accumulated values instead of computing mean/variance
535 | // across the batch.
536 | optional bool use_global_stats = 1;
537 | // How much does the moving average decay each iteration?
538 | optional float moving_average_fraction = 2 [default = .999];
539 | // Small value to add to the variance estimate so that we don't divide by
540 | // zero.
541 | optional float eps = 3 [default = 1e-5];
542 | optional bool update_global_stats = 4 [default = false];
543 | }
544 |
545 | message BNParameter {
546 | optional FillerParameter slope_filler = 1;
547 | optional FillerParameter bias_filler = 2;
548 | optional float momentum = 3 [default = 0.9];
549 | optional float eps = 4 [default = 1e-5];
550 | // If true, will use the moving average mean and std for training and test.
551 | // Will override the lr_param and freeze all the parameters.
552 | // Make sure to initialize the layer properly with pretrained parameters.
553 | optional bool frozen = 5 [default = false];
554 | enum Engine {
555 | DEFAULT = 0;
556 | CAFFE = 1;
557 | CUDNN = 2;
558 | }
559 | optional Engine engine = 6 [default = DEFAULT];
560 | }
561 |
562 | message BiasParameter {
563 | // The first axis of bottom[0] (the first input Blob) along which to apply
564 | // bottom[1] (the second input Blob). May be negative to index from the end
565 | // (e.g., -1 for the last axis).
566 | //
567 | // For example, if bottom[0] is 4D with shape 100x3x40x60, the output
568 | // top[0] will have the same shape, and bottom[1] may have any of the
569 | // following shapes (for the given value of axis):
570 | // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60
571 | // (axis == 1 == -3) 3; 3x40; 3x40x60
572 | // (axis == 2 == -2) 40; 40x60
573 | // (axis == 3 == -1) 60
574 | // Furthermore, bottom[1] may have the empty shape (regardless of the value of
575 | // "axis") -- a scalar bias.
576 | optional int32 axis = 1 [default = 1];
577 |
578 | // (num_axes is ignored unless just one bottom is given and the bias is
579 | // a learned parameter of the layer. Otherwise, num_axes is determined by the
580 | // number of axes by the second bottom.)
581 | // The number of axes of the input (bottom[0]) covered by the bias
582 | // parameter, or -1 to cover all axes of bottom[0] starting from `axis`.
583 | // Set num_axes := 0, to add a zero-axis Blob: a scalar.
584 | optional int32 num_axes = 2 [default = 1];
585 |
586 | // (filler is ignored unless just one bottom is given and the bias is
587 | // a learned parameter of the layer.)
588 | // The initialization for the learned bias parameter.
589 | // Default is the zero (0) initialization, resulting in the BiasLayer
590 | // initially performing the identity operation.
591 | optional FillerParameter filler = 3;
592 | }
593 |
594 | message ContrastiveLossParameter {
595 | // margin for dissimilar pair
596 | optional float margin = 1 [default = 1.0];
597 | // The first implementation of this cost did not exactly match the cost of
598 | // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2.
599 | // legacy_version = false (the default) uses (margin - d)^2 as proposed in the
600 | // Hadsell paper. New models should probably use this version.
601 | // legacy_version = true uses (margin - d^2). This is kept to support /
602 | // reproduce existing models and results
603 | optional bool legacy_version = 2 [default = false];
604 | }
605 |
606 | message ConvolutionParameter {
607 | optional uint32 num_output = 1; // The number of outputs for the layer
608 | optional bool bias_term = 2 [default = true]; // whether to have bias terms
609 |
610 | // Pad, kernel size, and stride are all given as a single value for equal
611 | // dimensions in all spatial dimensions, or once per spatial dimension.
612 | repeated uint32 pad = 3; // The padding size; defaults to 0
613 | repeated uint32 kernel_size = 4; // The kernel size
614 | repeated uint32 stride = 6; // The stride; defaults to 1
615 | // Factor used to dilate the kernel, (implicitly) zero-filling the resulting
616 | // holes. (Kernel dilation is sometimes referred to by its use in the
617 | // algorithme à trous from Holschneider et al. 1987.)
618 | repeated uint32 dilation = 18; // The dilation; defaults to 1
619 |
620 | // For 2D convolution only, the *_h and *_w versions may also be used to
621 | // specify both spatial dimensions.
622 | optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only)
623 | optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only)
624 | optional uint32 kernel_h = 11; // The kernel height (2D only)
625 | optional uint32 kernel_w = 12; // The kernel width (2D only)
626 | optional uint32 stride_h = 13; // The stride height (2D only)
627 | optional uint32 stride_w = 14; // The stride width (2D only)
628 |
629 | optional uint32 group = 5 [default = 1]; // The group size for group conv
630 |
631 | optional FillerParameter weight_filler = 7; // The filler for the weight
632 | optional FillerParameter bias_filler = 8; // The filler for the bias
633 | enum Engine {
634 | DEFAULT = 0;
635 | CAFFE = 1;
636 | CUDNN = 2;
637 | }
638 | optional Engine engine = 15 [default = DEFAULT];
639 |
640 | // The axis to interpret as "channels" when performing convolution.
641 | // Preceding dimensions are treated as independent inputs;
642 | // succeeding dimensions are treated as "spatial".
643 | // With (N, C, H, W) inputs, and axis == 1 (the default), we perform
644 | // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for
645 | // groups g>1) filters across the spatial axes (H, W) of the input.
646 | // With (N, C, D, H, W) inputs, and axis == 1, we perform
647 | // N independent 3D convolutions, sliding (C/g)-channels
648 | // filters across the spatial axes (D, H, W) of the input.
649 | optional int32 axis = 16 [default = 1];
650 |
651 | // Whether to force use of the general ND convolution, even if a specific
652 | // implementation for blobs of the appropriate number of spatial dimensions
653 | // is available. (Currently, there is only a 2D-specific convolution
654 | // implementation; for input blobs with num_axes != 2, this option is
655 | // ignored and the ND implementation will be used.)
656 | optional bool force_nd_im2col = 17 [default = false];
657 | }
658 |
659 | message DataParameter {
660 | enum DB {
661 | LEVELDB = 0;
662 | LMDB = 1;
663 | }
664 | // Specify the data source.
665 | optional string source = 1;
666 | // Specify the batch size.
667 | optional uint32 batch_size = 4;
668 | // The rand_skip variable is for the data layer to skip a few data points
669 | // to avoid all asynchronous sgd clients to start at the same point. The skip
670 | // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
671 | // be larger than the number of keys in the database.
672 | // DEPRECATED. Each solver accesses a different subset of the database.
673 | optional uint32 rand_skip = 7 [default = 0];
674 | optional DB backend = 8 [default = LEVELDB];
675 | // DEPRECATED. See TransformationParameter. For data pre-processing, we can do
676 | // simple scaling and subtracting the data mean, if provided. Note that the
677 | // mean subtraction is always carried out before scaling.
678 | optional float scale = 2 [default = 1];
679 | optional string mean_file = 3;
680 | // DEPRECATED. See TransformationParameter. Specify if we would like to randomly
681 | // crop an image.
682 | optional uint32 crop_size = 5 [default = 0];
683 | // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
684 | // data.
685 | optional bool mirror = 6 [default = false];
686 | // Force the encoded image to have 3 color channels
687 | optional bool force_encoded_color = 9 [default = false];
688 | // Prefetch queue (Number of batches to prefetch to host memory, increase if
689 | // data access bandwidth varies).
690 | optional uint32 prefetch = 10 [default = 4];
691 | }
692 |
693 | message DenseCRFParameter {
694 | // max number of iteration for message passing
695 | optional int32 max_iter = 1 [default = 10];
696 | // positional std and weight for "Positional" filter (color-independent)
697 | repeated float pos_xy_std = 2;
698 | repeated float pos_w = 3;
699 | // positional std, color std and weight for Bilateral filter
700 | repeated float bi_xy_std = 4;
701 | repeated float bi_rgb_std = 5;
702 | repeated float bi_w = 6;
703 | // output is probability or score (score = log(prob))
704 | optional bool output_probability = 7 [default = true];
705 | }
706 |
707 | message DomainTransformParameter {
708 | // Max number of iteration for filtering.
709 | optional int32 num_iter = 1 [default = 3];
710 | // Standard deviation for spatial domain.
711 | optional float spatial_sigma = 2 [default = 50];
712 | // Standard deviation for range domain.
713 | optional float range_sigma = 3 [default = 5];
714 | // minimum weight value (to avoid zero gradient for ref_grad_data)
715 | optional float min_weight = 4 [default = 0];
716 | }
717 |
718 | message DropoutParameter {
719 | optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
720 | }
721 |
722 | // DummyDataLayer fills any number of arbitrarily shaped blobs with random
723 | // (or constant) data generated by "Fillers" (see "message FillerParameter").
724 | message DummyDataParameter {
725 | // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N
726 | // shape fields, and 0, 1 or N data_fillers.
727 | //
728 | // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used.
729 | // If 1 data_filler is specified, it is applied to all top blobs. If N are
730 | // specified, the ith is applied to the ith top blob.
731 | repeated FillerParameter data_filler = 1;
732 | repeated BlobShape shape = 6;
733 |
734 | // 4D dimensions -- deprecated. Use "shape" instead.
735 | repeated uint32 num = 2;
736 | repeated uint32 channels = 3;
737 | repeated uint32 height = 4;
738 | repeated uint32 width = 5;
739 | }
740 |
741 | message EltwiseParameter {
742 | enum EltwiseOp {
743 | PROD = 0;
744 | SUM = 1;
745 | MAX = 2;
746 | }
747 | optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation
748 | repeated float coeff = 2; // blob-wise coefficient for SUM operation
749 |
750 | // Whether to use an asymptotically slower (for >2 inputs) but stabler method
751 | // of computing the gradient for the PROD operation. (No effect for SUM op.)
752 | optional bool stable_prod_grad = 3 [default = true];
753 | }
754 |
755 | // Message that stores parameters used by ELULayer
756 | message ELUParameter {
757 | // Described in:
758 | // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate
759 | // Deep Network Learning by Exponential Linear Units (ELUs). arXiv
760 | optional float alpha = 1 [default = 1];
761 | }
762 |
763 | // Message that stores parameters used by EmbedLayer
764 | message EmbedParameter {
765 | optional uint32 num_output = 1; // The number of outputs for the layer
766 | // The input is given as integers to be interpreted as one-hot
767 | // vector indices with dimension num_input. Hence num_input should be
768 | // 1 greater than the maximum possible input value.
769 | optional uint32 input_dim = 2;
770 |
771 | optional bool bias_term = 3 [default = true]; // Whether to use a bias term
772 | optional FillerParameter weight_filler = 4; // The filler for the weight
773 | optional FillerParameter bias_filler = 5; // The filler for the bias
774 |
775 | }
776 |
777 | // Message that stores parameters used by ExpLayer
778 | message ExpParameter {
779 | // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
780 | // Or if base is set to the default (-1), base is set to e,
781 | // so y = exp(shift + scale * x).
782 | optional float base = 1 [default = -1.0];
783 | optional float scale = 2 [default = 1.0];
784 | optional float shift = 3 [default = 0.0];
785 | }
786 |
787 | /// Message that stores parameters used by FlattenLayer
788 | message FlattenParameter {
789 | // The first axis to flatten: all preceding axes are retained in the output.
790 | // May be negative to index from the end (e.g., -1 for the last axis).
791 | optional int32 axis = 1 [default = 1];
792 |
793 | // The last axis to flatten: all following axes are retained in the output.
794 | // May be negative to index from the end (e.g., the default -1 for the last
795 | // axis).
796 | optional int32 end_axis = 2 [default = -1];
797 | }
798 |
799 | // Message that stores parameters used by HDF5DataLayer
800 | message HDF5DataParameter {
801 | // Specify the data source.
802 | optional string source = 1;
803 | // Specify the batch size.
804 | optional uint32 batch_size = 2;
805 |
806 | // Specify whether to shuffle the data.
807 | // If shuffle == true, the ordering of the HDF5 files is shuffled,
808 | // and the ordering of data within any given HDF5 file is shuffled,
809 | // but data between different files are not interleaved; all of a file's
810 | // data are output (in a random order) before moving onto another file.
811 | optional bool shuffle = 3 [default = false];
812 | }
813 |
814 | message HDF5OutputParameter {
815 | optional string file_name = 1;
816 | }
817 |
818 | message HingeLossParameter {
819 | enum Norm {
820 | L1 = 1;
821 | L2 = 2;
822 | }
823 | // Specify the Norm to use L1 or L2
824 | optional Norm norm = 1 [default = L1];
825 | }
826 |
827 | message ImageDataParameter {
828 | // Specify the data source.
829 | optional string source = 1;
830 | // Specify the batch size.
831 | optional uint32 batch_size = 4 [default = 1];
832 | // The rand_skip variable is for the data layer to skip a few data points
833 | // to avoid all asynchronous sgd clients to start at the same point. The skip
834 | // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
835 | // be larger than the number of keys in the database.
836 | optional uint32 rand_skip = 7 [default = 0];
837 | // Whether or not ImageLayer should shuffle the list of files at every epoch.
838 | optional bool shuffle = 8 [default = false];
839 | // It will also resize images if new_height or new_width are not zero.
840 | optional uint32 new_height = 9 [default = 0];
841 | optional uint32 new_width = 10 [default = 0];
842 | // Specify if the images are color or gray
843 | optional bool is_color = 11 [default = true];
844 |
845 | // This is the value set for pixels or images where we don't know the label
846 | optional int32 ignore_label = 15 [default = 255];
847 | enum LabelType {
848 | NONE = 0;
849 | IMAGE = 1;
850 | PIXEL = 2;
851 | }
852 | optional LabelType label_type = 16 [default = IMAGE];
853 |
854 | // DEPRECATED. See TransformationParameter. For data pre-processing, we can do
855 | // simple scaling and subtracting the data mean, if provided. Note that the
856 | // mean subtraction is always carried out before scaling.
857 | optional float scale = 2 [default = 1];
858 | optional string mean_file = 3;
859 | // DEPRECATED. See TransformationParameter. Specify if we would like to randomly
860 | // crop an image.
861 | optional uint32 crop_size = 5 [default = 0];
862 | // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
863 | // data.
864 | optional bool mirror = 6 [default = false];
865 | optional string root_folder = 12 [default = ""];
866 | }
867 |
868 | message InfogainLossParameter {
869 | // Specify the infogain matrix source.
870 | optional string source = 1;
871 | }
872 |
873 | message InnerProductParameter {
874 | optional uint32 num_output = 1; // The number of outputs for the layer
875 | optional bool bias_term = 2 [default = true]; // whether to have bias terms
876 | optional FillerParameter weight_filler = 3; // The filler for the weight
877 | optional FillerParameter bias_filler = 4; // The filler for the bias
878 |
879 | // The first axis to be lumped into a single inner product computation;
880 | // all preceding axes are retained in the output.
881 | // May be negative to index from the end (e.g., -1 for the last axis).
882 | optional int32 axis = 5 [default = 1];
883 | // Specify whether to transpose the weight matrix or not.
884 | // If transpose == true, any operations will be performed on the transpose
885 | // of the weight matrix. The weight matrix itself is not going to be transposed
886 | // but rather the transfer flag of operations will be toggled accordingly.
887 | optional bool transpose = 6 [default = false];
888 | }
889 |
890 | message InterpParameter {
891 | optional int32 height = 1 [default = 0]; // Height of output
892 | optional int32 width = 2 [default = 0]; // Width of output
893 | optional int32 zoom_factor = 3 [default = 1]; // zoom factor
894 | optional int32 shrink_factor = 4 [default = 1]; // shrink factor
895 | optional int32 pad_beg = 5 [default = 0]; // padding at begin of input
896 | optional int32 pad_end = 6 [default = 0]; // padding at end of input
897 | }
898 |
899 | // Message that stores parameters used by LogLayer
900 | message LogParameter {
901 | // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0.
902 | // Or if base is set to the default (-1), base is set to e,
903 | // so y = ln(shift + scale * x) = log_e(shift + scale * x)
904 | optional float base = 1 [default = -1.0];
905 | optional float scale = 2 [default = 1.0];
906 | optional float shift = 3 [default = 0.0];
907 | }
908 |
909 | // Message that stores parameters used by LRNLayer
910 | message LRNParameter {
911 | optional uint32 local_size = 1 [default = 5];
912 | optional float alpha = 2 [default = 1.];
913 | optional float beta = 3 [default = 0.75];
914 | enum NormRegion {
915 | ACROSS_CHANNELS = 0;
916 | WITHIN_CHANNEL = 1;
917 | }
918 | optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS];
919 | optional float k = 5 [default = 1.];
920 | enum Engine {
921 | DEFAULT = 0;
922 | CAFFE = 1;
923 | CUDNN = 2;
924 | }
925 | optional Engine engine = 6 [default = DEFAULT];
926 | }
927 |
928 | message MatReadParameter {
929 | required string prefix = 1;
930 | optional string source = 2 [default = ""];
931 | optional int32 strip = 3 [default = 0];
932 | optional int32 batch_size = 4 [default = 1];
933 | }
934 |
935 | message MatWriteParameter {
936 | required string prefix = 1;
937 | optional string source = 2 [default = ""];
938 | optional int32 strip = 3 [default = 0];
939 | optional int32 period = 4 [default = 1];
940 | }
941 |
942 | message MemoryDataParameter {
943 | optional uint32 batch_size = 1;
944 | optional uint32 channels = 2;
945 | optional uint32 height = 3;
946 | optional uint32 width = 4;
947 | }
948 |
949 | message MVNParameter {
950 | // This parameter can be set to false to normalize mean only
951 | optional bool normalize_variance = 1 [default = true];
952 |
953 | // This parameter can be set to true to perform DNN-like MVN
954 | optional bool across_channels = 2 [default = false];
955 |
956 | // Epsilon for not dividing by zero while normalizing variance
957 | optional float eps = 3 [default = 1e-9];
958 | }
959 |
960 | message PoolingParameter {
961 | enum PoolMethod {
962 | MAX = 0;
963 | AVE = 1;
964 | STOCHASTIC = 2;
965 | }
966 | optional PoolMethod pool = 1 [default = MAX]; // The pooling method
967 | // Pad, kernel size, and stride are all given as a single value for equal
968 | // dimensions in height and width or as Y, X pairs.
969 | optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X)
970 | optional uint32 pad_h = 9 [default = 0]; // The padding height
971 | optional uint32 pad_w = 10 [default = 0]; // The padding width
972 | optional uint32 kernel_size = 2; // The kernel size (square)
973 | optional uint32 kernel_h = 5; // The kernel height
974 | optional uint32 kernel_w = 6; // The kernel width
975 | optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X)
976 | optional uint32 stride_h = 7; // The stride height
977 | optional uint32 stride_w = 8; // The stride width
978 | enum Engine {
979 | DEFAULT = 0;
980 | CAFFE = 1;
981 | CUDNN = 2;
982 | }
983 | optional Engine engine = 11 [default = DEFAULT];
984 | // If global_pooling then it will pool over the size of the bottom by doing
985 | // kernel_h = bottom->height and kernel_w = bottom->width
986 | optional bool global_pooling = 12 [default = false];
987 | }
988 |
989 | message PowerParameter {
990 | // PowerLayer computes outputs y = (shift + scale * x) ^ power.
991 | optional float power = 1 [default = 1.0];
992 | optional float scale = 2 [default = 1.0];
993 | optional float shift = 3 [default = 0.0];
994 | }
995 |
996 | message PythonParameter {
997 | optional string module = 1;
998 | optional string layer = 2;
999 | // This value is set to the attribute `param_str` of the `PythonLayer` object
1000 | // in Python before calling the `setup()` method. This could be a number,
1001 | // string, dictionary in Python dict format, JSON, etc. You may parse this
1002 | // string in `setup` method and use it in `forward` and `backward`.
1003 | optional string param_str = 3 [default = ''];
1004 | // Whether this PythonLayer is shared among worker solvers during data parallelism.
1005 | // If true, each worker solver sequentially run forward from this layer.
1006 | // This value should be set true if you are using it as a data layer.
1007 | optional bool share_in_parallel = 4 [default = false];
1008 | }
1009 |
1010 | // Message that stores parameters used by ReductionLayer
1011 | message ReductionParameter {
1012 | enum ReductionOp {
1013 | SUM = 1;
1014 | ASUM = 2;
1015 | SUMSQ = 3;
1016 | MEAN = 4;
1017 | }
1018 |
1019 | optional ReductionOp operation = 1 [default = SUM]; // reduction operation
1020 |
1021 | // The first axis to reduce to a scalar -- may be negative to index from the
1022 | // end (e.g., -1 for the last axis).
1023 | // (Currently, only reduction along ALL "tail" axes is supported; reduction
1024 | // of axis M through N, where N < num_axes - 1, is unsupported.)
1025 | // Suppose we have an n-axis bottom Blob with shape:
1026 | // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)).
1027 | // If axis == m, the output Blob will have shape
1028 | // (d0, d1, d2, ..., d(m-1)),
1029 | // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1))
1030 | // times, each including (dm * d(m+1) * ... * d(n-1)) individual data.
1031 | // If axis == 0 (the default), the output Blob always has the empty shape
1032 | // (count 1), performing reduction across the entire input --
1033 | // often useful for creating new loss functions.
1034 | optional int32 axis = 2 [default = 0];
1035 |
1036 | optional float coeff = 3 [default = 1.0]; // coefficient for output
1037 | }
1038 |
1039 | // Message that stores parameters used by ReLULayer
1040 | message ReLUParameter {
1041 | // Allow non-zero slope for negative inputs to speed up optimization
1042 | // Described in:
1043 | // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities
1044 | // improve neural network acoustic models. In ICML Workshop on Deep Learning
1045 | // for Audio, Speech, and Language Processing.
1046 | optional float negative_slope = 1 [default = 0];
1047 | enum Engine {
1048 | DEFAULT = 0;
1049 | CAFFE = 1;
1050 | CUDNN = 2;
1051 | }
1052 | optional Engine engine = 2 [default = DEFAULT];
1053 | }
1054 |
1055 | message ReshapeParameter {
1056 | // Specify the output dimensions. If some of the dimensions are set to 0,
1057 | // the corresponding dimension from the bottom layer is used (unchanged).
1058 | // Exactly one dimension may be set to -1, in which case its value is
1059 | // inferred from the count of the bottom blob and the remaining dimensions.
1060 | // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8:
1061 | //
1062 | // layer {
1063 | // type: "Reshape" bottom: "input" top: "output"
1064 | // reshape_param { ... }
1065 | // }
1066 | //
1067 | // If "input" is 2D with shape 2 x 8, then the following reshape_param
1068 | // specifications are all equivalent, producing a 3D blob "output" with shape
1069 | // 2 x 2 x 4:
1070 | //
1071 | // reshape_param { shape { dim: 2 dim: 2 dim: 4 } }
1072 | // reshape_param { shape { dim: 0 dim: 2 dim: 4 } }
1073 | // reshape_param { shape { dim: 0 dim: 2 dim: -1 } }
1074 | // reshape_param { shape { dim: -1 dim: 0 dim: 2 } }
1075 | //
1076 | optional BlobShape shape = 1;
1077 |
1078 | // axis and num_axes control the portion of the bottom blob's shape that are
1079 | // replaced by (included in) the reshape. By default (axis == 0 and
1080 | // num_axes == -1), the entire bottom blob shape is included in the reshape,
1081 | // and hence the shape field must specify the entire output shape.
1082 | //
1083 | // axis may be non-zero to retain some portion of the beginning of the input
1084 | // shape (and may be negative to index from the end; e.g., -1 to begin the
1085 | // reshape after the last axis, including nothing in the reshape,
1086 | // -2 to include only the last axis, etc.).
1087 | //
1088 | // For example, suppose "input" is a 2D blob with shape 2 x 8.
1089 | // Then the following ReshapeLayer specifications are all equivalent,
1090 | // producing a blob "output" with shape 2 x 2 x 4:
1091 | //
1092 | // reshape_param { shape { dim: 2 dim: 2 dim: 4 } }
1093 | // reshape_param { shape { dim: 2 dim: 4 } axis: 1 }
1094 | // reshape_param { shape { dim: 2 dim: 4 } axis: -3 }
1095 | //
1096 | // num_axes specifies the extent of the reshape.
1097 | // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on
1098 | // input axes in the range [axis, axis+num_axes].
1099 | // num_axes may also be -1, the default, to include all remaining axes
1100 | // (starting from axis).
1101 | //
1102 | // For example, suppose "input" is a 2D blob with shape 2 x 8.
1103 | // Then the following ReshapeLayer specifications are equivalent,
1104 | // producing a blob "output" with shape 1 x 2 x 8.
1105 | //
1106 | // reshape_param { shape { dim: 1 dim: 2 dim: 8 } }
1107 | // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 }
1108 | // reshape_param { shape { dim: 1 } num_axes: 0 }
1109 | //
1110 | // On the other hand, these would produce output blob shape 2 x 1 x 8:
1111 | //
1112 | // reshape_param { shape { dim: 2 dim: 1 dim: 8 } }
1113 | // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 }
1114 | //
1115 | optional int32 axis = 2 [default = 0];
1116 | optional int32 num_axes = 3 [default = -1];
1117 | }
1118 |
1119 | message ScaleParameter {
1120 | // The first axis of bottom[0] (the first input Blob) along which to apply
1121 | // bottom[1] (the second input Blob). May be negative to index from the end
1122 | // (e.g., -1 for the last axis).
1123 | //
1124 | // For example, if bottom[0] is 4D with shape 100x3x40x60, the output
1125 | // top[0] will have the same shape, and bottom[1] may have any of the
1126 | // following shapes (for the given value of axis):
1127 | // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60
1128 | // (axis == 1 == -3) 3; 3x40; 3x40x60
1129 | // (axis == 2 == -2) 40; 40x60
1130 | // (axis == 3 == -1) 60
1131 | // Furthermore, bottom[1] may have the empty shape (regardless of the value of
1132 | // "axis") -- a scalar multiplier.
1133 | optional int32 axis = 1 [default = 1];
1134 |
1135 | // (num_axes is ignored unless just one bottom is given and the scale is
1136 | // a learned parameter of the layer. Otherwise, num_axes is determined by the
1137 | // number of axes by the second bottom.)
1138 | // The number of axes of the input (bottom[0]) covered by the scale
1139 | // parameter, or -1 to cover all axes of bottom[0] starting from `axis`.
1140 | // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar.
1141 | optional int32 num_axes = 2 [default = 1];
1142 |
1143 | // (filler is ignored unless just one bottom is given and the scale is
1144 | // a learned parameter of the layer.)
1145 | // The initialization for the learned scale parameter.
1146 | // Default is the unit (1) initialization, resulting in the ScaleLayer
1147 | // initially performing the identity operation.
1148 | optional FillerParameter filler = 3;
1149 |
1150 | // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but
1151 | // may be more efficient). Initialized with bias_filler (defaults to 0).
1152 | optional bool bias_term = 4 [default = false];
1153 | optional FillerParameter bias_filler = 5;
1154 | }
1155 |
1156 | message SegAccuracyParameter {
1157 | enum AccuracyMetric {
1158 | PixelAccuracy = 0;
1159 | ClassAccuracy = 1;
1160 | PixelIOU = 2;
1161 | }
1162 | optional AccuracyMetric metric = 1 [default = PixelAccuracy];
1163 | // will ignore pixels with this value when computing accuracy
1164 | repeated int32 ignore_label = 2;
1165 | optional bool reset = 3 [default = true];
1166 | }
1167 |
1168 | message SigmoidParameter {
1169 | enum Engine {
1170 | DEFAULT = 0;
1171 | CAFFE = 1;
1172 | CUDNN = 2;
1173 | }
1174 | optional Engine engine = 1 [default = DEFAULT];
1175 | }
1176 |
1177 | message SliceParameter {
1178 | // The axis along which to slice -- may be negative to index from the end
1179 | // (e.g., -1 for the last axis).
1180 | // By default, SliceLayer concatenates blobs along the "channels" axis (1).
1181 | optional int32 axis = 3 [default = 1];
1182 | repeated uint32 slice_point = 2;
1183 |
1184 | // DEPRECATED: alias for "axis" -- does not support negative indexing.
1185 | optional uint32 slice_dim = 1 [default = 1];
1186 | }
1187 |
1188 | // Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer
1189 | message SoftmaxParameter {
1190 | enum Engine {
1191 | DEFAULT = 0;
1192 | CAFFE = 1;
1193 | CUDNN = 2;
1194 | }
1195 | optional Engine engine = 1 [default = DEFAULT];
1196 |
1197 | // The axis along which to perform the softmax -- may be negative to index
1198 | // from the end (e.g., -1 for the last axis).
1199 | // Any other axes will be evaluated as independent softmaxes.
1200 | optional int32 axis = 2 [default = 1];
1201 | }
1202 |
1203 | message TanHParameter {
1204 | enum Engine {
1205 | DEFAULT = 0;
1206 | CAFFE = 1;
1207 | CUDNN = 2;
1208 | }
1209 | optional Engine engine = 1 [default = DEFAULT];
1210 | }
1211 |
1212 | // Message that stores parameters used by TileLayer
1213 | message TileParameter {
1214 | // The index of the axis to tile.
1215 | optional int32 axis = 1 [default = 1];
1216 |
1217 | // The number of copies (tiles) of the blob to output.
1218 | optional int32 tiles = 2;
1219 | }
1220 |
1221 | // Message that stores parameters used by ThresholdLayer
1222 | message ThresholdParameter {
1223 | optional float threshold = 1 [default = 0]; // Strictly positive values
1224 | }
1225 |
1226 | message UniqueLabelParameter {
1227 | required int32 max_labels = 1;
1228 | repeated int32 ignore_label = 2;
1229 | repeated float force_label = 3;
1230 | }
1231 |
1232 | message WindowDataParameter {
1233 | // Specify the data source.
1234 | optional string source = 1;
1235 | // For data pre-processing, we can do simple scaling and subtracting the
1236 | // data mean, if provided. Note that the mean subtraction is always carried
1237 | // out before scaling.
1238 | optional float scale = 2 [default = 1];
1239 | optional string mean_file = 3;
1240 | // Specify the batch size.
1241 | optional uint32 batch_size = 4;
1242 | // Specify if we would like to randomly crop an image.
1243 | optional uint32 crop_size = 5 [default = 0];
1244 | // Specify if we want to randomly mirror data.
1245 | optional bool mirror = 6 [default = false];
1246 | // Foreground (object) overlap threshold
1247 | optional float fg_threshold = 7 [default = 0.5];
1248 | // Background (non-object) overlap threshold
1249 | optional float bg_threshold = 8 [default = 0.5];
1250 | // Fraction of batch that should be foreground objects
1251 | optional float fg_fraction = 9 [default = 0.25];
1252 | // Amount of contextual padding to add around a window
1253 | // (used only by the window_data_layer)
1254 | optional uint32 context_pad = 10 [default = 0];
1255 | // Mode for cropping out a detection window
1256 | // warp: cropped window is warped to a fixed size and aspect ratio
1257 | // square: the tightest square around the window is cropped
1258 | optional string crop_mode = 11 [default = "warp"];
1259 | // cache_images: will load all images in memory for faster access
1260 | optional bool cache_images = 12 [default = false];
1261 | // append root_folder to locate images
1262 | optional string root_folder = 13 [default = ""];
1263 | }
1264 |
1265 | message SPPParameter {
1266 | enum PoolMethod {
1267 | MAX = 0;
1268 | AVE = 1;
1269 | STOCHASTIC = 2;
1270 | }
1271 | optional uint32 pyramid_height = 1;
1272 | optional PoolMethod pool = 2 [default = MAX]; // The pooling method
1273 | enum Engine {
1274 | DEFAULT = 0;
1275 | CAFFE = 1;
1276 | CUDNN = 2;
1277 | }
1278 | optional Engine engine = 6 [default = DEFAULT];
1279 | }
1280 |
1281 | // DEPRECATED: use LayerParameter.
1282 | message V1LayerParameter {
1283 | repeated string bottom = 2;
1284 | repeated string top = 3;
1285 | optional string name = 4;
1286 | repeated NetStateRule include = 32;
1287 | repeated NetStateRule exclude = 33;
1288 | enum LayerType {
1289 | NONE = 0;
1290 | ABSVAL = 35;
1291 | ACCURACY = 1;
1292 | ARGMAX = 30;
1293 | BNLL = 2;
1294 | CONCAT = 3;
1295 | CONTRASTIVE_LOSS = 37;
1296 | CONVOLUTION = 4;
1297 | DATA = 5;
1298 | DECONVOLUTION = 39;
1299 | DROPOUT = 6;
1300 | DUMMY_DATA = 32;
1301 | EUCLIDEAN_LOSS = 7;
1302 | ELTWISE = 25;
1303 | EXP = 38;
1304 | FLATTEN = 8;
1305 | HDF5_DATA = 9;
1306 | HDF5_OUTPUT = 10;
1307 | HINGE_LOSS = 28;
1308 | IM2COL = 11;
1309 | IMAGE_DATA = 12;
1310 | INFOGAIN_LOSS = 13;
1311 | INNER_PRODUCT = 14;
1312 | LRN = 15;
1313 | MEMORY_DATA = 29;
1314 | MULTINOMIAL_LOGISTIC_LOSS = 16;
1315 | MVN = 34;
1316 | POOLING = 17;
1317 | POWER = 26;
1318 | RELU = 18;
1319 | SIGMOID = 19;
1320 | SIGMOID_CROSS_ENTROPY_LOSS = 27;
1321 | SILENCE = 36;
1322 | SOFTMAX = 20;
1323 | SOFTMAX_LOSS = 21;
1324 | SPLIT = 22;
1325 | SLICE = 33;
1326 | TANH = 23;
1327 | WINDOW_DATA = 24;
1328 | THRESHOLD = 31;
1329 | }
1330 | optional LayerType type = 5;
1331 | repeated BlobProto blobs = 6;
1332 | repeated string param = 1001;
1333 | repeated DimCheckMode blob_share_mode = 1002;
1334 | enum DimCheckMode {
1335 | STRICT = 0;
1336 | PERMISSIVE = 1;
1337 | }
1338 | repeated float blobs_lr = 7;
1339 | repeated float weight_decay = 8;
1340 | repeated float loss_weight = 35;
1341 | optional AccuracyParameter accuracy_param = 27;
1342 | optional ArgMaxParameter argmax_param = 23;
1343 | optional ConcatParameter concat_param = 9;
1344 | optional ContrastiveLossParameter contrastive_loss_param = 40;
1345 | optional ConvolutionParameter convolution_param = 10;
1346 | optional DataParameter data_param = 11;
1347 | optional DropoutParameter dropout_param = 12;
1348 | optional DummyDataParameter dummy_data_param = 26;
1349 | optional EltwiseParameter eltwise_param = 24;
1350 | optional ExpParameter exp_param = 41;
1351 | optional HDF5DataParameter hdf5_data_param = 13;
1352 | optional HDF5OutputParameter hdf5_output_param = 14;
1353 | optional HingeLossParameter hinge_loss_param = 29;
1354 | optional ImageDataParameter image_data_param = 15;
1355 | optional InfogainLossParameter infogain_loss_param = 16;
1356 | optional InnerProductParameter inner_product_param = 17;
1357 | optional LRNParameter lrn_param = 18;
1358 | optional MemoryDataParameter memory_data_param = 22;
1359 | optional MVNParameter mvn_param = 34;
1360 | optional PoolingParameter pooling_param = 19;
1361 | optional PowerParameter power_param = 21;
1362 | optional ReLUParameter relu_param = 30;
1363 | optional SigmoidParameter sigmoid_param = 38;
1364 | optional SoftmaxParameter softmax_param = 39;
1365 | optional SliceParameter slice_param = 31;
1366 | optional TanHParameter tanh_param = 37;
1367 | optional ThresholdParameter threshold_param = 25;
1368 | optional WindowDataParameter window_data_param = 20;
1369 | optional TransformationParameter transform_param = 36;
1370 | optional LossParameter loss_param = 42;
1371 | optional V0LayerParameter layer = 1;
1372 | }
1373 |
1374 | // DEPRECATED: V0LayerParameter is the old way of specifying layer parameters
1375 | // in Caffe. We keep this message type around for legacy support.
1376 | message V0LayerParameter {
1377 | optional string name = 1; // the layer name
1378 | optional string type = 2; // the string to specify the layer type
1379 |
1380 | // Parameters to specify layers with inner products.
1381 | optional uint32 num_output = 3; // The number of outputs for the layer
1382 | optional bool biasterm = 4 [default = true]; // whether to have bias terms
1383 | optional FillerParameter weight_filler = 5; // The filler for the weight
1384 | optional FillerParameter bias_filler = 6; // The filler for the bias
1385 |
1386 | optional uint32 pad = 7 [default = 0]; // The padding size
1387 | optional uint32 kernelsize = 8; // The kernel size
1388 | optional uint32 group = 9 [default = 1]; // The group size for group conv
1389 | optional uint32 stride = 10 [default = 1]; // The stride
1390 | enum PoolMethod {
1391 | MAX = 0;
1392 | AVE = 1;
1393 | STOCHASTIC = 2;
1394 | }
1395 | optional PoolMethod pool = 11 [default = MAX]; // The pooling method
1396 | optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
1397 |
1398 | optional uint32 local_size = 13 [default = 5]; // for local response norm
1399 | optional float alpha = 14 [default = 1.]; // for local response norm
1400 | optional float beta = 15 [default = 0.75]; // for local response norm
1401 | optional float k = 22 [default = 1.];
1402 |
1403 | // For data layers, specify the data source
1404 | optional string source = 16;
1405 | // For data pre-processing, we can do simple scaling and subtracting the
1406 | // data mean, if provided. Note that the mean subtraction is always carried
1407 | // out before scaling.
1408 | optional float scale = 17 [default = 1];
1409 | optional string meanfile = 18;
1410 | // For data layers, specify the batch size.
1411 | optional uint32 batchsize = 19;
1412 | // For data layers, specify if we would like to randomly crop an image.
1413 | optional uint32 cropsize = 20 [default = 0];
1414 | // For data layers, specify if we want to randomly mirror data.
1415 | optional bool mirror = 21 [default = false];
1416 |
1417 | // The blobs containing the numeric parameters of the layer
1418 | repeated BlobProto blobs = 50;
1419 | // The ratio that is multiplied on the global learning rate. If you want to
1420 | // set the learning ratio for one blob, you need to set it for all blobs.
1421 | repeated float blobs_lr = 51;
1422 | // The weight decay that is multiplied on the global weight decay.
1423 | repeated float weight_decay = 52;
1424 |
1425 | // The rand_skip variable is for the data layer to skip a few data points
1426 | // to avoid all asynchronous sgd clients to start at the same point. The skip
1427 | // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
1428 | // be larger than the number of keys in the database.
1429 | optional uint32 rand_skip = 53 [default = 0];
1430 |
1431 | // Fields related to detection (det_*)
1432 | // foreground (object) overlap threshold
1433 | optional float det_fg_threshold = 54 [default = 0.5];
1434 | // background (non-object) overlap threshold
1435 | optional float det_bg_threshold = 55 [default = 0.5];
1436 | // Fraction of batch that should be foreground objects
1437 | optional float det_fg_fraction = 56 [default = 0.25];
1438 |
1439 | // optional bool OBSOLETE_can_clobber = 57 [default = true];
1440 |
1441 | // Amount of contextual padding to add around a window
1442 | // (used only by the window_data_layer)
1443 | optional uint32 det_context_pad = 58 [default = 0];
1444 |
1445 | // Mode for cropping out a detection window
1446 | // warp: cropped window is warped to a fixed size and aspect ratio
1447 | // square: the tightest square around the window is cropped
1448 | optional string det_crop_mode = 59 [default = "warp"];
1449 |
1450 | // For ReshapeLayer, one needs to specify the new dimensions.
1451 | optional int32 new_num = 60 [default = 0];
1452 | optional int32 new_channels = 61 [default = 0];
1453 | optional int32 new_height = 62 [default = 0];
1454 | optional int32 new_width = 63 [default = 0];
1455 |
1456 | // Whether or not ImageLayer should shuffle the list of files at every epoch.
1457 | // It will also resize images if new_height or new_width are not zero.
1458 | optional bool shuffle_images = 64 [default = false];
1459 |
1460 | // For ConcatLayer, one needs to specify the dimension for concatenation, and
1461 | // the other dimensions must be the same for all the bottom blobs.
1462 | // By default it will concatenate blobs along the channels dimension.
1463 | optional uint32 concat_dim = 65 [default = 1];
1464 |
1465 | optional HDF5OutputParameter hdf5_output_param = 1001;
1466 | }
1467 |
1468 | message PReLUParameter {
1469 | // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers:
1470 | // Surpassing Human-Level Performance on ImageNet Classification, 2015.
1471 |
1472 | // Initial value of a_i. Default is a_i=0.25 for all i.
1473 | optional FillerParameter filler = 1;
1474 | // Whether or not slope paramters are shared across channels.
1475 | optional bool channel_shared = 2 [default = false];
1476 | }
1477 |
--------------------------------------------------------------------------------