├── requirements.txt
├── arch
├── __init__.py
├── discriminators.py
├── generators.py
└── ops.py
├── examples
├── model_image.png
├── result_square.png
├── VOC
│ ├── 2007_002105_gt.png
│ ├── 2007_002105_img.jpg
│ ├── 2007_002105_semiSup.png
│ └── 2007_002105_imgFromSeg.jpg
├── ACDC
│ ├── patient001_frame01_2_gt.png
│ ├── patient001_frame01_2_img.jpg
│ ├── patient001_frame01_2_seg2img.jpg
│ └── patient001_frame01_2_genLabels.png
└── Cityscapes
│ ├── dusseldorf_000176_000019_leftImg8bit_gt.png
│ ├── dusseldorf_000176_000019_leftImg8bit_img.jpg
│ ├── dusseldorf_000176_000019_leftImg8bit_seg2img.jpg
│ └── dusseldorf_000176_000019_leftImg8bit_genLabels.png
├── .gitignore
├── data
└── Datasets.txt
├── LICENSE
├── main.py
├── testing.py
├── README.md
├── data_utils
├── augmentations.py
├── __init__.py
└── dataloader.py
├── validation.py
├── utils.py
└── model.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pandas
3 | scipy
4 | tensorboardX
5 | pillow
6 |
--------------------------------------------------------------------------------
/arch/__init__.py:
--------------------------------------------------------------------------------
1 | from .generators import define_Gen
2 | from .discriminators import define_Dis
3 | from .ops import set_grad
--------------------------------------------------------------------------------
/examples/model_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/model_image.png
--------------------------------------------------------------------------------
/examples/result_square.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/result_square.png
--------------------------------------------------------------------------------
/examples/VOC/2007_002105_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/VOC/2007_002105_gt.png
--------------------------------------------------------------------------------
/examples/VOC/2007_002105_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/VOC/2007_002105_img.jpg
--------------------------------------------------------------------------------
/examples/VOC/2007_002105_semiSup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/VOC/2007_002105_semiSup.png
--------------------------------------------------------------------------------
/examples/ACDC/patient001_frame01_2_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/ACDC/patient001_frame01_2_gt.png
--------------------------------------------------------------------------------
/examples/VOC/2007_002105_imgFromSeg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/VOC/2007_002105_imgFromSeg.jpg
--------------------------------------------------------------------------------
/examples/ACDC/patient001_frame01_2_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/ACDC/patient001_frame01_2_img.jpg
--------------------------------------------------------------------------------
/examples/ACDC/patient001_frame01_2_seg2img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/ACDC/patient001_frame01_2_seg2img.jpg
--------------------------------------------------------------------------------
/examples/ACDC/patient001_frame01_2_genLabels.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/ACDC/patient001_frame01_2_genLabels.png
--------------------------------------------------------------------------------
/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_gt.png
--------------------------------------------------------------------------------
/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_img.jpg
--------------------------------------------------------------------------------
/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_seg2img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_seg2img.jpg
--------------------------------------------------------------------------------
/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_genLabels.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_genLabels.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | *.xml
3 | *.DS_Store
4 | # datasets
5 | datasets/VOC2012/*
6 | datasets/horse2zebra/*
7 | datasets/Cityscapes/*
8 | datasets/VOC2012/*
9 | datasets/Cityspaces/*
10 | arch/__pycache__/
11 | arch/__pycache__/*
12 | datasets/horse2zebra/*
13 | __pycache__/
14 |
--------------------------------------------------------------------------------
/data/Datasets.txt:
--------------------------------------------------------------------------------
1 | File containing the links to public segmentation benchmarks.
2 | - PASCAL VOC 2012 Segmentation (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/)
3 | - Cityscapes (https://www.cityscapes-dataset.com)
4 | - ACDC (https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html)
5 |
6 | Or you can download the preprocessed datasets from the link mentioned in the README
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Arnab Kumar Mondal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | import model as md
4 | from utils import create_link
5 | from testing import test
6 | from validation import validation
7 |
8 |
9 | # To get arguments from commandline
10 | def get_args():
11 | parser = ArgumentParser(description='cycleGAN PyTorch')
12 | parser.add_argument('--epochs', type=int, default=400)
13 | parser.add_argument('--decay_epoch', type=int, default=100)
14 | parser.add_argument('--batch_size', type=int, default=2)
15 | parser.add_argument('--lr', type=float, default=.0002)
16 | # parser.add_argument('--load_height', type=int, default=286)
17 | # parser.add_argument('--load_width', type=int, default=286)
18 | parser.add_argument('--gpu_ids', type=str, default='0')
19 | parser.add_argument('--crop_height', type=int, default=None)
20 | parser.add_argument('--crop_width', type=int, default=None)
21 | parser.add_argument('--lamda_img', type=int, default=0.5) # For image_cycle_loss
22 | parser.add_argument('--lamda_gt', type=int, default=0.1) # For gt_cycle_loss
23 | parser.add_argument('--lamda_perceptual', type=int, default=0) # For image cycle perceptual loss
24 | # parser.add_argument('--idt_coef', type=float, default=0.5)
25 | # parser.add_argument('--omega', type=int, default=5)
26 | parser.add_argument('--lab_CE_weight', type=int, default=1)
27 | parser.add_argument('--lab_MSE_weight', type=int, default=1)
28 | parser.add_argument('--lab_perceptual_weight', type=int, default=0)
29 | parser.add_argument('--adversarial_weight', type=int, default=1.0)
30 | parser.add_argument('--discriminator_weight', type=int, default=1.0)
31 | parser.add_argument('--training', type=bool, default=False)
32 | parser.add_argument('--testing', type=bool, default=False)
33 | parser.add_argument('--validation', type=bool, default=False)
34 | parser.add_argument('--model', type=str, default='supervised_model')
35 | parser.add_argument('--results_dir', type=str, default='./results')
36 | parser.add_argument('--validation_dir', type=str, default='./val_results')
37 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/semisupervised_cycleGAN')
38 | parser.add_argument('--dataset',type=str,choices=['voc2012', 'cityscapes', 'acdc'],default='voc2012')
39 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
40 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
41 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
42 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
43 | parser.add_argument('--gen_net', type=str, default='deeplab')
44 | parser.add_argument('--dis_net', type=str, default='fc_disc')
45 | args = parser.parse_args()
46 | return args
47 |
48 |
49 | def main():
50 | args = get_args()
51 | # set gpu ids
52 | str_ids = args.gpu_ids.split(',')
53 | args.gpu_ids = []
54 | for str_id in str_ids:
55 | id = int(str_id)
56 | if id >= 0:
57 | args.gpu_ids.append(id)
58 |
59 | ### For setting the image dimensions for different datasets
60 | if args.crop_height == None and args.crop_width == None:
61 | if args.dataset == 'voc2012':
62 | args.crop_height = args.crop_width = 320
63 | elif args.dataset == 'acdc':
64 | args.crop_height = args.crop_width = 256
65 | elif args.dataset == 'cityscapes':
66 | args.crop_height = 512
67 | args.crop_width = 1024
68 |
69 | if args.training:
70 | if args.model == "semisupervised_cycleGAN":
71 | print("Training semi-supervised cycleGAN")
72 | model = md.semisuper_cycleGAN(args)
73 | model.train(args)
74 | if args.model == "supervised_model":
75 | print("Training base model")
76 | model = md.supervised_model(args)
77 | model.train(args)
78 | if args.testing:
79 | print("Testing")
80 | test(args)
81 | if args.validation:
82 | print("Validating")
83 | validation(args)
84 |
85 |
86 | if __name__ == '__main__':
87 | main()
88 |
--------------------------------------------------------------------------------
/arch/discriminators.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from .ops import conv_norm_lrelu, get_norm_layer, init_network
3 | import torch
4 | from torch.nn import functional as F
5 | import functools
6 |
7 |
8 | class FCDiscriminator(nn.Module):
9 | ### This has been adapted directly from this repo:
10 | ### https://github.com/hfslyc/AdvSemiSeg
11 |
12 | def __init__(self, num_classes, ndf = 64):
13 | super(FCDiscriminator, self).__init__()
14 |
15 | self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
16 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
17 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
18 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
19 | self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1)
20 |
21 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
22 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
23 | #self.sigmoid = nn.Sigmoid()
24 |
25 |
26 | def forward(self, x):
27 | x = self.conv1(x)
28 | x = self.leaky_relu(x)
29 | x = self.conv2(x)
30 | x = self.leaky_relu(x)
31 | x = self.conv3(x)
32 | x = self.leaky_relu(x)
33 | x = self.conv4(x)
34 | x = self.leaky_relu(x)
35 | x = self.classifier(x)
36 | #x = self.up_sample(x)
37 | #x = self.sigmoid(x)
38 |
39 | return x
40 |
41 |
42 | class NLayerDiscriminator(nn.Module):
43 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_bias=False):
44 | super(NLayerDiscriminator, self).__init__()
45 | dis_model = [nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
46 | nn.LeakyReLU(0.2, True)]
47 | nf_mult = 1
48 | nf_mult_prev = 1
49 | for n in range(1, n_layers):
50 | nf_mult_prev = nf_mult
51 | nf_mult = min(2**n, 8)
52 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2,
53 | norm_layer= norm_layer, padding=1, bias=use_bias)]
54 | nf_mult_prev = nf_mult
55 | nf_mult = min(2**n_layers, 8)
56 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1,
57 | norm_layer= norm_layer, padding=1, bias=use_bias)]
58 | dis_model += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)]
59 |
60 | self.dis_model = nn.Sequential(*dis_model)
61 |
62 | def forward(self, input):
63 | return self.dis_model(input)
64 |
65 |
66 | class PixelDiscriminator(nn.Module):
67 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_bias=False):
68 | super(PixelDiscriminator, self).__init__()
69 | dis_model = [
70 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
71 | nn.LeakyReLU(0.2, True),
72 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
73 | norm_layer(ndf * 2),
74 | nn.LeakyReLU(0.2, True),
75 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
76 |
77 | self.dis_model = nn.Sequential(*dis_model)
78 |
79 | def forward(self, input):
80 | return self.dis_model(input)
81 |
82 |
83 |
84 | def define_Dis(input_nc, ndf, netD, n_layers_D=3, norm='batch', gpu_ids=[0]):
85 | dis_net = None
86 | norm_layer = get_norm_layer(norm_type=norm)
87 | if type(norm_layer) == functools.partial:
88 | use_bias = norm_layer.func == nn.InstanceNorm2d
89 | else:
90 | use_bias = norm_layer == nn.InstanceNorm2d
91 |
92 | if netD == 'n_layers':
93 | dis_net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_bias=use_bias)
94 | elif netD == 'pixel':
95 | dis_net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_bias=use_bias)
96 | elif netD == 'fc_disc':
97 | dis_net = FCDiscriminator(input_nc, ndf)
98 | else:
99 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
100 |
101 | return init_network(dis_net, gpu_ids)
102 |
--------------------------------------------------------------------------------
/testing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch import nn
5 | from torch.autograd import Variable
6 | import torchvision
7 | import torchvision.datasets as dsets
8 | from torch.utils.data import DataLoader
9 | import torchvision.transforms as transforms
10 | import utils
11 | from PIL import Image
12 | from arch import define_Gen
13 | from data_utils import VOCDataset, CityscapesDataset, ACDCDataset, get_transformation
14 |
15 | root = './data/VOC2012test/VOC2012'
16 | root_cityscapes = './data/Cityscape'
17 | root_acdc = './data/ACDC'
18 |
19 | def test(args):
20 |
21 | ### For selecting the number of channels
22 | if args.dataset == 'voc2012':
23 | n_channels = 21
24 | elif args.dataset == 'cityscapes':
25 | n_channels = 20
26 | elif args.dataset == 'acdc':
27 | n_channels = 4
28 |
29 | transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset)
30 |
31 | ## let the choice of dataset configurable
32 | if args.dataset == 'voc2012':
33 | test_set = VOCDataset(root_path=root, name='test', ratio=0.5, transformation=transform, augmentation=None)
34 | elif args.dataset == 'cityscapes':
35 | test_set = CityscapesDataset(root_path=root_cityscapes, name='test', ratio=0.5, transformation=transform, augmentation=None)
36 | elif args.dataset == 'acdc':
37 | test_set = ACDCDataset(root_path=root_acdc, name='test', ratio=0.5, transformation=transform, augmentation=None)
38 |
39 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)
40 |
41 | Gsi = define_Gen(input_nc=3, output_nc=n_channels, ngf=args.ngf, netG='resnet_9blocks_softmax',
42 | norm=args.norm, use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
43 |
44 | ### activation_softmax
45 | activation_softmax = nn.Softmax2d()
46 |
47 | if(args.model == 'supervised_model'):
48 |
49 | ### loading the checkpoint
50 | try:
51 | ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir))
52 | Gsi.load_state_dict(ckpt['Gsi'])
53 |
54 | except:
55 | print(' [*] No checkpoint!')
56 |
57 | ### run
58 | Gsi.eval()
59 | for i, (image_test, image_name) in enumerate(test_loader):
60 | image_test = utils.cuda(image_test, args.gpu_ids)
61 | seg_map = Gsi(image_test)
62 | seg_map = activation_softmax(seg_map)
63 |
64 | prediction = seg_map.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() ### To convert from 22 --> 1 channel
65 | for j in range(prediction.shape[0]):
66 | new_img = prediction[j] ### Taking a particular image from the batch
67 | new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image
68 |
69 | ### Now the new_img is PIL.Image
70 | new_img.save(os.path.join(args.results_dir+'/supervised/'+image_name[j]+'.png'))
71 |
72 |
73 | print('Epoch-', str(i+1), ' Done!')
74 |
75 | elif(args.model == 'semisupervised_cycleGAN'):
76 |
77 | ### loading the checkpoint
78 | try:
79 | ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir))
80 | Gsi.load_state_dict(ckpt['Gsi'])
81 |
82 | except:
83 | print(' [*] No checkpoint!')
84 |
85 | ### run
86 | Gsi.eval()
87 | for i, (image_test, image_name) in enumerate(test_loader):
88 | image_test = utils.cuda(image_test, args.gpu_ids)
89 | seg_map = Gsi(image_test)
90 | seg_map = activation_softmax(seg_map)
91 |
92 | prediction = seg_map.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() ### To convert from 22 --> 1 channel
93 | for j in range(prediction.shape[0]):
94 | new_img = prediction[j] ### Taking a particular image from the batch
95 | new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image
96 |
97 | ### Now the new_img is PIL.Image
98 | new_img.save(os.path.join(args.results_dir+'/unsupervised/'+image_name[j]+'.png'))
99 |
100 | print('Epoch-', str(i+1), ' Done!')
101 |
102 |
103 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Revisting Cycle-GAN for semi-supervised segmentation
2 | This repo contains the official Pytorch implementation of the paper: [Revisiting CycleGAN for semi-supervised segmentation](https://arxiv.org/abs/1908.11569)
3 |
4 | ## Contents
5 | 1. [Summary of the Model](#1-summary-of-the-model)
6 | 2. [Setup instructions and dependancies](#2-setup-instructions-and-dependancies)
7 | 3. [Repository Overview](#3-repository-overview)
8 | 4. [Running the model](#4-running-the-model)
9 | 5. [Some results of the paper](#5-some-results-of-the-paper)
10 | 6. [Contact](#6-contact)
11 | 7. [License](#7-license)
12 |
13 | ## 1. Summary of the Model
14 | The following shows the training procedure for our proposed model
15 |
16 |
17 |
18 | We propose a training procedure for semi-supervised segmentation using the principles of image-to-image translation using GANs. The proposed procedure has been evaluated on three segmentation datasets, namely **VOC**, **Cityscapes**, **ACDC**. We are easily able to achieve *2-4% improvement in the mean IoU for all of our semisupervised model* as compared to the supervised model on the same amount of data. For further information regarding the model, training procedure details you may refer to the paper for further details.
19 |
20 | ## 2. Setup instructions and dependancies
21 | The code has been written in *Python 3.6* and *Pytorch v1.0* with *Torchvision v0.3*. You can install all the dependancies for the model by running the following command in a virtual environment
22 | ```
23 | pip install -r requirements.txt
24 | ```
25 | For training/testing the model, you must first download all the 3 datasets. You can download the processed version of the dataset [here](https://drive.google.com/drive/folders/1v_ahAjNGPHjw9G15dlxjImtolne6HZ0B?usp=sharing). Also for storing the results of the validation/testing datasets, checkpoints and tensorboard logs, the directory structure must in the following way:
26 |
27 | .
28 | ├── arch
29 | │ ├── ...
30 | ├── data # Follow the way the dataset has been placed here
31 | │ ├── ACDC # Here the ACDC dataset must be placed
32 | │ └── Cityscape # Here the Cityscapes dataset must be placed
33 | │ └── VOC2012 # Here the VOC train/val dataset must be placed
34 | │ └── VOC2012test # Here the VOC test dataset must be placed
35 | ├── data_utils
36 | │ ├── ...
37 | ├── checkpoints # Create this directory to store checkpoints
38 | ├── examples
39 | ├── main.py
40 | ├── model.py
41 | ├── README.md
42 | ├── testing.py
43 | ├── utils.py
44 | ├── validation.py
45 | ├── results # Create this directory for storing the results
46 | │ ├── supervised # Directory for storing supervised results
47 | │ └── unsupervised # Directory for storing semisupervised results
48 | └── tensorboard_results # Create this directory to store tensorboard log curves
49 |
50 |
51 | ## 3. Repository Overview
52 | The following are the information regarding the various important files in the directory and their function:
53 |
54 | - `arch` : The directory stores the architectures for the generators and discriminators used in our model
55 | - `data_utils` : The dataloaders and also helper functions for data processing
56 | - `main.py` : Stores the various hyperparameter information and default settings
57 | - `model.py` : Stores the training procedure for both supervised and semisupervised model, and also checkpointing and logging details
58 | - `utils.py` : Stores the helper functions required for training
59 |
60 | ## 4. Running the model
61 | You configure the various defaults that are being specified in the `main.py` file. And also modify the supervision percentage on the dataset by modifying the dataloader calling function in the `model.py` file.
62 |
63 | For training/validation/testing the our proposed semisupervised model:
64 |
65 | ```
66 | python main.py --model 'semisupervised_cycleGAN' --dataset 'voc2012' --gpu_ids '0' --training True
67 | ```
68 |
69 | Similar commands for the *validation* and *testing* can be put up by replacing `--training` with `--validation` and `--testing` respectively.
70 |
71 | ## 5. Some results of the paper
72 | Some of the results produced by our semisupervised model are as follows. *For more such results, consider seeing the main paper and also its supplementary section*
73 |
74 |
75 |
76 | *The generated labels are the labels obtained from semisupervised model while the generated image are the images obtained by passing labels from label to image network*
77 |
78 | ## 6. Contact
79 | If you have found our research work helpful, please consider citing the original paper.
80 |
81 | If you have any doubt regarding the codebase, you can open up an issue or mail at sanu.arnab@gmail.com / aagarwal@ma.iitr.ac.in
82 |
83 | ## 7. License
84 | This repository is licensed under MIT license
85 |
--------------------------------------------------------------------------------
/data_utils/augmentations.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import random
4 | import numpy as np
5 | import torchvision.transforms.functional as tf
6 | from PIL import Image, ImageOps
7 |
8 |
9 | class Compose(object):
10 | def __init__(self, augmentations):
11 | self.augmentations = augmentations
12 | self.PIL2Numpy = False
13 |
14 | def __call__(self, img, mask):
15 | if isinstance(img, np.ndarray):
16 | img = Image.fromarray(img, mode="RGB")
17 | mask = Image.fromarray(mask, mode="L")
18 | self.PIL2Numpy = True
19 |
20 | assert img.size == mask.size
21 | for a in self.augmentations:
22 | img, mask = a(img, mask)
23 |
24 | if self.PIL2Numpy:
25 | img, mask = np.array(img), np.array(mask, dtype=np.uint8)
26 |
27 | return img, mask
28 |
29 |
30 | class RandomCrop(object):
31 | def __init__(self, size, padding=0):
32 | if isinstance(size, numbers.Number):
33 | self.size = (int(size), int(size))
34 | else:
35 | self.size = size
36 | self.padding = padding
37 |
38 | def __call__(self, img, mask):
39 | if self.padding > 0:
40 | img = ImageOps.expand(img, border=self.padding, fill=0)
41 | mask = ImageOps.expand(mask, border=self.padding, fill=0)
42 |
43 | assert img.size == mask.size
44 | w, h = img.size
45 | th, tw = self.size
46 | if w == tw and h == th:
47 | return img, mask
48 | if w < tw or h < th:
49 | return (
50 | img.resize((tw, th), Image.BILINEAR),
51 | mask.resize((tw, th), Image.NEAREST),
52 | )
53 |
54 | x1 = random.randint(0, w - tw)
55 | y1 = random.randint(0, h - th)
56 | return (
57 | img.crop((x1, y1, x1 + tw, y1 + th)),
58 | mask.crop((x1, y1, x1 + tw, y1 + th)),
59 | )
60 |
61 |
62 | class CenterCrop(object):
63 | def __init__(self, size):
64 | if isinstance(size, numbers.Number):
65 | self.size = (int(size), int(size))
66 | else:
67 | self.size = size
68 |
69 | def __call__(self, img, mask):
70 | assert img.size == mask.size
71 | w, h = img.size
72 | th, tw = self.size
73 | x1 = int(round((w - tw) / 2.))
74 | y1 = int(round((h - th) / 2.))
75 | return (
76 | img.crop((x1, y1, x1 + tw, y1 + th)),
77 | mask.crop((x1, y1, x1 + tw, y1 + th)),
78 | )
79 |
80 |
81 | class RandomRotate(object):
82 | def __init__(self, degree):
83 | self.degree = degree
84 |
85 | def __call__(self, img, mask):
86 | rotate_degree = random.random() * 2 * self.degree - self.degree
87 | return (
88 | tf.affine(img,
89 | translate=(0, 0),
90 | scale=1.0,
91 | angle=rotate_degree,
92 | resample=Image.BILINEAR,
93 | fillcolor=(0, 0, 0),
94 | shear=0.0),
95 | tf.affine(mask,
96 | translate=(0, 0),
97 | scale=1.0,
98 | angle=rotate_degree,
99 | resample=Image.NEAREST,
100 | fillcolor=250,
101 | shear=0.0))
102 |
103 |
104 | class Scale(object):
105 | def __init__(self, size):
106 | self.size = size
107 |
108 | def __call__(self, img, mask):
109 | assert img.size == mask.size
110 | w, h = img.size
111 | if (w >= h and w == self.size) or (h >= w and h == self.size):
112 | return img, mask
113 | if w > h:
114 | ow = self.size
115 | oh = int(self.size * h / w)
116 | return (
117 | img.resize((ow, oh), Image.BILINEAR),
118 | mask.resize((ow, oh), Image.NEAREST),
119 | )
120 | else:
121 | oh = self.size
122 | ow = int(self.size * w / h)
123 | return (
124 | img.resize((ow, oh), Image.BILINEAR),
125 | mask.resize((ow, oh), Image.NEAREST),
126 | )
127 |
128 |
129 | class RandomSizedCrop(object):
130 | def __init__(self, size):
131 | self.size = size
132 |
133 | def __call__(self, img, mask):
134 | assert img.size == mask.size
135 | for attempt in range(10):
136 | area = img.size[0] * img.size[1]
137 | target_area = random.uniform(0.45, 1.0) * area
138 | aspect_ratio = random.uniform(0.5, 2)
139 |
140 | w = int(round(math.sqrt(target_area * aspect_ratio)))
141 | h = int(round(math.sqrt(target_area / aspect_ratio)))
142 |
143 | if random.random() < 0.5:
144 | w, h = h, w
145 |
146 | if w <= img.size[0] and h <= img.size[1]:
147 | x1 = random.randint(0, img.size[0] - w)
148 | y1 = random.randint(0, img.size[1] - h)
149 |
150 | img = img.crop((x1, y1, x1 + w, y1 + h))
151 | mask = mask.crop((x1, y1, x1 + w, y1 + h))
152 | assert img.size == (w, h)
153 |
154 | return (
155 | img.resize((self.size, self.size), Image.BILINEAR),
156 | mask.resize((self.size, self.size), Image.NEAREST),
157 | )
158 |
159 | # Fallback
160 | scale = Scale(self.size)
161 | crop = CenterCrop(self.size)
162 | return crop(*scale(img, mask))
163 |
164 |
165 | class RandomSized(object):
166 | def __init__(self, size):
167 | self.size = size
168 | self.scale = Scale(self.size)
169 | self.crop = RandomCrop(self.size)
170 |
171 | def __call__(self, img, mask):
172 | assert img.size == mask.size
173 |
174 | w = int(random.uniform(0.5, 2) * img.size[0])
175 | h = int(random.uniform(0.5, 2) * img.size[1])
176 |
177 | img, mask = (
178 | img.resize((w, h), Image.BILINEAR),
179 | mask.resize((w, h), Image.NEAREST),
180 | )
181 |
182 | return self.crop(*self.scale(img, mask))
183 |
--------------------------------------------------------------------------------
/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import *
2 | import numpy as np
3 | import torch
4 | import random
5 | from PIL import ImageOps, ImageFilter, ImageEnhance
6 | from .dataloader import VOCDataset, CityscapesDataset, ACDCDataset
7 |
8 |
9 | def colormap(n):
10 | cmap = np.zeros([n, 3]).astype(np.uint8)
11 |
12 | for i in np.arange(n):
13 | r, g, b = np.zeros(3)
14 |
15 | for j in np.arange(8):
16 | r = r + (1 << (7 - j)) * ((i & (1 << (3 * j))) >> (3 * j))
17 | g = g + (1 << (7 - j)) * ((i & (1 << (3 * j + 1))) >> (3 * j + 1))
18 | b = b + (1 << (7 - j)) * ((i & (1 << (3 * j + 2))) >> (3 * j + 2))
19 |
20 | cmap[i, :] = np.array([r, g, b])
21 |
22 | return cmap
23 |
24 |
25 | class Relabel:
26 | def __init__(self, olabel, nlabel):
27 | '''
28 | Converts a particular label value in the tensor to another
29 |
30 | Parameters
31 | ----------
32 | olabel : old label which needs to be changed
33 | nlabel : new label which will replace the old label
34 | '''
35 | self.olabel = olabel
36 | self.nlabel = nlabel
37 |
38 | def __call__(self, tensor):
39 | assert isinstance(tensor, torch.LongTensor), 'tensor needs to be LongTensor'
40 | tensor[tensor == self.olabel] = self.nlabel
41 | return tensor
42 |
43 |
44 | class ToLabel:
45 |
46 | def __call__(self, image):
47 | '''
48 | Used to change the image from dim(N x H x W) to dim(N x 1 x H x W)
49 | '''
50 | return torch.from_numpy(np.array(image)).long().unsqueeze(0)
51 |
52 |
53 | class Colorize:
54 |
55 | def __init__(self, n=22):
56 | self.cmap = colormap(256)
57 | self.cmap[n] = self.cmap[-1]
58 | self.cmap = torch.from_numpy(self.cmap[:n])
59 |
60 | def __call__(self, gray_image):
61 | size = gray_image.size()
62 |
63 | try:
64 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
65 | except:
66 | color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0)
67 |
68 | for label in range(1, len(self.cmap)):
69 | mask = gray_image == label
70 |
71 | color_image[0][mask] = self.cmap[label][0]
72 | color_image[1][mask] = self.cmap[label][1]
73 | color_image[2][mask] = self.cmap[label][2]
74 |
75 | return color_image
76 |
77 |
78 | def PILaugment(img, mask):
79 | if random.random() > 0.2:
80 | (w, h) = img.size
81 | (w_, h_) = mask.size
82 | assert (w == w_ and h == h_), 'The size should be the same.'
83 | crop = random.uniform(0.45, 0.75)
84 | W = int(crop * w)
85 | H = int(crop * h)
86 | start_x = w - W
87 | start_y = h - H
88 | x_pos = int(random.uniform(0, start_x))
89 | y_pos = int(random.uniform(0, start_y))
90 | img = img.crop((x_pos, y_pos, x_pos + W, y_pos + H))
91 | mask = mask.crop((x_pos, y_pos, x_pos + W, y_pos + H))
92 |
93 | if random.random() > 0.2:
94 | img = ImageOps.flip(img)
95 | mask = ImageOps.flip(mask)
96 |
97 | if random.random() > 0.2:
98 | img = ImageOps.mirror(img)
99 | mask = ImageOps.mirror(mask)
100 |
101 | if random.random() > 0.2:
102 | angle = random.random() * 90 - 45
103 | img = img.rotate(angle)
104 | mask = mask.rotate(angle)
105 | if random.random() > 0.95:
106 | img = img.filter(ImageFilter.GaussianBlur(2))
107 |
108 | if random.random() > 0.95:
109 | img = ImageEnhance.Contrast(img).enhance(1)
110 |
111 | if random.random() > 0.95:
112 | img = ImageEnhance.Brightness(img).enhance(1)
113 |
114 | return img, mask
115 |
116 |
117 | def get_transformation(size, resize=False, dataset = 'voc2012'):
118 | '''
119 | Used to return a transformation based on the size given, that is, returns an apt sized tensor
120 | If resize = True then Resizing else CenterCrop
121 | '''
122 |
123 | assert dataset in ['voc2012', 'cityscapes', 'acdc'], 'The dataset name must be set correctly in the get_transformation function'
124 | if dataset == 'voc2012' or dataset == 'cityscapes':
125 | if resize:
126 | transfom_lst = [
127 | Resize(size),
128 | CenterCrop(size),
129 | ToTensor(),
130 | Normalize([.5, .5, .5], [.5, .5, .5])
131 | ]
132 | else:
133 | transfom_lst = [
134 | CenterCrop(size),
135 | ToTensor(),
136 | Normalize([.5, .5, .5], [.5, .5, .5])
137 | ]
138 | elif dataset == 'acdc':
139 | if resize:
140 | transfom_lst = [
141 | Resize(size),
142 | CenterCrop(size),
143 | ToTensor(),
144 | Normalize([.5], [.5])
145 | ]
146 | else:
147 | transfom_lst = [
148 | CenterCrop(size),
149 | ToTensor(),
150 | Normalize([.5], [.5])
151 | ]
152 |
153 | input_transform = Compose(transfom_lst)
154 |
155 | ### This is because we don't need the Relabel function for the case of cityscapes dataset
156 | if dataset == 'voc2012':
157 | if resize:
158 | target_transform = Compose([
159 | Resize(size, interpolation=0),
160 | CenterCrop(size),
161 | ToLabel(),
162 | Relabel(255, 0) ## So as to replace the 255(boundaries) label as 0
163 | ])
164 |
165 | else:
166 | target_transform = Compose([
167 | CenterCrop(size),
168 | ToLabel(),
169 | Relabel(255, 0) ## So as to replace the 255(boundaries) label as 0
170 | ])
171 |
172 | elif dataset == 'cityscapes' or dataset == 'acdc':
173 | if resize:
174 | target_transform = Compose([
175 | Resize(size, interpolation=0),
176 | CenterCrop(size),
177 | ToLabel()
178 | ])
179 |
180 | else:
181 | target_transform = Compose([
182 | CenterCrop(size),
183 | ToLabel(),
184 | ])
185 |
186 | transform = {'img': input_transform, 'gt': target_transform}
187 | return transform
188 |
--------------------------------------------------------------------------------
/validation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch import nn
5 | from torch.autograd import Variable
6 | import torchvision
7 | import torchvision.datasets as dsets
8 | from torch.utils.data import DataLoader
9 | import torchvision.transforms as transforms
10 | import utils
11 | from utils import make_one_hot
12 | from PIL import Image
13 | from arch import define_Gen
14 | from data_utils import VOCDataset, CityscapesDataset, ACDCDataset, get_transformation
15 |
16 | root = './data/VOC2012'
17 | root_cityscapes = './data/Cityscape'
18 | root_acdc = './data/ACDC'
19 |
20 | def validation(args):
21 |
22 | ### For selecting the number of channels
23 | if args.dataset == 'voc2012':
24 | n_channels = 21
25 | elif args.dataset == 'cityscapes':
26 | n_channels = 20
27 | elif args.dataset == 'acdc':
28 | n_channels = 4
29 |
30 | transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset)
31 |
32 | ## let the choice of dataset configurable
33 | if args.dataset == 'voc2012':
34 | val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=transform, augmentation=None)
35 | elif args.dataset == 'cityscapes':
36 | val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform, augmentation=None)
37 | elif args.dataset == 'acdc':
38 | val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform, augmentation=None)
39 |
40 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)
41 |
42 | Gsi = define_Gen(input_nc=3, output_nc=n_channels, ngf=args.ngf, netG='deeplab',
43 | norm=args.norm, use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
44 |
45 | Gis = define_Gen(input_nc=n_channels, output_nc=3, ngf=args.ngf, netG='deeplab',
46 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)
47 |
48 | ### best_iou
49 | best_iou = 0
50 |
51 | ### Interpolation
52 | interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True)
53 |
54 | ### Softmax activation
55 | activation_softmax = nn.Softmax2d()
56 | activation_tanh = nn.Tanh()
57 |
58 | if(args.model == 'supervised_model'):
59 |
60 | ### loading the checkpoint
61 | try:
62 | ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir))
63 | Gsi.load_state_dict(ckpt['Gsi'])
64 | best_iou = ckpt['best_iou']
65 |
66 | except:
67 | print(' [*] No checkpoint!')
68 |
69 | ### run
70 | Gsi.eval()
71 | for i, (image_test, real_segmentation, image_name) in enumerate(val_loader):
72 | image_test = utils.cuda(image_test, args.gpu_ids)
73 | seg_map = Gsi(image_test)
74 | seg_map = interp(seg_map)
75 | seg_map = activation_softmax(seg_map)
76 |
77 | prediction = seg_map.data.max(1)[1].squeeze_(1).cpu().numpy() ### To convert from 22 --> 1 channel
78 | for j in range(prediction.shape[0]):
79 | new_img = prediction[j] ### Taking a particular image from the batch
80 | new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image
81 |
82 | ### Now the new_img is PIL.Image
83 | new_img.save(os.path.join(args.validation_dir+'/supervised/'+image_name[j]+'.png'))
84 |
85 |
86 | print('Epoch-', str(i+1), ' Done!')
87 |
88 | print('The iou of the resulting segment maps: ', str(best_iou))
89 |
90 |
91 | elif(args.model == 'semisupervised_cycleGAN'):
92 |
93 | ### loading the checkpoint
94 | try:
95 | ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir))
96 | Gsi.load_state_dict(ckpt['Gsi'])
97 | Gis.load_state_dict(ckpt['Gis'])
98 | best_iou = ckpt['best_iou']
99 |
100 | except:
101 | print(' [*] No checkpoint!')
102 |
103 | ### run
104 | Gsi.eval()
105 | for i, (image_test, real_segmentation, image_name) in enumerate(val_loader):
106 | image_test, real_segmentation = utils.cuda([image_test, real_segmentation], args.gpu_ids)
107 | seg_map = Gsi(image_test)
108 | seg_map = interp(seg_map)
109 | seg_map = activation_softmax(seg_map)
110 | fake_img = Gis(seg_map).detach()
111 | fake_img = interp(fake_img)
112 | fake_img = activation_tanh(fake_img)
113 |
114 | fake_img_from_labels = Gis(make_one_hot(real_segmentation, args.dataset, args.gpu_ids).float()).detach()
115 | fake_img_from_labels = interp(fake_img_from_labels)
116 | fake_img_from_labels = activation_tanh(fake_img_from_labels)
117 | fake_label_regenerated = Gsi(fake_img_from_labels).detach()
118 | fake_label_regenerated = interp(fake_label_regenerated)
119 | fake_label_regenerated = activation_softmax(fake_label_regenerated)
120 |
121 | prediction = seg_map.data.max(1)[1].squeeze_(1).cpu().numpy() ### To convert from 22 --> 1 channel
122 | fake_regenerated_label = fake_label_regenerated.data.max(1)[1].squeeze_(1).cpu().numpy()
123 |
124 | fake_img = fake_img.cpu()
125 | fake_img_from_labels = fake_img_from_labels.cpu()
126 |
127 | ### Now i am going to revert back the transformation on these images
128 | if args.dataset == 'voc2012' or args.dataset == 'cityscapes':
129 | trans_mean = [0.5, 0.5, 0.5]
130 | trans_std = [0.5, 0.5, 0.5]
131 | for k in range(3):
132 | fake_img[:, k, :, :] = ((fake_img[:, k, :, :] * trans_std[k]) + trans_mean[k])
133 | fake_img_from_labels[:, k, :, :] = ((fake_img_from_labels[:, k, :, :] * trans_std[k]) + trans_mean[k])
134 |
135 | elif args.dataset == 'acdc':
136 | trans_mean = [0.5]
137 | trans_std = [0.5]
138 | for k in range(1):
139 | fake_img[:, k, :, :] = ((fake_img[:, k, :, :] * trans_std[k]) + trans_mean[k])
140 | fake_img_from_labels[:, k, :, :] = ((fake_img_from_labels[:, k, :, :] * trans_std[k]) + trans_mean[k])
141 |
142 | for j in range(prediction.shape[0]):
143 | new_img = prediction[j] ### Taking a particular image from the batch
144 | new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image
145 |
146 | regen_label = fake_regenerated_label[j]
147 | regen_label = utils.colorize_mask(regen_label, args.dataset)
148 |
149 | ### Now the new_img is PIL.Image
150 | new_img.save(os.path.join(args.validation_dir+'/unsupervised/generated_labels/'+image_name[j]+'.png'))
151 | regen_label.save(os.path.join(args.validation_dir+'/unsupervised/regenerated_labels/'+image_name[j]+'.png'))
152 | torchvision.utils.save_image(fake_img[j], os.path.join(args.validation_dir+'/unsupervised/regenerated_image/'+image_name[j]+'.jpg'))
153 | torchvision.utils.save_image(fake_img_from_labels[j], os.path.join(args.validation_dir+'/unsupervised/image_from_labels/'+image_name[j]+'.jpg'))
154 |
155 | print('Epoch-', str(i+1), ' Done!')
156 |
157 | print('The iou of the resulting segment maps: ', str(best_iou))
158 |
159 |
--------------------------------------------------------------------------------
/data_utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | import torch
5 | import scipy.misc as m
6 | import scipy.io as sio
7 |
8 | from torch.utils.data import Dataset
9 | from utils import recursive_glob
10 | from PIL import Image
11 | from .augmentations import *
12 | import re
13 |
14 |
15 | class VOCDataset(Dataset):
16 | '''
17 | We assume that there will be txt to note all the image names
18 |
19 |
20 | color map:
21 | 0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle # 6=bus, 7=car, 8=cat, 9=chair, 10=cow, 11=diningtable,
22 | 12=dog, 13=horse, 14=motorbike, 15=person # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor,
23 | 21=boundaries(self-defined)
24 |
25 | Also it will return an image and ground truth as it is present in the image form, so ground truth won't
26 | be in one-hot form and rather would be a 2D tensor. To convert the labels in one-hot form in the training
27 | code we will be calling the function 'make_one_hot' function of utils.py
28 |
29 | '''
30 |
31 | def __init__(self, root_path, name='label', ratio=0.5, transformation=None, augmentation=None):
32 | super(VOCDataset, self).__init__()
33 | self.root_path = root_path
34 | self.ratio = ratio
35 | self.name = name
36 | assert transformation is not None, 'transformation must be provided, give None'
37 | self.transformation = transformation
38 | self.augmentation = augmentation
39 | assert name in ('label', 'unlabel','val', 'test'), 'dataset name should be restricted in "label", "unlabel", "test" and "val", given %s' % name
40 | assert 0 <= ratio <= 1, 'the ratio between "labeled" and "unlabeled" should be between 0 and 1, given %.1f' % ratio
41 | np.random.seed(1) ### Because of this we are not getting repeated images for labelled and unlabelled data
42 |
43 | if self.name != 'test':
44 | train_imgs = pd.read_table(os.path.join(self.root_path, 'ImageSets/Segmentation', 'trainvalAug.txt')).values.reshape(-1)
45 |
46 | val_imgs = pd.read_table(os.path.join(self.root_path, 'ImageSets/Segmentation', 'val.txt')).values.reshape(-1)
47 |
48 | labeled_imgs = np.random.choice(train_imgs, size=int(self.ratio * train_imgs.__len__()), replace=False)
49 | labeled_imgs = list(labeled_imgs)
50 |
51 | unlabeled_imgs = [x for x in train_imgs if x not in labeled_imgs]
52 |
53 | ### Now here we equalize the lengths of labelled and unlabelled imgs by just repeating up some images
54 | if self.ratio > 0.5:
55 | new_ratio = round((self.ratio/(1-self.ratio + 1e-6)), 1)
56 | excess_ratio = new_ratio - 1
57 | new_list_1 = unlabeled_imgs * int(excess_ratio)
58 | new_list_2 = list(np.random.choice(np.array(unlabeled_imgs), size=int((excess_ratio - int(excess_ratio))*unlabeled_imgs.__len__()), replace=False))
59 | unlabeled_imgs += (new_list_1 + new_list_2)
60 | elif self.ratio < 0.5:
61 | new_ratio = round(((1-self.ratio)/(self.ratio + 1e-6)), 1)
62 | excess_ratio = new_ratio - 1
63 | new_list_1 = labeled_imgs * int(excess_ratio)
64 | new_list_2 = list(np.random.choice(np.array(labeled_imgs), size=int((excess_ratio - int(excess_ratio))*labeled_imgs.__len__()), replace=False))
65 | labeled_imgs += (new_list_1 + new_list_2)
66 |
67 | if self.name == 'test':
68 | test_imgs = pd.read_table(os.path.join(self.root_path, 'ImageSets/Segmentation', 'test.txt')).values.reshape(-1)
69 | # test_imgs = np.array(test_imgs)
70 |
71 | if self.name == 'label':
72 | self.imgs = labeled_imgs
73 | elif self.name == 'unlabel':
74 | self.imgs = unlabeled_imgs
75 | elif self.name == 'val':
76 | self.imgs = val_imgs
77 | elif self.name == 'test':
78 | self.imgs = test_imgs
79 | else:
80 | raise ('{} not defined'.format(self.name))
81 |
82 | self.gts = self.imgs
83 |
84 | def __getitem__(self, index):
85 |
86 | if self.name == 'test':
87 | img_path = os.path.join(self.root_path, 'JPEGImages', self.imgs[index] + '.jpg')
88 |
89 | img = Image.open(img_path).convert('RGB')
90 |
91 | if self.augmentation is not None:
92 | img = self.augmentation(img)
93 |
94 | if self.transformation:
95 | img = self.transformation['img'](img)
96 |
97 | return img, self.imgs[index]
98 |
99 | else:
100 | img_path = os.path.join(self.root_path, 'JPEGImages', self.imgs[index] + '.jpg')
101 | gt_path = os.path.join(self.root_path, 'SegmentationClassAug', self.gts[index] + '.png')
102 |
103 | img = Image.open(img_path).convert('RGB')
104 | gt = Image.open(gt_path) #.convert('P')
105 |
106 | if self.augmentation is not None:
107 | img, gt = self.augmentation(img, gt)
108 |
109 | if self.transformation:
110 | img = self.transformation['img'](img)
111 | gt = self.transformation['gt'](gt)
112 |
113 | return img, gt, self.imgs[index]
114 |
115 | def __len__(self):
116 | return len(self.imgs)
117 |
118 |
119 | class CityscapesDataset(Dataset):
120 |
121 | def __init__(
122 | self,
123 | root_path,
124 | name="train",
125 | ratio=0.5,
126 | transformation=False,
127 | augmentation=None
128 | ):
129 | self.root = root_path
130 | self.name = name
131 | assert transformation is not None, 'transformation must be provided, give None'
132 | self.transformation = transformation
133 | self.augmentation = augmentation
134 | self.n_classes = 20
135 | self.ratio = ratio
136 | self.files = {}
137 |
138 | assert name in ('label', 'unlabel','val', 'test'), 'dataset name should be restricted in "label", "unlabel", "test" and "val", given %s' % name
139 |
140 | if self.name != 'test':
141 | self.images_base = os.path.join(self.root, "leftImg8bit")
142 | self.annotations_base = os.path.join(
143 | self.root, "gtFine", 'trainval'
144 | )
145 | else:
146 | self.images_base = os.path.join(self.root, "leftImg8bit", 'test')
147 | # self.annotations_base = os.path.join(
148 | # self.root, "gtFine", 'test'
149 | # )
150 |
151 | np.random.seed(1)
152 |
153 | if self.name != 'test':
154 | train_imgs = recursive_glob(rootdir=os.path.join(self.images_base, 'train'), suffix=".png")
155 | train_imgs = np.array(train_imgs)
156 |
157 | val_imgs = recursive_glob(rootdir=os.path.join(self.images_base, 'val'), suffix=".png")
158 |
159 | labeled_imgs = np.random.choice(train_imgs, size=int(self.ratio * train_imgs.__len__()), replace=False)
160 | labeled_imgs = list(labeled_imgs)
161 |
162 | unlabeled_imgs = [x for x in train_imgs if x not in labeled_imgs]
163 |
164 | ### Now here we equalize the lengths of labelled and unlabelled imgs by just repeating up some images
165 | if self.ratio > 0.5:
166 | new_ratio = round((self.ratio/(1-self.ratio + 1e-6)), 1)
167 | excess_ratio = new_ratio - 1
168 | new_list_1 = unlabeled_imgs * int(excess_ratio)
169 | new_list_2 = list(np.random.choice(np.array(unlabeled_imgs), size=int((excess_ratio - int(excess_ratio))*unlabeled_imgs.__len__()), replace=False))
170 | unlabeled_imgs += (new_list_1 + new_list_2)
171 | elif self.ratio < 0.5:
172 | new_ratio = round(((1-self.ratio)/(self.ratio + 1e-6)), 1)
173 | excess_ratio = new_ratio - 1
174 | new_list_1 = labeled_imgs * int(excess_ratio)
175 | new_list_2 = list(np.random.choice(np.array(labeled_imgs), size=int((excess_ratio - int(excess_ratio))*labeled_imgs.__len__()), replace=False))
176 | labeled_imgs += (new_list_1 + new_list_2)
177 |
178 | else:
179 | test_imgs = recursive_glob(rootdir=self.images_base, suffix=".png")
180 |
181 | if(self.name == 'label'):
182 | self.files[name] = list(labeled_imgs)
183 | elif(self.name == 'unlabel'):
184 | self.files[name] = list(unlabeled_imgs)
185 | elif(self.name == 'val'):
186 | self.files[name] = list(val_imgs)
187 | elif(self.name == 'test'):
188 | self.files[name] = list(test_imgs)
189 |
190 | '''
191 | This pattern for the various classes has been borrowed from the official repo for Cityscapes dataset
192 | You can see it here: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py
193 | '''
194 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
195 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
196 | self.class_names = [
197 | "road",
198 | "sidewalk",
199 | "building",
200 | "wall",
201 | "fence",
202 | "pole",
203 | "traffic_light",
204 | "traffic_sign",
205 | "vegetation",
206 | "terrain",
207 | "sky",
208 | "person",
209 | "rider",
210 | "car",
211 | "truck",
212 | "bus",
213 | "train",
214 | "motorcycle",
215 | "bicycle",
216 | "unlabelled"
217 | ]
218 |
219 | self.ignore_index = 250
220 | self.class_map = dict(zip(self.valid_classes, range(19)))
221 |
222 | if not self.files[name]:
223 | raise Exception(
224 | "No files for name=[%s] found in %s" % (name, self.images_base)
225 | )
226 |
227 | print("Found %d %s images" % (len(self.files[name]), name))
228 |
229 | def __len__(self):
230 | """__len__"""
231 | return len(self.files[self.name])
232 |
233 | def __getitem__(self, index):
234 | """__getitem__
235 | :param index:
236 | """
237 | if self.name == 'test':
238 | img_path = self.files[self.name][index].rstrip()
239 |
240 | img = Image.open(img_path).convert('RGB')
241 |
242 | if self.augmentation is not None:
243 | img = self.augmentation(img)
244 |
245 | if self.transformation:
246 | img = self.transformation['img'](img)
247 |
248 | return img, re.sub(r'.*/', '', img_path[34:]).rstrip('.png') ### These numbers have been hard coded so as to get a suitable name for the model
249 |
250 | else:
251 | img_path = self.files[self.name][index].rstrip()
252 | lbl_path = os.path.join(
253 | self.annotations_base,
254 | img_path.split(os.sep)[-2],
255 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png",
256 | )
257 |
258 | img = Image.open(img_path).convert('RGB')
259 | lbl = Image.open(lbl_path)
260 |
261 | if self.augmentation is not None:
262 | img, lbl = self.augmentation(img, lbl)
263 |
264 | if self.transformation:
265 | img = self.transformation['img'](img)
266 | lbl = self.transformation['gt'](lbl)
267 |
268 | lbl = self.encode_segmap(lbl)
269 |
270 | return img, lbl, re.sub(r'.*/', '', img_path[38:]).rstrip('.png') ### These numbers have been hard coded so as to get a suitable name for the model
271 |
272 | def encode_segmap(self, mask):
273 | # Put all void classes to ignore index
274 | for _voidc in self.void_classes:
275 | mask[mask == _voidc] = self.ignore_index
276 | for _validc in self.valid_classes:
277 | mask[mask == _validc] = self.class_map[_validc]
278 | mask[mask == self.ignore_index] = 19 ### Just a mapping between the two color values
279 | return mask
280 |
281 |
282 | class ACDCDataset(Dataset):
283 | '''
284 | The dataloader for ACDC dataset
285 | '''
286 |
287 | split_ratio = [0.85, 0.15]
288 | '''
289 | this split ratio is for the train (including the labeled and unlabeled) and the val dataset
290 | '''
291 |
292 | def __init__(self, root_path, name='label', ratio=0.5, transformation=None, augmentation=None):
293 | super(ACDCDataset, self).__init__()
294 | self.root = root_path
295 | self.name = name
296 | assert transformation is not None, 'transformation must be provided, give None'
297 | self.transformation = transformation
298 | self.augmentation = augmentation
299 | self.ratio = ratio
300 | self.files = {}
301 |
302 | if self.name != 'test':
303 | self.images_base = os.path.join(self.root, 'training')
304 | self.annotations_base = os.path.join(self.root, 'training_gt')
305 | else:
306 | self.images_base = os.path.join(self.root, 'testing')
307 | # self.annotations_base = os.path.join(
308 | # self.root, "gtFine", 'test'
309 | # )
310 |
311 | np.random.seed(1)
312 |
313 | if self.name != 'test':
314 | total_imgs = os.listdir(self.images_base)
315 | total_imgs = np.array(total_imgs)
316 |
317 | train_imgs = np.random.choice(total_imgs, size=int(self.__class__.split_ratio[0] * total_imgs.__len__()),replace=False)
318 | val_imgs = [x for x in total_imgs if x not in train_imgs]
319 |
320 | labeled_imgs = np.random.choice(train_imgs, size=int(self.ratio * train_imgs.__len__()), replace=False)
321 | labeled_imgs = list(labeled_imgs)
322 |
323 | unlabeled_imgs = [x for x in train_imgs if x not in labeled_imgs]
324 |
325 | ### Now here we equalize the lengths of labelled and unlabelled imgs by just repeating up some images
326 | if self.ratio > 0.5:
327 | new_ratio = round((self.ratio/(1-self.ratio + 1e-6)), 1)
328 | excess_ratio = new_ratio - 1
329 | new_list_1 = unlabeled_imgs * int(excess_ratio)
330 | new_list_2 = list(np.random.choice(np.array(unlabeled_imgs), size=int((excess_ratio - int(excess_ratio))*unlabeled_imgs.__len__()), replace=False))
331 | unlabeled_imgs += (new_list_1 + new_list_2)
332 | elif self.ratio < 0.5:
333 | new_ratio = round(((1-self.ratio)/(self.ratio + 1e-6)), 1)
334 | excess_ratio = new_ratio - 1
335 | new_list_1 = labeled_imgs * int(excess_ratio)
336 | new_list_2 = list(np.random.choice(np.array(labeled_imgs), size=int((excess_ratio - int(excess_ratio))*labeled_imgs.__len__()), replace=False))
337 | labeled_imgs += (new_list_1 + new_list_2)
338 |
339 | else:
340 | test_imgs = os.listdir(self.images_base)
341 |
342 | if(self.name == 'label'):
343 | self.files[name] = list(labeled_imgs)
344 | elif(self.name == 'unlabel'):
345 | self.files[name] = list(unlabeled_imgs)
346 | elif(self.name == 'val'):
347 | self.files[name] = list(val_imgs)
348 | elif(self.name == 'test'):
349 | self.files[name] = list(test_imgs)
350 |
351 | if not self.files[name]:
352 | raise Exception(
353 | "No files for name=[%s] found in %s" % (name, self.images_base)
354 | )
355 |
356 | print("Found %d %s images" % (len(self.files[name]), name))
357 |
358 | def __len__(self):
359 | """__len__"""
360 | return len(self.files[self.name])
361 |
362 | def __getitem__(self, index):
363 | """__getitem__
364 | :param index:
365 | """
366 | if self.name == 'test':
367 | img_path = os.path.join(self.images_base, self.files[self.name][index])
368 |
369 | img = Image.open(img_path)
370 |
371 | if self.augmentation is not None:
372 | img = self.augmentation(img)
373 |
374 | if self.transformation:
375 | img = self.transformation['img'](img)
376 |
377 | return img, self.files[self.name][index].rstrip('.jpg')
378 |
379 | else:
380 | img_path = os.path.join(self.images_base, self.files[self.name][index])
381 | lbl_path = os.path.join(self.annotations_base, self.files[self.name][index].rstrip('.jpg') + '.png')
382 |
383 | img = Image.open(img_path)
384 | lbl = Image.open(lbl_path)
385 |
386 | if self.augmentation is not None:
387 | img, lbl = self.augmentation(img, lbl)
388 |
389 | if self.transformation:
390 | img = self.transformation['img'](img)
391 | lbl = self.transformation['gt'](lbl)
392 |
393 | return img, lbl, self.files[self.name][index].rstrip('.jpg')
394 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch.nn as nn
7 | import torch
8 | from torch.autograd import Variable
9 | from PIL import Image
10 | from torchvision import models
11 | from collections import namedtuple
12 | from torchvision import transforms
13 |
14 |
15 | '''
16 | The palette is used to convert the segmentation map having values from 0-21 back to paletted image
17 | '''
18 | palette = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128,
19 | 128, 128, 128, 64, 0, 0, 192, 0, 0, 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128,
20 | 64, 128, 128, 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0, 0, 64, 128]
21 |
22 | cityscape_palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153,
23 | 250, 170, 30, 220, 220, 0, 107, 142, 35, 152, 251, 152, 0, 130, 180, 220, 20, 60,
24 | 255, 0, 0, 0, 0, 142, 0, 0, 70, 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
25 |
26 | acdc_palette = [0, 0, 0, 128, 64, 128, 70, 70, 70, 250, 170, 30]
27 |
28 | zero_pad = 256 * 3 - len(palette)
29 | for i in range(zero_pad):
30 | palette.append(0)
31 |
32 | zero_pad = 256 * 3 - len(cityscape_palette)
33 | for i in range(zero_pad):
34 | cityscape_palette.append(0)
35 |
36 | zero_pad = 256 * 3 - len(acdc_palette)
37 | for i in range(zero_pad):
38 | acdc_palette.append(0)
39 |
40 |
41 | def colorize_mask(mask, dataset):
42 | '''
43 | Used to convert the segmentation of one channel(mask) back to a paletted image
44 | '''
45 | # mask: numpy array of the mask
46 | assert dataset in ('voc2012', 'cityscapes', 'acdc')
47 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
48 | if (dataset == 'voc2012'):
49 | new_mask.putpalette(palette)
50 | elif (dataset == 'cityscapes'):
51 | new_mask.putpalette(cityscape_palette)
52 | elif (dataset == 'acdc'):
53 | new_mask.putpalette(acdc_palette)
54 |
55 | return new_mask
56 |
57 | ### To convert a paletted image to a tensor image of 3 dimension
58 | ### This is because a simple paletted image cannot be viewed with all the details
59 | def PIL_to_tensor(img, dataset):
60 | '''
61 | Here img is of the type PIL.Image
62 | '''
63 | assert dataset in ('voc2012', 'cityscapes', 'acdc')
64 | img_arr = np.array(img, dtype='float32')
65 | new_arr = np.zeros([3, img_arr.shape[0], img_arr.shape[1]], dtype='float32')
66 |
67 | if (dataset == 'voc2012'):
68 | for i in range(img_arr.shape[0]):
69 | for j in range(img_arr.shape[1]):
70 | # new_arr[i, :, :] = img_arr
71 | index = int(img_arr[i, j]*3)
72 | new_arr[0, i, j] = palette[index]
73 | new_arr[1, i, j] = palette[index+1]
74 | new_arr[2, i, j] = palette[index+2]
75 | elif (dataset == 'cityscapes'):
76 | for i in range(img_arr.shape[0]):
77 | for j in range(img_arr.shape[1]):
78 | # new_arr[i, :, :] = img_arr
79 | index = int(img_arr[i, j]*3)
80 | new_arr[0, i, j] = cityscape_palette[index]
81 | new_arr[1, i, j] = cityscape_palette[index+1]
82 | new_arr[2, i, j] = cityscape_palette[index+2]
83 | elif (dataset == 'acdc'):
84 | for i in range(img_arr.shape[0]):
85 | for j in range(img_arr.shape[1]):
86 | # new_arr[i, :, :] = img_arr
87 | index = int(img_arr[i, j]*3)
88 | new_arr[0, i, j] = acdc_palette[index]
89 | new_arr[1, i, j] = acdc_palette[index+1]
90 | new_arr[2, i, j] = acdc_palette[index+2]
91 |
92 | return_tensor = torch.tensor(new_arr)
93 |
94 | return return_tensor
95 |
96 | def smoothen_label(label, alpha, gpu_id):
97 | '''
98 | For smoothening of the classification labels
99 |
100 | labels : tensor having dimensrions: batch_size*21*H*W filled with zeroes and ones
101 | '''
102 | torch.manual_seed(0)
103 | try:
104 | smoothen_array = -1*alpha + torch.rand([label.shape[0], label.shape[1], label.shape[2], label.shape[3]]) * (2*alpha)
105 | smoothen_array = cuda(smoothen_array, gpu_id)
106 | label = label + smoothen_array
107 | except:
108 | smoothen_array = -1*alpha + torch.rand([label.shape[0], label.shape[1], label.shape[2], label.shape[3]]) * (2*alpha)
109 | label = label + smoothen_array
110 |
111 | return label
112 |
113 | '''
114 | To be used to apply gaussian noise in the input to the discriminator
115 | '''
116 | class GaussianNoise(nn.Module):
117 | """Gaussian noise regularizer.
118 |
119 | Args:
120 | sigma (float, optional): relative standard deviation used to generate the
121 | noise. Relative means that it will be multiplied by the magnitude of
122 | the value your are adding the noise to. This means that sigma can be
123 | the same regardless of the scale of the vector.
124 | is_relative_detach (bool, optional): whether to detach the variable before
125 | computing the scale of the noise. If `False` then the scale of the noise
126 | won't be seen as a constant but something to optimize: this will bias the
127 | network to generate vectors with smaller values.
128 | """
129 |
130 | def __init__(self, sigma=0.1, is_relative_detach=True):
131 | super().__init__()
132 | self.sigma = sigma
133 | self.is_relative_detach = is_relative_detach
134 |
135 | def forward(self, x):
136 | if self.training and self.sigma != 0:
137 | scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
138 | sampled_noise = torch.zeros(x.size()).normal_() * scale
139 | x = x + sampled_noise
140 | return x
141 |
142 | '''
143 | This will be used for calculation of perceptual losses
144 | '''
145 | class Vgg16(torch.nn.Module):
146 | def __init__(self, requires_grad=False):
147 | super(Vgg16, self).__init__()
148 | vgg_pretrained_features = models.vgg16(pretrained=True).features
149 | self.slice1 = torch.nn.Sequential()
150 | self.slice2 = torch.nn.Sequential()
151 | self.slice3 = torch.nn.Sequential()
152 | self.slice4 = torch.nn.Sequential()
153 | for x in range(4):
154 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
155 | for x in range(4, 9):
156 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
157 | for x in range(9, 16):
158 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
159 | for x in range(16, 23):
160 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
161 | if not requires_grad:
162 | for param in self.parameters():
163 | param.requires_grad = False
164 |
165 | def forward(self, X):
166 | h = self.slice1(X)
167 | h_relu1_2 = h
168 | h = self.slice2(h)
169 | h_relu2_2 = h
170 | h = self.slice3(h)
171 | h_relu3_3 = h
172 | h = self.slice4(h)
173 | h_relu4_3 = h
174 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
175 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
176 | return out
177 |
178 | '''
179 | The definition of the perceptual loss
180 | '''
181 | def perceptual_loss(x, y, gpu_ids):
182 | """
183 | Calculates the perceptual loss on the basis of the VGG network
184 | Parameters:
185 | x, y: the images between which perceptual loss is to be calculated
186 | """
187 |
188 | ### Considering the fact in this case x,y both are images in the range -1 to 1 and we need normal distribution
189 | ### before passing through VGG
190 |
191 | u = x*0.5 + 0.5
192 | v = y*0.5 + 0.5
193 | trans_mean = [0.485, 0.456, 0.406]
194 | trans_std = [0.229, 0.224, 0.225]
195 |
196 | for i in range(3):
197 | u[:, i, :, :] = u[:, i, :, :]*trans_std[i] + trans_mean[i]
198 | v[:, i, :, :] = v[:, i, :, :]*trans_std[i] + trans_mean[i]
199 |
200 | ### Now this is normal distribution
201 |
202 | vgg = Vgg16(requires_grad=False).cuda(gpu_ids[0])
203 |
204 | features_y = vgg(v)
205 | features_x = vgg(u)
206 |
207 | mse_loss = nn.MSELoss()
208 | loss = mse_loss(features_y.relu2_2, features_x.relu2_2)
209 |
210 | return loss
211 |
212 |
213 | # To make directories
214 | def mkdir(paths):
215 | for path in paths:
216 | if not os.path.isdir(path):
217 | os.makedirs(path)
218 |
219 |
220 | # To make cuda tensor
221 | def cuda(xs, gpu_id):
222 | if torch.cuda.is_available():
223 | if not isinstance(xs, (list, tuple)):
224 | return xs.cuda(int(gpu_id[0]))
225 | else:
226 | return [x.cuda(int(gpu_id[0])) for x in xs]
227 | return xs
228 |
229 |
230 | # For Pytorch datasets loader
231 | def create_link(dataset_dir):
232 | dirs = {}
233 | dirs['trainA'] = os.path.join(dataset_dir, 'ltrainA')
234 | dirs['trainB'] = os.path.join(dataset_dir, 'ltrainB')
235 | dirs['testA'] = os.path.join(dataset_dir, 'ltestA')
236 | dirs['testB'] = os.path.join(dataset_dir, 'ltestB')
237 | mkdir(dirs.values())
238 |
239 | for key in dirs:
240 | try:
241 | os.remove(os.path.join(dirs[key], 'Link'))
242 | except:
243 | pass
244 | os.symlink(os.path.abspath(os.path.join(dataset_dir, key)),
245 | os.path.join(dirs[key], 'Link'))
246 |
247 | return dirs
248 |
249 |
250 | def get_traindata_link(dataset_dir):
251 | dirs = {}
252 | dirs['trainA'] = os.path.join(dataset_dir, 'ltrainA')
253 | dirs['trainB'] = os.path.join(dataset_dir, 'ltrainB')
254 | return dirs
255 |
256 |
257 | def get_testdata_link(dataset_dir):
258 | dirs = {}
259 | dirs['testA'] = os.path.join(dataset_dir, 'ltestA')
260 | dirs['testB'] = os.path.join(dataset_dir, 'ltestB')
261 | return dirs
262 |
263 |
264 | # To save the checkpoint
265 | def save_checkpoint(state, save_path):
266 | torch.save(state, save_path)
267 |
268 |
269 | # To load the checkpoint
270 | def load_checkpoint(ckpt_path, map_location='cpu'):
271 | ckpt = torch.load(ckpt_path, map_location=map_location)
272 | print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
273 | return ckpt
274 |
275 |
276 | # To store 50 generated image in a pool and sample from it when it is full
277 | # Shrivastava et al’s strategy
278 | class Sample_from_Pool(object):
279 | def __init__(self, max_elements=50):
280 | self.max_elements = max_elements
281 | self.cur_elements = 0
282 | self.items = []
283 |
284 | def __call__(self, in_items):
285 | return_items = []
286 | for in_item in in_items:
287 | if self.cur_elements < self.max_elements:
288 | self.items.append(in_item)
289 | self.cur_elements = self.cur_elements + 1
290 | return_items.append(in_item)
291 | else:
292 | if np.random.ranf() > 0.5:
293 | idx = np.random.randint(0, self.max_elements)
294 | tmp = copy.copy(self.items[idx])
295 | self.items[idx] = in_item
296 | return_items.append(tmp)
297 | else:
298 | return_items.append(in_item)
299 | return return_items
300 |
301 |
302 | def recursive_glob(rootdir=".", suffix=""):
303 | """Performs recursive glob with given suffix and rootdir
304 | :param rootdir is the root directory
305 | :param suffix is the suffix to be searched
306 | """
307 | return [
308 | os.path.join(looproot, filename)
309 | for looproot, _, filenames in os.walk(rootdir)
310 | for filename in filenames
311 | if filename.endswith(suffix)
312 | ]
313 |
314 | def make_one_hot(labels, dataname, gpu_id):
315 | '''
316 | Converts an integer label torch.autograd.Variable to a one-hot Variable.
317 |
318 | Parameters
319 | ----------
320 | labels : torch.autograd.Variable of torch.cuda.LongTensor
321 | N x 1 x H x W, where N is batch size.
322 | Each value is an integer representing correct classification.
323 | C : integer.
324 | number of classes in labels.
325 |
326 | Returns
327 | -------
328 | target : torch.autograd.Variable of torch.cuda.FloatTensor
329 | N x C x H x W, where C is class number. One-hot encoded.
330 | '''
331 | assert dataname in ('voc2012', 'cityscapes', 'acdc'), 'dataset name should be one of the following: \'voc2012\',given {}'.format(dataname)
332 |
333 | if dataname == 'voc2012':
334 | C = 21
335 | elif dataname == 'cityscapes':
336 | C = 20
337 | elif dataname == 'acdc':
338 | C = 4
339 | else:
340 | raise NotImplementedError
341 |
342 | labels = labels.long()
343 | try:
344 | one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
345 | one_hot = cuda(one_hot, gpu_id)
346 | except:
347 | one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
348 | target = one_hot.scatter_(1, labels.data, 1)
349 |
350 | return target
351 |
352 |
353 | # Adapted from score written by wkentaro
354 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
355 |
356 |
357 | class runningScore(object):
358 | def __init__(self, n_classes, dataset):
359 | self.n_classes = n_classes
360 | self.confusion_matrix = np.zeros((n_classes, n_classes))
361 | self.dataset = dataset
362 |
363 | def _fast_hist(self, label_true, label_pred, n_class):
364 | mask = (label_true >= 0) & (label_true < n_class)
365 | hist = np.bincount(
366 | n_class * label_true[mask].astype(int) + label_pred[mask],
367 | minlength=n_class ** 2,
368 | ).reshape(n_class, n_class)
369 | return hist
370 |
371 | def update(self, label_trues, label_preds):
372 | for lt, lp in zip(label_trues, label_preds):
373 | self.confusion_matrix += self._fast_hist(
374 | lt.flatten(), lp.flatten(), self.n_classes
375 | )
376 |
377 | def get_scores(self):
378 | """Returns accuracy score evaluation result.
379 | - overall accuracy
380 | - mean accuracy
381 | - mean IU
382 | - fwavacc
383 | """
384 | hist = self.confusion_matrix
385 | acc = np.diag(hist).sum() / hist.sum()
386 | acc_cls = np.diag(hist) / hist.sum(axis=1)
387 | acc_cls = np.nanmean(acc_cls)
388 | if self.dataset == 'voc2012':
389 | iu = np.diag(hist[1:self.n_classes, 1:self.n_classes]) / (hist[1:self.n_classes, 1:self.n_classes].sum(axis=1) + hist[1:self.n_classes, 1:self.n_classes].sum(axis=0) - np.diag(hist[1:self.n_classes, 1:self.n_classes]))
390 | elif self.dataset == 'cityscapes':
391 | iu = np.diag(hist[0:self.n_classes-1, 0:self.n_classes-1]) / (hist[0:self.n_classes-1, 0:self.n_classes-1].sum(axis=1) + hist[0:self.n_classes-1, 0:self.n_classes-1].sum(axis=0) - np.diag(hist[0:self.n_classes-1, 0:self.n_classes-1]))
392 | elif self.dataset == 'acdc':
393 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
394 | mean_iu = np.nanmean(iu)
395 | # freq = hist.sum(axis=1) / hist.sum()
396 | # fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
397 | if self.dataset == 'voc2012' or self.dataset == 'cityscapes':
398 | cls_iu = dict(zip(range(self.n_classes-1), iu))
399 | elif self.dataset == 'acdc':
400 | cls_iu = dict(zip(range(self.n_classes), iu))
401 |
402 | return (
403 | {
404 | "Overall Acc: \t": acc,
405 | "Mean Acc : \t": acc_cls,
406 | "Mean IoU : \t": mean_iu,
407 | },
408 | cls_iu,
409 | )
410 |
411 | def reset(self):
412 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
413 |
414 |
415 | class averageMeter(object):
416 | """Computes and stores the average and current value"""
417 |
418 | def __init__(self):
419 | self.reset()
420 |
421 | def reset(self):
422 | self.val = 0
423 | self.avg = 0
424 | self.sum = 0
425 | self.count = 0
426 |
427 | def update(self, val, n=1):
428 | self.val = val
429 | self.sum += val * n
430 | self.count += n
431 | self.avg = self.sum / self.count
432 |
433 |
434 | class LambdaLR():
435 | def __init__(self, epochs, offset, decay_epoch):
436 | self.epochs = epochs
437 | self.offset = offset
438 | self.decay_epoch = decay_epoch
439 |
440 | def step(self, epoch):
441 | return 1.0 - max(0, epoch + self.offset - self.decay_epoch)/(self.epochs - self.decay_epoch)
442 |
443 | def print_networks(nets, names):
444 | print('------------Number of Parameters---------------')
445 | i=0
446 | for net in nets:
447 | num_params = 0
448 | for param in net.parameters():
449 | num_params += param.numel()
450 | print('[Network %s] Total number of parameters : %.3f M' % (names[i], num_params / 1e6))
451 | i=i+1
452 | print('-----------------------------------------------')
453 |
--------------------------------------------------------------------------------
/arch/generators.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | from torch import nn
4 | from .ops import conv_norm_relu, dconv_norm_relu, ResidualBlock, get_norm_layer, init_network, InitialBlock, RegularBottleneck, UpsamplingBottleneck, DownsamplingBottleneck, SSnbt, Downsample_Block_led, APN
5 |
6 |
7 | class UnetSkipConnectionBlock(nn.Module):
8 | def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False,
9 | innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
10 | super(UnetSkipConnectionBlock, self).__init__()
11 | self.outermost = outermost
12 | if type(norm_layer) == functools.partial:
13 | use_bias = norm_layer.func == nn.InstanceNorm2d
14 | else:
15 | use_bias = norm_layer == nn.InstanceNorm2d
16 | if input_nc is None:
17 | input_nc = outer_nc
18 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
19 |
20 | if outermost:
21 | upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1)
22 | down = [downconv]
23 | up = [nn.ReLU(True), upconv]
24 | model = down + [submodule] + up
25 | elif innermost:
26 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
27 | down = [nn.LeakyReLU(0.2, True), downconv]
28 | up = [nn.ReLU(True), upconv, norm_layer(outer_nc)]
29 | model = down + up
30 | else:
31 | upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
32 | down = [nn.LeakyReLU(0.2, True), downconv, norm_layer(inner_nc)]
33 | up = [nn.ReLU(True), upconv, norm_layer(outer_nc)]
34 |
35 | if use_dropout:
36 | model = down + [submodule] + up + [nn.Dropout(0.5)]
37 | else:
38 | model = down + [submodule] + up
39 |
40 | self.model = nn.Sequential(*model)
41 |
42 | def forward(self, x):
43 | if self.outermost:
44 | return self.model(x)
45 | else:
46 | return torch.cat([x, self.model(x)], 1)
47 |
48 | class UnetGenerator(nn.Module):
49 | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
50 | norm_layer=nn.BatchNorm2d, use_dropout=False):
51 | super(UnetGenerator, self).__init__()
52 |
53 | unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=None, norm_layer=norm_layer, innermost=True)
54 | for i in range(num_downs - 5):
55 | unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
56 | unet_block = UnetSkipConnectionBlock(ngf*4, ngf*8, submodule=unet_block, norm_layer=norm_layer)
57 | unet_block = UnetSkipConnectionBlock(ngf*2, ngf*4, submodule=unet_block, norm_layer=norm_layer)
58 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer)
59 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
60 | self.unet_model = unet_block
61 |
62 | def forward(self, input):
63 | return self.unet_model(input)
64 |
65 | class ResnetGenerator(nn.Module):
66 | def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True, num_blocks=6, softmax=False):
67 | super(ResnetGenerator, self).__init__()
68 | if type(norm_layer) == functools.partial:
69 | use_bias = norm_layer.func == nn.InstanceNorm2d
70 | else:
71 | use_bias = norm_layer == nn.InstanceNorm2d
72 |
73 | res_model = [nn.ReflectionPad2d(3),
74 | conv_norm_relu(input_nc, ngf * 1, 7, norm_layer=norm_layer, bias=use_bias),
75 | conv_norm_relu(ngf * 1, ngf * 2, 3, 2, 1, norm_layer=norm_layer, bias=use_bias),
76 | conv_norm_relu(ngf * 2, ngf * 4, 3, 2, 1, norm_layer=norm_layer, bias=use_bias)]
77 |
78 | for i in range(num_blocks):
79 | res_model += [ResidualBlock(ngf * 4, norm_layer, use_dropout, use_bias)]
80 |
81 | if softmax:
82 | res_model += [dconv_norm_relu(ngf * 4, ngf * 2, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias),
83 | dconv_norm_relu(ngf * 2, ngf * 1, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias),
84 | nn.ReflectionPad2d(3),
85 | nn.Conv2d(ngf, output_nc, 7)]
86 | else:
87 | res_model += [dconv_norm_relu(ngf * 4, ngf * 2, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias),
88 | dconv_norm_relu(ngf * 2, ngf * 1, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias),
89 | nn.ReflectionPad2d(3),
90 | nn.Conv2d(ngf, output_nc, 7),
91 | nn.Tanh()]
92 | self.res_model = nn.Sequential(*res_model)
93 |
94 | def forward(self, x):
95 | return self.res_model(x)
96 |
97 |
98 | class ENet(nn.Module):
99 | '''
100 | num_classes : The number of classes to segment the image into
101 | encoder_relu : whether to include relu or prelu for encoder part
102 | decoder_relu : whether to include relu or prelu for decoder part
103 | '''
104 | def __init__(self, num_classes, encoder_relu=False, decoder_relu=True):
105 | super().__init__()
106 |
107 | self.initial_block = InitialBlock(3, 16, padding=1, relu=encoder_relu)
108 |
109 | # Stage 1 - Encoder
110 | self.downsample1_0 = DownsamplingBottleneck(
111 | 16,
112 | 64,
113 | padding=1,
114 | return_indices=True,
115 | dropout_prob=0.01,
116 | relu=encoder_relu)
117 | self.regular1_1 = RegularBottleneck(
118 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu)
119 | self.regular1_2 = RegularBottleneck(
120 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu)
121 | self.regular1_3 = RegularBottleneck(
122 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu)
123 | self.regular1_4 = RegularBottleneck(
124 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu)
125 |
126 | # Stage 2 - Encoder
127 | self.downsample2_0 = DownsamplingBottleneck(
128 | 64,
129 | 128,
130 | padding=1,
131 | return_indices=True,
132 | dropout_prob=0.1,
133 | relu=encoder_relu)
134 | self.regular2_1 = RegularBottleneck(
135 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu)
136 | self.dilated2_2 = RegularBottleneck(
137 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
138 | self.asymmetric2_3 = RegularBottleneck(
139 | 128,
140 | kernel_size=5,
141 | padding=2,
142 | asymmetric=True,
143 | dropout_prob=0.1,
144 | relu=encoder_relu)
145 | self.dilated2_4 = RegularBottleneck(
146 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
147 | self.regular2_5 = RegularBottleneck(
148 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu)
149 | self.dilated2_6 = RegularBottleneck(
150 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
151 | self.asymmetric2_7 = RegularBottleneck(
152 | 128,
153 | kernel_size=5,
154 | asymmetric=True,
155 | padding=2,
156 | dropout_prob=0.1,
157 | relu=encoder_relu)
158 | self.dilated2_8 = RegularBottleneck(
159 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
160 |
161 | # Stage 3 - Encoder
162 | self.regular3_0 = RegularBottleneck(
163 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu)
164 | self.dilated3_1 = RegularBottleneck(
165 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
166 | self.asymmetric3_2 = RegularBottleneck(
167 | 128,
168 | kernel_size=5,
169 | padding=2,
170 | asymmetric=True,
171 | dropout_prob=0.1,
172 | relu=encoder_relu)
173 | self.dilated3_3 = RegularBottleneck(
174 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
175 | self.regular3_4 = RegularBottleneck(
176 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu)
177 | self.dilated3_5 = RegularBottleneck(
178 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
179 | self.asymmetric3_6 = RegularBottleneck(
180 | 128,
181 | kernel_size=5,
182 | asymmetric=True,
183 | padding=2,
184 | dropout_prob=0.1,
185 | relu=encoder_relu)
186 | self.dilated3_7 = RegularBottleneck(
187 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
188 |
189 | # Stage 4 - Decoder
190 | self.upsample4_0 = UpsamplingBottleneck(
191 | 128, 64, padding=1, dropout_prob=0.1, relu=decoder_relu)
192 | self.regular4_1 = RegularBottleneck(
193 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu)
194 | self.regular4_2 = RegularBottleneck(
195 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu)
196 |
197 | # Stage 5 - Decoder
198 | self.upsample5_0 = UpsamplingBottleneck(
199 | 64, 16, padding=1, dropout_prob=0.1, relu=decoder_relu)
200 | self.regular5_1 = RegularBottleneck(
201 | 16, padding=1, dropout_prob=0.1, relu=decoder_relu)
202 | self.transposed_conv = nn.ConvTranspose2d(
203 | 16,
204 | num_classes,
205 | kernel_size=3,
206 | stride=2,
207 | padding=1,
208 | output_padding=1,
209 | bias=False)
210 |
211 | def forward(self, x):
212 | # Initial block
213 | x = self.initial_block(x)
214 |
215 | # Stage 1 - Encoder
216 | x, max_indices1_0 = self.downsample1_0(x)
217 | x = self.regular1_1(x)
218 | x = self.regular1_2(x)
219 | x = self.regular1_3(x)
220 | x = self.regular1_4(x)
221 |
222 | # Stage 2 - Encoder
223 | x, max_indices2_0 = self.downsample2_0(x)
224 | x = self.regular2_1(x)
225 | x = self.dilated2_2(x)
226 | x = self.asymmetric2_3(x)
227 | x = self.dilated2_4(x)
228 | x = self.regular2_5(x)
229 | x = self.dilated2_6(x)
230 | x = self.asymmetric2_7(x)
231 | x = self.dilated2_8(x)
232 |
233 | # Stage 3 - Encoder
234 | x = self.regular3_0(x)
235 | x = self.dilated3_1(x)
236 | x = self.asymmetric3_2(x)
237 | x = self.dilated3_3(x)
238 | x = self.regular3_4(x)
239 | x = self.dilated3_5(x)
240 | x = self.asymmetric3_6(x)
241 | x = self.dilated3_7(x)
242 |
243 | # Stage 4 - Decoder
244 | x = self.upsample4_0(x, max_indices2_0)
245 | x = self.regular4_1(x)
246 | x = self.regular4_2(x)
247 |
248 | # Stage 5 - Decoder
249 | x = self.upsample5_0(x, max_indices1_0)
250 | x = self.regular5_1(x)
251 | x = self.transposed_conv(x)
252 |
253 | return x
254 |
255 |
256 | class LEDNet(nn.Module):
257 | '''
258 | Implementation of the LEDNet network
259 | '''
260 |
261 | def __init__(self, in_channels=3, n_classes=22, encoder_relu=False, decoder_relu=True, image_dim=128):
262 | super().__init__()
263 |
264 | ### Encoder
265 | self.downsample_1 = Downsample_Block_led(in_channels, 32, kernel_size=3, padding=1, relu=encoder_relu)
266 |
267 | self.ssnbt_1_1 = SSnbt(32, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01)
268 | self.ssnbt_1_2 = SSnbt(32, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01)
269 | self.ssnbt_1_3 = SSnbt(32, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01)
270 |
271 | self.downsample_2 = Downsample_Block_led(32, 64, kernel_size=3, padding=1, relu=encoder_relu)
272 |
273 | self.ssnbt_2_1 = SSnbt(64, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01)
274 | self.ssnbt_2_2 = SSnbt(64, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01)
275 |
276 | self.downsample_3 = Downsample_Block_led(64, 128, kernel_size=3, padding=1, relu=encoder_relu)
277 |
278 | self.ssnbt_3_1 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=1, relu=encoder_relu, drop_prob=0.1)
279 | self.ssnbt_3_2 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=2, relu=encoder_relu, drop_prob=0.1)
280 | self.ssnbt_3_3 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=5, relu=encoder_relu, drop_prob=0.1)
281 | self.ssnbt_3_4 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=9, relu=encoder_relu, drop_prob=0.1)
282 | self.ssnbt_3_5 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=2, relu=encoder_relu, drop_prob=0.1)
283 | self.ssnbt_3_6 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=5, relu=encoder_relu, drop_prob=0.1)
284 | self.ssnbt_3_7 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=9, relu=encoder_relu, drop_prob=0.1)
285 | self.ssnbt_3_8 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=17, relu=encoder_relu, drop_prob=0.1)
286 |
287 | ### Decoder
288 | self.apn = APN(128, n_classes, image_dim= image_dim // 8, relu=decoder_relu)
289 |
290 | def forward(self, x):
291 |
292 | ### Encoder
293 | x = self.downsample_1(x)
294 | x = self.ssnbt_1_1(x)
295 | x = self.ssnbt_1_2(x)
296 | x = self.ssnbt_1_3(x)
297 |
298 | x = self.downsample_2(x)
299 | x = self.ssnbt_2_1(x)
300 | x = self.ssnbt_2_2(x)
301 |
302 | x = self.downsample_3(x)
303 | x = self.ssnbt_3_1(x)
304 | x = self.ssnbt_3_2(x)
305 | x = self.ssnbt_3_3(x)
306 | x = self.ssnbt_3_4(x)
307 | x = self.ssnbt_3_5(x)
308 | x = self.ssnbt_3_6(x)
309 | x = self.ssnbt_3_7(x)
310 | x = self.ssnbt_3_8(x)
311 |
312 | ### Decoder
313 | out = self.apn(x)
314 |
315 | return out
316 |
317 |
318 | ### Code for deeplab model
319 | ### This code is modified from: https://github.com/hfslyc/AdvSemiSeg
320 | class Bottleneck(nn.Module):
321 | expansion = 4
322 |
323 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
324 | super(Bottleneck, self).__init__()
325 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
326 | self.bn1 = nn.BatchNorm2d(planes)
327 | for i in self.bn1.parameters():
328 | i.requires_grad = False
329 |
330 | padding = dilation
331 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
332 | padding=padding, bias=False, dilation = dilation)
333 | self.bn2 = nn.BatchNorm2d(planes)
334 | for i in self.bn2.parameters():
335 | i.requires_grad = False
336 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
337 | self.bn3 = nn.BatchNorm2d(planes * 4)
338 | for i in self.bn3.parameters():
339 | i.requires_grad = False
340 | self.relu = nn.ReLU(inplace=True)
341 | self.downsample = downsample
342 | self.stride = stride
343 |
344 |
345 | def forward(self, x):
346 | residual = x
347 |
348 | out = self.conv1(x)
349 | out = self.bn1(out)
350 | out = self.relu(out)
351 |
352 | out = self.conv2(out)
353 | out = self.bn2(out)
354 | out = self.relu(out)
355 |
356 | out = self.conv3(out)
357 | out = self.bn3(out)
358 |
359 | if self.downsample is not None:
360 | residual = self.downsample(x)
361 |
362 | out += residual
363 | out = self.relu(out)
364 |
365 | return out
366 |
367 | class Classifier_Module(nn.Module):
368 |
369 | def __init__(self, dilation_series, padding_series, num_classes):
370 | super(Classifier_Module, self).__init__()
371 | self.conv2d_list = nn.ModuleList()
372 | for dilation, padding in zip(dilation_series, padding_series):
373 | self.conv2d_list.append(nn.Conv2d(2048, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias = True))
374 |
375 | for m in self.conv2d_list:
376 | m.weight.data.normal_(0, 0.01)
377 |
378 | def forward(self, x):
379 | out = self.conv2d_list[0](x)
380 | for i in range(len(self.conv2d_list)-1):
381 | out += self.conv2d_list[i+1](x)
382 | return out
383 |
384 | class ResNet(nn.Module):
385 | def __init__(self, in_channels, block, layers, num_classes):
386 | self.inplanes = 64
387 | super(ResNet, self).__init__()
388 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
389 | bias=False)
390 | self.bn1 = nn.BatchNorm2d(64)
391 | for i in self.bn1.parameters():
392 | i.requires_grad = False
393 | self.relu = nn.ReLU(inplace=True)
394 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
395 | self.layer1 = self._make_layer(block, 64, layers[0])
396 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
397 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
398 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
399 | self.layer5 = self._make_pred_layer(Classifier_Module, [6,12,18,24],[6,12,18,24],num_classes)
400 |
401 | for m in self.modules():
402 | if isinstance(m, nn.Conv2d):
403 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
404 | m.weight.data.normal_(0, 0.01)
405 | elif isinstance(m, nn.BatchNorm2d):
406 | m.weight.data.fill_(1)
407 | m.bias.data.zero_()
408 | # for i in m.parameters():
409 | # i.requires_grad = False
410 |
411 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
412 | downsample = None
413 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
414 | downsample = nn.Sequential(
415 | nn.Conv2d(self.inplanes, planes * block.expansion,
416 | kernel_size=1, stride=stride, bias=False),
417 | nn.BatchNorm2d(planes * block.expansion))
418 | for i in downsample._modules['1'].parameters():
419 | i.requires_grad = False
420 | layers = []
421 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample))
422 | self.inplanes = planes * block.expansion
423 | for i in range(1, blocks):
424 | layers.append(block(self.inplanes, planes, dilation=dilation))
425 |
426 | return nn.Sequential(*layers)
427 | def _make_pred_layer(self,block, dilation_series, padding_series,num_classes):
428 | return block(dilation_series,padding_series,num_classes)
429 |
430 | def forward(self, x):
431 | x = self.conv1(x)
432 | x = self.bn1(x)
433 | x = self.relu(x)
434 | x = self.maxpool(x)
435 | x = self.layer1(x)
436 | x = self.layer2(x)
437 | x = self.layer3(x)
438 | x = self.layer4(x)
439 | x = self.layer5(x)
440 |
441 | return x
442 |
443 | def get_1x_lr_params_NOscale(self):
444 | """
445 | This generator returns all the parameters of the net except for
446 | the last classification layer. Note that for each batchnorm layer,
447 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
448 | any batchnorm parameter
449 | """
450 | b = []
451 |
452 | b.append(self.conv1)
453 | b.append(self.bn1)
454 | b.append(self.layer1)
455 | b.append(self.layer2)
456 | b.append(self.layer3)
457 | b.append(self.layer4)
458 |
459 |
460 | for i in range(len(b)):
461 | for j in b[i].modules():
462 | jj = 0
463 | for k in j.parameters():
464 | jj+=1
465 | if k.requires_grad:
466 | yield k
467 |
468 | def get_10x_lr_params(self):
469 | """
470 | This generator returns all the parameters for the last layer of the net,
471 | which does the classification of pixel into classes
472 | """
473 | b = []
474 | b.append(self.layer5.parameters())
475 |
476 | for j in range(len(b)):
477 | for i in b[j]:
478 | yield i
479 |
480 |
481 |
482 | def optim_parameters(self, args):
483 | return [{'params': self.get_1x_lr_params_NOscale(), 'lr': args.learning_rate},
484 | {'params': self.get_10x_lr_params(), 'lr': 10*args.learning_rate}]
485 |
486 |
487 | def define_Gen(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, gpu_ids=[0]):
488 | gen_net = None
489 | norm_layer = get_norm_layer(norm_type=norm)
490 |
491 | if netG == 'resnet_9blocks':
492 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=9, softmax=False)
493 | elif netG == 'resnet_9blocks_softmax':
494 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=9, softmax=True)
495 | elif netG == 'resnet_6blocks':
496 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=6, softmax=False)
497 | elif netG == 'resnet_6blocks_softmax':
498 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=6, softmax=True)
499 | elif netG == 'unet_128':
500 | gen_net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
501 | elif netG == 'unet_256':
502 | gen_net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
503 | elif netG == 'enet':
504 | ### For ENet the input_nc is always 3
505 | gen_net = ENet(num_classes = output_nc, encoder_relu=False, decoder_relu=True)
506 | elif netG == 'lednet_128':
507 | gen_net = LEDNet(in_channels=input_nc, n_classes= output_nc, encoder_relu=False, decoder_relu=True, image_dim=128)
508 | elif netG == 'lednet_256':
509 | gen_net = LEDNet(in_channels=input_nc, n_classes= output_nc, encoder_relu=False, decoder_relu=True, image_dim=256)
510 | elif netG == 'deeplab':
511 | gen_net = ResNet(in_channels=input_nc, block=Bottleneck, layers=[3, 4, 23, 3], num_classes=output_nc)
512 | else:
513 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
514 |
515 | return init_network(gen_net, gpu_ids)
516 |
--------------------------------------------------------------------------------
/arch/ops.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.nn import init
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch
6 |
7 | def get_norm_layer(norm_type='instance'):
8 | if norm_type == 'batch':
9 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
10 | elif norm_type == 'instance':
11 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
12 | else:
13 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
14 | return norm_layer
15 |
16 | def init_weights(net, init_type='normal', gain=0.02):
17 | def init_func(m):
18 | classname = m.__class__.__name__
19 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
20 | init.normal(m.weight.data, 0.0, gain)
21 | if hasattr(m, 'bias') and m.bias is not None:
22 | init.constant(m.bias.data, 0.0)
23 | elif classname.find('BatchNorm2d') != -1:
24 | init.normal(m.weight.data, 1.0, gain)
25 | init.constant(m.bias.data, 0.0)
26 |
27 | print('Network initialized with weights sampled from N(0,0.02).')
28 | net.apply(init_func)
29 |
30 |
31 | def init_network(net, gpu_ids=[]):
32 | if len(gpu_ids) > 0:
33 | assert(torch.cuda.is_available())
34 | net.cuda(gpu_ids[0])
35 | ### I have commented this out only while we are using the Deeplab model because as such it is not required in that
36 | init_weights(net)
37 | return net
38 |
39 |
40 | def conv_norm_lrelu(in_dim, out_dim, kernel_size, stride = 1, padding=0,
41 | norm_layer = nn.BatchNorm2d, bias = False):
42 | return nn.Sequential(
43 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias),
44 | norm_layer(out_dim), nn.LeakyReLU(0.2,True))
45 |
46 | def conv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0,
47 | norm_layer = nn.BatchNorm2d, bias = False):
48 | return nn.Sequential(
49 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias),
50 | norm_layer(out_dim), nn.ReLU(True))
51 |
52 | def dconv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0, output_padding=0,
53 | norm_layer = nn.BatchNorm2d, bias = False):
54 | return nn.Sequential(
55 | nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride,
56 | padding, output_padding, bias = bias),
57 | norm_layer(out_dim), nn.ReLU(True))
58 |
59 | class ResidualBlock(nn.Module):
60 | def __init__(self, dim, norm_layer, use_dropout, use_bias):
61 | super(ResidualBlock, self).__init__()
62 | res_block = [nn.ReflectionPad2d(1),
63 | conv_norm_relu(dim, dim, kernel_size=3,
64 | norm_layer= norm_layer, bias=use_bias)]
65 | if use_dropout:
66 | res_block += [nn.Dropout(0.5)]
67 | res_block += [nn.ReflectionPad2d(1),
68 | nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias),
69 | norm_layer(dim)]
70 |
71 | self.res_block = nn.Sequential(*res_block)
72 |
73 | def forward(self, x):
74 | return x + self.res_block(x)
75 |
76 |
77 | def set_grad(nets, requires_grad=False):
78 | for net in nets:
79 | for param in net.parameters():
80 | param.requires_grad = requires_grad
81 |
82 |
83 | #######
84 |
85 | ### A huge thanks to the repo: https://github.com/davidtvs/PyTorch-ENet
86 | ### A lot of ideas are being taken up from this repo
87 |
88 | '''
89 | Here we start off for writing the required blocks of code for the enet model
90 | '''
91 |
92 | class InitialBlock(nn.Module):
93 | '''
94 | This will be initial block of the network
95 | '''
96 |
97 | def __init__(self, input_dim = 3, output_dim = 16, kernel_size = 3, relu = True, padding=0, bias=False):
98 | super().__init__()
99 |
100 | if relu:
101 | activation = nn.ReLU()
102 | else:
103 | activation = nn.PReLU()
104 |
105 | ### So here we will define the main branch which will give the output by passing through the conv block
106 | self.main_branch = nn.Conv2d(in_channels = input_dim,
107 | out_channels = output_dim - 3,
108 | kernel_size = kernel_size,
109 | stride= 2,
110 | padding= padding,
111 | bias = bias)
112 |
113 | ### Extension branch, which is basically MaxPool2d in this case
114 | self.ext_branch = nn.MaxPool2d(kernel_size = kernel_size,
115 | stride= 2,
116 | padding= padding)
117 |
118 | ### Some of the extra basic blocks of activations and other things
119 | self.batch_norm = nn.BatchNorm2d(output_dim)
120 | self.activation = activation
121 |
122 | def forward(self, x):
123 | main = self.main_branch(x)
124 | ext = self.ext_branch(x)
125 |
126 | out = torch.cat((main, ext), 1) ### For concatenating their outputs
127 |
128 | out = self.batch_norm(out)
129 | out = self.activation(out)
130 |
131 | return out
132 |
133 | class RegularBottleneck(nn.Module):
134 | '''
135 | So in this also there are two branches:
136 | Main branch -->
137 | This is basically simply the input
138 | Extension Branch -->
139 | 1. 1*1 conv to decrease the spatial dimensions to a internal_ratio specified by the user
140 | 2. conv layer, which can be regular, atrous or asymmetric convolution
141 | 3. 1*1 to increase the spatial dimension back from internal_ratio to the original size
142 | 4. Regularizer which in this case is dropout
143 | '''
144 |
145 | def __init__(self, channels, internal_ratio = 4, kernel_size = 3, dilation = 1, padding = 1, asymmetric = False, dropout_prob=0, bias=False, relu=True):
146 | super().__init__()
147 |
148 | ### Now we need to ensure if the internal_ratio has a reasonable value
149 | assert internal_ratio > 1 and internal_ratio < channels , 'The value of the internal_ratio is not correct'
150 |
151 | internal_channels = channels // internal_ratio
152 |
153 | if relu:
154 | activation = nn.ReLU()
155 | else:
156 | activation = nn.PReLU()
157 |
158 | self.ext_conv1 = nn.Sequential(
159 | nn.Conv2d(in_channels = channels, out_channels = internal_channels, kernel_size = 1, stride = 1, padding = 0, bias=bias),
160 | nn.BatchNorm2d(internal_channels),
161 | activation)
162 |
163 | ### So here we are going to see for the case if we are given asymmetric convolutions
164 | ### So basically it is gonna be like, example for 3*3 we will break it into 3*1 and 1*3
165 | ### and then perform normal convolutions, first 3*1 and then 1*3
166 | if asymmetric:
167 | self.ext_conv2 = nn.Sequential(
168 | nn.Conv2d(internal_channels, internal_channels, (kernel_size, 1), stride = (1, 1), padding = (padding, 0), dilation=dilation, bias=bias),
169 | nn.BatchNorm2d(internal_channels),
170 | activation,
171 | nn.Conv2d(internal_channels, internal_channels, kernel_size = (1, kernel_size), stride=(1,1), padding=(0, padding), dilation=dilation, bias=bias),
172 | nn.BatchNorm2d(internal_channels),
173 | activation
174 | )
175 | else:
176 | self.ext_conv2 = nn.Sequential(
177 | nn.Conv2d(internal_channels, internal_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=bias),
178 | nn.BatchNorm2d(internal_channels),
179 | activation
180 | )
181 |
182 | ### Now we are gonna implement the 1*1 expansion code
183 | self.ext_conv3 = nn.Sequential(nn.Conv2d(internal_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias),
184 | nn.BatchNorm2d(channels),
185 | activation)
186 |
187 | self.ext_regularizer = nn.Dropout2d(p = dropout_prob)
188 |
189 | self.out_relu = activation
190 |
191 |
192 | def forward(self, x):
193 |
194 | main = x
195 |
196 | ext = self.ext_conv1(x)
197 | ext = self.ext_conv2(ext)
198 | ext = self.ext_conv3(ext)
199 | ext = self.ext_regularizer(ext)
200 |
201 | out = ext + main
202 |
203 | ### Activation applying after the addition
204 | return self.out_relu(out)
205 |
206 |
207 | class DownsamplingBottleneck(nn.Module):
208 | '''
209 | This is the Downsampling bottleneck layer to downsample the input
210 | '''
211 |
212 | def __init__(self, in_channels, out_channels, internal_ratio=4, kernel_size = 3, padding=0, return_indices = False, dropout_prob = 0, bias=False, relu = True):
213 | super().__init__()
214 |
215 | self.return_indices = return_indices
216 |
217 | ### Now we need to ensure if the internal_ratio has a reasonable value
218 | assert internal_ratio > 1 and internal_ratio < in_channels , 'The value of the internal_ratio is not correct'
219 |
220 | internal_channels = in_channels // internal_ratio
221 |
222 | if relu:
223 | activation = nn.ReLU()
224 | else:
225 | activation = nn.PReLU()
226 |
227 | self.main_max1 = nn.MaxPool2d(kernel_size = kernel_size, stride=2, padding=padding, return_indices=return_indices)
228 |
229 | ### The extension branch
230 | self.ext_conv1 = nn.Sequential(
231 | nn.Conv2d(in_channels, internal_channels, kernel_size=2, stride=2, bias=bias),
232 | nn.BatchNorm2d(internal_channels),
233 | activation)
234 |
235 | ### Now we are gonna do the conv operation
236 | self.ext_conv2 = nn.Sequential(
237 | nn.Conv2d(internal_channels, internal_channels, kernel_size= kernel_size, stride=1, padding=padding, bias=bias),
238 | nn.BatchNorm2d(internal_channels),
239 | activation
240 | )
241 |
242 | self.ext_conv3 = nn.Sequential(
243 | nn.Conv2d(internal_channels, out_channels, kernel_size=1, stride=1, bias=bias),
244 | nn.BatchNorm2d(out_channels),
245 | activation
246 | )
247 |
248 | self.ext_reg = nn.Dropout2d(p=dropout_prob)
249 |
250 | self.out_activation = activation
251 |
252 | def forward(self, x):
253 | if self.return_indices:
254 | main, max_indices = self.main_max1(x)
255 | else:
256 | main = self.main_max1(x)
257 |
258 | ext = self.ext_conv1(x)
259 | ext = self.ext_conv2(ext)
260 | ext = self.ext_conv3(ext)
261 | ext = self.ext_reg(ext)
262 |
263 | ### Main branch channel padding
264 | n, ch_ext, h, w = ext.size()
265 | ch_main = main.size()[1]
266 | padding = torch.zeros(n, ch_ext - ch_main, h, w)
267 |
268 | ### So before concatenating we would first have to see if the main is on gpu cause they both must be
269 | ### on the same memory
270 | if main.is_cuda:
271 | padding = padding.cuda()
272 |
273 |
274 | ### Concatenating
275 | main = torch.cat((main, padding), 1)
276 |
277 | out = main + ext
278 |
279 | if self.return_indices:
280 | return self.out_activation(out), max_indices
281 | else:
282 | return self.out_activation(out)
283 |
284 |
285 | class UpsamplingBottleneck(nn.Module):
286 | '''
287 | This is the Upsampling bottleneck as is displayed in the paper
288 | '''
289 |
290 | def __init__(self, in_channels, out_channels, internal_ratio = 4, kernel_size = 3, padding=0, dropout_prob = 0, bias=False, relu=True):
291 | super().__init__()
292 |
293 | ### Now we need to ensure if the internal_ratio has a reasonable value
294 | assert internal_ratio > 1 and internal_ratio < in_channels , 'The value of the internal_ratio is not correct'
295 |
296 | internal_channels = in_channels // internal_ratio
297 |
298 | if relu:
299 | activation = nn.ReLU()
300 | else:
301 | activation = nn.PReLU()
302 |
303 | self.main_conv1 = nn.Sequential(
304 | nn.Conv2d(in_channels, out_channels, kernel_size= kernel_size, padding=padding, bias = bias),
305 | nn.BatchNorm2d(out_channels)
306 | )
307 |
308 | self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2)
309 |
310 | ### Extension branch
311 | self.ext_conv1 = nn.Sequential(
312 | nn.Conv2d(in_channels, internal_channels, kernel_size= 1, bias= bias),
313 | nn.BatchNorm2d(internal_channels),
314 | activation
315 | )
316 |
317 | self.ext_conv2 = nn.Sequential(
318 | nn.ConvTranspose2d(internal_channels, internal_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1, bias=bias),
319 | nn.BatchNorm2d(internal_channels),
320 | activation
321 | )
322 |
323 | self.ext_conv3 = nn.Sequential(
324 | nn.Conv2d(internal_channels, out_channels, kernel_size=1, bias=bias),
325 | nn.BatchNorm2d(out_channels),
326 | activation
327 | )
328 |
329 | self.ext_reg = nn.Dropout2d(p=dropout_prob)
330 |
331 | self.out_activation = activation
332 |
333 | def forward(self, x, max_indices):
334 | main = self.main_conv1(x)
335 | main = self.main_unpool1(main, max_indices)
336 |
337 | ext = self.ext_conv1(x)
338 | ext = self.ext_conv2(ext)
339 | ext = self.ext_conv3(ext)
340 | ext = self.ext_reg(ext)
341 |
342 | out = main+ext
343 |
344 | return self.out_activation(out)
345 |
346 |
347 | #####
348 | '''
349 | So here writing up the helper functions for the case of LEDNet
350 | '''
351 |
352 | class Channel_Split(nn.Module):
353 | '''
354 | Used to split the channels of a tensor
355 | '''
356 | def __init__(self, channels):
357 | super().__init__()
358 |
359 | self.channels = channels
360 |
361 | def forward(self, x):
362 | x_left = x[:, 0:self.channels // 2, :, :]
363 | x_right = x[:, self.channels // 2 : , :, :]
364 |
365 | return x_left, x_right
366 |
367 |
368 | class ShuffleBlock(nn.Module):
369 | def __init__(self, groups):
370 | super(ShuffleBlock, self).__init__()
371 | self.groups = groups
372 |
373 | def forward(self, x):
374 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
375 | N,C,H,W = x.size()
376 | g = self.groups
377 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W)
378 |
379 |
380 | class SSnbt(nn.Module):
381 | '''
382 | The class for the implementation split-shuffle-non-bottleneck block
383 | '''
384 |
385 | def __init__(self, channels, kernel_size = 3, padding = 0, groups=4, dilation=1, bias= True, relu=False, drop_prob=0):
386 | '''
387 | groups: parameter for determining the shuffling order in the shuffleBlock
388 | '''
389 | super().__init__()
390 |
391 | if relu:
392 | activation = nn.ReLU()
393 | else:
394 | activation = nn.PReLU()
395 |
396 | ### Now firstly there will be channel split and hence we have to write for left and right parts of
397 | ### our model
398 |
399 | ### Also in this model the convolutions are asymmetric in nature
400 |
401 | mid_channels = channels // 2
402 |
403 | self.split = Channel_Split(channels)
404 |
405 | self.l_conv1 = nn.Conv2d(mid_channels, mid_channels, kernel_size=(kernel_size, 1), stride=1, padding=(padding, 0), dilation = 1, bias=bias)
406 |
407 | self.l_conv2 = nn.Sequential(
408 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (1, kernel_size), stride=1, padding=(0, padding), dilation=1, bias=bias),
409 | nn.BatchNorm2d(mid_channels),
410 | activation
411 | )
412 |
413 | self.r_conv1 = nn.Conv2d(mid_channels, mid_channels, kernel_size = (1, kernel_size), stride=1, padding=(0, padding), dilation=1, bias=bias)
414 |
415 | self.r_conv2 = nn.Sequential(
416 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (kernel_size, 1), stride=1, padding=(padding, 0), dilation=1, bias=bias),
417 | nn.BatchNorm2d(mid_channels),
418 | activation
419 | )
420 |
421 | self.l_conv3 = nn.Sequential(
422 | nn.Conv2d(mid_channels, mid_channels, kernel_size=(kernel_size, 1), stride=1, padding=(padding + dilation - 1, 0), dilation = dilation, bias=bias),
423 | activation
424 | )
425 |
426 | self.l_conv4 = nn.Sequential(
427 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (1, kernel_size), stride=1, padding=(0, padding + dilation - 1), dilation=dilation, bias=bias),
428 | nn.BatchNorm2d(mid_channels),
429 | activation
430 | )
431 |
432 | self.r_conv3 = nn.Sequential(
433 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (1, kernel_size), stride=1, padding=(0, padding + dilation - 1), dilation=dilation, bias=bias),
434 | activation
435 | )
436 |
437 | self.r_conv4 = nn.Sequential(
438 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (kernel_size, 1), stride=1, padding=(padding + dilation - 1, 0), dilation=dilation, bias=bias),
439 | nn.BatchNorm2d(mid_channels),
440 | activation
441 | )
442 |
443 | self.regularizer = nn.Dropout2d(p=drop_prob)
444 |
445 | self.out_activation = activation
446 | self.shuffle = ShuffleBlock(groups=groups)
447 |
448 | def forward(self, x):
449 | main = x
450 |
451 | l_input, r_input = self.split(x)
452 |
453 | l_input = self.l_conv1(l_input)
454 | l_input = self.l_conv2(l_input)
455 | r_input = self.r_conv1(r_input)
456 | r_input = self.r_conv2(r_input)
457 |
458 | l_input = self.l_conv3(l_input)
459 | l_input = self.l_conv4(l_input)
460 | r_input = self.r_conv3(r_input)
461 | r_input = self.r_conv4(r_input)
462 |
463 | ext = torch.cat((l_input, r_input), 1)
464 |
465 | ext = self.regularizer(ext)
466 |
467 | out = main + ext
468 | out = self.out_activation(out)
469 | out = self.shuffle(out)
470 |
471 | return out
472 |
473 |
474 | class Downsample_Block_led(nn.Module):
475 | '''
476 | The downsampling block for the LEDNet
477 |
478 | So basically we will do two operations in parallel, conv and max pool and then concat both of them
479 | '''
480 |
481 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=0, bias=True, relu=True):
482 | super().__init__()
483 |
484 | if relu:
485 | activation = nn.ReLU()
486 | else:
487 | activation = nn.PReLU()
488 |
489 | self.conv = nn.Conv2d(in_channels, out_channels-in_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias)
490 |
491 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
492 | self.out_activation = activation
493 |
494 | def forward(self, x):
495 | main = self.conv(x)
496 | ext = self.maxpool(x)
497 |
498 | out = torch.cat((main, ext), 1)
499 | out = self.out_activation(out)
500 | return out
501 |
502 | class APN(nn.Module):
503 | '''
504 | This is the Attention Pyramid Network including the whole upsampling block
505 | '''
506 |
507 | def __init__(self, in_channels, out_channels, image_dim, bias=True, relu=True):
508 | '''
509 | image_dim = dimension of the incoming image
510 | '''
511 | super().__init__()
512 |
513 | if relu:
514 | activation = nn.ReLU()
515 | else:
516 | activation = nn.PReLU()
517 |
518 | self.conv_33 = nn.Sequential(
519 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, bias=bias),
520 | nn.BatchNorm2d(in_channels),
521 | activation
522 | )
523 |
524 | self.conv_55 = nn.Sequential(
525 | nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=2, padding=2, bias=bias),
526 | nn.BatchNorm2d(in_channels),
527 | activation
528 | )
529 |
530 | self.conv_77 = nn.Sequential(
531 | nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=2, padding=3, bias=bias),
532 | nn.BatchNorm2d(in_channels),
533 | activation
534 | )
535 |
536 | self.conv_77_toC = nn.Sequential(
537 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
538 | nn.BatchNorm2d(out_channels),
539 | activation
540 | )
541 |
542 | self.conv_55_toC = nn.Sequential(
543 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
544 | nn.BatchNorm2d(out_channels),
545 | activation
546 | )
547 |
548 | self.conv_33_toC = nn.Sequential(
549 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
550 | nn.BatchNorm2d(out_channels),
551 | activation
552 | )
553 |
554 | self.conv_11 = nn.Sequential(
555 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
556 | nn.BatchNorm2d(out_channels),
557 | activation
558 | )
559 |
560 | self.global_pool = nn.AvgPool2d(kernel_size = (image_dim, image_dim), stride=0, padding=0)
561 | self.pool_conv_toC = nn.Sequential(
562 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
563 | nn.BatchNorm2d(out_channels),
564 | activation
565 | )
566 |
567 | self.upsample_times2 = nn.UpsamplingBilinear2d(scale_factor=2)
568 | self.upsample_globalPool = nn.UpsamplingBilinear2d(size = (image_dim, image_dim))
569 |
570 | def forward(self, x):
571 | output_conv33 = self.conv_33(x)
572 |
573 | output_conv55 = self.conv_55(output_conv33)
574 |
575 | output_conv77 = self.conv_77(output_conv55)
576 | output_conv77_toC = self.conv_77_toC(output_conv77)
577 | output_conv77_upsample = self.upsample_times2(output_conv77_toC)
578 |
579 | output_conv55_toC = self.conv_55_toC(output_conv55)
580 | output_conv55_toC = output_conv55_toC + output_conv77_upsample
581 | output_conv55_upsample = self.upsample_times2(output_conv55_toC)
582 |
583 | output_conv33_toC = self.conv_33_toC(output_conv33)
584 | output_conv33_toC = output_conv33_toC + output_conv55_upsample
585 | output_conv33_upsample = self.upsample_times2(output_conv33_toC)
586 |
587 | output_conv11 = self.conv_11(x)
588 |
589 | output_global_pool = self.global_pool(x)
590 | output_pool_conv_toC = self.pool_conv_toC(output_global_pool)
591 | output_pool_upsample = self.upsample_globalPool(output_pool_conv_toC)
592 |
593 | output_conv11 = output_conv11 * output_conv33_upsample
594 | output_conv11 = output_conv11 + output_pool_upsample
595 |
596 | final_output = F.upsample_bilinear(output_conv11, scale_factor=8)
597 |
598 | return final_output
599 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import itertools
3 | import numpy as np
4 |
5 | import torch
6 | from torch import nn
7 | from torch.autograd import Variable
8 | import torchvision
9 | import torchvision.datasets as dsets
10 | from torch.utils.data import DataLoader
11 | import torchvision.transforms as transforms
12 | import utils
13 | # from utils import perceptual_loss
14 | from arch import define_Gen, define_Dis, set_grad
15 | from data_utils import VOCDataset, CityscapesDataset, ACDCDataset, get_transformation
16 | from utils import make_one_hot
17 | from tensorboardX import SummaryWriter
18 |
19 | '''
20 | Class for CycleGAN with train() as a member function
21 |
22 | '''
23 | root = './data/VOC2012'
24 | root_cityscapes = "./data/Cityscape"
25 | root_acdc = './data/ACDC'
26 |
27 | ### The location for tensorboard visualizations
28 | tensorboard_loc = './tensorboard_results/first_run'
29 |
30 | ### The location from where we can get the pretrained model
31 | pretrained_loc = 'resnet101COCO-41f33a49.pth'
32 |
33 | class supervised_model(object):
34 | def __init__(self, args):
35 |
36 | if args.dataset == 'voc2012':
37 | self.n_channels = 21
38 | elif args.dataset == 'cityscapes':
39 | self.n_channels = 20
40 | elif args.dataset == 'acdc':
41 | self.n_channels = 4
42 |
43 | # Define the network
44 | self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab', norm=args.norm,
45 | use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) # for image to segmentation
46 |
47 | ### Now we put in the pretrained weights in Gsi
48 | ### These will only be used in the case of VOC and cityscapes
49 | if args.dataset != 'acdc':
50 | saved_state_dict = torch.load(pretrained_loc)
51 | new_params = self.Gsi.state_dict().copy()
52 | for name, param in new_params.items():
53 | # print(name)
54 | if name in saved_state_dict and param.size() == saved_state_dict[name].size():
55 | new_params[name].copy_(saved_state_dict[name])
56 | # print('copy {}'.format(name))
57 | # self.Gsi.load_state_dict(new_params)
58 |
59 | utils.print_networks([self.Gsi], ['Gsi'])
60 |
61 | ###Defining an interpolation function so as to match the output of network to feature map size
62 | self.interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True)
63 | self.interp_val = nn.Upsample(size = (512, 512), mode='bilinear', align_corners=True)
64 |
65 | self.CE = nn.CrossEntropyLoss()
66 | self.activation_softmax = nn.Softmax2d()
67 | self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(), lr=args.lr, betas=(0.9, 0.999))
68 |
69 | ### writer for tensorboard
70 | self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised')
71 | self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset)
72 |
73 | self.args = args
74 |
75 | if not os.path.isdir(args.checkpoint_dir):
76 | os.makedirs(args.checkpoint_dir)
77 |
78 | try:
79 | ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir))
80 | self.start_epoch = ckpt['epoch']
81 | self.Gsi.load_state_dict(ckpt['Gsi'])
82 | self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer'])
83 | self.best_iou = ckpt['best_iou']
84 | except:
85 | print(' [*] No checkpoint!')
86 | self.start_epoch = 0
87 | self.best_iou = -100
88 |
89 | def train(self, args):
90 |
91 | transform = get_transformation((self.args.crop_height, self.args.crop_width), resize=True, dataset=args.dataset)
92 | val_transform = get_transformation((512, 512), resize=True, dataset=args.dataset)
93 |
94 | # let the choice of dataset configurable
95 | if self.args.dataset == 'voc2012':
96 | labeled_set = VOCDataset(root_path=root, name='label', ratio=1.0, transformation=transform,
97 | augmentation=None)
98 | val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=val_transform,
99 | augmentation=None)
100 | labeled_loader = DataLoader(labeled_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
101 | val_loader = DataLoader(val_set, batch_size=self.args.batch_size, shuffle=True)
102 | elif self.args.dataset == 'cityscapes':
103 | labeled_set = CityscapesDataset(root_path=root_cityscapes, name='label', ratio=0.5, transformation=transform,
104 | augmentation=None)
105 | val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform,
106 | augmentation=None)
107 | labeled_loader = DataLoader(labeled_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
108 | val_loader = DataLoader(val_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
109 | elif self.args.dataset == 'acdc':
110 | labeled_set = ACDCDataset(root_path=root_acdc, name='label', ratio=0.5, transformation=transform,
111 | augmentation=None)
112 | val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform,
113 | augmentation=None)
114 | labeled_loader = DataLoader(labeled_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
115 | val_loader = DataLoader(val_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
116 |
117 | img_fake_sample = utils.Sample_from_Pool()
118 | gt_fake_sample = utils.Sample_from_Pool()
119 |
120 | for epoch in range(self.start_epoch, self.args.epochs):
121 | self.Gsi.train()
122 | for i, (l_img, l_gt, img_name) in enumerate(labeled_loader):
123 | # step
124 | step = epoch * len(labeled_loader) + i + 1
125 |
126 | self.gsi_optimizer.zero_grad()
127 |
128 | l_img, l_gt = utils.cuda([l_img, l_gt], args.gpu_ids)
129 |
130 | lab_gt = self.Gsi(l_img)
131 |
132 | lab_gt = self.interp(lab_gt) ### To get the output of model same as labels
133 |
134 | # CE losses
135 | fullsupervisedloss = self.CE(lab_gt, l_gt.squeeze(1))
136 |
137 | fullsupervisedloss.backward()
138 | self.gsi_optimizer.step()
139 |
140 | print("Epoch: (%3d) (%5d/%5d) | Crossentropy Loss:%.2e" %
141 | (epoch, i + 1, len(labeled_loader), fullsupervisedloss.item()))
142 |
143 | self.writer_supervised.add_scalars('Supervised Loss', {'CE Loss ':fullsupervisedloss}, len(labeled_loader)*epoch + i)
144 |
145 | ### For getting the IoU for the image
146 | self.Gsi.eval()
147 | with torch.no_grad():
148 | for i, (val_img, val_gt, _) in enumerate(val_loader):
149 | val_img, val_gt = utils.cuda([val_img, val_gt], args.gpu_ids)
150 |
151 | outputs = self.Gsi(val_img)
152 | outputs = self.interp_val(outputs)
153 | outputs = self.activation_softmax(outputs)
154 |
155 | pred = outputs.data.max(1)[1].cpu().numpy()
156 | gt = val_gt.squeeze().data.cpu().numpy()
157 |
158 | self.running_metrics_val.update(gt, pred)
159 |
160 | score, class_iou = self.running_metrics_val.get_scores()
161 |
162 | self.running_metrics_val.reset()
163 |
164 | ### For displaying the images generated by generator on tensorboard
165 | val_img, val_gt, _ = iter(val_loader).next()
166 | val_img, val_gt = utils.cuda([val_img, val_gt], args.gpu_ids)
167 | with torch.no_grad():
168 | fake = self.Gsi(val_img).detach()
169 | fake = self.interp_val(fake)
170 | fake = self.activation_softmax(fake)
171 | fake_prediction = fake.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
172 | val_gt = val_gt.cpu()
173 |
174 | ### display_tensor is the final tensor that will be displayed on tensorboard
175 | display_tensor = torch.zeros([fake.shape[0], 3, fake.shape[2], fake.shape[3]])
176 | display_tensor_gt = torch.zeros([val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]])
177 | for i in range(fake_prediction.shape[0]):
178 | new_img = fake_prediction[i]
179 | new_img = utils.colorize_mask(new_img, self.args.dataset) ### So this is the generated image in PIL.Image format
180 | img_tensor = utils.PIL_to_tensor(new_img, self.args.dataset)
181 | display_tensor[i, :, :, :] = img_tensor
182 |
183 | display_tensor_gt[i, :, :, :] = val_gt[i]
184 |
185 | self.writer_supervised.add_image('Generated segmented image', torchvision.utils.make_grid(display_tensor, nrow=2, normalize=True), epoch)
186 | self.writer_supervised.add_image('Ground truth for the image', torchvision.utils.make_grid(display_tensor_gt, nrow=2, normalize=True), epoch)
187 |
188 |
189 | if score["Mean IoU : \t"] >= self.best_iou:
190 | self.best_iou = score["Mean IoU : \t"]
191 | # Override the latest checkpoint
192 | utils.save_checkpoint({'epoch': epoch + 1,
193 | 'Gsi': self.Gsi.state_dict(),
194 | 'gsi_optimizer': self.gsi_optimizer.state_dict(),
195 | 'best_iou': self.best_iou,
196 | 'class_iou': class_iou},
197 | '%s/latest_supervised_model.ckpt' % (self.args.checkpoint_dir))
198 |
199 | self.writer_supervised.close()
200 |
201 |
202 | class semisuper_cycleGAN(object):
203 | def __init__(self, args):
204 |
205 | if args.dataset == 'voc2012':
206 | self.n_channels = 21
207 | elif args.dataset == 'cityscapes':
208 | self.n_channels = 20
209 | elif args.dataset == 'acdc':
210 | self.n_channels = 4
211 |
212 | # Define the network
213 | #####################################################
214 | # for segmentaion to image
215 | self.Gis = define_Gen(input_nc=self.n_channels, output_nc=3, ngf=args.ngf, netG='deeplab',
216 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)
217 | # for image to segmentation
218 | self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab',
219 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)
220 | self.Di = define_Dis(input_nc=3, ndf=args.ndf, netD='pixel', n_layers_D=3,
221 | norm=args.norm, gpu_ids=args.gpu_ids)
222 | self.Ds = define_Dis(input_nc=self.n_channels, ndf=args.ndf, netD='pixel', n_layers_D=3,
223 | norm=args.norm, gpu_ids=args.gpu_ids) # for voc 2012, there are 21 classes
224 |
225 | self.old_Gis = define_Gen(input_nc=self.n_channels, output_nc=3, ngf=args.ngf, netG='resnet_9blocks',
226 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)
227 | self.old_Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='resnet_9blocks_softmax',
228 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)
229 | self.old_Di = define_Dis(input_nc=3, ndf=args.ndf, netD='pixel', n_layers_D=3,
230 | norm=args.norm, gpu_ids=args.gpu_ids)
231 |
232 | ### To put the pretrained weights in Gis and Gsi
233 | # if args.dataset != 'acdc':
234 | # saved_state_dict = torch.load(pretrained_loc)
235 | # new_params_Gsi = self.Gsi.state_dict().copy()
236 | # # new_params_Gis = self.Gis.state_dict().copy()
237 | # for name, param in new_params_Gsi.items():
238 | # # print(name)
239 | # if name in saved_state_dict and param.size() == saved_state_dict[name].size():
240 | # new_params_Gsi[name].copy_(saved_state_dict[name])
241 | # # print('copy {}'.format(name))
242 | # self.Gsi.load_state_dict(new_params_Gsi)
243 | # for name, param in new_params_Gis.items():
244 | # # print(name)
245 | # if name in saved_state_dict and param.size() == saved_state_dict[name].size():
246 | # new_params_Gis[name].copy_(saved_state_dict[name])
247 | # # print('copy {}'.format(name))
248 | # # self.Gis.load_state_dict(new_params_Gis)
249 |
250 | ### This is just so as to get pretrained methods for the case of Gis
251 | if args.dataset == 'voc2012':
252 | try:
253 | ckpt_for_Arnab_loss = utils.load_checkpoint('./ckpt_for_Arnab_loss.ckpt')
254 | self.old_Gis.load_state_dict(ckpt_for_Arnab_loss['Gis'])
255 | self.old_Gsi.load_state_dict(ckpt_for_Arnab_loss['Gsi'])
256 | except:
257 | print('**There is an error in loading the ckpt_for_Arnab_loss**')
258 |
259 |
260 | utils.print_networks([self.Gsi], ['Gsi'])
261 |
262 |
263 | utils.print_networks([self.Gis,self.Gsi,self.Di,self.Ds], ['Gis','Gsi','Di','Ds'])
264 |
265 | self.args = args
266 |
267 | ### interpolation
268 | self.interp = nn.Upsample((args.crop_height, args.crop_width), mode='bilinear', align_corners=True)
269 |
270 | self.MSE = nn.MSELoss()
271 | self.L1 = nn.L1Loss()
272 | self.CE = nn.CrossEntropyLoss()
273 | self.activation_softmax = nn.Softmax2d()
274 | self.activation_tanh = nn.Tanh()
275 | self.activation_sigmoid = nn.Sigmoid()
276 |
277 | ### Tensorboard writer
278 | self.writer_semisuper = SummaryWriter(tensorboard_loc + '_semisuper')
279 | self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset)
280 |
281 | ### For adding gaussian noise
282 | self.gauss_noise = utils.GaussianNoise(sigma = 0.2)
283 |
284 | # Optimizers
285 | #####################################################
286 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.Gis.parameters(),self.Gsi.parameters()), lr=args.lr, betas=(0.5, 0.999))
287 | self.d_optimizer = torch.optim.Adam(itertools.chain(self.Di.parameters(),self.Ds.parameters()), lr=args.lr, betas=(0.5, 0.999))
288 |
289 |
290 | self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
291 | self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
292 |
293 | # Try loading checkpoint
294 | #####################################################
295 | if not os.path.isdir(args.checkpoint_dir):
296 | os.makedirs(args.checkpoint_dir)
297 |
298 | try:
299 | ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir))
300 | self.start_epoch = ckpt['epoch']
301 | self.Di.load_state_dict(ckpt['Di'])
302 | self.Ds.load_state_dict(ckpt['Ds'])
303 | self.Gis.load_state_dict(ckpt['Gis'])
304 | self.Gsi.load_state_dict(ckpt['Gsi'])
305 | self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
306 | self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
307 | self.best_iou = ckpt['best_iou']
308 | except:
309 | print(' [*] No checkpoint!')
310 | self.start_epoch = 0
311 | self.best_iou = -100
312 |
313 |
314 |
315 | def train(self, args):
316 | transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset)
317 |
318 | # let the choice of dataset configurable
319 | if self.args.dataset == 'voc2012':
320 | labeled_set = VOCDataset(root_path=root, name='label', ratio=0.1, transformation=transform,
321 | augmentation=None)
322 | unlabeled_set = VOCDataset(root_path=root, name='unlabel', ratio=0.1, transformation=transform,
323 | augmentation=None)
324 | val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=transform,
325 | augmentation=None)
326 | elif self.args.dataset == 'cityscapes':
327 | labeled_set = CityscapesDataset(root_path=root_cityscapes, name='label', ratio=0.5, transformation=transform,
328 | augmentation=None)
329 | unlabeled_set = CityscapesDataset(root_path=root_cityscapes, name='unlabel', ratio=0.5, transformation=transform,
330 | augmentation=None)
331 | val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform,
332 | augmentation=None)
333 | elif self.args.dataset == 'acdc':
334 | labeled_set = ACDCDataset(root_path=root_acdc, name='label', ratio=0.5, transformation=transform,
335 | augmentation=None)
336 | unlabeled_set = ACDCDataset(root_path=root_acdc, name='unlabel', ratio=0.5, transformation=transform,
337 | augmentation=None)
338 | val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform,
339 | augmentation=None)
340 |
341 | '''
342 | https://discuss.pytorch.org/t/about-the-relation-between-batch-size-and-length-of-data-loader/10510
343 | ^^ The reason for using drop_last=True so as to obtain an even size of all the batches and
344 | deleting the last batch with less images
345 | '''
346 | labeled_loader = DataLoader(labeled_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
347 | unlabeled_loader = DataLoader(unlabeled_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
348 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
349 |
350 | new_img_fake_sample = utils.Sample_from_Pool()
351 | img_fake_sample = utils.Sample_from_Pool()
352 | gt_fake_sample = utils.Sample_from_Pool()
353 |
354 | img_dis_loss, gt_dis_loss, unsupervisedloss, fullsupervisedloss = 0, 0, 0, 0
355 |
356 | ### Variable to regulate the frequency of update between Discriminators and Generators
357 | counter = 0
358 |
359 | for epoch in range(self.start_epoch, args.epochs):
360 | lr = self.g_optimizer.param_groups[0]['lr']
361 | print('learning rate = %.7f' % lr)
362 |
363 | self.Gsi.train()
364 | self.Gis.train()
365 |
366 | # if (epoch+1)%10 == 0:
367 | # args.lamda_img = args.lamda_img + 0.08
368 | # args.lamda_gt = args.lamda_gt + 0.04
369 |
370 | for i, ((l_img, l_gt, _), (unl_img, _, _)) in enumerate(zip(labeled_loader, unlabeled_loader)):
371 | # step
372 | step = epoch * min(len(labeled_loader), len(unlabeled_loader)) + i + 1
373 |
374 | l_img, unl_img, l_gt = utils.cuda([l_img, unl_img, l_gt], args.gpu_ids)
375 |
376 | # Generator Computations
377 | ##################################################
378 |
379 | set_grad([self.Di, self.Ds, self.old_Di], False)
380 | set_grad([self.old_Gsi, self.old_Gis], False)
381 | self.g_optimizer.zero_grad()
382 |
383 | # Forward pass through generators
384 | ##################################################
385 | fake_img = self.Gis(make_one_hot(l_gt, args.dataset, args.gpu_ids).float())
386 | fake_gt = self.Gsi(unl_img.float()) ### having 21 channels
387 | lab_gt = self.Gsi(l_img) ### having 21 channels
388 |
389 | ### Getting the outputs of the model to correct dimensions
390 | fake_img = self.interp(fake_img)
391 | fake_gt = self.interp(fake_gt)
392 | lab_gt = self.interp(lab_gt)
393 |
394 | # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0) ### will get into no channels
395 | # fake_gt = fake_gt.unsqueeze(1) ### will get into 1 channel only
396 | # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids)
397 |
398 | lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1))
399 |
400 | ### Again applying activations
401 | lab_gt = self.activation_softmax(lab_gt)
402 | fake_gt = self.activation_softmax(fake_gt)
403 | # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)
404 | # fake_gt = fake_gt.unsqueeze(1)
405 | # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids)
406 | # fake_img = self.activation_tanh(fake_img)
407 |
408 | recon_img = self.Gis(fake_gt.float())
409 | recon_lab_img = self.Gis(lab_gt.float())
410 | recon_gt = self.Gsi(fake_img.float())
411 |
412 | ### Getting the outputs of the model to correct dimensions
413 | recon_img = self.interp(recon_img)
414 | recon_lab_img = self.interp(recon_lab_img)
415 | recon_gt = self.interp(recon_gt)
416 |
417 | ### This is for the case of the new loss between the recon_img from resnet and deeplab network
418 | resnet_fake_gt = self.old_Gsi(unl_img.float())
419 | resnet_lab_gt = self.old_Gsi(l_img)
420 | resnet_lab_gt = self.activation_softmax(resnet_lab_gt)
421 | resnet_fake_gt = self.activation_softmax(resnet_fake_gt)
422 | resnet_recon_img = self.old_Gis(resnet_fake_gt.float())
423 | resnet_recon_lab_img = self.old_Gis(resnet_lab_gt.float())
424 |
425 | ## Applying the tanh activations
426 | # recon_img = self.activation_tanh(recon_img)
427 | # recon_lab_img = self.activation_tanh(recon_lab_img)
428 |
429 | # Adversarial losses
430 | ###################################################
431 | fake_img_dis = self.Di(fake_img)
432 | resnet_fake_img_dis = self.old_Di(recon_img)
433 |
434 | ### For passing different type of input to Ds
435 | fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)
436 | fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1)
437 | fake_gt_discriminator = make_one_hot(fake_gt_discriminator, args.dataset, args.gpu_ids)
438 | fake_gt_dis = self.Ds(fake_gt_discriminator.float())
439 | # lab_gt_dis = self.Ds(lab_gt)
440 |
441 | real_label_gt = utils.cuda(Variable(torch.ones(fake_gt_dis.size())), args.gpu_ids)
442 | real_label_img = utils.cuda(Variable(torch.ones(fake_img_dis.size())), args.gpu_ids)
443 |
444 | # here is much better to have a cross entropy loss for classification.
445 | img_gen_loss = self.MSE(fake_img_dis, real_label_img)
446 | gt_gen_loss = self.MSE(fake_gt_dis, real_label_gt)
447 | # gt_label_gen_loss = self.MSE(lab_gt_dis, real_label)
448 |
449 |
450 | # Cycle consistency losses
451 | ###################################################
452 | resnet_img_cycle_loss = self.MSE(resnet_fake_img_dis, real_label_img)
453 | # img_cycle_loss = self.L1(recon_img, unl_img)
454 | # img_cycle_loss_perceptual = perceptual_loss(recon_img, unl_img, args.gpu_ids)
455 | gt_cycle_loss = self.CE(recon_gt, l_gt.squeeze(1))
456 | # lab_img_cycle_loss = self.L1(recon_lab_img, l_img) * args.lamda
457 |
458 | # Total generators losses
459 | ###################################################
460 | # lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1))
461 | lab_loss_MSE = self.L1(fake_img, l_img)
462 | # lab_loss_perceptual = perceptual_loss(fake_img, l_img, args.gpu_ids)
463 |
464 | fullsupervisedloss = args.lab_CE_weight*lab_loss_CE + args.lab_MSE_weight*lab_loss_MSE
465 |
466 | unsupervisedloss = args.adversarial_weight*(img_gen_loss + gt_gen_loss) + resnet_img_cycle_loss + gt_cycle_loss*args.lamda_gt
467 |
468 | gen_loss = fullsupervisedloss + unsupervisedloss
469 |
470 | # Update generators
471 | ###################################################
472 | gen_loss.backward()
473 |
474 | self.g_optimizer.step()
475 |
476 |
477 | if counter % 1 == 0:
478 | # Discriminator Computations
479 | #################################################
480 |
481 | set_grad([self.Di, self.Ds, self.old_Di], True)
482 | self.d_optimizer.zero_grad()
483 |
484 | # Sample from history of generated images
485 | #################################################
486 | if torch.rand(1) < 0.0:
487 | fake_img = self.gauss_noise(fake_img.cpu())
488 | fake_gt = self.gauss_noise(fake_gt.cpu())
489 |
490 | recon_img = Variable(torch.Tensor(new_img_fake_sample([recon_img.cpu().data.numpy()])[0]))
491 | fake_img = Variable(torch.Tensor(img_fake_sample([fake_img.cpu().data.numpy()])[0]))
492 | # lab_gt = Variable(torch.Tensor(gt_fake_sample([lab_gt.cpu().data.numpy()])[0]))
493 | fake_gt = Variable(torch.Tensor(gt_fake_sample([fake_gt.cpu().data.numpy()])[0]))
494 |
495 | recon_img, fake_img, fake_gt = utils.cuda([recon_img, fake_img, fake_gt], args.gpu_ids)
496 |
497 | # Forward pass through discriminators
498 | #################################################
499 | unl_img_dis = self.Di(unl_img)
500 | fake_img_dis = self.Di(fake_img)
501 | resnet_recon_img_dis = self.old_Di(resnet_recon_img)
502 | resnet_fake_img_dis = self.old_Di(recon_img)
503 |
504 | # lab_gt_dis = self.Ds(lab_gt)
505 |
506 | l_gt = make_one_hot(l_gt, args.dataset, args.gpu_ids)
507 | real_gt_dis = self.Ds(l_gt.float())
508 |
509 | fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0)
510 | fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1)
511 | fake_gt_discriminator = make_one_hot(fake_gt_discriminator, args.dataset, args.gpu_ids)
512 | fake_gt_dis = self.Ds(fake_gt_discriminator.float())
513 |
514 | real_label_img = utils.cuda(Variable(torch.ones(unl_img_dis.size())), args.gpu_ids)
515 | fake_label_img = utils.cuda(Variable(torch.zeros(fake_img_dis.size())), args.gpu_ids)
516 | real_label_gt = utils.cuda(Variable(torch.ones(real_gt_dis.size())), args.gpu_ids)
517 | fake_label_gt = utils.cuda(Variable(torch.zeros(fake_gt_dis.size())), args.gpu_ids)
518 |
519 | # Discriminator losses
520 | ##################################################
521 | img_dis_real_loss = self.MSE(unl_img_dis, real_label_img)
522 | img_dis_fake_loss = self.MSE(fake_img_dis, fake_label_img)
523 | gt_dis_real_loss = self.MSE(real_gt_dis, real_label_gt)
524 | gt_dis_fake_loss = self.MSE(fake_gt_dis, fake_label_gt)
525 | # lab_gt_dis_fake_loss = self.MSE(lab_gt_dis, fake_label)
526 |
527 | cycle_img_dis_real_loss = self.MSE(resnet_recon_img_dis, real_label_img)
528 | cycle_img_dis_fake_loss = self.MSE(resnet_fake_img_dis, fake_label_img)
529 |
530 | # Total discriminators losses
531 | img_dis_loss = (img_dis_real_loss + img_dis_fake_loss)*0.5
532 | gt_dis_loss = (gt_dis_real_loss + gt_dis_fake_loss)*0.5
533 | # lab_gt_dis_loss = (gt_dis_real_loss + lab_gt_dis_fake_loss)*0.33
534 | cycle_img_dis_loss = cycle_img_dis_real_loss + cycle_img_dis_fake_loss
535 |
536 | # Update discriminators
537 | ##################################################
538 | discriminator_loss = args.discriminator_weight * (img_dis_loss + gt_dis_loss) + cycle_img_dis_loss
539 | discriminator_loss.backward()
540 |
541 | # lab_gt_dis_loss.backward()
542 | self.d_optimizer.step()
543 |
544 | print("Epoch: (%3d) (%5d/%5d) | Dis Loss:%.2e | Unlab Gen Loss:%.2e | Lab Gen loss:%.2e" %
545 | (epoch, i + 1, min(len(labeled_loader), len(unlabeled_loader)),
546 | img_dis_loss + gt_dis_loss, unsupervisedloss, fullsupervisedloss))
547 |
548 | self.writer_semisuper.add_scalars('Dis Loss', {'img_dis_loss':img_dis_loss, 'gt_dis_loss':gt_dis_loss, 'cycle_img_dis_loss':cycle_img_dis_loss}, len(labeled_loader)*epoch + i)
549 | self.writer_semisuper.add_scalars('Unlabelled Loss', {'img_gen_loss': img_gen_loss, 'gt_gen_loss':gt_gen_loss, 'img_cycle_loss':resnet_img_cycle_loss, 'gt_cycle_loss':gt_cycle_loss}, len(labeled_loader)*epoch + i)
550 | self.writer_semisuper.add_scalars('Labelled Loss', {'lab_loss_CE':lab_loss_CE, 'lab_loss_MSE':lab_loss_MSE}, len(labeled_loader)*epoch + i)
551 |
552 | counter += 1
553 |
554 | ### For getting the mean IoU
555 | self.Gsi.eval()
556 | self.Gis.eval()
557 | with torch.no_grad():
558 | for i, (val_img, val_gt, _) in enumerate(val_loader):
559 | val_img, val_gt = utils.cuda([val_img, val_gt], args.gpu_ids)
560 |
561 | outputs = self.Gsi(val_img)
562 | outputs = self.interp(outputs)
563 | outputs = self.activation_softmax(outputs)
564 |
565 | pred = outputs.data.max(1)[1].cpu().numpy()
566 | gt = val_gt.squeeze().data.cpu().numpy()
567 |
568 | self.running_metrics_val.update(gt, pred)
569 |
570 | score, class_iou = self.running_metrics_val.get_scores()
571 |
572 | self.running_metrics_val.reset()
573 |
574 | print('The mIoU for the epoch is: ', score["Mean IoU : \t"])
575 |
576 | ### For displaying the images generated by generator on tensorboard using validation images
577 | val_image, val_gt, _ = iter(val_loader).next()
578 | val_image, val_gt = utils.cuda([val_image, val_gt], args.gpu_ids)
579 | with torch.no_grad():
580 | fake_label = self.Gsi(val_image).detach()
581 | fake_label = self.interp(fake_label)
582 | fake_label = self.activation_softmax(fake_label)
583 | fake_label = fake_label.data.max(1)[1].squeeze_(1).squeeze_(0)
584 | fake_label = fake_label.unsqueeze(1)
585 | fake_label = make_one_hot(fake_label, args.dataset, args.gpu_ids)
586 | fake_img = self.Gis(fake_label).detach()
587 | fake_img = self.interp(fake_img)
588 | # fake_img = self.activation_tanh(fake_img)
589 |
590 | fake_img_from_labels = self.Gis(make_one_hot(val_gt, args.dataset, args.gpu_ids).float()).detach()
591 | fake_img_from_labels = self.interp(fake_img_from_labels)
592 | # fake_img_from_labels = self.activation_tanh(fake_img_from_labels)
593 | fake_label_regenerated = self.Gsi(fake_img_from_labels).detach()
594 | fake_label_regenerated = self.interp(fake_label_regenerated)
595 | fake_label_regenerated = self.activation_softmax(fake_label_regenerated)
596 | fake_prediction_label = fake_label.data.max(1)[1].squeeze_(1).cpu().numpy()
597 | fake_regenerated_label = fake_label_regenerated.data.max(1)[1].squeeze_(1).cpu().numpy()
598 | val_gt = val_gt.cpu()
599 |
600 | fake_img = fake_img.cpu()
601 | fake_img_from_labels = fake_img_from_labels.cpu()
602 | ### Now i am going to revert back the transformation on these images
603 | if self.args.dataset == 'voc2012' or self.args.dataset == 'cityscapes':
604 | trans_mean = [0.5, 0.5, 0.5]
605 | trans_std = [0.5, 0.5, 0.5]
606 | for i in range(3):
607 | fake_img[:, i, :, :] = ((fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i])
608 | fake_img_from_labels[:, i, :, :] = ((fake_img_from_labels[:, i, :, :] * trans_std[i]) + trans_mean[i])
609 |
610 | elif self.args.dataset == 'acdc':
611 | trans_mean = [0.5]
612 | trans_std = [0.5]
613 | for i in range(1):
614 | fake_img[:, i, :, :] = ((fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i])
615 | fake_img_from_labels[:, i, :, :] = ((fake_img_from_labels[:, i, :, :] * trans_std[i]) + trans_mean[i])
616 |
617 | ### display_tensor is the final tensor that will be displayed on tensorboard
618 | display_tensor_label = torch.zeros([fake_label.shape[0], 3, fake_label.shape[2], fake_label.shape[3]])
619 | display_tensor_gt = torch.zeros([val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]])
620 | display_tensor_regen_label = torch.zeros([fake_label_regenerated.shape[0], 3, fake_label_regenerated.shape[2], fake_label_regenerated.shape[3]])
621 | for i in range(fake_prediction_label.shape[0]):
622 | new_img_label = fake_prediction_label[i]
623 | new_img_label = utils.colorize_mask(new_img_label, self.args.dataset) ### So this is the generated image in PIL.Image format
624 | img_tensor_label = utils.PIL_to_tensor(new_img_label, self.args.dataset)
625 | display_tensor_label[i, :, :, :] = img_tensor_label
626 |
627 | display_tensor_gt[i, :, :, :] = val_gt[i]
628 |
629 | regen_label = fake_regenerated_label[i]
630 | regen_label = utils.colorize_mask(regen_label, self.args.dataset)
631 | regen_tensor_label = utils.PIL_to_tensor(regen_label, self.args.dataset)
632 | display_tensor_regen_label[i, :, :, :] = regen_tensor_label
633 |
634 | self.writer_semisuper.add_image('Generated segmented image: ', torchvision.utils.make_grid(display_tensor_label, nrow=2, normalize=True), epoch)
635 | self.writer_semisuper.add_image('Generated image back from segmentation: ', torchvision.utils.make_grid(fake_img, nrow=2, normalize=True), epoch)
636 | self.writer_semisuper.add_image('Ground truth for the image: ', torchvision.utils.make_grid(display_tensor_gt, nrow=2, normalize=True), epoch)
637 | self.writer_semisuper.add_image('Image generated from val labels: ', torchvision.utils.make_grid(fake_img_from_labels, nrow=2, normalize=True), epoch)
638 | self.writer_semisuper.add_image('Labels generated back from the cycle: ', torchvision.utils.make_grid(display_tensor_regen_label, nrow=2, normalize=True), epoch)
639 |
640 |
641 | if score["Mean IoU : \t"] >= self.best_iou:
642 | self.best_iou = score["Mean IoU : \t"]
643 |
644 | # Override the latest checkpoint
645 | #######################################################
646 | utils.save_checkpoint({'epoch': epoch + 1,
647 | 'Di': self.Di.state_dict(),
648 | 'Ds': self.Ds.state_dict(),
649 | 'Gis': self.Gis.state_dict(),
650 | 'Gsi': self.Gsi.state_dict(),
651 | 'd_optimizer': self.d_optimizer.state_dict(),
652 | 'g_optimizer': self.g_optimizer.state_dict(),
653 | 'best_iou': self.best_iou,
654 | 'class_iou': class_iou},
655 | '%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir))
656 |
657 | # Update learning rates
658 | ########################
659 | self.g_lr_scheduler.step()
660 | self.d_lr_scheduler.step()
661 |
662 | self.writer_semisuper.close()
663 |
--------------------------------------------------------------------------------