├── .gitignore ├── WHAT_asset ├── al.png ├── com.png ├── ep.png ├── label.png └── psnr.png ├── WHAT_src ├── config.py ├── data │ ├── __init__.py │ ├── data_camvid.py │ └── data_nyu.py ├── loss │ ├── __init__.py │ ├── mse.py │ └── mse_var.py ├── main.py ├── model │ ├── __init__.py │ ├── aleatoric.py │ ├── combined.py │ ├── common.py │ ├── epistemic.py │ └── normal.py ├── op.py └── util.py └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | 4 | WHAT_exp 5 | WHAT_papr 6 | WHAT_ref 7 | WHAT_plan.md 8 | 9 | *.pyc 10 | *.xml 11 | -------------------------------------------------------------------------------- /WHAT_asset/al.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/al.png -------------------------------------------------------------------------------- /WHAT_asset/com.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/com.png -------------------------------------------------------------------------------- /WHAT_asset/ep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/ep.png -------------------------------------------------------------------------------- /WHAT_asset/label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/label.png -------------------------------------------------------------------------------- /WHAT_asset/psnr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/psnr.png -------------------------------------------------------------------------------- /WHAT_src/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from distutils.util import strtobool 4 | 5 | parser = argparse.ArgumentParser() 6 | 7 | # Environment 8 | parser.add_argument("--is_train", type=strtobool, default='true') 9 | parser.add_argument("--tensorboard", type=strtobool, default='true') 10 | parser.add_argument('--cpu', action='store_true', help='use cpu only') 11 | parser.add_argument('--gpu', type=int, default=0) 12 | parser.add_argument("--num_gpu", type=int, default=1) 13 | parser.add_argument("--num_work", type=int, default=8) 14 | parser.add_argument("--exp_dir", type=str, default="../WHAT_exp") 15 | parser.add_argument("--exp_load", type=str, default=None) 16 | 17 | # Data 18 | parser.add_argument("--data_dir", type=str, default="/mnt/sda") 19 | parser.add_argument("--data_name", type=str, default="fashion_mnist") 20 | parser.add_argument('--batch_size', type=int, default=32) 21 | parser.add_argument('--rgb_range', type=int, default=1) 22 | 23 | # Model 24 | parser.add_argument('--uncertainty', default='normal', 25 | choices=('normal', 'epistemic', 'aleatoric', 'combined')) 26 | parser.add_argument('--in_channels', type=int, default=1) 27 | parser.add_argument('--n_feats', type=int, default=32) 28 | parser.add_argument('--var_weight', type=float, default=1.) 29 | parser.add_argument('--drop_rate', type=float, default=0.2) 30 | 31 | # Train 32 | parser.add_argument("--epochs", type=int, default=200) 33 | parser.add_argument("--lr", type=float, default=1e-3) 34 | parser.add_argument("--decay", type=str, default='50-100-150-200') 35 | parser.add_argument("--gamma", type=float, default=0.5) 36 | parser.add_argument("--optimizer", type=str, default='rmsprop', 37 | choices=('sgd', 'adam', 'rmsprop')) 38 | parser.add_argument("--weight_decay", type=float, default=1e-4) 39 | parser.add_argument("--momentum", type=float, default=0.9) 40 | parser.add_argument("--betas", type=tuple, default=(0.9, 0.999)) 41 | parser.add_argument("--epsilon", type=float, default=1e-8) 42 | 43 | # Test 44 | parser.add_argument('--n_samples', type=int, default=25) 45 | 46 | 47 | def save_args(obj, defaults, kwargs): 48 | for k,v in defaults.iteritems(): 49 | if k in kwargs: v = kwargs[k] 50 | setattr(obj, k, v) 51 | 52 | 53 | def get_config(): 54 | config = parser.parse_args() 55 | config.data_dir = os.path.expanduser(config.data_dir) 56 | return config 57 | -------------------------------------------------------------------------------- /WHAT_src/data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision.datasets as dset 4 | import torchvision.transforms as transforms 5 | 6 | 7 | def get_dataloader(config): 8 | data_dir = config.data_dir 9 | batch_size = config.batch_size 10 | 11 | trans = transforms.Compose([transforms.ToTensor(), 12 | transforms.Normalize((0.0,), (1.0,))]) 13 | 14 | if config.data_name == 'mnist': 15 | train_dataset = dset.MNIST(root=data_dir, train=True, transform=trans, download=True) 16 | test_dataset = dset.MNIST(root=data_dir, train=False, transform=trans, download=True) 17 | elif config.data_name == 'fashion_mnist': 18 | train_dataset = dset.FashionMNIST(root=data_dir, train=True, transform=trans, download=True) 19 | test_dataset = dset.FashionMNIST(root=data_dir, train=False, transform=trans, download=True) 20 | 21 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 22 | num_workers=config.num_work, shuffle=True) 23 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, 24 | num_workers=config.num_work, shuffle=False) 25 | 26 | print('==>>> total trainning batch number: {}'.format(len(train_loader))) 27 | print('==>>> total testing batch number: {}'.format(len(test_loader))) 28 | 29 | data_loader = {'train': train_loader, 'test': test_loader} 30 | 31 | return data_loader 32 | -------------------------------------------------------------------------------- /WHAT_src/data/data_camvid.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_src/data/data_camvid.py -------------------------------------------------------------------------------- /WHAT_src/data/data_nyu.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://pytorch.org/tutorials/beginner/data_loading_tutorial.html 3 | ''' 4 | 5 | import h5py 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | FILE_PATH = '/Users/chokiheum/Data/nyu_v2/nyu_depth_v2_labeled.mat' 10 | 11 | 12 | class NYU_v2(Dataset): 13 | def __init__(self, file_path, transform=None): 14 | self.transform = transform 15 | self.file_path = file_path 16 | 17 | def __getitem__(self, idx): 18 | with h5py.File(self.file_path, 'r') as f: 19 | depth = torch.from_numpy(f['depths'][idx].astype('float32')) 20 | image = torch.from_numpy(f['images'][idx].astype('float32')) 21 | label = torch.from_numpy(f['labels'][idx].astype('float32')) 22 | 23 | sample = {'depth': depth, 'image': image, 'label': label} 24 | 25 | if self.transform: 26 | sample = self.transform(sample) 27 | 28 | return sample 29 | 30 | def __len__(self): 31 | with h5py.File(self.file_path, 'r') as f: 32 | length = len(f['images']) 33 | return length 34 | 35 | 36 | if __name__ == '__main__': 37 | nyu_dataset = NYU_v2(file_path=FILE_PATH) 38 | dataloader = DataLoader(nyu_dataset, 39 | batch_size=4, shuffle=True, num_workers=4) 40 | 41 | for i_batch, sample_batched in enumerate(dataloader): 42 | print(sample_batched['image']) 43 | -------------------------------------------------------------------------------- /WHAT_src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from importlib import import_module 4 | 5 | 6 | class Loss(nn.Module): 7 | def __init__(self, config): 8 | super(Loss, self).__init__() 9 | print('Preparing loss function...') 10 | 11 | self.num_gpu = config.num_gpu 12 | self.losses = [] 13 | self.loss_module = nn.ModuleList() 14 | 15 | if config.uncertainty == 'epistemic' or config.uncertainty == 'normal': 16 | module = import_module('loss.mse') 17 | loss_function = getattr(module, 'MSE')() 18 | else: 19 | module = import_module('loss.mse_var') 20 | loss_function = getattr(module, 'MSE_VAR')( 21 | var_weight=config.var_weight) 22 | 23 | self.losses.append({'function': loss_function}) 24 | 25 | self.loss_module.to(config.device) 26 | if not config.cpu and config.num_gpu > 1: 27 | self.loss_module = nn.DataParallel( 28 | self.loss_module, range(self.num_gpu)) 29 | 30 | def forward(self, results, label): 31 | losses = [] 32 | for i, l in enumerate(self.losses): 33 | if l['function'] is not None: 34 | loss = l['function'](results, label) 35 | effective_loss = loss 36 | losses.append(effective_loss) 37 | 38 | loss_sum = sum(losses) 39 | if len(self.losses) > 1: 40 | self.log[-1, -1] += loss_sum.item() 41 | 42 | return loss_sum 43 | -------------------------------------------------------------------------------- /WHAT_src/loss/mse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MSE(nn.Module): 7 | def __init__(self): 8 | super(MSE, self).__init__() 9 | 10 | def forward(self, results, label): 11 | mean = results['mean'] 12 | loss = F.mse_loss(mean, label) 13 | return loss 14 | -------------------------------------------------------------------------------- /WHAT_src/loss/mse_var.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MSE_VAR(nn.Module): 7 | def __init__(self, var_weight): 8 | super(MSE_VAR, self).__init__() 9 | self.var_weight = var_weight 10 | 11 | def forward(self, results, label): 12 | mean, var = results['mean'], results['var'] 13 | var = self.var_weight * var 14 | 15 | loss1 = torch.mul(torch.exp(-var), (mean - label) ** 2) 16 | loss2 = var 17 | loss = .5 * (loss1 + loss2) 18 | return loss.mean() 19 | 20 | -------------------------------------------------------------------------------- /WHAT_src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from config import get_config 4 | from data import get_dataloader 5 | from op import Operator 6 | from util import Checkpoint 7 | 8 | 9 | def main(config): 10 | config.device = torch.device('cuda:{}'.format(config.gpu) 11 | if torch.cuda.is_available() else 'cpu') 12 | 13 | # load data_loader 14 | data_loader = get_dataloader(config) 15 | check_point = Checkpoint(config) 16 | operator = Operator(config, check_point) 17 | 18 | if config.is_train: 19 | operator.train(data_loader) 20 | else: 21 | operator.test(data_loader) 22 | 23 | 24 | if __name__ == "__main__": 25 | config = get_config() 26 | main(config) 27 | -------------------------------------------------------------------------------- /WHAT_src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, config): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.is_train = config.is_train 15 | self.num_gpu = config.num_gpu 16 | self.uncertainty = config.uncertainty 17 | self.n_samples = config.n_samples 18 | module = import_module('model.' + config.uncertainty) 19 | self.model = module.make_model(config).to(config.device) 20 | 21 | def forward(self, input): 22 | if self.model.training: 23 | if self.num_gpu > 1: 24 | return P.data_parallel(self.model, input, 25 | list(range(self.num_gpu))) 26 | else: 27 | return self.model.forward(input) 28 | else: 29 | forward_func = self.model.forward 30 | if self.uncertainty == 'normal': 31 | return forward_func(input) 32 | if self.uncertainty == 'aleatoric': 33 | return self.test_aleatoric(input, forward_func) 34 | elif self.uncertainty == 'epistemic': 35 | return self.test_epistemic(input, forward_func) 36 | elif self.uncertainty == 'combined': 37 | return self.test_combined(input, forward_func) 38 | 39 | def test_aleatoric(self, input, forward_func): 40 | results = forward_func(input) 41 | mean1 = results['mean'] 42 | var1 = torch.exp(results['var']) 43 | var1_norm = var1 / var1.max() 44 | results = {'mean': mean1, 'var': var1_norm} 45 | return results 46 | 47 | def test_epistemic(self, input, forward_func): 48 | mean1s = [] 49 | mean2s = [] 50 | 51 | for i_sample in range(self.n_samples): 52 | results = forward_func(input) 53 | mean1 = results['mean'] 54 | mean1s.append(mean1 ** 2) 55 | mean2s.append(mean1) 56 | 57 | mean1s_ = torch.stack(mean1s, dim=0).mean(dim=0) 58 | mean2s_ = torch.stack(mean2s, dim=0).mean(dim=0) 59 | 60 | var1 = mean1s_ - mean2s_ ** 2 61 | var1_norm = var1 / var1.max() 62 | results = {'mean': mean2s_, 'var': var1_norm} 63 | return results 64 | 65 | def test_combined(self, input, forward_func): 66 | mean1s = [] 67 | mean2s = [] 68 | var1s = [] 69 | 70 | for i_sample in range(self.n_samples): 71 | results = forward_func(input) 72 | mean1 = results['mean'] 73 | mean1s.append(mean1 ** 2) 74 | mean2s.append(mean1) 75 | var1 = results['var'] 76 | var1s.append(torch.exp(var1)) 77 | 78 | mean1s_ = torch.stack(mean1s, dim=0).mean(dim=0) 79 | mean2s_ = torch.stack(mean2s, dim=0).mean(dim=0) 80 | 81 | var1s_ = torch.stack(var1s, dim=0).mean(dim=0) 82 | var2 = mean1s_ - mean2s_ ** 2 83 | var_ = var1s_ + var2 84 | var_norm = var_ / var_.max() 85 | results = {'mean': mean2s_, 'var': var_norm} 86 | return results 87 | 88 | def save(self, ckpt, epoch): 89 | save_dirs = [os.path.join(ckpt.model_dir, 'model_latest.pt')] 90 | save_dirs.append( 91 | os.path.join(ckpt.model_dir, 'model_{}.pt'.format(epoch))) 92 | for s in save_dirs: 93 | torch.save(self.model.state_dict(), s) 94 | 95 | def load(self, ckpt, cpu=False): 96 | epoch = ckpt.last_epoch 97 | kwargs = {} 98 | if cpu: 99 | kwargs = {'map_location': lambda storage, loc: storage} 100 | if epoch == -1: 101 | load_from = torch.load( 102 | os.path.join(ckpt.model_dir, 'model_latest.pt'), **kwargs) 103 | else: 104 | load_from = torch.load( 105 | os.path.join(ckpt.model_dir, 'model_{}.pt'.format(epoch)), **kwargs) 106 | if load_from: 107 | self.model.load_state_dict(load_from, strict=False) 108 | -------------------------------------------------------------------------------- /WHAT_src/model/aleatoric.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from model import common 5 | 6 | 7 | def make_model(args): 8 | return ALEATORIC(args) 9 | 10 | 11 | class ALEATORIC(nn.Module): 12 | def __init__(self, config): 13 | super(ALEATORIC, self).__init__() 14 | self.drop_rate = config.drop_rate 15 | in_channels = config.in_channels 16 | filter_config = (64, 128) 17 | 18 | self.encoders = nn.ModuleList() 19 | self.decoders_mean = nn.ModuleList() 20 | self.decoders_var = nn.ModuleList() 21 | 22 | # setup number of conv-bn-relu blocks per module and number of filters 23 | encoder_n_layers = (2, 2, 3, 3, 3) 24 | encoder_filter_config = (in_channels,) + filter_config 25 | decoder_n_layers = (3, 3, 3, 2, 1) 26 | decoder_filter_config = filter_config[::-1] + (filter_config[0],) 27 | 28 | for i in range(0, 2): 29 | # encoder architecture 30 | self.encoders.append(_Encoder(encoder_filter_config[i], 31 | encoder_filter_config[i + 1], 32 | encoder_n_layers[i])) 33 | 34 | # decoder architecture 35 | self.decoders_mean.append(_Decoder(decoder_filter_config[i], 36 | decoder_filter_config[i + 1], 37 | decoder_n_layers[i])) 38 | 39 | # decoder architecture 40 | self.decoders_var.append(_Decoder(decoder_filter_config[i], 41 | decoder_filter_config[i + 1], 42 | decoder_n_layers[i])) 43 | 44 | # final classifier (equivalent to a fully connected layer) 45 | self.classifier_mean = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1) 46 | self.classifier_var = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1) 47 | 48 | def forward(self, x): 49 | indices = [] 50 | unpool_sizes = [] 51 | feat = x 52 | 53 | # encoder path, keep track of pooling indices and features size 54 | for i in range(0, 2): 55 | (feat, ind), size = self.encoders[i](feat) 56 | if i == 1: 57 | feat = F.dropout(feat, p=self.drop_rate) 58 | indices.append(ind) 59 | unpool_sizes.append(size) 60 | 61 | feat_mean = feat 62 | feat_var = feat 63 | # decoder path, upsampling with corresponding indices and size 64 | for i in range(0, 2): 65 | feat_mean = self.decoders_mean[i](feat_mean, indices[1 - i], unpool_sizes[1 - i]) 66 | feat_var = self.decoders_var[i](feat_var, indices[1 - i], unpool_sizes[1 - i]) 67 | if i == 0: 68 | feat_mean = F.dropout(feat_mean, p=self.drop_rate) 69 | feat_var = F.dropout(feat_var, p=self.drop_rate) 70 | 71 | output_mean = self.classifier_mean(feat_mean) 72 | output_var = self.classifier_var(feat_var) 73 | 74 | results = {'mean': output_mean, 'var': output_var} 75 | return results 76 | 77 | 78 | class _Encoder(nn.Module): 79 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 80 | """Encoder layer follows VGG rules + keeps pooling indices 81 | Args: 82 | n_in_feat (int): number of input features 83 | n_out_feat (int): number of output features 84 | n_blocks (int): number of conv-batch-relu block inside the encoder 85 | drop_rate (float): dropout rate to use 86 | """ 87 | super(_Encoder, self).__init__() 88 | 89 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 90 | nn.BatchNorm2d(n_out_feat), 91 | nn.ReLU()] 92 | 93 | if n_blocks > 1: 94 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1), 95 | nn.BatchNorm2d(n_out_feat), 96 | nn.ReLU()] 97 | 98 | self.features = nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | output = self.features(x) 102 | return F.max_pool2d(output, 2, 2, return_indices=True), output.size() 103 | 104 | 105 | class _Decoder(nn.Module): 106 | """Decoder layer decodes the features by unpooling with respect to 107 | the pooling indices of the corresponding decoder part. 108 | Args: 109 | n_in_feat (int): number of input features 110 | n_out_feat (int): number of output features 111 | n_blocks (int): number of conv-batch-relu block inside the decoder 112 | drop_rate (float): dropout rate to use 113 | """ 114 | 115 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 116 | super(_Decoder, self).__init__() 117 | 118 | layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1), 119 | nn.BatchNorm2d(n_in_feat), 120 | nn.ReLU()] 121 | 122 | if n_blocks > 1: 123 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 124 | nn.BatchNorm2d(n_out_feat), 125 | nn.ReLU()] 126 | 127 | self.features = nn.Sequential(*layers) 128 | 129 | def forward(self, x, indices, size): 130 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size) 131 | return self.features(unpooled) 132 | -------------------------------------------------------------------------------- /WHAT_src/model/combined.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from model import common 5 | 6 | 7 | def make_model(args): 8 | return COMBINED(args) 9 | 10 | 11 | class COMBINED(nn.Module): 12 | def __init__(self, config): 13 | super(COMBINED, self).__init__() 14 | self.drop_rate = config.drop_rate 15 | in_channels = config.in_channels 16 | filter_config = (64, 128) 17 | 18 | self.encoders = nn.ModuleList() 19 | self.decoders_mean = nn.ModuleList() 20 | self.decoders_var = nn.ModuleList() 21 | 22 | # setup number of conv-bn-relu blocks per module and number of filters 23 | encoder_n_layers = (2, 2, 3, 3, 3) 24 | encoder_filter_config = (in_channels,) + filter_config 25 | decoder_n_layers = (3, 3, 3, 2, 1) 26 | decoder_filter_config = filter_config[::-1] + (filter_config[0],) 27 | 28 | for i in range(0, 2): 29 | # encoder architecture 30 | self.encoders.append(_Encoder(encoder_filter_config[i], 31 | encoder_filter_config[i + 1], 32 | encoder_n_layers[i])) 33 | 34 | # decoder architecture 35 | self.decoders_mean.append(_Decoder(decoder_filter_config[i], 36 | decoder_filter_config[i + 1], 37 | decoder_n_layers[i])) 38 | 39 | # decoder architecture 40 | self.decoders_var.append(_Decoder(decoder_filter_config[i], 41 | decoder_filter_config[i + 1], 42 | decoder_n_layers[i])) 43 | 44 | # final classifier (equivalent to a fully connected layer) 45 | self.classifier_mean = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1) 46 | self.classifier_var = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1) 47 | 48 | def forward(self, x): 49 | indices = [] 50 | unpool_sizes = [] 51 | feat = x 52 | 53 | # encoder path, keep track of pooling indices and features size 54 | for i in range(0, 2): 55 | (feat, ind), size = self.encoders[i](feat) 56 | if i == 1: 57 | feat = F.dropout(feat, p=self.drop_rate, training=True) 58 | indices.append(ind) 59 | unpool_sizes.append(size) 60 | 61 | feat_mean = feat 62 | feat_var = feat 63 | # decoder path, upsampling with corresponding indices and size 64 | for i in range(0, 2): 65 | feat_mean = self.decoders_mean[i](feat_mean, indices[1 - i], unpool_sizes[1 - i]) 66 | feat_var = self.decoders_var[i](feat_var, indices[1 - i], unpool_sizes[1 - i]) 67 | if i == 0: 68 | feat_mean = F.dropout(feat_mean, p=self.drop_rate, training=True) 69 | feat_var = F.dropout(feat_var, p=self.drop_rate, training=True) 70 | 71 | output_mean = self.classifier_mean(feat_mean) 72 | output_var = self.classifier_var(feat_var) 73 | 74 | results = {'mean': output_mean, 'var': output_var} 75 | return results 76 | 77 | 78 | class _Encoder(nn.Module): 79 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 80 | """Encoder layer follows VGG rules + keeps pooling indices 81 | Args: 82 | n_in_feat (int): number of input features 83 | n_out_feat (int): number of output features 84 | n_blocks (int): number of conv-batch-relu block inside the encoder 85 | drop_rate (float): dropout rate to use 86 | """ 87 | super(_Encoder, self).__init__() 88 | 89 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 90 | nn.BatchNorm2d(n_out_feat), 91 | nn.ReLU()] 92 | 93 | if n_blocks > 1: 94 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1), 95 | nn.BatchNorm2d(n_out_feat), 96 | nn.ReLU()] 97 | 98 | self.features = nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | output = self.features(x) 102 | return F.max_pool2d(output, 2, 2, return_indices=True), output.size() 103 | 104 | 105 | class _Decoder(nn.Module): 106 | """Decoder layer decodes the features by unpooling with respect to 107 | the pooling indices of the corresponding decoder part. 108 | Args: 109 | n_in_feat (int): number of input features 110 | n_out_feat (int): number of output features 111 | n_blocks (int): number of conv-batch-relu block inside the decoder 112 | drop_rate (float): dropout rate to use 113 | """ 114 | 115 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 116 | super(_Decoder, self).__init__() 117 | 118 | layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1), 119 | nn.BatchNorm2d(n_in_feat), 120 | nn.ReLU()] 121 | 122 | if n_blocks > 1: 123 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 124 | nn.BatchNorm2d(n_out_feat), 125 | nn.ReLU()] 126 | 127 | self.features = nn.Sequential(*layers) 128 | 129 | def forward(self, x, indices, size): 130 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size) 131 | return self.features(unpooled) 132 | -------------------------------------------------------------------------------- /WHAT_src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 9 | return nn.Conv2d( 10 | in_channels, out_channels, kernel_size, 11 | padding=(kernel_size//2), bias=bias) 12 | 13 | 14 | class BasicBlock(nn.Sequential): 15 | def __init__( 16 | self, conv, in_channels, out_channels, kernel_size, 17 | bias=False, bn=True, act=nn.ReLU(True)): 18 | 19 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 20 | if bn: 21 | m.append(nn.BatchNorm2d(out_channels)) 22 | if act is not None: 23 | m.append(act) 24 | 25 | super(BasicBlock, self).__init__(*m) 26 | 27 | 28 | class ResBlock(nn.Module): 29 | def __init__( 30 | self, conv, n_feats, kernel_size, 31 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 32 | 33 | super(ResBlock, self).__init__() 34 | m = [] 35 | for i in range(2): 36 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 37 | if bn: 38 | m.append(nn.BatchNorm2d(n_feats)) 39 | if i == 0: 40 | m.append(act) 41 | 42 | self.body = nn.Sequential(*m) 43 | self.res_scale = res_scale 44 | 45 | def forward(self, x): 46 | res = self.body(x).mul(self.res_scale) 47 | res += x 48 | 49 | return res 50 | 51 | 52 | class Upsampler(nn.Sequential): 53 | def __init__(self, conv, scale, n_feats, 54 | bn=False, act=False, bias=True): 55 | m = [] 56 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 57 | for _ in range(int(math.log(scale, 2))): 58 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 59 | m.append(nn.PixelShuffle(2)) 60 | if bn: 61 | m.append(nn.BatchNorm2d(n_feats)) 62 | if act == 'relu': 63 | m.append(nn.ReLU(True)) 64 | elif act == 'prelu': 65 | m.append(nn.PReLU(n_feats)) 66 | 67 | elif scale == 3: 68 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 69 | m.append(nn.PixelShuffle(3)) 70 | if bn: 71 | m.append(nn.BatchNorm2d(n_feats)) 72 | if act == 'relu': 73 | m.append(nn.ReLU(True)) 74 | elif act == 'prelu': 75 | m.append(nn.PReLU(n_feats)) 76 | else: 77 | raise NotImplementedError 78 | 79 | super(Upsampler, self).__init__(*m) 80 | -------------------------------------------------------------------------------- /WHAT_src/model/epistemic.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from model import common 5 | 6 | 7 | def make_model(args): 8 | return EPISTEMIC(args) 9 | 10 | 11 | class EPISTEMIC(nn.Module): 12 | def __init__(self, config): 13 | super(EPISTEMIC, self).__init__() 14 | self.drop_rate = config.drop_rate 15 | in_channels = config.in_channels 16 | filter_config = (64, 128) 17 | 18 | self.encoders = nn.ModuleList() 19 | self.decoders = nn.ModuleList() 20 | 21 | # setup number of conv-bn-relu blocks per module and number of filters 22 | encoder_n_layers = (2, 2, 3, 3, 3) 23 | encoder_filter_config = (in_channels,) + filter_config 24 | decoder_n_layers = (3, 3, 3, 2, 1) 25 | decoder_filter_config = filter_config[::-1] + (filter_config[0],) 26 | 27 | for i in range(0, 2): 28 | # encoder architecture 29 | self.encoders.append(_Encoder(encoder_filter_config[i], 30 | encoder_filter_config[i + 1], 31 | encoder_n_layers[i])) 32 | 33 | # decoder architecture 34 | self.decoders.append(_Decoder(decoder_filter_config[i], 35 | decoder_filter_config[i + 1], 36 | decoder_n_layers[i])) 37 | 38 | # final classifier (equivalent to a fully connected layer) 39 | self.classifier = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1) 40 | 41 | def forward(self, x): 42 | indices = [] 43 | unpool_sizes = [] 44 | feat = x 45 | 46 | # encoder path, keep track of pooling indices and features size 47 | for i in range(0, 2): 48 | (feat, ind), size = self.encoders[i](feat) 49 | if i == 1: 50 | feat = F.dropout(feat, p=self.drop_rate, training=True) 51 | indices.append(ind) 52 | unpool_sizes.append(size) 53 | 54 | # decoder path, upsampling with corresponding indices and size 55 | for i in range(0, 2): 56 | feat = self.decoders[i](feat, indices[1 - i], unpool_sizes[1 - i]) 57 | if i == 0: 58 | feat = F.dropout(feat, p=self.drop_rate, training=True) 59 | 60 | output = self.classifier(feat) 61 | results = {'mean': output} 62 | 63 | return results 64 | 65 | 66 | class _Encoder(nn.Module): 67 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 68 | """Encoder layer follows VGG rules + keeps pooling indices 69 | Args: 70 | n_in_feat (int): number of input features 71 | n_out_feat (int): number of output features 72 | n_blocks (int): number of conv-batch-relu block inside the encoder 73 | drop_rate (float): dropout rate to use 74 | """ 75 | super(_Encoder, self).__init__() 76 | 77 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 78 | nn.BatchNorm2d(n_out_feat), 79 | nn.ReLU()] 80 | 81 | if n_blocks > 1: 82 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1), 83 | nn.BatchNorm2d(n_out_feat), 84 | nn.ReLU()] 85 | 86 | self.features = nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | output = self.features(x) 90 | return F.max_pool2d(output, 2, 2, return_indices=True), output.size() 91 | 92 | 93 | class _Decoder(nn.Module): 94 | """Decoder layer decodes the features by unpooling with respect to 95 | the pooling indices of the corresponding decoder part. 96 | Args: 97 | n_in_feat (int): number of input features 98 | n_out_feat (int): number of output features 99 | n_blocks (int): number of conv-batch-relu block inside the decoder 100 | drop_rate (float): dropout rate to use 101 | """ 102 | 103 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 104 | super(_Decoder, self).__init__() 105 | 106 | layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1), 107 | nn.BatchNorm2d(n_in_feat), 108 | nn.ReLU()] 109 | 110 | if n_blocks > 1: 111 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 112 | nn.BatchNorm2d(n_out_feat), 113 | nn.ReLU()] 114 | 115 | self.features = nn.Sequential(*layers) 116 | 117 | def forward(self, x, indices, size): 118 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size) 119 | return self.features(unpooled) 120 | -------------------------------------------------------------------------------- /WHAT_src/model/normal.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from model import common 5 | 6 | 7 | def make_model(args): 8 | return NORMAL(args) 9 | 10 | 11 | class NORMAL(nn.Module): 12 | def __init__(self, config): 13 | super(NORMAL, self).__init__() 14 | self.drop_rate = config.drop_rate 15 | in_channels = config.in_channels 16 | filter_config = (64, 128) 17 | 18 | self.encoders = nn.ModuleList() 19 | self.decoders = nn.ModuleList() 20 | 21 | # setup number of conv-bn-relu blocks per module and number of filters 22 | encoder_n_layers = (2, 2, 2, 2, 2) 23 | encoder_filter_config = (in_channels,) + filter_config 24 | decoder_n_layers = (2, 2, 2, 2, 1) 25 | decoder_filter_config = filter_config[::-1] + (filter_config[0],) 26 | 27 | for i in range(0, 2): 28 | # encoder architecture 29 | self.encoders.append(_Encoder(encoder_filter_config[i], 30 | encoder_filter_config[i + 1], 31 | encoder_n_layers[i])) 32 | 33 | # decoder architecture 34 | self.decoders.append(_Decoder(decoder_filter_config[i], 35 | decoder_filter_config[i + 1], 36 | decoder_n_layers[i])) 37 | 38 | # final classifier (equivalent to a fully connected layer) 39 | self.classifier = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1) 40 | 41 | def forward(self, x): 42 | indices = [] 43 | unpool_sizes = [] 44 | feat = x 45 | 46 | # encoder path, keep track of pooling indices and features size 47 | for i in range(0, 2): 48 | (feat, ind), size = self.encoders[i](feat) 49 | if i == 1: 50 | feat = F.dropout(feat, p=self.drop_rate) 51 | indices.append(ind) 52 | unpool_sizes.append(size) 53 | 54 | # decoder path, upsampling with corresponding indices and size 55 | for i in range(0, 2): 56 | feat = self.decoders[i](feat, indices[1 - i], unpool_sizes[1 - i]) 57 | if i == 0: 58 | feat = F.dropout(feat, p=self.drop_rate) 59 | 60 | output = self.classifier(feat) 61 | results = {'mean': output} 62 | 63 | return results 64 | 65 | 66 | class _Encoder(nn.Module): 67 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 68 | """Encoder layer follows VGG rules + keeps pooling indices 69 | Args: 70 | n_in_feat (int): number of input features 71 | n_out_feat (int): number of output features 72 | n_blocks (int): number of conv-batch-relu block inside the encoder 73 | drop_rate (float): dropout rate to use 74 | """ 75 | super(_Encoder, self).__init__() 76 | 77 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 78 | nn.BatchNorm2d(n_out_feat), 79 | nn.ReLU(inplace=True)] 80 | 81 | if n_blocks > 1: 82 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1), 83 | nn.BatchNorm2d(n_out_feat), 84 | nn.ReLU(inplace=True)] 85 | self.features = nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | output = self.features(x) 89 | return F.max_pool2d(output, 2, 2, return_indices=True), output.size() 90 | 91 | 92 | class _Decoder(nn.Module): 93 | """Decoder layer decodes the features by unpooling with respect to 94 | the pooling indices of the corresponding decoder part. 95 | Args: 96 | n_in_feat (int): number of input features 97 | n_out_feat (int): number of output features 98 | n_blocks (int): number of conv-batch-relu block inside the decoder 99 | drop_rate (float): dropout rate to use 100 | """ 101 | 102 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 103 | super(_Decoder, self).__init__() 104 | 105 | layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1), 106 | nn.BatchNorm2d(n_in_feat), 107 | nn.ReLU(inplace=True)] 108 | 109 | if n_blocks > 1: 110 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 111 | nn.BatchNorm2d(n_out_feat), 112 | nn.ReLU(inplace=True)] 113 | self.features = nn.Sequential(*layers) 114 | 115 | def forward(self, x, indices, size): 116 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size) 117 | return self.features(unpooled) 118 | -------------------------------------------------------------------------------- /WHAT_src/op.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.tensorboard import SummaryWriter 3 | 4 | from model import * 5 | from loss import Loss 6 | from util import make_optimizer, calc_psnr, summary 7 | 8 | 9 | class Operator: 10 | def __init__(self, config, ckeck_point): 11 | self.config = config 12 | self.epochs = config.epochs 13 | self.uncertainty = config.uncertainty 14 | self.ckpt = ckeck_point 15 | self.tensorboard = config.tensorboard 16 | if self.tensorboard: 17 | self.summary_writer = SummaryWriter(self.ckpt.log_dir, 300) 18 | 19 | # set model, criterion, optimizer 20 | self.model = Model(config) 21 | summary(self.model, config_file=self.ckpt.config_file) 22 | 23 | # set criterion, optimizer 24 | self.criterion = Loss(config) 25 | self.optimizer = make_optimizer(config, self.model) 26 | 27 | # load ckpt, model, optimizer 28 | if self.ckpt.exp_load is not None or not config.is_train: 29 | print("Loading model... ") 30 | self.load(self.ckpt) 31 | print(self.ckpt.last_epoch, self.ckpt.global_step) 32 | 33 | def train(self, data_loader): 34 | last_epoch = self.ckpt.last_epoch 35 | train_batch_num = len(data_loader['train']) 36 | 37 | for epoch in range(last_epoch, self.epochs): 38 | for batch_idx, batch_data in enumerate(data_loader['train']): 39 | batch_input, batch_label = batch_data 40 | batch_input = batch_input.to(self.config.device) 41 | batch_label = batch_label.to(self.config.device) 42 | 43 | # forward 44 | batch_results = self.model(batch_input) 45 | loss = self.criterion(batch_results, batch_input) 46 | 47 | # backward 48 | self.optimizer.zero_grad() 49 | loss.backward() 50 | self.optimizer.step() 51 | print('Epoch: {:03d}/{:03d}, Iter: {:03d}/{:03d}, Loss: {:5f}' 52 | .format(epoch, self.config.epochs, 53 | batch_idx, train_batch_num, 54 | loss.item())) 55 | 56 | # use tensorboard 57 | if self.tensorboard: 58 | current_global_step = self.ckpt.step() 59 | self.summary_writer.add_scalar('train/loss', 60 | loss, current_global_step) 61 | self.summary_writer.add_images("train/input_img", 62 | batch_input, 63 | current_global_step) 64 | self.summary_writer.add_images("train/mean_img", 65 | torch.clamp(batch_results['mean'], 0., 1.), 66 | current_global_step) 67 | 68 | # use tensorboard 69 | if self.tensorboard: 70 | print(self.optimizer.get_lr(), epoch) 71 | self.summary_writer.add_scalar('epoch_lr', 72 | self.optimizer.get_lr(), epoch) 73 | 74 | # test model & save model 75 | self.optimizer.schedule() 76 | self.save(self.ckpt, epoch) 77 | self.test(data_loader) 78 | self.model.train() 79 | 80 | self.summary_writer.close() 81 | 82 | def test(self, data_loader): 83 | with torch.no_grad(): 84 | self.model.eval() 85 | 86 | total_psnr = 0. 87 | psnrs = [] 88 | test_batch_num = len(data_loader['test']) 89 | for batch_idx, batch_data in enumerate(data_loader['test']): 90 | batch_input, batch_label = batch_data 91 | batch_input = batch_input.to(self.config.device) 92 | batch_label = batch_label.to(self.config.device) 93 | 94 | # forward 95 | batch_results = self.model(batch_input) 96 | current_psnr = calc_psnr(batch_results['mean'], batch_input) 97 | psnrs.append(current_psnr) 98 | total_psnr = sum(psnrs) / len(psnrs) 99 | print("Test iter: {:03d}/{:03d}, Total: {:5f}, Current: {:05f}".format( 100 | batch_idx, test_batch_num, 101 | total_psnr, psnrs[batch_idx])) 102 | 103 | # use tensorboard 104 | if self.tensorboard: 105 | self.summary_writer.add_scalar('test/psnr', 106 | total_psnr, self.ckpt.last_epoch) 107 | self.summary_writer.add_images("test/input_img", 108 | batch_input, self.ckpt.last_epoch) 109 | self.summary_writer.add_images("test/mean_img", 110 | torch.clamp(batch_results['mean'], 0., 1.), 111 | self.ckpt.last_epoch) 112 | if not self.uncertainty == 'normal': 113 | self.summary_writer.add_images("test/var_img", 114 | batch_results['var'], 115 | self.ckpt.last_epoch) 116 | 117 | def load(self, ckpt): 118 | ckpt.load() # load ckpt 119 | self.model.load(ckpt) # load model 120 | self.optimizer.load(ckpt) # load optimizer 121 | 122 | def save(self, ckpt, epoch): 123 | ckpt.save(epoch) # save ckpt: global_step, last_epoch 124 | self.model.save(ckpt, epoch) # save model: weight 125 | self.optimizer.save(ckpt) # save optimizer: 126 | 127 | 128 | -------------------------------------------------------------------------------- /WHAT_src/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import sys 4 | from datetime import datetime 5 | from functools import reduce 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as lrs 10 | from torch.nn.modules.module import _addindent 11 | 12 | 13 | class Checkpoint: 14 | def __init__(self, config): 15 | self.global_step = 0 16 | self.last_epoch = 0 17 | self.config = config 18 | self.exp_dir = config.exp_dir 19 | self.exp_load = config.exp_load 20 | exp_type = config.uncertainty 21 | now = datetime.now().strftime('%m%d_%H%M') 22 | 23 | if config.exp_load is None: 24 | dir_fmt = '{}/{}_{}'.format(config.data_name, exp_type, now) 25 | else: 26 | dir_fmt = '{}/{}_{}'.format(config.data_name, exp_type, self.exp_load) 27 | 28 | self.model_dir = os.path.join(self.exp_dir, dir_fmt, 'model') 29 | self.log_dir = os.path.join(self.exp_dir, dir_fmt, 'log') 30 | self.save_dir = os.path.join(self.exp_dir, dir_fmt, 'save') 31 | self.ckpt_dir = os.path.join(self.log_dir, 'ckpt.pt') 32 | 33 | os.makedirs(self.model_dir, exist_ok=True) 34 | os.makedirs(self.log_dir, exist_ok=True) 35 | os.makedirs(self.save_dir, exist_ok=True) 36 | 37 | # save config 38 | self.config_file = os.path.join(self.log_dir, 'config.txt') 39 | with open(self.config_file, 'w') as f: 40 | for k, v in vars(config).items(): 41 | f.writelines('{}: {} \n'.format(k, v)) 42 | 43 | def step(self): 44 | self.global_step += 1 45 | return self.global_step 46 | 47 | def save(self, epoch): 48 | self.last_epoch = epoch 49 | save_ckpt = {'global_step': self.global_step, 50 | 'last_epoch': self.last_epoch} 51 | torch.save(save_ckpt, self.ckpt_dir) 52 | 53 | def load(self): 54 | load_ckpt = torch.load(self.ckpt_dir) 55 | self.global_step = load_ckpt['global_step'] 56 | self.last_epoch = load_ckpt['last_epoch'] 57 | 58 | 59 | def calc_psnr(output, label, rgb_range=1.): 60 | if label.nelement() == 1: return 0 61 | 62 | diff = (output - label) / rgb_range 63 | mse = diff.pow(2).mean() 64 | 65 | return -10 * math.log10(mse) 66 | 67 | 68 | def make_optimizer(config, model): 69 | trainable = filter(lambda x: x.requires_grad, model.parameters()) 70 | kwargs_optimizer = {'lr': config.lr, 'weight_decay': config.weight_decay} 71 | 72 | if config.optimizer == 'sgd': 73 | optimizer_class = optim.SGD 74 | kwargs_optimizer['momentum'] = config.momentum 75 | elif config.optimizer == 'adam': 76 | optimizer_class = optim.Adam 77 | kwargs_optimizer['betas'] = config.betas 78 | kwargs_optimizer['eps'] = config.epsilon 79 | elif config.optimizer == 'rmsprop': 80 | optimizer_class = optim.RMSprop 81 | kwargs_optimizer['eps'] = config.epsilon 82 | 83 | # scheduler 84 | milestones = list(map(lambda x: int(x), config.decay.split('-'))) 85 | kwargs_scheduler = {'milestones': milestones, 'gamma': config.gamma} 86 | scheduler_class = lrs.MultiStepLR 87 | 88 | class CustomOptimizer(optimizer_class): 89 | def __init__(self, *args, **kwargs): 90 | super(CustomOptimizer, self).__init__(*args, **kwargs) 91 | 92 | def _register_scheduler(self, scheduler_class, **kwargs): 93 | self.scheduler = scheduler_class(self, **kwargs) 94 | 95 | def save(self, ckpt): 96 | save_dir = os.path.join(ckpt.model_dir, 'optimizer.pt') 97 | torch.save(self.state_dict(), save_dir) 98 | 99 | def load(self, ckpt): 100 | load_dir = os.path.join(ckpt.model_dir, 'optimizer.pt') 101 | epoch = ckpt.last_epoch 102 | self.load_state_dict(torch.load(load_dir)) 103 | if epoch > 1: 104 | for _ in range(epoch): self.scheduler.step() 105 | 106 | def schedule(self): 107 | self.scheduler.step() 108 | 109 | def get_lr(self): 110 | return self.scheduler.get_lr()[0] 111 | 112 | def get_last_epoch(self): 113 | return self.scheduler.last_epoch 114 | 115 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 116 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 117 | return optimizer 118 | 119 | 120 | def summary(model, config_file, file=sys.stdout): 121 | def repr(model): 122 | # We treat the extra repr like the sub-module, one item per line 123 | extra_lines = [] 124 | extra_repr = model.extra_repr() 125 | # empty string will be split into list [''] 126 | if extra_repr: 127 | extra_lines = extra_repr.split('\n') 128 | child_lines = [] 129 | total_params = 0 130 | for key, module in model._modules.items(): 131 | mod_str, num_params = repr(module) 132 | mod_str = _addindent(mod_str, 2) 133 | child_lines.append('(' + key + '): ' + mod_str) 134 | total_params += num_params 135 | lines = extra_lines + child_lines 136 | 137 | for name, p in model._parameters.items(): 138 | if hasattr(p, 'shape'): 139 | total_params += reduce(lambda x, y: x * y, p.shape) 140 | 141 | main_str = model._get_name() + '(' 142 | if lines: 143 | # simple one-liner info, which most builtin Modules will use 144 | if len(extra_lines) == 1 and not child_lines: 145 | main_str += extra_lines[0] 146 | else: 147 | main_str += '\n ' + '\n '.join(lines) + '\n' 148 | 149 | main_str += ')' 150 | if file is sys.stdout: 151 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params) 152 | else: 153 | main_str += ', {:,} params'.format(total_params) 154 | return main_str, total_params 155 | 156 | string, count = repr(model) 157 | print(string, file=open(config_file, 'a')) 158 | 159 | if file is not None: 160 | print(string, file=sys.stdout) 161 | file.flush() 162 | 163 | return count 164 | 165 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? 2 | 3 | Pytorch implementation of ["What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?", NIPS 2017](https://arxiv.org/abs/1703.04977) 4 | 5 | 6 | 7 | ## 1. Usage 8 | 9 | ``` 10 | # Data Tree 11 | config.data_dir/ 12 | └── config.data_name/ 13 | 14 | # Project Tree 15 | WHAT 16 | ├── WHAT_src/ 17 | │ ├── data/ *.py 18 | │ ├── loss/ *.py 19 | │ ├── model/ *.py 20 | │ └── *.py 21 | └── WHAT_exp/ 22 | ├── log/ 23 | ├── model/ 24 | └── save/ 25 | ``` 26 | 27 | 28 | 29 | ### 1.1 Train 30 | 31 | ``` 32 | # L2 loss only 33 | python train.py --uncertainty "normal" --drop_rate 0. 34 | 35 | # Epistemic / Aleatoric 36 | python train.py --uncertainty ["epistemic", "aleatoric"] 37 | 38 | # Epistemic + Aleatoric 39 | python train.py --uncertainty "combined" 40 | ``` 41 | 42 | 43 | 44 | ### 1.2 Test 45 | 46 | ``` 47 | # L2 loss only 48 | python train.py --is_train false --uncertainty "normal" 49 | 50 | # Epistemic 51 | python train.py --is_train false --uncertainty "epistemic" --n_samples 25 [or 5, 50] 52 | 53 | # Aleatoric 54 | python train.py --is_train false --uncertainty "aleatoric" 55 | 56 | # Epistemic + Aleatoric 57 | python train.py --is_train false --uncertainty "combined" --n_samples 25 [or 5, 50] 58 | ``` 59 | 60 | 61 | 62 | ### 1.3 Requirements 63 | 64 | - Python3.7 65 | 66 | - Pytorch >= 1.0 67 | - Torchvision 68 | - distutils 69 | 70 | 71 | 72 | ## 2. Experiment 73 | 74 | This is not official implementation. 75 | 76 | 77 | 78 | ### 2.1 Network & Datset 79 | 80 | - Autoencoder based on [Bayesian Segnet](https://arxiv.org/abs/1511.02680) 81 | 82 | - Network depth 2 (paper 5) 83 | - Drop_rate 0.2 (paper 0.5) 84 | 85 | - Fahsion MNIST / MNIST 86 | 87 | - Input = Label (for autoencoder) 88 | 89 | 90 | 91 | ### 2.2 Results 92 | 93 | #### 2.2.1 PSNR 94 | 95 | Combined > Aleatoric > Normal (w/o D.O) > Epistemic > Normal (w/ D.O) 96 | 97 | drawing 98 | 99 | 100 | 101 | #### 2.2.2 Images 102 | 103 | drawing Input / Label 104 | 105 | drawing Combined 106 | 107 | drawing Aleatoric 108 | 109 | drawingEpistemic 110 | 111 | 112 | 113 | --------------------------------------------------------------------------------