├── requirements.txt ├── arch ├── __init__.py ├── discriminators.py ├── generators.py └── ops.py ├── examples ├── model_image.png ├── result_square.png ├── VOC │ ├── 2007_002105_gt.png │ ├── 2007_002105_img.jpg │ ├── 2007_002105_semiSup.png │ └── 2007_002105_imgFromSeg.jpg ├── ACDC │ ├── patient001_frame01_2_gt.png │ ├── patient001_frame01_2_img.jpg │ ├── patient001_frame01_2_seg2img.jpg │ └── patient001_frame01_2_genLabels.png └── Cityscapes │ ├── dusseldorf_000176_000019_leftImg8bit_gt.png │ ├── dusseldorf_000176_000019_leftImg8bit_img.jpg │ ├── dusseldorf_000176_000019_leftImg8bit_seg2img.jpg │ └── dusseldorf_000176_000019_leftImg8bit_genLabels.png ├── .gitignore ├── data └── Datasets.txt ├── LICENSE ├── main.py ├── testing.py ├── README.md ├── data_utils ├── augmentations.py ├── __init__.py └── dataloader.py ├── validation.py ├── utils.py └── model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | scipy 4 | tensorboardX 5 | pillow 6 | -------------------------------------------------------------------------------- /arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .generators import define_Gen 2 | from .discriminators import define_Dis 3 | from .ops import set_grad -------------------------------------------------------------------------------- /examples/model_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/model_image.png -------------------------------------------------------------------------------- /examples/result_square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/result_square.png -------------------------------------------------------------------------------- /examples/VOC/2007_002105_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/VOC/2007_002105_gt.png -------------------------------------------------------------------------------- /examples/VOC/2007_002105_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/VOC/2007_002105_img.jpg -------------------------------------------------------------------------------- /examples/VOC/2007_002105_semiSup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/VOC/2007_002105_semiSup.png -------------------------------------------------------------------------------- /examples/ACDC/patient001_frame01_2_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/ACDC/patient001_frame01_2_gt.png -------------------------------------------------------------------------------- /examples/VOC/2007_002105_imgFromSeg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/VOC/2007_002105_imgFromSeg.jpg -------------------------------------------------------------------------------- /examples/ACDC/patient001_frame01_2_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/ACDC/patient001_frame01_2_img.jpg -------------------------------------------------------------------------------- /examples/ACDC/patient001_frame01_2_seg2img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/ACDC/patient001_frame01_2_seg2img.jpg -------------------------------------------------------------------------------- /examples/ACDC/patient001_frame01_2_genLabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/ACDC/patient001_frame01_2_genLabels.png -------------------------------------------------------------------------------- /examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_gt.png -------------------------------------------------------------------------------- /examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_img.jpg -------------------------------------------------------------------------------- /examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_seg2img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_seg2img.jpg -------------------------------------------------------------------------------- /examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_genLabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arnab39/Semi-supervised-segmentation-cycleGAN/HEAD/examples/Cityscapes/dusseldorf_000176_000019_leftImg8bit_genLabels.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | *.xml 3 | *.DS_Store 4 | # datasets 5 | datasets/VOC2012/* 6 | datasets/horse2zebra/* 7 | datasets/Cityscapes/* 8 | datasets/VOC2012/* 9 | datasets/Cityspaces/* 10 | arch/__pycache__/ 11 | arch/__pycache__/* 12 | datasets/horse2zebra/* 13 | __pycache__/ 14 | -------------------------------------------------------------------------------- /data/Datasets.txt: -------------------------------------------------------------------------------- 1 | File containing the links to public segmentation benchmarks. 2 | - PASCAL VOC 2012 Segmentation (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) 3 | - Cityscapes (https://www.cityscapes-dataset.com) 4 | - ACDC (https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html) 5 | 6 | Or you can download the preprocessed datasets from the link mentioned in the README 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Arnab Kumar Mondal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import model as md 4 | from utils import create_link 5 | from testing import test 6 | from validation import validation 7 | 8 | 9 | # To get arguments from commandline 10 | def get_args(): 11 | parser = ArgumentParser(description='cycleGAN PyTorch') 12 | parser.add_argument('--epochs', type=int, default=400) 13 | parser.add_argument('--decay_epoch', type=int, default=100) 14 | parser.add_argument('--batch_size', type=int, default=2) 15 | parser.add_argument('--lr', type=float, default=.0002) 16 | # parser.add_argument('--load_height', type=int, default=286) 17 | # parser.add_argument('--load_width', type=int, default=286) 18 | parser.add_argument('--gpu_ids', type=str, default='0') 19 | parser.add_argument('--crop_height', type=int, default=None) 20 | parser.add_argument('--crop_width', type=int, default=None) 21 | parser.add_argument('--lamda_img', type=int, default=0.5) # For image_cycle_loss 22 | parser.add_argument('--lamda_gt', type=int, default=0.1) # For gt_cycle_loss 23 | parser.add_argument('--lamda_perceptual', type=int, default=0) # For image cycle perceptual loss 24 | # parser.add_argument('--idt_coef', type=float, default=0.5) 25 | # parser.add_argument('--omega', type=int, default=5) 26 | parser.add_argument('--lab_CE_weight', type=int, default=1) 27 | parser.add_argument('--lab_MSE_weight', type=int, default=1) 28 | parser.add_argument('--lab_perceptual_weight', type=int, default=0) 29 | parser.add_argument('--adversarial_weight', type=int, default=1.0) 30 | parser.add_argument('--discriminator_weight', type=int, default=1.0) 31 | parser.add_argument('--training', type=bool, default=False) 32 | parser.add_argument('--testing', type=bool, default=False) 33 | parser.add_argument('--validation', type=bool, default=False) 34 | parser.add_argument('--model', type=str, default='supervised_model') 35 | parser.add_argument('--results_dir', type=str, default='./results') 36 | parser.add_argument('--validation_dir', type=str, default='./val_results') 37 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/semisupervised_cycleGAN') 38 | parser.add_argument('--dataset',type=str,choices=['voc2012', 'cityscapes', 'acdc'],default='voc2012') 39 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 40 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 41 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 42 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 43 | parser.add_argument('--gen_net', type=str, default='deeplab') 44 | parser.add_argument('--dis_net', type=str, default='fc_disc') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def main(): 50 | args = get_args() 51 | # set gpu ids 52 | str_ids = args.gpu_ids.split(',') 53 | args.gpu_ids = [] 54 | for str_id in str_ids: 55 | id = int(str_id) 56 | if id >= 0: 57 | args.gpu_ids.append(id) 58 | 59 | ### For setting the image dimensions for different datasets 60 | if args.crop_height == None and args.crop_width == None: 61 | if args.dataset == 'voc2012': 62 | args.crop_height = args.crop_width = 320 63 | elif args.dataset == 'acdc': 64 | args.crop_height = args.crop_width = 256 65 | elif args.dataset == 'cityscapes': 66 | args.crop_height = 512 67 | args.crop_width = 1024 68 | 69 | if args.training: 70 | if args.model == "semisupervised_cycleGAN": 71 | print("Training semi-supervised cycleGAN") 72 | model = md.semisuper_cycleGAN(args) 73 | model.train(args) 74 | if args.model == "supervised_model": 75 | print("Training base model") 76 | model = md.supervised_model(args) 77 | model.train(args) 78 | if args.testing: 79 | print("Testing") 80 | test(args) 81 | if args.validation: 82 | print("Validating") 83 | validation(args) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /arch/discriminators.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .ops import conv_norm_lrelu, get_norm_layer, init_network 3 | import torch 4 | from torch.nn import functional as F 5 | import functools 6 | 7 | 8 | class FCDiscriminator(nn.Module): 9 | ### This has been adapted directly from this repo: 10 | ### https://github.com/hfslyc/AdvSemiSeg 11 | 12 | def __init__(self, num_classes, ndf = 64): 13 | super(FCDiscriminator, self).__init__() 14 | 15 | self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 16 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 17 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 18 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 19 | self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1) 20 | 21 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 22 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 23 | #self.sigmoid = nn.Sigmoid() 24 | 25 | 26 | def forward(self, x): 27 | x = self.conv1(x) 28 | x = self.leaky_relu(x) 29 | x = self.conv2(x) 30 | x = self.leaky_relu(x) 31 | x = self.conv3(x) 32 | x = self.leaky_relu(x) 33 | x = self.conv4(x) 34 | x = self.leaky_relu(x) 35 | x = self.classifier(x) 36 | #x = self.up_sample(x) 37 | #x = self.sigmoid(x) 38 | 39 | return x 40 | 41 | 42 | class NLayerDiscriminator(nn.Module): 43 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_bias=False): 44 | super(NLayerDiscriminator, self).__init__() 45 | dis_model = [nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1), 46 | nn.LeakyReLU(0.2, True)] 47 | nf_mult = 1 48 | nf_mult_prev = 1 49 | for n in range(1, n_layers): 50 | nf_mult_prev = nf_mult 51 | nf_mult = min(2**n, 8) 52 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, 53 | norm_layer= norm_layer, padding=1, bias=use_bias)] 54 | nf_mult_prev = nf_mult 55 | nf_mult = min(2**n_layers, 8) 56 | dis_model += [conv_norm_lrelu(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, 57 | norm_layer= norm_layer, padding=1, bias=use_bias)] 58 | dis_model += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)] 59 | 60 | self.dis_model = nn.Sequential(*dis_model) 61 | 62 | def forward(self, input): 63 | return self.dis_model(input) 64 | 65 | 66 | class PixelDiscriminator(nn.Module): 67 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_bias=False): 68 | super(PixelDiscriminator, self).__init__() 69 | dis_model = [ 70 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 71 | nn.LeakyReLU(0.2, True), 72 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 73 | norm_layer(ndf * 2), 74 | nn.LeakyReLU(0.2, True), 75 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 76 | 77 | self.dis_model = nn.Sequential(*dis_model) 78 | 79 | def forward(self, input): 80 | return self.dis_model(input) 81 | 82 | 83 | 84 | def define_Dis(input_nc, ndf, netD, n_layers_D=3, norm='batch', gpu_ids=[0]): 85 | dis_net = None 86 | norm_layer = get_norm_layer(norm_type=norm) 87 | if type(norm_layer) == functools.partial: 88 | use_bias = norm_layer.func == nn.InstanceNorm2d 89 | else: 90 | use_bias = norm_layer == nn.InstanceNorm2d 91 | 92 | if netD == 'n_layers': 93 | dis_net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_bias=use_bias) 94 | elif netD == 'pixel': 95 | dis_net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_bias=use_bias) 96 | elif netD == 'fc_disc': 97 | dis_net = FCDiscriminator(input_nc, ndf) 98 | else: 99 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) 100 | 101 | return init_network(dis_net, gpu_ids) 102 | -------------------------------------------------------------------------------- /testing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torchvision 7 | import torchvision.datasets as dsets 8 | from torch.utils.data import DataLoader 9 | import torchvision.transforms as transforms 10 | import utils 11 | from PIL import Image 12 | from arch import define_Gen 13 | from data_utils import VOCDataset, CityscapesDataset, ACDCDataset, get_transformation 14 | 15 | root = './data/VOC2012test/VOC2012' 16 | root_cityscapes = './data/Cityscape' 17 | root_acdc = './data/ACDC' 18 | 19 | def test(args): 20 | 21 | ### For selecting the number of channels 22 | if args.dataset == 'voc2012': 23 | n_channels = 21 24 | elif args.dataset == 'cityscapes': 25 | n_channels = 20 26 | elif args.dataset == 'acdc': 27 | n_channels = 4 28 | 29 | transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset) 30 | 31 | ## let the choice of dataset configurable 32 | if args.dataset == 'voc2012': 33 | test_set = VOCDataset(root_path=root, name='test', ratio=0.5, transformation=transform, augmentation=None) 34 | elif args.dataset == 'cityscapes': 35 | test_set = CityscapesDataset(root_path=root_cityscapes, name='test', ratio=0.5, transformation=transform, augmentation=None) 36 | elif args.dataset == 'acdc': 37 | test_set = ACDCDataset(root_path=root_acdc, name='test', ratio=0.5, transformation=transform, augmentation=None) 38 | 39 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) 40 | 41 | Gsi = define_Gen(input_nc=3, output_nc=n_channels, ngf=args.ngf, netG='resnet_9blocks_softmax', 42 | norm=args.norm, use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) 43 | 44 | ### activation_softmax 45 | activation_softmax = nn.Softmax2d() 46 | 47 | if(args.model == 'supervised_model'): 48 | 49 | ### loading the checkpoint 50 | try: 51 | ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir)) 52 | Gsi.load_state_dict(ckpt['Gsi']) 53 | 54 | except: 55 | print(' [*] No checkpoint!') 56 | 57 | ### run 58 | Gsi.eval() 59 | for i, (image_test, image_name) in enumerate(test_loader): 60 | image_test = utils.cuda(image_test, args.gpu_ids) 61 | seg_map = Gsi(image_test) 62 | seg_map = activation_softmax(seg_map) 63 | 64 | prediction = seg_map.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() ### To convert from 22 --> 1 channel 65 | for j in range(prediction.shape[0]): 66 | new_img = prediction[j] ### Taking a particular image from the batch 67 | new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image 68 | 69 | ### Now the new_img is PIL.Image 70 | new_img.save(os.path.join(args.results_dir+'/supervised/'+image_name[j]+'.png')) 71 | 72 | 73 | print('Epoch-', str(i+1), ' Done!') 74 | 75 | elif(args.model == 'semisupervised_cycleGAN'): 76 | 77 | ### loading the checkpoint 78 | try: 79 | ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) 80 | Gsi.load_state_dict(ckpt['Gsi']) 81 | 82 | except: 83 | print(' [*] No checkpoint!') 84 | 85 | ### run 86 | Gsi.eval() 87 | for i, (image_test, image_name) in enumerate(test_loader): 88 | image_test = utils.cuda(image_test, args.gpu_ids) 89 | seg_map = Gsi(image_test) 90 | seg_map = activation_softmax(seg_map) 91 | 92 | prediction = seg_map.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() ### To convert from 22 --> 1 channel 93 | for j in range(prediction.shape[0]): 94 | new_img = prediction[j] ### Taking a particular image from the batch 95 | new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image 96 | 97 | ### Now the new_img is PIL.Image 98 | new_img.save(os.path.join(args.results_dir+'/unsupervised/'+image_name[j]+'.png')) 99 | 100 | print('Epoch-', str(i+1), ' Done!') 101 | 102 | 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Revisting Cycle-GAN for semi-supervised segmentation 2 | This repo contains the official Pytorch implementation of the paper: [Revisiting CycleGAN for semi-supervised segmentation](https://arxiv.org/abs/1908.11569) 3 | 4 | ## Contents 5 | 1. [Summary of the Model](#1-summary-of-the-model) 6 | 2. [Setup instructions and dependancies](#2-setup-instructions-and-dependancies) 7 | 3. [Repository Overview](#3-repository-overview) 8 | 4. [Running the model](#4-running-the-model) 9 | 5. [Some results of the paper](#5-some-results-of-the-paper) 10 | 6. [Contact](#6-contact) 11 | 7. [License](#7-license) 12 | 13 | ## 1. Summary of the Model 14 | The following shows the training procedure for our proposed model 15 | 16 | 17 | 18 | We propose a training procedure for semi-supervised segmentation using the principles of image-to-image translation using GANs. The proposed procedure has been evaluated on three segmentation datasets, namely **VOC**, **Cityscapes**, **ACDC**. We are easily able to achieve *2-4% improvement in the mean IoU for all of our semisupervised model* as compared to the supervised model on the same amount of data. For further information regarding the model, training procedure details you may refer to the paper for further details. 19 | 20 | ## 2. Setup instructions and dependancies 21 | The code has been written in *Python 3.6* and *Pytorch v1.0* with *Torchvision v0.3*. You can install all the dependancies for the model by running the following command in a virtual environment 22 | ``` 23 | pip install -r requirements.txt 24 | ``` 25 | For training/testing the model, you must first download all the 3 datasets. You can download the processed version of the dataset [here](https://drive.google.com/drive/folders/1v_ahAjNGPHjw9G15dlxjImtolne6HZ0B?usp=sharing). Also for storing the results of the validation/testing datasets, checkpoints and tensorboard logs, the directory structure must in the following way: 26 | 27 | . 28 | ├── arch 29 | │ ├── ... 30 | ├── data # Follow the way the dataset has been placed here 31 | │ ├── ACDC # Here the ACDC dataset must be placed 32 | │ └── Cityscape # Here the Cityscapes dataset must be placed 33 | │ └── VOC2012 # Here the VOC train/val dataset must be placed 34 | │ └── VOC2012test # Here the VOC test dataset must be placed 35 | ├── data_utils 36 | │ ├── ... 37 | ├── checkpoints # Create this directory to store checkpoints 38 | ├── examples 39 | ├── main.py 40 | ├── model.py 41 | ├── README.md 42 | ├── testing.py 43 | ├── utils.py 44 | ├── validation.py 45 | ├── results # Create this directory for storing the results 46 | │ ├── supervised # Directory for storing supervised results 47 | │ └── unsupervised # Directory for storing semisupervised results 48 | └── tensorboard_results # Create this directory to store tensorboard log curves 49 | 50 | 51 | ## 3. Repository Overview 52 | The following are the information regarding the various important files in the directory and their function: 53 | 54 | - `arch` : The directory stores the architectures for the generators and discriminators used in our model 55 | - `data_utils` : The dataloaders and also helper functions for data processing 56 | - `main.py` : Stores the various hyperparameter information and default settings 57 | - `model.py` : Stores the training procedure for both supervised and semisupervised model, and also checkpointing and logging details 58 | - `utils.py` : Stores the helper functions required for training 59 | 60 | ## 4. Running the model 61 | You configure the various defaults that are being specified in the `main.py` file. And also modify the supervision percentage on the dataset by modifying the dataloader calling function in the `model.py` file. 62 | 63 | For training/validation/testing the our proposed semisupervised model: 64 | 65 | ``` 66 | python main.py --model 'semisupervised_cycleGAN' --dataset 'voc2012' --gpu_ids '0' --training True 67 | ``` 68 | 69 | Similar commands for the *validation* and *testing* can be put up by replacing `--training` with `--validation` and `--testing` respectively. 70 | 71 | ## 5. Some results of the paper 72 | Some of the results produced by our semisupervised model are as follows. *For more such results, consider seeing the main paper and also its supplementary section* 73 | 74 | 75 | 76 | *The generated labels are the labels obtained from semisupervised model while the generated image are the images obtained by passing labels from label to image network* 77 | 78 | ## 6. Contact 79 | If you have found our research work helpful, please consider citing the original paper. 80 | 81 | If you have any doubt regarding the codebase, you can open up an issue or mail at sanu.arnab@gmail.com / aagarwal@ma.iitr.ac.in 82 | 83 | ## 7. License 84 | This repository is licensed under MIT license 85 | -------------------------------------------------------------------------------- /data_utils/augmentations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import numpy as np 5 | import torchvision.transforms.functional as tf 6 | from PIL import Image, ImageOps 7 | 8 | 9 | class Compose(object): 10 | def __init__(self, augmentations): 11 | self.augmentations = augmentations 12 | self.PIL2Numpy = False 13 | 14 | def __call__(self, img, mask): 15 | if isinstance(img, np.ndarray): 16 | img = Image.fromarray(img, mode="RGB") 17 | mask = Image.fromarray(mask, mode="L") 18 | self.PIL2Numpy = True 19 | 20 | assert img.size == mask.size 21 | for a in self.augmentations: 22 | img, mask = a(img, mask) 23 | 24 | if self.PIL2Numpy: 25 | img, mask = np.array(img), np.array(mask, dtype=np.uint8) 26 | 27 | return img, mask 28 | 29 | 30 | class RandomCrop(object): 31 | def __init__(self, size, padding=0): 32 | if isinstance(size, numbers.Number): 33 | self.size = (int(size), int(size)) 34 | else: 35 | self.size = size 36 | self.padding = padding 37 | 38 | def __call__(self, img, mask): 39 | if self.padding > 0: 40 | img = ImageOps.expand(img, border=self.padding, fill=0) 41 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 42 | 43 | assert img.size == mask.size 44 | w, h = img.size 45 | th, tw = self.size 46 | if w == tw and h == th: 47 | return img, mask 48 | if w < tw or h < th: 49 | return ( 50 | img.resize((tw, th), Image.BILINEAR), 51 | mask.resize((tw, th), Image.NEAREST), 52 | ) 53 | 54 | x1 = random.randint(0, w - tw) 55 | y1 = random.randint(0, h - th) 56 | return ( 57 | img.crop((x1, y1, x1 + tw, y1 + th)), 58 | mask.crop((x1, y1, x1 + tw, y1 + th)), 59 | ) 60 | 61 | 62 | class CenterCrop(object): 63 | def __init__(self, size): 64 | if isinstance(size, numbers.Number): 65 | self.size = (int(size), int(size)) 66 | else: 67 | self.size = size 68 | 69 | def __call__(self, img, mask): 70 | assert img.size == mask.size 71 | w, h = img.size 72 | th, tw = self.size 73 | x1 = int(round((w - tw) / 2.)) 74 | y1 = int(round((h - th) / 2.)) 75 | return ( 76 | img.crop((x1, y1, x1 + tw, y1 + th)), 77 | mask.crop((x1, y1, x1 + tw, y1 + th)), 78 | ) 79 | 80 | 81 | class RandomRotate(object): 82 | def __init__(self, degree): 83 | self.degree = degree 84 | 85 | def __call__(self, img, mask): 86 | rotate_degree = random.random() * 2 * self.degree - self.degree 87 | return ( 88 | tf.affine(img, 89 | translate=(0, 0), 90 | scale=1.0, 91 | angle=rotate_degree, 92 | resample=Image.BILINEAR, 93 | fillcolor=(0, 0, 0), 94 | shear=0.0), 95 | tf.affine(mask, 96 | translate=(0, 0), 97 | scale=1.0, 98 | angle=rotate_degree, 99 | resample=Image.NEAREST, 100 | fillcolor=250, 101 | shear=0.0)) 102 | 103 | 104 | class Scale(object): 105 | def __init__(self, size): 106 | self.size = size 107 | 108 | def __call__(self, img, mask): 109 | assert img.size == mask.size 110 | w, h = img.size 111 | if (w >= h and w == self.size) or (h >= w and h == self.size): 112 | return img, mask 113 | if w > h: 114 | ow = self.size 115 | oh = int(self.size * h / w) 116 | return ( 117 | img.resize((ow, oh), Image.BILINEAR), 118 | mask.resize((ow, oh), Image.NEAREST), 119 | ) 120 | else: 121 | oh = self.size 122 | ow = int(self.size * w / h) 123 | return ( 124 | img.resize((ow, oh), Image.BILINEAR), 125 | mask.resize((ow, oh), Image.NEAREST), 126 | ) 127 | 128 | 129 | class RandomSizedCrop(object): 130 | def __init__(self, size): 131 | self.size = size 132 | 133 | def __call__(self, img, mask): 134 | assert img.size == mask.size 135 | for attempt in range(10): 136 | area = img.size[0] * img.size[1] 137 | target_area = random.uniform(0.45, 1.0) * area 138 | aspect_ratio = random.uniform(0.5, 2) 139 | 140 | w = int(round(math.sqrt(target_area * aspect_ratio))) 141 | h = int(round(math.sqrt(target_area / aspect_ratio))) 142 | 143 | if random.random() < 0.5: 144 | w, h = h, w 145 | 146 | if w <= img.size[0] and h <= img.size[1]: 147 | x1 = random.randint(0, img.size[0] - w) 148 | y1 = random.randint(0, img.size[1] - h) 149 | 150 | img = img.crop((x1, y1, x1 + w, y1 + h)) 151 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 152 | assert img.size == (w, h) 153 | 154 | return ( 155 | img.resize((self.size, self.size), Image.BILINEAR), 156 | mask.resize((self.size, self.size), Image.NEAREST), 157 | ) 158 | 159 | # Fallback 160 | scale = Scale(self.size) 161 | crop = CenterCrop(self.size) 162 | return crop(*scale(img, mask)) 163 | 164 | 165 | class RandomSized(object): 166 | def __init__(self, size): 167 | self.size = size 168 | self.scale = Scale(self.size) 169 | self.crop = RandomCrop(self.size) 170 | 171 | def __call__(self, img, mask): 172 | assert img.size == mask.size 173 | 174 | w = int(random.uniform(0.5, 2) * img.size[0]) 175 | h = int(random.uniform(0.5, 2) * img.size[1]) 176 | 177 | img, mask = ( 178 | img.resize((w, h), Image.BILINEAR), 179 | mask.resize((w, h), Image.NEAREST), 180 | ) 181 | 182 | return self.crop(*self.scale(img, mask)) 183 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import * 2 | import numpy as np 3 | import torch 4 | import random 5 | from PIL import ImageOps, ImageFilter, ImageEnhance 6 | from .dataloader import VOCDataset, CityscapesDataset, ACDCDataset 7 | 8 | 9 | def colormap(n): 10 | cmap = np.zeros([n, 3]).astype(np.uint8) 11 | 12 | for i in np.arange(n): 13 | r, g, b = np.zeros(3) 14 | 15 | for j in np.arange(8): 16 | r = r + (1 << (7 - j)) * ((i & (1 << (3 * j))) >> (3 * j)) 17 | g = g + (1 << (7 - j)) * ((i & (1 << (3 * j + 1))) >> (3 * j + 1)) 18 | b = b + (1 << (7 - j)) * ((i & (1 << (3 * j + 2))) >> (3 * j + 2)) 19 | 20 | cmap[i, :] = np.array([r, g, b]) 21 | 22 | return cmap 23 | 24 | 25 | class Relabel: 26 | def __init__(self, olabel, nlabel): 27 | ''' 28 | Converts a particular label value in the tensor to another 29 | 30 | Parameters 31 | ---------- 32 | olabel : old label which needs to be changed 33 | nlabel : new label which will replace the old label 34 | ''' 35 | self.olabel = olabel 36 | self.nlabel = nlabel 37 | 38 | def __call__(self, tensor): 39 | assert isinstance(tensor, torch.LongTensor), 'tensor needs to be LongTensor' 40 | tensor[tensor == self.olabel] = self.nlabel 41 | return tensor 42 | 43 | 44 | class ToLabel: 45 | 46 | def __call__(self, image): 47 | ''' 48 | Used to change the image from dim(N x H x W) to dim(N x 1 x H x W) 49 | ''' 50 | return torch.from_numpy(np.array(image)).long().unsqueeze(0) 51 | 52 | 53 | class Colorize: 54 | 55 | def __init__(self, n=22): 56 | self.cmap = colormap(256) 57 | self.cmap[n] = self.cmap[-1] 58 | self.cmap = torch.from_numpy(self.cmap[:n]) 59 | 60 | def __call__(self, gray_image): 61 | size = gray_image.size() 62 | 63 | try: 64 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 65 | except: 66 | color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) 67 | 68 | for label in range(1, len(self.cmap)): 69 | mask = gray_image == label 70 | 71 | color_image[0][mask] = self.cmap[label][0] 72 | color_image[1][mask] = self.cmap[label][1] 73 | color_image[2][mask] = self.cmap[label][2] 74 | 75 | return color_image 76 | 77 | 78 | def PILaugment(img, mask): 79 | if random.random() > 0.2: 80 | (w, h) = img.size 81 | (w_, h_) = mask.size 82 | assert (w == w_ and h == h_), 'The size should be the same.' 83 | crop = random.uniform(0.45, 0.75) 84 | W = int(crop * w) 85 | H = int(crop * h) 86 | start_x = w - W 87 | start_y = h - H 88 | x_pos = int(random.uniform(0, start_x)) 89 | y_pos = int(random.uniform(0, start_y)) 90 | img = img.crop((x_pos, y_pos, x_pos + W, y_pos + H)) 91 | mask = mask.crop((x_pos, y_pos, x_pos + W, y_pos + H)) 92 | 93 | if random.random() > 0.2: 94 | img = ImageOps.flip(img) 95 | mask = ImageOps.flip(mask) 96 | 97 | if random.random() > 0.2: 98 | img = ImageOps.mirror(img) 99 | mask = ImageOps.mirror(mask) 100 | 101 | if random.random() > 0.2: 102 | angle = random.random() * 90 - 45 103 | img = img.rotate(angle) 104 | mask = mask.rotate(angle) 105 | if random.random() > 0.95: 106 | img = img.filter(ImageFilter.GaussianBlur(2)) 107 | 108 | if random.random() > 0.95: 109 | img = ImageEnhance.Contrast(img).enhance(1) 110 | 111 | if random.random() > 0.95: 112 | img = ImageEnhance.Brightness(img).enhance(1) 113 | 114 | return img, mask 115 | 116 | 117 | def get_transformation(size, resize=False, dataset = 'voc2012'): 118 | ''' 119 | Used to return a transformation based on the size given, that is, returns an apt sized tensor 120 | If resize = True then Resizing else CenterCrop 121 | ''' 122 | 123 | assert dataset in ['voc2012', 'cityscapes', 'acdc'], 'The dataset name must be set correctly in the get_transformation function' 124 | if dataset == 'voc2012' or dataset == 'cityscapes': 125 | if resize: 126 | transfom_lst = [ 127 | Resize(size), 128 | CenterCrop(size), 129 | ToTensor(), 130 | Normalize([.5, .5, .5], [.5, .5, .5]) 131 | ] 132 | else: 133 | transfom_lst = [ 134 | CenterCrop(size), 135 | ToTensor(), 136 | Normalize([.5, .5, .5], [.5, .5, .5]) 137 | ] 138 | elif dataset == 'acdc': 139 | if resize: 140 | transfom_lst = [ 141 | Resize(size), 142 | CenterCrop(size), 143 | ToTensor(), 144 | Normalize([.5], [.5]) 145 | ] 146 | else: 147 | transfom_lst = [ 148 | CenterCrop(size), 149 | ToTensor(), 150 | Normalize([.5], [.5]) 151 | ] 152 | 153 | input_transform = Compose(transfom_lst) 154 | 155 | ### This is because we don't need the Relabel function for the case of cityscapes dataset 156 | if dataset == 'voc2012': 157 | if resize: 158 | target_transform = Compose([ 159 | Resize(size, interpolation=0), 160 | CenterCrop(size), 161 | ToLabel(), 162 | Relabel(255, 0) ## So as to replace the 255(boundaries) label as 0 163 | ]) 164 | 165 | else: 166 | target_transform = Compose([ 167 | CenterCrop(size), 168 | ToLabel(), 169 | Relabel(255, 0) ## So as to replace the 255(boundaries) label as 0 170 | ]) 171 | 172 | elif dataset == 'cityscapes' or dataset == 'acdc': 173 | if resize: 174 | target_transform = Compose([ 175 | Resize(size, interpolation=0), 176 | CenterCrop(size), 177 | ToLabel() 178 | ]) 179 | 180 | else: 181 | target_transform = Compose([ 182 | CenterCrop(size), 183 | ToLabel(), 184 | ]) 185 | 186 | transform = {'img': input_transform, 'gt': target_transform} 187 | return transform 188 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torchvision 7 | import torchvision.datasets as dsets 8 | from torch.utils.data import DataLoader 9 | import torchvision.transforms as transforms 10 | import utils 11 | from utils import make_one_hot 12 | from PIL import Image 13 | from arch import define_Gen 14 | from data_utils import VOCDataset, CityscapesDataset, ACDCDataset, get_transformation 15 | 16 | root = './data/VOC2012' 17 | root_cityscapes = './data/Cityscape' 18 | root_acdc = './data/ACDC' 19 | 20 | def validation(args): 21 | 22 | ### For selecting the number of channels 23 | if args.dataset == 'voc2012': 24 | n_channels = 21 25 | elif args.dataset == 'cityscapes': 26 | n_channels = 20 27 | elif args.dataset == 'acdc': 28 | n_channels = 4 29 | 30 | transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset) 31 | 32 | ## let the choice of dataset configurable 33 | if args.dataset == 'voc2012': 34 | val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=transform, augmentation=None) 35 | elif args.dataset == 'cityscapes': 36 | val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform, augmentation=None) 37 | elif args.dataset == 'acdc': 38 | val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform, augmentation=None) 39 | 40 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) 41 | 42 | Gsi = define_Gen(input_nc=3, output_nc=n_channels, ngf=args.ngf, netG='deeplab', 43 | norm=args.norm, use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) 44 | 45 | Gis = define_Gen(input_nc=n_channels, output_nc=3, ngf=args.ngf, netG='deeplab', 46 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) 47 | 48 | ### best_iou 49 | best_iou = 0 50 | 51 | ### Interpolation 52 | interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True) 53 | 54 | ### Softmax activation 55 | activation_softmax = nn.Softmax2d() 56 | activation_tanh = nn.Tanh() 57 | 58 | if(args.model == 'supervised_model'): 59 | 60 | ### loading the checkpoint 61 | try: 62 | ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir)) 63 | Gsi.load_state_dict(ckpt['Gsi']) 64 | best_iou = ckpt['best_iou'] 65 | 66 | except: 67 | print(' [*] No checkpoint!') 68 | 69 | ### run 70 | Gsi.eval() 71 | for i, (image_test, real_segmentation, image_name) in enumerate(val_loader): 72 | image_test = utils.cuda(image_test, args.gpu_ids) 73 | seg_map = Gsi(image_test) 74 | seg_map = interp(seg_map) 75 | seg_map = activation_softmax(seg_map) 76 | 77 | prediction = seg_map.data.max(1)[1].squeeze_(1).cpu().numpy() ### To convert from 22 --> 1 channel 78 | for j in range(prediction.shape[0]): 79 | new_img = prediction[j] ### Taking a particular image from the batch 80 | new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image 81 | 82 | ### Now the new_img is PIL.Image 83 | new_img.save(os.path.join(args.validation_dir+'/supervised/'+image_name[j]+'.png')) 84 | 85 | 86 | print('Epoch-', str(i+1), ' Done!') 87 | 88 | print('The iou of the resulting segment maps: ', str(best_iou)) 89 | 90 | 91 | elif(args.model == 'semisupervised_cycleGAN'): 92 | 93 | ### loading the checkpoint 94 | try: 95 | ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) 96 | Gsi.load_state_dict(ckpt['Gsi']) 97 | Gis.load_state_dict(ckpt['Gis']) 98 | best_iou = ckpt['best_iou'] 99 | 100 | except: 101 | print(' [*] No checkpoint!') 102 | 103 | ### run 104 | Gsi.eval() 105 | for i, (image_test, real_segmentation, image_name) in enumerate(val_loader): 106 | image_test, real_segmentation = utils.cuda([image_test, real_segmentation], args.gpu_ids) 107 | seg_map = Gsi(image_test) 108 | seg_map = interp(seg_map) 109 | seg_map = activation_softmax(seg_map) 110 | fake_img = Gis(seg_map).detach() 111 | fake_img = interp(fake_img) 112 | fake_img = activation_tanh(fake_img) 113 | 114 | fake_img_from_labels = Gis(make_one_hot(real_segmentation, args.dataset, args.gpu_ids).float()).detach() 115 | fake_img_from_labels = interp(fake_img_from_labels) 116 | fake_img_from_labels = activation_tanh(fake_img_from_labels) 117 | fake_label_regenerated = Gsi(fake_img_from_labels).detach() 118 | fake_label_regenerated = interp(fake_label_regenerated) 119 | fake_label_regenerated = activation_softmax(fake_label_regenerated) 120 | 121 | prediction = seg_map.data.max(1)[1].squeeze_(1).cpu().numpy() ### To convert from 22 --> 1 channel 122 | fake_regenerated_label = fake_label_regenerated.data.max(1)[1].squeeze_(1).cpu().numpy() 123 | 124 | fake_img = fake_img.cpu() 125 | fake_img_from_labels = fake_img_from_labels.cpu() 126 | 127 | ### Now i am going to revert back the transformation on these images 128 | if args.dataset == 'voc2012' or args.dataset == 'cityscapes': 129 | trans_mean = [0.5, 0.5, 0.5] 130 | trans_std = [0.5, 0.5, 0.5] 131 | for k in range(3): 132 | fake_img[:, k, :, :] = ((fake_img[:, k, :, :] * trans_std[k]) + trans_mean[k]) 133 | fake_img_from_labels[:, k, :, :] = ((fake_img_from_labels[:, k, :, :] * trans_std[k]) + trans_mean[k]) 134 | 135 | elif args.dataset == 'acdc': 136 | trans_mean = [0.5] 137 | trans_std = [0.5] 138 | for k in range(1): 139 | fake_img[:, k, :, :] = ((fake_img[:, k, :, :] * trans_std[k]) + trans_mean[k]) 140 | fake_img_from_labels[:, k, :, :] = ((fake_img_from_labels[:, k, :, :] * trans_std[k]) + trans_mean[k]) 141 | 142 | for j in range(prediction.shape[0]): 143 | new_img = prediction[j] ### Taking a particular image from the batch 144 | new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image 145 | 146 | regen_label = fake_regenerated_label[j] 147 | regen_label = utils.colorize_mask(regen_label, args.dataset) 148 | 149 | ### Now the new_img is PIL.Image 150 | new_img.save(os.path.join(args.validation_dir+'/unsupervised/generated_labels/'+image_name[j]+'.png')) 151 | regen_label.save(os.path.join(args.validation_dir+'/unsupervised/regenerated_labels/'+image_name[j]+'.png')) 152 | torchvision.utils.save_image(fake_img[j], os.path.join(args.validation_dir+'/unsupervised/regenerated_image/'+image_name[j]+'.jpg')) 153 | torchvision.utils.save_image(fake_img_from_labels[j], os.path.join(args.validation_dir+'/unsupervised/image_from_labels/'+image_name[j]+'.jpg')) 154 | 155 | print('Epoch-', str(i+1), ' Done!') 156 | 157 | print('The iou of the resulting segment maps: ', str(best_iou)) 158 | 159 | -------------------------------------------------------------------------------- /data_utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | import scipy.misc as m 6 | import scipy.io as sio 7 | 8 | from torch.utils.data import Dataset 9 | from utils import recursive_glob 10 | from PIL import Image 11 | from .augmentations import * 12 | import re 13 | 14 | 15 | class VOCDataset(Dataset): 16 | ''' 17 | We assume that there will be txt to note all the image names 18 | 19 | 20 | color map: 21 | 0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle # 6=bus, 7=car, 8=cat, 9=chair, 10=cow, 11=diningtable, 22 | 12=dog, 13=horse, 14=motorbike, 15=person # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor, 23 | 21=boundaries(self-defined) 24 | 25 | Also it will return an image and ground truth as it is present in the image form, so ground truth won't 26 | be in one-hot form and rather would be a 2D tensor. To convert the labels in one-hot form in the training 27 | code we will be calling the function 'make_one_hot' function of utils.py 28 | 29 | ''' 30 | 31 | def __init__(self, root_path, name='label', ratio=0.5, transformation=None, augmentation=None): 32 | super(VOCDataset, self).__init__() 33 | self.root_path = root_path 34 | self.ratio = ratio 35 | self.name = name 36 | assert transformation is not None, 'transformation must be provided, give None' 37 | self.transformation = transformation 38 | self.augmentation = augmentation 39 | assert name in ('label', 'unlabel','val', 'test'), 'dataset name should be restricted in "label", "unlabel", "test" and "val", given %s' % name 40 | assert 0 <= ratio <= 1, 'the ratio between "labeled" and "unlabeled" should be between 0 and 1, given %.1f' % ratio 41 | np.random.seed(1) ### Because of this we are not getting repeated images for labelled and unlabelled data 42 | 43 | if self.name != 'test': 44 | train_imgs = pd.read_table(os.path.join(self.root_path, 'ImageSets/Segmentation', 'trainvalAug.txt')).values.reshape(-1) 45 | 46 | val_imgs = pd.read_table(os.path.join(self.root_path, 'ImageSets/Segmentation', 'val.txt')).values.reshape(-1) 47 | 48 | labeled_imgs = np.random.choice(train_imgs, size=int(self.ratio * train_imgs.__len__()), replace=False) 49 | labeled_imgs = list(labeled_imgs) 50 | 51 | unlabeled_imgs = [x for x in train_imgs if x not in labeled_imgs] 52 | 53 | ### Now here we equalize the lengths of labelled and unlabelled imgs by just repeating up some images 54 | if self.ratio > 0.5: 55 | new_ratio = round((self.ratio/(1-self.ratio + 1e-6)), 1) 56 | excess_ratio = new_ratio - 1 57 | new_list_1 = unlabeled_imgs * int(excess_ratio) 58 | new_list_2 = list(np.random.choice(np.array(unlabeled_imgs), size=int((excess_ratio - int(excess_ratio))*unlabeled_imgs.__len__()), replace=False)) 59 | unlabeled_imgs += (new_list_1 + new_list_2) 60 | elif self.ratio < 0.5: 61 | new_ratio = round(((1-self.ratio)/(self.ratio + 1e-6)), 1) 62 | excess_ratio = new_ratio - 1 63 | new_list_1 = labeled_imgs * int(excess_ratio) 64 | new_list_2 = list(np.random.choice(np.array(labeled_imgs), size=int((excess_ratio - int(excess_ratio))*labeled_imgs.__len__()), replace=False)) 65 | labeled_imgs += (new_list_1 + new_list_2) 66 | 67 | if self.name == 'test': 68 | test_imgs = pd.read_table(os.path.join(self.root_path, 'ImageSets/Segmentation', 'test.txt')).values.reshape(-1) 69 | # test_imgs = np.array(test_imgs) 70 | 71 | if self.name == 'label': 72 | self.imgs = labeled_imgs 73 | elif self.name == 'unlabel': 74 | self.imgs = unlabeled_imgs 75 | elif self.name == 'val': 76 | self.imgs = val_imgs 77 | elif self.name == 'test': 78 | self.imgs = test_imgs 79 | else: 80 | raise ('{} not defined'.format(self.name)) 81 | 82 | self.gts = self.imgs 83 | 84 | def __getitem__(self, index): 85 | 86 | if self.name == 'test': 87 | img_path = os.path.join(self.root_path, 'JPEGImages', self.imgs[index] + '.jpg') 88 | 89 | img = Image.open(img_path).convert('RGB') 90 | 91 | if self.augmentation is not None: 92 | img = self.augmentation(img) 93 | 94 | if self.transformation: 95 | img = self.transformation['img'](img) 96 | 97 | return img, self.imgs[index] 98 | 99 | else: 100 | img_path = os.path.join(self.root_path, 'JPEGImages', self.imgs[index] + '.jpg') 101 | gt_path = os.path.join(self.root_path, 'SegmentationClassAug', self.gts[index] + '.png') 102 | 103 | img = Image.open(img_path).convert('RGB') 104 | gt = Image.open(gt_path) #.convert('P') 105 | 106 | if self.augmentation is not None: 107 | img, gt = self.augmentation(img, gt) 108 | 109 | if self.transformation: 110 | img = self.transformation['img'](img) 111 | gt = self.transformation['gt'](gt) 112 | 113 | return img, gt, self.imgs[index] 114 | 115 | def __len__(self): 116 | return len(self.imgs) 117 | 118 | 119 | class CityscapesDataset(Dataset): 120 | 121 | def __init__( 122 | self, 123 | root_path, 124 | name="train", 125 | ratio=0.5, 126 | transformation=False, 127 | augmentation=None 128 | ): 129 | self.root = root_path 130 | self.name = name 131 | assert transformation is not None, 'transformation must be provided, give None' 132 | self.transformation = transformation 133 | self.augmentation = augmentation 134 | self.n_classes = 20 135 | self.ratio = ratio 136 | self.files = {} 137 | 138 | assert name in ('label', 'unlabel','val', 'test'), 'dataset name should be restricted in "label", "unlabel", "test" and "val", given %s' % name 139 | 140 | if self.name != 'test': 141 | self.images_base = os.path.join(self.root, "leftImg8bit") 142 | self.annotations_base = os.path.join( 143 | self.root, "gtFine", 'trainval' 144 | ) 145 | else: 146 | self.images_base = os.path.join(self.root, "leftImg8bit", 'test') 147 | # self.annotations_base = os.path.join( 148 | # self.root, "gtFine", 'test' 149 | # ) 150 | 151 | np.random.seed(1) 152 | 153 | if self.name != 'test': 154 | train_imgs = recursive_glob(rootdir=os.path.join(self.images_base, 'train'), suffix=".png") 155 | train_imgs = np.array(train_imgs) 156 | 157 | val_imgs = recursive_glob(rootdir=os.path.join(self.images_base, 'val'), suffix=".png") 158 | 159 | labeled_imgs = np.random.choice(train_imgs, size=int(self.ratio * train_imgs.__len__()), replace=False) 160 | labeled_imgs = list(labeled_imgs) 161 | 162 | unlabeled_imgs = [x for x in train_imgs if x not in labeled_imgs] 163 | 164 | ### Now here we equalize the lengths of labelled and unlabelled imgs by just repeating up some images 165 | if self.ratio > 0.5: 166 | new_ratio = round((self.ratio/(1-self.ratio + 1e-6)), 1) 167 | excess_ratio = new_ratio - 1 168 | new_list_1 = unlabeled_imgs * int(excess_ratio) 169 | new_list_2 = list(np.random.choice(np.array(unlabeled_imgs), size=int((excess_ratio - int(excess_ratio))*unlabeled_imgs.__len__()), replace=False)) 170 | unlabeled_imgs += (new_list_1 + new_list_2) 171 | elif self.ratio < 0.5: 172 | new_ratio = round(((1-self.ratio)/(self.ratio + 1e-6)), 1) 173 | excess_ratio = new_ratio - 1 174 | new_list_1 = labeled_imgs * int(excess_ratio) 175 | new_list_2 = list(np.random.choice(np.array(labeled_imgs), size=int((excess_ratio - int(excess_ratio))*labeled_imgs.__len__()), replace=False)) 176 | labeled_imgs += (new_list_1 + new_list_2) 177 | 178 | else: 179 | test_imgs = recursive_glob(rootdir=self.images_base, suffix=".png") 180 | 181 | if(self.name == 'label'): 182 | self.files[name] = list(labeled_imgs) 183 | elif(self.name == 'unlabel'): 184 | self.files[name] = list(unlabeled_imgs) 185 | elif(self.name == 'val'): 186 | self.files[name] = list(val_imgs) 187 | elif(self.name == 'test'): 188 | self.files[name] = list(test_imgs) 189 | 190 | ''' 191 | This pattern for the various classes has been borrowed from the official repo for Cityscapes dataset 192 | You can see it here: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 193 | ''' 194 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 195 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 196 | self.class_names = [ 197 | "road", 198 | "sidewalk", 199 | "building", 200 | "wall", 201 | "fence", 202 | "pole", 203 | "traffic_light", 204 | "traffic_sign", 205 | "vegetation", 206 | "terrain", 207 | "sky", 208 | "person", 209 | "rider", 210 | "car", 211 | "truck", 212 | "bus", 213 | "train", 214 | "motorcycle", 215 | "bicycle", 216 | "unlabelled" 217 | ] 218 | 219 | self.ignore_index = 250 220 | self.class_map = dict(zip(self.valid_classes, range(19))) 221 | 222 | if not self.files[name]: 223 | raise Exception( 224 | "No files for name=[%s] found in %s" % (name, self.images_base) 225 | ) 226 | 227 | print("Found %d %s images" % (len(self.files[name]), name)) 228 | 229 | def __len__(self): 230 | """__len__""" 231 | return len(self.files[self.name]) 232 | 233 | def __getitem__(self, index): 234 | """__getitem__ 235 | :param index: 236 | """ 237 | if self.name == 'test': 238 | img_path = self.files[self.name][index].rstrip() 239 | 240 | img = Image.open(img_path).convert('RGB') 241 | 242 | if self.augmentation is not None: 243 | img = self.augmentation(img) 244 | 245 | if self.transformation: 246 | img = self.transformation['img'](img) 247 | 248 | return img, re.sub(r'.*/', '', img_path[34:]).rstrip('.png') ### These numbers have been hard coded so as to get a suitable name for the model 249 | 250 | else: 251 | img_path = self.files[self.name][index].rstrip() 252 | lbl_path = os.path.join( 253 | self.annotations_base, 254 | img_path.split(os.sep)[-2], 255 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png", 256 | ) 257 | 258 | img = Image.open(img_path).convert('RGB') 259 | lbl = Image.open(lbl_path) 260 | 261 | if self.augmentation is not None: 262 | img, lbl = self.augmentation(img, lbl) 263 | 264 | if self.transformation: 265 | img = self.transformation['img'](img) 266 | lbl = self.transformation['gt'](lbl) 267 | 268 | lbl = self.encode_segmap(lbl) 269 | 270 | return img, lbl, re.sub(r'.*/', '', img_path[38:]).rstrip('.png') ### These numbers have been hard coded so as to get a suitable name for the model 271 | 272 | def encode_segmap(self, mask): 273 | # Put all void classes to ignore index 274 | for _voidc in self.void_classes: 275 | mask[mask == _voidc] = self.ignore_index 276 | for _validc in self.valid_classes: 277 | mask[mask == _validc] = self.class_map[_validc] 278 | mask[mask == self.ignore_index] = 19 ### Just a mapping between the two color values 279 | return mask 280 | 281 | 282 | class ACDCDataset(Dataset): 283 | ''' 284 | The dataloader for ACDC dataset 285 | ''' 286 | 287 | split_ratio = [0.85, 0.15] 288 | ''' 289 | this split ratio is for the train (including the labeled and unlabeled) and the val dataset 290 | ''' 291 | 292 | def __init__(self, root_path, name='label', ratio=0.5, transformation=None, augmentation=None): 293 | super(ACDCDataset, self).__init__() 294 | self.root = root_path 295 | self.name = name 296 | assert transformation is not None, 'transformation must be provided, give None' 297 | self.transformation = transformation 298 | self.augmentation = augmentation 299 | self.ratio = ratio 300 | self.files = {} 301 | 302 | if self.name != 'test': 303 | self.images_base = os.path.join(self.root, 'training') 304 | self.annotations_base = os.path.join(self.root, 'training_gt') 305 | else: 306 | self.images_base = os.path.join(self.root, 'testing') 307 | # self.annotations_base = os.path.join( 308 | # self.root, "gtFine", 'test' 309 | # ) 310 | 311 | np.random.seed(1) 312 | 313 | if self.name != 'test': 314 | total_imgs = os.listdir(self.images_base) 315 | total_imgs = np.array(total_imgs) 316 | 317 | train_imgs = np.random.choice(total_imgs, size=int(self.__class__.split_ratio[0] * total_imgs.__len__()),replace=False) 318 | val_imgs = [x for x in total_imgs if x not in train_imgs] 319 | 320 | labeled_imgs = np.random.choice(train_imgs, size=int(self.ratio * train_imgs.__len__()), replace=False) 321 | labeled_imgs = list(labeled_imgs) 322 | 323 | unlabeled_imgs = [x for x in train_imgs if x not in labeled_imgs] 324 | 325 | ### Now here we equalize the lengths of labelled and unlabelled imgs by just repeating up some images 326 | if self.ratio > 0.5: 327 | new_ratio = round((self.ratio/(1-self.ratio + 1e-6)), 1) 328 | excess_ratio = new_ratio - 1 329 | new_list_1 = unlabeled_imgs * int(excess_ratio) 330 | new_list_2 = list(np.random.choice(np.array(unlabeled_imgs), size=int((excess_ratio - int(excess_ratio))*unlabeled_imgs.__len__()), replace=False)) 331 | unlabeled_imgs += (new_list_1 + new_list_2) 332 | elif self.ratio < 0.5: 333 | new_ratio = round(((1-self.ratio)/(self.ratio + 1e-6)), 1) 334 | excess_ratio = new_ratio - 1 335 | new_list_1 = labeled_imgs * int(excess_ratio) 336 | new_list_2 = list(np.random.choice(np.array(labeled_imgs), size=int((excess_ratio - int(excess_ratio))*labeled_imgs.__len__()), replace=False)) 337 | labeled_imgs += (new_list_1 + new_list_2) 338 | 339 | else: 340 | test_imgs = os.listdir(self.images_base) 341 | 342 | if(self.name == 'label'): 343 | self.files[name] = list(labeled_imgs) 344 | elif(self.name == 'unlabel'): 345 | self.files[name] = list(unlabeled_imgs) 346 | elif(self.name == 'val'): 347 | self.files[name] = list(val_imgs) 348 | elif(self.name == 'test'): 349 | self.files[name] = list(test_imgs) 350 | 351 | if not self.files[name]: 352 | raise Exception( 353 | "No files for name=[%s] found in %s" % (name, self.images_base) 354 | ) 355 | 356 | print("Found %d %s images" % (len(self.files[name]), name)) 357 | 358 | def __len__(self): 359 | """__len__""" 360 | return len(self.files[self.name]) 361 | 362 | def __getitem__(self, index): 363 | """__getitem__ 364 | :param index: 365 | """ 366 | if self.name == 'test': 367 | img_path = os.path.join(self.images_base, self.files[self.name][index]) 368 | 369 | img = Image.open(img_path) 370 | 371 | if self.augmentation is not None: 372 | img = self.augmentation(img) 373 | 374 | if self.transformation: 375 | img = self.transformation['img'](img) 376 | 377 | return img, self.files[self.name][index].rstrip('.jpg') 378 | 379 | else: 380 | img_path = os.path.join(self.images_base, self.files[self.name][index]) 381 | lbl_path = os.path.join(self.annotations_base, self.files[self.name][index].rstrip('.jpg') + '.png') 382 | 383 | img = Image.open(img_path) 384 | lbl = Image.open(lbl_path) 385 | 386 | if self.augmentation is not None: 387 | img, lbl = self.augmentation(img, lbl) 388 | 389 | if self.transformation: 390 | img = self.transformation['img'](img) 391 | lbl = self.transformation['gt'](lbl) 392 | 393 | return img, lbl, self.files[self.name][index].rstrip('.jpg') 394 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch 8 | from torch.autograd import Variable 9 | from PIL import Image 10 | from torchvision import models 11 | from collections import namedtuple 12 | from torchvision import transforms 13 | 14 | 15 | ''' 16 | The palette is used to convert the segmentation map having values from 0-21 back to paletted image 17 | ''' 18 | palette = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 19 | 128, 128, 128, 64, 0, 0, 192, 0, 0, 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128, 20 | 64, 128, 128, 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0, 0, 64, 128] 21 | 22 | cityscape_palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 23 | 250, 170, 30, 220, 220, 0, 107, 142, 35, 152, 251, 152, 0, 130, 180, 220, 20, 60, 24 | 255, 0, 0, 0, 0, 142, 0, 0, 70, 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 25 | 26 | acdc_palette = [0, 0, 0, 128, 64, 128, 70, 70, 70, 250, 170, 30] 27 | 28 | zero_pad = 256 * 3 - len(palette) 29 | for i in range(zero_pad): 30 | palette.append(0) 31 | 32 | zero_pad = 256 * 3 - len(cityscape_palette) 33 | for i in range(zero_pad): 34 | cityscape_palette.append(0) 35 | 36 | zero_pad = 256 * 3 - len(acdc_palette) 37 | for i in range(zero_pad): 38 | acdc_palette.append(0) 39 | 40 | 41 | def colorize_mask(mask, dataset): 42 | ''' 43 | Used to convert the segmentation of one channel(mask) back to a paletted image 44 | ''' 45 | # mask: numpy array of the mask 46 | assert dataset in ('voc2012', 'cityscapes', 'acdc') 47 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 48 | if (dataset == 'voc2012'): 49 | new_mask.putpalette(palette) 50 | elif (dataset == 'cityscapes'): 51 | new_mask.putpalette(cityscape_palette) 52 | elif (dataset == 'acdc'): 53 | new_mask.putpalette(acdc_palette) 54 | 55 | return new_mask 56 | 57 | ### To convert a paletted image to a tensor image of 3 dimension 58 | ### This is because a simple paletted image cannot be viewed with all the details 59 | def PIL_to_tensor(img, dataset): 60 | ''' 61 | Here img is of the type PIL.Image 62 | ''' 63 | assert dataset in ('voc2012', 'cityscapes', 'acdc') 64 | img_arr = np.array(img, dtype='float32') 65 | new_arr = np.zeros([3, img_arr.shape[0], img_arr.shape[1]], dtype='float32') 66 | 67 | if (dataset == 'voc2012'): 68 | for i in range(img_arr.shape[0]): 69 | for j in range(img_arr.shape[1]): 70 | # new_arr[i, :, :] = img_arr 71 | index = int(img_arr[i, j]*3) 72 | new_arr[0, i, j] = palette[index] 73 | new_arr[1, i, j] = palette[index+1] 74 | new_arr[2, i, j] = palette[index+2] 75 | elif (dataset == 'cityscapes'): 76 | for i in range(img_arr.shape[0]): 77 | for j in range(img_arr.shape[1]): 78 | # new_arr[i, :, :] = img_arr 79 | index = int(img_arr[i, j]*3) 80 | new_arr[0, i, j] = cityscape_palette[index] 81 | new_arr[1, i, j] = cityscape_palette[index+1] 82 | new_arr[2, i, j] = cityscape_palette[index+2] 83 | elif (dataset == 'acdc'): 84 | for i in range(img_arr.shape[0]): 85 | for j in range(img_arr.shape[1]): 86 | # new_arr[i, :, :] = img_arr 87 | index = int(img_arr[i, j]*3) 88 | new_arr[0, i, j] = acdc_palette[index] 89 | new_arr[1, i, j] = acdc_palette[index+1] 90 | new_arr[2, i, j] = acdc_palette[index+2] 91 | 92 | return_tensor = torch.tensor(new_arr) 93 | 94 | return return_tensor 95 | 96 | def smoothen_label(label, alpha, gpu_id): 97 | ''' 98 | For smoothening of the classification labels 99 | 100 | labels : tensor having dimensrions: batch_size*21*H*W filled with zeroes and ones 101 | ''' 102 | torch.manual_seed(0) 103 | try: 104 | smoothen_array = -1*alpha + torch.rand([label.shape[0], label.shape[1], label.shape[2], label.shape[3]]) * (2*alpha) 105 | smoothen_array = cuda(smoothen_array, gpu_id) 106 | label = label + smoothen_array 107 | except: 108 | smoothen_array = -1*alpha + torch.rand([label.shape[0], label.shape[1], label.shape[2], label.shape[3]]) * (2*alpha) 109 | label = label + smoothen_array 110 | 111 | return label 112 | 113 | ''' 114 | To be used to apply gaussian noise in the input to the discriminator 115 | ''' 116 | class GaussianNoise(nn.Module): 117 | """Gaussian noise regularizer. 118 | 119 | Args: 120 | sigma (float, optional): relative standard deviation used to generate the 121 | noise. Relative means that it will be multiplied by the magnitude of 122 | the value your are adding the noise to. This means that sigma can be 123 | the same regardless of the scale of the vector. 124 | is_relative_detach (bool, optional): whether to detach the variable before 125 | computing the scale of the noise. If `False` then the scale of the noise 126 | won't be seen as a constant but something to optimize: this will bias the 127 | network to generate vectors with smaller values. 128 | """ 129 | 130 | def __init__(self, sigma=0.1, is_relative_detach=True): 131 | super().__init__() 132 | self.sigma = sigma 133 | self.is_relative_detach = is_relative_detach 134 | 135 | def forward(self, x): 136 | if self.training and self.sigma != 0: 137 | scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x 138 | sampled_noise = torch.zeros(x.size()).normal_() * scale 139 | x = x + sampled_noise 140 | return x 141 | 142 | ''' 143 | This will be used for calculation of perceptual losses 144 | ''' 145 | class Vgg16(torch.nn.Module): 146 | def __init__(self, requires_grad=False): 147 | super(Vgg16, self).__init__() 148 | vgg_pretrained_features = models.vgg16(pretrained=True).features 149 | self.slice1 = torch.nn.Sequential() 150 | self.slice2 = torch.nn.Sequential() 151 | self.slice3 = torch.nn.Sequential() 152 | self.slice4 = torch.nn.Sequential() 153 | for x in range(4): 154 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 155 | for x in range(4, 9): 156 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 157 | for x in range(9, 16): 158 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 159 | for x in range(16, 23): 160 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 161 | if not requires_grad: 162 | for param in self.parameters(): 163 | param.requires_grad = False 164 | 165 | def forward(self, X): 166 | h = self.slice1(X) 167 | h_relu1_2 = h 168 | h = self.slice2(h) 169 | h_relu2_2 = h 170 | h = self.slice3(h) 171 | h_relu3_3 = h 172 | h = self.slice4(h) 173 | h_relu4_3 = h 174 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 175 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 176 | return out 177 | 178 | ''' 179 | The definition of the perceptual loss 180 | ''' 181 | def perceptual_loss(x, y, gpu_ids): 182 | """ 183 | Calculates the perceptual loss on the basis of the VGG network 184 | Parameters: 185 | x, y: the images between which perceptual loss is to be calculated 186 | """ 187 | 188 | ### Considering the fact in this case x,y both are images in the range -1 to 1 and we need normal distribution 189 | ### before passing through VGG 190 | 191 | u = x*0.5 + 0.5 192 | v = y*0.5 + 0.5 193 | trans_mean = [0.485, 0.456, 0.406] 194 | trans_std = [0.229, 0.224, 0.225] 195 | 196 | for i in range(3): 197 | u[:, i, :, :] = u[:, i, :, :]*trans_std[i] + trans_mean[i] 198 | v[:, i, :, :] = v[:, i, :, :]*trans_std[i] + trans_mean[i] 199 | 200 | ### Now this is normal distribution 201 | 202 | vgg = Vgg16(requires_grad=False).cuda(gpu_ids[0]) 203 | 204 | features_y = vgg(v) 205 | features_x = vgg(u) 206 | 207 | mse_loss = nn.MSELoss() 208 | loss = mse_loss(features_y.relu2_2, features_x.relu2_2) 209 | 210 | return loss 211 | 212 | 213 | # To make directories 214 | def mkdir(paths): 215 | for path in paths: 216 | if not os.path.isdir(path): 217 | os.makedirs(path) 218 | 219 | 220 | # To make cuda tensor 221 | def cuda(xs, gpu_id): 222 | if torch.cuda.is_available(): 223 | if not isinstance(xs, (list, tuple)): 224 | return xs.cuda(int(gpu_id[0])) 225 | else: 226 | return [x.cuda(int(gpu_id[0])) for x in xs] 227 | return xs 228 | 229 | 230 | # For Pytorch datasets loader 231 | def create_link(dataset_dir): 232 | dirs = {} 233 | dirs['trainA'] = os.path.join(dataset_dir, 'ltrainA') 234 | dirs['trainB'] = os.path.join(dataset_dir, 'ltrainB') 235 | dirs['testA'] = os.path.join(dataset_dir, 'ltestA') 236 | dirs['testB'] = os.path.join(dataset_dir, 'ltestB') 237 | mkdir(dirs.values()) 238 | 239 | for key in dirs: 240 | try: 241 | os.remove(os.path.join(dirs[key], 'Link')) 242 | except: 243 | pass 244 | os.symlink(os.path.abspath(os.path.join(dataset_dir, key)), 245 | os.path.join(dirs[key], 'Link')) 246 | 247 | return dirs 248 | 249 | 250 | def get_traindata_link(dataset_dir): 251 | dirs = {} 252 | dirs['trainA'] = os.path.join(dataset_dir, 'ltrainA') 253 | dirs['trainB'] = os.path.join(dataset_dir, 'ltrainB') 254 | return dirs 255 | 256 | 257 | def get_testdata_link(dataset_dir): 258 | dirs = {} 259 | dirs['testA'] = os.path.join(dataset_dir, 'ltestA') 260 | dirs['testB'] = os.path.join(dataset_dir, 'ltestB') 261 | return dirs 262 | 263 | 264 | # To save the checkpoint 265 | def save_checkpoint(state, save_path): 266 | torch.save(state, save_path) 267 | 268 | 269 | # To load the checkpoint 270 | def load_checkpoint(ckpt_path, map_location='cpu'): 271 | ckpt = torch.load(ckpt_path, map_location=map_location) 272 | print(' [*] Loading checkpoint from %s succeed!' % ckpt_path) 273 | return ckpt 274 | 275 | 276 | # To store 50 generated image in a pool and sample from it when it is full 277 | # Shrivastava et al’s strategy 278 | class Sample_from_Pool(object): 279 | def __init__(self, max_elements=50): 280 | self.max_elements = max_elements 281 | self.cur_elements = 0 282 | self.items = [] 283 | 284 | def __call__(self, in_items): 285 | return_items = [] 286 | for in_item in in_items: 287 | if self.cur_elements < self.max_elements: 288 | self.items.append(in_item) 289 | self.cur_elements = self.cur_elements + 1 290 | return_items.append(in_item) 291 | else: 292 | if np.random.ranf() > 0.5: 293 | idx = np.random.randint(0, self.max_elements) 294 | tmp = copy.copy(self.items[idx]) 295 | self.items[idx] = in_item 296 | return_items.append(tmp) 297 | else: 298 | return_items.append(in_item) 299 | return return_items 300 | 301 | 302 | def recursive_glob(rootdir=".", suffix=""): 303 | """Performs recursive glob with given suffix and rootdir 304 | :param rootdir is the root directory 305 | :param suffix is the suffix to be searched 306 | """ 307 | return [ 308 | os.path.join(looproot, filename) 309 | for looproot, _, filenames in os.walk(rootdir) 310 | for filename in filenames 311 | if filename.endswith(suffix) 312 | ] 313 | 314 | def make_one_hot(labels, dataname, gpu_id): 315 | ''' 316 | Converts an integer label torch.autograd.Variable to a one-hot Variable. 317 | 318 | Parameters 319 | ---------- 320 | labels : torch.autograd.Variable of torch.cuda.LongTensor 321 | N x 1 x H x W, where N is batch size. 322 | Each value is an integer representing correct classification. 323 | C : integer. 324 | number of classes in labels. 325 | 326 | Returns 327 | ------- 328 | target : torch.autograd.Variable of torch.cuda.FloatTensor 329 | N x C x H x W, where C is class number. One-hot encoded. 330 | ''' 331 | assert dataname in ('voc2012', 'cityscapes', 'acdc'), 'dataset name should be one of the following: \'voc2012\',given {}'.format(dataname) 332 | 333 | if dataname == 'voc2012': 334 | C = 21 335 | elif dataname == 'cityscapes': 336 | C = 20 337 | elif dataname == 'acdc': 338 | C = 4 339 | else: 340 | raise NotImplementedError 341 | 342 | labels = labels.long() 343 | try: 344 | one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_() 345 | one_hot = cuda(one_hot, gpu_id) 346 | except: 347 | one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_() 348 | target = one_hot.scatter_(1, labels.data, 1) 349 | 350 | return target 351 | 352 | 353 | # Adapted from score written by wkentaro 354 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 355 | 356 | 357 | class runningScore(object): 358 | def __init__(self, n_classes, dataset): 359 | self.n_classes = n_classes 360 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 361 | self.dataset = dataset 362 | 363 | def _fast_hist(self, label_true, label_pred, n_class): 364 | mask = (label_true >= 0) & (label_true < n_class) 365 | hist = np.bincount( 366 | n_class * label_true[mask].astype(int) + label_pred[mask], 367 | minlength=n_class ** 2, 368 | ).reshape(n_class, n_class) 369 | return hist 370 | 371 | def update(self, label_trues, label_preds): 372 | for lt, lp in zip(label_trues, label_preds): 373 | self.confusion_matrix += self._fast_hist( 374 | lt.flatten(), lp.flatten(), self.n_classes 375 | ) 376 | 377 | def get_scores(self): 378 | """Returns accuracy score evaluation result. 379 | - overall accuracy 380 | - mean accuracy 381 | - mean IU 382 | - fwavacc 383 | """ 384 | hist = self.confusion_matrix 385 | acc = np.diag(hist).sum() / hist.sum() 386 | acc_cls = np.diag(hist) / hist.sum(axis=1) 387 | acc_cls = np.nanmean(acc_cls) 388 | if self.dataset == 'voc2012': 389 | iu = np.diag(hist[1:self.n_classes, 1:self.n_classes]) / (hist[1:self.n_classes, 1:self.n_classes].sum(axis=1) + hist[1:self.n_classes, 1:self.n_classes].sum(axis=0) - np.diag(hist[1:self.n_classes, 1:self.n_classes])) 390 | elif self.dataset == 'cityscapes': 391 | iu = np.diag(hist[0:self.n_classes-1, 0:self.n_classes-1]) / (hist[0:self.n_classes-1, 0:self.n_classes-1].sum(axis=1) + hist[0:self.n_classes-1, 0:self.n_classes-1].sum(axis=0) - np.diag(hist[0:self.n_classes-1, 0:self.n_classes-1])) 392 | elif self.dataset == 'acdc': 393 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 394 | mean_iu = np.nanmean(iu) 395 | # freq = hist.sum(axis=1) / hist.sum() 396 | # fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 397 | if self.dataset == 'voc2012' or self.dataset == 'cityscapes': 398 | cls_iu = dict(zip(range(self.n_classes-1), iu)) 399 | elif self.dataset == 'acdc': 400 | cls_iu = dict(zip(range(self.n_classes), iu)) 401 | 402 | return ( 403 | { 404 | "Overall Acc: \t": acc, 405 | "Mean Acc : \t": acc_cls, 406 | "Mean IoU : \t": mean_iu, 407 | }, 408 | cls_iu, 409 | ) 410 | 411 | def reset(self): 412 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 413 | 414 | 415 | class averageMeter(object): 416 | """Computes and stores the average and current value""" 417 | 418 | def __init__(self): 419 | self.reset() 420 | 421 | def reset(self): 422 | self.val = 0 423 | self.avg = 0 424 | self.sum = 0 425 | self.count = 0 426 | 427 | def update(self, val, n=1): 428 | self.val = val 429 | self.sum += val * n 430 | self.count += n 431 | self.avg = self.sum / self.count 432 | 433 | 434 | class LambdaLR(): 435 | def __init__(self, epochs, offset, decay_epoch): 436 | self.epochs = epochs 437 | self.offset = offset 438 | self.decay_epoch = decay_epoch 439 | 440 | def step(self, epoch): 441 | return 1.0 - max(0, epoch + self.offset - self.decay_epoch)/(self.epochs - self.decay_epoch) 442 | 443 | def print_networks(nets, names): 444 | print('------------Number of Parameters---------------') 445 | i=0 446 | for net in nets: 447 | num_params = 0 448 | for param in net.parameters(): 449 | num_params += param.numel() 450 | print('[Network %s] Total number of parameters : %.3f M' % (names[i], num_params / 1e6)) 451 | i=i+1 452 | print('-----------------------------------------------') 453 | -------------------------------------------------------------------------------- /arch/generators.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch import nn 4 | from .ops import conv_norm_relu, dconv_norm_relu, ResidualBlock, get_norm_layer, init_network, InitialBlock, RegularBottleneck, UpsamplingBottleneck, DownsamplingBottleneck, SSnbt, Downsample_Block_led, APN 5 | 6 | 7 | class UnetSkipConnectionBlock(nn.Module): 8 | def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, 9 | innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 10 | super(UnetSkipConnectionBlock, self).__init__() 11 | self.outermost = outermost 12 | if type(norm_layer) == functools.partial: 13 | use_bias = norm_layer.func == nn.InstanceNorm2d 14 | else: 15 | use_bias = norm_layer == nn.InstanceNorm2d 16 | if input_nc is None: 17 | input_nc = outer_nc 18 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 19 | 20 | if outermost: 21 | upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1) 22 | down = [downconv] 23 | up = [nn.ReLU(True), upconv] 24 | model = down + [submodule] + up 25 | elif innermost: 26 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 27 | down = [nn.LeakyReLU(0.2, True), downconv] 28 | up = [nn.ReLU(True), upconv, norm_layer(outer_nc)] 29 | model = down + up 30 | else: 31 | upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) 32 | down = [nn.LeakyReLU(0.2, True), downconv, norm_layer(inner_nc)] 33 | up = [nn.ReLU(True), upconv, norm_layer(outer_nc)] 34 | 35 | if use_dropout: 36 | model = down + [submodule] + up + [nn.Dropout(0.5)] 37 | else: 38 | model = down + [submodule] + up 39 | 40 | self.model = nn.Sequential(*model) 41 | 42 | def forward(self, x): 43 | if self.outermost: 44 | return self.model(x) 45 | else: 46 | return torch.cat([x, self.model(x)], 1) 47 | 48 | class UnetGenerator(nn.Module): 49 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 50 | norm_layer=nn.BatchNorm2d, use_dropout=False): 51 | super(UnetGenerator, self).__init__() 52 | 53 | unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=None, norm_layer=norm_layer, innermost=True) 54 | for i in range(num_downs - 5): 55 | unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 56 | unet_block = UnetSkipConnectionBlock(ngf*4, ngf*8, submodule=unet_block, norm_layer=norm_layer) 57 | unet_block = UnetSkipConnectionBlock(ngf*2, ngf*4, submodule=unet_block, norm_layer=norm_layer) 58 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer) 59 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 60 | self.unet_model = unet_block 61 | 62 | def forward(self, input): 63 | return self.unet_model(input) 64 | 65 | class ResnetGenerator(nn.Module): 66 | def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True, num_blocks=6, softmax=False): 67 | super(ResnetGenerator, self).__init__() 68 | if type(norm_layer) == functools.partial: 69 | use_bias = norm_layer.func == nn.InstanceNorm2d 70 | else: 71 | use_bias = norm_layer == nn.InstanceNorm2d 72 | 73 | res_model = [nn.ReflectionPad2d(3), 74 | conv_norm_relu(input_nc, ngf * 1, 7, norm_layer=norm_layer, bias=use_bias), 75 | conv_norm_relu(ngf * 1, ngf * 2, 3, 2, 1, norm_layer=norm_layer, bias=use_bias), 76 | conv_norm_relu(ngf * 2, ngf * 4, 3, 2, 1, norm_layer=norm_layer, bias=use_bias)] 77 | 78 | for i in range(num_blocks): 79 | res_model += [ResidualBlock(ngf * 4, norm_layer, use_dropout, use_bias)] 80 | 81 | if softmax: 82 | res_model += [dconv_norm_relu(ngf * 4, ngf * 2, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias), 83 | dconv_norm_relu(ngf * 2, ngf * 1, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias), 84 | nn.ReflectionPad2d(3), 85 | nn.Conv2d(ngf, output_nc, 7)] 86 | else: 87 | res_model += [dconv_norm_relu(ngf * 4, ngf * 2, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias), 88 | dconv_norm_relu(ngf * 2, ngf * 1, 3, 2, 1, 1, norm_layer=norm_layer, bias=use_bias), 89 | nn.ReflectionPad2d(3), 90 | nn.Conv2d(ngf, output_nc, 7), 91 | nn.Tanh()] 92 | self.res_model = nn.Sequential(*res_model) 93 | 94 | def forward(self, x): 95 | return self.res_model(x) 96 | 97 | 98 | class ENet(nn.Module): 99 | ''' 100 | num_classes : The number of classes to segment the image into 101 | encoder_relu : whether to include relu or prelu for encoder part 102 | decoder_relu : whether to include relu or prelu for decoder part 103 | ''' 104 | def __init__(self, num_classes, encoder_relu=False, decoder_relu=True): 105 | super().__init__() 106 | 107 | self.initial_block = InitialBlock(3, 16, padding=1, relu=encoder_relu) 108 | 109 | # Stage 1 - Encoder 110 | self.downsample1_0 = DownsamplingBottleneck( 111 | 16, 112 | 64, 113 | padding=1, 114 | return_indices=True, 115 | dropout_prob=0.01, 116 | relu=encoder_relu) 117 | self.regular1_1 = RegularBottleneck( 118 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 119 | self.regular1_2 = RegularBottleneck( 120 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 121 | self.regular1_3 = RegularBottleneck( 122 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 123 | self.regular1_4 = RegularBottleneck( 124 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 125 | 126 | # Stage 2 - Encoder 127 | self.downsample2_0 = DownsamplingBottleneck( 128 | 64, 129 | 128, 130 | padding=1, 131 | return_indices=True, 132 | dropout_prob=0.1, 133 | relu=encoder_relu) 134 | self.regular2_1 = RegularBottleneck( 135 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 136 | self.dilated2_2 = RegularBottleneck( 137 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 138 | self.asymmetric2_3 = RegularBottleneck( 139 | 128, 140 | kernel_size=5, 141 | padding=2, 142 | asymmetric=True, 143 | dropout_prob=0.1, 144 | relu=encoder_relu) 145 | self.dilated2_4 = RegularBottleneck( 146 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 147 | self.regular2_5 = RegularBottleneck( 148 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 149 | self.dilated2_6 = RegularBottleneck( 150 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 151 | self.asymmetric2_7 = RegularBottleneck( 152 | 128, 153 | kernel_size=5, 154 | asymmetric=True, 155 | padding=2, 156 | dropout_prob=0.1, 157 | relu=encoder_relu) 158 | self.dilated2_8 = RegularBottleneck( 159 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 160 | 161 | # Stage 3 - Encoder 162 | self.regular3_0 = RegularBottleneck( 163 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 164 | self.dilated3_1 = RegularBottleneck( 165 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 166 | self.asymmetric3_2 = RegularBottleneck( 167 | 128, 168 | kernel_size=5, 169 | padding=2, 170 | asymmetric=True, 171 | dropout_prob=0.1, 172 | relu=encoder_relu) 173 | self.dilated3_3 = RegularBottleneck( 174 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 175 | self.regular3_4 = RegularBottleneck( 176 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 177 | self.dilated3_5 = RegularBottleneck( 178 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 179 | self.asymmetric3_6 = RegularBottleneck( 180 | 128, 181 | kernel_size=5, 182 | asymmetric=True, 183 | padding=2, 184 | dropout_prob=0.1, 185 | relu=encoder_relu) 186 | self.dilated3_7 = RegularBottleneck( 187 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 188 | 189 | # Stage 4 - Decoder 190 | self.upsample4_0 = UpsamplingBottleneck( 191 | 128, 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 192 | self.regular4_1 = RegularBottleneck( 193 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 194 | self.regular4_2 = RegularBottleneck( 195 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 196 | 197 | # Stage 5 - Decoder 198 | self.upsample5_0 = UpsamplingBottleneck( 199 | 64, 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 200 | self.regular5_1 = RegularBottleneck( 201 | 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 202 | self.transposed_conv = nn.ConvTranspose2d( 203 | 16, 204 | num_classes, 205 | kernel_size=3, 206 | stride=2, 207 | padding=1, 208 | output_padding=1, 209 | bias=False) 210 | 211 | def forward(self, x): 212 | # Initial block 213 | x = self.initial_block(x) 214 | 215 | # Stage 1 - Encoder 216 | x, max_indices1_0 = self.downsample1_0(x) 217 | x = self.regular1_1(x) 218 | x = self.regular1_2(x) 219 | x = self.regular1_3(x) 220 | x = self.regular1_4(x) 221 | 222 | # Stage 2 - Encoder 223 | x, max_indices2_0 = self.downsample2_0(x) 224 | x = self.regular2_1(x) 225 | x = self.dilated2_2(x) 226 | x = self.asymmetric2_3(x) 227 | x = self.dilated2_4(x) 228 | x = self.regular2_5(x) 229 | x = self.dilated2_6(x) 230 | x = self.asymmetric2_7(x) 231 | x = self.dilated2_8(x) 232 | 233 | # Stage 3 - Encoder 234 | x = self.regular3_0(x) 235 | x = self.dilated3_1(x) 236 | x = self.asymmetric3_2(x) 237 | x = self.dilated3_3(x) 238 | x = self.regular3_4(x) 239 | x = self.dilated3_5(x) 240 | x = self.asymmetric3_6(x) 241 | x = self.dilated3_7(x) 242 | 243 | # Stage 4 - Decoder 244 | x = self.upsample4_0(x, max_indices2_0) 245 | x = self.regular4_1(x) 246 | x = self.regular4_2(x) 247 | 248 | # Stage 5 - Decoder 249 | x = self.upsample5_0(x, max_indices1_0) 250 | x = self.regular5_1(x) 251 | x = self.transposed_conv(x) 252 | 253 | return x 254 | 255 | 256 | class LEDNet(nn.Module): 257 | ''' 258 | Implementation of the LEDNet network 259 | ''' 260 | 261 | def __init__(self, in_channels=3, n_classes=22, encoder_relu=False, decoder_relu=True, image_dim=128): 262 | super().__init__() 263 | 264 | ### Encoder 265 | self.downsample_1 = Downsample_Block_led(in_channels, 32, kernel_size=3, padding=1, relu=encoder_relu) 266 | 267 | self.ssnbt_1_1 = SSnbt(32, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01) 268 | self.ssnbt_1_2 = SSnbt(32, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01) 269 | self.ssnbt_1_3 = SSnbt(32, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01) 270 | 271 | self.downsample_2 = Downsample_Block_led(32, 64, kernel_size=3, padding=1, relu=encoder_relu) 272 | 273 | self.ssnbt_2_1 = SSnbt(64, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01) 274 | self.ssnbt_2_2 = SSnbt(64, kernel_size=3, padding=1, groups=4, relu=encoder_relu, drop_prob=0.01) 275 | 276 | self.downsample_3 = Downsample_Block_led(64, 128, kernel_size=3, padding=1, relu=encoder_relu) 277 | 278 | self.ssnbt_3_1 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=1, relu=encoder_relu, drop_prob=0.1) 279 | self.ssnbt_3_2 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=2, relu=encoder_relu, drop_prob=0.1) 280 | self.ssnbt_3_3 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=5, relu=encoder_relu, drop_prob=0.1) 281 | self.ssnbt_3_4 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=9, relu=encoder_relu, drop_prob=0.1) 282 | self.ssnbt_3_5 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=2, relu=encoder_relu, drop_prob=0.1) 283 | self.ssnbt_3_6 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=5, relu=encoder_relu, drop_prob=0.1) 284 | self.ssnbt_3_7 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=9, relu=encoder_relu, drop_prob=0.1) 285 | self.ssnbt_3_8 = SSnbt(128, kernel_size=3, padding=1, groups=8, dilation=17, relu=encoder_relu, drop_prob=0.1) 286 | 287 | ### Decoder 288 | self.apn = APN(128, n_classes, image_dim= image_dim // 8, relu=decoder_relu) 289 | 290 | def forward(self, x): 291 | 292 | ### Encoder 293 | x = self.downsample_1(x) 294 | x = self.ssnbt_1_1(x) 295 | x = self.ssnbt_1_2(x) 296 | x = self.ssnbt_1_3(x) 297 | 298 | x = self.downsample_2(x) 299 | x = self.ssnbt_2_1(x) 300 | x = self.ssnbt_2_2(x) 301 | 302 | x = self.downsample_3(x) 303 | x = self.ssnbt_3_1(x) 304 | x = self.ssnbt_3_2(x) 305 | x = self.ssnbt_3_3(x) 306 | x = self.ssnbt_3_4(x) 307 | x = self.ssnbt_3_5(x) 308 | x = self.ssnbt_3_6(x) 309 | x = self.ssnbt_3_7(x) 310 | x = self.ssnbt_3_8(x) 311 | 312 | ### Decoder 313 | out = self.apn(x) 314 | 315 | return out 316 | 317 | 318 | ### Code for deeplab model 319 | ### This code is modified from: https://github.com/hfslyc/AdvSemiSeg 320 | class Bottleneck(nn.Module): 321 | expansion = 4 322 | 323 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 324 | super(Bottleneck, self).__init__() 325 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 326 | self.bn1 = nn.BatchNorm2d(planes) 327 | for i in self.bn1.parameters(): 328 | i.requires_grad = False 329 | 330 | padding = dilation 331 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 332 | padding=padding, bias=False, dilation = dilation) 333 | self.bn2 = nn.BatchNorm2d(planes) 334 | for i in self.bn2.parameters(): 335 | i.requires_grad = False 336 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 337 | self.bn3 = nn.BatchNorm2d(planes * 4) 338 | for i in self.bn3.parameters(): 339 | i.requires_grad = False 340 | self.relu = nn.ReLU(inplace=True) 341 | self.downsample = downsample 342 | self.stride = stride 343 | 344 | 345 | def forward(self, x): 346 | residual = x 347 | 348 | out = self.conv1(x) 349 | out = self.bn1(out) 350 | out = self.relu(out) 351 | 352 | out = self.conv2(out) 353 | out = self.bn2(out) 354 | out = self.relu(out) 355 | 356 | out = self.conv3(out) 357 | out = self.bn3(out) 358 | 359 | if self.downsample is not None: 360 | residual = self.downsample(x) 361 | 362 | out += residual 363 | out = self.relu(out) 364 | 365 | return out 366 | 367 | class Classifier_Module(nn.Module): 368 | 369 | def __init__(self, dilation_series, padding_series, num_classes): 370 | super(Classifier_Module, self).__init__() 371 | self.conv2d_list = nn.ModuleList() 372 | for dilation, padding in zip(dilation_series, padding_series): 373 | self.conv2d_list.append(nn.Conv2d(2048, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias = True)) 374 | 375 | for m in self.conv2d_list: 376 | m.weight.data.normal_(0, 0.01) 377 | 378 | def forward(self, x): 379 | out = self.conv2d_list[0](x) 380 | for i in range(len(self.conv2d_list)-1): 381 | out += self.conv2d_list[i+1](x) 382 | return out 383 | 384 | class ResNet(nn.Module): 385 | def __init__(self, in_channels, block, layers, num_classes): 386 | self.inplanes = 64 387 | super(ResNet, self).__init__() 388 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, 389 | bias=False) 390 | self.bn1 = nn.BatchNorm2d(64) 391 | for i in self.bn1.parameters(): 392 | i.requires_grad = False 393 | self.relu = nn.ReLU(inplace=True) 394 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 395 | self.layer1 = self._make_layer(block, 64, layers[0]) 396 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 397 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 398 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 399 | self.layer5 = self._make_pred_layer(Classifier_Module, [6,12,18,24],[6,12,18,24],num_classes) 400 | 401 | for m in self.modules(): 402 | if isinstance(m, nn.Conv2d): 403 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 404 | m.weight.data.normal_(0, 0.01) 405 | elif isinstance(m, nn.BatchNorm2d): 406 | m.weight.data.fill_(1) 407 | m.bias.data.zero_() 408 | # for i in m.parameters(): 409 | # i.requires_grad = False 410 | 411 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 412 | downsample = None 413 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: 414 | downsample = nn.Sequential( 415 | nn.Conv2d(self.inplanes, planes * block.expansion, 416 | kernel_size=1, stride=stride, bias=False), 417 | nn.BatchNorm2d(planes * block.expansion)) 418 | for i in downsample._modules['1'].parameters(): 419 | i.requires_grad = False 420 | layers = [] 421 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample)) 422 | self.inplanes = planes * block.expansion 423 | for i in range(1, blocks): 424 | layers.append(block(self.inplanes, planes, dilation=dilation)) 425 | 426 | return nn.Sequential(*layers) 427 | def _make_pred_layer(self,block, dilation_series, padding_series,num_classes): 428 | return block(dilation_series,padding_series,num_classes) 429 | 430 | def forward(self, x): 431 | x = self.conv1(x) 432 | x = self.bn1(x) 433 | x = self.relu(x) 434 | x = self.maxpool(x) 435 | x = self.layer1(x) 436 | x = self.layer2(x) 437 | x = self.layer3(x) 438 | x = self.layer4(x) 439 | x = self.layer5(x) 440 | 441 | return x 442 | 443 | def get_1x_lr_params_NOscale(self): 444 | """ 445 | This generator returns all the parameters of the net except for 446 | the last classification layer. Note that for each batchnorm layer, 447 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 448 | any batchnorm parameter 449 | """ 450 | b = [] 451 | 452 | b.append(self.conv1) 453 | b.append(self.bn1) 454 | b.append(self.layer1) 455 | b.append(self.layer2) 456 | b.append(self.layer3) 457 | b.append(self.layer4) 458 | 459 | 460 | for i in range(len(b)): 461 | for j in b[i].modules(): 462 | jj = 0 463 | for k in j.parameters(): 464 | jj+=1 465 | if k.requires_grad: 466 | yield k 467 | 468 | def get_10x_lr_params(self): 469 | """ 470 | This generator returns all the parameters for the last layer of the net, 471 | which does the classification of pixel into classes 472 | """ 473 | b = [] 474 | b.append(self.layer5.parameters()) 475 | 476 | for j in range(len(b)): 477 | for i in b[j]: 478 | yield i 479 | 480 | 481 | 482 | def optim_parameters(self, args): 483 | return [{'params': self.get_1x_lr_params_NOscale(), 'lr': args.learning_rate}, 484 | {'params': self.get_10x_lr_params(), 'lr': 10*args.learning_rate}] 485 | 486 | 487 | def define_Gen(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, gpu_ids=[0]): 488 | gen_net = None 489 | norm_layer = get_norm_layer(norm_type=norm) 490 | 491 | if netG == 'resnet_9blocks': 492 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=9, softmax=False) 493 | elif netG == 'resnet_9blocks_softmax': 494 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=9, softmax=True) 495 | elif netG == 'resnet_6blocks': 496 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=6, softmax=False) 497 | elif netG == 'resnet_6blocks_softmax': 498 | gen_net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, num_blocks=6, softmax=True) 499 | elif netG == 'unet_128': 500 | gen_net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 501 | elif netG == 'unet_256': 502 | gen_net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 503 | elif netG == 'enet': 504 | ### For ENet the input_nc is always 3 505 | gen_net = ENet(num_classes = output_nc, encoder_relu=False, decoder_relu=True) 506 | elif netG == 'lednet_128': 507 | gen_net = LEDNet(in_channels=input_nc, n_classes= output_nc, encoder_relu=False, decoder_relu=True, image_dim=128) 508 | elif netG == 'lednet_256': 509 | gen_net = LEDNet(in_channels=input_nc, n_classes= output_nc, encoder_relu=False, decoder_relu=True, image_dim=256) 510 | elif netG == 'deeplab': 511 | gen_net = ResNet(in_channels=input_nc, block=Bottleneck, layers=[3, 4, 23, 3], num_classes=output_nc) 512 | else: 513 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 514 | 515 | return init_network(gen_net, gpu_ids) 516 | -------------------------------------------------------------------------------- /arch/ops.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import init 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | def get_norm_layer(norm_type='instance'): 8 | if norm_type == 'batch': 9 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 10 | elif norm_type == 'instance': 11 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 12 | else: 13 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 14 | return norm_layer 15 | 16 | def init_weights(net, init_type='normal', gain=0.02): 17 | def init_func(m): 18 | classname = m.__class__.__name__ 19 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 20 | init.normal(m.weight.data, 0.0, gain) 21 | if hasattr(m, 'bias') and m.bias is not None: 22 | init.constant(m.bias.data, 0.0) 23 | elif classname.find('BatchNorm2d') != -1: 24 | init.normal(m.weight.data, 1.0, gain) 25 | init.constant(m.bias.data, 0.0) 26 | 27 | print('Network initialized with weights sampled from N(0,0.02).') 28 | net.apply(init_func) 29 | 30 | 31 | def init_network(net, gpu_ids=[]): 32 | if len(gpu_ids) > 0: 33 | assert(torch.cuda.is_available()) 34 | net.cuda(gpu_ids[0]) 35 | ### I have commented this out only while we are using the Deeplab model because as such it is not required in that 36 | init_weights(net) 37 | return net 38 | 39 | 40 | def conv_norm_lrelu(in_dim, out_dim, kernel_size, stride = 1, padding=0, 41 | norm_layer = nn.BatchNorm2d, bias = False): 42 | return nn.Sequential( 43 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias), 44 | norm_layer(out_dim), nn.LeakyReLU(0.2,True)) 45 | 46 | def conv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0, 47 | norm_layer = nn.BatchNorm2d, bias = False): 48 | return nn.Sequential( 49 | nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias), 50 | norm_layer(out_dim), nn.ReLU(True)) 51 | 52 | def dconv_norm_relu(in_dim, out_dim, kernel_size, stride = 1, padding=0, output_padding=0, 53 | norm_layer = nn.BatchNorm2d, bias = False): 54 | return nn.Sequential( 55 | nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, 56 | padding, output_padding, bias = bias), 57 | norm_layer(out_dim), nn.ReLU(True)) 58 | 59 | class ResidualBlock(nn.Module): 60 | def __init__(self, dim, norm_layer, use_dropout, use_bias): 61 | super(ResidualBlock, self).__init__() 62 | res_block = [nn.ReflectionPad2d(1), 63 | conv_norm_relu(dim, dim, kernel_size=3, 64 | norm_layer= norm_layer, bias=use_bias)] 65 | if use_dropout: 66 | res_block += [nn.Dropout(0.5)] 67 | res_block += [nn.ReflectionPad2d(1), 68 | nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias), 69 | norm_layer(dim)] 70 | 71 | self.res_block = nn.Sequential(*res_block) 72 | 73 | def forward(self, x): 74 | return x + self.res_block(x) 75 | 76 | 77 | def set_grad(nets, requires_grad=False): 78 | for net in nets: 79 | for param in net.parameters(): 80 | param.requires_grad = requires_grad 81 | 82 | 83 | ####### 84 | 85 | ### A huge thanks to the repo: https://github.com/davidtvs/PyTorch-ENet 86 | ### A lot of ideas are being taken up from this repo 87 | 88 | ''' 89 | Here we start off for writing the required blocks of code for the enet model 90 | ''' 91 | 92 | class InitialBlock(nn.Module): 93 | ''' 94 | This will be initial block of the network 95 | ''' 96 | 97 | def __init__(self, input_dim = 3, output_dim = 16, kernel_size = 3, relu = True, padding=0, bias=False): 98 | super().__init__() 99 | 100 | if relu: 101 | activation = nn.ReLU() 102 | else: 103 | activation = nn.PReLU() 104 | 105 | ### So here we will define the main branch which will give the output by passing through the conv block 106 | self.main_branch = nn.Conv2d(in_channels = input_dim, 107 | out_channels = output_dim - 3, 108 | kernel_size = kernel_size, 109 | stride= 2, 110 | padding= padding, 111 | bias = bias) 112 | 113 | ### Extension branch, which is basically MaxPool2d in this case 114 | self.ext_branch = nn.MaxPool2d(kernel_size = kernel_size, 115 | stride= 2, 116 | padding= padding) 117 | 118 | ### Some of the extra basic blocks of activations and other things 119 | self.batch_norm = nn.BatchNorm2d(output_dim) 120 | self.activation = activation 121 | 122 | def forward(self, x): 123 | main = self.main_branch(x) 124 | ext = self.ext_branch(x) 125 | 126 | out = torch.cat((main, ext), 1) ### For concatenating their outputs 127 | 128 | out = self.batch_norm(out) 129 | out = self.activation(out) 130 | 131 | return out 132 | 133 | class RegularBottleneck(nn.Module): 134 | ''' 135 | So in this also there are two branches: 136 | Main branch --> 137 | This is basically simply the input 138 | Extension Branch --> 139 | 1. 1*1 conv to decrease the spatial dimensions to a internal_ratio specified by the user 140 | 2. conv layer, which can be regular, atrous or asymmetric convolution 141 | 3. 1*1 to increase the spatial dimension back from internal_ratio to the original size 142 | 4. Regularizer which in this case is dropout 143 | ''' 144 | 145 | def __init__(self, channels, internal_ratio = 4, kernel_size = 3, dilation = 1, padding = 1, asymmetric = False, dropout_prob=0, bias=False, relu=True): 146 | super().__init__() 147 | 148 | ### Now we need to ensure if the internal_ratio has a reasonable value 149 | assert internal_ratio > 1 and internal_ratio < channels , 'The value of the internal_ratio is not correct' 150 | 151 | internal_channels = channels // internal_ratio 152 | 153 | if relu: 154 | activation = nn.ReLU() 155 | else: 156 | activation = nn.PReLU() 157 | 158 | self.ext_conv1 = nn.Sequential( 159 | nn.Conv2d(in_channels = channels, out_channels = internal_channels, kernel_size = 1, stride = 1, padding = 0, bias=bias), 160 | nn.BatchNorm2d(internal_channels), 161 | activation) 162 | 163 | ### So here we are going to see for the case if we are given asymmetric convolutions 164 | ### So basically it is gonna be like, example for 3*3 we will break it into 3*1 and 1*3 165 | ### and then perform normal convolutions, first 3*1 and then 1*3 166 | if asymmetric: 167 | self.ext_conv2 = nn.Sequential( 168 | nn.Conv2d(internal_channels, internal_channels, (kernel_size, 1), stride = (1, 1), padding = (padding, 0), dilation=dilation, bias=bias), 169 | nn.BatchNorm2d(internal_channels), 170 | activation, 171 | nn.Conv2d(internal_channels, internal_channels, kernel_size = (1, kernel_size), stride=(1,1), padding=(0, padding), dilation=dilation, bias=bias), 172 | nn.BatchNorm2d(internal_channels), 173 | activation 174 | ) 175 | else: 176 | self.ext_conv2 = nn.Sequential( 177 | nn.Conv2d(internal_channels, internal_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=bias), 178 | nn.BatchNorm2d(internal_channels), 179 | activation 180 | ) 181 | 182 | ### Now we are gonna implement the 1*1 expansion code 183 | self.ext_conv3 = nn.Sequential(nn.Conv2d(internal_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias), 184 | nn.BatchNorm2d(channels), 185 | activation) 186 | 187 | self.ext_regularizer = nn.Dropout2d(p = dropout_prob) 188 | 189 | self.out_relu = activation 190 | 191 | 192 | def forward(self, x): 193 | 194 | main = x 195 | 196 | ext = self.ext_conv1(x) 197 | ext = self.ext_conv2(ext) 198 | ext = self.ext_conv3(ext) 199 | ext = self.ext_regularizer(ext) 200 | 201 | out = ext + main 202 | 203 | ### Activation applying after the addition 204 | return self.out_relu(out) 205 | 206 | 207 | class DownsamplingBottleneck(nn.Module): 208 | ''' 209 | This is the Downsampling bottleneck layer to downsample the input 210 | ''' 211 | 212 | def __init__(self, in_channels, out_channels, internal_ratio=4, kernel_size = 3, padding=0, return_indices = False, dropout_prob = 0, bias=False, relu = True): 213 | super().__init__() 214 | 215 | self.return_indices = return_indices 216 | 217 | ### Now we need to ensure if the internal_ratio has a reasonable value 218 | assert internal_ratio > 1 and internal_ratio < in_channels , 'The value of the internal_ratio is not correct' 219 | 220 | internal_channels = in_channels // internal_ratio 221 | 222 | if relu: 223 | activation = nn.ReLU() 224 | else: 225 | activation = nn.PReLU() 226 | 227 | self.main_max1 = nn.MaxPool2d(kernel_size = kernel_size, stride=2, padding=padding, return_indices=return_indices) 228 | 229 | ### The extension branch 230 | self.ext_conv1 = nn.Sequential( 231 | nn.Conv2d(in_channels, internal_channels, kernel_size=2, stride=2, bias=bias), 232 | nn.BatchNorm2d(internal_channels), 233 | activation) 234 | 235 | ### Now we are gonna do the conv operation 236 | self.ext_conv2 = nn.Sequential( 237 | nn.Conv2d(internal_channels, internal_channels, kernel_size= kernel_size, stride=1, padding=padding, bias=bias), 238 | nn.BatchNorm2d(internal_channels), 239 | activation 240 | ) 241 | 242 | self.ext_conv3 = nn.Sequential( 243 | nn.Conv2d(internal_channels, out_channels, kernel_size=1, stride=1, bias=bias), 244 | nn.BatchNorm2d(out_channels), 245 | activation 246 | ) 247 | 248 | self.ext_reg = nn.Dropout2d(p=dropout_prob) 249 | 250 | self.out_activation = activation 251 | 252 | def forward(self, x): 253 | if self.return_indices: 254 | main, max_indices = self.main_max1(x) 255 | else: 256 | main = self.main_max1(x) 257 | 258 | ext = self.ext_conv1(x) 259 | ext = self.ext_conv2(ext) 260 | ext = self.ext_conv3(ext) 261 | ext = self.ext_reg(ext) 262 | 263 | ### Main branch channel padding 264 | n, ch_ext, h, w = ext.size() 265 | ch_main = main.size()[1] 266 | padding = torch.zeros(n, ch_ext - ch_main, h, w) 267 | 268 | ### So before concatenating we would first have to see if the main is on gpu cause they both must be 269 | ### on the same memory 270 | if main.is_cuda: 271 | padding = padding.cuda() 272 | 273 | 274 | ### Concatenating 275 | main = torch.cat((main, padding), 1) 276 | 277 | out = main + ext 278 | 279 | if self.return_indices: 280 | return self.out_activation(out), max_indices 281 | else: 282 | return self.out_activation(out) 283 | 284 | 285 | class UpsamplingBottleneck(nn.Module): 286 | ''' 287 | This is the Upsampling bottleneck as is displayed in the paper 288 | ''' 289 | 290 | def __init__(self, in_channels, out_channels, internal_ratio = 4, kernel_size = 3, padding=0, dropout_prob = 0, bias=False, relu=True): 291 | super().__init__() 292 | 293 | ### Now we need to ensure if the internal_ratio has a reasonable value 294 | assert internal_ratio > 1 and internal_ratio < in_channels , 'The value of the internal_ratio is not correct' 295 | 296 | internal_channels = in_channels // internal_ratio 297 | 298 | if relu: 299 | activation = nn.ReLU() 300 | else: 301 | activation = nn.PReLU() 302 | 303 | self.main_conv1 = nn.Sequential( 304 | nn.Conv2d(in_channels, out_channels, kernel_size= kernel_size, padding=padding, bias = bias), 305 | nn.BatchNorm2d(out_channels) 306 | ) 307 | 308 | self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2) 309 | 310 | ### Extension branch 311 | self.ext_conv1 = nn.Sequential( 312 | nn.Conv2d(in_channels, internal_channels, kernel_size= 1, bias= bias), 313 | nn.BatchNorm2d(internal_channels), 314 | activation 315 | ) 316 | 317 | self.ext_conv2 = nn.Sequential( 318 | nn.ConvTranspose2d(internal_channels, internal_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1, bias=bias), 319 | nn.BatchNorm2d(internal_channels), 320 | activation 321 | ) 322 | 323 | self.ext_conv3 = nn.Sequential( 324 | nn.Conv2d(internal_channels, out_channels, kernel_size=1, bias=bias), 325 | nn.BatchNorm2d(out_channels), 326 | activation 327 | ) 328 | 329 | self.ext_reg = nn.Dropout2d(p=dropout_prob) 330 | 331 | self.out_activation = activation 332 | 333 | def forward(self, x, max_indices): 334 | main = self.main_conv1(x) 335 | main = self.main_unpool1(main, max_indices) 336 | 337 | ext = self.ext_conv1(x) 338 | ext = self.ext_conv2(ext) 339 | ext = self.ext_conv3(ext) 340 | ext = self.ext_reg(ext) 341 | 342 | out = main+ext 343 | 344 | return self.out_activation(out) 345 | 346 | 347 | ##### 348 | ''' 349 | So here writing up the helper functions for the case of LEDNet 350 | ''' 351 | 352 | class Channel_Split(nn.Module): 353 | ''' 354 | Used to split the channels of a tensor 355 | ''' 356 | def __init__(self, channels): 357 | super().__init__() 358 | 359 | self.channels = channels 360 | 361 | def forward(self, x): 362 | x_left = x[:, 0:self.channels // 2, :, :] 363 | x_right = x[:, self.channels // 2 : , :, :] 364 | 365 | return x_left, x_right 366 | 367 | 368 | class ShuffleBlock(nn.Module): 369 | def __init__(self, groups): 370 | super(ShuffleBlock, self).__init__() 371 | self.groups = groups 372 | 373 | def forward(self, x): 374 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 375 | N,C,H,W = x.size() 376 | g = self.groups 377 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W) 378 | 379 | 380 | class SSnbt(nn.Module): 381 | ''' 382 | The class for the implementation split-shuffle-non-bottleneck block 383 | ''' 384 | 385 | def __init__(self, channels, kernel_size = 3, padding = 0, groups=4, dilation=1, bias= True, relu=False, drop_prob=0): 386 | ''' 387 | groups: parameter for determining the shuffling order in the shuffleBlock 388 | ''' 389 | super().__init__() 390 | 391 | if relu: 392 | activation = nn.ReLU() 393 | else: 394 | activation = nn.PReLU() 395 | 396 | ### Now firstly there will be channel split and hence we have to write for left and right parts of 397 | ### our model 398 | 399 | ### Also in this model the convolutions are asymmetric in nature 400 | 401 | mid_channels = channels // 2 402 | 403 | self.split = Channel_Split(channels) 404 | 405 | self.l_conv1 = nn.Conv2d(mid_channels, mid_channels, kernel_size=(kernel_size, 1), stride=1, padding=(padding, 0), dilation = 1, bias=bias) 406 | 407 | self.l_conv2 = nn.Sequential( 408 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (1, kernel_size), stride=1, padding=(0, padding), dilation=1, bias=bias), 409 | nn.BatchNorm2d(mid_channels), 410 | activation 411 | ) 412 | 413 | self.r_conv1 = nn.Conv2d(mid_channels, mid_channels, kernel_size = (1, kernel_size), stride=1, padding=(0, padding), dilation=1, bias=bias) 414 | 415 | self.r_conv2 = nn.Sequential( 416 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (kernel_size, 1), stride=1, padding=(padding, 0), dilation=1, bias=bias), 417 | nn.BatchNorm2d(mid_channels), 418 | activation 419 | ) 420 | 421 | self.l_conv3 = nn.Sequential( 422 | nn.Conv2d(mid_channels, mid_channels, kernel_size=(kernel_size, 1), stride=1, padding=(padding + dilation - 1, 0), dilation = dilation, bias=bias), 423 | activation 424 | ) 425 | 426 | self.l_conv4 = nn.Sequential( 427 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (1, kernel_size), stride=1, padding=(0, padding + dilation - 1), dilation=dilation, bias=bias), 428 | nn.BatchNorm2d(mid_channels), 429 | activation 430 | ) 431 | 432 | self.r_conv3 = nn.Sequential( 433 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (1, kernel_size), stride=1, padding=(0, padding + dilation - 1), dilation=dilation, bias=bias), 434 | activation 435 | ) 436 | 437 | self.r_conv4 = nn.Sequential( 438 | nn.Conv2d(mid_channels, mid_channels, kernel_size = (kernel_size, 1), stride=1, padding=(padding + dilation - 1, 0), dilation=dilation, bias=bias), 439 | nn.BatchNorm2d(mid_channels), 440 | activation 441 | ) 442 | 443 | self.regularizer = nn.Dropout2d(p=drop_prob) 444 | 445 | self.out_activation = activation 446 | self.shuffle = ShuffleBlock(groups=groups) 447 | 448 | def forward(self, x): 449 | main = x 450 | 451 | l_input, r_input = self.split(x) 452 | 453 | l_input = self.l_conv1(l_input) 454 | l_input = self.l_conv2(l_input) 455 | r_input = self.r_conv1(r_input) 456 | r_input = self.r_conv2(r_input) 457 | 458 | l_input = self.l_conv3(l_input) 459 | l_input = self.l_conv4(l_input) 460 | r_input = self.r_conv3(r_input) 461 | r_input = self.r_conv4(r_input) 462 | 463 | ext = torch.cat((l_input, r_input), 1) 464 | 465 | ext = self.regularizer(ext) 466 | 467 | out = main + ext 468 | out = self.out_activation(out) 469 | out = self.shuffle(out) 470 | 471 | return out 472 | 473 | 474 | class Downsample_Block_led(nn.Module): 475 | ''' 476 | The downsampling block for the LEDNet 477 | 478 | So basically we will do two operations in parallel, conv and max pool and then concat both of them 479 | ''' 480 | 481 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=0, bias=True, relu=True): 482 | super().__init__() 483 | 484 | if relu: 485 | activation = nn.ReLU() 486 | else: 487 | activation = nn.PReLU() 488 | 489 | self.conv = nn.Conv2d(in_channels, out_channels-in_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias) 490 | 491 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 492 | self.out_activation = activation 493 | 494 | def forward(self, x): 495 | main = self.conv(x) 496 | ext = self.maxpool(x) 497 | 498 | out = torch.cat((main, ext), 1) 499 | out = self.out_activation(out) 500 | return out 501 | 502 | class APN(nn.Module): 503 | ''' 504 | This is the Attention Pyramid Network including the whole upsampling block 505 | ''' 506 | 507 | def __init__(self, in_channels, out_channels, image_dim, bias=True, relu=True): 508 | ''' 509 | image_dim = dimension of the incoming image 510 | ''' 511 | super().__init__() 512 | 513 | if relu: 514 | activation = nn.ReLU() 515 | else: 516 | activation = nn.PReLU() 517 | 518 | self.conv_33 = nn.Sequential( 519 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, bias=bias), 520 | nn.BatchNorm2d(in_channels), 521 | activation 522 | ) 523 | 524 | self.conv_55 = nn.Sequential( 525 | nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=2, padding=2, bias=bias), 526 | nn.BatchNorm2d(in_channels), 527 | activation 528 | ) 529 | 530 | self.conv_77 = nn.Sequential( 531 | nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=2, padding=3, bias=bias), 532 | nn.BatchNorm2d(in_channels), 533 | activation 534 | ) 535 | 536 | self.conv_77_toC = nn.Sequential( 537 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 538 | nn.BatchNorm2d(out_channels), 539 | activation 540 | ) 541 | 542 | self.conv_55_toC = nn.Sequential( 543 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 544 | nn.BatchNorm2d(out_channels), 545 | activation 546 | ) 547 | 548 | self.conv_33_toC = nn.Sequential( 549 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 550 | nn.BatchNorm2d(out_channels), 551 | activation 552 | ) 553 | 554 | self.conv_11 = nn.Sequential( 555 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 556 | nn.BatchNorm2d(out_channels), 557 | activation 558 | ) 559 | 560 | self.global_pool = nn.AvgPool2d(kernel_size = (image_dim, image_dim), stride=0, padding=0) 561 | self.pool_conv_toC = nn.Sequential( 562 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 563 | nn.BatchNorm2d(out_channels), 564 | activation 565 | ) 566 | 567 | self.upsample_times2 = nn.UpsamplingBilinear2d(scale_factor=2) 568 | self.upsample_globalPool = nn.UpsamplingBilinear2d(size = (image_dim, image_dim)) 569 | 570 | def forward(self, x): 571 | output_conv33 = self.conv_33(x) 572 | 573 | output_conv55 = self.conv_55(output_conv33) 574 | 575 | output_conv77 = self.conv_77(output_conv55) 576 | output_conv77_toC = self.conv_77_toC(output_conv77) 577 | output_conv77_upsample = self.upsample_times2(output_conv77_toC) 578 | 579 | output_conv55_toC = self.conv_55_toC(output_conv55) 580 | output_conv55_toC = output_conv55_toC + output_conv77_upsample 581 | output_conv55_upsample = self.upsample_times2(output_conv55_toC) 582 | 583 | output_conv33_toC = self.conv_33_toC(output_conv33) 584 | output_conv33_toC = output_conv33_toC + output_conv55_upsample 585 | output_conv33_upsample = self.upsample_times2(output_conv33_toC) 586 | 587 | output_conv11 = self.conv_11(x) 588 | 589 | output_global_pool = self.global_pool(x) 590 | output_pool_conv_toC = self.pool_conv_toC(output_global_pool) 591 | output_pool_upsample = self.upsample_globalPool(output_pool_conv_toC) 592 | 593 | output_conv11 = output_conv11 * output_conv33_upsample 594 | output_conv11 = output_conv11 + output_pool_upsample 595 | 596 | final_output = F.upsample_bilinear(output_conv11, scale_factor=8) 597 | 598 | return final_output 599 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | import numpy as np 4 | 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | import torchvision 9 | import torchvision.datasets as dsets 10 | from torch.utils.data import DataLoader 11 | import torchvision.transforms as transforms 12 | import utils 13 | # from utils import perceptual_loss 14 | from arch import define_Gen, define_Dis, set_grad 15 | from data_utils import VOCDataset, CityscapesDataset, ACDCDataset, get_transformation 16 | from utils import make_one_hot 17 | from tensorboardX import SummaryWriter 18 | 19 | ''' 20 | Class for CycleGAN with train() as a member function 21 | 22 | ''' 23 | root = './data/VOC2012' 24 | root_cityscapes = "./data/Cityscape" 25 | root_acdc = './data/ACDC' 26 | 27 | ### The location for tensorboard visualizations 28 | tensorboard_loc = './tensorboard_results/first_run' 29 | 30 | ### The location from where we can get the pretrained model 31 | pretrained_loc = 'resnet101COCO-41f33a49.pth' 32 | 33 | class supervised_model(object): 34 | def __init__(self, args): 35 | 36 | if args.dataset == 'voc2012': 37 | self.n_channels = 21 38 | elif args.dataset == 'cityscapes': 39 | self.n_channels = 20 40 | elif args.dataset == 'acdc': 41 | self.n_channels = 4 42 | 43 | # Define the network 44 | self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab', norm=args.norm, 45 | use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) # for image to segmentation 46 | 47 | ### Now we put in the pretrained weights in Gsi 48 | ### These will only be used in the case of VOC and cityscapes 49 | if args.dataset != 'acdc': 50 | saved_state_dict = torch.load(pretrained_loc) 51 | new_params = self.Gsi.state_dict().copy() 52 | for name, param in new_params.items(): 53 | # print(name) 54 | if name in saved_state_dict and param.size() == saved_state_dict[name].size(): 55 | new_params[name].copy_(saved_state_dict[name]) 56 | # print('copy {}'.format(name)) 57 | # self.Gsi.load_state_dict(new_params) 58 | 59 | utils.print_networks([self.Gsi], ['Gsi']) 60 | 61 | ###Defining an interpolation function so as to match the output of network to feature map size 62 | self.interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True) 63 | self.interp_val = nn.Upsample(size = (512, 512), mode='bilinear', align_corners=True) 64 | 65 | self.CE = nn.CrossEntropyLoss() 66 | self.activation_softmax = nn.Softmax2d() 67 | self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(), lr=args.lr, betas=(0.9, 0.999)) 68 | 69 | ### writer for tensorboard 70 | self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised') 71 | self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset) 72 | 73 | self.args = args 74 | 75 | if not os.path.isdir(args.checkpoint_dir): 76 | os.makedirs(args.checkpoint_dir) 77 | 78 | try: 79 | ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir)) 80 | self.start_epoch = ckpt['epoch'] 81 | self.Gsi.load_state_dict(ckpt['Gsi']) 82 | self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer']) 83 | self.best_iou = ckpt['best_iou'] 84 | except: 85 | print(' [*] No checkpoint!') 86 | self.start_epoch = 0 87 | self.best_iou = -100 88 | 89 | def train(self, args): 90 | 91 | transform = get_transformation((self.args.crop_height, self.args.crop_width), resize=True, dataset=args.dataset) 92 | val_transform = get_transformation((512, 512), resize=True, dataset=args.dataset) 93 | 94 | # let the choice of dataset configurable 95 | if self.args.dataset == 'voc2012': 96 | labeled_set = VOCDataset(root_path=root, name='label', ratio=1.0, transformation=transform, 97 | augmentation=None) 98 | val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=val_transform, 99 | augmentation=None) 100 | labeled_loader = DataLoader(labeled_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True) 101 | val_loader = DataLoader(val_set, batch_size=self.args.batch_size, shuffle=True) 102 | elif self.args.dataset == 'cityscapes': 103 | labeled_set = CityscapesDataset(root_path=root_cityscapes, name='label', ratio=0.5, transformation=transform, 104 | augmentation=None) 105 | val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform, 106 | augmentation=None) 107 | labeled_loader = DataLoader(labeled_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True) 108 | val_loader = DataLoader(val_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True) 109 | elif self.args.dataset == 'acdc': 110 | labeled_set = ACDCDataset(root_path=root_acdc, name='label', ratio=0.5, transformation=transform, 111 | augmentation=None) 112 | val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform, 113 | augmentation=None) 114 | labeled_loader = DataLoader(labeled_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True) 115 | val_loader = DataLoader(val_set, batch_size=self.args.batch_size, shuffle=True, drop_last=True) 116 | 117 | img_fake_sample = utils.Sample_from_Pool() 118 | gt_fake_sample = utils.Sample_from_Pool() 119 | 120 | for epoch in range(self.start_epoch, self.args.epochs): 121 | self.Gsi.train() 122 | for i, (l_img, l_gt, img_name) in enumerate(labeled_loader): 123 | # step 124 | step = epoch * len(labeled_loader) + i + 1 125 | 126 | self.gsi_optimizer.zero_grad() 127 | 128 | l_img, l_gt = utils.cuda([l_img, l_gt], args.gpu_ids) 129 | 130 | lab_gt = self.Gsi(l_img) 131 | 132 | lab_gt = self.interp(lab_gt) ### To get the output of model same as labels 133 | 134 | # CE losses 135 | fullsupervisedloss = self.CE(lab_gt, l_gt.squeeze(1)) 136 | 137 | fullsupervisedloss.backward() 138 | self.gsi_optimizer.step() 139 | 140 | print("Epoch: (%3d) (%5d/%5d) | Crossentropy Loss:%.2e" % 141 | (epoch, i + 1, len(labeled_loader), fullsupervisedloss.item())) 142 | 143 | self.writer_supervised.add_scalars('Supervised Loss', {'CE Loss ':fullsupervisedloss}, len(labeled_loader)*epoch + i) 144 | 145 | ### For getting the IoU for the image 146 | self.Gsi.eval() 147 | with torch.no_grad(): 148 | for i, (val_img, val_gt, _) in enumerate(val_loader): 149 | val_img, val_gt = utils.cuda([val_img, val_gt], args.gpu_ids) 150 | 151 | outputs = self.Gsi(val_img) 152 | outputs = self.interp_val(outputs) 153 | outputs = self.activation_softmax(outputs) 154 | 155 | pred = outputs.data.max(1)[1].cpu().numpy() 156 | gt = val_gt.squeeze().data.cpu().numpy() 157 | 158 | self.running_metrics_val.update(gt, pred) 159 | 160 | score, class_iou = self.running_metrics_val.get_scores() 161 | 162 | self.running_metrics_val.reset() 163 | 164 | ### For displaying the images generated by generator on tensorboard 165 | val_img, val_gt, _ = iter(val_loader).next() 166 | val_img, val_gt = utils.cuda([val_img, val_gt], args.gpu_ids) 167 | with torch.no_grad(): 168 | fake = self.Gsi(val_img).detach() 169 | fake = self.interp_val(fake) 170 | fake = self.activation_softmax(fake) 171 | fake_prediction = fake.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy() 172 | val_gt = val_gt.cpu() 173 | 174 | ### display_tensor is the final tensor that will be displayed on tensorboard 175 | display_tensor = torch.zeros([fake.shape[0], 3, fake.shape[2], fake.shape[3]]) 176 | display_tensor_gt = torch.zeros([val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]]) 177 | for i in range(fake_prediction.shape[0]): 178 | new_img = fake_prediction[i] 179 | new_img = utils.colorize_mask(new_img, self.args.dataset) ### So this is the generated image in PIL.Image format 180 | img_tensor = utils.PIL_to_tensor(new_img, self.args.dataset) 181 | display_tensor[i, :, :, :] = img_tensor 182 | 183 | display_tensor_gt[i, :, :, :] = val_gt[i] 184 | 185 | self.writer_supervised.add_image('Generated segmented image', torchvision.utils.make_grid(display_tensor, nrow=2, normalize=True), epoch) 186 | self.writer_supervised.add_image('Ground truth for the image', torchvision.utils.make_grid(display_tensor_gt, nrow=2, normalize=True), epoch) 187 | 188 | 189 | if score["Mean IoU : \t"] >= self.best_iou: 190 | self.best_iou = score["Mean IoU : \t"] 191 | # Override the latest checkpoint 192 | utils.save_checkpoint({'epoch': epoch + 1, 193 | 'Gsi': self.Gsi.state_dict(), 194 | 'gsi_optimizer': self.gsi_optimizer.state_dict(), 195 | 'best_iou': self.best_iou, 196 | 'class_iou': class_iou}, 197 | '%s/latest_supervised_model.ckpt' % (self.args.checkpoint_dir)) 198 | 199 | self.writer_supervised.close() 200 | 201 | 202 | class semisuper_cycleGAN(object): 203 | def __init__(self, args): 204 | 205 | if args.dataset == 'voc2012': 206 | self.n_channels = 21 207 | elif args.dataset == 'cityscapes': 208 | self.n_channels = 20 209 | elif args.dataset == 'acdc': 210 | self.n_channels = 4 211 | 212 | # Define the network 213 | ##################################################### 214 | # for segmentaion to image 215 | self.Gis = define_Gen(input_nc=self.n_channels, output_nc=3, ngf=args.ngf, netG='deeplab', 216 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) 217 | # for image to segmentation 218 | self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab', 219 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) 220 | self.Di = define_Dis(input_nc=3, ndf=args.ndf, netD='pixel', n_layers_D=3, 221 | norm=args.norm, gpu_ids=args.gpu_ids) 222 | self.Ds = define_Dis(input_nc=self.n_channels, ndf=args.ndf, netD='pixel', n_layers_D=3, 223 | norm=args.norm, gpu_ids=args.gpu_ids) # for voc 2012, there are 21 classes 224 | 225 | self.old_Gis = define_Gen(input_nc=self.n_channels, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', 226 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) 227 | self.old_Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='resnet_9blocks_softmax', 228 | norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) 229 | self.old_Di = define_Dis(input_nc=3, ndf=args.ndf, netD='pixel', n_layers_D=3, 230 | norm=args.norm, gpu_ids=args.gpu_ids) 231 | 232 | ### To put the pretrained weights in Gis and Gsi 233 | # if args.dataset != 'acdc': 234 | # saved_state_dict = torch.load(pretrained_loc) 235 | # new_params_Gsi = self.Gsi.state_dict().copy() 236 | # # new_params_Gis = self.Gis.state_dict().copy() 237 | # for name, param in new_params_Gsi.items(): 238 | # # print(name) 239 | # if name in saved_state_dict and param.size() == saved_state_dict[name].size(): 240 | # new_params_Gsi[name].copy_(saved_state_dict[name]) 241 | # # print('copy {}'.format(name)) 242 | # self.Gsi.load_state_dict(new_params_Gsi) 243 | # for name, param in new_params_Gis.items(): 244 | # # print(name) 245 | # if name in saved_state_dict and param.size() == saved_state_dict[name].size(): 246 | # new_params_Gis[name].copy_(saved_state_dict[name]) 247 | # # print('copy {}'.format(name)) 248 | # # self.Gis.load_state_dict(new_params_Gis) 249 | 250 | ### This is just so as to get pretrained methods for the case of Gis 251 | if args.dataset == 'voc2012': 252 | try: 253 | ckpt_for_Arnab_loss = utils.load_checkpoint('./ckpt_for_Arnab_loss.ckpt') 254 | self.old_Gis.load_state_dict(ckpt_for_Arnab_loss['Gis']) 255 | self.old_Gsi.load_state_dict(ckpt_for_Arnab_loss['Gsi']) 256 | except: 257 | print('**There is an error in loading the ckpt_for_Arnab_loss**') 258 | 259 | 260 | utils.print_networks([self.Gsi], ['Gsi']) 261 | 262 | 263 | utils.print_networks([self.Gis,self.Gsi,self.Di,self.Ds], ['Gis','Gsi','Di','Ds']) 264 | 265 | self.args = args 266 | 267 | ### interpolation 268 | self.interp = nn.Upsample((args.crop_height, args.crop_width), mode='bilinear', align_corners=True) 269 | 270 | self.MSE = nn.MSELoss() 271 | self.L1 = nn.L1Loss() 272 | self.CE = nn.CrossEntropyLoss() 273 | self.activation_softmax = nn.Softmax2d() 274 | self.activation_tanh = nn.Tanh() 275 | self.activation_sigmoid = nn.Sigmoid() 276 | 277 | ### Tensorboard writer 278 | self.writer_semisuper = SummaryWriter(tensorboard_loc + '_semisuper') 279 | self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset) 280 | 281 | ### For adding gaussian noise 282 | self.gauss_noise = utils.GaussianNoise(sigma = 0.2) 283 | 284 | # Optimizers 285 | ##################################################### 286 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.Gis.parameters(),self.Gsi.parameters()), lr=args.lr, betas=(0.5, 0.999)) 287 | self.d_optimizer = torch.optim.Adam(itertools.chain(self.Di.parameters(),self.Ds.parameters()), lr=args.lr, betas=(0.5, 0.999)) 288 | 289 | 290 | self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) 291 | self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) 292 | 293 | # Try loading checkpoint 294 | ##################################################### 295 | if not os.path.isdir(args.checkpoint_dir): 296 | os.makedirs(args.checkpoint_dir) 297 | 298 | try: 299 | ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) 300 | self.start_epoch = ckpt['epoch'] 301 | self.Di.load_state_dict(ckpt['Di']) 302 | self.Ds.load_state_dict(ckpt['Ds']) 303 | self.Gis.load_state_dict(ckpt['Gis']) 304 | self.Gsi.load_state_dict(ckpt['Gsi']) 305 | self.d_optimizer.load_state_dict(ckpt['d_optimizer']) 306 | self.g_optimizer.load_state_dict(ckpt['g_optimizer']) 307 | self.best_iou = ckpt['best_iou'] 308 | except: 309 | print(' [*] No checkpoint!') 310 | self.start_epoch = 0 311 | self.best_iou = -100 312 | 313 | 314 | 315 | def train(self, args): 316 | transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset) 317 | 318 | # let the choice of dataset configurable 319 | if self.args.dataset == 'voc2012': 320 | labeled_set = VOCDataset(root_path=root, name='label', ratio=0.1, transformation=transform, 321 | augmentation=None) 322 | unlabeled_set = VOCDataset(root_path=root, name='unlabel', ratio=0.1, transformation=transform, 323 | augmentation=None) 324 | val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=transform, 325 | augmentation=None) 326 | elif self.args.dataset == 'cityscapes': 327 | labeled_set = CityscapesDataset(root_path=root_cityscapes, name='label', ratio=0.5, transformation=transform, 328 | augmentation=None) 329 | unlabeled_set = CityscapesDataset(root_path=root_cityscapes, name='unlabel', ratio=0.5, transformation=transform, 330 | augmentation=None) 331 | val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform, 332 | augmentation=None) 333 | elif self.args.dataset == 'acdc': 334 | labeled_set = ACDCDataset(root_path=root_acdc, name='label', ratio=0.5, transformation=transform, 335 | augmentation=None) 336 | unlabeled_set = ACDCDataset(root_path=root_acdc, name='unlabel', ratio=0.5, transformation=transform, 337 | augmentation=None) 338 | val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform, 339 | augmentation=None) 340 | 341 | ''' 342 | https://discuss.pytorch.org/t/about-the-relation-between-batch-size-and-length-of-data-loader/10510 343 | ^^ The reason for using drop_last=True so as to obtain an even size of all the batches and 344 | deleting the last batch with less images 345 | ''' 346 | labeled_loader = DataLoader(labeled_set, batch_size=args.batch_size, shuffle=True, drop_last=True) 347 | unlabeled_loader = DataLoader(unlabeled_set, batch_size=args.batch_size, shuffle=True, drop_last=True) 348 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, drop_last=True) 349 | 350 | new_img_fake_sample = utils.Sample_from_Pool() 351 | img_fake_sample = utils.Sample_from_Pool() 352 | gt_fake_sample = utils.Sample_from_Pool() 353 | 354 | img_dis_loss, gt_dis_loss, unsupervisedloss, fullsupervisedloss = 0, 0, 0, 0 355 | 356 | ### Variable to regulate the frequency of update between Discriminators and Generators 357 | counter = 0 358 | 359 | for epoch in range(self.start_epoch, args.epochs): 360 | lr = self.g_optimizer.param_groups[0]['lr'] 361 | print('learning rate = %.7f' % lr) 362 | 363 | self.Gsi.train() 364 | self.Gis.train() 365 | 366 | # if (epoch+1)%10 == 0: 367 | # args.lamda_img = args.lamda_img + 0.08 368 | # args.lamda_gt = args.lamda_gt + 0.04 369 | 370 | for i, ((l_img, l_gt, _), (unl_img, _, _)) in enumerate(zip(labeled_loader, unlabeled_loader)): 371 | # step 372 | step = epoch * min(len(labeled_loader), len(unlabeled_loader)) + i + 1 373 | 374 | l_img, unl_img, l_gt = utils.cuda([l_img, unl_img, l_gt], args.gpu_ids) 375 | 376 | # Generator Computations 377 | ################################################## 378 | 379 | set_grad([self.Di, self.Ds, self.old_Di], False) 380 | set_grad([self.old_Gsi, self.old_Gis], False) 381 | self.g_optimizer.zero_grad() 382 | 383 | # Forward pass through generators 384 | ################################################## 385 | fake_img = self.Gis(make_one_hot(l_gt, args.dataset, args.gpu_ids).float()) 386 | fake_gt = self.Gsi(unl_img.float()) ### having 21 channels 387 | lab_gt = self.Gsi(l_img) ### having 21 channels 388 | 389 | ### Getting the outputs of the model to correct dimensions 390 | fake_img = self.interp(fake_img) 391 | fake_gt = self.interp(fake_gt) 392 | lab_gt = self.interp(lab_gt) 393 | 394 | # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0) ### will get into no channels 395 | # fake_gt = fake_gt.unsqueeze(1) ### will get into 1 channel only 396 | # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids) 397 | 398 | lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1)) 399 | 400 | ### Again applying activations 401 | lab_gt = self.activation_softmax(lab_gt) 402 | fake_gt = self.activation_softmax(fake_gt) 403 | # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0) 404 | # fake_gt = fake_gt.unsqueeze(1) 405 | # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids) 406 | # fake_img = self.activation_tanh(fake_img) 407 | 408 | recon_img = self.Gis(fake_gt.float()) 409 | recon_lab_img = self.Gis(lab_gt.float()) 410 | recon_gt = self.Gsi(fake_img.float()) 411 | 412 | ### Getting the outputs of the model to correct dimensions 413 | recon_img = self.interp(recon_img) 414 | recon_lab_img = self.interp(recon_lab_img) 415 | recon_gt = self.interp(recon_gt) 416 | 417 | ### This is for the case of the new loss between the recon_img from resnet and deeplab network 418 | resnet_fake_gt = self.old_Gsi(unl_img.float()) 419 | resnet_lab_gt = self.old_Gsi(l_img) 420 | resnet_lab_gt = self.activation_softmax(resnet_lab_gt) 421 | resnet_fake_gt = self.activation_softmax(resnet_fake_gt) 422 | resnet_recon_img = self.old_Gis(resnet_fake_gt.float()) 423 | resnet_recon_lab_img = self.old_Gis(resnet_lab_gt.float()) 424 | 425 | ## Applying the tanh activations 426 | # recon_img = self.activation_tanh(recon_img) 427 | # recon_lab_img = self.activation_tanh(recon_lab_img) 428 | 429 | # Adversarial losses 430 | ################################################### 431 | fake_img_dis = self.Di(fake_img) 432 | resnet_fake_img_dis = self.old_Di(recon_img) 433 | 434 | ### For passing different type of input to Ds 435 | fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0) 436 | fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1) 437 | fake_gt_discriminator = make_one_hot(fake_gt_discriminator, args.dataset, args.gpu_ids) 438 | fake_gt_dis = self.Ds(fake_gt_discriminator.float()) 439 | # lab_gt_dis = self.Ds(lab_gt) 440 | 441 | real_label_gt = utils.cuda(Variable(torch.ones(fake_gt_dis.size())), args.gpu_ids) 442 | real_label_img = utils.cuda(Variable(torch.ones(fake_img_dis.size())), args.gpu_ids) 443 | 444 | # here is much better to have a cross entropy loss for classification. 445 | img_gen_loss = self.MSE(fake_img_dis, real_label_img) 446 | gt_gen_loss = self.MSE(fake_gt_dis, real_label_gt) 447 | # gt_label_gen_loss = self.MSE(lab_gt_dis, real_label) 448 | 449 | 450 | # Cycle consistency losses 451 | ################################################### 452 | resnet_img_cycle_loss = self.MSE(resnet_fake_img_dis, real_label_img) 453 | # img_cycle_loss = self.L1(recon_img, unl_img) 454 | # img_cycle_loss_perceptual = perceptual_loss(recon_img, unl_img, args.gpu_ids) 455 | gt_cycle_loss = self.CE(recon_gt, l_gt.squeeze(1)) 456 | # lab_img_cycle_loss = self.L1(recon_lab_img, l_img) * args.lamda 457 | 458 | # Total generators losses 459 | ################################################### 460 | # lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1)) 461 | lab_loss_MSE = self.L1(fake_img, l_img) 462 | # lab_loss_perceptual = perceptual_loss(fake_img, l_img, args.gpu_ids) 463 | 464 | fullsupervisedloss = args.lab_CE_weight*lab_loss_CE + args.lab_MSE_weight*lab_loss_MSE 465 | 466 | unsupervisedloss = args.adversarial_weight*(img_gen_loss + gt_gen_loss) + resnet_img_cycle_loss + gt_cycle_loss*args.lamda_gt 467 | 468 | gen_loss = fullsupervisedloss + unsupervisedloss 469 | 470 | # Update generators 471 | ################################################### 472 | gen_loss.backward() 473 | 474 | self.g_optimizer.step() 475 | 476 | 477 | if counter % 1 == 0: 478 | # Discriminator Computations 479 | ################################################# 480 | 481 | set_grad([self.Di, self.Ds, self.old_Di], True) 482 | self.d_optimizer.zero_grad() 483 | 484 | # Sample from history of generated images 485 | ################################################# 486 | if torch.rand(1) < 0.0: 487 | fake_img = self.gauss_noise(fake_img.cpu()) 488 | fake_gt = self.gauss_noise(fake_gt.cpu()) 489 | 490 | recon_img = Variable(torch.Tensor(new_img_fake_sample([recon_img.cpu().data.numpy()])[0])) 491 | fake_img = Variable(torch.Tensor(img_fake_sample([fake_img.cpu().data.numpy()])[0])) 492 | # lab_gt = Variable(torch.Tensor(gt_fake_sample([lab_gt.cpu().data.numpy()])[0])) 493 | fake_gt = Variable(torch.Tensor(gt_fake_sample([fake_gt.cpu().data.numpy()])[0])) 494 | 495 | recon_img, fake_img, fake_gt = utils.cuda([recon_img, fake_img, fake_gt], args.gpu_ids) 496 | 497 | # Forward pass through discriminators 498 | ################################################# 499 | unl_img_dis = self.Di(unl_img) 500 | fake_img_dis = self.Di(fake_img) 501 | resnet_recon_img_dis = self.old_Di(resnet_recon_img) 502 | resnet_fake_img_dis = self.old_Di(recon_img) 503 | 504 | # lab_gt_dis = self.Ds(lab_gt) 505 | 506 | l_gt = make_one_hot(l_gt, args.dataset, args.gpu_ids) 507 | real_gt_dis = self.Ds(l_gt.float()) 508 | 509 | fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0) 510 | fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1) 511 | fake_gt_discriminator = make_one_hot(fake_gt_discriminator, args.dataset, args.gpu_ids) 512 | fake_gt_dis = self.Ds(fake_gt_discriminator.float()) 513 | 514 | real_label_img = utils.cuda(Variable(torch.ones(unl_img_dis.size())), args.gpu_ids) 515 | fake_label_img = utils.cuda(Variable(torch.zeros(fake_img_dis.size())), args.gpu_ids) 516 | real_label_gt = utils.cuda(Variable(torch.ones(real_gt_dis.size())), args.gpu_ids) 517 | fake_label_gt = utils.cuda(Variable(torch.zeros(fake_gt_dis.size())), args.gpu_ids) 518 | 519 | # Discriminator losses 520 | ################################################## 521 | img_dis_real_loss = self.MSE(unl_img_dis, real_label_img) 522 | img_dis_fake_loss = self.MSE(fake_img_dis, fake_label_img) 523 | gt_dis_real_loss = self.MSE(real_gt_dis, real_label_gt) 524 | gt_dis_fake_loss = self.MSE(fake_gt_dis, fake_label_gt) 525 | # lab_gt_dis_fake_loss = self.MSE(lab_gt_dis, fake_label) 526 | 527 | cycle_img_dis_real_loss = self.MSE(resnet_recon_img_dis, real_label_img) 528 | cycle_img_dis_fake_loss = self.MSE(resnet_fake_img_dis, fake_label_img) 529 | 530 | # Total discriminators losses 531 | img_dis_loss = (img_dis_real_loss + img_dis_fake_loss)*0.5 532 | gt_dis_loss = (gt_dis_real_loss + gt_dis_fake_loss)*0.5 533 | # lab_gt_dis_loss = (gt_dis_real_loss + lab_gt_dis_fake_loss)*0.33 534 | cycle_img_dis_loss = cycle_img_dis_real_loss + cycle_img_dis_fake_loss 535 | 536 | # Update discriminators 537 | ################################################## 538 | discriminator_loss = args.discriminator_weight * (img_dis_loss + gt_dis_loss) + cycle_img_dis_loss 539 | discriminator_loss.backward() 540 | 541 | # lab_gt_dis_loss.backward() 542 | self.d_optimizer.step() 543 | 544 | print("Epoch: (%3d) (%5d/%5d) | Dis Loss:%.2e | Unlab Gen Loss:%.2e | Lab Gen loss:%.2e" % 545 | (epoch, i + 1, min(len(labeled_loader), len(unlabeled_loader)), 546 | img_dis_loss + gt_dis_loss, unsupervisedloss, fullsupervisedloss)) 547 | 548 | self.writer_semisuper.add_scalars('Dis Loss', {'img_dis_loss':img_dis_loss, 'gt_dis_loss':gt_dis_loss, 'cycle_img_dis_loss':cycle_img_dis_loss}, len(labeled_loader)*epoch + i) 549 | self.writer_semisuper.add_scalars('Unlabelled Loss', {'img_gen_loss': img_gen_loss, 'gt_gen_loss':gt_gen_loss, 'img_cycle_loss':resnet_img_cycle_loss, 'gt_cycle_loss':gt_cycle_loss}, len(labeled_loader)*epoch + i) 550 | self.writer_semisuper.add_scalars('Labelled Loss', {'lab_loss_CE':lab_loss_CE, 'lab_loss_MSE':lab_loss_MSE}, len(labeled_loader)*epoch + i) 551 | 552 | counter += 1 553 | 554 | ### For getting the mean IoU 555 | self.Gsi.eval() 556 | self.Gis.eval() 557 | with torch.no_grad(): 558 | for i, (val_img, val_gt, _) in enumerate(val_loader): 559 | val_img, val_gt = utils.cuda([val_img, val_gt], args.gpu_ids) 560 | 561 | outputs = self.Gsi(val_img) 562 | outputs = self.interp(outputs) 563 | outputs = self.activation_softmax(outputs) 564 | 565 | pred = outputs.data.max(1)[1].cpu().numpy() 566 | gt = val_gt.squeeze().data.cpu().numpy() 567 | 568 | self.running_metrics_val.update(gt, pred) 569 | 570 | score, class_iou = self.running_metrics_val.get_scores() 571 | 572 | self.running_metrics_val.reset() 573 | 574 | print('The mIoU for the epoch is: ', score["Mean IoU : \t"]) 575 | 576 | ### For displaying the images generated by generator on tensorboard using validation images 577 | val_image, val_gt, _ = iter(val_loader).next() 578 | val_image, val_gt = utils.cuda([val_image, val_gt], args.gpu_ids) 579 | with torch.no_grad(): 580 | fake_label = self.Gsi(val_image).detach() 581 | fake_label = self.interp(fake_label) 582 | fake_label = self.activation_softmax(fake_label) 583 | fake_label = fake_label.data.max(1)[1].squeeze_(1).squeeze_(0) 584 | fake_label = fake_label.unsqueeze(1) 585 | fake_label = make_one_hot(fake_label, args.dataset, args.gpu_ids) 586 | fake_img = self.Gis(fake_label).detach() 587 | fake_img = self.interp(fake_img) 588 | # fake_img = self.activation_tanh(fake_img) 589 | 590 | fake_img_from_labels = self.Gis(make_one_hot(val_gt, args.dataset, args.gpu_ids).float()).detach() 591 | fake_img_from_labels = self.interp(fake_img_from_labels) 592 | # fake_img_from_labels = self.activation_tanh(fake_img_from_labels) 593 | fake_label_regenerated = self.Gsi(fake_img_from_labels).detach() 594 | fake_label_regenerated = self.interp(fake_label_regenerated) 595 | fake_label_regenerated = self.activation_softmax(fake_label_regenerated) 596 | fake_prediction_label = fake_label.data.max(1)[1].squeeze_(1).cpu().numpy() 597 | fake_regenerated_label = fake_label_regenerated.data.max(1)[1].squeeze_(1).cpu().numpy() 598 | val_gt = val_gt.cpu() 599 | 600 | fake_img = fake_img.cpu() 601 | fake_img_from_labels = fake_img_from_labels.cpu() 602 | ### Now i am going to revert back the transformation on these images 603 | if self.args.dataset == 'voc2012' or self.args.dataset == 'cityscapes': 604 | trans_mean = [0.5, 0.5, 0.5] 605 | trans_std = [0.5, 0.5, 0.5] 606 | for i in range(3): 607 | fake_img[:, i, :, :] = ((fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i]) 608 | fake_img_from_labels[:, i, :, :] = ((fake_img_from_labels[:, i, :, :] * trans_std[i]) + trans_mean[i]) 609 | 610 | elif self.args.dataset == 'acdc': 611 | trans_mean = [0.5] 612 | trans_std = [0.5] 613 | for i in range(1): 614 | fake_img[:, i, :, :] = ((fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i]) 615 | fake_img_from_labels[:, i, :, :] = ((fake_img_from_labels[:, i, :, :] * trans_std[i]) + trans_mean[i]) 616 | 617 | ### display_tensor is the final tensor that will be displayed on tensorboard 618 | display_tensor_label = torch.zeros([fake_label.shape[0], 3, fake_label.shape[2], fake_label.shape[3]]) 619 | display_tensor_gt = torch.zeros([val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]]) 620 | display_tensor_regen_label = torch.zeros([fake_label_regenerated.shape[0], 3, fake_label_regenerated.shape[2], fake_label_regenerated.shape[3]]) 621 | for i in range(fake_prediction_label.shape[0]): 622 | new_img_label = fake_prediction_label[i] 623 | new_img_label = utils.colorize_mask(new_img_label, self.args.dataset) ### So this is the generated image in PIL.Image format 624 | img_tensor_label = utils.PIL_to_tensor(new_img_label, self.args.dataset) 625 | display_tensor_label[i, :, :, :] = img_tensor_label 626 | 627 | display_tensor_gt[i, :, :, :] = val_gt[i] 628 | 629 | regen_label = fake_regenerated_label[i] 630 | regen_label = utils.colorize_mask(regen_label, self.args.dataset) 631 | regen_tensor_label = utils.PIL_to_tensor(regen_label, self.args.dataset) 632 | display_tensor_regen_label[i, :, :, :] = regen_tensor_label 633 | 634 | self.writer_semisuper.add_image('Generated segmented image: ', torchvision.utils.make_grid(display_tensor_label, nrow=2, normalize=True), epoch) 635 | self.writer_semisuper.add_image('Generated image back from segmentation: ', torchvision.utils.make_grid(fake_img, nrow=2, normalize=True), epoch) 636 | self.writer_semisuper.add_image('Ground truth for the image: ', torchvision.utils.make_grid(display_tensor_gt, nrow=2, normalize=True), epoch) 637 | self.writer_semisuper.add_image('Image generated from val labels: ', torchvision.utils.make_grid(fake_img_from_labels, nrow=2, normalize=True), epoch) 638 | self.writer_semisuper.add_image('Labels generated back from the cycle: ', torchvision.utils.make_grid(display_tensor_regen_label, nrow=2, normalize=True), epoch) 639 | 640 | 641 | if score["Mean IoU : \t"] >= self.best_iou: 642 | self.best_iou = score["Mean IoU : \t"] 643 | 644 | # Override the latest checkpoint 645 | ####################################################### 646 | utils.save_checkpoint({'epoch': epoch + 1, 647 | 'Di': self.Di.state_dict(), 648 | 'Ds': self.Ds.state_dict(), 649 | 'Gis': self.Gis.state_dict(), 650 | 'Gsi': self.Gsi.state_dict(), 651 | 'd_optimizer': self.d_optimizer.state_dict(), 652 | 'g_optimizer': self.g_optimizer.state_dict(), 653 | 'best_iou': self.best_iou, 654 | 'class_iou': class_iou}, 655 | '%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) 656 | 657 | # Update learning rates 658 | ######################## 659 | self.g_lr_scheduler.step() 660 | self.d_lr_scheduler.step() 661 | 662 | self.writer_semisuper.close() 663 | --------------------------------------------------------------------------------