├── .gitignore ├── LICENSE.md ├── dataset.py ├── utils.py ├── README.md ├── test.py ├── models.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.pth 4 | *.mat 5 | data/ 6 | results/ 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The code is released under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for NonCommercial use only. Any commercial use should get formal permission first (Email: VAGRANTLYUN@GMAIL.COM). 2 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | import os.path as osp 5 | 6 | 7 | class BSDS_Dataset(torch.utils.data.Dataset): 8 | def __init__(self, root='data/HED-BSDS', split='test', transform=False): 9 | super(BSDS_Dataset, self).__init__() 10 | self.root = root 11 | self.split = split 12 | self.transform = transform 13 | if self.split == 'train': 14 | self.file_list = osp.join(self.root, 'bsds_pascal_train_pair.lst') 15 | elif self.split == 'test': 16 | self.file_list = osp.join(self.root, 'test.lst') 17 | else: 18 | raise ValueError('Invalid split type!') 19 | with open(self.file_list, 'r') as f: 20 | self.file_list = f.readlines() 21 | self.mean = np.array([104.00698793, 116.66876762, 122.67891434], dtype=np.float32) 22 | 23 | def __len__(self): 24 | return len(self.file_list) 25 | 26 | def __getitem__(self, index): 27 | if self.split == 'train': 28 | img_file, label_file = self.file_list[index].split() 29 | label = cv2.imread(osp.join(self.root, label_file), 0) 30 | label = np.array(label, dtype=np.float32) 31 | label = label[np.newaxis, :, :] 32 | label[label == 0] = 0 33 | label[np.logical_and(label > 0, label < 127.5)] = 2 34 | label[label >= 127.5] = 1 35 | else: 36 | img_file = self.file_list[index].rstrip() 37 | 38 | img = cv2.imread(osp.join(self.root, img_file)) 39 | img = np.array(img, dtype=np.float32) 40 | img = (img - self.mean).transpose((2, 0, 1)) 41 | 42 | if self.split == 'train': 43 | return img, label 44 | else: 45 | return img 46 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, path='log.txt'): 10 | self.logger = logging.getLogger('Logger') 11 | self.file_handler = logging.FileHandler(path, 'w') 12 | self.stdout_handler = logging.StreamHandler() 13 | self.logger.addHandler(self.file_handler) 14 | self.logger.addHandler(self.stdout_handler) 15 | self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 16 | self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 17 | self.logger.setLevel(logging.INFO) 18 | 19 | def info(self, txt): 20 | self.logger.info(txt) 21 | 22 | def close(self): 23 | self.file_handler.close() 24 | self.stdout_handler.close() 25 | 26 | 27 | class Averagvalue(object): 28 | """Computes and stores the average and current value""" 29 | def __init__(self): 30 | self.reset() 31 | 32 | def reset(self): 33 | self.val = 0 34 | self.avg = 0 35 | self.sum = 0 36 | self.count = 0 37 | 38 | def update(self, val, n=1): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = self.sum / self.count 43 | 44 | 45 | def Cross_entropy_loss(prediction, label): 46 | mask = label.clone() 47 | num_positive = torch.sum((mask == 1).float()).float() 48 | num_negative = torch.sum((mask == 0).float()).float() 49 | 50 | mask[mask == 1] = 1.0 * num_negative / (num_positive + num_negative) 51 | mask[mask == 0] = 1.1 * num_positive / (num_positive + num_negative) 52 | # mask[mask == 2] = 0 53 | selected_idx = mask != 2 54 | prediction = prediction[selected_idx] 55 | label = label[selected_idx] 56 | mask = mask[selected_idx] 57 | cost = F.binary_cross_entropy(prediction, label, weight=mask, reduce=False) 58 | return torch.sum(cost) 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [Richer Convolutional Features for Edge Detection](http://mmcheng.net/rcfedge/) 2 | 3 | This is the PyTorch implementation of our edge detection method, RCF. 4 | 5 | ### Citations 6 | 7 | If you are using the code/model/data provided here in a publication, please consider citing: 8 | 9 | @article{liu2019richer, 10 | title={Richer Convolutional Features for Edge Detection}, 11 | author={Liu, Yun and Cheng, Ming-Ming and Hu, Xiaowei and Bian, Jia-Wang and Zhang, Le and Bai, Xiang and Tang, Jinhui}, 12 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 13 | volume={41}, 14 | number={8}, 15 | pages={1939--1946}, 16 | year={2019}, 17 | publisher={IEEE} 18 | } 19 | 20 | @article{liu2022semantic, 21 | title={Semantic edge detection with diverse deep supervision}, 22 | author={Liu, Yun and Cheng, Ming-Ming and Fan, Deng-Ping and Zhang, Le and Bian, JiaWang and Tao, Dacheng}, 23 | journal={International Journal of Computer Vision}, 24 | volume={130}, 25 | pages={179--198}, 26 | year={2022}, 27 | publisher={Springer} 28 | } 29 | 30 | ### Training 31 | 32 | 1. Clone the RCF repository: 33 | ``` 34 | git clone https://github.com/yun-liu/RCF-PyTorch.git 35 | ``` 36 | 37 | 2. Download the ImageNet-pretrained model ([GitHub](https://github.com/yun-liu/RCF-PyTorch/releases/download/v1.0/vgg16convs.mat) or [Baidu Yun](https://pan.baidu.com/s/1vfntX-cTKnk58atNW5T1lA?pwd=g5af)), and put it into the `$ROOT_DIR` folder. 38 | 39 | 3. Download the datasets as below, and extract these datasets to the `$ROOT_DIR/data/` folder. 40 | 41 | ``` 42 | wget http://mftp.mmcheng.net/liuyun/rcf/data/bsds_pascal_train_pair.lst 43 | wget http://mftp.mmcheng.net/liuyun/rcf/data/HED-BSDS.tar.gz 44 | wget http://mftp.mmcheng.net/liuyun/rcf/data/PASCAL.tar.gz 45 | ``` 46 | 47 | 4. Run the following command to start the training: 48 | ``` 49 | python train.py --save-dir /path/to/output/directory/ 50 | ``` 51 | 52 | ### Testing 53 | 54 | 1. Download the pretrained model (BSDS500+PASCAL: [GitHub](https://github.com/yun-liu/RCF-PyTorch/releases/download/v1.0/bsds500_pascal_model.pth) or [Baidu Yun](https://pan.baidu.com/s/1Tpf_-dIxHmKwH5IeClt0Ng?pwd=03ad)), and put it into the `$ROOT_DIR` folder. 55 | 56 | 2. Run the following command to start the testing: 57 | ``` 58 | python test.py --checkpoint bsds500_pascal_model.pth --save-dir /path/to/output/directory/ 59 | ``` 60 | This pretrained model should achieve an ODS F-measure of 0.812. 61 | 62 | For more information about RCF and edge quality evaluation, please refer to this page: [yun-liu/RCF](https://github.com/yun-liu/RCF) 63 | 64 | ### Edge PR Curves 65 | 66 | We have released the code and data for plotting the edge PR curves of many existing edge detectors [here](https://github.com/yun-liu/plot-edge-pr-curves). 67 | 68 | ### RCF based on other frameworks 69 | 70 | Caffe based RCF: [yun-liu/RCF](https://github.com/yun-liu/RCF) 71 | 72 | Jittor based RCF: [yun-liu/RCF-Jittor](https://github.com/yun-liu/RCF-Jittor) 73 | 74 | ### Acknowledgements 75 | 76 | [1] [balajiselvaraj1601/RCF_Pytorch_Updated](https://github.com/balajiselvaraj1601/RCF_Pytorch_Updated) 77 | 78 | [2] [meteorshowers/RCF-pytorch](https://github.com/meteorshowers/RCF-pytorch) 79 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import os.path as osp 4 | import cv2 5 | import argparse 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import torchvision 9 | from dataset import BSDS_Dataset 10 | from models import RCF 11 | 12 | 13 | def single_scale_test(model, test_loader, test_list, save_dir): 14 | model.eval() 15 | if not osp.isdir(save_dir): 16 | os.makedirs(save_dir) 17 | for idx, image in enumerate(test_loader): 18 | image = image.cuda() 19 | _, _, H, W = image.shape 20 | results = model(image) 21 | all_res = torch.zeros((len(results), 1, H, W)) 22 | for i in range(len(results)): 23 | all_res[i, 0, :, :] = results[i] 24 | filename = osp.splitext(test_list[idx])[0] 25 | torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename)) 26 | fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() 27 | fuse_res = ((1 - fuse_res) * 255).astype(np.uint8) 28 | cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res) 29 | #print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') 30 | print('Running single-scale test done') 31 | 32 | 33 | def multi_scale_test(model, test_loader, test_list, save_dir): 34 | model.eval() 35 | if not osp.isdir(save_dir): 36 | os.makedirs(save_dir) 37 | scale = [0.5, 1, 1.5] 38 | for idx, image in enumerate(test_loader): 39 | in_ = image[0].numpy().transpose((1, 2, 0)) 40 | _, _, H, W = image.shape 41 | ms_fuse = np.zeros((H, W), np.float32) 42 | for k in range(len(scale)): 43 | im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) 44 | im_ = im_.transpose((2, 0, 1)) 45 | results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0)) 46 | fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() 47 | fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR) 48 | ms_fuse += fuse_res 49 | ms_fuse = ms_fuse / len(scale) 50 | ### rescale trick 51 | # ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min()) 52 | filename = osp.splitext(test_list[idx])[0] 53 | ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8) 54 | cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse) 55 | #print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') 56 | print('Running multi-scale test done') 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser(description='PyTorch Testing') 61 | parser.add_argument('--gpu', default='0', type=str, help='GPU ID') 62 | parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint') 63 | parser.add_argument('--save-dir', help='output folder', default='results/RCF') 64 | parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS') 65 | args = parser.parse_args() 66 | 67 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 68 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 69 | 70 | if not osp.isdir(args.save_dir): 71 | os.makedirs(args.save_dir) 72 | 73 | test_dataset = BSDS_Dataset(root=args.dataset, split='test') 74 | test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False) 75 | test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list] 76 | assert len(test_list) == len(test_loader) 77 | 78 | model = RCF().cuda() 79 | 80 | if osp.isfile(args.checkpoint): 81 | print("=> loading checkpoint from '{}'".format(args.checkpoint)) 82 | checkpoint = torch.load(args.checkpoint) 83 | model.load_state_dict(checkpoint) 84 | print("=> checkpoint loaded") 85 | else: 86 | print("=> no checkpoint found at '{}'".format(args.checkpoint)) 87 | 88 | print('Performing the testing...') 89 | single_scale_test(model, test_loader, test_list, args.save_dir) 90 | multi_scale_test(model, test_loader, test_list, args.save_dir) 91 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import scipy.io as sio 5 | import torch.nn.functional as F 6 | 7 | 8 | class RCF(nn.Module): 9 | def __init__(self, pretrained=None): 10 | super(RCF, self).__init__() 11 | self.conv1_1 = nn.Conv2d( 3, 64, 3, padding=1, dilation=1) 12 | self.conv1_2 = nn.Conv2d( 64, 64, 3, padding=1, dilation=1) 13 | self.conv2_1 = nn.Conv2d( 64, 128, 3, padding=1, dilation=1) 14 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1, dilation=1) 15 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1, dilation=1) 16 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1, dilation=1) 17 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1, dilation=1) 18 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1, dilation=1) 19 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1, dilation=1) 20 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1, dilation=1) 21 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 22 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 23 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 24 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 25 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 26 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 27 | self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True) 28 | self.act = nn.ReLU(inplace=True) 29 | 30 | self.conv1_1_down = nn.Conv2d( 64, 21, 1) 31 | self.conv1_2_down = nn.Conv2d( 64, 21, 1) 32 | self.conv2_1_down = nn.Conv2d(128, 21, 1) 33 | self.conv2_2_down = nn.Conv2d(128, 21, 1) 34 | self.conv3_1_down = nn.Conv2d(256, 21, 1) 35 | self.conv3_2_down = nn.Conv2d(256, 21, 1) 36 | self.conv3_3_down = nn.Conv2d(256, 21, 1) 37 | self.conv4_1_down = nn.Conv2d(512, 21, 1) 38 | self.conv4_2_down = nn.Conv2d(512, 21, 1) 39 | self.conv4_3_down = nn.Conv2d(512, 21, 1) 40 | self.conv5_1_down = nn.Conv2d(512, 21, 1) 41 | self.conv5_2_down = nn.Conv2d(512, 21, 1) 42 | self.conv5_3_down = nn.Conv2d(512, 21, 1) 43 | 44 | self.score_dsn1 = nn.Conv2d(21, 1, 1) 45 | self.score_dsn2 = nn.Conv2d(21, 1, 1) 46 | self.score_dsn3 = nn.Conv2d(21, 1, 1) 47 | self.score_dsn4 = nn.Conv2d(21, 1, 1) 48 | self.score_dsn5 = nn.Conv2d(21, 1, 1) 49 | self.score_fuse = nn.Conv2d(5, 1, 1) 50 | 51 | self.weight_deconv2 = self._make_bilinear_weights( 4, 1).cuda() 52 | self.weight_deconv3 = self._make_bilinear_weights( 8, 1).cuda() 53 | self.weight_deconv4 = self._make_bilinear_weights(16, 1).cuda() 54 | self.weight_deconv5 = self._make_bilinear_weights(16, 1).cuda() 55 | 56 | # init weights 57 | self.apply(self._init_weights) 58 | if pretrained is not None: 59 | vgg16 = sio.loadmat(pretrained) 60 | torch_params = self.state_dict() 61 | 62 | for k in vgg16.keys(): 63 | name_par = k.split('-') 64 | size = len(name_par) 65 | if size == 2: 66 | name_space = name_par[0] + '.' + name_par[1] 67 | data = np.squeeze(vgg16[k]) 68 | torch_params[name_space] = torch.from_numpy(data) 69 | self.load_state_dict(torch_params) 70 | 71 | def _init_weights(self, m): 72 | if isinstance(m, nn.Conv2d): 73 | m.weight.data.normal_(0, 0.01) 74 | if m.weight.data.shape == torch.Size([1, 5, 1, 1]): 75 | nn.init.constant_(m.weight, 0.2) 76 | if m.bias is not None: 77 | nn.init.constant_(m.bias, 0) 78 | 79 | # Based on HED implementation @ https://github.com/xwjabc/hed 80 | def _make_bilinear_weights(self, size, num_channels): 81 | factor = (size + 1) // 2 82 | if size % 2 == 1: 83 | center = factor - 1 84 | else: 85 | center = factor - 0.5 86 | og = np.ogrid[:size, :size] 87 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 88 | filt = torch.from_numpy(filt) 89 | w = torch.zeros(num_channels, num_channels, size, size) 90 | w.requires_grad = False 91 | for i in range(num_channels): 92 | for j in range(num_channels): 93 | if i == j: 94 | w[i, j] = filt 95 | return w 96 | 97 | # Based on BDCN implementation @ https://github.com/pkuCactus/BDCN 98 | def _crop(self, data, img_h, img_w, crop_h, crop_w): 99 | _, _, h, w = data.size() 100 | assert(img_h <= h and img_w <= w) 101 | data = data[:, :, crop_h:crop_h + img_h, crop_w:crop_w + img_w] 102 | return data 103 | 104 | def forward(self, x): 105 | img_h, img_w = x.shape[2], x.shape[3] 106 | conv1_1 = self.act(self.conv1_1(x)) 107 | conv1_2 = self.act(self.conv1_2(conv1_1)) 108 | pool1 = self.pool1(conv1_2) 109 | conv2_1 = self.act(self.conv2_1(pool1)) 110 | conv2_2 = self.act(self.conv2_2(conv2_1)) 111 | pool2 = self.pool2(conv2_2) 112 | conv3_1 = self.act(self.conv3_1(pool2)) 113 | conv3_2 = self.act(self.conv3_2(conv3_1)) 114 | conv3_3 = self.act(self.conv3_3(conv3_2)) 115 | pool3 = self.pool3(conv3_3) 116 | conv4_1 = self.act(self.conv4_1(pool3)) 117 | conv4_2 = self.act(self.conv4_2(conv4_1)) 118 | conv4_3 = self.act(self.conv4_3(conv4_2)) 119 | pool4 = self.pool4(conv4_3) 120 | conv5_1 = self.act(self.conv5_1(pool4)) 121 | conv5_2 = self.act(self.conv5_2(conv5_1)) 122 | conv5_3 = self.act(self.conv5_3(conv5_2)) 123 | 124 | conv1_1_down = self.conv1_1_down(conv1_1) 125 | conv1_2_down = self.conv1_2_down(conv1_2) 126 | conv2_1_down = self.conv2_1_down(conv2_1) 127 | conv2_2_down = self.conv2_2_down(conv2_2) 128 | conv3_1_down = self.conv3_1_down(conv3_1) 129 | conv3_2_down = self.conv3_2_down(conv3_2) 130 | conv3_3_down = self.conv3_3_down(conv3_3) 131 | conv4_1_down = self.conv4_1_down(conv4_1) 132 | conv4_2_down = self.conv4_2_down(conv4_2) 133 | conv4_3_down = self.conv4_3_down(conv4_3) 134 | conv5_1_down = self.conv5_1_down(conv5_1) 135 | conv5_2_down = self.conv5_2_down(conv5_2) 136 | conv5_3_down = self.conv5_3_down(conv5_3) 137 | 138 | out1 = self.score_dsn1(conv1_1_down + conv1_2_down) 139 | out2 = self.score_dsn2(conv2_1_down + conv2_2_down) 140 | out3 = self.score_dsn3(conv3_1_down + conv3_2_down + conv3_3_down) 141 | out4 = self.score_dsn4(conv4_1_down + conv4_2_down + conv4_3_down) 142 | out5 = self.score_dsn5(conv5_1_down + conv5_2_down + conv5_3_down) 143 | 144 | out2 = F.conv_transpose2d(out2, self.weight_deconv2, stride=2) 145 | out3 = F.conv_transpose2d(out3, self.weight_deconv3, stride=4) 146 | out4 = F.conv_transpose2d(out4, self.weight_deconv4, stride=8) 147 | out5 = F.conv_transpose2d(out5, self.weight_deconv5, stride=8) 148 | 149 | out2 = self._crop(out2, img_h, img_w, 1, 1) 150 | out3 = self._crop(out3, img_h, img_w, 2, 2) 151 | out4 = self._crop(out4, img_h, img_w, 4, 4) 152 | out5 = self._crop(out5, img_h, img_w, 0, 0) 153 | 154 | fuse = torch.cat((out1, out2, out3, out4, out5), dim=1) 155 | fuse = self.score_fuse(fuse) 156 | results = [out1, out2, out3, out4, out5, fuse] 157 | results = [torch.sigmoid(r) for r in results] 158 | return results 159 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import os.path as osp 4 | import cv2 5 | import argparse 6 | import time 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torchvision 10 | from dataset import BSDS_Dataset 11 | from models import RCF 12 | from utils import Logger, Averagvalue, Cross_entropy_loss 13 | 14 | 15 | def train(args, model, train_loader, optimizer, epoch, logger): 16 | batch_time = Averagvalue() 17 | losses = Averagvalue() 18 | model.train() 19 | end = time.time() 20 | counter = 0 21 | for i, (image, label) in enumerate(train_loader): 22 | image, label = image.cuda(), label.cuda() 23 | outputs = model(image) 24 | loss = torch.zeros(1).cuda() 25 | for o in outputs: 26 | loss = loss + Cross_entropy_loss(o, label) 27 | counter += 1 28 | loss = loss / args.iter_size 29 | loss.backward() 30 | if counter == args.iter_size: 31 | optimizer.step() 32 | optimizer.zero_grad() 33 | counter = 0 34 | # measure accuracy and record loss 35 | losses.update(loss.item(), image.size(0)) 36 | batch_time.update(time.time() - end) 37 | if i % args.print_freq == 0: 38 | logger.info('Epoch: [{0}/{1}][{2}/{3}] '.format(epoch + 1, args.max_epoch, i, len(train_loader)) + \ 39 | 'Time {batch_time.val:.3f} (avg: {batch_time.avg:.3f}) '.format(batch_time=batch_time) + \ 40 | 'Loss {loss.val:f} (avg: {loss.avg:f}) '.format(loss=losses)) 41 | end = time.time() 42 | 43 | 44 | def single_scale_test(model, test_loader, test_list, save_dir): 45 | model.eval() 46 | if not osp.isdir(save_dir): 47 | os.makedirs(save_dir) 48 | for idx, image in enumerate(test_loader): 49 | image = image.cuda() 50 | _, _, H, W = image.shape 51 | results = model(image) 52 | all_res = torch.zeros((len(results), 1, H, W)) 53 | for i in range(len(results)): 54 | all_res[i, 0, :, :] = results[i] 55 | filename = osp.splitext(test_list[idx])[0] 56 | torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename)) 57 | fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() 58 | fuse_res = ((1 - fuse_res) * 255).astype(np.uint8) 59 | cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res) 60 | #print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') 61 | logger.info('Running single-scale test done') 62 | 63 | 64 | def multi_scale_test(model, test_loader, test_list, save_dir): 65 | model.eval() 66 | if not osp.isdir(save_dir): 67 | os.makedirs(save_dir) 68 | scale = [0.5, 1, 1.5] 69 | for idx, image in enumerate(test_loader): 70 | in_ = image[0].numpy().transpose((1, 2, 0)) 71 | _, _, H, W = image.shape 72 | ms_fuse = np.zeros((H, W), np.float32) 73 | for k in range(len(scale)): 74 | im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) 75 | im_ = im_.transpose((2, 0, 1)) 76 | results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0)) 77 | fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy() 78 | fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR) 79 | ms_fuse += fuse_res 80 | ms_fuse = ms_fuse / len(scale) 81 | ### rescale trick 82 | # ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min()) 83 | filename = osp.splitext(test_list[idx])[0] 84 | ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8) 85 | cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse) 86 | #print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='') 87 | logger.info('Running multi-scale test done') 88 | 89 | 90 | if __name__ == '__main__': 91 | parser = argparse.ArgumentParser(description='PyTorch Training') 92 | parser.add_argument('--batch-size', default=1, type=int, help='batch size') 93 | parser.add_argument('--lr', default=1e-6, type=float, help='initial learning rate') 94 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 95 | parser.add_argument('--weight-decay', default=2e-4, type=float, help='weight decay') 96 | parser.add_argument('--stepsize', default=3, type=int, help='learning rate step size') 97 | parser.add_argument('--gamma', default=0.1, type=float, help='learning rate decay rate') 98 | parser.add_argument('--max-epoch', default=10, type=int, help='the number of training epochs') 99 | parser.add_argument('--iter-size', default=10, type=int, help='iter size') 100 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number') 101 | parser.add_argument('--print-freq', default=200, type=int, help='print frequency') 102 | parser.add_argument('--gpu', default='0', type=str, help='GPU ID') 103 | parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint') 104 | parser.add_argument('--save-dir', help='output folder', default='results/RCF') 105 | parser.add_argument('--dataset', help='root folder of dataset', default='data') 106 | args = parser.parse_args() 107 | 108 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 109 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 110 | 111 | if not osp.isdir(args.save_dir): 112 | os.makedirs(args.save_dir) 113 | 114 | logger = Logger(osp.join(args.save_dir, 'log.txt')) 115 | logger.info('Called with args:') 116 | for (key, value) in vars(args).items(): 117 | logger.info('{0:15} | {1}'.format(key, value)) 118 | 119 | train_dataset = BSDS_Dataset(root=args.dataset, split='train') 120 | test_dataset = BSDS_Dataset(root=osp.join(args.dataset, 'HED-BSDS'), split='test') 121 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, drop_last=True, shuffle=True) 122 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, drop_last=False, shuffle=False) 123 | test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list] 124 | assert len(test_list) == len(test_loader) 125 | 126 | model = RCF(pretrained='vgg16convs.mat').cuda() 127 | parameters = {'conv1-4.weight': [], 'conv1-4.bias': [], 'conv5.weight': [], 'conv5.bias': [], 128 | 'conv_down_1-5.weight': [], 'conv_down_1-5.bias': [], 'score_dsn_1-5.weight': [], 129 | 'score_dsn_1-5.bias': [], 'score_fuse.weight': [], 'score_fuse.bias': []} 130 | for pname, p in model.named_parameters(): 131 | if pname in ['conv1_1.weight','conv1_2.weight', 132 | 'conv2_1.weight','conv2_2.weight', 133 | 'conv3_1.weight','conv3_2.weight','conv3_3.weight', 134 | 'conv4_1.weight','conv4_2.weight','conv4_3.weight']: 135 | parameters['conv1-4.weight'].append(p) 136 | elif pname in ['conv1_1.bias','conv1_2.bias', 137 | 'conv2_1.bias','conv2_2.bias', 138 | 'conv3_1.bias','conv3_2.bias','conv3_3.bias', 139 | 'conv4_1.bias','conv4_2.bias','conv4_3.bias']: 140 | parameters['conv1-4.bias'].append(p) 141 | elif pname in ['conv5_1.weight','conv5_2.weight','conv5_3.weight']: 142 | parameters['conv5.weight'].append(p) 143 | elif pname in ['conv5_1.bias','conv5_2.bias','conv5_3.bias']: 144 | parameters['conv5.bias'].append(p) 145 | elif pname in ['conv1_1_down.weight','conv1_2_down.weight', 146 | 'conv2_1_down.weight','conv2_2_down.weight', 147 | 'conv3_1_down.weight','conv3_2_down.weight','conv3_3_down.weight', 148 | 'conv4_1_down.weight','conv4_2_down.weight','conv4_3_down.weight', 149 | 'conv5_1_down.weight','conv5_2_down.weight','conv5_3_down.weight']: 150 | parameters['conv_down_1-5.weight'].append(p) 151 | elif pname in ['conv1_1_down.bias','conv1_2_down.bias', 152 | 'conv2_1_down.bias','conv2_2_down.bias', 153 | 'conv3_1_down.bias','conv3_2_down.bias','conv3_3_down.bias', 154 | 'conv4_1_down.bias','conv4_2_down.bias','conv4_3_down.bias', 155 | 'conv5_1_down.bias','conv5_2_down.bias','conv5_3_down.bias']: 156 | parameters['conv_down_1-5.bias'].append(p) 157 | elif pname in ['score_dsn1.weight','score_dsn2.weight','score_dsn3.weight', 'score_dsn4.weight','score_dsn5.weight']: 158 | parameters['score_dsn_1-5.weight'].append(p) 159 | elif pname in ['score_dsn1.bias','score_dsn2.bias','score_dsn3.bias', 'score_dsn4.bias','score_dsn5.bias']: 160 | parameters['score_dsn_1-5.bias'].append(p) 161 | elif pname in ['score_fuse.weight']: 162 | parameters['score_fuse.weight'].append(p) 163 | elif pname in ['score_fuse.bias']: 164 | parameters['score_fuse.bias'].append(p) 165 | 166 | optimizer = torch.optim.SGD([ 167 | {'params': parameters['conv1-4.weight'], 'lr': args.lr*1, 'weight_decay': args.weight_decay}, 168 | {'params': parameters['conv1-4.bias'], 'lr': args.lr*2, 'weight_decay': 0.}, 169 | {'params': parameters['conv5.weight'], 'lr': args.lr*100, 'weight_decay': args.weight_decay}, 170 | {'params': parameters['conv5.bias'], 'lr': args.lr*200, 'weight_decay': 0.}, 171 | {'params': parameters['conv_down_1-5.weight'], 'lr': args.lr*0.1, 'weight_decay': args.weight_decay}, 172 | {'params': parameters['conv_down_1-5.bias'], 'lr': args.lr*0.2, 'weight_decay': 0.}, 173 | {'params': parameters['score_dsn_1-5.weight'], 'lr': args.lr*0.01, 'weight_decay': args.weight_decay}, 174 | {'params': parameters['score_dsn_1-5.bias'], 'lr': args.lr*0.02, 'weight_decay': 0.}, 175 | {'params': parameters['score_fuse.weight'], 'lr': args.lr*0.001, 'weight_decay': args.weight_decay}, 176 | {'params': parameters['score_fuse.bias'], 'lr': args.lr*0.002, 'weight_decay': 0.}, 177 | ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 178 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma) 179 | 180 | if args.resume is not None: 181 | if osp.isfile(args.resume): 182 | logger.info("=> loading checkpoint from '{}'".format(args.resume)) 183 | checkpoint = torch.load(args.resume) 184 | model.load_state_dict(checkpoint['state_dict']) 185 | optimizer.load_state_dict(checkpoint['optimizer']) 186 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 187 | args.start_epoch = checkpoint['epoch'] + 1 188 | logger.info("=> checkpoint loaded") 189 | else: 190 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 191 | 192 | for epoch in range(args.start_epoch, args.max_epoch): 193 | logger.info('Performing initial testing...') 194 | train(args, model, train_loader, optimizer, epoch, logger) 195 | save_dir = osp.join(args.save_dir, 'epoch%d-test' % (epoch + 1)) 196 | single_scale_test(model, test_loader, test_list, save_dir) 197 | multi_scale_test(model, test_loader, test_list, save_dir) 198 | # Save checkpoint 199 | save_file = osp.join(args.save_dir, 'checkpoint_epoch{}.pth'.format(epoch + 1)) 200 | torch.save({ 201 | 'epoch': epoch, 202 | 'args': args, 203 | 'state_dict': model.state_dict(), 204 | 'optimizer': optimizer.state_dict(), 205 | 'lr_scheduler': lr_scheduler.state_dict(), 206 | }, save_file) 207 | lr_scheduler.step() # will adjust learning rate 208 | 209 | logger.close() 210 | --------------------------------------------------------------------------------