├── .gitignore
├── WHAT_asset
├── al.png
├── com.png
├── ep.png
├── label.png
└── psnr.png
├── WHAT_src
├── config.py
├── data
│ ├── __init__.py
│ ├── data_camvid.py
│ └── data_nyu.py
├── loss
│ ├── __init__.py
│ ├── mse.py
│ └── mse_var.py
├── main.py
├── model
│ ├── __init__.py
│ ├── aleatoric.py
│ ├── combined.py
│ ├── common.py
│ ├── epistemic.py
│ └── normal.py
├── op.py
└── util.py
└── readme.md
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 |
4 | WHAT_exp
5 | WHAT_papr
6 | WHAT_ref
7 | WHAT_plan.md
8 |
9 | *.pyc
10 | *.xml
11 |
--------------------------------------------------------------------------------
/WHAT_asset/al.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/al.png
--------------------------------------------------------------------------------
/WHAT_asset/com.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/com.png
--------------------------------------------------------------------------------
/WHAT_asset/ep.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/ep.png
--------------------------------------------------------------------------------
/WHAT_asset/label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/label.png
--------------------------------------------------------------------------------
/WHAT_asset/psnr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_asset/psnr.png
--------------------------------------------------------------------------------
/WHAT_src/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from distutils.util import strtobool
4 |
5 | parser = argparse.ArgumentParser()
6 |
7 | # Environment
8 | parser.add_argument("--is_train", type=strtobool, default='true')
9 | parser.add_argument("--tensorboard", type=strtobool, default='true')
10 | parser.add_argument('--cpu', action='store_true', help='use cpu only')
11 | parser.add_argument('--gpu', type=int, default=0)
12 | parser.add_argument("--num_gpu", type=int, default=1)
13 | parser.add_argument("--num_work", type=int, default=8)
14 | parser.add_argument("--exp_dir", type=str, default="../WHAT_exp")
15 | parser.add_argument("--exp_load", type=str, default=None)
16 |
17 | # Data
18 | parser.add_argument("--data_dir", type=str, default="/mnt/sda")
19 | parser.add_argument("--data_name", type=str, default="fashion_mnist")
20 | parser.add_argument('--batch_size', type=int, default=32)
21 | parser.add_argument('--rgb_range', type=int, default=1)
22 |
23 | # Model
24 | parser.add_argument('--uncertainty', default='normal',
25 | choices=('normal', 'epistemic', 'aleatoric', 'combined'))
26 | parser.add_argument('--in_channels', type=int, default=1)
27 | parser.add_argument('--n_feats', type=int, default=32)
28 | parser.add_argument('--var_weight', type=float, default=1.)
29 | parser.add_argument('--drop_rate', type=float, default=0.2)
30 |
31 | # Train
32 | parser.add_argument("--epochs", type=int, default=200)
33 | parser.add_argument("--lr", type=float, default=1e-3)
34 | parser.add_argument("--decay", type=str, default='50-100-150-200')
35 | parser.add_argument("--gamma", type=float, default=0.5)
36 | parser.add_argument("--optimizer", type=str, default='rmsprop',
37 | choices=('sgd', 'adam', 'rmsprop'))
38 | parser.add_argument("--weight_decay", type=float, default=1e-4)
39 | parser.add_argument("--momentum", type=float, default=0.9)
40 | parser.add_argument("--betas", type=tuple, default=(0.9, 0.999))
41 | parser.add_argument("--epsilon", type=float, default=1e-8)
42 |
43 | # Test
44 | parser.add_argument('--n_samples', type=int, default=25)
45 |
46 |
47 | def save_args(obj, defaults, kwargs):
48 | for k,v in defaults.iteritems():
49 | if k in kwargs: v = kwargs[k]
50 | setattr(obj, k, v)
51 |
52 |
53 | def get_config():
54 | config = parser.parse_args()
55 | config.data_dir = os.path.expanduser(config.data_dir)
56 | return config
57 |
--------------------------------------------------------------------------------
/WHAT_src/data/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import torchvision.datasets as dset
4 | import torchvision.transforms as transforms
5 |
6 |
7 | def get_dataloader(config):
8 | data_dir = config.data_dir
9 | batch_size = config.batch_size
10 |
11 | trans = transforms.Compose([transforms.ToTensor(),
12 | transforms.Normalize((0.0,), (1.0,))])
13 |
14 | if config.data_name == 'mnist':
15 | train_dataset = dset.MNIST(root=data_dir, train=True, transform=trans, download=True)
16 | test_dataset = dset.MNIST(root=data_dir, train=False, transform=trans, download=True)
17 | elif config.data_name == 'fashion_mnist':
18 | train_dataset = dset.FashionMNIST(root=data_dir, train=True, transform=trans, download=True)
19 | test_dataset = dset.FashionMNIST(root=data_dir, train=False, transform=trans, download=True)
20 |
21 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
22 | num_workers=config.num_work, shuffle=True)
23 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
24 | num_workers=config.num_work, shuffle=False)
25 |
26 | print('==>>> total trainning batch number: {}'.format(len(train_loader)))
27 | print('==>>> total testing batch number: {}'.format(len(test_loader)))
28 |
29 | data_loader = {'train': train_loader, 'test': test_loader}
30 |
31 | return data_loader
32 |
--------------------------------------------------------------------------------
/WHAT_src/data/data_camvid.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/what/e3fe42ac8568bdaf28cf7fde112a8f95368097b9/WHAT_src/data/data_camvid.py
--------------------------------------------------------------------------------
/WHAT_src/data/data_nyu.py:
--------------------------------------------------------------------------------
1 | '''
2 | https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
3 | '''
4 |
5 | import h5py
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | FILE_PATH = '/Users/chokiheum/Data/nyu_v2/nyu_depth_v2_labeled.mat'
10 |
11 |
12 | class NYU_v2(Dataset):
13 | def __init__(self, file_path, transform=None):
14 | self.transform = transform
15 | self.file_path = file_path
16 |
17 | def __getitem__(self, idx):
18 | with h5py.File(self.file_path, 'r') as f:
19 | depth = torch.from_numpy(f['depths'][idx].astype('float32'))
20 | image = torch.from_numpy(f['images'][idx].astype('float32'))
21 | label = torch.from_numpy(f['labels'][idx].astype('float32'))
22 |
23 | sample = {'depth': depth, 'image': image, 'label': label}
24 |
25 | if self.transform:
26 | sample = self.transform(sample)
27 |
28 | return sample
29 |
30 | def __len__(self):
31 | with h5py.File(self.file_path, 'r') as f:
32 | length = len(f['images'])
33 | return length
34 |
35 |
36 | if __name__ == '__main__':
37 | nyu_dataset = NYU_v2(file_path=FILE_PATH)
38 | dataloader = DataLoader(nyu_dataset,
39 | batch_size=4, shuffle=True, num_workers=4)
40 |
41 | for i_batch, sample_batched in enumerate(dataloader):
42 | print(sample_batched['image'])
43 |
--------------------------------------------------------------------------------
/WHAT_src/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from importlib import import_module
4 |
5 |
6 | class Loss(nn.Module):
7 | def __init__(self, config):
8 | super(Loss, self).__init__()
9 | print('Preparing loss function...')
10 |
11 | self.num_gpu = config.num_gpu
12 | self.losses = []
13 | self.loss_module = nn.ModuleList()
14 |
15 | if config.uncertainty == 'epistemic' or config.uncertainty == 'normal':
16 | module = import_module('loss.mse')
17 | loss_function = getattr(module, 'MSE')()
18 | else:
19 | module = import_module('loss.mse_var')
20 | loss_function = getattr(module, 'MSE_VAR')(
21 | var_weight=config.var_weight)
22 |
23 | self.losses.append({'function': loss_function})
24 |
25 | self.loss_module.to(config.device)
26 | if not config.cpu and config.num_gpu > 1:
27 | self.loss_module = nn.DataParallel(
28 | self.loss_module, range(self.num_gpu))
29 |
30 | def forward(self, results, label):
31 | losses = []
32 | for i, l in enumerate(self.losses):
33 | if l['function'] is not None:
34 | loss = l['function'](results, label)
35 | effective_loss = loss
36 | losses.append(effective_loss)
37 |
38 | loss_sum = sum(losses)
39 | if len(self.losses) > 1:
40 | self.log[-1, -1] += loss_sum.item()
41 |
42 | return loss_sum
43 |
--------------------------------------------------------------------------------
/WHAT_src/loss/mse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MSE(nn.Module):
7 | def __init__(self):
8 | super(MSE, self).__init__()
9 |
10 | def forward(self, results, label):
11 | mean = results['mean']
12 | loss = F.mse_loss(mean, label)
13 | return loss
14 |
--------------------------------------------------------------------------------
/WHAT_src/loss/mse_var.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MSE_VAR(nn.Module):
7 | def __init__(self, var_weight):
8 | super(MSE_VAR, self).__init__()
9 | self.var_weight = var_weight
10 |
11 | def forward(self, results, label):
12 | mean, var = results['mean'], results['var']
13 | var = self.var_weight * var
14 |
15 | loss1 = torch.mul(torch.exp(-var), (mean - label) ** 2)
16 | loss2 = var
17 | loss = .5 * (loss1 + loss2)
18 | return loss.mean()
19 |
20 |
--------------------------------------------------------------------------------
/WHAT_src/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from config import get_config
4 | from data import get_dataloader
5 | from op import Operator
6 | from util import Checkpoint
7 |
8 |
9 | def main(config):
10 | config.device = torch.device('cuda:{}'.format(config.gpu)
11 | if torch.cuda.is_available() else 'cpu')
12 |
13 | # load data_loader
14 | data_loader = get_dataloader(config)
15 | check_point = Checkpoint(config)
16 | operator = Operator(config, check_point)
17 |
18 | if config.is_train:
19 | operator.train(data_loader)
20 | else:
21 | operator.test(data_loader)
22 |
23 |
24 | if __name__ == "__main__":
25 | config = get_config()
26 | main(config)
27 |
--------------------------------------------------------------------------------
/WHAT_src/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel as P
7 |
8 |
9 | class Model(nn.Module):
10 | def __init__(self, config):
11 | super(Model, self).__init__()
12 | print('Making model...')
13 |
14 | self.is_train = config.is_train
15 | self.num_gpu = config.num_gpu
16 | self.uncertainty = config.uncertainty
17 | self.n_samples = config.n_samples
18 | module = import_module('model.' + config.uncertainty)
19 | self.model = module.make_model(config).to(config.device)
20 |
21 | def forward(self, input):
22 | if self.model.training:
23 | if self.num_gpu > 1:
24 | return P.data_parallel(self.model, input,
25 | list(range(self.num_gpu)))
26 | else:
27 | return self.model.forward(input)
28 | else:
29 | forward_func = self.model.forward
30 | if self.uncertainty == 'normal':
31 | return forward_func(input)
32 | if self.uncertainty == 'aleatoric':
33 | return self.test_aleatoric(input, forward_func)
34 | elif self.uncertainty == 'epistemic':
35 | return self.test_epistemic(input, forward_func)
36 | elif self.uncertainty == 'combined':
37 | return self.test_combined(input, forward_func)
38 |
39 | def test_aleatoric(self, input, forward_func):
40 | results = forward_func(input)
41 | mean1 = results['mean']
42 | var1 = torch.exp(results['var'])
43 | var1_norm = var1 / var1.max()
44 | results = {'mean': mean1, 'var': var1_norm}
45 | return results
46 |
47 | def test_epistemic(self, input, forward_func):
48 | mean1s = []
49 | mean2s = []
50 |
51 | for i_sample in range(self.n_samples):
52 | results = forward_func(input)
53 | mean1 = results['mean']
54 | mean1s.append(mean1 ** 2)
55 | mean2s.append(mean1)
56 |
57 | mean1s_ = torch.stack(mean1s, dim=0).mean(dim=0)
58 | mean2s_ = torch.stack(mean2s, dim=0).mean(dim=0)
59 |
60 | var1 = mean1s_ - mean2s_ ** 2
61 | var1_norm = var1 / var1.max()
62 | results = {'mean': mean2s_, 'var': var1_norm}
63 | return results
64 |
65 | def test_combined(self, input, forward_func):
66 | mean1s = []
67 | mean2s = []
68 | var1s = []
69 |
70 | for i_sample in range(self.n_samples):
71 | results = forward_func(input)
72 | mean1 = results['mean']
73 | mean1s.append(mean1 ** 2)
74 | mean2s.append(mean1)
75 | var1 = results['var']
76 | var1s.append(torch.exp(var1))
77 |
78 | mean1s_ = torch.stack(mean1s, dim=0).mean(dim=0)
79 | mean2s_ = torch.stack(mean2s, dim=0).mean(dim=0)
80 |
81 | var1s_ = torch.stack(var1s, dim=0).mean(dim=0)
82 | var2 = mean1s_ - mean2s_ ** 2
83 | var_ = var1s_ + var2
84 | var_norm = var_ / var_.max()
85 | results = {'mean': mean2s_, 'var': var_norm}
86 | return results
87 |
88 | def save(self, ckpt, epoch):
89 | save_dirs = [os.path.join(ckpt.model_dir, 'model_latest.pt')]
90 | save_dirs.append(
91 | os.path.join(ckpt.model_dir, 'model_{}.pt'.format(epoch)))
92 | for s in save_dirs:
93 | torch.save(self.model.state_dict(), s)
94 |
95 | def load(self, ckpt, cpu=False):
96 | epoch = ckpt.last_epoch
97 | kwargs = {}
98 | if cpu:
99 | kwargs = {'map_location': lambda storage, loc: storage}
100 | if epoch == -1:
101 | load_from = torch.load(
102 | os.path.join(ckpt.model_dir, 'model_latest.pt'), **kwargs)
103 | else:
104 | load_from = torch.load(
105 | os.path.join(ckpt.model_dir, 'model_{}.pt'.format(epoch)), **kwargs)
106 | if load_from:
107 | self.model.load_state_dict(load_from, strict=False)
108 |
--------------------------------------------------------------------------------
/WHAT_src/model/aleatoric.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from model import common
5 |
6 |
7 | def make_model(args):
8 | return ALEATORIC(args)
9 |
10 |
11 | class ALEATORIC(nn.Module):
12 | def __init__(self, config):
13 | super(ALEATORIC, self).__init__()
14 | self.drop_rate = config.drop_rate
15 | in_channels = config.in_channels
16 | filter_config = (64, 128)
17 |
18 | self.encoders = nn.ModuleList()
19 | self.decoders_mean = nn.ModuleList()
20 | self.decoders_var = nn.ModuleList()
21 |
22 | # setup number of conv-bn-relu blocks per module and number of filters
23 | encoder_n_layers = (2, 2, 3, 3, 3)
24 | encoder_filter_config = (in_channels,) + filter_config
25 | decoder_n_layers = (3, 3, 3, 2, 1)
26 | decoder_filter_config = filter_config[::-1] + (filter_config[0],)
27 |
28 | for i in range(0, 2):
29 | # encoder architecture
30 | self.encoders.append(_Encoder(encoder_filter_config[i],
31 | encoder_filter_config[i + 1],
32 | encoder_n_layers[i]))
33 |
34 | # decoder architecture
35 | self.decoders_mean.append(_Decoder(decoder_filter_config[i],
36 | decoder_filter_config[i + 1],
37 | decoder_n_layers[i]))
38 |
39 | # decoder architecture
40 | self.decoders_var.append(_Decoder(decoder_filter_config[i],
41 | decoder_filter_config[i + 1],
42 | decoder_n_layers[i]))
43 |
44 | # final classifier (equivalent to a fully connected layer)
45 | self.classifier_mean = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1)
46 | self.classifier_var = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1)
47 |
48 | def forward(self, x):
49 | indices = []
50 | unpool_sizes = []
51 | feat = x
52 |
53 | # encoder path, keep track of pooling indices and features size
54 | for i in range(0, 2):
55 | (feat, ind), size = self.encoders[i](feat)
56 | if i == 1:
57 | feat = F.dropout(feat, p=self.drop_rate)
58 | indices.append(ind)
59 | unpool_sizes.append(size)
60 |
61 | feat_mean = feat
62 | feat_var = feat
63 | # decoder path, upsampling with corresponding indices and size
64 | for i in range(0, 2):
65 | feat_mean = self.decoders_mean[i](feat_mean, indices[1 - i], unpool_sizes[1 - i])
66 | feat_var = self.decoders_var[i](feat_var, indices[1 - i], unpool_sizes[1 - i])
67 | if i == 0:
68 | feat_mean = F.dropout(feat_mean, p=self.drop_rate)
69 | feat_var = F.dropout(feat_var, p=self.drop_rate)
70 |
71 | output_mean = self.classifier_mean(feat_mean)
72 | output_var = self.classifier_var(feat_var)
73 |
74 | results = {'mean': output_mean, 'var': output_var}
75 | return results
76 |
77 |
78 | class _Encoder(nn.Module):
79 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2):
80 | """Encoder layer follows VGG rules + keeps pooling indices
81 | Args:
82 | n_in_feat (int): number of input features
83 | n_out_feat (int): number of output features
84 | n_blocks (int): number of conv-batch-relu block inside the encoder
85 | drop_rate (float): dropout rate to use
86 | """
87 | super(_Encoder, self).__init__()
88 |
89 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
90 | nn.BatchNorm2d(n_out_feat),
91 | nn.ReLU()]
92 |
93 | if n_blocks > 1:
94 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1),
95 | nn.BatchNorm2d(n_out_feat),
96 | nn.ReLU()]
97 |
98 | self.features = nn.Sequential(*layers)
99 |
100 | def forward(self, x):
101 | output = self.features(x)
102 | return F.max_pool2d(output, 2, 2, return_indices=True), output.size()
103 |
104 |
105 | class _Decoder(nn.Module):
106 | """Decoder layer decodes the features by unpooling with respect to
107 | the pooling indices of the corresponding decoder part.
108 | Args:
109 | n_in_feat (int): number of input features
110 | n_out_feat (int): number of output features
111 | n_blocks (int): number of conv-batch-relu block inside the decoder
112 | drop_rate (float): dropout rate to use
113 | """
114 |
115 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2):
116 | super(_Decoder, self).__init__()
117 |
118 | layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1),
119 | nn.BatchNorm2d(n_in_feat),
120 | nn.ReLU()]
121 |
122 | if n_blocks > 1:
123 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
124 | nn.BatchNorm2d(n_out_feat),
125 | nn.ReLU()]
126 |
127 | self.features = nn.Sequential(*layers)
128 |
129 | def forward(self, x, indices, size):
130 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size)
131 | return self.features(unpooled)
132 |
--------------------------------------------------------------------------------
/WHAT_src/model/combined.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from model import common
5 |
6 |
7 | def make_model(args):
8 | return COMBINED(args)
9 |
10 |
11 | class COMBINED(nn.Module):
12 | def __init__(self, config):
13 | super(COMBINED, self).__init__()
14 | self.drop_rate = config.drop_rate
15 | in_channels = config.in_channels
16 | filter_config = (64, 128)
17 |
18 | self.encoders = nn.ModuleList()
19 | self.decoders_mean = nn.ModuleList()
20 | self.decoders_var = nn.ModuleList()
21 |
22 | # setup number of conv-bn-relu blocks per module and number of filters
23 | encoder_n_layers = (2, 2, 3, 3, 3)
24 | encoder_filter_config = (in_channels,) + filter_config
25 | decoder_n_layers = (3, 3, 3, 2, 1)
26 | decoder_filter_config = filter_config[::-1] + (filter_config[0],)
27 |
28 | for i in range(0, 2):
29 | # encoder architecture
30 | self.encoders.append(_Encoder(encoder_filter_config[i],
31 | encoder_filter_config[i + 1],
32 | encoder_n_layers[i]))
33 |
34 | # decoder architecture
35 | self.decoders_mean.append(_Decoder(decoder_filter_config[i],
36 | decoder_filter_config[i + 1],
37 | decoder_n_layers[i]))
38 |
39 | # decoder architecture
40 | self.decoders_var.append(_Decoder(decoder_filter_config[i],
41 | decoder_filter_config[i + 1],
42 | decoder_n_layers[i]))
43 |
44 | # final classifier (equivalent to a fully connected layer)
45 | self.classifier_mean = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1)
46 | self.classifier_var = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1)
47 |
48 | def forward(self, x):
49 | indices = []
50 | unpool_sizes = []
51 | feat = x
52 |
53 | # encoder path, keep track of pooling indices and features size
54 | for i in range(0, 2):
55 | (feat, ind), size = self.encoders[i](feat)
56 | if i == 1:
57 | feat = F.dropout(feat, p=self.drop_rate, training=True)
58 | indices.append(ind)
59 | unpool_sizes.append(size)
60 |
61 | feat_mean = feat
62 | feat_var = feat
63 | # decoder path, upsampling with corresponding indices and size
64 | for i in range(0, 2):
65 | feat_mean = self.decoders_mean[i](feat_mean, indices[1 - i], unpool_sizes[1 - i])
66 | feat_var = self.decoders_var[i](feat_var, indices[1 - i], unpool_sizes[1 - i])
67 | if i == 0:
68 | feat_mean = F.dropout(feat_mean, p=self.drop_rate, training=True)
69 | feat_var = F.dropout(feat_var, p=self.drop_rate, training=True)
70 |
71 | output_mean = self.classifier_mean(feat_mean)
72 | output_var = self.classifier_var(feat_var)
73 |
74 | results = {'mean': output_mean, 'var': output_var}
75 | return results
76 |
77 |
78 | class _Encoder(nn.Module):
79 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2):
80 | """Encoder layer follows VGG rules + keeps pooling indices
81 | Args:
82 | n_in_feat (int): number of input features
83 | n_out_feat (int): number of output features
84 | n_blocks (int): number of conv-batch-relu block inside the encoder
85 | drop_rate (float): dropout rate to use
86 | """
87 | super(_Encoder, self).__init__()
88 |
89 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
90 | nn.BatchNorm2d(n_out_feat),
91 | nn.ReLU()]
92 |
93 | if n_blocks > 1:
94 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1),
95 | nn.BatchNorm2d(n_out_feat),
96 | nn.ReLU()]
97 |
98 | self.features = nn.Sequential(*layers)
99 |
100 | def forward(self, x):
101 | output = self.features(x)
102 | return F.max_pool2d(output, 2, 2, return_indices=True), output.size()
103 |
104 |
105 | class _Decoder(nn.Module):
106 | """Decoder layer decodes the features by unpooling with respect to
107 | the pooling indices of the corresponding decoder part.
108 | Args:
109 | n_in_feat (int): number of input features
110 | n_out_feat (int): number of output features
111 | n_blocks (int): number of conv-batch-relu block inside the decoder
112 | drop_rate (float): dropout rate to use
113 | """
114 |
115 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2):
116 | super(_Decoder, self).__init__()
117 |
118 | layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1),
119 | nn.BatchNorm2d(n_in_feat),
120 | nn.ReLU()]
121 |
122 | if n_blocks > 1:
123 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
124 | nn.BatchNorm2d(n_out_feat),
125 | nn.ReLU()]
126 |
127 | self.features = nn.Sequential(*layers)
128 |
129 | def forward(self, x, indices, size):
130 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size)
131 | return self.features(unpooled)
132 |
--------------------------------------------------------------------------------
/WHAT_src/model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
9 | return nn.Conv2d(
10 | in_channels, out_channels, kernel_size,
11 | padding=(kernel_size//2), bias=bias)
12 |
13 |
14 | class BasicBlock(nn.Sequential):
15 | def __init__(
16 | self, conv, in_channels, out_channels, kernel_size,
17 | bias=False, bn=True, act=nn.ReLU(True)):
18 |
19 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
20 | if bn:
21 | m.append(nn.BatchNorm2d(out_channels))
22 | if act is not None:
23 | m.append(act)
24 |
25 | super(BasicBlock, self).__init__(*m)
26 |
27 |
28 | class ResBlock(nn.Module):
29 | def __init__(
30 | self, conv, n_feats, kernel_size,
31 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
32 |
33 | super(ResBlock, self).__init__()
34 | m = []
35 | for i in range(2):
36 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
37 | if bn:
38 | m.append(nn.BatchNorm2d(n_feats))
39 | if i == 0:
40 | m.append(act)
41 |
42 | self.body = nn.Sequential(*m)
43 | self.res_scale = res_scale
44 |
45 | def forward(self, x):
46 | res = self.body(x).mul(self.res_scale)
47 | res += x
48 |
49 | return res
50 |
51 |
52 | class Upsampler(nn.Sequential):
53 | def __init__(self, conv, scale, n_feats,
54 | bn=False, act=False, bias=True):
55 | m = []
56 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
57 | for _ in range(int(math.log(scale, 2))):
58 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
59 | m.append(nn.PixelShuffle(2))
60 | if bn:
61 | m.append(nn.BatchNorm2d(n_feats))
62 | if act == 'relu':
63 | m.append(nn.ReLU(True))
64 | elif act == 'prelu':
65 | m.append(nn.PReLU(n_feats))
66 |
67 | elif scale == 3:
68 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
69 | m.append(nn.PixelShuffle(3))
70 | if bn:
71 | m.append(nn.BatchNorm2d(n_feats))
72 | if act == 'relu':
73 | m.append(nn.ReLU(True))
74 | elif act == 'prelu':
75 | m.append(nn.PReLU(n_feats))
76 | else:
77 | raise NotImplementedError
78 |
79 | super(Upsampler, self).__init__(*m)
80 |
--------------------------------------------------------------------------------
/WHAT_src/model/epistemic.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from model import common
5 |
6 |
7 | def make_model(args):
8 | return EPISTEMIC(args)
9 |
10 |
11 | class EPISTEMIC(nn.Module):
12 | def __init__(self, config):
13 | super(EPISTEMIC, self).__init__()
14 | self.drop_rate = config.drop_rate
15 | in_channels = config.in_channels
16 | filter_config = (64, 128)
17 |
18 | self.encoders = nn.ModuleList()
19 | self.decoders = nn.ModuleList()
20 |
21 | # setup number of conv-bn-relu blocks per module and number of filters
22 | encoder_n_layers = (2, 2, 3, 3, 3)
23 | encoder_filter_config = (in_channels,) + filter_config
24 | decoder_n_layers = (3, 3, 3, 2, 1)
25 | decoder_filter_config = filter_config[::-1] + (filter_config[0],)
26 |
27 | for i in range(0, 2):
28 | # encoder architecture
29 | self.encoders.append(_Encoder(encoder_filter_config[i],
30 | encoder_filter_config[i + 1],
31 | encoder_n_layers[i]))
32 |
33 | # decoder architecture
34 | self.decoders.append(_Decoder(decoder_filter_config[i],
35 | decoder_filter_config[i + 1],
36 | decoder_n_layers[i]))
37 |
38 | # final classifier (equivalent to a fully connected layer)
39 | self.classifier = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1)
40 |
41 | def forward(self, x):
42 | indices = []
43 | unpool_sizes = []
44 | feat = x
45 |
46 | # encoder path, keep track of pooling indices and features size
47 | for i in range(0, 2):
48 | (feat, ind), size = self.encoders[i](feat)
49 | if i == 1:
50 | feat = F.dropout(feat, p=self.drop_rate, training=True)
51 | indices.append(ind)
52 | unpool_sizes.append(size)
53 |
54 | # decoder path, upsampling with corresponding indices and size
55 | for i in range(0, 2):
56 | feat = self.decoders[i](feat, indices[1 - i], unpool_sizes[1 - i])
57 | if i == 0:
58 | feat = F.dropout(feat, p=self.drop_rate, training=True)
59 |
60 | output = self.classifier(feat)
61 | results = {'mean': output}
62 |
63 | return results
64 |
65 |
66 | class _Encoder(nn.Module):
67 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2):
68 | """Encoder layer follows VGG rules + keeps pooling indices
69 | Args:
70 | n_in_feat (int): number of input features
71 | n_out_feat (int): number of output features
72 | n_blocks (int): number of conv-batch-relu block inside the encoder
73 | drop_rate (float): dropout rate to use
74 | """
75 | super(_Encoder, self).__init__()
76 |
77 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
78 | nn.BatchNorm2d(n_out_feat),
79 | nn.ReLU()]
80 |
81 | if n_blocks > 1:
82 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1),
83 | nn.BatchNorm2d(n_out_feat),
84 | nn.ReLU()]
85 |
86 | self.features = nn.Sequential(*layers)
87 |
88 | def forward(self, x):
89 | output = self.features(x)
90 | return F.max_pool2d(output, 2, 2, return_indices=True), output.size()
91 |
92 |
93 | class _Decoder(nn.Module):
94 | """Decoder layer decodes the features by unpooling with respect to
95 | the pooling indices of the corresponding decoder part.
96 | Args:
97 | n_in_feat (int): number of input features
98 | n_out_feat (int): number of output features
99 | n_blocks (int): number of conv-batch-relu block inside the decoder
100 | drop_rate (float): dropout rate to use
101 | """
102 |
103 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2):
104 | super(_Decoder, self).__init__()
105 |
106 | layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1),
107 | nn.BatchNorm2d(n_in_feat),
108 | nn.ReLU()]
109 |
110 | if n_blocks > 1:
111 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
112 | nn.BatchNorm2d(n_out_feat),
113 | nn.ReLU()]
114 |
115 | self.features = nn.Sequential(*layers)
116 |
117 | def forward(self, x, indices, size):
118 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size)
119 | return self.features(unpooled)
120 |
--------------------------------------------------------------------------------
/WHAT_src/model/normal.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from model import common
5 |
6 |
7 | def make_model(args):
8 | return NORMAL(args)
9 |
10 |
11 | class NORMAL(nn.Module):
12 | def __init__(self, config):
13 | super(NORMAL, self).__init__()
14 | self.drop_rate = config.drop_rate
15 | in_channels = config.in_channels
16 | filter_config = (64, 128)
17 |
18 | self.encoders = nn.ModuleList()
19 | self.decoders = nn.ModuleList()
20 |
21 | # setup number of conv-bn-relu blocks per module and number of filters
22 | encoder_n_layers = (2, 2, 2, 2, 2)
23 | encoder_filter_config = (in_channels,) + filter_config
24 | decoder_n_layers = (2, 2, 2, 2, 1)
25 | decoder_filter_config = filter_config[::-1] + (filter_config[0],)
26 |
27 | for i in range(0, 2):
28 | # encoder architecture
29 | self.encoders.append(_Encoder(encoder_filter_config[i],
30 | encoder_filter_config[i + 1],
31 | encoder_n_layers[i]))
32 |
33 | # decoder architecture
34 | self.decoders.append(_Decoder(decoder_filter_config[i],
35 | decoder_filter_config[i + 1],
36 | decoder_n_layers[i]))
37 |
38 | # final classifier (equivalent to a fully connected layer)
39 | self.classifier = nn.Conv2d(filter_config[0], in_channels, 3, 1, 1)
40 |
41 | def forward(self, x):
42 | indices = []
43 | unpool_sizes = []
44 | feat = x
45 |
46 | # encoder path, keep track of pooling indices and features size
47 | for i in range(0, 2):
48 | (feat, ind), size = self.encoders[i](feat)
49 | if i == 1:
50 | feat = F.dropout(feat, p=self.drop_rate)
51 | indices.append(ind)
52 | unpool_sizes.append(size)
53 |
54 | # decoder path, upsampling with corresponding indices and size
55 | for i in range(0, 2):
56 | feat = self.decoders[i](feat, indices[1 - i], unpool_sizes[1 - i])
57 | if i == 0:
58 | feat = F.dropout(feat, p=self.drop_rate)
59 |
60 | output = self.classifier(feat)
61 | results = {'mean': output}
62 |
63 | return results
64 |
65 |
66 | class _Encoder(nn.Module):
67 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2):
68 | """Encoder layer follows VGG rules + keeps pooling indices
69 | Args:
70 | n_in_feat (int): number of input features
71 | n_out_feat (int): number of output features
72 | n_blocks (int): number of conv-batch-relu block inside the encoder
73 | drop_rate (float): dropout rate to use
74 | """
75 | super(_Encoder, self).__init__()
76 |
77 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
78 | nn.BatchNorm2d(n_out_feat),
79 | nn.ReLU(inplace=True)]
80 |
81 | if n_blocks > 1:
82 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1),
83 | nn.BatchNorm2d(n_out_feat),
84 | nn.ReLU(inplace=True)]
85 | self.features = nn.Sequential(*layers)
86 |
87 | def forward(self, x):
88 | output = self.features(x)
89 | return F.max_pool2d(output, 2, 2, return_indices=True), output.size()
90 |
91 |
92 | class _Decoder(nn.Module):
93 | """Decoder layer decodes the features by unpooling with respect to
94 | the pooling indices of the corresponding decoder part.
95 | Args:
96 | n_in_feat (int): number of input features
97 | n_out_feat (int): number of output features
98 | n_blocks (int): number of conv-batch-relu block inside the decoder
99 | drop_rate (float): dropout rate to use
100 | """
101 |
102 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2):
103 | super(_Decoder, self).__init__()
104 |
105 | layers = [nn.Conv2d(n_in_feat, n_in_feat, 3, 1, 1),
106 | nn.BatchNorm2d(n_in_feat),
107 | nn.ReLU(inplace=True)]
108 |
109 | if n_blocks > 1:
110 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1),
111 | nn.BatchNorm2d(n_out_feat),
112 | nn.ReLU(inplace=True)]
113 | self.features = nn.Sequential(*layers)
114 |
115 | def forward(self, x, indices, size):
116 | unpooled = F.max_unpool2d(x, indices, 2, 2, 0, size)
117 | return self.features(unpooled)
118 |
--------------------------------------------------------------------------------
/WHAT_src/op.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.tensorboard import SummaryWriter
3 |
4 | from model import *
5 | from loss import Loss
6 | from util import make_optimizer, calc_psnr, summary
7 |
8 |
9 | class Operator:
10 | def __init__(self, config, ckeck_point):
11 | self.config = config
12 | self.epochs = config.epochs
13 | self.uncertainty = config.uncertainty
14 | self.ckpt = ckeck_point
15 | self.tensorboard = config.tensorboard
16 | if self.tensorboard:
17 | self.summary_writer = SummaryWriter(self.ckpt.log_dir, 300)
18 |
19 | # set model, criterion, optimizer
20 | self.model = Model(config)
21 | summary(self.model, config_file=self.ckpt.config_file)
22 |
23 | # set criterion, optimizer
24 | self.criterion = Loss(config)
25 | self.optimizer = make_optimizer(config, self.model)
26 |
27 | # load ckpt, model, optimizer
28 | if self.ckpt.exp_load is not None or not config.is_train:
29 | print("Loading model... ")
30 | self.load(self.ckpt)
31 | print(self.ckpt.last_epoch, self.ckpt.global_step)
32 |
33 | def train(self, data_loader):
34 | last_epoch = self.ckpt.last_epoch
35 | train_batch_num = len(data_loader['train'])
36 |
37 | for epoch in range(last_epoch, self.epochs):
38 | for batch_idx, batch_data in enumerate(data_loader['train']):
39 | batch_input, batch_label = batch_data
40 | batch_input = batch_input.to(self.config.device)
41 | batch_label = batch_label.to(self.config.device)
42 |
43 | # forward
44 | batch_results = self.model(batch_input)
45 | loss = self.criterion(batch_results, batch_input)
46 |
47 | # backward
48 | self.optimizer.zero_grad()
49 | loss.backward()
50 | self.optimizer.step()
51 | print('Epoch: {:03d}/{:03d}, Iter: {:03d}/{:03d}, Loss: {:5f}'
52 | .format(epoch, self.config.epochs,
53 | batch_idx, train_batch_num,
54 | loss.item()))
55 |
56 | # use tensorboard
57 | if self.tensorboard:
58 | current_global_step = self.ckpt.step()
59 | self.summary_writer.add_scalar('train/loss',
60 | loss, current_global_step)
61 | self.summary_writer.add_images("train/input_img",
62 | batch_input,
63 | current_global_step)
64 | self.summary_writer.add_images("train/mean_img",
65 | torch.clamp(batch_results['mean'], 0., 1.),
66 | current_global_step)
67 |
68 | # use tensorboard
69 | if self.tensorboard:
70 | print(self.optimizer.get_lr(), epoch)
71 | self.summary_writer.add_scalar('epoch_lr',
72 | self.optimizer.get_lr(), epoch)
73 |
74 | # test model & save model
75 | self.optimizer.schedule()
76 | self.save(self.ckpt, epoch)
77 | self.test(data_loader)
78 | self.model.train()
79 |
80 | self.summary_writer.close()
81 |
82 | def test(self, data_loader):
83 | with torch.no_grad():
84 | self.model.eval()
85 |
86 | total_psnr = 0.
87 | psnrs = []
88 | test_batch_num = len(data_loader['test'])
89 | for batch_idx, batch_data in enumerate(data_loader['test']):
90 | batch_input, batch_label = batch_data
91 | batch_input = batch_input.to(self.config.device)
92 | batch_label = batch_label.to(self.config.device)
93 |
94 | # forward
95 | batch_results = self.model(batch_input)
96 | current_psnr = calc_psnr(batch_results['mean'], batch_input)
97 | psnrs.append(current_psnr)
98 | total_psnr = sum(psnrs) / len(psnrs)
99 | print("Test iter: {:03d}/{:03d}, Total: {:5f}, Current: {:05f}".format(
100 | batch_idx, test_batch_num,
101 | total_psnr, psnrs[batch_idx]))
102 |
103 | # use tensorboard
104 | if self.tensorboard:
105 | self.summary_writer.add_scalar('test/psnr',
106 | total_psnr, self.ckpt.last_epoch)
107 | self.summary_writer.add_images("test/input_img",
108 | batch_input, self.ckpt.last_epoch)
109 | self.summary_writer.add_images("test/mean_img",
110 | torch.clamp(batch_results['mean'], 0., 1.),
111 | self.ckpt.last_epoch)
112 | if not self.uncertainty == 'normal':
113 | self.summary_writer.add_images("test/var_img",
114 | batch_results['var'],
115 | self.ckpt.last_epoch)
116 |
117 | def load(self, ckpt):
118 | ckpt.load() # load ckpt
119 | self.model.load(ckpt) # load model
120 | self.optimizer.load(ckpt) # load optimizer
121 |
122 | def save(self, ckpt, epoch):
123 | ckpt.save(epoch) # save ckpt: global_step, last_epoch
124 | self.model.save(ckpt, epoch) # save model: weight
125 | self.optimizer.save(ckpt) # save optimizer:
126 |
127 |
128 |
--------------------------------------------------------------------------------
/WHAT_src/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import sys
4 | from datetime import datetime
5 | from functools import reduce
6 |
7 | import torch
8 | import torch.optim as optim
9 | import torch.optim.lr_scheduler as lrs
10 | from torch.nn.modules.module import _addindent
11 |
12 |
13 | class Checkpoint:
14 | def __init__(self, config):
15 | self.global_step = 0
16 | self.last_epoch = 0
17 | self.config = config
18 | self.exp_dir = config.exp_dir
19 | self.exp_load = config.exp_load
20 | exp_type = config.uncertainty
21 | now = datetime.now().strftime('%m%d_%H%M')
22 |
23 | if config.exp_load is None:
24 | dir_fmt = '{}/{}_{}'.format(config.data_name, exp_type, now)
25 | else:
26 | dir_fmt = '{}/{}_{}'.format(config.data_name, exp_type, self.exp_load)
27 |
28 | self.model_dir = os.path.join(self.exp_dir, dir_fmt, 'model')
29 | self.log_dir = os.path.join(self.exp_dir, dir_fmt, 'log')
30 | self.save_dir = os.path.join(self.exp_dir, dir_fmt, 'save')
31 | self.ckpt_dir = os.path.join(self.log_dir, 'ckpt.pt')
32 |
33 | os.makedirs(self.model_dir, exist_ok=True)
34 | os.makedirs(self.log_dir, exist_ok=True)
35 | os.makedirs(self.save_dir, exist_ok=True)
36 |
37 | # save config
38 | self.config_file = os.path.join(self.log_dir, 'config.txt')
39 | with open(self.config_file, 'w') as f:
40 | for k, v in vars(config).items():
41 | f.writelines('{}: {} \n'.format(k, v))
42 |
43 | def step(self):
44 | self.global_step += 1
45 | return self.global_step
46 |
47 | def save(self, epoch):
48 | self.last_epoch = epoch
49 | save_ckpt = {'global_step': self.global_step,
50 | 'last_epoch': self.last_epoch}
51 | torch.save(save_ckpt, self.ckpt_dir)
52 |
53 | def load(self):
54 | load_ckpt = torch.load(self.ckpt_dir)
55 | self.global_step = load_ckpt['global_step']
56 | self.last_epoch = load_ckpt['last_epoch']
57 |
58 |
59 | def calc_psnr(output, label, rgb_range=1.):
60 | if label.nelement() == 1: return 0
61 |
62 | diff = (output - label) / rgb_range
63 | mse = diff.pow(2).mean()
64 |
65 | return -10 * math.log10(mse)
66 |
67 |
68 | def make_optimizer(config, model):
69 | trainable = filter(lambda x: x.requires_grad, model.parameters())
70 | kwargs_optimizer = {'lr': config.lr, 'weight_decay': config.weight_decay}
71 |
72 | if config.optimizer == 'sgd':
73 | optimizer_class = optim.SGD
74 | kwargs_optimizer['momentum'] = config.momentum
75 | elif config.optimizer == 'adam':
76 | optimizer_class = optim.Adam
77 | kwargs_optimizer['betas'] = config.betas
78 | kwargs_optimizer['eps'] = config.epsilon
79 | elif config.optimizer == 'rmsprop':
80 | optimizer_class = optim.RMSprop
81 | kwargs_optimizer['eps'] = config.epsilon
82 |
83 | # scheduler
84 | milestones = list(map(lambda x: int(x), config.decay.split('-')))
85 | kwargs_scheduler = {'milestones': milestones, 'gamma': config.gamma}
86 | scheduler_class = lrs.MultiStepLR
87 |
88 | class CustomOptimizer(optimizer_class):
89 | def __init__(self, *args, **kwargs):
90 | super(CustomOptimizer, self).__init__(*args, **kwargs)
91 |
92 | def _register_scheduler(self, scheduler_class, **kwargs):
93 | self.scheduler = scheduler_class(self, **kwargs)
94 |
95 | def save(self, ckpt):
96 | save_dir = os.path.join(ckpt.model_dir, 'optimizer.pt')
97 | torch.save(self.state_dict(), save_dir)
98 |
99 | def load(self, ckpt):
100 | load_dir = os.path.join(ckpt.model_dir, 'optimizer.pt')
101 | epoch = ckpt.last_epoch
102 | self.load_state_dict(torch.load(load_dir))
103 | if epoch > 1:
104 | for _ in range(epoch): self.scheduler.step()
105 |
106 | def schedule(self):
107 | self.scheduler.step()
108 |
109 | def get_lr(self):
110 | return self.scheduler.get_lr()[0]
111 |
112 | def get_last_epoch(self):
113 | return self.scheduler.last_epoch
114 |
115 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
116 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
117 | return optimizer
118 |
119 |
120 | def summary(model, config_file, file=sys.stdout):
121 | def repr(model):
122 | # We treat the extra repr like the sub-module, one item per line
123 | extra_lines = []
124 | extra_repr = model.extra_repr()
125 | # empty string will be split into list ['']
126 | if extra_repr:
127 | extra_lines = extra_repr.split('\n')
128 | child_lines = []
129 | total_params = 0
130 | for key, module in model._modules.items():
131 | mod_str, num_params = repr(module)
132 | mod_str = _addindent(mod_str, 2)
133 | child_lines.append('(' + key + '): ' + mod_str)
134 | total_params += num_params
135 | lines = extra_lines + child_lines
136 |
137 | for name, p in model._parameters.items():
138 | if hasattr(p, 'shape'):
139 | total_params += reduce(lambda x, y: x * y, p.shape)
140 |
141 | main_str = model._get_name() + '('
142 | if lines:
143 | # simple one-liner info, which most builtin Modules will use
144 | if len(extra_lines) == 1 and not child_lines:
145 | main_str += extra_lines[0]
146 | else:
147 | main_str += '\n ' + '\n '.join(lines) + '\n'
148 |
149 | main_str += ')'
150 | if file is sys.stdout:
151 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
152 | else:
153 | main_str += ', {:,} params'.format(total_params)
154 | return main_str, total_params
155 |
156 | string, count = repr(model)
157 | print(string, file=open(config_file, 'a'))
158 |
159 | if file is not None:
160 | print(string, file=sys.stdout)
161 | file.flush()
162 |
163 | return count
164 |
165 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?
2 |
3 | Pytorch implementation of ["What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?", NIPS 2017](https://arxiv.org/abs/1703.04977)
4 |
5 |
6 |
7 | ## 1. Usage
8 |
9 | ```
10 | # Data Tree
11 | config.data_dir/
12 | └── config.data_name/
13 |
14 | # Project Tree
15 | WHAT
16 | ├── WHAT_src/
17 | │ ├── data/ *.py
18 | │ ├── loss/ *.py
19 | │ ├── model/ *.py
20 | │ └── *.py
21 | └── WHAT_exp/
22 | ├── log/
23 | ├── model/
24 | └── save/
25 | ```
26 |
27 |
28 |
29 | ### 1.1 Train
30 |
31 | ```
32 | # L2 loss only
33 | python train.py --uncertainty "normal" --drop_rate 0.
34 |
35 | # Epistemic / Aleatoric
36 | python train.py --uncertainty ["epistemic", "aleatoric"]
37 |
38 | # Epistemic + Aleatoric
39 | python train.py --uncertainty "combined"
40 | ```
41 |
42 |
43 |
44 | ### 1.2 Test
45 |
46 | ```
47 | # L2 loss only
48 | python train.py --is_train false --uncertainty "normal"
49 |
50 | # Epistemic
51 | python train.py --is_train false --uncertainty "epistemic" --n_samples 25 [or 5, 50]
52 |
53 | # Aleatoric
54 | python train.py --is_train false --uncertainty "aleatoric"
55 |
56 | # Epistemic + Aleatoric
57 | python train.py --is_train false --uncertainty "combined" --n_samples 25 [or 5, 50]
58 | ```
59 |
60 |
61 |
62 | ### 1.3 Requirements
63 |
64 | - Python3.7
65 |
66 | - Pytorch >= 1.0
67 | - Torchvision
68 | - distutils
69 |
70 |
71 |
72 | ## 2. Experiment
73 |
74 | This is not official implementation.
75 |
76 |
77 |
78 | ### 2.1 Network & Datset
79 |
80 | - Autoencoder based on [Bayesian Segnet](https://arxiv.org/abs/1511.02680)
81 |
82 | - Network depth 2 (paper 5)
83 | - Drop_rate 0.2 (paper 0.5)
84 |
85 | - Fahsion MNIST / MNIST
86 |
87 | - Input = Label (for autoencoder)
88 |
89 |
90 |
91 | ### 2.2 Results
92 |
93 | #### 2.2.1 PSNR
94 |
95 | Combined > Aleatoric > Normal (w/o D.O) > Epistemic > Normal (w/ D.O)
96 |
97 |
98 |
99 |
100 |
101 | #### 2.2.2 Images
102 |
103 |
Input / Label
104 |
105 |
Combined
106 |
107 |
Aleatoric
108 |
109 |
Epistemic
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------