├── networks
├── resnext_modify
│ ├── __init__.py
│ ├── config.py
│ ├── resnext101_regular.py
│ ├── resnext101_5out.py
│ └── resnext_101_32x4d_.py
├── DeepLabV3.py
└── TVSD.py
├── joint_transforms.py
├── config.py
├── utils
├── IRNN_Forward_cuda.cu
└── IRNN_Backward_cuda.cu
├── misc.py
├── evaluate.py
├── infer.py
├── losses.py
├── README.md
├── dataset
└── VShadow_crosspairwise.py
└── train.py
/networks/resnext_modify/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnext101_regular import ResNeXt101
--------------------------------------------------------------------------------
/networks/resnext_modify/config.py:
--------------------------------------------------------------------------------
1 | resnext_101_32_path = '/home/chenzhihao/shadow_detection/shadow-MT/backbone_pth/resnext_101_32x4d.pth'
2 |
--------------------------------------------------------------------------------
/networks/resnext_modify/resnext101_regular.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .config import resnext_101_32_path
5 | from .resnext_101_32x4d_ import resnext_101_32x4d
6 | import torch._utils
7 |
8 |
9 | class ResNeXt101(nn.Module):
10 | def __init__(self, pretained=True):
11 | super(ResNeXt101, self).__init__()
12 | net = resnext_101_32x4d
13 | if pretained:
14 | net.load_state_dict(torch.load(resnext_101_32_path))
15 | net = list(net.children())
16 | self.layer0 = nn.Sequential(*net[:3])
17 | self.layer1 = nn.Sequential(*net[3: 5])
18 | self.layer2 = net[5]
19 | self.layer3 = net[6]
20 | self.layer4 = net[7]
21 |
22 | def forward(self, x):
23 | layer0 = self.layer0(x)
24 | layer1 = self.layer1(layer0)
25 | layer2 = self.layer2(layer1)
26 | layer3 = self.layer3(layer2)
27 | layer4 = self.layer4(layer3)
28 | return layer4
29 |
30 |
--------------------------------------------------------------------------------
/networks/resnext_modify/resnext101_5out.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .resnext_101_32x4d_ import resnext_101_32x4d
5 | from .config import resnext_101_32_path
6 |
7 | class ResNeXt101(nn.Module):
8 | def __init__(self):
9 | super(ResNeXt101, self).__init__()
10 | net = resnext_101_32x4d
11 | net.load_state_dict(torch.load(resnext_101_32_path))
12 | net = list(net.children())
13 | self.layer0 = nn.Sequential(*net[:3])
14 | self.layer1 = nn.Sequential(*net[3: 5])
15 | self.layer2 = net[5]
16 | self.layer3 = net[6]
17 | self.layer4 = net[7]
18 |
19 | def forward(self, x):
20 | layers = []
21 | layer0 = self.layer0(x)
22 | layers.append(layer0)
23 | layer1 = self.layer1(layer0)
24 | layers.append(layer1)
25 | layer2 = self.layer2(layer1)
26 | layers.append(layer2)
27 | layer3 = self.layer3(layer2)
28 | layers.append(layer3)
29 | layer4 = self.layer4(layer3)
30 | layers.append(layer4)
31 | return layers
32 |
--------------------------------------------------------------------------------
/joint_transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from PIL import Image
4 |
5 |
6 | class Compose(object):
7 | def __init__(self, transforms):
8 | self.transforms = transforms
9 |
10 | def __call__(self, img, mask, manual_random=None):
11 | assert img.size == mask.size
12 | for t in self.transforms:
13 | img, mask = t(img, mask, manual_random)
14 | return img, mask
15 |
16 |
17 | class RandomHorizontallyFlip(object):
18 | def __call__(self, img, mask, manual_random=None):
19 | if manual_random is None:
20 | if random.random() < 0.5:
21 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)
22 | return img, mask
23 | else:
24 | if manual_random < 0.5:
25 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)
26 | return img, mask
27 |
28 |
29 | class Resize(object):
30 | def __init__(self, size):
31 | self.size = tuple(reversed(size)) # size: (h, w)
32 |
33 | def __call__(self, img, mask, manual_random=None):
34 | assert img.size == mask.size
35 | return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST)
36 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | '''
3 | Dataset root form: ($path, $type, $name)
4 | path: dataset path
5 | type: "image" or "video"
6 | name: just for easy mark
7 | '''
8 | ### Saliency datasets
9 | # DUT_OMRON_training_root = ('/home/ext/chenzhihao/Datasets/saliency_dataset/DUT-OMRON', 'image', 'DUT-OMRON')
10 | # MSRA10K_training_root = ('/home/ext/chenzhihao/Datasets/saliency_dataset/MSRA10K', 'image', 'MSRA10K')
11 | # DAVIS_training_root = ('/home/ext/chenzhihao/Datasets/saliency_dataset/DAVIS_train', 'video', 'DAVIS_train')
12 | # DAVIS_validation_root = ('/home/ext/chenzhihao/Datasets/saliency_dataset/DAVIS_val', 'video', 'DAVIS_val')
13 |
14 | # Shadow datasets
15 | # SBU_training_root = ('/home/ext/chenzhihao/Datasets/SBU-shadow/SBUTrain4KRecoveredSmall', 'image', 'SBU_train')
16 | # SBU_testing_root = ('/home/ext/chenzhihao/Datasets/SBU-shadow/SBU-Test', 'image', 'SBU_test')
17 | ViSha_training_root = ('/home/ext/chenzhihao/Datasets/ViSha/train', 'video', 'ViSD_train')
18 | ViSha_validation_root = ('/home/ext/chenzhihao/Datasets/ViSha/test', 'video', 'ViSD_test')
19 |
20 |
21 | '''
22 | Pretrained single model path
23 | '''
24 | # PDBM_single_path = '/home/ext/chenzhihao/code/video_shadow/models_saliency/PDBM_single_256/50000.pth'
25 | # DeepLabV3_path = '/home/ext/chenzhihao/code/video_shadow/models/deeplabv3/20.pth'
26 | # FPN_path = '/home/ext/chenzhihao/code/video_shadow/models/FPN/20.pth'
--------------------------------------------------------------------------------
/utils/IRNN_Forward_cuda.cu:
--------------------------------------------------------------------------------
1 | #define CUDA_KERNEL_LOOP(i, n) \
2 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
3 | i < (n); \
4 | i += blockDim.x * gridDim.x)
5 |
6 | #define INDEX(b,c,h,w,channels,height,width) ((b * channels + c) * height + h) * width+ w
7 |
8 | extern "C" __global__ void IRNNForward(
9 | const float* input_feature,
10 |
11 | const float* weight_up,
12 | const float* weight_right,
13 | const float* weight_down,
14 | const float* weight_left,
15 |
16 | const float* bias_up,
17 | const float* bias_right,
18 | const float* bias_down,
19 | const float* bias_left,
20 |
21 | float* output_up,
22 | float* output_right,
23 | float* output_down,
24 | float* output_left,
25 |
26 | const int channels,
27 | const int height,
28 | const int width,
29 | const int n){
30 |
31 | CUDA_KERNEL_LOOP(index,n){
32 | int w = index % width;
33 | int h = index / width % height;
34 | int c = index / width / height % channels;
35 | int b = index / width / height / channels;
36 |
37 | float temp = 0;
38 |
39 | // left
40 | output_left[index] = input_feature[INDEX(b, c, h, width-1, channels, height, width)] > 0 ? input_feature[INDEX(b, c, h, width-1, channels, height, width)] : 0;
41 | for (int i = width-2; i>=w; i--)
42 | {
43 | temp = output_left[index] * weight_left[c] + bias_left[c] + input_feature[INDEX(b, c, h, i, channels, height, width)];
44 | output_left[index] = (temp > 0)? temp : 0;
45 | }
46 |
47 | // right
48 | output_right[index] = input_feature[INDEX(b, c, h, 0, channels, height, width)] > 0 ? input_feature[INDEX(b, c, h, 0, channels, height, width)] : 0;
49 | for (int i = 1; i <= w; i++)
50 | {
51 | temp = output_right[index] * weight_right[c] + bias_right[c] + input_feature[INDEX(b, c, h, i, channels, height, width)];
52 | output_right[index] = (temp > 0)? temp : 0;
53 | }
54 |
55 | // up
56 | output_up[index] = input_feature[INDEX(b,c,height-1,w,channels,height,width)] > 0 ? input_feature[INDEX(b,c,height-1,w,channels,height,width)] : 0;
57 | for (int i = height-2; i >= h; i--)
58 | {
59 | temp = output_up[index] * weight_up[c] + bias_up[c] + input_feature[INDEX(b, c, i, w, channels, height, width)];
60 | output_up[index] = (temp > 0)? temp : 0;
61 | }
62 |
63 | // down
64 | output_down[index] = input_feature[INDEX(b, c, 0, w, channels, height, width)] > 0 ? input_feature[INDEX(b, c, 0, w, channels, height, width)] : 0;
65 | for (int i = 1; i <= h; i++)
66 | {
67 | temp = output_down[index] * weight_down[c] + bias_down[c] + input_feature[INDEX(b, c, i, w, channels, height, width)];
68 | output_down[index] = (temp > 0)? temp : 0;
69 | }
70 | }
71 | }
--------------------------------------------------------------------------------
/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from medpy import metric
4 | import torch
5 |
6 |
7 | class AvgMeter(object):
8 | def __init__(self):
9 | self.reset()
10 |
11 | def reset(self):
12 | self.val = 0
13 | self.avg = 0
14 | self.sum = 0
15 | self.count = 0
16 |
17 | def update(self, val, n=1):
18 | self.val = val
19 | self.sum += val * n
20 | self.count += n
21 | self.avg = self.sum / self.count
22 |
23 |
24 | def check_mkdir(dir_name):
25 | if not os.path.exists(dir_name):
26 | os.makedirs(dir_name)
27 |
28 |
29 | def _sigmoid(x):
30 | return 1 / (1 + np.exp(-x))
31 |
32 | def cal_precision_recall_mae(prediction, gt):
33 | # input should be np array with data type uint8
34 | assert prediction.dtype == np.uint8
35 | assert gt.dtype == np.uint8
36 | assert prediction.shape == gt.shape
37 |
38 | eps = 1e-4
39 |
40 | prediction = prediction / 255.
41 | gt = gt / 255.
42 |
43 | mae = np.mean(np.abs(prediction - gt))
44 |
45 | hard_gt = np.zeros(prediction.shape)
46 | hard_gt[gt > 0.5] = 1
47 | t = np.sum(hard_gt)
48 |
49 | precision, recall = [], []
50 | # calculating precision and recall at 255 different binarizing thresholds
51 | for threshold in range(256):
52 | threshold = threshold / 255.
53 |
54 | hard_prediction = np.zeros(prediction.shape)
55 | hard_prediction[prediction > threshold] = 1
56 |
57 | tp = np.sum(hard_prediction * hard_gt)
58 | p = np.sum(hard_prediction)
59 |
60 | precision.append((tp + eps) / (p + eps))
61 | recall.append((tp + eps) / (t + eps))
62 |
63 | return precision, recall, mae
64 |
65 |
66 | def cal_fmeasure(precision, recall):
67 | assert len(precision) == 256
68 | assert len(recall) == 256
69 | beta_square = 0.3
70 | max_fmeasure = max([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)])
71 |
72 | return max_fmeasure
73 |
74 | def cal_Jaccard(prediction, gt):
75 | # input should be np array with data type uint8
76 | assert prediction.dtype == np.uint8
77 | assert gt.dtype == np.uint8
78 | assert prediction.shape == gt.shape
79 |
80 | prediction = prediction / 255.
81 | gt = gt / 255.
82 |
83 | pred = (prediction > 0.5)
84 | gt = (gt > 0.5)
85 | Jaccard = metric.binary.jc(pred, gt)
86 |
87 | return Jaccard
88 |
89 | def cal_BER(prediction, label, thr = 127.5):
90 | prediction = (prediction > thr)
91 | label = (label > thr)
92 | prediction_tmp = prediction.astype(np.float)
93 | label_tmp = label.astype(np.float)
94 | TP = np.sum(prediction_tmp * label_tmp)
95 | TN = np.sum((1 - prediction_tmp) * (1 - label_tmp))
96 | Np = np.sum(label_tmp)
97 | Nn = np.sum((1-label_tmp))
98 | BER = 0.5 * (2 - TP / Np - TN / Nn) * 100
99 | shadow_BER = (1 - TP / Np) * 100
100 | non_shadow_BER = (1 - TN / Nn) * 100
101 |
102 | return BER, shadow_BER, non_shadow_BER
103 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from PIL import Image
4 | from misc import check_mkdir, cal_precision_recall_mae, AvgMeter, cal_fmeasure, cal_Jaccard, cal_BER
5 | from tqdm import tqdm
6 | import argparse
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--models', type=str, default='TVSD', help='model name')
10 | parser.add_argument('--snapshot', type=str, default='12', help='model name')
11 | tmp_args = parser.parse_args()
12 |
13 | root_path = f'/home/ext/chenzhihao/code/video_shadow/models/{tmp_args.models}/predict_{tmp_args.snapshot}'
14 | save_path = f'/home/ext/chenzhihao/code/video_shadow/models/{tmp_args.models}/predict_fuse_{tmp_args.snapshot}'
15 |
16 | gt_path = '/home/ext/chenzhihao/Datasets/ViSha/test/labels'
17 | input_path = '/home/ext/chenzhihao/Datasets/ViSha/test/images'
18 |
19 | precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
20 | mae_record = AvgMeter()
21 | Jaccard_record = AvgMeter()
22 | BER_record = AvgMeter()
23 | shadow_BER_record = AvgMeter()
24 | non_shadow_BER_record = AvgMeter()
25 |
26 | video_list = os.listdir(root_path)
27 | for video in tqdm(video_list):
28 | gt_list = os.listdir(os.path.join(gt_path, video))
29 | img_list = [f for f in os.listdir(os.path.join(root_path, video)) if f.split('_', 1)[0]+'.png' in gt_list] # include overlap images
30 | img_set = list(set([img.split('_', 1)[0] for img in img_list])) # remove repeat
31 | for img_prefix in img_set:
32 | # jump exist images
33 | check_mkdir(os.path.join(save_path, video))
34 | save_name = os.path.join(save_path, video, '{}.png'.format(img_prefix))
35 | # if not os.path.exists(os.path.join(save_path, video, save_name)):
36 | imgs = [img for img in img_list if img.split('_', 1)[0] == img_prefix] # imgs waited for fuse
37 | fuse = []
38 | for img_path in imgs:
39 | img = np.array(Image.open(os.path.join(root_path, video, img_path)).convert('L')).astype(np.float32)
40 | # if np.max(img) > 0: # normalize prediction mask
41 | # img = (img - np.min(img)) / (np.max(img) - np.min(img)) * 255
42 | fuse.append(img)
43 | fuse = (sum(fuse) / len(imgs)).astype(np.uint8)
44 | # save image
45 | print(f'Save:{save_name}')
46 | Image.fromarray(fuse).save(save_name)
47 | # else:
48 | # print(f'Exist:{save_name}')
49 | # fuse = np.array(Image.open(save_name).convert('L')).astype(np.uint8)
50 | # calculate metric
51 | gt = np.array(Image.open(os.path.join(gt_path, video, img_prefix+'.png')))
52 | precision, recall, mae = cal_precision_recall_mae(fuse, gt)
53 | Jaccard = cal_Jaccard(fuse, gt)
54 | Jaccard_record.update(Jaccard)
55 | BER, shadow_BER, non_shadow_BER = cal_BER(fuse, gt)
56 | BER_record.update(BER)
57 | shadow_BER_record.update(shadow_BER)
58 | non_shadow_BER_record.update(non_shadow_BER)
59 | for pidx, pdata in enumerate(zip(precision, recall)):
60 | p, r = pdata
61 | precision_record[pidx].update(p)
62 | recall_record[pidx].update(r)
63 | mae_record.update(mae)
64 |
65 | fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
66 | [rrecord.avg for rrecord in recall_record])
67 | log = 'MAE:{}, F-beta:{}, Jaccard:{}, BER:{}, SBER:{}, non-SBER:{}'.format(mae_record.avg, fmeasure, Jaccard_record.avg, BER_record.avg, shadow_BER_record.avg, non_shadow_BER_record.avg)
68 | print(log)
69 |
70 |
71 |
--------------------------------------------------------------------------------
/networks/DeepLabV3.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from .resnext_modify import resnext101_regular
5 |
6 | # DeeplabV3 plus
7 | class DeepLabV3(nn.Module):
8 | def __init__(self, num_classes=1, pretrained=True):
9 | super(DeepLabV3, self).__init__()
10 | self.backbone = resnext101_regular.ResNeXt101()
11 | aspp_dilate = [12, 24, 36]
12 | # aspp_dilate = [6, 12, 18]
13 | self.aspp = ASPP(2048, aspp_dilate)
14 | self.final_pre = nn.Conv2d(256, 1, 1)
15 | initialize_weights(self.aspp, self.final_pre)
16 |
17 | def forward(self, x):
18 | # x shape: B, C, W, H
19 | x_size = x.size()
20 | x0 = self.backbone.layer0(x)
21 | x1 = self.backbone.layer1(x0) # x1 is lower feature
22 | x2 = self.backbone.layer2(x1)
23 | x3 = self.backbone.layer3(x2)
24 | x4 = self.backbone.layer4(x3)
25 | fea = self.aspp(x4)
26 | predict = self.final_pre(fea)
27 | predict = F.interpolate(predict, size=x_size[2:], mode='bilinear', align_corners=False)
28 | return x1, fea, predict
29 |
30 | class ASPP(nn.Module):
31 | def __init__(self, in_channels, atrous_rates):
32 | super(ASPP, self).__init__()
33 | out_channels = 256
34 | modules = []
35 | modules.append(nn.Sequential(
36 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
37 | nn.BatchNorm2d(out_channels),
38 | nn.ReLU(inplace=True)))
39 |
40 | rate1, rate2, rate3 = tuple(atrous_rates)
41 | modules.append(ASPPConv(in_channels, out_channels, rate1))
42 | modules.append(ASPPConv(in_channels, out_channels, rate2))
43 | modules.append(ASPPConv(in_channels, out_channels, rate3))
44 | modules.append(ASPPPooling(in_channels, out_channels))
45 |
46 | self.convs = nn.ModuleList(modules)
47 |
48 | self.project = nn.Sequential(
49 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
50 | nn.BatchNorm2d(out_channels),
51 | nn.ReLU(inplace=True),
52 | nn.Dropout(0.1))
53 |
54 | def forward(self, x):
55 | res = []
56 | for conv in self.convs:
57 | res.append(conv(x))
58 | res = torch.cat(res, dim=1)
59 | return self.project(res)
60 |
61 | class ASPPConv(nn.Sequential):
62 | def __init__(self, in_channels, out_channels, dilation):
63 | modules = [
64 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
65 | nn.BatchNorm2d(out_channels),
66 | nn.ReLU(inplace=True)
67 | ]
68 | super(ASPPConv, self).__init__(*modules)
69 |
70 | class ASPPPooling(nn.Sequential):
71 | def __init__(self, in_channels, out_channels):
72 | super(ASPPPooling, self).__init__(
73 | nn.AdaptiveAvgPool2d(1),
74 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
75 | nn.BatchNorm2d(out_channels),
76 | nn.ReLU(inplace=True))
77 |
78 | def forward(self, x):
79 | size = x.shape[-2:]
80 | x = super(ASPPPooling, self).forward(x)
81 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
82 |
83 | def initialize_weights(*models):
84 | for model in models:
85 | for module in model.modules():
86 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
87 | nn.init.kaiming_normal_(module.weight)
88 | if module.bias is not None:
89 | module.bias.data.zero_()
90 | elif isinstance(module, nn.BatchNorm2d):
91 | module.weight.data.fill_(1)
92 | module.bias.data.zero_()
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | import torch
5 | from PIL import Image
6 | from torchvision import transforms
7 |
8 | from config import ViSha_validation_root
9 | from misc import check_mkdir
10 | from networks.TVSD import TVSD
11 | import argparse
12 | from tqdm import tqdm
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--snapshot', type=str, default='2', help='snapshot')
16 | parser.add_argument('--models', type=str, default='TVSD', help='model name')
17 | tmp_args = parser.parse_args()
18 |
19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
20 |
21 | ckpt_path = './models'
22 | exp_name = tmp_args.models
23 | args = {
24 | 'snapshot': tmp_args.snapshot,
25 | 'scale': 416,
26 | 'test_adjacent': 5,
27 | 'input_folder': 'images',
28 | 'label_folder': 'labels'
29 | }
30 |
31 | img_transform = transforms.Compose([
32 | transforms.Resize((args['scale'], args['scale'])),
33 | transforms.ToTensor(),
34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
35 | ])
36 | target_transform = transforms.ToTensor()
37 |
38 | root = ViSha_validation_root[0]
39 |
40 | to_pil = transforms.ToPILImage()
41 |
42 |
43 | def main():
44 | net = TVSD().cuda()
45 |
46 | if len(args['snapshot']) > 0:
47 | check_point = torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))
48 | net.load_state_dict(check_point['model'])
49 |
50 | net.eval()
51 | with torch.no_grad():
52 | video_list = os.listdir(os.path.join(root, args['input_folder']))
53 | for video in tqdm(video_list):
54 | # all images
55 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, args['input_folder'], video)) if
56 | f.endswith('.jpg')]
57 | # need evaluation images
58 | img_eval_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, args['label_folder'], video)) if
59 | f.endswith('.png')]
60 |
61 | img_eval_list = sortImg(img_eval_list)
62 | for exemplar_idx, exemplar_name in enumerate(img_eval_list):
63 | query_idx_list = getAdjacentIndex(exemplar_idx, 0, len(img_list), args['test_adjacent'])
64 | for query_idx in query_idx_list:
65 | exemplar = Image.open(os.path.join(root, args['input_folder'], video, exemplar_name + '.jpg')).convert('RGB')
66 | w, h = exemplar.size
67 | query = Image.open(os.path.join(root, args['input_folder'], video, img_list[query_idx] + '.jpg')).convert('RGB')
68 | exemplar_tensor = img_transform(exemplar).unsqueeze(0).cuda()
69 | query_tensor = img_transform(query).unsqueeze(0).cuda()
70 | exemplar_pre, _, _, _ = net(exemplar_tensor, query_tensor, query_tensor)
71 | res = (exemplar_pre.data > 0).to(torch.float32)
72 | prediction = np.array(
73 | transforms.Resize((h, w))(to_pil(res.squeeze(0).cpu())))
74 | check_mkdir(os.path.join(ckpt_path, exp_name, "predict_" + args['snapshot'], video))
75 | # save form as 00000001_1.png, 000000001_2.png
76 | save_name = f"{exemplar_name}_by{query_idx}.png"
77 | print(os.path.join(ckpt_path, exp_name, "predict_" + args['snapshot'], video, save_name))
78 | Image.fromarray(prediction).save(
79 | os.path.join(ckpt_path, exp_name, "predict_" + args['snapshot'], video, save_name))
80 |
81 |
82 | def sortImg(img_list):
83 | img_int_list = [int(f) for f in img_list]
84 | sort_index = [i for i, v in sorted(enumerate(img_int_list), key=lambda x: x[1])] # sort img to 001,002,003...
85 | return [img_list[i] for i in sort_index]
86 |
87 |
88 | def getAdjacentIndex(current_index, start_index, video_length, adjacent_length):
89 | if current_index + adjacent_length < start_index + video_length:
90 | query_index_list = [current_index+i+1 for i in range(adjacent_length)]
91 | else:
92 | query_index_list = [current_index-i-1 for i in range(adjacent_length)]
93 | return query_index_list
94 |
95 | if __name__ == '__main__':
96 | main()
97 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 | from torch.autograd import Variable
6 | try:
7 | from itertools import ifilterfalse
8 | except ImportError: # py3k
9 | from itertools import filterfalse as ifilterfalse
10 |
11 | def bce2d_new(input, target, reduction='mean'):
12 | assert(input.size() == target.size())
13 | pos = torch.eq(target, 1).float()
14 | neg = torch.eq(target, 0).float()
15 | # ing = ((torch.gt(target, 0) & torch.lt(target, 1))).float()
16 |
17 | num_pos = torch.sum(pos)
18 | num_neg = torch.sum(neg)
19 | num_total = num_pos + num_neg
20 |
21 | alpha = num_neg / num_total
22 | beta = 1.1 * num_pos / num_total
23 | # target pixel = 1 -> weight beta
24 | # target pixel = 0 -> weight 1-beta
25 | weights = alpha * pos + beta * neg
26 |
27 | return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction)
28 |
29 | def BCE_IOU(pred, mask):
30 | weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
31 | wbce = F.binary_cross_entropy_with_logits(pred, mask)
32 | wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
33 |
34 | pred = torch.sigmoid(pred)
35 | inter = ((pred * mask)*weit).sum(dim=(2, 3))
36 | union = ((pred + mask)*weit).sum(dim=(2, 3))
37 | wiou = 1 - (inter + 1)/(union - inter+1)
38 | return wbce.mean(), wiou.mean()
39 |
40 | # --------------------------- BINARY Lovasz LOSSES ---------------------------
41 | def lovasz_hinge(logits, labels, per_image=True, ignore=None):
42 | """
43 | Binary Lovasz hinge loss
44 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
45 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
46 | per_image: compute the loss per image instead of per batch
47 | ignore: void class id
48 | """
49 | if per_image:
50 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
51 | for log, lab in zip(logits, labels))
52 | else:
53 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
54 | return loss
55 |
56 |
57 | def lovasz_hinge_flat(logits, labels):
58 | """
59 | Binary Lovasz hinge loss
60 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
61 | labels: [P] Tensor, binary ground truth labels (0 or 1)
62 | ignore: label to ignore
63 | """
64 | if len(labels) == 0:
65 | # only void pixels, the gradients should be 0
66 | return logits.sum() * 0.
67 | signs = 2. * labels.float() - 1.
68 | errors = (1. - logits * Variable(signs))
69 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
70 | perm = perm.data
71 | gt_sorted = labels[perm]
72 | grad = lovasz_grad(gt_sorted)
73 | loss = torch.dot(F.relu(errors_sorted), Variable(grad))
74 | return loss
75 |
76 |
77 | def flatten_binary_scores(scores, labels, ignore=None):
78 | """
79 | Flattens predictions in the batch (binary case)
80 | Remove labels equal to 'ignore'
81 | """
82 | scores = scores.view(-1)
83 | labels = labels.view(-1)
84 | if ignore is None:
85 | return scores, labels
86 | valid = (labels != ignore)
87 | vscores = scores[valid]
88 | vlabels = labels[valid]
89 | return vscores, vlabels
90 |
91 |
92 | class StableBCELoss(torch.nn.modules.Module):
93 | def __init__(self):
94 | super(StableBCELoss, self).__init__()
95 | def forward(self, input, target):
96 | neg_abs = - input.abs()
97 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
98 | return loss.mean()
99 |
100 |
101 | def binary_xloss(logits, labels, ignore=None):
102 | """
103 | Binary Cross entropy loss
104 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
105 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
106 | ignore: void class id
107 | """
108 | logits, labels = flatten_binary_scores(logits, labels, ignore)
109 | loss = StableBCELoss()(logits, Variable(labels.float()))
110 | return loss
111 |
112 | def lovasz_grad(gt_sorted):
113 | """
114 | Computes gradient of the Lovasz extension w.r.t sorted errors
115 | See Alg. 1 in paper
116 | """
117 | p = len(gt_sorted)
118 | gts = gt_sorted.sum()
119 | intersection = gts - gt_sorted.float().cumsum(0)
120 | union = gts + (1 - gt_sorted).float().cumsum(0)
121 | jaccard = 1. - intersection / union
122 | if p > 1: # cover 1-pixel case
123 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
124 | return jaccard
125 |
126 | def mean(l, ignore_nan=False, empty=0):
127 | """
128 | nanmean compatible with generators.
129 | """
130 | l = iter(l)
131 | if ignore_nan:
132 | l = ifilterfalse(isnan, l)
133 | try:
134 | n = 1
135 | acc = next(l)
136 | except StopIteration:
137 | if empty == 'raise':
138 | raise ValueError('Empty mean')
139 | return empty
140 | for n, v in enumerate(l, 2):
141 | acc += v
142 | if n == 1:
143 | return acc
144 | return acc / n
145 |
146 | def isnan(x):
147 | return x != x
--------------------------------------------------------------------------------
/utils/IRNN_Backward_cuda.cu:
--------------------------------------------------------------------------------
1 | #define CUDA_KERNEL_LOOP(i, n) \
2 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
3 | i < (n); \
4 | i += blockDim.x * gridDim.x)
5 |
6 | #define INDEX(b,c,h,w,channels,height,width) ((b * channels + c) * height + h) * width+ w
7 |
8 |
9 | extern "C" __global__ void IRNNBackward(
10 | float* grad_input,
11 |
12 | float* grad_weight_up_map,
13 | float* grad_weight_right_map,
14 | float* grad_weight_down_map,
15 | float* grad_weight_left_map,
16 |
17 | float* grad_bias_up_map,
18 | float* grad_bias_right_map,
19 | float* grad_bias_down_map,
20 | float* grad_bias_left_map,
21 |
22 | const float* weight_up,
23 | const float* weight_right,
24 | const float* weight_down,
25 | const float* weight_left,
26 |
27 | const float* grad_output_up,
28 | const float* grad_output_right,
29 | const float* grad_output_down,
30 | const float* grad_output_left,
31 |
32 | const float* output_up,
33 | const float* output_right,
34 | const float* output_down,
35 | const float* output_left,
36 |
37 | const int channels,
38 | const int height,
39 | const int width,
40 | const int n) {
41 |
42 | CUDA_KERNEL_LOOP(index,n){
43 |
44 | int w = index % width;
45 | int h = index / width % height;
46 | int c = index / width / height % channels;
47 | int b = index / width / height / channels;
48 |
49 | float diff_left = 0;
50 | float diff_right = 0;
51 | float diff_up = 0;
52 | float diff_down = 0;
53 |
54 | //left
55 |
56 | for (int i = 0; i<=w; i++)
57 | {
58 | diff_left *= weight_left[c];
59 | diff_left += grad_output_left[INDEX(b, c, h, i, channels, height, width)];
60 | diff_left *= (output_left[INDEX(b, c, h, i, channels, height, width)]<=0)? 0 : 1;
61 | }
62 |
63 |
64 | float temp = grad_output_left[INDEX(b, c, h, 0, channels, height, width)];
65 | for (int i = 1; i < w +1 ; i++)
66 | {
67 | temp = (output_left[INDEX(b, c, h, i-1, channels, height, width)] >0?1:0) * temp * weight_left[c] + grad_output_left[INDEX(b, c, h, i, channels, height, width)];
68 | }
69 |
70 | if (w != width - 1){
71 | grad_weight_left_map[index] = temp * output_left[INDEX(b, c, h, w+1, channels, height, width)] * (output_left[index] > 0? 1:0);
72 | grad_bias_left_map[index] = diff_left;
73 | }
74 |
75 | // right
76 |
77 | for (int i = width -1; i>=w; i--)
78 | {
79 | diff_right *= weight_right[c];
80 | diff_right += grad_output_right[INDEX(b, c, h, i, channels, height, width)];
81 | diff_right *= (output_right[INDEX(b, c, h, i, channels, height, width)]<=0)? 0 : 1;
82 | }
83 |
84 |
85 | temp = grad_output_right[INDEX(b, c, h, width-1, channels, height, width)];
86 | for (int i = width -2; i > w - 1 ; i--)
87 | {
88 | temp = (output_right[INDEX(b, c, h, i+1, channels, height, width)] >0?1:0) * temp * weight_right[c] + grad_output_right[INDEX(b, c, h, i, channels, height, width)];
89 | }
90 |
91 | if (w != 0){
92 | grad_weight_right_map[index] = temp * output_right[INDEX(b, c, h, w-1, channels, height, width)] * (output_right[index] > 0? 1:0);
93 | grad_bias_right_map[index] = diff_right;
94 | }
95 |
96 | // up
97 |
98 |
99 | for (int i = 0; i<=h; i++)
100 | {
101 | diff_up *= weight_up[c];
102 | diff_up += grad_output_up[INDEX(b, c, i, w, channels, height, width)];
103 | diff_up *= (output_up[INDEX(b, c, i, w, channels, height, width)]<=0)? 0 : 1;
104 | }
105 |
106 |
107 | temp = grad_output_up[INDEX(b, c, 0, w, channels, height, width)];
108 | for (int i = 1; i < h +1 ; i++)
109 | {
110 | temp = (output_up[INDEX(b, c, i-1, w, channels, height, width)] >0?1:0) * temp * weight_up[c] + grad_output_up[INDEX(b, c, i, w, channels, height, width)];
111 | }
112 |
113 | if (h != height - 1){
114 | grad_weight_up_map[index] = temp * output_up[INDEX(b, c, h+1, w, channels, height, width)] * (output_up[index] > 0? 1:0);
115 | grad_bias_up_map[index] = diff_up;
116 | }
117 |
118 | // down
119 |
120 | for (int i = height -1; i>=h; i--)
121 | {
122 | diff_down *= weight_down[c];
123 | diff_down += grad_output_down[INDEX(b, c, i, w, channels, height, width)];
124 | diff_down *= (output_down[INDEX(b, c, i, w, channels, height, width)]<=0)? 0 : 1;
125 | }
126 |
127 |
128 | temp = grad_output_down[INDEX(b, c, height-1, w, channels, height, width)];
129 | for (int i = height -2; i > h - 1 ; i--)
130 | {
131 | temp = (output_down[INDEX(b, c, i+1, w, channels, height, width)] >0?1:0) * temp * weight_down[c] + grad_output_down[INDEX(b, c, i, w, channels, height, width)];
132 | }
133 |
134 | if (h != 0){
135 | grad_weight_down_map[index] = temp * output_down[INDEX(b, c, h-1, w, channels, height, width)] * (output_down[index] > 0? 1:0);
136 | grad_bias_down_map[index] = diff_down;
137 | }
138 | grad_input[index] = diff_down + diff_left + diff_right + diff_up;
139 | }
140 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Triple-cooperative Video Shadow Detection
2 | Code and dataset for the CVPR 2021 paper **"Triple-cooperative Video Shadow Detection"**[[arXiv link](https://arxiv.org/abs/2103.06533)] [[official link](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Triple-Cooperative_Video_Shadow_Detection_CVPR_2021_paper.pdf)].
3 | by Zhihao Chen1, Liang Wan1, Lei Zhu2, Jia Shen1, Huazhu Fu3, Wennan Liu4, and Jing Qin5
4 | 1College of Intelligence and Computing, Tianjin University
5 | 2Department of Applied Mathematics and Theoretical Physics, University of Cambridge
6 | 3Inception Institute of Artificial Intelligence, UAE
7 | 4Academy of Medical Engineering and Translational Medicine, Tianjin University
8 | 5The Hong Kong Polytechnic University
9 |
10 | #### News: In 2021.4.7, We first release the code of TVSD and ViSha dataset.
11 | #### News: In 2022.5.7, [Lihao Liu](https://github.com/lihaoliu-cambridge/video-shadow-detection) publish a pytorch-lightning implementation for TVSD.
12 |
13 | ***
14 |
15 | ## Citation
16 | @inproceedings{chen21TVSD,
17 | author = {Chen, Zhihao and Wan, Liang and Zhu, Lei and Shen, Jia and Fu, Huazhu and Liu, Wennan and Qin, Jing},
18 | title = {Triple-cooperative Video Shadow Detection},
19 | booktitle = {CVPR},
20 | year = {2021}
21 | }
22 |
23 | ## Pytorch-lightning Version
24 | Pytorch-lightning Version is available at [https://github.com/lihaoliu-cambridge/video-shadow-detection](https://github.com/lihaoliu-cambridge/video-shadow-detection) implemented by [Lihao Liu](https://github.com/lihaoliu-cambridge)
25 |
26 | ## Dataset
27 | ViSha dataset is available at **[ViSha Homepage](https://erasernut.github.io/ViSha.html)**
28 |
29 | ## Requirement
30 | * Python 3.6
31 | * PyTorch 1.3.1
32 | * torchvision
33 | * numpy
34 | * tqdm
35 | * PIL
36 | * math
37 | * time
38 | * datatime
39 | * argparse
40 | * apex (alternative, fp16 for save memory and speedup)
41 |
42 | ## Training
43 | 1. Modify the data path on ```./config.py```
44 | 2. Modify the pretrained backbone path on ```./networks/resnext_modify/config.py```
45 | 3. Run by ```python train.py``` and model will be saved in ```./models/TVSD```
46 |
47 | The pretrained ResNeXt model is ported from the [official](https://github.com/facebookresearch/ResNeXt) torch version,
48 | using the [convertor](https://github.com/clcarwin/convert_torch_to_pytorch) provided by clcarwin.
49 | You can directly [download](https://drive.google.com/open?id=1dnH-IHwmu9xFPlyndqI6MfF4LvH6JKNQ) the pretrained model ported by us.
50 |
51 | ## Testing
52 | 1. Modify the data path on ```./config.py```
53 | 2. Make sure you have a snapshot in ```./models/TVSD``` (Tips: You can download the trained model which is reported in our paper at [BaiduNetdisk](https://pan.baidu.com/s/17d-wLwA5oyafMdooJlesyw)(pw: 8p5h) or [Google Drive](https://drive.google.com/file/d/14dSMN6P7fUyL_KOubaXOAUZp_Dc0tFzq/view?usp=sharing))
54 | 4. Run by ```python infer.py``` to generate predicted masks
55 | 5. Run by ```python evaluate.py``` to evaluate the generated results
56 |
57 | ## Results in ViSha testing set
58 | As mentioned in our paper, since there is no CNN-based method for video shadow detection, we make comparison against 12 state-of-the-art methods for relevant tasks, including BDRAR[1], DSD[2], MTMT[3] (single-image shadow detection), FPN[4], PSPNet[5] (single-image semantic segmentation), DSS[6], R^3 Net[7] (single-image saliency detection), PDBM[8], MAG[9] (video saliency detection), COSNet[10], FEELVOS[11], STM[12] (object object segmentation)
59 | [1]L. Zhu, Z. Deng, X. Hu, C.-W. Fu, X. Xu, J. Qin, and P.-A. Heng. Bidirectional feature pyramid network with recurrent attention residual modules for shadow detection. In ECCV, pages 121–136, 2018.
60 | [2]Q. Zheng, X. Qiao, Y. Cao, and R.W. Lau. Distraction-aware shadow detection. In CVPR, pages 5167–5176, 2019.
61 | [3]Z. Chen, L. Zhu, L. Wan, S. Wang, W. Feng, and P.-A. Heng. A multi-task mean teacher for semi-supervised shadow detection. In CVPR, pages 5611–5620, 2020.
62 | [4]T.-Y. Lin, P. Doll´ar, R. Girshick, K. He, B. Hariharan, and S.Belongie. Feature pyramid networks for object detection. In CVPR, pages 2117–2125, 2017.
63 | [5]H. Zhao, J. Shi, X. Qi, X. Wang, and J. Jia. Pyramid scene parsing network. In CVPR, pages 2881–2890, 2017.
64 | [6]Q. Hou, M. Cheng, X. Hu, A. Borji, Z. Tu, and P. Torr. Deeply supervised salient object detection with short connections. IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(4):815–828, 2019.
65 | [7]Z. Deng, X. Hu, L. Zhu, X. Xu, J. Qin, G. Han, and P.-A. Heng. R3net: Recurrent residual refinement network for saliency detection. In IJCAI, pages 684–690. AAAI Press, 2018.
66 | [8]H. Song, W. Wang, S. Zhao, J. Shen, and K.-M. Lam. Pyramid dilated deeper convlstm for video salient object detection. In ECCV, pages 715–731, 2018.
67 | [9]H. Li, G. Chen, G. Li, and Y. Yu. Motion guided attention for video salient object detection. In ICCV, pages 7274–7283, 2019.
68 | [10]X. Lu, W. Wang, C. Ma, J. Shen, L. Shao, and F. Porikli. See more, know more: Unsupervised video object segmentation with co-attention siamese networks. In CVPR, pages 3623–3632, 2019.
69 | [11]P. Voigtlaender, Y. Chai, F. Schroff, H. Adam, B. Leibe, and L.-C. Chen. Feelvos: Fast end-to-end embedding learning for video object segmentation. In CVPR, June 2019.
70 | [12]S.W. Oh, J.-Y. Lee, N. Xu, and S.J. Kim. Video object segmentation using space-time memory networks. In ICCV, pages 9226–9235, 2019.
71 |
72 | We evaluate those methods and our TVSD in ViSha testing set and release all results in [BaiduNetdisk](https://pan.baidu.com/s/1t_PgW3JCrTGvf_PVyeR-iw)(pw: ritw) or [Google Drive](https://drive.google.com/drive/folders/13XgGxu9DDuuz2vS6ugFrLLmXRZzoHWhb?usp=sharing)
73 |
--------------------------------------------------------------------------------
/dataset/VShadow_crosspairwise.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 |
4 | import torch.utils.data as data
5 | from PIL import Image
6 | import random
7 | import torch
8 | import numpy as np
9 |
10 |
11 | # return image triple pairs in video and return single image
12 | class CrossPairwiseImg(data.Dataset):
13 | def __init__(self, root, joint_transform=None, img_transform=None, target_transform=None):
14 | self.img_root, self.video_root = self.split_root(root)
15 | self.joint_transform = joint_transform
16 | self.img_transform = img_transform
17 | self.target_transform = target_transform
18 | self.input_folder = 'images'
19 | self.label_folder = 'labels'
20 | self.img_ext = '.jpg'
21 | self.label_ext = '.png'
22 | self.num_video_frame = 0
23 | # get all frames from video datasets
24 | self.videoImg_list = self.generateImgFromVideo(self.video_root)
25 | print('Total video frames is {}.'.format(self.num_video_frame))
26 | # get all frames from image datasets
27 | if len(self.img_root) > 0:
28 | self.singleImg_list = self.generateImgFromSingle(self.img_root)
29 | print('Total single image frames is {}.'.format(len(self.singleImg_list)))
30 |
31 |
32 | def __getitem__(self, index):
33 | manual_random = random.random() # random for transformation
34 | # pair in video
35 | exemplar_path, exemplar_gt_path, videoStartIndex, videoLength = self.videoImg_list[index] # exemplar
36 | # sample from same video
37 | query_index = np.random.randint(videoStartIndex, videoStartIndex + videoLength)
38 | if query_index == index:
39 | query_index = np.random.randint(videoStartIndex, videoStartIndex + videoLength)
40 | query_path, query_gt_path, videoStartIndex2, videoLength2 = self.videoImg_list[query_index] # query
41 | if videoStartIndex != videoStartIndex2 or videoLength != videoLength2:
42 | raise TypeError('Something wrong')
43 | # sample from different video
44 | while True:
45 | other_index = np.random.randint(0, self.__len__())
46 | if other_index < videoStartIndex or other_index > videoStartIndex + videoLength - 1:
47 | break # find image from different video
48 | other_path, other_gt_path, videoStartIndex3, videoLength3 = self.videoImg_list[other_index] # other
49 | if videoStartIndex == videoStartIndex3:
50 | raise TypeError('Something wrong')
51 | # single image in image dataset
52 | if len(self.img_root) > 0:
53 | single_idx = np.random.randint(0, videoLength)
54 | single_image_path, single_gt_path = self.singleImg_list[single_idx] # single image
55 |
56 | # read image and gt
57 | exemplar = Image.open(exemplar_path).convert('RGB')
58 | query = Image.open(query_path).convert('RGB')
59 | other = Image.open(other_path).convert('RGB')
60 | exemplar_gt = Image.open(exemplar_gt_path).convert('L')
61 | query_gt = Image.open(query_gt_path).convert('L')
62 | other_gt = Image.open(other_gt_path).convert('L')
63 | if len(self.img_root) > 0:
64 | single_image = Image.open(single_image_path).convert('RGB')
65 | single_gt = Image.open(single_gt_path).convert('L')
66 |
67 | # transformation
68 | if self.joint_transform is not None:
69 | exemplar, exemplar_gt = self.joint_transform(exemplar, exemplar_gt, manual_random)
70 | query, query_gt = self.joint_transform(query, query_gt, manual_random)
71 | other, other_gt = self.joint_transform(other, other_gt)
72 | if len(self.img_root) > 0:
73 | single_image, single_gt = self.joint_transform(single_image, single_gt)
74 | if self.img_transform is not None:
75 | exemplar = self.img_transform(exemplar)
76 | query = self.img_transform(query)
77 | other = self.img_transform(other)
78 | if len(self.img_root) > 0:
79 | single_image = self.img_transform(single_image)
80 | if self.target_transform is not None:
81 | exemplar_gt = self.target_transform(exemplar_gt)
82 | query_gt = self.target_transform(query_gt)
83 | other_gt = self.target_transform(other_gt)
84 | if len(self.img_root) > 0:
85 | single_gt = self.target_transform(single_gt)
86 | if len(self.img_root) > 0:
87 | sample = {'exemplar': exemplar, 'exemplar_gt': exemplar_gt, 'query': query, 'query_gt': query_gt,
88 | 'other': other, 'other_gt': other_gt, 'single_image': single_image, 'single_gt': single_gt}
89 | else:
90 | sample = {'exemplar': exemplar, 'exemplar_gt': exemplar_gt, 'query': query, 'query_gt': query_gt,
91 | 'other': other, 'other_gt': other_gt}
92 | return sample
93 |
94 | def generateImgFromVideo(self, root):
95 | imgs = []
96 | root = root[0] # assume that only one video dataset
97 | video_list = os.listdir(os.path.join(root[0], self.input_folder))
98 | for video in video_list:
99 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root[0], self.input_folder, video)) if f.endswith(self.img_ext)] # no ext
100 | img_list = self.sortImg(img_list)
101 | for img in img_list:
102 | # videoImgGt: (img, gt, video start index, video length)
103 | videoImgGt = (os.path.join(root[0], self.input_folder, video, img + self.img_ext),
104 | os.path.join(root[0], self.label_folder, video, img + self.label_ext), self.num_video_frame, len(img_list))
105 | imgs.append(videoImgGt)
106 | self.num_video_frame += len(img_list)
107 | return imgs
108 |
109 | def generateImgFromSingle(self, root):
110 | imgs = []
111 | for sub_root in root:
112 | tmp = self.generateImagePair(sub_root[0])
113 | imgs.extend(tmp) # deal with image case
114 | print('Image number of ImageSet {} is {}.'.format(sub_root[2], len(tmp)))
115 |
116 | return imgs
117 |
118 | def generateImagePair(self, root):
119 | img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, self.input_folder)) if f.endswith(self.img_ext)]
120 | if len(img_list) == 0:
121 | raise IOError('make sure the dataset path is correct')
122 | return [(os.path.join(root, self.input_folder, img_name + self.img_ext), os.path.join(root, self.label_folder, img_name + self.label_ext))
123 | for img_name in img_list]
124 |
125 | def sortImg(self, img_list):
126 | img_int_list = [int(f) for f in img_list]
127 | sort_index = [i for i, v in sorted(enumerate(img_int_list), key=lambda x: x[1])] # sort img to 001,002,003...
128 | return [img_list[i] for i in sort_index]
129 |
130 | def split_root(self, root):
131 | if not isinstance(root, list):
132 | raise TypeError('root should be a list')
133 | img_root_list = []
134 | video_root_list = []
135 | for tmp in root:
136 | if tmp[1] == 'image':
137 | img_root_list.append(tmp)
138 | elif tmp[1] == 'video':
139 | video_root_list.append(tmp)
140 | else:
141 | raise TypeError('you should input video or image')
142 | return img_root_list, video_root_list
143 |
144 | def __len__(self):
145 | return len(self.videoImg_list)//2*2
146 |
147 |
148 |
149 |
--------------------------------------------------------------------------------
/networks/TVSD.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from .DeepLabV3 import DeepLabV3
5 |
6 | class TVSD(nn.Module):
7 | def __init__(self, pretrained_path=None, num_classes=1, all_channel=256, all_dim=26 * 26, T=0.07): # 473./8=60 416./8=52
8 | super(TVSD, self).__init__()
9 | self.encoder = DeepLabV3()
10 | self.T = T
11 | # load pretrained model from DeepLabV3 module
12 | # in our experiments, no need to pretrain the single deeplabv3
13 | if pretrained_path is not None:
14 | checkpoint = torch.load(pretrained_path)
15 | print(f"Load checkpoint:{pretrained_path}")
16 | self.encoder.load_state_dict(checkpoint['model'])
17 | self.co_attention = CoattentionModel(num_classes=num_classes, all_channel=all_channel, all_dim=all_dim)
18 | self.project = nn.Sequential(
19 | nn.Conv2d(256, 48, 1, bias=False),
20 | nn.BatchNorm2d(48),
21 | nn.ReLU(inplace=True)
22 | )
23 | self.final_pre = nn.Sequential(
24 | nn.Conv2d(304, 256, 3, padding=1, bias=False),
25 | nn.BatchNorm2d(256),
26 | nn.ReLU(inplace=True),
27 | nn.Conv2d(256, num_classes, 1)
28 | )
29 | initialize_weights(self.co_attention, self.project, self.final_pre)
30 |
31 | def forward(self, input1, input2, input3):
32 | input_size = input1.size()[2:]
33 | low_exemplar, exemplar, _ = self.encoder(input1)
34 | low_query, query, _ = self.encoder(input2)
35 | low_other, other, _ = self.encoder(input3)
36 | x1, x2 = self.co_attention(exemplar, query)
37 | x1 = F.interpolate(x1, size=low_exemplar.shape[2:], mode='bilinear', align_corners=False)
38 | x2 = F.interpolate(x2, size=low_query.shape[2:], mode='bilinear', align_corners=False)
39 | x3 = F.interpolate(other, size=low_other.shape[2:], mode='bilinear', align_corners=False)
40 | fuse_exemplar = torch.cat([x1, self.project(low_exemplar)], dim=1)
41 | fuse_query = torch.cat([x2, self.project(low_query)], dim=1)
42 | fuse_other = torch.cat([x3, self.project(low_other)], dim=1)
43 | exemplar_pre = self.final_pre(fuse_exemplar)
44 | query_pre = self.final_pre(fuse_query)
45 | other_pre = self.final_pre(fuse_other)
46 |
47 | # scene vector
48 | v1 = F.adaptive_avg_pool2d(exemplar, (1, 1)).squeeze(-1).squeeze(-1)
49 | v1 = nn.functional.normalize(v1, dim=1)
50 | v2 = F.adaptive_avg_pool2d(query, (1, 1)).squeeze(-1).squeeze(-1)
51 | v2 = nn.functional.normalize(v2, dim=1)
52 | v3 = F.adaptive_avg_pool2d(other, (1, 1)).squeeze(-1).squeeze(-1)
53 | v3 = nn.functional.normalize(v3, dim=1)
54 |
55 | l_pos = torch.einsum('nc,nc->n', [v1, v2]).unsqueeze(-1)
56 | l_neg1 = torch.einsum('nc,nc->n', [v1, v3]).unsqueeze(-1)
57 | # l_neg2 = torch.einsum('nc,nc->n', [v2, v3]).unsqueeze(-1)
58 | # logits = torch.cat([l_pos, l_neg1, l_neg2], dim=1)
59 | logits = torch.cat([l_pos, l_neg1], dim=1)
60 | logits /= self.T
61 | exemplar_pre = F.upsample(exemplar_pre, input_size, mode='bilinear', align_corners=False) # upsample to the size of input image, scale=8
62 | query_pre = F.upsample(query_pre, input_size, mode='bilinear', align_corners=False) # upsample to the size of input image, scale=8
63 | other_pre = F.upsample(other_pre, input_size, mode='bilinear', align_corners=False) # upsample to the size of input image, scale=8
64 | return exemplar_pre, query_pre, other_pre, logits
65 |
66 |
67 |
68 | class CoattentionModel(nn.Module): # spatial and channel attention module
69 | def __init__(self, num_classes=1, all_channel=256, all_dim=26 * 26): # 473./8=60 416./8=52
70 | super(CoattentionModel, self).__init__()
71 | self.linear_e = nn.Linear(all_channel, all_channel, bias=False)
72 | self.channel = all_channel
73 | self.dim = all_dim
74 | self.gate1 = nn.Conv2d(all_channel * 2, 1, kernel_size=1, bias=False)
75 | self.gate2 = nn.Conv2d(all_channel * 2, 1, kernel_size=1, bias=False)
76 | self.gate_s = nn.Sigmoid()
77 | self.conv1 = nn.Conv2d(all_channel * 2, all_channel, kernel_size=3, padding=1, bias=False)
78 | self.conv2 = nn.Conv2d(all_channel * 2, all_channel, kernel_size=3, padding=1, bias=False)
79 | self.bn1 = nn.BatchNorm2d(all_channel)
80 | self.bn2 = nn.BatchNorm2d(all_channel)
81 | self.prelu = nn.ReLU(inplace=True)
82 | self.globalAvgPool = nn.AvgPool2d(26, stride=1)
83 | self.fc1 = nn.Linear(in_features=256*2, out_features=16)
84 | self.fc2 = nn.Linear(in_features=16, out_features=256)
85 | self.fc3 = nn.Linear(in_features=256*2, out_features=16)
86 | self.fc4 = nn.Linear(in_features=16, out_features=256)
87 | self.relu = nn.ReLU(inplace=True)
88 | self.sigmoid = nn.Sigmoid()
89 |
90 | def forward(self, exemplar, query):
91 |
92 | # spatial co-attention
93 | fea_size = query.size()[2:]
94 | all_dim = fea_size[0] * fea_size[1]
95 | exemplar_flat = exemplar.view(-1, query.size()[1], all_dim) # N,C,H*W
96 | query_flat = query.view(-1, query.size()[1], all_dim)
97 | exemplar_t = torch.transpose(exemplar_flat, 1, 2).contiguous() # batch size x dim x num
98 | exemplar_corr = self.linear_e(exemplar_t) #
99 | A = torch.bmm(exemplar_corr, query_flat)
100 | A1 = F.softmax(A.clone(), dim=1) #
101 | B = F.softmax(torch.transpose(A, 1, 2), dim=1)
102 | query_att = torch.bmm(exemplar_flat, A1).contiguous()
103 | exemplar_att = torch.bmm(query_flat, B).contiguous()
104 | input1_att = exemplar_att.view(-1, query.size()[1], fea_size[0], fea_size[1])
105 | input2_att = query_att.view(-1, query.size()[1], fea_size[0], fea_size[1])
106 |
107 | # spacial attention
108 | input1_mask = self.gate1(torch.cat([input1_att, input2_att], dim=1))
109 | input2_mask = self.gate2(torch.cat([input1_att, input2_att], dim=1))
110 | input1_mask = self.gate_s(input1_mask)
111 | input2_mask = self.gate_s(input2_mask)
112 |
113 | # channel attention
114 | out_e = self.globalAvgPool(torch.cat([input1_att, input2_att], dim=1))
115 | out_e = out_e.view(out_e.size(0), -1)
116 | out_e = self.fc1(out_e)
117 | out_e = self.relu(out_e)
118 | out_e = self.fc2(out_e)
119 | out_e = self.sigmoid(out_e)
120 | out_e = out_e.view(out_e.size(0), out_e.size(1), 1, 1)
121 | out_q = self.globalAvgPool(torch.cat([input1_att, input2_att], dim=1))
122 | out_q = out_q.view(out_q.size(0), -1)
123 | out_q = self.fc3(out_q)
124 | out_q = self.relu(out_q)
125 | out_q = self.fc4(out_q)
126 | out_q = self.sigmoid(out_q)
127 | out_q = out_q.view(out_q.size(0), out_q.size(1), 1, 1)
128 |
129 | # apply dual attention masks
130 | input1_att = input1_att * input1_mask
131 | input2_att = input2_att * input2_mask
132 | input2_att = out_e * input2_att
133 | input1_att = out_q * input1_att
134 |
135 | # concate original feature
136 | input1_att = torch.cat([input1_att, exemplar], 1)
137 | input2_att = torch.cat([input2_att, query], 1)
138 | input1_att = self.conv1(input1_att)
139 | input2_att = self.conv2(input2_att)
140 | input1_att = self.bn1(input1_att)
141 | input2_att = self.bn2(input2_att)
142 | input1_att = self.prelu(input1_att)
143 | input2_att = self.prelu(input2_att)
144 |
145 | return input1_att, input2_att # shape: NxCx
146 |
147 | def initialize_weights(*models):
148 | for model in models:
149 | for module in model.modules():
150 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
151 | nn.init.kaiming_normal_(module.weight)
152 | if module.bias is not None:
153 | module.bias.data.zero_()
154 | elif isinstance(module, nn.BatchNorm2d):
155 | module.weight.data.fill_(1)
156 | module.bias.data.zero_()
157 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 |
4 | import torch
5 | from torch import nn
6 | from torch import optim
7 | from torch.backends import cudnn
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | from tqdm import tqdm
11 |
12 | import joint_transforms
13 | from config import ViSha_training_root
14 | from dataset.VShadow_crosspairwise import CrossPairwiseImg
15 | from misc import AvgMeter, check_mkdir
16 | from networks.TVSD import TVSD
17 | from torch.optim.lr_scheduler import StepLR
18 | import math
19 | from losses import lovasz_hinge, binary_xloss
20 | import random
21 | import torch.nn.functional as F
22 | import numpy as np
23 | from apex import amp
24 | import time
25 |
26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
27 | cudnn.deterministic = True
28 | cudnn.benchmark = False
29 |
30 | ckpt_path = './models'
31 | exp_name = 'TVSD'
32 |
33 | args = {
34 | 'max_epoch': 12,
35 | 'train_batch_size': 5,
36 | 'last_iter': 0,
37 | 'finetune_lr': 5e-5,
38 | 'scratch_lr': 5e-4,
39 | 'weight_decay': 5e-4,
40 | 'momentum': 0.9,
41 | 'snapshot': '',
42 | 'scale': 416,
43 | 'multi-scale': None,
44 | 'gpu': '0,1',
45 | 'multi-GPUs': False,
46 | 'fp16': True,
47 | 'warm_up_epochs': 3,
48 | 'seed': 2020
49 | }
50 | # fix random seed
51 | np.random.seed(args['seed'])
52 | torch.manual_seed(args['seed'])
53 | torch.cuda.manual_seed(args['seed'])
54 |
55 | # multi-GPUs training
56 | if args['multi-GPUs']:
57 | os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']
58 | batch_size = args['train_batch_size'] * len(args['gpu'].split(','))
59 | # single-GPU training
60 | else:
61 | torch.cuda.set_device(0)
62 | batch_size = args['train_batch_size']
63 |
64 | joint_transform = joint_transforms.Compose([
65 | joint_transforms.RandomHorizontallyFlip(),
66 | joint_transforms.Resize((args['scale'], args['scale']))
67 | ])
68 | val_joint_transform = joint_transforms.Compose([
69 | joint_transforms.Resize((args['scale'], args['scale']))
70 | ])
71 | img_transform = transforms.Compose([
72 | transforms.ToTensor(),
73 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
74 | ])
75 | target_transform = transforms.ToTensor()
76 | to_pil = transforms.ToPILImage()
77 |
78 | print('=====>Dataset loading<======')
79 | training_root = [ViSha_training_root] # training_root should be a list form, like [datasetA, datasetB, datasetC], here we use only one dataset.
80 | train_set = CrossPairwiseImg(training_root, joint_transform, img_transform, target_transform)
81 | train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=8, shuffle=True)
82 | print("max epoch:{}".format(args['max_epoch']))
83 |
84 | ce_loss = nn.CrossEntropyLoss()
85 |
86 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')
87 |
88 | def main():
89 | print('=====>Prepare Network {}<======'.format(exp_name))
90 | # multi-GPUs training
91 | if args['multi-GPUs']:
92 | net = torch.nn.DataParallel(TVSD()).cuda().train()
93 | params = [
94 | {"params": net.module.encoder.parameters(), "lr": args['finetune_lr']},
95 | {"params": net.module.co_attention.parameters(), "lr": args['scratch_lr']},
96 | {"params": net.module.encoder.final_pre.parameters(), "lr": args['scratch_lr']},
97 | {"params": net.module.co_attention.parameters(), "lr": args['scratch_lr']},
98 | {"params": net.module.project.parameters(), "lr": args['scratch_lr']},
99 | {"params": net.module.final_pre.parameters(), "lr": args['scratch_lr']}
100 | ]
101 | # single-GPU training
102 | else:
103 | net = TVSD().cuda().train()
104 | params = [
105 | {"params": net.encoder.backbone.parameters(), "lr": args['finetune_lr']},
106 | {"params": net.encoder.aspp.parameters(), "lr": args['scratch_lr']},
107 | {"params": net.encoder.final_pre.parameters(), "lr": args['scratch_lr']},
108 | {"params": net.co_attention.parameters(), "lr": args['scratch_lr']},
109 | {"params": net.project.parameters(), "lr": args['scratch_lr']},
110 | {"params": net.final_pre.parameters(), "lr": args['scratch_lr']}
111 | ]
112 |
113 | # optimizer = optim.SGD(params, momentum=args['momentum'], weight_decay=args['weight_decay'])
114 | optimizer = optim.Adam(params, betas=(0.9, 0.99), eps=6e-8, weight_decay=args['weight_decay'])
115 | warm_up_with_cosine_lr = lambda epoch: epoch / args['warm_up_epochs'] if epoch <= args['warm_up_epochs'] else 0.5 * \
116 | (math.cos((epoch-args['warm_up_epochs'])/(args['max_epoch']-args['warm_up_epochs'])*math.pi)+1)
117 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_cosine_lr)
118 | # scheduler = StepLR(optimizer, step_size=10, gamma=0.1) # change learning rate after 20000 iters
119 |
120 | check_mkdir(ckpt_path)
121 | check_mkdir(os.path.join(ckpt_path, exp_name))
122 | open(log_path, 'w').write(str(args) + '\n\n')
123 | if args['fp16']:
124 | net, optimizer = amp.initialize(net, optimizer, opt_level="O1")
125 | train(net, optimizer, scheduler)
126 |
127 |
128 | def train(net, optimizer, scheduler):
129 | curr_epoch = 1
130 | curr_iter = 1
131 | start = 0
132 | print('=====>Start training<======')
133 | while True:
134 | loss_record1, loss_record2, loss_record3, loss_record4 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
135 | loss_record5, loss_record6, loss_record7, loss_record8 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
136 |
137 | for i, sample in enumerate(tqdm(train_loader, desc=f'Epoch: {curr_epoch}', ncols=100, ascii=' =', bar_format='{l_bar}{bar}|')):
138 |
139 | exemplar, exemplar_gt, query, query_gt = sample['exemplar'].cuda(), sample['exemplar_gt'].cuda(), sample['query'].cuda(), sample['query_gt'].cuda()
140 | other, other_gt = sample['other'].cuda(), sample['other_gt'].cuda()
141 |
142 | optimizer.zero_grad()
143 |
144 | exemplar_pre, query_pre, other_pre, scene_logits = net(exemplar, query, other)
145 |
146 | bce_loss1 = binary_xloss(exemplar_pre, exemplar_gt)
147 | bce_loss2 = binary_xloss(query_pre, query_gt)
148 | bce_loss3 = binary_xloss(other_pre, other_gt)
149 |
150 | loss_hinge1 = lovasz_hinge(exemplar_pre, exemplar_gt)
151 | loss_hinge2 = lovasz_hinge(query_pre, query_gt)
152 | loss_hinge3 = lovasz_hinge(other_pre, other_gt)
153 |
154 | loss_seg = bce_loss1 + bce_loss2 + bce_loss3 + loss_hinge1 + loss_hinge2 + loss_hinge3
155 | # classification loss
156 | scene_labels = torch.zeros(scene_logits.shape[0], dtype=torch.long).cuda()
157 | cla_loss = ce_loss(scene_logits, scene_labels) * 10
158 | loss = loss_seg + cla_loss
159 |
160 | if args['fp16']:
161 | with amp.scale_loss(loss, optimizer) as scaled_loss:
162 | scaled_loss.backward()
163 | else:
164 | loss.backward()
165 |
166 | torch.nn.utils.clip_grad_norm_(net.parameters(), 12) # gradient clip
167 | optimizer.step() # change gradient
168 |
169 | loss_record1.update(bce_loss1.item(), batch_size)
170 | loss_record2.update(bce_loss2.item(), batch_size)
171 | loss_record3.update(bce_loss3.item(), batch_size)
172 | loss_record4.update(loss_hinge1.item(), batch_size)
173 | loss_record5.update(loss_hinge2.item(), batch_size)
174 | loss_record6.update(loss_hinge3.item(), batch_size)
175 | loss_record7.update(cla_loss.item(), batch_size)
176 |
177 | curr_iter += 1
178 |
179 | log = "iter: %d, bce1: %f5, bce2: %f5, bce3: %f5, hinge1: %f5, hinge2: %f5, hinge3: %f5, cla: %f5, lr: %f8"%\
180 | (curr_iter, loss_record1.avg, loss_record2.avg, loss_record3.avg, loss_record4.avg, loss_record5.avg,
181 | loss_record6.avg, loss_record7.avg, scheduler.get_lr()[0])
182 |
183 | if (curr_iter-1) % 20 == 0:
184 | elapsed = (time.clock() - start)
185 | start = time.clock()
186 | log_time = log + ' [time {}]'.format(elapsed)
187 | print(log_time)
188 | open(log_path, 'a').write(log + '\n')
189 |
190 | if curr_epoch % 1 == 0:
191 | if args['multi-GPUs']:
192 | # torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_epoch))
193 | checkpoint = {
194 | 'model': net.module.state_dict(),
195 | 'optimizer': optimizer.state_dict(),
196 | 'amp': amp.state_dict()
197 | }
198 | torch.save(checkpoint, os.path.join(ckpt_path, exp_name, f'{curr_epoch}.pth'))
199 | else:
200 | # torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_epoch))
201 | checkpoint = {
202 | 'model': net.state_dict(),
203 | 'optimizer': optimizer.state_dict(),
204 | 'amp': amp.state_dict()
205 | }
206 | torch.save(checkpoint, os.path.join(ckpt_path, exp_name, f'{curr_epoch}.pth'))
207 | if curr_epoch > args['max_epoch']:
208 | # torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
209 | return
210 |
211 | curr_epoch += 1
212 | scheduler.step() # change learning rate after epoch
213 |
214 |
215 | if __name__ == '__main__':
216 | main()
--------------------------------------------------------------------------------
/networks/resnext_modify/resnext_101_32x4d_.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class LambdaBase(nn.Sequential):
7 | def __init__(self, fn, *args):
8 | super(LambdaBase, self).__init__(*args)
9 | self.lambda_func = fn
10 |
11 | def forward_prepare(self, input):
12 | output = []
13 | for module in self._modules.values():
14 | output.append(module(input))
15 | return output if output else input
16 |
17 |
18 | class Lambda(LambdaBase):
19 | def forward(self, input):
20 | return self.lambda_func(self.forward_prepare(input))
21 |
22 |
23 | class LambdaMap(LambdaBase):
24 | def forward(self, input):
25 | return list(map(self.lambda_func, self.forward_prepare(input)))
26 |
27 |
28 | class LambdaReduce(LambdaBase):
29 | def forward(self, input):
30 | return reduce(self.lambda_func, self.forward_prepare(input))
31 |
32 |
33 | resnext_101_32x4d = nn.Sequential( # Sequential,
34 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False),
35 | nn.BatchNorm2d(64),
36 | nn.ReLU(),
37 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
38 | nn.Sequential( # Sequential,
39 | nn.Sequential( # Sequential,
40 | LambdaMap(lambda x: x, # ConcatTable,
41 | nn.Sequential( # Sequential,
42 | nn.Sequential( # Sequential,
43 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
44 | nn.BatchNorm2d(128),
45 | nn.ReLU(),
46 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
47 | nn.BatchNorm2d(128),
48 | nn.ReLU(),
49 | ),
50 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
51 | nn.BatchNorm2d(256),
52 | ),
53 | nn.Sequential( # Sequential,
54 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
55 | nn.BatchNorm2d(256),
56 | ),
57 | ),
58 | LambdaReduce(lambda x, y: x + y), # CAddTable,
59 | nn.ReLU(),
60 | ),
61 | nn.Sequential( # Sequential,
62 | LambdaMap(lambda x: x, # ConcatTable,
63 | nn.Sequential( # Sequential,
64 | nn.Sequential( # Sequential,
65 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
66 | nn.BatchNorm2d(128),
67 | nn.ReLU(),
68 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
69 | nn.BatchNorm2d(128),
70 | nn.ReLU(),
71 | ),
72 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
73 | nn.BatchNorm2d(256),
74 | ),
75 | Lambda(lambda x: x), # Identity,
76 | ),
77 | LambdaReduce(lambda x, y: x + y), # CAddTable,
78 | nn.ReLU(),
79 | ),
80 | nn.Sequential( # Sequential,
81 | LambdaMap(lambda x: x, # ConcatTable,
82 | nn.Sequential( # Sequential,
83 | nn.Sequential( # Sequential,
84 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
85 | nn.BatchNorm2d(128),
86 | nn.ReLU(),
87 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
88 | nn.BatchNorm2d(128),
89 | nn.ReLU(),
90 | ),
91 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
92 | nn.BatchNorm2d(256),
93 | ),
94 | Lambda(lambda x: x), # Identity,
95 | ),
96 | LambdaReduce(lambda x, y: x + y), # CAddTable,
97 | nn.ReLU(),
98 | ),
99 | ),
100 | nn.Sequential( # Sequential,
101 | nn.Sequential( # Sequential,
102 | LambdaMap(lambda x: x, # ConcatTable,
103 | nn.Sequential( # Sequential,
104 | nn.Sequential( # Sequential,
105 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
106 | nn.BatchNorm2d(256),
107 | nn.ReLU(),
108 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False),
109 | nn.BatchNorm2d(256),
110 | nn.ReLU(),
111 | ),
112 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
113 | nn.BatchNorm2d(512),
114 | ),
115 | nn.Sequential( # Sequential,
116 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
117 | nn.BatchNorm2d(512),
118 | ),
119 | ),
120 | LambdaReduce(lambda x, y: x + y), # CAddTable,
121 | nn.ReLU(),
122 | ),
123 | nn.Sequential( # Sequential,
124 | LambdaMap(lambda x: x, # ConcatTable,
125 | nn.Sequential( # Sequential,
126 | nn.Sequential( # Sequential,
127 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
128 | nn.BatchNorm2d(256),
129 | nn.ReLU(),
130 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
131 | nn.BatchNorm2d(256),
132 | nn.ReLU(),
133 | ),
134 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
135 | nn.BatchNorm2d(512),
136 | ),
137 | Lambda(lambda x: x), # Identity,
138 | ),
139 | LambdaReduce(lambda x, y: x + y), # CAddTable,
140 | nn.ReLU(),
141 | ),
142 | nn.Sequential( # Sequential,
143 | LambdaMap(lambda x: x, # ConcatTable,
144 | nn.Sequential( # Sequential,
145 | nn.Sequential( # Sequential,
146 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
147 | nn.BatchNorm2d(256),
148 | nn.ReLU(),
149 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
150 | nn.BatchNorm2d(256),
151 | nn.ReLU(),
152 | ),
153 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
154 | nn.BatchNorm2d(512),
155 | ),
156 | Lambda(lambda x: x), # Identity,
157 | ),
158 | LambdaReduce(lambda x, y: x + y), # CAddTable,
159 | nn.ReLU(),
160 | ),
161 | nn.Sequential( # Sequential,
162 | LambdaMap(lambda x: x, # ConcatTable,
163 | nn.Sequential( # Sequential,
164 | nn.Sequential( # Sequential,
165 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
166 | nn.BatchNorm2d(256),
167 | nn.ReLU(),
168 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
169 | nn.BatchNorm2d(256),
170 | nn.ReLU(),
171 | ),
172 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
173 | nn.BatchNorm2d(512),
174 | ),
175 | Lambda(lambda x: x), # Identity,
176 | ),
177 | LambdaReduce(lambda x, y: x + y), # CAddTable,
178 | nn.ReLU(),
179 | ),
180 | ),
181 | nn.Sequential( # Sequential,
182 | nn.Sequential( # Sequential,
183 | LambdaMap(lambda x: x, # ConcatTable,
184 | nn.Sequential( # Sequential,
185 | nn.Sequential( # Sequential,
186 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
187 | nn.BatchNorm2d(512),
188 | nn.ReLU(),
189 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False),
190 | nn.BatchNorm2d(512),
191 | nn.ReLU(),
192 | ),
193 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
194 | nn.BatchNorm2d(1024),
195 | ),
196 | nn.Sequential( # Sequential,
197 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False),
198 | nn.BatchNorm2d(1024),
199 | ),
200 | ),
201 | LambdaReduce(lambda x, y: x + y), # CAddTable,
202 | nn.ReLU(),
203 | ),
204 | nn.Sequential( # Sequential,
205 | LambdaMap(lambda x: x, # ConcatTable,
206 | nn.Sequential( # Sequential,
207 | nn.Sequential( # Sequential,
208 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
209 | nn.BatchNorm2d(512),
210 | nn.ReLU(),
211 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
212 | nn.BatchNorm2d(512),
213 | nn.ReLU(),
214 | ),
215 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
216 | nn.BatchNorm2d(1024),
217 | ),
218 | Lambda(lambda x: x), # Identity,
219 | ),
220 | LambdaReduce(lambda x, y: x + y), # CAddTable,
221 | nn.ReLU(),
222 | ),
223 | nn.Sequential( # Sequential,
224 | LambdaMap(lambda x: x, # ConcatTable,
225 | nn.Sequential( # Sequential,
226 | nn.Sequential( # Sequential,
227 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
228 | nn.BatchNorm2d(512),
229 | nn.ReLU(),
230 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
231 | nn.BatchNorm2d(512),
232 | nn.ReLU(),
233 | ),
234 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
235 | nn.BatchNorm2d(1024),
236 | ),
237 | Lambda(lambda x: x), # Identity,
238 | ),
239 | LambdaReduce(lambda x, y: x + y), # CAddTable,
240 | nn.ReLU(),
241 | ),
242 | nn.Sequential( # Sequential,
243 | LambdaMap(lambda x: x, # ConcatTable,
244 | nn.Sequential( # Sequential,
245 | nn.Sequential( # Sequential,
246 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
247 | nn.BatchNorm2d(512),
248 | nn.ReLU(),
249 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
250 | nn.BatchNorm2d(512),
251 | nn.ReLU(),
252 | ),
253 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
254 | nn.BatchNorm2d(1024),
255 | ),
256 | Lambda(lambda x: x), # Identity,
257 | ),
258 | LambdaReduce(lambda x, y: x + y), # CAddTable,
259 | nn.ReLU(),
260 | ),
261 | nn.Sequential( # Sequential,
262 | LambdaMap(lambda x: x, # ConcatTable,
263 | nn.Sequential( # Sequential,
264 | nn.Sequential( # Sequential,
265 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
266 | nn.BatchNorm2d(512),
267 | nn.ReLU(),
268 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
269 | nn.BatchNorm2d(512),
270 | nn.ReLU(),
271 | ),
272 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
273 | nn.BatchNorm2d(1024),
274 | ),
275 | Lambda(lambda x: x), # Identity,
276 | ),
277 | LambdaReduce(lambda x, y: x + y), # CAddTable,
278 | nn.ReLU(),
279 | ),
280 | nn.Sequential( # Sequential,
281 | LambdaMap(lambda x: x, # ConcatTable,
282 | nn.Sequential( # Sequential,
283 | nn.Sequential( # Sequential,
284 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
285 | nn.BatchNorm2d(512),
286 | nn.ReLU(),
287 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
288 | nn.BatchNorm2d(512),
289 | nn.ReLU(),
290 | ),
291 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
292 | nn.BatchNorm2d(1024),
293 | ),
294 | Lambda(lambda x: x), # Identity,
295 | ),
296 | LambdaReduce(lambda x, y: x + y), # CAddTable,
297 | nn.ReLU(),
298 | ),
299 | nn.Sequential( # Sequential,
300 | LambdaMap(lambda x: x, # ConcatTable,
301 | nn.Sequential( # Sequential,
302 | nn.Sequential( # Sequential,
303 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
304 | nn.BatchNorm2d(512),
305 | nn.ReLU(),
306 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
307 | nn.BatchNorm2d(512),
308 | nn.ReLU(),
309 | ),
310 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
311 | nn.BatchNorm2d(1024),
312 | ),
313 | Lambda(lambda x: x), # Identity,
314 | ),
315 | LambdaReduce(lambda x, y: x + y), # CAddTable,
316 | nn.ReLU(),
317 | ),
318 | nn.Sequential( # Sequential,
319 | LambdaMap(lambda x: x, # ConcatTable,
320 | nn.Sequential( # Sequential,
321 | nn.Sequential( # Sequential,
322 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
323 | nn.BatchNorm2d(512),
324 | nn.ReLU(),
325 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
326 | nn.BatchNorm2d(512),
327 | nn.ReLU(),
328 | ),
329 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
330 | nn.BatchNorm2d(1024),
331 | ),
332 | Lambda(lambda x: x), # Identity,
333 | ),
334 | LambdaReduce(lambda x, y: x + y), # CAddTable,
335 | nn.ReLU(),
336 | ),
337 | nn.Sequential( # Sequential,
338 | LambdaMap(lambda x: x, # ConcatTable,
339 | nn.Sequential( # Sequential,
340 | nn.Sequential( # Sequential,
341 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
342 | nn.BatchNorm2d(512),
343 | nn.ReLU(),
344 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
345 | nn.BatchNorm2d(512),
346 | nn.ReLU(),
347 | ),
348 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
349 | nn.BatchNorm2d(1024),
350 | ),
351 | Lambda(lambda x: x), # Identity,
352 | ),
353 | LambdaReduce(lambda x, y: x + y), # CAddTable,
354 | nn.ReLU(),
355 | ),
356 | nn.Sequential( # Sequential,
357 | LambdaMap(lambda x: x, # ConcatTable,
358 | nn.Sequential( # Sequential,
359 | nn.Sequential( # Sequential,
360 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
361 | nn.BatchNorm2d(512),
362 | nn.ReLU(),
363 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
364 | nn.BatchNorm2d(512),
365 | nn.ReLU(),
366 | ),
367 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
368 | nn.BatchNorm2d(1024),
369 | ),
370 | Lambda(lambda x: x), # Identity,
371 | ),
372 | LambdaReduce(lambda x, y: x + y), # CAddTable,
373 | nn.ReLU(),
374 | ),
375 | nn.Sequential( # Sequential,
376 | LambdaMap(lambda x: x, # ConcatTable,
377 | nn.Sequential( # Sequential,
378 | nn.Sequential( # Sequential,
379 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
380 | nn.BatchNorm2d(512),
381 | nn.ReLU(),
382 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
383 | nn.BatchNorm2d(512),
384 | nn.ReLU(),
385 | ),
386 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
387 | nn.BatchNorm2d(1024),
388 | ),
389 | Lambda(lambda x: x), # Identity,
390 | ),
391 | LambdaReduce(lambda x, y: x + y), # CAddTable,
392 | nn.ReLU(),
393 | ),
394 | nn.Sequential( # Sequential,
395 | LambdaMap(lambda x: x, # ConcatTable,
396 | nn.Sequential( # Sequential,
397 | nn.Sequential( # Sequential,
398 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
399 | nn.BatchNorm2d(512),
400 | nn.ReLU(),
401 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
402 | nn.BatchNorm2d(512),
403 | nn.ReLU(),
404 | ),
405 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
406 | nn.BatchNorm2d(1024),
407 | ),
408 | Lambda(lambda x: x), # Identity,
409 | ),
410 | LambdaReduce(lambda x, y: x + y), # CAddTable,
411 | nn.ReLU(),
412 | ),
413 | nn.Sequential( # Sequential,
414 | LambdaMap(lambda x: x, # ConcatTable,
415 | nn.Sequential( # Sequential,
416 | nn.Sequential( # Sequential,
417 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
418 | nn.BatchNorm2d(512),
419 | nn.ReLU(),
420 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
421 | nn.BatchNorm2d(512),
422 | nn.ReLU(),
423 | ),
424 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
425 | nn.BatchNorm2d(1024),
426 | ),
427 | Lambda(lambda x: x), # Identity,
428 | ),
429 | LambdaReduce(lambda x, y: x + y), # CAddTable,
430 | nn.ReLU(),
431 | ),
432 | nn.Sequential( # Sequential,
433 | LambdaMap(lambda x: x, # ConcatTable,
434 | nn.Sequential( # Sequential,
435 | nn.Sequential( # Sequential,
436 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
437 | nn.BatchNorm2d(512),
438 | nn.ReLU(),
439 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
440 | nn.BatchNorm2d(512),
441 | nn.ReLU(),
442 | ),
443 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
444 | nn.BatchNorm2d(1024),
445 | ),
446 | Lambda(lambda x: x), # Identity,
447 | ),
448 | LambdaReduce(lambda x, y: x + y), # CAddTable,
449 | nn.ReLU(),
450 | ),
451 | nn.Sequential( # Sequential,
452 | LambdaMap(lambda x: x, # ConcatTable,
453 | nn.Sequential( # Sequential,
454 | nn.Sequential( # Sequential,
455 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
456 | nn.BatchNorm2d(512),
457 | nn.ReLU(),
458 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
459 | nn.BatchNorm2d(512),
460 | nn.ReLU(),
461 | ),
462 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
463 | nn.BatchNorm2d(1024),
464 | ),
465 | Lambda(lambda x: x), # Identity,
466 | ),
467 | LambdaReduce(lambda x, y: x + y), # CAddTable,
468 | nn.ReLU(),
469 | ),
470 | nn.Sequential( # Sequential,
471 | LambdaMap(lambda x: x, # ConcatTable,
472 | nn.Sequential( # Sequential,
473 | nn.Sequential( # Sequential,
474 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
475 | nn.BatchNorm2d(512),
476 | nn.ReLU(),
477 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
478 | nn.BatchNorm2d(512),
479 | nn.ReLU(),
480 | ),
481 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
482 | nn.BatchNorm2d(1024),
483 | ),
484 | Lambda(lambda x: x), # Identity,
485 | ),
486 | LambdaReduce(lambda x, y: x + y), # CAddTable,
487 | nn.ReLU(),
488 | ),
489 | nn.Sequential( # Sequential,
490 | LambdaMap(lambda x: x, # ConcatTable,
491 | nn.Sequential( # Sequential,
492 | nn.Sequential( # Sequential,
493 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
494 | nn.BatchNorm2d(512),
495 | nn.ReLU(),
496 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
497 | nn.BatchNorm2d(512),
498 | nn.ReLU(),
499 | ),
500 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
501 | nn.BatchNorm2d(1024),
502 | ),
503 | Lambda(lambda x: x), # Identity,
504 | ),
505 | LambdaReduce(lambda x, y: x + y), # CAddTable,
506 | nn.ReLU(),
507 | ),
508 | nn.Sequential( # Sequential,
509 | LambdaMap(lambda x: x, # ConcatTable,
510 | nn.Sequential( # Sequential,
511 | nn.Sequential( # Sequential,
512 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
513 | nn.BatchNorm2d(512),
514 | nn.ReLU(),
515 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
516 | nn.BatchNorm2d(512),
517 | nn.ReLU(),
518 | ),
519 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
520 | nn.BatchNorm2d(1024),
521 | ),
522 | Lambda(lambda x: x), # Identity,
523 | ),
524 | LambdaReduce(lambda x, y: x + y), # CAddTable,
525 | nn.ReLU(),
526 | ),
527 | nn.Sequential( # Sequential,
528 | LambdaMap(lambda x: x, # ConcatTable,
529 | nn.Sequential( # Sequential,
530 | nn.Sequential( # Sequential,
531 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
532 | nn.BatchNorm2d(512),
533 | nn.ReLU(),
534 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
535 | nn.BatchNorm2d(512),
536 | nn.ReLU(),
537 | ),
538 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
539 | nn.BatchNorm2d(1024),
540 | ),
541 | Lambda(lambda x: x), # Identity,
542 | ),
543 | LambdaReduce(lambda x, y: x + y), # CAddTable,
544 | nn.ReLU(),
545 | ),
546 | nn.Sequential( # Sequential,
547 | LambdaMap(lambda x: x, # ConcatTable,
548 | nn.Sequential( # Sequential,
549 | nn.Sequential( # Sequential,
550 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
551 | nn.BatchNorm2d(512),
552 | nn.ReLU(),
553 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
554 | nn.BatchNorm2d(512),
555 | nn.ReLU(),
556 | ),
557 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
558 | nn.BatchNorm2d(1024),
559 | ),
560 | Lambda(lambda x: x), # Identity,
561 | ),
562 | LambdaReduce(lambda x, y: x + y), # CAddTable,
563 | nn.ReLU(),
564 | ),
565 | nn.Sequential( # Sequential,
566 | LambdaMap(lambda x: x, # ConcatTable,
567 | nn.Sequential( # Sequential,
568 | nn.Sequential( # Sequential,
569 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
570 | nn.BatchNorm2d(512),
571 | nn.ReLU(),
572 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
573 | nn.BatchNorm2d(512),
574 | nn.ReLU(),
575 | ),
576 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
577 | nn.BatchNorm2d(1024),
578 | ),
579 | Lambda(lambda x: x), # Identity,
580 | ),
581 | LambdaReduce(lambda x, y: x + y), # CAddTable,
582 | nn.ReLU(),
583 | ),
584 | nn.Sequential( # Sequential,
585 | LambdaMap(lambda x: x, # ConcatTable,
586 | nn.Sequential( # Sequential,
587 | nn.Sequential( # Sequential,
588 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
589 | nn.BatchNorm2d(512),
590 | nn.ReLU(),
591 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
592 | nn.BatchNorm2d(512),
593 | nn.ReLU(),
594 | ),
595 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
596 | nn.BatchNorm2d(1024),
597 | ),
598 | Lambda(lambda x: x), # Identity,
599 | ),
600 | LambdaReduce(lambda x, y: x + y), # CAddTable,
601 | nn.ReLU(),
602 | ),
603 | nn.Sequential( # Sequential,
604 | LambdaMap(lambda x: x, # ConcatTable,
605 | nn.Sequential( # Sequential,
606 | nn.Sequential( # Sequential,
607 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
608 | nn.BatchNorm2d(512),
609 | nn.ReLU(),
610 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False),
611 | nn.BatchNorm2d(512),
612 | nn.ReLU(),
613 | ),
614 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
615 | nn.BatchNorm2d(1024),
616 | ),
617 | Lambda(lambda x: x), # Identity,
618 | ),
619 | LambdaReduce(lambda x, y: x + y), # CAddTable,
620 | nn.ReLU(),
621 | ),
622 | ),
623 | nn.Sequential( # Sequential,
624 | nn.Sequential( # Sequential,
625 | LambdaMap(lambda x: x, # ConcatTable,
626 | nn.Sequential( # Sequential,
627 | nn.Sequential( # Sequential,
628 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
629 | nn.BatchNorm2d(1024),
630 | nn.ReLU(),
631 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (2, 2), (2, 2), 32, bias=False),
632 | nn.BatchNorm2d(1024),
633 | nn.ReLU(),
634 | ),
635 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
636 | nn.BatchNorm2d(2048),
637 | ),
638 | nn.Sequential( # Sequential,
639 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
640 | nn.BatchNorm2d(2048),
641 | ),
642 | ),
643 | LambdaReduce(lambda x, y: x + y), # CAddTable,
644 | nn.ReLU(),
645 | ),
646 | nn.Sequential( # Sequential,
647 | LambdaMap(lambda x: x, # ConcatTable,
648 | nn.Sequential( # Sequential,
649 | nn.Sequential( # Sequential,
650 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
651 | nn.BatchNorm2d(1024),
652 | nn.ReLU(),
653 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (2, 2), (2, 2), 32, bias=False),
654 | nn.BatchNorm2d(1024),
655 | nn.ReLU(),
656 | ),
657 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
658 | nn.BatchNorm2d(2048),
659 | ),
660 | Lambda(lambda x: x), # Identity,
661 | ),
662 | LambdaReduce(lambda x, y: x + y), # CAddTable,
663 | nn.ReLU(),
664 | ),
665 | nn.Sequential( # Sequential,
666 | LambdaMap(lambda x: x, # ConcatTable,
667 | nn.Sequential( # Sequential,
668 | nn.Sequential( # Sequential,
669 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
670 | nn.BatchNorm2d(1024),
671 | nn.ReLU(),
672 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (2, 2), (2, 2), 32, bias=False),
673 | nn.BatchNorm2d(1024),
674 | nn.ReLU(),
675 | ),
676 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False),
677 | nn.BatchNorm2d(2048),
678 | ),
679 | Lambda(lambda x: x), # Identity,
680 | ),
681 | LambdaReduce(lambda x, y: x + y), # CAddTable,
682 | nn.ReLU(),
683 | ),
684 | ),
685 | nn.AvgPool2d((7, 7), (1, 1)),
686 | Lambda(lambda x: x.view(x.size(0), -1)), # View,
687 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear,
688 | )
689 |
--------------------------------------------------------------------------------