├── logs └── .gitignore ├── result ├── .gitignore ├── 91-3.png ├── shadow_removal.jpg └── detected_shadow.jpg ├── checkpoints └── .gitkeep ├── dataset ├── test │ ├── test_A │ │ └── .gitignore │ ├── test_B │ │ └── .gitignore │ └── test_C │ │ └── .gitignore └── train │ ├── train_A │ └── .gitignore │ ├── train_B │ └── .gitignore │ └── train_C │ └── .gitignore ├── README.md ├── utils ├── data_loader.py └── ISTD_transforms.py ├── models └── ST_CGAN.py ├── test.py └── train.py /logs/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /result/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/test/test_A/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/test/test_B/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/test/test_C/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/train/train_A/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/train/train_B/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/train/train_C/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /result/91-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IsHYuhi/ST-CGAN_Stacked_Conditional_Generative_Adversarial_Networks/HEAD/result/91-3.png -------------------------------------------------------------------------------- /result/shadow_removal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IsHYuhi/ST-CGAN_Stacked_Conditional_Generative_Adversarial_Networks/HEAD/result/shadow_removal.jpg -------------------------------------------------------------------------------- /result/detected_shadow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IsHYuhi/ST-CGAN_Stacked_Conditional_Generative_Adversarial_Networks/HEAD/result/detected_shadow.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ST-CGAN: Stacked Conditional Generative Adversarial Networks for Jointly Learning Shadow Detection and Shadow Removal with PyTorch 2 | 3 | This repository is unofficial implementation of [Stacked Conditional Generative Adversarial Networks for Jointly Learning Shadow Detection and Shadow Removal](https://arxiv.org/abs/1712.02478) [Wang+, **CVPR** 2018] with PyTorch. 4 | 5 | **Official Dataset and Code(coming soon...) is [here](https://github.com/DeepInsight-PCALab/ST-CGAN).** 6 | 7 | ## Requirements 8 | * Python3.x 9 | * PyTorch 1.5.0 10 | * pillow 11 | * matplotlib 12 | 13 | ## Usage 14 | * Set datasets under ```./dataset```. You can Download datasets from [here](https://github.com/DeepInsight-PCALab/ST-CGAN). 15 | 16 | Then, 17 | ### Training 18 | ``` 19 | python3 train.py 20 | ``` 21 | ### Testing 22 | When Testing images from ISTD dataset. 23 | ``` 24 | python3 test.py -l 25 | ``` 26 | When you would like to test your own image. 27 | ``` 28 | python3 test.py -l -i -o 29 | ``` 30 | 31 | 32 | ## Results 33 | Here is a result from test sets. 34 | ![](https://github.com/IsHYuhi/ST-CGAN_Stacked_Conditional_Generative_Adversarial_Networks/blob/master/result/91-3.png) 35 | (Left to right: input, ground truth, shadow removal, ground truth shadow, shadow detection) 36 | 37 | ### Shadow Detection 38 | Here are some results from validation set. 39 | ![](https://github.com/IsHYuhi/ST-CGAN_Stacked_Conditional_Generative_Adversarial_Networks/blob/master/result/detected_shadow.jpg) 40 | (Top to bottom: ground truth, shadow detection) 41 | 42 | ### Shadow Removal 43 | Here are some results from validation set. 44 | ![](https://github.com/IsHYuhi/ST-CGAN_Stacked_Conditional_Generative_Adversarial_Networks/blob/master/result/shadow_removal.jpg) 45 | (Top to bottom: input, ground truth, shadow removal) 46 | 47 | ## Trained model 48 | You can download from [here](https://drive.google.com/drive/folders/1J1l21k5AoUXHxic-Bj3eXBFP--YzjFXO?usp=sharing). 49 | 50 | ## References 51 | * Stacked Conditional Generative Adversarial Networks for Jointly Learning Shadow Detection and Shadow Removal, Jifeng Wang, Xiang Li, Le Hui, Jian Yang, **Nanjing University of Science and Technology**, [[arXiv]](https://arxiv.org/abs/1712.02478) 52 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import torch.utils.data as data 5 | from . import ISTD_transforms 6 | from PIL import Image 7 | import random 8 | from torchvision import transforms 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def make_datapath_list(phase="train", rate=0.8): 13 | """ 14 | make filepath list for train, validation and test images 15 | """ 16 | random.seed(44) 17 | 18 | rootpath = './dataset/' + phase + '/' 19 | files_name = os.listdir(rootpath + phase + '_A') 20 | 21 | if phase=='train': 22 | random.shuffle(files_name) 23 | elif phase=='test': 24 | files_name.sort() 25 | 26 | path_A = [] 27 | path_B = [] 28 | path_C = [] 29 | 30 | for name in files_name: 31 | path_A.append(rootpath + phase + '_A/'+name) 32 | path_B.append(rootpath + phase + '_B/'+name) 33 | path_C.append(rootpath + phase + '_C/'+name) 34 | 35 | num = len(path_A) 36 | 37 | if phase=='train': 38 | path_A, path_A_val = path_A[:int(num*rate)], path_A[int(num*rate):] 39 | path_B, path_B_val = path_B[:int(num*rate)], path_B[int(num*rate):] 40 | path_C, path_C_val = path_C[:int(num*rate)], path_C[int(num*rate):] 41 | path_list = {'path_A': path_A, 'path_B': path_B, 'path_C': path_C} 42 | path_list_val = {'path_A': path_A_val, 'path_B': path_B_val, 'path_C': path_C_val} 43 | return path_list, path_list_val 44 | 45 | elif phase=='test': 46 | path_list = {'path_A': path_A, 'path_B': path_B, 'path_C': path_C} 47 | return path_list 48 | 49 | class ImageTransformOwn(): 50 | """ 51 | preprocessing images for own images 52 | """ 53 | def __init__(self, size=256, mean=(0.5, ), std=(0.5, )): 54 | self.data_transform = transforms.Compose([transforms.ToTensor(), 55 | transforms.Normalize(mean, std)]) 56 | 57 | def __call__(self, img): 58 | return self.data_transform(img) 59 | 60 | 61 | class ImageTransform(): 62 | """ 63 | preprocessing images 64 | """ 65 | def __init__(self, size=286, crop_size=256, mean=(0.5, ), std=(0.5, )): 66 | self.data_transform = {'train': ISTD_transforms.Compose([ISTD_transforms.Scale(size=size), 67 | ISTD_transforms.RandomCrop(size=crop_size), 68 | ISTD_transforms.RandomHorizontalFlip(p=0.5), 69 | ISTD_transforms.ToTensor(), 70 | ISTD_transforms.Normalize(mean, std)]), 71 | 72 | 'val': ISTD_transforms.Compose([ISTD_transforms.Scale(size=size), 73 | ISTD_transforms.RandomCrop(size=crop_size), 74 | ISTD_transforms.ToTensor(), 75 | ISTD_transforms.Normalize(mean, std)]), 76 | 77 | 'test': ISTD_transforms.Compose([ISTD_transforms.Scale(size=size), 78 | ISTD_transforms.RandomCrop(size=crop_size), 79 | ISTD_transforms.ToTensor(), 80 | ISTD_transforms.Normalize(mean, std)])} 81 | 82 | def __call__(self, phase, img): 83 | return self.data_transform[phase](img) 84 | 85 | 86 | class ImageDataset(data.Dataset): 87 | """ 88 | Dataset class. Inherit Dataset class from PyTrorch. 89 | """ 90 | def __init__(self, img_list, img_transform, phase): 91 | self.img_list = img_list 92 | self.img_transform = img_transform 93 | self.phase = phase 94 | 95 | def __len__(self): 96 | return len(self.img_list['path_A']) 97 | 98 | def __getitem__(self, index): 99 | ''' 100 | get tensor type preprocessed Image 101 | ''' 102 | img = Image.open(self.img_list['path_A'][index]).convert('RGB') 103 | gt_shadow = Image.open(self.img_list['path_B'][index]) 104 | gt = Image.open(self.img_list['path_C'][index]).convert('RGB') 105 | 106 | img, gt_shadow, gt = self.img_transform(self.phase, [img, gt_shadow, gt]) 107 | 108 | return img, gt_shadow, gt 109 | 110 | if __name__ == '__main__': 111 | img = Image.open('../dataset/train/train_A/test.png').convert('RGB') 112 | gt_shadow = Image.open('../dataset/train/train_B/test.png') 113 | gt = Image.open('../dataset/train/train_C/test.png').convert('RGB') 114 | 115 | print(img.size) 116 | print(gt_shadow.size) 117 | print(gt.size) 118 | 119 | f = plt.figure() 120 | f.add_subplot(1, 3, 1) 121 | plt.imshow(img) 122 | f.add_subplot(1, 3, 2) 123 | plt.imshow(gt_shadow, cmap='gray') 124 | f.add_subplot(1, 3, 3) 125 | plt.imshow(gt) 126 | 127 | img_transforms = ImageTransform(size=286, crop_size=256, mean=(0.5, ), std=(0.5, )) 128 | img, gt_shadow, gt = img_transforms([img, gt_shadow, gt]) 129 | 130 | print(img.shape) 131 | print(gt_shadow.shape) 132 | print(gt.shape) 133 | 134 | 135 | f.add_subplot(2, 3, 4) 136 | plt.imshow(transforms.ToPILImage()(img).convert('RGB')) 137 | f.add_subplot(2, 3, 5) 138 | plt.imshow(transforms.ToPILImage()(gt_shadow).convert('L'), cmap='gray') 139 | f.add_subplot(2, 3, 6) 140 | plt.imshow(transforms.ToPILImage()(gt).convert('RGB')) 141 | f.tight_layout() 142 | plt.show() -------------------------------------------------------------------------------- /utils/ISTD_transforms.py: -------------------------------------------------------------------------------- 1 | #refered https://github.com/pytorch/vision/blob/master/torchvision/transforms/transforms.py 2 | 3 | import math 4 | import numbers 5 | import random 6 | import warnings 7 | from collections.abc import Sequence 8 | from typing import Tuple, List, Optional 9 | 10 | import torch 11 | from PIL import Image 12 | from torch import Tensor 13 | import torchvision.transforms.functional as F 14 | 15 | 16 | class Compose(object): 17 | def __init__(self, transforms): 18 | self.transforms = transforms 19 | 20 | def __call__(self, img): 21 | for t in self.transforms: 22 | img = t(img) 23 | return img 24 | 25 | def __repr__(self): 26 | format_string = self.__class__.__name__ + '(' 27 | for t in self.transforms: 28 | format_string += '\n' 29 | format_string += ' {0}'.format(t) 30 | format_string += '\n)' 31 | return format_string 32 | 33 | 34 | class ToTensor(object): 35 | def __call__(self, pic): 36 | return F.to_tensor(pic[0]), F.to_tensor(pic[1]), F.to_tensor(pic[2]) 37 | 38 | def __repr__(self): 39 | return self.__class__.__name__ + '()' 40 | 41 | 42 | class Scale(object): 43 | def __init__(self, size, interpolation=Image.BILINEAR): 44 | self.size = size 45 | self.interpolation = interpolation 46 | 47 | def __call__(self, imgs): 48 | output = [] 49 | for img in imgs: 50 | w, h = img.size 51 | if (w <= h and w == self.size) or (h <= w and h == self.size): 52 | output.append(img) 53 | continue 54 | if w < h: 55 | ow = self.size 56 | oh = int(self.size * h / w) 57 | output.append(img.resize((ow, oh), self.interpolation)) 58 | continue 59 | else: 60 | oh = self.size 61 | ow = int(self.size * w / h) 62 | output.append(img.resize((ow, oh), self.interpolation)) 63 | return output[0], output[1], output[2] 64 | 65 | 66 | class Normalize(object): 67 | def __init__(self, mean, std, inplace=False): 68 | self.mean = mean 69 | self.std = std 70 | self.inplace = inplace 71 | 72 | def __call__(self, tensor): 73 | return F.normalize(tensor[0], self.mean, self.std, self.inplace), F.normalize(tensor[1], self.mean, self.std, self.inplace), F.normalize(tensor[2], self.mean, self.std, self.inplace) 74 | 75 | def __repr__(self): 76 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 77 | 78 | 79 | class CenterCrop(torch.nn.Module): 80 | def __init__(self, size): 81 | super().__init__() 82 | if isinstance(size, numbers.Number): 83 | self.size = (int(size), int(size)) 84 | elif isinstance(size, Sequence) and len(size) == 1: 85 | self.size = (size[0], size[0]) 86 | else: 87 | if len(size) != 2: 88 | raise ValueError("Please provide only two dimensions (h, w) for size.") 89 | 90 | self.size = size 91 | 92 | def forward(self, img): 93 | return F.center_crop(img[0], self.size), F.center_crop(img[1], self.size), F.center_crop(img[2], self.size) 94 | 95 | def __repr__(self): 96 | return self.__class__.__name__ + '(size={0})'.format(self.size) 97 | 98 | 99 | class RandomCrop(torch.nn.Module): 100 | @staticmethod 101 | def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: 102 | w, h = img.size 103 | th, tw = output_size 104 | if w == tw and h == th: 105 | return 0, 0, h, w 106 | 107 | i = torch.randint(0, h - th + 1, size=(1, )).item() 108 | j = torch.randint(0, w - tw + 1, size=(1, )).item() 109 | return i, j, th, tw 110 | 111 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): 112 | super().__init__() 113 | if isinstance(size, numbers.Number): 114 | self.size = (int(size), int(size)) 115 | elif isinstance(size, Sequence) and len(size) == 1: 116 | self.size = (size[0], size[0]) 117 | else: 118 | if len(size) != 2: 119 | raise ValueError("Please provide only two dimensions (h, w) for size.") 120 | 121 | # cast to tuple for torchscript 122 | self.size = tuple(size) 123 | self.padding = padding 124 | self.pad_if_needed = pad_if_needed 125 | self.fill = fill 126 | self.padding_mode = padding_mode 127 | 128 | def forward(self, img): 129 | if self.padding is not None: 130 | img[0] = F.pad(img[0], self.padding, self.fill, self.padding_mode) 131 | 132 | width, height = img[0].size 133 | # pad the width if needed 134 | if self.pad_if_needed and width < self.size[1]: 135 | padding = [self.size[1] - width, 0] 136 | img[0] = F.pad(img[0], padding, self.fill, self.padding_mode) 137 | # pad the height if needed 138 | if self.pad_if_needed and height < self.size[0]: 139 | padding = [0, self.size[0] - height] 140 | img[0] = F.pad(img[0], padding, self.fill, self.padding_mode) 141 | 142 | i, j, h, w = self.get_params(img[0], self.size) 143 | 144 | return F.crop(img[0], i, j, h, w), F.crop(img[1], i, j, h, w), F.crop(img[2], i, j, h, w) 145 | 146 | def __repr__(self): 147 | return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) 148 | 149 | 150 | class RandomHorizontalFlip(torch.nn.Module): 151 | def __init__(self, p=0.5): 152 | super().__init__() 153 | self.p = p 154 | 155 | def forward(self, img): 156 | if torch.rand(1) < self.p: 157 | return F.hflip(img[0]), F.hflip(img[1]), F.hflip(img[2]) 158 | return img[0], img[1], img[2] 159 | 160 | def __repr__(self): 161 | return self.__class__.__name__ + '(p={})'.format(self.p) -------------------------------------------------------------------------------- /models/ST_CGAN.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | 5 | def weights_init(init_type='gaussian'): 6 | def init_fun(m): 7 | classname = m.__class__.__name__ 8 | if (classname.find('Conv') == 0 or classname.find( 9 | 'Linear') == 0) and hasattr(m, 'weight'): 10 | if init_type == 'gaussian': 11 | nn.init.normal_(m.weight, 0.0, 0.02) 12 | elif init_type == 'xavier': 13 | nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) 14 | elif init_type == 'kaiming': 15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 16 | elif init_type == 'orthogonal': 17 | nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) 18 | elif init_type == 'default': 19 | pass 20 | else: 21 | assert 0, "Unsupported initialization: {}".format(init_type) 22 | if hasattr(m, 'bias') and m.bias is not None: 23 | nn.init.constant_(m.bias, 0.0) 24 | 25 | return init_fun 26 | 27 | 28 | class Cvi(nn.Module): 29 | def __init__(self, in_channels, out_channels, before=None, after=False, kernel_size=4, stride=2, 30 | padding=1, dilation=1, groups=1, bias=False): 31 | super(Cvi, self).__init__() 32 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 33 | self.conv.apply(weights_init('gaussian')) 34 | 35 | if after=='BN': 36 | self.after = nn.BatchNorm2d(out_channels) 37 | elif after=='Tanh': 38 | self.after = torch.tanh 39 | elif after=='sigmoid': 40 | self.after = torch.sigmoid 41 | 42 | if before=='ReLU': 43 | self.before = nn.ReLU(inplace=True) 44 | elif before=='LReLU': 45 | self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True) 46 | 47 | def forward(self, x): 48 | 49 | if hasattr(self, 'before'): 50 | x = self.before(x) 51 | 52 | x = self.conv(x) 53 | 54 | if hasattr(self, 'after'): 55 | x = self.after(x) 56 | 57 | return x 58 | 59 | 60 | class CvTi(nn.Module): 61 | def __init__(self, in_channels, out_channels, before=None, after=False, kernel_size=4, stride=2, 62 | padding=1, dilation=1, groups=1, bias=False): 63 | super(CvTi, self).__init__() 64 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias) 65 | self.conv.apply(weights_init('gaussian')) 66 | 67 | if after=='BN': 68 | self.after = nn.BatchNorm2d(out_channels) 69 | elif after=='Tanh': 70 | self.after = torch.tanh 71 | elif after=='sigmoid': 72 | self.after = torch.sigmoid 73 | 74 | if before=='ReLU': 75 | self.before = nn.ReLU(inplace=True) 76 | elif before=='LReLU': 77 | self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True) 78 | 79 | def forward(self, x): 80 | 81 | if hasattr(self, 'before'): 82 | x = self.before(x) 83 | 84 | x = self.conv(x) 85 | 86 | if hasattr(self, 'after'): 87 | x = self.after(x) 88 | 89 | return x 90 | 91 | class Generator(nn.Module): 92 | def __init__(self, input_channels=3, output_channels=1): 93 | super(Generator, self).__init__() 94 | 95 | self.Cv0 = Cvi(input_channels, 64) 96 | 97 | self.Cv1 = Cvi(64, 128, before='LReLU', after='BN') 98 | 99 | self.Cv2 = Cvi(128, 256, before='LReLU', after='BN') 100 | 101 | self.Cv3 = Cvi(256, 512, before='LReLU', after='BN') 102 | 103 | self.Cv4 = Cvi(512, 512, before='LReLU', after='BN') 104 | 105 | self.Cv5 = Cvi(512, 512, before='LReLU') 106 | 107 | self.CvT6 = CvTi(512, 512, before='ReLU', after='BN') 108 | 109 | self.CvT7 = CvTi(1024, 512, before='ReLU', after='BN') 110 | 111 | self.CvT8 = CvTi(1024, 256, before='ReLU', after='BN') 112 | 113 | self.CvT9 = CvTi(512, 128, before='ReLU', after='BN') 114 | 115 | self.CvT10 = CvTi(256, 64, before='ReLU', after='BN') 116 | 117 | self.CvT11 = CvTi(128, output_channels, before='ReLU', after='Tanh') 118 | 119 | def forward(self, input): 120 | #encoder 121 | x0 = self.Cv0(input) 122 | x1 = self.Cv1(x0) 123 | x2 = self.Cv2(x1) 124 | x3 = self.Cv3(x2) 125 | x4_1 = self.Cv4(x3) 126 | x4_2 = self.Cv4(x4_1) 127 | x4_3 = self.Cv4(x4_2) 128 | x5 = self.Cv5(x4_3) 129 | 130 | #decoder 131 | x6 = self.CvT6(x5) 132 | 133 | cat1_1 = torch.cat([x6, x4_3], dim=1) 134 | x7_1 = self.CvT7(cat1_1) 135 | cat1_2 = torch.cat([x7_1, x4_2], dim=1) 136 | x7_2 = self.CvT7(cat1_2) 137 | cat1_3 = torch.cat([x7_2, x4_1], dim=1) 138 | x7_3 = self.CvT7(cat1_3) 139 | 140 | cat2 = torch.cat([x7_3, x3], dim=1) 141 | x8 = self.CvT8(cat2) 142 | 143 | cat3 = torch.cat([x8, x2], dim=1) 144 | x9 = self.CvT9(cat3) 145 | 146 | cat4 = torch.cat([x9, x1], dim=1) 147 | x10 = self.CvT10(cat4) 148 | 149 | cat5 = torch.cat([x10, x0], dim=1) 150 | out = self.CvT11(cat5) 151 | 152 | return out 153 | 154 | class Discriminator(nn.Module): 155 | def __init__(self, input_channels=4): 156 | super(Discriminator, self).__init__() 157 | 158 | self.Cv0 = Cvi(input_channels, 64) 159 | 160 | self.Cv1 = Cvi(64, 128, before='LReLU', after='BN') 161 | 162 | self.Cv2 = Cvi(128, 256, before='LReLU', after='BN') 163 | 164 | self.Cv3 = Cvi(256, 512, before='LReLU', after='BN') 165 | 166 | self.Cv4 = Cvi(512, 1, before='LReLU', after='sigmoid') 167 | 168 | def forward(self, input): 169 | x0 = self.Cv0(input) 170 | x1 = self.Cv1(x0) 171 | x2 = self.Cv2(x1) 172 | x3 = self.Cv3(x2) 173 | out = self.Cv4(x3) 174 | 175 | return out 176 | 177 | if __name__ == '__main__': 178 | #BCHW 179 | size = (3, 3, 256, 256) 180 | input = torch.ones(size) 181 | l1 = nn.L1Loss() 182 | input.requires_grad = True 183 | 184 | #convolution test 185 | conv = Cvi(3, 3) 186 | conv2 = Cvi(3, 3, before='ReLU', after='BN') 187 | output = conv(input) 188 | output2 = conv2(output) 189 | print(output.shape) 190 | print(output2.shape) 191 | loss = l1(output, torch.randn(3, 3, 128, 128)) 192 | loss.backward() 193 | print(loss.item()) 194 | 195 | convT = CvTi(3, 3) 196 | outputT = convT(output) 197 | print(outputT.shape) 198 | 199 | 200 | #Generator test 201 | model = Generator() 202 | output = model(input) 203 | print(output.shape) 204 | loss = l1(output, torch.randn(3, 1, 256, 256)) 205 | loss.backward() 206 | print(loss.item()) 207 | 208 | #Discriminator test 209 | size = (3, 4, 256, 256) 210 | input = torch.ones(size) 211 | l1 = nn.L1Loss() 212 | input.requires_grad = True 213 | model = Discriminator() 214 | output = model(input) 215 | print(output.shape) 216 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from utils.data_loader import make_datapath_list, ImageDataset, ImageTransform, ImageTransformOwn 2 | from models.ST_CGAN import Generator, Discriminator 3 | from torchvision.utils import make_grid 4 | from torchvision.utils import save_image 5 | from torchvision import models 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | from collections import OrderedDict 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | import matplotlib.pyplot as plt 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | import numpy as np 16 | import argparse 17 | import time 18 | import torch 19 | import os 20 | 21 | torch.manual_seed(44) 22 | # choose your device 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 24 | 25 | def get_parser(): 26 | parser = argparse.ArgumentParser( 27 | prog='ST-CGAN: Stacked Conditional Generative Adversarial Networks for Jointly Learning Shadow Detection and Shadow Removal', 28 | usage='python3 main.py', 29 | description='This module demonstrates shadow detection and removal using ST-CGAN.', 30 | add_help=True) 31 | 32 | parser.add_argument('-l', '--load', type=str, default=None, help='the number of checkpoints') 33 | parser.add_argument('-i', '--image_path', type=str, default=None, help='file path of image you want to test') 34 | parser.add_argument('-o', '--out_path', type=str, default='./test_result', help='saving path') 35 | parser.add_argument('-s', '--image_size', type=int, default=286) 36 | parser.add_argument('-cs', '--crop_size', type=int, default=256) 37 | parser.add_argument('-rs', '--resized_size', type=int, default=256) 38 | 39 | return parser 40 | 41 | def fix_model_state_dict(state_dict): 42 | ''' 43 | remove 'module.' of dataparallel 44 | ''' 45 | new_state_dict = OrderedDict() 46 | for k, v in state_dict.items(): 47 | name = k 48 | if name.startswith('module.'): 49 | name = name[7:] 50 | new_state_dict[name] = v 51 | return new_state_dict 52 | 53 | def check_dir(): 54 | if not os.path.exists('./test_result'): 55 | os.mkdir('./test_result') 56 | if not os.path.exists('./test_result/detected_shadow'): 57 | os.mkdir('./test_result/detected_shadow') 58 | if not os.path.exists('./test_result/shadow_removal_image'): 59 | os.mkdir('./test_result/shadow_removal_image') 60 | if not os.path.exists('./test_result/grid'): 61 | os.mkdir('./test_result/grid') 62 | 63 | def unnormalize(x): 64 | x = x.transpose(1, 3) 65 | #mean, std 66 | x = x * torch.Tensor((0.5, )) + torch.Tensor((0.5, )) 67 | x = x.transpose(1, 3) 68 | return x 69 | 70 | def test(G1, G2, test_dataset): 71 | ''' 72 | this module test dataset from ISTD dataset 73 | ''' 74 | check_dir() 75 | 76 | device = "cuda" if torch.cuda.is_available() else "cpu" 77 | 78 | G1.to(device) 79 | G2.to(device) 80 | 81 | """use GPU in parallel""" 82 | if device == 'cuda': 83 | G1 = torch.nn.DataParallel(G1) 84 | G2 = torch.nn.DataParallel(G2) 85 | print("parallel mode") 86 | 87 | print("device:{}".format(device)) 88 | 89 | G1.eval() 90 | G2.eval() 91 | 92 | for n, (img, gt_shadow, gt) in enumerate([test_dataset[i] for i in range(test_dataset.__len__())]): 93 | 94 | print(test_dataset.img_list['path_A'][n].split('/')[4][:-4]) 95 | 96 | img = torch.unsqueeze(img, dim=0) 97 | gt_shadow = torch.unsqueeze(gt_shadow, dim=0) 98 | gt = torch.unsqueeze(gt, dim=0) 99 | 100 | with torch.no_grad(): 101 | detected_shadow = G1(img.to(device)) 102 | detected_shadow = detected_shadow.to(torch.device('cpu')) 103 | concat = torch.cat([img, detected_shadow], dim=1) 104 | shadow_removal_image = G2(concat.to(device)) 105 | shadow_removal_image = shadow_removal_image.to(torch.device('cpu')) 106 | 107 | 108 | grid = make_grid(torch.cat([unnormalize(img), unnormalize(gt), unnormalize(shadow_removal_image), 109 | unnormalize(torch.cat([gt_shadow, gt_shadow, gt_shadow], dim=1)), 110 | unnormalize(torch.cat([detected_shadow, detected_shadow, detected_shadow], dim=1))], 111 | dim=0)) 112 | 113 | save_image(grid, './test_result/grid/'+test_dataset.img_list['path_A'][n].split('/')[4]) 114 | 115 | detected_shadow = transforms.ToPILImage(mode='L')(unnormalize(detected_shadow)[0, :, :, :]) 116 | detected_shadow.save('./test_result/detected_shadow/'+test_dataset.img_list['path_A'][n].split('/')[4]) 117 | 118 | shadow_removal_image = transforms.ToPILImage(mode='RGB')(unnormalize(shadow_removal_image)[0, :, :, :]) 119 | shadow_removal_image.save('./test_result/shadow_removal_image/'+test_dataset.img_list['path_A'][n].split('/')[4]) 120 | 121 | def test_own_image(G1, G2, path, out_path, size, img_transform): 122 | img = Image.open(path).convert('RGB') 123 | width, height = img.width, img.height 124 | img = img.resize((size, size), Image.LANCZOS) 125 | img = img_transform(img) 126 | img = torch.unsqueeze(img, dim=0) 127 | 128 | device = "cuda" if torch.cuda.is_available() else "cpu" 129 | 130 | G1.to(device) 131 | G2.to(device) 132 | 133 | """use GPU in parallel""" 134 | if device == 'cuda': 135 | G1 = torch.nn.DataParallel(G1) 136 | G2 = torch.nn.DataParallel(G2) 137 | print("parallel mode") 138 | 139 | print("device:{}".format(device)) 140 | 141 | G1.eval() 142 | G2.eval() 143 | 144 | with torch.no_grad(): 145 | detected_shadow = G1(img.to(device)) 146 | detected_shadow = detected_shadow.to(torch.device('cpu')) 147 | concat = torch.cat([img, detected_shadow], dim=1) 148 | shadow_removal_image = G2(concat.to(device)) 149 | shadow_removal_image = shadow_removal_image.to(torch.device('cpu')) 150 | 151 | 152 | grid = make_grid(torch.cat([unnormalize(img), 153 | unnormalize(torch.cat([detected_shadow, detected_shadow, detected_shadow], dim=1)), 154 | unnormalize(shadow_removal_image)], 155 | dim=0)) 156 | 157 | save_image(grid, out_path + '/grid_' + path.split('/')[-1]) 158 | 159 | detected_shadow = transforms.ToPILImage(mode='L')(unnormalize(detected_shadow)[0, :, :, :]) 160 | detected_shadow= detected_shadow.resize((width, height), Image.LANCZOS) 161 | detected_shadow.save(out_path + '/detected_shadow_' + path.split('/')[-1]) 162 | 163 | shadow_removal_image = transforms.ToPILImage(mode='RGB')(unnormalize(shadow_removal_image)[0, :, :, :]) 164 | shadow_removal_image = shadow_removal_image.resize((width, height), Image.LANCZOS) 165 | shadow_removal_image.save(out_path + '/shadow_removal_image_' + path.split('/')[-1]) 166 | 167 | def main(parser): 168 | G1 = Generator(input_channels=3, output_channels=1) 169 | G2 = Generator(input_channels=4, output_channels=3) 170 | 171 | '''load''' 172 | if parser.load is not None: 173 | print('load checkpoint ' + parser.load) 174 | 175 | G1_weights = torch.load('./checkpoints/ST-CGAN_G1_'+parser.load+'.pth') 176 | G1.load_state_dict(fix_model_state_dict(G1_weights)) 177 | 178 | G2_weights = torch.load('./checkpoints/ST-CGAN_G2_'+parser.load+'.pth') 179 | G2.load_state_dict(fix_model_state_dict(G2_weights)) 180 | 181 | mean = (0.5,) 182 | std = (0.5,) 183 | 184 | size = parser.image_size 185 | crop_size = parser.crop_size 186 | resized_size = parser.resized_size 187 | 188 | # test own image 189 | if parser.image_path is not None: 190 | print('test ' + parser.image_path) 191 | test_own_image(G1, G2, parser.image_path, parser.out_path, resized_size, img_transform=ImageTransformOwn(size=size, mean=mean, std=std)) 192 | 193 | # test images from the ISTD dataset 194 | else: 195 | print('test ISTD dataset') 196 | test_img_list = make_datapath_list(phase='test') 197 | test_dataset = ImageDataset(img_list=test_img_list, 198 | img_transform=ImageTransform(size=size, crop_size=crop_size, mean=mean, std=std), 199 | phase='test') 200 | test(G1, G2, test_dataset) 201 | 202 | if __name__ == "__main__": 203 | parser = get_parser().parse_args() 204 | main(parser) 205 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils.data_loader import make_datapath_list, ImageDataset, ImageTransform 2 | from models.ST_CGAN import Generator, Discriminator 3 | from torchvision.utils import make_grid 4 | from torchvision.utils import save_image 5 | from torch.autograd import Variable 6 | from collections import OrderedDict 7 | from torchvision import models 8 | from tqdm import tqdm 9 | 10 | import matplotlib.pyplot as plt 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | import numpy as np 14 | import argparse 15 | import time 16 | import torch 17 | import os 18 | 19 | torch.manual_seed(44) 20 | # choose your device 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 22 | 23 | def get_parser(): 24 | parser = argparse.ArgumentParser( 25 | prog='ST-CGAN: Stacked Conditional Generative Adversarial Networks for Jointly Learning Shadow Detection and Shadow Removal', 26 | usage='python3 main.py', 27 | description='This module demonstrates shadow detection and removal using ST-CGAN.', 28 | add_help=True) 29 | 30 | parser.add_argument('-e', '--epoch', type=int, default=10000, help='Number of epochs') 31 | parser.add_argument('-b', '--batch_size', type=int, default=8, help='Batch size') 32 | parser.add_argument('-l', '--load', type=str, default=None, help='the number of checkpoints') 33 | parser.add_argument('-hor', '--hold_out_ratio', type=float, default=0.8, help='training-validation ratio') 34 | parser.add_argument('-s', '--image_size', type=int, default=286) 35 | parser.add_argument('-cs', '--crop_size', type=int, default=256) 36 | parser.add_argument('-lr', '--lr', type=float, default=2e-4) 37 | 38 | return parser 39 | 40 | def fix_model_state_dict(state_dict): 41 | ''' 42 | remove 'module.' of dataparallel 43 | ''' 44 | new_state_dict = OrderedDict() 45 | for k, v in state_dict.items(): 46 | name = k 47 | if name.startswith('module.'): 48 | name = name[7:] 49 | new_state_dict[name] = v 50 | return new_state_dict 51 | 52 | def set_requires_grad(nets, requires_grad=False): 53 | for net in nets: 54 | if net is not None: 55 | for param in net.parameters(): 56 | param.requires_grad = requires_grad 57 | 58 | def unnormalize(x): 59 | x = x.transpose(1, 3) 60 | #mean, std 61 | x = x * torch.Tensor((0.5, )) + torch.Tensor((0.5, )) 62 | x = x.transpose(1, 3) 63 | return x 64 | 65 | def evaluate(G1, G2, dataset, device, filename): 66 | img, gt_shadow, gt = zip(*[dataset[i] for i in range(8)]) 67 | img = torch.stack(img) 68 | gt_shadow = torch.stack(gt_shadow) 69 | gt = torch.stack(gt) 70 | 71 | with torch.no_grad(): 72 | detected_shadow = G1(img.to(device)) 73 | detected_shadow = detected_shadow.to(torch.device('cpu')) 74 | concat = torch.cat([img, detected_shadow], dim=1) 75 | shadow_removal_image = G2(concat.to(device)) 76 | shadow_removal_image = shadow_removal_image.to(torch.device('cpu')) 77 | 78 | grid_detect = make_grid(torch.cat((unnormalize(gt_shadow), unnormalize(detected_shadow)), dim=0)) 79 | grid_removal = make_grid(torch.cat((unnormalize(img), unnormalize(gt), unnormalize(shadow_removal_image)), dim=0)) 80 | 81 | save_image(grid_detect, filename+'_detect.jpg') 82 | save_image(grid_removal, filename+'_removal.jpg') 83 | 84 | def plot_log(data, save_model_name='model'): 85 | plt.cla() 86 | plt.plot(data['G'], label='G_loss ') 87 | plt.plot(data['D'], label='D_loss ') 88 | plt.legend() 89 | plt.xlabel('epoch') 90 | plt.ylabel('loss') 91 | plt.title('Loss') 92 | plt.savefig('./logs/'+save_model_name+'.png') 93 | 94 | def check_dir(): 95 | if not os.path.exists('./logs'): 96 | os.mkdir('./logs') 97 | if not os.path.exists('./checkpoints'): 98 | os.mkdir('./checkpoints') 99 | if not os.path.exists('./result'): 100 | os.mkdir('./result') 101 | 102 | def train_model(G1, G2, D1, D2, dataloader, val_dataset, num_epochs, parser, save_model_name='model'): 103 | 104 | check_dir() 105 | 106 | device = "cuda" if torch.cuda.is_available() else "cpu" 107 | Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor 108 | 109 | G1.to(device) 110 | G2.to(device) 111 | D1.to(device) 112 | D2.to(device) 113 | 114 | """use GPU in parallel""" 115 | if device == 'cuda': 116 | G1 = torch.nn.DataParallel(G1) 117 | G2 = torch.nn.DataParallel(G2) 118 | D1 = torch.nn.DataParallel(D1) 119 | D2 = torch.nn.DataParallel(D2) 120 | print("parallel mode") 121 | 122 | print("device:{}".format(device)) 123 | 124 | lr = parser.lr 125 | beta1, beta2 = 0.5, 0.999 126 | 127 | optimizerG = torch.optim.Adam([{'params': G1.parameters()}, {'params': G2.parameters()}], 128 | lr=lr, 129 | betas=(beta1, beta2)) 130 | optimizerD = torch.optim.Adam([{'params': D1.parameters()}, {'params': D2.parameters()}], 131 | lr=lr, 132 | betas=(beta1, beta2)) 133 | 134 | criterionGAN = nn.BCEWithLogitsLoss().to(device) 135 | criterionL1 = nn.L1Loss().to(device) 136 | 137 | torch.backends.cudnn.benchmark = True 138 | 139 | mini_batch_size = parser.batch_size 140 | num_train_imgs = len(dataloader.dataset) 141 | batch_size = dataloader.batch_size 142 | lambda_dict = {'lambda1':5, 'lambda2':0.1, 'lambda3':0.1} 143 | 144 | iteration = 1 145 | g_losses = [] 146 | d_losses = [] 147 | 148 | for epoch in range(num_epochs+1): 149 | 150 | G1.train() 151 | G2.train() 152 | D1.train() 153 | D2.train() 154 | t_epoch_start = time.time() 155 | 156 | epoch_g_loss = 0.0 157 | epoch_d_loss = 0.0 158 | 159 | print('-----------') 160 | print('Epoch {}/{}'.format(epoch, num_epochs)) 161 | print('(train)') 162 | 163 | for images, gt_shadow, gt in tqdm(dataloader): 164 | 165 | # if size of minibatch is 1, an error would be occured. 166 | if images.size()[0] == 1: 167 | continue 168 | 169 | images = images.to(device) 170 | gt = gt.to(device) 171 | gt_shadow = gt_shadow.to(device) 172 | 173 | mini_batch_size = images.size()[0] 174 | 175 | # Train Discriminator 176 | set_requires_grad([D1, D2], True) # enable backprop$ 177 | optimizerD.zero_grad() 178 | 179 | # for D1 180 | detected_shadow = G1(images) 181 | fake1 = torch.cat([images, detected_shadow], dim=1) 182 | real1 = torch.cat([images, gt_shadow], dim=1) 183 | out_D1_fake = D1(fake1.detach()) 184 | out_D1_real = D1(real1)# .detach() is not required as real1 doesn't have grad 185 | 186 | # for D2 187 | shadow_removal_image = G2(fake1) 188 | fake2 = torch.cat([fake1, shadow_removal_image], dim=1) 189 | real2 = torch.cat([real1, gt], dim=1) 190 | out_D2_fake = D2(fake2.detach()) 191 | out_D2_real = D2(real2)# .detach() is not required as real2 doesn't have grad 192 | 193 | # L_CGAN1 194 | label_D1_fake = Variable(Tensor(np.zeros(out_D1_fake.size())), requires_grad=True) 195 | label_D1_real = Variable(Tensor(np.ones(out_D1_fake.size())), requires_grad=True) 196 | 197 | loss_D1_fake = criterionGAN(out_D1_fake, label_D1_fake) 198 | loss_D1_real = criterionGAN(out_D1_real, label_D1_real) 199 | D_L_CGAN1 = loss_D1_fake + loss_D1_real 200 | 201 | # L_CGAN2 202 | label_D2_fake = Variable(Tensor(np.zeros(out_D2_fake.size())), requires_grad=True) 203 | label_D2_real = Variable(Tensor(np.ones(out_D2_fake.size())), requires_grad=True) 204 | 205 | loss_D2_fake = criterionGAN(out_D2_fake, label_D2_fake) 206 | loss_D2_real = criterionGAN(out_D2_real, label_D2_real) 207 | D_L_CGAN2 = loss_D2_fake + loss_D2_real 208 | 209 | # total 210 | D_loss = lambda_dict['lambda2'] * D_L_CGAN1 + lambda_dict['lambda3'] * D_L_CGAN2 211 | D_loss.backward() 212 | optimizerD.step() 213 | 214 | # Train Generator 215 | set_requires_grad([D1, D2], False) 216 | optimizerG.zero_grad() 217 | 218 | # L_CGAN1 219 | fake1 = torch.cat([images, detected_shadow], dim=1) 220 | out_D1_fake = D1(fake1.detach()) 221 | G_L_CGAN1 = criterionGAN(out_D1_fake, label_D1_real) 222 | 223 | # L_data1 224 | G_L_data1 = criterionL1(detected_shadow, gt_shadow) 225 | 226 | # L_CGAN2 227 | fake2 = torch.cat([fake1, shadow_removal_image], dim=1) 228 | out_D2_fake = D2(fake2.detach()) 229 | G_L_CGAN2 = criterionGAN(out_D2_fake, label_D2_real) 230 | 231 | #L_data2 232 | G_L_data2 = criterionL1(gt, shadow_removal_image) 233 | 234 | #total 235 | G_loss = G_L_data1 + lambda_dict['lambda1'] * G_L_data2 + lambda_dict['lambda2'] * G_L_CGAN1 + lambda_dict['lambda3'] * G_L_CGAN2 236 | G_loss.backward() 237 | optimizerG.step() 238 | 239 | epoch_d_loss += D_loss.item() 240 | epoch_g_loss += G_loss.item() 241 | 242 | t_epoch_finish = time.time() 243 | print('-----------') 244 | print('epoch {} || Epoch_D_Loss:{:.4f} || Epoch_G_Loss:{:.4f}'.format(epoch, epoch_d_loss/batch_size, epoch_g_loss/batch_size)) 245 | print('timer: {:.4f} sec.'.format(t_epoch_finish - t_epoch_start)) 246 | 247 | d_losses += [epoch_d_loss/batch_size] 248 | g_losses += [epoch_g_loss/batch_size] 249 | t_epoch_start = time.time() 250 | plot_log({'G':g_losses, 'D':d_losses}, save_model_name) 251 | 252 | if(epoch%10 == 0): 253 | torch.save(G1.state_dict(), 'checkpoints/'+save_model_name+'_G1_'+str(epoch)+'.pth') 254 | torch.save(G2.state_dict(), 'checkpoints/'+save_model_name+'_G2_'+str(epoch)+'.pth') 255 | torch.save(D1.state_dict(), 'checkpoints/'+save_model_name+'_D1_'+str(epoch)+'.pth') 256 | torch.save(D2.state_dict(), 'checkpoints/'+save_model_name+'_D2_'+str(epoch)+'.pth') 257 | G1.eval() 258 | G2.eval() 259 | evaluate(G1, G2, val_dataset, device, '{:s}/val_{:d}'.format('result', epoch)) 260 | 261 | return G1, G2, D1, D2 262 | 263 | 264 | 265 | def main(parser): 266 | G1 = Generator(input_channels=3, output_channels=1) 267 | G2 = Generator(input_channels=4, output_channels=3) 268 | D1 = Discriminator(input_channels=4) 269 | D2 = Discriminator(input_channels=7) 270 | 271 | '''load''' 272 | if parser.load is not None: 273 | print('load checkpoint ' + parser.load) 274 | 275 | G1_weights = torch.load('./checkpoints/ST-CGAN_G1_'+parser.load+'.pth') 276 | G1.load_state_dict(fix_model_state_dict(G1_weights)) 277 | 278 | G2_weights = torch.load('./checkpoints/ST-CGAN_G2_'+parser.load+'.pth') 279 | G2.load_state_dict(fix_model_state_dict(G2_weights)) 280 | 281 | D1_weights = torch.load('./checkpoints/ST-CGAN_D1_'+parser.load+'.pth') 282 | D1.load_state_dict(fix_model_state_dict(D1_weights)) 283 | 284 | D2_weights = torch.load('./checkpoints/ST-CGAN_D2_'+parser.load+'.pth') 285 | D2.load_state_dict(fix_model_state_dict(D2_weights)) 286 | 287 | train_img_list, val_img_list = make_datapath_list(phase='train', rate=parser.hold_out_ratio) 288 | 289 | mean = (0.5,) 290 | std = (0.5,) 291 | size = parser.image_size 292 | crop_size = parser.crop_size 293 | batch_size = parser.batch_size 294 | num_epochs = parser.epoch 295 | 296 | train_dataset = ImageDataset(img_list=train_img_list, 297 | img_transform=ImageTransform(size=size, crop_size=crop_size, mean=mean, std=std), 298 | phase='train') 299 | val_dataset = ImageDataset(img_list=val_img_list, 300 | img_transform=ImageTransform(size=size, crop_size=crop_size, mean=mean, std=std), 301 | phase='val') 302 | 303 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) #num_workers=4 304 | 305 | G1_update, G2_update, D1_update, D2_update = train_model(G1, G2, D1, D2, dataloader=train_dataloader, 306 | val_dataset=val_dataset, num_epochs=num_epochs, 307 | parser=parser, save_model_name='ST-CGAN') 308 | 309 | if __name__ == "__main__": 310 | parser = get_parser().parse_args() 311 | main(parser) 312 | --------------------------------------------------------------------------------