├── README.md ├── data ├── __init__.py ├── dataset.py └── transform.py ├── lib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── data_prefetcher.cpython-37.pyc │ ├── data_prefetcher_gray.cpython-37.pyc │ └── lr_finder.cpython-37.pyc ├── data_prefetcher.py ├── dataset.py ├── lr_finder.py └── transform.py ├── lscloss.py ├── net_agg.py ├── requirements.txt ├── test.py ├── tools.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # SCWSSOD 2 | This is the implementation of `Structure-Consistent Weakly Supervised Salient Object Detection with Local Saliency Coherence (AAAI2021)`. 3 | # Training 4 | ## Requirements 5 | 1. Clone this project and install required pytorch first. 6 | 2. pip install -r requirements.txt 7 | ## Training data 8 | The training data can be downloaed from [Scribble_Saliency](https://github.com/JingZhang617/Scribble_Saliency). 9 | ## Pretrained weights for backbone 10 | The pretrianed weight for backbone can be downloaded from [Res50](https://drive.google.com/file/d/1arzcXccUPW1QpvBrAaaBv1CapviBQAJL/view?usp=sharing). 11 | ## Traing procedure 12 | 1. Download training data and put them into 'data' folder. 13 | 2. run train.py 14 | # Testing 15 | ## Test Model 16 | The test model can be downloaded from [model](https://drive.google.com/file/d/1X8Y7NcnzRY8we2tgDS6KRVOde5ij7yWE/view?usp=sharing). 17 | ## Testing procedure 18 | 1. Modify test path 19 | 2. run test.py 20 | ## Predicted Maps 21 | The predicted Saliency Maps can be downloaded from [prediction](https://drive.google.com/file/d/1a_Hrl0YhMNdNsskKLrZ7JJhZwJtxwyqs/view?usp=sharing). 22 | ## Evaluation 23 | The evaluation code is from [GCPANet](https://github.com/JosephChenHub/GCPANet). 24 | # Others 25 | The code is based on [GCPANet](https://github.com/JosephChenHub/GCPANet) and [GatedCRFLoss](https://github.com/LEONOB2014/GatedCRFLoss). 26 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/SCWSSOD/f8650567cbbc8df5bf6edc32a633c47a885574cd/data/__init__.py -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import os 5 | import os.path as osp 6 | import cv2 7 | import torch 8 | import numpy as np 9 | try: 10 | from . import transform 11 | except: 12 | import transform 13 | 14 | from torch.utils.data import Dataset, DataLoader 15 | from lib.data_prefetcher import DataPrefetcher 16 | 17 | class Config(object): 18 | def __init__(self, **kwargs): 19 | self.kwargs = kwargs 20 | print('\nParameters...') 21 | for k, v in self.kwargs.items(): 22 | print('%-10s: %s'%(k, v)) 23 | 24 | if 'ECSSD' in self.kwargs['datapath']: 25 | self.mean = np.array([[[117.15, 112.48, 92.86]]]) 26 | self.std = np.array([[[ 56.36, 53.82, 54.23]]]) 27 | elif 'DUTS' in self.kwargs['datapath']: 28 | self.mean = np.array([[[124.55, 118.90, 102.94]]]) 29 | self.std = np.array([[[ 56.77, 55.97, 57.50]]]) 30 | elif 'DUT-OMRON' in self.kwargs['datapath']: 31 | self.mean = np.array([[[120.61, 121.86, 114.92]]]) 32 | self.std = np.array([[[ 58.10, 57.16, 61.09]]]) 33 | elif 'MSRA-10K' in self.kwargs['datapath']: 34 | self.mean = np.array([[[115.57, 110.48, 100.00]]]) 35 | self.std = np.array([[[ 57.55, 54.89, 55.30]]]) 36 | elif 'MSRA-B' in self.kwargs['datapath']: 37 | self.mean = np.array([[[114.87, 110.47, 95.76]]]) 38 | self.std = np.array([[[ 58.12, 55.30, 55.82]]]) 39 | elif 'SED2' in self.kwargs['datapath']: 40 | self.mean = np.array([[[126.34, 133.87, 133.72]]]) 41 | self.std = np.array([[[ 45.88, 45.59, 48.13]]]) 42 | elif 'PASCAL-S' in self.kwargs['datapath']: 43 | self.mean = np.array([[[117.02, 112.75, 102.48]]]) 44 | self.std = np.array([[[ 59.81, 58.96, 60.44]]]) 45 | elif 'HKU-IS' in self.kwargs['datapath']: 46 | self.mean = np.array([[[123.58, 121.69, 104.22]]]) 47 | self.std = np.array([[[ 55.40, 53.55, 55.19]]]) 48 | elif 'SOD' in self.kwargs['datapath']: 49 | self.mean = np.array([[[109.91, 112.13, 93.90]]]) 50 | self.std = np.array([[[ 53.29, 50.45, 48.06]]]) 51 | elif 'THUR15K' in self.kwargs['datapath']: 52 | self.mean = np.array([[[122.60, 120.28, 104.46]]]) 53 | self.std = np.array([[[ 55.99, 55.39, 56.97]]]) 54 | elif 'SOC' in self.kwargs['datapath']: 55 | self.mean = np.array([[[120.48, 111.78, 101.27]]]) 56 | self.std = np.array([[[ 58.51, 56.73, 56.38]]]) 57 | else: 58 | #raise ValueError 59 | self.mean = np.array([[[0.485*256, 0.456*256, 0.406*256]]]) 60 | self.std = np.array([[[0.229*256, 0.224*256, 0.225*256]]]) 61 | 62 | def __getattr__(self, name): 63 | if name in self.kwargs: 64 | return self.kwargs[name] 65 | else: 66 | return None 67 | 68 | 69 | class Data(Dataset): 70 | def __init__(self, cfg): 71 | with open(cfg.datapath+'/'+cfg.mode+'.txt', 'r') as lines: 72 | self.samples = [] 73 | for line in lines: 74 | imagepath = cfg.datapath + '/image/' + line.strip() + '.jpg' 75 | maskpath = cfg.datapath + '/scribble/' + line.strip() + '.png' 76 | self.samples.append([imagepath, maskpath]) 77 | 78 | if cfg.mode == 'train': 79 | self.transform = transform.Compose(transform.Normalize(mean=cfg.mean, std=cfg.std), 80 | transform.Resize(320, 320), 81 | transform.RandomHorizontalFlip(), 82 | transform.RandomCrop(320, 320), 83 | transform.ToTensor()) 84 | elif cfg.mode == 'test': 85 | self.transform = transform.Compose(transform.Normalize(mean=cfg.mean, std=cfg.std), 86 | transform.Resize(320, 320), 87 | transform.ToTensor()) 88 | else: 89 | raise ValueError 90 | 91 | def __getitem__(self, idx): 92 | imagepath, maskpath = self.samples[idx] 93 | image = cv2.imread(imagepath).astype(np.float32)[:,:,::-1] 94 | mask = cv2.imread(maskpath).astype(np.float32)[:,:,::-1] 95 | H, W, C = mask.shape 96 | image, mask = self.transform(image, mask) 97 | mask[mask == 0.] = 255. 98 | mask[mask == 2.] = 0. 99 | return image, mask, (H, W), maskpath.split('/')[-1] 100 | 101 | def __len__(self): 102 | return len(self.samples) 103 | 104 | 105 | if __name__=='__main__': 106 | import matplotlib.pyplot as plt 107 | plt.ion() 108 | 109 | cfg = Config(mode='train', datapath='./DUTS') 110 | data = Data(cfg) 111 | loader = DataLoader(data, batch_size=1, shuffle=True, num_workers=8) 112 | prefetcher = DataPrefetcher(loader) 113 | batch_idx = -1 114 | image, mask = prefetcher.next() 115 | image = image[0].permute(1,2,0).cpu().numpy()*cfg.std + cfg.mean 116 | mask = mask[0].cpu().numpy() 117 | plt.subplot(121) 118 | plt.imshow(np.uint8(image)) 119 | plt.subplot(122) 120 | plt.imshow(mask) 121 | input() 122 | 123 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | 8 | class Compose(object): 9 | def __init__(self, *ops): 10 | self.ops = ops 11 | 12 | def __call__(self, image, mask): 13 | for op in self.ops: 14 | image, mask = op(image, mask) 15 | return image, mask 16 | 17 | class RGBDCompose(object): 18 | def __init__(self, *ops): 19 | self.ops = ops 20 | 21 | def __call__(self, image, depth, mask): 22 | for op in self.ops: 23 | image, depth, mask = op(image, depth, mask) 24 | return image, depth, mask 25 | 26 | 27 | class Normalize(object): 28 | def __init__(self, mean, std): 29 | self.mean = mean 30 | self.std = std 31 | 32 | def __call__(self, image, mask): 33 | image = (image - self.mean)/self.std 34 | # mask /= 255 35 | return image, mask 36 | 37 | class RGBDNormalize(object): 38 | def __init__(self, mean, std): 39 | self.mean = mean 40 | self.std = std 41 | 42 | def __call__(self, image, depth, mask): 43 | image = (image - self.mean)/self.std 44 | depth = (depth - self.mean)/self.std 45 | mask /= 255 46 | return image, mask 47 | 48 | class Resize(object): 49 | def __init__(self, H, W): 50 | self.H = H 51 | self.W = W 52 | 53 | def __call__(self, image, mask): 54 | image = cv2.resize(image, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 55 | mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 56 | return image, mask 57 | 58 | class RandomCrop(object): 59 | def __init__(self, H, W): 60 | self.H = H 61 | self.W = W 62 | 63 | def __call__(self, image, mask): 64 | H,W,_ = image.shape 65 | xmin = np.random.randint(W-self.W+1) 66 | ymin = np.random.randint(H-self.H+1) 67 | image = image[ymin:ymin+self.H, xmin:xmin+self.W, :] 68 | mask = mask[ymin:ymin+self.H, xmin:xmin+self.W, :] 69 | return image, mask 70 | 71 | class RandomHorizontalFlip(object): 72 | def __call__(self, image, mask): 73 | if np.random.randint(2)==1: 74 | image = image[:,::-1,:].copy() 75 | mask = mask[:,::-1,:].copy() 76 | return image, mask 77 | 78 | class ToTensor(object): 79 | def __call__(self, image, mask): 80 | image = torch.from_numpy(image) 81 | image = image.permute(2, 0, 1) 82 | mask = torch.from_numpy(mask) 83 | mask = mask.permute(2, 0, 1) 84 | return image, mask.mean(dim=0, keepdim=True) 85 | 86 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/SCWSSOD/f8650567cbbc8df5bf6edc32a633c47a885574cd/lib/__init__.py -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/SCWSSOD/f8650567cbbc8df5bf6edc32a633c47a885574cd/lib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/data_prefetcher.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/SCWSSOD/f8650567cbbc8df5bf6edc32a633c47a885574cd/lib/__pycache__/data_prefetcher.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/data_prefetcher_gray.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/SCWSSOD/f8650567cbbc8df5bf6edc32a633c47a885574cd/lib/__pycache__/data_prefetcher_gray.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/lr_finder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/SCWSSOD/f8650567cbbc8df5bf6edc32a633c47a885574cd/lib/__pycache__/lr_finder.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DataPrefetcher(object): 4 | def __init__(self, loader): 5 | self.loader = iter(loader) 6 | self.stream = torch.cuda.Stream() 7 | self.preload() 8 | 9 | 10 | def preload(self): 11 | try: 12 | self.next_input, self.next_target, _, _ = next(self.loader) 13 | except StopIteration: 14 | self.next_input = None 15 | self.next_target = None 16 | return 17 | 18 | with torch.cuda.stream(self.stream): 19 | self.next_input = self.next_input.cuda(non_blocking=True) 20 | self.next_target = self.next_target.cuda(non_blocking=True) 21 | self.next_input = self.next_input.float() #if need 22 | self.next_target = self.next_target.float() #if need 23 | 24 | def next(self): 25 | torch.cuda.current_stream().wait_stream(self.stream) 26 | input = self.next_input 27 | target = self.next_target 28 | self.preload() 29 | return input, target 30 | -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import os 5 | import os.path as osp 6 | import cv2 7 | import torch 8 | import numpy as np 9 | try: 10 | from . import transform 11 | except: 12 | import transform 13 | 14 | from torch.utils.data import Dataset 15 | 16 | class Config(object): 17 | def __init__(self, **kwargs): 18 | self.kwargs = kwargs 19 | print('\nParameters...') 20 | for k, v in self.kwargs.items(): 21 | print('%-10s: %s'%(k, v)) 22 | 23 | if 'ECSSD' in self.kwargs['datapath']: 24 | self.mean = np.array([[[117.15, 112.48, 92.86]]]) 25 | self.std = np.array([[[ 56.36, 53.82, 54.23]]]) 26 | elif 'DUTS' in self.kwargs['datapath']: 27 | self.mean = np.array([[[124.55, 118.90, 102.94]]]) 28 | self.std = np.array([[[ 56.77, 55.97, 57.50]]]) 29 | elif 'DUT-OMRON' in self.kwargs['datapath']: 30 | self.mean = np.array([[[120.61, 121.86, 114.92]]]) 31 | self.std = np.array([[[ 58.10, 57.16, 61.09]]]) 32 | elif 'MSRA-10K' in self.kwargs['datapath']: 33 | self.mean = np.array([[[115.57, 110.48, 100.00]]]) 34 | self.std = np.array([[[ 57.55, 54.89, 55.30]]]) 35 | elif 'MSRA-B' in self.kwargs['datapath']: 36 | self.mean = np.array([[[114.87, 110.47, 95.76]]]) 37 | self.std = np.array([[[ 58.12, 55.30, 55.82]]]) 38 | elif 'SED2' in self.kwargs['datapath']: 39 | self.mean = np.array([[[126.34, 133.87, 133.72]]]) 40 | self.std = np.array([[[ 45.88, 45.59, 48.13]]]) 41 | elif 'PASCAL-S' in self.kwargs['datapath']: 42 | self.mean = np.array([[[117.02, 112.75, 102.48]]]) 43 | self.std = np.array([[[ 59.81, 58.96, 60.44]]]) 44 | elif 'HKU-IS' in self.kwargs['datapath']: 45 | self.mean = np.array([[[123.58, 121.69, 104.22]]]) 46 | self.std = np.array([[[ 55.40, 53.55, 55.19]]]) 47 | elif 'SOD' in self.kwargs['datapath']: 48 | self.mean = np.array([[[109.91, 112.13, 93.90]]]) 49 | self.std = np.array([[[ 53.29, 50.45, 48.06]]]) 50 | elif 'THUR15K' in self.kwargs['datapath']: 51 | self.mean = np.array([[[122.60, 120.28, 104.46]]]) 52 | self.std = np.array([[[ 55.99, 55.39, 56.97]]]) 53 | elif 'SOC' in self.kwargs['datapath']: 54 | self.mean = np.array([[[120.48, 111.78, 101.27]]]) 55 | self.std = np.array([[[ 58.51, 56.73, 56.38]]]) 56 | else: 57 | #raise ValueError 58 | self.mean = np.array([[[0.485*256, 0.456*256, 0.406*256]]]) 59 | self.std = np.array([[[0.229*256, 0.224*256, 0.225*256]]]) 60 | 61 | def __getattr__(self, name): 62 | if name in self.kwargs: 63 | return self.kwargs[name] 64 | else: 65 | return None 66 | 67 | 68 | class Data(Dataset): 69 | def __init__(self, cfg): 70 | with open(os.path.join(cfg.datapath , cfg.mode+'.txt'), 'r') as lines: 71 | self.samples = [] 72 | for line in lines: 73 | imagepath = os.path.join(cfg.datapath, 'image', line.strip() + '.jpg') 74 | maskpath = os.path.join(cfg.datapath, 'mask', line.strip() + '.png') 75 | self.samples.append([imagepath, maskpath]) 76 | 77 | if cfg.mode == 'train': 78 | self.transform = transform.Compose( transform.Normalize(mean=cfg.mean, std=cfg.std), 79 | transform.Resize(320, 320), 80 | transform.RandomHorizontalFlip(), 81 | transform.RandomCrop(288,288), 82 | transform.ToTensor()) 83 | elif cfg.mode == 'test': 84 | self.transform = transform.Compose( transform.Normalize(mean=cfg.mean, std=cfg.std), 85 | transform.Resize(320, 320), 86 | transform.ToTensor()) 87 | else: 88 | raise ValueError 89 | 90 | def __getitem__(self, idx): 91 | imagepath, maskpath = self.samples[idx] 92 | image = cv2.imread(imagepath).astype(np.float32)[:,:,::-1] 93 | mask = cv2.imread(maskpath).astype(np.float32)[:,:,::-1] 94 | H, W, C = mask.shape 95 | image, mask = self.transform(image, mask) 96 | return image, mask, (H, W), maskpath.split('/')[-1] 97 | 98 | def __len__(self): 99 | return len(self.samples) 100 | 101 | 102 | 103 | if __name__=='__main__': 104 | import matplotlib.pyplot as plt 105 | plt.ion() 106 | 107 | cfg = Config(mode='train', datapath='./data/DUTS') 108 | data = Data(cfg) 109 | for i in range(100): 110 | image, depth, mask = data[i] 111 | image = image.permute(1,2,0).numpy()*cfg.std + cfg.mean 112 | mask = mask.numpy() 113 | plt.subplot(121) 114 | plt.imshow(np.uint8(image)) 115 | plt.subplot(122) 116 | plt.imshow(mask) 117 | input() 118 | -------------------------------------------------------------------------------- /lib/lr_finder.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, with_statement, division 2 | import copy 3 | import os 4 | import torch 5 | from tqdm import tqdm 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | import matplotlib.pyplot as plt 8 | from lib.data_prefetcher import DataPrefetcher 9 | import torch.nn.functional as F 10 | 11 | 12 | class LRFinder(object): 13 | """Learning rate range test. 14 | 15 | The learning rate range test increases the learning rate in a pre-training run 16 | between two boundaries in a linear or exponential manner. It provides valuable 17 | information on how well the network can be trained over a range of learning rates 18 | and what is the optimal learning rate. 19 | 20 | Arguments: 21 | model (torch.nn.Module): wrapped model. 22 | optimizer (torch.optim.Optimizer): wrapped optimizer where the defined learning 23 | is assumed to be the lower boundary of the range test. 24 | criterion (torch.nn.Module): wrapped loss function. 25 | device (str or torch.device, optional): a string ("cpu" or "cuda") with an 26 | optional ordinal for the device type (e.g. "cuda:X", where is the ordinal). 27 | Alternatively, can be an object representing the device on which the 28 | computation will take place. Default: None, uses the same device as `model`. 29 | memory_cache (boolean): if this flag is set to True, `state_dict` of model and 30 | optimizer will be cached in memory. Otherwise, they will be saved to files 31 | under the `cache_dir`. 32 | cache_dir (string): path for storing temporary files. If no path is specified, 33 | system-wide temporary directory is used. 34 | Notice that this parameter will be ignored if `memory_cache` is True. 35 | 36 | Example: 37 | >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") 38 | >>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100) 39 | 40 | Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 41 | fastai/lr_find: https://github.com/fastai/fastai 42 | 43 | """ 44 | 45 | def __init__(self, model, optimizer, criterion, memory_cache=True, cache_dir=None): 46 | self.model = model 47 | self.optimizer = optimizer 48 | self.criterion = criterion 49 | self.history = {"lr": [], "loss": []} 50 | self.best_loss = None 51 | self.memory_cache = memory_cache 52 | self.cache_dir = cache_dir 53 | 54 | # Save the original state of the model and optimizer so they can be restored if 55 | # needed 56 | self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) 57 | self.state_cacher.store('model', self.model.state_dict()) 58 | self.state_cacher.store('optimizer', self.optimizer.state_dict()) 59 | 60 | 61 | def reset(self): 62 | """Restores the model and optimizer to their initial states.""" 63 | self.model.load_state_dict(self.state_cacher.retrieve('model')) 64 | self.optimizer.load_state_dict(self.state_cacher.retrieve('optimizer')) 65 | 66 | def range_test( 67 | self, 68 | train_loader, 69 | val_loader=None, 70 | end_lr=10, 71 | num_iter=100, 72 | step_mode="exp", 73 | smooth_f=0.05, 74 | diverge_th=5, 75 | ): 76 | """Performs the learning rate range test. 77 | 78 | Arguments: 79 | train_loader (torch.utils.data.DataLoader): the training set data laoder. 80 | val_loader (torch.utils.data.DataLoader, optional): if `None` the range test 81 | will only use the training loss. When given a data loader, the model is 82 | evaluated after each iteration on that dataset and the evaluation loss 83 | is used. Note that in this mode the test takes significantly longer but 84 | generally produces more precise results. Default: None. 85 | end_lr (float, optional): the maximum learning rate to test. Default: 10. 86 | num_iter (int, optional): the number of iterations over which the test 87 | occurs. Default: 100. 88 | step_mode (str, optional): one of the available learning rate policies, 89 | linear or exponential ("linear", "exp"). Default: "exp". 90 | smooth_f (float, optional): the loss smoothing factor within the [0, 1[ 91 | interval. Disabled if set to 0, otherwise the loss is smoothed using 92 | exponential smoothing. Default: 0.05. 93 | diverge_th (int, optional): the test is stopped when the loss surpasses the 94 | threshold: diverge_th * best_loss. Default: 5. 95 | 96 | """ 97 | # Reset test results 98 | self.history = {"lr": [], "loss": []} 99 | self.best_loss = None 100 | 101 | # Initialize the proper learning rate policy 102 | if step_mode.lower() == "exp": 103 | lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter) 104 | elif step_mode.lower() == "linear": 105 | lr_schedule = LinearLR(self.optimizer, end_lr, num_iter) 106 | else: 107 | raise ValueError("expected one of (exp, linear), got {}".format(step_mode)) 108 | 109 | if smooth_f < 0 or smooth_f >= 1: 110 | raise ValueError("smooth_f is outside the range [0, 1[") 111 | 112 | # Create an iterator to get data batch by batch 113 | prefetcher = DataPrefetcher(train_loader) 114 | for iteration in tqdm(range(num_iter)): 115 | # Get a new set of inputs and labels 116 | try: 117 | inputs, labels = prefetcher.next() 118 | except StopIteration: 119 | prefetcher = DataPrefetcher(train_loader) 120 | inputs, labels = prefetcher.next() 121 | 122 | # Train on batch and retrieve loss 123 | loss = self._train_batch(inputs, labels) 124 | if val_loader: 125 | loss = self._validate(val_loader) 126 | 127 | # Update the learning rate 128 | lr_schedule.step() 129 | self.history["lr"].append(lr_schedule.get_lr()[0]) 130 | 131 | # Track the best loss and smooth it if smooth_f is specified 132 | if iteration == 0: 133 | self.best_loss = loss 134 | else: 135 | if smooth_f > 0: 136 | loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1] 137 | if loss < self.best_loss: 138 | self.best_loss = loss 139 | 140 | # Check if the loss has diverged; if it has, stop the test 141 | self.history["loss"].append(loss) 142 | if loss > diverge_th * self.best_loss: 143 | print("Stopping early, the loss has diverged") 144 | break 145 | 146 | print("Learning rate search finished. See the graph with {finder_name}.plot()") 147 | 148 | def _train_batch(self, inputs, labels): 149 | # Set model to training mode 150 | self.model.train() 151 | 152 | # Forward pass 153 | self.optimizer.zero_grad() 154 | #outputs = self.model.forward(inputs) 155 | #loss = self.criterion(outputs, labels) 156 | out2, out3, out4, out5 = self.model.forward(inputs) 157 | loss2 = F.binary_cross_entropy_with_logits(out2, labels) 158 | loss3 = F.binary_cross_entropy_with_logits(out3, labels) 159 | loss4 = F.binary_cross_entropy_with_logits(out4, labels) 160 | loss5 = F.binary_cross_entropy_with_logits(out5, labels) 161 | loss = loss2*1 + loss3*0.8 + loss4*0.6 + loss5*0.4 162 | 163 | # Backward pass 164 | loss.backward() 165 | self.optimizer.step() 166 | 167 | return loss.item() 168 | 169 | def _validate(self, dataloader): 170 | # Set model to evaluation mode and disable gradient computation 171 | running_loss = 0 172 | self.model.eval() 173 | with torch.no_grad(): 174 | for inputs, labels in dataloader: 175 | # Move data to the correct device 176 | inputs = inputs.cuda() 177 | labels = labels.cuda() 178 | 179 | # Forward pass and loss computation 180 | outputs = self.model(inputs) 181 | loss = self.criterion(outputs, labels) 182 | running_loss += loss.item() * inputs.size(0) 183 | 184 | return running_loss / len(dataloader.dataset) 185 | 186 | def plot(self, skip_start=10, skip_end=5, log_lr=True): 187 | """Plots the learning rate range test. 188 | 189 | Arguments: 190 | skip_start (int, optional): number of batches to trim from the start. 191 | Default: 10. 192 | skip_end (int, optional): number of batches to trim from the start. 193 | Default: 5. 194 | log_lr (bool, optional): True to plot the learning rate in a logarithmic 195 | scale; otherwise, plotted in a linear scale. Default: True. 196 | 197 | """ 198 | 199 | if skip_start < 0: 200 | raise ValueError("skip_start cannot be negative") 201 | if skip_end < 0: 202 | raise ValueError("skip_end cannot be negative") 203 | 204 | # Get the data to plot from the history dictionary. Also, handle skip_end=0 205 | # properly so the behaviour is the expected 206 | lrs = self.history["lr"] 207 | losses = self.history["loss"] 208 | if skip_end == 0: 209 | lrs = lrs[skip_start:] 210 | losses = losses[skip_start:] 211 | else: 212 | lrs = lrs[skip_start:-skip_end] 213 | losses = losses[skip_start:-skip_end] 214 | 215 | # Plot loss as a function of the learning rate 216 | plt.plot(lrs, losses) 217 | if log_lr: 218 | plt.xscale("log") 219 | plt.xlabel("Learning rate") 220 | plt.ylabel("Loss") 221 | plt.show() 222 | 223 | 224 | class LinearLR(_LRScheduler): 225 | """Linearly increases the learning rate between two boundaries over a number of 226 | iterations. 227 | 228 | Arguments: 229 | optimizer (torch.optim.Optimizer): wrapped optimizer. 230 | end_lr (float, optional): the initial learning rate which is the lower 231 | boundary of the test. Default: 10. 232 | num_iter (int, optional): the number of iterations over which the test 233 | occurs. Default: 100. 234 | last_epoch (int): the index of last epoch. Default: -1. 235 | 236 | """ 237 | 238 | def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): 239 | self.end_lr = end_lr 240 | self.num_iter = num_iter 241 | super(LinearLR, self).__init__(optimizer, last_epoch) 242 | 243 | def get_lr(self): 244 | curr_iter = self.last_epoch + 1 245 | r = curr_iter / self.num_iter 246 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 247 | 248 | 249 | class ExponentialLR(_LRScheduler): 250 | """Exponentially increases the learning rate between two boundaries over a number of 251 | iterations. 252 | 253 | Arguments: 254 | optimizer (torch.optim.Optimizer): wrapped optimizer. 255 | end_lr (float, optional): the initial learning rate which is the lower 256 | boundary of the test. Default: 10. 257 | num_iter (int, optional): the number of iterations over which the test 258 | occurs. Default: 100. 259 | last_epoch (int): the index of last epoch. Default: -1. 260 | 261 | """ 262 | 263 | def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): 264 | self.end_lr = end_lr 265 | self.num_iter = num_iter 266 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 267 | 268 | def get_lr(self): 269 | curr_iter = self.last_epoch + 1 270 | r = curr_iter / self.num_iter 271 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 272 | 273 | 274 | class StateCacher(object): 275 | def __init__(self, in_memory, cache_dir=None): 276 | self.in_memory = in_memory 277 | self.cache_dir = cache_dir 278 | 279 | if self.cache_dir is None: 280 | import tempfile 281 | self.cache_dir = tempfile.gettempdir() 282 | else: 283 | if not os.path.isdir(self.cache_dir): 284 | raise ValueError('Given `cache_dir` is not a valid directory.') 285 | 286 | self.cached = {} 287 | 288 | def store(self, key, state_dict): 289 | if self.in_memory: 290 | self.cached.update({key: copy.deepcopy(state_dict)}) 291 | else: 292 | fn = os.path.join(self.cache_dir, 'state_{}_{}.pt'.format(key, id(self))) 293 | self.cached.update({key: fn}) 294 | torch.save(state_dict, fn) 295 | 296 | def retrieve(self, key): 297 | if key not in self.cached: 298 | raise KeyError('Target {} was not cached.'.format(key)) 299 | 300 | if self.in_memory: 301 | return self.cached.get(key) 302 | else: 303 | fn = self.cached.get(key) 304 | if not os.path.exists(fn): 305 | raise RuntimeError('Failed to load state in {}. File does not exist anymore.'.format(fn)) 306 | state_dict = torch.load(fn, map_location=lambda storage, location: storage) 307 | return state_dict 308 | 309 | def __del__(self): 310 | """Check whether there are unused cached files existing in `cache_dir` before 311 | this instance being destroyed.""" 312 | if self.in_memory: 313 | return 314 | 315 | for k in self.cached: 316 | if os.path.exists(self.cached[k]): 317 | os.remove(self.cached[k]) 318 | -------------------------------------------------------------------------------- /lib/transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | 8 | class Compose(object): 9 | def __init__(self, *ops): 10 | self.ops = ops 11 | 12 | def __call__(self, image, mask): 13 | for op in self.ops: 14 | image, mask = op(image, mask) 15 | return image, mask 16 | 17 | 18 | 19 | class Normalize(object): 20 | def __init__(self, mean, std): 21 | self.mean = mean 22 | self.std = std 23 | 24 | def __call__(self, image, mask): 25 | image = (image - self.mean)/self.std 26 | mask /= 255 27 | return image, mask 28 | 29 | 30 | class Resize(object): 31 | def __init__(self, H, W): 32 | self.H = H 33 | self.W = W 34 | 35 | def __call__(self, image, mask): 36 | image = cv2.resize(image, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 37 | mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 38 | return image, mask 39 | 40 | class RandomCrop(object): 41 | def __init__(self, H, W): 42 | self.H = H 43 | self.W = W 44 | 45 | def __call__(self, image, mask): 46 | H,W,_ = image.shape 47 | xmin = np.random.randint(W-self.W+1) 48 | ymin = np.random.randint(H-self.H+1) 49 | image = image[ymin:ymin+self.H, xmin:xmin+self.W, :] 50 | mask = mask[ymin:ymin+self.H, xmin:xmin+self.W, :] 51 | return image, mask 52 | 53 | class RandomHorizontalFlip(object): 54 | def __call__(self, image, mask): 55 | if np.random.randint(2)==1: 56 | image = image[:,::-1,:].copy() 57 | mask = mask[:,::-1,:].copy() 58 | return image, mask 59 | 60 | class ToTensor(object): 61 | def __call__(self, image, mask): 62 | image = torch.from_numpy(image) 63 | image = image.permute(2, 0, 1) 64 | mask = torch.from_numpy(mask) 65 | mask = mask.permute(2, 0, 1) 66 | return image, mask.mean(dim=0, keepdim=True) 67 | 68 | -------------------------------------------------------------------------------- /lscloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class LocalSaliencyCoherence(torch.nn.Module): 6 | """ 7 | This loss function based on the following paper. 8 | Please consider using the following bibtex for citation: 9 | @article{obukhov2019gated, 10 | author={Anton Obukhov and Stamatios Georgoulis and Dengxin Dai and Luc {Van Gool}}, 11 | title={Gated {CRF} Loss for Weakly Supervised Semantic Image Segmentation}, 12 | journal={CoRR}, 13 | volume={abs/1906.04651}, 14 | year={2019}, 15 | url={http://arxiv.org/abs/1906.04651}, 16 | } 17 | """ 18 | def forward( 19 | self, y_hat_softmax, kernels_desc, kernels_radius, sample, height_input, width_input, 20 | mask_src=None, mask_dst=None, compatibility=None, custom_modality_downsamplers=None, out_kernels_vis=False 21 | ): 22 | """ 23 | Performs the forward pass of the loss. 24 | :param y_hat_softmax: A tensor of predicted per-pixel class probabilities of size NxCxHxW 25 | :param kernels_desc: A list of dictionaries, each describing one Gaussian kernel composition from modalities. 26 | The final kernel is a weighted sum of individual kernels. Following example is a composition of 27 | RGBXY and XY kernels: 28 | kernels_desc: [{ 29 | 'weight': 0.9, # Weight of RGBXY kernel 30 | 'xy': 6, # Sigma for XY 31 | 'rgb': 0.1, # Sigma for RGB 32 | },{ 33 | 'weight': 0.1, # Weight of XY kernel 34 | 'xy': 6, # Sigma for XY 35 | }] 36 | :param kernels_radius: Defines size of bounding box region around each pixel in which the kernel is constructed. 37 | :param sample: A dictionary with modalities (except 'xy') used in kernels_desc parameter. Each of the provided 38 | modalities is allowed to be larger than the shape of y_hat_softmax, in such case downsampling will be 39 | invoked. Default downsampling method is area resize; this can be overriden by setting. 40 | custom_modality_downsamplers parameter. 41 | :param width_input, height_input: Dimensions of the full scale resolution of modalities 42 | :param mask_src: (optional) Source mask. 43 | :param mask_dst: (optional) Destination mask. 44 | :param compatibility: (optional) Classes compatibility matrix, defaults to Potts model. 45 | :param custom_modality_downsamplers: A dictionary of modality downsampling functions. 46 | :param out_kernels_vis: Whether to return a tensor with kernels visualized with some step. 47 | :return: Loss function value. 48 | """ 49 | assert y_hat_softmax.dim() == 4, 'Prediction must be a NCHW batch' 50 | N, C, height_pred, width_pred = y_hat_softmax.shape 51 | 52 | device = y_hat_softmax.device 53 | 54 | assert width_input % width_pred == 0 and height_input % height_pred == 0 and \ 55 | width_input * height_pred == height_input * width_pred, \ 56 | f'[{width_input}x{height_input}] !~= [{width_pred}x{height_pred}]' 57 | 58 | kernels = self._create_kernels( 59 | kernels_desc, kernels_radius, sample, N, height_pred, width_pred, device, custom_modality_downsamplers 60 | ) 61 | 62 | y_hat_unfolded = self._unfold(y_hat_softmax, kernels_radius) 63 | y_hat_unfolded = torch.abs(y_hat_unfolded[:, :, kernels_radius, kernels_radius, :, :].view(N, C, 1, 1, height_pred, width_pred) - y_hat_unfolded) 64 | 65 | loss = torch.mean((kernels * y_hat_unfolded).view(N, C, (kernels_radius * 2 + 1) ** 2, height_pred, width_pred).sum(dim=2, keepdim=True)) 66 | 67 | 68 | out = { 69 | 'loss': loss.mean(), 70 | } 71 | 72 | if out_kernels_vis: 73 | out['kernels_vis'] = self._visualize_kernels( 74 | kernels, kernels_radius, height_input, width_input, height_pred, width_pred 75 | ) 76 | 77 | return out 78 | 79 | @staticmethod 80 | def _downsample(img, modality, height_dst, width_dst, custom_modality_downsamplers): 81 | if custom_modality_downsamplers is not None and modality in custom_modality_downsamplers: 82 | f_down = custom_modality_downsamplers[modality] 83 | else: 84 | f_down = F.adaptive_avg_pool2d 85 | return f_down(img, (height_dst, width_dst)) 86 | 87 | @staticmethod 88 | def _create_kernels( 89 | kernels_desc, kernels_radius, sample, N, height_pred, width_pred, device, custom_modality_downsamplers 90 | ): 91 | kernels = None 92 | for i, desc in enumerate(kernels_desc): 93 | weight = desc['weight'] 94 | features = [] 95 | for modality, sigma in desc.items(): 96 | if modality == 'weight': 97 | continue 98 | if modality == 'xy': 99 | feature = LocalSaliencyCoherence._get_mesh(N, height_pred, width_pred, device) 100 | else: 101 | assert modality in sample, \ 102 | f'Modality {modality} is listed in {i}-th kernel descriptor, but not present in the sample' 103 | feature = sample[modality] 104 | # feature = LocalSaliencyCoherence._downsample( 105 | # feature, modality, height_pred, width_pred, custom_modality_downsamplers 106 | # ) 107 | feature /= sigma 108 | features.append(feature) 109 | features = torch.cat(features, dim=1) 110 | kernel = weight * LocalSaliencyCoherence._create_kernels_from_features(features, kernels_radius) 111 | kernels = kernel if kernels is None else kernel + kernels 112 | return kernels 113 | 114 | @staticmethod 115 | def _create_kernels_from_features(features, radius): 116 | assert features.dim() == 4, 'Features must be a NCHW batch' 117 | N, C, H, W = features.shape 118 | kernels = LocalSaliencyCoherence._unfold(features, radius) 119 | kernels = kernels - kernels[:, :, radius, radius, :, :].view(N, C, 1, 1, H, W) 120 | kernels = (-0.5 * kernels ** 2).sum(dim=1, keepdim=True).exp() 121 | # kernels[:, :, radius, radius, :, :] = 0 122 | return kernels 123 | 124 | @staticmethod 125 | def _get_mesh(N, H, W, device): 126 | return torch.cat(( 127 | torch.arange(0, W, 1, dtype=torch.float32, device=device).view(1, 1, 1, W).repeat(N, 1, H, 1), 128 | torch.arange(0, H, 1, dtype=torch.float32, device=device).view(1, 1, H, 1).repeat(N, 1, 1, W) 129 | ), 1) 130 | 131 | @staticmethod 132 | def _unfold(img, radius): 133 | assert img.dim() == 4, 'Unfolding requires NCHW batch' 134 | N, C, H, W = img.shape 135 | diameter = 2 * radius + 1 136 | return F.unfold(img, diameter, 1, radius).view(N, C, diameter, diameter, H, W) 137 | 138 | @staticmethod 139 | def _visualize_kernels(kernels, radius, height_input, width_input, height_pred, width_pred): 140 | diameter = 2 * radius + 1 141 | vis = kernels[:, :, :, :, radius::diameter, radius::diameter] 142 | vis_nh, vis_nw = vis.shape[-2:] 143 | vis = vis.permute(0, 1, 4, 2, 5, 3).contiguous().view(kernels.shape[0], 1, diameter * vis_nh, diameter * vis_nw) 144 | if vis.shape[2] > height_pred: 145 | vis = vis[:, :, :height_pred, :] 146 | if vis.shape[3] > width_pred: 147 | vis = vis[:, :, :, :width_pred] 148 | if vis.shape[2:] != (height_pred, width_pred): 149 | vis = F.pad(vis, [0, width_pred-vis.shape[3], 0, height_pred-vis.shape[2]]) 150 | vis = F.interpolate(vis, (height_input, width_input), mode='nearest') 151 | return vis 152 | -------------------------------------------------------------------------------- /net_agg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def weight_init(module): 12 | for n, m in module.named_children(): 13 | print('initialize: '+n) 14 | if isinstance(m, nn.Conv2d): 15 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 16 | if m.bias is not None: 17 | nn.init.zeros_(m.bias) 18 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 19 | nn.init.ones_(m.weight) 20 | if m.bias is not None: 21 | nn.init.zeros_(m.bias) 22 | elif isinstance(m, nn.Linear): 23 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 24 | if m.bias is not None: 25 | nn.init.zeros_(m.bias) 26 | else: 27 | m.initialize() 28 | 29 | class Bottleneck(nn.Module): 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 31 | super(Bottleneck, self).__init__() 32 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3*dilation-1)//2, bias=False, dilation=dilation) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 37 | self.bn3 = nn.BatchNorm2d(planes*4) 38 | self.downsample = downsample 39 | 40 | def forward(self, x): 41 | residual = x 42 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 43 | out = F.relu(self.bn2(self.conv2(out)), inplace=True) 44 | out = self.bn3(self.conv3(out)) 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | return F.relu(out+residual, inplace=True) 48 | 49 | 50 | class ResNet(nn.Module): 51 | def __init__(self): 52 | super(ResNet, self).__init__() 53 | self.inplanes = 64 54 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 55 | self.bn1 = nn.BatchNorm2d(64) 56 | self.layer1 = self.make_layer( 64, 3, stride=1, dilation=1) 57 | self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) 58 | self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) 59 | self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) 60 | self.initialize() 61 | 62 | def make_layer(self, planes, blocks, stride, dilation): 63 | downsample = None 64 | if stride != 1 or self.inplanes != planes*4: 65 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes*4, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes*4)) 66 | 67 | layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)] 68 | self.inplanes = planes*4 69 | for _ in range(1, blocks): 70 | layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) 75 | out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) 76 | out2 = self.layer1(out1) 77 | out3 = self.layer2(out2) 78 | out4 = self.layer3(out3) 79 | out5 = self.layer4(out4) 80 | return out1, out2, out3, out4, out5 81 | 82 | def initialize(self): 83 | self.load_state_dict(torch.load('resnet50-19c8e357.pth'), strict=False) 84 | 85 | class CA(nn.Module): 86 | def __init__(self, in_channel_left, in_channel_down): 87 | super(CA, self).__init__() 88 | self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=1, stride=1, padding=0) 89 | self.bn0 = nn.BatchNorm2d(256) 90 | self.conv1 = nn.Conv2d(in_channel_down, 256, kernel_size=1, stride=1, padding=0) 91 | self.conv2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) 92 | 93 | def forward(self, left, down): 94 | left = F.relu(self.bn0(self.conv0(left)), inplace=True) #256 95 | down = down.mean(dim=(2,3), keepdim=True) 96 | down = F.relu(self.conv1(down), inplace=True) 97 | down = torch.sigmoid(self.conv2(down)) 98 | return left * down 99 | 100 | def initialize(self): 101 | weight_init(self) 102 | 103 | """ Self Refinement Module """ 104 | class SRM(nn.Module): 105 | def __init__(self, in_channel): 106 | super(SRM, self).__init__() 107 | self.conv1 = nn.Conv2d(in_channel, 256, kernel_size=3, stride=1, padding=1) 108 | self.bn1 = nn.BatchNorm2d(256) 109 | self.conv2 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 110 | 111 | def forward(self, x): 112 | out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) #256 113 | out2 = self.conv2(out1) 114 | w, b = out2[:, :256, :, :], out2[:, 256:, :, :] 115 | return F.relu(w * out1 + b, inplace=True) 116 | 117 | 118 | def initialize(self): 119 | weight_init(self) 120 | 121 | 122 | """ Feature Interweaved Aggregation Module """ 123 | class FAM(nn.Module): 124 | def __init__(self, in_channel_left, in_channel_down, in_channel_right): 125 | super(FAM, self).__init__() 126 | #self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=1, stride=1, padding=0) 127 | self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=3, stride=1, padding=1) 128 | self.bn0 = nn.BatchNorm2d(256) 129 | self.conv1 = nn.Conv2d(in_channel_down, 256, kernel_size=3, stride=1, padding=1) 130 | self.bn1 = nn.BatchNorm2d(256) 131 | self.conv2 = nn.Conv2d(in_channel_right, 256, kernel_size=3, stride=1, padding=1) 132 | self.bn2 = nn.BatchNorm2d(256) 133 | 134 | self.conv_d1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 135 | self.conv_d2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 136 | self.conv_l = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 137 | self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 138 | self.bn3 = nn.BatchNorm2d(256) 139 | self.conv_att1 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) 140 | self.conv_att2 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) 141 | self.conv_att3 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) 142 | 143 | 144 | def forward(self, left, down, right): 145 | left = F.relu(self.bn0(self.conv0(left)), inplace=True) #256 channels 146 | down = F.relu(self.bn1(self.conv1(down)), inplace=True) #256 channels 147 | right = F.relu(self.bn2(self.conv2(right)), inplace=True) #256 148 | 149 | down_1 = self.conv_d1(down) 150 | 151 | w1 = self.conv_l(left) 152 | if down.size()[2:] != left.size()[2:]: 153 | down_ = F.interpolate(down, size=left.size()[2:], mode='bilinear', align_corners=False) 154 | z1 = F.relu(w1 * down_, inplace=True) 155 | else: 156 | z1 = F.relu(w1 * down, inplace=True) 157 | z1_att = F.adaptive_avg_pool2d(self.conv_att1(z1), (1,1)) 158 | z1 = z1_att * z1 159 | 160 | if down_1.size()[2:] != left.size()[2:]: 161 | down_1 = F.interpolate(down_1, size=left.size()[2:], mode='bilinear', align_corners=False) 162 | 163 | z2 = F.relu(down_1 * left, inplace=True) 164 | z2_att = F.adaptive_avg_pool2d(self.conv_att2(z2), (1,1)) 165 | z2 = z2_att * z2 166 | 167 | # z3 168 | down_2 = self.conv_d2(right) 169 | if down_2.size()[2:] != left.size()[2:]: 170 | down_2 = F.interpolate(down_2, size=left.size()[2:], mode='bilinear', align_corners=False) 171 | z3 = F.relu(down_2 * left, inplace=True) 172 | z3_att = F.adaptive_avg_pool2d(self.conv_att3(z3), (1,1)) 173 | z3 = z3_att * z3 174 | out = (z1 + z2 + z3) / (z1_att + z2_att + z3_att) 175 | # out = torch.cat((z1, z2, z3), dim=1) 176 | return F.relu(self.bn3(self.conv3(out)), inplace=True) 177 | 178 | 179 | def initialize(self): 180 | weight_init(self) 181 | 182 | 183 | class SA(nn.Module): 184 | def __init__(self, in_channel_left, in_channel_down): 185 | super(SA, self).__init__() 186 | self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=3, stride=1, padding=1) 187 | self.bn0 = nn.BatchNorm2d(256) 188 | self.conv2 = nn.Conv2d(in_channel_down, 512, kernel_size=3, stride=1, padding=1) 189 | 190 | def forward(self, left, down): 191 | left = F.relu(self.bn0(self.conv0(left)), inplace=True) #256 channels 192 | down_1 = self.conv2(down) #wb 193 | if down_1.size()[2:] != left.size()[2:]: 194 | down_1 = F.interpolate(down_1, size=left.size()[2:], mode='bilinear', align_corners=False) 195 | w,b = down_1[:,:256,:,:], down_1[:,256:,:,:] 196 | 197 | return F.relu(w*left+b, inplace=True) 198 | 199 | def initialize(self): 200 | weight_init(self) 201 | 202 | 203 | class SCWSSOD(nn.Module): 204 | def __init__(self, cfg): 205 | super(SCWSSOD, self).__init__() 206 | self.cfg = cfg 207 | self.bkbone = ResNet() 208 | 209 | self.ca45 = CA(2048, 2048) 210 | self.ca35 = CA(2048, 2048) 211 | self.ca25 = CA(2048, 2048) 212 | self.ca55 = CA(256, 2048) 213 | self.sa55 = SA(2048, 2048) 214 | 215 | self.fam45 = FAM(1024, 256, 256) 216 | self.fam34 = FAM( 512, 256, 256) 217 | self.fam23 = FAM( 256, 256, 256) 218 | 219 | self.srm5 = SRM(256) 220 | self.srm4 = SRM(256) 221 | self.srm3 = SRM(256) 222 | self.srm2 = SRM(256) 223 | 224 | self.linear5 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) 225 | self.linear4 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) 226 | self.linear3 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) 227 | self.linear2 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) 228 | self.initialize() 229 | 230 | def forward(self, x, mode=None): 231 | out1, out2, out3, out4, out5_ = self.bkbone(x) 232 | # GCF 233 | out4_a = self.ca45(out5_, out5_) 234 | out3_a = self.ca35(out5_, out5_) 235 | out2_a = self.ca25(out5_, out5_) 236 | # HA 237 | out5_a = self.sa55(out5_, out5_) 238 | out5 = self.ca55(out5_a, out5_) 239 | # out 240 | out5 = self.srm5(out5) 241 | out4 = self.srm4(self.fam45(out4, out5, out4_a)) 242 | out3 = self.srm3(self.fam34(out3, out4, out3_a)) 243 | out2 = self.srm2(self.fam23(out2, out3, out2_a)) 244 | # we use bilinear interpolation instead of transpose convolution 245 | if mode == 'Test': 246 | # ------------------------------------------------------ TEST ---------------------------------------------------- 247 | out5 = F.interpolate(self.linear5(out5), size=x.size()[2:], mode='bilinear', align_corners=False) 248 | out4 = F.interpolate(self.linear4(out4), size=x.size()[2:], mode='bilinear', align_corners=False) 249 | out3 = F.interpolate(self.linear3(out3), size=x.size()[2:], mode='bilinear', align_corners=False) 250 | out2 = F.interpolate(self.linear2(out2), size=x.size()[2:], mode='bilinear', align_corners=False) 251 | return out2, out3, out4, out5 252 | else: 253 | # ------------------------------------------------------ TRAIN ---------------------------------------------------- 254 | out5 = torch.sigmoid(F.interpolate(self.linear5(out5), size=x.size()[2:], mode='bilinear', align_corners=False)) 255 | out4 = torch.sigmoid(F.interpolate(self.linear4(out4), size=x.size()[2:], mode='bilinear', align_corners=False)) 256 | out3 = torch.sigmoid(F.interpolate(self.linear3(out3), size=x.size()[2:], mode='bilinear', align_corners=False)) 257 | out2 = torch.sigmoid(F.interpolate(self.linear2(out2), size=x.size()[2:], mode='bilinear', align_corners=False)) 258 | out5 = torch.cat((1 - out5, out5), 1) 259 | out4 = torch.cat((1 - out4, out4), 1) 260 | out3 = torch.cat((1 - out3, out3), 1) 261 | out2 = torch.cat((1 - out2, out2), 1) 262 | return out2, out3, out4, out5 263 | def initialize(self): 264 | weight_init(self) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astunparse==1.6.3 3 | blessings @ file:///home/conda/feedstock_root/build_artifacts/blessings_1591051596117/work 4 | cachetools==4.1.0 5 | certifi==2020.12.5 6 | cffi==1.14.0 7 | chardet==3.0.4 8 | cycler==0.10.0 9 | Cython @ file:///home/conda/feedstock_root/build_artifacts/cython_1591799499719/work 10 | decorator==4.4.2 11 | gast==0.3.3 12 | google-auth==1.15.0 13 | google-auth-oauthlib==0.4.1 14 | google-pasta==0.2.0 15 | gpustat==0.6.0 16 | grpcio==1.29.0 17 | h5py==2.10.0 18 | idna==2.9 19 | imageio==2.8.0 20 | importlib-metadata==1.6.0 21 | joblib==0.14.1 22 | Keras-Preprocessing==1.1.2 23 | kiwisolver==1.2.0 24 | kmeans-pytorch==0.3 25 | Markdown==3.2.2 26 | matplotlib==3.2.1 27 | networkx==2.4 28 | ninja==1.9.0.post1 29 | numpy==1.18.3 30 | nvidia-ml-py3==7.352.0 31 | oauthlib==3.1.0 32 | olefile==0.46 33 | opencv-python==4.2.0.34 34 | opt-einsum==3.2.1 35 | Pillow==7.1.2 36 | protobuf==3.11.4 37 | psutil==5.7.0 38 | pyasn1==0.4.8 39 | pyasn1-modules==0.2.8 40 | pycparser==2.20 41 | pyparsing==2.4.7 42 | python-dateutil==2.8.1 43 | PyWavelets==1.1.1 44 | PyYAML==5.3.1 45 | requests==2.23.0 46 | requests-oauthlib==1.3.0 47 | rsa==4.0 48 | scikit-image==0.16.2 49 | scikit-learn==0.22.2.post1 50 | scipy==1.4.1 51 | six==1.14.0 52 | sklearn==0.0 53 | tensorboard==2.2.1 54 | tensorboard-plugin-wit==1.6.0.post3 55 | tensorboardX==2.0 56 | tensorflow==2.2.0 57 | tensorflow-estimator==2.2.0 58 | termcolor==1.1.0 59 | torch==1.0.1 60 | torchvision==0.2.1 61 | tqdm==4.45.0 62 | urllib3==1.25.9 63 | Werkzeug==1.0.1 64 | wrapt==1.12.1 65 | yacs==0.1.7 66 | zipp==3.1.0 67 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import os 5 | import sys 6 | #sys.path.insert(0, '../') 7 | sys.dont_write_bytecode = True 8 | 9 | import cv2 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | plt.ion() 13 | from skimage import img_as_ubyte 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch.utils.data import DataLoader 18 | #from tensorboardX import SummaryWriter 19 | from lib import dataset 20 | from net_agg import SCWSSOD 21 | import time 22 | import logging as logger 23 | 24 | TAG = "scwssod" 25 | SAVE_PATH = TAG 26 | GPU_ID=0 27 | os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID) 28 | 29 | logger.basicConfig(level=logger.INFO, format='%(levelname)s %(asctime)s %(filename)s: %(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', \ 30 | filename="test_%s.log"%(TAG), filemode="w") 31 | 32 | 33 | DATASETS = ['./data/ECSSD', './data/DUT-OMRON', './data/PASCAL-S', './data/HKU-IS', './data/THUR15K', './data/DUTS', ] 34 | # DATASETS = ['./data/DUTS',] 35 | 36 | 37 | class Test(object): 38 | def __init__(self, Dataset, datapath, Network): 39 | ## dataset 40 | self.datapath = datapath.split("/")[-1] 41 | print("Testing on %s"%self.datapath) 42 | self.cfg = Dataset.Config(datapath=datapath, mode='test') 43 | self.data = Dataset.Data(self.cfg) 44 | self.loader = DataLoader(self.data, batch_size=1, shuffle=True, num_workers=8) 45 | ## network 46 | self.net = Network(self.cfg) 47 | # self.net = nn.DataParallel(self.net) 48 | path = './ours/aggsclsc/model-36.pt' 49 | state_dict = torch.load(path) 50 | print('complete loading: {}'.format(path)) 51 | self.net.load_state_dict(state_dict) 52 | print('model has {} parameters in total'.format(sum(x.numel() for x in self.net.parameters()))) 53 | self.net.train(False) 54 | self.net.cuda() 55 | self.net.eval() 56 | 57 | def accuracy(self): 58 | with torch.no_grad(): 59 | mae, fscore, cnt, number = 0, 0, 0, 256 60 | mean_pr, mean_re, threshod = 0, 0, np.linspace(0, 1, number, endpoint=False) 61 | cost_time = 0 62 | for image, mask, (H, W), maskpath in self.loader: 63 | image, mask = image.cuda().float(), mask.cuda().float() 64 | start_time = time.time() 65 | out2, out3, out4, out5 = self.net(image, 'Test') 66 | pred = torch.sigmoid(out2) 67 | torch.cuda.synchronize() 68 | end_time = time.time() 69 | cost_time += end_time - start_time 70 | 71 | ## MAE 72 | cnt += 1 73 | mae += (pred-mask).abs().mean() 74 | ## F-Score 75 | precision = torch.zeros(number) 76 | recall = torch.zeros(number) 77 | for i in range(number): 78 | temp = (pred >= threshod[i]).float() 79 | precision[i] = (temp*mask).sum()/(temp.sum()+1e-12) 80 | recall[i] = (temp*mask).sum()/(mask.sum()+1e-12) 81 | mean_pr += precision 82 | mean_re += recall 83 | fscore = mean_pr*mean_re*(1+0.3)/(0.3*mean_pr+mean_re+1e-12) 84 | if cnt % 20 == 0: 85 | fps = image.shape[0] / (end_time - start_time) 86 | print('MAE=%.6f, F-score=%.6f, fps=%.4f'%(mae/cnt, fscore.max()/cnt, fps)) 87 | fps = len(self.loader.dataset) / cost_time 88 | msg = '%s MAE=%.6f, F-score=%.6f, len(imgs)=%s, fps=%.4f'%(self.datapath, mae/cnt, fscore.max()/cnt, len(self.loader.dataset), fps) 89 | print(msg) 90 | logger.info(msg) 91 | 92 | def save(self): 93 | with torch.no_grad(): 94 | for image, mask, (H, W), name in self.loader: 95 | out2, out3, out4, out5 = self.net(image.cuda().float(), 'Test') 96 | out2 = F.interpolate(out2, size=(H, W), mode='bilinear', align_corners=False) 97 | pred = (torch.sigmoid(out2[0, 0])).cpu().numpy() 98 | pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) 99 | head = './pred_maps/{}/'.format(TAG) + self.cfg.datapath.split('/')[-1] 100 | if not os.path.exists(head): 101 | os.makedirs(head) 102 | cv2.imwrite(head + '/' + name[0], img_as_ubyte(pred)) 103 | 104 | 105 | if __name__=='__main__': 106 | for e in DATASETS: 107 | t =Test(dataset, e, SCWSSOD) 108 | t.accuracy() 109 | t.save() -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | 6 | def ToLabel(E): 7 | fgs = np.argmax(E, axis=1).astype(np.float32) 8 | return fgs.astype(np.uint8) 9 | 10 | 11 | def SSIM(x, y): 12 | C1 = 0.01 ** 2 13 | C2 = 0.03 ** 2 14 | 15 | mu_x = nn.AvgPool2d(3, 1, 1)(x) 16 | mu_y = nn.AvgPool2d(3, 1, 1)(y) 17 | mu_x_mu_y = mu_x * mu_y 18 | mu_x_sq = mu_x.pow(2) 19 | mu_y_sq = mu_y.pow(2) 20 | 21 | sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq 22 | sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq 23 | sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y 24 | 25 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) 26 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) 27 | SSIM = SSIM_n / SSIM_d 28 | 29 | return torch.clamp((1 - SSIM) / 2, 0, 1) 30 | 31 | 32 | def SaliencyStructureConsistency(x, y, alpha): 33 | ssim = torch.mean(SSIM(x,y)) 34 | l1_loss = torch.mean(torch.abs(x-y)) 35 | loss_ssc = alpha*ssim + (1-alpha)*l1_loss 36 | return loss_ssc 37 | 38 | 39 | def SaliencyStructureConsistencynossim(x, y): 40 | l1_loss = torch.mean(torch.abs(x-y)) 41 | return l1_loss 42 | 43 | 44 | def set_seed(seed): 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed_all(seed) 47 | np.random.seed(seed) 48 | random.seed(seed) 49 | torch.backends.cudnn.deterministic = True 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import sys 5 | import datetime 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import cv2 11 | from torch.utils.data import DataLoader 12 | from tensorboardX import SummaryWriter 13 | from data import dataset 14 | from net_agg import SCWSSOD 15 | import logging as logger 16 | from lib.data_prefetcher import DataPrefetcher 17 | from lscloss import * 18 | import numpy as np 19 | from tools import * 20 | import matplotlib.pyplot as plt 21 | 22 | TAG = "scwssod" 23 | SAVE_PATH = "scwssod" 24 | logger.basicConfig(level=logger.INFO, format='%(levelname)s %(asctime)s %(filename)s: %(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', \ 25 | filename="train_%s.log"%(TAG), filemode="w") 26 | 27 | 28 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 29 | 30 | """ set lr """ 31 | def get_triangle_lr(base_lr, max_lr, total_steps, cur, ratio=1., \ 32 | annealing_decay=1e-2, momentums=[0.95, 0.85]): 33 | first = int(total_steps*ratio) 34 | last = total_steps - first 35 | min_lr = base_lr * annealing_decay 36 | 37 | cycle = np.floor(1 + cur/total_steps) 38 | x = np.abs(cur*2.0/total_steps - 2.0*cycle + 1) 39 | if cur < first: 40 | lr = base_lr + (max_lr - base_lr) * np.maximum(0., 1.0 - x) 41 | else: 42 | lr = ((base_lr - min_lr)*cur + min_lr*first - base_lr*total_steps)/(first - total_steps) 43 | if isinstance(momentums, int): 44 | momentum = momentums 45 | else: 46 | if cur < first: 47 | momentum = momentums[0] + (momentums[1] - momentums[0]) * np.maximum(0., 1.-x) 48 | else: 49 | momentum = momentums[0] 50 | 51 | return lr, momentum 52 | 53 | 54 | def get_polylr(base_lr, last_epoch, num_steps, power): 55 | return base_lr * (1.0 - min(last_epoch, num_steps-1) / num_steps) **power 56 | 57 | 58 | BASE_LR = 1e-5 59 | MAX_LR = 1e-2 60 | loss_lsc_kernels_desc_defaults = [{"weight": 1, "xy": 6, "rgb": 0.1}] 61 | loss_lsc_radius = 5 62 | batch = 16 63 | l = 0.3 64 | 65 | 66 | def train(Dataset, Network): 67 | ## dataset 68 | cfg = Dataset.Config(datapath='./data/DUTS', savepath=SAVE_PATH, mode='train', batch=batch, lr=1e-3, momen=0.9, decay=5e-4, epoch=40) 69 | data = Dataset.Data(cfg) 70 | loader = DataLoader(data, batch_size=cfg.batch, shuffle=True, num_workers=8) 71 | ## network 72 | net = Network(cfg) 73 | # print('model has {} parameters in total'.format(sum(x.numel() for x in net.parameters()))) 74 | criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduction='mean') 75 | loss_lsc = LocalSaliencyCoherence().cuda() 76 | net.train(True) 77 | net.cuda() 78 | criterion.cuda() 79 | ## parameter 80 | base, head = [], [] 81 | for name, param in net.named_parameters(): 82 | if 'bkbone' in name: 83 | base.append(param) 84 | else: 85 | head.append(param) 86 | optimizer = torch.optim.SGD([{'params':base}, {'params':head}], lr=cfg.lr, momentum=cfg.momen, weight_decay=cfg.decay, nesterov=True) 87 | sw = SummaryWriter(cfg.savepath) 88 | global_step = 0 89 | 90 | db_size = len(loader) 91 | 92 | # -------------------------- training ------------------------------------ 93 | for epoch in range(cfg.epoch): 94 | prefetcher = DataPrefetcher(loader) 95 | batch_idx = -1 96 | image, mask = prefetcher.next() 97 | while image is not None: 98 | niter = epoch * db_size + batch_idx 99 | lr, momentum = get_triangle_lr(BASE_LR, MAX_LR, cfg.epoch*db_size, niter, ratio=1.) 100 | optimizer.param_groups[0]['lr'] = 0.1 * lr # for backbone 101 | optimizer.param_groups[1]['lr'] = lr 102 | optimizer.momentum = momentum 103 | batch_idx += 1 104 | global_step += 1 105 | 106 | ###### saliency structure consistency loss ###### 107 | image_scale = F.interpolate(image, scale_factor=0.3, mode='bilinear', align_corners=True) 108 | out2, out3, out4, out5 = net(image, 'Train') 109 | out2_s, out3_s, out4_s, out5_s = net(image_scale, 'Train') 110 | out2_scale = F.interpolate(out2[:, 1:2], scale_factor=0.3, mode='bilinear', align_corners=True) 111 | loss_ssc = SaliencyStructureConsistency(out2_s[:, 1:2], out2_scale, 0.85) 112 | 113 | ###### label for partial cross-entropy loss ###### 114 | gt = mask.squeeze(1).long() 115 | bg_label = gt.clone() 116 | fg_label = gt.clone() 117 | bg_label[gt != 0] = 255 118 | fg_label[gt == 0] = 255 119 | 120 | ###### local saliency coherence loss (scale to realize large batchsize) ###### 121 | image_ = F.interpolate(image, scale_factor=0.25, mode='bilinear', align_corners=True) 122 | sample = {'rgb': image_} 123 | out2_ = F.interpolate(out2[:, 1:2], scale_factor=0.25, mode='bilinear', align_corners=True) 124 | loss2_lsc = loss_lsc(out2_, loss_lsc_kernels_desc_defaults, loss_lsc_radius, sample, image_.shape[2], image_.shape[3])['loss'] 125 | loss2 = loss_ssc + criterion(out2, fg_label) + criterion(out2, bg_label) + l * loss2_lsc ## dominant loss 126 | 127 | ###### auxiliary losses ###### 128 | out3_ = F.interpolate(out3[:, 1:2], scale_factor=0.25, mode='bilinear', align_corners=True) 129 | loss3_lsc = loss_lsc(out3_, loss_lsc_kernels_desc_defaults, loss_lsc_radius, sample, image_.shape[2], image_.shape[3])['loss'] 130 | loss3 = criterion(out3, fg_label) + criterion(out3, bg_label) + l * loss3_lsc 131 | out4_ = F.interpolate(out4[:, 1:2], scale_factor=0.25, mode='bilinear', align_corners=True) 132 | loss4_lsc = loss_lsc(out4_, loss_lsc_kernels_desc_defaults, loss_lsc_radius, sample, image_.shape[2], image_.shape[3])['loss'] 133 | loss4 = criterion(out4, fg_label) + criterion(out4, bg_label) + l * loss4_lsc 134 | out5_ = F.interpolate(out5[:, 1:2], scale_factor=0.25, mode='bilinear', align_corners=True) 135 | loss5_lsc = loss_lsc(out5_, loss_lsc_kernels_desc_defaults, loss_lsc_radius, sample, image_.shape[2], image_.shape[3])['loss'] 136 | loss5 = criterion(out5, fg_label) + criterion(out5, bg_label) + l * loss5_lsc 137 | 138 | ###### objective function ###### 139 | loss = loss2*1 + loss3*0.8 + loss4*0.6 + loss5*0.4 140 | optimizer.zero_grad() 141 | loss.backward() 142 | optimizer.step() 143 | sw.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step=global_step) 144 | if batch_idx % 10 == 0: 145 | msg = '%s| %s | step:%d/%d/%d | lr=%.6f | loss=%.6f | loss2=%.6f | loss3=%.6f | loss4=%.6f | loss5=%.6f' % (SAVE_PATH, datetime.datetime.now(), global_step, epoch+1, cfg.epoch, optimizer.param_groups[0]['lr'], loss.item(), loss2.item(), loss3.item(), loss4.item(), loss5.item()) 146 | print(msg) 147 | logger.info(msg) 148 | image, mask = prefetcher.next() 149 | if epoch > 28: 150 | if (epoch+1) % 1 == 0 or (epoch+1) == cfg.epoch: 151 | torch.save(net.state_dict(), cfg.savepath+'/model-'+str(epoch+1)+'.pt') 152 | 153 | 154 | if __name__=='__main__': 155 | train(dataset, SCWSSOD) 156 | --------------------------------------------------------------------------------