├── libs ├── __init__.py ├── datasets │ ├── __init__.py │ └── voc.py ├── models │ ├── __init__.py │ ├── resnet.py │ └── pspnet.py ├── utils │ ├── __init__.py │ ├── crf.py │ └── metric.py └── caffe.proto ├── data ├── models │ └── .gitignore └── datasets │ ├── cityscapes │ └── labels.txt │ ├── voc12 │ └── labels.txt │ └── ade20k │ └── labels.txt ├── docs ├── demo.png ├── voc12_ms.json └── voc12_ss.json ├── config ├── ade20k.yaml ├── voc12_ss.yaml ├── cityscapes.yaml └── voc12_ms.yaml ├── LICENSE ├── .gitignore ├── README.md ├── draw_model.py ├── demo.py ├── convert.py └── eval.py /libs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/models/.gitignore: -------------------------------------------------------------------------------- 1 | *.caffemodel 2 | *.pth -------------------------------------------------------------------------------- /docs/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kazuto1011/pspnet-pytorch/HEAD/docs/demo.png -------------------------------------------------------------------------------- /libs/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .voc import * 3 | -------------------------------------------------------------------------------- /libs/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .resnet import * 3 | from .pspnet import * 4 | -------------------------------------------------------------------------------- /libs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .crf import * 3 | from .metric import * 4 | -------------------------------------------------------------------------------- /data/datasets/cityscapes/labels.txt: -------------------------------------------------------------------------------- 1 | 255 unlabeled 2 | 0 Road 3 | 1 Sidewalk 4 | 2 Building 5 | 3 Wall 6 | 4 Fence 7 | 5 Pole 8 | 6 Traffic light 9 | 7 Traffic sign 10 | 8 Vegetation 11 | 9 Terrain 12 | 10 Sky 13 | 11 Person 14 | 12 Rider 15 | 13 Car 16 | 14 Truck 17 | 15 Bus 18 | 16 Train 19 | 17 Motorcycle 20 | 18 Bicycle -------------------------------------------------------------------------------- /data/datasets/voc12/labels.txt: -------------------------------------------------------------------------------- 1 | 0 __background__ 2 | 1 aeroplane 3 | 2 bicycle 4 | 3 bird 5 | 4 boat 6 | 5 bottle 7 | 6 bus 8 | 7 car 9 | 8 cat 10 | 9 chair 11 | 10 cow 12 | 11 diningtable 13 | 12 dog 14 | 13 horse 15 | 14 motorbike 16 | 15 person 17 | 16 pottedplant 18 | 17 sheep 19 | 18 sofa 20 | 19 train 21 | 20 tvmonitor -------------------------------------------------------------------------------- /config/ade20k.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 'ade20k' 2 | DATASET_ROOT: 3 | CAFFE_MODEL: 'data/models/pspnet50_ADE20K.caffemodel' 4 | PYTORCH_MODEL: 'data/models/pspnet50_ADE20K.pth' 5 | LABELS: 'data/datasets/ade20k/labels.txt' 6 | N_CLASSES: 150 7 | N_BLOCKS: [3, 4, 6, 3] 8 | PYRAMIDS: [6, 3, 2, 1] 9 | IMAGE: 10 | SIZE: 11 | BASE: 512 12 | TRAIN: 473 13 | TEST: 473 14 | MEAN: 15 | R: 123.68 16 | G: 116.779 17 | B: 103.939 18 | SCALES: [0.5, 0.75, 1, 1.25, 1.5, 1.75] 19 | NUM_WORKERS: 1 -------------------------------------------------------------------------------- /config/voc12_ss.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 'voc12' 2 | DATASET_ROOT: '/media/kazuto1011/Extra/VOCdevkit' 3 | CAFFE_MODEL: 'data/models/pspnet101_VOC2012.caffemodel' 4 | PYTORCH_MODEL: 'data/models/pspnet101_VOC2012.pth' 5 | LABELS: 'data/datasets/voc12/labels.txt' 6 | N_CLASSES: 21 7 | N_BLOCKS: [3, 4, 23, 3] 8 | PYRAMIDS: [6, 3, 2, 1] 9 | IMAGE: 10 | SIZE: 11 | BASE: 512 12 | TRAIN: 473 13 | TEST: 473 14 | MEAN: 15 | R: 123.68 16 | G: 116.779 17 | B: 103.939 18 | SCALES: [1,] 19 | NUM_WORKERS: 1 -------------------------------------------------------------------------------- /config/cityscapes.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 'cityscapes' 2 | DATASET_ROOT: 3 | CAFFE_MODEL: 'data/models/pspnet101_cityscapes.caffemodel' 4 | PYTORCH_MODEL: 'data/models/pspnet101_cityscapes.pth' 5 | LABELS: 'data/datasets/cityscapes/labels.txt' 6 | N_CLASSES: 19 7 | N_BLOCKS: [3, 4, 23, 3] 8 | PYRAMIDS: [6, 3, 2, 1] 9 | IMAGE: 10 | SIZE: 11 | BASE: 2048 12 | TRAIN: 713 13 | TEST: 713 14 | MEAN: 15 | R: 123.68 16 | G: 116.779 17 | B: 103.939 18 | SCALES: [0.5, 0.75, 1, 1.25, 1.5, 1.75] 19 | NUM_WORKERS: 1 -------------------------------------------------------------------------------- /config/voc12_ms.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 'voc12' 2 | DATASET_ROOT: '/media/kazuto1011/Extra/VOCdevkit' 3 | CAFFE_MODEL: 'data/models/pspnet101_VOC2012.caffemodel' 4 | PYTORCH_MODEL: 'data/models/pspnet101_VOC2012.pth' 5 | LABELS: 'data/datasets/voc12/labels.txt' 6 | N_CLASSES: 21 7 | N_BLOCKS: [3, 4, 23, 3] 8 | PYRAMIDS: [6, 3, 2, 1] 9 | IMAGE: 10 | SIZE: 11 | BASE: 512 12 | TRAIN: 473 13 | TEST: 473 14 | MEAN: 15 | R: 123.68 16 | G: 116.779 17 | B: 103.939 18 | SCALES: [0.5, 0.75, 1, 1.25, 1.5, 1.75] 19 | NUM_WORKERS: 1 -------------------------------------------------------------------------------- /libs/utils/crf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Kazuto Nakashima 5 | # URL: http://kazuto1011.github.io 6 | # Created: 2017-11-01 7 | 8 | import numpy as np 9 | import pydensecrf.densecrf as dcrf 10 | import pydensecrf.utils as utils 11 | 12 | MAX_ITER = 10 13 | POS_W = 3 14 | POS_XY_STD = 3 15 | Bi_W = 4 16 | Bi_XY_STD = 49 17 | Bi_RGB_STD = 5 18 | 19 | 20 | def dense_crf(img, probs): 21 | c = probs.shape[0] 22 | h = probs.shape[1] 23 | w = probs.shape[2] 24 | 25 | U = utils.unary_from_softmax(probs) 26 | U = np.ascontiguousarray(U) 27 | 28 | img = np.ascontiguousarray(img) 29 | 30 | d = dcrf.DenseCRF2D(w, h, c) 31 | d.setUnaryEnergy(U) 32 | d.addPairwiseGaussian(sxy=POS_XY_STD, compat=POS_W) 33 | d.addPairwiseBilateral(sxy=Bi_XY_STD, srgb=Bi_RGB_STD, rgbim=img, compat=Bi_W) 34 | 35 | Q = d.inference(MAX_ITER) 36 | Q = np.array(Q).reshape((c, h, w)) 37 | 38 | return Q 39 | -------------------------------------------------------------------------------- /docs/voc12_ms.json: -------------------------------------------------------------------------------- 1 | { 2 | "Class IoU": { 3 | "0": 0.9723835140393139, 4 | "1": 0.932284405801355, 5 | "2": 0.7468137710962681, 6 | "3": 0.9218783356840369, 7 | "4": 0.8484579004810219, 8 | "5": 0.8696297249201347, 9 | "6": 0.9720973207300103, 10 | "7": 0.935915049578212, 11 | "8": 0.9374628071322245, 12 | "9": 0.6875070250549213, 13 | "10": 0.9582069281376151, 14 | "11": 0.8308168870071049, 15 | "12": 0.9282477592335033, 16 | "13": 0.9304011384698115, 17 | "14": 0.9140842413972792, 18 | "15": 0.9055205679928403, 19 | "16": 0.8116604297212983, 20 | "17": 0.9352167645389193, 21 | "18": 0.8234079050857073, 22 | "19": 0.9300925249646486, 23 | "20": 0.8099681859474088 24 | }, 25 | "FreqW Acc": 0.9531461900878117, 26 | "Mean Acc": 0.9279972291530902, 27 | "Mean IoU": 0.8858120565244588, 28 | "Overall Acc": 0.9756045013064952 29 | } -------------------------------------------------------------------------------- /docs/voc12_ss.json: -------------------------------------------------------------------------------- 1 | { 2 | "Class IoU": { 3 | "0": 0.9720507018787903, 4 | "1": 0.9296281094338285, 5 | "2": 0.7234742052854309, 6 | "3": 0.9225243177338182, 7 | "4": 0.8318045051638611, 8 | "5": 0.8782490336084968, 9 | "6": 0.9715514433351778, 10 | "7": 0.9402895154307185, 11 | "8": 0.9372223191974904, 12 | "9": 0.6759331602494226, 13 | "10": 0.9526220286282768, 14 | "11": 0.8278434084687414, 15 | "12": 0.9248270465951416, 16 | "13": 0.9248037370339303, 17 | "14": 0.9131082290202047, 18 | "15": 0.9030004679899833, 19 | "16": 0.7972657605334218, 20 | "17": 0.9282957360023752, 21 | "18": 0.8207954040092147, 22 | "19": 0.934780815109344, 23 | "20": 0.7932057422450404 24 | }, 25 | "FreqW Acc": 0.9521466591274718, 26 | "Mean Acc": 0.9302564874727208, 27 | "Mean IoU": 0.8811083660453672, 28 | "Overall Acc": 0.9749297938026947 29 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kazuto Nakashima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/utils/metric.py: -------------------------------------------------------------------------------- 1 | # Originally written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | 7 | def _fast_hist(label_true, label_pred, n_class): 8 | mask = (label_true >= 0) & (label_true < n_class) 9 | hist = np.bincount( 10 | n_class * label_true[mask].astype(int) + label_pred[mask], 11 | minlength=n_class ** 2, 12 | ).reshape(n_class, n_class) 13 | return hist 14 | 15 | 16 | def scores(label_trues, label_preds, n_class): 17 | hist = np.zeros((n_class, n_class)) 18 | for lt, lp in zip(label_trues, label_preds): 19 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 20 | acc = np.diag(hist).sum() / hist.sum() 21 | acc_cls = np.diag(hist) / hist.sum(axis=1) 22 | acc_cls = np.nanmean(acc_cls) 23 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 24 | mean_iu = np.nanmean(iu) 25 | freq = hist.sum(axis=1) / hist.sum() 26 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 27 | cls_iu = dict(zip(range(n_class), iu)) 28 | 29 | return ( 30 | { 31 | "Overall Acc": acc, 32 | "Mean Acc": acc_cls, 33 | "FreqW Acc": fwavacc, 34 | "Mean IoU": mean_iu, 35 | }, 36 | cls_iu, 37 | ) 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | .vscode/ 104 | .style.yapf 105 | results.json 106 | model.pdf -------------------------------------------------------------------------------- /libs/datasets/voc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import os.path as osp 5 | import sys 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.utils.data as data 11 | from PIL import Image 12 | 13 | 14 | class VOCSegmentation(data.Dataset): 15 | def __init__( 16 | self, 17 | root, 18 | image_set, 19 | transform=None, 20 | target_transform=None, 21 | dataset_name="VOC2007", 22 | ): 23 | self.root = root 24 | self.image_set = image_set 25 | self.transform = transform 26 | self.target_transform = target_transform 27 | self.mean_rgb = np.array([123.68, 116.779, 103.939]) 28 | 29 | self._annopath = osp.join( 30 | self.root, dataset_name, "SegmentationClass", "%s.png" 31 | ) 32 | self._imgpath = osp.join(self.root, dataset_name, "JPEGImages", "%s.jpg") 33 | self._imgsetpath = osp.join( 34 | self.root, dataset_name, "ImageSets", "Segmentation", "%s.txt" 35 | ) 36 | 37 | with open(self._imgsetpath % self.image_set) as f: 38 | self.ids = f.readlines() 39 | self.ids = [x.strip("\n") for x in self.ids] 40 | 41 | def __getitem__(self, index): 42 | img_id = self.ids[index] 43 | 44 | target = Image.open(self._annopath % img_id) 45 | img = Image.open(self._imgpath % img_id).convert("RGB") 46 | 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | if self.target_transform is not None: 51 | target = self.target_transform(target) 52 | 53 | img = (np.array(img) - self.mean_rgb).transpose(2, 0, 1) 54 | img = torch.from_numpy(img.astype(np.float32)) 55 | target = np.array(target, dtype=np.int32) 56 | target[target == 255] = -1 57 | target = torch.from_numpy(target.astype(np.int64)) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | return len(self.ids) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PSPNet with PyTorch 2 | 3 | Unofficial implementation of "Pyramid Scene Parsing Network" (https://arxiv.org/abs/1612.01105). This repository is just for caffe to pytorch model conversion and evaluation. 4 | 5 | ### Requirements 6 | 7 | * pytorch 8 | * click 9 | * addict 10 | * pydensecrf 11 | * protobuf 12 | 13 | ## Preparation 14 | Instead of building the author's caffe implementation, you can convert off-the-shelf caffemodels to pytorch models via the ```caffe.proto```. 15 | 16 | ### 1. Compile the ```caffe.proto``` for Python API 17 | This step can be skipped. FYI.
18 | Download [the author's ```caffe.proto```](https://github.com/hszhao/PSPNet/blob/master/src/caffe/proto/caffe.proto) into the ```libs```, not the one in the original caffe. 19 | ```sh 20 | # For protoc command 21 | pip install protobuf 22 | # This generates ./caffe_pb2.py 23 | protoc --python_out=. caffe.proto 24 | ``` 25 | 26 | ### 2. Model conversion 27 | 28 | 1. Find the caffemodels on [the author's page](https://github.com/hszhao/PSPNet#usage) (e.g. pspnet50_ADE20K.caffemodel) and store them to the ```data/models/``` directory. 29 | 2. Convert the caffemodels to ```.pth``` file. 30 | 31 | ```sh 32 | python convert.py -c 33 | ``` 34 | 35 | ## Demo 36 | 37 | ```sh 38 | python demo.py -c -i 39 | ``` 40 | * With a ```--no-cuda``` option, this runs on CPU. 41 | * With a ```--crf``` option, you can perform a CRF postprocessing. 42 | 43 | ![demo](docs/demo.png) 44 | 45 | ## Evaluation 46 | 47 | PASCAL VOC2012 only. Please set the dataset path in ```config/voc12.yaml```. 48 | 49 | ```sh 50 | python eval.py -c config/voc12.yaml 51 | ``` 52 | 53 | 88.1% mIoU (SS) and 88.6% mIoU (MS) on validation set.
54 | *NOTE: 3 points lower than caffe implementation. WIP* 55 | 56 | * SS: averaged prediction with flipping (2x) 57 | * MS: averaged prediction with multi-scaling (6x) and flipping (2x) 58 | * Both: No CRF post-processing 59 | 60 | ## References 61 | 62 | * Official implementation: https://github.com/hszhao/PSPNet 63 | * Chainer implementation: https://github.com/mitmul/chainer-pspnet -------------------------------------------------------------------------------- /draw_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Kazuto Nakashima 5 | # URL: http://kazuto1011.github.io 6 | # Created: 2017-11-06 7 | 8 | import torch 9 | from graphviz import Digraph 10 | from torch.autograd import Variable 11 | 12 | from libs.models import * 13 | 14 | 15 | def make_dot(var, params): 16 | 17 | param_map = {id(v): k for k, v in params.items()} 18 | 19 | node_attr = dict( 20 | style="filled", 21 | shape="box", 22 | align="left", 23 | fontsize="12", 24 | ranksep="0.1", 25 | height="0.2", 26 | ) 27 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 28 | seen = set() 29 | 30 | def size_to_str(size): 31 | return "(" + (", ").join(["%d" % v for v in size]) + ")" 32 | 33 | def add_nodes(var): 34 | if var not in seen: 35 | if torch.is_tensor(var): 36 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor="orange") 37 | elif hasattr(var, "variable"): 38 | u = var.variable 39 | dot.node(str(id(var)), size_to_str(u.size()), fillcolor="lightblue") 40 | else: 41 | dot.node(str(id(var)), str(type(var).__name__.replace("Backward", ""))) 42 | seen.add(var) 43 | if hasattr(var, "next_functions"): 44 | for u in var.next_functions: 45 | if u[0] is not None: 46 | dot.edge(str(id(u[0])), str(id(var))) 47 | add_nodes(u[0]) 48 | if hasattr(var, "saved_tensors"): 49 | for t in var.saved_tensors: 50 | dot.edge(str(id(t)), str(id(var))) 51 | add_nodes(t) 52 | 53 | add_nodes(var.grad_fn) 54 | 55 | return dot 56 | 57 | 58 | if __name__ == "__main__": 59 | # Define a model 60 | model = PSPNet(n_classes=6, n_blocks=[3, 4, 6, 3], pyramids=[6, 3, 2, 1]) 61 | 62 | # Build a computational graph from x to y 63 | x = torch.randn(2, 3, 512, 512) 64 | y1, y2 = model(Variable(x)) 65 | g = make_dot(y1 + y2, model.state_dict()) 66 | g.view(filename="model", cleanup=True) 67 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Kazuto Nakashima 5 | # URL: http://kazuto1011.github.io 6 | # Created: 2017-11-15 7 | 8 | import click 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import yaml 16 | from addict import Dict 17 | from torch.autograd import Variable 18 | 19 | from libs.models import PSPNet 20 | from libs.utils import dense_crf 21 | 22 | 23 | @click.command() 24 | @click.option("--config", "-c", required=True) 25 | @click.option("--image-path", "-i", required=True) 26 | @click.option("--cuda/--no-cuda", default=True) 27 | @click.option("--crf", is_flag=True) 28 | def main(config, image_path, cuda, crf): 29 | CONFIG = Dict(yaml.load(open(config))) 30 | 31 | cuda = cuda and torch.cuda.is_available() 32 | 33 | # Label list 34 | with open(CONFIG.LABELS) as f: 35 | classes = {} 36 | for label in f: 37 | label = label.rstrip().split("\t") 38 | classes[int(label[0])] = label[1].split(",")[0] 39 | 40 | # Load a model 41 | state_dict = torch.load(CONFIG.PYTORCH_MODEL) 42 | 43 | # Model 44 | model = PSPNet( 45 | n_classes=CONFIG.N_CLASSES, n_blocks=CONFIG.N_BLOCKS, pyramids=CONFIG.PYRAMIDS 46 | ) 47 | model.load_state_dict(state_dict) 48 | model.eval() 49 | if cuda: 50 | model.cuda() 51 | 52 | image_size = (CONFIG.IMAGE.SIZE.TEST,) * 2 53 | 54 | # Image preprocessing 55 | image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(float) 56 | image = cv2.resize(image, image_size) 57 | image_original = image.astype(np.uint8) 58 | image = image[..., ::-1] - np.array( 59 | [CONFIG.IMAGE.MEAN.R, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.B] 60 | ) 61 | image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0) 62 | image = image.cuda() if cuda else image 63 | 64 | # Inference 65 | output = model(Variable(image, volatile=True)) 66 | 67 | output = F.upsample(output, size=image_size, mode="bilinear") 68 | output = F.softmax(output, dim=1) 69 | output = output[0].cpu().data.numpy() 70 | 71 | if crf: 72 | output = dense_crf(image_original, output) 73 | labelmap = np.argmax(output.transpose(1, 2, 0), axis=2) 74 | 75 | labels = np.unique(labelmap) 76 | 77 | rows = np.floor(np.sqrt(len(labels) + 1)) 78 | cols = np.ceil((len(labels) + 1) / rows) 79 | 80 | plt.figure(figsize=(10, 10)) 81 | ax = plt.subplot(rows, cols, 1) 82 | ax.set_title("Input image") 83 | ax.imshow(image_original[:, :, ::-1]) 84 | ax.set_xticks([]) 85 | ax.set_yticks([]) 86 | 87 | for i, label in enumerate(labels): 88 | print("{0:3d}: {1}".format(label, classes[label])) 89 | mask = labelmap == label 90 | ax = plt.subplot(rows, cols, i + 2) 91 | ax.set_title(classes[label]) 92 | ax.imshow(image_original[:, :, ::-1]) 93 | ax.imshow(mask.astype(np.float32), alpha=0.5, cmap="viridis") 94 | ax.set_xticks([]) 95 | ax.set_yticks([]) 96 | 97 | plt.tight_layout() 98 | plt.show() 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /libs/models/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Kazuto Nakashima 5 | # URL: http://kazuto1011.github.io 6 | # Created: 2017-11-19 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.model_zoo as model_zoo 11 | import torch.nn.functional as F 12 | from collections import OrderedDict 13 | 14 | 15 | class _ConvBatchNormReLU(nn.Sequential): 16 | """Convolution Unit""" 17 | 18 | def __init__( 19 | self, 20 | in_channels, 21 | out_channels, 22 | kernel_size, 23 | stride, 24 | padding, 25 | dilation, 26 | relu=True, 27 | ): 28 | super(_ConvBatchNormReLU, self).__init__() 29 | self.add_module( 30 | "conv", 31 | nn.Conv2d( 32 | in_channels=in_channels, 33 | out_channels=out_channels, 34 | kernel_size=kernel_size, 35 | stride=stride, 36 | padding=padding, 37 | dilation=dilation, 38 | bias=False, 39 | ), 40 | ) 41 | self.add_module( 42 | "bn", nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.95, affine=True) 43 | ) 44 | if relu: 45 | self.add_module("relu", nn.ReLU()) 46 | 47 | def forward(self, x): 48 | return super(_ConvBatchNormReLU, self).forward(x) 49 | 50 | 51 | class _Bottleneck(nn.Module): 52 | """Bottleneck Unit""" 53 | 54 | def __init__( 55 | self, in_channels, mid_channels, out_channels, stride, dilation, downsample 56 | ): 57 | super(_Bottleneck, self).__init__() 58 | self.reduce = _ConvBatchNormReLU(in_channels, mid_channels, 1, 1, 0, 1) 59 | self.conv3x3 = _ConvBatchNormReLU( 60 | mid_channels, mid_channels, 3, stride, dilation, dilation 61 | ) 62 | self.increase = _ConvBatchNormReLU( 63 | mid_channels, out_channels, 1, 1, 0, 1, relu=False 64 | ) 65 | self.downsample = downsample 66 | if self.downsample: 67 | self.proj = _ConvBatchNormReLU( 68 | in_channels, out_channels, 1, stride, 0, 1, relu=False 69 | ) 70 | 71 | def forward(self, x): 72 | h = self.reduce(x) 73 | h = self.conv3x3(h) 74 | h = self.increase(h) 75 | if self.downsample: 76 | h += self.proj(x) 77 | else: 78 | h += x 79 | return F.relu(h) 80 | 81 | 82 | class _ResBlock(nn.Sequential): 83 | """Residual Block""" 84 | 85 | def __init__( 86 | self, n_layers, in_channels, mid_channels, out_channels, stride, dilation 87 | ): 88 | super(_ResBlock, self).__init__() 89 | self.add_module( 90 | "block1", 91 | _Bottleneck( 92 | in_channels, mid_channels, out_channels, stride, dilation, True 93 | ), 94 | ) 95 | for i in range(2, n_layers + 1): 96 | self.add_module( 97 | "block" + str(i), 98 | _Bottleneck( 99 | out_channels, mid_channels, out_channels, 1, dilation, False 100 | ), 101 | ) 102 | 103 | def __call__(self, x): 104 | return super(_ResBlock, self).forward(x) 105 | -------------------------------------------------------------------------------- /libs/models/pspnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Kazuto Nakashima 5 | # URL: http://kazuto1011.github.io 6 | # Created: 2017-11-15 7 | 8 | from __future__ import absolute_import 9 | 10 | from collections import OrderedDict 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from .resnet import _ConvBatchNormReLU, _ResBlock 17 | 18 | 19 | class _DilatedFCN(nn.Module): 20 | """ResNet-based Dilated FCN""" 21 | 22 | def __init__(self, n_blocks): 23 | super(_DilatedFCN, self).__init__() 24 | self.layer1 = nn.Sequential( 25 | OrderedDict( 26 | [ 27 | ("conv1", _ConvBatchNormReLU(3, 64, 3, 2, 1, 1)), 28 | ("conv2", _ConvBatchNormReLU(64, 64, 3, 1, 1, 1)), 29 | ("conv3", _ConvBatchNormReLU(64, 128, 3, 1, 1, 1)), 30 | ("pool", nn.MaxPool2d(3, 2, 1)), 31 | ] 32 | ) 33 | ) 34 | self.layer2 = _ResBlock(n_blocks[0], 128, 64, 256, 1, 1) 35 | self.layer3 = _ResBlock(n_blocks[1], 256, 128, 512, 2, 1) 36 | self.layer4 = _ResBlock(n_blocks[2], 512, 256, 1024, 1, 2) 37 | self.layer5 = _ResBlock(n_blocks[3], 1024, 512, 2048, 1, 4) 38 | 39 | def forward(self, x): 40 | h = self.layer1(x) 41 | h = self.layer2(h) 42 | h = self.layer3(h) 43 | h1 = self.layer4(h) 44 | h2 = self.layer5(h1) 45 | if self.training: 46 | return h1, h2 47 | else: 48 | return h2 49 | 50 | 51 | class _PyramidPoolModule(nn.Sequential): 52 | """Pyramid Pooling Module""" 53 | 54 | def __init__(self, in_channels, pyramids=[6, 3, 2, 1]): 55 | super(_PyramidPoolModule, self).__init__() 56 | out_channels = in_channels // len(pyramids) 57 | self.stages = nn.Module() 58 | for i, p in enumerate(pyramids): 59 | self.stages.add_module( 60 | "s{}".format(i), 61 | nn.Sequential( 62 | OrderedDict( 63 | [ 64 | ("pool", nn.AdaptiveAvgPool2d(output_size=p)), 65 | ( 66 | "conv", 67 | _ConvBatchNormReLU( 68 | in_channels, out_channels, 1, 1, 0, 1 69 | ), 70 | ), 71 | ] 72 | ) 73 | ), 74 | ) 75 | 76 | def forward(self, x): 77 | hs = [x] 78 | height, width = x.size()[2:] 79 | for stage in self.stages.children(): 80 | h = stage(x) 81 | h = F.upsample(h, (height, width), mode="bilinear") 82 | hs.append(h) 83 | return torch.cat(hs, dim=1) 84 | 85 | 86 | class PSPNet(nn.Module): 87 | """Pyramid Scene Parsing Network""" 88 | 89 | def __init__(self, n_classes, n_blocks, pyramids): 90 | super(PSPNet, self).__init__() 91 | self.n_classes = n_classes 92 | self.fcn = _DilatedFCN(n_blocks=n_blocks) 93 | self.ppm = _PyramidPoolModule(in_channels=2048, pyramids=pyramids) 94 | # Main branch 95 | self.final = nn.Sequential( 96 | OrderedDict( 97 | [ 98 | ("conv5_4", _ConvBatchNormReLU(4096, 512, 3, 1, 1, 1)), 99 | ("drop5_4", nn.Dropout2d(p=0.1)), 100 | ("conv6", nn.Conv2d(512, n_classes, 1, stride=1, padding=0)), 101 | ] 102 | ) 103 | ) 104 | # Auxiliary branch 105 | self.aux = nn.Sequential( 106 | OrderedDict( 107 | [ 108 | ("conv4_aux", _ConvBatchNormReLU(1024, 256, 3, 1, 1, 1)), 109 | ("drop4_aux", nn.Dropout2d(p=0.1)), 110 | ("conv6_1", nn.Conv2d(256, n_classes, 1, stride=1, padding=0)), 111 | ] 112 | ) 113 | ) 114 | 115 | def forward(self, x): 116 | if self.training: 117 | aux, h = self.fcn(x) 118 | aux = self.aux(aux) 119 | else: 120 | h = self.fcn(x) 121 | h = self.ppm(h) 122 | h = self.final(h) 123 | 124 | if self.training: 125 | return aux, h 126 | else: 127 | return h 128 | 129 | 130 | if __name__ == "__main__": 131 | model = PSPNet(n_classes=150, n_blocks=[3, 4, 6, 3], pyramids=[6, 3, 2, 1]) 132 | print(list(model.named_children())) 133 | model.eval() 134 | image = torch.autograd.Variable(torch.randn(1, 3, 473, 473)) 135 | print(model(image).size()) 136 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Kazuto Nakashima 5 | # URL: http://kazuto1011.github.io 6 | # Created: 2017-11-15 7 | 8 | from __future__ import print_function 9 | 10 | import re 11 | from collections import OrderedDict 12 | 13 | import click 14 | import numpy as np 15 | import torch 16 | import yaml 17 | from addict import Dict 18 | 19 | from libs import caffe_pb2 20 | from libs.models import PSPNet 21 | 22 | 23 | def parse_caffemodel(model_path): 24 | caffemodel = caffe_pb2.NetParameter() 25 | with open(model_path, "rb") as f: 26 | caffemodel.MergeFromString(f.read()) 27 | 28 | # Check trainable layers 29 | print(set([(layer.type, len(layer.blobs)) for layer in caffemodel.layer])) 30 | 31 | params = OrderedDict() 32 | for layer in caffemodel.layer: 33 | print("{} ({}): {}".format(layer.name, layer.type, len(layer.blobs))) 34 | 35 | # Convolution or Dilated Convolution 36 | if "Convolution" in layer.type: 37 | params[layer.name] = {} 38 | params[layer.name]["kernel_size"] = layer.convolution_param.kernel_size[0] 39 | params[layer.name]["stride"] = layer.convolution_param.stride[0] 40 | params[layer.name]["weight"] = list(layer.blobs[0].data) 41 | if len(layer.blobs) == 2: 42 | params[layer.name]["bias"] = list(layer.blobs[1].data) 43 | if len(layer.convolution_param.pad) == 1: # or [] 44 | params[layer.name]["padding"] = layer.convolution_param.pad[0] 45 | else: 46 | params[layer.name]["padding"] = 0 47 | if isinstance(layer.convolution_param.dilation, int): # or [] 48 | params[layer.name]["dilation"] = layer.convolution_param.dilation 49 | else: 50 | params[layer.name]["dilation"] = 1 51 | 52 | # Batch Normalization 53 | elif "BN" in layer.type: 54 | params[layer.name] = {} 55 | params[layer.name]["weight"] = list(layer.blobs[0].data) 56 | params[layer.name]["bias"] = list(layer.blobs[1].data) 57 | params[layer.name]["running_mean"] = list(layer.blobs[2].data) 58 | params[layer.name]["running_var"] = list(layer.blobs[3].data) 59 | params[layer.name]["eps"] = layer.bn_param.eps 60 | params[layer.name]["momentum"] = layer.bn_param.momentum 61 | 62 | return params 63 | 64 | 65 | # Hard coded translater 66 | def translate_layer_name(source): 67 | def conv_or_bn(source): 68 | if "bn" in source: 69 | return ".bn" 70 | else: 71 | return ".conv" 72 | 73 | source = re.split("[_/]", source) 74 | layer = int(source[0][4]) # Remove "conv" 75 | target = "" 76 | 77 | if layer == 1: 78 | target += "fcn.layer{}.conv{}".format(layer, source[1]) 79 | target += conv_or_bn(source) 80 | elif layer in range(2, 6): 81 | block = int(source[1]) 82 | # Auxirally layer 83 | if layer == 4 and len(source) == 3 and source[2] == "bn": 84 | target += "aux.conv4_aux.bn" 85 | elif layer == 4 and len(source) == 2: 86 | target += "aux.conv4_aux.conv" 87 | # Pyramid pooling modules 88 | elif layer == 5 and block == 3 and "pool" in source[2]: 89 | pyramid = {1: 3, 2: 2, 3: 1, 6: 0}[int(source[2][4])] 90 | target += "ppm.stages.s{}.conv".format(pyramid) 91 | target += conv_or_bn(source) 92 | # Last convolutions 93 | elif layer == 5 and block == 4: 94 | target += "final.conv5_4" 95 | target += conv_or_bn(source) 96 | else: 97 | target += "fcn.layer{}".format(layer) 98 | target += ".block{}".format(block) 99 | if source[2] == "3x3": 100 | target += ".conv3x3" 101 | else: 102 | target += ".{}".format(source[3]) 103 | target += conv_or_bn(source) 104 | elif layer == 6: 105 | if len(source) == 1: 106 | target += "final.conv6" 107 | else: 108 | target += "aux.conv6_1" 109 | 110 | return target 111 | 112 | 113 | @click.command() 114 | @click.option("--config", "-c", required=True) 115 | def main(config): 116 | WHITELIST = ["kernel_size", "stride", "padding", "dilation", "eps", "momentum"] 117 | CONFIG = Dict(yaml.load(open(config))) 118 | 119 | params = parse_caffemodel(CONFIG.CAFFE_MODEL) 120 | 121 | model = PSPNet( 122 | n_classes=CONFIG.N_CLASSES, n_blocks=CONFIG.N_BLOCKS, pyramids=CONFIG.PYRAMIDS 123 | ) 124 | model.eval() 125 | own_state = model.state_dict() 126 | 127 | report = [] 128 | state_dict = OrderedDict() 129 | for layer_name, layer_dict in params.items(): 130 | for param_name, values in layer_dict.items(): 131 | if param_name in WHITELIST: 132 | attribute = translate_layer_name(layer_name) 133 | attribute = eval("model." + attribute + "." + param_name) 134 | message = " ".join( 135 | [ 136 | layer_name.ljust(25), 137 | "->", 138 | param_name, 139 | "pytorch: " + str(attribute), 140 | "caffe: " + str(values), 141 | ] 142 | ) 143 | print(message, end="") 144 | if isinstance(attribute, tuple): 145 | if attribute[0] != values: 146 | report.append(message) 147 | else: 148 | if abs(attribute - values) > 1e-4: 149 | report.append(message) 150 | print(": Checked!") 151 | continue 152 | param_name = translate_layer_name(layer_name) + "." + param_name 153 | if param_name in own_state: 154 | print(layer_name.ljust(25), "->", param_name, end="") 155 | values = torch.FloatTensor(values) 156 | values = values.view_as(own_state[param_name]) 157 | state_dict[param_name] = values 158 | print(": Copied!") 159 | 160 | print("Inconsistent parameters (*_3x3 dilation and momentum can be ignored):") 161 | print(*report, sep="\n") 162 | 163 | # Check 164 | model.load_state_dict(state_dict) 165 | torch.save(state_dict, CONFIG.PYTORCH_MODEL) 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Kazuto Nakashima 5 | # URL: http://kazuto1011.github.io 6 | # Created: 2018-01-24 7 | 8 | import json 9 | import pickle 10 | from math import ceil 11 | 12 | import click 13 | import cv2 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import yaml 19 | from addict import Dict 20 | from PIL import Image 21 | from torch.autograd import Variable 22 | from torchvision import transforms 23 | from tqdm import tqdm 24 | 25 | from libs.datasets import VOCSegmentation 26 | from libs.models import PSPNet 27 | from libs.utils import scores 28 | 29 | 30 | def pad_image(image, crop_size): 31 | new_h, new_w = image.shape[2:] 32 | pad_h = max(crop_size - new_h, 0) 33 | pad_w = max(crop_size - new_w, 0) 34 | padded_image = torch.FloatTensor(1, 3, new_h + pad_h, new_w + pad_w).zero_() 35 | for i in range(3): # RGB 36 | padded_image[:, [i], ...] = F.pad( 37 | image[:, [i], ...], 38 | pad=(0, pad_w, 0, pad_h), # Pad right and bottom 39 | mode="constant", 40 | value=0, 41 | ).data 42 | return padded_image 43 | 44 | 45 | def to_cuda(tensors, cuda): 46 | return tensors.cuda() if cuda else tensors 47 | 48 | 49 | def to_var(tensors, cuda): 50 | tensors = to_cuda(tensors, cuda) 51 | variables = Variable(tensors, volatile=True) 52 | return variables 53 | 54 | 55 | def flip(x, dim=3): 56 | xsize = x.size() 57 | dim = x.dim() + dim if dim < 0 else dim 58 | x = x.view(-1, *xsize[dim:]) 59 | x = x.view(x.size(0), x.size(1), -1)[ 60 | :, 61 | getattr( 62 | torch.arange(x.size(1) - 1, -1, -1), ("cpu", "cuda")[x.is_cuda] 63 | )().long(), 64 | :, 65 | ] 66 | return x.view(xsize) 67 | 68 | 69 | def tile_predict(image, model, crop_size, cuda, n_classes): 70 | # Original MATLAB script 71 | # https://github.com/hszhao/PSPNet/blob/master/evaluation/scale_process.m 72 | pad_h, pad_w = image.shape[2:] 73 | stride_rate = 2 / 3.0 74 | stride = int(ceil(crop_size * stride_rate)) 75 | h_grid = int(ceil((pad_h - crop_size) / float(stride)) + 1) 76 | w_grid = int(ceil((pad_w - crop_size) / float(stride)) + 1) 77 | count = to_cuda(torch.FloatTensor(1, 1, pad_h, pad_w).zero_(), cuda) 78 | prediction = to_cuda(torch.FloatTensor(1, n_classes, pad_h, pad_w).zero_(), cuda) 79 | for ih in range(h_grid): 80 | for iw in range(w_grid): 81 | sh, sw = ih * stride, iw * stride 82 | eh, ew = min(sh + crop_size, pad_h), min(sw + crop_size, pad_w) 83 | sh, sw = eh - crop_size, ew - crop_size # Stay within image size 84 | image_sub = image[..., sh:eh, sw:ew] 85 | image_sub = pad_image(image_sub, crop_size) 86 | image_sub = to_var(image_sub, cuda) 87 | output = model(image_sub) 88 | output = F.upsample(output, size=(crop_size,) * 2, mode="bilinear") 89 | count[..., sh:eh, sw:ew] += 1 90 | prediction[..., sh:eh, sw:ew] += output.data 91 | prediction /= count # Normalize overlayed parts 92 | return prediction 93 | 94 | 95 | @click.command() 96 | @click.option("--config", "-c", required=True) 97 | @click.option("--cuda/--no-cuda", default=True) 98 | @click.option("--show", is_flag=True) 99 | def main(config, cuda, show): 100 | CONFIG = Dict(yaml.load(open(config))) 101 | 102 | cuda = cuda and torch.cuda.is_available() 103 | 104 | dataset = VOCSegmentation( 105 | root=CONFIG.DATASET_ROOT, image_set="val", dataset_name="VOC2012" 106 | ) 107 | 108 | dataloader = torch.utils.data.DataLoader( 109 | dataset=dataset, 110 | batch_size=1, #! DO NOT CHANGE 111 | num_workers=CONFIG.NUM_WORKERS, 112 | pin_memory=False, 113 | shuffle=False, 114 | ) 115 | 116 | # Load a model 117 | state_dict = torch.load(CONFIG.PYTORCH_MODEL) 118 | 119 | # Model 120 | model = PSPNet( 121 | n_classes=CONFIG.N_CLASSES, n_blocks=CONFIG.N_BLOCKS, pyramids=CONFIG.PYRAMIDS 122 | ) 123 | model.load_state_dict(state_dict) 124 | model = nn.DataParallel(model) 125 | model.eval() 126 | if cuda: 127 | model.cuda() 128 | 129 | crop_size = CONFIG.IMAGE.SIZE.TEST 130 | targets, outputs = [], [] 131 | 132 | for image, target in tqdm( 133 | dataloader, total=len(dataloader), leave=False, dynamic_ncols=True 134 | ): 135 | 136 | h, w = image.size()[2:] 137 | outputs_ = [] 138 | 139 | for scale in CONFIG.SCALES: 140 | 141 | # Resize 142 | long_side = int(scale * CONFIG.IMAGE.SIZE.BASE) 143 | new_h = long_side 144 | new_w = long_side 145 | if h > w: 146 | new_w = int(long_side * w / h) 147 | else: 148 | new_h = int(long_side * h / w) 149 | image_ = F.upsample(image, size=(new_h, new_w), mode="bilinear").data 150 | 151 | # Predict (w/ flipping) 152 | if long_side <= crop_size: 153 | # Padding evaluation 154 | image_ = pad_image(image_, crop_size) 155 | image_ = to_var(image_, cuda) 156 | output = torch.cat( 157 | (model(image_), flip(model(flip(image_)))) # C, H, W # C, H, W 158 | ) 159 | output = F.upsample(output, size=(crop_size,) * 2, mode="bilinear") 160 | # Revert to original size 161 | output = output[..., 0:new_h, 0:new_w] 162 | output = F.upsample(output, size=(h, w), mode="bilinear") 163 | outputs_ += [o for o in output.data] # 2 x [C, H, W] 164 | else: 165 | # Sliced evaluation 166 | image_ = pad_image(image_, crop_size) 167 | output = torch.cat( 168 | ( 169 | tile_predict(image_, model, crop_size, cuda, CONFIG.N_CLASSES), 170 | flip( 171 | tile_predict( 172 | flip(image_), model, crop_size, cuda, CONFIG.N_CLASSES 173 | ) 174 | ), 175 | ) 176 | ) 177 | # Revert to original size 178 | output = output[..., 0:new_h, 0:new_w] 179 | output = F.upsample(output, size=(h, w), mode="bilinear") 180 | outputs_ += [o for o in output.data] # 2 x [C, H, W] 181 | 182 | # Average 183 | output = torch.stack(outputs_, dim=0) # 2x#scales, C, H, W 184 | output = torch.mean(output, dim=0) # C, H, W 185 | output = torch.max(output, dim=0)[1] # H, W 186 | output = output.cpu().numpy() 187 | target = target.squeeze(0).numpy() 188 | 189 | if show: 190 | res_gt = np.concatenate((output, target), 1) 191 | mask = (res_gt >= 0)[..., None] 192 | res_gt[res_gt < 0] = 0 193 | res_gt = np.uint8(res_gt / float(CONFIG.N_CLASSES) * 255) 194 | res_gt = cv2.applyColorMap(res_gt, cv2.COLORMAP_JET) 195 | res_gt = np.uint8(res_gt * mask) 196 | img = np.uint8(image.numpy()[0].transpose(1, 2, 0) + dataset.mean_rgb)[ 197 | ..., ::-1 198 | ] 199 | img_res_gt = np.concatenate((img, res_gt), 1) 200 | cv2.imshow("result", img_res_gt) 201 | cv2.waitKey(10) 202 | 203 | outputs.append(output) 204 | targets.append(target) 205 | 206 | score, class_iou = scores(targets, outputs, n_class=CONFIG.N_CLASSES) 207 | 208 | for k, v in score.items(): 209 | print(k, v) 210 | 211 | score["Class IoU"] = {} 212 | for i in range(CONFIG.N_CLASSES): 213 | score["Class IoU"][i] = class_iou[i] 214 | 215 | with open("results.json", "w") as f: 216 | json.dump(score, f, indent=4, sort_keys=True) 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /data/datasets/ade20k/labels.txt: -------------------------------------------------------------------------------- 1 | 0 wall 2 | 1 building, edifice 3 | 2 sky 4 | 3 floor, flooring 5 | 4 tree 6 | 5 ceiling 7 | 6 road, route 8 | 7 bed 9 | 8 windowpane, window 10 | 9 grass 11 | 10 cabinet 12 | 11 sidewalk, pavement 13 | 12 person, individual, someone, somebody, mortal, soul 14 | 13 earth, ground 15 | 14 door, double door 16 | 15 table 17 | 16 mountain, mount 18 | 17 plant, flora, plant life 19 | 18 curtain, drape, drapery, mantle, pall 20 | 19 chair 21 | 20 car, auto, automobile, machine, motorcar 22 | 21 water 23 | 22 painting, picture 24 | 23 sofa, couch, lounge 25 | 24 shelf 26 | 25 house 27 | 26 sea 28 | 27 mirror 29 | 28 rug, carpet, carpeting 30 | 29 field 31 | 30 armchair 32 | 31 seat 33 | 32 fence, fencing 34 | 33 desk 35 | 34 rock, stone 36 | 35 wardrobe, closet, press 37 | 36 lamp 38 | 37 bathtub, bathing tub, bath, tub 39 | 38 railing, rail 40 | 39 cushion 41 | 40 base, pedestal, stand 42 | 41 box 43 | 42 column, pillar 44 | 43 signboard, sign 45 | 44 chest of drawers, chest, bureau, dresser 46 | 45 counter 47 | 46 sand 48 | 47 sink 49 | 48 skyscraper 50 | 49 fireplace, hearth, open fireplace 51 | 50 refrigerator, icebox 52 | 51 grandstand, covered stand 53 | 52 path 54 | 53 stairs, steps 55 | 54 runway 56 | 55 case, display case, showcase, vitrine 57 | 56 pool table, billiard table, snooker table 58 | 57 pillow 59 | 58 screen door, screen 60 | 59 stairway, staircase 61 | 60 river 62 | 61 bridge, span 63 | 62 bookcase 64 | 63 blind, screen 65 | 64 coffee table, cocktail table 66 | 65 toilet, can, commode, crapper, pot, potty, stool, throne 67 | 66 flower 68 | 67 book 69 | 68 hill 70 | 69 bench 71 | 70 countertop 72 | 71 stove, kitchen stove, range, kitchen range, cooking stove 73 | 72 palm, palm tree 74 | 73 kitchen island 75 | 74 computer, computing machine, computing device, data processor, electronic computer, information processing system 76 | 75 swivel chair 77 | 76 boat 78 | 77 bar 79 | 78 arcade machine 80 | 79 hovel, hut, hutch, shack, shanty 81 | 80 bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle 82 | 81 towel 83 | 82 light, light source 84 | 83 truck, motortruck 85 | 84 tower 86 | 85 chandelier, pendant, pendent 87 | 86 awning, sunshade, sunblind 88 | 87 streetlight, street lamp 89 | 88 booth, cubicle, stall, kiosk 90 | 89 television, television receiver, television set, tv, tv set, idiot box, boob tube, telly, goggle box 91 | 90 airplane, aeroplane, plane 92 | 91 dirt track 93 | 92 apparel, wearing apparel, dress, clothes 94 | 93 pole 95 | 94 land, ground, soil 96 | 95 bannister, banister, balustrade, balusters, handrail 97 | 96 escalator, moving staircase, moving stairway 98 | 97 ottoman, pouf, pouffe, puff, hassock 99 | 98 bottle 100 | 99 buffet, counter, sideboard 101 | 100 poster, posting, placard, notice, bill, card 102 | 101 stage 103 | 102 van 104 | 103 ship 105 | 104 fountain 106 | 105 conveyer belt, conveyor belt, conveyer, conveyor, transporter 107 | 106 canopy 108 | 107 washer, automatic washer, washing machine 109 | 108 plaything, toy 110 | 109 swimming pool, swimming bath, natatorium 111 | 110 stool 112 | 111 barrel, cask 113 | 112 basket, handbasket 114 | 113 waterfall, falls 115 | 114 tent, collapsible shelter 116 | 115 bag 117 | 116 minibike, motorbike 118 | 117 cradle 119 | 118 oven 120 | 119 ball 121 | 120 food, solid food 122 | 121 step, stair 123 | 122 tank, storage tank 124 | 123 trade name, brand name, brand, marque 125 | 124 microwave, microwave oven 126 | 125 pot, flowerpot 127 | 126 animal, animate being, beast, brute, creature, fauna 128 | 127 bicycle, bike, wheel, cycle 129 | 128 lake 130 | 129 dishwasher, dish washer, dishwashing machine 131 | 130 screen, silver screen, projection screen 132 | 131 blanket, cover 133 | 132 sculpture 134 | 133 hood, exhaust hood 135 | 134 sconce 136 | 135 vase 137 | 136 traffic light, traffic signal, stoplight 138 | 137 tray 139 | 138 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 140 | 139 fan 141 | 140 pier, wharf, wharfage, dock 142 | 141 crt screen 143 | 142 plate 144 | 143 monitor, monitoring device 145 | 144 bulletin board, notice board 146 | 145 shower 147 | 146 radiator 148 | 147 glass, drinking glass 149 | 148 clock 150 | 149 flag 151 | 0 wall 152 | 1 building, edifice 153 | 2 sky 154 | 3 floor, flooring 155 | 4 tree 156 | 5 ceiling 157 | 6 road, route 158 | 7 bed 159 | 8 windowpane, window 160 | 9 grass 161 | 10 cabinet 162 | 11 sidewalk, pavement 163 | 12 person, individual, someone, somebody, mortal, soul 164 | 13 earth, ground 165 | 14 door, double door 166 | 15 table 167 | 16 mountain, mount 168 | 17 plant, flora, plant life 169 | 18 curtain, drape, drapery, mantle, pall 170 | 19 chair 171 | 20 car, auto, automobile, machine, motorcar 172 | 21 water 173 | 22 painting, picture 174 | 23 sofa, couch, lounge 175 | 24 shelf 176 | 25 house 177 | 26 sea 178 | 27 mirror 179 | 28 rug, carpet, carpeting 180 | 29 field 181 | 30 armchair 182 | 31 seat 183 | 32 fence, fencing 184 | 33 desk 185 | 34 rock, stone 186 | 35 wardrobe, closet, press 187 | 36 lamp 188 | 37 bathtub, bathing tub, bath, tub 189 | 38 railing, rail 190 | 39 cushion 191 | 40 base, pedestal, stand 192 | 41 box 193 | 42 column, pillar 194 | 43 signboard, sign 195 | 44 chest of drawers, chest, bureau, dresser 196 | 45 counter 197 | 46 sand 198 | 47 sink 199 | 48 skyscraper 200 | 49 fireplace, hearth, open fireplace 201 | 50 refrigerator, icebox 202 | 51 grandstand, covered stand 203 | 52 path 204 | 53 stairs, steps 205 | 54 runway 206 | 55 case, display case, showcase, vitrine 207 | 56 pool table, billiard table, snooker table 208 | 57 pillow 209 | 58 screen door, screen 210 | 59 stairway, staircase 211 | 60 river 212 | 61 bridge, span 213 | 62 bookcase 214 | 63 blind, screen 215 | 64 coffee table, cocktail table 216 | 65 toilet, can, commode, crapper, pot, potty, stool, throne 217 | 66 flower 218 | 67 book 219 | 68 hill 220 | 69 bench 221 | 70 countertop 222 | 71 stove, kitchen stove, range, kitchen range, cooking stove 223 | 72 palm, palm tree 224 | 73 kitchen island 225 | 74 computer, computing machine, computing device, data processor, electronic computer, information processing system 226 | 75 swivel chair 227 | 76 boat 228 | 77 bar 229 | 78 arcade machine 230 | 79 hovel, hut, hutch, shack, shanty 231 | 80 bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle 232 | 81 towel 233 | 82 light, light source 234 | 83 truck, motortruck 235 | 84 tower 236 | 85 chandelier, pendant, pendent 237 | 86 awning, sunshade, sunblind 238 | 87 streetlight, street lamp 239 | 88 booth, cubicle, stall, kiosk 240 | 89 television, television receiver, television set, tv, tv set, idiot box, boob tube, telly, goggle box 241 | 90 airplane, aeroplane, plane 242 | 91 dirt track 243 | 92 apparel, wearing apparel, dress, clothes 244 | 93 pole 245 | 94 land, ground, soil 246 | 95 bannister, banister, balustrade, balusters, handrail 247 | 96 escalator, moving staircase, moving stairway 248 | 97 ottoman, pouf, pouffe, puff, hassock 249 | 98 bottle 250 | 99 buffet, counter, sideboard 251 | 100 poster, posting, placard, notice, bill, card 252 | 101 stage 253 | 102 van 254 | 103 ship 255 | 104 fountain 256 | 105 conveyer belt, conveyor belt, conveyer, conveyor, transporter 257 | 106 canopy 258 | 107 washer, automatic washer, washing machine 259 | 108 plaything, toy 260 | 109 swimming pool, swimming bath, natatorium 261 | 110 stool 262 | 111 barrel, cask 263 | 112 basket, handbasket 264 | 113 waterfall, falls 265 | 114 tent, collapsible shelter 266 | 115 bag 267 | 116 minibike, motorbike 268 | 117 cradle 269 | 118 oven 270 | 119 ball 271 | 120 food, solid food 272 | 121 step, stair 273 | 122 tank, storage tank 274 | 123 trade name, brand name, brand, marque 275 | 124 microwave, microwave oven 276 | 125 pot, flowerpot 277 | 126 animal, animate being, beast, brute, creature, fauna 278 | 127 bicycle, bike, wheel, cycle 279 | 128 lake 280 | 129 dishwasher, dish washer, dishwashing machine 281 | 130 screen, silver screen, projection screen 282 | 131 blanket, cover 283 | 132 sculpture 284 | 133 hood, exhaust hood 285 | 134 sconce 286 | 135 vase 287 | 136 traffic light, traffic signal, stoplight 288 | 137 tray 289 | 138 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 290 | 139 fan 291 | 140 pier, wharf, wharfage, dock 292 | 141 crt screen 293 | 142 plate 294 | 143 monitor, monitoring device 295 | 144 bulletin board, notice board 296 | 145 shower 297 | 146 radiator 298 | 147 glass, drinking glass 299 | 148 clock 300 | 149 flag -------------------------------------------------------------------------------- /libs/caffe.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package caffe; 4 | 5 | // Specifies the shape (dimensions) of a Blob. 6 | message BlobShape { 7 | repeated int64 dim = 1 [packed = true]; 8 | } 9 | 10 | message BlobProto { 11 | optional BlobShape shape = 7; 12 | repeated float data = 5 [packed = true]; 13 | repeated float diff = 6 [packed = true]; 14 | repeated double double_data = 8 [packed = true]; 15 | repeated double double_diff = 9 [packed = true]; 16 | 17 | // 4D dimensions -- deprecated. Use "shape" instead. 18 | optional int32 num = 1 [default = 0]; 19 | optional int32 channels = 2 [default = 0]; 20 | optional int32 height = 3 [default = 0]; 21 | optional int32 width = 4 [default = 0]; 22 | } 23 | 24 | // The BlobProtoVector is simply a way to pass multiple blobproto instances 25 | // around. 26 | message BlobProtoVector { 27 | repeated BlobProto blobs = 1; 28 | } 29 | 30 | message Datum { 31 | optional int32 channels = 1; 32 | optional int32 height = 2; 33 | optional int32 width = 3; 34 | // the actual image data, in bytes 35 | optional bytes data = 4; 36 | optional int32 label = 5; 37 | // Optionally, the datum could also hold float data. 38 | repeated float float_data = 6; 39 | // If true data contains an encoded image that need to be decoded 40 | optional bool encoded = 7 [default = false]; 41 | } 42 | 43 | message FillerParameter { 44 | // The filler type. 45 | optional string type = 1 [default = 'constant']; 46 | optional float value = 2 [default = 0]; // the value in constant filler 47 | optional float min = 3 [default = 0]; // the min value in uniform filler 48 | optional float max = 4 [default = 1]; // the max value in uniform filler 49 | optional float mean = 5 [default = 0]; // the mean value in Gaussian filler 50 | optional float std = 6 [default = 1]; // the std value in Gaussian filler 51 | // The expected number of non-zero output weights for a given input in 52 | // Gaussian filler -- the default -1 means don't perform sparsification. 53 | optional int32 sparse = 7 [default = -1]; 54 | // Normalize the filler variance by fan_in, fan_out, or their average. 55 | // Applies to 'xavier' and 'msra' fillers. 56 | enum VarianceNorm { 57 | FAN_IN = 0; 58 | FAN_OUT = 1; 59 | AVERAGE = 2; 60 | } 61 | optional VarianceNorm variance_norm = 8 [default = FAN_IN]; 62 | } 63 | 64 | message NetParameter { 65 | optional string name = 1; // consider giving the network a name 66 | // The input blobs to the network. 67 | repeated string input = 3; 68 | // The shape of the input blobs. 69 | repeated BlobShape input_shape = 8; 70 | 71 | // 4D input dimensions -- deprecated. Use "shape" instead. 72 | // If specified, for each input blob there should be four 73 | // values specifying the num, channels, height and width of the input blob. 74 | // Thus, there should be a total of (4 * #input) numbers. 75 | repeated int32 input_dim = 4; 76 | 77 | // Whether the network will force every layer to carry out backward operation. 78 | // If set False, then whether to carry out backward is determined 79 | // automatically according to the net structure and learning rates. 80 | optional bool force_backward = 5 [default = false]; 81 | // The current "state" of the network, including the phase, level, and stage. 82 | // Some layers may be included/excluded depending on this state and the states 83 | // specified in the layers' include and exclude fields. 84 | optional NetState state = 6; 85 | 86 | // Print debugging information about results while running Net::Forward, 87 | // Net::Backward, and Net::Update. 88 | optional bool debug_info = 7 [default = false]; 89 | 90 | // The layers that make up the net. Each of their configurations, including 91 | // connectivity and behavior, is specified as a LayerParameter. 92 | repeated LayerParameter layer = 100; // ID 100 so layers are printed last. 93 | 94 | // DEPRECATED: use 'layer' instead. 95 | repeated V1LayerParameter layers = 2; 96 | } 97 | 98 | // NOTE 99 | // Update the next available ID when you add a new SolverParameter field. 100 | // 101 | // SolverParameter next available ID: 41 (last added: type) 102 | message SolverParameter { 103 | ////////////////////////////////////////////////////////////////////////////// 104 | // Specifying the train and test networks 105 | // 106 | // Exactly one train net must be specified using one of the following fields: 107 | // train_net_param, train_net, net_param, net 108 | // One or more test nets may be specified using any of the following fields: 109 | // test_net_param, test_net, net_param, net 110 | // If more than one test net field is specified (e.g., both net and 111 | // test_net are specified), they will be evaluated in the field order given 112 | // above: (1) test_net_param, (2) test_net, (3) net_param/net. 113 | // A test_iter must be specified for each test_net. 114 | // A test_level and/or a test_stage may also be specified for each test_net. 115 | ////////////////////////////////////////////////////////////////////////////// 116 | 117 | // Proto filename for the train net, possibly combined with one or more 118 | // test nets. 119 | optional string net = 24; 120 | // Inline train net param, possibly combined with one or more test nets. 121 | optional NetParameter net_param = 25; 122 | 123 | optional string train_net = 1; // Proto filename for the train net. 124 | repeated string test_net = 2; // Proto filenames for the test nets. 125 | optional NetParameter train_net_param = 21; // Inline train net params. 126 | repeated NetParameter test_net_param = 22; // Inline test net params. 127 | 128 | // The states for the train/test nets. Must be unspecified or 129 | // specified once per net. 130 | // 131 | // By default, all states will have solver = true; 132 | // train_state will have phase = TRAIN, 133 | // and all test_state's will have phase = TEST. 134 | // Other defaults are set according to the NetState defaults. 135 | optional NetState train_state = 26; 136 | repeated NetState test_state = 27; 137 | 138 | // The number of iterations for each test net. 139 | repeated int32 test_iter = 3; 140 | 141 | // The number of iterations between two testing phases. 142 | optional int32 test_interval = 4 [default = 0]; 143 | optional bool test_compute_loss = 19 [default = false]; 144 | // If true, run an initial test pass before the first iteration, 145 | // ensuring memory availability and printing the starting value of the loss. 146 | optional bool test_initialization = 32 [default = true]; 147 | optional float base_lr = 5; // The base learning rate 148 | // the number of iterations between displaying info. If display = 0, no info 149 | // will be displayed. 150 | optional int32 display = 6; 151 | // Display the loss averaged over the last average_loss iterations 152 | optional int32 average_loss = 33 [default = 1]; 153 | optional int32 max_iter = 7; // the maximum number of iterations 154 | // accumulate gradients over `iter_size` x `batch_size` instances 155 | optional int32 iter_size = 36 [default = 1]; 156 | 157 | // The learning rate decay policy. The currently implemented learning rate 158 | // policies are as follows: 159 | // - fixed: always return base_lr. 160 | // - step: return base_lr * gamma ^ (floor(iter / step)) 161 | // - exp: return base_lr * gamma ^ iter 162 | // - inv: return base_lr * (1 + gamma * iter) ^ (- power) 163 | // - multistep: similar to step but it allows non uniform steps defined by 164 | // stepvalue 165 | // - poly: the effective learning rate follows a polynomial decay, to be 166 | // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) 167 | // - sigmoid: the effective learning rate follows a sigmod decay 168 | // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) 169 | // 170 | // where base_lr, max_iter, gamma, step, stepvalue and power are defined 171 | // in the solver parameter protocol buffer, and iter is the current iteration. 172 | optional string lr_policy = 8; 173 | optional float gamma = 9; // The parameter to compute the learning rate. 174 | optional float power = 10; // The parameter to compute the learning rate. 175 | optional float momentum = 11; // The momentum value. 176 | optional float weight_decay = 12; // The weight decay. 177 | // regularization types supported: L1 and L2 178 | // controlled by weight_decay 179 | optional string regularization_type = 29 [default = "L2"]; 180 | // the stepsize for learning rate policy "step" 181 | optional int32 stepsize = 13; 182 | // the stepsize for learning rate policy "multistep" 183 | repeated int32 stepvalue = 34; 184 | 185 | // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, 186 | // whenever their actual L2 norm is larger. 187 | optional float clip_gradients = 35 [default = -1]; 188 | 189 | optional int32 snapshot = 14 [default = 0]; // The snapshot interval 190 | optional string snapshot_prefix = 15; // The prefix for the snapshot. 191 | // whether to snapshot diff in the results or not. Snapshotting diff will help 192 | // debugging but the final protocol buffer size will be much larger. 193 | optional bool snapshot_diff = 16 [default = false]; 194 | enum SnapshotFormat { 195 | HDF5 = 0; 196 | BINARYPROTO = 1; 197 | } 198 | optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; 199 | // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. 200 | enum SolverMode { 201 | CPU = 0; 202 | GPU = 1; 203 | } 204 | optional SolverMode solver_mode = 17 [default = GPU]; 205 | // the device_id will that be used in GPU mode. Use device_id = 0 in default. 206 | optional int32 device_id = 18 [default = 0]; 207 | // If non-negative, the seed with which the Solver will initialize the Caffe 208 | // random number generator -- useful for reproducible results. Otherwise, 209 | // (and by default) initialize using a seed derived from the system clock. 210 | optional int64 random_seed = 20 [default = -1]; 211 | 212 | // type of the solver 213 | optional string type = 40 [default = "SGD"]; 214 | 215 | // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam 216 | optional float delta = 31 [default = 1e-8]; 217 | // parameters for the Adam solver 218 | optional float momentum2 = 39 [default = 0.999]; 219 | 220 | // RMSProp decay value 221 | // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) 222 | optional float rms_decay = 38; 223 | 224 | // If true, print information about the state of the net that may help with 225 | // debugging learning problems. 226 | optional bool debug_info = 23 [default = false]; 227 | 228 | // If false, don't save a snapshot after training finishes. 229 | optional bool snapshot_after_train = 28 [default = true]; 230 | 231 | // DEPRECATED: old solver enum types, use string instead 232 | enum SolverType { 233 | SGD = 0; 234 | NESTEROV = 1; 235 | ADAGRAD = 2; 236 | RMSPROP = 3; 237 | ADADELTA = 4; 238 | ADAM = 5; 239 | } 240 | // DEPRECATED: use type instead of solver_type 241 | optional SolverType solver_type = 30 [default = SGD]; 242 | } 243 | 244 | // A message that stores the solver snapshots 245 | message SolverState { 246 | optional int32 iter = 1; // The current iteration 247 | optional string learned_net = 2; // The file that stores the learned net. 248 | repeated BlobProto history = 3; // The history for sgd solvers 249 | optional int32 current_step = 4 [default = 0]; // The current step for learning rate 250 | } 251 | 252 | enum Phase { 253 | TRAIN = 0; 254 | TEST = 1; 255 | } 256 | 257 | message NetState { 258 | optional Phase phase = 1 [default = TEST]; 259 | optional int32 level = 2 [default = 0]; 260 | repeated string stage = 3; 261 | } 262 | 263 | message NetStateRule { 264 | // Set phase to require the NetState have a particular phase (TRAIN or TEST) 265 | // to meet this rule. 266 | optional Phase phase = 1; 267 | 268 | // Set the minimum and/or maximum levels in which the layer should be used. 269 | // Leave undefined to meet the rule regardless of level. 270 | optional int32 min_level = 2; 271 | optional int32 max_level = 3; 272 | 273 | // Customizable sets of stages to include or exclude. 274 | // The net must have ALL of the specified stages and NONE of the specified 275 | // "not_stage"s to meet the rule. 276 | // (Use multiple NetStateRules to specify conjunctions of stages.) 277 | repeated string stage = 4; 278 | repeated string not_stage = 5; 279 | } 280 | 281 | // Specifies training parameters (multipliers on global learning constants, 282 | // and the name and other settings used for weight sharing). 283 | message ParamSpec { 284 | // The names of the parameter blobs -- useful for sharing parameters among 285 | // layers, but never required otherwise. To share a parameter between two 286 | // layers, give it a (non-empty) name. 287 | optional string name = 1; 288 | 289 | // Whether to require shared weights to have the same shape, or just the same 290 | // count -- defaults to STRICT if unspecified. 291 | optional DimCheckMode share_mode = 2; 292 | enum DimCheckMode { 293 | // STRICT (default) requires that num, channels, height, width each match. 294 | STRICT = 0; 295 | // PERMISSIVE requires only the count (num*channels*height*width) to match. 296 | PERMISSIVE = 1; 297 | } 298 | 299 | // The multiplier on the global learning rate for this parameter. 300 | optional float lr_mult = 3 [default = 1.0]; 301 | 302 | // The multiplier on the global weight decay for this parameter. 303 | optional float decay_mult = 4 [default = 1.0]; 304 | } 305 | 306 | // NOTE 307 | // Update the next available ID when you add a new LayerParameter field. 308 | // 309 | // LayerParameter next available layer-specific ID: 153 (last added: bn_param) 310 | message LayerParameter { 311 | optional string name = 1; // the layer name 312 | optional string type = 2; // the layer type 313 | repeated string bottom = 3; // the name of each bottom blob 314 | repeated string top = 4; // the name of each top blob 315 | 316 | // The train / test phase for computation. 317 | optional Phase phase = 10; 318 | 319 | // The amount of weight to assign each top blob in the objective. 320 | // Each layer assigns a default value, usually of either 0 or 1, 321 | // to each top blob. 322 | repeated float loss_weight = 5; 323 | 324 | // Specifies training parameters (multipliers on global learning constants, 325 | // and the name and other settings used for weight sharing). 326 | repeated ParamSpec param = 6; 327 | 328 | // The blobs containing the numeric parameters of the layer. 329 | repeated BlobProto blobs = 7; 330 | 331 | // Specifies on which bottoms the backpropagation should be skipped. 332 | // The size must be either 0 or equal to the number of bottoms. 333 | repeated bool propagate_down = 11; 334 | 335 | // Rules controlling whether and when a layer is included in the network, 336 | // based on the current NetState. You may specify a non-zero number of rules 337 | // to include OR exclude, but not both. If no include or exclude rules are 338 | // specified, the layer is always included. If the current NetState meets 339 | // ANY (i.e., one or more) of the specified rules, the layer is 340 | // included/excluded. 341 | repeated NetStateRule include = 8; 342 | repeated NetStateRule exclude = 9; 343 | 344 | // Parameters for data pre-processing. 345 | optional TransformationParameter transform_param = 100; 346 | 347 | // Parameters shared by loss layers. 348 | optional LossParameter loss_param = 101; 349 | 350 | // Layer type-specific parameters. 351 | // 352 | // Note: certain layers may have more than one computational engine 353 | // for their implementation. These layers include an Engine type and 354 | // engine parameter for selecting the implementation. 355 | // The default for the engine is set by the ENGINE switch at compile-time. 356 | optional AccuracyParameter accuracy_param = 102; 357 | optional AdaptiveBiasChannelParameter adaptive_bias_channel_param = 148; 358 | optional ArgMaxParameter argmax_param = 103; 359 | optional BatchNormParameter batch_norm_param = 139; 360 | optional BNParameter bn_param = 152; 361 | optional BiasParameter bias_param = 141; 362 | optional BiasChannelParameter bias_channel_param = 149; 363 | optional ConcatParameter concat_param = 104; 364 | optional ContrastiveLossParameter contrastive_loss_param = 105; 365 | optional ConvolutionParameter convolution_param = 106; 366 | optional DataParameter data_param = 107; 367 | optional DenseCRFParameter dense_crf_param = 146; 368 | optional DomainTransformParameter domain_transform_param = 147; 369 | optional DropoutParameter dropout_param = 108; 370 | optional DummyDataParameter dummy_data_param = 109; 371 | optional EltwiseParameter eltwise_param = 110; 372 | optional ELUParameter elu_param = 140; 373 | optional EmbedParameter embed_param = 137; 374 | optional ExpParameter exp_param = 111; 375 | optional FlattenParameter flatten_param = 135; 376 | optional HDF5DataParameter hdf5_data_param = 112; 377 | optional HDF5OutputParameter hdf5_output_param = 113; 378 | optional HingeLossParameter hinge_loss_param = 114; 379 | optional ImageDataParameter image_data_param = 115; 380 | optional InfogainLossParameter infogain_loss_param = 116; 381 | optional InnerProductParameter inner_product_param = 117; 382 | optional InterpParameter interp_param = 143; 383 | optional LogParameter log_param = 134; 384 | optional LRNParameter lrn_param = 118; 385 | optional MatReadParameter mat_read_param = 151; 386 | optional MatWriteParameter mat_write_param = 145; 387 | optional MemoryDataParameter memory_data_param = 119; 388 | optional MVNParameter mvn_param = 120; 389 | optional PoolingParameter pooling_param = 121; 390 | optional PowerParameter power_param = 122; 391 | optional PReLUParameter prelu_param = 131; 392 | optional PythonParameter python_param = 130; 393 | optional ReductionParameter reduction_param = 136; 394 | optional ReLUParameter relu_param = 123; 395 | optional ReshapeParameter reshape_param = 133; 396 | optional ScaleParameter scale_param = 142; 397 | optional SegAccuracyParameter seg_accuracy_param = 144; 398 | optional SigmoidParameter sigmoid_param = 124; 399 | optional SoftmaxParameter softmax_param = 125; 400 | optional SPPParameter spp_param = 132; 401 | optional SliceParameter slice_param = 126; 402 | optional TanHParameter tanh_param = 127; 403 | optional ThresholdParameter threshold_param = 128; 404 | optional TileParameter tile_param = 138; 405 | optional UniqueLabelParameter unique_label_param = 150; 406 | optional WindowDataParameter window_data_param = 129; 407 | } 408 | 409 | // Message that stores parameters used to apply transformation 410 | // to the data layer's data 411 | message TransformationParameter { 412 | // For data pre-processing, we can do simple scaling and subtracting the 413 | // data mean, if provided. Note that the mean subtraction is always carried 414 | // out before scaling. 415 | optional float scale = 1 [default = 1]; 416 | // Specify if we want to randomly mirror data. 417 | optional bool mirror = 2 [default = false]; 418 | // Specify if we would like to randomly crop an image. 419 | optional uint32 crop_size = 3 [default = 0]; 420 | // mean_file and mean_value cannot be specified at the same time 421 | optional string mean_file = 4; 422 | // if specified can be repeated once (would substract it from all the channels) 423 | // or can be repeated the same number of times as channels 424 | // (would subtract them from the corresponding channel) 425 | repeated float mean_value = 5; 426 | // Force the decoded image to have 3 color channels. 427 | optional bool force_color = 6 [default = false]; 428 | // Force the decoded image to have 1 color channels. 429 | optional bool force_gray = 7 [default = false]; 430 | // If we want to do data augmentation, Scaling factor for randomly scaling input images 431 | repeated float scale_factors = 8; 432 | // the width for cropped region 433 | optional uint32 crop_width = 9 [default = 0]; 434 | // the height for cropped region 435 | optional uint32 crop_height = 10 [default = 0]; 436 | 437 | } 438 | 439 | // Message that stores parameters shared by loss layers 440 | message LossParameter { 441 | // If specified, ignore instances with the given label. 442 | optional int32 ignore_label = 1; 443 | // How to normalize the loss for loss layers that aggregate across batches, 444 | // spatial dimensions, or other dimensions. Currently only implemented in 445 | // SoftmaxWithLoss layer. 446 | enum NormalizationMode { 447 | // Divide by the number of examples in the batch times spatial dimensions. 448 | // Outputs that receive the ignore label will NOT be ignored in computing 449 | // the normalization factor. 450 | FULL = 0; 451 | // Divide by the total number of output locations that do not take the 452 | // ignore_label. If ignore_label is not set, this behaves like FULL. 453 | VALID = 1; 454 | // Divide by the batch size. 455 | BATCH_SIZE = 2; 456 | // Do not normalize the loss. 457 | NONE = 3; 458 | } 459 | optional NormalizationMode normalization = 3 [default = VALID]; 460 | // Deprecated. Ignored if normalization is specified. If normalization 461 | // is not specified, then setting this to false will be equivalent to 462 | // normalization = BATCH_SIZE to be consistent with previous behavior. 463 | optional bool normalize = 2; 464 | } 465 | 466 | // Messages that store parameters used by individual layer types follow, in 467 | // alphabetical order. 468 | 469 | message AccuracyParameter { 470 | // When computing accuracy, count as correct by comparing the true label to 471 | // the top k scoring classes. By default, only compare to the top scoring 472 | // class (i.e. argmax). 473 | optional uint32 top_k = 1 [default = 1]; 474 | 475 | // The "label" axis of the prediction blob, whose argmax corresponds to the 476 | // predicted label -- may be negative to index from the end (e.g., -1 for the 477 | // last axis). For example, if axis == 1 and the predictions are 478 | // (N x C x H x W), the label blob is expected to contain N*H*W ground truth 479 | // labels with integer values in {0, 1, ..., C-1}. 480 | optional int32 axis = 2 [default = 1]; 481 | 482 | // If specified, ignore instances with the given label. 483 | optional int32 ignore_label = 3; 484 | } 485 | 486 | message AdaptiveBiasChannelParameter { 487 | optional int32 num_iter = 1 [default = 1]; 488 | optional float bg_portion = 2 [default = 0.2]; 489 | optional float fg_portion = 3 [default = 0.2]; 490 | optional bool suppress_others = 4 [default = true]; 491 | optional float margin_others = 5 [default = 1e-5]; 492 | } 493 | 494 | message ArgMaxParameter { 495 | // If true produce pairs (argmax, maxval) 496 | optional bool out_max_val = 1 [default = false]; 497 | optional uint32 top_k = 2 [default = 1]; 498 | // The axis along which to maximise -- may be negative to index from the 499 | // end (e.g., -1 for the last axis). 500 | // By default ArgMaxLayer maximizes over the flattened trailing dimensions 501 | // for each index of the first / num dimension. 502 | optional int32 axis = 3; 503 | } 504 | 505 | message BiasChannelParameter { 506 | // Score biases. Separate values for BG / FG 507 | optional float bg_bias = 1 [default = 1.]; 508 | optional float fg_bias = 2 [default = 2.]; 509 | // will ignore labels with this value when adding bias 510 | repeated int32 ignore_label = 3; 511 | enum LabelType { 512 | IMAGE = 1; 513 | PIXEL = 2; 514 | } 515 | optional LabelType label_type = 4 [default = IMAGE]; 516 | // If the dataset defines generic background label or not. 517 | // The default value is defined for PASCAL VOC segmentation 518 | optional int32 background_label = 6 [default = 0]; 519 | } 520 | 521 | message ConcatParameter { 522 | // The axis along which to concatenate -- may be negative to index from the 523 | // end (e.g., -1 for the last axis). Other axes must have the 524 | // same dimension for all the bottom blobs. 525 | // By default, ConcatLayer concatenates blobs along the "channels" axis (1). 526 | optional int32 axis = 2 [default = 1]; 527 | 528 | // DEPRECATED: alias for "axis" -- does not support negative indexing. 529 | optional uint32 concat_dim = 1 [default = 1]; 530 | } 531 | 532 | message BatchNormParameter { 533 | // If false, accumulate global mean/variance values via a moving average. If 534 | // true, use those accumulated values instead of computing mean/variance 535 | // across the batch. 536 | optional bool use_global_stats = 1; 537 | // How much does the moving average decay each iteration? 538 | optional float moving_average_fraction = 2 [default = .999]; 539 | // Small value to add to the variance estimate so that we don't divide by 540 | // zero. 541 | optional float eps = 3 [default = 1e-5]; 542 | optional bool update_global_stats = 4 [default = false]; 543 | } 544 | 545 | message BNParameter { 546 | optional FillerParameter slope_filler = 1; 547 | optional FillerParameter bias_filler = 2; 548 | optional float momentum = 3 [default = 0.9]; 549 | optional float eps = 4 [default = 1e-5]; 550 | // If true, will use the moving average mean and std for training and test. 551 | // Will override the lr_param and freeze all the parameters. 552 | // Make sure to initialize the layer properly with pretrained parameters. 553 | optional bool frozen = 5 [default = false]; 554 | enum Engine { 555 | DEFAULT = 0; 556 | CAFFE = 1; 557 | CUDNN = 2; 558 | } 559 | optional Engine engine = 6 [default = DEFAULT]; 560 | } 561 | 562 | message BiasParameter { 563 | // The first axis of bottom[0] (the first input Blob) along which to apply 564 | // bottom[1] (the second input Blob). May be negative to index from the end 565 | // (e.g., -1 for the last axis). 566 | // 567 | // For example, if bottom[0] is 4D with shape 100x3x40x60, the output 568 | // top[0] will have the same shape, and bottom[1] may have any of the 569 | // following shapes (for the given value of axis): 570 | // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 571 | // (axis == 1 == -3) 3; 3x40; 3x40x60 572 | // (axis == 2 == -2) 40; 40x60 573 | // (axis == 3 == -1) 60 574 | // Furthermore, bottom[1] may have the empty shape (regardless of the value of 575 | // "axis") -- a scalar bias. 576 | optional int32 axis = 1 [default = 1]; 577 | 578 | // (num_axes is ignored unless just one bottom is given and the bias is 579 | // a learned parameter of the layer. Otherwise, num_axes is determined by the 580 | // number of axes by the second bottom.) 581 | // The number of axes of the input (bottom[0]) covered by the bias 582 | // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. 583 | // Set num_axes := 0, to add a zero-axis Blob: a scalar. 584 | optional int32 num_axes = 2 [default = 1]; 585 | 586 | // (filler is ignored unless just one bottom is given and the bias is 587 | // a learned parameter of the layer.) 588 | // The initialization for the learned bias parameter. 589 | // Default is the zero (0) initialization, resulting in the BiasLayer 590 | // initially performing the identity operation. 591 | optional FillerParameter filler = 3; 592 | } 593 | 594 | message ContrastiveLossParameter { 595 | // margin for dissimilar pair 596 | optional float margin = 1 [default = 1.0]; 597 | // The first implementation of this cost did not exactly match the cost of 598 | // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. 599 | // legacy_version = false (the default) uses (margin - d)^2 as proposed in the 600 | // Hadsell paper. New models should probably use this version. 601 | // legacy_version = true uses (margin - d^2). This is kept to support / 602 | // reproduce existing models and results 603 | optional bool legacy_version = 2 [default = false]; 604 | } 605 | 606 | message ConvolutionParameter { 607 | optional uint32 num_output = 1; // The number of outputs for the layer 608 | optional bool bias_term = 2 [default = true]; // whether to have bias terms 609 | 610 | // Pad, kernel size, and stride are all given as a single value for equal 611 | // dimensions in all spatial dimensions, or once per spatial dimension. 612 | repeated uint32 pad = 3; // The padding size; defaults to 0 613 | repeated uint32 kernel_size = 4; // The kernel size 614 | repeated uint32 stride = 6; // The stride; defaults to 1 615 | // Factor used to dilate the kernel, (implicitly) zero-filling the resulting 616 | // holes. (Kernel dilation is sometimes referred to by its use in the 617 | // algorithme à trous from Holschneider et al. 1987.) 618 | repeated uint32 dilation = 18; // The dilation; defaults to 1 619 | 620 | // For 2D convolution only, the *_h and *_w versions may also be used to 621 | // specify both spatial dimensions. 622 | optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) 623 | optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) 624 | optional uint32 kernel_h = 11; // The kernel height (2D only) 625 | optional uint32 kernel_w = 12; // The kernel width (2D only) 626 | optional uint32 stride_h = 13; // The stride height (2D only) 627 | optional uint32 stride_w = 14; // The stride width (2D only) 628 | 629 | optional uint32 group = 5 [default = 1]; // The group size for group conv 630 | 631 | optional FillerParameter weight_filler = 7; // The filler for the weight 632 | optional FillerParameter bias_filler = 8; // The filler for the bias 633 | enum Engine { 634 | DEFAULT = 0; 635 | CAFFE = 1; 636 | CUDNN = 2; 637 | } 638 | optional Engine engine = 15 [default = DEFAULT]; 639 | 640 | // The axis to interpret as "channels" when performing convolution. 641 | // Preceding dimensions are treated as independent inputs; 642 | // succeeding dimensions are treated as "spatial". 643 | // With (N, C, H, W) inputs, and axis == 1 (the default), we perform 644 | // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for 645 | // groups g>1) filters across the spatial axes (H, W) of the input. 646 | // With (N, C, D, H, W) inputs, and axis == 1, we perform 647 | // N independent 3D convolutions, sliding (C/g)-channels 648 | // filters across the spatial axes (D, H, W) of the input. 649 | optional int32 axis = 16 [default = 1]; 650 | 651 | // Whether to force use of the general ND convolution, even if a specific 652 | // implementation for blobs of the appropriate number of spatial dimensions 653 | // is available. (Currently, there is only a 2D-specific convolution 654 | // implementation; for input blobs with num_axes != 2, this option is 655 | // ignored and the ND implementation will be used.) 656 | optional bool force_nd_im2col = 17 [default = false]; 657 | } 658 | 659 | message DataParameter { 660 | enum DB { 661 | LEVELDB = 0; 662 | LMDB = 1; 663 | } 664 | // Specify the data source. 665 | optional string source = 1; 666 | // Specify the batch size. 667 | optional uint32 batch_size = 4; 668 | // The rand_skip variable is for the data layer to skip a few data points 669 | // to avoid all asynchronous sgd clients to start at the same point. The skip 670 | // point would be set as rand_skip * rand(0,1). Note that rand_skip should not 671 | // be larger than the number of keys in the database. 672 | // DEPRECATED. Each solver accesses a different subset of the database. 673 | optional uint32 rand_skip = 7 [default = 0]; 674 | optional DB backend = 8 [default = LEVELDB]; 675 | // DEPRECATED. See TransformationParameter. For data pre-processing, we can do 676 | // simple scaling and subtracting the data mean, if provided. Note that the 677 | // mean subtraction is always carried out before scaling. 678 | optional float scale = 2 [default = 1]; 679 | optional string mean_file = 3; 680 | // DEPRECATED. See TransformationParameter. Specify if we would like to randomly 681 | // crop an image. 682 | optional uint32 crop_size = 5 [default = 0]; 683 | // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror 684 | // data. 685 | optional bool mirror = 6 [default = false]; 686 | // Force the encoded image to have 3 color channels 687 | optional bool force_encoded_color = 9 [default = false]; 688 | // Prefetch queue (Number of batches to prefetch to host memory, increase if 689 | // data access bandwidth varies). 690 | optional uint32 prefetch = 10 [default = 4]; 691 | } 692 | 693 | message DenseCRFParameter { 694 | // max number of iteration for message passing 695 | optional int32 max_iter = 1 [default = 10]; 696 | // positional std and weight for "Positional" filter (color-independent) 697 | repeated float pos_xy_std = 2; 698 | repeated float pos_w = 3; 699 | // positional std, color std and weight for Bilateral filter 700 | repeated float bi_xy_std = 4; 701 | repeated float bi_rgb_std = 5; 702 | repeated float bi_w = 6; 703 | // output is probability or score (score = log(prob)) 704 | optional bool output_probability = 7 [default = true]; 705 | } 706 | 707 | message DomainTransformParameter { 708 | // Max number of iteration for filtering. 709 | optional int32 num_iter = 1 [default = 3]; 710 | // Standard deviation for spatial domain. 711 | optional float spatial_sigma = 2 [default = 50]; 712 | // Standard deviation for range domain. 713 | optional float range_sigma = 3 [default = 5]; 714 | // minimum weight value (to avoid zero gradient for ref_grad_data) 715 | optional float min_weight = 4 [default = 0]; 716 | } 717 | 718 | message DropoutParameter { 719 | optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio 720 | } 721 | 722 | // DummyDataLayer fills any number of arbitrarily shaped blobs with random 723 | // (or constant) data generated by "Fillers" (see "message FillerParameter"). 724 | message DummyDataParameter { 725 | // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N 726 | // shape fields, and 0, 1 or N data_fillers. 727 | // 728 | // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. 729 | // If 1 data_filler is specified, it is applied to all top blobs. If N are 730 | // specified, the ith is applied to the ith top blob. 731 | repeated FillerParameter data_filler = 1; 732 | repeated BlobShape shape = 6; 733 | 734 | // 4D dimensions -- deprecated. Use "shape" instead. 735 | repeated uint32 num = 2; 736 | repeated uint32 channels = 3; 737 | repeated uint32 height = 4; 738 | repeated uint32 width = 5; 739 | } 740 | 741 | message EltwiseParameter { 742 | enum EltwiseOp { 743 | PROD = 0; 744 | SUM = 1; 745 | MAX = 2; 746 | } 747 | optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation 748 | repeated float coeff = 2; // blob-wise coefficient for SUM operation 749 | 750 | // Whether to use an asymptotically slower (for >2 inputs) but stabler method 751 | // of computing the gradient for the PROD operation. (No effect for SUM op.) 752 | optional bool stable_prod_grad = 3 [default = true]; 753 | } 754 | 755 | // Message that stores parameters used by ELULayer 756 | message ELUParameter { 757 | // Described in: 758 | // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate 759 | // Deep Network Learning by Exponential Linear Units (ELUs). arXiv 760 | optional float alpha = 1 [default = 1]; 761 | } 762 | 763 | // Message that stores parameters used by EmbedLayer 764 | message EmbedParameter { 765 | optional uint32 num_output = 1; // The number of outputs for the layer 766 | // The input is given as integers to be interpreted as one-hot 767 | // vector indices with dimension num_input. Hence num_input should be 768 | // 1 greater than the maximum possible input value. 769 | optional uint32 input_dim = 2; 770 | 771 | optional bool bias_term = 3 [default = true]; // Whether to use a bias term 772 | optional FillerParameter weight_filler = 4; // The filler for the weight 773 | optional FillerParameter bias_filler = 5; // The filler for the bias 774 | 775 | } 776 | 777 | // Message that stores parameters used by ExpLayer 778 | message ExpParameter { 779 | // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. 780 | // Or if base is set to the default (-1), base is set to e, 781 | // so y = exp(shift + scale * x). 782 | optional float base = 1 [default = -1.0]; 783 | optional float scale = 2 [default = 1.0]; 784 | optional float shift = 3 [default = 0.0]; 785 | } 786 | 787 | /// Message that stores parameters used by FlattenLayer 788 | message FlattenParameter { 789 | // The first axis to flatten: all preceding axes are retained in the output. 790 | // May be negative to index from the end (e.g., -1 for the last axis). 791 | optional int32 axis = 1 [default = 1]; 792 | 793 | // The last axis to flatten: all following axes are retained in the output. 794 | // May be negative to index from the end (e.g., the default -1 for the last 795 | // axis). 796 | optional int32 end_axis = 2 [default = -1]; 797 | } 798 | 799 | // Message that stores parameters used by HDF5DataLayer 800 | message HDF5DataParameter { 801 | // Specify the data source. 802 | optional string source = 1; 803 | // Specify the batch size. 804 | optional uint32 batch_size = 2; 805 | 806 | // Specify whether to shuffle the data. 807 | // If shuffle == true, the ordering of the HDF5 files is shuffled, 808 | // and the ordering of data within any given HDF5 file is shuffled, 809 | // but data between different files are not interleaved; all of a file's 810 | // data are output (in a random order) before moving onto another file. 811 | optional bool shuffle = 3 [default = false]; 812 | } 813 | 814 | message HDF5OutputParameter { 815 | optional string file_name = 1; 816 | } 817 | 818 | message HingeLossParameter { 819 | enum Norm { 820 | L1 = 1; 821 | L2 = 2; 822 | } 823 | // Specify the Norm to use L1 or L2 824 | optional Norm norm = 1 [default = L1]; 825 | } 826 | 827 | message ImageDataParameter { 828 | // Specify the data source. 829 | optional string source = 1; 830 | // Specify the batch size. 831 | optional uint32 batch_size = 4 [default = 1]; 832 | // The rand_skip variable is for the data layer to skip a few data points 833 | // to avoid all asynchronous sgd clients to start at the same point. The skip 834 | // point would be set as rand_skip * rand(0,1). Note that rand_skip should not 835 | // be larger than the number of keys in the database. 836 | optional uint32 rand_skip = 7 [default = 0]; 837 | // Whether or not ImageLayer should shuffle the list of files at every epoch. 838 | optional bool shuffle = 8 [default = false]; 839 | // It will also resize images if new_height or new_width are not zero. 840 | optional uint32 new_height = 9 [default = 0]; 841 | optional uint32 new_width = 10 [default = 0]; 842 | // Specify if the images are color or gray 843 | optional bool is_color = 11 [default = true]; 844 | 845 | // This is the value set for pixels or images where we don't know the label 846 | optional int32 ignore_label = 15 [default = 255]; 847 | enum LabelType { 848 | NONE = 0; 849 | IMAGE = 1; 850 | PIXEL = 2; 851 | } 852 | optional LabelType label_type = 16 [default = IMAGE]; 853 | 854 | // DEPRECATED. See TransformationParameter. For data pre-processing, we can do 855 | // simple scaling and subtracting the data mean, if provided. Note that the 856 | // mean subtraction is always carried out before scaling. 857 | optional float scale = 2 [default = 1]; 858 | optional string mean_file = 3; 859 | // DEPRECATED. See TransformationParameter. Specify if we would like to randomly 860 | // crop an image. 861 | optional uint32 crop_size = 5 [default = 0]; 862 | // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror 863 | // data. 864 | optional bool mirror = 6 [default = false]; 865 | optional string root_folder = 12 [default = ""]; 866 | } 867 | 868 | message InfogainLossParameter { 869 | // Specify the infogain matrix source. 870 | optional string source = 1; 871 | } 872 | 873 | message InnerProductParameter { 874 | optional uint32 num_output = 1; // The number of outputs for the layer 875 | optional bool bias_term = 2 [default = true]; // whether to have bias terms 876 | optional FillerParameter weight_filler = 3; // The filler for the weight 877 | optional FillerParameter bias_filler = 4; // The filler for the bias 878 | 879 | // The first axis to be lumped into a single inner product computation; 880 | // all preceding axes are retained in the output. 881 | // May be negative to index from the end (e.g., -1 for the last axis). 882 | optional int32 axis = 5 [default = 1]; 883 | // Specify whether to transpose the weight matrix or not. 884 | // If transpose == true, any operations will be performed on the transpose 885 | // of the weight matrix. The weight matrix itself is not going to be transposed 886 | // but rather the transfer flag of operations will be toggled accordingly. 887 | optional bool transpose = 6 [default = false]; 888 | } 889 | 890 | message InterpParameter { 891 | optional int32 height = 1 [default = 0]; // Height of output 892 | optional int32 width = 2 [default = 0]; // Width of output 893 | optional int32 zoom_factor = 3 [default = 1]; // zoom factor 894 | optional int32 shrink_factor = 4 [default = 1]; // shrink factor 895 | optional int32 pad_beg = 5 [default = 0]; // padding at begin of input 896 | optional int32 pad_end = 6 [default = 0]; // padding at end of input 897 | } 898 | 899 | // Message that stores parameters used by LogLayer 900 | message LogParameter { 901 | // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. 902 | // Or if base is set to the default (-1), base is set to e, 903 | // so y = ln(shift + scale * x) = log_e(shift + scale * x) 904 | optional float base = 1 [default = -1.0]; 905 | optional float scale = 2 [default = 1.0]; 906 | optional float shift = 3 [default = 0.0]; 907 | } 908 | 909 | // Message that stores parameters used by LRNLayer 910 | message LRNParameter { 911 | optional uint32 local_size = 1 [default = 5]; 912 | optional float alpha = 2 [default = 1.]; 913 | optional float beta = 3 [default = 0.75]; 914 | enum NormRegion { 915 | ACROSS_CHANNELS = 0; 916 | WITHIN_CHANNEL = 1; 917 | } 918 | optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; 919 | optional float k = 5 [default = 1.]; 920 | enum Engine { 921 | DEFAULT = 0; 922 | CAFFE = 1; 923 | CUDNN = 2; 924 | } 925 | optional Engine engine = 6 [default = DEFAULT]; 926 | } 927 | 928 | message MatReadParameter { 929 | required string prefix = 1; 930 | optional string source = 2 [default = ""]; 931 | optional int32 strip = 3 [default = 0]; 932 | optional int32 batch_size = 4 [default = 1]; 933 | } 934 | 935 | message MatWriteParameter { 936 | required string prefix = 1; 937 | optional string source = 2 [default = ""]; 938 | optional int32 strip = 3 [default = 0]; 939 | optional int32 period = 4 [default = 1]; 940 | } 941 | 942 | message MemoryDataParameter { 943 | optional uint32 batch_size = 1; 944 | optional uint32 channels = 2; 945 | optional uint32 height = 3; 946 | optional uint32 width = 4; 947 | } 948 | 949 | message MVNParameter { 950 | // This parameter can be set to false to normalize mean only 951 | optional bool normalize_variance = 1 [default = true]; 952 | 953 | // This parameter can be set to true to perform DNN-like MVN 954 | optional bool across_channels = 2 [default = false]; 955 | 956 | // Epsilon for not dividing by zero while normalizing variance 957 | optional float eps = 3 [default = 1e-9]; 958 | } 959 | 960 | message PoolingParameter { 961 | enum PoolMethod { 962 | MAX = 0; 963 | AVE = 1; 964 | STOCHASTIC = 2; 965 | } 966 | optional PoolMethod pool = 1 [default = MAX]; // The pooling method 967 | // Pad, kernel size, and stride are all given as a single value for equal 968 | // dimensions in height and width or as Y, X pairs. 969 | optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) 970 | optional uint32 pad_h = 9 [default = 0]; // The padding height 971 | optional uint32 pad_w = 10 [default = 0]; // The padding width 972 | optional uint32 kernel_size = 2; // The kernel size (square) 973 | optional uint32 kernel_h = 5; // The kernel height 974 | optional uint32 kernel_w = 6; // The kernel width 975 | optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) 976 | optional uint32 stride_h = 7; // The stride height 977 | optional uint32 stride_w = 8; // The stride width 978 | enum Engine { 979 | DEFAULT = 0; 980 | CAFFE = 1; 981 | CUDNN = 2; 982 | } 983 | optional Engine engine = 11 [default = DEFAULT]; 984 | // If global_pooling then it will pool over the size of the bottom by doing 985 | // kernel_h = bottom->height and kernel_w = bottom->width 986 | optional bool global_pooling = 12 [default = false]; 987 | } 988 | 989 | message PowerParameter { 990 | // PowerLayer computes outputs y = (shift + scale * x) ^ power. 991 | optional float power = 1 [default = 1.0]; 992 | optional float scale = 2 [default = 1.0]; 993 | optional float shift = 3 [default = 0.0]; 994 | } 995 | 996 | message PythonParameter { 997 | optional string module = 1; 998 | optional string layer = 2; 999 | // This value is set to the attribute `param_str` of the `PythonLayer` object 1000 | // in Python before calling the `setup()` method. This could be a number, 1001 | // string, dictionary in Python dict format, JSON, etc. You may parse this 1002 | // string in `setup` method and use it in `forward` and `backward`. 1003 | optional string param_str = 3 [default = '']; 1004 | // Whether this PythonLayer is shared among worker solvers during data parallelism. 1005 | // If true, each worker solver sequentially run forward from this layer. 1006 | // This value should be set true if you are using it as a data layer. 1007 | optional bool share_in_parallel = 4 [default = false]; 1008 | } 1009 | 1010 | // Message that stores parameters used by ReductionLayer 1011 | message ReductionParameter { 1012 | enum ReductionOp { 1013 | SUM = 1; 1014 | ASUM = 2; 1015 | SUMSQ = 3; 1016 | MEAN = 4; 1017 | } 1018 | 1019 | optional ReductionOp operation = 1 [default = SUM]; // reduction operation 1020 | 1021 | // The first axis to reduce to a scalar -- may be negative to index from the 1022 | // end (e.g., -1 for the last axis). 1023 | // (Currently, only reduction along ALL "tail" axes is supported; reduction 1024 | // of axis M through N, where N < num_axes - 1, is unsupported.) 1025 | // Suppose we have an n-axis bottom Blob with shape: 1026 | // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). 1027 | // If axis == m, the output Blob will have shape 1028 | // (d0, d1, d2, ..., d(m-1)), 1029 | // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) 1030 | // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. 1031 | // If axis == 0 (the default), the output Blob always has the empty shape 1032 | // (count 1), performing reduction across the entire input -- 1033 | // often useful for creating new loss functions. 1034 | optional int32 axis = 2 [default = 0]; 1035 | 1036 | optional float coeff = 3 [default = 1.0]; // coefficient for output 1037 | } 1038 | 1039 | // Message that stores parameters used by ReLULayer 1040 | message ReLUParameter { 1041 | // Allow non-zero slope for negative inputs to speed up optimization 1042 | // Described in: 1043 | // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities 1044 | // improve neural network acoustic models. In ICML Workshop on Deep Learning 1045 | // for Audio, Speech, and Language Processing. 1046 | optional float negative_slope = 1 [default = 0]; 1047 | enum Engine { 1048 | DEFAULT = 0; 1049 | CAFFE = 1; 1050 | CUDNN = 2; 1051 | } 1052 | optional Engine engine = 2 [default = DEFAULT]; 1053 | } 1054 | 1055 | message ReshapeParameter { 1056 | // Specify the output dimensions. If some of the dimensions are set to 0, 1057 | // the corresponding dimension from the bottom layer is used (unchanged). 1058 | // Exactly one dimension may be set to -1, in which case its value is 1059 | // inferred from the count of the bottom blob and the remaining dimensions. 1060 | // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: 1061 | // 1062 | // layer { 1063 | // type: "Reshape" bottom: "input" top: "output" 1064 | // reshape_param { ... } 1065 | // } 1066 | // 1067 | // If "input" is 2D with shape 2 x 8, then the following reshape_param 1068 | // specifications are all equivalent, producing a 3D blob "output" with shape 1069 | // 2 x 2 x 4: 1070 | // 1071 | // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } 1072 | // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } 1073 | // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } 1074 | // reshape_param { shape { dim: -1 dim: 0 dim: 2 } } 1075 | // 1076 | optional BlobShape shape = 1; 1077 | 1078 | // axis and num_axes control the portion of the bottom blob's shape that are 1079 | // replaced by (included in) the reshape. By default (axis == 0 and 1080 | // num_axes == -1), the entire bottom blob shape is included in the reshape, 1081 | // and hence the shape field must specify the entire output shape. 1082 | // 1083 | // axis may be non-zero to retain some portion of the beginning of the input 1084 | // shape (and may be negative to index from the end; e.g., -1 to begin the 1085 | // reshape after the last axis, including nothing in the reshape, 1086 | // -2 to include only the last axis, etc.). 1087 | // 1088 | // For example, suppose "input" is a 2D blob with shape 2 x 8. 1089 | // Then the following ReshapeLayer specifications are all equivalent, 1090 | // producing a blob "output" with shape 2 x 2 x 4: 1091 | // 1092 | // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } 1093 | // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } 1094 | // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } 1095 | // 1096 | // num_axes specifies the extent of the reshape. 1097 | // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on 1098 | // input axes in the range [axis, axis+num_axes]. 1099 | // num_axes may also be -1, the default, to include all remaining axes 1100 | // (starting from axis). 1101 | // 1102 | // For example, suppose "input" is a 2D blob with shape 2 x 8. 1103 | // Then the following ReshapeLayer specifications are equivalent, 1104 | // producing a blob "output" with shape 1 x 2 x 8. 1105 | // 1106 | // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } 1107 | // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } 1108 | // reshape_param { shape { dim: 1 } num_axes: 0 } 1109 | // 1110 | // On the other hand, these would produce output blob shape 2 x 1 x 8: 1111 | // 1112 | // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } 1113 | // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } 1114 | // 1115 | optional int32 axis = 2 [default = 0]; 1116 | optional int32 num_axes = 3 [default = -1]; 1117 | } 1118 | 1119 | message ScaleParameter { 1120 | // The first axis of bottom[0] (the first input Blob) along which to apply 1121 | // bottom[1] (the second input Blob). May be negative to index from the end 1122 | // (e.g., -1 for the last axis). 1123 | // 1124 | // For example, if bottom[0] is 4D with shape 100x3x40x60, the output 1125 | // top[0] will have the same shape, and bottom[1] may have any of the 1126 | // following shapes (for the given value of axis): 1127 | // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 1128 | // (axis == 1 == -3) 3; 3x40; 3x40x60 1129 | // (axis == 2 == -2) 40; 40x60 1130 | // (axis == 3 == -1) 60 1131 | // Furthermore, bottom[1] may have the empty shape (regardless of the value of 1132 | // "axis") -- a scalar multiplier. 1133 | optional int32 axis = 1 [default = 1]; 1134 | 1135 | // (num_axes is ignored unless just one bottom is given and the scale is 1136 | // a learned parameter of the layer. Otherwise, num_axes is determined by the 1137 | // number of axes by the second bottom.) 1138 | // The number of axes of the input (bottom[0]) covered by the scale 1139 | // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. 1140 | // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. 1141 | optional int32 num_axes = 2 [default = 1]; 1142 | 1143 | // (filler is ignored unless just one bottom is given and the scale is 1144 | // a learned parameter of the layer.) 1145 | // The initialization for the learned scale parameter. 1146 | // Default is the unit (1) initialization, resulting in the ScaleLayer 1147 | // initially performing the identity operation. 1148 | optional FillerParameter filler = 3; 1149 | 1150 | // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but 1151 | // may be more efficient). Initialized with bias_filler (defaults to 0). 1152 | optional bool bias_term = 4 [default = false]; 1153 | optional FillerParameter bias_filler = 5; 1154 | } 1155 | 1156 | message SegAccuracyParameter { 1157 | enum AccuracyMetric { 1158 | PixelAccuracy = 0; 1159 | ClassAccuracy = 1; 1160 | PixelIOU = 2; 1161 | } 1162 | optional AccuracyMetric metric = 1 [default = PixelAccuracy]; 1163 | // will ignore pixels with this value when computing accuracy 1164 | repeated int32 ignore_label = 2; 1165 | optional bool reset = 3 [default = true]; 1166 | } 1167 | 1168 | message SigmoidParameter { 1169 | enum Engine { 1170 | DEFAULT = 0; 1171 | CAFFE = 1; 1172 | CUDNN = 2; 1173 | } 1174 | optional Engine engine = 1 [default = DEFAULT]; 1175 | } 1176 | 1177 | message SliceParameter { 1178 | // The axis along which to slice -- may be negative to index from the end 1179 | // (e.g., -1 for the last axis). 1180 | // By default, SliceLayer concatenates blobs along the "channels" axis (1). 1181 | optional int32 axis = 3 [default = 1]; 1182 | repeated uint32 slice_point = 2; 1183 | 1184 | // DEPRECATED: alias for "axis" -- does not support negative indexing. 1185 | optional uint32 slice_dim = 1 [default = 1]; 1186 | } 1187 | 1188 | // Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer 1189 | message SoftmaxParameter { 1190 | enum Engine { 1191 | DEFAULT = 0; 1192 | CAFFE = 1; 1193 | CUDNN = 2; 1194 | } 1195 | optional Engine engine = 1 [default = DEFAULT]; 1196 | 1197 | // The axis along which to perform the softmax -- may be negative to index 1198 | // from the end (e.g., -1 for the last axis). 1199 | // Any other axes will be evaluated as independent softmaxes. 1200 | optional int32 axis = 2 [default = 1]; 1201 | } 1202 | 1203 | message TanHParameter { 1204 | enum Engine { 1205 | DEFAULT = 0; 1206 | CAFFE = 1; 1207 | CUDNN = 2; 1208 | } 1209 | optional Engine engine = 1 [default = DEFAULT]; 1210 | } 1211 | 1212 | // Message that stores parameters used by TileLayer 1213 | message TileParameter { 1214 | // The index of the axis to tile. 1215 | optional int32 axis = 1 [default = 1]; 1216 | 1217 | // The number of copies (tiles) of the blob to output. 1218 | optional int32 tiles = 2; 1219 | } 1220 | 1221 | // Message that stores parameters used by ThresholdLayer 1222 | message ThresholdParameter { 1223 | optional float threshold = 1 [default = 0]; // Strictly positive values 1224 | } 1225 | 1226 | message UniqueLabelParameter { 1227 | required int32 max_labels = 1; 1228 | repeated int32 ignore_label = 2; 1229 | repeated float force_label = 3; 1230 | } 1231 | 1232 | message WindowDataParameter { 1233 | // Specify the data source. 1234 | optional string source = 1; 1235 | // For data pre-processing, we can do simple scaling and subtracting the 1236 | // data mean, if provided. Note that the mean subtraction is always carried 1237 | // out before scaling. 1238 | optional float scale = 2 [default = 1]; 1239 | optional string mean_file = 3; 1240 | // Specify the batch size. 1241 | optional uint32 batch_size = 4; 1242 | // Specify if we would like to randomly crop an image. 1243 | optional uint32 crop_size = 5 [default = 0]; 1244 | // Specify if we want to randomly mirror data. 1245 | optional bool mirror = 6 [default = false]; 1246 | // Foreground (object) overlap threshold 1247 | optional float fg_threshold = 7 [default = 0.5]; 1248 | // Background (non-object) overlap threshold 1249 | optional float bg_threshold = 8 [default = 0.5]; 1250 | // Fraction of batch that should be foreground objects 1251 | optional float fg_fraction = 9 [default = 0.25]; 1252 | // Amount of contextual padding to add around a window 1253 | // (used only by the window_data_layer) 1254 | optional uint32 context_pad = 10 [default = 0]; 1255 | // Mode for cropping out a detection window 1256 | // warp: cropped window is warped to a fixed size and aspect ratio 1257 | // square: the tightest square around the window is cropped 1258 | optional string crop_mode = 11 [default = "warp"]; 1259 | // cache_images: will load all images in memory for faster access 1260 | optional bool cache_images = 12 [default = false]; 1261 | // append root_folder to locate images 1262 | optional string root_folder = 13 [default = ""]; 1263 | } 1264 | 1265 | message SPPParameter { 1266 | enum PoolMethod { 1267 | MAX = 0; 1268 | AVE = 1; 1269 | STOCHASTIC = 2; 1270 | } 1271 | optional uint32 pyramid_height = 1; 1272 | optional PoolMethod pool = 2 [default = MAX]; // The pooling method 1273 | enum Engine { 1274 | DEFAULT = 0; 1275 | CAFFE = 1; 1276 | CUDNN = 2; 1277 | } 1278 | optional Engine engine = 6 [default = DEFAULT]; 1279 | } 1280 | 1281 | // DEPRECATED: use LayerParameter. 1282 | message V1LayerParameter { 1283 | repeated string bottom = 2; 1284 | repeated string top = 3; 1285 | optional string name = 4; 1286 | repeated NetStateRule include = 32; 1287 | repeated NetStateRule exclude = 33; 1288 | enum LayerType { 1289 | NONE = 0; 1290 | ABSVAL = 35; 1291 | ACCURACY = 1; 1292 | ARGMAX = 30; 1293 | BNLL = 2; 1294 | CONCAT = 3; 1295 | CONTRASTIVE_LOSS = 37; 1296 | CONVOLUTION = 4; 1297 | DATA = 5; 1298 | DECONVOLUTION = 39; 1299 | DROPOUT = 6; 1300 | DUMMY_DATA = 32; 1301 | EUCLIDEAN_LOSS = 7; 1302 | ELTWISE = 25; 1303 | EXP = 38; 1304 | FLATTEN = 8; 1305 | HDF5_DATA = 9; 1306 | HDF5_OUTPUT = 10; 1307 | HINGE_LOSS = 28; 1308 | IM2COL = 11; 1309 | IMAGE_DATA = 12; 1310 | INFOGAIN_LOSS = 13; 1311 | INNER_PRODUCT = 14; 1312 | LRN = 15; 1313 | MEMORY_DATA = 29; 1314 | MULTINOMIAL_LOGISTIC_LOSS = 16; 1315 | MVN = 34; 1316 | POOLING = 17; 1317 | POWER = 26; 1318 | RELU = 18; 1319 | SIGMOID = 19; 1320 | SIGMOID_CROSS_ENTROPY_LOSS = 27; 1321 | SILENCE = 36; 1322 | SOFTMAX = 20; 1323 | SOFTMAX_LOSS = 21; 1324 | SPLIT = 22; 1325 | SLICE = 33; 1326 | TANH = 23; 1327 | WINDOW_DATA = 24; 1328 | THRESHOLD = 31; 1329 | } 1330 | optional LayerType type = 5; 1331 | repeated BlobProto blobs = 6; 1332 | repeated string param = 1001; 1333 | repeated DimCheckMode blob_share_mode = 1002; 1334 | enum DimCheckMode { 1335 | STRICT = 0; 1336 | PERMISSIVE = 1; 1337 | } 1338 | repeated float blobs_lr = 7; 1339 | repeated float weight_decay = 8; 1340 | repeated float loss_weight = 35; 1341 | optional AccuracyParameter accuracy_param = 27; 1342 | optional ArgMaxParameter argmax_param = 23; 1343 | optional ConcatParameter concat_param = 9; 1344 | optional ContrastiveLossParameter contrastive_loss_param = 40; 1345 | optional ConvolutionParameter convolution_param = 10; 1346 | optional DataParameter data_param = 11; 1347 | optional DropoutParameter dropout_param = 12; 1348 | optional DummyDataParameter dummy_data_param = 26; 1349 | optional EltwiseParameter eltwise_param = 24; 1350 | optional ExpParameter exp_param = 41; 1351 | optional HDF5DataParameter hdf5_data_param = 13; 1352 | optional HDF5OutputParameter hdf5_output_param = 14; 1353 | optional HingeLossParameter hinge_loss_param = 29; 1354 | optional ImageDataParameter image_data_param = 15; 1355 | optional InfogainLossParameter infogain_loss_param = 16; 1356 | optional InnerProductParameter inner_product_param = 17; 1357 | optional LRNParameter lrn_param = 18; 1358 | optional MemoryDataParameter memory_data_param = 22; 1359 | optional MVNParameter mvn_param = 34; 1360 | optional PoolingParameter pooling_param = 19; 1361 | optional PowerParameter power_param = 21; 1362 | optional ReLUParameter relu_param = 30; 1363 | optional SigmoidParameter sigmoid_param = 38; 1364 | optional SoftmaxParameter softmax_param = 39; 1365 | optional SliceParameter slice_param = 31; 1366 | optional TanHParameter tanh_param = 37; 1367 | optional ThresholdParameter threshold_param = 25; 1368 | optional WindowDataParameter window_data_param = 20; 1369 | optional TransformationParameter transform_param = 36; 1370 | optional LossParameter loss_param = 42; 1371 | optional V0LayerParameter layer = 1; 1372 | } 1373 | 1374 | // DEPRECATED: V0LayerParameter is the old way of specifying layer parameters 1375 | // in Caffe. We keep this message type around for legacy support. 1376 | message V0LayerParameter { 1377 | optional string name = 1; // the layer name 1378 | optional string type = 2; // the string to specify the layer type 1379 | 1380 | // Parameters to specify layers with inner products. 1381 | optional uint32 num_output = 3; // The number of outputs for the layer 1382 | optional bool biasterm = 4 [default = true]; // whether to have bias terms 1383 | optional FillerParameter weight_filler = 5; // The filler for the weight 1384 | optional FillerParameter bias_filler = 6; // The filler for the bias 1385 | 1386 | optional uint32 pad = 7 [default = 0]; // The padding size 1387 | optional uint32 kernelsize = 8; // The kernel size 1388 | optional uint32 group = 9 [default = 1]; // The group size for group conv 1389 | optional uint32 stride = 10 [default = 1]; // The stride 1390 | enum PoolMethod { 1391 | MAX = 0; 1392 | AVE = 1; 1393 | STOCHASTIC = 2; 1394 | } 1395 | optional PoolMethod pool = 11 [default = MAX]; // The pooling method 1396 | optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio 1397 | 1398 | optional uint32 local_size = 13 [default = 5]; // for local response norm 1399 | optional float alpha = 14 [default = 1.]; // for local response norm 1400 | optional float beta = 15 [default = 0.75]; // for local response norm 1401 | optional float k = 22 [default = 1.]; 1402 | 1403 | // For data layers, specify the data source 1404 | optional string source = 16; 1405 | // For data pre-processing, we can do simple scaling and subtracting the 1406 | // data mean, if provided. Note that the mean subtraction is always carried 1407 | // out before scaling. 1408 | optional float scale = 17 [default = 1]; 1409 | optional string meanfile = 18; 1410 | // For data layers, specify the batch size. 1411 | optional uint32 batchsize = 19; 1412 | // For data layers, specify if we would like to randomly crop an image. 1413 | optional uint32 cropsize = 20 [default = 0]; 1414 | // For data layers, specify if we want to randomly mirror data. 1415 | optional bool mirror = 21 [default = false]; 1416 | 1417 | // The blobs containing the numeric parameters of the layer 1418 | repeated BlobProto blobs = 50; 1419 | // The ratio that is multiplied on the global learning rate. If you want to 1420 | // set the learning ratio for one blob, you need to set it for all blobs. 1421 | repeated float blobs_lr = 51; 1422 | // The weight decay that is multiplied on the global weight decay. 1423 | repeated float weight_decay = 52; 1424 | 1425 | // The rand_skip variable is for the data layer to skip a few data points 1426 | // to avoid all asynchronous sgd clients to start at the same point. The skip 1427 | // point would be set as rand_skip * rand(0,1). Note that rand_skip should not 1428 | // be larger than the number of keys in the database. 1429 | optional uint32 rand_skip = 53 [default = 0]; 1430 | 1431 | // Fields related to detection (det_*) 1432 | // foreground (object) overlap threshold 1433 | optional float det_fg_threshold = 54 [default = 0.5]; 1434 | // background (non-object) overlap threshold 1435 | optional float det_bg_threshold = 55 [default = 0.5]; 1436 | // Fraction of batch that should be foreground objects 1437 | optional float det_fg_fraction = 56 [default = 0.25]; 1438 | 1439 | // optional bool OBSOLETE_can_clobber = 57 [default = true]; 1440 | 1441 | // Amount of contextual padding to add around a window 1442 | // (used only by the window_data_layer) 1443 | optional uint32 det_context_pad = 58 [default = 0]; 1444 | 1445 | // Mode for cropping out a detection window 1446 | // warp: cropped window is warped to a fixed size and aspect ratio 1447 | // square: the tightest square around the window is cropped 1448 | optional string det_crop_mode = 59 [default = "warp"]; 1449 | 1450 | // For ReshapeLayer, one needs to specify the new dimensions. 1451 | optional int32 new_num = 60 [default = 0]; 1452 | optional int32 new_channels = 61 [default = 0]; 1453 | optional int32 new_height = 62 [default = 0]; 1454 | optional int32 new_width = 63 [default = 0]; 1455 | 1456 | // Whether or not ImageLayer should shuffle the list of files at every epoch. 1457 | // It will also resize images if new_height or new_width are not zero. 1458 | optional bool shuffle_images = 64 [default = false]; 1459 | 1460 | // For ConcatLayer, one needs to specify the dimension for concatenation, and 1461 | // the other dimensions must be the same for all the bottom blobs. 1462 | // By default it will concatenate blobs along the channels dimension. 1463 | optional uint32 concat_dim = 65 [default = 1]; 1464 | 1465 | optional HDF5OutputParameter hdf5_output_param = 1001; 1466 | } 1467 | 1468 | message PReLUParameter { 1469 | // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: 1470 | // Surpassing Human-Level Performance on ImageNet Classification, 2015. 1471 | 1472 | // Initial value of a_i. Default is a_i=0.25 for all i. 1473 | optional FillerParameter filler = 1; 1474 | // Whether or not slope paramters are shared across channels. 1475 | optional bool channel_shared = 2 [default = false]; 1476 | } 1477 | --------------------------------------------------------------------------------