├── .gitignore ├── LICENSE ├── README.md ├── examples ├── curve_regression │ ├── .gitignore │ ├── train_and_test_epistemic.py │ └── train_and_test_epistemic_aleatoric.py ├── map_synthesis │ ├── .gitignore │ ├── preprocess.py │ └── train_and_test_pix2pix.py ├── miccai_endovis_segmentation │ ├── .gitignore │ ├── preprocess.py │ └── train_and_test_epistemic.py └── mnist_classification │ ├── .gitignore │ └── train_and_test_epistemic.py ├── figs ├── bayesian_unet.gif ├── curve_regression_1.png ├── curve_regression_2.png ├── map_synthesis_samples.png ├── miccai_endovis_segmentation.png ├── miccai_endovis_segmentation_samples.png ├── miccai_heatmap_regression.png ├── mnist_classification.png └── unet.png ├── pytorch_bcnn ├── __init__.py ├── data │ ├── __init__.py │ ├── augmentor │ │ ├── __init__.py │ │ ├── image.py │ │ └── volume.py │ ├── io │ │ ├── __init__.py │ │ └── mhd.py │ └── normalizer │ │ ├── __init__.py │ │ ├── image.py │ │ └── volume.py ├── datasets │ ├── __init__.py │ ├── image.py │ └── volume.py ├── extensions │ ├── __init__.py │ ├── log_report.py │ ├── print_report.py │ └── validator.py ├── functions │ ├── __init__.py │ ├── accuracy │ │ ├── __init__.py │ │ ├── discrete_dice.py │ │ └── discrete_jaccard.py │ ├── crop.py │ ├── loss │ │ ├── __init__.py │ │ ├── _helper.py │ │ ├── dice.py │ │ ├── jaccard.py │ │ ├── noised_cross_entropy.py │ │ ├── noised_mean_squared_error.py │ │ ├── sigmoid_cross_entropy.py │ │ ├── sigmoid_soft_cross_entropy.py │ │ └── softmax_cross_entropy.py │ └── stride_pooling.py ├── inference │ ├── __init__.py │ └── inferencer.py ├── initializers │ ├── __init__.py │ └── bilinear_upsample.py ├── links │ ├── __init__.py │ ├── classifier.py │ ├── connection │ │ ├── __init__.py │ │ └── pixel_shuffle_upsampler.py │ ├── mc_sampler.py │ ├── noise │ │ ├── __init__.py │ │ └── mc_dropout.py │ └── regressor.py ├── models │ ├── __init__.py │ ├── discriminators │ │ ├── __init__.py │ │ ├── discriminator_base.py │ │ └── patch_discriminator.py │ └── unet │ │ ├── __init__.py │ │ ├── _helper.py │ │ ├── bayesian_unet.py │ │ ├── unet.py │ │ └── unet_base.py ├── updaters │ ├── __init__.py │ └── cgan │ │ ├── __init__.py │ │ ├── _replay_buffer.py │ │ ├── dcgan.py │ │ └── lsgan.py ├── utils │ └── __init__.py └── visualizer │ ├── __init__.py │ └── image.py ├── recipe ├── AIO.def ├── AIO.def.in ├── Makefile └── python_packages.txt ├── requirements.txt ├── setup.py └── tests ├── lenna.png ├── test_augmentator_2d.py ├── test_augmentator_3d.py ├── test_cross_entropy.py ├── test_dataset.py ├── test_dice_loss.py ├── test_discriminator.py ├── test_inferencer.py ├── test_initializer.py ├── test_mc_sampler.py ├── test_model.py ├── test_normalizer.py ├── test_seed.py └── test_visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # mydata 107 | .mhd 108 | .mha 109 | .raw -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/curve_regression/.gitignore: -------------------------------------------------------------------------------- 1 | /logs 2 | -------------------------------------------------------------------------------- /examples/curve_regression/train_and_test_epistemic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from pytorch_bcnn.links.noise import MCDropout 12 | from pytorch_bcnn.links import Regressor 13 | from pytorch_bcnn.links import MCSampler 14 | from pytorch_bcnn.inference import Inferencer 15 | from pytorch_bcnn.utils import fixed_seed 16 | from pytorch_trainer import iterators 17 | from pytorch_trainer import dataset 18 | from pytorch_trainer import training 19 | from pytorch_trainer.training import extensions 20 | 21 | 22 | class Dataset(dataset.DatasetMixin): 23 | 24 | def __init__(self, 25 | func=lambda x: x*np.sin(x), 26 | n_samples=100, 27 | x_lim=(-5, 5), 28 | dtype=np.float32): 29 | 30 | assert callable(func) 31 | assert isinstance(x_lim, (list, tuple)) 32 | 33 | x = np.random.rand(n_samples, 1)*(x_lim[1] - x_lim[0]) + x_lim[0] 34 | x = np.sort(x, axis=0) 35 | t = func(x) 36 | 37 | self._func = func 38 | self._x_lim = x_lim 39 | self._x = x.astype(dtype) 40 | self._t = t.astype(dtype) 41 | 42 | @property 43 | def x(self): # NOTE: input 44 | return self._x 45 | 46 | @property 47 | def y(self): # NOTE: observation 48 | return self._t 49 | 50 | @property 51 | def t(self): # NOTE: ground-truth 52 | return self._t 53 | 54 | def __len__(self): 55 | return len(self._x) 56 | 57 | @dataset.convert_to_tensor 58 | def get_example(self, i): 59 | return self.x[i], self.y[i] 60 | 61 | 62 | class BayesianMLP(nn.Module): 63 | 64 | def __init__(self, n_in, n_units, n_out, drop_ratio): 65 | super(BayesianMLP, self).__init__() 66 | 67 | self.n_in = n_in 68 | self.n_units = n_units 69 | self.n_out = n_out 70 | self.drop_ratio = drop_ratio 71 | 72 | self.l1 = nn.Linear(n_in, n_units) 73 | self.l2 = nn.Linear(n_units, n_units) 74 | self.l3 = nn.Linear(n_units, n_out) 75 | 76 | self.dropout = MCDropout(drop_ratio) 77 | 78 | def forward(self, x): 79 | h = F.relu(self.l1(x)) 80 | h = self.dropout(h) 81 | h = F.relu(self.l2(h)) 82 | h = self.dropout(h) 83 | return self.l3(h) 84 | 85 | 86 | def train_phase(predictor, train, valid, args): 87 | 88 | # visualize 89 | plt.rcParams['font.size'] = 18 90 | plt.figure(figsize=(13,5)) 91 | ax = sns.scatterplot(x=train.x.ravel(), y=train.y.ravel(), color='blue', s=55, alpha=0.3) 92 | ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2) 93 | ax.set_xlabel('x') 94 | ax.set_ylabel('y') 95 | ax.set_xlim(-10, 10) 96 | ax.set_ylim(-15, 15) 97 | plt.legend(['Ground-truth', 'Observation']) 98 | plt.title('Training data set') 99 | plt.tight_layout() 100 | plt.savefig(os.path.join(args.out, 'train_dataset.png')) 101 | plt.close() 102 | 103 | # setup iterators 104 | train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True) 105 | valid_iter = iterators.SerialIterator(valid, args.batchsize, repeat=False, shuffle=False) 106 | 107 | # setup a model 108 | device = torch.device(args.gpu) 109 | 110 | model = Regressor(predictor) 111 | model.to(device) 112 | 113 | # setup an optimizer 114 | optimizer = torch.optim.Adam(model.parameters(), 115 | weight_decay=max(args.decay, 0)) 116 | 117 | # setup a trainer 118 | updater = training.updaters.StandardUpdater( 119 | train_iter, optimizer, model, device=device) 120 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) 121 | 122 | trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu)) 123 | 124 | # trainer.extend(DumpGraph(model, 'main/loss')) 125 | 126 | frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) 127 | trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) 128 | 129 | trainer.extend(extensions.LogReport()) 130 | 131 | if args.plot and extensions.PlotReport.available(): 132 | trainer.extend( 133 | extensions.PlotReport(['main/loss', 'validation/main/loss'], 134 | 'epoch', file_name='loss.png')) 135 | trainer.extend( 136 | extensions.PlotReport( 137 | ['main/accuracy', 'validation/main/accuracy'], 138 | 'epoch', file_name='accuracy.png')) 139 | 140 | trainer.extend(extensions.PrintReport( 141 | ['epoch', 'iteration', 'main/loss', 'validation/main/loss', 142 | 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) 143 | 144 | trainer.extend(extensions.ProgressBar()) 145 | 146 | if args.resume: 147 | trainer.load_state_dict(torch.load(args.resume)) 148 | 149 | trainer.run() 150 | 151 | torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth')) 152 | 153 | 154 | def test_phase(predictor, test, args): 155 | 156 | # setup an iterator 157 | test_iter = iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) 158 | 159 | # setup an inferencer 160 | predictor.load_state_dict(torch.load(os.path.join(args.out, 'predictor.pth'))) 161 | 162 | model = MCSampler(predictor, 163 | mc_iteration=args.mc_iteration, 164 | activation=None, 165 | reduce_mean=None, 166 | reduce_var=None) 167 | 168 | device = torch.device(args.gpu) 169 | model.to(device) 170 | 171 | infer = Inferencer(test_iter, model, device=args.gpu) 172 | 173 | pred, uncert = infer.run() 174 | 175 | 176 | # visualize 177 | x = test.x.ravel() 178 | t = test.t.ravel() 179 | pred = pred.ravel() 180 | uncert = uncert.ravel() 181 | 182 | plt.rcParams['font.size'] = 18 183 | plt.figure(figsize=(13,5)) 184 | ax = sns.scatterplot(x=x, y=pred, color='blue', s=75) 185 | ax.errorbar(x, pred, yerr=uncert, fmt='none', capsize=10, ecolor='gray', linewidth=1.5) 186 | ax.plot(x, t, color='red', linewidth=1.5) 187 | ax.set_xlabel('x') 188 | ax.set_ylabel('y') 189 | ax.set_xlim(-10, 10) 190 | ax.set_ylim(-15, 15) 191 | plt.legend(['Ground-truth', 'Prediction', 'Predicted variance']) 192 | plt.title('Result on testing data set') 193 | plt.tight_layout() 194 | plt.savefig(os.path.join(args.out, 'eval.png')) 195 | plt.close() 196 | 197 | 198 | def main(): 199 | parser = argparse.ArgumentParser(description='Example: Uncertainty estimates in regression') 200 | parser.add_argument('--batchsize', '-b', type=int, default=50, 201 | help='Number of images in each mini-batch') 202 | parser.add_argument('--epoch', '-e', type=int, default=300, 203 | help='Number of sweeps over the dataset to train') 204 | parser.add_argument('--frequency', '-f', type=int, default=-1, 205 | help='Frequency of taking a snapshot') 206 | parser.add_argument('--gpu', '-g', type=str, default='cuda:0', 207 | help='GPU Device') 208 | parser.add_argument('--out', '-o', default='logs', 209 | help='Directory to output the log files') 210 | parser.add_argument('--resume', '-r', default='', 211 | help='Resume the training from snapshot') 212 | parser.add_argument('--unit', '-u', type=int, default=20, 213 | help='Number of units') 214 | parser.add_argument('--noplot', dest='plot', action='store_false', 215 | help='Disable PlotReport extension') 216 | parser.add_argument('--test_on_test', action='store_true', 217 | help='Switch to the testing phase on test dataset') 218 | parser.add_argument('--test_on_valid', action='store_true', 219 | help='Switch to the testing phase on valid dataset') 220 | parser.add_argument('--mc_iteration', type=int, default=50, 221 | help='Number of iteration of MCMC') 222 | parser.add_argument('--decay', type=float, default=-1, 223 | help='Weight of L2 regularization') 224 | parser.add_argument('--seed', type=int, default=0, 225 | help='Fix the random seed') 226 | args = parser.parse_args() 227 | 228 | 229 | os.makedirs(args.out, exist_ok=True) 230 | 231 | with fixed_seed(args.seed, strict=False): 232 | 233 | # setup a predictor 234 | predictor = BayesianMLP(n_in=1, n_units=args.unit, n_out=1, drop_ratio=0.1) 235 | 236 | # setup dataset 237 | train = Dataset(x_lim=(-5, 5), n_samples=1000) 238 | valid = Dataset(x_lim=(-5, 5), n_samples=1000) 239 | test = Dataset(x_lim=(-10, 10), n_samples=500) 240 | 241 | # run 242 | if args.test_on_test: 243 | test_phase(predictor, test, args) 244 | elif args.test_on_valid: 245 | test_phase(predictor, valid, args) 246 | else: 247 | train_phase(predictor, train, valid, args) 248 | 249 | if __name__ == '__main__': 250 | main() 251 | -------------------------------------------------------------------------------- /examples/map_synthesis/.gitignore: -------------------------------------------------------------------------------- 1 | /logs 2 | /preprocessed 3 | /temp 4 | -------------------------------------------------------------------------------- /examples/map_synthesis/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import numpy as np 5 | import tqdm 6 | import urllib.request 7 | from shutil import copyfile 8 | import tarfile 9 | 10 | from chainer_bcnn.data import load_image, save_image 11 | 12 | 13 | def my_hook(t): # https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py 14 | """Wraps tqdm instance. 15 | Don't forget to close() or __exit__() 16 | the tqdm instance once you're done with it (easiest using `with` syntax). 17 | Example 18 | ------- 19 | >>> with tqdm(...) as t: 20 | ... reporthook = my_hook(t) 21 | ... urllib.urlretrieve(..., reporthook=reporthook) 22 | """ 23 | last_b = [0] 24 | 25 | def update_to(b=1, bsize=1, tsize=None): 26 | """ 27 | b : int, optional 28 | Number of blocks transferred so far [default: 1]. 29 | bsize : int, optional 30 | Size of each block (in tqdm units) [default: 1]. 31 | tsize : int, optional 32 | Total size (in tqdm units). If [default: None] remains unchanged. 33 | """ 34 | if tsize is not None: 35 | t.total = tsize 36 | t.update((b - last_b[0]) * bsize) 37 | last_b[0] = b 38 | 39 | return update_to 40 | 41 | def download(url, out): 42 | 43 | os.makedirs(os.path.dirname(out), exist_ok=True) 44 | 45 | if not os.path.exists(out): 46 | with tqdm.tqdm(unit='B', unit_scale=True, miniters=1, ncols=80) as t: 47 | urllib.request.urlretrieve (url, out, reporthook=my_hook(t)) 48 | 49 | 50 | def preprocess_map(root, size, out): 51 | 52 | os.makedirs(out, exist_ok=True) 53 | 54 | files = glob.glob(os.path.join(root, '*.jpg')) 55 | 56 | for i, f in enumerate(files): 57 | 58 | img, _ = load_image(f) 59 | _, w, _ = img.shape 60 | 61 | img_a = img[:,:w//2,:].astype(np.float32) 62 | img_b = img[:,w//2:,:].astype(np.float32) 63 | 64 | img_a = cv2.resize(img_a, size) 65 | img_b = cv2.resize(img_b, size) 66 | 67 | img_a /= 127.5 68 | img_b /= 127.5 69 | 70 | img_a -= 1. 71 | img_b -= 1. 72 | 73 | save_image(os.path.join(out, '%04d_a.mha' % i), img_a) 74 | save_image(os.path.join(out, '%04d_b.mha' % i), img_b) 75 | 76 | 77 | if __name__ == '__main__': 78 | 79 | url = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz' 80 | temp = os.path.join('./temp', os.path.basename(url)) 81 | download(url, temp) 82 | 83 | with tarfile.open(temp, 'r:*') as tar: 84 | tar.extractall('./temp') 85 | 86 | preprocess_map('./temp/maps/train', (286,286), './preprocessed/train') 87 | preprocess_map('./temp/maps/val', (256,256), './preprocessed/val') 88 | -------------------------------------------------------------------------------- /examples/miccai_endovis_segmentation/.gitignore: -------------------------------------------------------------------------------- 1 | /preprocessed 2 | /logs 3 | -------------------------------------------------------------------------------- /examples/miccai_endovis_segmentation/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import tqdm 5 | import glob 6 | import zipfile 7 | import urllib.request 8 | from shutil import copyfile 9 | 10 | def my_hook(t): # https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py 11 | """Wraps tqdm instance. 12 | Don't forget to close() or __exit__() 13 | the tqdm instance once you're done with it (easiest using `with` syntax). 14 | Example 15 | ------- 16 | >>> with tqdm(...) as t: 17 | ... reporthook = my_hook(t) 18 | ... urllib.urlretrieve(..., reporthook=reporthook) 19 | """ 20 | last_b = [0] 21 | 22 | def update_to(b=1, bsize=1, tsize=None): 23 | """ 24 | b : int, optional 25 | Number of blocks transferred so far [default: 1]. 26 | bsize : int, optional 27 | Size of each block (in tqdm units) [default: 1]. 28 | tsize : int, optional 29 | Total size (in tqdm units). If [default: None] remains unchanged. 30 | """ 31 | if tsize is not None: 32 | t.total = tsize 33 | t.update((b - last_b[0]) * bsize) 34 | last_b[0] = b 35 | 36 | return update_to 37 | 38 | def download(url, out): 39 | 40 | os.makedirs(os.path.dirname(out), exist_ok=True) 41 | 42 | if not os.path.exists(out): 43 | with tqdm.tqdm(unit='B', unit_scale=True, miniters=1, ncols=80) as t: 44 | urllib.request.urlretrieve (url, out, reporthook=my_hook(t)) 45 | 46 | def unzip(zip_file, out): 47 | 48 | os.makedirs(os.path.dirname(out), exist_ok=True) 49 | 50 | with zipfile.ZipFile(zip_file) as existing_zip: 51 | existing_zip.extractall(out) 52 | 53 | def preprocess_images(files, out_dir): 54 | 55 | commonpath = os.path.commonpath(files) 56 | 57 | for f in tqdm.tqdm(files): 58 | out = os.path.join(out_dir, os.path.relpath(f, commonpath)) 59 | os.makedirs(os.path.dirname(out), exist_ok=True) 60 | copyfile(f, out) 61 | 62 | def preprocess_labels(files, out_dir, binary=True): 63 | 64 | commonpath = os.path.commonpath(files) 65 | 66 | for f in tqdm.tqdm(files): 67 | out = os.path.join(out_dir, os.path.relpath(f, commonpath)) 68 | os.makedirs(os.path.dirname(out), exist_ok=True) 69 | 70 | src = cv2.imread(f) 71 | src = src[:,:,0] 72 | 73 | dst = np.zeros(src.shape, src.dtype) 74 | 75 | if binary: 76 | dst[src!=0] = 1 77 | else: 78 | dst[src== 70] = 1 79 | dst[src==160] = 2 80 | 81 | cv2.imwrite(out, dst) 82 | 83 | def preprocess_train(out_dir, temp_dir=None): 84 | 85 | if temp_dir is None: 86 | temp_dir = os.path.join(out_dir, 'temp') 87 | 88 | train_url = 'http://opencas.webarchiv.kit.edu/data/endovis15_ins/Segmentation_Rigid_Training.zip' 89 | train_zip = os.path.join(temp_dir, os.path.basename(train_url)) 90 | train_dir = os.path.join(temp_dir, 'train', 'image_and_label') 91 | 92 | download(train_url, train_zip) 93 | unzip(train_zip, train_dir) 94 | 95 | train_image_files = glob.glob(os.path.join(train_dir, '**', '*_raw.png'), recursive=True) 96 | print('# train images:', len(train_image_files)) 97 | preprocess_images(train_image_files, os.path.join(out_dir, 'train')) 98 | 99 | train_label_files = glob.glob(os.path.join(train_dir, '**', '*_class.png'), recursive=True) 100 | print('# train labels:', len(train_label_files)) 101 | preprocess_labels(train_label_files, os.path.join(out_dir, 'train')) 102 | 103 | def preprocess_test(out_dir, temp_dir=None): 104 | 105 | if temp_dir is None: 106 | temp_dir = os.path.join(out_dir, 'temp') 107 | 108 | test_image_url = 'http://opencas.webarchiv.kit.edu/data/endovis15_ins/Segmentation_Rigid_Testing_Revision.zip' 109 | test_image_zip = os.path.join(temp_dir, os.path.basename(test_image_url)) 110 | test_image_dir = os.path.join(temp_dir, 'test', 'image') 111 | 112 | download(test_image_url, test_image_zip) 113 | unzip(test_image_zip, test_image_dir) 114 | 115 | test_label_url = 'http://opencas.webarchiv.kit.edu/data/endovis15_ins/Segmentation_Rigid_Testing_GT.zip' 116 | test_label_zip = os.path.join(temp_dir, os.path.basename(test_label_url)) 117 | test_label_dir = os.path.join(temp_dir, 'test', 'label') 118 | 119 | download(test_label_url, test_label_zip) 120 | unzip(test_label_zip, test_label_dir) 121 | 122 | test_image_files = glob.glob(os.path.join(test_image_dir, '**', '*_raw.png'), recursive=True) 123 | print('# test images:', len(test_image_files)) 124 | preprocess_images(test_image_files, os.path.join(out_dir, 'test')) 125 | 126 | test_label_files = glob.glob(os.path.join(test_label_dir, '**', '*_class.png'), recursive=True) 127 | print('# test labels:', len(test_label_files)) 128 | preprocess_labels(test_label_files, os.path.join(out_dir, 'test')) 129 | 130 | 131 | if __name__ == '__main__': 132 | 133 | out_dir = './preprocessed' 134 | 135 | preprocess_train(out_dir) 136 | preprocess_test(out_dir) 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /examples/mnist_classification/.gitignore: -------------------------------------------------------------------------------- 1 | /logs 2 | -------------------------------------------------------------------------------- /figs/bayesian_unet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/bayesian_unet.gif -------------------------------------------------------------------------------- /figs/curve_regression_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/curve_regression_1.png -------------------------------------------------------------------------------- /figs/curve_regression_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/curve_regression_2.png -------------------------------------------------------------------------------- /figs/map_synthesis_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/map_synthesis_samples.png -------------------------------------------------------------------------------- /figs/miccai_endovis_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/miccai_endovis_segmentation.png -------------------------------------------------------------------------------- /figs/miccai_endovis_segmentation_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/miccai_endovis_segmentation_samples.png -------------------------------------------------------------------------------- /figs/miccai_heatmap_regression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/miccai_heatmap_regression.png -------------------------------------------------------------------------------- /figs/mnist_classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/mnist_classification.png -------------------------------------------------------------------------------- /figs/unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/figs/unet.png -------------------------------------------------------------------------------- /pytorch_bcnn/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import data # NOQA 4 | from . import datasets # NOQA 5 | from . import extensions # NOQA 6 | from . import functions # NOQA 7 | from . import inference # NOQA 8 | from . import initializers # NOQA 9 | from . import links # NOQA 10 | from . import models # NOQA 11 | from . import utils # NOQA 12 | from . import visualizer # NOQA 13 | from . import updaters # NOQA 14 | -------------------------------------------------------------------------------- /pytorch_bcnn/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .io import load_image, save_image # NOQA 4 | -------------------------------------------------------------------------------- /pytorch_bcnn/data/augmentor/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from collections import OrderedDict 5 | from abc import ABCMeta, abstractmethod 6 | import json 7 | 8 | _channel_axis = 0 9 | 10 | 11 | class DataAugmentor(object): 12 | """ Data augmentor for image and volume data. 13 | This class manages the operations. 14 | """ 15 | 16 | def __init__(self, n_dim=None): 17 | 18 | assert n_dim is None or n_dim in [2, 3] 19 | 20 | self._n_dim = n_dim 21 | self._operations = [] 22 | 23 | def add(self, op): 24 | 25 | assert isinstance(op, Operation) 26 | if self._n_dim is None: # NOTE: auto set 27 | self._n_dim = op.ndim 28 | self._operations.append(op) 29 | 30 | def get(self): 31 | return self._operations 32 | 33 | def preprocess(self, x): 34 | 35 | is_expanded = False 36 | 37 | if x is not None: 38 | if isinstance(x, list): 39 | if x[0].ndim == self._n_dim: 40 | x = [np.expand_dims(x_i, _channel_axis) for x_i in x] 41 | is_expanded = True 42 | assert x[0].ndim == self._n_dim + 1, '`x[0].ndim` must be `self._n_dim + 1`' 43 | else: 44 | if x.ndim == self._n_dim: 45 | x = np.expand_dims(x, _channel_axis) 46 | is_expanded = True 47 | assert x.ndim == self._n_dim + 1, '`x.ndim` must be `self._n_dim + 1`' 48 | 49 | return x, is_expanded 50 | 51 | def postprocess(self, x, is_expanded): 52 | 53 | if not is_expanded: 54 | return x 55 | 56 | if x is not None: 57 | if isinstance(x, list): 58 | x = [np.rollaxis(x_i, _channel_axis, 0)[0] for x_i in x] 59 | else: 60 | x = np.rollaxis(x, _channel_axis, 0)[0] 61 | return x 62 | 63 | def apply(self, x=None, y=None): 64 | 65 | x, is_expanded_x = self.preprocess(x) 66 | y, is_expanded_y = self.preprocess(y) 67 | 68 | for op in self._operations: 69 | x, y = op.apply(x, y) 70 | 71 | x = self.postprocess(x, is_expanded_x) 72 | y = self.postprocess(y, is_expanded_y) 73 | 74 | assert(x is not None or y is not None) 75 | if x is None: 76 | return y 77 | if y is None: 78 | return x 79 | return x, y 80 | 81 | def __call__(self, x=None, y=None): 82 | return self.apply(x, y) 83 | 84 | def summary(self, out=None): 85 | 86 | ret = OrderedDict() 87 | 88 | for op in self._operations: 89 | name = op.__class__.__name__ 90 | 91 | if name in ret.keys(): 92 | cnt = 1 93 | while True: 94 | _name = '%s_%d' % (name, cnt) 95 | if _name not in ret.keys(): break 96 | cnt += 1 97 | name = _name 98 | 99 | args = op.summary().copy() 100 | ignore_keys = ['__class__', 'self'] 101 | for key in ignore_keys: 102 | if key in args.keys(): 103 | args.pop(key) 104 | 105 | ret[name] = args 106 | 107 | if out is None: 108 | return ret 109 | 110 | with open(out, 'w', encoding='utf-8') as f: 111 | json.dump(ret, f, ensure_ascii=False, indent=4) 112 | 113 | return ret 114 | 115 | 116 | class Operation(metaclass=ABCMeta): 117 | """ Base class of operations 118 | """ 119 | def __init__(self, *args, **kwargs): 120 | self._args = locals() 121 | 122 | def preprocess_core(self, x): 123 | if x is None: 124 | return x 125 | elif isinstance(x, list): 126 | return x 127 | else: 128 | return [x] # NOTE: to list 129 | 130 | def preprocess(self, x, y): 131 | x = self.preprocess_core(x) 132 | y = self.preprocess_core(y) 133 | return x, y 134 | 135 | def postprocess_core(self, x): 136 | if x is None: 137 | return x 138 | elif len(x) == 1: 139 | return x[0] 140 | else: 141 | return x 142 | 143 | def postprocess(self, x, y): 144 | x = self.postprocess_core(x) 145 | y = self.postprocess_core(y) 146 | return x, y 147 | 148 | @abstractmethod 149 | def apply_core(self, x, y): 150 | raise NotImplementedError() 151 | 152 | def apply(self, x=None, y=None): 153 | x, y = self.preprocess(x, y) 154 | x, y = self.apply_core(x, y) 155 | x, y = self.postprocess(x, y) 156 | return x, y 157 | 158 | @property 159 | @abstractmethod 160 | def ndim(self): 161 | raise NotImplementedError() 162 | 163 | def summary(self): 164 | return self._args 165 | 166 | 167 | from .image import Flip as Flip2D # NOQA 168 | from .image import Crop as Crop2D # NOQA 169 | from .image import ResizeCrop as ResizeCrop2D # NOQA 170 | from .image import Affine as Affine2D # NOQA 171 | 172 | from .volume import Flip as Flip3D # NOQA 173 | from .volume import Crop as Crop3D # NOQA 174 | from .volume import ResizeCrop as ResizeCrop3D # NOQA 175 | from .volume import Affine as Affine3D # NOQA 176 | -------------------------------------------------------------------------------- /pytorch_bcnn/data/io/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import os 5 | import cv2 6 | 7 | from . import mhd 8 | 9 | 10 | def load_image(filename): 11 | """ Load a two/three dimensional image from given filename 12 | 13 | Args: 14 | filename (str) 15 | 16 | Returns: 17 | numpy.ndarray: An image 18 | list of float: Spacing 19 | """ 20 | 21 | _, ext = os.path.splitext(os.path.basename(filename)) 22 | 23 | if ext in ('.mha', '.mhd'): 24 | [img, img_header] = mhd.read(filename) 25 | spacing = img_header['ElementSpacing'] 26 | img.flags.writeable = True 27 | if img.ndim == 3: 28 | img = np.transpose(img, (1, 2, 0)) 29 | 30 | elif ext in ('.png', '.jpg', '.bmp'): 31 | img = cv2.imread(filename) 32 | spacing = None 33 | 34 | else: 35 | raise NotImplementedError() 36 | 37 | return img, spacing 38 | 39 | 40 | def save_image(filename, image, spacing=None): 41 | """ Save a two/three dimensional image 42 | 43 | Args: 44 | filename (str) 45 | image (numpy.ndarray): A two/three dimensional image 46 | spacing (list of float, optional): Spacing. Defaults to None. 47 | """ 48 | 49 | dirname = os.path.dirname(filename) 50 | if dirname != '': 51 | os.makedirs(dirname, exist_ok=True) 52 | _, ext = os.path.splitext(os.path.basename(filename)) 53 | 54 | if ext in ('.mha', '.mhd'): 55 | header = {} 56 | if spacing is not None: 57 | header['ElementSpacing'] = spacing 58 | if image.ndim == 2: 59 | header['TransformMatrix'] = '1 0 0 1' 60 | header['Offset'] = '0 0' 61 | header['CenterOfRotation'] = '0 0' 62 | elif image.ndim == 3: 63 | image = image.transpose((2, 0, 1)) 64 | header['TransformMatrix'] = '1 0 0 0 1 0 0 0 1' 65 | header['Offset'] = '0 0 0' 66 | header['CenterOfRotation'] = '0 0 0' 67 | else: 68 | raise NotImplementedError() 69 | mhd.write(filename, image, header) 70 | 71 | elif ext in ('.png', '.jpg', '.bmp'): 72 | cv2.imwrite(filename, image) 73 | 74 | else: 75 | raise NotImplementedError() 76 | -------------------------------------------------------------------------------- /pytorch_bcnn/data/normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | from .. import augmentor 6 | 7 | class Normalizer(augmentor.DataAugmentor): 8 | 9 | def apply(self, x): 10 | 11 | x, is_expanded_x = self.preprocess(x) 12 | 13 | for op in self._operations: 14 | x = op.apply(x) 15 | 16 | x = self.postprocess(x, is_expanded_x) 17 | 18 | return x 19 | 20 | def __call__(self, x): 21 | return self.apply(x) 22 | 23 | class Operation(augmentor.Operation): 24 | 25 | def preprocess(self, x): 26 | x = self.preprocess_core(x) 27 | return x 28 | 29 | def postprocess(self, x): 30 | x = self.postprocess_core(x) 31 | return x 32 | 33 | @abstractmethod 34 | def apply_core(self, x): 35 | raise NotImplementedError() 36 | 37 | def apply(self, x): 38 | x = self.preprocess(x) 39 | x = self.apply_core(x) 40 | x = self.postprocess(x) 41 | return x 42 | 43 | from .image import Quantize as Quantize2D # NOQA 44 | from .image import Clip as Clip2D # NOQA 45 | from .image import Subtract as Subtract2D # NOQA 46 | from .image import Divide as Divide2D # NOQA 47 | 48 | from .volume import Quantize as Quantize3D # NOQA 49 | from .volume import Clip as Clip3D # NOQA 50 | from .volume import Subtract as Subtract3D # NOQA 51 | from .volume import Divide as Divide3D # NOQA 52 | -------------------------------------------------------------------------------- /pytorch_bcnn/data/normalizer/image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from . import Operation 5 | 6 | _row_axis = 1 7 | _col_axis = 2 8 | _channel_axis = 0 9 | 10 | def quantize(x, n_bit, x_min=0., x_max=1., rescale=True): 11 | 12 | n_discrete_values = 1 << n_bit 13 | scale = (n_discrete_values - 1) / (x_max - x_min) 14 | quantized = np.round(x * scale) - np.round(x_min * scale) 15 | quantized = np.clip(quantized, 0., n_discrete_values) 16 | 17 | if not rescale: 18 | return quantized 19 | 20 | quantized /= scale 21 | return quantized + x_min 22 | 23 | class Quantize(Operation): 24 | """ Quantize the given images to specific resolution. 25 | 26 | Non-linearity and overfitting make the neural networks sensitive to tiny noises in high-dimensional data [1]. 27 | Quantizing it to a necessary and sufficient level may be effective, especially for medical images which have >16 bits information. 28 | [1] Goodfellow et al., "Explaining and harnessing adversarial examples.", 2014. https://arxiv.org/abs/1412.6572 29 | 30 | Args: 31 | n_bit (int): Number of bits. 32 | x_min (float, optional): Minimum value in the input domain. Defaults to 0. 33 | x_max (float, optional): Maximum value in the input domain. Defaults to 1. 34 | rescale (bool, optional): If True, output value is rescaled to input domain. Defaults to True. 35 | """ 36 | def __init__(self, n_bit, x_min=0., x_max=1., rescale=True): 37 | self._args = locals() 38 | self._n_bit = n_bit 39 | self._x_min = x_min 40 | self._x_max = x_max 41 | self._rescale = rescale 42 | self._ndim = 2 43 | 44 | @property 45 | def ndim(self): 46 | return self._ndim 47 | 48 | def apply_core(self, x): 49 | x = [quantize(x_i, self._n_bit, 50 | self._x_min, self._x_max, self._rescale) 51 | for x_i in x] 52 | return x 53 | 54 | 55 | def clip(x, param, scale=1.): 56 | 57 | if isinstance(param, str): 58 | if param == 'minmax': 59 | param = (np.min(x), np.max(x)) 60 | elif param == 'ch_minmax': 61 | tmp = np.swapaxes(x, _channel_axis, 0) 62 | tmp = np.reshape(tmp, (len(tmp), -1)) 63 | tmp_shape = [len(tmp)] + [1] * (x.ndim - 1) 64 | param = (np.min(tmp, axis=1).reshape(tmp_shape), 65 | np.max(tmp, axis=1).reshape(tmp_shape)) 66 | else: 67 | raise NotImplementedError('unsupported parameters..') 68 | 69 | assert isinstance(param, (list, tuple)) 70 | x = (x - param[0]) / (param[1] - param[0]) # [0, 1] 71 | x = np.clip(x, 0., 1.) 72 | 73 | return x * scale 74 | 75 | class Clip(Operation): 76 | """ Clip (limit) the values in given images. 77 | 78 | Args: 79 | param (tuple or str): Tuple of minimum and maximum values. 80 | If 'minmax' or 'ch_minmax', minimum and maximum values are automatically estimated. 81 | 'ch_minmax' is the channel-wise minmax normalization. 82 | """ 83 | def __init__(self, param): 84 | self._args = locals() 85 | self._param = param 86 | self._ndim = 2 87 | 88 | @property 89 | def ndim(self): 90 | return self._ndim 91 | 92 | def apply_core(self, x): 93 | x = [clip(x_i, self._param) for x_i in x] 94 | return x 95 | 96 | 97 | def subtract(x, param): 98 | 99 | if isinstance(param, str): 100 | if param == 'mean': # NOTE: for z-score normalization 101 | param = np.mean(x) 102 | elif param == 'ch_mean': 103 | tmp = np.swapaxes(x, _channel_axis, 0) 104 | tmp = np.reshape(tmp, (len(tmp), -1)) 105 | tmp_shape = [len(tmp)] + [1] * (x.ndim - 1) 106 | param = np.mean(tmp, axis=1).reshape(tmp_shape) 107 | else: 108 | raise NotImplementedError('unsupported parameters..') 109 | 110 | x -= param 111 | 112 | return x 113 | 114 | class Subtract(Operation): 115 | """ Subtract a value or tensor from given images. 116 | 117 | Args: 118 | param (float, numpy.ndarray or str): A value or tensor. 119 | If 'mean' or 'ch_mean', subtracting values are automatically estimated. 120 | 'ch_mean' is to subtract the channel-wise mean. 121 | """ 122 | def __init__(self, param): 123 | self._args = locals() 124 | self._param = param 125 | self._ndim = 2 126 | 127 | @property 128 | def ndim(self): 129 | return self._ndim 130 | 131 | def apply_core(self, x): 132 | x = [subtract(x_i, self._param) for x_i in x] 133 | return x 134 | 135 | 136 | def divide(x, param): 137 | 138 | if isinstance(param, str): 139 | if param == 'std': # NOTE: for z-score normalization 140 | param = np.std(x) 141 | elif param == 'ch_std': 142 | tmp = np.swapaxes(x, _channel_axis, 0) 143 | tmp = np.reshape(tmp, (len(tmp), -1)) 144 | tmp_shape = [len(tmp)] + [1] * (x.ndim - 1) 145 | param = np.std(tmp, axis=1).reshape(tmp_shape) 146 | else: 147 | raise NotImplementedError('unsupported parameters..') 148 | 149 | x /= param 150 | 151 | return x 152 | 153 | class Divide(Operation): 154 | """ Divide the given images by a value or tensor 155 | 156 | Args: 157 | param (float, numpy.ndarray or str): A value or tensor. 158 | If 'std' or 'ch_std', deviding values are automatically estimated. 159 | 'ch_std' is to divide the channel-wise standard deviation. 160 | """ 161 | def __init__(self, param): 162 | self._args = locals() 163 | self._param = param 164 | self._ndim = 2 165 | 166 | @property 167 | def ndim(self): 168 | return self._ndim 169 | 170 | def apply_core(self, x): 171 | x = [divide(x_i, self._param) for x_i in x] 172 | return x 173 | -------------------------------------------------------------------------------- /pytorch_bcnn/data/normalizer/volume.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import image 4 | 5 | _row_axis = 1 6 | _col_axis = 2 7 | _depth_axis = 3 8 | _channel_axis = 0 9 | 10 | assert image._channel_axis == _channel_axis 11 | 12 | class Quantize(image.Quantize): 13 | """ Quantize the given images to specific resolution. 14 | 15 | Non-linearity and overfitting make the neural networks sensitive to tiny noises in high-dimensional data [1]. 16 | Quantizing it to a necessary and sufficient level may be effective, especially for medical images which have >16 bits information. 17 | [1] Goodfellow et al., "Explaining and harnessing adversarial examples.", 2014. https://arxiv.org/abs/1412.6572 18 | 19 | Args: 20 | n_bit (int): Number of bits. 21 | x_min (float, optional): Minimum value in the input domain. Defaults to 0.. 22 | x_max (float, optional): Maximum value in the input domain. Defaults to 1.. 23 | rescale (bool, optional): If True, output value is rescaled to input domain. Defaults to True. 24 | """ 25 | def __init__(self, n_bit, x_min=0., x_max=1., rescale=True): 26 | super().__init__(n_bit, x_min, x_max, rescale) 27 | self._ndim = 3 28 | 29 | class Clip(image.Clip): 30 | """ Clip (limit) the values in given images. 31 | 32 | Args: 33 | param (tuple or str): Tuple of minimum and maximum values. 34 | If 'minmax' or 'ch_minmax', minimum and maximum values are automatically estimated. 35 | 'ch_minmax' is the channel-wise minmax normalization. 36 | """ 37 | def __init__(self, param): 38 | super().__init__(param) 39 | self._ndim = 3 40 | 41 | class Subtract(image.Subtract): 42 | """ Subtract a value or tensor from given images. 43 | 44 | Args: 45 | param (float, numpy.ndarray or str): A value or tensor. 46 | If 'mean' or 'ch_mean', subtracting values are automatically estimated. 47 | 'ch_mean' is to subtract the channel-wise mean. 48 | """ 49 | def __init__(self, param): 50 | super().__init__(param) 51 | self._ndim = 3 52 | 53 | class Divide(image.Divide): 54 | """ Divide the given images by a value or tensor 55 | 56 | Args: 57 | param (float, numpy.ndarray or str): A value or tensor. 58 | If 'std' or 'ch_std', deviding values are automatically estimated. 59 | 'ch_std' is to divide the channel-wise standard deviation. 60 | """ 61 | def __init__(self, param): 62 | super().__init__(param) 63 | self._ndim = 3 64 | 65 | -------------------------------------------------------------------------------- /pytorch_bcnn/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from pytorch_trainer.dataset import DatasetMixin 5 | from pytorch_trainer.dataset import convert_to_tensor 6 | import tqdm 7 | import glob 8 | import warnings 9 | from abc import ABCMeta, abstractmethod 10 | from collections import OrderedDict 11 | 12 | from ..data import load_image # NOQA 13 | 14 | class BaseDataset(DatasetMixin, metaclass=ABCMeta): 15 | """ Base class of dataset 16 | 17 | Args: 18 | root (str): Directory to the dataset 19 | patients (list, optional): List of patient names. Defaults to []. 20 | classes (None or list, optional): List of class names. Defaults to None. 21 | dtypes (dict, optional): An dictionary of data types. Defaults to {}. 22 | filenames (dict, optional): An dictionary of wildcard to filenames. 23 | Each filename can be a format string using '{root}' and '{patient}'. Defaults to {}. 24 | normalizer (callable, optional): An callable function for normalization. Defaults to None. 25 | augmentor (callable, optional): An callable function for data augmentation. Defaults to None. 26 | """ 27 | def __init__(self, 28 | root, 29 | patients=[], 30 | classes=None, 31 | dtypes={}, 32 | filenames={}, 33 | normalizer=None, 34 | augmentor=None): 35 | 36 | super(BaseDataset, self).__init__() 37 | 38 | assert isinstance(patients, (list, np.ndarray)), \ 39 | 'please specify the patient names..' 40 | if classes is not None: 41 | if isinstance(classes, list): 42 | classes = np.asarray(classes) 43 | assert isinstance(classes, np.ndarray), \ 44 | 'class names should be list or np.ndarray..' 45 | assert isinstance(dtypes, dict), \ 46 | 'please specify the dtype per each file..' 47 | assert isinstance(filenames, dict), \ 48 | 'please specify the filename per each file..' 49 | if normalizer is not None: 50 | assert callable(normalizer), 'normalizer should be callable..' 51 | if augmentor is not None: 52 | assert callable(augmentor), 'augmentor should be callable..' 53 | 54 | # initialize 55 | files = OrderedDict() 56 | file_sizes = [] 57 | 58 | for key in filenames.keys(): 59 | 60 | files[key] = [] 61 | for p in tqdm.tqdm(patients, desc='Collecting %s files' % key, ncols=80): 62 | files[key].extend( 63 | glob.glob(filenames[key].format(root=root, patient=p))) 64 | 65 | if len(files[key]) == 0: 66 | warnings.warn('%s files are not found.. ' % key) 67 | file_sizes.append(len(files[key])) 68 | 69 | assert all(file_sizes[0] == s for s in file_sizes), \ 70 | 'the number of files must be the same..' 71 | 72 | self._root = root 73 | self._patients = patients 74 | self._classes = classes 75 | self._dtypes = dtypes 76 | self._filenames = filenames 77 | self._files = files 78 | self._normalizer = normalizer 79 | self._augmentor = augmentor 80 | 81 | def __len__(self): 82 | key = list(self._files.keys())[0] 83 | return len(self._files[key]) 84 | 85 | @property 86 | def classes(self): 87 | return self._classes 88 | 89 | @property 90 | def n_classes(self): 91 | if self.classes is None: 92 | return None 93 | return len(self.classes) 94 | 95 | @property 96 | def files(self): 97 | return self._files 98 | 99 | @property 100 | def dtypes(self): 101 | return self._dtypes 102 | 103 | @property 104 | def normalizer(self): 105 | return self._normalizer 106 | 107 | @property 108 | def augmentor(self): 109 | return self._augmentor 110 | 111 | @augmentor.deleter 112 | def augmentor(self): 113 | self._augmentor = None 114 | 115 | @classmethod 116 | @abstractmethod 117 | def normalize(self, **kwargs): 118 | raise NotImplementedError() 119 | 120 | @classmethod 121 | @abstractmethod 122 | def denormalize(self, **kwargs): 123 | raise NotImplementedError() 124 | 125 | @classmethod 126 | @abstractmethod 127 | @convert_to_tensor 128 | def get_example(self, i): 129 | raise NotImplementedError() 130 | 131 | @classmethod 132 | @abstractmethod 133 | def __copy__(self): 134 | """Copy the class instance""" 135 | raise NotImplementedError() 136 | 137 | from .volume import VolumeDataset # NOQA 138 | from .image import ImageDataset # NOQA 139 | 140 | def train_valid_split(train, valid_ratio): 141 | 142 | if isinstance(train, BaseDataset): 143 | 144 | valid = train.__copy__() 145 | 146 | n_samples = len(train) 147 | 148 | valid_indices = np.random.choice(np.arange(n_samples), 149 | int(valid_ratio * n_samples), 150 | replace=False) 151 | files = train.files 152 | 153 | for key in files.keys(): 154 | valid._files[key] = np.asarray(files[key])[valid_indices] 155 | train._files[key] = np.delete( 156 | np.asarray(files[key]), valid_indices) 157 | 158 | elif isinstance(train, (list, np.ndarray)): 159 | 160 | valid = np.asarray(train) 161 | 162 | n_samples = len(train) 163 | 164 | valid_indices = np.random.choice(np.arange(n_samples), 165 | int(valid_ratio * n_samples), 166 | replace=False) 167 | 168 | valid = valid[valid_indices] 169 | train = np.delete(train, valid_indices) 170 | 171 | assert len(train) + len(valid) == n_samples 172 | 173 | return train, valid 174 | 175 | 176 | def load_crossval_list(xls_file, index): 177 | import pandas as pd 178 | from distutils.version import LooseVersion 179 | 180 | if LooseVersion(pd.__version__) >= LooseVersion('0.21.0'): 181 | df = pd.read_excel(xls_file, sheet_name=index) 182 | else: 183 | df = pd.read_excel(xls_file, sheetname=index) 184 | 185 | train = df['train'].dropna().tolist() 186 | valid = df['valid'].dropna().tolist() 187 | test = df['test'].dropna().tolist() 188 | 189 | return {'train': train, 'valid': valid, 'test': test} 190 | -------------------------------------------------------------------------------- /pytorch_bcnn/datasets/image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from collections import OrderedDict 5 | from inspect import signature 6 | 7 | from . import BaseDataset 8 | from . import convert_to_tensor 9 | from ..data import load_image 10 | 11 | _supported_filetypes = [ 12 | 'image', 13 | 'label', 14 | 'mask', 15 | ] 16 | 17 | _default_dtypes = OrderedDict({ 18 | 'image': np.float32, 19 | 'label': np.int32, 20 | 'mask': np.uint8, 21 | }) 22 | 23 | _default_filenames = OrderedDict({ 24 | 'image': '{root}/{patient}/image.mha', 25 | 'label': '{root}/{patient}/label.mha', 26 | 'mask': '{root}/{patient}/mask.mha', 27 | }) 28 | 29 | _default_mask_cvals = OrderedDict({ 30 | 'image': 0, 31 | 'label': 0, 32 | }) 33 | 34 | _channel_axis = 0 35 | 36 | 37 | def _inspect_n_args(func): 38 | sig = signature(func) 39 | return len(sig.parameters) 40 | 41 | 42 | class ImageDataset(BaseDataset): 43 | """ Dataset for two-dimensional images 44 | 45 | Args: 46 | root (str): Directory to the dataset 47 | patients (list, optional): List of patient names. Defaults to []. 48 | classes (None or list, optional): List of class names. Defaults to None. 49 | dtypes (dict, optional): An dictionary of data types. 50 | Defaults to {'image': np.float32, 'label': np.int32, 'mask': np.uint8}. 51 | filenames (dict, optional): An dictionary of wildcard to filenames. 52 | Each filename can be a format string using '{root}' and '{patient}'. 53 | Defaults to {'image': '{root}/{patient}/image.mha', 54 | 'label': '{root}/{patient}/label.mha', 'mask': '{root}/{patient}/mask.mha'}. 55 | normalizer (callable, optional): An callable function for normalization. Defaults to None. 56 | augmentor (callable, optional): An callable function for data augmentation. Defaults to None. 57 | mask_cvals (dict, optional): Value used for points outside the mask. 58 | Defaults to {'image': 0, 'label': 0} 59 | """ 60 | def __init__(self, 61 | root, 62 | patients=[], 63 | classes=None, 64 | dtypes=_default_dtypes, 65 | filenames=_default_filenames, 66 | normalizer=None, 67 | augmentor=None, 68 | mask_cvals=_default_mask_cvals): 69 | 70 | for key in filenames.keys(): 71 | if key not in _supported_filetypes: 72 | raise KeyError('unsupported filetype.. <%s>' % key) 73 | 74 | super(ImageDataset, self).__init__( 75 | root, patients, classes, dtypes, 76 | filenames, normalizer, augmentor) 77 | 78 | self._mask_cvals = mask_cvals 79 | 80 | def normalize(self, x, y=None): 81 | 82 | # reshape 83 | if x.ndim == 2: 84 | x = x[np.newaxis] 85 | elif x.ndim == 3: 86 | x = np.transpose(x, (2, 0, 1)) # [c, w, h] 87 | 88 | if y is not None: 89 | # NOTE: assume that `y` is categorical label 90 | if y.dtype in [np.int32, np.int64]: 91 | if y.ndim == 3: 92 | if y.shape[-1] in [1, 3]: 93 | y = y[:, :, 0] # NOTE: ad-hoc 94 | else: 95 | pass 96 | 97 | # NOTE: assume that `y` is continuous label (e.g., heatmap) 98 | elif y.dtype in [np.float32, np.float64]: 99 | if y.ndim == 2: 100 | y = y[np.newaxis] 101 | elif y.ndim == 3: 102 | y = np.transpose(y, (2, 0, 1)) # [c, w, h] 103 | 104 | else: 105 | raise NotImplementedError('unsupported dtype..') 106 | 107 | # normalizer 108 | if self.normalizer is not None: 109 | if _inspect_n_args(self.normalizer) == 2: 110 | x, y = self.normalizer(x, y) 111 | else: 112 | x = self.normalizer(x) 113 | 114 | return x, y 115 | 116 | def denormalize(self, x, y=None): 117 | raise NotImplementedError() 118 | 119 | def masking(self, x, y, mask): 120 | 121 | if x.ndim -1 != mask.ndim: 122 | mask = np.squeeze(mask, -1) 123 | 124 | x[:, mask==0] = self._mask_cvals['image'] 125 | if y is not None: 126 | y[mask==0] = self._mask_cvals['label'] 127 | 128 | return x, y 129 | 130 | def load_images(self, i): 131 | 132 | images, spacings = {}, {} 133 | 134 | for key in self.files.keys(): 135 | 136 | images[key], spacings[key] = \ 137 | load_image(self.files[key][i]) 138 | 139 | images[key] = images[key].astype(self.dtypes[key]) 140 | 141 | return images, spacings 142 | 143 | @convert_to_tensor 144 | def get_example(self, i): 145 | 146 | # load 147 | images, _ = self.load_images(i) 148 | 149 | image = images['image'] 150 | label = images.get('label') 151 | mask = images.get('mask') 152 | 153 | # transfrom 154 | image, label = self.normalize(image, label) 155 | 156 | # masking 157 | if mask is not None: 158 | image, label = self.masking(image, label, mask) 159 | 160 | # augment 161 | if self.augmentor is not None: 162 | if _inspect_n_args(self.augmentor) == 2: 163 | image, label = self.augmentor(image, label) 164 | else: 165 | image = self.augmentor(image) 166 | 167 | # return 168 | if label is None: 169 | return image 170 | 171 | return image, label 172 | 173 | def __copy__(self): 174 | 175 | return ImageDataset( 176 | self._root, 177 | self._patients, 178 | self._classes, 179 | self._dtypes, 180 | self._filenames, 181 | self._normalizer, 182 | self._augmentor, 183 | self._mask_cvals) 184 | -------------------------------------------------------------------------------- /pytorch_bcnn/datasets/volume.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | 5 | from .image import ImageDataset 6 | from .image import _inspect_n_args 7 | 8 | class VolumeDataset(ImageDataset): 9 | """ Dataset for three-dimensional images 10 | 11 | Args: 12 | root (str): Directory to the dataset 13 | patients (list, optional): List of patient names. Defaults to []. 14 | classes (None or list, optional): List of class names. Defaults to None. 15 | dtypes (dict, optional): An dictionary of data types. 16 | Defaults to {'image': np.float32, 'label': np.int32, 'mask': np.uint8}. 17 | filenames (dict, optional): An dictionary of wildcard to filenames. 18 | Each filename can be a format string using '{root}' and '{patient}'. 19 | Defaults to {'image': '{root}/{patient}/image.mha', 20 | 'label': '{root}/{patient}/label.mha', 'mask': '{root}/{patient}/mask.mha'}. 21 | normalizer (callable, optional): An callable function for normalization. Defaults to None. 22 | augmentor (callable, optional): An callable function for data augmentation. Defaults to None. 23 | mask_cvals (dict, optional): Value used for points outside the mask. 24 | Defaults to {'image': 0, 'label': 0} 25 | """ 26 | 27 | def normalize(self, x, y=None): 28 | 29 | # reshape 30 | if x.ndim == 3: 31 | x = x[np.newaxis] 32 | elif x.ndim == 4: 33 | x = np.transpose(x, (3, 0, 1, 2)) # [c, w, h, d] 34 | 35 | if y is not None: 36 | # NOTE: assume that `y` is categorical label 37 | if y.dtype in [np.int32, np.int64]: 38 | if y.ndim == 4: 39 | if y.shape[-1] == 1: 40 | y = y[:, :, :, 0] 41 | else: 42 | pass 43 | 44 | # NOTE: assume that `y` is continuous label (e.g., heatmap) 45 | elif y.dtype in [np.float32, np.float64]: 46 | if y.ndim == 3: 47 | y = y[np.newaxis] 48 | elif y.ndim == 4: 49 | y = np.transpose(y, (3, 0, 1, 2)) # [c, w, h, d] 50 | 51 | else: 52 | raise NotImplementedError('unsupported dtype..') 53 | 54 | # normalizer 55 | if self.normalizer is not None: 56 | if _inspect_n_args(self.normalizer) == 2: 57 | x, y = self.normalizer(x, y) 58 | else: 59 | x = self.normalizer(x) 60 | 61 | return x, y 62 | 63 | def denormalize(self, x, y=None): 64 | raise NotImplementedError() 65 | 66 | 67 | class VolumeSliceDataset(ImageDataset): 68 | pass 69 | -------------------------------------------------------------------------------- /pytorch_bcnn/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .validator import Validator # NOQA 4 | from .log_report import LogReport # NOQA 5 | from .print_report import PrintReport # NOQA 6 | -------------------------------------------------------------------------------- /pytorch_bcnn/extensions/log_report.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import json 4 | import os 5 | import six 6 | from pytorch_trainer.training import extensions 7 | 8 | 9 | class LogReport(extensions.LogReport): 10 | 11 | """__init__(\ 12 | keys=None, trigger=(1, 'epoch'), postprocess=None, filename='log') 13 | Trainer extension to output the accumulated results to a log file. 14 | This extension accumulates the observations of the trainer to 15 | :class:`~chainer.DictSummary` at a regular interval specified by a supplied 16 | trigger, and writes them into a log file in JSON format. 17 | There are two triggers to handle this extension. One is the trigger to 18 | invoke this extension, which is used to handle the timing of accumulating 19 | the results. It is set to ``1, 'iteration'`` by default. The other is the 20 | trigger to determine when to emit the result. When this trigger returns 21 | True, this extension appends the summary of accumulated values to the list 22 | of past summaries, and writes the list to the log file. Then, this 23 | extension makes a new fresh summary object which is used until the next 24 | time that the trigger fires. 25 | It also adds some entries to each result dictionary. 26 | - ``'epoch'`` and ``'iteration'`` are the epoch and iteration counts at the 27 | output, respectively. 28 | - ``'elapsed_time'`` is the elapsed time in seconds since the training 29 | begins. The value is taken from :attr:`Trainer.elapsed_time`. 30 | 31 | Args: 32 | keys (iterable of strs): Keys of values to accumulate. If this is None, 33 | all the values are accumulated and output to the log file. 34 | trigger: Trigger that decides when to aggregate the result and output 35 | the values. This is distinct from the trigger of this extension 36 | itself. If it is a tuple in the form ``, 'epoch'`` or 37 | ``, 'iteration'``, it is passed to :class:`IntervalTrigger`. 38 | postprocess: Callback to postprocess the result dictionaries. Each 39 | result dictionary is passed to this callback on the output. This 40 | callback can modify the result dictionaries, which are used to 41 | output to the log file. 42 | log_json_name (str): Name of the log file for json format under the output 43 | directory. It can be a format string: the last result dictionary is 44 | passed for the formatting. For example, users can use '{iteration}' 45 | to separate the log files for different iterations. If the log name 46 | is None, it does not output the log to any file. 47 | log_csv_name (str): Name of the log file for csv format under the output 48 | directory. 49 | """ 50 | 51 | def __init__(self, keys=None, trigger=(1, 'iteration'), postprocess=None, 52 | json_name='log', csv_name='log.csv', **kwargs): 53 | 54 | super(LogReport, self).__init__( 55 | keys, trigger, postprocess, json_name, **kwargs) 56 | 57 | self._log_json_name = self._log_name 58 | self._log_csv_name = csv_name 59 | 60 | def _write_json_log(self, path, _dict, indent=4): 61 | """ Append data in JSON format to the end of a JSON file. 62 | In the original implementation, if you save it for each iteration, 63 | you will write out more of it at once and it will be slower. 64 | NOTE: In the original implementation, saving per each iteration might slow down the computation time. 65 | NOTE: Assumes file contains a JSON object (like a Python 66 | dict) ending in '}'. 67 | """ 68 | 69 | with open(path, 'ab') as fp: 70 | fp.seek(0, 2) # Go to the end of file 71 | if fp.tell() == 0: # Check if file is empty 72 | new_ending = json.dumps(_dict, indent=indent) 73 | new_ending = new_ending.split('\n') 74 | new_ending = [' '*indent + x for x in new_ending] 75 | new_ending = '\n'.join(new_ending) 76 | new_ending = '[\n' + new_ending + '\n]' 77 | fp.write(new_ending.encode()) 78 | 79 | else: 80 | fp.seek(-2, 2) 81 | fp.truncate() # Remove the last two character 82 | 83 | new_ending = json.dumps(_dict, indent=indent) 84 | new_ending = new_ending.split('\n') 85 | new_ending = [' ' * indent + x for x in new_ending] 86 | new_ending = '\n'.join(new_ending) 87 | new_ending = ',\n' + new_ending + '\n]' 88 | fp.write(new_ending.encode()) 89 | 90 | def _accumulate_observations(self, trainer): 91 | 92 | keys = self._keys 93 | observation = trainer.observation 94 | summary = self._summary 95 | 96 | if keys is None: 97 | summary.add(observation) 98 | else: 99 | summary.add({k: observation[k] for k in keys if k in observation}) 100 | 101 | return summary 102 | 103 | def initialize(self, trainer): 104 | 105 | keys = self._keys 106 | summary = self._accumulate_observations(trainer) 107 | 108 | # make header 109 | if keys is None: 110 | self._keys = ['epoch', 'iteration', 'elapsed_time'] 111 | self._keys.extend(sorted(summary._summaries.keys())) 112 | else: 113 | self._keys = ['epoch', 'iteration', 'elapsed_time'] 114 | for k in keys: 115 | if k not in self._keys: 116 | self._keys.append(k) 117 | 118 | self._log_csv_name = os.path.join(trainer.out, self._log_csv_name) 119 | self._log_json_name = os.path.join(trainer.out, self._log_json_name) 120 | 121 | os.makedirs(os.path.dirname(self._log_csv_name), exist_ok=True) 122 | os.makedirs(os.path.dirname(self._log_json_name), exist_ok=True) 123 | 124 | with open(self._log_csv_name, 'w+') as fp: 125 | fp.write(','.join(self._keys) + '\n') 126 | 127 | self.__call__(trainer) 128 | 129 | def _update(self, data): 130 | entry = {key: data[key] if key in data else None for key in self._keys} 131 | 132 | # write CSV file 133 | with open(self._log_csv_name, 'a') as fp: 134 | temp_list = [] 135 | for h in self._keys: 136 | if h in data.keys(): 137 | temp_list.append(str(data[h])) 138 | else: 139 | temp_list.append(','.join(' ')) 140 | fp.write(','.join(temp_list) + '\n') 141 | 142 | # write JSON file 143 | self._write_json_log(self._log_json_name, entry) 144 | 145 | def _init_trigger(self, trainer): 146 | return trainer.updater.iteration == 0 147 | 148 | def __call__(self, trainer): 149 | 150 | summary = self._accumulate_observations(trainer) 151 | updater = trainer.updater 152 | 153 | if self._trigger(trainer) or \ 154 | self._init_trigger(trainer): 155 | 156 | # output the result 157 | stats = summary.compute_mean() 158 | stats_cpu = {} 159 | for name, value in six.iteritems(stats): 160 | stats_cpu[name] = float(value) # copy to CPU 161 | 162 | stats_cpu['epoch'] = updater.epoch 163 | stats_cpu['iteration'] = updater.iteration 164 | stats_cpu['elapsed_time'] = trainer.elapsed_time 165 | 166 | if self._postprocess is not None: 167 | self._postprocess(stats_cpu) 168 | 169 | self._log.append(stats_cpu) 170 | 171 | # write to the log file 172 | self._update(stats_cpu) 173 | 174 | # reset the summary for the next output 175 | self._init_summary() 176 | -------------------------------------------------------------------------------- /pytorch_bcnn/extensions/print_report.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import sys 5 | from pytorch_trainer.training import extensions 6 | from pytorch_trainer.training.extensions import log_report as log_report_module 7 | from pytorch_trainer.training.extensions import util 8 | 9 | 10 | class PrintReport(extensions.PrintReport): 11 | 12 | """Trainer extension to print the accumulated results. 13 | 14 | This extension uses the log accumulated by a :class:`LogReport` extension 15 | to print specified entries of the log in a human-readable format. 16 | 17 | Args: 18 | entries (list of str): List of keys of observations to print. 19 | n_step (int): Number of steps to print the log. 20 | log_report (str or LogReport): Log report to accumulate the 21 | observations. This is either the name of a LogReport extensions 22 | registered to the trainer, or a LogReport instance to use 23 | internally. 24 | out: Stream to print the bar. Standard output is used by default. 25 | 26 | """ 27 | 28 | def __init__(self, entries, n_step=1, log_report='LogReport', out=sys.stdout): 29 | 30 | super(PrintReport, self).__init__( 31 | entries, log_report, out) 32 | 33 | self._n_step = n_step 34 | 35 | def __call__(self, trainer): 36 | out = self._out 37 | 38 | if self._header: 39 | out.write(self._header) 40 | self._header = None 41 | 42 | log_report = self._log_report 43 | if isinstance(log_report, str): 44 | log_report = trainer.get_extension(log_report) 45 | elif isinstance(log_report, log_report_module.LogReport): 46 | log_report(trainer) # update the log report 47 | else: 48 | raise TypeError('log report has a wrong type %s' % 49 | type(log_report)) 50 | 51 | log = log_report.log 52 | log_len = self._log_len 53 | while len(log) > log_len: 54 | # delete the printed contents from the current cursor 55 | if os.name == 'nt': 56 | util.erase_console(0, 0) 57 | else: 58 | out.write('\033[J') 59 | self._print(log[log_len]) 60 | log_len += self._n_step 61 | self._log_len = log_len 62 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import loss # NOQA 4 | from . import accuracy # NOQA 5 | from .crop import crop # NOQA 6 | from .stride_pooling import stride_pooling_2d # NOQA 7 | from .stride_pooling import stride_pooling_3d # NOQA 8 | from .stride_pooling import stride_pooling_nd # NOQA 9 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/accuracy/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .discrete_dice import softmax_discrete_dice 4 | from .discrete_jaccard import softmax_discrete_jaccard 5 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/accuracy/discrete_dice.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import six 4 | import numpy as np 5 | import torch 6 | 7 | from ..loss._helper import to_onehot 8 | 9 | def _check_type_forward(x, t): 10 | assert x.shape == t.shape, 'x.shape != t.shape..' 11 | 12 | 13 | def _discrete_dice(y, t, n_class, normalize=True, 14 | ignore_label=-1, eps=1e-08): 15 | """ Dice coefficient 16 | NOTE: This is not a differentiable function. 17 | See also: ~pytorch_bcnn.functions.loss.dice 18 | """ 19 | b = y.shape[0] 20 | c = n_class 21 | 22 | t_onehot = to_onehot(t, n_class=c) 23 | t_onehot = t_onehot.view(b, c, -1) 24 | 25 | y_onehot = to_onehot(y, n_class=c) 26 | y_onehot = y_onehot.view(b, c, -1) 27 | 28 | if ignore_label != -1: 29 | t_onehot = torch.cat( (t_onehot[:, :ignore_label], t_onehot[:, ignore_label + 1:]), dim=1) 30 | y_onehot = torch.cat( (y_onehot[:, :ignore_label], y_onehot[:, ignore_label + 1:]), dim=1) 31 | 32 | intersection = y_onehot * t_onehot 33 | cardinality = y_onehot + t_onehot 34 | 35 | if normalize: # NOTE: channel-wise 36 | intersection = torch.sum(intersection, dim=-1) 37 | cardinality = torch.sum(cardinality, dim=-1) 38 | ret = (2. * intersection / (cardinality + eps)) 39 | ret = torch.mean(ret, dim=1) 40 | 41 | else: 42 | intersection = torch.sum(intersection, dim=(0, 2)) 43 | cardinality = torch.sum(cardinality, dim=(0, 2)) 44 | ret = (2. * intersection / (cardinality + eps)) 45 | 46 | return torch.mean(ret) 47 | 48 | 49 | def softmax_discrete_dice(y, t, normalize=True, ignore_label=-1, eps=1e-8): 50 | """ Dice coefficient with Softmax pre-activates. 51 | 52 | Args: 53 | y (~torch.Tensor): Logits 54 | t (~torch.Tensor): Ground-truth label 55 | normalize (bool, optional): If True, calculate the dice coefficients for each class and take the average. Defaults to True. 56 | ignore_label (int, optional): Defaults to -1. 57 | eps (float, optional): Defaults to 1e-08. 58 | 59 | NOTE: This is not a differentiable function. 60 | See also: ~pytorch_bcnn.functions.loss.dice 61 | """ 62 | n_class = y.shape[1] 63 | y = torch.argmax(y, dim=1) 64 | return _discrete_dice(y, t, n_class, normalize=normalize, 65 | ignore_label=ignore_label, 66 | eps=eps)(y, t) 67 | 68 | def sigmoid_discrete_dice(y, t, eps=1e-8): 69 | raise NotImplementedError() 70 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/accuracy/discrete_jaccard.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import six 4 | import numpy as np 5 | import torch 6 | 7 | from ..loss._helper import to_onehot 8 | 9 | def _check_type_forward(x, t): 10 | assert x.shape == t.shape, 'x.shape != t.shape..' 11 | 12 | 13 | def _discrete_jaccard(y, t, n_class, normalize=True, 14 | ignore_label=-1, eps=1e-08): 15 | """ Jaccard index 16 | NOTE: This is not a differentiable function. 17 | See also: ~pytorch_bcnn.functions.loss.jaccard 18 | """ 19 | b = y.shape[0] 20 | c = n_class 21 | 22 | t_onehot = to_onehot(t, n_class=c) 23 | t_onehot = t_onehot.reshape(b, c, -1) 24 | 25 | y_onehot = to_onehot(y, n_class=c) 26 | y_onehot = y_onehot.reshape(b, c, -1) 27 | 28 | if ignore_label != -1: 29 | t_onehot = torch.cat( (t_onehot[:, :ignore_label], t_onehot[:, ignore_label + 1:]), dim=1) 30 | y_onehot = torch.cat( (y_onehot[:, :ignore_label], y_onehot[:, ignore_label + 1:]), dim=1) 31 | 32 | intersection = y_onehot * t_onehot 33 | cardinality = y_onehot + t_onehot 34 | 35 | if normalize: # NOTE: channel-wise 36 | intersection = torch.sum(intersection, dim=-1) 37 | cardinality = torch.sum(cardinality, dim=-1) 38 | union = cardinality - intersection 39 | ret = (2. * intersection / (union + eps)) 40 | ret = torch.mean(ret, dim=1) 41 | 42 | else: 43 | intersection = torch.sum(intersection, dim=(0, 2)) 44 | cardinality = torch.sum(cardinality, dim=(0, 2)) 45 | union = cardinality - intersection 46 | ret = (2. * intersection / (union + eps)) 47 | 48 | return torch.mean(ret) 49 | 50 | 51 | def softmax_discrete_jaccard(y, t, normalize=True, ignore_label=-1, eps=1e-8): 52 | """ Jaccard index with Softmax pre-activates. 53 | 54 | Args: 55 | y (~torch.Tensor): Logits 56 | t (~torch.Tensor): Ground-truth label 57 | normalize (bool, optional): If True, calculate the jaccard indices for each class and take the average. Defaults to True. 58 | ignore_label (int, optional): Defaults to -1. 59 | eps (float, optional): Defaults to 1e-08. 60 | 61 | NOTE: This is not a differentiable function. 62 | See also: ~pytorch_bcnn.functions.loss.jaccard 63 | """ 64 | n_class = y.shape[1] 65 | y = torch.argmax(y, dim=1) 66 | return _discrete_jaccard(y, t, n_class, normalize=normalize, 67 | ignore_label=ignore_label, 68 | eps=eps)(y, t) 69 | 70 | def sigmoid_discrete_jaccard(y, t, eps=1e-8): 71 | raise NotImplementedError() 72 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/crop.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | def contiguous(func): 4 | 5 | def wrap(*args, **kwards): 6 | 7 | ret = func(*args, **kwards) 8 | 9 | return ret.contiguous() 10 | 11 | return wrap 12 | 13 | 14 | @contiguous 15 | def crop_2d(x, shape): 16 | left = (x.shape[2] - shape[2]) // 2 17 | top = (x.shape[3] - shape[3]) // 2 18 | right = left + shape[2] 19 | bottom = top + shape[3] 20 | assert left >= 0 and top >= 0 and \ 21 | right <= x.shape[2] and bottom <= x.shape[3], \ 22 | 'Cropping image is less shape than input shape.\n'\ 23 | 'Input shape:{}, Cropping shape:{}, (L,R,T,B):({},{},{},{})'.format( 24 | x.shape, shape, left, right, top, bottom) 25 | return x[:, :, left:right, top:bottom] 26 | 27 | 28 | @contiguous 29 | def crop_3d(x, shape): 30 | left = (x.shape[2] - shape[2]) // 2 31 | top = (x.shape[3] - shape[3]) // 2 32 | near = (x.shape[4] - shape[4]) // 2 33 | right = left + shape[2] 34 | bottom = top + shape[3] 35 | far = near + shape[4] 36 | assert left >= 0 and top >= 0 and near >= 0 and \ 37 | right <= x.shape[2] and bottom <= x.shape[3] and far <= x.shape[4],\ 38 | 'Cropping image is less shape than input shape.\n' \ 39 | 'Input shape:{}, Cropping shape:{}, (L,R,T,B,N,F):({},{},{},{},{},{})'.format( 40 | x.shape, shape, left, right, top, bottom, near, far) 41 | return x[:, :, left:right, top:bottom, near:far] 42 | 43 | 44 | @contiguous 45 | def crop_nd(x, shape): 46 | slices = [slice(0, x.shape[0]), slice(0, x.shape[1])] 47 | for n in range(2, x.ndim): 48 | start = (x.shape[n] - shape[n]) // 2 49 | end = start + shape[n] 50 | assert start >= 0 and end <= x.shape[n], \ 51 | 'Cropping image is less shape than input shape.\n' \ 52 | 'Dimension: {}, Cropping shape: {}, (Start, End): ({},{})'.format( 53 | n, x.shape, start, end) 54 | slices.append(slice(start, end)) 55 | return x[tuple(slices)] 56 | 57 | 58 | def crop(x, shape, ndim=None): 59 | """ Spatial cropping x by given shape 60 | 61 | Args: 62 | x (ndarray or Variable): Input tensor 63 | shape (tuple): Desired spatial shape 64 | ndim (int, optional): Input dimensions. If None, this will be estimated automatically. 65 | Defaults to None. 66 | 67 | Returns: 68 | ndarray or Variable: Cropped tensor 69 | """ 70 | 71 | if ndim is None: 72 | ndim = x.dim() - 2 73 | 74 | if len(shape) == ndim: 75 | shape = (None, None, ) + tuple(shape) 76 | elif len(shape) == (ndim + 2): 77 | pass 78 | else: 79 | raise ValueError('`len(shape)` must be equal to `x.dim` or `x.dim-2`..') 80 | 81 | if x.shape[2:] == shape[2:]: 82 | return x 83 | 84 | if ndim == 2: 85 | return crop_2d(x, shape) 86 | elif ndim == 3: 87 | return crop_3d(x, shape) 88 | 89 | return crop_nd(x, shape) 90 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .softmax_cross_entropy import softmax_cross_entropy # NOQA 4 | from .sigmoid_cross_entropy import sigmoid_cross_entropy # NOQA 5 | from .sigmoid_soft_cross_entropy import sigmoid_soft_cross_entropy # NOQA 6 | 7 | from .noised_mean_squared_error import noised_mean_squared_error # NOQA 8 | from .noised_cross_entropy import noised_softmax_cross_entropy # NOQA 9 | from .noised_cross_entropy import noised_sigmoid_cross_entropy # NOQA 10 | from .noised_cross_entropy import noised_sigmoid_soft_cross_entropy # NOQA 11 | 12 | from .dice import softmax_dice # NOQA 13 | from .dice import softmax_dice_loss # NOQA 14 | from .jaccard import softmax_jaccard # NOQA 15 | from .jaccard import softmax_jaccard_loss # NOQA 16 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/_helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | def to_onehot(t, n_class): 6 | 7 | dtype = t.dtype 8 | device = t.device 9 | 10 | t_onehot = torch.eye(n_class, dtype=dtype, device=device)[t] 11 | 12 | axes = tuple(range(t_onehot.dim())) 13 | axes = (axes[0], axes[-1],) + axes[1:-1] 14 | 15 | return t_onehot.permute(axes).contiguous() 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/dice.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from ._helper import to_onehot 6 | 7 | def _check_type_forward(x, t): 8 | assert t.dim() == x.dim() - 1, 't.dim() != x.dim() - 1..' 9 | assert x.shape[0] == t.shape[0], 'x.shape[0] != t.shape[0]..' 10 | assert x.shape[2:] == t.shape[1:], 'x.shape[2:] != t.shape[1:]..' 11 | 12 | 13 | def dice(y, t, normalize=True, class_weight=None, 14 | ignore_label=-1, reduce='mean', eps=1e-08): 15 | """ Differentable Dice coefficient. 16 | See: https://arxiv.org/pdf/1606.04797.pdf 17 | 18 | Args: 19 | y (~torch.Tensor): Probability 20 | t (~torch.Tensor): Ground-truth label 21 | normalize (bool, optional): If True, calculate the dice coefficients for each class and take the average. Defaults to True. 22 | class_weight (list or ndarray, optional): Defaults to None. 23 | ignore_label (int, optional): Defaults to -1. 24 | reduce (str, optional): Defaults to 'mean'. 25 | eps (float, optional): Defaults to 1e-08. 26 | """ 27 | _check_type_forward(y, t) 28 | 29 | device = y.device 30 | dtype = y.dtype 31 | 32 | if class_weight is not None: 33 | class_weight = torch.as_tensor(class_weight, dtype=dtype, device=device) 34 | 35 | b, c = y.shape[:2] 36 | t_onehot = to_onehot(t, n_class=c) 37 | 38 | y = y.view(b, c, -1) 39 | t_onehot = t_onehot.view(b, c, -1) 40 | 41 | if ignore_label != -1: 42 | t_onehot = torch.cat( (t_onehot[:, :ignore_label], t_onehot[:, ignore_label + 1:]), dim=1) 43 | y = torch.cat( (y[:, :ignore_label], y[:, ignore_label + 1:]), dim=1) 44 | 45 | intersection = y * t_onehot 46 | cardinality = y + t_onehot 47 | 48 | if normalize: # NOTE: channel-wise 49 | intersection = torch.sum(intersection, dim=-1) 50 | cardinality = torch.sum(cardinality, dim=-1) 51 | ret = (2. * intersection / (cardinality + eps)) 52 | if class_weight is not None: 53 | ret *= class_weight 54 | ret = torch.mean(ret, dim=1) 55 | 56 | else: 57 | intersection = torch.sum(intersection, dim=(0, 2)) 58 | cardinality = torch.sum(cardinality, dim=(0, 2)) 59 | ret = (2. * intersection / (cardinality + eps)) 60 | if class_weight is not None: 61 | ret *= class_weight 62 | 63 | if reduce == 'mean': 64 | ret = torch.mean(ret) 65 | else: 66 | raise NotImplementedError('unsupported reduce type..') 67 | 68 | return ret 69 | 70 | 71 | def softmax_dice(y, t, normalize=True, class_weight=None, 72 | ignore_label=-1, reduce='mean', eps=1e-08): 73 | """ Differentable Dice coefficient with Softmax pre-activates. 74 | See: https://arxiv.org/pdf/1606.04797.pdf 75 | 76 | Args: 77 | y (~torch.Tensor): Probability 78 | t (~torch.Tensor): Ground-truth label 79 | normalize (bool, optional): If True, calculate the dice coefficients for each class and take the average. Defaults to True. 80 | class_weight (list or ndarray, optional): Defaults to None. 81 | ignore_label (int, optional): Defaults to -1. 82 | reduce (str, optional): Defaults to 'mean'. 83 | eps (float, optional): Defaults to 1e-08. 84 | """ 85 | y = torch.softmax(y, dim=1) 86 | return dice(y, t, normalize, class_weight, 87 | ignore_label, reduce, eps) 88 | 89 | 90 | def softmax_dice_loss(y, t, normalize=True, class_weight=None, 91 | ignore_label=-1, reduce='mean', eps=1e-08): 92 | """ Differentable Dice-coefficient loss with Softmax pre-activates. 93 | See: https://arxiv.org/pdf/1606.04797.pdf 94 | 95 | Args: 96 | y (~torch.Tensor): Probability 97 | t (~torch.Tensor): Ground-truth label 98 | normalize (bool, optional): If True, calculate the dice coefficients for each class and take the average. Defaults to True. 99 | class_weight (list or ndarray, optional): Defaults to None. 100 | ignore_label (int, optional): Defaults to -1. 101 | reduce (str, optional): Defaults to 'mean'. 102 | eps (float, optional): Defaults to 1e-08. 103 | """ 104 | return 1.0 - softmax_dice(y, t, normalize, class_weight, 105 | ignore_label, reduce, eps) 106 | 107 | 108 | def sigmoid_dice(y, t, *args, **kwards): 109 | raise NotImplementedError() 110 | 111 | 112 | def sigmoid_dice_loss(y, t, *args, **kwards): 113 | raise NotImplementedError() 114 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/jaccard.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from ._helper import to_onehot 6 | 7 | def _check_type_forward(x, t): 8 | assert t.dim() == x.dim() - 1, 't.dim() != x.dim() - 1..' 9 | assert x.shape[0] == t.shape[0], 'x.shape[0] != t.shape[0]..' 10 | assert x.shape[2:] == t.shape[1:], 'x.shape[2:] != t.shape[1:]..' 11 | 12 | 13 | def jaccard(y, t, normalize=True, class_weight=None, 14 | ignore_label=-1, reduce='mean', eps=1e-08): 15 | """ Differentable Jaccard index. 16 | 17 | Args: 18 | y (~torch.Tensor): Probability 19 | t (~torch.Tensor): Ground-truth label 20 | normalize (bool, optional): If True, calculate the jaccard indices for each class and take the average. Defaults to True. 21 | class_weight (list or ndarray, optional): Defaults to None. 22 | ignore_label (int, optional): Defaults to -1. 23 | reduce (str, optional): Defaults to 'mean'. 24 | eps (float, optional): Defaults to 1e-08. 25 | """ 26 | _check_type_forward(y, t) 27 | 28 | device = y.device 29 | dtype = y.dtype 30 | 31 | if class_weight is not None: 32 | class_weight = torch.as_tensor(class_weight, dtype=dtype, device=device) 33 | 34 | b, c = y.shape[:2] 35 | t_onehot = to_onehot(t, n_class=c) 36 | 37 | y = y.view(b, c, -1) 38 | t_onehot = t_onehot.view(b, c, -1) 39 | 40 | if ignore_label != -1: 41 | t_onehot = torch.cat( (t_onehot[:, :ignore_label], t_onehot[:, ignore_label + 1:]), dim=1) 42 | y = torch.cat( (y[:, :ignore_label], y[:, ignore_label + 1:]), dim=1) 43 | 44 | intersection = y * t_onehot 45 | cardinality = y + t_onehot 46 | 47 | if normalize: # NOTE: channel-wise 48 | intersection = torch.sum(intersection, dim=-1) 49 | cardinality = torch.sum(cardinality, dim=-1) 50 | union = cardinality - intersection 51 | ret = (2. * intersection / (union + eps)) 52 | if class_weight is not None: 53 | ret *= class_weight 54 | ret = torch.mean(ret, dim=1) 55 | 56 | else: 57 | intersection = torch.sum(intersection, dim=(0, 2)) 58 | cardinality = torch.sum(cardinality, dim=(0, 2)) 59 | union = cardinality - intersection 60 | ret = (2. * intersection / (union + eps)) 61 | if class_weight is not None: 62 | ret *= class_weight 63 | 64 | if reduce == 'mean': 65 | ret = torch.mean(ret) 66 | else: 67 | raise NotImplementedError('unsupported reduce type..') 68 | 69 | return ret 70 | 71 | 72 | def softmax_jaccard(y, t, normalize=True, class_weight=None, 73 | ignore_label=-1, reduce='mean', eps=1e-08): 74 | """ Differentable Jaccard index with Softmax pre-activates. 75 | 76 | Args: 77 | y (~torch.Tensor): Probability 78 | t (~torch.Tensor): Ground-truth label 79 | normalize (bool, optional): If True, calculate the jaccard indices for each class and take the average. Defaults to True. 80 | class_weight (list or ndarray, optional): Defaults to None. 81 | ignore_label (int, optional): Defaults to -1. 82 | reduce (str, optional): Defaults to 'mean'. 83 | eps (float, optional): Defaults to 1e-08. 84 | """ 85 | y = torch.softmax(y, dim=1) 86 | return jaccard(y, t, normalize, class_weight, 87 | ignore_label, reduce, eps) 88 | 89 | 90 | def softmax_jaccard_loss(y, t, normalize=True, class_weight=None, 91 | ignore_label=-1, reduce='mean', eps=1e-08): 92 | """ Differentable Jaccard-index loss with Softmax pre-activates. 93 | 94 | Args: 95 | y (~torch.Tensor): Probability 96 | t (~torch.Tensor): Ground-truth label 97 | normalize (bool, optional): If True, calculate the jaccard indices for each class and take the average. Defaults to True. 98 | class_weight (list or ndarray, optional): Defaults to None. 99 | ignore_label (int, optional): Defaults to -1. 100 | reduce (str, optional): Defaults to 'mean'. 101 | eps (float, optional): Defaults to 1e-08. 102 | """ 103 | return 1.0 - softmax_jaccard(y, t, normalize, class_weight, 104 | ignore_label, reduce, eps) 105 | 106 | 107 | def sigmoid_jaccard(y, t, *args, **kwards): 108 | raise NotImplementedError() 109 | 110 | 111 | def sigmoid_jaccard_loss(y, t, *args, **kwards): 112 | raise NotImplementedError() 113 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/noised_cross_entropy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .softmax_cross_entropy import softmax_cross_entropy 6 | from .sigmoid_cross_entropy import sigmoid_cross_entropy 7 | from .sigmoid_soft_cross_entropy import sigmoid_soft_cross_entropy 8 | 9 | def noised_softmax_cross_entropy(y, t, mc_iteration, 10 | normalize=True, class_weight=None, 11 | ignore_label=-1, reduce='mean'): 12 | """ Softmax Cross-entropy for aleatoric uncertainty estimates. 13 | See: https://arxiv.org/pdf/1703.04977.pdf 14 | 15 | Args: 16 | y (list of ~torch.Tensor): logits and sigma 17 | t (~torch.Tensor): ground-truth 18 | mc_iteration (int): number of iteration of MCMC. 19 | normalize (bool, optional): Defaults to True. 20 | reduce (str, optional): Defaults to 'mean'. 21 | 22 | Returns: 23 | [~torch.Tensor]: Loss value. 24 | """ 25 | 26 | assert isinstance(y, (list, tuple)) 27 | 28 | logits, log_std = y 29 | 30 | assert logits.shape[0] == log_std.shape[0] 31 | assert log_std.shape[1] in (logits.shape[1], 1) 32 | assert logits.shape[2:] == log_std.shape[2:] 33 | 34 | dtype = logits.dtype 35 | device = logits.device 36 | 37 | ret = [] 38 | 39 | # std = torch.sqrt(torch.exp(log_var)) 40 | std = torch.exp(log_std) 41 | 42 | for _ in range(mc_iteration): 43 | noise = std * torch.empty(std.shape, dtype=dtype, device=device).normal_(0., 1.) 44 | loss = softmax_cross_entropy(logits + noise, t, 45 | normalize=normalize, 46 | class_weight=class_weight, 47 | ignore_label=ignore_label, 48 | reduce=reduce) 49 | ret.append(loss[None]) 50 | 51 | ret = torch.cat(ret, dim=0) 52 | 53 | if reduce == 'mean': 54 | return torch.mean(ret) 55 | 56 | return ret 57 | 58 | 59 | def noised_sigmoid_cross_entropy(y, t, mc_iteration, normalize=True, reduce='mean'): 60 | """ Sigmoid Cross-entropy for aleatoric uncertainty estimates. 61 | 62 | Args: 63 | y (list of ~torch.Tensor): logits and sigma 64 | t (~torch.Tensor): ground-truth 65 | mc_iteration (int): number of iteration of MCMC. 66 | normalize (bool, optional): Defaults to True. 67 | reduce (str, optional): Defaults to 'mean'. 68 | 69 | Returns: 70 | [~torch.Tensor]: Loss value. 71 | """ 72 | assert isinstance(y, (list, tuple)) 73 | 74 | logits, log_std = y 75 | 76 | assert logits.shape[0] == log_std.shape[0] 77 | assert log_std.shape[1] in (logits.shape[1], 1) 78 | assert logits.shape[2:] == log_std.shape[2:] 79 | assert logits.shape == t.shape 80 | 81 | dtype = logits.dtype 82 | device = logits.device 83 | 84 | ret = [] 85 | 86 | # std = torch.sqrt(torch.exp(log_var)) 87 | std = torch.exp(log_std) 88 | 89 | for _ in range(mc_iteration): 90 | noise = std * torch.empty(std.shape, dtype=dtype, device=device).normal_(0., 1.) 91 | loss = sigmoid_cross_entropy(logits + noise, t, 92 | normalize=normalize, 93 | reduce=reduce) 94 | ret.append(loss[None]) 95 | 96 | ret = torch.cat(ret, dim=0) 97 | 98 | if reduce == 'mean': 99 | return torch.mean(ret) 100 | 101 | return ret 102 | 103 | 104 | def noised_sigmoid_soft_cross_entropy(y, t, mc_iteration, normalize=True, reduce='mean'): 105 | """ Sigmoid Soft Cross-entropy for aleatoric uncertainty estimates. 106 | 107 | Args: 108 | y (list of ~torch.Tensor): logits and sigma 109 | t (~torch.Tensor): ground-truth 110 | mc_iteration (int): number of iteration of MCMC. 111 | normalize (bool, optional): Defaults to True. 112 | reduce (str, optional): Defaults to 'mean'. 113 | 114 | Returns: 115 | [~torch.Tensor]: Loss value. 116 | """ 117 | assert isinstance(y, (list, tuple)) 118 | 119 | logits, log_std = y 120 | 121 | assert logits.shape == log_std.shape 122 | assert logits.shape == t.shape 123 | 124 | dtype = logits.dtype 125 | device = logits.device 126 | 127 | ret = [] 128 | 129 | # std = torch.sqrt(torch.exp(log_var)) 130 | std = torch.exp(log_std) 131 | 132 | for _ in range(mc_iteration): 133 | noise = std * torch.empty(std.shape, dtype=dtype, device=device).normal_(0., 1.) 134 | loss = sigmoid_soft_cross_entropy(logits + noise, t, 135 | normalize=normalize, 136 | reduce=reduce) 137 | ret.append(loss[None]) 138 | 139 | ret = torch.cat(ret, dim=0) 140 | 141 | if reduce == 'mean': 142 | return torch.mean(ret) 143 | 144 | return ret 145 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/noised_mean_squared_error.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy 4 | import torch 5 | 6 | def _check_type_forward(logits, log_var, t): 7 | assert logits.dim() == t.dim(), 'logits.dim() != t.dim()..' 8 | assert log_var.dim() == t.dim(), 'log_var.dim() != t.dim()..' 9 | assert logits.shape == t.shape, 'logits.shape != t.shape..' 10 | assert log_var.shape[0] == t.shape[0], 'log_var.shape[0] != t.shape[0]..' 11 | assert log_var.shape[2:] == t.shape[2:], 'log_var.shape[0] != t.shape[0]..' 12 | 13 | 14 | def noised_squared_error(y, t, normalize=False): 15 | 16 | assert isinstance(y, (list,tuple)) 17 | logits, log_var = y 18 | 19 | _check_type_forward(logits, log_var, t) 20 | 21 | loss = torch.exp(- log_var) * (logits - t)**2. + log_var 22 | 23 | if normalize: 24 | count = loss.numel() 25 | else: 26 | count = len(loss) 27 | 28 | loss = torch.sum(loss / count) 29 | 30 | return loss 31 | 32 | 33 | def noised_mean_squared_error(y, t): 34 | """ Mean squared error for aleatoric uncertainty estimates. 35 | See: https://arxiv.org/pdf/1703.04977.pdf 36 | 37 | Args: 38 | y (list of ~torch.Tensor): logits and sigma 39 | t (~torch.Tensor): ground-truth 40 | 41 | Returns: 42 | [~torch.Tensor]: Loss value. 43 | """ 44 | return noised_squared_error(y, t, normalize=True) 45 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/sigmoid_cross_entropy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | _reduce_table = { 8 | 'mean': 'sum', 9 | 'no': 'none', 10 | } 11 | 12 | def _check_type_forward(x, t): 13 | assert x.shape == t.shape, 'x.shape != t.shape..' 14 | 15 | def sigmoid_cross_entropy(x, t, normalize=True, reduce='mean'): 16 | 17 | _check_type_forward(x, t) 18 | 19 | _reduce = _reduce_table[reduce] 20 | 21 | log1p_exp = torch.log1p(torch.exp(x)) 22 | loss = t * (log1p_exp - x) + (1 - t) * log1p_exp 23 | 24 | if _reduce == 'sum': 25 | if normalize: 26 | count = t.numel() 27 | else: 28 | count = len(t) 29 | count = max(count, 1.) 30 | 31 | loss /= count 32 | 33 | return loss 34 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/sigmoid_soft_cross_entropy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | _reduce_table = { 8 | 'mean': 'sum', 9 | 'no': 'none', 10 | } 11 | 12 | def _check_type_forward(x, t): 13 | assert x.shape == t.shape, 'x.shape != t.shape..' 14 | 15 | def sigmoid_soft_cross_entropy(x, t, normalize=True, reduce='mean'): 16 | 17 | _check_type_forward(x, t) 18 | 19 | _reduce = _reduce_table[reduce] 20 | 21 | log1p_exp = torch.log1p(torch.exp(x)) 22 | loss = t * (log1p_exp - x) + (1 - t) * log1p_exp 23 | 24 | if _reduce == 'sum': 25 | if normalize: 26 | count = t.numel() 27 | else: 28 | count = len(t) 29 | count = max(count, 1.) 30 | 31 | loss /= count 32 | 33 | return loss 34 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/loss/softmax_cross_entropy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | _reduce_table = { 8 | 'mean': 'sum', 9 | 'no': 'none', 10 | } 11 | 12 | def _check_type_forward(x, t): 13 | assert t.dim() == x.dim() - 1, 't.dim() != x.dim() - 1..' 14 | assert x.shape[0] == t.shape[0], 'x.shape[0] != t.shape[0]..' 15 | assert x.shape[2:] == t.shape[1:], 'x.shape[2:] != t.shape[1:]..' 16 | 17 | def softmax_cross_entropy(x, t, normalize=True, class_weight=None, 18 | ignore_label=-1, reduce='mean'): 19 | 20 | _check_type_forward(x, t) 21 | 22 | _reduce = _reduce_table[reduce] 23 | 24 | log_p = F.log_softmax(x, dim=1) 25 | loss = F.nll_loss(log_p, t, class_weight, None, ignore_label, None, _reduce) 26 | 27 | if _reduce == 'sum': 28 | if normalize: 29 | count = t.numel() 30 | else: 31 | count = len(t) 32 | count = max(count, 1.) 33 | 34 | loss /= count 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /pytorch_bcnn/functions/stride_pooling.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | def contiguous(func): 4 | 5 | def wrap(*args, **kwards): 6 | 7 | ret = func(*args, **kwards) 8 | 9 | return ret.contiguous() 10 | 11 | return wrap 12 | 13 | 14 | def _pair(x, ndim=2): 15 | if hasattr(x, '__getitem__'): 16 | return x 17 | return [x] * ndim 18 | 19 | 20 | @contiguous 21 | def stride_pooling_2d(x, stride): 22 | stride = _pair(stride, 2) 23 | return x[:, :, ::stride[0], ::stride[1]] 24 | 25 | 26 | @contiguous 27 | def stride_pooling_3d(x, stride): 28 | stride = _pair(stride, 3) 29 | return x[:, :, ::stride[0], ::stride[1], ::stride[2]] 30 | 31 | 32 | def stride_pooling_nd(x, stride): 33 | """ Spatial pooling by stride. 34 | 35 | Args: 36 | x (ndarray or Variable): Input tensor 37 | stride (tuple or int): Stride length 38 | 39 | Returns: 40 | ndarray or Variable: Output tensor 41 | """ 42 | 43 | ndim = x.dim() - 2 44 | 45 | if ndim == 2: 46 | return stride_pooling_2d(x, stride) 47 | elif ndim == 3: 48 | return stride_pooling_3d(x, stride) 49 | else: 50 | raise NotImplementedError('unsupported nd pooling..') 51 | -------------------------------------------------------------------------------- /pytorch_bcnn/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .inferencer import Inferencer # NOQA 4 | -------------------------------------------------------------------------------- /pytorch_bcnn/inference/inferencer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | 5 | import torch 6 | from pytorch_trainer.dataset import convert 7 | from pytorch_trainer.dataset import iterator as iterator_module 8 | from pytorch_trainer import iterators 9 | from pytorch_trainer import reporter as reporter_module 10 | from pytorch_trainer.training import extension 11 | 12 | import copy 13 | import six 14 | import tqdm 15 | import sys 16 | import traceback 17 | 18 | 19 | def _concat_arrays(arrays): 20 | """Concat CPU and GPU array 21 | 22 | Args: 23 | arrays (numpy.array or torch.Tensor): CPU or GPU array 24 | """ 25 | # torch 26 | if isinstance(arrays[0], torch.Tensor): 27 | return torch.cat(arrays) 28 | 29 | # numpy 30 | if not isinstance(arrays[0], np.ndarray): 31 | arrays = np.asarray(arrays) 32 | 33 | return np.concatenate(arrays) 34 | 35 | 36 | def _split_predictions(pred): 37 | """split preditions into list of array(s). 38 | Args: 39 | pred (list): A list of preditions. 40 | 41 | Returns: 42 | List of array(s) 43 | """ 44 | if len(pred) == 0: 45 | raise ValueError('prediction is empty') 46 | 47 | first_elem = pred[0] 48 | 49 | if isinstance(first_elem, (tuple, list)): 50 | result = [] 51 | 52 | for i in six.moves.range(len(first_elem)): 53 | result.append(_concat_arrays([example[i] for example in pred])) 54 | 55 | return tuple(result) 56 | 57 | elif isinstance(first_elem, dict): 58 | result = {} 59 | 60 | for key in first_elem: 61 | result[key] = _concat_arrays([example[key] for example in pred]) 62 | 63 | return result 64 | 65 | else: 66 | return _concat_arrays(pred) 67 | 68 | 69 | def _variable_to_array(var, to_numpy=True): 70 | 71 | if isinstance(var, (tuple, list)): 72 | array = var 73 | 74 | if to_numpy: 75 | array = [v.detach().cpu().numpy() for v in array] 76 | 77 | return tuple(array) 78 | 79 | elif isinstance(var, dict): 80 | array = {} 81 | for key, v in var.items(): 82 | if to_numpy: 83 | v = v.detach().cpu().numpy() 84 | array[key] = v 85 | 86 | return array 87 | else: 88 | array = var 89 | 90 | if to_numpy: 91 | array = array.detach().cpu().numpy() 92 | 93 | return array 94 | 95 | 96 | class Inferencer(object): 97 | """ The inferencing loop for PyTorch. 98 | 99 | Args: 100 | iterator: Dataset iterator for the training dataset. It can also be a 101 | dictionary that maps strings to iterators. 102 | If this is just an iterator, then the 103 | iterator is registered by the name ``'main'``. 104 | model: Model to predict outputs. It can also be a dictionary 105 | that maps strings to models. 106 | If this is just an model, then the model is 107 | registered by the name ``'main'``. 108 | converter (optional): Converter function to build input arrays. Each batch 109 | extracted by the main iterator and the ``device`` option are passed 110 | to this function. :func:`chainer.dataset.concat_examples` is used 111 | by default. 112 | device (int, optional): Device to which the training data is sent. Negative value 113 | indicates the host memory (CPU). Defaults to None. 114 | to_numpy (bool, optional): Allow the PyTorch's output tensor to be converted to Numpy. Defaults to True. 115 | """ 116 | 117 | def __init__(self, iterator, model, 118 | converter=convert.concat_examples, 119 | device=None, to_numpy=True): 120 | 121 | if device is not None: 122 | device = torch.device(device) 123 | 124 | if isinstance(iterator, iterator_module.Iterator): 125 | iterator = {'main': iterator} 126 | self._iterators = iterator 127 | 128 | if not isinstance(model, dict): 129 | model = {'main': model} 130 | self._model = model 131 | 132 | self.observation = {} 133 | reporter = reporter_module.Reporter() 134 | for name, target in six.iteritems(self._model): 135 | reporter.add_observer(name, target) 136 | reporter.add_observers( 137 | name + '/', target.named_children()) 138 | self.reporter = reporter 139 | 140 | self.converter = converter 141 | self.device = device 142 | self.to_numpy = to_numpy 143 | 144 | def get_model(self, name): 145 | return self._model[name] 146 | 147 | def get_iterator(self, name): 148 | return self._iterators[name] 149 | 150 | def predict(self, model, batch): 151 | ret = self.predict_core(model, batch) 152 | return ret 153 | 154 | def predict_core(self, model, batch): 155 | in_arrays = self.converter(batch, self.device) 156 | 157 | for m in self._model.values(): 158 | m.eval() 159 | 160 | with torch.no_grad(): 161 | if isinstance(in_arrays, tuple): 162 | y = model(*in_arrays) 163 | elif isinstance(in_arrays, dict): 164 | y = model(**in_arrays) 165 | else: 166 | y = model(in_arrays) 167 | 168 | return _variable_to_array(y, to_numpy=self.to_numpy) 169 | 170 | def finalize(self): 171 | for iterator in six.itervalues(self._iterators): 172 | iterator.finalize() 173 | 174 | def run(self): 175 | reporter = self.reporter 176 | 177 | iterator = self._iterators['main'] 178 | model = self._model['main'] 179 | 180 | if hasattr(iterator, 'reset'): 181 | iterator.reset() 182 | it = iterator 183 | else: 184 | it = copy.copy(iterator) 185 | 186 | rets = [] 187 | 188 | try: 189 | for batch in tqdm.tqdm(it, desc='inference', 190 | total=len(it.dataset) // it.batch_size, 191 | ncols=80, leave=False): 192 | with reporter.scope(self.observation): 193 | pred = self.predict(model, batch) 194 | rets.append(pred) 195 | 196 | except Exception as e: 197 | print('Exception in main inference loop: {}'.format(e), 198 | file=sys.stderr) 199 | print('Traceback (most recent call last):', file=sys.stderr) 200 | traceback.print_tb(sys.exc_info()[2]) 201 | six.reraise(*sys.exc_info()) 202 | 203 | finally: 204 | pass 205 | 206 | return _split_predictions(rets) 207 | 208 | def __del__(self): 209 | self.finalize() 210 | -------------------------------------------------------------------------------- /pytorch_bcnn/initializers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .bilinear_upsample import bilinear_upsample # NOQA 4 | -------------------------------------------------------------------------------- /pytorch_bcnn/initializers/bilinear_upsample.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def _kernel_center(ksize): 8 | 9 | center = [None] * len(ksize) 10 | factor = [None] * len(ksize) 11 | 12 | for i, s in enumerate(ksize): 13 | 14 | factor[i] = (s + 1) // 2 15 | if s % 2 == 1: 16 | center[i] = factor[i] - 1 17 | else: 18 | center[i] = factor[i] - 0.5 19 | 20 | return center, factor 21 | 22 | 23 | def _bilinear_kernel_2d(ksize): 24 | """ Get a kernel upsampling by bilinear interpolation 25 | 26 | Args: 27 | ksize (list of int): Kernel size. 28 | 29 | Returns: 30 | numpy.ndarray: A kernel. 31 | 32 | See also: 33 | https://arxiv.org/pdf/1411.4038.pdf 34 | https://github.com/d2l-ai/d2l-en/blob/master/chapter_computer-vision/fcn.md#initialize-the-transposed-convolution-layer 35 | """ 36 | 37 | assert len(ksize) == 2 38 | 39 | og = np.ogrid[:ksize[0], :ksize[1]] 40 | center, factor = _kernel_center(ksize) 41 | 42 | kernel = (1 - abs(og[0] - center[0]) / factor[0]) * \ 43 | (1 - abs(og[1] - center[1]) / factor[1]) 44 | 45 | return kernel 46 | 47 | 48 | def _bilinear_kernel_3d(ksize): 49 | 50 | assert len(ksize) == 3 51 | 52 | og = np.ogrid[:ksize[0], :ksize[1], :ksize[2]] 53 | center, factor = _kernel_center(ksize) 54 | 55 | kernel = (1 - abs(og[0] - center[0]) / factor[0]) * \ 56 | (1 - abs(og[1] - center[1]) / factor[1]) * \ 57 | (1 - abs(og[2] - center[2]) / factor[2]) 58 | 59 | return kernel 60 | 61 | 62 | def _bilinear_kernel_nd(ksize, dtype=np.float32): 63 | 64 | if len(ksize) == 2: 65 | kernel = _bilinear_kernel_2d(ksize) 66 | elif len(ksize) == 3: 67 | kernel = _bilinear_kernel_3d(ksize) 68 | else: 69 | raise NotImplementedError() 70 | 71 | return kernel.astype(dtype) 72 | 73 | 74 | def bilinear_upsample(tensor, gain=1.): 75 | """ Initializer of Bilinear upsampling kernel for convolutional weights. 76 | See also: https://arxiv.org/pdf/1411.4038.pdf 77 | """ 78 | shape = list(tensor.shape) 79 | 80 | dtype = tensor.dtype 81 | device = tensor.device 82 | 83 | if shape[0] != shape[1]: 84 | raise ValueError( 85 | 'The number of input and output channels are NOT same..') 86 | 87 | with torch.no_grad(): 88 | 89 | ksize = shape[2:] 90 | kernel = gain * _bilinear_kernel_nd(ksize) 91 | 92 | weight = np.zeros(shape) 93 | weight[range(shape[0]), range(shape[1]), ...] = kernel 94 | 95 | tensor[...] = torch.as_tensor(weight, dtype=dtype, device=device) 96 | -------------------------------------------------------------------------------- /pytorch_bcnn/links/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classifier import Classifier # NOQA 4 | from .regressor import Regressor # NOQA 5 | from .mc_sampler import MCSampler # NOQA 6 | -------------------------------------------------------------------------------- /pytorch_bcnn/links/classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | from pytorch_trainer import reporter 6 | from functools import partial 7 | 8 | def softmax_cross_entropy(y, t): 9 | import torch.nn.functional as F 10 | log_p = F.log_softmax(y, dim=1) 11 | loss = F.nll_loss(log_p, t, reduction='sum') 12 | return loss 13 | 14 | 15 | def accuracy(y, t): 16 | pred = y.argmax(1).reshape(t.shape) 17 | acc = (pred == t).mean(dtype=y.dtype) 18 | return acc 19 | 20 | 21 | def _get_value(args, kwargs, key): 22 | 23 | if not (isinstance(key, (int, str))): 24 | raise TypeError('key must be int or str, but is %s' % 25 | type(key)) 26 | 27 | if isinstance(key, int): 28 | if not (-len(args) <= key < len(args)): 29 | msg = 'key %d is out of bounds' % key 30 | raise ValueError(msg) 31 | value = args[key] 32 | 33 | elif isinstance(key, str): 34 | if key not in kwargs: 35 | msg = 'key "%s" is not found' % key 36 | raise ValueError(msg) 37 | value = kwargs[key] 38 | 39 | return value 40 | 41 | def get_values(args, kwargs, keys): 42 | 43 | getter = partial(_get_value, 44 | args=args, kwargs=kwargs) 45 | 46 | if isinstance(keys, (list,tuple)): 47 | return [getter(key=key) for key in keys] 48 | 49 | return getter(key=keys) 50 | 51 | 52 | class Classifier(nn.Module): 53 | 54 | """A simple classifier model. 55 | This is an example of chain that wraps another chain. It computes the 56 | loss and accuracy based on a given input/label pair. 57 | Args: 58 | predictor (~chainer.Link): Predictor network. 59 | lossfun (callable): 60 | Loss function. 61 | You can specify one of loss functions from 62 | :doc:`built-in loss functions `, or 63 | your own loss function (see the example below). 64 | It should not be an 65 | :doc:`loss functions with parameters ` 66 | (i.e., :class:`~chainer.Link` instance). 67 | The function must accept two argument (an output from predictor 68 | and its ground truth labels), and return a loss. 69 | Returned value must be a Variable derived from the input Variable 70 | to perform backpropagation on the variable. 71 | accfun (callable): 72 | Function that computes accuracy. 73 | You can specify one of evaluation functions from 74 | :doc:`built-in evaluation functions `, or 75 | your own evaluation function. 76 | The signature of the function is the same as ``lossfun``. 77 | activation (callable): 78 | Function that apply final activation functions to preditions. 79 | You can specify one of evaluation functions from 80 | :doc:`built-in activation functions `, or 81 | your own activation function. 82 | x_keys (tuple, int or str): Key to specify input variable from arguments. 83 | When it is ``int``, a variable in positional arguments is used. 84 | And when it is ``str``, a variable in keyword arguments is used. 85 | If you use multiple variables, please specify ``tuple`` of ``int`` or ``str``. 86 | t_keys (tuple, int or str): Key to specify label variable from arguments. 87 | When it is ``int``, a variable in positional arguments is used. 88 | And when it is ``str``, a variable in keyword arguments is used. 89 | If you use multiple variables, please specify ``tuple`` of ``int`` or ``str``. 90 | 91 | Attributes: 92 | predictor (~chainer.Link): Predictor network. 93 | lossfun (callable): 94 | Loss function. 95 | See the description in the arguments for details. 96 | accfun (callable): 97 | Function that computes accuracy. 98 | See the description in the arguments for details. 99 | activation (callable): 100 | Activation function after the predictor output. 101 | See the description in the arguments for details. 102 | x (~chainer.Variable or tuple): Inputs for the last minibatch. 103 | y (~chainer.Variable or tuple): Predictions for the last minibatch. 104 | t (~chainer.Variable or tuple): Labels for the last minibatch. 105 | loss (~chainer.Variable): Loss value for the last minibatch. 106 | accuracy (~chainer.Variable): Accuracy for the last minibatch. 107 | 108 | .. note:: 109 | This link uses :func:`chainer.softmax_cross_entropy` with 110 | default arguments as a loss function (specified by ``lossfun``), 111 | if users do not explicitly change it. In particular, the loss function 112 | does not support double backpropagation. 113 | If you need second or higher order differentiation, you need to turn 114 | it on with ``enable_double_backprop=True``: 115 | >>> import chainer.functions as F 116 | >>> import chainer.links as L 117 | >>> 118 | >>> def lossfun(x, t): 119 | ... return F.softmax_cross_entropy( 120 | ... x, t, enable_double_backprop=True) 121 | >>> 122 | >>> predictor = L.Linear(10) 123 | >>> model = L.Classifier(predictor, lossfun=lossfun) 124 | """ 125 | 126 | def __init__(self, predictor, 127 | lossfun=nn.CrossEntropyLoss(reduction='mean'), 128 | accfun=accuracy, 129 | activation=None, 130 | x_keys=(0), t_keys=(-1)): 131 | 132 | super(Classifier, self).__init__() 133 | 134 | assert callable(predictor), 'predictor should be callable..' 135 | if lossfun is not None: 136 | assert callable(lossfun), 'lossfun should be callable..' 137 | if accfun is not None: 138 | assert callable(accfun), 'accfun should be callable..' 139 | if activation is not None: 140 | assert callable(activation), 'activation should be callable..' 141 | 142 | 143 | self.add_module('predictor', predictor) 144 | 145 | self.lossfun = lossfun 146 | self.accfun = accfun 147 | self.activation = activation 148 | 149 | self.x_keys = x_keys 150 | self.t_keys = t_keys 151 | 152 | self._reset() 153 | 154 | def _reset(self): 155 | 156 | self.x = None 157 | self.y = None 158 | self.t = None 159 | 160 | self.loss = None 161 | self.accuracy = None 162 | 163 | def forward(self, *args, **kwargs): 164 | """Computes the loss value for input and label pair. 165 | It also computes accuracy and stores it to the attribute. 166 | Args: 167 | args (list of ~chainer.Variable): Input minibatch. 168 | kwargs (dict of ~chainer.Variable): Input minibatch. 169 | When ``label_key`` is ``int``, the corresponding element in ``args`` 170 | is treated as ground truth labels. And when it is ``str``, the 171 | element in ``kwargs`` is used. 172 | The all elements of ``args`` and ``kwargs`` except the ground truth 173 | labels are features. 174 | It feeds features to the predictor and compare the result 175 | with ground truth labels. 176 | .. note:: 177 | We set ``None`` to the attributes ``y``, ``loss`` and ``accuracy`` 178 | each time before running the predictor, to avoid unnecessary memory 179 | consumption. Note that the variables set on those attributes hold 180 | the whole computation graph when they are computed. The graph 181 | stores interim values on memory required for back-propagation. 182 | We need to clear the attributes to free those values. 183 | Returns: 184 | ~chainer.Variable: Loss value. 185 | """ 186 | 187 | self._reset() 188 | 189 | n_args = len(args) + len(kwargs) 190 | x = get_values(args, kwargs, self.x_keys) 191 | t = get_values(args, kwargs, self.t_keys) if n_args > 1 else None 192 | 193 | # predict, and then apply final activation 194 | y = self.predictor(x) 195 | 196 | if self.activation is not None: 197 | y = self.activation(y) 198 | 199 | # preserve 200 | self.x = x 201 | self.y = y 202 | self.t = t 203 | 204 | 205 | # if only input `x` is exist, return the predictions 206 | if t is None: 207 | return y 208 | 209 | # if ground-truth label `t` is exist, evaluate the loss and accuracy. 210 | # return the loss during training, otherwise return the predictions. 211 | if self.lossfun is not None: 212 | self.loss = self.lossfun(y, t) 213 | reporter.report({'loss': self.loss}, self) 214 | 215 | if self.accfun is not None: 216 | self.accuracy = self.accfun(y, t) 217 | reporter.report({'accuracy': self.accuracy}, self) 218 | 219 | if self.training: 220 | 221 | if self.loss is None: 222 | raise ValueError('loss is None..') 223 | 224 | return self.loss 225 | 226 | else: 227 | return self.y 228 | -------------------------------------------------------------------------------- /pytorch_bcnn/links/connection/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .pixel_shuffle_upsampler import PixelShuffleUpsampler2D # NOQA 4 | from .pixel_shuffle_upsampler import PixelShuffleUpsampler3D # NOQA 5 | -------------------------------------------------------------------------------- /pytorch_bcnn/links/connection/pixel_shuffle_upsampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PixelShuffleUpsampler2D(nn.Conv2d): 9 | """Pixel Shuffler for the super resolution. 10 | This upsampler is effective upsampling method compared with the deconvolution. 11 | The deconvolution has a problem of the checkerboard artifact. 12 | A detail of this problem shows the following. 13 | http://distill.pub/2016/deconv-checkerboard/ 14 | 15 | See also: 16 | https://arxiv.org/abs/1609.05158 17 | """ 18 | ndim = 2 19 | 20 | def __init__(self, in_channels, out_channels, resolution, kernel_size, stride=1, 21 | padding=0, dilation=1, groups=1, 22 | bias=True, padding_mode='zeros'): 23 | 24 | m = resolution ** self.ndim 25 | 26 | super(PixelShuffleUpsampler2D, self).__init__( 27 | in_channels, out_channels * m, kernel_size, stride, 28 | padding, dilation, groups, bias, padding_mode) 29 | 30 | self.resolution = resolution 31 | self.out_channels = out_channels 32 | 33 | def extra_repr(self): 34 | s = ('{in_channels}, {out_channels}, resolution={resolution}' 35 | ', kernel_size={kernel_size}, stride={stride}') 36 | if self.padding != (0,) * len(self.padding): 37 | s += ', padding={padding}' 38 | if self.dilation != (1,) * len(self.dilation): 39 | s += ', dilation={dilation}' 40 | if self.output_padding != (0,) * len(self.output_padding): 41 | s += ', output_padding={output_padding}' 42 | if self.groups != 1: 43 | s += ', groups={groups}' 44 | if self.bias is None: 45 | s += ', bias=False' 46 | if self.padding_mode != 'zeros': 47 | s += ', padding_mode={padding_mode}' 48 | return s.format(**self.__dict__) 49 | 50 | def forward(self, x): 51 | r = self.resolution 52 | out = super().forward(x) 53 | batchsize = out.shape[0] 54 | in_channels = out.shape[1] 55 | out_channels = self.out_channels 56 | 57 | in_shape = out.shape[2:] 58 | out_shape = tuple(s * r for s in in_shape) 59 | 60 | r_tuple = tuple(self.resolution for _ in range(self.ndim)) 61 | out = out.view((batchsize, out_channels,) + r_tuple + in_shape) 62 | out = out.permute(self.make_transpose_indices()).contiguous() 63 | out = out.view((batchsize, out_channels, ) + out_shape) 64 | return out 65 | 66 | def make_transpose_indices(self): 67 | si = [0, 1] 68 | si.extend([2 * (i + 1) + 1 for i in range(self.ndim)]) 69 | si.extend([2 * (i + 1) for i in range(self.ndim)]) 70 | return si 71 | 72 | 73 | class PixelShuffleUpsampler3D(nn.Conv3d): 74 | """Pixel Shuffler for the super resolution. 75 | This upsampler is effective upsampling method compared with the deconvolution. 76 | The deconvolution has a problem of the checkerboard artifact. 77 | A detail of this problem shows the following. 78 | http://distill.pub/2016/deconv-checkerboard/ 79 | 80 | See also: 81 | https://arxiv.org/abs/1609.05158 82 | """ 83 | ndim = 3 84 | 85 | def __init__(self, in_channels, out_channels, resolution, kernel_size, stride=1, 86 | padding=0, dilation=1, groups=1, 87 | bias=True, padding_mode='zeros'): 88 | 89 | m = resolution ** self.ndim 90 | 91 | super(PixelShuffleUpsampler3D, self).__init__( 92 | in_channels, out_channels * m, kernel_size, stride, 93 | padding, dilation, groups, bias, padding_mode) 94 | 95 | self.resolution = resolution 96 | self.out_channels = out_channels 97 | 98 | def extra_repr(self): 99 | s = ('{in_channels}, {out_channels}, resolution={resolution}' 100 | ', kernel_size={kernel_size}, stride={stride}') 101 | if self.padding != (0,) * len(self.padding): 102 | s += ', padding={padding}' 103 | if self.dilation != (1,) * len(self.dilation): 104 | s += ', dilation={dilation}' 105 | if self.output_padding != (0,) * len(self.output_padding): 106 | s += ', output_padding={output_padding}' 107 | if self.groups != 1: 108 | s += ', groups={groups}' 109 | if self.bias is None: 110 | s += ', bias=False' 111 | if self.padding_mode != 'zeros': 112 | s += ', padding_mode={padding_mode}' 113 | return s.format(**self.__dict__) 114 | 115 | def forward(self, x): 116 | r = self.resolution 117 | out = super().forward(x) 118 | batchsize = out.shape[0] 119 | in_channels = out.shape[1] 120 | out_channels = self.out_channels 121 | 122 | in_shape = out.shape[2:] 123 | out_shape = tuple(s * r for s in in_shape) 124 | 125 | r_tuple = tuple(self.resolution for _ in range(self.ndim)) 126 | out = out.view((batchsize, out_channels,) + r_tuple + in_shape) 127 | out = out.permute(self.make_transpose_indices()).contiguous() 128 | out = out.view((batchsize, out_channels, ) + out_shape) 129 | return out 130 | 131 | def make_transpose_indices(self): 132 | si = [0, 1] 133 | si.extend([2 * (i + 1) + 1 for i in range(self.ndim)]) 134 | si.extend([2 * (i + 1) for i in range(self.ndim)]) 135 | return si 136 | -------------------------------------------------------------------------------- /pytorch_bcnn/links/mc_sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import six 4 | import warnings 5 | from functools import partial 6 | from itertools import starmap 7 | from itertools import chain 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .classifier import get_values 14 | 15 | def _concat_variables(arrays): 16 | 17 | return torch.cat([array[None] \ 18 | for array in arrays], dim=0) 19 | 20 | def _concat_samples(samples): 21 | 22 | first_elem = samples[0] 23 | 24 | if isinstance(first_elem, (tuple, list)): 25 | result = [] 26 | 27 | for i in six.moves.range(len(first_elem)): 28 | result.append(_concat_variables( 29 | [example[i] for example in samples])) 30 | 31 | return list(result) 32 | 33 | else: 34 | return _concat_variables(samples) 35 | 36 | 37 | def _predict(samples, mode='variance', 38 | reduce_mean=None, reduce_var=None, 39 | eps=1e-8): 40 | 41 | mean = torch.mean(samples, dim=0) 42 | 43 | if mode == 'variance': 44 | var = torch.var(samples, dim=0) 45 | elif mode == 'entropy': 46 | var = - mean * torch.log2(mean + eps) 47 | else: 48 | raise NotImplementedError('unsupported mode..') 49 | 50 | if reduce_mean is not None: 51 | mean = reduce_mean(mean) 52 | 53 | if reduce_var is not None: 54 | var = reduce_var(var) 55 | 56 | return mean, var 57 | 58 | 59 | class MCSampler(nn.Module): 60 | """ Monte Carlo estimation to approximate the predictive distribution. 61 | Predictive variance is a metric indicating uncertainty. 62 | 63 | Args: 64 | predictor (~chainer.Link): Predictor network. 65 | mc_iteration (int): Number of iterations in MCMC sampling 66 | activation (list or callable, optional): Activation function. If predictor makes multiple outputs, 67 | this must be a list of activations. Defaults to partial(F.softmax, axis=1). 68 | reduce_mean (list or callable, optional): Reduce function for mean tensor. If predictor makes multiple outputs, 69 | this must be a list of callable functions. Defaults to partial(F.argmax, axis=1). 70 | reduce_var (list or callable, optional): Reduce function for variance tensor. If predictor makes multiple outputs, 71 | this must be a list of callable functions. Defaults to partial(F.mean, axis=1). 72 | mode (str, optional): Method for calculating uncertainty. Defaults to 'variance'. (one of `{'variance', 'entropy'}`). 73 | eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-8. 74 | x_keys (tuple, int or str, optional): Key to specify input variable from arguments. 75 | When it is ``int``, a variable in positional arguments is used. 76 | And when it is ``str``, a variable in keyword arguments is used. 77 | If you use multiple variables, please specify ``tuple`` of ``int`` or ``str``. Defaults to (0). 78 | 79 | See also: https://arxiv.org/pdf/1506.02142.pdf 80 | https://arxiv.org/pdf/1511.02680.pdf 81 | """ 82 | def __init__(self, 83 | predictor, 84 | mc_iteration, 85 | activation=partial(torch.softmax, dim=1), 86 | reduce_mean=partial(torch.argmax, dim=1), 87 | reduce_var=partial(torch.mean, dim=1), 88 | mode='variance', 89 | eps=1e-8, 90 | x_keys=(0), 91 | ): 92 | super(MCSampler, self).__init__() 93 | 94 | assert callable(predictor), 'predictor should be callable..' 95 | 96 | self.add_module('predictor', predictor) 97 | 98 | self.activation = activation 99 | self.mc_iteration = mc_iteration 100 | self.reduce_mean = reduce_mean 101 | self.reduce_var = reduce_var 102 | self.mode = mode 103 | self.eps = eps 104 | 105 | self.x_keys = x_keys 106 | 107 | 108 | def forward(self, *args, **kwargs): 109 | 110 | if self.training: 111 | warnings.warn('During the training phase, MCMC sampling is not executed..') 112 | return self.predictor(*args, **kwargs) 113 | 114 | 115 | x = get_values(args, kwargs, self.x_keys) 116 | 117 | # MCMC sampling 118 | mc_samples = [] 119 | activation = self.activation 120 | 121 | for _ in range(self.mc_iteration): 122 | 123 | logits = self.predictor(x) 124 | 125 | if activation is None: 126 | y = logits 127 | elif isinstance(logits, (list,tuple)): 128 | assert isinstance(activation, (list,tuple)) 129 | assert len(logits) == len(activation) 130 | 131 | y = list(starmap(lambda f, x: f(x), \ 132 | zip(activation, logits))) 133 | else: 134 | y = activation(logits) 135 | 136 | mc_samples.append(y) 137 | 138 | mc_samples = _concat_samples(mc_samples) 139 | 140 | 141 | # uncertainty estimates 142 | reduce_mean = self.reduce_mean 143 | reduce_var = self.reduce_var 144 | 145 | if isinstance(mc_samples, list): 146 | if reduce_mean is None: 147 | reduce_mean = [None] * len(mc_samples) 148 | if reduce_var is None: 149 | reduce_var = [None] * len(mc_samples) 150 | 151 | assert isinstance(reduce_mean, (list,tuple)) 152 | assert isinstance(reduce_var, (list,tuple)) 153 | 154 | ret = list(starmap(lambda _samples, _reduce_m, _reduce_v: 155 | _predict(_samples, self.mode, 156 | _reduce_m, _reduce_v, 157 | self.eps), 158 | zip(mc_samples, reduce_mean, reduce_var))) 159 | 160 | ret = list(chain.from_iterable(ret)) 161 | 162 | else: 163 | ret = _predict(mc_samples, self.mode, 164 | self.reduce_mean, self.reduce_var, 165 | self.eps) 166 | 167 | return ret 168 | -------------------------------------------------------------------------------- /pytorch_bcnn/links/noise/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .mc_dropout import MCDropout # NOQA 4 | -------------------------------------------------------------------------------- /pytorch_bcnn/links/noise/mc_dropout.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class MCDropout(nn.Dropout): 8 | """ 9 | Drops elements of input variable randomly. 10 | This module drops input elements randomly with probability ``p`` and 11 | scales the remaining elements by factor ``1 / (1 - p)``. 12 | Args: 13 | p: probability of an element to be zeroed. Default: 0.5 14 | inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 15 | 16 | Shape: 17 | - Input: :math:`(*)`. Input can be of any shape 18 | - Output: :math:`(*)`. Output is of the same shape as input 19 | 20 | Examples:: 21 | 22 | >>> m = MCDropout(p=0.2) 23 | >>> input = torch.randn(20, 16) 24 | >>> output = m(input) 25 | 26 | See the paper by Y. Gal, and G. Zoubin: `Dropout as a bayesian approximation: \ 27 | Representing model uncertainty in deep learning .\ 28 | ` 29 | 30 | See also: A. Kendall: `Bayesian SegNet: Model Uncertainty \ 31 | in Deep Convolutional Encoder-Decoder Architectures for Scene Understanding \ 32 | `_. 33 | """ 34 | 35 | def forward(self, input): 36 | return F.dropout(input, self.p, True, self.inplace) 37 | -------------------------------------------------------------------------------- /pytorch_bcnn/links/regressor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch.nn as nn 4 | 5 | from .classifier import Classifier 6 | 7 | class Regressor(Classifier): 8 | """ A simple regressor model. 9 | It computes the loss and accuracy based on given input/label pair. 10 | 11 | Args: 12 | predictor (~chainer.Link): Predictor network. 13 | lossfun (callable): 14 | Loss function. 15 | You can specify one of loss functions from 16 | :doc:`built-in loss functions `, or 17 | your own loss function (see the example below). 18 | It should not be an 19 | :doc:`loss functions with parameters ` 20 | (i.e., :class:`~chainer.Link` instance). 21 | The function must accept two argument (an output from predictor 22 | and its ground truth labels), and return a loss. 23 | Returned value must be a Variable derived from the input Variable 24 | to perform backpropagation on the variable. 25 | accfun (callable): 26 | Function that computes accuracy. 27 | You can specify one of evaluation functions from 28 | :doc:`built-in evaluation functions `, or 29 | your own evaluation function. 30 | The signature of the function is the same as ``lossfun``. 31 | activation (callable): 32 | Function that apply final activation functions to preditions. 33 | You can specify one of evaluation functions from 34 | :doc:`built-in activation functions `, or 35 | your own activation function. 36 | x_keys (tuple, int or str): Key to specify input variable from arguments. 37 | When it is ``int``, a variable in positional arguments is used. 38 | And when it is ``str``, a variable in keyword arguments is used. 39 | If you use multiple variables, please specify ``tuple`` of ``int`` or ``str``. 40 | t_keys (tuple, int or str): Key to specify label variable from arguments. 41 | When it is ``int``, a variable in positional arguments is used. 42 | And when it is ``str``, a variable in keyword arguments is used. 43 | If you use multiple variables, please specify ``tuple`` of ``int`` or ``str``. 44 | 45 | Attributes: 46 | predictor (~chainer.Link): Predictor network. 47 | lossfun (callable): 48 | Loss function. 49 | See the description in the arguments for details. 50 | accfun (callable): 51 | Function that computes accuracy. 52 | See the description in the arguments for details. 53 | activation (callable): 54 | Activation function after the predictor output. 55 | See the description in the arguments for details. 56 | x (~chainer.Variable or tuple): Inputs for the last minibatch. 57 | y (~chainer.Variable or tuple): Predictions for the last minibatch. 58 | t (~chainer.Variable or tuple): Labels for the last minibatch. 59 | loss (~chainer.Variable): Loss value for the last minibatch. 60 | accuracy (~chainer.Variable): Accuracy for the last minibatch. 61 | 62 | See also: ~chainer_bcnn.links.Classifier 63 | """ 64 | 65 | def __init__(self, predictor, 66 | lossfun=nn.MSELoss(reduction='mean'), 67 | accfun=nn.L1Loss(reduction='mean'), 68 | activation=None, 69 | x_keys=(0), t_keys=(-1)): 70 | 71 | super(Regressor, self).__init__( 72 | predictor, lossfun, accfun, activation, 73 | x_keys, t_keys 74 | ) 75 | -------------------------------------------------------------------------------- /pytorch_bcnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from abc import ABCMeta, abstractmethod 5 | import json 6 | 7 | def _len_children(chain): 8 | if not hasattr(chain, 'children'): 9 | return None 10 | return len([l for l in chain.children()]) 11 | 12 | def _freeze_layers(chain, name=None, startwith=None, endwith=None, 13 | recursive=True, verbose=False): 14 | 15 | for l_name, l in chain.named_children(): 16 | 17 | flag = False 18 | if name is not None: 19 | flag = flag or (l_name == name) 20 | if startwith is not None: 21 | flag = flag or l_name.startswith(startwith) 22 | if endwith is not None: 23 | flag = flag or l_name.endswith(endwith) 24 | 25 | if flag: 26 | l = getattr(chain, l_name) 27 | l.requires_grad_(False) 28 | if verbose == True: 29 | print('disabled update:', l_name) 30 | 31 | if recursive and hasattr(l, 'children'): 32 | _freeze_layers(l, name, 33 | startwith, endwith, 34 | recursive, verbose) 35 | 36 | def _show_statistics(chain): 37 | 38 | def _show_statistics_depth(chain, depth): 39 | 40 | depth += 1 41 | 42 | for name, l in chain.named_children(): 43 | l = getattr(chain, name) 44 | print('--'*depth, name) 45 | 46 | if hasattr(l, 'children'): 47 | _show_statistics_depth(l, depth) 48 | 49 | if not hasattr(chain, 'children') or _len_children(chain) == 0: 50 | 51 | # parameters 52 | print(' '*depth, '(params)') 53 | for name, p in chain.named_parameters(): 54 | summary = [' '*depth + ' %s:' % name] 55 | if p.data is not None: 56 | summary.append('%.3e +- %.3e' % ((p.data.mean()), (p.data.std()))) 57 | summary.append(list(p.data.shape)) 58 | if hasattr(p, 'requires_grad'): 59 | if not p.requires_grad: summary.append('freeze') 60 | else: 61 | summary.append(None) 62 | print(*summary) 63 | 64 | 65 | for name, l in chain.named_children(): 66 | print(name) 67 | _show_statistics_depth(l, depth=0) 68 | 69 | 70 | class Model(torch.nn.Module, metaclass=ABCMeta): 71 | """ Base class of Models (e.g., U-Net) 72 | """ 73 | 74 | def freeze_layers(self, name=None, 75 | startwith=None, endwith=None, 76 | recursive=True, verbose=False): 77 | _freeze_layers(self, name, 78 | startwith, endwith, 79 | recursive, verbose) 80 | 81 | def show_statistics(self): 82 | _show_statistics(self) 83 | 84 | def count_params(self): 85 | return sum(p.numel() for p in self.parameters()) 86 | 87 | def count_trainable_params(self): 88 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 89 | 90 | def count_freezed_params(self): 91 | return sum(p.numel() for p in self.parameters() if not p.requires_grad) 92 | 93 | def save_args(self, out): 94 | args = self._args.copy() 95 | ignore_keys = ['__class__', 'self'] 96 | for key in ignore_keys: 97 | if key in args.keys(): 98 | args.pop(key) 99 | 100 | with open(out, 'w', encoding='utf-8') as f: 101 | json.dump(args, f, ensure_ascii=False, indent=4) 102 | 103 | def __getitem__(self, name): 104 | return getattr(self, name) 105 | 106 | @abstractmethod 107 | def forward(self, x, **kwargs): 108 | ''' 109 | Args: 110 | x (~chainer.Variable) 111 | kwargs: Optional arguments will be contained. 112 | Return: 113 | o (~chainer.Variable) 114 | ''' 115 | raise NotImplementedError() 116 | 117 | 118 | from .unet import UNetBase # NOQA 119 | from .unet import UNet # NOQA 120 | from .unet import BayesianUNet # NOQA 121 | 122 | from .discriminators import DiscriminatorBase # NOQA 123 | from .discriminators import PatchDiscriminator # NOQA 124 | -------------------------------------------------------------------------------- /pytorch_bcnn/models/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .discriminator_base import DiscriminatorBase 4 | from .patch_discriminator import PatchDiscriminator 5 | 6 | _supported_models = { 7 | 'discriminator_base': DiscriminatorBase, 8 | 'patch_discriminator': PatchDiscriminator, 9 | } 10 | -------------------------------------------------------------------------------- /pytorch_bcnn/models/discriminators/discriminator_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import copy 4 | import warnings 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .. import Model 10 | from ..unet.unet_base import UNetBaseBlock 11 | from ..unet._helper import conv, _default_conv_param 12 | from ..unet._helper import norm, _default_norm_param 13 | from ..unet._helper import pool, _default_pool_param 14 | from ..unet._helper import activation, _default_activation_param 15 | from ..unet._helper import dropout, _default_dropout_param 16 | from ..unet._helper import initializer 17 | from ...functions import crop 18 | 19 | 20 | class Block(UNetBaseBlock): 21 | """ Convolution blocks """ 22 | pass 23 | 24 | 25 | class DiscriminatorBase(Model): 26 | """ Base class of discriminator 27 | 28 | Args: 29 | ndim (int): Number of spatial dimensions. 30 | in_channels (int): Number of input channels. 31 | nlayer (int, optional): Number of layers. 32 | Defaults to 4. 33 | nfilter (list or int, optional): Number of filters. 34 | Defaults to 64. 35 | ninner (list or int, optional): Number of layers in UNetBlock. 36 | Defaults to 1. 37 | conv_param (dict, optional): Hyperparameter of convolution layer. 38 | Defaults to {'name':'conv', 'ksize': 3, 'stride': 1, 'pad': 1, 39 | 'initialW': {'name': 'he_normal', 'scale': 1.0}, 'initial_bias': {'name': 'zero'}}. 40 | pool_param (dict, optional): Hyperparameter of pooling layer. 41 | Defaults to {'name': 'max', 'ksize': 2, 'stride': 2}. 42 | norm_param (dict or None, optional): Hyperparameter of normalization layer. 43 | Defaults to {'name': 'batch'}. 44 | activation_param (dict, optional): Hyperparameter of activation layer. 45 | Defaults to {'name': 'relu'}. 46 | dropout_param (dict or None, optional): Hyperparameter of dropout layer. 47 | Defaults to {'name': 'dropout', 'ratio': .5}. 48 | dropout_enables (list or tuple, optional): Set whether to apply dropout for each layer. 49 | If None, apply the dropout in all layers. 50 | Defaults to None. 51 | residual (bool, optional): Enable the residual learning. 52 | Defaults to False. 53 | preserve_color (bool, optional): If True, the normalization will be discarded in the first layer. 54 | Defaults to False. 55 | 56 | See: https://arxiv.org/pdf/1406.2661.pdf 57 | """ 58 | def __init__(self, 59 | ndim, 60 | in_channels, 61 | nlayer=4, 62 | nfilter=64, 63 | ninner=1, 64 | conv_param=_default_conv_param, 65 | pool_param=_default_pool_param, 66 | norm_param=_default_norm_param, 67 | activation_param=_default_activation_param, 68 | dropout_param=_default_dropout_param, 69 | dropout_enables=None, 70 | residual=False, 71 | preserve_color=False 72 | ): 73 | 74 | super(DiscriminatorBase, self).__init__() 75 | 76 | self._args = locals() 77 | 78 | self._ndim = ndim 79 | self._nlayer = nlayer 80 | 81 | if isinstance(nfilter, int): 82 | nfilter = [nfilter*(2**i) for i in range(nlayer)] 83 | assert len(nfilter) == nlayer 84 | self._nfilter = nfilter 85 | 86 | if isinstance(ninner, int): 87 | ninner = [ninner]*nlayer 88 | assert len(ninner) == nlayer 89 | self._ninner = ninner 90 | 91 | self._conv_param = conv_param 92 | self._pool_param = pool_param 93 | self._norm_param = norm_param 94 | self._activation_param = activation_param, 95 | self._dropout_param = dropout_param 96 | 97 | if dropout_enables is None: 98 | dropout_enables = [True]*nlayer 99 | assert isinstance(dropout_enables, (list,tuple)) 100 | self._dropout_enables = dropout_enables 101 | 102 | self._residual = residual 103 | self._preserve_color = preserve_color 104 | 105 | self._pool = pool(ndim, pool_param) 106 | self._activation = activation(activation_param) 107 | self._dropout = dropout(dropout_param) 108 | 109 | # down 110 | for i in range(nlayer): 111 | 112 | self.add_module('block_%d' % i, 113 | Block(ndim, 114 | in_channels if i == 0 else nfilter[i-1], 115 | nfilter[i], 116 | conv_param, 117 | None if preserve_color and i == 0 else norm_param, 118 | activation_param, 119 | ninner[i], 120 | residual)) 121 | 122 | def forward(self, x): 123 | 124 | h = x 125 | 126 | # down 127 | for i in range(self._nlayer): 128 | 129 | if i != 0: 130 | h = self._pool(h) 131 | 132 | h = self['block_%d' % (i)](h) 133 | 134 | if self._dropout_enables[i]: 135 | h = self._dropout(h) 136 | 137 | return h 138 | 139 | -------------------------------------------------------------------------------- /pytorch_bcnn/models/discriminators/patch_discriminator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .discriminator_base import DiscriminatorBase 4 | from .discriminator_base import conv, crop 5 | from .discriminator_base import _default_conv_param 6 | from .discriminator_base import _default_pool_param 7 | from .discriminator_base import _default_norm_param 8 | from .discriminator_base import _default_activation_param 9 | from .discriminator_base import _default_dropout_param 10 | 11 | 12 | class PatchDiscriminator(DiscriminatorBase): 13 | """ Patch based discriminator (Markovian discriminator) 14 | 15 | Args: 16 | ndim (int): Number of spatial dimensions. 17 | in_channels (int): Number of input channels. 18 | out_channels (int): Number of output channels. 19 | nlayer (int, optional): Number of layers. 20 | Defaults to 4. 21 | nfilter (list or int, optional): Number of filters. 22 | Defaults to 64. 23 | ninner (list or int, optional): Number of layers in UNetBlock. 24 | Defaults to 1. 25 | conv_param (dict, optional): Hyperparameter of convolution layer. 26 | Defaults to {'name':'conv', 'ksize': 3, 'stride': 1, 'pad': 1, 27 | 'initialW': {'name': 'he_normal', 'scale': 1.0}, 'initial_bias': {'name': 'zero'}}. 28 | pool_param (dict, optional): Hyperparameter of pooling layer. 29 | Defaults to {'name': 'max', 'ksize': 2, 'stride': 2}. 30 | norm_param (dict or None, optional): Hyperparameter of normalization layer. 31 | Defaults to {'name': 'batch'}. 32 | activation_param (dict, optional): Hyperparameter of activation layer. 33 | Defaults to {'name': 'relu'}. 34 | dropout_param (dict or None, optional): Hyperparameter of dropout layer. 35 | Defaults to {'name': 'dropout', 'ratio': .5}. 36 | dropout_enables (list or tuple, optional): Set whether to apply dropout for each layer. 37 | If None, apply the dropout in all layers. 38 | Defaults to None. 39 | residual (bool, optional): Enable the residual learning. 40 | Defaults to False. 41 | preserve_color (bool, optional): If True, the normalization will be discarded in the first layer. 42 | Defaults to False. 43 | 44 | See: https://arxiv.org/pdf/1611.07004.pdf 45 | """ 46 | def __init__(self, 47 | ndim, 48 | in_channels, 49 | out_channels, 50 | nlayer=4, 51 | nfilter=64, 52 | ninner=1, 53 | conv_param=_default_conv_param, 54 | pool_param=_default_pool_param, 55 | norm_param=_default_norm_param, 56 | activation_param=_default_activation_param, 57 | dropout_param=_default_dropout_param, 58 | dropout_enables=None, 59 | residual=False, 60 | preserve_color=False 61 | ): 62 | 63 | super(PatchDiscriminator, self).__init__( 64 | ndim, 65 | in_channels, 66 | nlayer, 67 | nfilter, 68 | ninner, 69 | conv_param, 70 | pool_param, 71 | norm_param, 72 | activation_param, 73 | dropout_param, 74 | dropout_enables, 75 | residual, 76 | preserve_color) 77 | self._args = locals() 78 | 79 | self._out_channels = out_channels 80 | 81 | conv_out_param = { 82 | 'name': 'conv', 83 | 'kernel_size': 3, 84 | 'stride': 1, 85 | 'padding': 1, 86 | 'padding_mode': conv_param.get('padding_mode', 'zeros'), 87 | 'bias': conv_param.get('bias', True), 88 | 'initialW': conv_param.get('initialW', None), 89 | 'initial_bias': conv_param.get('initial_bias', None), 90 | 'hook': conv_param.get('hook', None), 91 | } 92 | 93 | self.add_module('conv_out', 94 | conv(ndim, 95 | self._nfilter[-1], 96 | out_channels, 97 | conv_out_param)) 98 | 99 | def forward(self, x): 100 | 101 | h = super().forward(x) 102 | 103 | out = self['conv_out'](h) 104 | out = crop(out, h.shape) 105 | 106 | return out 107 | -------------------------------------------------------------------------------- /pytorch_bcnn/models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .unet_base import UNetBase 4 | from .unet import UNet 5 | from .bayesian_unet import BayesianUNet 6 | 7 | _supported_models = { 8 | 'unet_base': UNetBase, 9 | 'unet': UNet, 10 | 'bayesian_unet': BayesianUNet, 11 | } 12 | -------------------------------------------------------------------------------- /pytorch_bcnn/models/unet/_helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from functools import partial 7 | from inspect import isfunction 8 | import copy 9 | 10 | from ...links.noise import MCDropout 11 | from ...functions import stride_pooling_nd 12 | from ...initializers import bilinear_upsample 13 | from ...links.connection import PixelShuffleUpsampler2D 14 | from ...links.connection import PixelShuffleUpsampler3D 15 | 16 | 17 | # supported functions 18 | 19 | _supported_convs_2d = { 20 | 'conv': nn.Conv2d, 21 | # 'deformable': DeformableConvolution2D, # NOTE: unsupported in PyTorch 22 | } 23 | 24 | _supported_convs_3d = { 25 | 'conv': nn.Conv3d, 26 | } 27 | 28 | _supported_upconvs_2d = { 29 | 'deconv': nn.ConvTranspose2d, 30 | 'pixel_shuffle': PixelShuffleUpsampler2D, 31 | } 32 | 33 | _supported_upconvs_3d = { 34 | 'deconv': nn.ConvTranspose3d, 35 | 'pixel_shuffle': PixelShuffleUpsampler3D, 36 | } 37 | 38 | _supported_pools_2d = { 39 | 'none': lambda x: x, 40 | 'max': F.max_pool2d, 41 | 'average': F.avg_pool2d, 42 | 'stride': stride_pooling_nd, 43 | } 44 | 45 | _supported_pools_3d = { 46 | 'none': lambda x: x, 47 | 'max': F.max_pool3d, 48 | 'average': F.avg_pool3d, 49 | 'stride': stride_pooling_nd, 50 | } 51 | 52 | _supported_norms_2d = { 53 | 'batch': nn.BatchNorm2d, 54 | 'instance': nn.InstanceNorm2d, 55 | } 56 | 57 | _supported_norms_3d = { 58 | 'batch': nn.BatchNorm3d, 59 | 'instance': nn.InstanceNorm3d, 60 | } 61 | 62 | _supported_activations = { 63 | 'none': lambda x: x, 64 | 'identity': lambda x: x, 65 | 'relu': F.relu, 66 | 'leaky_relu': F.leaky_relu, 67 | 'tanh': F.tanh, 68 | 'sigmoid': F.sigmoid, 69 | # 'clipped_relu': F.clipped_relu, # NOTE: unsupported in PyTorch 70 | # 'crelu': F.crelu, 71 | 'elu': F.elu, 72 | # 'hard_sigmoid': F.hard_sigmoid, 73 | 'softplus': F.softplus, 74 | 'softmax': F.softmax, 75 | 'log_softmax': F.log_softmax, 76 | # 'maxout': F.maxout, 77 | # 'swish': F.swish, 78 | 'selu': F.selu, 79 | 'rrelu': F.rrelu, 80 | 'prelu': F.prelu, 81 | } 82 | 83 | _supported_dropouts = { 84 | 'none': lambda x: x, 85 | 'dropout': nn.Dropout, 86 | 'mc_dropout': MCDropout, 87 | } 88 | 89 | _supported_initializers = { 90 | 'zero': nn.init.zeros_, 91 | 'identity': nn.init.eye_, 92 | 'constant': nn.init.constant_, 93 | 'one': nn.init.ones_, 94 | 'normal': nn.init.normal_, 95 | # 'lecun_normal': , # NOTE: unsupported in PyTorch 96 | 'glorot_normal': nn.init.xavier_normal_, 97 | 'he_normal': nn.init.kaiming_normal_, 98 | 'orthogonal': nn.init.orthogonal_, 99 | 'uniform': nn.init.uniform_, 100 | # 'lecun_uniform': , 101 | 'glorot_uniform': nn.init.xavier_uniform_, 102 | 'he_uniform': nn.init.kaiming_uniform_, 103 | 'bilinear': bilinear_upsample, 104 | } 105 | 106 | _supported_link_hooks = { 107 | 'spectral_normalization': nn.utils.spectral_norm, 108 | 'weight_standardization': nn.utils.weight_norm, 109 | } 110 | 111 | # default parameters 112 | 113 | _default_conv_param = { 114 | 'name':'conv', 115 | 'kernel_size': 3, 116 | 'stride': 1, 117 | 'padding': 1, 118 | 'padding_mode': 'reflect', 119 | 'initialW': {'name': 'he_normal'}, 120 | 'initial_bias': {'name': 'zero'}, 121 | } 122 | 123 | _default_upconv_param = { 124 | 'name':'deconv', 125 | 'kernel_size': 3, 126 | 'stride': 2, 127 | 'padding': 0, 128 | # 'padding_mode': 'reflect', # NOTE: unsupported in PyTorch 129 | 'initialW': {'name': 'bilinear'}, 130 | 'initial_bias': {'name': 'zero'}, 131 | } 132 | 133 | _default_pool_param = { 134 | 'name': 'max', 135 | 'kernel_size': 2, 136 | 'stride': 2, 137 | } 138 | 139 | _default_norm_param = { 140 | 'name': 'batch' 141 | } 142 | 143 | _default_activation_param = { 144 | 'name': 'relu', 145 | # 'inplace': True 146 | } 147 | 148 | _default_dropout_param = { 149 | 'name': 'dropout', 150 | 'p': .5, 151 | } 152 | 153 | 154 | def _mapper(param, supported): 155 | assert isinstance(param, dict) 156 | assert isinstance(supported, dict) 157 | 158 | param = copy.deepcopy(param) 159 | 160 | if 'name' not in param.keys(): 161 | raise ValueError('"name" must be in param.keys()..') 162 | 163 | name = param.pop('name') 164 | 165 | if name not in supported.keys(): 166 | raise KeyError('"%s" is not supported.. Available: %s' 167 | % (name, supported.keys())) 168 | 169 | func = supported[name] 170 | 171 | if isfunction(func): 172 | return partial(func, **param) 173 | elif issubclass(func, (nn.Module)): 174 | return func(**param) 175 | else: 176 | raise ValueError('unsupported class type.. <%s>' % func.__class__) 177 | 178 | 179 | def pool(ndim, param): 180 | """ Return a function of the pool layer """ 181 | if ndim == 2: 182 | supported_pools = _supported_pools_2d 183 | else: 184 | supported_pools = _supported_pools_3d 185 | 186 | return _mapper(param, supported_pools) 187 | 188 | def activation(param): 189 | """ Return a function of the activation layer """ 190 | return _mapper(param, _supported_activations) 191 | 192 | def dropout(param): 193 | """ Return a function of the activation layer """ 194 | return _mapper(param, _supported_dropouts) 195 | 196 | def initializer(param): 197 | """ Return a function of the initializer """ 198 | return _mapper(param, _supported_initializers) 199 | 200 | def norm(ndim, size, param): 201 | """ Return a object of the normalization layer """ 202 | param = copy.deepcopy(param) 203 | param['num_features'] = size 204 | 205 | if ndim == 2: 206 | supported_norms = _supported_norms_2d 207 | else: 208 | supported_norms = _supported_norms_3d 209 | 210 | return _mapper(param, supported_norms) 211 | 212 | def link_hook(param): 213 | """ Return a function of the link hook """ 214 | return _mapper(param, _supported_link_hooks) 215 | 216 | def conv(ndim, in_channels, out_channels, param): 217 | """ Return a object of the convolution layer """ 218 | conv_param = copy.deepcopy(param) 219 | 220 | initialW_param = conv_param.pop('initialW', None) 221 | initial_bias_param = conv_param.pop('initial_bias', None) 222 | hook_param = conv_param.pop('hook', None) 223 | 224 | conv_param['in_channels'] = in_channels 225 | conv_param['out_channels'] = out_channels 226 | 227 | if ndim == 2: 228 | supported_convs = _supported_convs_2d 229 | else: 230 | supported_convs = _supported_convs_3d 231 | 232 | link = _mapper(conv_param, supported_convs) 233 | 234 | if link.weight is not None \ 235 | and initialW_param is not None: 236 | initialW = initializer(initialW_param) 237 | initialW(link.weight.data) 238 | 239 | if link.bias is not None \ 240 | and initial_bias_param is not None: 241 | initial_bias = initializer(initial_bias_param) 242 | initial_bias(link.bias.data) 243 | 244 | if hook_param is not None: 245 | hook = link_hook(hook_param) 246 | link = hook(link) 247 | 248 | return link 249 | 250 | def upconv(ndim, in_channels, out_channels, param): 251 | """ Return a object of the up-convolution layer """ 252 | conv_param = copy.deepcopy(param) 253 | 254 | initialW_param = conv_param.pop('initialW', None) 255 | initial_bias_param = conv_param.pop('initial_bias', None) 256 | hook_param = conv_param.pop('hook', None) 257 | 258 | conv_param['in_channels'] = in_channels 259 | conv_param['out_channels'] = out_channels 260 | 261 | if ndim == 2: 262 | supported_upconvs = _supported_upconvs_2d 263 | else: 264 | supported_upconvs = _supported_upconvs_3d 265 | 266 | link = _mapper(conv_param, supported_upconvs) 267 | 268 | if link.weight is not None \ 269 | and initialW_param is not None: 270 | initialW = initializer(initialW_param) 271 | initialW(link.weight.data) 272 | 273 | if link.bias is not None \ 274 | and initial_bias_param is not None: 275 | initial_bias = initializer(initial_bias_param) 276 | initial_bias(link.bias.data) 277 | 278 | if hook_param is not None: 279 | hook = link_hook(hook_param) 280 | link = hook(link) 281 | 282 | return link 283 | -------------------------------------------------------------------------------- /pytorch_bcnn/models/unet/bayesian_unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch.nn.functional as F 4 | from pytorch_trainer import reporter 5 | import warnings 6 | 7 | from .unet_base import UNetBase 8 | from ._helper import conv 9 | from ._helper import _default_conv_param 10 | from ._helper import _default_norm_param 11 | from ._helper import _default_upconv_param 12 | from ._helper import _default_pool_param 13 | from ._helper import _default_activation_param 14 | from ._helper import _default_dropout_param 15 | from ...functions import crop 16 | 17 | 18 | def _check_dropout_param(param): 19 | 20 | name = param['name'] 21 | if name == 'dropout': 22 | warnings.warn('`%s` is not supported in BayesianUNet.. \ 23 | Use ``mc_dropout`` instead.' % name) 24 | param['name'] = 'mc_dropout' 25 | 26 | 27 | class BayesianUNet(UNetBase): 28 | """ Bayesian U-Net 29 | 30 | Args: 31 | ndim (int): Number of spatial dimensions. 32 | in_channels (int): Number of input channels. 33 | out_channels (int): Number of output channels. 34 | nlayer (int, optional): Number of layers. 35 | Defaults to 5. 36 | nfilter (list or int, optional): Number of filters. 37 | Defaults to 32. 38 | ninner (list or int, optional): Number of layers in UNetBlock. 39 | Defaults to 2. 40 | sigma (bool, optional): If True, the network concurrently outputs the sigma. 41 | Defaults to False. 42 | sigma_channels (int or None, optional): Number of channels for the sigma. 43 | If None, this is set equal to number of output channels, automatically. 44 | Defaults to None. 45 | conv_param (dict, optional): Hyperparameter of convolution layer. 46 | Defaults to {'name':'conv', 'ksize': 3, 'stride': 1, 'pad': 1, 47 | 'initialW': {'name': 'he_normal', 'scale': 1.0}, 'initial_bias': {'name': 'zero'}}. 48 | pool_param (dict, optional): Hyperparameter of pooling layer. 49 | Defaults to {'name': 'max', 'ksize': 2, 'stride': 2}. 50 | upconv_param (dict, optional): Hyperparameter of up-convolution layer. 51 | Defaults to {'name':'deconv', 'ksize': 3, 'stride': 2, 'pad': 0, 52 | 'initialW': {'name': 'bilinear', 'scale': 1.0}, 'initial_bias': {'name': 'zero'}}. 53 | norm_param (dict or None, optional): Hyperparameter of normalization layer. 54 | Defaults to {'name': 'batch'}. 55 | activation_param (dict, optional): Hyperparameter of activation layer. 56 | Defaults to {'name': 'relu'}. 57 | dropout_param (dict or None, optional): Hyperparameter of dropout layer. 58 | Defaults to {'name': 'mc_dropout', 'ratio': .5,}. 59 | dropout_enables (list or tuple, optional): Set whether to apply dropout for each layer. 60 | If None, apply the dropout in all layers. 61 | Defaults to None. 62 | residual (bool, optional): Enable the residual learning. 63 | Defaults to False. 64 | preserve_color (bool, optional): If True, the normalization will be discarded in the first layer. 65 | Defaults to False. 66 | exp_ninner (str, optional): Specify the number of layers in ExpansionBlock. 67 | If 'same', it is set to the same value as `ninner`. 68 | Defaults to 'same'. 69 | exp_norm_param (str, optional): Specify the hyperparameter of normalization layer in ExpansionBlock. 70 | If 'same', it is set to the same value as `norm_param`. 71 | Defaults to 'same'. 72 | exp_activation_param (str, optional): Specify the hyperparameter of normalization layer in ExpansionBlock. 73 | If 'same', it is set to the same value as `activation_param`. 74 | Defaults to 'same'. 75 | exp_dropout_param (str, optional): Specify the hyperparameter of normalization layer in ExpansionBlock. 76 | If 'same', it is set to the same value as `dropout_param`. 77 | Defaults to 'same'. 78 | 79 | See also: ~chainer_bcnn.links.mc_sampler 80 | ~chainer_bcnn.functions.mc_dropout 81 | """ 82 | 83 | def __init__(self, 84 | ndim, 85 | in_channels, 86 | out_channels, 87 | nlayer=5, 88 | nfilter=32, 89 | ninner=2, 90 | sigma=False, 91 | sigma_channels=None, 92 | conv_param=_default_conv_param, 93 | pool_param=_default_pool_param, 94 | upconv_param=_default_upconv_param, 95 | norm_param=_default_norm_param, 96 | activation_param=_default_activation_param, 97 | dropout_param={'name': 'mc_dropout', 'p': .5,}, 98 | dropout_enables=None, 99 | residual=False, 100 | preserve_color=False, 101 | exp_ninner='same', 102 | exp_norm_param='same', 103 | exp_activation_param='same', 104 | exp_dropout_param='same', 105 | ): 106 | 107 | _check_dropout_param(dropout_param) 108 | if exp_dropout_param != 'same': 109 | _check_dropout_param(exp_dropout_param) 110 | 111 | return_all_latent = False 112 | 113 | super(BayesianUNet, self).__init__( 114 | ndim, 115 | in_channels, 116 | nlayer, 117 | nfilter, 118 | ninner, 119 | conv_param, 120 | pool_param, 121 | upconv_param, 122 | norm_param, 123 | activation_param, 124 | dropout_param, 125 | dropout_enables, 126 | residual, 127 | preserve_color, 128 | exp_ninner, 129 | exp_norm_param, 130 | exp_activation_param, 131 | exp_dropout_param, 132 | return_all_latent) 133 | self._args = locals() 134 | 135 | if sigma_channels is None: 136 | sigma_channels = out_channels 137 | 138 | self._out_channels = out_channels 139 | self._sigma = sigma 140 | self._sigma_channels = sigma_channels 141 | 142 | conv_out_param = { 143 | 'name': 'conv', 144 | 'kernel_size': 3, 145 | 'stride': 1, 146 | 'padding': 1, 147 | 'padding_mode': conv_param.get('padding_mode', 'zeros'), 148 | 'bias': conv_param.get('bias', True), 149 | 'initialW': conv_param.get('initialW', None), 150 | 'initial_bias': conv_param.get('initial_bias', None), 151 | 'hook': conv_param.get('hook', None), 152 | } 153 | 154 | conv_out_nfilter_in = self._nfilter[0] 155 | if self._exp_ninner[0] == 0: 156 | conv_out_nfilter_in += self._nfilter[1] 157 | 158 | self.add_module('conv_out', 159 | conv(ndim, 160 | conv_out_nfilter_in, 161 | out_channels, 162 | conv_out_param)) 163 | 164 | 165 | if sigma: 166 | conv_sigma_param = { 167 | 'name': 'conv', 168 | 'kernel_size': 3, 169 | 'stride': 1, 170 | 'padding': 1, 171 | 'padding_mode': conv_param.get('padding_mode', 'zeros'), 172 | 'bias': False, 173 | 'initialW': {'name': 'zero'}, 174 | 'hook': conv_param.get('hook', None), 175 | } 176 | 177 | self.add_module('conv_sigma', 178 | conv(ndim, 179 | conv_out_nfilter_in, 180 | sigma_channels, 181 | conv_sigma_param)) 182 | 183 | def forward(self, x): 184 | 185 | h = super().forward(x) 186 | 187 | out = self['conv_out'](h) 188 | out = crop(out, x.shape) 189 | 190 | if not self._sigma: 191 | return out 192 | 193 | sigma = self['conv_sigma'](h) 194 | sigma = crop(sigma, x.shape) 195 | 196 | reporter.report({'sigma': torch.mean(sigma)}, self) 197 | 198 | return out, sigma 199 | -------------------------------------------------------------------------------- /pytorch_bcnn/models/unet/unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import warnings 4 | import copy 5 | 6 | from .unet_base import UNetBase 7 | from ._helper import conv 8 | from ._helper import _default_conv_param 9 | from ._helper import _default_norm_param 10 | from ._helper import _default_upconv_param 11 | from ._helper import _default_pool_param 12 | from ._helper import _default_activation_param 13 | from ._helper import _default_dropout_param 14 | from ...functions import crop 15 | 16 | class UNet(UNetBase): 17 | """ U-Net model 18 | 19 | Args: 20 | ndim (int): Number of spatial dimensions. 21 | in_channels (int): Number of input channels. 22 | out_channels (int): Number of output channels. 23 | nlayer (int, optional): Number of layers. 24 | Defaults to 5. 25 | nfilter (list or int, optional): Number of filters. 26 | Defaults to 32. 27 | ninner (list or int, optional): Number of layers in UNetBlock. 28 | Defaults to 2. 29 | conv_param (dict, optional): Hyperparameter of convolution layer. 30 | Defaults to {'name':'conv', 'ksize': 3, 'stride': 1, 'pad': 1, 31 | 'initialW': {'name': 'he_normal', 'scale': 1.0}, 'initial_bias': {'name': 'zero'}}. 32 | pool_param (dict, optional): Hyperparameter of pooling layer. 33 | Defaults to {'name': 'max', 'ksize': 2, 'stride': 2}. 34 | upconv_param (dict, optional): Hyperparameter of up-convolution layer. 35 | Defaults to {'name':'deconv', 'ksize': 3, 'stride': 2, 'pad': 0, 36 | 'initialW': {'name': 'bilinear', 'scale': 1.0}, 'initial_bias': {'name': 'zero'}}. 37 | norm_param (dict or None, optional): Hyperparameter of normalization layer. 38 | Defaults to {'name': 'batch'}. 39 | activation_param (dict, optional): Hyperparameter of activation layer. 40 | Defaults to {'name': 'relu'}. 41 | dropout_param (dict or None, optional): Hyperparameter of dropout layer. 42 | Defaults to {'name': 'dropout', 'ratio': .5}. 43 | dropout_enables (list or tuple, optional): Set whether to apply dropout for each layer. 44 | If None, apply the dropout in all layers. 45 | Defaults to None. 46 | residual (bool, optional): Enable the residual learning. 47 | Defaults to False. 48 | preserve_color (bool, optional): If True, the normalization will be discarded in the first layer. 49 | Defaults to False. 50 | exp_ninner (str, optional): Specify the number of layers in ExpansionBlock. 51 | If 'same', it is set to the same value as `ninner`. 52 | Defaults to 'same'. 53 | exp_norm_param (str, optional): Specify the hyperparameter of normalization layer in ExpansionBlock. 54 | If 'same', it is set to the same value as `norm_param`. 55 | Defaults to 'same'. 56 | exp_activation_param (str, optional): Specify the hyperparameter of normalization layer in ExpansionBlock. 57 | If 'same', it is set to the same value as `activation_param`. 58 | Defaults to 'same'. 59 | exp_dropout_param (str, optional): Specify the hyperparameter of normalization layer in ExpansionBlock. 60 | If 'same', it is set to the same value as `dropout_param`. 61 | Defaults to 'same'. 62 | 63 | See also: https://arxiv.org/abs/1505.04597 64 | """ 65 | def __init__(self, 66 | ndim, 67 | in_channels, 68 | out_channels, 69 | nlayer=5, 70 | nfilter=32, 71 | ninner=2, 72 | conv_param=_default_conv_param, 73 | pool_param=_default_pool_param, 74 | upconv_param=_default_upconv_param, 75 | norm_param=_default_norm_param, 76 | activation_param=_default_activation_param, 77 | dropout_param=_default_dropout_param, 78 | dropout_enables=None, 79 | residual=False, 80 | preserve_color=False, 81 | exp_ninner='same', 82 | exp_norm_param='same', 83 | exp_activation_param='same', 84 | exp_dropout_param='same', 85 | ): 86 | 87 | return_all_latent = False 88 | 89 | super(UNet, self).__init__( 90 | ndim, 91 | in_channels, 92 | nlayer, 93 | nfilter, 94 | ninner, 95 | conv_param, 96 | pool_param, 97 | upconv_param, 98 | norm_param, 99 | activation_param, 100 | dropout_param, 101 | dropout_enables, 102 | residual, 103 | preserve_color, 104 | exp_ninner, 105 | exp_norm_param, 106 | exp_activation_param, 107 | exp_dropout_param, 108 | return_all_latent) 109 | self._args = locals() 110 | 111 | self._out_channels = out_channels 112 | 113 | conv_out_param = { 114 | 'name': 'conv', 115 | 'kernel_size': 3, 116 | 'stride': 1, 117 | 'padding': 1, 118 | 'padding_mode': conv_param.get('padding_mode', 'zeros'), 119 | 'bias': conv_param.get('bias', True), 120 | 'initialW': conv_param.get('initialW', None), 121 | 'initial_bias': conv_param.get('initial_bias', None), 122 | 'hook': conv_param.get('hook', None), 123 | } 124 | 125 | conv_out_nfilter_in = self._nfilter[0] 126 | if exp_ninner[0] == 0: 127 | conv_out_nfilter_in += self._nfilter[1] 128 | 129 | self.add_module('conv_out', 130 | conv(ndim, 131 | conv_out_nfilter_in, 132 | out_channels, 133 | conv_out_param)) 134 | 135 | def forward(self, x): 136 | 137 | h = super().forward(x) 138 | out = self['conv_out'](h) 139 | out = crop(out, x.shape) 140 | 141 | return out 142 | -------------------------------------------------------------------------------- /pytorch_bcnn/updaters/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cgan import DCGANUpdater 4 | from .cgan import LSGANUpdater 5 | -------------------------------------------------------------------------------- /pytorch_bcnn/updaters/cgan/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | from pytorch_trainer.training import StandardUpdater 6 | from pytorch_trainer import reporter 7 | from pytorch_trainer.dataset import convert 8 | 9 | 10 | def _update(optimizer, in_arrays, loss_func): 11 | 12 | if isinstance(in_arrays, tuple): 13 | optimizer.update(loss_func, *in_arrays) 14 | elif isinstance(in_arrays, dict): 15 | optimizer.update(loss_func, **in_arrays) 16 | else: 17 | optimizer.update(loss_func, in_arrays) 18 | 19 | 20 | class CGANUpdater(StandardUpdater, metaclass=ABCMeta): 21 | """ Base class of updater for conditional GANs 22 | 23 | Args: 24 | iterator: Dataset iterator for the training dataset. 25 | optimizer (dict): Optimizers to update parameters. It should be a dictionary 26 | that has `gen` and `dis` keys. Note that `gen` and `dis` means the generator 27 | and discniminator, respectively. 28 | model (dict): Generative and discriminative models. It should be a dictionary 29 | that has `gen` and `dis` keys. Note that `gen` and `dis` means the generator 30 | and discniminator, respectively. 31 | alpha (float): Loss scaling factor for balancing the conditional loss. 32 | converter (optional): Converter function to build input arrays. Defaults to `convert.concat_examples`. 33 | device (int, optional): Device to which the training data is sent. Negative value 34 | indicates the host memory (CPU). Defaults to None. 35 | loss_func: Conditional loss function. `lossfun` attribute of the optimizer's target link for 36 | the generator is used by default. Defaults to None. 37 | """ 38 | 39 | 40 | 41 | def __init__(self, iterator, optimizer, model, alpha, 42 | converter=convert.concat_examples, 43 | device=None, loss_func=None): 44 | 45 | assert isinstance(optimizer, dict) 46 | 47 | super(CGANUpdater, self).__init__( 48 | iterator, optimizer, model, converter, 49 | device, loss_func) 50 | 51 | self.alpha = alpha 52 | 53 | @property 54 | def discriminator(self): 55 | return self._models['dis'] 56 | 57 | @property 58 | def generator(self): 59 | return self._models['gen'] 60 | 61 | def conditional_lossfun(self, y_fake, y_true): 62 | 63 | model = self.generator 64 | 65 | if hasattr(model, 'lossfun'): 66 | lossfun = model.lossfun 67 | else: 68 | lossfun = self.loss_func 69 | 70 | loss = lossfun(y_fake, y_true) 71 | reporter.report({'loss_cond': loss}) 72 | return loss 73 | 74 | 75 | @abstractmethod 76 | def discriminative_lossfun(self, *args, **kwargs): 77 | raise NotImplementedError() 78 | 79 | @abstractmethod 80 | def generative_lossfun(self, *args, **kwargs): 81 | raise NotImplementedError() 82 | 83 | @abstractmethod 84 | def update_core(self): 85 | raise NotImplementedError() 86 | 87 | 88 | from .dcgan import DCGANUpdater # NOQA 89 | from .lsgan import LSGANUpdater # NOQA 90 | -------------------------------------------------------------------------------- /pytorch_bcnn/updaters/cgan/_replay_buffer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import torch 5 | 6 | class ReplayBuffer(object): 7 | """ Buffer for handling the experience replay. 8 | Args: 9 | size (int): buffer size 10 | p (float): probability to evoke the past experience 11 | 12 | See also: 13 | https://arxiv.org/pdf/1612.07828.pdf 14 | https://arxiv.org/pdf/1703.10593.pdf 15 | """ 16 | 17 | def __init__(self, size, p=0.5): 18 | self.size = size 19 | self.p = p 20 | self._buffer = [] 21 | 22 | 23 | @property 24 | def buffer(self): 25 | if len(self._buffer) == 0: 26 | return None 27 | return self._buffer 28 | 29 | def __call__(self, samples): 30 | 31 | if not isinstance(samples, torch.Tensor): 32 | samples = torch.as_tensor(samples) 33 | 34 | n_samples = len(samples) 35 | 36 | if self.size == 0: 37 | return samples 38 | 39 | if len(self._buffer) < self.size: 40 | if len(self._buffer) == 0: 41 | self._buffer = samples 42 | self._buffer = torch.cat((self._buffer, samples)) 43 | return samples 44 | 45 | # evoke the memory 46 | random_bool = np.random.rand(n_samples) < self.p 47 | replay_indices = np.random.randint(0, len(self._buffer), size=n_samples)[random_bool] 48 | sample_indices = np.random.randint(0, n_samples, size=n_samples)[random_bool] 49 | 50 | self._buffer[replay_indices], samples[sample_indices] \ 51 | = samples[sample_indices], self._buffer[replay_indices] # swap 52 | 53 | return samples 54 | 55 | 56 | if __name__ == '__main__': 57 | 58 | import numpy as np 59 | import torch 60 | 61 | buffer = ReplayBuffer(10) 62 | print(buffer.buffer) 63 | 64 | for i in range(20): 65 | a = buffer(torch.as_tensor(np.zeros((2,3,4,5)) + i)) 66 | print(i, a) 67 | print(a.shape) 68 | print(a.__class__) 69 | 70 | print(len(buffer.buffer)) 71 | print(buffer.buffer.shape) 72 | -------------------------------------------------------------------------------- /pytorch_bcnn/updaters/cgan/dcgan.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy 4 | import torch 5 | import torch.nn.functional as F 6 | from pytorch_trainer import reporter 7 | from pytorch_trainer.dataset import convert 8 | 9 | from . import CGANUpdater 10 | from ._replay_buffer import ReplayBuffer 11 | 12 | class DCGANUpdater(CGANUpdater): 13 | """ Updater for DCGAN 14 | 15 | Args: 16 | iterator: Dataset iterator for the training dataset. 17 | optimizer (dict): Optimizers to update parameters. It should be a dictionary 18 | that has `gen` and `dis` keys. Note that `gen` and `dis` means the generator 19 | and discniminator, respectively. 20 | model (dict): Generative and discriminative models. It should be a dictionary 21 | that has `gen` and `dis` keys. Note that `gen` and `dis` means the generator 22 | and discniminator, respectively. 23 | alpha (float): Loss scaling factor for balancing the conditional loss. 24 | buffer_size (int, optional): Size of buffer, which handles the experience replay. Defaults to 0. 25 | converter (optional): Converter function to build input arrays. Defaults to `convert.concat_examples`. 26 | device (int, optional): Device to which the training data is sent. Negative value 27 | indicates the host memory (CPU). Defaults to None. 28 | loss_func: Conditional loss function. `lossfun` attribute of the optimizer's target link for 29 | the generator is used by default. Defaults to None. 30 | 31 | See also: 32 | https://arxiv.org/pdf/1511.06434.pdf 33 | """ 34 | def __init__(self, iterator, optimizer, model, alpha, buffer_size=0, 35 | converter=convert.concat_examples, 36 | device=None, loss_func=None): 37 | 38 | super(DCGANUpdater, self).__init__( 39 | iterator, optimizer, model, alpha, converter, 40 | device, loss_func) 41 | 42 | self._buffer = ReplayBuffer(buffer_size) 43 | self._buffer_size = buffer_size 44 | 45 | def discriminative_lossfun(self, p_real, p_fake): 46 | size = p_real.numel() / p_real.shape[1] 47 | loss = (torch.sum(F.softplus(-p_real)) / size \ 48 | + torch.sum(F.softplus(p_fake)) / size) * 0.5 # NOTE: equivalent to binary cross entropy 49 | reporter.report({'loss_dis': loss}) 50 | return loss 51 | 52 | def generative_lossfun(self, p_fake): 53 | size = p_fake.numel() / p_fake.shape[1] 54 | loss = torch.sum(F.softplus(-p_fake)) / size 55 | reporter.report({'loss_gen': loss}) 56 | return loss 57 | 58 | def update_core(self): 59 | 60 | iterator = self._iterators['main'] 61 | batch = iterator.next() 62 | in_arrays = convert._call_converter(self.converter, batch, self.device) 63 | 64 | opt_dis = self._optimizers['dis'] 65 | opt_gen = self._optimizers['gen'] 66 | 67 | for model in self._models.values(): 68 | model.train() 69 | 70 | x_real, y_real = in_arrays 71 | 72 | # generative 73 | self.discriminator.requires_grad_(False) 74 | 75 | y_fake = self.generator(x_real) 76 | xy_fake = torch.cat((x_real, y_fake), dim=1) 77 | p_fake = self.discriminator(xy_fake) 78 | 79 | loss_gen = self.generative_lossfun(p_fake) \ 80 | + self.alpha * self.conditional_lossfun(y_fake, y_real) 81 | 82 | opt_gen.zero_grad() 83 | loss_gen.backward() 84 | opt_gen.step() 85 | 86 | # discriminative 87 | # NOTE: deallocate intermediate variable nodes related to the generator 88 | # with `detach` method 89 | self.discriminator.requires_grad_(True) 90 | 91 | y_fake_old = self._buffer(y_fake.detach()) 92 | 93 | xy_fake = torch.cat((x_real, y_fake_old), dim=1) 94 | p_fake = self.discriminator(xy_fake) 95 | 96 | xy_real = torch.cat((x_real, y_real), dim=1) 97 | p_real = self.discriminator(xy_real) 98 | 99 | loss_dis = self.discriminative_lossfun(p_real, p_fake) 100 | 101 | opt_dis.zero_grad() 102 | loss_dis.backward() 103 | opt_dis.step() 104 | -------------------------------------------------------------------------------- /pytorch_bcnn/updaters/cgan/lsgan.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from pytorch_trainer import reporter 6 | 7 | from .dcgan import DCGANUpdater 8 | 9 | 10 | class LSGANUpdater(DCGANUpdater): 11 | """ Updater for Least Square GAN (LSGAN) 12 | 13 | Args: 14 | iterator: Dataset iterator for the training dataset. 15 | optimizer (dict): Optimizers to update parameters. It should be a dictionary 16 | that has `gen` and `dis` keys. Note that `gen` and `dis` means the generator 17 | and discniminator, respectively. 18 | model (dict): Generative and discriminative models. It should be a dictionary 19 | that has `gen` and `dis` keys. Note that `gen` and `dis` means the generator 20 | and discniminator, respectively. 21 | alpha (float): Loss scaling factor for balancing the conditional loss. 22 | buffer_size (int, optional): Size of buffer, which handles the experience replay. Defaults to 0. 23 | converter (optional): Converter function to build input arrays. Defaults to `convert.concat_examples`. 24 | device (int, optional): Device to which the training data is sent. Negative value 25 | indicates the host memory (CPU). Defaults to None. 26 | loss_func: Conditional loss function. `lossfun` attribute of the optimizer's target link for 27 | the generator is used by default. Defaults to None. 28 | 29 | See also: 30 | https://arxiv.org/pdf/1611.04076.pdf 31 | """ 32 | def discriminative_lossfun(self, p_real, p_fake): 33 | t_1 = torch.ones(p_real.shape, dtype=p_real.dtype, device=p_real.device) 34 | t_0 = torch.zeros(p_fake.shape, dtype=p_fake.dtype, device=p_fake.device) 35 | loss = (F.mse_loss(p_real, t_1) \ 36 | + F.mse_loss(p_fake, t_0)) * 0.5 37 | reporter.report({'loss_dis': loss}) 38 | return loss 39 | 40 | def generative_lossfun(self, p_fake): 41 | t_1 = torch.ones(p_fake.shape, dtype=p_fake.dtype, device=p_fake.device) 42 | loss = F.mse_loss(p_fake, t_1) 43 | reporter.report({'loss_gen': loss}) 44 | return loss 45 | -------------------------------------------------------------------------------- /pytorch_bcnn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import numpy as np 5 | import contextlib 6 | import warnings 7 | import tempfile 8 | import shutil 9 | import argparse 10 | import json 11 | 12 | 13 | @contextlib.contextmanager 14 | def fixed_seed(seed, strict=False): 15 | """Fix random seed to improve the reproducibility. 16 | 17 | Args: 18 | seed (float): Random seed 19 | strict (bool, optional): If True, cuDNN works under deterministic mode. 20 | Defaults to False. 21 | 22 | TODO: Even if `strict` is set to True, the reproducibility cannot be guaranteed under the `MultiprocessIterator`. 23 | If your dataset has stochastic behavior, such as data augmentation, you should use the `SerialIterator` or `MultithreadIterator`. 24 | """ 25 | 26 | import random 27 | import torch 28 | import copy 29 | 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | if torch.cuda.is_available(): 34 | torch.cuda.manual_seed(seed) 35 | 36 | if strict: 37 | warnings.warn('Even if `strict` is set to True, the reproducibility cannot be guaranteed under the `MultiprocessIterator`. \ 38 | If your dataset has stochastic behavior such as data augmentation, you should use the `SerialIterator` or `MultithreadIterator`.') 39 | 40 | _deterministic = copy.copy(torch.backends.cudnn.deterministic) 41 | _benchmark = copy.copy(torch.backends.cudnn.benchmark) 42 | 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | yield 47 | 48 | if strict: 49 | torch.backends.cudnn.deterministic = _deterministic 50 | torch.backends.cudnn.benchmark = _benchmark 51 | 52 | 53 | # https://github.com/chainer/chainerui/blob/master/chainerui/utils/tempdir.py 54 | @contextlib.contextmanager 55 | def tempdir(**kwargs): 56 | # A context manager that defines a lifetime of a temporary directory. 57 | ignore_errors = kwargs.pop('ignore_errors', False) 58 | 59 | temp_dir = tempfile.mkdtemp(**kwargs) 60 | try: 61 | yield temp_dir 62 | finally: 63 | shutil.rmtree(temp_dir, ignore_errors=ignore_errors) 64 | 65 | # https://github.com/chainer/chainerui/blob/master/chainerui/utils/save_args.py 66 | def convert_dict(conditions): 67 | if isinstance(conditions, argparse.Namespace): 68 | return vars(conditions) 69 | return conditions 70 | 71 | # https://github.com/chainer/chainerui/blob/master/chainerui/utils/save_args.py 72 | def save_args(conditions, out_path): 73 | """A util function to save experiment condition for job table. 74 | 75 | Args: 76 | conditions (:class:`argparse.Namespace` or dict): Experiment conditions 77 | to show on a job table. Keys are show as table header and values 78 | are show at a job row. 79 | out_path (str): Output directory name to save conditions. 80 | 81 | """ 82 | 83 | args = convert_dict(conditions) 84 | 85 | try: 86 | os.makedirs(out_path) 87 | except OSError: 88 | pass 89 | 90 | with tempdir(prefix='args', dir=out_path) as tempd: 91 | path = os.path.join(tempd, 'args.json') 92 | with open(path, 'w') as f: 93 | json.dump(args, f, indent=4) 94 | 95 | new_path = os.path.join(out_path, 'args') 96 | shutil.move(path, new_path) 97 | 98 | 99 | # https://github.com/chainer/chainer/blob/v7.1.0/chainer/training/extensions/_snapshot.py 100 | def _find_snapshot_files(fmt, path): 101 | '''Only prefix and suffix match 102 | TODO(kuenishi): currently clean format string such as 103 | "snapshot{.iteration}.npz" can only be parsed, but tricky (or 104 | invalid) formats like "snapshot{{.iteration}}.npz" are hard to 105 | detect and to properly show errors, just ignored or fails so far. 106 | Args: 107 | fmt (str): format string to match with file names of 108 | existing snapshots, where prefix and suffix are 109 | only examined. Also, files' staleness is judged 110 | by timestamps. The default is metime. 111 | path (str): a directory path to search for snapshot files. 112 | Returns: 113 | A sorted list of pair of ``mtime, filename``, whose file 114 | name that matched the format ``fmt`` directly under ``path``. 115 | ''' 116 | prefix = fmt.split('{')[0] 117 | suffix = fmt.split('}')[-1] 118 | 119 | matched_files = (file for file in os.listdir(path) 120 | if file.startswith(prefix) and file.endswith(suffix)) 121 | 122 | def _prepend_mtime(f): 123 | t = os.stat(os.path.join(path, f)).st_mtime 124 | return (t, f) 125 | 126 | return sorted(_prepend_mtime(file) for file in matched_files) 127 | 128 | # https://github.com/chainer/chainer/blob/v7.1.0/chainer/training/extensions/_snapshot.py 129 | def _find_latest_snapshot(fmt, path): 130 | """Finds the latest snapshots in a directory 131 | Args: 132 | fmt (str): format string to match with file names of 133 | existing snapshots, where prefix and suffix are 134 | only examined. Also, files' staleness is judged 135 | by timestamps. The default is metime. 136 | path (str): a directory path to search for snapshot files. 137 | Returns: 138 | Latest snapshot file, in terms of a file that has newest 139 | ``mtime`` that matches format ``fmt`` directly under 140 | ``path``. If no such file found, it returns ``None``. 141 | """ 142 | snapshot_files = _find_snapshot_files(fmt, path) 143 | 144 | if len(snapshot_files) > 0: 145 | _, filename = snapshot_files[-1] 146 | return filename 147 | 148 | return None 149 | 150 | 151 | def find_latest_snapshot(fmt, path, return_fullpath=True): 152 | '''Alias of :func:`_find_latest_snapshot` 153 | ''' 154 | ret = _find_latest_snapshot(fmt, path) 155 | 156 | if ret is None: 157 | raise FileNotFoundError('cannot find snapshot for <%s>' % 158 | os.path.join(path, fmt)) 159 | 160 | if return_fullpath: 161 | return os.path.join(path, ret) 162 | 163 | return ret 164 | -------------------------------------------------------------------------------- /pytorch_bcnn/visualizer/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | 6 | class Visualizer(metaclass=ABCMeta): 7 | """ Base class of visualizers 8 | """ 9 | 10 | def __init__(self, *args, **kwargs): 11 | self._examples = None 12 | self.reset() 13 | 14 | def reset(self): 15 | self._examples = [] 16 | 17 | @property 18 | def n_examples(self): 19 | return len(self._examples) 20 | 21 | @abstractmethod 22 | def add_example(self, x, **kwargs): 23 | raise NotImplementedError() 24 | 25 | @abstractmethod 26 | def add_batch(self, x, **kwargs): 27 | raise NotImplementedError() 28 | 29 | @abstractmethod 30 | def save(self, filename): 31 | raise NotImplementedError() 32 | 33 | 34 | from .image import ImageVisualizer # NOQA 35 | -------------------------------------------------------------------------------- /recipe/AIO.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 3 | 4 | %environment 5 | SALMON_USER=/win/salmon/user 6 | export SALMON_USER 7 | SCALLOP_USER=/win/scallop/user 8 | export SCALLOP_USER 9 | %setup 10 | 11 | %post 12 | # Change repository location 13 | sed -i.bak -e "s%http://archive.ubuntu.com/ubuntu/%http://ftp.jaist.ac.jp/pub/Linux/ubuntu/%g" /etc/apt/sources.list 14 | cat /etc/apt/sources.list 15 | 16 | #binding points 17 | mkdir -p /win/salmon/user /win/scallop/user 18 | 19 | # python 20 | export PIP_INDEX_URL=http://cl-hammerhead:3141/root/pypi/+simple/ 21 | export PIP_TRUSTED_HOST=cl-hammerhead 22 | apt-get update && apt-get -y install locales && locale-gen en_US.UTF-8 23 | apt-get update -y && \ 24 | apt-get install -y --no-install-recommends \ 25 | wget ca-certificates openssl \ 26 | python3-dev \ 27 | python3-wheel \ 28 | python3-setuptools && \ 29 | 30 | # pip 31 | wget https://bootstrap.pypa.io/get-pip.py 32 | python3 get-pip.py 33 | # timezone 34 | apt-get install -y tzdata 35 | ln -fs /usr/share/zoneinfo/Asia/Tokyo /etc/localtime && dpkg-reconfigure -f noninteractive tzdata 36 | 37 | # libdeflate 38 | apt-get install -y git cmake 39 | git clone --recursive http://octopus.naist.jp/suzuki/pylibdeflate.git 40 | pip3 install ./pylibdeflate 41 | 42 | # Keras 43 | apt-get install -y libcupti-dev graphviz 44 | pip3 install tensorflow-gpu keras 45 | 46 | # packages 47 | apt-get install -y graphviz python3-tk htop pigz 48 | pip3 install bokeh chainerui click fasteners graphviz h5py imageio matplotlib numba numpy opencv-python openpyxl pandas pathlib pillow plotly progressbar2 psutil pydot pydot-ng scikit-image scikit-learn scipy seaborn SimpleITK sphinx tqdm vtk xlrd 49 | 50 | # Chainer 51 | pip3 install --no-cache-dir cupy-cuda100 chainer 52 | 53 | # PyTorch 54 | pip3 install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp35-cp35m-linux_x86_64.whl 55 | pip3 install torchvision 56 | 57 | apt-get clean 58 | %runscript 59 | 60 | %test 61 | -------------------------------------------------------------------------------- /recipe/AIO.def.in: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 3 | 4 | %environment 5 | SALMON_USER=/win/salmon/user 6 | export SALMON_USER 7 | SCALLOP_USER=/win/scallop/user 8 | export SCALLOP_USER 9 | %setup 10 | 11 | %post 12 | # Change repository location 13 | sed -i.bak -e "s%http://archive.ubuntu.com/ubuntu/%http://ftp.jaist.ac.jp/pub/Linux/ubuntu/%g" /etc/apt/sources.list 14 | cat /etc/apt/sources.list 15 | 16 | #binding points 17 | mkdir -p /win/salmon/user /win/scallop/user 18 | 19 | # python 20 | export PIP_INDEX_URL=http://cl-hammerhead:3141/root/pypi/+simple/ 21 | export PIP_TRUSTED_HOST=cl-hammerhead 22 | apt-get update && apt-get -y install locales && locale-gen en_US.UTF-8 23 | apt-get update -y && \ 24 | apt-get install -y --no-install-recommends \ 25 | wget ca-certificates openssl \ 26 | python3-dev \ 27 | python3-wheel \ 28 | python3-setuptools && \ 29 | 30 | # pip 31 | wget https://bootstrap.pypa.io/get-pip.py 32 | python3 get-pip.py 33 | # timezone 34 | apt-get install -y tzdata 35 | ln -fs /usr/share/zoneinfo/Asia/Tokyo /etc/localtime && dpkg-reconfigure -f noninteractive tzdata 36 | 37 | # libdeflate 38 | apt-get install -y git cmake 39 | git clone --recursive http://octopus.naist.jp/suzuki/pylibdeflate.git 40 | pip3 install ./pylibdeflate 41 | 42 | # Keras 43 | apt-get install -y libcupti-dev graphviz 44 | pip3 install tensorflow-gpu keras 45 | 46 | # packages 47 | apt-get install -y graphviz python3-tk htop pigz 48 | pip3 install python_packages 49 | 50 | # Chainer 51 | pip3 install --no-cache-dir cupy-cuda100 chainer 52 | 53 | # PyTorch 54 | pip3 install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp35-cp35m-linux_x86_64.whl 55 | pip3 install torchvision 56 | 57 | apt-get clean 58 | %runscript 59 | 60 | %test 61 | -------------------------------------------------------------------------------- /recipe/Makefile: -------------------------------------------------------------------------------- 1 | .PYONY: all 2 | 3 | all: AIO.simg 4 | 5 | AIO.def: AIO.def.in python_packages.txt 6 | sed "s/python_packages/`sort python_packages.txt | sed 's/\r//g' | perl -pe 's/\n/ /g'`/g" AIO.def.in > AIO.def 7 | 8 | AIO.simg: AIO.def 9 | sudo singularity build AIO.simg AIO.def 10 | -------------------------------------------------------------------------------- /recipe/python_packages.txt: -------------------------------------------------------------------------------- 1 | bokeh 2 | click 3 | fasteners 4 | graphviz 5 | h5py 6 | imageio 7 | matplotlib 8 | numba 9 | numpy 10 | opencv-python 11 | pandas 12 | pathlib 13 | pillow 14 | plotly 15 | progressbar2 16 | psutil 17 | pydot 18 | pydot-ng 19 | scikit-image 20 | scikit-learn 21 | scipy 22 | seaborn 23 | sphinx 24 | SimpleITK 25 | tqdm 26 | vtk 27 | chainerui 28 | openpyxl 29 | xlrd -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | pandas 3 | scipy 4 | scikit-image==0.12.3 5 | scikit_learn 6 | numpy 7 | opencv_python 8 | SimpleITK 9 | joblib 10 | tqdm 11 | openpyxl 12 | matplotlib==3.0.3 13 | seaborn 14 | tensorboardX 15 | torch 16 | pytorch-trainer 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages 4 | try: 5 | from setuptools import setup 6 | except ImportError: 7 | from distutils.core import setup 8 | 9 | setup( 10 | name='pytorch_bcnn', 11 | version='1.1.0', 12 | description='Bayesian Convolutional Neural Networks', 13 | long_description=open('README.md').read(), 14 | author='yuta-hi', 15 | packages=find_packages(), 16 | include_package_data=True, 17 | scripts=[], 18 | install_requires=open('requirements.txt').readlines(), 19 | url='https://github.com/yuta-hi/pytorch_bayesian_unet', 20 | license='MIT', 21 | classifiers=[ 22 | 'Programming Language :: Python :: 3', 23 | ], 24 | ) 25 | -------------------------------------------------------------------------------- /tests/lenna.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuta-hi/pytorch_bayesian_unet/bb22b44c64f5d83d78aa93880da97e0e6168dc1c/tests/lenna.png -------------------------------------------------------------------------------- /tests/test_augmentator_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from chainer_bcnn.data.augmentor import DataAugmentor, Crop2D, Flip2D, Affine2D, ResizeCrop2D 3 | 4 | import cv2 5 | import time 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def main(): 10 | augmentor = DataAugmentor() 11 | augmentor.add(ResizeCrop2D(resize_size=(400, 500), 12 | crop_size=(300, 400))) 13 | augmentor.add(Flip2D(axis=2)) 14 | augmentor.add(Affine2D(rotation=15., 15 | translate=(10., 10.), 16 | shear=0.25, 17 | zoom=(0.8, 1.2), 18 | keep_aspect_ratio=True, 19 | fill_mode=('nearest', 'constant'), 20 | cval=(0., 0.), 21 | interp_order=(3, 0))) 22 | 23 | augmentor.summary('augment.json') 24 | 25 | x_in = cv2.imread('lenna.png').astype(np.float32) / 255. 26 | x_in = x_in[:, :, ::-1] 27 | x_in = np.transpose(x_in, (2, 0, 1)) 28 | y_in = x_in[0, ...] 29 | print(x_in.shape) 30 | print(y_in.shape) 31 | 32 | tic = time.time() 33 | x_out, y_out = augmentor.apply(x_in, y_in) 34 | print('time: %f [sec]' % (time.time()-tic)) 35 | 36 | print(x_out.shape) 37 | print(y_out.shape) 38 | 39 | plt.subplot(2, 2, 1) 40 | plt.imshow(np.transpose(x_in, (1, 2, 0))) 41 | plt.xlabel('x') 42 | plt.ylabel('y') 43 | plt.subplot(2, 2, 2) 44 | plt.imshow(np.transpose(x_out, (1, 2, 0))) 45 | plt.xlabel('x') 46 | plt.ylabel('y') 47 | plt.subplot(2, 2, 3) 48 | plt.imshow(y_in) 49 | plt.xlabel('x') 50 | plt.ylabel('y') 51 | plt.subplot(2, 2, 4) 52 | plt.imshow(y_out) 53 | plt.xlabel('x') 54 | plt.ylabel('y') 55 | plt.show() 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /tests/test_augmentator_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from chainer_bcnn.data.augmentor import DataAugmentor, Crop3D, Flip3D, Affine3D 3 | from chainer_bcnn.data import load_image, save_image 4 | import time 5 | import argparse 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--image', type=str, default='image.mhd') 10 | parser.add_argument('--label', type=str, default='label.mhd') 11 | args = parser.parse_args() 12 | 13 | augmentor = DataAugmentor() 14 | augmentor.add(Crop3D(size=(100, 200, 300))) 15 | augmentor.add(Flip3D(axis=2)) 16 | augmentor.add(Affine3D( 17 | rotation=(15., 15., 15.), 18 | translate=(10., 10., 10.), 19 | shear=(np.pi / 8, np.pi / 8, np.pi / 8), 20 | zoom=(0.8, 1.2), 21 | keep_aspect_ratio=True, 22 | fill_mode=('constant', 'constant'), 23 | cval=(-3000., -1.), 24 | interp_order=(0, 0))) 25 | 26 | augmentor.summary('augment.json') 27 | 28 | x_in, spacing = load_image(args.image) 29 | x_in = np.expand_dims(x_in, axis=0) # add channel-axis 30 | x_in = x_in.astype(np.float32) 31 | 32 | y_in, _ = load_image(args.label) 33 | y_in = y_in.astype(np.float32) 34 | 35 | tic = time.time() 36 | x_out, y_out = augmentor.apply(x_in, y_in) 37 | print('time: %f [sec]' % (time.time()-tic)) 38 | 39 | save_image('x_out.mha', x_out[0, :, :, :], spacing) 40 | save_image('y_out.mha', y_out, spacing) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /tests/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | 4 | import torch 5 | from torch.nn.functional import cross_entropy 6 | 7 | from pytorch_bcnn.functions.loss import softmax_cross_entropy 8 | 9 | def test(): 10 | 11 | b, c, w, h = 5, 10, 20, 30 12 | 13 | x = np.random.rand(b, c, w, h).astype(np.float32) 14 | t = np.random.randint(0, c, (b, w, h)).astype(np.int32) 15 | 16 | ret = chainer.functions.softmax_cross_entropy(x, t, normalize=False) 17 | print(ret.data) 18 | 19 | t = t.astype(np.int64) 20 | ret = softmax_cross_entropy(torch.as_tensor(x), torch.as_tensor(t), normalize=False) 21 | print(ret) 22 | 23 | ret = cross_entropy(torch.as_tensor(x), torch.as_tensor(t), reduction='sum') 24 | print(ret) 25 | 26 | if __name__ == '__main__': 27 | test() 28 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | from pytorch_bcnn.data import load_image, save_image 2 | from pytorch_bcnn.datasets import ImageDataset, VolumeDataset 3 | import numpy as np 4 | from collections import OrderedDict 5 | import matplotlib.pyplot as plt 6 | import argparse 7 | 8 | patient_list = ['k1565', 'k1585'] 9 | 10 | class_list = ['background', 'pelvis', 'femur', 'adductor_muscles', 11 | 'biceps_femoris_muscle', 'gluteus_maximus_muscle', 12 | 'gluteus_medius_muscle', 'gluteus_minimus_muscle', 13 | 'gracilis_muscle', 'iliacus_muscle', 'obturator_externus_muscle', 14 | 'obturator_internus_muscle', 'pectineus_muscle', 15 | 'piriformis_muscle', 'psoas_major_muscle', 16 | 'rectus_femoris_muscle', 'sartorius_muscle', 17 | 'semimembranosus_muscle', 'semitendinosus_muscle', 18 | 'tensor_fasciae_latae_muscle', 19 | 'vastus_lateralis_muscle_and_vastus_intermedius_muscle', 20 | 'vastus_medialis_muscle', 'sacrum'] 21 | 22 | dtypes = OrderedDict({ 23 | 'image': np.float32, 24 | 'label': np.int64, 25 | 'mask' : np.uint8, 26 | }) 27 | 28 | mask_cvals = OrderedDict({ 29 | 'image': -3000, 30 | 'label': 0, 31 | }) 32 | 33 | 34 | def test_2d(root): 35 | 36 | filenames = OrderedDict({ 37 | 'image': '{root}/{patient}/slice/*image*.mhd', 38 | 'label': '{root}/{patient}/slice/*mask*.mhd', 39 | 'mask' : '{root}/{patient}/slice/*skin*.mhd', 40 | }) 41 | 42 | dataset = ImageDataset(root, 43 | patients=patient_list, classes=class_list, 44 | dtypes=dtypes, filenames=filenames, 45 | mask_cvals=mask_cvals) 46 | 47 | print('# dataset:', len(dataset)) 48 | print('# classes:', dataset.n_classes) 49 | 50 | sample = dataset.get_example(0) 51 | 52 | print(sample[0].shape) 53 | print(sample[1].shape) 54 | 55 | plt.subplot(1, 2, 1) 56 | plt.imshow(sample[0][0, :, :], cmap='gray') 57 | plt.colorbar() 58 | plt.subplot(1, 2, 2) 59 | plt.imshow(sample[1][:, :]) 60 | plt.colorbar() 61 | plt.show() 62 | 63 | 64 | def test_3d(root): 65 | 66 | filenames = OrderedDict({ 67 | 'image': '{root}/{patient}/*image*.mhd', 68 | 'label': '{root}/{patient}/*mask*.mhd', 69 | 'mask' : '{root}/{patient}/*skin*.mhd', 70 | }) 71 | 72 | dataset = VolumeDataset(root, 73 | patients=patient_list, classes=class_list, 74 | dtypes=dtypes, filenames=filenames, 75 | mask_cvals=mask_cvals) 76 | 77 | print('# dataset:', len(dataset)) 78 | print('# classes:', dataset.n_classes) 79 | 80 | sample = dataset.get_example(0) 81 | 82 | print(sample[0].shape) 83 | print(sample[1].shape) 84 | 85 | plt.subplot(1, 2, 1) 86 | plt.imshow(sample[0][0, :, :, 100], cmap='gray') 87 | plt.colorbar() 88 | plt.subplot(1, 2, 2) 89 | plt.imshow(sample[1][:, :, 100]) 90 | plt.colorbar() 91 | plt.show() 92 | 93 | 94 | def main(): 95 | 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('--root', type=str, default='') 98 | args = parser.parse_args() 99 | 100 | test_2d(args.root) 101 | test_3d(args.root) 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /tests/test_dice_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from pytorch_bcnn.functions.accuracy.discrete_dice import _discrete_dice 5 | from pytorch_bcnn.functions.loss.dice import dice 6 | from pytorch_bcnn.functions.loss._helper import to_onehot 7 | 8 | 9 | def _dice(y, t): 10 | y = y.astype(np.bool) 11 | t = t.astype(np.bool) 12 | 13 | return 2. * np.logical_and(y, t).sum() / (y.sum() + t.sum()) 14 | 15 | if __name__ == '__main__': 16 | 17 | n_class = 3 18 | 19 | y = np.random.randint(0, n_class, (10, 100,200)).astype(np.int64) 20 | t = np.random.randint(0, n_class, (10, 100,200)).astype(np.int64) 21 | 22 | d_all = [] 23 | for i in range(1, n_class): 24 | d = _dice(y==i, t==i) 25 | print('class %d:' % i, d) 26 | d_all.append(d[np.newaxis]) 27 | 28 | print('mean:', np.mean(np.concatenate(d_all))) 29 | 30 | y = torch.as_tensor(y) 31 | t = torch.as_tensor(t) 32 | d = _discrete_dice(y, t, n_class, normalize=False, ignore_label=0) 33 | print(d) 34 | 35 | y = to_onehot(y, n_class).float() 36 | d = dice(y, t, normalize=False, ignore_label=0) 37 | print(d) 38 | -------------------------------------------------------------------------------- /tests/test_discriminator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import subprocess 4 | import warnings 5 | import torch 6 | from torchviz import make_dot 7 | from pytorch_bcnn.models import PatchDiscriminator 8 | 9 | def main(): 10 | 11 | conv_param = { 12 | 'name':'conv', 13 | 'kernel_size': 3, 14 | 'stride': 1, 15 | 'padding': 1, 16 | 'initialW': {'name': 'normal', 'std': 0.02}, 17 | 'initial_bias': {'name': 'zero'}, 18 | 'hook': {'name': 'spectral_normalization'} 19 | } 20 | 21 | pool_param = { 22 | 'name': 'stride', 23 | 'stride': 2 24 | } 25 | 26 | norm_param = { 27 | 'name': 'batch' 28 | } 29 | 30 | activation_param = { 31 | 'name': 'leaky_relu' 32 | } 33 | 34 | dropout_param = { 35 | 'name': 'none' 36 | } 37 | 38 | model = PatchDiscriminator( 39 | ndim=2, 40 | in_channels=1, 41 | out_channels=1, 42 | nlayer=5, 43 | nfilter=32, 44 | conv_param=conv_param, 45 | pool_param=pool_param, 46 | norm_param=norm_param, 47 | activation_param=activation_param, 48 | dropout_param=dropout_param) 49 | 50 | x = np.random.rand(2, 1, 256, 256).astype(np.float32) 51 | y = model(torch.as_tensor(x)) 52 | print(y.shape) 53 | 54 | dot = make_dot(y, params=dict(model.named_parameters())) 55 | dot.render('graph_2d_discriminator', format='pdf') 56 | 57 | model.save_args('2d_discriminator.json') 58 | model.show_statistics() 59 | print('-----') 60 | 61 | print(model.count_params()) 62 | print('-----') 63 | 64 | # check the shape of the first left singular vector. 65 | vector_name = 'weight_u' 66 | print(vector_name) 67 | print(getattr(model.block_0.conv_0, vector_name).shape) 68 | print('-----') 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /tests/test_inferencer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch_trainer.iterators import SerialIterator 4 | from pytorch_bcnn.models import UNet, BayesianUNet 5 | from pytorch_bcnn.links import MCSampler 6 | from pytorch_bcnn.inference import Inferencer 7 | 8 | from pytorch_trainer.dataset import DatasetMixin 9 | from pytorch_trainer.dataset import convert_to_tensor 10 | 11 | 12 | class Dataset(DatasetMixin): 13 | 14 | def __init__(self, n_samples, shape, dtype=np.float32): 15 | self._n_samples = n_samples 16 | self._shape = shape 17 | self._dtype = dtype 18 | 19 | def __len__(self): 20 | return self._n_samples 21 | 22 | @convert_to_tensor 23 | def get_example(self, i): 24 | return np.random.rand(*self._shape).astype(self._dtype) 25 | 26 | 27 | def test(predictor, shape, batch_size, gpu, to_numpy): 28 | 29 | print('------') 30 | 31 | n_samples = 10 32 | dataset = Dataset(n_samples, shape) 33 | 34 | model = MCSampler(predictor, mc_iteration=5) 35 | model.eval() 36 | 37 | device = torch.device(gpu) 38 | model.to(device) 39 | 40 | iterator = SerialIterator(dataset, batch_size, repeat=False) 41 | 42 | infer = Inferencer(iterator, model, device=gpu, to_numpy=to_numpy) 43 | 44 | ret = infer.run() 45 | 46 | if isinstance(ret, (list, tuple)): 47 | for r in ret: 48 | print(r.shape) 49 | print(r.__class__) 50 | else: 51 | print(ret.shape) 52 | print(ret.__class__) 53 | 54 | 55 | def main(): 56 | test(BayesianUNet(ndim=2, in_channels=1, out_channels=5), 57 | (1, 200, 300), 58 | batch_size=2, 59 | gpu='cuda', 60 | to_numpy=True) 61 | 62 | test(BayesianUNet(ndim=3, in_channels=1, out_channels=5, nlayer=3), 63 | (1, 64, 64, 64), 64 | batch_size=2, 65 | gpu='cuda', 66 | to_numpy=True) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /tests/test_initializer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from pytorch_bcnn.initializers import bilinear_upsample 6 | from pytorch_bcnn.functions import crop 7 | import matplotlib.pyplot as plt 8 | 9 | def main(): 10 | x = cv2.imread('lenna.png') 11 | x = cv2.resize(x, (64, 64)) 12 | x = np.transpose(x, (2,0,1)) 13 | x = np.expand_dims(x, axis=0).astype(np.float32) 14 | 15 | c = x.shape[1] 16 | 17 | deconv = nn.ConvTranspose2d(c, c, kernel_size=(3,3), stride=2, padding=(0,0), bias=False) 18 | bilinear_upsample(deconv.weight) 19 | 20 | y = deconv(torch.as_tensor(x)) 21 | y = y.detach().numpy() 22 | 23 | plt.subplot(1,2,1) 24 | plt.imshow(x[0,0], cmap='gray') 25 | plt.colorbar() 26 | plt.subplot(1,2,2) 27 | plt.imshow(y[0,0], cmap='gray') 28 | plt.colorbar() 29 | plt.show() 30 | 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /tests/test_mc_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch_bcnn.links import MCSampler 4 | 5 | def _calc_uncertanty_from_mc_samples(samples): 6 | 7 | mean_pred = samples.mean(dim=0, keepdims=False) 8 | var_pred = samples.var(dim=0, keepdims=True) 9 | 10 | return mean_pred, var_pred 11 | 12 | 13 | def main(): 14 | mc_iteration = 10 15 | mc_samples = np.random.rand(1, 10, 2).astype(np.float32) 16 | mc_samples = np.repeat(mc_samples, mc_iteration, axis=0) 17 | mc_samples = torch.Tensor(mc_samples) 18 | 19 | _mean, _var = _calc_uncertanty_from_mc_samples(mc_samples) 20 | 21 | print('numpy') 22 | print(_mean) 23 | print(_var) 24 | print('------') 25 | 26 | sampler = MCSampler(lambda x: x, mc_iteration, lambda x: x, None, None) 27 | sampler.eval() 28 | 29 | mean, var = sampler(mc_samples[0]) 30 | 31 | print('mc_sampler') 32 | print(mean) 33 | print(var) 34 | 35 | print((np.abs(mean-_mean))) 36 | print((np.abs(var-_var))) 37 | print('------') 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import subprocess 4 | import warnings 5 | 6 | import torch 7 | from torchviz import make_dot 8 | from pytorch_bcnn.models import UNet, BayesianUNet 9 | 10 | def main(): 11 | model = BayesianUNet(ndim=2, in_channels=1, out_channels=10) 12 | x = np.random.rand(2, 1, 200, 300).astype(np.float32) 13 | y = model(torch.Tensor(x)) 14 | 15 | dot = make_dot(y, params=dict(model.named_parameters())) 16 | dot.render('graph_2d_unet', format='pdf') 17 | 18 | print(y.shape) 19 | model.save_args('2d_unet.json') 20 | model.show_statistics() 21 | print(model.count_params()) 22 | 23 | print('-----') 24 | 25 | model = BayesianUNet(ndim=3, in_channels=1, out_channels=10, nlayer=3) 26 | x = np.random.rand(2, 1, 20, 30, 10).astype(np.float32) 27 | y = model(torch.Tensor(x)) 28 | 29 | dot = make_dot(y, params=dict(model.named_parameters())) 30 | dot.render('graph_3d_unet', format='pdf') 31 | 32 | print(y.shape) 33 | model.save_args('3d_unet.json') 34 | model.show_statistics() 35 | print(model.count_params()) 36 | 37 | print('-----') 38 | 39 | model.freeze_layers('upconv', verbose=True) 40 | print('-----') 41 | model.freeze_layers('upconv', recursive=False, verbose=True) 42 | print('-----') 43 | model.freeze_layers(startwith='upconv', verbose=True) 44 | print('-----') 45 | model.freeze_layers(endwith='norm', verbose=True) 46 | 47 | model.show_statistics() 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /tests/test_normalizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_bcnn.data.normalizer import Normalizer, Quantize2D, Clip2D, Subtract2D, Divide2D 3 | 4 | import cv2 5 | import time 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def main(): 10 | normalizer = Normalizer() 11 | normalizer.add(Clip2D('ch_minmax')) 12 | # normalizer.add(Quantize2D(n_bit=3)) 13 | # normalizer.add(Subtract2D('ch_mean')) 14 | # normalizer.add(Divide2D('ch_std')) 15 | normalizer.add(Subtract2D(0.5)) 16 | normalizer.add(Divide2D(0.5)) 17 | normalizer.summary('norm.json') 18 | 19 | x_in = cv2.imread('lenna.png').astype(np.float32) 20 | x_in = x_in[:, :, ::-1] 21 | x_in = np.transpose(x_in, (2, 0, 1)) 22 | print(x_in.shape) 23 | 24 | tic = time.time() 25 | x_out = normalizer.apply(x_in) 26 | print('time: %f [sec]' % (time.time()-tic)) 27 | 28 | print(x_out.shape) 29 | 30 | plt.imshow(np.transpose(x_out, (1, 2, 0))[:,:,0]) 31 | plt.xlabel('x') 32 | plt.ylabel('y') 33 | plt.colorbar() 34 | plt.show() 35 | 36 | 37 | if __name__ == '__main__': 38 | main() 39 | -------------------------------------------------------------------------------- /tests/test_seed.py: -------------------------------------------------------------------------------- 1 | from pytorch_bcnn.utils import fixed_seed 2 | 3 | 4 | def main(): 5 | import random 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | print(np.random.rand(10)) 11 | print(random.random()) 12 | print(nn.Conv2d(2, 2, 3, 1, 0).weight.data) 13 | print(torch.backends.cudnn.deterministic) 14 | print(torch.backends.cudnn.benchmark) 15 | 16 | print('------') 17 | 18 | with fixed_seed(0, True): 19 | print(np.random.rand(10)) 20 | print(random.random()) 21 | print(nn.Conv2d(2, 2, 3, 1, 0).weight.data) 22 | print(torch.backends.cudnn.deterministic) 23 | print(torch.backends.cudnn.benchmark) 24 | 25 | print(np.random.rand(10)) 26 | print(random.random()) 27 | print(nn.Conv2d(2, 2, 3, 1, 0).weight.data) 28 | 29 | print(torch.backends.cudnn.deterministic) 30 | print(torch.backends.cudnn.benchmark) 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /tests/test_visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch_bcnn.visualizer import ImageVisualizer 4 | from pytorch_bcnn.visualizer.image import _default_cmap 5 | 6 | 7 | def test_classification_sparse(): 8 | 9 | _categorical_cmaps = { 10 | 'y': _default_cmap, 11 | 't': _default_cmap, 12 | } 13 | 14 | _categorical_clims = { 15 | 'x': (0., 1.), 16 | } 17 | 18 | _categorical_transforms = { 19 | 'x': lambda x: x, 20 | 'y': lambda x: x, 21 | 't': lambda x: x, 22 | } 23 | 24 | # NOTE: out is an unexpected argument. 25 | visualizer = ImageVisualizer(transforms=_categorical_transforms, 26 | cmaps=_categorical_cmaps, 27 | clims=_categorical_clims) 28 | 29 | x = torch.as_tensor(np.random.rand(3, 100, 200)) 30 | y = torch.as_tensor(np.random.randint(0, 10, (100, 200))) 31 | t = torch.as_tensor(np.random.randint(0, 10, (100, 200))) 32 | 33 | for _ in range(3): 34 | visualizer.add_example(x, y, t) 35 | visualizer.save('test_classification_sparse.png') 36 | 37 | 38 | def test_classification(): 39 | 40 | _categorical_cmaps = { 41 | 'y': _default_cmap, 42 | 't': _default_cmap, 43 | } 44 | 45 | _categorical_clims = { 46 | 'x': (0., 1.), 47 | } 48 | 49 | _categorical_transforms = { 50 | 'x': lambda x: x, 51 | 'y': lambda x: np.argmax(x, axis=0), 52 | 't': lambda x: np.argmax(x, axis=0), 53 | } 54 | 55 | # NOTE: out is an unexpected argument. 56 | visualizer = ImageVisualizer(transforms=_categorical_transforms, 57 | cmaps=_categorical_cmaps, 58 | clims=_categorical_clims) 59 | 60 | x = torch.as_tensor(np.random.rand(3, 100, 200)) 61 | y = torch.as_tensor(np.random.rand(10, 100, 200)) 62 | t = torch.as_tensor(np.random.rand(10, 100, 200)) 63 | 64 | for _ in range(3): 65 | visualizer.add_example(x, y, t) 66 | visualizer.save('test_classification.png') 67 | 68 | 69 | def test_regression(): 70 | 71 | _regression_cmaps = None 72 | 73 | _regression_clims = { 74 | 'x': (0., 1.), 75 | 'y': (0., 1.), 76 | 't': (0., 1.), 77 | } 78 | 79 | _regression_transforms = None 80 | 81 | visualizer = ImageVisualizer(transforms=_regression_transforms, 82 | cmaps=_regression_cmaps, 83 | clims=_regression_clims) 84 | 85 | x = torch.as_tensor(np.random.rand(3, 100, 200)) 86 | y = torch.as_tensor(np.random.rand(5, 100, 200)) 87 | t = torch.as_tensor(np.random.rand(5, 100, 200)) 88 | 89 | for _ in range(3): 90 | visualizer.add_example(x, y, t) 91 | visualizer.save('test_regression.png') 92 | 93 | 94 | def main(): 95 | test_classification_sparse() 96 | test_classification() 97 | test_regression() 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | --------------------------------------------------------------------------------