├── LICENSE ├── README.md ├── ckpt └── R3Net │ └── placeholder ├── config.py ├── datasets.py ├── infer.py ├── joint_transforms.py ├── misc.py ├── model.py ├── resnext ├── __init__.py ├── config.py ├── resnext101.py └── resnext_101_32x4d_.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Zijun Deng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # R3Net: Recurrent Residual Refinement Network for Saliency Detection 2 | 3 | by Zijun Deng, Xiaowei Hu, Lei Zhu, Xuemiao Xu, Jing Qin, Guoqiang Han, and Pheng-Ann Heng [[paper link](https://www.ijcai.org/proceedings/2018/95)] 4 | 5 | This implementation is written by Zijun Deng at the South China University of Technology. 6 | 7 | *** 8 | 9 | ## Citation 10 | @inproceedings{deng18r, 11 |      author = {Deng, Zijun and Hu, Xiaowei and Zhu, Lei and Xu, Xuemiao and Qin, Jing and Han, Guoqiang and Heng, Pheng-Ann}, 12 |      title = {R$^{3}${N}et: Recurrent Residual Refinement Network for Saliency Detection}, 13 |      booktitle = {IJCAI}, 14 |      year = {2018} 15 | } 16 | 17 | ## Saliency Map 18 | The results of salienct object detection on five datasets (ECSSD, HKU-IS, PASCAL-S, SOD, DUT-OMRON) can be found 19 | at [Google Drive](https://drive.google.com/open?id=1PloaTokZEfWPy8voDm7mp3yvHnXCtn2c). 20 | 21 | ## Trained Model 22 | You can download the trained model which is reported in our paper at 23 | [Google Drive](https://drive.google.com/open?id=1Y50Cj5Ek-ZIsFj03_pRMSsvqXXeIJSaS). 24 | 25 | ## Requirement 26 | * Python 2.7 27 | * PyTorch 0.4.0 28 | * torchvision 29 | * numpy 30 | * Cython 31 | * pydensecrf ([here](https://github.com/Andrew-Qibin/dss_crf) to install) 32 | 33 | ## Training 34 | 1. Set the path of pretrained ResNeXt model in resnext/config.py 35 | 2. Set the path of MSRA10K dataset in config.py 36 | 3. Run by ```python train.py``` 37 | 38 | The pretrained ResNeXt model is ported from the [official](https://github.com/facebookresearch/ResNeXt) torch version, 39 | using the [convertor](https://github.com/clcarwin/convert_torch_to_pytorch) provided by clcarwin. 40 | You can directly [download](https://drive.google.com/open?id=1dnH-IHwmu9xFPlyndqI6MfF4LvH6JKNQ) the pretrained model ported by me. 41 | 42 | *Hyper-parameters* of training were gathered at the beginning of *train.py* and you can conveniently 43 | change them as you need. 44 | 45 | Training a model on a single GTX 1080Ti GPU takes about 70 minutes. 46 | 47 | ## Testing 48 | 1. Set the path of five benchmark datasets in config.py 49 | 2. Put the trained model in ckpt/R3Net 50 | 2. Run by ```python infer.py``` 51 | 52 | *Settings* of testing were gathered at the beginning of *infer.py* and you can conveniently 53 | change them as you need. 54 | 55 | ## Useful links 56 | * [MSRA10K](http://mmcheng.net/msra10k/): our training set 57 | * [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html), 58 | [HKU-IS](https://sites.google.com/site/ligb86/hkuis), 59 | [PASCAL-S](http://cbi.gatech.edu/salobj/), 60 | [SOD](http://elderlab.yorku.ca/SOD/), 61 | [DUT-OMRON](http://ice.dlut.edu.cn/lu/DUT-OMRON/Homepage.htm): the five benchmark datasets 62 | -------------------------------------------------------------------------------- /ckpt/R3Net/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijundeng/R3Net/93b8a34e68445aea1b97cdda9bc7a34be99309db/ckpt/R3Net/placeholder -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | 4 | datasets_root = '/home/b3-542/文档/DataSets/saliency' 5 | 6 | # For each dataset, I put images and masks together 7 | msra10k_path = os.path.join(datasets_root, 'msra10k') 8 | ecssd_path = os.path.join(datasets_root, 'ecssd') 9 | hkuis_path = os.path.join(datasets_root, 'hkuis') 10 | pascals_path = os.path.join(datasets_root, 'pascals') 11 | dutomron_path = os.path.join(datasets_root, 'dutomron') 12 | sod_path = os.path.join(datasets_root, 'sod') 13 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import torch.utils.data as data 5 | from PIL import Image 6 | 7 | 8 | def make_dataset(root): 9 | img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')] 10 | return [(os.path.join(root, img_name + '.jpg'), os.path.join(root, img_name + '.png')) for img_name in img_list] 11 | 12 | 13 | class ImageFolder(data.Dataset): 14 | # image and gt should be in the same folder and have same filename except extended name (jpg and png respectively) 15 | def __init__(self, root, joint_transform=None, transform=None, target_transform=None): 16 | self.root = root 17 | self.imgs = make_dataset(root) 18 | self.joint_transform = joint_transform 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | 22 | def __getitem__(self, index): 23 | img_path, gt_path = self.imgs[index] 24 | img = Image.open(img_path).convert('RGB') 25 | target = Image.open(gt_path).convert('L') 26 | if self.joint_transform is not None: 27 | img, target = self.joint_transform(img, target) 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | if self.target_transform is not None: 31 | target = self.target_transform(target) 32 | 33 | return img, target 34 | 35 | def __len__(self): 36 | return len(self.imgs) 37 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import torch 5 | from PIL import Image 6 | from torch.autograd import Variable 7 | from torchvision import transforms 8 | 9 | from config import ecssd_path, hkuis_path, pascals_path, sod_path, dutomron_path 10 | from misc import check_mkdir, crf_refine, AvgMeter, cal_precision_recall_mae, cal_fmeasure 11 | from model import R3Net 12 | 13 | torch.manual_seed(2018) 14 | 15 | # set which gpu to use 16 | torch.cuda.set_device(0) 17 | 18 | # the following two args specify the location of the file of trained model (pth extension) 19 | # you should have the pth file in the folder './$ckpt_path$/$exp_name$' 20 | ckpt_path = './ckpt' 21 | exp_name = 'R3Net' 22 | 23 | args = { 24 | 'snapshot': '6000', # your snapshot filename (exclude extension name) 25 | 'crf_refine': True, # whether to use crf to refine results 26 | 'save_results': True # whether to save the resulting masks 27 | } 28 | 29 | img_transform = transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 32 | ]) 33 | to_pil = transforms.ToPILImage() 34 | 35 | to_test = {'ecssd': ecssd_path, 'hkuis': hkuis_path, 'pascal': pascals_path, 'sod': sod_path, 'dutomron': dutomron_path} 36 | 37 | 38 | def main(): 39 | net = R3Net().cuda() 40 | 41 | print 'load snapshot \'%s\' for testing' % args['snapshot'] 42 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 43 | net.eval() 44 | 45 | results = {} 46 | 47 | with torch.no_grad(): 48 | 49 | for name, root in to_test.iteritems(): 50 | 51 | precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)] 52 | mae_record = AvgMeter() 53 | 54 | if args['save_results']: 55 | check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) 56 | 57 | img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')] 58 | for idx, img_name in enumerate(img_list): 59 | print 'predicting for %s: %d / %d' % (name, idx + 1, len(img_list)) 60 | 61 | img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') 62 | img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda() 63 | prediction = net(img_var) 64 | prediction = np.array(to_pil(prediction.data.squeeze(0).cpu())) 65 | 66 | if args['crf_refine']: 67 | prediction = crf_refine(np.array(img), prediction) 68 | 69 | gt = np.array(Image.open(os.path.join(root, img_name + '.png')).convert('L')) 70 | precision, recall, mae = cal_precision_recall_mae(prediction, gt) 71 | for pidx, pdata in enumerate(zip(precision, recall)): 72 | p, r = pdata 73 | precision_record[pidx].update(p) 74 | recall_record[pidx].update(r) 75 | mae_record.update(mae) 76 | 77 | if args['save_results']: 78 | Image.fromarray(prediction).save(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % ( 79 | exp_name, name, args['snapshot']), img_name + '.png')) 80 | 81 | fmeasure = cal_fmeasure([precord.avg for precord in precision_record], 82 | [rrecord.avg for rrecord in recall_record]) 83 | 84 | results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg} 85 | 86 | print 'test results:' 87 | print results 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /joint_transforms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | 4 | from PIL import Image, ImageOps 5 | 6 | 7 | class Compose(object): 8 | def __init__(self, transforms): 9 | self.transforms = transforms 10 | 11 | def __call__(self, img, mask): 12 | assert img.size == mask.size 13 | for t in self.transforms: 14 | img, mask = t(img, mask) 15 | return img, mask 16 | 17 | 18 | class RandomCrop(object): 19 | def __init__(self, size, padding=0): 20 | if isinstance(size, numbers.Number): 21 | self.size = (int(size), int(size)) 22 | else: 23 | self.size = size 24 | self.padding = padding 25 | 26 | def __call__(self, img, mask): 27 | if self.padding > 0: 28 | img = ImageOps.expand(img, border=self.padding, fill=0) 29 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 30 | 31 | assert img.size == mask.size 32 | w, h = img.size 33 | th, tw = self.size 34 | if w == tw and h == th: 35 | return img, mask 36 | if w < tw or h < th: 37 | return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) 38 | 39 | x1 = random.randint(0, w - tw) 40 | y1 = random.randint(0, h - th) 41 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 42 | 43 | 44 | class RandomHorizontallyFlip(object): 45 | def __call__(self, img, mask): 46 | if random.random() < 0.5: 47 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 48 | return img, mask 49 | 50 | 51 | class RandomRotate(object): 52 | def __init__(self, degree): 53 | self.degree = degree 54 | 55 | def __call__(self, img, mask): 56 | rotate_degree = random.random() * 2 * self.degree - self.degree 57 | return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) 58 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import pydensecrf.densecrf as dcrf 5 | 6 | 7 | class AvgMeter(object): 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def check_mkdir(dir_name): 25 | if not os.path.exists(dir_name): 26 | os.mkdir(dir_name) 27 | 28 | 29 | def cal_precision_recall_mae(prediction, gt): 30 | # input should be np array with data type uint8 31 | assert prediction.dtype == np.uint8 32 | assert gt.dtype == np.uint8 33 | assert prediction.shape == gt.shape 34 | 35 | eps = 1e-4 36 | 37 | prediction = prediction / 255. 38 | gt = gt / 255. 39 | 40 | mae = np.mean(np.abs(prediction - gt)) 41 | 42 | hard_gt = np.zeros(prediction.shape) 43 | hard_gt[gt > 0.5] = 1 44 | t = np.sum(hard_gt) 45 | 46 | precision, recall = [], [] 47 | # calculating precision and recall at 255 different binarizing thresholds 48 | for threshold in range(256): 49 | threshold = threshold / 255. 50 | 51 | hard_prediction = np.zeros(prediction.shape) 52 | hard_prediction[prediction > threshold] = 1 53 | 54 | tp = np.sum(hard_prediction * hard_gt) 55 | p = np.sum(hard_prediction) 56 | 57 | precision.append((tp + eps) / (p + eps)) 58 | recall.append((tp + eps) / (t + eps)) 59 | 60 | return precision, recall, mae 61 | 62 | 63 | def cal_fmeasure(precision, recall): 64 | assert len(precision) == 256 65 | assert len(recall) == 256 66 | beta_square = 0.3 67 | max_fmeasure = max([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]) 68 | 69 | return max_fmeasure 70 | 71 | 72 | # codes of this function are borrowed from https://github.com/Andrew-Qibin/dss_crf 73 | def crf_refine(img, annos): 74 | def _sigmoid(x): 75 | return 1 / (1 + np.exp(-x)) 76 | 77 | assert img.dtype == np.uint8 78 | assert annos.dtype == np.uint8 79 | assert img.shape[:2] == annos.shape 80 | 81 | # img and annos should be np array with data type uint8 82 | 83 | EPSILON = 1e-8 84 | 85 | M = 2 # salient or not 86 | tau = 1.05 87 | # Setup the CRF model 88 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 89 | 90 | anno_norm = annos / 255. 91 | 92 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 93 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 94 | 95 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') 96 | U[0, :] = n_energy.flatten() 97 | U[1, :] = p_energy.flatten() 98 | 99 | d.setUnaryEnergy(U) 100 | 101 | d.addPairwiseGaussian(sxy=3, compat=3) 102 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 103 | 104 | # Do the inference 105 | infer = np.array(d.inference(1)).astype('float32') 106 | res = infer[1, :] 107 | 108 | res = res * 255 109 | res = res.reshape(img.shape[:2]) 110 | return res.astype('uint8') 111 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from resnext import ResNeXt101 6 | 7 | 8 | class R3Net(nn.Module): 9 | def __init__(self): 10 | super(R3Net, self).__init__() 11 | resnext = ResNeXt101() 12 | self.layer0 = resnext.layer0 13 | self.layer1 = resnext.layer1 14 | self.layer2 = resnext.layer2 15 | self.layer3 = resnext.layer3 16 | self.layer4 = resnext.layer4 17 | 18 | self.reduce_low = nn.Sequential( 19 | nn.Conv2d(64 + 256 + 512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU(), 20 | nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU(), 21 | nn.Conv2d(256, 256, kernel_size=1), nn.BatchNorm2d(256), nn.PReLU() 22 | ) 23 | self.reduce_high = nn.Sequential( 24 | nn.Conv2d(1024 + 2048, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU(), 25 | nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU(), 26 | _ASPP(256) 27 | ) 28 | 29 | self.predict0 = nn.Conv2d(256, 1, kernel_size=1) 30 | self.predict1 = nn.Sequential( 31 | nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 32 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 33 | nn.Conv2d(128, 1, kernel_size=1) 34 | ) 35 | self.predict2 = nn.Sequential( 36 | nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 37 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 38 | nn.Conv2d(128, 1, kernel_size=1) 39 | ) 40 | self.predict3 = nn.Sequential( 41 | nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 42 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 43 | nn.Conv2d(128, 1, kernel_size=1) 44 | ) 45 | self.predict4 = nn.Sequential( 46 | nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 47 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 48 | nn.Conv2d(128, 1, kernel_size=1) 49 | ) 50 | self.predict5 = nn.Sequential( 51 | nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 52 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 53 | nn.Conv2d(128, 1, kernel_size=1) 54 | ) 55 | self.predict6 = nn.Sequential( 56 | nn.Conv2d(257, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 57 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 58 | nn.Conv2d(128, 1, kernel_size=1) 59 | ) 60 | for m in self.modules(): 61 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout): 62 | m.inplace = True 63 | 64 | def forward(self, x): 65 | layer0 = self.layer0(x) 66 | layer1 = self.layer1(layer0) 67 | layer2 = self.layer2(layer1) 68 | layer3 = self.layer3(layer2) 69 | layer4 = self.layer4(layer3) 70 | 71 | l0_size = layer0.size()[2:] 72 | reduce_low = self.reduce_low(torch.cat(( 73 | layer0, 74 | F.upsample(layer1, size=l0_size, mode='bilinear', align_corners=True), 75 | F.upsample(layer2, size=l0_size, mode='bilinear', align_corners=True)), 1)) 76 | reduce_high = self.reduce_high(torch.cat(( 77 | layer3, 78 | F.upsample(layer4, size=layer3.size()[2:], mode='bilinear', align_corners=True)), 1)) 79 | reduce_high = F.upsample(reduce_high, size=l0_size, mode='bilinear', align_corners=True) 80 | 81 | predict0 = self.predict0(reduce_high) 82 | predict1 = self.predict1(torch.cat((predict0, reduce_low), 1)) + predict0 83 | predict2 = self.predict2(torch.cat((predict1, reduce_high), 1)) + predict1 84 | predict3 = self.predict3(torch.cat((predict2, reduce_low), 1)) + predict2 85 | predict4 = self.predict4(torch.cat((predict3, reduce_high), 1)) + predict3 86 | predict5 = self.predict5(torch.cat((predict4, reduce_low), 1)) + predict4 87 | predict6 = self.predict6(torch.cat((predict5, reduce_high), 1)) + predict5 88 | 89 | predict0 = F.upsample(predict0, size=x.size()[2:], mode='bilinear', align_corners=True) 90 | predict1 = F.upsample(predict1, size=x.size()[2:], mode='bilinear', align_corners=True) 91 | predict2 = F.upsample(predict2, size=x.size()[2:], mode='bilinear', align_corners=True) 92 | predict3 = F.upsample(predict3, size=x.size()[2:], mode='bilinear', align_corners=True) 93 | predict4 = F.upsample(predict4, size=x.size()[2:], mode='bilinear', align_corners=True) 94 | predict5 = F.upsample(predict5, size=x.size()[2:], mode='bilinear', align_corners=True) 95 | predict6 = F.upsample(predict6, size=x.size()[2:], mode='bilinear', align_corners=True) 96 | 97 | if self.training: 98 | return predict0, predict1, predict2, predict3, predict4, predict5, predict6 99 | return F.sigmoid(predict6) 100 | 101 | 102 | class _ASPP(nn.Module): 103 | # this module is proposed in deeplabv3 and we use it in all of our baselines 104 | def __init__(self, in_dim): 105 | super(_ASPP, self).__init__() 106 | down_dim = in_dim / 2 107 | self.conv1 = nn.Sequential( 108 | nn.Conv2d(in_dim, down_dim, kernel_size=1), nn.BatchNorm2d(down_dim), nn.PReLU() 109 | ) 110 | self.conv2 = nn.Sequential( 111 | nn.Conv2d(in_dim, down_dim, kernel_size=3, dilation=2, padding=2), nn.BatchNorm2d(down_dim), nn.PReLU() 112 | ) 113 | self.conv3 = nn.Sequential( 114 | nn.Conv2d(in_dim, down_dim, kernel_size=3, dilation=4, padding=4), nn.BatchNorm2d(down_dim), nn.PReLU() 115 | ) 116 | self.conv4 = nn.Sequential( 117 | nn.Conv2d(in_dim, down_dim, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(down_dim), nn.PReLU() 118 | ) 119 | self.conv5 = nn.Sequential( 120 | nn.Conv2d(in_dim, down_dim, kernel_size=1), nn.BatchNorm2d(down_dim), nn.PReLU() 121 | ) 122 | self.fuse = nn.Sequential( 123 | nn.Conv2d(5 * down_dim, in_dim, kernel_size=1), nn.BatchNorm2d(in_dim), nn.PReLU() 124 | ) 125 | 126 | def forward(self, x): 127 | conv1 = self.conv1(x) 128 | conv2 = self.conv2(x) 129 | conv3 = self.conv3(x) 130 | conv4 = self.conv4(x) 131 | conv5 = F.upsample(self.conv5(F.adaptive_avg_pool2d(x, 1)), size=x.size()[2:], mode='bilinear', 132 | align_corners=True) 133 | return self.fuse(torch.cat((conv1, conv2, conv3, conv4, conv5), 1)) 134 | -------------------------------------------------------------------------------- /resnext/__init__.py: -------------------------------------------------------------------------------- 1 | from resnext101 import ResNeXt101 2 | -------------------------------------------------------------------------------- /resnext/config.py: -------------------------------------------------------------------------------- 1 | resnext101_32_path = 'resnext_101_32x4d.pth' 2 | -------------------------------------------------------------------------------- /resnext/resnext101.py: -------------------------------------------------------------------------------- 1 | import resnext_101_32x4d_ 2 | import torch 3 | from torch import nn 4 | from config import resnext101_32_path 5 | 6 | 7 | class ResNeXt101(nn.Module): 8 | def __init__(self): 9 | super(ResNeXt101, self).__init__() 10 | net = resnext_101_32x4d_.resnext_101_32x4d 11 | net.load_state_dict(torch.load(resnext101_32_path)) 12 | 13 | net = list(net.children()) 14 | self.layer0 = nn.Sequential(*net[:4]) 15 | self.layer1 = net[4] 16 | self.layer2 = net[5] 17 | self.layer3 = net[6] 18 | self.layer4 = net[7] 19 | 20 | def forward(self, x): 21 | layer0 = self.layer0(x) 22 | layer1 = self.layer1(layer0) 23 | layer2 = self.layer2(layer1) 24 | layer3 = self.layer3(layer2) 25 | layer4 = self.layer4(layer3) 26 | return layer4 27 | -------------------------------------------------------------------------------- /resnext/resnext_101_32x4d_.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from functools import reduce 5 | 6 | 7 | class LambdaBase(nn.Sequential): 8 | def __init__(self, fn, *args): 9 | super(LambdaBase, self).__init__(*args) 10 | self.lambda_func = fn 11 | 12 | def forward_prepare(self, input): 13 | output = [] 14 | for module in self._modules.values(): 15 | output.append(module(input)) 16 | return output if output else input 17 | 18 | 19 | class Lambda(LambdaBase): 20 | def forward(self, input): 21 | return self.lambda_func(self.forward_prepare(input)) 22 | 23 | 24 | class LambdaMap(LambdaBase): 25 | def forward(self, input): 26 | return list(map(self.lambda_func, self.forward_prepare(input))) 27 | 28 | 29 | class LambdaReduce(LambdaBase): 30 | def forward(self, input): 31 | return reduce(self.lambda_func, self.forward_prepare(input)) 32 | 33 | 34 | resnext_101_32x4d = nn.Sequential( # Sequential, 35 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), 36 | nn.BatchNorm2d(64), 37 | nn.ReLU(), 38 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)), 39 | nn.Sequential( # Sequential, 40 | nn.Sequential( # Sequential, 41 | LambdaMap(lambda x: x, # ConcatTable, 42 | nn.Sequential( # Sequential, 43 | nn.Sequential( # Sequential, 44 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 45 | nn.BatchNorm2d(128), 46 | nn.ReLU(), 47 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 48 | nn.BatchNorm2d(128), 49 | nn.ReLU(), 50 | ), 51 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 52 | nn.BatchNorm2d(256), 53 | ), 54 | nn.Sequential( # Sequential, 55 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 56 | nn.BatchNorm2d(256), 57 | ), 58 | ), 59 | LambdaReduce(lambda x, y: x + y), # CAddTable, 60 | nn.ReLU(), 61 | ), 62 | nn.Sequential( # Sequential, 63 | LambdaMap(lambda x: x, # ConcatTable, 64 | nn.Sequential( # Sequential, 65 | nn.Sequential( # Sequential, 66 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 67 | nn.BatchNorm2d(128), 68 | nn.ReLU(), 69 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 70 | nn.BatchNorm2d(128), 71 | nn.ReLU(), 72 | ), 73 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 74 | nn.BatchNorm2d(256), 75 | ), 76 | Lambda(lambda x: x), # Identity, 77 | ), 78 | LambdaReduce(lambda x, y: x + y), # CAddTable, 79 | nn.ReLU(), 80 | ), 81 | nn.Sequential( # Sequential, 82 | LambdaMap(lambda x: x, # ConcatTable, 83 | nn.Sequential( # Sequential, 84 | nn.Sequential( # Sequential, 85 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 86 | nn.BatchNorm2d(128), 87 | nn.ReLU(), 88 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 89 | nn.BatchNorm2d(128), 90 | nn.ReLU(), 91 | ), 92 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 93 | nn.BatchNorm2d(256), 94 | ), 95 | Lambda(lambda x: x), # Identity, 96 | ), 97 | LambdaReduce(lambda x, y: x + y), # CAddTable, 98 | nn.ReLU(), 99 | ), 100 | ), 101 | nn.Sequential( # Sequential, 102 | nn.Sequential( # Sequential, 103 | LambdaMap(lambda x: x, # ConcatTable, 104 | nn.Sequential( # Sequential, 105 | nn.Sequential( # Sequential, 106 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 107 | nn.BatchNorm2d(256), 108 | nn.ReLU(), 109 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 110 | nn.BatchNorm2d(256), 111 | nn.ReLU(), 112 | ), 113 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 114 | nn.BatchNorm2d(512), 115 | ), 116 | nn.Sequential( # Sequential, 117 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 118 | nn.BatchNorm2d(512), 119 | ), 120 | ), 121 | LambdaReduce(lambda x, y: x + y), # CAddTable, 122 | nn.ReLU(), 123 | ), 124 | nn.Sequential( # Sequential, 125 | LambdaMap(lambda x: x, # ConcatTable, 126 | nn.Sequential( # Sequential, 127 | nn.Sequential( # Sequential, 128 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 129 | nn.BatchNorm2d(256), 130 | nn.ReLU(), 131 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 132 | nn.BatchNorm2d(256), 133 | nn.ReLU(), 134 | ), 135 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 136 | nn.BatchNorm2d(512), 137 | ), 138 | Lambda(lambda x: x), # Identity, 139 | ), 140 | LambdaReduce(lambda x, y: x + y), # CAddTable, 141 | nn.ReLU(), 142 | ), 143 | nn.Sequential( # Sequential, 144 | LambdaMap(lambda x: x, # ConcatTable, 145 | nn.Sequential( # Sequential, 146 | nn.Sequential( # Sequential, 147 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 148 | nn.BatchNorm2d(256), 149 | nn.ReLU(), 150 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 151 | nn.BatchNorm2d(256), 152 | nn.ReLU(), 153 | ), 154 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 155 | nn.BatchNorm2d(512), 156 | ), 157 | Lambda(lambda x: x), # Identity, 158 | ), 159 | LambdaReduce(lambda x, y: x + y), # CAddTable, 160 | nn.ReLU(), 161 | ), 162 | nn.Sequential( # Sequential, 163 | LambdaMap(lambda x: x, # ConcatTable, 164 | nn.Sequential( # Sequential, 165 | nn.Sequential( # Sequential, 166 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 167 | nn.BatchNorm2d(256), 168 | nn.ReLU(), 169 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 170 | nn.BatchNorm2d(256), 171 | nn.ReLU(), 172 | ), 173 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 174 | nn.BatchNorm2d(512), 175 | ), 176 | Lambda(lambda x: x), # Identity, 177 | ), 178 | LambdaReduce(lambda x, y: x + y), # CAddTable, 179 | nn.ReLU(), 180 | ), 181 | ), 182 | nn.Sequential( # Sequential, 183 | nn.Sequential( # Sequential, 184 | LambdaMap(lambda x: x, # ConcatTable, 185 | nn.Sequential( # Sequential, 186 | nn.Sequential( # Sequential, 187 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 188 | nn.BatchNorm2d(512), 189 | nn.ReLU(), 190 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 191 | nn.BatchNorm2d(512), 192 | nn.ReLU(), 193 | ), 194 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 195 | nn.BatchNorm2d(1024), 196 | ), 197 | nn.Sequential( # Sequential, 198 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 199 | nn.BatchNorm2d(1024), 200 | ), 201 | ), 202 | LambdaReduce(lambda x, y: x + y), # CAddTable, 203 | nn.ReLU(), 204 | ), 205 | nn.Sequential( # Sequential, 206 | LambdaMap(lambda x: x, # ConcatTable, 207 | nn.Sequential( # Sequential, 208 | nn.Sequential( # Sequential, 209 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 210 | nn.BatchNorm2d(512), 211 | nn.ReLU(), 212 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 213 | nn.BatchNorm2d(512), 214 | nn.ReLU(), 215 | ), 216 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 217 | nn.BatchNorm2d(1024), 218 | ), 219 | Lambda(lambda x: x), # Identity, 220 | ), 221 | LambdaReduce(lambda x, y: x + y), # CAddTable, 222 | nn.ReLU(), 223 | ), 224 | nn.Sequential( # Sequential, 225 | LambdaMap(lambda x: x, # ConcatTable, 226 | nn.Sequential( # Sequential, 227 | nn.Sequential( # Sequential, 228 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 229 | nn.BatchNorm2d(512), 230 | nn.ReLU(), 231 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 232 | nn.BatchNorm2d(512), 233 | nn.ReLU(), 234 | ), 235 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 236 | nn.BatchNorm2d(1024), 237 | ), 238 | Lambda(lambda x: x), # Identity, 239 | ), 240 | LambdaReduce(lambda x, y: x + y), # CAddTable, 241 | nn.ReLU(), 242 | ), 243 | nn.Sequential( # Sequential, 244 | LambdaMap(lambda x: x, # ConcatTable, 245 | nn.Sequential( # Sequential, 246 | nn.Sequential( # Sequential, 247 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 248 | nn.BatchNorm2d(512), 249 | nn.ReLU(), 250 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 251 | nn.BatchNorm2d(512), 252 | nn.ReLU(), 253 | ), 254 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 255 | nn.BatchNorm2d(1024), 256 | ), 257 | Lambda(lambda x: x), # Identity, 258 | ), 259 | LambdaReduce(lambda x, y: x + y), # CAddTable, 260 | nn.ReLU(), 261 | ), 262 | nn.Sequential( # Sequential, 263 | LambdaMap(lambda x: x, # ConcatTable, 264 | nn.Sequential( # Sequential, 265 | nn.Sequential( # Sequential, 266 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 267 | nn.BatchNorm2d(512), 268 | nn.ReLU(), 269 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 270 | nn.BatchNorm2d(512), 271 | nn.ReLU(), 272 | ), 273 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 274 | nn.BatchNorm2d(1024), 275 | ), 276 | Lambda(lambda x: x), # Identity, 277 | ), 278 | LambdaReduce(lambda x, y: x + y), # CAddTable, 279 | nn.ReLU(), 280 | ), 281 | nn.Sequential( # Sequential, 282 | LambdaMap(lambda x: x, # ConcatTable, 283 | nn.Sequential( # Sequential, 284 | nn.Sequential( # Sequential, 285 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 286 | nn.BatchNorm2d(512), 287 | nn.ReLU(), 288 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 289 | nn.BatchNorm2d(512), 290 | nn.ReLU(), 291 | ), 292 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 293 | nn.BatchNorm2d(1024), 294 | ), 295 | Lambda(lambda x: x), # Identity, 296 | ), 297 | LambdaReduce(lambda x, y: x + y), # CAddTable, 298 | nn.ReLU(), 299 | ), 300 | nn.Sequential( # Sequential, 301 | LambdaMap(lambda x: x, # ConcatTable, 302 | nn.Sequential( # Sequential, 303 | nn.Sequential( # Sequential, 304 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 305 | nn.BatchNorm2d(512), 306 | nn.ReLU(), 307 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 308 | nn.BatchNorm2d(512), 309 | nn.ReLU(), 310 | ), 311 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 312 | nn.BatchNorm2d(1024), 313 | ), 314 | Lambda(lambda x: x), # Identity, 315 | ), 316 | LambdaReduce(lambda x, y: x + y), # CAddTable, 317 | nn.ReLU(), 318 | ), 319 | nn.Sequential( # Sequential, 320 | LambdaMap(lambda x: x, # ConcatTable, 321 | nn.Sequential( # Sequential, 322 | nn.Sequential( # Sequential, 323 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 324 | nn.BatchNorm2d(512), 325 | nn.ReLU(), 326 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 327 | nn.BatchNorm2d(512), 328 | nn.ReLU(), 329 | ), 330 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 331 | nn.BatchNorm2d(1024), 332 | ), 333 | Lambda(lambda x: x), # Identity, 334 | ), 335 | LambdaReduce(lambda x, y: x + y), # CAddTable, 336 | nn.ReLU(), 337 | ), 338 | nn.Sequential( # Sequential, 339 | LambdaMap(lambda x: x, # ConcatTable, 340 | nn.Sequential( # Sequential, 341 | nn.Sequential( # Sequential, 342 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 343 | nn.BatchNorm2d(512), 344 | nn.ReLU(), 345 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 346 | nn.BatchNorm2d(512), 347 | nn.ReLU(), 348 | ), 349 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 350 | nn.BatchNorm2d(1024), 351 | ), 352 | Lambda(lambda x: x), # Identity, 353 | ), 354 | LambdaReduce(lambda x, y: x + y), # CAddTable, 355 | nn.ReLU(), 356 | ), 357 | nn.Sequential( # Sequential, 358 | LambdaMap(lambda x: x, # ConcatTable, 359 | nn.Sequential( # Sequential, 360 | nn.Sequential( # Sequential, 361 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 362 | nn.BatchNorm2d(512), 363 | nn.ReLU(), 364 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 365 | nn.BatchNorm2d(512), 366 | nn.ReLU(), 367 | ), 368 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 369 | nn.BatchNorm2d(1024), 370 | ), 371 | Lambda(lambda x: x), # Identity, 372 | ), 373 | LambdaReduce(lambda x, y: x + y), # CAddTable, 374 | nn.ReLU(), 375 | ), 376 | nn.Sequential( # Sequential, 377 | LambdaMap(lambda x: x, # ConcatTable, 378 | nn.Sequential( # Sequential, 379 | nn.Sequential( # Sequential, 380 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 381 | nn.BatchNorm2d(512), 382 | nn.ReLU(), 383 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 384 | nn.BatchNorm2d(512), 385 | nn.ReLU(), 386 | ), 387 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 388 | nn.BatchNorm2d(1024), 389 | ), 390 | Lambda(lambda x: x), # Identity, 391 | ), 392 | LambdaReduce(lambda x, y: x + y), # CAddTable, 393 | nn.ReLU(), 394 | ), 395 | nn.Sequential( # Sequential, 396 | LambdaMap(lambda x: x, # ConcatTable, 397 | nn.Sequential( # Sequential, 398 | nn.Sequential( # Sequential, 399 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 400 | nn.BatchNorm2d(512), 401 | nn.ReLU(), 402 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 403 | nn.BatchNorm2d(512), 404 | nn.ReLU(), 405 | ), 406 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 407 | nn.BatchNorm2d(1024), 408 | ), 409 | Lambda(lambda x: x), # Identity, 410 | ), 411 | LambdaReduce(lambda x, y: x + y), # CAddTable, 412 | nn.ReLU(), 413 | ), 414 | nn.Sequential( # Sequential, 415 | LambdaMap(lambda x: x, # ConcatTable, 416 | nn.Sequential( # Sequential, 417 | nn.Sequential( # Sequential, 418 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 419 | nn.BatchNorm2d(512), 420 | nn.ReLU(), 421 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 422 | nn.BatchNorm2d(512), 423 | nn.ReLU(), 424 | ), 425 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 426 | nn.BatchNorm2d(1024), 427 | ), 428 | Lambda(lambda x: x), # Identity, 429 | ), 430 | LambdaReduce(lambda x, y: x + y), # CAddTable, 431 | nn.ReLU(), 432 | ), 433 | nn.Sequential( # Sequential, 434 | LambdaMap(lambda x: x, # ConcatTable, 435 | nn.Sequential( # Sequential, 436 | nn.Sequential( # Sequential, 437 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 438 | nn.BatchNorm2d(512), 439 | nn.ReLU(), 440 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 441 | nn.BatchNorm2d(512), 442 | nn.ReLU(), 443 | ), 444 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 445 | nn.BatchNorm2d(1024), 446 | ), 447 | Lambda(lambda x: x), # Identity, 448 | ), 449 | LambdaReduce(lambda x, y: x + y), # CAddTable, 450 | nn.ReLU(), 451 | ), 452 | nn.Sequential( # Sequential, 453 | LambdaMap(lambda x: x, # ConcatTable, 454 | nn.Sequential( # Sequential, 455 | nn.Sequential( # Sequential, 456 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 457 | nn.BatchNorm2d(512), 458 | nn.ReLU(), 459 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 460 | nn.BatchNorm2d(512), 461 | nn.ReLU(), 462 | ), 463 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 464 | nn.BatchNorm2d(1024), 465 | ), 466 | Lambda(lambda x: x), # Identity, 467 | ), 468 | LambdaReduce(lambda x, y: x + y), # CAddTable, 469 | nn.ReLU(), 470 | ), 471 | nn.Sequential( # Sequential, 472 | LambdaMap(lambda x: x, # ConcatTable, 473 | nn.Sequential( # Sequential, 474 | nn.Sequential( # Sequential, 475 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 476 | nn.BatchNorm2d(512), 477 | nn.ReLU(), 478 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 479 | nn.BatchNorm2d(512), 480 | nn.ReLU(), 481 | ), 482 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 483 | nn.BatchNorm2d(1024), 484 | ), 485 | Lambda(lambda x: x), # Identity, 486 | ), 487 | LambdaReduce(lambda x, y: x + y), # CAddTable, 488 | nn.ReLU(), 489 | ), 490 | nn.Sequential( # Sequential, 491 | LambdaMap(lambda x: x, # ConcatTable, 492 | nn.Sequential( # Sequential, 493 | nn.Sequential( # Sequential, 494 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 495 | nn.BatchNorm2d(512), 496 | nn.ReLU(), 497 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 498 | nn.BatchNorm2d(512), 499 | nn.ReLU(), 500 | ), 501 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 502 | nn.BatchNorm2d(1024), 503 | ), 504 | Lambda(lambda x: x), # Identity, 505 | ), 506 | LambdaReduce(lambda x, y: x + y), # CAddTable, 507 | nn.ReLU(), 508 | ), 509 | nn.Sequential( # Sequential, 510 | LambdaMap(lambda x: x, # ConcatTable, 511 | nn.Sequential( # Sequential, 512 | nn.Sequential( # Sequential, 513 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 514 | nn.BatchNorm2d(512), 515 | nn.ReLU(), 516 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 517 | nn.BatchNorm2d(512), 518 | nn.ReLU(), 519 | ), 520 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 521 | nn.BatchNorm2d(1024), 522 | ), 523 | Lambda(lambda x: x), # Identity, 524 | ), 525 | LambdaReduce(lambda x, y: x + y), # CAddTable, 526 | nn.ReLU(), 527 | ), 528 | nn.Sequential( # Sequential, 529 | LambdaMap(lambda x: x, # ConcatTable, 530 | nn.Sequential( # Sequential, 531 | nn.Sequential( # Sequential, 532 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 533 | nn.BatchNorm2d(512), 534 | nn.ReLU(), 535 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 536 | nn.BatchNorm2d(512), 537 | nn.ReLU(), 538 | ), 539 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 540 | nn.BatchNorm2d(1024), 541 | ), 542 | Lambda(lambda x: x), # Identity, 543 | ), 544 | LambdaReduce(lambda x, y: x + y), # CAddTable, 545 | nn.ReLU(), 546 | ), 547 | nn.Sequential( # Sequential, 548 | LambdaMap(lambda x: x, # ConcatTable, 549 | nn.Sequential( # Sequential, 550 | nn.Sequential( # Sequential, 551 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 552 | nn.BatchNorm2d(512), 553 | nn.ReLU(), 554 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 555 | nn.BatchNorm2d(512), 556 | nn.ReLU(), 557 | ), 558 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 559 | nn.BatchNorm2d(1024), 560 | ), 561 | Lambda(lambda x: x), # Identity, 562 | ), 563 | LambdaReduce(lambda x, y: x + y), # CAddTable, 564 | nn.ReLU(), 565 | ), 566 | nn.Sequential( # Sequential, 567 | LambdaMap(lambda x: x, # ConcatTable, 568 | nn.Sequential( # Sequential, 569 | nn.Sequential( # Sequential, 570 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 571 | nn.BatchNorm2d(512), 572 | nn.ReLU(), 573 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 574 | nn.BatchNorm2d(512), 575 | nn.ReLU(), 576 | ), 577 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 578 | nn.BatchNorm2d(1024), 579 | ), 580 | Lambda(lambda x: x), # Identity, 581 | ), 582 | LambdaReduce(lambda x, y: x + y), # CAddTable, 583 | nn.ReLU(), 584 | ), 585 | nn.Sequential( # Sequential, 586 | LambdaMap(lambda x: x, # ConcatTable, 587 | nn.Sequential( # Sequential, 588 | nn.Sequential( # Sequential, 589 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 590 | nn.BatchNorm2d(512), 591 | nn.ReLU(), 592 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 593 | nn.BatchNorm2d(512), 594 | nn.ReLU(), 595 | ), 596 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 597 | nn.BatchNorm2d(1024), 598 | ), 599 | Lambda(lambda x: x), # Identity, 600 | ), 601 | LambdaReduce(lambda x, y: x + y), # CAddTable, 602 | nn.ReLU(), 603 | ), 604 | nn.Sequential( # Sequential, 605 | LambdaMap(lambda x: x, # ConcatTable, 606 | nn.Sequential( # Sequential, 607 | nn.Sequential( # Sequential, 608 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 609 | nn.BatchNorm2d(512), 610 | nn.ReLU(), 611 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 612 | nn.BatchNorm2d(512), 613 | nn.ReLU(), 614 | ), 615 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 616 | nn.BatchNorm2d(1024), 617 | ), 618 | Lambda(lambda x: x), # Identity, 619 | ), 620 | LambdaReduce(lambda x, y: x + y), # CAddTable, 621 | nn.ReLU(), 622 | ), 623 | ), 624 | nn.Sequential( # Sequential, 625 | nn.Sequential( # Sequential, 626 | LambdaMap(lambda x: x, # ConcatTable, 627 | nn.Sequential( # Sequential, 628 | nn.Sequential( # Sequential, 629 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 630 | nn.BatchNorm2d(1024), 631 | nn.ReLU(), 632 | nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 633 | nn.BatchNorm2d(1024), 634 | nn.ReLU(), 635 | ), 636 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 637 | nn.BatchNorm2d(2048), 638 | ), 639 | nn.Sequential( # Sequential, 640 | nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 641 | nn.BatchNorm2d(2048), 642 | ), 643 | ), 644 | LambdaReduce(lambda x, y: x + y), # CAddTable, 645 | nn.ReLU(), 646 | ), 647 | nn.Sequential( # Sequential, 648 | LambdaMap(lambda x: x, # ConcatTable, 649 | nn.Sequential( # Sequential, 650 | nn.Sequential( # Sequential, 651 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 652 | nn.BatchNorm2d(1024), 653 | nn.ReLU(), 654 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 655 | nn.BatchNorm2d(1024), 656 | nn.ReLU(), 657 | ), 658 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 659 | nn.BatchNorm2d(2048), 660 | ), 661 | Lambda(lambda x: x), # Identity, 662 | ), 663 | LambdaReduce(lambda x, y: x + y), # CAddTable, 664 | nn.ReLU(), 665 | ), 666 | nn.Sequential( # Sequential, 667 | LambdaMap(lambda x: x, # ConcatTable, 668 | nn.Sequential( # Sequential, 669 | nn.Sequential( # Sequential, 670 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 671 | nn.BatchNorm2d(1024), 672 | nn.ReLU(), 673 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 674 | nn.BatchNorm2d(1024), 675 | nn.ReLU(), 676 | ), 677 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 678 | nn.BatchNorm2d(2048), 679 | ), 680 | Lambda(lambda x: x), # Identity, 681 | ), 682 | LambdaReduce(lambda x, y: x + y), # CAddTable, 683 | nn.ReLU(), 684 | ), 685 | ), 686 | nn.AvgPool2d((7, 7), (1, 1)), 687 | Lambda(lambda x: x.view(x.size(0), -1)), # View, 688 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear, 689 | ) 690 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | 11 | import joint_transforms 12 | from config import msra10k_path 13 | from datasets import ImageFolder 14 | from misc import AvgMeter, check_mkdir 15 | from model import R3Net 16 | from torch.backends import cudnn 17 | 18 | cudnn.benchmark = True 19 | 20 | torch.manual_seed(2018) 21 | torch.cuda.set_device(0) 22 | 23 | ckpt_path = './ckpt' 24 | exp_name = 'R3Net' 25 | 26 | args = { 27 | 'iter_num': 6000, 28 | 'train_batch_size': 14, 29 | 'last_iter': 0, 30 | 'lr': 1e-3, 31 | 'lr_decay': 0.9, 32 | 'weight_decay': 5e-4, 33 | 'momentum': 0.9, 34 | 'snapshot': '' 35 | } 36 | 37 | joint_transform = joint_transforms.Compose([ 38 | joint_transforms.RandomCrop(300), 39 | joint_transforms.RandomHorizontallyFlip(), 40 | joint_transforms.RandomRotate(10) 41 | ]) 42 | img_transform = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 45 | ]) 46 | target_transform = transforms.ToTensor() 47 | 48 | train_set = ImageFolder(msra10k_path, joint_transform, img_transform, target_transform) 49 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=12, shuffle=True) 50 | 51 | criterion = nn.BCEWithLogitsLoss().cuda() 52 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt') 53 | 54 | 55 | def main(): 56 | net = R3Net().cuda().train() 57 | 58 | optimizer = optim.SGD([ 59 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 60 | 'lr': 2 * args['lr']}, 61 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 62 | 'lr': args['lr'], 'weight_decay': args['weight_decay']} 63 | ], momentum=args['momentum']) 64 | 65 | if len(args['snapshot']) > 0: 66 | print 'training resumes from ' + args['snapshot'] 67 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 68 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth'))) 69 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] 70 | optimizer.param_groups[1]['lr'] = args['lr'] 71 | 72 | check_mkdir(ckpt_path) 73 | check_mkdir(os.path.join(ckpt_path, exp_name)) 74 | open(log_path, 'w').write(str(args) + '\n\n') 75 | train(net, optimizer) 76 | 77 | 78 | def train(net, optimizer): 79 | curr_iter = args['last_iter'] 80 | while True: 81 | total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 82 | loss3_record, loss4_record, loss5_record, loss6_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 83 | 84 | for i, data in enumerate(train_loader): 85 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num'] 86 | ) ** args['lr_decay'] 87 | optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num'] 88 | ) ** args['lr_decay'] 89 | 90 | inputs, labels = data 91 | batch_size = inputs.size(0) 92 | inputs = Variable(inputs).cuda() 93 | labels = Variable(labels).cuda() 94 | 95 | optimizer.zero_grad() 96 | outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs6 = net(inputs) 97 | loss0 = criterion(outputs0, labels) 98 | loss1 = criterion(outputs1, labels) 99 | loss2 = criterion(outputs2, labels) 100 | loss3 = criterion(outputs3, labels) 101 | loss4 = criterion(outputs4, labels) 102 | loss5 = criterion(outputs5, labels) 103 | loss6 = criterion(outputs6, labels) 104 | 105 | total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 106 | total_loss.backward() 107 | optimizer.step() 108 | 109 | total_loss_record.update(total_loss.data[0], batch_size) 110 | loss0_record.update(loss0.data[0], batch_size) 111 | loss1_record.update(loss1.data[0], batch_size) 112 | loss2_record.update(loss2.data[0], batch_size) 113 | loss3_record.update(loss3.data[0], batch_size) 114 | loss4_record.update(loss4.data[0], batch_size) 115 | loss5_record.update(loss5.data[0], batch_size) 116 | loss6_record.update(loss6.data[0], batch_size) 117 | 118 | curr_iter += 1 119 | 120 | log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \ 121 | '[loss4 %.5f], [loss5 %.5f], [loss6 %.5f], [lr %.13f]' % \ 122 | (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg, 123 | loss3_record.avg, loss4_record.avg, loss5_record.avg, loss6_record.avg, 124 | optimizer.param_groups[1]['lr']) 125 | print log 126 | open(log_path, 'a').write(log + '\n') 127 | 128 | if curr_iter == args['iter_num']: 129 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) 130 | torch.save(optimizer.state_dict(), 131 | os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) 132 | return 133 | 134 | 135 | if __name__ == '__main__': 136 | main() 137 | --------------------------------------------------------------------------------