├── networks ├── resnext_modify │ ├── __init__.py │ ├── config.py │ ├── resnext101_regular.py │ ├── resnext101_5out.py │ └── resnext_101_32x4d_.py ├── DeepLabV3.py └── TVSD.py ├── joint_transforms.py ├── config.py ├── utils ├── IRNN_Forward_cuda.cu └── IRNN_Backward_cuda.cu ├── misc.py ├── evaluate.py ├── infer.py ├── losses.py ├── README.md ├── dataset └── VShadow_crosspairwise.py └── train.py /networks/resnext_modify/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnext101_regular import ResNeXt101 -------------------------------------------------------------------------------- /networks/resnext_modify/config.py: -------------------------------------------------------------------------------- 1 | resnext_101_32_path = '/home/chenzhihao/shadow_detection/shadow-MT/backbone_pth/resnext_101_32x4d.pth' 2 | -------------------------------------------------------------------------------- /networks/resnext_modify/resnext101_regular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .config import resnext_101_32_path 5 | from .resnext_101_32x4d_ import resnext_101_32x4d 6 | import torch._utils 7 | 8 | 9 | class ResNeXt101(nn.Module): 10 | def __init__(self, pretained=True): 11 | super(ResNeXt101, self).__init__() 12 | net = resnext_101_32x4d 13 | if pretained: 14 | net.load_state_dict(torch.load(resnext_101_32_path)) 15 | net = list(net.children()) 16 | self.layer0 = nn.Sequential(*net[:3]) 17 | self.layer1 = nn.Sequential(*net[3: 5]) 18 | self.layer2 = net[5] 19 | self.layer3 = net[6] 20 | self.layer4 = net[7] 21 | 22 | def forward(self, x): 23 | layer0 = self.layer0(x) 24 | layer1 = self.layer1(layer0) 25 | layer2 = self.layer2(layer1) 26 | layer3 = self.layer3(layer2) 27 | layer4 = self.layer4(layer3) 28 | return layer4 29 | 30 | -------------------------------------------------------------------------------- /networks/resnext_modify/resnext101_5out.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .resnext_101_32x4d_ import resnext_101_32x4d 5 | from .config import resnext_101_32_path 6 | 7 | class ResNeXt101(nn.Module): 8 | def __init__(self): 9 | super(ResNeXt101, self).__init__() 10 | net = resnext_101_32x4d 11 | net.load_state_dict(torch.load(resnext_101_32_path)) 12 | net = list(net.children()) 13 | self.layer0 = nn.Sequential(*net[:3]) 14 | self.layer1 = nn.Sequential(*net[3: 5]) 15 | self.layer2 = net[5] 16 | self.layer3 = net[6] 17 | self.layer4 = net[7] 18 | 19 | def forward(self, x): 20 | layers = [] 21 | layer0 = self.layer0(x) 22 | layers.append(layer0) 23 | layer1 = self.layer1(layer0) 24 | layers.append(layer1) 25 | layer2 = self.layer2(layer1) 26 | layers.append(layer2) 27 | layer3 = self.layer3(layer2) 28 | layers.append(layer3) 29 | layer4 = self.layer4(layer3) 30 | layers.append(layer4) 31 | return layers 32 | -------------------------------------------------------------------------------- /joint_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image 4 | 5 | 6 | class Compose(object): 7 | def __init__(self, transforms): 8 | self.transforms = transforms 9 | 10 | def __call__(self, img, mask, manual_random=None): 11 | assert img.size == mask.size 12 | for t in self.transforms: 13 | img, mask = t(img, mask, manual_random) 14 | return img, mask 15 | 16 | 17 | class RandomHorizontallyFlip(object): 18 | def __call__(self, img, mask, manual_random=None): 19 | if manual_random is None: 20 | if random.random() < 0.5: 21 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 22 | return img, mask 23 | else: 24 | if manual_random < 0.5: 25 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 26 | return img, mask 27 | 28 | 29 | class Resize(object): 30 | def __init__(self, size): 31 | self.size = tuple(reversed(size)) # size: (h, w) 32 | 33 | def __call__(self, img, mask, manual_random=None): 34 | assert img.size == mask.size 35 | return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) 36 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | ''' 3 | Dataset root form: ($path, $type, $name) 4 | path: dataset path 5 | type: "image" or "video" 6 | name: just for easy mark 7 | ''' 8 | ### Saliency datasets 9 | # DUT_OMRON_training_root = ('/home/ext/chenzhihao/Datasets/saliency_dataset/DUT-OMRON', 'image', 'DUT-OMRON') 10 | # MSRA10K_training_root = ('/home/ext/chenzhihao/Datasets/saliency_dataset/MSRA10K', 'image', 'MSRA10K') 11 | # DAVIS_training_root = ('/home/ext/chenzhihao/Datasets/saliency_dataset/DAVIS_train', 'video', 'DAVIS_train') 12 | # DAVIS_validation_root = ('/home/ext/chenzhihao/Datasets/saliency_dataset/DAVIS_val', 'video', 'DAVIS_val') 13 | 14 | # Shadow datasets 15 | # SBU_training_root = ('/home/ext/chenzhihao/Datasets/SBU-shadow/SBUTrain4KRecoveredSmall', 'image', 'SBU_train') 16 | # SBU_testing_root = ('/home/ext/chenzhihao/Datasets/SBU-shadow/SBU-Test', 'image', 'SBU_test') 17 | ViSha_training_root = ('/home/ext/chenzhihao/Datasets/ViSha/train', 'video', 'ViSD_train') 18 | ViSha_validation_root = ('/home/ext/chenzhihao/Datasets/ViSha/test', 'video', 'ViSD_test') 19 | 20 | 21 | ''' 22 | Pretrained single model path 23 | ''' 24 | # PDBM_single_path = '/home/ext/chenzhihao/code/video_shadow/models_saliency/PDBM_single_256/50000.pth' 25 | # DeepLabV3_path = '/home/ext/chenzhihao/code/video_shadow/models/deeplabv3/20.pth' 26 | # FPN_path = '/home/ext/chenzhihao/code/video_shadow/models/FPN/20.pth' -------------------------------------------------------------------------------- /utils/IRNN_Forward_cuda.cu: -------------------------------------------------------------------------------- 1 | #define CUDA_KERNEL_LOOP(i, n) \ 2 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 3 | i < (n); \ 4 | i += blockDim.x * gridDim.x) 5 | 6 | #define INDEX(b,c,h,w,channels,height,width) ((b * channels + c) * height + h) * width+ w 7 | 8 | extern "C" __global__ void IRNNForward( 9 | const float* input_feature, 10 | 11 | const float* weight_up, 12 | const float* weight_right, 13 | const float* weight_down, 14 | const float* weight_left, 15 | 16 | const float* bias_up, 17 | const float* bias_right, 18 | const float* bias_down, 19 | const float* bias_left, 20 | 21 | float* output_up, 22 | float* output_right, 23 | float* output_down, 24 | float* output_left, 25 | 26 | const int channels, 27 | const int height, 28 | const int width, 29 | const int n){ 30 | 31 | CUDA_KERNEL_LOOP(index,n){ 32 | int w = index % width; 33 | int h = index / width % height; 34 | int c = index / width / height % channels; 35 | int b = index / width / height / channels; 36 | 37 | float temp = 0; 38 | 39 | // left 40 | output_left[index] = input_feature[INDEX(b, c, h, width-1, channels, height, width)] > 0 ? input_feature[INDEX(b, c, h, width-1, channels, height, width)] : 0; 41 | for (int i = width-2; i>=w; i--) 42 | { 43 | temp = output_left[index] * weight_left[c] + bias_left[c] + input_feature[INDEX(b, c, h, i, channels, height, width)]; 44 | output_left[index] = (temp > 0)? temp : 0; 45 | } 46 | 47 | // right 48 | output_right[index] = input_feature[INDEX(b, c, h, 0, channels, height, width)] > 0 ? input_feature[INDEX(b, c, h, 0, channels, height, width)] : 0; 49 | for (int i = 1; i <= w; i++) 50 | { 51 | temp = output_right[index] * weight_right[c] + bias_right[c] + input_feature[INDEX(b, c, h, i, channels, height, width)]; 52 | output_right[index] = (temp > 0)? temp : 0; 53 | } 54 | 55 | // up 56 | output_up[index] = input_feature[INDEX(b,c,height-1,w,channels,height,width)] > 0 ? input_feature[INDEX(b,c,height-1,w,channels,height,width)] : 0; 57 | for (int i = height-2; i >= h; i--) 58 | { 59 | temp = output_up[index] * weight_up[c] + bias_up[c] + input_feature[INDEX(b, c, i, w, channels, height, width)]; 60 | output_up[index] = (temp > 0)? temp : 0; 61 | } 62 | 63 | // down 64 | output_down[index] = input_feature[INDEX(b, c, 0, w, channels, height, width)] > 0 ? input_feature[INDEX(b, c, 0, w, channels, height, width)] : 0; 65 | for (int i = 1; i <= h; i++) 66 | { 67 | temp = output_down[index] * weight_down[c] + bias_down[c] + input_feature[INDEX(b, c, i, w, channels, height, width)]; 68 | output_down[index] = (temp > 0)? temp : 0; 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from medpy import metric 4 | import torch 5 | 6 | 7 | class AvgMeter(object): 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def check_mkdir(dir_name): 25 | if not os.path.exists(dir_name): 26 | os.makedirs(dir_name) 27 | 28 | 29 | def _sigmoid(x): 30 | return 1 / (1 + np.exp(-x)) 31 | 32 | def cal_precision_recall_mae(prediction, gt): 33 | # input should be np array with data type uint8 34 | assert prediction.dtype == np.uint8 35 | assert gt.dtype == np.uint8 36 | assert prediction.shape == gt.shape 37 | 38 | eps = 1e-4 39 | 40 | prediction = prediction / 255. 41 | gt = gt / 255. 42 | 43 | mae = np.mean(np.abs(prediction - gt)) 44 | 45 | hard_gt = np.zeros(prediction.shape) 46 | hard_gt[gt > 0.5] = 1 47 | t = np.sum(hard_gt) 48 | 49 | precision, recall = [], [] 50 | # calculating precision and recall at 255 different binarizing thresholds 51 | for threshold in range(256): 52 | threshold = threshold / 255. 53 | 54 | hard_prediction = np.zeros(prediction.shape) 55 | hard_prediction[prediction > threshold] = 1 56 | 57 | tp = np.sum(hard_prediction * hard_gt) 58 | p = np.sum(hard_prediction) 59 | 60 | precision.append((tp + eps) / (p + eps)) 61 | recall.append((tp + eps) / (t + eps)) 62 | 63 | return precision, recall, mae 64 | 65 | 66 | def cal_fmeasure(precision, recall): 67 | assert len(precision) == 256 68 | assert len(recall) == 256 69 | beta_square = 0.3 70 | max_fmeasure = max([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]) 71 | 72 | return max_fmeasure 73 | 74 | def cal_Jaccard(prediction, gt): 75 | # input should be np array with data type uint8 76 | assert prediction.dtype == np.uint8 77 | assert gt.dtype == np.uint8 78 | assert prediction.shape == gt.shape 79 | 80 | prediction = prediction / 255. 81 | gt = gt / 255. 82 | 83 | pred = (prediction > 0.5) 84 | gt = (gt > 0.5) 85 | Jaccard = metric.binary.jc(pred, gt) 86 | 87 | return Jaccard 88 | 89 | def cal_BER(prediction, label, thr = 127.5): 90 | prediction = (prediction > thr) 91 | label = (label > thr) 92 | prediction_tmp = prediction.astype(np.float) 93 | label_tmp = label.astype(np.float) 94 | TP = np.sum(prediction_tmp * label_tmp) 95 | TN = np.sum((1 - prediction_tmp) * (1 - label_tmp)) 96 | Np = np.sum(label_tmp) 97 | Nn = np.sum((1-label_tmp)) 98 | BER = 0.5 * (2 - TP / Np - TN / Nn) * 100 99 | shadow_BER = (1 - TP / Np) * 100 100 | non_shadow_BER = (1 - TN / Nn) * 100 101 | 102 | return BER, shadow_BER, non_shadow_BER 103 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | from misc import check_mkdir, cal_precision_recall_mae, AvgMeter, cal_fmeasure, cal_Jaccard, cal_BER 5 | from tqdm import tqdm 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--models', type=str, default='TVSD', help='model name') 10 | parser.add_argument('--snapshot', type=str, default='12', help='model name') 11 | tmp_args = parser.parse_args() 12 | 13 | root_path = f'/home/ext/chenzhihao/code/video_shadow/models/{tmp_args.models}/predict_{tmp_args.snapshot}' 14 | save_path = f'/home/ext/chenzhihao/code/video_shadow/models/{tmp_args.models}/predict_fuse_{tmp_args.snapshot}' 15 | 16 | gt_path = '/home/ext/chenzhihao/Datasets/ViSha/test/labels' 17 | input_path = '/home/ext/chenzhihao/Datasets/ViSha/test/images' 18 | 19 | precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)] 20 | mae_record = AvgMeter() 21 | Jaccard_record = AvgMeter() 22 | BER_record = AvgMeter() 23 | shadow_BER_record = AvgMeter() 24 | non_shadow_BER_record = AvgMeter() 25 | 26 | video_list = os.listdir(root_path) 27 | for video in tqdm(video_list): 28 | gt_list = os.listdir(os.path.join(gt_path, video)) 29 | img_list = [f for f in os.listdir(os.path.join(root_path, video)) if f.split('_', 1)[0]+'.png' in gt_list] # include overlap images 30 | img_set = list(set([img.split('_', 1)[0] for img in img_list])) # remove repeat 31 | for img_prefix in img_set: 32 | # jump exist images 33 | check_mkdir(os.path.join(save_path, video)) 34 | save_name = os.path.join(save_path, video, '{}.png'.format(img_prefix)) 35 | # if not os.path.exists(os.path.join(save_path, video, save_name)): 36 | imgs = [img for img in img_list if img.split('_', 1)[0] == img_prefix] # imgs waited for fuse 37 | fuse = [] 38 | for img_path in imgs: 39 | img = np.array(Image.open(os.path.join(root_path, video, img_path)).convert('L')).astype(np.float32) 40 | # if np.max(img) > 0: # normalize prediction mask 41 | # img = (img - np.min(img)) / (np.max(img) - np.min(img)) * 255 42 | fuse.append(img) 43 | fuse = (sum(fuse) / len(imgs)).astype(np.uint8) 44 | # save image 45 | print(f'Save:{save_name}') 46 | Image.fromarray(fuse).save(save_name) 47 | # else: 48 | # print(f'Exist:{save_name}') 49 | # fuse = np.array(Image.open(save_name).convert('L')).astype(np.uint8) 50 | # calculate metric 51 | gt = np.array(Image.open(os.path.join(gt_path, video, img_prefix+'.png'))) 52 | precision, recall, mae = cal_precision_recall_mae(fuse, gt) 53 | Jaccard = cal_Jaccard(fuse, gt) 54 | Jaccard_record.update(Jaccard) 55 | BER, shadow_BER, non_shadow_BER = cal_BER(fuse, gt) 56 | BER_record.update(BER) 57 | shadow_BER_record.update(shadow_BER) 58 | non_shadow_BER_record.update(non_shadow_BER) 59 | for pidx, pdata in enumerate(zip(precision, recall)): 60 | p, r = pdata 61 | precision_record[pidx].update(p) 62 | recall_record[pidx].update(r) 63 | mae_record.update(mae) 64 | 65 | fmeasure = cal_fmeasure([precord.avg for precord in precision_record], 66 | [rrecord.avg for rrecord in recall_record]) 67 | log = 'MAE:{}, F-beta:{}, Jaccard:{}, BER:{}, SBER:{}, non-SBER:{}'.format(mae_record.avg, fmeasure, Jaccard_record.avg, BER_record.avg, shadow_BER_record.avg, non_shadow_BER_record.avg) 68 | print(log) 69 | 70 | 71 | -------------------------------------------------------------------------------- /networks/DeepLabV3.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .resnext_modify import resnext101_regular 5 | 6 | # DeeplabV3 plus 7 | class DeepLabV3(nn.Module): 8 | def __init__(self, num_classes=1, pretrained=True): 9 | super(DeepLabV3, self).__init__() 10 | self.backbone = resnext101_regular.ResNeXt101() 11 | aspp_dilate = [12, 24, 36] 12 | # aspp_dilate = [6, 12, 18] 13 | self.aspp = ASPP(2048, aspp_dilate) 14 | self.final_pre = nn.Conv2d(256, 1, 1) 15 | initialize_weights(self.aspp, self.final_pre) 16 | 17 | def forward(self, x): 18 | # x shape: B, C, W, H 19 | x_size = x.size() 20 | x0 = self.backbone.layer0(x) 21 | x1 = self.backbone.layer1(x0) # x1 is lower feature 22 | x2 = self.backbone.layer2(x1) 23 | x3 = self.backbone.layer3(x2) 24 | x4 = self.backbone.layer4(x3) 25 | fea = self.aspp(x4) 26 | predict = self.final_pre(fea) 27 | predict = F.interpolate(predict, size=x_size[2:], mode='bilinear', align_corners=False) 28 | return x1, fea, predict 29 | 30 | class ASPP(nn.Module): 31 | def __init__(self, in_channels, atrous_rates): 32 | super(ASPP, self).__init__() 33 | out_channels = 256 34 | modules = [] 35 | modules.append(nn.Sequential( 36 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 37 | nn.BatchNorm2d(out_channels), 38 | nn.ReLU(inplace=True))) 39 | 40 | rate1, rate2, rate3 = tuple(atrous_rates) 41 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 42 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 43 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 44 | modules.append(ASPPPooling(in_channels, out_channels)) 45 | 46 | self.convs = nn.ModuleList(modules) 47 | 48 | self.project = nn.Sequential( 49 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 50 | nn.BatchNorm2d(out_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Dropout(0.1)) 53 | 54 | def forward(self, x): 55 | res = [] 56 | for conv in self.convs: 57 | res.append(conv(x)) 58 | res = torch.cat(res, dim=1) 59 | return self.project(res) 60 | 61 | class ASPPConv(nn.Sequential): 62 | def __init__(self, in_channels, out_channels, dilation): 63 | modules = [ 64 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 65 | nn.BatchNorm2d(out_channels), 66 | nn.ReLU(inplace=True) 67 | ] 68 | super(ASPPConv, self).__init__(*modules) 69 | 70 | class ASPPPooling(nn.Sequential): 71 | def __init__(self, in_channels, out_channels): 72 | super(ASPPPooling, self).__init__( 73 | nn.AdaptiveAvgPool2d(1), 74 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 75 | nn.BatchNorm2d(out_channels), 76 | nn.ReLU(inplace=True)) 77 | 78 | def forward(self, x): 79 | size = x.shape[-2:] 80 | x = super(ASPPPooling, self).forward(x) 81 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 82 | 83 | def initialize_weights(*models): 84 | for model in models: 85 | for module in model.modules(): 86 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 87 | nn.init.kaiming_normal_(module.weight) 88 | if module.bias is not None: 89 | module.bias.data.zero_() 90 | elif isinstance(module, nn.BatchNorm2d): 91 | module.weight.data.fill_(1) 92 | module.bias.data.zero_() -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import torch 5 | from PIL import Image 6 | from torchvision import transforms 7 | 8 | from config import ViSha_validation_root 9 | from misc import check_mkdir 10 | from networks.TVSD import TVSD 11 | import argparse 12 | from tqdm import tqdm 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--snapshot', type=str, default='2', help='snapshot') 16 | parser.add_argument('--models', type=str, default='TVSD', help='model name') 17 | tmp_args = parser.parse_args() 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 20 | 21 | ckpt_path = './models' 22 | exp_name = tmp_args.models 23 | args = { 24 | 'snapshot': tmp_args.snapshot, 25 | 'scale': 416, 26 | 'test_adjacent': 5, 27 | 'input_folder': 'images', 28 | 'label_folder': 'labels' 29 | } 30 | 31 | img_transform = transforms.Compose([ 32 | transforms.Resize((args['scale'], args['scale'])), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 35 | ]) 36 | target_transform = transforms.ToTensor() 37 | 38 | root = ViSha_validation_root[0] 39 | 40 | to_pil = transforms.ToPILImage() 41 | 42 | 43 | def main(): 44 | net = TVSD().cuda() 45 | 46 | if len(args['snapshot']) > 0: 47 | check_point = torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')) 48 | net.load_state_dict(check_point['model']) 49 | 50 | net.eval() 51 | with torch.no_grad(): 52 | video_list = os.listdir(os.path.join(root, args['input_folder'])) 53 | for video in tqdm(video_list): 54 | # all images 55 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, args['input_folder'], video)) if 56 | f.endswith('.jpg')] 57 | # need evaluation images 58 | img_eval_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, args['label_folder'], video)) if 59 | f.endswith('.png')] 60 | 61 | img_eval_list = sortImg(img_eval_list) 62 | for exemplar_idx, exemplar_name in enumerate(img_eval_list): 63 | query_idx_list = getAdjacentIndex(exemplar_idx, 0, len(img_list), args['test_adjacent']) 64 | for query_idx in query_idx_list: 65 | exemplar = Image.open(os.path.join(root, args['input_folder'], video, exemplar_name + '.jpg')).convert('RGB') 66 | w, h = exemplar.size 67 | query = Image.open(os.path.join(root, args['input_folder'], video, img_list[query_idx] + '.jpg')).convert('RGB') 68 | exemplar_tensor = img_transform(exemplar).unsqueeze(0).cuda() 69 | query_tensor = img_transform(query).unsqueeze(0).cuda() 70 | exemplar_pre, _, _, _ = net(exemplar_tensor, query_tensor, query_tensor) 71 | res = (exemplar_pre.data > 0).to(torch.float32) 72 | prediction = np.array( 73 | transforms.Resize((h, w))(to_pil(res.squeeze(0).cpu()))) 74 | check_mkdir(os.path.join(ckpt_path, exp_name, "predict_" + args['snapshot'], video)) 75 | # save form as 00000001_1.png, 000000001_2.png 76 | save_name = f"{exemplar_name}_by{query_idx}.png" 77 | print(os.path.join(ckpt_path, exp_name, "predict_" + args['snapshot'], video, save_name)) 78 | Image.fromarray(prediction).save( 79 | os.path.join(ckpt_path, exp_name, "predict_" + args['snapshot'], video, save_name)) 80 | 81 | 82 | def sortImg(img_list): 83 | img_int_list = [int(f) for f in img_list] 84 | sort_index = [i for i, v in sorted(enumerate(img_int_list), key=lambda x: x[1])] # sort img to 001,002,003... 85 | return [img_list[i] for i in sort_index] 86 | 87 | 88 | def getAdjacentIndex(current_index, start_index, video_length, adjacent_length): 89 | if current_index + adjacent_length < start_index + video_length: 90 | query_index_list = [current_index+i+1 for i in range(adjacent_length)] 91 | else: 92 | query_index_list = [current_index-i-1 for i in range(adjacent_length)] 93 | return query_index_list 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torch.autograd import Variable 6 | try: 7 | from itertools import ifilterfalse 8 | except ImportError: # py3k 9 | from itertools import filterfalse as ifilterfalse 10 | 11 | def bce2d_new(input, target, reduction='mean'): 12 | assert(input.size() == target.size()) 13 | pos = torch.eq(target, 1).float() 14 | neg = torch.eq(target, 0).float() 15 | # ing = ((torch.gt(target, 0) & torch.lt(target, 1))).float() 16 | 17 | num_pos = torch.sum(pos) 18 | num_neg = torch.sum(neg) 19 | num_total = num_pos + num_neg 20 | 21 | alpha = num_neg / num_total 22 | beta = 1.1 * num_pos / num_total 23 | # target pixel = 1 -> weight beta 24 | # target pixel = 0 -> weight 1-beta 25 | weights = alpha * pos + beta * neg 26 | 27 | return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction) 28 | 29 | def BCE_IOU(pred, mask): 30 | weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 31 | wbce = F.binary_cross_entropy_with_logits(pred, mask) 32 | wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 33 | 34 | pred = torch.sigmoid(pred) 35 | inter = ((pred * mask)*weit).sum(dim=(2, 3)) 36 | union = ((pred + mask)*weit).sum(dim=(2, 3)) 37 | wiou = 1 - (inter + 1)/(union - inter+1) 38 | return wbce.mean(), wiou.mean() 39 | 40 | # --------------------------- BINARY Lovasz LOSSES --------------------------- 41 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 42 | """ 43 | Binary Lovasz hinge loss 44 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 45 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 46 | per_image: compute the loss per image instead of per batch 47 | ignore: void class id 48 | """ 49 | if per_image: 50 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 51 | for log, lab in zip(logits, labels)) 52 | else: 53 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 54 | return loss 55 | 56 | 57 | def lovasz_hinge_flat(logits, labels): 58 | """ 59 | Binary Lovasz hinge loss 60 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 61 | labels: [P] Tensor, binary ground truth labels (0 or 1) 62 | ignore: label to ignore 63 | """ 64 | if len(labels) == 0: 65 | # only void pixels, the gradients should be 0 66 | return logits.sum() * 0. 67 | signs = 2. * labels.float() - 1. 68 | errors = (1. - logits * Variable(signs)) 69 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 70 | perm = perm.data 71 | gt_sorted = labels[perm] 72 | grad = lovasz_grad(gt_sorted) 73 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 74 | return loss 75 | 76 | 77 | def flatten_binary_scores(scores, labels, ignore=None): 78 | """ 79 | Flattens predictions in the batch (binary case) 80 | Remove labels equal to 'ignore' 81 | """ 82 | scores = scores.view(-1) 83 | labels = labels.view(-1) 84 | if ignore is None: 85 | return scores, labels 86 | valid = (labels != ignore) 87 | vscores = scores[valid] 88 | vlabels = labels[valid] 89 | return vscores, vlabels 90 | 91 | 92 | class StableBCELoss(torch.nn.modules.Module): 93 | def __init__(self): 94 | super(StableBCELoss, self).__init__() 95 | def forward(self, input, target): 96 | neg_abs = - input.abs() 97 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 98 | return loss.mean() 99 | 100 | 101 | def binary_xloss(logits, labels, ignore=None): 102 | """ 103 | Binary Cross entropy loss 104 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 105 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 106 | ignore: void class id 107 | """ 108 | logits, labels = flatten_binary_scores(logits, labels, ignore) 109 | loss = StableBCELoss()(logits, Variable(labels.float())) 110 | return loss 111 | 112 | def lovasz_grad(gt_sorted): 113 | """ 114 | Computes gradient of the Lovasz extension w.r.t sorted errors 115 | See Alg. 1 in paper 116 | """ 117 | p = len(gt_sorted) 118 | gts = gt_sorted.sum() 119 | intersection = gts - gt_sorted.float().cumsum(0) 120 | union = gts + (1 - gt_sorted).float().cumsum(0) 121 | jaccard = 1. - intersection / union 122 | if p > 1: # cover 1-pixel case 123 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 124 | return jaccard 125 | 126 | def mean(l, ignore_nan=False, empty=0): 127 | """ 128 | nanmean compatible with generators. 129 | """ 130 | l = iter(l) 131 | if ignore_nan: 132 | l = ifilterfalse(isnan, l) 133 | try: 134 | n = 1 135 | acc = next(l) 136 | except StopIteration: 137 | if empty == 'raise': 138 | raise ValueError('Empty mean') 139 | return empty 140 | for n, v in enumerate(l, 2): 141 | acc += v 142 | if n == 1: 143 | return acc 144 | return acc / n 145 | 146 | def isnan(x): 147 | return x != x -------------------------------------------------------------------------------- /utils/IRNN_Backward_cuda.cu: -------------------------------------------------------------------------------- 1 | #define CUDA_KERNEL_LOOP(i, n) \ 2 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 3 | i < (n); \ 4 | i += blockDim.x * gridDim.x) 5 | 6 | #define INDEX(b,c,h,w,channels,height,width) ((b * channels + c) * height + h) * width+ w 7 | 8 | 9 | extern "C" __global__ void IRNNBackward( 10 | float* grad_input, 11 | 12 | float* grad_weight_up_map, 13 | float* grad_weight_right_map, 14 | float* grad_weight_down_map, 15 | float* grad_weight_left_map, 16 | 17 | float* grad_bias_up_map, 18 | float* grad_bias_right_map, 19 | float* grad_bias_down_map, 20 | float* grad_bias_left_map, 21 | 22 | const float* weight_up, 23 | const float* weight_right, 24 | const float* weight_down, 25 | const float* weight_left, 26 | 27 | const float* grad_output_up, 28 | const float* grad_output_right, 29 | const float* grad_output_down, 30 | const float* grad_output_left, 31 | 32 | const float* output_up, 33 | const float* output_right, 34 | const float* output_down, 35 | const float* output_left, 36 | 37 | const int channels, 38 | const int height, 39 | const int width, 40 | const int n) { 41 | 42 | CUDA_KERNEL_LOOP(index,n){ 43 | 44 | int w = index % width; 45 | int h = index / width % height; 46 | int c = index / width / height % channels; 47 | int b = index / width / height / channels; 48 | 49 | float diff_left = 0; 50 | float diff_right = 0; 51 | float diff_up = 0; 52 | float diff_down = 0; 53 | 54 | //left 55 | 56 | for (int i = 0; i<=w; i++) 57 | { 58 | diff_left *= weight_left[c]; 59 | diff_left += grad_output_left[INDEX(b, c, h, i, channels, height, width)]; 60 | diff_left *= (output_left[INDEX(b, c, h, i, channels, height, width)]<=0)? 0 : 1; 61 | } 62 | 63 | 64 | float temp = grad_output_left[INDEX(b, c, h, 0, channels, height, width)]; 65 | for (int i = 1; i < w +1 ; i++) 66 | { 67 | temp = (output_left[INDEX(b, c, h, i-1, channels, height, width)] >0?1:0) * temp * weight_left[c] + grad_output_left[INDEX(b, c, h, i, channels, height, width)]; 68 | } 69 | 70 | if (w != width - 1){ 71 | grad_weight_left_map[index] = temp * output_left[INDEX(b, c, h, w+1, channels, height, width)] * (output_left[index] > 0? 1:0); 72 | grad_bias_left_map[index] = diff_left; 73 | } 74 | 75 | // right 76 | 77 | for (int i = width -1; i>=w; i--) 78 | { 79 | diff_right *= weight_right[c]; 80 | diff_right += grad_output_right[INDEX(b, c, h, i, channels, height, width)]; 81 | diff_right *= (output_right[INDEX(b, c, h, i, channels, height, width)]<=0)? 0 : 1; 82 | } 83 | 84 | 85 | temp = grad_output_right[INDEX(b, c, h, width-1, channels, height, width)]; 86 | for (int i = width -2; i > w - 1 ; i--) 87 | { 88 | temp = (output_right[INDEX(b, c, h, i+1, channels, height, width)] >0?1:0) * temp * weight_right[c] + grad_output_right[INDEX(b, c, h, i, channels, height, width)]; 89 | } 90 | 91 | if (w != 0){ 92 | grad_weight_right_map[index] = temp * output_right[INDEX(b, c, h, w-1, channels, height, width)] * (output_right[index] > 0? 1:0); 93 | grad_bias_right_map[index] = diff_right; 94 | } 95 | 96 | // up 97 | 98 | 99 | for (int i = 0; i<=h; i++) 100 | { 101 | diff_up *= weight_up[c]; 102 | diff_up += grad_output_up[INDEX(b, c, i, w, channels, height, width)]; 103 | diff_up *= (output_up[INDEX(b, c, i, w, channels, height, width)]<=0)? 0 : 1; 104 | } 105 | 106 | 107 | temp = grad_output_up[INDEX(b, c, 0, w, channels, height, width)]; 108 | for (int i = 1; i < h +1 ; i++) 109 | { 110 | temp = (output_up[INDEX(b, c, i-1, w, channels, height, width)] >0?1:0) * temp * weight_up[c] + grad_output_up[INDEX(b, c, i, w, channels, height, width)]; 111 | } 112 | 113 | if (h != height - 1){ 114 | grad_weight_up_map[index] = temp * output_up[INDEX(b, c, h+1, w, channels, height, width)] * (output_up[index] > 0? 1:0); 115 | grad_bias_up_map[index] = diff_up; 116 | } 117 | 118 | // down 119 | 120 | for (int i = height -1; i>=h; i--) 121 | { 122 | diff_down *= weight_down[c]; 123 | diff_down += grad_output_down[INDEX(b, c, i, w, channels, height, width)]; 124 | diff_down *= (output_down[INDEX(b, c, i, w, channels, height, width)]<=0)? 0 : 1; 125 | } 126 | 127 | 128 | temp = grad_output_down[INDEX(b, c, height-1, w, channels, height, width)]; 129 | for (int i = height -2; i > h - 1 ; i--) 130 | { 131 | temp = (output_down[INDEX(b, c, i+1, w, channels, height, width)] >0?1:0) * temp * weight_down[c] + grad_output_down[INDEX(b, c, i, w, channels, height, width)]; 132 | } 133 | 134 | if (h != 0){ 135 | grad_weight_down_map[index] = temp * output_down[INDEX(b, c, h-1, w, channels, height, width)] * (output_down[index] > 0? 1:0); 136 | grad_bias_down_map[index] = diff_down; 137 | } 138 | grad_input[index] = diff_down + diff_left + diff_right + diff_up; 139 | } 140 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Triple-cooperative Video Shadow Detection 2 | Code and dataset for the CVPR 2021 paper **"Triple-cooperative Video Shadow Detection"**[[arXiv link](https://arxiv.org/abs/2103.06533)] [[official link](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Triple-Cooperative_Video_Shadow_Detection_CVPR_2021_paper.pdf)]. 3 | by Zhihao Chen1, Liang Wan1, Lei Zhu2, Jia Shen1, Huazhu Fu3, Wennan Liu4, and Jing Qin5 4 | 1College of Intelligence and Computing, Tianjin University 5 | 2Department of Applied Mathematics and Theoretical Physics, University of Cambridge 6 | 3Inception Institute of Artificial Intelligence, UAE 7 | 4Academy of Medical Engineering and Translational Medicine, Tianjin University 8 | 5The Hong Kong Polytechnic University 9 | 10 | #### News: In 2021.4.7, We first release the code of TVSD and ViSha dataset. 11 | #### News: In 2022.5.7, [Lihao Liu](https://github.com/lihaoliu-cambridge/video-shadow-detection) publish a pytorch-lightning implementation for TVSD. 12 | 13 | *** 14 | 15 | ## Citation 16 | @inproceedings{chen21TVSD, 17 |      author = {Chen, Zhihao and Wan, Liang and Zhu, Lei and Shen, Jia and Fu, Huazhu and Liu, Wennan and Qin, Jing}, 18 |      title = {Triple-cooperative Video Shadow Detection}, 19 |      booktitle = {CVPR}, 20 |      year = {2021} 21 | } 22 | 23 | ## Pytorch-lightning Version 24 | Pytorch-lightning Version is available at [https://github.com/lihaoliu-cambridge/video-shadow-detection](https://github.com/lihaoliu-cambridge/video-shadow-detection) implemented by [Lihao Liu](https://github.com/lihaoliu-cambridge) 25 | 26 | ## Dataset 27 | ViSha dataset is available at **[ViSha Homepage](https://erasernut.github.io/ViSha.html)** 28 | 29 | ## Requirement 30 | * Python 3.6 31 | * PyTorch 1.3.1 32 | * torchvision 33 | * numpy 34 | * tqdm 35 | * PIL 36 | * math 37 | * time 38 | * datatime 39 | * argparse 40 | * apex (alternative, fp16 for save memory and speedup) 41 | 42 | ## Training 43 | 1. Modify the data path on ```./config.py``` 44 | 2. Modify the pretrained backbone path on ```./networks/resnext_modify/config.py``` 45 | 3. Run by ```python train.py``` and model will be saved in ```./models/TVSD``` 46 | 47 | The pretrained ResNeXt model is ported from the [official](https://github.com/facebookresearch/ResNeXt) torch version, 48 | using the [convertor](https://github.com/clcarwin/convert_torch_to_pytorch) provided by clcarwin. 49 | You can directly [download](https://drive.google.com/open?id=1dnH-IHwmu9xFPlyndqI6MfF4LvH6JKNQ) the pretrained model ported by us. 50 | 51 | ## Testing 52 | 1. Modify the data path on ```./config.py``` 53 | 2. Make sure you have a snapshot in ```./models/TVSD``` (Tips: You can download the trained model which is reported in our paper at [BaiduNetdisk](https://pan.baidu.com/s/17d-wLwA5oyafMdooJlesyw)(pw: 8p5h) or [Google Drive](https://drive.google.com/file/d/14dSMN6P7fUyL_KOubaXOAUZp_Dc0tFzq/view?usp=sharing)) 54 | 4. Run by ```python infer.py``` to generate predicted masks 55 | 5. Run by ```python evaluate.py``` to evaluate the generated results 56 | 57 | ## Results in ViSha testing set 58 | As mentioned in our paper, since there is no CNN-based method for video shadow detection, we make comparison against 12 state-of-the-art methods for relevant tasks, including BDRAR[1], DSD[2], MTMT[3] (single-image shadow detection), FPN[4], PSPNet[5] (single-image semantic segmentation), DSS[6], R^3 Net[7] (single-image saliency detection), PDBM[8], MAG[9] (video saliency detection), COSNet[10], FEELVOS[11], STM[12] (object object segmentation) 59 | [1]L. Zhu, Z. Deng, X. Hu, C.-W. Fu, X. Xu, J. Qin, and P.-A. Heng. Bidirectional feature pyramid network with recurrent attention residual modules for shadow detection. In ECCV, pages 121–136, 2018. 60 | [2]Q. Zheng, X. Qiao, Y. Cao, and R.W. Lau. Distraction-aware shadow detection. In CVPR, pages 5167–5176, 2019. 61 | [3]Z. Chen, L. Zhu, L. Wan, S. Wang, W. Feng, and P.-A. Heng. A multi-task mean teacher for semi-supervised shadow detection. In CVPR, pages 5611–5620, 2020. 62 | [4]T.-Y. Lin, P. Doll´ar, R. Girshick, K. He, B. Hariharan, and S.Belongie. Feature pyramid networks for object detection. In CVPR, pages 2117–2125, 2017. 63 | [5]H. Zhao, J. Shi, X. Qi, X. Wang, and J. Jia. Pyramid scene parsing network. In CVPR, pages 2881–2890, 2017. 64 | [6]Q. Hou, M. Cheng, X. Hu, A. Borji, Z. Tu, and P. Torr. Deeply supervised salient object detection with short connections. IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(4):815–828, 2019. 65 | [7]Z. Deng, X. Hu, L. Zhu, X. Xu, J. Qin, G. Han, and P.-A. Heng. R3net: Recurrent residual refinement network for saliency detection. In IJCAI, pages 684–690. AAAI Press, 2018. 66 | [8]H. Song, W. Wang, S. Zhao, J. Shen, and K.-M. Lam. Pyramid dilated deeper convlstm for video salient object detection. In ECCV, pages 715–731, 2018. 67 | [9]H. Li, G. Chen, G. Li, and Y. Yu. Motion guided attention for video salient object detection. In ICCV, pages 7274–7283, 2019. 68 | [10]X. Lu, W. Wang, C. Ma, J. Shen, L. Shao, and F. Porikli. See more, know more: Unsupervised video object segmentation with co-attention siamese networks. In CVPR, pages 3623–3632, 2019. 69 | [11]P. Voigtlaender, Y. Chai, F. Schroff, H. Adam, B. Leibe, and L.-C. Chen. Feelvos: Fast end-to-end embedding learning for video object segmentation. In CVPR, June 2019. 70 | [12]S.W. Oh, J.-Y. Lee, N. Xu, and S.J. Kim. Video object segmentation using space-time memory networks. In ICCV, pages 9226–9235, 2019. 71 | 72 | We evaluate those methods and our TVSD in ViSha testing set and release all results in [BaiduNetdisk](https://pan.baidu.com/s/1t_PgW3JCrTGvf_PVyeR-iw)(pw: ritw) or [Google Drive](https://drive.google.com/drive/folders/13XgGxu9DDuuz2vS6ugFrLLmXRZzoHWhb?usp=sharing) 73 | -------------------------------------------------------------------------------- /dataset/VShadow_crosspairwise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import torch.utils.data as data 5 | from PIL import Image 6 | import random 7 | import torch 8 | import numpy as np 9 | 10 | 11 | # return image triple pairs in video and return single image 12 | class CrossPairwiseImg(data.Dataset): 13 | def __init__(self, root, joint_transform=None, img_transform=None, target_transform=None): 14 | self.img_root, self.video_root = self.split_root(root) 15 | self.joint_transform = joint_transform 16 | self.img_transform = img_transform 17 | self.target_transform = target_transform 18 | self.input_folder = 'images' 19 | self.label_folder = 'labels' 20 | self.img_ext = '.jpg' 21 | self.label_ext = '.png' 22 | self.num_video_frame = 0 23 | # get all frames from video datasets 24 | self.videoImg_list = self.generateImgFromVideo(self.video_root) 25 | print('Total video frames is {}.'.format(self.num_video_frame)) 26 | # get all frames from image datasets 27 | if len(self.img_root) > 0: 28 | self.singleImg_list = self.generateImgFromSingle(self.img_root) 29 | print('Total single image frames is {}.'.format(len(self.singleImg_list))) 30 | 31 | 32 | def __getitem__(self, index): 33 | manual_random = random.random() # random for transformation 34 | # pair in video 35 | exemplar_path, exemplar_gt_path, videoStartIndex, videoLength = self.videoImg_list[index] # exemplar 36 | # sample from same video 37 | query_index = np.random.randint(videoStartIndex, videoStartIndex + videoLength) 38 | if query_index == index: 39 | query_index = np.random.randint(videoStartIndex, videoStartIndex + videoLength) 40 | query_path, query_gt_path, videoStartIndex2, videoLength2 = self.videoImg_list[query_index] # query 41 | if videoStartIndex != videoStartIndex2 or videoLength != videoLength2: 42 | raise TypeError('Something wrong') 43 | # sample from different video 44 | while True: 45 | other_index = np.random.randint(0, self.__len__()) 46 | if other_index < videoStartIndex or other_index > videoStartIndex + videoLength - 1: 47 | break # find image from different video 48 | other_path, other_gt_path, videoStartIndex3, videoLength3 = self.videoImg_list[other_index] # other 49 | if videoStartIndex == videoStartIndex3: 50 | raise TypeError('Something wrong') 51 | # single image in image dataset 52 | if len(self.img_root) > 0: 53 | single_idx = np.random.randint(0, videoLength) 54 | single_image_path, single_gt_path = self.singleImg_list[single_idx] # single image 55 | 56 | # read image and gt 57 | exemplar = Image.open(exemplar_path).convert('RGB') 58 | query = Image.open(query_path).convert('RGB') 59 | other = Image.open(other_path).convert('RGB') 60 | exemplar_gt = Image.open(exemplar_gt_path).convert('L') 61 | query_gt = Image.open(query_gt_path).convert('L') 62 | other_gt = Image.open(other_gt_path).convert('L') 63 | if len(self.img_root) > 0: 64 | single_image = Image.open(single_image_path).convert('RGB') 65 | single_gt = Image.open(single_gt_path).convert('L') 66 | 67 | # transformation 68 | if self.joint_transform is not None: 69 | exemplar, exemplar_gt = self.joint_transform(exemplar, exemplar_gt, manual_random) 70 | query, query_gt = self.joint_transform(query, query_gt, manual_random) 71 | other, other_gt = self.joint_transform(other, other_gt) 72 | if len(self.img_root) > 0: 73 | single_image, single_gt = self.joint_transform(single_image, single_gt) 74 | if self.img_transform is not None: 75 | exemplar = self.img_transform(exemplar) 76 | query = self.img_transform(query) 77 | other = self.img_transform(other) 78 | if len(self.img_root) > 0: 79 | single_image = self.img_transform(single_image) 80 | if self.target_transform is not None: 81 | exemplar_gt = self.target_transform(exemplar_gt) 82 | query_gt = self.target_transform(query_gt) 83 | other_gt = self.target_transform(other_gt) 84 | if len(self.img_root) > 0: 85 | single_gt = self.target_transform(single_gt) 86 | if len(self.img_root) > 0: 87 | sample = {'exemplar': exemplar, 'exemplar_gt': exemplar_gt, 'query': query, 'query_gt': query_gt, 88 | 'other': other, 'other_gt': other_gt, 'single_image': single_image, 'single_gt': single_gt} 89 | else: 90 | sample = {'exemplar': exemplar, 'exemplar_gt': exemplar_gt, 'query': query, 'query_gt': query_gt, 91 | 'other': other, 'other_gt': other_gt} 92 | return sample 93 | 94 | def generateImgFromVideo(self, root): 95 | imgs = [] 96 | root = root[0] # assume that only one video dataset 97 | video_list = os.listdir(os.path.join(root[0], self.input_folder)) 98 | for video in video_list: 99 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root[0], self.input_folder, video)) if f.endswith(self.img_ext)] # no ext 100 | img_list = self.sortImg(img_list) 101 | for img in img_list: 102 | # videoImgGt: (img, gt, video start index, video length) 103 | videoImgGt = (os.path.join(root[0], self.input_folder, video, img + self.img_ext), 104 | os.path.join(root[0], self.label_folder, video, img + self.label_ext), self.num_video_frame, len(img_list)) 105 | imgs.append(videoImgGt) 106 | self.num_video_frame += len(img_list) 107 | return imgs 108 | 109 | def generateImgFromSingle(self, root): 110 | imgs = [] 111 | for sub_root in root: 112 | tmp = self.generateImagePair(sub_root[0]) 113 | imgs.extend(tmp) # deal with image case 114 | print('Image number of ImageSet {} is {}.'.format(sub_root[2], len(tmp))) 115 | 116 | return imgs 117 | 118 | def generateImagePair(self, root): 119 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, self.input_folder)) if f.endswith(self.img_ext)] 120 | if len(img_list) == 0: 121 | raise IOError('make sure the dataset path is correct') 122 | return [(os.path.join(root, self.input_folder, img_name + self.img_ext), os.path.join(root, self.label_folder, img_name + self.label_ext)) 123 | for img_name in img_list] 124 | 125 | def sortImg(self, img_list): 126 | img_int_list = [int(f) for f in img_list] 127 | sort_index = [i for i, v in sorted(enumerate(img_int_list), key=lambda x: x[1])] # sort img to 001,002,003... 128 | return [img_list[i] for i in sort_index] 129 | 130 | def split_root(self, root): 131 | if not isinstance(root, list): 132 | raise TypeError('root should be a list') 133 | img_root_list = [] 134 | video_root_list = [] 135 | for tmp in root: 136 | if tmp[1] == 'image': 137 | img_root_list.append(tmp) 138 | elif tmp[1] == 'video': 139 | video_root_list.append(tmp) 140 | else: 141 | raise TypeError('you should input video or image') 142 | return img_root_list, video_root_list 143 | 144 | def __len__(self): 145 | return len(self.videoImg_list)//2*2 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /networks/TVSD.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .DeepLabV3 import DeepLabV3 5 | 6 | class TVSD(nn.Module): 7 | def __init__(self, pretrained_path=None, num_classes=1, all_channel=256, all_dim=26 * 26, T=0.07): # 473./8=60 416./8=52 8 | super(TVSD, self).__init__() 9 | self.encoder = DeepLabV3() 10 | self.T = T 11 | # load pretrained model from DeepLabV3 module 12 | # in our experiments, no need to pretrain the single deeplabv3 13 | if pretrained_path is not None: 14 | checkpoint = torch.load(pretrained_path) 15 | print(f"Load checkpoint:{pretrained_path}") 16 | self.encoder.load_state_dict(checkpoint['model']) 17 | self.co_attention = CoattentionModel(num_classes=num_classes, all_channel=all_channel, all_dim=all_dim) 18 | self.project = nn.Sequential( 19 | nn.Conv2d(256, 48, 1, bias=False), 20 | nn.BatchNorm2d(48), 21 | nn.ReLU(inplace=True) 22 | ) 23 | self.final_pre = nn.Sequential( 24 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 25 | nn.BatchNorm2d(256), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(256, num_classes, 1) 28 | ) 29 | initialize_weights(self.co_attention, self.project, self.final_pre) 30 | 31 | def forward(self, input1, input2, input3): 32 | input_size = input1.size()[2:] 33 | low_exemplar, exemplar, _ = self.encoder(input1) 34 | low_query, query, _ = self.encoder(input2) 35 | low_other, other, _ = self.encoder(input3) 36 | x1, x2 = self.co_attention(exemplar, query) 37 | x1 = F.interpolate(x1, size=low_exemplar.shape[2:], mode='bilinear', align_corners=False) 38 | x2 = F.interpolate(x2, size=low_query.shape[2:], mode='bilinear', align_corners=False) 39 | x3 = F.interpolate(other, size=low_other.shape[2:], mode='bilinear', align_corners=False) 40 | fuse_exemplar = torch.cat([x1, self.project(low_exemplar)], dim=1) 41 | fuse_query = torch.cat([x2, self.project(low_query)], dim=1) 42 | fuse_other = torch.cat([x3, self.project(low_other)], dim=1) 43 | exemplar_pre = self.final_pre(fuse_exemplar) 44 | query_pre = self.final_pre(fuse_query) 45 | other_pre = self.final_pre(fuse_other) 46 | 47 | # scene vector 48 | v1 = F.adaptive_avg_pool2d(exemplar, (1, 1)).squeeze(-1).squeeze(-1) 49 | v1 = nn.functional.normalize(v1, dim=1) 50 | v2 = F.adaptive_avg_pool2d(query, (1, 1)).squeeze(-1).squeeze(-1) 51 | v2 = nn.functional.normalize(v2, dim=1) 52 | v3 = F.adaptive_avg_pool2d(other, (1, 1)).squeeze(-1).squeeze(-1) 53 | v3 = nn.functional.normalize(v3, dim=1) 54 | 55 | l_pos = torch.einsum('nc,nc->n', [v1, v2]).unsqueeze(-1) 56 | l_neg1 = torch.einsum('nc,nc->n', [v1, v3]).unsqueeze(-1) 57 | # l_neg2 = torch.einsum('nc,nc->n', [v2, v3]).unsqueeze(-1) 58 | # logits = torch.cat([l_pos, l_neg1, l_neg2], dim=1) 59 | logits = torch.cat([l_pos, l_neg1], dim=1) 60 | logits /= self.T 61 | exemplar_pre = F.upsample(exemplar_pre, input_size, mode='bilinear', align_corners=False) # upsample to the size of input image, scale=8 62 | query_pre = F.upsample(query_pre, input_size, mode='bilinear', align_corners=False) # upsample to the size of input image, scale=8 63 | other_pre = F.upsample(other_pre, input_size, mode='bilinear', align_corners=False) # upsample to the size of input image, scale=8 64 | return exemplar_pre, query_pre, other_pre, logits 65 | 66 | 67 | 68 | class CoattentionModel(nn.Module): # spatial and channel attention module 69 | def __init__(self, num_classes=1, all_channel=256, all_dim=26 * 26): # 473./8=60 416./8=52 70 | super(CoattentionModel, self).__init__() 71 | self.linear_e = nn.Linear(all_channel, all_channel, bias=False) 72 | self.channel = all_channel 73 | self.dim = all_dim 74 | self.gate1 = nn.Conv2d(all_channel * 2, 1, kernel_size=1, bias=False) 75 | self.gate2 = nn.Conv2d(all_channel * 2, 1, kernel_size=1, bias=False) 76 | self.gate_s = nn.Sigmoid() 77 | self.conv1 = nn.Conv2d(all_channel * 2, all_channel, kernel_size=3, padding=1, bias=False) 78 | self.conv2 = nn.Conv2d(all_channel * 2, all_channel, kernel_size=3, padding=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(all_channel) 80 | self.bn2 = nn.BatchNorm2d(all_channel) 81 | self.prelu = nn.ReLU(inplace=True) 82 | self.globalAvgPool = nn.AvgPool2d(26, stride=1) 83 | self.fc1 = nn.Linear(in_features=256*2, out_features=16) 84 | self.fc2 = nn.Linear(in_features=16, out_features=256) 85 | self.fc3 = nn.Linear(in_features=256*2, out_features=16) 86 | self.fc4 = nn.Linear(in_features=16, out_features=256) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.sigmoid = nn.Sigmoid() 89 | 90 | def forward(self, exemplar, query): 91 | 92 | # spatial co-attention 93 | fea_size = query.size()[2:] 94 | all_dim = fea_size[0] * fea_size[1] 95 | exemplar_flat = exemplar.view(-1, query.size()[1], all_dim) # N,C,H*W 96 | query_flat = query.view(-1, query.size()[1], all_dim) 97 | exemplar_t = torch.transpose(exemplar_flat, 1, 2).contiguous() # batch size x dim x num 98 | exemplar_corr = self.linear_e(exemplar_t) # 99 | A = torch.bmm(exemplar_corr, query_flat) 100 | A1 = F.softmax(A.clone(), dim=1) # 101 | B = F.softmax(torch.transpose(A, 1, 2), dim=1) 102 | query_att = torch.bmm(exemplar_flat, A1).contiguous() 103 | exemplar_att = torch.bmm(query_flat, B).contiguous() 104 | input1_att = exemplar_att.view(-1, query.size()[1], fea_size[0], fea_size[1]) 105 | input2_att = query_att.view(-1, query.size()[1], fea_size[0], fea_size[1]) 106 | 107 | # spacial attention 108 | input1_mask = self.gate1(torch.cat([input1_att, input2_att], dim=1)) 109 | input2_mask = self.gate2(torch.cat([input1_att, input2_att], dim=1)) 110 | input1_mask = self.gate_s(input1_mask) 111 | input2_mask = self.gate_s(input2_mask) 112 | 113 | # channel attention 114 | out_e = self.globalAvgPool(torch.cat([input1_att, input2_att], dim=1)) 115 | out_e = out_e.view(out_e.size(0), -1) 116 | out_e = self.fc1(out_e) 117 | out_e = self.relu(out_e) 118 | out_e = self.fc2(out_e) 119 | out_e = self.sigmoid(out_e) 120 | out_e = out_e.view(out_e.size(0), out_e.size(1), 1, 1) 121 | out_q = self.globalAvgPool(torch.cat([input1_att, input2_att], dim=1)) 122 | out_q = out_q.view(out_q.size(0), -1) 123 | out_q = self.fc3(out_q) 124 | out_q = self.relu(out_q) 125 | out_q = self.fc4(out_q) 126 | out_q = self.sigmoid(out_q) 127 | out_q = out_q.view(out_q.size(0), out_q.size(1), 1, 1) 128 | 129 | # apply dual attention masks 130 | input1_att = input1_att * input1_mask 131 | input2_att = input2_att * input2_mask 132 | input2_att = out_e * input2_att 133 | input1_att = out_q * input1_att 134 | 135 | # concate original feature 136 | input1_att = torch.cat([input1_att, exemplar], 1) 137 | input2_att = torch.cat([input2_att, query], 1) 138 | input1_att = self.conv1(input1_att) 139 | input2_att = self.conv2(input2_att) 140 | input1_att = self.bn1(input1_att) 141 | input2_att = self.bn2(input2_att) 142 | input1_att = self.prelu(input1_att) 143 | input2_att = self.prelu(input2_att) 144 | 145 | return input1_att, input2_att # shape: NxCx 146 | 147 | def initialize_weights(*models): 148 | for model in models: 149 | for module in model.modules(): 150 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 151 | nn.init.kaiming_normal_(module.weight) 152 | if module.bias is not None: 153 | module.bias.data.zero_() 154 | elif isinstance(module, nn.BatchNorm2d): 155 | module.weight.data.fill_(1) 156 | module.bias.data.zero_() 157 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.backends import cudnn 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from tqdm import tqdm 11 | 12 | import joint_transforms 13 | from config import ViSha_training_root 14 | from dataset.VShadow_crosspairwise import CrossPairwiseImg 15 | from misc import AvgMeter, check_mkdir 16 | from networks.TVSD import TVSD 17 | from torch.optim.lr_scheduler import StepLR 18 | import math 19 | from losses import lovasz_hinge, binary_xloss 20 | import random 21 | import torch.nn.functional as F 22 | import numpy as np 23 | from apex import amp 24 | import time 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 27 | cudnn.deterministic = True 28 | cudnn.benchmark = False 29 | 30 | ckpt_path = './models' 31 | exp_name = 'TVSD' 32 | 33 | args = { 34 | 'max_epoch': 12, 35 | 'train_batch_size': 5, 36 | 'last_iter': 0, 37 | 'finetune_lr': 5e-5, 38 | 'scratch_lr': 5e-4, 39 | 'weight_decay': 5e-4, 40 | 'momentum': 0.9, 41 | 'snapshot': '', 42 | 'scale': 416, 43 | 'multi-scale': None, 44 | 'gpu': '0,1', 45 | 'multi-GPUs': False, 46 | 'fp16': True, 47 | 'warm_up_epochs': 3, 48 | 'seed': 2020 49 | } 50 | # fix random seed 51 | np.random.seed(args['seed']) 52 | torch.manual_seed(args['seed']) 53 | torch.cuda.manual_seed(args['seed']) 54 | 55 | # multi-GPUs training 56 | if args['multi-GPUs']: 57 | os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu'] 58 | batch_size = args['train_batch_size'] * len(args['gpu'].split(',')) 59 | # single-GPU training 60 | else: 61 | torch.cuda.set_device(0) 62 | batch_size = args['train_batch_size'] 63 | 64 | joint_transform = joint_transforms.Compose([ 65 | joint_transforms.RandomHorizontallyFlip(), 66 | joint_transforms.Resize((args['scale'], args['scale'])) 67 | ]) 68 | val_joint_transform = joint_transforms.Compose([ 69 | joint_transforms.Resize((args['scale'], args['scale'])) 70 | ]) 71 | img_transform = transforms.Compose([ 72 | transforms.ToTensor(), 73 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 74 | ]) 75 | target_transform = transforms.ToTensor() 76 | to_pil = transforms.ToPILImage() 77 | 78 | print('=====>Dataset loading<======') 79 | training_root = [ViSha_training_root] # training_root should be a list form, like [datasetA, datasetB, datasetC], here we use only one dataset. 80 | train_set = CrossPairwiseImg(training_root, joint_transform, img_transform, target_transform) 81 | train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=8, shuffle=True) 82 | print("max epoch:{}".format(args['max_epoch'])) 83 | 84 | ce_loss = nn.CrossEntropyLoss() 85 | 86 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt') 87 | 88 | def main(): 89 | print('=====>Prepare Network {}<======'.format(exp_name)) 90 | # multi-GPUs training 91 | if args['multi-GPUs']: 92 | net = torch.nn.DataParallel(TVSD()).cuda().train() 93 | params = [ 94 | {"params": net.module.encoder.parameters(), "lr": args['finetune_lr']}, 95 | {"params": net.module.co_attention.parameters(), "lr": args['scratch_lr']}, 96 | {"params": net.module.encoder.final_pre.parameters(), "lr": args['scratch_lr']}, 97 | {"params": net.module.co_attention.parameters(), "lr": args['scratch_lr']}, 98 | {"params": net.module.project.parameters(), "lr": args['scratch_lr']}, 99 | {"params": net.module.final_pre.parameters(), "lr": args['scratch_lr']} 100 | ] 101 | # single-GPU training 102 | else: 103 | net = TVSD().cuda().train() 104 | params = [ 105 | {"params": net.encoder.backbone.parameters(), "lr": args['finetune_lr']}, 106 | {"params": net.encoder.aspp.parameters(), "lr": args['scratch_lr']}, 107 | {"params": net.encoder.final_pre.parameters(), "lr": args['scratch_lr']}, 108 | {"params": net.co_attention.parameters(), "lr": args['scratch_lr']}, 109 | {"params": net.project.parameters(), "lr": args['scratch_lr']}, 110 | {"params": net.final_pre.parameters(), "lr": args['scratch_lr']} 111 | ] 112 | 113 | # optimizer = optim.SGD(params, momentum=args['momentum'], weight_decay=args['weight_decay']) 114 | optimizer = optim.Adam(params, betas=(0.9, 0.99), eps=6e-8, weight_decay=args['weight_decay']) 115 | warm_up_with_cosine_lr = lambda epoch: epoch / args['warm_up_epochs'] if epoch <= args['warm_up_epochs'] else 0.5 * \ 116 | (math.cos((epoch-args['warm_up_epochs'])/(args['max_epoch']-args['warm_up_epochs'])*math.pi)+1) 117 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_cosine_lr) 118 | # scheduler = StepLR(optimizer, step_size=10, gamma=0.1) # change learning rate after 20000 iters 119 | 120 | check_mkdir(ckpt_path) 121 | check_mkdir(os.path.join(ckpt_path, exp_name)) 122 | open(log_path, 'w').write(str(args) + '\n\n') 123 | if args['fp16']: 124 | net, optimizer = amp.initialize(net, optimizer, opt_level="O1") 125 | train(net, optimizer, scheduler) 126 | 127 | 128 | def train(net, optimizer, scheduler): 129 | curr_epoch = 1 130 | curr_iter = 1 131 | start = 0 132 | print('=====>Start training<======') 133 | while True: 134 | loss_record1, loss_record2, loss_record3, loss_record4 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 135 | loss_record5, loss_record6, loss_record7, loss_record8 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 136 | 137 | for i, sample in enumerate(tqdm(train_loader, desc=f'Epoch: {curr_epoch}', ncols=100, ascii=' =', bar_format='{l_bar}{bar}|')): 138 | 139 | exemplar, exemplar_gt, query, query_gt = sample['exemplar'].cuda(), sample['exemplar_gt'].cuda(), sample['query'].cuda(), sample['query_gt'].cuda() 140 | other, other_gt = sample['other'].cuda(), sample['other_gt'].cuda() 141 | 142 | optimizer.zero_grad() 143 | 144 | exemplar_pre, query_pre, other_pre, scene_logits = net(exemplar, query, other) 145 | 146 | bce_loss1 = binary_xloss(exemplar_pre, exemplar_gt) 147 | bce_loss2 = binary_xloss(query_pre, query_gt) 148 | bce_loss3 = binary_xloss(other_pre, other_gt) 149 | 150 | loss_hinge1 = lovasz_hinge(exemplar_pre, exemplar_gt) 151 | loss_hinge2 = lovasz_hinge(query_pre, query_gt) 152 | loss_hinge3 = lovasz_hinge(other_pre, other_gt) 153 | 154 | loss_seg = bce_loss1 + bce_loss2 + bce_loss3 + loss_hinge1 + loss_hinge2 + loss_hinge3 155 | # classification loss 156 | scene_labels = torch.zeros(scene_logits.shape[0], dtype=torch.long).cuda() 157 | cla_loss = ce_loss(scene_logits, scene_labels) * 10 158 | loss = loss_seg + cla_loss 159 | 160 | if args['fp16']: 161 | with amp.scale_loss(loss, optimizer) as scaled_loss: 162 | scaled_loss.backward() 163 | else: 164 | loss.backward() 165 | 166 | torch.nn.utils.clip_grad_norm_(net.parameters(), 12) # gradient clip 167 | optimizer.step() # change gradient 168 | 169 | loss_record1.update(bce_loss1.item(), batch_size) 170 | loss_record2.update(bce_loss2.item(), batch_size) 171 | loss_record3.update(bce_loss3.item(), batch_size) 172 | loss_record4.update(loss_hinge1.item(), batch_size) 173 | loss_record5.update(loss_hinge2.item(), batch_size) 174 | loss_record6.update(loss_hinge3.item(), batch_size) 175 | loss_record7.update(cla_loss.item(), batch_size) 176 | 177 | curr_iter += 1 178 | 179 | log = "iter: %d, bce1: %f5, bce2: %f5, bce3: %f5, hinge1: %f5, hinge2: %f5, hinge3: %f5, cla: %f5, lr: %f8"%\ 180 | (curr_iter, loss_record1.avg, loss_record2.avg, loss_record3.avg, loss_record4.avg, loss_record5.avg, 181 | loss_record6.avg, loss_record7.avg, scheduler.get_lr()[0]) 182 | 183 | if (curr_iter-1) % 20 == 0: 184 | elapsed = (time.clock() - start) 185 | start = time.clock() 186 | log_time = log + ' [time {}]'.format(elapsed) 187 | print(log_time) 188 | open(log_path, 'a').write(log + '\n') 189 | 190 | if curr_epoch % 1 == 0: 191 | if args['multi-GPUs']: 192 | # torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_epoch)) 193 | checkpoint = { 194 | 'model': net.module.state_dict(), 195 | 'optimizer': optimizer.state_dict(), 196 | 'amp': amp.state_dict() 197 | } 198 | torch.save(checkpoint, os.path.join(ckpt_path, exp_name, f'{curr_epoch}.pth')) 199 | else: 200 | # torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_epoch)) 201 | checkpoint = { 202 | 'model': net.state_dict(), 203 | 'optimizer': optimizer.state_dict(), 204 | 'amp': amp.state_dict() 205 | } 206 | torch.save(checkpoint, os.path.join(ckpt_path, exp_name, f'{curr_epoch}.pth')) 207 | if curr_epoch > args['max_epoch']: 208 | # torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) 209 | return 210 | 211 | curr_epoch += 1 212 | scheduler.step() # change learning rate after epoch 213 | 214 | 215 | if __name__ == '__main__': 216 | main() -------------------------------------------------------------------------------- /networks/resnext_modify/resnext_101_32x4d_.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class LambdaBase(nn.Sequential): 7 | def __init__(self, fn, *args): 8 | super(LambdaBase, self).__init__(*args) 9 | self.lambda_func = fn 10 | 11 | def forward_prepare(self, input): 12 | output = [] 13 | for module in self._modules.values(): 14 | output.append(module(input)) 15 | return output if output else input 16 | 17 | 18 | class Lambda(LambdaBase): 19 | def forward(self, input): 20 | return self.lambda_func(self.forward_prepare(input)) 21 | 22 | 23 | class LambdaMap(LambdaBase): 24 | def forward(self, input): 25 | return list(map(self.lambda_func, self.forward_prepare(input))) 26 | 27 | 28 | class LambdaReduce(LambdaBase): 29 | def forward(self, input): 30 | return reduce(self.lambda_func, self.forward_prepare(input)) 31 | 32 | 33 | resnext_101_32x4d = nn.Sequential( # Sequential, 34 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), 35 | nn.BatchNorm2d(64), 36 | nn.ReLU(), 37 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)), 38 | nn.Sequential( # Sequential, 39 | nn.Sequential( # Sequential, 40 | LambdaMap(lambda x: x, # ConcatTable, 41 | nn.Sequential( # Sequential, 42 | nn.Sequential( # Sequential, 43 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 44 | nn.BatchNorm2d(128), 45 | nn.ReLU(), 46 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 47 | nn.BatchNorm2d(128), 48 | nn.ReLU(), 49 | ), 50 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 51 | nn.BatchNorm2d(256), 52 | ), 53 | nn.Sequential( # Sequential, 54 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 55 | nn.BatchNorm2d(256), 56 | ), 57 | ), 58 | LambdaReduce(lambda x, y: x + y), # CAddTable, 59 | nn.ReLU(), 60 | ), 61 | nn.Sequential( # Sequential, 62 | LambdaMap(lambda x: x, # ConcatTable, 63 | nn.Sequential( # Sequential, 64 | nn.Sequential( # Sequential, 65 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 66 | nn.BatchNorm2d(128), 67 | nn.ReLU(), 68 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 69 | nn.BatchNorm2d(128), 70 | nn.ReLU(), 71 | ), 72 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 73 | nn.BatchNorm2d(256), 74 | ), 75 | Lambda(lambda x: x), # Identity, 76 | ), 77 | LambdaReduce(lambda x, y: x + y), # CAddTable, 78 | nn.ReLU(), 79 | ), 80 | nn.Sequential( # Sequential, 81 | LambdaMap(lambda x: x, # ConcatTable, 82 | nn.Sequential( # Sequential, 83 | nn.Sequential( # Sequential, 84 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 85 | nn.BatchNorm2d(128), 86 | nn.ReLU(), 87 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 88 | nn.BatchNorm2d(128), 89 | nn.ReLU(), 90 | ), 91 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 92 | nn.BatchNorm2d(256), 93 | ), 94 | Lambda(lambda x: x), # Identity, 95 | ), 96 | LambdaReduce(lambda x, y: x + y), # CAddTable, 97 | nn.ReLU(), 98 | ), 99 | ), 100 | nn.Sequential( # Sequential, 101 | nn.Sequential( # Sequential, 102 | LambdaMap(lambda x: x, # ConcatTable, 103 | nn.Sequential( # Sequential, 104 | nn.Sequential( # Sequential, 105 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 106 | nn.BatchNorm2d(256), 107 | nn.ReLU(), 108 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 109 | nn.BatchNorm2d(256), 110 | nn.ReLU(), 111 | ), 112 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 113 | nn.BatchNorm2d(512), 114 | ), 115 | nn.Sequential( # Sequential, 116 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 117 | nn.BatchNorm2d(512), 118 | ), 119 | ), 120 | LambdaReduce(lambda x, y: x + y), # CAddTable, 121 | nn.ReLU(), 122 | ), 123 | nn.Sequential( # Sequential, 124 | LambdaMap(lambda x: x, # ConcatTable, 125 | nn.Sequential( # Sequential, 126 | nn.Sequential( # Sequential, 127 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 128 | nn.BatchNorm2d(256), 129 | nn.ReLU(), 130 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 131 | nn.BatchNorm2d(256), 132 | nn.ReLU(), 133 | ), 134 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 135 | nn.BatchNorm2d(512), 136 | ), 137 | Lambda(lambda x: x), # Identity, 138 | ), 139 | LambdaReduce(lambda x, y: x + y), # CAddTable, 140 | nn.ReLU(), 141 | ), 142 | nn.Sequential( # Sequential, 143 | LambdaMap(lambda x: x, # ConcatTable, 144 | nn.Sequential( # Sequential, 145 | nn.Sequential( # Sequential, 146 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 147 | nn.BatchNorm2d(256), 148 | nn.ReLU(), 149 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 150 | nn.BatchNorm2d(256), 151 | nn.ReLU(), 152 | ), 153 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 154 | nn.BatchNorm2d(512), 155 | ), 156 | Lambda(lambda x: x), # Identity, 157 | ), 158 | LambdaReduce(lambda x, y: x + y), # CAddTable, 159 | nn.ReLU(), 160 | ), 161 | nn.Sequential( # Sequential, 162 | LambdaMap(lambda x: x, # ConcatTable, 163 | nn.Sequential( # Sequential, 164 | nn.Sequential( # Sequential, 165 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 166 | nn.BatchNorm2d(256), 167 | nn.ReLU(), 168 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 169 | nn.BatchNorm2d(256), 170 | nn.ReLU(), 171 | ), 172 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 173 | nn.BatchNorm2d(512), 174 | ), 175 | Lambda(lambda x: x), # Identity, 176 | ), 177 | LambdaReduce(lambda x, y: x + y), # CAddTable, 178 | nn.ReLU(), 179 | ), 180 | ), 181 | nn.Sequential( # Sequential, 182 | nn.Sequential( # Sequential, 183 | LambdaMap(lambda x: x, # ConcatTable, 184 | nn.Sequential( # Sequential, 185 | nn.Sequential( # Sequential, 186 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 187 | nn.BatchNorm2d(512), 188 | nn.ReLU(), 189 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 190 | nn.BatchNorm2d(512), 191 | nn.ReLU(), 192 | ), 193 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 194 | nn.BatchNorm2d(1024), 195 | ), 196 | nn.Sequential( # Sequential, 197 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 198 | nn.BatchNorm2d(1024), 199 | ), 200 | ), 201 | LambdaReduce(lambda x, y: x + y), # CAddTable, 202 | nn.ReLU(), 203 | ), 204 | nn.Sequential( # Sequential, 205 | LambdaMap(lambda x: x, # ConcatTable, 206 | nn.Sequential( # Sequential, 207 | nn.Sequential( # Sequential, 208 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 209 | nn.BatchNorm2d(512), 210 | nn.ReLU(), 211 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 212 | nn.BatchNorm2d(512), 213 | nn.ReLU(), 214 | ), 215 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 216 | nn.BatchNorm2d(1024), 217 | ), 218 | Lambda(lambda x: x), # Identity, 219 | ), 220 | LambdaReduce(lambda x, y: x + y), # CAddTable, 221 | nn.ReLU(), 222 | ), 223 | nn.Sequential( # Sequential, 224 | LambdaMap(lambda x: x, # ConcatTable, 225 | nn.Sequential( # Sequential, 226 | nn.Sequential( # Sequential, 227 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 228 | nn.BatchNorm2d(512), 229 | nn.ReLU(), 230 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 231 | nn.BatchNorm2d(512), 232 | nn.ReLU(), 233 | ), 234 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 235 | nn.BatchNorm2d(1024), 236 | ), 237 | Lambda(lambda x: x), # Identity, 238 | ), 239 | LambdaReduce(lambda x, y: x + y), # CAddTable, 240 | nn.ReLU(), 241 | ), 242 | nn.Sequential( # Sequential, 243 | LambdaMap(lambda x: x, # ConcatTable, 244 | nn.Sequential( # Sequential, 245 | nn.Sequential( # Sequential, 246 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 247 | nn.BatchNorm2d(512), 248 | nn.ReLU(), 249 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 250 | nn.BatchNorm2d(512), 251 | nn.ReLU(), 252 | ), 253 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 254 | nn.BatchNorm2d(1024), 255 | ), 256 | Lambda(lambda x: x), # Identity, 257 | ), 258 | LambdaReduce(lambda x, y: x + y), # CAddTable, 259 | nn.ReLU(), 260 | ), 261 | nn.Sequential( # Sequential, 262 | LambdaMap(lambda x: x, # ConcatTable, 263 | nn.Sequential( # Sequential, 264 | nn.Sequential( # Sequential, 265 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 266 | nn.BatchNorm2d(512), 267 | nn.ReLU(), 268 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 269 | nn.BatchNorm2d(512), 270 | nn.ReLU(), 271 | ), 272 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 273 | nn.BatchNorm2d(1024), 274 | ), 275 | Lambda(lambda x: x), # Identity, 276 | ), 277 | LambdaReduce(lambda x, y: x + y), # CAddTable, 278 | nn.ReLU(), 279 | ), 280 | nn.Sequential( # Sequential, 281 | LambdaMap(lambda x: x, # ConcatTable, 282 | nn.Sequential( # Sequential, 283 | nn.Sequential( # Sequential, 284 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 285 | nn.BatchNorm2d(512), 286 | nn.ReLU(), 287 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 288 | nn.BatchNorm2d(512), 289 | nn.ReLU(), 290 | ), 291 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 292 | nn.BatchNorm2d(1024), 293 | ), 294 | Lambda(lambda x: x), # Identity, 295 | ), 296 | LambdaReduce(lambda x, y: x + y), # CAddTable, 297 | nn.ReLU(), 298 | ), 299 | nn.Sequential( # Sequential, 300 | LambdaMap(lambda x: x, # ConcatTable, 301 | nn.Sequential( # Sequential, 302 | nn.Sequential( # Sequential, 303 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 304 | nn.BatchNorm2d(512), 305 | nn.ReLU(), 306 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 307 | nn.BatchNorm2d(512), 308 | nn.ReLU(), 309 | ), 310 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 311 | nn.BatchNorm2d(1024), 312 | ), 313 | Lambda(lambda x: x), # Identity, 314 | ), 315 | LambdaReduce(lambda x, y: x + y), # CAddTable, 316 | nn.ReLU(), 317 | ), 318 | nn.Sequential( # Sequential, 319 | LambdaMap(lambda x: x, # ConcatTable, 320 | nn.Sequential( # Sequential, 321 | nn.Sequential( # Sequential, 322 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 323 | nn.BatchNorm2d(512), 324 | nn.ReLU(), 325 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 326 | nn.BatchNorm2d(512), 327 | nn.ReLU(), 328 | ), 329 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 330 | nn.BatchNorm2d(1024), 331 | ), 332 | Lambda(lambda x: x), # Identity, 333 | ), 334 | LambdaReduce(lambda x, y: x + y), # CAddTable, 335 | nn.ReLU(), 336 | ), 337 | nn.Sequential( # Sequential, 338 | LambdaMap(lambda x: x, # ConcatTable, 339 | nn.Sequential( # Sequential, 340 | nn.Sequential( # Sequential, 341 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 342 | nn.BatchNorm2d(512), 343 | nn.ReLU(), 344 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 345 | nn.BatchNorm2d(512), 346 | nn.ReLU(), 347 | ), 348 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 349 | nn.BatchNorm2d(1024), 350 | ), 351 | Lambda(lambda x: x), # Identity, 352 | ), 353 | LambdaReduce(lambda x, y: x + y), # CAddTable, 354 | nn.ReLU(), 355 | ), 356 | nn.Sequential( # Sequential, 357 | LambdaMap(lambda x: x, # ConcatTable, 358 | nn.Sequential( # Sequential, 359 | nn.Sequential( # Sequential, 360 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 361 | nn.BatchNorm2d(512), 362 | nn.ReLU(), 363 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 364 | nn.BatchNorm2d(512), 365 | nn.ReLU(), 366 | ), 367 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 368 | nn.BatchNorm2d(1024), 369 | ), 370 | Lambda(lambda x: x), # Identity, 371 | ), 372 | LambdaReduce(lambda x, y: x + y), # CAddTable, 373 | nn.ReLU(), 374 | ), 375 | nn.Sequential( # Sequential, 376 | LambdaMap(lambda x: x, # ConcatTable, 377 | nn.Sequential( # Sequential, 378 | nn.Sequential( # Sequential, 379 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 380 | nn.BatchNorm2d(512), 381 | nn.ReLU(), 382 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 383 | nn.BatchNorm2d(512), 384 | nn.ReLU(), 385 | ), 386 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 387 | nn.BatchNorm2d(1024), 388 | ), 389 | Lambda(lambda x: x), # Identity, 390 | ), 391 | LambdaReduce(lambda x, y: x + y), # CAddTable, 392 | nn.ReLU(), 393 | ), 394 | nn.Sequential( # Sequential, 395 | LambdaMap(lambda x: x, # ConcatTable, 396 | nn.Sequential( # Sequential, 397 | nn.Sequential( # Sequential, 398 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 399 | nn.BatchNorm2d(512), 400 | nn.ReLU(), 401 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 402 | nn.BatchNorm2d(512), 403 | nn.ReLU(), 404 | ), 405 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 406 | nn.BatchNorm2d(1024), 407 | ), 408 | Lambda(lambda x: x), # Identity, 409 | ), 410 | LambdaReduce(lambda x, y: x + y), # CAddTable, 411 | nn.ReLU(), 412 | ), 413 | nn.Sequential( # Sequential, 414 | LambdaMap(lambda x: x, # ConcatTable, 415 | nn.Sequential( # Sequential, 416 | nn.Sequential( # Sequential, 417 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 418 | nn.BatchNorm2d(512), 419 | nn.ReLU(), 420 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 421 | nn.BatchNorm2d(512), 422 | nn.ReLU(), 423 | ), 424 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 425 | nn.BatchNorm2d(1024), 426 | ), 427 | Lambda(lambda x: x), # Identity, 428 | ), 429 | LambdaReduce(lambda x, y: x + y), # CAddTable, 430 | nn.ReLU(), 431 | ), 432 | nn.Sequential( # Sequential, 433 | LambdaMap(lambda x: x, # ConcatTable, 434 | nn.Sequential( # Sequential, 435 | nn.Sequential( # Sequential, 436 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 437 | nn.BatchNorm2d(512), 438 | nn.ReLU(), 439 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 440 | nn.BatchNorm2d(512), 441 | nn.ReLU(), 442 | ), 443 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 444 | nn.BatchNorm2d(1024), 445 | ), 446 | Lambda(lambda x: x), # Identity, 447 | ), 448 | LambdaReduce(lambda x, y: x + y), # CAddTable, 449 | nn.ReLU(), 450 | ), 451 | nn.Sequential( # Sequential, 452 | LambdaMap(lambda x: x, # ConcatTable, 453 | nn.Sequential( # Sequential, 454 | nn.Sequential( # Sequential, 455 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 456 | nn.BatchNorm2d(512), 457 | nn.ReLU(), 458 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 459 | nn.BatchNorm2d(512), 460 | nn.ReLU(), 461 | ), 462 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 463 | nn.BatchNorm2d(1024), 464 | ), 465 | Lambda(lambda x: x), # Identity, 466 | ), 467 | LambdaReduce(lambda x, y: x + y), # CAddTable, 468 | nn.ReLU(), 469 | ), 470 | nn.Sequential( # Sequential, 471 | LambdaMap(lambda x: x, # ConcatTable, 472 | nn.Sequential( # Sequential, 473 | nn.Sequential( # Sequential, 474 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 475 | nn.BatchNorm2d(512), 476 | nn.ReLU(), 477 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 478 | nn.BatchNorm2d(512), 479 | nn.ReLU(), 480 | ), 481 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 482 | nn.BatchNorm2d(1024), 483 | ), 484 | Lambda(lambda x: x), # Identity, 485 | ), 486 | LambdaReduce(lambda x, y: x + y), # CAddTable, 487 | nn.ReLU(), 488 | ), 489 | nn.Sequential( # Sequential, 490 | LambdaMap(lambda x: x, # ConcatTable, 491 | nn.Sequential( # Sequential, 492 | nn.Sequential( # Sequential, 493 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 494 | nn.BatchNorm2d(512), 495 | nn.ReLU(), 496 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 497 | nn.BatchNorm2d(512), 498 | nn.ReLU(), 499 | ), 500 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 501 | nn.BatchNorm2d(1024), 502 | ), 503 | Lambda(lambda x: x), # Identity, 504 | ), 505 | LambdaReduce(lambda x, y: x + y), # CAddTable, 506 | nn.ReLU(), 507 | ), 508 | nn.Sequential( # Sequential, 509 | LambdaMap(lambda x: x, # ConcatTable, 510 | nn.Sequential( # Sequential, 511 | nn.Sequential( # Sequential, 512 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 513 | nn.BatchNorm2d(512), 514 | nn.ReLU(), 515 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 516 | nn.BatchNorm2d(512), 517 | nn.ReLU(), 518 | ), 519 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 520 | nn.BatchNorm2d(1024), 521 | ), 522 | Lambda(lambda x: x), # Identity, 523 | ), 524 | LambdaReduce(lambda x, y: x + y), # CAddTable, 525 | nn.ReLU(), 526 | ), 527 | nn.Sequential( # Sequential, 528 | LambdaMap(lambda x: x, # ConcatTable, 529 | nn.Sequential( # Sequential, 530 | nn.Sequential( # Sequential, 531 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 532 | nn.BatchNorm2d(512), 533 | nn.ReLU(), 534 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 535 | nn.BatchNorm2d(512), 536 | nn.ReLU(), 537 | ), 538 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 539 | nn.BatchNorm2d(1024), 540 | ), 541 | Lambda(lambda x: x), # Identity, 542 | ), 543 | LambdaReduce(lambda x, y: x + y), # CAddTable, 544 | nn.ReLU(), 545 | ), 546 | nn.Sequential( # Sequential, 547 | LambdaMap(lambda x: x, # ConcatTable, 548 | nn.Sequential( # Sequential, 549 | nn.Sequential( # Sequential, 550 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 551 | nn.BatchNorm2d(512), 552 | nn.ReLU(), 553 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 554 | nn.BatchNorm2d(512), 555 | nn.ReLU(), 556 | ), 557 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 558 | nn.BatchNorm2d(1024), 559 | ), 560 | Lambda(lambda x: x), # Identity, 561 | ), 562 | LambdaReduce(lambda x, y: x + y), # CAddTable, 563 | nn.ReLU(), 564 | ), 565 | nn.Sequential( # Sequential, 566 | LambdaMap(lambda x: x, # ConcatTable, 567 | nn.Sequential( # Sequential, 568 | nn.Sequential( # Sequential, 569 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 570 | nn.BatchNorm2d(512), 571 | nn.ReLU(), 572 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 573 | nn.BatchNorm2d(512), 574 | nn.ReLU(), 575 | ), 576 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 577 | nn.BatchNorm2d(1024), 578 | ), 579 | Lambda(lambda x: x), # Identity, 580 | ), 581 | LambdaReduce(lambda x, y: x + y), # CAddTable, 582 | nn.ReLU(), 583 | ), 584 | nn.Sequential( # Sequential, 585 | LambdaMap(lambda x: x, # ConcatTable, 586 | nn.Sequential( # Sequential, 587 | nn.Sequential( # Sequential, 588 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 589 | nn.BatchNorm2d(512), 590 | nn.ReLU(), 591 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 592 | nn.BatchNorm2d(512), 593 | nn.ReLU(), 594 | ), 595 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 596 | nn.BatchNorm2d(1024), 597 | ), 598 | Lambda(lambda x: x), # Identity, 599 | ), 600 | LambdaReduce(lambda x, y: x + y), # CAddTable, 601 | nn.ReLU(), 602 | ), 603 | nn.Sequential( # Sequential, 604 | LambdaMap(lambda x: x, # ConcatTable, 605 | nn.Sequential( # Sequential, 606 | nn.Sequential( # Sequential, 607 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 608 | nn.BatchNorm2d(512), 609 | nn.ReLU(), 610 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 611 | nn.BatchNorm2d(512), 612 | nn.ReLU(), 613 | ), 614 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 615 | nn.BatchNorm2d(1024), 616 | ), 617 | Lambda(lambda x: x), # Identity, 618 | ), 619 | LambdaReduce(lambda x, y: x + y), # CAddTable, 620 | nn.ReLU(), 621 | ), 622 | ), 623 | nn.Sequential( # Sequential, 624 | nn.Sequential( # Sequential, 625 | LambdaMap(lambda x: x, # ConcatTable, 626 | nn.Sequential( # Sequential, 627 | nn.Sequential( # Sequential, 628 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 629 | nn.BatchNorm2d(1024), 630 | nn.ReLU(), 631 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (2, 2), (2, 2), 32, bias=False), 632 | nn.BatchNorm2d(1024), 633 | nn.ReLU(), 634 | ), 635 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 636 | nn.BatchNorm2d(2048), 637 | ), 638 | nn.Sequential( # Sequential, 639 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 640 | nn.BatchNorm2d(2048), 641 | ), 642 | ), 643 | LambdaReduce(lambda x, y: x + y), # CAddTable, 644 | nn.ReLU(), 645 | ), 646 | nn.Sequential( # Sequential, 647 | LambdaMap(lambda x: x, # ConcatTable, 648 | nn.Sequential( # Sequential, 649 | nn.Sequential( # Sequential, 650 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 651 | nn.BatchNorm2d(1024), 652 | nn.ReLU(), 653 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (2, 2), (2, 2), 32, bias=False), 654 | nn.BatchNorm2d(1024), 655 | nn.ReLU(), 656 | ), 657 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 658 | nn.BatchNorm2d(2048), 659 | ), 660 | Lambda(lambda x: x), # Identity, 661 | ), 662 | LambdaReduce(lambda x, y: x + y), # CAddTable, 663 | nn.ReLU(), 664 | ), 665 | nn.Sequential( # Sequential, 666 | LambdaMap(lambda x: x, # ConcatTable, 667 | nn.Sequential( # Sequential, 668 | nn.Sequential( # Sequential, 669 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 670 | nn.BatchNorm2d(1024), 671 | nn.ReLU(), 672 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (2, 2), (2, 2), 32, bias=False), 673 | nn.BatchNorm2d(1024), 674 | nn.ReLU(), 675 | ), 676 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 677 | nn.BatchNorm2d(2048), 678 | ), 679 | Lambda(lambda x: x), # Identity, 680 | ), 681 | LambdaReduce(lambda x, y: x + y), # CAddTable, 682 | nn.ReLU(), 683 | ), 684 | ), 685 | nn.AvgPool2d((7, 7), (1, 1)), 686 | Lambda(lambda x: x.view(x.size(0), -1)), # View, 687 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear, 688 | ) 689 | --------------------------------------------------------------------------------