├── FCN-8 ├── __pycache__ │ ├── model.cpython-36.pyc │ └── voc.cpython-36.pyc ├── data_loader.py ├── main.py ├── model.py ├── models │ └── vgg.py ├── solver.py ├── test.py ├── test2.py ├── utils.py └── voc.py ├── LRNN ├── __pycache__ │ ├── data_loader.cpython-36.pyc │ ├── model.cpython-36.pyc │ ├── solver.cpython-36.pyc │ └── utils.cpython-36.pyc ├── data_loader.py ├── main.py ├── model.py ├── solver.py ├── test.py ├── utils.py └── voc.py ├── README.md ├── data_loader.py ├── fig1.PNG ├── fig2.PNG ├── main.py ├── model.py ├── solver.py ├── utils.py └── voc_loader.py /FCN-8/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieltan07/spatialaffinitynetwork/b1b2e6fac23eec7bfe910768e4979abb9e46bebc/FCN-8/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /FCN-8/__pycache__/voc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieltan07/spatialaffinitynetwork/b1b2e6fac23eec7bfe910768e4979abb9e46bebc/FCN-8/__pycache__/voc.cpython-36.pyc -------------------------------------------------------------------------------- /FCN-8/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torchvision.datasets import ImageFolder 8 | from PIL import Image 9 | import h5py 10 | import numpy as np 11 | 12 | class FlowersDataset(Dataset): 13 | def __init__(self, image_path, transform, mode): 14 | self.transform = transform 15 | self.mode = mode 16 | self.data = h5py.File(image_path, 'r') 17 | self.num_data = self.data["train_images"].shape[0] 18 | self.attr2idx = {} 19 | self.idx2attr = {} 20 | 21 | print ('Start preprocessing dataset..!') 22 | self.preprocess() 23 | print ('Finished preprocessing dataset..!') 24 | 25 | if self.mode == 'train': 26 | self.num_data = self.data["train_images"].shape[0] 27 | elif self.mode == 'test': 28 | self.num_data = self.data["test_images"].shape[0] 29 | 30 | def preprocess(self): 31 | main_colors = ["blue","orange","pink","purple","red","white","yellow"] 32 | for i, attr in enumerate(main_colors): 33 | self.attr2idx[attr] = i 34 | self.idx2attr[i] = attr 35 | 36 | 37 | def __getitem__(self, index): 38 | image = Image.fromarray(np.uint8(self.data[self.mode+"_images"][index])) 39 | feature = np.float32(self.data[self.mode+"_feature"][index]) 40 | identity = int(self.data[self.mode+"_class"][index]) 41 | 42 | 43 | return self.transform(image), torch.FloatTensor(feature), identity 44 | 45 | def __len__(self): 46 | return self.num_data 47 | 48 | def get_loader(image_path, crop_size, image_size, batch_size, dataset='Flowers', mode='train'): 49 | """Build and return data loader.""" 50 | 51 | if mode == 'train': 52 | transform = transforms.Compose([ 53 | transforms.RandomCrop(crop_size), 54 | transforms.Scale(image_size), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 58 | else: 59 | transform = transforms.Compose([ 60 | transforms.CenterCrop(crop_size), 61 | transforms.Scale(image_size), 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 64 | 65 | if dataset == 'Flowers': 66 | dataset = FlowersDataset(image_path, transform, mode) 67 | 68 | shuffle = False 69 | if mode == 'train': 70 | shuffle = True 71 | 72 | data_loader = DataLoader(dataset=dataset, 73 | batch_size=batch_size, 74 | shuffle=shuffle) 75 | return data_loader 76 | -------------------------------------------------------------------------------- /FCN-8/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from solver import Solver 4 | from data_loader import get_loader 5 | from torch.backends import cudnn 6 | from utils import * 7 | 8 | 9 | def str2bool(v): 10 | return v.lower() in ('true') 11 | 12 | def main(config): 13 | # For fast training 14 | cudnn.benchmark = True 15 | 16 | # Create directories if not exist 17 | mkdir(config.log_path) 18 | mkdir(config.model_save_path) 19 | mkdir(config.sample_path) 20 | mkdir(config.result_path) 21 | data_loader = {} 22 | 23 | if config.dataset == 'Flowers': 24 | data_loader['Flowers'] = get_loader(config.flowers_image_path, config.flowers_crop_size, 25 | config.image_size, config.batch_size, 'Flowers', config.mode) 26 | 27 | # Solver 28 | solver = Solver(data_loader, vars(config)) 29 | 30 | if config.mode == 'train': 31 | if config.dataset in ['CelebA', 'Flowers', 'RaFD']: 32 | solver.train() 33 | elif config.mode == 'test': 34 | if config.dataset in ['CelebA', 'Flowers', 'RaFD']: 35 | solver.test() 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | 40 | # Model hyper-parameters 41 | parser.add_argument('--c_dim', type=int, default=5) 42 | parser.add_argument('--i_dim', type=int, default=103) 43 | parser.add_argument('--c2_dim', type=int, default=8) 44 | parser.add_argument('--celebA_crop_size', type=int, default=178) 45 | parser.add_argument('--rafd_crop_size', type=int, default=256) 46 | parser.add_argument('--flowers_crop_size', type=int, default=100) 47 | parser.add_argument('--image_size', type=int, default=128) 48 | parser.add_argument('--g_conv_dim', type=int, default=64)#16) 49 | parser.add_argument('--d_conv_dim', type=int, default=64) 50 | parser.add_argument('--g_repeat_num', type=int, default=6) 51 | parser.add_argument('--d_repeat_num', type=int, default=6) 52 | parser.add_argument('--g_lr', type=float, default=0.0001) 53 | parser.add_argument('--d_lr', type=float, default=0.0001) 54 | parser.add_argument('--lambda_cls', type=float, default=1) 55 | parser.add_argument('--lambda_rec', type=float, default=10) 56 | parser.add_argument('--lambda_gp', type=float, default=10) 57 | parser.add_argument('--d_train_repeat', type=int, default=5) 58 | 59 | # Training settings 60 | parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA','Flowers', 'RaFD', 'Both']) 61 | parser.add_argument('--num_epochs', type=int, default=20) 62 | parser.add_argument('--num_epochs_decay', type=int, default=10) 63 | parser.add_argument('--num_iters', type=int, default=200000) 64 | parser.add_argument('--num_iters_decay', type=int, default=100000) 65 | parser.add_argument('--batch_size', type=int, default=16) 66 | parser.add_argument('--num_workers', type=int, default=1) 67 | parser.add_argument('--beta1', type=float, default=0.5) 68 | parser.add_argument('--beta2', type=float, default=0.999) 69 | parser.add_argument('--pretrained_model', type=str, default=None) 70 | 71 | # Test settings 72 | parser.add_argument('--test_model', type=str, default='20_1000') 73 | 74 | # Misc 75 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 76 | parser.add_argument('--use_tensorboard', type=str2bool, default=False) 77 | 78 | # Path 79 | parser.add_argument('--celebA_image_path', type=str, default='./data/CelebA_nocrop/images') 80 | parser.add_argument('--rafd_image_path', type=str, default='./data/RaFD/train') 81 | parser.add_argument('--flowers_image_path', type=str, default='../StarGAN/7386_flowers.hdf5') 82 | parser.add_argument('--metadata_path', type=str, default='./data/list_attr_celeba.txt') 83 | parser.add_argument('--log_path', type=str, default='./stargan/logs') 84 | parser.add_argument('--model_save_path', type=str, default='./stargan/models') 85 | parser.add_argument('--sample_path', type=str, default='./stargan/samples') 86 | parser.add_argument('--result_path', type=str, default='./stargan/results') 87 | 88 | # Step size 89 | parser.add_argument('--log_step', type=int, default=10) 90 | parser.add_argument('--sample_step', type=int, default=500) 91 | parser.add_argument('--model_save_step', type=int, default=1000) 92 | 93 | config = parser.parse_args() 94 | 95 | args = vars(config) 96 | print('------------ Options -------------') 97 | for k, v in sorted(args.items()): 98 | print('%s: %s' % (str(k), str(v))) 99 | print('-------------- End ----------------') 100 | 101 | main(config) -------------------------------------------------------------------------------- /FCN-8/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision 6 | from torch.autograd import Variable 7 | 8 | # https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py 9 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 10 | """Make a 2D bilinear kernel suitable for upsampling""" 11 | factor = (kernel_size + 1) // 2 12 | if kernel_size % 2 == 1: 13 | center = factor - 1 14 | else: 15 | center = factor - 0.5 16 | og = np.ogrid[:kernel_size, :kernel_size] 17 | filt = (1 - np.abs(og[0] - center) / factor) * \ 18 | (1 - np.abs(og[1] - center) / factor) 19 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 20 | dtype=np.float64) 21 | weight[range(in_channels), range(out_channels), :, :] = filt 22 | return torch.from_numpy(weight).float() 23 | 24 | 25 | class FCN32s(nn.Module): 26 | def __init__(self, n_class=21): 27 | super(FCN32s, self).__init__() 28 | # conv1 29 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 30 | self.relu1_1 = nn.ReLU(inplace=True) 31 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 32 | self.relu1_2 = nn.ReLU(inplace=True) 33 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 34 | 35 | # conv2 36 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 37 | self.relu2_1 = nn.ReLU(inplace=True) 38 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 39 | self.relu2_2 = nn.ReLU(inplace=True) 40 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 41 | 42 | # conv3 43 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 44 | self.relu3_1 = nn.ReLU(inplace=True) 45 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 46 | self.relu3_2 = nn.ReLU(inplace=True) 47 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 48 | self.relu3_3 = nn.ReLU(inplace=True) 49 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 50 | 51 | # conv4 52 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 53 | self.relu4_1 = nn.ReLU(inplace=True) 54 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 55 | self.relu4_2 = nn.ReLU(inplace=True) 56 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 57 | self.relu4_3 = nn.ReLU(inplace=True) 58 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 59 | 60 | # conv5 61 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 62 | self.relu5_1 = nn.ReLU(inplace=True) 63 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 64 | self.relu5_2 = nn.ReLU(inplace=True) 65 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 66 | self.relu5_3 = nn.ReLU(inplace=True) 67 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 68 | 69 | # fc6 70 | self.fc6 = nn.Conv2d(512, 4096, 7) 71 | self.relu6 = nn.ReLU(inplace=True) 72 | self.drop6 = nn.Dropout2d() 73 | 74 | # fc7 75 | self.fc7 = nn.Conv2d(4096, 4096, 1) 76 | self.relu7 = nn.ReLU(inplace=True) 77 | self.drop7 = nn.Dropout2d() 78 | 79 | self.score_fr = nn.Conv2d(4096, n_class, 1) 80 | self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, 81 | bias=False) 82 | 83 | def forward(self, x): 84 | h = x 85 | h = self.relu1_1(self.conv1_1(h)) 86 | h = self.relu1_2(self.conv1_2(h)) 87 | h = self.pool1(h) 88 | 89 | h = self.relu2_1(self.conv2_1(h)) 90 | h = self.relu2_2(self.conv2_2(h)) 91 | h = self.pool2(h) 92 | 93 | h = self.relu3_1(self.conv3_1(h)) 94 | h = self.relu3_2(self.conv3_2(h)) 95 | h = self.relu3_3(self.conv3_3(h)) 96 | h = self.pool3(h) 97 | 98 | h = self.relu4_1(self.conv4_1(h)) 99 | h = self.relu4_2(self.conv4_2(h)) 100 | h = self.relu4_3(self.conv4_3(h)) 101 | h = self.pool4(h) 102 | 103 | h = self.relu5_1(self.conv5_1(h)) 104 | h = self.relu5_2(self.conv5_2(h)) 105 | h = self.relu5_3(self.conv5_3(h)) 106 | h = self.pool5(h) 107 | 108 | h = self.relu6(self.fc6(h)) 109 | h = self.drop6(h) 110 | 111 | h = self.relu7(self.fc7(h)) 112 | h = self.drop7(h) 113 | 114 | h = self.score_fr(h) 115 | 116 | h = self.upscore(h) 117 | h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous() 118 | 119 | return h 120 | 121 | def initialize_weights(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | m.weight.data.normal_(0.0, 0.02) 125 | if m.bias is not None: 126 | m.bias.data.zero_() 127 | if isinstance(m, nn.ConvTranspose2d): 128 | assert m.kernel_size[0] == m.kernel_size[1] 129 | initial_weight = get_upsampling_weight( 130 | m.in_channels, m.out_channels, m.kernel_size[0]) 131 | m.weight.data.copy_(initial_weight) 132 | 133 | def copy_params_from_vgg16(self, vgg16): 134 | features = [ 135 | self.conv1_1, self.relu1_1, 136 | self.conv1_2, self.relu1_2, 137 | self.pool1, 138 | self.conv2_1, self.relu2_1, 139 | self.conv2_2, self.relu2_2, 140 | self.pool2, 141 | self.conv3_1, self.relu3_1, 142 | self.conv3_2, self.relu3_2, 143 | self.conv3_3, self.relu3_3, 144 | self.pool3, 145 | self.conv4_1, self.relu4_1, 146 | self.conv4_2, self.relu4_2, 147 | self.conv4_3, self.relu4_3, 148 | self.pool4, 149 | self.conv5_1, self.relu5_1, 150 | self.conv5_2, self.relu5_2, 151 | self.conv5_3, self.relu5_3, 152 | self.pool5, 153 | ] 154 | for l1, l2 in zip(vgg16.features, features): 155 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 156 | assert l1.weight.size() == l2.weight.size() 157 | assert l1.bias.size() == l2.bias.size() 158 | l2.weight.data = l1.weight.data 159 | l2.bias.data = l1.bias.data 160 | for i, name in zip([0, 3], ['fc6', 'fc7']): 161 | l1 = vgg16.classifier[i] 162 | l2 = getattr(self, name) 163 | l2.weight.data = l1.weight.data.view(l2.weight.size()) 164 | l2.bias.data = l1.bias.data.view(l2.bias.size()) 165 | 166 | 167 | 168 | class FCN16s(nn.Module): 169 | 170 | # pretrained_model = \ 171 | # osp.expanduser('~/data/models/pytorch/fcn16s_from_caffe.pth') 172 | 173 | # @classmethod 174 | # def download(cls): 175 | # return fcn.data.cached_download( 176 | # url='http://drive.google.com/uc?id=0B9P1L--7Wd2vVGE3TkRMbWlNRms', 177 | # path=cls.pretrained_model, 178 | # md5='991ea45d30d632a01e5ec48002cac617', 179 | # ) 180 | 181 | def __init__(self, n_class=21): 182 | super(FCN16s, self).__init__() 183 | # conv1 184 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 185 | self.relu1_1 = nn.ReLU(inplace=True) 186 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 187 | self.relu1_2 = nn.ReLU(inplace=True) 188 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 189 | 190 | # conv2 191 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 192 | self.relu2_1 = nn.ReLU(inplace=True) 193 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 194 | self.relu2_2 = nn.ReLU(inplace=True) 195 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 196 | 197 | # conv3 198 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 199 | self.relu3_1 = nn.ReLU(inplace=True) 200 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 201 | self.relu3_2 = nn.ReLU(inplace=True) 202 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 203 | self.relu3_3 = nn.ReLU(inplace=True) 204 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 205 | 206 | # conv4 207 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 208 | self.relu4_1 = nn.ReLU(inplace=True) 209 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 210 | self.relu4_2 = nn.ReLU(inplace=True) 211 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 212 | self.relu4_3 = nn.ReLU(inplace=True) 213 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 214 | 215 | # conv5 216 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 217 | self.relu5_1 = nn.ReLU(inplace=True) 218 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 219 | self.relu5_2 = nn.ReLU(inplace=True) 220 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 221 | self.relu5_3 = nn.ReLU(inplace=True) 222 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 223 | 224 | # fc6 225 | self.fc6 = nn.Conv2d(512, 4096, 7) 226 | self.relu6 = nn.ReLU(inplace=True) 227 | self.drop6 = nn.Dropout2d() 228 | 229 | # fc7 230 | self.fc7 = nn.Conv2d(4096, 4096, 1) 231 | self.relu7 = nn.ReLU(inplace=True) 232 | self.drop7 = nn.Dropout2d() 233 | 234 | self.score_fr = nn.Conv2d(4096, n_class, 1) 235 | self.score_pool4 = nn.Conv2d(512, n_class, 1) 236 | 237 | self.upscore2 = nn.ConvTranspose2d( 238 | n_class, n_class, 4, stride=2, bias=False) 239 | self.upscore16 = nn.ConvTranspose2d( 240 | n_class, n_class, 32, stride=16, bias=False) 241 | 242 | self._initialize_weights() 243 | 244 | def _initialize_weights(self): 245 | for m in self.modules(): 246 | if isinstance(m, nn.Conv2d): 247 | m.weight.data.zero_() 248 | if m.bias is not None: 249 | m.bias.data.zero_() 250 | if isinstance(m, nn.ConvTranspose2d): 251 | assert m.kernel_size[0] == m.kernel_size[1] 252 | initial_weight = get_upsampling_weight( 253 | m.in_channels, m.out_channels, m.kernel_size[0]) 254 | m.weight.data.copy_(initial_weight) 255 | 256 | def forward(self, x): 257 | h = x 258 | h = self.relu1_1(self.conv1_1(h)) 259 | h = self.relu1_2(self.conv1_2(h)) 260 | h = self.pool1(h) 261 | 262 | h = self.relu2_1(self.conv2_1(h)) 263 | h = self.relu2_2(self.conv2_2(h)) 264 | h = self.pool2(h) 265 | 266 | h = self.relu3_1(self.conv3_1(h)) 267 | h = self.relu3_2(self.conv3_2(h)) 268 | h = self.relu3_3(self.conv3_3(h)) 269 | h = self.pool3(h) 270 | 271 | h = self.relu4_1(self.conv4_1(h)) 272 | h = self.relu4_2(self.conv4_2(h)) 273 | h = self.relu4_3(self.conv4_3(h)) 274 | h = self.pool4(h) 275 | pool4 = h # 1/16 276 | 277 | h = self.relu5_1(self.conv5_1(h)) 278 | h = self.relu5_2(self.conv5_2(h)) 279 | h = self.relu5_3(self.conv5_3(h)) 280 | h = self.pool5(h) 281 | 282 | h = self.relu6(self.fc6(h)) 283 | h = self.drop6(h) 284 | 285 | h = self.relu7(self.fc7(h)) 286 | h = self.drop7(h) 287 | 288 | h = self.score_fr(h) 289 | h = self.upscore2(h) 290 | upscore2 = h # 1/16 291 | 292 | h = self.score_pool4(pool4) 293 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 294 | score_pool4c = h # 1/16 295 | 296 | h = upscore2 + score_pool4c 297 | 298 | h = self.upscore16(h) 299 | h = h[:, :, 27:27 + x.size()[2], 27:27 + x.size()[3]].contiguous() 300 | 301 | return h 302 | 303 | def copy_params_from_fcn32s(self, fcn32s): 304 | for name, l1 in fcn32s.named_children(): 305 | try: 306 | l2 = getattr(self, name) 307 | l2.weight # skip ReLU / Dropout 308 | except Exception: 309 | continue 310 | assert l1.weight.size() == l2.weight.size() 311 | assert l1.bias.size() == l2.bias.size() 312 | l2.weight.data.copy_(l1.weight.data) 313 | l2.bias.data.copy_(l1.bias.data) 314 | 315 | 316 | class FCN8s(nn.Module): 317 | 318 | # pretrained_model = \ 319 | # osp.expanduser('~/data/models/pytorch/fcn8s_from_caffe.pth') 320 | 321 | # @classmethod 322 | # def download(cls): 323 | # return fcn.data.cached_download( 324 | # url='http://drive.google.com/uc?id=0B9P1L--7Wd2vT0FtdThWREhjNkU', 325 | # path=cls.pretrained_model, 326 | # md5='dbd9bbb3829a3184913bccc74373afbb', 327 | # ) 328 | 329 | def __init__(self, n_class=21): 330 | super(FCN8s, self).__init__() 331 | # conv1 332 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 333 | self.relu1_1 = nn.ReLU(inplace=True) 334 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 335 | self.relu1_2 = nn.ReLU(inplace=True) 336 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 337 | 338 | # conv2 339 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 340 | self.relu2_1 = nn.ReLU(inplace=True) 341 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 342 | self.relu2_2 = nn.ReLU(inplace=True) 343 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 344 | 345 | # conv3 346 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 347 | self.relu3_1 = nn.ReLU(inplace=True) 348 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 349 | self.relu3_2 = nn.ReLU(inplace=True) 350 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 351 | self.relu3_3 = nn.ReLU(inplace=True) 352 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 353 | 354 | # conv4 355 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 356 | self.relu4_1 = nn.ReLU(inplace=True) 357 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 358 | self.relu4_2 = nn.ReLU(inplace=True) 359 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 360 | self.relu4_3 = nn.ReLU(inplace=True) 361 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 362 | 363 | # conv5 364 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 365 | self.relu5_1 = nn.ReLU(inplace=True) 366 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 367 | self.relu5_2 = nn.ReLU(inplace=True) 368 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 369 | self.relu5_3 = nn.ReLU(inplace=True) 370 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 371 | 372 | # fc6 373 | self.fc6 = nn.Conv2d(512, 4096, 7) 374 | self.relu6 = nn.ReLU(inplace=True) 375 | self.drop6 = nn.Dropout2d() 376 | 377 | # fc7 378 | self.fc7 = nn.Conv2d(4096, 4096, 1) 379 | self.relu7 = nn.ReLU(inplace=True) 380 | self.drop7 = nn.Dropout2d() 381 | 382 | self.score_fr = nn.Conv2d(4096, n_class, 1) 383 | self.score_pool3 = nn.Conv2d(256, n_class, 1) 384 | self.score_pool4 = nn.Conv2d(512, n_class, 1) 385 | 386 | self.upscore2 = nn.ConvTranspose2d( 387 | n_class, n_class, 4, stride=2, bias=False) 388 | self.upscore8 = nn.ConvTranspose2d( 389 | n_class, n_class, 16, stride=8, bias=False) 390 | self.upscore_pool4 = nn.ConvTranspose2d( 391 | n_class, n_class, 4, stride=2, bias=False) 392 | 393 | self._initialize_weights() 394 | 395 | def _initialize_weights(self): 396 | for m in self.modules(): 397 | if isinstance(m, nn.Conv2d): 398 | m.weight.data.zero_() 399 | if m.bias is not None: 400 | m.bias.data.zero_() 401 | if isinstance(m, nn.ConvTranspose2d): 402 | assert m.kernel_size[0] == m.kernel_size[1] 403 | initial_weight = get_upsampling_weight( 404 | m.in_channels, m.out_channels, m.kernel_size[0]) 405 | m.weight.data.copy_(initial_weight) 406 | 407 | def forward(self, x): 408 | h = x 409 | h = self.relu1_1(self.conv1_1(h)) 410 | h = self.relu1_2(self.conv1_2(h)) 411 | h = self.pool1(h) 412 | 413 | h = self.relu2_1(self.conv2_1(h)) 414 | h = self.relu2_2(self.conv2_2(h)) 415 | h = self.pool2(h) 416 | 417 | h = self.relu3_1(self.conv3_1(h)) 418 | h = self.relu3_2(self.conv3_2(h)) 419 | h = self.relu3_3(self.conv3_3(h)) 420 | h = self.pool3(h) 421 | pool3 = h # 1/8 422 | 423 | h = self.relu4_1(self.conv4_1(h)) 424 | h = self.relu4_2(self.conv4_2(h)) 425 | h = self.relu4_3(self.conv4_3(h)) 426 | h = self.pool4(h) 427 | pool4 = h # 1/16 428 | 429 | h = self.relu5_1(self.conv5_1(h)) 430 | h = self.relu5_2(self.conv5_2(h)) 431 | h = self.relu5_3(self.conv5_3(h)) 432 | h = self.pool5(h) 433 | 434 | h = self.relu6(self.fc6(h)) 435 | h = self.drop6(h) 436 | 437 | h = self.relu7(self.fc7(h)) 438 | h = self.drop7(h) 439 | 440 | h = self.score_fr(h) 441 | h = self.upscore2(h) 442 | upscore2 = h # 1/16 443 | 444 | h = self.score_pool4(pool4) 445 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 446 | score_pool4c = h # 1/16 447 | 448 | h = upscore2 + score_pool4c # 1/16 449 | h = self.upscore_pool4(h) 450 | upscore_pool4 = h # 1/8 451 | 452 | h = self.score_pool3(pool3) 453 | h = h[:, :, 454 | 9:9 + upscore_pool4.size()[2], 455 | 9:9 + upscore_pool4.size()[3]] 456 | score_pool3c = h # 1/8 457 | 458 | h = upscore_pool4 + score_pool3c # 1/8 459 | 460 | h = self.upscore8(h) 461 | h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous() 462 | 463 | return h 464 | 465 | def copy_params_from_fcn16s(self, fcn16s): 466 | for name, l1 in fcn16s.named_children(): 467 | try: 468 | l2 = getattr(self, name) 469 | l2.weight # skip ReLU / Dropout 470 | except Exception: 471 | continue 472 | assert l1.weight.size() == l2.weight.size() 473 | l2.weight.data.copy_(l1.weight.data) 474 | if l1.bias is not None: 475 | assert l1.bias.size() == l2.bias.size() 476 | l2.bias.data.copy_(l1.bias.data) 477 | 478 | 479 | class FCN8sAtOnce(FCN8s): 480 | 481 | # pretrained_model = \ 482 | # osp.expanduser('~/data/models/pytorch/fcn8s-atonce_from_caffe.pth') 483 | 484 | # @classmethod 485 | # def download(cls): 486 | # return fcn.data.cached_download( 487 | # url='http://drive.google.com/uc?id=0B9P1L--7Wd2vblE1VUIxV1o2d2M', 488 | # path=cls.pretrained_model, 489 | # md5='bfed4437e941fef58932891217fe6464', 490 | # ) 491 | 492 | def forward(self, x): 493 | h = x 494 | h = self.relu1_1(self.conv1_1(h)) 495 | h = self.relu1_2(self.conv1_2(h)) 496 | h = self.pool1(h) 497 | 498 | h = self.relu2_1(self.conv2_1(h)) 499 | h = self.relu2_2(self.conv2_2(h)) 500 | h = self.pool2(h) 501 | 502 | h = self.relu3_1(self.conv3_1(h)) 503 | h = self.relu3_2(self.conv3_2(h)) 504 | h = self.relu3_3(self.conv3_3(h)) 505 | h = self.pool3(h) 506 | pool3 = h # 1/8 507 | 508 | h = self.relu4_1(self.conv4_1(h)) 509 | h = self.relu4_2(self.conv4_2(h)) 510 | h = self.relu4_3(self.conv4_3(h)) 511 | h = self.pool4(h) 512 | pool4 = h # 1/16 513 | 514 | h = self.relu5_1(self.conv5_1(h)) 515 | h = self.relu5_2(self.conv5_2(h)) 516 | h = self.relu5_3(self.conv5_3(h)) 517 | h = self.pool5(h) 518 | 519 | h = self.relu6(self.fc6(h)) 520 | h = self.drop6(h) 521 | 522 | h = self.relu7(self.fc7(h)) 523 | h = self.drop7(h) 524 | 525 | h = self.score_fr(h) 526 | h = self.upscore2(h) 527 | upscore2 = h # 1/16 528 | 529 | h = self.score_pool4(pool4 * 0.01) # XXX: scaling to train at once 530 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 531 | score_pool4c = h # 1/16 532 | 533 | h = upscore2 + score_pool4c # 1/16 534 | h = self.upscore_pool4(h) 535 | upscore_pool4 = h # 1/8 536 | 537 | h = self.score_pool3(pool3 * 0.0001) # XXX: scaling to train at once 538 | h = h[:, :, 539 | 9:9 + upscore_pool4.size()[2], 540 | 9:9 + upscore_pool4.size()[3]] 541 | score_pool3c = h # 1/8 542 | 543 | h = upscore_pool4 + score_pool3c # 1/8 544 | 545 | h = self.upscore8(h) 546 | h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous() 547 | 548 | return h 549 | 550 | def copy_params_from_vgg16(self, vgg16): 551 | features = [ 552 | self.conv1_1, self.relu1_1, 553 | self.conv1_2, self.relu1_2, 554 | self.pool1, 555 | self.conv2_1, self.relu2_1, 556 | self.conv2_2, self.relu2_2, 557 | self.pool2, 558 | self.conv3_1, self.relu3_1, 559 | self.conv3_2, self.relu3_2, 560 | self.conv3_3, self.relu3_3, 561 | self.pool3, 562 | self.conv4_1, self.relu4_1, 563 | self.conv4_2, self.relu4_2, 564 | self.conv4_3, self.relu4_3, 565 | self.pool4, 566 | self.conv5_1, self.relu5_1, 567 | self.conv5_2, self.relu5_2, 568 | self.conv5_3, self.relu5_3, 569 | self.pool5, 570 | ] 571 | for l1, l2 in zip(vgg16.features, features): 572 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 573 | assert l1.weight.size() == l2.weight.size() 574 | assert l1.bias.size() == l2.bias.size() 575 | l2.weight.data.copy_(l1.weight.data) 576 | l2.bias.data.copy_(l1.bias.data) 577 | for i, name in zip([0, 3], ['fc6', 'fc7']): 578 | l1 = vgg16.classifier[i] 579 | l2 = getattr(self, name) 580 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 581 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 582 | 583 | class VGG16Modified(nn.Module): 584 | 585 | def __init__(self, n_classes=21): 586 | super(VGG16Modified, self).__init__() 587 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 588 | self.relu1_1 = nn.ReLU(inplace=True) 589 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 590 | self.relu1_2 = nn.ReLU(inplace=True) 591 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 128 592 | 593 | # conv2 594 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 595 | self.relu2_1 = nn.ReLU(inplace=True) 596 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 597 | self.relu2_2 = nn.ReLU(inplace=True) 598 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 64 599 | 600 | # conv3 601 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 602 | self.relu3_1 = nn.ReLU(inplace=True) 603 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 604 | self.relu3_2 = nn.ReLU(inplace=True) 605 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 606 | self.relu3_3 = nn.ReLU(inplace=True) 607 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 32 608 | 609 | # conv4 610 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 611 | self.relu4_1 = nn.ReLU(inplace=True) 612 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 613 | self.relu4_2 = nn.ReLU(inplace=True) 614 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 615 | self.relu4_3 = nn.ReLU(inplace=True) 616 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 16 617 | 618 | # conv5 619 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 620 | self.relu5_1 = nn.ReLU(inplace=True) 621 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 622 | self.relu5_2 = nn.ReLU(inplace=True) 623 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 624 | self.relu5_3 = nn.ReLU(inplace=True) 625 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 8 626 | 627 | 628 | self.conv6s_re = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 629 | nn.ReLU(inplace=True), 630 | nn.Upsample(scale_factor=2, mode='bilinear')) 631 | self.conv6_3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 632 | nn.ReLU(inplace=True)) 633 | self.conv6_2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 634 | nn.ReLU(inplace=True)) 635 | self.conv6_1 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 636 | nn.ReLU(inplace=True), 637 | nn.Upsample(scale_factor=2, mode='bilinear')) 638 | 639 | self.conv7_3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 640 | nn.ReLU(inplace=True)) 641 | self.conv7_2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 642 | nn.ReLU(inplace=True)) 643 | self.conv7_1 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=True), 644 | nn.ReLU(inplace=True), 645 | nn.Upsample(scale_factor=2, mode='bilinear')) 646 | 647 | self.conv8_3 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 648 | nn.ReLU(inplace=True)) 649 | self.conv8_2 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 650 | nn.ReLU(inplace=True)) 651 | self.conv8_1 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=True), 652 | nn.ReLU(inplace=True), 653 | nn.Upsample(scale_factor=2, mode='bilinear')) 654 | 655 | self.conv9 = nn.Sequential(nn.Conv2d(128, 32*3*4, kernel_size=3, stride=1, padding=1, bias=True), 656 | nn.Tanh()) 657 | 658 | self.conv10 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=True), 659 | nn.ReLU(inplace=True)) 660 | 661 | self.conv11 = nn.Sequential(nn.Conv2d(64, n_classes, kernel_size=3, stride=1, padding=1, bias=True), 662 | nn.Upsample(scale_factor=2, mode='bilinear')) 663 | 664 | 665 | 666 | self.coarse_conv_in = nn.Sequential(nn.Conv2d(n_classes, 32, kernel_size=3, stride=1, padding=1, bias=True), 667 | nn.ReLU(inplace=True), 668 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=True), 669 | nn.ReLU(inplace=True), 670 | nn.AvgPool2d(kernel_size=2, stride=2)) 671 | 672 | 673 | def to_tridiagonal_multidim(self, w): 674 | N,W,C,D = w.size() 675 | tmp_w = w.unsqueeze(2).expand([N,W,W,C,D]) 676 | 677 | eye_a = Variable(torch.diag(torch.ones(W-1).cuda(),diagonal=-1)) 678 | eye_b = Variable(torch.diag(torch.ones(W).cuda(),diagonal=0)) 679 | eye_c = Variable(torch.diag(torch.ones(W-1).cuda(),diagonal=1)) 680 | 681 | 682 | tmp_eye_a = eye_a.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 683 | a = tmp_w[:,:,:,:,0] * tmp_eye_a 684 | tmp_eye_b = eye_b.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 685 | b = tmp_w[:,:,:,:,1] * tmp_eye_b 686 | tmp_eye_c = eye_c.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 687 | c = tmp_w[:,:,:,:,2] * tmp_eye_c 688 | 689 | return a+b+c 690 | def forward(self, x, coarse_segmentation): 691 | h = x 692 | h = self.relu1_1(self.conv1_1(h)) 693 | h = self.relu1_2(self.conv1_2(h)) 694 | h = self.pool1(h) 695 | 696 | h = self.relu2_1(self.conv2_1(h)) 697 | h = self.relu2_2(self.conv2_2(h)) 698 | h = self.pool2(h) 699 | 700 | conv3_1 = self.relu3_1(self.conv3_1(h)) 701 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) 702 | conv3_3= self.relu3_3(self.conv3_3(conv3_2)) 703 | h = self.pool3(conv3_3) 704 | pool3 = h # 1/8 705 | 706 | conv4_1 = self.relu4_1(self.conv4_1(h)) 707 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) 708 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) 709 | h = self.pool4(conv4_3) 710 | pool4 = h # 1/16 711 | 712 | conv5_1 = self.relu5_1(self.conv5_1(h)) 713 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) 714 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) 715 | h = self.pool5(conv5_3) 716 | 717 | 718 | 719 | conv6_re = self.conv6s_re(h) 720 | 721 | 722 | 723 | skip_1 = conv5_3 + conv6_re 724 | conv6_3 = self.conv6_3(skip_1) 725 | skip_2 = conv5_2 + conv6_3 726 | conv6_2 = self.conv6_2(skip_2) 727 | skip_3 = conv5_1 + conv6_2 728 | conv6_1 = self.conv6_1(skip_3) 729 | 730 | skip_4 = conv4_3 + conv6_1 731 | conv7_3 = self.conv7_3(skip_4) 732 | skip_5 = conv4_2 + conv7_3 733 | conv7_2 = self.conv7_2(skip_5) 734 | skip_6 = conv4_1 + conv7_2 735 | conv7_1 = self.conv7_1(skip_6) 736 | 737 | skip_7 = conv3_3 + conv7_1 738 | conv8_3 = self.conv8_3(skip_7) 739 | skip_8 = conv3_2 + conv8_3 740 | conv8_2 = self.conv8_2(skip_8) 741 | skip_9 = conv3_1 + conv8_2 742 | conv8_1 = self.conv8_1(skip_9) 743 | 744 | conv9 = self.conv9(conv8_1) 745 | 746 | N,C,H,W = conv9.size() 747 | four_directions = C // 4 748 | conv9_reshaped_W = conv9.permute(0,2,3,1) 749 | # conv9_reshaped_H = conv9.permute(0,3,2,1) 750 | 751 | conv_x1_flat = conv9_reshaped_W[:,:,:,0:four_directions].contiguous() 752 | conv_y1_flat = conv9_reshaped_W[:,:,:,four_directions:2*four_directions].contiguous() 753 | conv_x2_flat = conv9_reshaped_W[:,:,:,2*four_directions:3*four_directions].contiguous() 754 | conv_y2_flat = conv9_reshaped_W[:,:,:,3*four_directions:4*four_directions].contiguous() 755 | 756 | w_x1 = conv_x1_flat.view(N,H,W,four_directions//3,3) # N, H, W, 32, 3 757 | w_y1 = conv_y1_flat.view(N,H,W,four_directions//3,3) # N, H, W, 32, 3 758 | w_x2 = conv_x2_flat.view(N,H,W,four_directions//3,3) # N, H, W, 32, 3 759 | w_y2 = conv_y2_flat.view(N,H,W,four_directions//3,3) # N, H, W, 32, 3 760 | 761 | rnn_h1 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 762 | rnn_h2 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 763 | rnn_h3 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 764 | rnn_h4 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 765 | 766 | x_t = self.coarse_conv_in(coarse_segmentation).permute(0,2,3,1) 767 | 768 | # horizontal 769 | for i in range(W): 770 | #left to right 771 | tmp_w = w_x1[:,:,i,:,:] # N, H, 1, 32, 3 772 | tmp_w = self.to_tridiagonal_multidim(tmp_w) # N, H, W, 32 773 | # tmp_x = x_t[:,:,i,:].unsqueeze(1) 774 | # tmp_x = tmp_x.expand([batch, W, H, 32]) 775 | 776 | w_h_prev_1 = torch.sum(tmp_w * rnn_h1[:,:,i-1,:].clone().unsqueeze(1).expand([N, W, H, 32]),dim=2) 777 | w_x_curr_1 = (1 - torch.sum(tmp_w, dim=2)) * x_t[:,:,i,:] 778 | rnn_h1[:,:,i,:] = w_x_curr_1 + w_h_prev_1 779 | 780 | 781 | #right to left 782 | # tmp_w = w_x1[:,:,i,:,:] # N, H, 1, 32, 3 783 | # tmp_w = to_tridiagonal_multidim(tmp_w) 784 | w_h_prev_2 = torch.sum(tmp_w * rnn_h2[:,:,i-1,:].clone().unsqueeze(1).expand([N, W, H, 32]),dim=2) 785 | w_x_curr_2 = (1 - torch.sum(tmp_w, dim=2)) * x_t[:,:,W - i-1,:] 786 | rnn_h2[:,:,i,:] = w_x_curr_2 + w_h_prev_2 787 | 788 | w_y1_T = w_y1.transpose(1,2) 789 | x_t_T = x_t.transpose(1,2) 790 | 791 | for i in range(H): 792 | # up to down 793 | tmp_w = w_y1_T[:,:,i,:,:] # N, W, 1, 32, 3 794 | tmp_w = self.to_tridiagonal_multidim(tmp_w) # N, W, H, 32 795 | 796 | w_h_prev_3 = torch.sum(tmp_w * rnn_h3[:,:,i-1,:].clone().unsqueeze(1).expand([N, H, W, 32]),dim=2) 797 | w_x_curr_3 = (1 - torch.sum(tmp_w, dim=2)) * x_t_T[:,:,i,:] 798 | rnn_h3[:,:,i,:] = w_x_curr_3 + w_h_prev_3 799 | 800 | # down to up 801 | w_h_prev_4 = torch.sum(tmp_w * rnn_h4[:,:,i-1,:].clone().unsqueeze(1).expand([N, H, W, 32]),dim=2) 802 | w_x_curr_4 = (1 - torch.sum(tmp_w, dim=2)) * x_t[:,:,H-i-1,:] 803 | rnn_h4[:,:,i,:] = w_x_curr_4 + w_h_prev_4 804 | 805 | rnn_h3 = rnn_h3.transpose(1,2) 806 | rnn_h4 = rnn_h4.transpose(1,2) 807 | 808 | 809 | 810 | rnn_h5 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 811 | rnn_h6 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 812 | rnn_h7 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 813 | rnn_h8 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 814 | 815 | # horizontal 816 | for i in range(W): 817 | #left to right 818 | tmp_w = w_x2[:,:,i,:,:] # N, H, 1, 32, 3 819 | tmp_w = self.to_tridiagonal_multidim(tmp_w) # N, H, W, 32 820 | # tmp_x = x_t[:,:,i,:].unsqueeze(1) 821 | # tmp_x = tmp_x.expand([batch, W, H, 32]) 822 | 823 | w_h_prev_5 = torch.sum(tmp_w * rnn_h5[:,:,i-1,:].clone().unsqueeze(1).expand([N, W, H, 32]),dim=2) 824 | w_x_curr_5 = (1 - torch.sum(tmp_w, dim=2)) * rnn_h1[:,:,i,:] 825 | rnn_h5[:,:,i,:] = w_x_curr_5 + w_h_prev_5 826 | 827 | 828 | #right to left 829 | # tmp_w = w_x1[:,:,i,:,:] # N, H, 1, 32, 3 830 | # tmp_w = to_tridiagonal_multidim(tmp_w) 831 | w_h_prev_6 = torch.sum(tmp_w * rnn_h6[:,:,i-1,:].clone().unsqueeze(1).expand([N, W, H, 32]),dim=2) 832 | w_x_curr_6 = (1 - torch.sum(tmp_w, dim=2)) * rnn_h2[:,:,W - i-1,:] 833 | rnn_h6[:,:,i,:] = w_x_curr_6 + w_h_prev_6 834 | 835 | w_y2_T = w_y2.transpose(1,2) 836 | rnn_h3_T = rnn_h3.transpose(1,2) 837 | rnn_h4_T = rnn_h4.transpose(1,2) 838 | for i in range(H): 839 | # up to down 840 | tmp_w = w_y2_T[:,:,i,:,:] # N, W, 1, 32, 3 841 | tmp_w = self.to_tridiagonal_multidim(tmp_w) # N, W, H, 32 842 | 843 | w_h_prev_7 = torch.sum(tmp_w * rnn_h7[:,:,i-1,:].clone().unsqueeze(1).expand([N, H, W, 32]),dim=2) 844 | w_x_curr_7 = (1 - torch.sum(tmp_w, dim=2)) * rnn_h3_T[:,:,i,:] 845 | rnn_h7[:,:,i,:] = w_x_curr_7 + w_h_prev_7 846 | 847 | # down to up 848 | w_h_prev_8 = torch.sum(tmp_w * rnn_h8[:,:,i-1,:].clone().unsqueeze(1).expand([N, H, W, 32]),dim=2) 849 | w_x_curr_8 = (1 - torch.sum(tmp_w, dim=2)) * rnn_h4_T[:,:,H-i-1,:] 850 | rnn_h8[:,:,i,:] = w_x_curr_8 + w_h_prev_8 851 | 852 | rnn_h3 = rnn_h3.transpose(1,2) 853 | rnn_h4 = rnn_h4.transpose(1,2) 854 | 855 | concat6 = torch.cat([rnn_h5.unsqueeze(4),rnn_h6.unsqueeze(4),rnn_h7.unsqueeze(4),rnn_h8.unsqueeze(4)],dim=4) 856 | elt_max = torch.max(concat6, dim=4)[0] 857 | elt_max_reordered = elt_max.permute(0,3,1,2) 858 | conv10 = self.conv10(elt_max_reordered) 859 | conv11 = self.conv11(conv10) 860 | return conv11 861 | 862 | def copy_params_from_vgg16(self, vgg_model_file): 863 | features = [ 864 | self.conv1_1, self.relu1_1, 865 | self.conv1_2, self.relu1_2, 866 | self.pool1, 867 | self.conv2_1, self.relu2_1, 868 | self.conv2_2, self.relu2_2, 869 | self.pool2, 870 | self.conv3_1, self.relu3_1, 871 | self.conv3_2, self.relu3_2, 872 | self.conv3_3, self.relu3_3, 873 | self.pool3, 874 | self.conv4_1, self.relu4_1, 875 | self.conv4_2, self.relu4_2, 876 | self.conv4_3, self.relu4_3, 877 | self.pool4, 878 | self.conv5_1, self.relu5_1, 879 | self.conv5_2, self.relu5_2, 880 | self.conv5_3, self.relu5_3, 881 | self.pool5, 882 | ] 883 | 884 | 885 | vgg16 = torchvision.models.vgg16(pretrained=False) 886 | state_dict = torch.load(vgg_model_file) 887 | vgg16.load_state_dict(state_dict) 888 | 889 | 890 | for l1, l2 in zip(vgg16.features, features): 891 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 892 | assert l1.weight.size() == l2.weight.size() 893 | assert l1.bias.size() == l2.bias.size() 894 | l2.weight.data.copy_(l1.weight.data) 895 | l2.bias.data.copy_(l1.bias.data) 896 | 897 | -------------------------------------------------------------------------------- /FCN-8/models/vgg.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import fcn 4 | 5 | import torch 6 | import torchvision 7 | 8 | 9 | def VGG16(pretrained=False): 10 | model = torchvision.models.vgg16(pretrained=False) 11 | if not pretrained: 12 | return model 13 | model_file = _get_vgg16_pretrained_model() 14 | state_dict = torch.load(model_file) 15 | model.load_state_dict(state_dict) 16 | return model 17 | 18 | 19 | def _get_vgg16_pretrained_model(): 20 | return fcn.data.cached_download( 21 | url='http://drive.google.com/uc?id=0B9P1L--7Wd2vLTJZMXpIRkVVRFk', 22 | path='./vgg16_from_caffe.pth', 23 | md5='aa75b158f4181e7f6230029eb96c1b13', 24 | ) 25 | -------------------------------------------------------------------------------- /FCN-8/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | import time 7 | import datetime 8 | from torch.autograd import grad 9 | from torch.autograd import Variable 10 | from torchvision.utils import save_image 11 | from torchvision import transforms 12 | from model import Generator 13 | from model import Discriminator 14 | from PIL import Image 15 | 16 | 17 | class Solver(object): 18 | DEFAULTS = {} 19 | def __init__(self, data_loader, config): 20 | self.__dict__.update(Solver.DEFAULTS, **config) 21 | self.flowers_loader = data_loader["Flowers"] 22 | # Build tensorboard if use 23 | self.build_model() 24 | if self.use_tensorboard: 25 | self.build_tensorboard() 26 | 27 | # Start with trained model 28 | if self.pretrained_model: 29 | self.load_pretrained_model() 30 | 31 | def build_model(self): 32 | # Define a generator and a discriminator 33 | self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) 34 | self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 35 | 36 | # Optimizers 37 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) 38 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) 39 | 40 | # Print networks 41 | self.print_network(self.G, 'G') 42 | self.print_network(self.D, 'D') 43 | 44 | if torch.cuda.is_available(): 45 | self.G.cuda() 46 | self.D.cuda() 47 | 48 | def train(self): 49 | """Train StarGAN within a single dataset.""" 50 | 51 | # Set dataloader 52 | if self.dataset == 'CelebA': 53 | self.data_loader = self.celebA_loader 54 | elif self.dataset == 'Flowers': 55 | self.data_loader = self.flowers_loader 56 | else: 57 | self.data_loader = self.rafd_loader 58 | 59 | # The number of iterations per epoch 60 | iters_per_epoch = len(self.data_loader) 61 | 62 | fixed_x = [] 63 | real_c = [] 64 | for i, (images, labels, identity) in enumerate(self.data_loader): 65 | fixed_x.append(images) 66 | real_c.append(labels) 67 | if i == 0: 68 | break 69 | 70 | # Fixed inputs and target domain labels for debugging 71 | fixed_x = torch.cat(fixed_x, dim=0) 72 | fixed_x = self.to_var(fixed_x, volatile=True) 73 | real_c = torch.cat(real_c, dim=0) 74 | 75 | if self.dataset == 'CelebA': 76 | fixed_c_list = self.make_celeb_labels(real_c) 77 | elif self.dataset == 'Flowers': 78 | fixed_c_list = self.make_flowers_labels(real_c) 79 | elif self.dataset == 'RaFD': 80 | fixed_c_list = [] 81 | for i in range(self.c_dim): 82 | fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim) 83 | fixed_c_list.append(self.to_var(fixed_c, volatile=True)) 84 | 85 | # lr cache for decaying 86 | g_lr = self.g_lr 87 | d_lr = self.d_lr 88 | 89 | # Start with trained model if exists 90 | if self.pretrained_model: 91 | start = int(self.pretrained_model.split('_')[0]) 92 | else: 93 | start = 0 94 | 95 | # Start training 96 | start_time = time.time() 97 | for e in range(start, self.num_epochs): 98 | for i, (real_x, real_label, identity) in enumerate(self.data_loader): 99 | 100 | # Generate fake labels randomly (target domain labels) 101 | rand_idx = torch.randperm(real_label.size(0)) 102 | fake_label = real_label[rand_idx] 103 | 104 | real_c = real_label.clone() 105 | fake_c = fake_label.clone() 106 | 107 | # Convert tensor to variable 108 | real_x = self.to_var(real_x) 109 | real_c = self.to_var(real_c) # input for the generator 110 | fake_c = self.to_var(fake_c) 111 | real_label = self.to_var(real_label) # this is same as real_c if dataset == 'CelebA' 112 | fake_label = self.to_var(fake_label) 113 | identity = self.to_var(identity) 114 | 115 | 116 | 117 | # ================== Train D ================== # 118 | 119 | # Compute loss with real images 120 | real_gan, real_feature = self.D(real_x) 121 | 122 | d_loss_real = -torch.mean(real_gan) 123 | 124 | d_loss_feature = F.binary_cross_entropy_with_logits( 125 | real_feature, real_label, size_average=False) / real_x.size(0) 126 | 127 | # Compute loss with fake images 128 | fake_x = self.G(real_x, fake_c) 129 | fake_x = Variable(fake_x.data) 130 | fake_gan, fake_feature = self.D(fake_x) 131 | 132 | d_loss_fake = torch.mean(fake_gan) 133 | d_loss_gan = d_loss_real + d_loss_fake 134 | 135 | # Backward + Optimize 136 | d_loss = d_loss_gan + d_loss_feature 137 | self.reset_grad() 138 | d_loss.backward() 139 | self.d_optimizer.step() 140 | 141 | # Compute classification accuracy of the discriminator 142 | if (i+1) % self.log_step == 0: 143 | accuracies = self.compute_accuracy(real_feature, real_label, self.dataset) 144 | log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] 145 | if self.dataset == 'CelebA': 146 | print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') 147 | else: 148 | print('Classification Acc (8 emotional expressions): ', end='') 149 | print(log) 150 | 151 | # Compute gradient penalty 152 | alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) 153 | interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) 154 | out, out_cls = self.D(interpolated) 155 | 156 | grad = torch.autograd.grad(outputs=out, 157 | inputs=interpolated, 158 | grad_outputs=torch.ones(out.size()).cuda(), 159 | retain_graph=True, 160 | create_graph=True, 161 | only_inputs=True)[0] 162 | 163 | grad = grad.view(grad.size(0), -1) 164 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) 165 | d_loss_gp = torch.mean((grad_l2norm - 1)**2) 166 | 167 | # Backward + Optimize 168 | d_loss = self.lambda_gp * d_loss_gp 169 | self.reset_grad() 170 | d_loss.backward() 171 | self.d_optimizer.step() 172 | 173 | # Logging 174 | loss = {} 175 | loss['D/loss_real'] = d_loss_real.data[0] 176 | loss['D/loss_fake'] = d_loss_fake.data[0] 177 | loss['D/loss_feature'] = d_loss_feature.data[0] 178 | # loss['D/loss_gp'] = d_loss_gp.data[0] 179 | 180 | # ================== Train G ================== # 181 | if (i+1) % self.d_train_repeat == 0: 182 | 183 | # Original-to-target and target-to-original domain 184 | fake_x = self.G(real_x, fake_c) 185 | rec_x = self.G(fake_x, real_c) 186 | 187 | # Compute losses 188 | fake_gan, fake_features = self.D(fake_x) 189 | g_loss_gan = - torch.mean(fake_gan) 190 | g_loss_cycle = torch.mean(torch.abs(real_x - rec_x)) 191 | 192 | 193 | g_loss_feature = F.binary_cross_entropy_with_logits( 194 | fake_features, fake_label, size_average=False) / fake_x.size(0) 195 | 196 | # Backward + Optimize 197 | g_loss = g_loss_gan + 10 * g_loss_cycle + g_loss_feature 198 | 199 | self.reset_grad() 200 | g_loss.backward() 201 | self.g_optimizer.step() 202 | 203 | # Logging 204 | loss['G/loss_gan'] = g_loss_gan.data[0] 205 | loss['G/loss_cycle'] = 10 * g_loss_cycle.data[0] 206 | loss['G/loss_feature'] = g_loss_feature.data[0] 207 | 208 | # Print out log info 209 | if (i+1) % self.log_step == 0: 210 | elapsed = time.time() - start_time 211 | elapsed = str(datetime.timedelta(seconds=elapsed)) 212 | 213 | log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( 214 | elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) 215 | 216 | for tag, value in loss.items(): 217 | log += ", {}: {:.4f}".format(tag, value) 218 | print(log) 219 | 220 | if self.use_tensorboard: 221 | for tag, value in loss.items(): 222 | self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) 223 | 224 | # Translate fixed images for debugging 225 | if (i+1) % self.sample_step == 0: 226 | fake_image_list = [fixed_x] 227 | for fixed_c in fixed_c_list: 228 | fake_image_list.append(self.G(fixed_x, fixed_c)) 229 | fake_images = torch.cat(fake_image_list, dim=3) 230 | save_image(self.denorm(fake_images.data), 231 | os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) 232 | print('Translated images and saved into {}..!'.format(self.sample_path)) 233 | 234 | # Save model checkpoints 235 | if (i+1) % self.model_save_step == 0: 236 | torch.save(self.G.state_dict(), 237 | os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1))) 238 | torch.save(self.D.state_dict(), 239 | os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1))) 240 | 241 | # Decay learning rate 242 | if (e+1) > (self.num_epochs - self.num_epochs_decay): 243 | g_lr -= (self.g_lr / float(self.num_epochs_decay)) 244 | d_lr -= (self.d_lr / float(self.num_epochs_decay)) 245 | self.update_lr(g_lr, d_lr) 246 | print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) 247 | 248 | def print_network(self, model, name): 249 | num_params = 0 250 | for p in model.parameters(): 251 | num_params += p.numel() 252 | print(name) 253 | print(model) 254 | print("The number of parameters: {}".format(num_params)) 255 | 256 | def load_pretrained_model(self): 257 | self.G.load_state_dict(torch.load(os.path.join( 258 | self.model_save_path, '{}_G.pth'.format(self.pretrained_model)))) 259 | self.D.load_state_dict(torch.load(os.path.join( 260 | self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) 261 | print('loaded trained models (step: {})..!'.format(self.pretrained_model)) 262 | 263 | def build_tensorboard(self): 264 | from logger import Logger 265 | self.logger = Logger(self.log_path) 266 | 267 | def update_lr(self, g_lr, d_lr): 268 | for param_group in self.g_optimizer.param_groups: 269 | param_group['lr'] = g_lr 270 | for param_group in self.d_optimizer.param_groups: 271 | param_group['lr'] = d_lr 272 | 273 | def reset_grad(self): 274 | self.g_optimizer.zero_grad() 275 | self.d_optimizer.zero_grad() 276 | 277 | def to_var(self, x, volatile=False): 278 | if torch.cuda.is_available(): 279 | x = x.cuda() 280 | return Variable(x, volatile=volatile) 281 | 282 | def denorm(self, x): 283 | out = (x + 1) / 2 284 | return out.clamp_(0, 1) 285 | 286 | def threshold(self, x): 287 | x = x.clone() 288 | x[x >= 0.5] = 1 289 | x[x < 0.5] = 0 290 | return x 291 | 292 | def compute_accuracy(self, x, y, dataset): 293 | if dataset == 'CelebA': 294 | x = F.sigmoid(x) 295 | predicted = self.threshold(x) 296 | correct = (predicted == y).float() 297 | accuracy = torch.mean(correct, dim=0) * 100.0 298 | elif dataset == 'Flowers': 299 | x = F.sigmoid(x) 300 | predicted = self.threshold(x) 301 | correct = (predicted == y).float() 302 | accuracy = torch.mean(correct, dim=0) * 100.0 303 | 304 | else: 305 | _, predicted = torch.max(x, dim=1) 306 | correct = (predicted == y).float() 307 | accuracy = torch.mean(correct) * 100.0 308 | return accuracy 309 | 310 | def one_hot(self, labels, dim): 311 | """Convert label indices to one-hot vector""" 312 | batch_size = labels.size(0) 313 | out = torch.zeros(batch_size, dim) 314 | out[np.arange(batch_size), labels.long()] = 1 315 | return out 316 | 317 | def make_celeb_labels(self, real_c): 318 | """Generate domain labels for CelebA for debugging/testing. 319 | 320 | if dataset == 'CelebA': 321 | return single and multiple attribute changes 322 | elif dataset == 'Both': 323 | return single attribute changes 324 | """ 325 | y = [torch.FloatTensor([1, 0, 0]), # black hair 326 | torch.FloatTensor([0, 1, 0]), # blond hair 327 | torch.FloatTensor([0, 0, 1])] # brown hair 328 | 329 | fixed_c_list = [] 330 | 331 | # single attribute transfer 332 | for i in range(self.c_dim): 333 | fixed_c = real_c.clone() 334 | for c in fixed_c: 335 | if i < 3: 336 | c[:3] = y[i] 337 | else: 338 | c[i] = 0 if c[i] == 1 else 1 # opposite value 339 | fixed_c_list.append(self.to_var(fixed_c, volatile=True)) 340 | 341 | # multi-attribute transfer (H+G, H+A, G+A, H+G+A) 342 | if self.dataset == 'CelebA': 343 | for i in range(4): 344 | fixed_c = real_c.clone() 345 | for c in fixed_c: 346 | if i in [0, 1, 3]: # Hair color to brown 347 | c[:3] = y[2] 348 | if i in [0, 2, 3]: # Gender 349 | c[3] = 0 if c[3] == 1 else 1 350 | if i in [1, 2, 3]: # Aged 351 | c[4] = 0 if c[4] == 1 else 1 352 | fixed_c_list.append(self.to_var(fixed_c, volatile=True)) 353 | return fixed_c_list 354 | 355 | def make_flowers_labels(self,real_c): 356 | """Generate domain labels for CelebA for debugging/testing. 357 | 358 | if dataset == 'CelebA': 359 | return single and multiple attribute changes 360 | elif dataset == 'Both': 361 | return single attribute changes 362 | """ 363 | 364 | fixed_c_list = [] 365 | 366 | # single attribute transfer 367 | for i in range(self.c_dim): 368 | fixed_c = real_c.clone() 369 | for c in fixed_c: 370 | c[:] = torch.FloatTensor(np.eye(self.c_dim)[i]) 371 | fixed_c_list.append(self.to_var(fixed_c, volatile=True)) 372 | 373 | return fixed_c_list 374 | def train_multi(self): 375 | """Train StarGAN with multiple datasets. 376 | In the code below, 1 is related to CelebA and 2 is releated to RaFD. 377 | """ 378 | # Fixed imagse and labels for debugging 379 | fixed_x = [] 380 | real_c = [] 381 | 382 | for i, (images, labels) in enumerate(self.celebA_loader): 383 | fixed_x.append(images) 384 | real_c.append(labels) 385 | if i == 2: 386 | break 387 | 388 | fixed_x = torch.cat(fixed_x, dim=0) 389 | fixed_x = self.to_var(fixed_x, volatile=True) 390 | real_c = torch.cat(real_c, dim=0) 391 | fixed_c1_list = self.make_celeb_labels(real_c) 392 | 393 | fixed_c2_list = [] 394 | for i in range(self.c2_dim): 395 | fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c2_dim) 396 | fixed_c2_list.append(self.to_var(fixed_c, volatile=True)) 397 | 398 | fixed_zero1 = self.to_var(torch.zeros(fixed_x.size(0), self.c2_dim)) # zero vector when training with CelebA 399 | fixed_mask1 = self.to_var(self.one_hot(torch.zeros(fixed_x.size(0)), 2)) # mask vector: [1, 0] 400 | fixed_zero2 = self.to_var(torch.zeros(fixed_x.size(0), self.c_dim)) # zero vector when training with RaFD 401 | fixed_mask2 = self.to_var(self.one_hot(torch.ones(fixed_x.size(0)), 2)) # mask vector: [0, 1] 402 | 403 | # lr cache for decaying 404 | g_lr = self.g_lr 405 | d_lr = self.d_lr 406 | 407 | # data iterator 408 | data_iter1 = iter(self.celebA_loader) 409 | data_iter2 = iter(self.rafd_loader) 410 | 411 | # Start with trained model 412 | if self.pretrained_model: 413 | start = int(self.pretrained_model) + 1 414 | else: 415 | start = 0 416 | 417 | # # Start training 418 | start_time = time.time() 419 | for i in range(start, self.num_iters): 420 | 421 | # Fetch mini-batch images and labels 422 | try: 423 | real_x1, real_label1 = next(data_iter1) 424 | except: 425 | data_iter1 = iter(self.celebA_loader) 426 | real_x1, real_label1 = next(data_iter1) 427 | 428 | try: 429 | real_x2, real_label2 = next(data_iter2) 430 | except: 431 | data_iter2 = iter(self.rafd_loader) 432 | real_x2, real_label2 = next(data_iter2) 433 | 434 | # Generate fake labels randomly (target domain labels) 435 | rand_idx = torch.randperm(real_label1.size(0)) 436 | fake_label1 = real_label1[rand_idx] 437 | rand_idx = torch.randperm(real_label2.size(0)) 438 | fake_label2 = real_label2[rand_idx] 439 | 440 | real_c1 = real_label1.clone() 441 | fake_c1 = fake_label1.clone() 442 | zero1 = torch.zeros(real_x1.size(0), self.c2_dim) 443 | mask1 = self.one_hot(torch.zeros(real_x1.size(0)), 2) 444 | 445 | real_c2 = self.one_hot(real_label2, self.c2_dim) 446 | fake_c2 = self.one_hot(fake_label2, self.c2_dim) 447 | zero2 = torch.zeros(real_x2.size(0), self.c_dim) 448 | mask2 = self.one_hot(torch.ones(real_x2.size(0)), 2) 449 | 450 | # Convert tensor to variable 451 | real_x1 = self.to_var(real_x1) 452 | real_c1 = self.to_var(real_c1) 453 | fake_c1 = self.to_var(fake_c1) 454 | mask1 = self.to_var(mask1) 455 | zero1 = self.to_var(zero1) 456 | 457 | real_x2 = self.to_var(real_x2) 458 | real_c2 = self.to_var(real_c2) 459 | fake_c2 = self.to_var(fake_c2) 460 | mask2 = self.to_var(mask2) 461 | zero2 = self.to_var(zero2) 462 | 463 | real_label1 = self.to_var(real_label1) 464 | fake_label1 = self.to_var(fake_label1) 465 | real_label2 = self.to_var(real_label2) 466 | fake_label2 = self.to_var(fake_label2) 467 | 468 | # ================== Train D ================== # 469 | 470 | # Real images (CelebA) 471 | out_real, out_cls = self.D(real_x1) 472 | out_cls1 = out_cls[:, :self.c_dim] # celebA part 473 | d_loss_real = - torch.mean(out_real) 474 | d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0) 475 | 476 | # Real images (RaFD) 477 | out_real, out_cls = self.D(real_x2) 478 | out_cls2 = out_cls[:, self.c_dim:] # rafd part 479 | d_loss_real += - torch.mean(out_real) 480 | d_loss_cls += F.cross_entropy(out_cls2, real_label2) 481 | 482 | # Compute classification accuracy of the discriminator 483 | if (i+1) % self.log_step == 0: 484 | accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA') 485 | log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] 486 | print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') 487 | print(log) 488 | accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD') 489 | log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] 490 | print('Classification Acc (8 emotional expressions): ', end='') 491 | print(log) 492 | 493 | # Fake images (CelebA) 494 | fake_c = torch.cat([fake_c1, zero1, mask1], dim=1) 495 | fake_x1 = self.G(real_x1, fake_c) 496 | fake_x1 = Variable(fake_x1.data) 497 | out_fake, _ = self.D(fake_x1) 498 | d_loss_fake = torch.mean(out_fake) 499 | 500 | # Fake images (RaFD) 501 | fake_c = torch.cat([zero2, fake_c2, mask2], dim=1) 502 | fake_x2 = self.G(real_x2, fake_c) 503 | out_fake, _ = self.D(fake_x2) 504 | d_loss_fake += torch.mean(out_fake) 505 | 506 | # Backward + Optimize 507 | d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls 508 | self.reset_grad() 509 | d_loss.backward() 510 | self.d_optimizer.step() 511 | 512 | # Compute gradient penalty 513 | if (i+1) % 2 == 0: 514 | real_x = real_x1 515 | fake_x = fake_x1 516 | else: 517 | real_x = real_x2 518 | fake_x = fake_x2 519 | 520 | alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) 521 | interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) 522 | out, out_cls = self.D(interpolated) 523 | 524 | if (i+1) % 2 == 0: 525 | out_cls = out_cls[:, :self.c_dim] # CelebA 526 | else: 527 | out_cls = out_cls[:, self.c_dim:] # RaFD 528 | 529 | grad = torch.autograd.grad(outputs=out, 530 | inputs=interpolated, 531 | grad_outputs=torch.ones(out.size()).cuda(), 532 | retain_graph=True, 533 | create_graph=True, 534 | only_inputs=True)[0] 535 | 536 | grad = grad.view(grad.size(0), -1) 537 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) 538 | d_loss_gp = torch.mean((grad_l2norm - 1)**2) 539 | 540 | # Backward + Optimize 541 | d_loss = self.lambda_gp * d_loss_gp 542 | self.reset_grad() 543 | d_loss.backward() 544 | self.d_optimizer.step() 545 | 546 | # Logging 547 | loss = {} 548 | loss['D/loss_real'] = d_loss_real.data[0] 549 | loss['D/loss_fake'] = d_loss_fake.data[0] 550 | loss['D/loss_cls'] = d_loss_cls.data[0] 551 | loss['D/loss_gp'] = d_loss_gp.data[0] 552 | 553 | # ================== Train G ================== # 554 | if (i+1) % self.d_train_repeat == 0: 555 | # Original-to-target and target-to-original domain (CelebA) 556 | fake_c = torch.cat([fake_c1, zero1, mask1], dim=1) 557 | real_c = torch.cat([real_c1, zero1, mask1], dim=1) 558 | fake_x1 = self.G(real_x1, fake_c) 559 | rec_x1 = self.G(fake_x1, real_c) 560 | 561 | # Compute losses 562 | out, out_cls = self.D(fake_x1) 563 | out_cls1 = out_cls[:, :self.c_dim] 564 | g_loss_fake = - torch.mean(out) 565 | g_loss_rec = torch.mean(torch.abs(real_x1 - rec_x1)) 566 | g_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, fake_label1, size_average=False) / fake_x1.size(0) 567 | 568 | # Original-to-target and target-to-original domain (RaFD) 569 | fake_c = torch.cat([zero2, fake_c2, mask2], dim=1) 570 | real_c = torch.cat([zero2, real_c2, mask2], dim=1) 571 | fake_x2 = self.G(real_x2, fake_c) 572 | rec_x2 = self.G(fake_x2, real_c) 573 | 574 | # Compute losses 575 | out, out_cls = self.D(fake_x2) 576 | out_cls2 = out_cls[:, self.c_dim:] 577 | g_loss_fake += - torch.mean(out) 578 | g_loss_rec += torch.mean(torch.abs(real_x2 - rec_x2)) 579 | g_loss_cls += F.cross_entropy(out_cls2, fake_label2) 580 | 581 | # Backward + Optimize 582 | g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec 583 | self.reset_grad() 584 | g_loss.backward() 585 | self.g_optimizer.step() 586 | 587 | # Logging 588 | loss['G/loss_fake'] = g_loss_fake.data[0] 589 | loss['G/loss_cls'] = g_loss_cls.data[0] 590 | loss['G/loss_rec'] = g_loss_rec.data[0] 591 | 592 | # Print out log info 593 | if (i+1) % self.log_step == 0: 594 | elapsed = time.time() - start_time 595 | elapsed = str(datetime.timedelta(seconds=elapsed)) 596 | 597 | log = "Elapsed [{}], Iter [{}/{}]".format( 598 | elapsed, i+1, self.num_iters) 599 | 600 | for tag, value in loss.items(): 601 | log += ", {}: {:.4f}".format(tag, value) 602 | print(log) 603 | 604 | if self.use_tensorboard: 605 | for tag, value in loss.items(): 606 | self.logger.scalar_summary(tag, value, i+1) 607 | 608 | # Translate the images (debugging) 609 | if (i+1) % self.sample_step == 0: 610 | fake_image_list = [fixed_x] 611 | 612 | # Changing hair color, gender, and age 613 | for j in range(self.c_dim): 614 | fake_c = torch.cat([fixed_c1_list[j], fixed_zero1, fixed_mask1], dim=1) 615 | fake_image_list.append(self.G(fixed_x, fake_c)) 616 | # Changing emotional expressions 617 | for j in range(self.c2_dim): 618 | fake_c = torch.cat([fixed_zero2, fixed_c2_list[j], fixed_mask2], dim=1) 619 | fake_image_list.append(self.G(fixed_x, fake_c)) 620 | fake = torch.cat(fake_image_list, dim=3) 621 | 622 | # Save the translated images 623 | save_image(self.denorm(fake.data), 624 | os.path.join(self.sample_path, '{}_fake.png'.format(i+1)), nrow=1, padding=0) 625 | 626 | # Save model checkpoints 627 | if (i+1) % self.model_save_step == 0: 628 | torch.save(self.G.state_dict(), 629 | os.path.join(self.model_save_path, '{}_G.pth'.format(i+1))) 630 | torch.save(self.D.state_dict(), 631 | os.path.join(self.model_save_path, '{}_D.pth'.format(i+1))) 632 | 633 | # Decay learning rate 634 | decay_step = 1000 635 | if (i+1) > (self.num_iters - self.num_iters_decay) and (i+1) % decay_step==0: 636 | g_lr -= (self.g_lr / float(self.num_iters_decay) * decay_step) 637 | d_lr -= (self.d_lr / float(self.num_iters_decay) * decay_step) 638 | self.update_lr(g_lr, d_lr) 639 | print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) 640 | 641 | def test(self): 642 | """Facial attribute transfer on CelebA or facial expression synthesis on RaFD.""" 643 | # Load trained parameters 644 | G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model)) 645 | self.G.load_state_dict(torch.load(G_path)) 646 | self.G.eval() 647 | 648 | if self.dataset == 'CelebA': 649 | data_loader = self.celebA_loader 650 | else: 651 | data_loader = self.rafd_loader 652 | 653 | for i, (real_x, org_c) in enumerate(data_loader): 654 | real_x = self.to_var(real_x, volatile=True) 655 | 656 | if self.dataset == 'CelebA': 657 | target_c_list = self.make_celeb_labels(org_c) 658 | else: 659 | target_c_list = [] 660 | for j in range(self.c_dim): 661 | target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim) 662 | target_c_list.append(self.to_var(target_c, volatile=True)) 663 | 664 | # Start translations 665 | fake_image_list = [real_x] 666 | for target_c in target_c_list: 667 | fake_image_list.append(self.G(real_x, target_c)) 668 | fake_images = torch.cat(fake_image_list, dim=3) 669 | save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1)) 670 | save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) 671 | print('Translated test images and saved into "{}"..!'.format(save_path)) 672 | 673 | def test_multi(self): 674 | """Facial attribute transfer and expression synthesis on CelebA.""" 675 | # Load trained parameters 676 | G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model)) 677 | self.G.load_state_dict(torch.load(G_path)) 678 | self.G.eval() 679 | 680 | for i, (real_x, org_c) in enumerate(self.celebA_loader): 681 | 682 | # Prepare input images and target domain labels 683 | real_x = self.to_var(real_x, volatile=True) 684 | target_c1_list = self.make_celeb_labels(org_c) 685 | target_c2_list = [] 686 | for j in range(self.c2_dim): 687 | target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c2_dim) 688 | target_c2_list.append(self.to_var(target_c, volatile=True)) 689 | 690 | # Zero vectors and mask vectors 691 | zero1 = self.to_var(torch.zeros(real_x.size(0), self.c2_dim)) # zero vector for rafd expressions 692 | mask1 = self.to_var(self.one_hot(torch.zeros(real_x.size(0)), 2)) # mask vector: [1, 0] 693 | zero2 = self.to_var(torch.zeros(real_x.size(0), self.c_dim)) # zero vector for celebA attributes 694 | mask2 = self.to_var(self.one_hot(torch.ones(real_x.size(0)), 2)) # mask vector: [0, 1] 695 | 696 | # Changing hair color, gender, and age 697 | fake_image_list = [real_x] 698 | for j in range(self.c_dim): 699 | target_c = torch.cat([target_c1_list[j], zero1, mask1], dim=1) 700 | fake_image_list.append(self.G(real_x, target_c)) 701 | 702 | # Changing emotional expressions 703 | for j in range(self.c2_dim): 704 | target_c = torch.cat([zero2, target_c2_list[j], mask2], dim=1) 705 | fake_image_list.append(self.G(real_x, target_c)) 706 | fake_images = torch.cat(fake_image_list, dim=3) 707 | 708 | # Save the translated images 709 | save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1)) 710 | save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) 711 | print('Translated test images and saved into "{}"..!'.format(save_path)) -------------------------------------------------------------------------------- /FCN-8/test.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torch.utils.data import DataLoader 3 | from PIL import Image 4 | import numpy as np 5 | from voc import * 6 | import matplotlib.pyplot as plt 7 | from model import * 8 | from torch.autograd import Variable 9 | 10 | if __name__ == '__main__': 11 | voc_root = "E:/Jonathan" 12 | data_loader = DataLoader(VOC2012ClassSeg(voc_root, split='train', transform=True), 13 | batch_size=5, shuffle=False, 14 | num_workers=4, pin_memory=True) 15 | 16 | 17 | model = FCN8s(n_class=21) 18 | if torch.cuda.is_available(): 19 | model = model.cuda() 20 | 21 | model_file = "./model_weights/fcn8s_from_caffe.pth" 22 | 23 | model_data = torch.load(model_file) 24 | model.load_state_dict(model_data) 25 | model.eval() 26 | for batch_idx, (data, target) in enumerate(data_loader): 27 | 28 | if torch.cuda.is_available(): 29 | data, target = data.cuda(), target.cuda() 30 | 31 | data, target = Variable(data, volatile=True), Variable(target) 32 | score = model(data) 33 | ## lbl_pred = score.data 34 | ## print(lbl_pred.size()) 35 | ## break 36 | imgs = data.data.cpu() 37 | lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :] 38 | lbl_true = target.data.cpu() 39 | print(lbl_pred.shape) 40 | for img, lt, lp in zip(imgs, lbl_true, lbl_pred): 41 | img, lt = data_loader.dataset.untransform(img, lt) 42 | print(img.shape) 43 | print(lt.shape) 44 | print(lp.shape) 45 | plt.subplot(131) 46 | plt.imshow(img) 47 | plt.subplot(132) 48 | plt.imshow(lp) 49 | plt.subplot(133) 50 | plt.imshow(lt) 51 | plt.show() 52 | 53 | -------------------------------------------------------------------------------- /FCN-8/test2.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torch.utils.data import DataLoader 3 | from PIL import Image 4 | import numpy as np 5 | from voc import * 6 | import matplotlib.pyplot as plt 7 | from model import * 8 | from torch.autograd import Variable 9 | import time 10 | 11 | if __name__ == '__main__': 12 | voc_root = "E:/Jonathan" 13 | data_loader = DataLoader(VOC2012ClassSeg(voc_root, split='train', transform=True), 14 | batch_size=1, shuffle=False, 15 | num_workers=0, pin_memory=True) 16 | 17 | 18 | model = FCN8s(n_class=21) 19 | 20 | guidance_module = VGG16Modified() 21 | if torch.cuda.is_available(): 22 | model = model.cuda() 23 | guidance_module = guidance_module.cuda() 24 | 25 | model_file = "./model_weights/fcn8s_from_caffe.pth" 26 | vgg_model_file = "./model_weights/vgg16_from_caffe.pth" 27 | 28 | 29 | model_data = torch.load(model_file) 30 | model.load_state_dict(model_data) 31 | ## model.eval() 32 | 33 | guidance_module.copy_params_from_vgg16(vgg_model_file) 34 | 35 | 36 | for batch_idx, (data, target) in enumerate(data_loader): 37 | t_num = target.numpy() 38 | print(np.sum(t_num==0) / (128*128)) 39 | if torch.cuda.is_available(): 40 | data, target = data.cuda(), target.cuda() 41 | 42 | data, target = Variable(data, volatile=True), Variable(target) 43 | st = time.time() 44 | coarse_map = model(data) 45 | refined = guidance_module(data,coarse_map) 46 | print(refined.size(), time.time()-st) 47 | 48 | 49 | break 50 | 51 | 52 | 53 | ## lbl_pred = score.data 54 | ## print(lbl_pred.size()) 55 | ## break 56 | ## imgs = data.data.cpu() 57 | ## lbl_pred = coarse_map.data.max(1)[1].cpu().numpy()[:, :, :] 58 | ## lbl_true = target.data.cpu() 59 | ## for img, lt, lp in zip(imgs, lbl_true, lbl_pred): 60 | ## img, lt = data_loader.dataset.untransform(img, lt) 61 | ## print(img.shape) 62 | ## print(lt.shape) 63 | ## print(lp.shape) 64 | ## plt.subplot(131) 65 | ## plt.imshow(img) 66 | ## plt.subplot(132) 67 | ## plt.imshow(lp) 68 | ## plt.subplot(133) 69 | ## plt.imshow(lt) 70 | ## plt.show() 71 | 72 | 73 | -------------------------------------------------------------------------------- /FCN-8/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mkdir(directory): 4 | if not os.path.exists(directory): 5 | os.makedirs(directory) -------------------------------------------------------------------------------- /FCN-8/voc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import os.path as osp 5 | import numbers 6 | import random 7 | import math 8 | import numpy as np 9 | import PIL.Image 10 | import scipy.io 11 | import torch 12 | from torch.utils import data 13 | from torchvision import transforms 14 | 15 | class RandomCropGenerator(object): 16 | def __call__(self, img): 17 | self.x1 = random.uniform(0, 1) 18 | self.y1 = random.uniform(0, 1) 19 | return img 20 | 21 | class RandomCrop(object): 22 | def __init__(self, size, padding=0, gen=None): 23 | if isinstance(size, numbers.Number): 24 | self.size = (int(size), int(size)) 25 | else: 26 | self.size = size 27 | self.padding = padding 28 | self._gen = gen 29 | 30 | def __call__(self, img): 31 | if self.padding > 0: 32 | img = ImageOps.expand(img, border=self.padding, fill=0) 33 | w, h = img.size 34 | th, tw = self.size 35 | if w == tw and h == th: 36 | return img 37 | 38 | if self._gen is not None: 39 | x1 = math.floor(self._gen.x1 * (w - tw)) 40 | y1 = math.floor(self._gen.y1 * (h - th)) 41 | else: 42 | x1 = random.randint(0, w - tw) 43 | y1 = random.randint(0, h - th) 44 | 45 | return img.crop((x1, y1, x1 + tw, y1 + th)) 46 | 47 | class VOCClassSegBase(data.Dataset): 48 | 49 | class_names = np.array([ 50 | 'background', 51 | 'aeroplane', 52 | 'bicycle', 53 | 'bird', 54 | 'boat', 55 | 'bottle', 56 | 'bus', 57 | 'car', 58 | 'cat', 59 | 'chair', 60 | 'cow', 61 | 'diningtable', 62 | 'dog', 63 | 'horse', 64 | 'motorbike', 65 | 'person', 66 | 'potted plant', 67 | 'sheep', 68 | 'sofa', 69 | 'train', 70 | 'tv/monitor', 71 | ]) 72 | mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 73 | 74 | def __init__(self, root, split='train', transform=False): 75 | self.root = root 76 | self.split = split 77 | self._transform = transform 78 | 79 | # VOC2011 and others are subset of VOC2012 80 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 81 | print(dataset_dir) 82 | self.files = collections.defaultdict(list) 83 | for split in ['train', 'val']: 84 | imgsets_file = osp.join( 85 | dataset_dir, 'ImageSets/Segmentation/%s.txt' % split) 86 | for did in open(imgsets_file): 87 | did = did.strip() 88 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 89 | lbl_file = osp.join( 90 | dataset_dir, 'SegmentationClass/%s.png' % did) 91 | self.files[split].append({ 92 | 'img': img_file, 93 | 'lbl': lbl_file, 94 | }) 95 | print(len(self.files["train"])) 96 | 97 | 98 | 99 | def __len__(self): 100 | return len(self.files[self.split]) 101 | 102 | def __getitem__(self, index): 103 | data_file = self.files[self.split][index] 104 | # load image 105 | img_file = data_file['img'] 106 | img_pil = PIL.Image.open(img_file) 107 | 108 | gen = RandomCropGenerator() 109 | onlyBgPatch = True 110 | while onlyBgPatch: 111 | 112 | transform_img = transforms.Compose([ 113 | gen, 114 | RandomCrop(128, gen=gen), 115 | transforms.Resize([128, 128])]) 116 | 117 | img = np.array(transform_img(img_pil),dtype=np.uint8) 118 | 119 | transform_mask = transforms.Compose([ 120 | RandomCrop(128, gen=gen)]) 121 | 122 | # load label 123 | lbl_file = data_file['lbl'] 124 | lbl_pil = PIL.Image.open(lbl_file) 125 | 126 | lbl_cropped = transform_mask(lbl_pil) 127 | lbl = np.array(transform_mask(lbl_pil), dtype=np.int32) 128 | lbl[lbl == 255] = -1 129 | unique_vals = np.unique(lbl) 130 | if len(unique_vals) > 2: 131 | onlyBgPatch = False 132 | for i in unique_vals: 133 | percentage_covered = np.sum(lbl==i) / (128*128) 134 | 135 | if percentage_covered >= 0.9: 136 | onlyBgPatch = True 137 | break 138 | 139 | if self._transform: 140 | return self.transform(img, lbl) 141 | else: 142 | return img, lbl 143 | 144 | def transform(self, img, lbl): 145 | img = img[:, :, ::-1] # RGB -> BGR 146 | img = img.astype(np.float64) 147 | img -= self.mean_bgr 148 | img = img.transpose(2, 0, 1) 149 | img = torch.from_numpy(img).float() 150 | lbl = torch.from_numpy(lbl).long() 151 | return img, lbl 152 | 153 | def untransform(self, img, lbl): 154 | img = img.numpy() 155 | img = img.transpose(1, 2, 0) 156 | img += self.mean_bgr 157 | img = img.astype(np.uint8) 158 | img = img[:, :, ::-1] 159 | lbl = lbl.numpy() 160 | return img, lbl 161 | 162 | 163 | class VOC2011ClassSeg(VOCClassSegBase): 164 | 165 | def __init__(self, root, split='train', transform=False): 166 | super(VOC2011ClassSeg, self).__init__( 167 | root, split=split, transform=transform) 168 | pkg_root = osp.join(osp.dirname(osp.realpath(__file__)), '..') 169 | imgsets_file = osp.join( 170 | pkg_root, 'ext/fcn.berkeleyvision.org', 171 | 'data/pascal/seg11valid.txt') 172 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 173 | for did in open(imgsets_file): 174 | did = did.strip() 175 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 176 | lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % did) 177 | self.files['seg11valid'].append({'img': img_file, 'lbl': lbl_file}) 178 | 179 | 180 | class VOC2012ClassSeg(VOCClassSegBase): 181 | 182 | url = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' # NOQA 183 | 184 | def __init__(self, root, split='train', transform=False): 185 | super(VOC2012ClassSeg, self).__init__( 186 | root, split=split, transform=transform) 187 | 188 | 189 | class SBDClassSeg(VOCClassSegBase): 190 | 191 | # XXX: It must be renamed to benchmark.tar to be extracted. 192 | url = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz' # NOQA 193 | 194 | def __init__(self, root, split='train', transform=False): 195 | self.root = root 196 | self.split = split 197 | self._transform = transform 198 | 199 | dataset_dir = osp.join(self.root, 'VOC/benchmark_RELEASE/dataset') 200 | self.files = collections.defaultdict(list) 201 | for split in ['train', 'val']: 202 | imgsets_file = osp.join(dataset_dir, '%s.txt' % split) 203 | for did in open(imgsets_file): 204 | did = did.strip() 205 | img_file = osp.join(dataset_dir, 'img/%s.jpg' % did) 206 | lbl_file = osp.join(dataset_dir, 'cls/%s.mat' % did) 207 | self.files[split].append({ 208 | 'img': img_file, 209 | 'lbl': lbl_file, 210 | }) 211 | 212 | def __getitem__(self, index): 213 | data_file = self.files[self.split][index] 214 | # load image 215 | img_file = data_file['img'] 216 | img = PIL.Image.open(img_file) 217 | img = np.array(img, dtype=np.uint8) 218 | # load label 219 | lbl_file = data_file['lbl'] 220 | mat = scipy.io.loadmat(lbl_file) 221 | lbl = mat['GTcls'][0]['Segmentation'][0].astype(np.int32) 222 | lbl[lbl == 255] = -1 223 | if self._transform: 224 | return self.transform(img, lbl) 225 | else: 226 | return img, lbl 227 | -------------------------------------------------------------------------------- /LRNN/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieltan07/spatialaffinitynetwork/b1b2e6fac23eec7bfe910768e4979abb9e46bebc/LRNN/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /LRNN/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieltan07/spatialaffinitynetwork/b1b2e6fac23eec7bfe910768e4979abb9e46bebc/LRNN/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /LRNN/__pycache__/solver.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieltan07/spatialaffinitynetwork/b1b2e6fac23eec7bfe910768e4979abb9e46bebc/LRNN/__pycache__/solver.cpython-36.pyc -------------------------------------------------------------------------------- /LRNN/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieltan07/spatialaffinitynetwork/b1b2e6fac23eec7bfe910768e4979abb9e46bebc/LRNN/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /LRNN/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torchvision.datasets import ImageFolder 8 | from PIL import Image 9 | import h5py 10 | import numpy as np 11 | import collections 12 | 13 | 14 | class RandomCropGenerator(object): 15 | def __call__(self, img): 16 | self.x1 = random.uniform(0, 1) 17 | self.y1 = random.uniform(0, 1) 18 | return img 19 | 20 | class RandomCrop(object): 21 | def __init__(self, size, padding=0, gen=None): 22 | if isinstance(size, numbers.Number): 23 | self.size = (int(size), int(size)) 24 | else: 25 | self.size = size 26 | self.padding = padding 27 | self._gen = gen 28 | 29 | def __call__(self, img): 30 | if self.padding > 0: 31 | img = ImageOps.expand(img, border=self.padding, fill=0) 32 | w, h = img.size 33 | th, tw = self.size 34 | if w == tw and h == th: 35 | return img 36 | 37 | if self._gen is not None: 38 | x1 = math.floor(self._gen.x1 * (w - tw)) 39 | y1 = math.floor(self._gen.y1 * (h - th)) 40 | else: 41 | x1 = random.randint(0, w - tw) 42 | y1 = random.randint(0, h - th) 43 | 44 | return img.crop((x1, y1, x1 + tw, y1 + th)) 45 | 46 | class FlowersDataset(Dataset): 47 | def __init__(self, image_path, transform, mode): 48 | self.transform = transform 49 | self.mode = mode 50 | self.data = h5py.File(image_path, 'r') 51 | self.num_data = self.data["train_images"].shape[0] 52 | self.attr2idx = {} 53 | self.idx2attr = {} 54 | 55 | print ('Start preprocessing dataset..!') 56 | self.preprocess() 57 | print ('Finished preprocessing dataset..!') 58 | 59 | if self.mode == 'train': 60 | self.num_data = self.data["train_images"].shape[0] 61 | elif self.mode == 'test': 62 | self.num_data = self.data["test_images"].shape[0] 63 | 64 | def preprocess(self): 65 | main_colors = ["blue","orange","pink","purple","red","white","yellow"] 66 | for i, attr in enumerate(main_colors): 67 | self.attr2idx[attr] = i 68 | self.idx2attr[i] = attr 69 | 70 | 71 | def __getitem__(self, index): 72 | image = Image.fromarray(np.uint8(self.data[self.mode+"_images"][index])) 73 | feature = np.float32(self.data[self.mode+"_feature"][index]) 74 | identity = int(self.data[self.mode+"_class"][index]) 75 | 76 | 77 | return self.transform(image), torch.FloatTensor(feature), identity 78 | 79 | def __len__(self): 80 | return self.num_data 81 | 82 | 83 | class PascalVOC2012(Dataset): 84 | 85 | class_names = np.array([ 86 | 'background', 87 | 'aeroplane', 88 | 'bicycle', 89 | 'bird', 90 | 'boat', 91 | 'bottle', 92 | 'bus', 93 | 'car', 94 | 'cat', 95 | 'chair', 96 | 'cow', 97 | 'diningtable', 98 | 'dog', 99 | 'horse', 100 | 'motorbike', 101 | 'person', 102 | 'potted plant', 103 | 'sheep', 104 | 'sofa', 105 | 'train', 106 | 'tv/monitor', 107 | ]) 108 | mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 109 | 110 | def __init__(self, root, transform=None, mode='train'): 111 | self.root = root 112 | 113 | if mode == 'train': 114 | self.split = 'train' 115 | else: 116 | self.split = 'val' 117 | 118 | 119 | self.transform = transform 120 | 121 | # VOC2011 and others are subset of VOC2012 122 | dataset_dir = os.path.join(self.root, 'VOC/VOCdevkit/VOC2012') 123 | self.files = collections.defaultdict(list) 124 | for split in ['train', 'val']: 125 | imgsets_file = os.path.join( 126 | dataset_dir, 'ImageSets/Segmentation/%s.txt' % split) 127 | for did in open(imgsets_file): 128 | did = did.strip() 129 | img_file = os.path.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 130 | lbl_file = os.path.join( 131 | dataset_dir, 'SegmentationClass/%s.png' % did) 132 | self.files[split].append({ 133 | 'img': img_file, 134 | 'lbl': lbl_file, 135 | }) 136 | def __len__(self): 137 | return len(self.files[self.split]) 138 | 139 | def __getitem__(self, index): 140 | data_file = self.files[self.split][index] 141 | # load image 142 | img_file = data_file['img'] 143 | img = Image.open(img_file) 144 | # img = np.array(img, dtype=np.uint8) 145 | # load label 146 | lbl_file = data_file['lbl'] 147 | lbl = Image.open(lbl_file) 148 | lbl = np.array(lbl, dtype=np.int32) 149 | lbl[lbl == 255] = -1 150 | 151 | img, lbl = self.preprocess(img,lbl) 152 | 153 | img = self.transform(img) 154 | im_size = img.size() 155 | 156 | rand = torch.zeros(1, im_size[1], im_size[2]) 157 | rand.random_(0,2) 158 | rand = rand.expand(3, im_size[1], im_size[2]) 159 | img_corrupted = img * rand 160 | lbl = torch.from_numpy(lbl).long() 161 | 162 | return img_corrupted, img 163 | 164 | def preprocess(self,img, lbl): 165 | # img = img[:, :, ::-1] # RGB -> BGR 166 | # img = img.astype(np.float64) 167 | # img -= self.mean_bgr 168 | # img = img.transpose(2, 0, 1) 169 | 170 | return img, lbl 171 | 172 | def postprocess(self, img, lbl): 173 | # img = img.numpy() 174 | # img = img.transpose(1, 2, 0) 175 | # img += self.mean_bgr 176 | # img = img.astype(np.uint8) 177 | # img = img[:, :, ::-1] 178 | # lbl = lbl.numpy() 179 | return img, lbl 180 | 181 | 182 | 183 | def get_loader(image_path, crop_size, image_size, batch_size, dataset='PascalVOC2012', mode='train'): 184 | """Build and return data loader.""" 185 | 186 | if mode == 'train': 187 | transform = transforms.Compose([ 188 | # transforms.RandomCrop(crop_size), 189 | transforms.Scale([image_size, image_size]), 190 | # transforms.RandomHorizontalFlip(), 191 | transforms.ToTensor()]) 192 | else: 193 | transform = transforms.Compose([ 194 | transforms.CenterCrop(crop_size), 195 | transforms.Scale(image_size), 196 | transforms.ToTensor()]) 197 | 198 | if dataset == 'PascalVOC2012': 199 | dataset = PascalVOC2012(image_path, transform, mode) 200 | 201 | shuffle = False 202 | if mode == 'train': 203 | shuffle = True 204 | 205 | data_loader = DataLoader(dataset=dataset, 206 | batch_size=batch_size, 207 | shuffle=shuffle) 208 | return data_loader 209 | -------------------------------------------------------------------------------- /LRNN/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from solver import Solver 4 | from data_loader import get_loader 5 | from torch.backends import cudnn 6 | from utils import * 7 | 8 | 9 | def str2bool(v): 10 | return v.lower() in ('true') 11 | 12 | def main(config): 13 | # For fast training 14 | cudnn.benchmark = True 15 | 16 | # Create directories if not exist 17 | mkdir(config.log_path) 18 | mkdir(config.model_save_path) 19 | mkdir(config.sample_path) 20 | mkdir(config.result_path) 21 | 22 | data_loader = get_loader(config.data_path, config.image_size, 23 | config.image_size, config.batch_size, 'PascalVOC2012', config.mode) 24 | 25 | # Solver 26 | solver = Solver(data_loader, vars(config)) 27 | 28 | if config.mode == 'train': 29 | solver.train() 30 | elif config.mode == 'test': 31 | solver.test() 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | 36 | # Model hyper-parameters 37 | parser.add_argument('--image_size', type=int, default=96) 38 | parser.add_argument('--lr', type=float, default=0.0001) 39 | 40 | # Training settings 41 | 42 | parser.add_argument('--num_epochs', type=int, default=100) 43 | parser.add_argument('--num_epochs_decay', type=int, default=60) 44 | parser.add_argument('--batch_size', type=int, default=20) 45 | parser.add_argument('--pretrained_model', type=str, default=None) 46 | 47 | # Misc 48 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 49 | parser.add_argument('--use_tensorboard', type=str2bool, default=False) 50 | 51 | # Path 52 | parser.add_argument('--data_path', type=str, default='E:/Jonathan') 53 | parser.add_argument('--log_path', type=str, default='./lrnn/logs') 54 | parser.add_argument('--model_save_path', type=str, default='./lrnn/models') 55 | parser.add_argument('--sample_path', type=str, default='./lrnn/samples') 56 | parser.add_argument('--result_path', type=str, default='./lrnn/results') 57 | 58 | # Step size 59 | parser.add_argument('--log_step', type=int, default=10) 60 | parser.add_argument('--sample_step', type=int, default=100) 61 | parser.add_argument('--model_save_step', type=int, default=1000) 62 | 63 | config = parser.parse_args() 64 | 65 | args = vars(config) 66 | print('------------ Options -------------') 67 | for k, v in sorted(args.items()): 68 | print('%s: %s' % (str(k), str(v))) 69 | print('-------------- End ----------------') 70 | 71 | main(config) -------------------------------------------------------------------------------- /LRNN/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class LRNN(nn.Module): 7 | 8 | def __init__(self): 9 | super(LRNN, self).__init__() 10 | 11 | self.scale_1 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=True) 12 | self.scale_2 = nn.MaxPool2d(2, stride=2) 13 | self.scale_3 = nn.MaxPool2d(2, stride=2) 14 | self.scale_4 = nn.MaxPool2d(2, stride=2) 15 | 16 | self.scale_2_resize = nn.Upsample(scale_factor=2, mode='bilinear') 17 | self.scale_3_resize = nn.Upsample(scale_factor=4, mode='bilinear') 18 | self.scale_4_resize = nn.Upsample(scale_factor=8, mode='bilinear') 19 | self.multi_conv = nn.Conv2d(12, 16, kernel_size=3, stride=1, padding=1, bias=True) 20 | 21 | 22 | 23 | self.conv2 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2, bias=True), 24 | nn.ReLU(inplace=True)) 25 | 26 | self.conv3 = nn.Sequential(nn.MaxPool2d(2, stride=2), 27 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=True), 28 | nn.ReLU(inplace=True)) 29 | 30 | self.conv4 = nn.Sequential(nn.MaxPool2d(2, stride=2), 31 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=True), 32 | nn.ReLU(inplace=True)) 33 | 34 | self.conv5 = nn.Sequential(nn.MaxPool2d(2, stride=2), 35 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=True), 36 | nn.ReLU(inplace=True)) 37 | 38 | self.conv6 = nn.Sequential(nn.MaxPool2d(2, stride=2), 39 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=True), 40 | nn.ReLU(inplace=True)) 41 | 42 | self.conv6s_re = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear'), 43 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True), 44 | nn.ReLU(inplace=True), 45 | nn.Upsample(scale_factor=2, mode='bilinear')) 46 | 47 | # concat conv6s_re with conv4 48 | self.conv7_re = nn.Sequential(nn.Conv2d(96, 32, kernel_size=3, stride=1, padding=1, bias=True), 49 | nn.ReLU(inplace=True), 50 | nn.Upsample(scale_factor=2, mode='bilinear')) 51 | # concat conv7_re with conv 3 52 | self.conv8_re = nn.Sequential(nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1, bias=True), 53 | nn.ReLU(inplace=True), 54 | nn.Upsample(scale_factor=2, mode='bilinear')) 55 | 56 | # concat conv8_re with conv 2 57 | self.conv9 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=True), 58 | nn.Tanh()) 59 | 60 | 61 | self.conv10 = nn.Sequential(nn.Conv2d(16, 64, kernel_size=3, stride=1, padding=1, bias=True), 62 | nn.ReLU(inplace=True)) 63 | 64 | self.conv11 = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1, bias=True), 65 | nn.ReLU(inplace=True)) 66 | 67 | def sample_multiscale(self,x): 68 | scale_1 = self.scale_1(x) 69 | scale_2 = self.scale_2(scale_1) 70 | scale_3 = self.scale_3(scale_2) 71 | scale_4 = self.scale_4(scale_3) 72 | 73 | scale_2 = self.scale_2_resize(scale_2) 74 | scale_3 = self.scale_3_resize(scale_3) 75 | scale_4 = self.scale_4_resize(scale_4) 76 | multi = torch.cat([scale_1, scale_2, scale_3, scale_4], dim=1) 77 | 78 | return self.multi_conv(multi) 79 | 80 | def flip(self,x, dim): 81 | xsize = x.size() 82 | dim = x.dim() + dim if dim < 0 else dim 83 | x = x.view(-1, *xsize[dim:]) 84 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, 85 | -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] 86 | return x.view(xsize) 87 | 88 | def forward(self, x): 89 | multiscale_input = self.sample_multiscale(x) 90 | conv2 = self.conv2(x) 91 | conv3 = self.conv3(conv2) 92 | conv4 = self.conv4(conv3) 93 | conv5 = self.conv5(conv4) 94 | conv6 = self.conv6(conv5) 95 | conv6s_re = self.conv6s_re(conv6) 96 | 97 | concat3 = torch.cat([conv6s_re, conv4],dim=1) 98 | conv7_re = self.conv7_re(concat3) 99 | 100 | concat4 = torch.cat([conv7_re, conv3],dim=1) 101 | conv8_re = self.conv8_re(concat4) 102 | 103 | concat5 = torch.cat([conv8_re, conv2], dim=1) 104 | conv9 = self.conv9(concat5) 105 | 106 | conv4_bn_x1 = conv9[:,0:16,:,:] 107 | conv4_bn_y1 = conv9[:,16:32,:,:] 108 | conv4_bn_x2 = conv9[:,32:48,:,:] 109 | conv4_bn_y2 = conv9[:,48:64,:,:] 110 | 111 | N, C, H, W = x.size() 112 | 113 | 114 | rnn_h1 = ((1- conv4_bn_x1[:,:,:,0])*multiscale_input[:,:,:,0]).unsqueeze(3) 115 | rnn_h2 = ((1- conv4_bn_x1[:,:,:,0])*multiscale_input[:,:,:,W-1]).unsqueeze(3) 116 | rnn_h3 = ((1- conv4_bn_y1[:,:,0,:])*multiscale_input[:,:,0,:]).unsqueeze(2) 117 | rnn_h4 = ((1- conv4_bn_y1[:,:,0,:])*multiscale_input[:,:,H-1,:]).unsqueeze(2) 118 | 119 | for i in range(1,W): 120 | rnn_h1_t = conv4_bn_x1[:,:,:,i]*rnn_h1[:,:,:,i-1] + (1 - conv4_bn_x1[:,:,:,i])*multiscale_input[:,:,:,i] 121 | rnn_h2_t = conv4_bn_x1[:,:,:,i]*rnn_h2[:,:,:,i-1] + (1 - conv4_bn_x1[:,:,:,i])*multiscale_input[:,:,:,W-i-1] 122 | 123 | rnn_h1 = torch.cat([rnn_h1, rnn_h1_t.unsqueeze(3)], dim=3) 124 | rnn_h2 = torch.cat([rnn_h2, rnn_h2_t.unsqueeze(3)], dim=3) 125 | 126 | for i in range(1,H): 127 | rnn_h3_t = conv4_bn_x1[:,:,i,:]*rnn_h3[:,:,i-1,:] + (1 - conv4_bn_x1[:,:,i,:])*multiscale_input[:,:,i,:] 128 | rnn_h4_t = conv4_bn_x1[:,:,i,:]*rnn_h4[:,:,i-1,:] + (1 - conv4_bn_x1[:,:,i,:])*multiscale_input[:,:,H-i-1,:] 129 | 130 | rnn_h3 = torch.cat([rnn_h3, rnn_h3_t.unsqueeze(2)], dim=2) 131 | rnn_h4 = torch.cat([rnn_h4, rnn_h4_t.unsqueeze(2)], dim=2) 132 | 133 | rnn_h5 = ((1- conv4_bn_x2[:,:,:,0])*rnn_h1[:,:,:,0]).unsqueeze(3) 134 | rnn_h6 = ((1- conv4_bn_x2[:,:,:,0])*rnn_h2[:,:,:,W-1]).unsqueeze(3) 135 | rnn_h7 = ((1- conv4_bn_y2[:,:,0,:])*rnn_h3[:,:,0,:]).unsqueeze(2) 136 | rnn_h8 = ((1- conv4_bn_y2[:,:,0,:])*rnn_h4[:,:,H-1,:]).unsqueeze(2) 137 | 138 | for i in range(1,W): 139 | rnn_h5_t = conv4_bn_x2[:,:,:,i]*rnn_h5[:,:,:,i-1] + (1 - conv4_bn_x2[:,:,:,i])*rnn_h1[:,:,:,i] 140 | rnn_h6_t = conv4_bn_x2[:,:,:,i]*rnn_h6[:,:,:,i-1] + (1 - conv4_bn_x2[:,:,:,i])*rnn_h2[:,:,:,W-i-1] 141 | rnn_h5 = torch.cat([rnn_h5, rnn_h5_t.unsqueeze(3)], dim=3) 142 | rnn_h6 = torch.cat([rnn_h6, rnn_h6_t.unsqueeze(3)], dim=3) 143 | for i in range(1,H): 144 | rnn_h7_t = conv4_bn_y2[:,:,i,:]*rnn_h7[:,:,i-1,:] + (1 - conv4_bn_y2[:,:,i,:])*rnn_h3[:,:,i,:] 145 | rnn_h8_t = conv4_bn_y2[:,:,i,:]*rnn_h8[:,:,i-1,:] + (1 - conv4_bn_y2[:,:,i,:])*rnn_h4[:,:,H-i-1,:] 146 | 147 | rnn_h7 = torch.cat([rnn_h7, rnn_h7_t.unsqueeze(2)], dim=2) 148 | rnn_h8 = torch.cat([rnn_h8, rnn_h8_t.unsqueeze(2)], dim=2) 149 | concat6 = torch.cat([rnn_h5.unsqueeze(4),rnn_h6.unsqueeze(4),rnn_h7.unsqueeze(4),rnn_h8.unsqueeze(4)],dim=4) 150 | elt_max = torch.max(concat6, dim=4)[0] 151 | conv10 = self.conv10(elt_max) 152 | conv11 = self.conv11(conv10) 153 | 154 | return conv11 155 | 156 | def _initialize_weights(self): 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | m.weight.data.zero_() 160 | if m.bias is not None: 161 | m.bias.data.zero_() 162 | if isinstance(m, nn.ConvTranspose2d): 163 | assert m.kernel_size[0] == m.kernel_size[1] 164 | initial_weight = get_upsampling_weight( 165 | m.in_channels, m.out_channels, m.kernel_size[0]) 166 | m.weight.data.copy_(initial_weight) 167 | 168 | class SimpleRNN(nn.Module): 169 | def __init__(self, hidden_size): 170 | super(SimpleRNN, self).__init__() 171 | self.hidden_size = hidden_size 172 | 173 | self.inp = nn.Linear(1, hidden_size) 174 | self.rnn = nn.LSTM(hidden_size, hidden_size, 2, dropout=0.05) 175 | self.out = nn.Linear(hidden_size, 1) 176 | 177 | def step(self, input, hidden=None): 178 | input = self.inp(input.view(1, -1)).unsqueeze(1) 179 | output, hidden = self.rnn(input, hidden) 180 | output = self.out(output.squeeze(1)) 181 | return output, hidden 182 | 183 | def forward(self, inputs, hidden=None, force=True, steps=0): 184 | if force or steps == 0: steps = len(inputs) 185 | outputs = Variable(torch.zeros(steps, 1, 1)) 186 | for i in range(steps): 187 | if force or i == 0: 188 | input = inputs[i] 189 | else: 190 | input = output 191 | output, hidden = self.step(input, hidden) 192 | outputs[i] = output 193 | return outputs, hidden -------------------------------------------------------------------------------- /LRNN/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | import time 7 | import datetime 8 | from torch.autograd import grad 9 | from torch.autograd import Variable 10 | from torchvision.utils import save_image 11 | from torchvision import transforms 12 | from model import * 13 | from PIL import Image 14 | 15 | 16 | class Solver(object): 17 | DEFAULTS = {} 18 | def __init__(self, data_loader, config): 19 | self.__dict__.update(Solver.DEFAULTS, **config) 20 | self.data_loader = data_loader 21 | # Build tensorboard if use 22 | self.build_model() 23 | if self.use_tensorboard: 24 | self.build_tensorboard() 25 | 26 | # Start with trained model 27 | if self.pretrained_model: 28 | self.load_pretrained_model() 29 | 30 | def build_model(self): 31 | # Define a generator and a discriminator 32 | self.LRNN = LRNN() 33 | 34 | # Optimizers 35 | self.optimizer = torch.optim.Adam(self.LRNN.parameters(), self.lr) 36 | 37 | # Print networks 38 | self.print_network(self.LRNN, 'LRNN') 39 | 40 | if torch.cuda.is_available(): 41 | self.LRNN.cuda() 42 | 43 | 44 | def train(self): 45 | """Train StarGAN within a single dataset.""" 46 | loss_fn = torch.nn.MSELoss(size_average=False) 47 | # The number of iterations per epoch 48 | iters_per_epoch = len(self.data_loader) 49 | 50 | fixed_x = [] 51 | real_c = [] 52 | for i, (images, target) in enumerate(self.data_loader): 53 | fixed_x.append(images) 54 | if i == 16: 55 | break 56 | 57 | # Fixed inputs and target domain labels for debugging 58 | fixed_x = torch.cat(fixed_x, dim=0) 59 | fixed_x = self.to_var(fixed_x, volatile=True) 60 | 61 | # lr cache for decaying 62 | lr = self.lr 63 | 64 | # Start with trained model if exists 65 | if self.pretrained_model: 66 | start = int(self.pretrained_model.split('_')[0]) 67 | else: 68 | start = 0 69 | 70 | # Start training 71 | start_time = time.time() 72 | for e in range(start, self.num_epochs): 73 | for i, (images, target) in enumerate(self.data_loader): 74 | 75 | # Convert tensor to variable 76 | images = self.to_var(images) 77 | target = self.to_var(target) 78 | 79 | refined_images = self.LRNN(images) 80 | 81 | mse_loss = loss_fn(refined_images, target) / images.size(0) 82 | 83 | self.reset_grad() 84 | mse_loss.backward() 85 | self.optimizer.step() 86 | 87 | # # Compute classification accuracy of the discriminator 88 | # if (i+1) % self.log_step == 0: 89 | # accuracies = self.compute_accuracy(real_feature, real_label, self.dataset) 90 | # log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] 91 | # if self.dataset == 'CelebA': 92 | # print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') 93 | # else: 94 | # print('Classification Acc (8 emotional expressions): ', end='') 95 | # print(log) 96 | 97 | 98 | # Logging 99 | loss = {} 100 | loss['loss'] = mse_loss.data[0] 101 | 102 | # Print out log info 103 | if (i+1) % self.log_step == 0: 104 | elapsed = time.time() - start_time 105 | elapsed = str(datetime.timedelta(seconds=elapsed)) 106 | 107 | log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( 108 | elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) 109 | 110 | for tag, value in loss.items(): 111 | log += ", {}: {:.4f}".format(tag, value) 112 | print(log) 113 | 114 | if self.use_tensorboard: 115 | for tag, value in loss.items(): 116 | self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) 117 | 118 | # Translate fixed images for debugging 119 | if (i+1) % self.sample_step == 0: 120 | fake_image_list = [fixed_x] 121 | refined_images = self.LRNN(fixed_x) 122 | fake_image_list.append(refined_images) 123 | fake_images = torch.cat(fake_image_list, dim=3) 124 | save_image(fake_images.data, 125 | os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) 126 | print('Translated images and saved into {}..!'.format(self.sample_path)) 127 | 128 | # Save model checkpoints 129 | if (i+1) % self.model_save_step == 0: 130 | torch.save(self.LRNN.state_dict(), 131 | os.path.join(self.model_save_path, '{}_{}_LRNN.pth'.format(e+1, i+1))) 132 | 133 | # Decay learning rate 134 | if (e+1) > (self.num_epochs - self.num_epochs_decay): 135 | lr -= (self.lr / float(self.num_epochs_decay)) 136 | self.update_lr(lr) 137 | print ('Decay learning rate to lr: {}.'.format(lr)) 138 | 139 | def print_network(self, model, name): 140 | num_params = 0 141 | for p in model.parameters(): 142 | num_params += p.numel() 143 | print(name) 144 | print(model) 145 | print("The number of parameters: {}".format(num_params)) 146 | 147 | def load_pretrained_model(self): 148 | self.LRNN.load_state_dict(torch.load(os.path.join( 149 | self.model_save_path, '{}_LRNN.pth'.format(self.pretrained_model)))) 150 | # self.D.load_state_dict(torch.load(os.path.join( 151 | # self.model_save_path, '{}_D.pth'.format(self.pretrained_model)))) 152 | print('loaded trained models (step: {})..!'.format(self.pretrained_model)) 153 | 154 | def build_tensorboard(self): 155 | from logger import Logger 156 | self.logger = Logger(self.log_path) 157 | 158 | def update_lr(self, lr): 159 | for param_group in self.optimizer.param_groups: 160 | param_group['lr'] = lr 161 | # for param_group in self.d_optimizer.param_groups: 162 | # param_group['lr'] = d_lr 163 | 164 | def reset_grad(self): 165 | self.optimizer.zero_grad() 166 | 167 | def to_var(self, x, volatile=False): 168 | if torch.cuda.is_available(): 169 | x = x.cuda() 170 | return Variable(x, volatile=volatile) 171 | 172 | def denorm(self, x): 173 | out = (x + 1) / 2 174 | return out.clamp_(0, 1) 175 | 176 | def threshold(self, x): 177 | x = x.clone() 178 | x[x >= 0.5] = 1 179 | x[x < 0.5] = 0 180 | return x 181 | 182 | def compute_accuracy(self, x, y, dataset): 183 | if dataset == 'CelebA': 184 | x = F.sigmoid(x) 185 | predicted = self.threshold(x) 186 | correct = (predicted == y).float() 187 | accuracy = torch.mean(correct, dim=0) * 100.0 188 | elif dataset == 'Flowers': 189 | x = F.sigmoid(x) 190 | predicted = self.threshold(x) 191 | correct = (predicted == y).float() 192 | accuracy = torch.mean(correct, dim=0) * 100.0 193 | 194 | else: 195 | _, predicted = torch.max(x, dim=1) 196 | correct = (predicted == y).float() 197 | accuracy = torch.mean(correct) * 100.0 198 | return accuracy 199 | 200 | def one_hot(self, labels, dim): 201 | """Convert label indices to one-hot vector""" 202 | batch_size = labels.size(0) 203 | out = torch.zeros(batch_size, dim) 204 | out[np.arange(batch_size), labels.long()] = 1 205 | return out 206 | 207 | def make_celeb_labels(self, real_c): 208 | """Generate domain labels for CelebA for debugging/testing. 209 | 210 | if dataset == 'CelebA': 211 | return single and multiple attribute changes 212 | elif dataset == 'Both': 213 | return single attribute changes 214 | """ 215 | y = [torch.FloatTensor([1, 0, 0]), # black hair 216 | torch.FloatTensor([0, 1, 0]), # blond hair 217 | torch.FloatTensor([0, 0, 1])] # brown hair 218 | 219 | fixed_c_list = [] 220 | 221 | # single attribute transfer 222 | for i in range(self.c_dim): 223 | fixed_c = real_c.clone() 224 | for c in fixed_c: 225 | if i < 3: 226 | c[:3] = y[i] 227 | else: 228 | c[i] = 0 if c[i] == 1 else 1 # opposite value 229 | fixed_c_list.append(self.to_var(fixed_c, volatile=True)) 230 | 231 | # multi-attribute transfer (H+G, H+A, G+A, H+G+A) 232 | if self.dataset == 'CelebA': 233 | for i in range(4): 234 | fixed_c = real_c.clone() 235 | for c in fixed_c: 236 | if i in [0, 1, 3]: # Hair color to brown 237 | c[:3] = y[2] 238 | if i in [0, 2, 3]: # Gender 239 | c[3] = 0 if c[3] == 1 else 1 240 | if i in [1, 2, 3]: # Aged 241 | c[4] = 0 if c[4] == 1 else 1 242 | fixed_c_list.append(self.to_var(fixed_c, volatile=True)) 243 | return fixed_c_list 244 | 245 | def make_flowers_labels(self,real_c): 246 | """Generate domain labels for CelebA for debugging/testing. 247 | 248 | if dataset == 'CelebA': 249 | return single and multiple attribute changes 250 | elif dataset == 'Both': 251 | return single attribute changes 252 | """ 253 | 254 | fixed_c_list = [] 255 | 256 | # single attribute transfer 257 | for i in range(self.c_dim): 258 | fixed_c = real_c.clone() 259 | for c in fixed_c: 260 | c[:] = torch.FloatTensor(np.eye(self.c_dim)[i]) 261 | fixed_c_list.append(self.to_var(fixed_c, volatile=True)) 262 | 263 | return fixed_c_list 264 | 265 | def test(self): 266 | """Facial attribute transfer on CelebA or facial expression synthesis on RaFD.""" 267 | # Load trained parameters 268 | G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model)) 269 | self.G.load_state_dict(torch.load(G_path)) 270 | self.G.eval() 271 | 272 | if self.dataset == 'CelebA': 273 | data_loader = self.celebA_loader 274 | else: 275 | data_loader = self.rafd_loader 276 | 277 | for i, (real_x, org_c) in enumerate(data_loader): 278 | real_x = self.to_var(real_x, volatile=True) 279 | 280 | if self.dataset == 'CelebA': 281 | target_c_list = self.make_celeb_labels(org_c) 282 | else: 283 | target_c_list = [] 284 | for j in range(self.c_dim): 285 | target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim) 286 | target_c_list.append(self.to_var(target_c, volatile=True)) 287 | 288 | # Start translations 289 | fake_image_list = [real_x] 290 | for target_c in target_c_list: 291 | fake_image_list.append(self.G(real_x, target_c)) 292 | fake_images = torch.cat(fake_image_list, dim=3) 293 | save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1)) 294 | save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) 295 | print('Translated test images and saved into "{}"..!'.format(save_path)) 296 | 297 | def test_multi(self): 298 | """Facial attribute transfer and expression synthesis on CelebA.""" 299 | # Load trained parameters 300 | G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model)) 301 | self.G.load_state_dict(torch.load(G_path)) 302 | self.G.eval() 303 | 304 | for i, (real_x, org_c) in enumerate(self.celebA_loader): 305 | 306 | # Prepare input images and target domain labels 307 | real_x = self.to_var(real_x, volatile=True) 308 | target_c1_list = self.make_celeb_labels(org_c) 309 | target_c2_list = [] 310 | for j in range(self.c2_dim): 311 | target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c2_dim) 312 | target_c2_list.append(self.to_var(target_c, volatile=True)) 313 | 314 | # Zero vectors and mask vectors 315 | zero1 = self.to_var(torch.zeros(real_x.size(0), self.c2_dim)) # zero vector for rafd expressions 316 | mask1 = self.to_var(self.one_hot(torch.zeros(real_x.size(0)), 2)) # mask vector: [1, 0] 317 | zero2 = self.to_var(torch.zeros(real_x.size(0), self.c_dim)) # zero vector for celebA attributes 318 | mask2 = self.to_var(self.one_hot(torch.ones(real_x.size(0)), 2)) # mask vector: [0, 1] 319 | 320 | # Changing hair color, gender, and age 321 | fake_image_list = [real_x] 322 | for j in range(self.c_dim): 323 | target_c = torch.cat([target_c1_list[j], zero1, mask1], dim=1) 324 | fake_image_list.append(self.G(real_x, target_c)) 325 | 326 | # Changing emotional expressions 327 | for j in range(self.c2_dim): 328 | target_c = torch.cat([zero2, target_c2_list[j], mask2], dim=1) 329 | fake_image_list.append(self.G(real_x, target_c)) 330 | fake_images = torch.cat(fake_image_list, dim=3) 331 | 332 | # Save the translated images 333 | save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1)) 334 | save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) 335 | print('Translated test images and saved into "{}"..!'.format(save_path)) -------------------------------------------------------------------------------- /LRNN/test.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torch.utils.data import DataLoader 3 | from PIL import Image 4 | import numpy as np 5 | from data_loader import * 6 | import matplotlib.pyplot as plt 7 | from model import * 8 | from torch.autograd import Variable 9 | 10 | if __name__ == '__main__': 11 | voc_root = "E:/Jonathan" 12 | data_loader = get_loader(voc_root, 96, 13 | 96, 1, 'PascalVOC2012', 'train') 14 | 15 | a = np.random.randn(20,3,96,96) 16 | a = Variable(torch.from_numpy(a).float()).cuda() 17 | 18 | gen = LRNN().cuda() 19 | 20 | gen(a) 21 | 22 | ## for batch_idx, (data, target) in enumerate(data_loader): 23 | ## 24 | ## print(data.size()) 25 | ## 26 | ## plt.subplot(131) 27 | ## plt.imshow(img) 28 | ## plt.subplot(132) 29 | ## plt.imshow(lp) 30 | ## plt.subplot(133) 31 | ## plt.imshow(lt) 32 | ## plt.show() 33 | ## break 34 | -------------------------------------------------------------------------------- /LRNN/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mkdir(directory): 4 | if not os.path.exists(directory): 5 | os.makedirs(directory) -------------------------------------------------------------------------------- /LRNN/voc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import os.path as osp 5 | 6 | import numpy as np 7 | import PIL.Image 8 | import scipy.io 9 | import torch 10 | from torch.utils import data 11 | 12 | 13 | class VOCClassSegBase(data.Dataset): 14 | 15 | class_names = np.array([ 16 | 'background', 17 | 'aeroplane', 18 | 'bicycle', 19 | 'bird', 20 | 'boat', 21 | 'bottle', 22 | 'bus', 23 | 'car', 24 | 'cat', 25 | 'chair', 26 | 'cow', 27 | 'diningtable', 28 | 'dog', 29 | 'horse', 30 | 'motorbike', 31 | 'person', 32 | 'potted plant', 33 | 'sheep', 34 | 'sofa', 35 | 'train', 36 | 'tv/monitor', 37 | ]) 38 | mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 39 | 40 | def __init__(self, root, split='train', transform=False): 41 | self.root = root 42 | self.split = split 43 | self._transform = transform 44 | 45 | # VOC2011 and others are subset of VOC2012 46 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 47 | print(dataset_dir) 48 | self.files = collections.defaultdict(list) 49 | for split in ['train', 'val']: 50 | imgsets_file = osp.join( 51 | dataset_dir, 'ImageSets/Segmentation/%s.txt' % split) 52 | for did in open(imgsets_file): 53 | did = did.strip() 54 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 55 | lbl_file = osp.join( 56 | dataset_dir, 'SegmentationClass/%s.png' % did) 57 | self.files[split].append({ 58 | 'img': img_file, 59 | 'lbl': lbl_file, 60 | }) 61 | 62 | def __len__(self): 63 | return len(self.files[self.split]) 64 | 65 | def __getitem__(self, index): 66 | data_file = self.files[self.split][index] 67 | # load image 68 | img_file = data_file['img'] 69 | img = PIL.Image.open(img_file) 70 | img = np.array(img, dtype=np.uint8) 71 | # load label 72 | lbl_file = data_file['lbl'] 73 | lbl = PIL.Image.open(lbl_file) 74 | lbl = np.array(lbl, dtype=np.int32) 75 | lbl[lbl == 255] = -1 76 | if self._transform: 77 | return self.transform(img, lbl) 78 | else: 79 | return img, lbl 80 | 81 | def transform(self, img, lbl): 82 | img = img[:, :, ::-1] # RGB -> BGR 83 | img = img.astype(np.float64) 84 | img -= self.mean_bgr 85 | img = img.transpose(2, 0, 1) 86 | img = torch.from_numpy(img).float() 87 | lbl = torch.from_numpy(lbl).long() 88 | return img, lbl 89 | 90 | def untransform(self, img, lbl): 91 | img = img.numpy() 92 | img = img.transpose(1, 2, 0) 93 | img += self.mean_bgr 94 | img = img.astype(np.uint8) 95 | img = img[:, :, ::-1] 96 | lbl = lbl.numpy() 97 | return img, lbl 98 | 99 | 100 | class VOC2011ClassSeg(VOCClassSegBase): 101 | 102 | def __init__(self, root, split='train', transform=False): 103 | super(VOC2011ClassSeg, self).__init__( 104 | root, split=split, transform=transform) 105 | pkg_root = osp.join(osp.dirname(osp.realpath(__file__)), '..') 106 | imgsets_file = osp.join( 107 | pkg_root, 'ext/fcn.berkeleyvision.org', 108 | 'data/pascal/seg11valid.txt') 109 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 110 | for did in open(imgsets_file): 111 | did = did.strip() 112 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 113 | lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % did) 114 | self.files['seg11valid'].append({'img': img_file, 'lbl': lbl_file}) 115 | 116 | 117 | class VOC2012ClassSeg(VOCClassSegBase): 118 | 119 | url = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' # NOQA 120 | 121 | def __init__(self, root, split='train', transform=False): 122 | super(VOC2012ClassSeg, self).__init__( 123 | root, split=split, transform=transform) 124 | 125 | 126 | class SBDClassSeg(VOCClassSegBase): 127 | 128 | # XXX: It must be renamed to benchmark.tar to be extracted. 129 | url = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz' # NOQA 130 | 131 | def __init__(self, root, split='train', transform=False): 132 | self.root = root 133 | self.split = split 134 | self._transform = transform 135 | 136 | dataset_dir = osp.join(self.root, 'VOC/benchmark_RELEASE/dataset') 137 | self.files = collections.defaultdict(list) 138 | for split in ['train', 'val']: 139 | imgsets_file = osp.join(dataset_dir, '%s.txt' % split) 140 | for did in open(imgsets_file): 141 | did = did.strip() 142 | img_file = osp.join(dataset_dir, 'img/%s.jpg' % did) 143 | lbl_file = osp.join(dataset_dir, 'cls/%s.mat' % did) 144 | self.files[split].append({ 145 | 'img': img_file, 146 | 'lbl': lbl_file, 147 | }) 148 | 149 | def __getitem__(self, index): 150 | data_file = self.files[self.split][index] 151 | # load image 152 | img_file = data_file['img'] 153 | img = PIL.Image.open(img_file) 154 | img = np.array(img, dtype=np.uint8) 155 | # load label 156 | lbl_file = data_file['lbl'] 157 | mat = scipy.io.loadmat(lbl_file) 158 | lbl = mat['GTcls'][0]['Segmentation'][0].astype(np.int32) 159 | lbl[lbl == 255] = -1 160 | if self._transform: 161 | return self.transform(img, lbl) 162 | else: 163 | return img, lbl 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatial Affinity Networks 2 | PyTorch Implementation of the paper Learning Affinity via Spatial Propagation Networks 3 | 4 | [Work in Progress] 5 | 6 | FCN-8 Code and Model adopted from https://github.com/wkentaro/pytorch-fcn 7 | 8 | The paper propose to learn a spatial affinity matrix by consturcting a row-wise / column-wise linear propagation model where each pixel in the current row/column incorporates the information from its three adjacent pixels in the previous row/column. The idea is to train a CNN that generates the weights of a recursive filter conditioned on an input image. Doing this for all 4 directions, you can form global and densely connected pairwise relations, as shown in Figure 1. 9 | 10 | 11 | ![alt text](https://github.com/danieltan07/spatialaffinitynetwork/blob/master/fig1.PNG) 12 | 13 | 14 | As shown in Figure 2, the model is separated into two modules: (1) a guidance network that outputs the weights / elements of the transformation matrix; (2) a propagation module that uses the weights given by the guidance network and uses it as a recursive filter to refine the result. 15 | 16 | ![alt text](https://github.com/danieltan07/spatialaffinitynetwork/blob/master/fig2.PNG) 17 | 18 | ## Some Implementation Details 19 | 20 | The guidance network outputs a tensor of size H x W x (C x 3 weights x 4 directions). We need to convert the tensor to a tridiagonal matrix so that when we perform a dot product it will correspond to the weights of the three adjacent pixels in the previous row/column. 21 | 22 | ```python 23 | def to_tridiagonal_multidim(self, w): 24 | # this function converts the weight vectors to a tridiagonal matrix 25 | 26 | N,W,C,D = w.size() 27 | 28 | # normalize the weights to stabilize the model 29 | tmp_w = w / torch.sum(torch.abs(w),dim=3).unsqueeze(-1) 30 | tmp_w = tmp_w.unsqueeze(2).expand([N,W,W,C,D]) 31 | 32 | # three identity matrices, one normal, one shifted left and the other shifted right 33 | eye_a = Variable(torch.diag(torch.ones(W-1).cuda(),diagonal=-1)) 34 | eye_b = Variable(torch.diag(torch.ones(W).cuda(),diagonal=0)) 35 | eye_c = Variable(torch.diag(torch.ones(W-1).cuda(),diagonal=1)) 36 | 37 | tmp_eye_a = eye_a.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 38 | a = tmp_w[:,:,:,:,0] * tmp_eye_a 39 | tmp_eye_b = eye_b.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 40 | b = tmp_w[:,:,:,:,1] * tmp_eye_b 41 | tmp_eye_c = eye_c.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 42 | c = tmp_w[:,:,:,:,2] * tmp_eye_c 43 | 44 | return a+b+c 45 | ``` 46 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torchvision.datasets import ImageFolder 8 | from PIL import Image 9 | import h5py 10 | import numpy as np 11 | import collections 12 | import numbers 13 | import math 14 | 15 | class RandomCropGenerator(object): 16 | def __call__(self, img): 17 | self.x1 = random.uniform(0, 1) 18 | self.y1 = random.uniform(0, 1) 19 | return img 20 | 21 | class RandomCrop(object): 22 | def __init__(self, size, padding=0, gen=None): 23 | if isinstance(size, numbers.Number): 24 | self.size = (int(size), int(size)) 25 | else: 26 | self.size = size 27 | self.padding = padding 28 | self._gen = gen 29 | 30 | def __call__(self, img): 31 | if self.padding > 0: 32 | img = ImageOps.expand(img, border=self.padding, fill=0) 33 | w, h = img.size 34 | th, tw = self.size 35 | if w == tw and h == th: 36 | return img 37 | 38 | if self._gen is not None: 39 | x1 = math.floor(self._gen.x1 * (w - tw)) 40 | y1 = math.floor(self._gen.y1 * (h - th)) 41 | else: 42 | x1 = random.randint(0, w - tw) 43 | y1 = random.randint(0, h - th) 44 | 45 | return img.crop((x1, y1, x1 + tw, y1 + th)) 46 | 47 | class PascalVOC2012(Dataset): 48 | 49 | class_names = np.array([ 50 | 'background', 51 | 'aeroplane', 52 | 'bicycle', 53 | 'bird', 54 | 'boat', 55 | 'bottle', 56 | 'bus', 57 | 'car', 58 | 'cat', 59 | 'chair', 60 | 'cow', 61 | 'diningtable', 62 | 'dog', 63 | 'horse', 64 | 'motorbike', 65 | 'person', 66 | 'potted plant', 67 | 'sheep', 68 | 'sofa', 69 | 'train', 70 | 'tv/monitor', 71 | ]) 72 | mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 73 | ''' 74 | color map 75 | 0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle # 6=bus, 7=car, 8=cat, 9=chair, 10=cow, 11=diningtable, 76 | 12=dog, 13=horse, 14=motorbike, 15=person # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor 77 | ''' 78 | palette = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 79 | 128, 128, 128, 64, 0, 0, 192, 0, 0, 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128, 80 | 64, 128, 128, 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0, 0, 64, 128] 81 | 82 | 83 | def __init__(self, root, transform, crop_size, image_size, mode='train'): 84 | self.root = root 85 | 86 | if mode == 'train': 87 | self.split = 'train' 88 | else: 89 | self.split = 'val' 90 | 91 | self.crop_size = crop_size 92 | self.image_size=image_size 93 | 94 | self._transform = transform 95 | zero_pad = 256 * 3 - len(self.palette) 96 | for i in range(zero_pad): 97 | self.palette.append(0) 98 | # VOC2011 and others are subset of VOC2012 99 | dataset_dir = os.path.join(self.root, 'VOC/VOCdevkit/VOC2012') 100 | self.files = collections.defaultdict(list) 101 | for split in ['train', 'val']: 102 | imgsets_file = os.path.join( 103 | dataset_dir, 'ImageSets/Segmentation/%s.txt' % split) 104 | for did in open(imgsets_file): 105 | did = did.strip() 106 | img_file = os.path.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 107 | lbl_file = os.path.join( 108 | dataset_dir, 'SegmentationClass/%s.png' % did) 109 | self.files[split].append({ 110 | 'img': img_file, 111 | 'lbl': lbl_file, 112 | }) 113 | def __len__(self): 114 | return len(self.files[self.split]) 115 | 116 | def colorize_mask(self,mask): 117 | # mask: numpy array of the mask 118 | 119 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 120 | new_mask.putpalette(self.palette) 121 | 122 | return new_mask.convert('RGB') 123 | def colorize_mask_batch(self,masks): 124 | color_masks = np.zeros((masks.shape[0],3,masks.shape[1],masks.shape[2])) 125 | toTensor = transforms.ToTensor() 126 | for i in range(masks.shape[0]): 127 | color_masks[i] = np.array(self.colorize_mask(masks[i])).transpose(2,0,1) 128 | 129 | return torch.from_numpy(color_masks).float() 130 | 131 | def __getitem__(self, index): 132 | data_file = self.files[self.split][index] 133 | # load image 134 | img_file = data_file['img'] 135 | img_pil = Image.open(img_file) 136 | gen = RandomCropGenerator() 137 | onlyBgPatch = True 138 | while onlyBgPatch: 139 | 140 | transform_img = transforms.Compose([ 141 | gen, 142 | RandomCrop(self.crop_size, gen=gen), 143 | transforms.Resize([self.image_size, self.image_size])]) 144 | 145 | img = np.array(transform_img(img_pil),dtype=np.uint8) 146 | 147 | transform_mask = transforms.Compose([ 148 | RandomCrop(self.crop_size, gen=gen), 149 | transforms.Resize([self.image_size, self.image_size],interpolation=Image.NEAREST)]) 150 | 151 | # load label 152 | lbl_file = data_file['lbl'] 153 | lbl_pil = Image.open(lbl_file) 154 | 155 | lbl_cropped = transform_mask(lbl_pil) 156 | lbl = np.array(transform_mask(lbl_pil), dtype=np.int32) 157 | lbl[lbl == 255] = -1 158 | unique_vals = np.unique(lbl) 159 | 160 | if len(unique_vals) >= 2: 161 | onlyBgPatch = False 162 | # for i in unique_vals: 163 | # percentage_covered = np.sum(lbl==i) / (self.image_size*self.image_size) 164 | 165 | # if percentage_covered >= 0.98: 166 | # onlyBgPatch = True 167 | # break 168 | 169 | if self._transform: 170 | return self.transform(img, lbl) 171 | else: 172 | return img, lbl 173 | 174 | def transform(self, img, lbl): 175 | img = img[:, :, ::-1] # RGB -> BGR 176 | img = img.astype(np.float64) 177 | img -= self.mean_bgr 178 | img = img.transpose(2, 0, 1) 179 | img = torch.from_numpy(img).float() 180 | lbl = torch.from_numpy(lbl).long() 181 | return img, lbl 182 | 183 | def untransform(self, img): 184 | img = img.transpose(1, 2, 0) 185 | img += self.mean_bgr 186 | img = img.astype(np.uint8) 187 | img = img[:, :, ::-1] / 255 188 | img = img.transpose(2, 0, 1) 189 | return img 190 | def untransform_batch(self, img): 191 | img = img.numpy() 192 | for i in range(img.shape[0]): 193 | img[i] = self.untransform(img[i]) 194 | 195 | return img 196 | 197 | 198 | def get_loader(image_path, crop_size, image_size, batch_size, transform=False, dataset='PascalVOC2012', mode='train'): 199 | """Build and return data loader.""" 200 | 201 | if dataset == 'PascalVOC2012': 202 | dataset = PascalVOC2012(image_path, transform, crop_size, image_size, mode) 203 | 204 | shuffle = False 205 | if mode == 'train': 206 | shuffle = True 207 | 208 | data_loader = DataLoader(dataset=dataset, 209 | batch_size=batch_size, 210 | shuffle=shuffle) 211 | return data_loader 212 | -------------------------------------------------------------------------------- /fig1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieltan07/spatialaffinitynetwork/b1b2e6fac23eec7bfe910768e4979abb9e46bebc/fig1.PNG -------------------------------------------------------------------------------- /fig2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieltan07/spatialaffinitynetwork/b1b2e6fac23eec7bfe910768e4979abb9e46bebc/fig2.PNG -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from solver import Solver 4 | from data_loader import get_loader 5 | from torch.backends import cudnn 6 | from utils import * 7 | 8 | 9 | def str2bool(v): 10 | return v.lower() in ('true') 11 | 12 | def main(config): 13 | # For fast training 14 | cudnn.benchmark = True 15 | 16 | # Create directories if not exist 17 | mkdir(config.log_path) 18 | mkdir(config.model_save_path) 19 | mkdir(config.sample_path) 20 | mkdir(config.result_path) 21 | 22 | data_loader = get_loader(config.data_path, config.image_size, 23 | config.crop_size, config.batch_size, transform=True, dataset='PascalVOC2012', mode=config.mode) 24 | 25 | # Solver 26 | solver = Solver(data_loader, vars(config)) 27 | 28 | if config.mode == 'train': 29 | solver.train() 30 | elif config.mode == 'test': 31 | solver.test() 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | 37 | # Model hyper-parameters 38 | parser.add_argument('--image_size', type=int, default=128) 39 | parser.add_argument('--crop_size', type=int, default=128) 40 | parser.add_argument('--num_classes', type=int, default=21) 41 | parser.add_argument('--lr', type=float, default=0.0001) 42 | 43 | # Training settings 44 | 45 | parser.add_argument('--num_epochs', type=int, default=100) 46 | parser.add_argument('--num_epochs_decay', type=int, default=60) 47 | parser.add_argument('--batch_size', type=int, default=20) 48 | parser.add_argument('--pretrained_model', type=str, default=None) 49 | 50 | # Misc 51 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 52 | parser.add_argument('--use_tensorboard', type=str2bool, default=False) 53 | 54 | # Path 55 | parser.add_argument('--data_path', type=str, default='E:/Jonathan') 56 | parser.add_argument('--fcn_model_path', type=str, default='E:/Jonathan/SpatialAffinity/FCN-8/model_weights/fcn8s_from_caffe.pth') 57 | parser.add_argument('--vgg_model_path', type=str, default='E:/Jonathan/SpatialAffinity/FCN-8/model_weights/vgg16_from_caffe.pth') 58 | parser.add_argument('--log_path', type=str, default='./spatial_affinity/logs') 59 | parser.add_argument('--model_save_path', type=str, default='./spatial_affinity/models') 60 | parser.add_argument('--sample_path', type=str, default='./spatial_affinity/samples') 61 | parser.add_argument('--result_path', type=str, default='./spatial_affinity/results') 62 | 63 | # Step size 64 | parser.add_argument('--log_step', type=int, default=10) 65 | parser.add_argument('--sample_step', type=int, default=70) 66 | parser.add_argument('--model_save_step', type=int, default=1000) 67 | config = parser.parse_args() 68 | 69 | args = vars(config) 70 | print('------------ Options -------------') 71 | for k, v in sorted(args.items()): 72 | print('%s: %s' % (str(k), str(v))) 73 | print('-------------- End ----------------') 74 | 75 | main(config) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision 6 | from torch.autograd import Variable 7 | 8 | # https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py 9 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 10 | """Make a 2D bilinear kernel suitable for upsampling""" 11 | factor = (kernel_size + 1) // 2 12 | if kernel_size % 2 == 1: 13 | center = factor - 1 14 | else: 15 | center = factor - 0.5 16 | og = np.ogrid[:kernel_size, :kernel_size] 17 | filt = (1 - np.abs(og[0] - center) / factor) * \ 18 | (1 - np.abs(og[1] - center) / factor) 19 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 20 | dtype=np.float64) 21 | weight[range(in_channels), range(out_channels), :, :] = filt 22 | return torch.from_numpy(weight).float() 23 | 24 | 25 | class FCN32s(nn.Module): 26 | def __init__(self, n_class=21): 27 | super(FCN32s, self).__init__() 28 | # conv1 29 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 30 | self.relu1_1 = nn.ReLU(inplace=True) 31 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 32 | self.relu1_2 = nn.ReLU(inplace=True) 33 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 34 | 35 | # conv2 36 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 37 | self.relu2_1 = nn.ReLU(inplace=True) 38 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 39 | self.relu2_2 = nn.ReLU(inplace=True) 40 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 41 | 42 | # conv3 43 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 44 | self.relu3_1 = nn.ReLU(inplace=True) 45 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 46 | self.relu3_2 = nn.ReLU(inplace=True) 47 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 48 | self.relu3_3 = nn.ReLU(inplace=True) 49 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 50 | 51 | # conv4 52 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 53 | self.relu4_1 = nn.ReLU(inplace=True) 54 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 55 | self.relu4_2 = nn.ReLU(inplace=True) 56 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 57 | self.relu4_3 = nn.ReLU(inplace=True) 58 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 59 | 60 | # conv5 61 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 62 | self.relu5_1 = nn.ReLU(inplace=True) 63 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 64 | self.relu5_2 = nn.ReLU(inplace=True) 65 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 66 | self.relu5_3 = nn.ReLU(inplace=True) 67 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 68 | 69 | # fc6 70 | self.fc6 = nn.Conv2d(512, 4096, 7) 71 | self.relu6 = nn.ReLU(inplace=True) 72 | self.drop6 = nn.Dropout2d() 73 | 74 | # fc7 75 | self.fc7 = nn.Conv2d(4096, 4096, 1) 76 | self.relu7 = nn.ReLU(inplace=True) 77 | self.drop7 = nn.Dropout2d() 78 | 79 | self.score_fr = nn.Conv2d(4096, n_class, 1) 80 | self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, 81 | bias=False) 82 | 83 | def forward(self, x): 84 | h = x 85 | h = self.relu1_1(self.conv1_1(h)) 86 | h = self.relu1_2(self.conv1_2(h)) 87 | h = self.pool1(h) 88 | 89 | h = self.relu2_1(self.conv2_1(h)) 90 | h = self.relu2_2(self.conv2_2(h)) 91 | h = self.pool2(h) 92 | 93 | h = self.relu3_1(self.conv3_1(h)) 94 | h = self.relu3_2(self.conv3_2(h)) 95 | h = self.relu3_3(self.conv3_3(h)) 96 | h = self.pool3(h) 97 | 98 | h = self.relu4_1(self.conv4_1(h)) 99 | h = self.relu4_2(self.conv4_2(h)) 100 | h = self.relu4_3(self.conv4_3(h)) 101 | h = self.pool4(h) 102 | 103 | h = self.relu5_1(self.conv5_1(h)) 104 | h = self.relu5_2(self.conv5_2(h)) 105 | h = self.relu5_3(self.conv5_3(h)) 106 | h = self.pool5(h) 107 | 108 | h = self.relu6(self.fc6(h)) 109 | h = self.drop6(h) 110 | 111 | h = self.relu7(self.fc7(h)) 112 | h = self.drop7(h) 113 | 114 | h = self.score_fr(h) 115 | 116 | h = self.upscore(h) 117 | h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous() 118 | 119 | return h 120 | 121 | def initialize_weights(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | m.weight.data.normal_(0.0, 0.02) 125 | if m.bias is not None: 126 | m.bias.data.zero_() 127 | if isinstance(m, nn.ConvTranspose2d): 128 | assert m.kernel_size[0] == m.kernel_size[1] 129 | initial_weight = get_upsampling_weight( 130 | m.in_channels, m.out_channels, m.kernel_size[0]) 131 | m.weight.data.copy_(initial_weight) 132 | 133 | def copy_params_from_vgg16(self, vgg16): 134 | features = [ 135 | self.conv1_1, self.relu1_1, 136 | self.conv1_2, self.relu1_2, 137 | self.pool1, 138 | self.conv2_1, self.relu2_1, 139 | self.conv2_2, self.relu2_2, 140 | self.pool2, 141 | self.conv3_1, self.relu3_1, 142 | self.conv3_2, self.relu3_2, 143 | self.conv3_3, self.relu3_3, 144 | self.pool3, 145 | self.conv4_1, self.relu4_1, 146 | self.conv4_2, self.relu4_2, 147 | self.conv4_3, self.relu4_3, 148 | self.pool4, 149 | self.conv5_1, self.relu5_1, 150 | self.conv5_2, self.relu5_2, 151 | self.conv5_3, self.relu5_3, 152 | self.pool5, 153 | ] 154 | for l1, l2 in zip(vgg16.features, features): 155 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 156 | assert l1.weight.size() == l2.weight.size() 157 | assert l1.bias.size() == l2.bias.size() 158 | l2.weight.data = l1.weight.data 159 | l2.bias.data = l1.bias.data 160 | for i, name in zip([0, 3], ['fc6', 'fc7']): 161 | l1 = vgg16.classifier[i] 162 | l2 = getattr(self, name) 163 | l2.weight.data = l1.weight.data.view(l2.weight.size()) 164 | l2.bias.data = l1.bias.data.view(l2.bias.size()) 165 | 166 | 167 | 168 | class FCN16s(nn.Module): 169 | 170 | # pretrained_model = \ 171 | # osp.expanduser('~/data/models/pytorch/fcn16s_from_caffe.pth') 172 | 173 | # @classmethod 174 | # def download(cls): 175 | # return fcn.data.cached_download( 176 | # url='http://drive.google.com/uc?id=0B9P1L--7Wd2vVGE3TkRMbWlNRms', 177 | # path=cls.pretrained_model, 178 | # md5='991ea45d30d632a01e5ec48002cac617', 179 | # ) 180 | 181 | def __init__(self, n_class=21): 182 | super(FCN16s, self).__init__() 183 | # conv1 184 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 185 | self.relu1_1 = nn.ReLU(inplace=True) 186 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 187 | self.relu1_2 = nn.ReLU(inplace=True) 188 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 189 | 190 | # conv2 191 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 192 | self.relu2_1 = nn.ReLU(inplace=True) 193 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 194 | self.relu2_2 = nn.ReLU(inplace=True) 195 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 196 | 197 | # conv3 198 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 199 | self.relu3_1 = nn.ReLU(inplace=True) 200 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 201 | self.relu3_2 = nn.ReLU(inplace=True) 202 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 203 | self.relu3_3 = nn.ReLU(inplace=True) 204 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 205 | 206 | # conv4 207 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 208 | self.relu4_1 = nn.ReLU(inplace=True) 209 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 210 | self.relu4_2 = nn.ReLU(inplace=True) 211 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 212 | self.relu4_3 = nn.ReLU(inplace=True) 213 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 214 | 215 | # conv5 216 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 217 | self.relu5_1 = nn.ReLU(inplace=True) 218 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 219 | self.relu5_2 = nn.ReLU(inplace=True) 220 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 221 | self.relu5_3 = nn.ReLU(inplace=True) 222 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 223 | 224 | # fc6 225 | self.fc6 = nn.Conv2d(512, 4096, 7) 226 | self.relu6 = nn.ReLU(inplace=True) 227 | self.drop6 = nn.Dropout2d() 228 | 229 | # fc7 230 | self.fc7 = nn.Conv2d(4096, 4096, 1) 231 | self.relu7 = nn.ReLU(inplace=True) 232 | self.drop7 = nn.Dropout2d() 233 | 234 | self.score_fr = nn.Conv2d(4096, n_class, 1) 235 | self.score_pool4 = nn.Conv2d(512, n_class, 1) 236 | 237 | self.upscore2 = nn.ConvTranspose2d( 238 | n_class, n_class, 4, stride=2, bias=False) 239 | self.upscore16 = nn.ConvTranspose2d( 240 | n_class, n_class, 32, stride=16, bias=False) 241 | 242 | self._initialize_weights() 243 | 244 | def _initialize_weights(self): 245 | for m in self.modules(): 246 | if isinstance(m, nn.Conv2d): 247 | m.weight.data.zero_() 248 | if m.bias is not None: 249 | m.bias.data.zero_() 250 | if isinstance(m, nn.ConvTranspose2d): 251 | assert m.kernel_size[0] == m.kernel_size[1] 252 | initial_weight = get_upsampling_weight( 253 | m.in_channels, m.out_channels, m.kernel_size[0]) 254 | m.weight.data.copy_(initial_weight) 255 | 256 | def forward(self, x): 257 | h = x 258 | h = self.relu1_1(self.conv1_1(h)) 259 | h = self.relu1_2(self.conv1_2(h)) 260 | h = self.pool1(h) 261 | 262 | h = self.relu2_1(self.conv2_1(h)) 263 | h = self.relu2_2(self.conv2_2(h)) 264 | h = self.pool2(h) 265 | 266 | h = self.relu3_1(self.conv3_1(h)) 267 | h = self.relu3_2(self.conv3_2(h)) 268 | h = self.relu3_3(self.conv3_3(h)) 269 | h = self.pool3(h) 270 | 271 | h = self.relu4_1(self.conv4_1(h)) 272 | h = self.relu4_2(self.conv4_2(h)) 273 | h = self.relu4_3(self.conv4_3(h)) 274 | h = self.pool4(h) 275 | pool4 = h # 1/16 276 | 277 | h = self.relu5_1(self.conv5_1(h)) 278 | h = self.relu5_2(self.conv5_2(h)) 279 | h = self.relu5_3(self.conv5_3(h)) 280 | h = self.pool5(h) 281 | 282 | h = self.relu6(self.fc6(h)) 283 | h = self.drop6(h) 284 | 285 | h = self.relu7(self.fc7(h)) 286 | h = self.drop7(h) 287 | 288 | h = self.score_fr(h) 289 | h = self.upscore2(h) 290 | upscore2 = h # 1/16 291 | 292 | h = self.score_pool4(pool4) 293 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 294 | score_pool4c = h # 1/16 295 | 296 | h = upscore2 + score_pool4c 297 | 298 | h = self.upscore16(h) 299 | h = h[:, :, 27:27 + x.size()[2], 27:27 + x.size()[3]].contiguous() 300 | 301 | return h 302 | 303 | def copy_params_from_fcn32s(self, fcn32s): 304 | for name, l1 in fcn32s.named_children(): 305 | try: 306 | l2 = getattr(self, name) 307 | l2.weight # skip ReLU / Dropout 308 | except Exception: 309 | continue 310 | assert l1.weight.size() == l2.weight.size() 311 | assert l1.bias.size() == l2.bias.size() 312 | l2.weight.data.copy_(l1.weight.data) 313 | l2.bias.data.copy_(l1.bias.data) 314 | 315 | 316 | class FCN8s(nn.Module): 317 | 318 | # pretrained_model = \ 319 | # osp.expanduser('~/data/models/pytorch/fcn8s_from_caffe.pth') 320 | 321 | # @classmethod 322 | # def download(cls): 323 | # return fcn.data.cached_download( 324 | # url='http://drive.google.com/uc?id=0B9P1L--7Wd2vT0FtdThWREhjNkU', 325 | # path=cls.pretrained_model, 326 | # md5='dbd9bbb3829a3184913bccc74373afbb', 327 | # ) 328 | 329 | def __init__(self, n_class=21): 330 | super(FCN8s, self).__init__() 331 | # conv1 332 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 333 | self.relu1_1 = nn.ReLU(inplace=True) 334 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 335 | self.relu1_2 = nn.ReLU(inplace=True) 336 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 337 | 338 | # conv2 339 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 340 | self.relu2_1 = nn.ReLU(inplace=True) 341 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 342 | self.relu2_2 = nn.ReLU(inplace=True) 343 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 344 | 345 | # conv3 346 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 347 | self.relu3_1 = nn.ReLU(inplace=True) 348 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 349 | self.relu3_2 = nn.ReLU(inplace=True) 350 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 351 | self.relu3_3 = nn.ReLU(inplace=True) 352 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 353 | 354 | # conv4 355 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 356 | self.relu4_1 = nn.ReLU(inplace=True) 357 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 358 | self.relu4_2 = nn.ReLU(inplace=True) 359 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 360 | self.relu4_3 = nn.ReLU(inplace=True) 361 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 362 | 363 | # conv5 364 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 365 | self.relu5_1 = nn.ReLU(inplace=True) 366 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 367 | self.relu5_2 = nn.ReLU(inplace=True) 368 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 369 | self.relu5_3 = nn.ReLU(inplace=True) 370 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 371 | 372 | # fc6 373 | self.fc6 = nn.Conv2d(512, 4096, 7) 374 | self.relu6 = nn.ReLU(inplace=True) 375 | self.drop6 = nn.Dropout2d() 376 | 377 | # fc7 378 | self.fc7 = nn.Conv2d(4096, 4096, 1) 379 | self.relu7 = nn.ReLU(inplace=True) 380 | self.drop7 = nn.Dropout2d() 381 | 382 | self.score_fr = nn.Conv2d(4096, n_class, 1) 383 | self.score_pool3 = nn.Conv2d(256, n_class, 1) 384 | self.score_pool4 = nn.Conv2d(512, n_class, 1) 385 | 386 | self.upscore2 = nn.ConvTranspose2d( 387 | n_class, n_class, 4, stride=2, bias=False) 388 | self.upscore8 = nn.ConvTranspose2d( 389 | n_class, n_class, 16, stride=8, bias=False) 390 | self.upscore_pool4 = nn.ConvTranspose2d( 391 | n_class, n_class, 4, stride=2, bias=False) 392 | 393 | self._initialize_weights() 394 | 395 | def _initialize_weights(self): 396 | for m in self.modules(): 397 | if isinstance(m, nn.Conv2d): 398 | m.weight.data.zero_() 399 | if m.bias is not None: 400 | m.bias.data.zero_() 401 | if isinstance(m, nn.ConvTranspose2d): 402 | assert m.kernel_size[0] == m.kernel_size[1] 403 | initial_weight = get_upsampling_weight( 404 | m.in_channels, m.out_channels, m.kernel_size[0]) 405 | m.weight.data.copy_(initial_weight) 406 | 407 | def forward(self, x): 408 | h = x 409 | h = self.relu1_1(self.conv1_1(h)) 410 | h = self.relu1_2(self.conv1_2(h)) 411 | h = self.pool1(h) 412 | 413 | h = self.relu2_1(self.conv2_1(h)) 414 | h = self.relu2_2(self.conv2_2(h)) 415 | h = self.pool2(h) 416 | 417 | h = self.relu3_1(self.conv3_1(h)) 418 | h = self.relu3_2(self.conv3_2(h)) 419 | h = self.relu3_3(self.conv3_3(h)) 420 | h = self.pool3(h) 421 | pool3 = h # 1/8 422 | 423 | h = self.relu4_1(self.conv4_1(h)) 424 | h = self.relu4_2(self.conv4_2(h)) 425 | h = self.relu4_3(self.conv4_3(h)) 426 | h = self.pool4(h) 427 | pool4 = h # 1/16 428 | 429 | h = self.relu5_1(self.conv5_1(h)) 430 | h = self.relu5_2(self.conv5_2(h)) 431 | h = self.relu5_3(self.conv5_3(h)) 432 | h = self.pool5(h) 433 | 434 | h = self.relu6(self.fc6(h)) 435 | h = self.drop6(h) 436 | 437 | h = self.relu7(self.fc7(h)) 438 | h = self.drop7(h) 439 | 440 | h = self.score_fr(h) 441 | h = self.upscore2(h) 442 | upscore2 = h # 1/16 443 | 444 | h = self.score_pool4(pool4) 445 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 446 | score_pool4c = h # 1/16 447 | 448 | h = upscore2 + score_pool4c # 1/16 449 | h = self.upscore_pool4(h) 450 | upscore_pool4 = h # 1/8 451 | 452 | h = self.score_pool3(pool3) 453 | h = h[:, :, 454 | 9:9 + upscore_pool4.size()[2], 455 | 9:9 + upscore_pool4.size()[3]] 456 | score_pool3c = h # 1/8 457 | 458 | h = upscore_pool4 + score_pool3c # 1/8 459 | 460 | h = self.upscore8(h) 461 | h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous() 462 | 463 | return h 464 | 465 | def copy_params_from_fcn16s(self, fcn16s): 466 | for name, l1 in fcn16s.named_children(): 467 | try: 468 | l2 = getattr(self, name) 469 | l2.weight # skip ReLU / Dropout 470 | except Exception: 471 | continue 472 | assert l1.weight.size() == l2.weight.size() 473 | l2.weight.data.copy_(l1.weight.data) 474 | if l1.bias is not None: 475 | assert l1.bias.size() == l2.bias.size() 476 | l2.bias.data.copy_(l1.bias.data) 477 | 478 | 479 | class FCN8sAtOnce(FCN8s): 480 | 481 | # pretrained_model = \ 482 | # osp.expanduser('~/data/models/pytorch/fcn8s-atonce_from_caffe.pth') 483 | 484 | # @classmethod 485 | # def download(cls): 486 | # return fcn.data.cached_download( 487 | # url='http://drive.google.com/uc?id=0B9P1L--7Wd2vblE1VUIxV1o2d2M', 488 | # path=cls.pretrained_model, 489 | # md5='bfed4437e941fef58932891217fe6464', 490 | # ) 491 | 492 | def forward(self, x): 493 | h = x 494 | h = self.relu1_1(self.conv1_1(h)) 495 | h = self.relu1_2(self.conv1_2(h)) 496 | h = self.pool1(h) 497 | 498 | h = self.relu2_1(self.conv2_1(h)) 499 | h = self.relu2_2(self.conv2_2(h)) 500 | h = self.pool2(h) 501 | 502 | h = self.relu3_1(self.conv3_1(h)) 503 | h = self.relu3_2(self.conv3_2(h)) 504 | h = self.relu3_3(self.conv3_3(h)) 505 | h = self.pool3(h) 506 | pool3 = h # 1/8 507 | 508 | h = self.relu4_1(self.conv4_1(h)) 509 | h = self.relu4_2(self.conv4_2(h)) 510 | h = self.relu4_3(self.conv4_3(h)) 511 | h = self.pool4(h) 512 | pool4 = h # 1/16 513 | 514 | h = self.relu5_1(self.conv5_1(h)) 515 | h = self.relu5_2(self.conv5_2(h)) 516 | h = self.relu5_3(self.conv5_3(h)) 517 | h = self.pool5(h) 518 | 519 | h = self.relu6(self.fc6(h)) 520 | h = self.drop6(h) 521 | 522 | h = self.relu7(self.fc7(h)) 523 | h = self.drop7(h) 524 | 525 | h = self.score_fr(h) 526 | h = self.upscore2(h) 527 | upscore2 = h # 1/16 528 | 529 | h = self.score_pool4(pool4 * 0.01) # XXX: scaling to train at once 530 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 531 | score_pool4c = h # 1/16 532 | 533 | h = upscore2 + score_pool4c # 1/16 534 | h = self.upscore_pool4(h) 535 | upscore_pool4 = h # 1/8 536 | 537 | h = self.score_pool3(pool3 * 0.0001) # XXX: scaling to train at once 538 | h = h[:, :, 539 | 9:9 + upscore_pool4.size()[2], 540 | 9:9 + upscore_pool4.size()[3]] 541 | score_pool3c = h # 1/8 542 | 543 | h = upscore_pool4 + score_pool3c # 1/8 544 | 545 | h = self.upscore8(h) 546 | h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous() 547 | 548 | return h 549 | 550 | def copy_params_from_vgg16(self, vgg16): 551 | features = [ 552 | self.conv1_1, self.relu1_1, 553 | self.conv1_2, self.relu1_2, 554 | self.pool1, 555 | self.conv2_1, self.relu2_1, 556 | self.conv2_2, self.relu2_2, 557 | self.pool2, 558 | self.conv3_1, self.relu3_1, 559 | self.conv3_2, self.relu3_2, 560 | self.conv3_3, self.relu3_3, 561 | self.pool3, 562 | self.conv4_1, self.relu4_1, 563 | self.conv4_2, self.relu4_2, 564 | self.conv4_3, self.relu4_3, 565 | self.pool4, 566 | self.conv5_1, self.relu5_1, 567 | self.conv5_2, self.relu5_2, 568 | self.conv5_3, self.relu5_3, 569 | self.pool5, 570 | ] 571 | for l1, l2 in zip(vgg16.features, features): 572 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 573 | assert l1.weight.size() == l2.weight.size() 574 | assert l1.bias.size() == l2.bias.size() 575 | l2.weight.data.copy_(l1.weight.data) 576 | l2.bias.data.copy_(l1.bias.data) 577 | for i, name in zip([0, 3], ['fc6', 'fc7']): 578 | l1 = vgg16.classifier[i] 579 | l2 = getattr(self, name) 580 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 581 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 582 | 583 | class VGG16Modified(nn.Module): 584 | 585 | def __init__(self, n_classes=21): 586 | super(VGG16Modified, self).__init__() 587 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 588 | self.relu1_1 = nn.ReLU(inplace=True) 589 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 590 | self.relu1_2 = nn.ReLU(inplace=True) 591 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 128 592 | 593 | # conv2 594 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 595 | self.relu2_1 = nn.ReLU(inplace=True) 596 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 597 | self.relu2_2 = nn.ReLU(inplace=True) 598 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 64 599 | 600 | # conv3 601 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 602 | self.relu3_1 = nn.ReLU(inplace=True) 603 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 604 | self.relu3_2 = nn.ReLU(inplace=True) 605 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 606 | self.relu3_3 = nn.ReLU(inplace=True) 607 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 32 608 | 609 | # conv4 610 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 611 | self.relu4_1 = nn.ReLU(inplace=True) 612 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 613 | self.relu4_2 = nn.ReLU(inplace=True) 614 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 615 | self.relu4_3 = nn.ReLU(inplace=True) 616 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 16 617 | 618 | # conv5 619 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 620 | self.relu5_1 = nn.ReLU(inplace=True) 621 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 622 | self.relu5_2 = nn.ReLU(inplace=True) 623 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 624 | self.relu5_3 = nn.ReLU(inplace=True) 625 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 8 626 | 627 | 628 | self.conv6s_re = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 629 | nn.ReLU(inplace=True), 630 | nn.Upsample(scale_factor=2, mode='bilinear')) 631 | self.conv6_3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 632 | nn.ReLU(inplace=True)) 633 | self.conv6_2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 634 | nn.ReLU(inplace=True)) 635 | self.conv6_1 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 636 | nn.ReLU(inplace=True), 637 | nn.Upsample(scale_factor=2, mode='bilinear')) 638 | 639 | self.conv7_3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 640 | nn.ReLU(inplace=True)) 641 | self.conv7_2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 642 | nn.ReLU(inplace=True)) 643 | self.conv7_1 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=True), 644 | nn.ReLU(inplace=True), 645 | nn.Upsample(scale_factor=2, mode='bilinear')) 646 | 647 | self.conv8_3 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 648 | nn.ReLU(inplace=True)) 649 | self.conv8_2 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 650 | nn.ReLU(inplace=True)) 651 | self.conv8_1 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=True), 652 | nn.ReLU(inplace=True), 653 | nn.Upsample(scale_factor=2, mode='bilinear')) 654 | 655 | self.conv9 = nn.Sequential(nn.Conv2d(128, 32*3*4, kernel_size=3, stride=1, padding=1, bias=True), 656 | nn.Tanh()) 657 | 658 | self.conv10 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=True), 659 | nn.ReLU(inplace=True)) 660 | 661 | self.conv11 = nn.Sequential(nn.Conv2d(64, n_classes, kernel_size=3, stride=1, padding=1, bias=True), 662 | nn.Upsample(scale_factor=2, mode='bilinear')) 663 | 664 | 665 | 666 | self.coarse_conv_in = nn.Sequential(nn.Conv2d(n_classes, 32, kernel_size=3, stride=1, padding=1, bias=True), 667 | nn.ReLU(inplace=True), 668 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=True), 669 | nn.ReLU(inplace=True), 670 | nn.AvgPool2d(kernel_size=2, stride=2)) 671 | 672 | 673 | def to_tridiagonal_multidim(self, w): 674 | N,W,C,D = w.size() 675 | tmp_w = w / torch.sum(torch.abs(w),dim=3).unsqueeze(-1) 676 | tmp_w = tmp_w.unsqueeze(2).expand([N,W,W,C,D]) 677 | 678 | eye_a = Variable(torch.diag(torch.ones(W-1).cuda(),diagonal=-1)) 679 | eye_b = Variable(torch.diag(torch.ones(W).cuda(),diagonal=0)) 680 | eye_c = Variable(torch.diag(torch.ones(W-1).cuda(),diagonal=1)) 681 | 682 | 683 | tmp_eye_a = eye_a.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 684 | a = tmp_w[:,:,:,:,0] * tmp_eye_a 685 | tmp_eye_b = eye_b.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 686 | b = tmp_w[:,:,:,:,1] * tmp_eye_b 687 | tmp_eye_c = eye_c.unsqueeze(-1).unsqueeze(0).expand([N,W,W,C]) 688 | c = tmp_w[:,:,:,:,2] * tmp_eye_c 689 | 690 | return a+b+c 691 | def forward(self, x, coarse_segmentation): 692 | h = x 693 | h = self.relu1_1(self.conv1_1(h)) 694 | h = self.relu1_2(self.conv1_2(h)) 695 | h = self.pool1(h) 696 | 697 | h = self.relu2_1(self.conv2_1(h)) 698 | h = self.relu2_2(self.conv2_2(h)) 699 | h = self.pool2(h) 700 | 701 | conv3_1 = self.relu3_1(self.conv3_1(h)) 702 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) 703 | conv3_3= self.relu3_3(self.conv3_3(conv3_2)) 704 | h = self.pool3(conv3_3) 705 | pool3 = h # 1/8 706 | 707 | conv4_1 = self.relu4_1(self.conv4_1(h)) 708 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) 709 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) 710 | h = self.pool4(conv4_3) 711 | pool4 = h # 1/16 712 | 713 | conv5_1 = self.relu5_1(self.conv5_1(h)) 714 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) 715 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) 716 | h = self.pool5(conv5_3) 717 | 718 | 719 | 720 | conv6_re = self.conv6s_re(h) 721 | 722 | 723 | 724 | skip_1 = conv5_3 + conv6_re 725 | conv6_3 = self.conv6_3(skip_1) 726 | skip_2 = conv5_2 + conv6_3 727 | conv6_2 = self.conv6_2(skip_2) 728 | skip_3 = conv5_1 + conv6_2 729 | conv6_1 = self.conv6_1(skip_3) 730 | 731 | skip_4 = conv4_3 + conv6_1 732 | conv7_3 = self.conv7_3(skip_4) 733 | skip_5 = conv4_2 + conv7_3 734 | conv7_2 = self.conv7_2(skip_5) 735 | skip_6 = conv4_1 + conv7_2 736 | conv7_1 = self.conv7_1(skip_6) 737 | 738 | skip_7 = conv3_3 + conv7_1 739 | conv8_3 = self.conv8_3(skip_7) 740 | skip_8 = conv3_2 + conv8_3 741 | conv8_2 = self.conv8_2(skip_8) 742 | skip_9 = conv3_1 + conv8_2 743 | conv8_1 = self.conv8_1(skip_9) 744 | 745 | conv9 = self.conv9(conv8_1) 746 | 747 | N,C,H,W = conv9.size() 748 | four_directions = C // 4 749 | conv9_reshaped_W = conv9.permute(0,2,3,1) 750 | # conv9_reshaped_H = conv9.permute(0,3,2,1) 751 | 752 | conv_x1_flat = conv9_reshaped_W[:,:,:,0:four_directions].contiguous() 753 | conv_y1_flat = conv9_reshaped_W[:,:,:,four_directions:2*four_directions].contiguous() 754 | conv_x2_flat = conv9_reshaped_W[:,:,:,2*four_directions:3*four_directions].contiguous() 755 | conv_y2_flat = conv9_reshaped_W[:,:,:,3*four_directions:4*four_directions].contiguous() 756 | 757 | w_x1 = conv_x1_flat.view(N,H,W,four_directions//3,3) # N, H, W, 32, 3 758 | w_y1 = conv_y1_flat.view(N,H,W,four_directions//3,3) # N, H, W, 32, 3 759 | w_x2 = conv_x2_flat.view(N,H,W,four_directions//3,3) # N, H, W, 32, 3 760 | w_y2 = conv_y2_flat.view(N,H,W,four_directions//3,3) # N, H, W, 32, 3 761 | 762 | rnn_h1 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 763 | rnn_h2 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 764 | rnn_h3 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 765 | rnn_h4 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 766 | 767 | x_t = self.coarse_conv_in(coarse_segmentation).permute(0,2,3,1) 768 | 769 | 770 | # horizontal 771 | for i in range(W): 772 | #left to right 773 | tmp_w = w_x1[:,:,i,:,:] # N, H, 1, 32, 3 774 | tmp_w = self.to_tridiagonal_multidim(tmp_w) # N, H, W, 32 775 | # tmp_x = x_t[:,:,i,:].unsqueeze(1) 776 | # tmp_x = tmp_x.expand([batch, W, H, 32]) 777 | if i == 0 : 778 | w_h_prev = 0 779 | else: 780 | w_h_prev = torch.sum(tmp_w * rnn_h1[:,:,i-1,:].clone().unsqueeze(1).expand([N, W, H, 32]),dim=2) 781 | 782 | 783 | w_x_curr = (1 - torch.sum(tmp_w, dim=2)) * x_t[:,:,i,:] 784 | 785 | rnn_h1[:,:,i,:] = w_x_curr + w_h_prev 786 | 787 | 788 | #right to left 789 | # tmp_w = w_x1[:,:,i,:,:] # N, H, 1, 32, 3 790 | # tmp_w = to_tridiagonal_multidim(tmp_w) 791 | 792 | if i == 0 : 793 | w_h_prev = 0 794 | else: 795 | w_h_prev = torch.sum(tmp_w * rnn_h2[:,:,W - i,:].clone().unsqueeze(1).expand([N, W, H, 32]),dim=2) 796 | 797 | 798 | w_x_curr = (1 - torch.sum(tmp_w, dim=2)) * x_t[:,:,W - i-1,:] 799 | rnn_h2[:,:,W - i-1,:] = w_x_curr + w_h_prev 800 | 801 | w_y1_T = w_y1.transpose(1,2) 802 | x_t_T = x_t.transpose(1,2) 803 | 804 | for i in range(H): 805 | # up to down 806 | tmp_w = w_y1_T[:,:,i,:,:] # N, W, 1, 32, 3 807 | tmp_w = self.to_tridiagonal_multidim(tmp_w) # N, W, H, 32 808 | 809 | if i == 0 : 810 | w_h_prev = 0 811 | else: 812 | w_h_prev = torch.sum(tmp_w * rnn_h3[:,:,i-1,:].clone().unsqueeze(1).expand([N, H, W, 32]),dim=2) 813 | 814 | w_x_curr = (1 - torch.sum(tmp_w, dim=2)) * x_t_T[:,:,i,:] 815 | rnn_h3[:,:,i,:] = w_x_curr + w_h_prev 816 | 817 | # down to up 818 | if i == 0 : 819 | w_h_prev = 0 820 | else: 821 | w_h_prev = torch.sum(tmp_w * rnn_h4[:,:,H - i,:].clone().unsqueeze(1).expand([N, H, W, 32]),dim=2) 822 | 823 | w_x_curr = (1 - torch.sum(tmp_w, dim=2)) * x_t[:,:,H-i-1,:] 824 | rnn_h4[:,:,H-i-1,:] = w_x_curr + w_h_prev 825 | 826 | rnn_h3 = rnn_h3.transpose(1,2) 827 | rnn_h4 = rnn_h4.transpose(1,2) 828 | 829 | rnn_h5 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 830 | rnn_h6 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 831 | rnn_h7 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 832 | rnn_h8 = Variable(torch.zeros((N, H, W, four_directions//3)).cuda()) 833 | 834 | # horizontal 835 | for i in range(W): 836 | #left to right 837 | tmp_w = w_x2[:,:,i,:,:] # N, H, 1, 32, 3 838 | tmp_w = self.to_tridiagonal_multidim(tmp_w) # N, H, W, 32 839 | # tmp_x = x_t[:,:,i,:].unsqueeze(1) 840 | # tmp_x = tmp_x.expand([batch, W, H, 32]) 841 | if i == 0 : 842 | w_h_prev = 0 843 | else: 844 | w_h_prev = torch.sum(tmp_w * rnn_h5[:,:,i-1,:].clone().unsqueeze(1).expand([N, W, H, 32]),dim=2) 845 | w_x_curr = (1 - torch.sum(tmp_w, dim=2)) * rnn_h1[:,:,i,:] 846 | rnn_h5[:,:,i,:] = w_x_curr + w_h_prev 847 | 848 | 849 | #right to left 850 | # tmp_w = w_x1[:,:,i,:,:] # N, H, 1, 32, 3 851 | # tmp_w = to_tridiagonal_multidim(tmp_w) 852 | if i == 0 : 853 | w_h_prev = 0 854 | else: 855 | w_h_prev = torch.sum(tmp_w * rnn_h6[:,:,W-i,:].clone().unsqueeze(1).expand([N, W, H, 32]),dim=2) 856 | w_x_curr = (1 - torch.sum(tmp_w, dim=2)) * rnn_h2[:,:,W - i-1,:] 857 | rnn_h6[:,:,W - i-1,:] = w_x_curr + w_h_prev 858 | 859 | w_y2_T = w_y2.transpose(1,2) 860 | rnn_h3_T = rnn_h3.transpose(1,2) 861 | rnn_h4_T = rnn_h4.transpose(1,2) 862 | for i in range(H): 863 | # up to down 864 | tmp_w = w_y2_T[:,:,i,:,:] # N, W, 1, 32, 3 865 | tmp_w = self.to_tridiagonal_multidim(tmp_w) # N, W, H, 32 866 | if i == 0 : 867 | w_h_prev = 0 868 | else: 869 | w_h_prev = torch.sum(tmp_w * rnn_h7[:,:,i-1,:].clone().unsqueeze(1).expand([N, H, W, 32]),dim=2) 870 | w_x_curr = (1 - torch.sum(tmp_w, dim=2)) * rnn_h3_T[:,:,i,:] 871 | rnn_h7[:,:,i,:] = w_x_curr + w_h_prev 872 | 873 | # down to up 874 | if i == 0 : 875 | w_h_prev = 0 876 | else: 877 | w_h_prev = torch.sum(tmp_w * rnn_h8[:,:,H-i,:].clone().unsqueeze(1).expand([N, H, W, 32]),dim=2) 878 | w_x_curr = (1 - torch.sum(tmp_w, dim=2)) * rnn_h4_T[:,:,H-i-1,:] 879 | rnn_h8[:,:,H-i-1,:] = w_x_curr + w_h_prev 880 | 881 | rnn_h3 = rnn_h3.transpose(1,2) 882 | rnn_h4 = rnn_h4.transpose(1,2) 883 | 884 | concat6 = torch.cat([rnn_h5.unsqueeze(4),rnn_h6.unsqueeze(4),rnn_h7.unsqueeze(4),rnn_h8.unsqueeze(4)],dim=4) 885 | elt_max = torch.max(concat6, dim=4)[0] 886 | elt_max_reordered = elt_max.permute(0,3,1,2) 887 | conv10 = self.conv10(elt_max_reordered) 888 | conv11 = self.conv11(conv10) 889 | return conv11 890 | 891 | def copy_params_from_vgg16(self, vgg_model_file): 892 | features = [ 893 | self.conv1_1, self.relu1_1, 894 | self.conv1_2, self.relu1_2, 895 | self.pool1, 896 | self.conv2_1, self.relu2_1, 897 | self.conv2_2, self.relu2_2, 898 | self.pool2, 899 | self.conv3_1, self.relu3_1, 900 | self.conv3_2, self.relu3_2, 901 | self.conv3_3, self.relu3_3, 902 | self.pool3, 903 | self.conv4_1, self.relu4_1, 904 | self.conv4_2, self.relu4_2, 905 | self.conv4_3, self.relu4_3, 906 | self.pool4, 907 | self.conv5_1, self.relu5_1, 908 | self.conv5_2, self.relu5_2, 909 | self.conv5_3, self.relu5_3, 910 | self.pool5, 911 | ] 912 | 913 | 914 | vgg16 = torchvision.models.vgg16(pretrained=False) 915 | state_dict = torch.load(vgg_model_file) 916 | vgg16.load_state_dict(state_dict) 917 | 918 | 919 | for l1, l2 in zip(vgg16.features, features): 920 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 921 | assert l1.weight.size() == l2.weight.size() 922 | assert l1.bias.size() == l2.bias.size() 923 | l2.weight.data.copy_(l1.weight.data) 924 | l2.bias.data.copy_(l1.bias.data) 925 | 926 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | import time 7 | import datetime 8 | from torch.autograd import grad 9 | from torch.autograd import Variable 10 | from torchvision.utils import save_image 11 | from torchvision import transforms 12 | from model import * 13 | from PIL import Image 14 | from data_loader import PascalVOC2012 15 | 16 | class CrossEntropyLoss2d(nn.Module): 17 | def __init__(self, weight=None, size_average=True, ignore_index=255): 18 | super(CrossEntropyLoss2d, self).__init__() 19 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 20 | 21 | def forward(self, inputs, targets): 22 | return self.nll_loss(F.log_softmax(inputs), targets) 23 | 24 | 25 | class Solver(object): 26 | DEFAULTS = {} 27 | def __init__(self, data_loader, config): 28 | self.__dict__.update(Solver.DEFAULTS, **config) 29 | self.data_loader = data_loader 30 | # Build tensorboard if use 31 | self.build_model() 32 | if self.use_tensorboard: 33 | self.build_tensorboard() 34 | 35 | # Start with trained model 36 | if self.pretrained_model: 37 | self.load_pretrained_model() 38 | 39 | def build_model(self): 40 | # Define a generator and a discriminator 41 | self.FCN8 = FCN8s(n_class=self.num_classes) 42 | self.guidance_module = VGG16Modified(n_classes=self.num_classes) 43 | # Optimizers 44 | self.optimizer = torch.optim.Adam(self.guidance_module.parameters(), self.lr) 45 | 46 | # Print networks 47 | self.print_network(self.guidance_module, 'Guidance Network') 48 | 49 | if torch.cuda.is_available(): 50 | self.FCN8.cuda() 51 | self.guidance_module.cuda() 52 | 53 | model_data = torch.load(self.fcn_model_path) 54 | self.FCN8.load_state_dict(model_data) 55 | self.FCN8.eval() 56 | 57 | self.guidance_module.copy_params_from_vgg16(self.vgg_model_path) 58 | 59 | 60 | def train(self): 61 | # The number of iterations per epoch 62 | print("start training") 63 | iters_per_epoch = len(self.data_loader) 64 | 65 | fixed_x = [] 66 | fixed_target = [] 67 | for i, (images, target) in enumerate(self.data_loader): 68 | fixed_x.append(images) 69 | fixed_target.append(target) 70 | if i == 1: 71 | break 72 | print("sample data") 73 | # Fixed inputs and target domain labels for debugging 74 | fixed_x = torch.cat(fixed_x, dim=0) 75 | fixed_x = self.to_var(fixed_x, volatile=True) 76 | fixed_target = torch.cat(fixed_target, dim=0) 77 | # lr cache for decaying 78 | lr = self.lr 79 | 80 | # Start with trained model if exists 81 | if self.pretrained_model: 82 | start = int(self.pretrained_model.split('_')[0]) 83 | else: 84 | start = 0 85 | 86 | # Start training 87 | start_time = time.time() 88 | 89 | criterion = CrossEntropyLoss2d(size_average=False, ignore_index=-1).cuda() 90 | for e in range(start, self.num_epochs): 91 | 92 | for i, (images, target) in enumerate(self.data_loader): 93 | N = images.size(0) 94 | 95 | # Convert tensor to variable 96 | images = self.to_var(images) 97 | target = self.to_var(target) 98 | 99 | coarse_map = self.FCN8(images) 100 | 101 | refined_map= self.guidance_module(images,coarse_map) 102 | 103 | # tmp = refined_map.data.cpu().numpy() 104 | 105 | # assert refined_map.size()[2:] == target.size()[1:] 106 | # assert refined_map.size()[1] == self.num_classes 107 | softmax_ce_loss = criterion(refined_map, target) / N 108 | 109 | self.reset_grad() 110 | softmax_ce_loss.backward() 111 | 112 | self.optimizer.step() 113 | 114 | 115 | # # Compute classification accuracy of the discriminator 116 | # if (i+1) % self.log_step == 0: 117 | # accuracies = self.compute_accuracy(real_feature, real_label, self.dataset) 118 | # log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] 119 | # if self.dataset == 'CelebA': 120 | # print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='') 121 | # else: 122 | # print('Classification Acc (8 emotional expressions): ', end='') 123 | # print(log) 124 | 125 | 126 | # Logging 127 | loss = {} 128 | loss['loss'] = softmax_ce_loss.data[0] 129 | 130 | # Print out log info 131 | if (i+1) % self.log_step == 0: 132 | elapsed = time.time() - start_time 133 | elapsed = str(datetime.timedelta(seconds=elapsed)) 134 | 135 | log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( 136 | elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) 137 | 138 | for tag, value in loss.items(): 139 | log += ", {}: {:.4f}".format(tag, value) 140 | print(log) 141 | 142 | if self.use_tensorboard: 143 | for tag, value in loss.items(): 144 | self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) 145 | 146 | # Translate fixed images for debugging 147 | if (i+1) % self.sample_step == 0: 148 | fake_image_list = [torch.from_numpy(self.data_loader.dataset.untransform_batch(fixed_x.data.cpu()))] 149 | coarse_map = self.FCN8(fixed_x) 150 | refined_map= self.guidance_module(fixed_x,coarse_map) 151 | 152 | lbl_pred = coarse_map.data.max(1)[1].cpu().numpy() 153 | lbl_pred_refined = refined_map.data.max(1)[1].cpu().numpy() 154 | lbl_pred = self.data_loader.dataset.colorize_mask_batch(lbl_pred) 155 | lbl_pred_refined = self.data_loader.dataset.colorize_mask_batch(lbl_pred_refined) 156 | lbl_true = self.data_loader.dataset.colorize_mask_batch(fixed_target.numpy()) 157 | # print(lbl_pred.size()) 158 | # print(lbl_pred_refined.size()) 159 | # print(lbl_true.size()) 160 | fake_image_list.append(lbl_pred) 161 | fake_image_list.append(lbl_pred_refined) 162 | fake_image_list.append(lbl_true) 163 | # fake_image_list.append(lbl_pred_refined.unsqueeze(1).expand(fixed_x.size()).float()) 164 | # fake_image_list.append(lbl_true) 165 | fake_images = torch.cat(fake_image_list, dim=3) 166 | save_image(fake_images, 167 | os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) 168 | print('Translated images and saved into {}..!'.format(self.sample_path)) 169 | 170 | del coarse_map, refined_map, lbl_pred, lbl_pred_refined, fake_image_list 171 | 172 | # Save model checkpoints 173 | if (i+1) % self.model_save_step == 0: 174 | torch.save(self.guidance_module.state_dict(), 175 | os.path.join(self.model_save_path, '{}_{}_spatial.pth'.format(e+1, i+1))) 176 | 177 | # Decay learning rate 178 | if (e+1) > (self.num_epochs - self.num_epochs_decay): 179 | lr -= (self.lr / float(self.num_epochs_decay)) 180 | self.update_lr(lr) 181 | print ('Decay learning rate to lr: {}.'.format(lr)) 182 | def labels_to_rgb(self,labels): 183 | return 184 | def print_network(self, model, name): 185 | num_params = 0 186 | for p in model.parameters(): 187 | num_params += p.numel() 188 | print(name) 189 | print(model) 190 | print("The number of parameters: {}".format(num_params)) 191 | 192 | def load_pretrained_model(self): 193 | self.guidance_module.load_state_dict(torch.load(os.path.join( 194 | self.model_save_path, '{}_spatial.pth'.format(self.pretrained_model)))) 195 | 196 | print('loaded trained models (step: {})..!'.format(self.pretrained_model)) 197 | 198 | def build_tensorboard(self): 199 | from logger import Logger 200 | self.logger = Logger(self.log_path) 201 | 202 | def update_lr(self, lr): 203 | for param_group in self.optimizer.param_groups: 204 | param_group['lr'] = lr 205 | 206 | def reset_grad(self): 207 | self.optimizer.zero_grad() 208 | 209 | def to_var(self, x, volatile=False): 210 | if torch.cuda.is_available(): 211 | x = x.cuda() 212 | return Variable(x, volatile=volatile) 213 | 214 | def denorm(self, x): 215 | out = (x + 1) / 2 216 | return out.clamp_(0, 1) 217 | 218 | def threshold(self, x): 219 | x = x.clone() 220 | x[x >= 0.5] = 1 221 | x[x < 0.5] = 0 222 | return x 223 | 224 | def compute_accuracy(self, x, y, dataset): 225 | if dataset == 'CelebA': 226 | x = F.sigmoid(x) 227 | predicted = self.threshold(x) 228 | correct = (predicted == y).float() 229 | accuracy = torch.mean(correct, dim=0) * 100.0 230 | elif dataset == 'Flowers': 231 | x = F.sigmoid(x) 232 | predicted = self.threshold(x) 233 | correct = (predicted == y).float() 234 | accuracy = torch.mean(correct, dim=0) * 100.0 235 | 236 | else: 237 | _, predicted = torch.max(x, dim=1) 238 | correct = (predicted == y).float() 239 | accuracy = torch.mean(correct) * 100.0 240 | return accuracy 241 | 242 | def one_hot(self, labels, dim): 243 | """Convert label indices to one-hot vector""" 244 | batch_size = labels.size(0) 245 | out = torch.zeros(batch_size, dim) 246 | out[np.arange(batch_size), labels.long()] = 1 247 | return out 248 | 249 | def test(self): 250 | """Facial attribute transfer on CelebA or facial expression synthesis on RaFD.""" 251 | # Load trained parameters 252 | G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model)) 253 | self.G.load_state_dict(torch.load(G_path)) 254 | self.G.eval() 255 | 256 | if self.dataset == 'CelebA': 257 | data_loader = self.celebA_loader 258 | else: 259 | data_loader = self.rafd_loader 260 | 261 | for i, (real_x, org_c) in enumerate(data_loader): 262 | real_x = self.to_var(real_x, volatile=True) 263 | 264 | if self.dataset == 'CelebA': 265 | target_c_list = self.make_celeb_labels(org_c) 266 | else: 267 | target_c_list = [] 268 | for j in range(self.c_dim): 269 | target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim) 270 | target_c_list.append(self.to_var(target_c, volatile=True)) 271 | 272 | # Start translations 273 | fake_image_list = [real_x] 274 | for target_c in target_c_list: 275 | fake_image_list.append(self.G(real_x, target_c)) 276 | fake_images = torch.cat(fake_image_list, dim=3) 277 | save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1)) 278 | save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) 279 | print('Translated test images and saved into "{}"..!'.format(save_path)) 280 | 281 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mkdir(directory): 4 | if not os.path.exists(directory): 5 | os.makedirs(directory) -------------------------------------------------------------------------------- /voc_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import os.path as osp 5 | import numbers 6 | import random 7 | import math 8 | import numpy as np 9 | import PIL.Image 10 | import scipy.io 11 | import torch 12 | from torch.utils import data 13 | from torchvision import transforms 14 | 15 | class RandomCropGenerator(object): 16 | def __call__(self, img): 17 | self.x1 = random.uniform(0, 1) 18 | self.y1 = random.uniform(0, 1) 19 | return img 20 | 21 | class RandomCrop(object): 22 | def __init__(self, size, padding=0, gen=None): 23 | if isinstance(size, numbers.Number): 24 | self.size = (int(size), int(size)) 25 | else: 26 | self.size = size 27 | self.padding = padding 28 | self._gen = gen 29 | 30 | def __call__(self, img): 31 | if self.padding > 0: 32 | img = ImageOps.expand(img, border=self.padding, fill=0) 33 | w, h = img.size 34 | th, tw = self.size 35 | if w == tw and h == th: 36 | return img 37 | 38 | if self._gen is not None: 39 | x1 = math.floor(self._gen.x1 * (w - tw)) 40 | y1 = math.floor(self._gen.y1 * (h - th)) 41 | else: 42 | x1 = random.randint(0, w - tw) 43 | y1 = random.randint(0, h - th) 44 | 45 | return img.crop((x1, y1, x1 + tw, y1 + th)) 46 | 47 | class VOCClassSegBase(data.Dataset): 48 | 49 | class_names = np.array([ 50 | 'background', 51 | 'aeroplane', 52 | 'bicycle', 53 | 'bird', 54 | 'boat', 55 | 'bottle', 56 | 'bus', 57 | 'car', 58 | 'cat', 59 | 'chair', 60 | 'cow', 61 | 'diningtable', 62 | 'dog', 63 | 'horse', 64 | 'motorbike', 65 | 'person', 66 | 'potted plant', 67 | 'sheep', 68 | 'sofa', 69 | 'train', 70 | 'tv/monitor', 71 | ]) 72 | mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 73 | 74 | def __init__(self, root, split='train', transform=False): 75 | self.root = root 76 | self.split = split 77 | self._transform = transform 78 | 79 | # VOC2011 and others are subset of VOC2012 80 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 81 | print(dataset_dir) 82 | self.files = collections.defaultdict(list) 83 | for split in ['train', 'val']: 84 | imgsets_file = osp.join( 85 | dataset_dir, 'ImageSets/Segmentation/%s.txt' % split) 86 | for did in open(imgsets_file): 87 | did = did.strip() 88 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 89 | lbl_file = osp.join( 90 | dataset_dir, 'SegmentationClass/%s.png' % did) 91 | self.files[split].append({ 92 | 'img': img_file, 93 | 'lbl': lbl_file, 94 | }) 95 | print(len(self.files["train"])) 96 | 97 | 98 | 99 | def __len__(self): 100 | return len(self.files[self.split]) 101 | 102 | def __getitem__(self, index): 103 | data_file = self.files[self.split][index] 104 | # load image 105 | img_file = data_file['img'] 106 | img_pil = PIL.Image.open(img_file) 107 | 108 | gen = RandomCropGenerator() 109 | onlyBgPatch = True 110 | while onlyBgPatch: 111 | 112 | transform_img = transforms.Compose([ 113 | gen, 114 | RandomCrop(128, gen=gen), 115 | transforms.Resize([128, 128])]) 116 | 117 | img = np.array(transform_img(img_pil),dtype=np.uint8) 118 | 119 | transform_mask = transforms.Compose([ 120 | RandomCrop(128, gen=gen)]) 121 | 122 | # load label 123 | lbl_file = data_file['lbl'] 124 | lbl_pil = PIL.Image.open(lbl_file) 125 | 126 | lbl_cropped = transform_mask(lbl_pil) 127 | lbl = np.array(transform_mask(lbl_pil), dtype=np.int32) 128 | lbl[lbl == 255] = -1 129 | unique_vals = np.unique(lbl) 130 | if len(unique_vals) > 2: 131 | onlyBgPatch = False 132 | for i in unique_vals: 133 | percentage_covered = np.sum(lbl==i) / (128*128) 134 | 135 | if percentage_covered >= 0.9: 136 | onlyBgPatch = True 137 | break 138 | 139 | if self._transform: 140 | return self.transform(img, lbl) 141 | else: 142 | return img, lbl 143 | 144 | def transform(self, img, lbl): 145 | img = img[:, :, ::-1] # RGB -> BGR 146 | img = img.astype(np.float64) 147 | img -= self.mean_bgr 148 | img = img.transpose(2, 0, 1) 149 | img = torch.from_numpy(img).float() 150 | lbl = torch.from_numpy(lbl).long() 151 | return img, lbl 152 | 153 | def untransform(self, img, lbl): 154 | img = img.numpy() 155 | img = img.transpose(1, 2, 0) 156 | img += self.mean_bgr 157 | img = img.astype(np.uint8) 158 | img = img[:, :, ::-1] 159 | lbl = lbl.numpy() 160 | return img, lbl 161 | 162 | 163 | class VOC2011ClassSeg(VOCClassSegBase): 164 | 165 | def __init__(self, root, split='train', transform=False): 166 | super(VOC2011ClassSeg, self).__init__( 167 | root, split=split, transform=transform) 168 | pkg_root = osp.join(osp.dirname(osp.realpath(__file__)), '..') 169 | imgsets_file = osp.join( 170 | pkg_root, 'ext/fcn.berkeleyvision.org', 171 | 'data/pascal/seg11valid.txt') 172 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 173 | for did in open(imgsets_file): 174 | did = did.strip() 175 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) 176 | lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % did) 177 | self.files['seg11valid'].append({'img': img_file, 'lbl': lbl_file}) 178 | 179 | 180 | class VOC2012ClassSeg(VOCClassSegBase): 181 | 182 | url = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' # NOQA 183 | 184 | def __init__(self, root, split='train', transform=False): 185 | super(VOC2012ClassSeg, self).__init__( 186 | root, split=split, transform=transform) 187 | 188 | 189 | class SBDClassSeg(VOCClassSegBase): 190 | 191 | # XXX: It must be renamed to benchmark.tar to be extracted. 192 | url = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz' # NOQA 193 | 194 | def __init__(self, root, split='train', transform=False): 195 | self.root = root 196 | self.split = split 197 | self._transform = transform 198 | 199 | dataset_dir = osp.join(self.root, 'VOC/benchmark_RELEASE/dataset') 200 | self.files = collections.defaultdict(list) 201 | for split in ['train', 'val']: 202 | imgsets_file = osp.join(dataset_dir, '%s.txt' % split) 203 | for did in open(imgsets_file): 204 | did = did.strip() 205 | img_file = osp.join(dataset_dir, 'img/%s.jpg' % did) 206 | lbl_file = osp.join(dataset_dir, 'cls/%s.mat' % did) 207 | self.files[split].append({ 208 | 'img': img_file, 209 | 'lbl': lbl_file, 210 | }) 211 | 212 | def __getitem__(self, index): 213 | data_file = self.files[self.split][index] 214 | # load image 215 | img_file = data_file['img'] 216 | img = PIL.Image.open(img_file) 217 | img = np.array(img, dtype=np.uint8) 218 | # load label 219 | lbl_file = data_file['lbl'] 220 | mat = scipy.io.loadmat(lbl_file) 221 | lbl = mat['GTcls'][0]['Segmentation'][0].astype(np.int32) 222 | lbl[lbl == 255] = -1 223 | if self._transform: 224 | return self.transform(img, lbl) 225 | else: 226 | return img, lbl 227 | --------------------------------------------------------------------------------