├── README.md ├── __pycache__ └── modelsize_estimate.cpython-36.pyc ├── ckpt └── placeholder.md ├── crf_refine.py ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── sbu_dataset.cpython-39.pyc │ ├── sbu_dataset_new.cpython-36.pyc │ ├── sbu_dataset_new.cpython-39.pyc │ ├── transforms.cpython-36.pyc │ └── transforms.cpython-39.pyc ├── sbu_dataset.py ├── sbu_dataset_new.py ├── shadow_dataset.py └── transforms.py ├── env.yaml ├── figures ├── SDDNet.png ├── qualitative_results.png └── quantitative_results.png ├── logs ├── args.txt ├── summary │ ├── events.out.tfevents.1683270473.user-NULL.24389.0 │ ├── events.out.tfevents.1684641896.user-NULL.34088.0 │ ├── events.out.tfevents.1684649636.user-NULL.19888.0 │ ├── events.out.tfevents.1684649792.user-NULL.22683.0 │ ├── events.out.tfevents.1684650039.user-NULL.27206.0 │ ├── events.out.tfevents.1684650615.user-NULL.37139.0 │ ├── events.out.tfevents.1684674591.user-NULL.39936.0 │ ├── events.out.tfevents.1684674671.user-NULL.726.0 │ └── events.out.tfevents.1684674825.user-NULL.3845.0 └── train.log ├── modelsize_estimate.py ├── networks ├── __pycache__ │ ├── fdrnet.cpython-36.pyc │ ├── fdrnet.cpython-39.pyc │ ├── fdrnet_6.cpython-39.pyc │ ├── fdrnet_66.cpython-39.pyc │ ├── fdrnet_666.cpython-39.pyc │ ├── fdrnet_7.cpython-36.pyc │ ├── fdrnet_7.cpython-39.pyc │ ├── fdrnet_77.cpython-36.pyc │ ├── fdrnet_77.cpython-39.pyc │ ├── fdrnet_777.cpython-36.pyc │ ├── fdrnet_basic.cpython-36.pyc │ ├── fdrnet_basic_fdr.cpython-36.pyc │ ├── fdrnet_basic_fdr_ssf.cpython-36.pyc │ ├── fdrnet_basic_ml.cpython-36.pyc │ ├── fdrnet_basic_ml_fdr.cpython-36.pyc │ ├── fdrnet_basic_newml.cpython-36.pyc │ ├── fdrnet_gyc.cpython-36.pyc │ ├── loss.cpython-36.pyc │ └── loss.cpython-39.pyc ├── loss.py └── sddnet.py ├── resnext ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── config.cpython-36.pyc │ ├── config.cpython-39.pyc │ ├── resnext101_regular.cpython-36.pyc │ ├── resnext101_regular.cpython-39.pyc │ ├── resnext_101_32x4d_.cpython-36.pyc │ └── resnext_101_32x4d_.cpython-39.pyc ├── config.py ├── resnext101_regular.py ├── resnext101_regular_scratch.py ├── resnext101_regular_sep.py └── resnext_101_32x4d_.py ├── test.py ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-39.pyc ├── evaluation.cpython-36.pyc ├── evaluation.cpython-39.pyc ├── misc.cpython-36.pyc ├── misc.cpython-39.pyc ├── transforms.cpython-39.pyc ├── visualization.cpython-36.pyc └── visualization.cpython-39.pyc ├── cfg_parser.py ├── evaluation.py ├── misc.py ├── transforms.py └── visualization.py /README.md: -------------------------------------------------------------------------------- 1 | # SDDNet_ACMMM23 2 | 3 | Runmin Cong, Yuchen Guan, Jinpeng Chen, Wei Zhang, Yao Zhao, and Sam Kwong, SDDNet: Style-guided dual-layer disentanglement network for shadow detection, ACM Multimedia (ACM MM), 2023. In Press. 4 | 5 | ## Network 6 | 7 | ### Our overall framework: 8 | 9 | ![image](figures/SDDNet.png) 10 | 11 | 12 | ## Requirement 13 | 14 | Pleasure configure the environment according to the given version: 15 | 16 | - python 3.6.10 17 | - pytorch 1.10.1 18 | - cudatoolkit 11.1 19 | - torchvision 0.11.2 20 | - tensorboard 2.3.0 21 | - opencv-python 3.4.2 22 | - PIL 7.2.0 23 | - pydensecrf 1.0rc3 24 | - numpy 1.18.5 25 | 26 | We also provide ".yaml" files for conda environment configuration, you can use `conda env create -f env.yaml` to create a required environment. 27 | 28 | ResNext101 has been adopted, please put `resnext_101_32x4d.pth` in the `SDDNet/resnext` directory. You can download the model from [[Link](https://pan.baidu.com/s/12aR793_GeohinDlFbqGlzQ)], code: ```mvpl```. 29 | 30 | 31 | ## Preparation 32 | 33 | Please follow this structure to inspect the code: 34 | 35 | 36 | ```python 37 | ├── ISTD_Dataset 38 | ├── test 39 | ├── train 40 | ├── SBU-shadow 41 | ├── SBU-Test_rename 42 | ├── SBUTrain4KRecoveredSmall 43 | ├── UCF 44 | ├── train_A 45 | ├── train_B 46 | ├── SDDNet 47 | ├── ckpt 48 | ├── datasets 49 | ├── logs 50 | ├── networks 51 | ├── resnext 52 | ├── test 53 | ├── utils 54 | ├── crf_refine.py 55 | ├── modelsize_estimate.py 56 | ├── test.py 57 | ├── train.py 58 | ``` 59 | 60 | 61 | ## Training and Testing 62 | 63 | **Please Note** : 64 | The input images folder is always named 'train_A' and the GT folder is always named 'train_B' for uniform processing. 65 | 66 | **Training command** : 67 | ```python 68 | python train.py 69 | ``` 70 | 71 | **Testing command** : 72 | The trained model for SDDNet can be download here: [[Baidu Netdisk Link](https://pan.baidu.com/s/1OyFuHeWtfiueOUan9GxQrg)], code: ```mvpl``` or [[Google Drive Link](https://drive.google.com/drive/folders/1Qz7zPT1A4u1OO6t8v6AiySBMBK9rBZLb?usp=sharing)]. 73 | ```python 74 | python test.py 75 | python crf_refine.py 76 | ``` 77 | 78 | 80 | ## Results 81 | 82 | 1. **Qualitative results**: we provide the saliency maps, you can download them from [[Baidu Netdisk Link](https://pan.baidu.com/s/1-wvG-LVGIu4HEiP1izs_ZQ)], code: ```mvpl``` or [[Google Drive Link](https://drive.google.com/drive/folders/1Qz7zPT1A4u1OO6t8v6AiySBMBK9rBZLb?usp=sharing)]. 83 | 2. **Quantitative results**: 84 | 85 | ![image](figures/quantitative_results.png) 86 | 87 | 88 | 89 | 100 | ## Contact Us 101 | If you have any questions, please contact Runmin Cong at [rmcong@sdu.edu.cn](mailto:rmcong@sdu.edu.cn) or Yuchen Guan at [yuchenguan@bjtu.edu.cn](mailto:19281155@bjtu.edu.cn). 102 | 103 | -------------------------------------------------------------------------------- /__pycache__/modelsize_estimate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/__pycache__/modelsize_estimate.cpython-36.pyc -------------------------------------------------------------------------------- /ckpt/placeholder.md: -------------------------------------------------------------------------------- 1 | put the checkpoint here. -------------------------------------------------------------------------------- /crf_refine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from tqdm import tqdm 4 | from utils.evaluation import evaluate, crf_refine 5 | from utils.misc import split_np_imgrid, get_np_imgrid 6 | import numpy as np 7 | 8 | 9 | # raw_dir = 'test/raw_modify3_new' 10 | # crf_dir = 'test/crf_basic' 11 | raw_dir = 'test/demo' 12 | crf_dir = 'test/demo' 13 | 14 | im_names = os.listdir(raw_dir) 15 | os.makedirs(crf_dir, exist_ok=True) 16 | 17 | for im_name in tqdm(im_names): 18 | im_grid_path = os.path.join(raw_dir, im_name) 19 | im_grid = cv2.imread(im_grid_path) 20 | # ims = split_np_imgrid(im_grid, 3, 3) 21 | ims = split_np_imgrid(im_grid, 3, 3) 22 | input_im = ims[0].copy(order='C') 23 | prob = ims[1][:, :, 0].copy(order='C') 24 | refined = crf_refine(input_im, prob) 25 | ims[1] = np.stack((refined,)*3, axis=2) 26 | im_grid_new = get_np_imgrid(np.stack(ims, axis=0), nrow=3, padding=0) 27 | cv2.imwrite(os.path.join(crf_dir, im_name), im_grid_new) 28 | 29 | 30 | from utils.evaluation import evaluate 31 | 32 | im_grid_dir = crf_dir 33 | # pos_err, neg_err, ber, acc, df = evaluate(im_grid_dir, pred_id=1, gt_id=2, nimg=3, nrow=3) 34 | pos_err, neg_err, ber, acc, df = evaluate(im_grid_dir, pred_id=1, gt_id=2, nimg=3, nrow=3) 35 | print(f'\t BER: {ber:.2f}, pErr: {pos_err:.2f}, nErr: {neg_err:.2f}, acc:{acc:.4f}') -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | 16 | 17 | def find_dataset_using_name(dataset_name): 18 | """Import the module "datasets/[dataset_name]_dataset.py". 19 | 20 | In the file, the class called DatasetNameDataset() will 21 | be instantiated. It has to be a subclass of BaseDataset, 22 | and it is case-insensitive. 23 | """ 24 | dataset_filename = "datasets." + dataset_name + "_dataset" 25 | datasetlib = importlib.import_module(dataset_filename) 26 | 27 | dataset = None 28 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 29 | for name, CLS in datasetlib.__dict__.items(): 30 | if name.lower() == target_dataset_name.lower() \ 31 | and issubclass(CLS, torch.utils.data.Dataset): 32 | dataset = CLS 33 | 34 | if dataset is None: 35 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 36 | 37 | return dataset 38 | 39 | 40 | # def create_dataset(config): 41 | # """Create a dataset given the option. 42 | 43 | # This function wraps the class CustomDatasetDataLoader. 44 | # This is the main interface between this package and 'train.py'/'test.py' 45 | 46 | # Example: 47 | # >>> from data import create_dataset 48 | # >>> dataset = create_dataset(opt) 49 | # """ 50 | # dataset_class = find_dataset_using_name(config['dataset_mode']) 51 | # dataset = dataset_class(config) 52 | # dataloader = torch.utils.data.DataLoader( 53 | # dataset, 54 | # batch_size=config['batch_size'], 55 | # shuffle=not config['serial_batches'], 56 | # num_workers=int(config['num_threads'])) 57 | # return dataloader 58 | 59 | -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sbu_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/datasets/__pycache__/sbu_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sbu_dataset_new.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/datasets/__pycache__/sbu_dataset_new.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sbu_dataset_new.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/datasets/__pycache__/sbu_dataset_new.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/datasets/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/datasets/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/sbu_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from torch.utils.data import Dataset 3 | import random 4 | import os 5 | import numpy as np 6 | import torch 7 | from collections import OrderedDict 8 | from torchvision import transforms 9 | from utils.transforms import JointRandHrzFlip, JointResize, \ 10 | JointNormalize, JointToTensor, \ 11 | JointRandVertFlip 12 | 13 | 14 | class SBUDataset(Dataset): 15 | def __init__(self, 16 | data_root, 17 | phase=None, 18 | img_dirs=['ShadowImages'], 19 | mask_dir='ShadowMasks', 20 | augmentation=False, 21 | im_size=400, 22 | max_dataset_size=None, 23 | normalize=True): 24 | 25 | self.root_dir = data_root 26 | self.img_dirs = img_dirs 27 | self.img_names = sorted(os.listdir(os.path.join(self.root_dir, img_dirs[0]))) 28 | self.mask_dir = mask_dir 29 | self.augmentation = augmentation 30 | 31 | self.size = len(self.img_names) 32 | # None means doesn't change the size of dataset to be loaded 33 | if max_dataset_size is not None: 34 | assert isinstance(max_dataset_size, int) and max_dataset_size > 0 35 | self.size = min(max_dataset_size, self.size) 36 | self.img_names = self.img_names[:self.size] 37 | 38 | assert phase in ['train', 'val', 'test', None] 39 | if phase == 'train': 40 | self.joint_transform = transforms.Compose([JointRandHrzFlip(), 41 | # JointRandVertFlip(), 42 | JointResize(im_size)]) 43 | img_transform = [ JointToTensor() ] 44 | if normalize: 45 | img_transform.append( JointNormalize([0.485, 0.456, 0.406], 46 | [0.229, 0.224, 0.225]) ) 47 | self.img_transform = transforms.Compose(img_transform) 48 | self.target_transform = transforms.ToTensor() 49 | 50 | elif phase in ['val', 'test']: 51 | self.joint_transform = None 52 | 53 | img_transform = [ JointToTensor() ] 54 | if normalize: 55 | img_transform.append( JointNormalize([0.485, 0.456, 0.406], 56 | [0.229, 0.224, 0.225]) ) 57 | self.img_transform = transforms.Compose(img_transform) 58 | 59 | self.target_transform = transforms.ToTensor() 60 | 61 | else: # pahse is None 62 | self.joint_transform = None 63 | self.img_transform = None 64 | self.target_transform = None 65 | 66 | def __getitem__(self, index): 67 | sample = OrderedDict() 68 | img_name = self.img_names[index] 69 | 70 | if self.augmentation: 71 | ret_key = ['ShadowImages_input'] 72 | img_dir = random.choice(self.img_dirs) 73 | img_path = os.path.join(self.root_dir, img_dir, img_name) 74 | img = cv2.imread(img_path) 75 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 76 | ret_val= [ img ] 77 | 78 | 79 | else: 80 | ret_key = [] 81 | ret_val = [] 82 | for img_dir in self.img_dirs: 83 | img_path = os.path.join(self.root_dir, img_dir, img_name) 84 | img = cv2.imread(img_path) 85 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 86 | ret_key.append(img_dir+'_input') 87 | ret_val.append(img) 88 | 89 | mask_name = os.path.splitext(img_name)[0]+'.png' 90 | mask_path = os.path.join(self.root_dir, self.mask_dir, mask_name) 91 | # print(mask_path) 92 | mask = ((cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 125)*255).astype(np.uint8) 93 | ret_key.append('gt') 94 | ret_val.append(mask) 95 | 96 | if self.joint_transform: 97 | ret_val = self.joint_transform(ret_val) 98 | 99 | if self.img_transform: 100 | ret_val[:-1] = self.img_transform(ret_val[:-1]) 101 | 102 | if self.target_transform: 103 | ret_val[-1] = self.target_transform(ret_val[-1]) 104 | 105 | ret_key.append('im_name') 106 | ret_val.append(img_name) 107 | 108 | # print(ret_key) 109 | # print(ret_val) 110 | return OrderedDict(zip(ret_key, ret_val)) 111 | 112 | 113 | def __len__(self): 114 | return self.size 115 | 116 | 117 | -------------------------------------------------------------------------------- /datasets/sbu_dataset_new.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from torch.utils.data import Dataset 3 | import random 4 | import os 5 | import numpy as np 6 | import torch 7 | from collections import OrderedDict 8 | from torchvision import transforms 9 | from .transforms import JointRandHrzFlip, JointResize, \ 10 | JointNormalize, JointToTensor, \ 11 | JointRandVertFlip 12 | 13 | 14 | class SBUDataset(Dataset): 15 | def __init__(self, 16 | data_root, 17 | phase=None, 18 | img_dirs=['train_A'], 19 | mask_dir='train_B', 20 | # noshad_dir='train_C', 21 | augmentation=False, 22 | im_size=400, 23 | max_dataset_size=None, 24 | normalize=True): 25 | 26 | self.root_dir = data_root 27 | self.img_dirs = img_dirs 28 | self.noshad_dir = noshad_dir 29 | self.img_names = sorted(os.listdir(os.path.join(self.root_dir, img_dirs[0]))) 30 | self.mask_dir = mask_dir 31 | self.augmentation = augmentation 32 | self.phase = phase 33 | 34 | self.size = len(self.img_names) 35 | # None means doesn't change the size of dataset to be loaded 36 | if max_dataset_size is not None: 37 | assert isinstance(max_dataset_size, int) and max_dataset_size > 0 38 | self.size = min(max_dataset_size, self.size) 39 | self.img_names = self.img_names[:self.size] 40 | 41 | assert phase in ['train', 'val', 'test', None] 42 | if phase == 'train': 43 | self.joint_transform = transforms.Compose([JointRandHrzFlip(), 44 | # JointRandVertFlip(), 45 | JointResize(im_size)]) 46 | # self.joint_transform = transforms.Compose([JointRandHrzFlip()]) 47 | img_transform = [ JointToTensor() ] 48 | if normalize: 49 | img_transform.append( JointNormalize([0.485, 0.456, 0.406], 50 | [0.229, 0.224, 0.225]) ) 51 | self.img_transform = transforms.Compose(img_transform) 52 | self.target_transform = transforms.ToTensor() 53 | 54 | elif phase in ['val', 'test']: 55 | self.joint_transform = None 56 | 57 | img_transform = [ JointResize(im_size), JointToTensor() ] 58 | if normalize: 59 | img_transform.append( JointNormalize([0.485, 0.456, 0.406], 60 | [0.229, 0.224, 0.225]) ) 61 | self.img_transform = transforms.Compose(img_transform) 62 | 63 | self.target_transform = transforms.Compose([JointResize(im_size), transforms.ToTensor()]) 64 | 65 | else: # pahse is None 66 | self.joint_transform = None 67 | self.img_transform = None 68 | self.target_transform = None 69 | 70 | def _load_sample_pairs(self): 71 | pass 72 | 73 | def __getitem__(self, index): 74 | sample = OrderedDict() 75 | img_name = self.img_names[index] 76 | 77 | if self.augmentation: 78 | ret_key = ['train_A_input'] 79 | img_dir = random.choice(self.img_dirs) 80 | img_path = os.path.join(self.root_dir, img_dir, img_name) 81 | img = cv2.imread(img_path) 82 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 83 | ret_val= [ img ] 84 | 85 | 86 | else: 87 | ret_key = [] 88 | ret_val = [] 89 | for img_dir in self.img_dirs: 90 | img_path = os.path.join(self.root_dir, img_dir, img_name) 91 | img = cv2.imread(img_path) 92 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 93 | ret_key.append(img_dir+'_input') 94 | ret_val.append(img) 95 | 96 | mask_name = os.path.splitext(img_name)[0]+'.png' 97 | mask_path = os.path.join(self.root_dir, self.mask_dir, mask_name) 98 | # print(mask_path) 99 | mask = ((cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 125)*255).astype(np.uint8) 100 | ret_key.append('gt') 101 | ret_val.append(mask) 102 | 103 | # noshad_name = os.path.splitext(img_name)[0]+'.png' 104 | # noshad_path = os.path.join(self.root_dir, self.noshad_dir, noshad_name) 105 | # # print(noshad_path) 106 | # noshad = cv2.imread(noshad_path) 107 | # noshad = cv2.cvtColor(noshad, cv2.COLOR_BGR2RGB) 108 | # ret_key.append('noshad') 109 | # ret_val.append(noshad) 110 | 111 | # if self.phase == 'train': 112 | # noshad_name = os.path.splitext(img_name)[0]+'.png' 113 | # noshad_path = os.path.join(self.root_dir, self.noshad_dir, noshad_name) 114 | # noshad = cv2.imread(noshad_path) 115 | # noshad = cv2.cvtColor(noshad, cv2.COLOR_BGR2RGB) 116 | # ret_key.append('noshad') 117 | # ret_val.append(noshad) 118 | 119 | if self.joint_transform: 120 | ret_val = self.joint_transform(ret_val) 121 | 122 | if self.img_transform: 123 | ret_val[:-1] = self.img_transform(ret_val[:-1]) 124 | 125 | if self.target_transform: 126 | ret_val[-1] = self.target_transform(ret_val[-1]) 127 | 128 | ret_key.append('im_name') 129 | ret_val.append(img_name) 130 | 131 | # print(ret_key) 132 | # print(ret_val) 133 | return OrderedDict(zip(ret_key, ret_val)) 134 | 135 | 136 | def __len__(self): 137 | return self.size 138 | 139 | 140 | -------------------------------------------------------------------------------- /datasets/shadow_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import cv2 4 | from torch.utils.data import Dataset 5 | import random 6 | import os 7 | import numpy as np 8 | import torch 9 | from collections import OrderedDict 10 | from torchvision import transforms 11 | from .transforms import JointRandHrzFlip, JointResize, \ 12 | JointNormalize, JointToTensor, \ 13 | JointRandVertFlip 14 | 15 | 16 | class ShadowDataset(Dataset): 17 | image_dir = 'ShadowImages' 18 | mask_dir = 'ShadowMasks' 19 | def __init__(self, data_root, im_size=512, phase='train', max_dataset_size=None): 20 | self.root_dir = data_root 21 | self.img_dirs = img_dirs 22 | self.img_names = sorted(os.listdir(os.path.join(self.root_dir, img_dirs[0]))) 23 | self.mask_dir = mask_dir 24 | self.augmentation = augmentation 25 | 26 | self.size = len(self.img_names) 27 | # None means doesn't change the size of dataset to be loaded 28 | if max_dataset_size is not None: 29 | assert isinstance(max_dataset_size, int) and max_dataset_size > 0 30 | self.size = min(max_dataset_size, self.size) 31 | self.img_names = self.img_names[:self.size] 32 | 33 | assert phase in ['train', 'val', 'test', None] 34 | if phase == 'train': 35 | self.joint_transform = transforms.Compose([JointRandHrzFlip(), 36 | # JointRandVertFlip(), 37 | JointResize(im_size)]) 38 | # self.joint_transform = transforms.Compose([JointRandHrzFlip()]) 39 | img_transform = [ JointToTensor() ] 40 | if normalize: 41 | img_transform.append( JointNormalize([0.485, 0.456, 0.406], 42 | [0.229, 0.224, 0.225]) ) 43 | self.img_transform = transforms.Compose(img_transform) 44 | self.target_transform = transforms.ToTensor() 45 | 46 | elif phase in ['val', 'test']: 47 | self.joint_transform = None 48 | 49 | img_transform = [ JointResize(im_size), JointToTensor() ] 50 | if normalize: 51 | img_transform.append( JointNormalize([0.485, 0.456, 0.406], 52 | [0.229, 0.224, 0.225]) ) 53 | self.img_transform = transforms.Compose(img_transform) 54 | 55 | self.target_transform = transforms.ToTensor() 56 | 57 | else: # pahse is None 58 | self.joint_transform = None 59 | self.img_transform = None 60 | self.target_transform = None 61 | 62 | def _load_sample_pairs(self): 63 | pass 64 | 65 | def __getitem__(self, index): 66 | sample = OrderedDict() 67 | img_name = self.img_names[index] 68 | 69 | if self.augmentation: 70 | ret_key = ['ShadowImages_input'] 71 | img_dir = random.choice(self.img_dirs) 72 | img_path = os.path.join(self.root_dir, img_dir, img_name) 73 | img = cv2.imread(img_path) 74 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 75 | ret_val= [ img ] 76 | 77 | 78 | else: 79 | ret_key = [] 80 | ret_val = [] 81 | for img_dir in self.img_dirs: 82 | img_path = os.path.join(self.root_dir, img_dir, img_name) 83 | img = cv2.imread(img_path) 84 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 85 | ret_key.append(img_dir+'_input') 86 | ret_val.append(img) 87 | 88 | mask_name = os.path.splitext(img_name)[0]+'.png' 89 | mask_path = os.path.join(self.root_dir, self.mask_dir, mask_name) 90 | # print(mask_path) 91 | mask = ((cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 125)*255).astype(np.uint8) 92 | ret_key.append('gt') 93 | ret_val.append(mask) 94 | 95 | if self.joint_transform: 96 | ret_val = self.joint_transform(ret_val) 97 | 98 | if self.img_transform: 99 | ret_val[:-1] = self.img_transform(ret_val[:-1]) 100 | 101 | if self.target_transform: 102 | ret_val[-1] = self.target_transform(ret_val[-1]) 103 | 104 | ret_key.append('im_name') 105 | ret_val.append(img_name) 106 | 107 | # print(ret_key) 108 | # print(ret_val) 109 | return OrderedDict(zip(ret_key, ret_val)) 110 | 111 | 112 | def __len__(self): 113 | return self.size 114 | 115 | 116 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from torchvision import transforms 4 | import random 5 | 6 | 7 | class ToCV2Image(object): 8 | ''' 9 | convert a CHW, range [0., 1.] tensor to a cv2 image 10 | ''' 11 | def __init__(self, in_color='rgb', 12 | out_color='bgr'): 13 | assert in_color in ['rgb', 'bgr'] 14 | assert out_color in ['rgb', 'bgr'] 15 | self.in_color = in_color 16 | self.out_color = out_color 17 | 18 | def __call__(self, im_tensor): 19 | cv2_img = (im_tensor.cpu().numpy() * 255).astype(np.uint8).transpose(1, 2, 0) 20 | if self.in_color != self.out_color: 21 | cv2_img = cv2.img[:, :, ::-1] 22 | return cv2_img 23 | 24 | 25 | class JointRandHrzFlip(object): 26 | def __init__(self, p=0.5): 27 | self.p = p 28 | 29 | def flip_single(self, image): 30 | return cv2.flip(image, 1) 31 | 32 | def __call__(self, img): 33 | assert isinstance(img, (np.ndarray, list, tuple)) 34 | if random.random() < self.p: 35 | if isinstance(img, np.ndarray): 36 | flipped = self.flip_single(img) 37 | else: 38 | flipped = [] 39 | for each in img: 40 | flipped.append(self.flip_single(each)) 41 | return flipped 42 | else: 43 | return img 44 | 45 | 46 | class JointRandVertFlip(object): 47 | def __init__(self, p=0.5): 48 | self.p = p 49 | 50 | def flip_single(self, image): 51 | return cv2.flip(image, 0) 52 | 53 | def __call__(self, img): 54 | assert isinstance(img, (np.ndarray, list, tuple)) 55 | if random.random() < self.p: 56 | if isinstance(img, np.ndarray): 57 | flipped = self.flip_single(img) 58 | else: 59 | flipped = [] 60 | for each in img: 61 | flipped.append(self.flip_single(each)) 62 | return flipped 63 | else: 64 | return img 65 | 66 | 67 | class JointResize(object): 68 | def __init__(self, size, interpolation='bilinear'): 69 | assert isinstance(size, (tuple, int)) 70 | if isinstance(size, int): 71 | self.size = (size, size) 72 | else: 73 | self.size = size 74 | map_dict = {'bilinear': cv2.INTER_LINEAR, 75 | 'bicubic': cv2.INTER_CUBIC, 76 | 'nearest': cv2.INTER_NEAREST 77 | } 78 | assert interpolation in map_dict.keys() 79 | self.inter_flag = map_dict[interpolation] 80 | 81 | def resize_single(self, image): 82 | return cv2.resize(image, self.size, interpolation=self.inter_flag) 83 | 84 | 85 | def __call__(self, img): 86 | assert isinstance(img, (np.ndarray, list, tuple)) 87 | if isinstance(img, np.ndarray): 88 | resized = self.resize_single(img) 89 | else: 90 | resized = [] 91 | for image in img: 92 | resized.append(self.resize_single(image)) 93 | return resized 94 | 95 | 96 | class JointToTensor(object): 97 | def __init__(self): 98 | self.to_tensor_single = transforms.ToTensor() 99 | 100 | def __call__(self, img): 101 | assert isinstance(img, (np.ndarray, list, tuple)) 102 | if isinstance(img, np.ndarray): 103 | im_tensor = self.to_tensor_single(img) 104 | else: 105 | im_tensor = [] 106 | for image in img: 107 | im_tensor.append(self.to_tensor_single(image)) 108 | return im_tensor 109 | 110 | 111 | class JointNormalize(object): 112 | def __init__(self, mean, std, inplace=False): 113 | self.normalize_single = transforms.Normalize(mean, std, inplace) 114 | 115 | def __call__(self, img): 116 | assert isinstance(img, (np.ndarray, list, tuple)) 117 | if isinstance(img, np.ndarray): 118 | normalized = self.normalize_single(img) 119 | else: 120 | normalized = [] 121 | for image in img: 122 | normalized.append(self.normalize_single(image)) 123 | return normalized 124 | 125 | 126 | class JointRandCrop(object): 127 | def __init__(self, size): 128 | pass 129 | def __call__(self, img): 130 | pass 131 | 132 | 133 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/2 134 | # class Denormalize(object): 135 | # def __init__(self, mean, std): 136 | # self.mean = mean 137 | # self.std = std 138 | 139 | # def __call__(self, tensor): 140 | # """ 141 | # Args: 142 | # tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 143 | # Returns: 144 | # Tensor: Normalized image. 145 | # """ 146 | # for t, m, s in zip(tensor, self.mean, self.std): 147 | # t.mul_(s).add_(m) 148 | # # The normalize code -> t.sub_(m).div_(s) 149 | # return tensor 150 | 151 | 152 | class Denormalize(object): 153 | """ 154 | reverse operation of Normalize 155 | denormalize a image tensor(float) for visualization 156 | """ 157 | def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 158 | # self.mean = torch.Tensor(mean).view(3, 1, 1) 159 | # self.std = torch.Tensor(std).view(3, 1, 1) 160 | inv_std = [1/x for x in std ] 161 | inv_mean = [-m*s for (m, s) in zip(mean, inv_std)] 162 | self.denorm = transforms.Normalize(mean=inv_mean, std=inv_std) 163 | 164 | def __call__(self, x): 165 | # x : Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized. 166 | # return x * self.std + self.mean 167 | return self.denorm(x) 168 | 169 | 170 | class Binarize(object): 171 | def __init__(self, threshold=125): 172 | assert isinstance(threshold, (int, float)) 173 | self.threshold = threshold 174 | 175 | def __call__(self, img): 176 | assert isinstance(img, np.ndarray) 177 | return (img > self.threshold).astype(img.dtype) 178 | 179 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: gyc_sdd 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _tflow_select=2.1.0=gpu 9 | - absl-py=0.9.0=py36_0 10 | - astor=0.8.0=py36_0 11 | - astroid=2.3.3=py36_0 12 | - attrs=19.3.0=py_0 13 | - backcall=0.2.0=py_0 14 | - bayesian-optimization=1.1.0=py_0 15 | - blas=1.0=mkl 16 | - bleach=3.1.5=py_0 17 | - blinker=1.4=py36_0 18 | - brotlipy=0.7.0=py36h7b6447c_1000 19 | - bzip2=1.0.8=h7b6447c_0 20 | - c-ares=1.15.0=h7b6447c_1001 21 | - ca-certificates=2021.10.26=h06a4308_2 22 | - cachetools=4.1.0=py_1 23 | - cairo=1.14.12=h8948797_3 24 | - certifi=2020.12.5=py36h06a4308_0 25 | - cffi=1.14.0=py36h2e261b9_0 26 | - chardet=3.0.4=py36_1003 27 | - click=7.1.2=py_0 28 | - conda=4.10.3=py36h06a4308_0 29 | - conda-package-handling=1.6.1=py36h7b6447c_0 30 | - cryptography=2.9.2=py36h1ba5d50_0 31 | - cudatoolkit=10.1.243=h6bb024c_0 32 | - cudnn=7.6.5=cuda10.1_0 33 | - cupti=10.1.168=0 34 | - cython=0.29.17=py36he6710b0_0 35 | - dbus=1.13.16=hb2f20db_0 36 | - decorator=4.4.2=py_0 37 | - defusedxml=0.6.0=py_0 38 | - entrypoints=0.3=py36_0 39 | - expat=2.2.9=he6710b0_2 40 | - ffmpeg=4.0=hcdf2ecd_0 41 | - fontconfig=2.13.0=h9420a91_0 42 | - freeglut=3.0.0=hf484d3e_5 43 | - freetype=2.10.2=h5ab3b9f_0 44 | - fribidi=1.0.9=h7b6447c_0 45 | - gast=0.2.2=py36_0 46 | - glib=2.63.1=h5a9c865_0 47 | - google-auth=1.17.2=py_0 48 | - google-auth-oauthlib=0.4.1=py_2 49 | - google-pasta=0.2.0=py_0 50 | - graphite2=1.3.14=h23475e2_0 51 | - graphviz=2.40.1=h21bd128_2 52 | - grpcio=1.27.2=py36hf8bcb03_0 53 | - gst-plugins-base=1.14.0=hbbd80ab_1 54 | - gstreamer=1.14.0=hb453b48_1 55 | - h5py=2.8.0=py36h989c5e5_3 56 | - harfbuzz=1.8.8=hffaf4a1_0 57 | - hdf5=1.10.2=hba1933b_1 58 | - icu=58.2=he6710b0_3 59 | - idna=2.10=py_0 60 | - importlib-metadata=1.7.0=py36_0 61 | - importlib_metadata=1.7.0=0 62 | - intel-openmp=2020.1=217 63 | - ipykernel=5.3.0=py36h5ca1d4c_0 64 | - ipython=7.16.1=py36h5ca1d4c_0 65 | - ipython_genutils=0.2.0=py36_0 66 | - ipywidgets=7.5.1=py_0 67 | - isort=4.3.21=py36_0 68 | - jasper=2.0.14=h07fcdf6_1 69 | - jedi=0.17.1=py36_0 70 | - jinja2=2.11.2=py_0 71 | - joblib=0.16.0=py_0 72 | - jpeg=9b=h024ee3a_2 73 | - jsonschema=3.2.0=py36_0 74 | - jupyter=1.0.0=py36_7 75 | - jupyter_client=6.1.5=py_0 76 | - jupyter_console=6.1.0=py_0 77 | - jupyter_core=4.6.3=py36_0 78 | - keras-applications=1.0.8=py_1 79 | - keras-preprocessing=1.1.0=py_1 80 | - lazy-object-proxy=1.5.0=py36h7b6447c_0 81 | - lcms2=2.11=h396b838_0 82 | - ld_impl_linux-64=2.33.1=h53a641e_7 83 | - libedit=3.1.20191231=h14c3975_1 84 | - libffi=3.2.1=hd88cf55_4 85 | - libgcc-ng=9.1.0=hdf63c60_0 86 | - libgfortran-ng=7.3.0=hdf63c60_0 87 | - libglu=9.0.0=hf484d3e_1 88 | - libopencv=3.4.2=hb342d67_1 89 | - libopus=1.3.1=h7b6447c_0 90 | - libpng=1.6.37=hbc83047_0 91 | - libprotobuf=3.12.3=hd408876_0 92 | - libsodium=1.0.18=h7b6447c_0 93 | - libstdcxx-ng=9.1.0=hdf63c60_0 94 | - libtiff=4.1.0=h2733197_1 95 | - libuuid=1.0.3=h1bed415_2 96 | - libvpx=1.7.0=h439df22_0 97 | - libxcb=1.14=h7b6447c_0 98 | - libxml2=2.9.10=he19cac6_1 99 | - lz4-c=1.9.2=he6710b0_0 100 | - markdown=3.1.1=py36_0 101 | - markupsafe=1.1.1=py36h7b6447c_0 102 | - mccabe=0.6.1=py36_1 103 | - mistune=0.8.4=py36h7b6447c_0 104 | - mkl=2020.1=217 105 | - mkl-service=2.3.0=py36he904b0f_0 106 | - mkl_fft=1.1.0=py36h23d657b_0 107 | - mkl_random=1.1.1=py36h0573a6f_0 108 | - nb_conda=2.2.1=py36_0 109 | - nb_conda_kernels=2.2.3=py36_0 110 | - nbconvert=5.6.1=py36_0 111 | - nbformat=5.0.7=py_0 112 | - ncurses=6.2=he6710b0_1 113 | - ninja=1.9.0=py36hfd86e86_0 114 | - notebook=6.0.3=py36_0 115 | - numpy=1.18.5=py36ha1c710e_0 116 | - numpy-base=1.18.5=py36hde5b4d6_0 117 | - oauthlib=3.1.0=py_0 118 | - olefile=0.46=py36_0 119 | - opencv=3.4.2=py36h6fd60c2_1 120 | - openssl=1.1.1l=h7f8727e_0 121 | - opt_einsum=3.1.0=py_0 122 | - pandas=1.0.3=py36h0573a6f_0 123 | - pandoc=2.9.2.1=0 124 | - pandocfilters=1.4.2=py36_1 125 | - pango=1.42.4=h049681c_0 126 | - parso=0.7.0=py_0 127 | - pcre=8.44=he6710b0_0 128 | - pexpect=4.8.0=py36_0 129 | - pickleshare=0.7.5=py36_0 130 | - pillow=7.2.0=py36hb39fc2d_0 131 | - pip=20.1.1=py36_1 132 | - pixman=0.40.0=h7b6447c_0 133 | - prometheus_client=0.8.0=py_0 134 | - prompt-toolkit=3.0.5=py_0 135 | - prompt_toolkit=3.0.5=0 136 | - protobuf=3.12.3=py36he6710b0_0 137 | - ptyprocess=0.6.0=py36_0 138 | - py-opencv=3.4.2=py36hb342d67_1 139 | - pyasn1=0.4.8=py_0 140 | - pyasn1-modules=0.2.7=py_0 141 | - pycosat=0.6.3=py36h7b6447c_0 142 | - pycparser=2.20=py_2 143 | - pydot=1.4.1=py36_0 144 | - pygments=2.6.1=py_0 145 | - pyjwt=1.7.1=py36_0 146 | - pylint=2.4.4=py36_0 147 | - pyopenssl=19.1.0=py_1 148 | - pyparsing=2.4.7=py_0 149 | - pyqt=5.9.2=py36h05f1152_2 150 | - pyrsistent=0.16.0=py36h7b6447c_0 151 | - pysocks=1.7.1=py36_0 152 | - python=3.6.10=h0371630_0 153 | - python-dateutil=2.8.1=py_0 154 | - python_abi=3.6=1_cp36m 155 | - pytorch=1.5.0=py3.6_cuda10.1.243_cudnn7.6.3_0 156 | - pytz=2020.1=py_0 157 | - pyzmq=19.0.1=py36he6710b0_1 158 | - qt=5.9.7=h5867ecd_1 159 | - qtconsole=4.7.5=py_0 160 | - qtpy=1.9.0=py_0 161 | - readline=7.0=h7b6447c_5 162 | - requests=2.24.0=py_0 163 | - requests-oauthlib=1.3.0=py_0 164 | - rsa=4.0=py_0 165 | - ruamel_yaml=0.15.87=py36h7b6447c_1 166 | - scikit-learn=0.22.1=py36hd81dba3_0 167 | - scipy=1.5.0=py36h0b6359f_0 168 | - send2trash=1.5.0=py36_0 169 | - setuptools=49.2.0=py36_0 170 | - sip=4.19.8=py36hf484d3e_0 171 | - six=1.15.0=py_0 172 | - sqlite=3.32.3=h62c20be_0 173 | - tensorboard=2.3.0=pyh4dce500_0 174 | - tensorboard-plugin-wit=1.6.0=py_0 175 | - tensorboardx=2.1=py_0 176 | - tensorflow=2.1.0=gpu_py36h2e5cdaa_0 177 | - tensorflow-base=2.1.0=gpu_py36h6c5654b_0 178 | - tensorflow-estimator=2.1.0=pyhd54b08b_0 179 | - tensorflow-gpu=2.1.0=h0d30ee6_0 180 | - termcolor=1.1.0=py36_1 181 | - terminado=0.8.3=py36_0 182 | - testpath=0.4.4=py_0 183 | - tk=8.6.10=hbc83047_0 184 | - torchvision=0.6.0=py36_cu101 185 | - tornado=6.0.4=py36h7b6447c_1 186 | - tqdm=4.62.3=pyhd3eb1b0_1 187 | - traitlets=4.3.3=py36_0 188 | - typed-ast=1.4.1=py36h7b6447c_0 189 | - urllib3=1.25.9=py_0 190 | - wcwidth=0.2.5=py_0 191 | - webencodings=0.5.1=py36_1 192 | - werkzeug=1.0.1=py_0 193 | - wheel=0.34.2=py36_0 194 | - widgetsnbextension=3.5.1=py36_0 195 | - wrapt=1.12.1=py36h7b6447c_1 196 | - xz=5.2.5=h7b6447c_0 197 | - yaml=0.2.5=h7b6447c_0 198 | - zeromq=4.3.2=he6710b0_2 199 | - zipp=3.1.0=py_0 200 | - zlib=1.2.11=h7b6447c_3 201 | - zstd=1.4.4=h0b5b093_3 202 | - pip: 203 | - appdirs==1.4.3 204 | - cycler==0.10.0 205 | - efficientnet==1.1.0 206 | - efficientnet-pytorch==0.6.3 207 | - filelock==3.0.12 208 | - future==0.18.2 209 | - gdown==3.12.2 210 | - imagecodecs==2020.2.18 211 | - imageio==2.8.0 212 | - keras==2.3.1 213 | - kerassurgeon==0.1.3 214 | - kiwisolver==1.2.0 215 | - matplotlib==3.2.1 216 | - networkx==2.4 217 | - packaging==20.3 218 | - pooch==1.1.0 219 | - progress==1.5 220 | - progressbar==2.5 221 | # - pydensecrf==0.1 222 | - pydensecrf 223 | - pywavelets==1.1.1 224 | - pyyaml==5.3.1 225 | - scienceplots==1.0.6 226 | - scikit-image==0.17.1 227 | - tfkerassurgeon==0.2.1 228 | - thop==0.0.31-2005241907 229 | - tifffile==2020.5.7 230 | - torchinfo==0.0.5 231 | prefix: /home/lzhu68/miniconda3 232 | -------------------------------------------------------------------------------- /figures/SDDNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/figures/SDDNet.png -------------------------------------------------------------------------------- /figures/qualitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/figures/qualitative_results.png -------------------------------------------------------------------------------- /figures/quantitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/figures/quantitative_results.png -------------------------------------------------------------------------------- /logs/args.txt: -------------------------------------------------------------------------------- 1 | acc_step = 1 2 | action = train 3 | ckpt = None 4 | config = None 5 | eval_batch = 1 6 | eval_data = SBU_test+UCF_test 7 | eval_size = 512 8 | i_print = 10 9 | logdir = logs 10 | loglevel = info 11 | loss = bbce 12 | lr = 0.0005 13 | lr_gamma = 0.7 14 | lr_step = 1 15 | model = BANet.efficientnet-b3 16 | nworker = 4 17 | prob_th = 0.5 18 | save_ckpt = 1 19 | seed = 4 20 | total_ep = 20 21 | train_batch = 4 22 | train_data = SBU_train 23 | train_size = 512 24 | wd = 0.0001 25 | -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1683270473.user-NULL.24389.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1683270473.user-NULL.24389.0 -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1684641896.user-NULL.34088.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1684641896.user-NULL.34088.0 -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1684649636.user-NULL.19888.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1684649636.user-NULL.19888.0 -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1684649792.user-NULL.22683.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1684649792.user-NULL.22683.0 -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1684650039.user-NULL.27206.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1684650039.user-NULL.27206.0 -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1684650615.user-NULL.37139.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1684650615.user-NULL.37139.0 -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1684674591.user-NULL.39936.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1684674591.user-NULL.39936.0 -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1684674671.user-NULL.726.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1684674671.user-NULL.726.0 -------------------------------------------------------------------------------- /logs/summary/events.out.tfevents.1684674825.user-NULL.3845.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/logs/summary/events.out.tfevents.1684674825.user-NULL.3845.0 -------------------------------------------------------------------------------- /logs/train.log: -------------------------------------------------------------------------------- 1 | 2023-05-21 09:13:45,248 - ShadowDet - INFO - Experiment arguments: 2 | ============begin================== 3 | acc_step = 1 4 | action = train 5 | ckpt = None 6 | config = None 7 | eval_batch = 1 8 | eval_data = SBU_test+UCF_test 9 | eval_size = 512 10 | i_print = 10 11 | logdir = logs 12 | loglevel = info 13 | loss = bbce 14 | lr = 0.0005 15 | lr_gamma = 0.7 16 | lr_step = 1 17 | model = BANet.efficientnet-b3 18 | nworker = 4 19 | prob_th = 0.5 20 | save_ckpt = 1 21 | seed = 4 22 | total_ep = 20 23 | train_batch = 4 24 | train_data = SBU_train 25 | train_size = 512 26 | wd = 0.0001 27 | =============end================ 28 | 2023-05-21 09:13:49,937 - ShadowDet - INFO - model BANet.efficientnet-b3 is created! 29 | 2023-05-21 09:13:49,943 - ShadowDet - INFO - Dataloaders are prepared! 30 | 2023-05-21 09:14:17,068 - ShadowDet - INFO - Scores: 31 | =============================================== 32 | SBU_test.pos_err:100.0 33 | SBU_test.neg_err:0.0 34 | SBU_test.ber:50.0 35 | SBU_test.acc:0.8141027688980103 36 | =============================================== 37 | 2023-05-21 09:14:21,998 - ShadowDet - INFO - Scores: 38 | =============================================== 39 | UCF_test.pos_err:100.0 40 | UCF_test.neg_err:0.0 41 | UCF_test.ber:50.0 42 | UCF_test.acc:0.8303857445716858 43 | =============================================== 44 | 2023-05-21 09:14:25,377 - ShadowDet - INFO - [batch 10/1021, epoch 1/20]: loss_mask:0.12509822845458984 loss_noshad:0.07160969078540802 loss_shadimg:0.08985886722803116 loss_filter:1.0134570402442478e-05 loss_total:0.28657692670822144 45 | 2023-05-21 09:14:28,150 - ShadowDet - INFO - [batch 20/1021, epoch 1/20]: loss_mask:0.113631471991539 loss_noshad:0.050021421164274216 loss_shadimg:0.06579504907131195 loss_filter:4.761375294037862e-06 loss_total:0.2294527143239975 46 | 2023-05-21 09:14:30,919 - ShadowDet - INFO - [batch 30/1021, epoch 1/20]: loss_mask:0.0707603171467781 loss_noshad:0.03882474824786186 loss_shadimg:0.05074380710721016 loss_filter:3.494694055916625e-06 loss_total:0.16033238172531128 47 | -------------------------------------------------------------------------------- /modelsize_estimate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def modelsize(model, input, type_size=4): 7 | para = sum([np.prod(list(p.size())) for p in model.parameters()]) 8 | # print('Model {} : Number of params: {}'.format(model._get_name(), para)) 9 | print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000)) 10 | 11 | input_ = input.clone() 12 | input_.requires_grad_(requires_grad=False) 13 | 14 | mods = list(model.modules()) 15 | out_sizes = [] 16 | 17 | for i in range(1, len(mods)): 18 | m = mods[i] 19 | if isinstance(m, nn.ReLU): 20 | if m.inplace: 21 | continue 22 | out = m(input_) 23 | out_sizes.append(np.array(out.size())) 24 | input_ = out 25 | 26 | total_nums = 0 27 | for i in range(len(out_sizes)): 28 | s = out_sizes[i] 29 | nums = np.prod(np.array(s)) 30 | total_nums += nums 31 | 32 | # print('Model {} : Number of intermedite variables without backward: {}'.format(model._get_name(), total_nums)) 33 | # print('Model {} : Number of intermedite variables with backward: {}'.format(model._get_name(), total_nums*2)) 34 | print('Model {} : intermedite variables: {:3f} M (without backward)' 35 | .format(model._get_name(), total_nums * type_size / 1000 / 1000)) 36 | print('Model {} : intermedite variables: {:3f} M (with backward)' 37 | .format(model._get_name(), total_nums * type_size*2 / 1000 / 1000)) 38 | 39 | -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_6.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_6.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_66.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_66.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_666.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_666.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_7.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_7.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_7.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_7.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_77.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_77.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_77.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_77.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_777.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_777.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_basic.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_basic.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_basic_fdr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_basic_fdr.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_basic_fdr_ssf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_basic_fdr_ssf.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_basic_ml.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_basic_ml.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_basic_ml_fdr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_basic_ml_fdr.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_basic_newml.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_basic_newml.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/fdrnet_gyc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/fdrnet_gyc.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/networks/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /networks/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BBCEWithLogitLoss(nn.Module): 7 | ''' 8 | Balanced BCEWithLogitLoss 9 | ''' 10 | def __init__(self): 11 | super(BBCEWithLogitLoss, self).__init__() 12 | 13 | def forward(self, pred, gt): 14 | eps = 1e-10 15 | count_pos = torch.sum(gt) + eps 16 | count_neg = torch.sum(1. - gt) 17 | ratio = count_neg / count_pos 18 | w_neg = count_pos / (count_pos + count_neg) 19 | 20 | bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio) 21 | loss = w_neg * bce1(pred, gt) 22 | 23 | return loss 24 | 25 | 26 | class DiceLoss(nn.Module): 27 | def __init__(self, apply_sigmoid=True, smooth=1): 28 | super(DiceLoss, self).__init__() 29 | self.smooth = smooth 30 | self.apply_sigmoid = apply_sigmoid 31 | 32 | def forward(self, pred, target): 33 | if self.apply_sigmoid: 34 | pred = F.sigmoid(pred) 35 | 36 | numerator = 2 * torch.sum(pred * target) + self.smooth 37 | denominator = torch.sum(pred + target) + self.smooth 38 | return 1 - numerator / denominator 39 | 40 | 41 | class EdgeLoss(nn.Module): 42 | def __init__(self, apply_sigmoid=True): 43 | super(EdgeLoss, self).__init__() 44 | self.apply_sigmoid = apply_sigmoid 45 | 46 | def forward(self, edge_shadmask, edge_noshad): 47 | if self.apply_sigmoid: 48 | edge_shadmask = F.sigmoid(edge_shadmask) 49 | edge_noshad = F.sigmoid(edge_noshad) 50 | # can't for backward gradient computation 51 | # edge_shadmask[edge_shadmask >= 0.5] = 1 52 | # edge_shadmask[edge_shadmask < 0.5] =0 53 | # edge_noshad[edge_noshad >= 0.5] = 1 54 | # edge_noshad[edge_noshad < 0.5] = 0 55 | edge = edge_shadmask + edge_noshad 56 | # print(edge) 57 | edge[edge<1] = 0 58 | edge[edge>=1] = 1 59 | numerator = torch.sum(edge) 60 | denominator = torch.sum(1. - edge) 61 | loss = numerator / denominator 62 | return loss 63 | 64 | class OrthoLoss(nn.Module): 65 | def __init__(self): 66 | super(OrthoLoss, self).__init__() 67 | 68 | def forward(self, pred, target): 69 | batch_size = pred.size(0) 70 | pred = pred.view(batch_size, -1) 71 | target = target.view(batch_size, -1) 72 | 73 | pred_ = pred 74 | target_ = target 75 | ortho_loss = 0 76 | dim = pred.shape[1] 77 | for i in range(pred.shape[0]): 78 | # ortho_loss += torch.mean(torch.abs(pred_[i:i+1,:].mm(target_[i:i+1,:].t()))/dim) 79 | ortho_loss += torch.mean((pred_[i:i+1,:].mm(target_[i:i+1,:].t())).pow(2)/dim) 80 | 81 | ortho_loss /= pred.shape[0] 82 | return ortho_loss 83 | 84 | class DiffLoss(nn.Module): 85 | def __init__(self): 86 | super(DiffLoss, self).__init__() 87 | self.l1 = nn.L1Loss() 88 | self.ortho = OrthoLoss() 89 | 90 | def forward(self, img, noshad, mask): 91 | # batch_size = noshad.size(0) 92 | # mask = (mask > 0.5).type(torch.int64) 93 | img_noshad = img * (1-mask) 94 | noshad_noshad = noshad * (1-mask) 95 | img_shad = img * mask 96 | noshad_shad = noshad * mask 97 | 98 | loss = self.l1(img_noshad, noshad_noshad) 99 | # loss += 0.01*self.ortho(img_shad, noshad_shad) 100 | return loss 101 | 102 | class DiffLoss_2(nn.Module): 103 | def __init__(self): 104 | super(DiffLoss_2, self).__init__() 105 | self.l1 = nn.L1Loss() 106 | self.ortho = OrthoLoss() 107 | 108 | def forward(self, img, noshad, label, mask): 109 | # batch_size = noshad.size(0) 110 | mask = (mask > 0.5).type(torch.int64) 111 | # mask = torch.sigmoid(mask) 112 | img_noshad = img * (1-mask) 113 | noshad_noshad = noshad * (1-label) 114 | 115 | loss = self.l1(img_noshad, noshad_noshad) 116 | # loss += 0.01*self.ortho(img_shad, noshad_shad) 117 | return loss 118 | 119 | class StyleLoss(nn.Module): 120 | def __init__(self): 121 | super(StyleLoss, self).__init__() 122 | self.l1 = nn.L1Loss() 123 | 124 | # def forward(self, img, noshad, mask): 125 | # # batch_size = noshad.size(0) 126 | # hmask = (mask > 0.5).type(torch.int64) 127 | # smask = (mask > 0.25).type(torch.int64) 128 | # img_noshad = img * (smask-hmask) 129 | # noshad_shad = noshad * hmask 130 | 131 | # count_noshad = torch.sum(1.-hmask) 132 | 133 | # lns = torch.mean(img_noshad) 134 | # ls = torch.mean(noshad_shad) 135 | 136 | # loss = torch.abs(lns-ls)*count_noshad 137 | # return loss 138 | 139 | def gram_matrix(self, y): 140 | """ Returns the gram matrix of y (used to compute style loss) """ 141 | (b, c, h, w) = y.size() 142 | features = y.view(b, c, w * h) 143 | features_t = features.transpose(1, 2) #C和w*h转置 144 | gram = features.bmm(features_t) / (c * h * w) #bmm 将features与features_t相乘 145 | return gram 146 | 147 | def forward(self, img, noshad): 148 | # mask = (mask > 0.5).type(torch.int64) 149 | # img_noshad = img * (1-mask) 150 | # # noshad_shad = noshad * mask 151 | # noshad_shad = noshad 152 | 153 | # g_noshad = self.gram_matrix(img_noshad) 154 | # g_shad = self.gram_matrix(noshad_shad) 155 | g_noshad = self.gram_matrix(img) 156 | g_shad = self.gram_matrix(noshad) 157 | 158 | loss = 0 159 | for i in range(g_shad.shape[0]): 160 | loss += torch.mean(torch.abs(g_noshad[i:i+1,:] - g_shad[i:i+1,:])) 161 | return loss 162 | 163 | class ZeroLoss(nn.Module): 164 | def __init__(self): 165 | super(ZeroLoss, self).__init__() 166 | 167 | def forward(self, target): 168 | zero_loss = torch.mean(torch.abs(target)) 169 | return zero_loss -------------------------------------------------------------------------------- /networks/sddnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from torchvision import models as tvmodels 5 | import torchvision 6 | from efficientnet_pytorch import EfficientNet 7 | import math 8 | import numpy as np 9 | from torchvision import transforms 10 | 11 | 12 | class Sobel(nn.Module): 13 | def __init__(self): 14 | super(Sobel, self).__init__() 15 | self.sobel_conv = nn.Conv2d(1, 1, 3, padding=1, bias=False) 16 | sobel_kernel = torch.tensor([[-1,-1,-1], [-1,8,-1], [-1,-1,-1]], dtype=torch.float32) 17 | sobel_kernel = sobel_kernel.reshape((1, 1, 3, 3)) 18 | self.sobel_conv.weight.data = sobel_kernel 19 | self.transform = transforms.Compose([ 20 | transforms.Grayscale(num_output_channels=1), 21 | # transforms.ToTensor() 22 | ]) 23 | 24 | def forward(self, x): 25 | if x.shape[1] !=1: 26 | x = self.transform(x) 27 | x = torch.sigmoid(x) 28 | edge = self.sobel_conv(x) 29 | return edge 30 | 31 | 32 | class SELayer(nn.Module): 33 | def __init__(self, channel, reduction=16): 34 | super(SELayer, self).__init__() 35 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 36 | self.fc = nn.Sequential( 37 | nn.Linear(channel, channel // reduction, bias=False), 38 | nn.LeakyReLU(inplace=True), 39 | nn.Linear(channel // reduction, channel, bias=False), 40 | nn.Sigmoid() 41 | ) 42 | 43 | def forward(self, x): 44 | b, c, _, _ = x.size() 45 | y = self.avg_pool(x).view(b, c) 46 | y = self.fc(y).view(b, c, 1, 1) 47 | return x * y.expand_as(x) 48 | 49 | 50 | class ConstantNormalize(nn.Module): 51 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 52 | super(ConstantNormalize, self).__init__() 53 | mean = torch.Tensor(mean).view([1, 3, 1, 1]) 54 | std = torch.Tensor(std).view([1, 3, 1, 1]) 55 | # https://discuss.pytorch.org/t/keeping-constant-value-in-module-on-correct-device/10129 56 | self.register_buffer('mean', mean) 57 | self.register_buffer('std', std) 58 | 59 | def forward(self, x): 60 | return (x - self.mean) / (self.std + 1e-5) 61 | 62 | 63 | class Conv1x1(nn.Sequential): 64 | def __init__(self, in_planes, out_planes=16, has_se=False, se_reduction=None): 65 | if has_se: 66 | if se_reduction is None: 67 | # se_reduction= int(math.sqrt(in_planes)) 68 | se_reduction = 2 69 | super(Conv1x1, self).__init__(SELayer(in_planes, se_reduction), 70 | nn.Conv2d(in_planes, out_planes, 1, bias=False), 71 | nn.BatchNorm2d(out_planes), 72 | nn.LeakyReLU() 73 | ) 74 | else: 75 | super(Conv1x1, self).__init__(nn.Conv2d(in_planes, out_planes, 1, bias=False), 76 | nn.BatchNorm2d(out_planes), 77 | nn.LeakyReLU() 78 | ) 79 | 80 | # https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnext50_32x4d 81 | class ResBlock(nn.Module): 82 | def __init__(self, in_planes, out_planes): 83 | super(ResBlock, self).__init__() 84 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 85 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1, bias=False) 86 | self.bn1 = nn.BatchNorm2d(out_planes) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, padding=1, bias=False) 89 | self.bn2 = nn.BatchNorm2d(out_planes) 90 | self.in_planes = in_planes 91 | self.out_planes = out_planes 92 | if self.in_planes != self.out_planes: 93 | self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1, bias=False) 94 | self.bn3 = nn.BatchNorm2d(out_planes) 95 | 96 | def forward(self, x): 97 | identity = x 98 | if self.in_planes != self.out_planes: 99 | identity = self.conv3(identity) 100 | identity = self.bn3(identity) 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | 109 | out += identity 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | class FRUnit(nn.Module): 116 | """ 117 | Factorisation and Reweighting unit 118 | """ 119 | def __init__(self, channels=32, mu_init=0.5, reweight_mode='manual', normalize=True): 120 | super(FRUnit, self).__init__() 121 | assert reweight_mode in ['manual', 'constant', 'learnable', 'nn'] 122 | self.mu = mu_init 123 | self.reweight_mode = reweight_mode 124 | self.normalize = normalize 125 | self.inv_conv = ResBlock(channels, channels) 126 | self.var_conv = ResBlock(channels, channels) 127 | # self.mix_conv = ResBlock(2*channels, channels) 128 | # self.mix_conv = Conv1x1(2*channels, channels) 129 | if reweight_mode == 'learnable': 130 | self.mu = nn.Parameter(torch.tensor(mu_init)) 131 | # self.x = nn.Parameter(torch.tensor(0.)) 132 | # self.mu = torch.sigmoid(self.x) 133 | elif reweight_mode == 'nn': 134 | self.fc = nn.Sequential(nn.Linear(channels, 1), 135 | nn.Sigmoid() 136 | ) 137 | else: 138 | self.mu = mu_init 139 | 140 | 141 | def forward(self, feat): 142 | shad_feat = self.inv_conv(feat) 143 | noshad_feat = self.var_conv(feat) 144 | # noshad_feat = feat - shad_feat 145 | 146 | if self.normalize: 147 | shad_feat = F.normalize(shad_feat) 148 | noshad_feat = F.normalize(noshad_feat) 149 | # noshad_feat = shad_feat - (shad_feat * noshad_feat).sum(keepdim=True, dim=1) * shad_feat 150 | # noshad_feat = F.normalize(noshad_feat) 151 | 152 | if self.reweight_mode == 'nn': 153 | gap = feat.mean([2, 3]) 154 | self.mu = self.fc(gap).view(-1, 1, 1, 1) 155 | 156 | # mix_feat = self.mu * noshad_feat + (1 - self.mu) * shad_feat 157 | mix_feat = noshad_feat + shad_feat 158 | # mix_feat = self.mix_conv(torch.cat([shad_feat, noshad_feat], dim=1)) 159 | # print(self.mu) 160 | return shad_feat, noshad_feat, mix_feat 161 | 162 | 163 | def set_mu(self, val): 164 | assert self.reweight_mode == 'manual' 165 | self.mu = val 166 | 167 | 168 | ml_features = [] 169 | 170 | def feature_hook(module, fea_in, fea_out): 171 | # print("hooker working") 172 | # module_name.append(module.__class__) 173 | # features_in_hook.append(fea_in) 174 | global ml_features 175 | ml_features.append(fea_out) 176 | return None 177 | 178 | 179 | class ShadFilter(nn.Module): 180 | def __init__(self, in_channels): 181 | super(ShadFilter, self).__init__() 182 | # self.layer1 = nn.Linear(in_channels, in_channels) 183 | self.layer1 = nn.Linear(in_channels**2, in_channels**2) 184 | # self.layer1 = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, padding=1, bias=False) 185 | # self.layer1 = nn.Conv2d(1, in_channels//2, kernel_size=3, stride=1, padding=1, bias=False) 186 | # self.bn1 = nn.BatchNorm2d(in_channels//2) 187 | self.relu = nn.LeakyReLU() 188 | self.layer2 = nn.Linear(in_channels**2, in_channels**2) 189 | # self.layer2 = nn.Conv2d(in_channels//2, in_channels//2, kernel_size=3, padding=1, bias=False) 190 | # self.layer2 = nn.Conv2d(in_channels//2, in_channels//2, stride=1, kernel_size=3, padding=1, bias=False) 191 | # self.bn2 = nn.BatchNorm2d(1) 192 | # self.bn3 = nn.BatchNorm2d(in_channels//2) 193 | 194 | def gram_matrix(self, y): 195 | """ Returns the gram matrix of y (used to compute style loss) """ 196 | (b, c, h, w) = y.size() 197 | features = y.view(b, c, w * h) 198 | features_t = features.transpose(1, 2) #C和w*h转置 199 | gram = features.bmm(features_t) / (c * h * w) #bmm 将features与features_t相乘 200 | return gram 201 | 202 | def forward(self, x): 203 | # x = torch.triu(self.gram_matrix(x), diagonal=0) 204 | x = self.gram_matrix(x) 205 | (b, h, w) = x.size() 206 | # x = x.view(b, 1, h, w) 207 | x = x.view(b, h*w) 208 | x = self.layer1(x) 209 | # x = self.bn1(x) 210 | x = F.normalize(x) 211 | x = self.relu(x) 212 | x = self.layer2(x) 213 | # x = self.bn2(x) 214 | # x = self.relu(x) 215 | # x = F.normalize(x) 216 | # x = self.bn3(x) 217 | # print(x[0]) 218 | 219 | # print(x.size()) 220 | 221 | return x 222 | 223 | 224 | class ShadFilter2(nn.Module): 225 | def __init__(self, inp_size): 226 | super(ShadFilter2, self).__init__() 227 | self.layer1 = nn.Linear(inp_size, inp_size//2) 228 | self.layer2 = nn.Linear(inp_size//2, inp_size//4) 229 | self.layer3 = nn.Linear(inp_size//4, 64) 230 | self.relu = nn.LeakyReLU() 231 | 232 | def gram_matrix(self, y): 233 | (b, c, h, w) = y.size() 234 | features = y.view(b, c, w * h) 235 | features_t = features.transpose(1, 2) 236 | gram = features.bmm(features_t) 237 | return gram 238 | 239 | def forward(self, x): 240 | gm = self.gram_matrix(x) 241 | b, _, _ = gm.size() 242 | tgm = gm[torch.triu(torch.ones(gm.size()[0], gm.size()[1], gm.size()[2]))==1].view(b, -1) 243 | out = self.layer1(tgm) 244 | out = self.relu(out) 245 | out = self.layer2(out) 246 | out = self.relu(out) 247 | out = self.layer3(out) 248 | return out 249 | 250 | 251 | class LEModule(nn.Module): 252 | def __init__(self, lf_ch, hf_ch, out_ch): 253 | super(LEModule, self).__init__() 254 | self.conv0 = nn.Conv2d(hf_ch, hf_ch, kernel_size=1, padding=0) 255 | self.batch0 = nn.BatchNorm2d(hf_ch) 256 | self.relu0 = nn.LeakyReLU(inplace=True) 257 | self.conv1 = nn.Conv2d(hf_ch, hf_ch, kernel_size=3, padding=1) 258 | self.batch1 = nn.BatchNorm2d(hf_ch) 259 | self.relu1 = nn.LeakyReLU(inplace=True) 260 | self.conv2 = nn.Conv2d(hf_ch, hf_ch, kernel_size=5, padding=2) 261 | self.batch2 = nn.BatchNorm2d(hf_ch) 262 | self.relu2 = nn.LeakyReLU(inplace=True) 263 | self.conv3 = Conv1x1(3*hf_ch, hf_ch) 264 | self.conv4 = Conv1x1(lf_ch+hf_ch, out_ch) 265 | 266 | def forward(self, low_feat, high_feat): 267 | x0 = self.conv0(low_feat) 268 | x0 = self.batch0(x0) 269 | x0 = self.relu0(x0) 270 | x1 = self.conv1(low_feat) 271 | x1 = self.batch1(x1) 272 | x1 = self.relu1(x1) 273 | x2 = self.conv2(low_feat) 274 | x2 = self.batch2(x2) 275 | x2 = self.relu2(x2) 276 | x3 = self.conv3(torch.cat([x0,x1,x2], dim=1)) 277 | x4 = self.conv4(torch.cat([high_feat,x3], dim=1)) 278 | return x4 279 | 280 | 281 | class FDUnit(nn.Module): 282 | def __init__(self, in_ch, out_ch): 283 | super(FDUnit, self).__init__() 284 | self.shad_conv = ResBlock(in_ch, out_ch) 285 | self.shad_bn = nn.BatchNorm2d(out_ch) 286 | self.shad_fc = nn.Sequential(nn.Linear(out_ch, 1), 287 | nn.Sigmoid()) 288 | self.noshad_conv = ResBlock(in_ch, out_ch) 289 | self.noshad_bn = nn.BatchNorm2d(out_ch) 290 | self.noshad_fc = nn.Sequential(nn.Linear(out_ch, 1), 291 | nn.Sigmoid()) 292 | 293 | def forward(self, x): 294 | shad_feat = self.shad_conv(x) 295 | shad_feat = self.shad_bn(shad_feat) 296 | alpha = shad_feat.mean([2,3]) 297 | alpha = self.shad_fc(alpha).view(-1, 1, 1, 1) 298 | noshad_feat = self.noshad_conv(x) 299 | noshad_feat = self.noshad_bn(noshad_feat) 300 | beta = noshad_feat.mean([2,3]) 301 | beta = self.noshad_fc(beta).view(-1, 1, 1, 1) 302 | mix_feat = alpha * shad_feat + beta * noshad_feat 303 | return shad_feat, noshad_feat, mix_feat 304 | 305 | 306 | class SDDNet(nn.Module): 307 | # decompose net 308 | def __init__(self, 309 | backbone='efficientnet-b0', 310 | proj_planes=16, 311 | pred_planes=32, 312 | use_pretrained=True, 313 | fix_backbone=False, 314 | has_se=False, 315 | dropout_2d=0, 316 | normalize=False, 317 | mu_init=0.5, 318 | reweight_mode='constant'): 319 | 320 | super(SDDNet, self).__init__() 321 | 322 | self.mu_init = mu_init 323 | # self.reweight_mode = reweight_mode 324 | self.reweight_mode = 'nn' 325 | 326 | # load backbone 327 | if use_pretrained: 328 | self.feat_net = EfficientNet.from_pretrained(backbone) 329 | # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py 330 | else: 331 | self.feat_net = EfficientNet.from_name(backbone) 332 | 333 | # load vgg 334 | # vgg = torchvision.models.vgg16(pretrained=True) 335 | 336 | # remove classification head to get correct param count 337 | self.feat_net._avg_pooling=None 338 | self.feat_net._dropout=None 339 | self.feat_net._fc=None 340 | 341 | # register hook to extract multi-level features 342 | in_planes = [] 343 | feat_layer_ids = list(range(0, len(self.feat_net._blocks), 2)) 344 | for idx in feat_layer_ids: 345 | self.feat_net._blocks[idx].register_forward_hook(hook=feature_hook) 346 | in_planes.append(self.feat_net._blocks[idx]._bn2.num_features) 347 | 348 | if fix_backbone: 349 | for param in self.feat_net.parameters(): 350 | param.requires_grad = False 351 | 352 | self.norm = ConstantNormalize() 353 | 354 | # 1*1 projection conv 355 | proj_convs = [ Conv1x1(ip, proj_planes, has_se=has_se) for ip in in_planes ] 356 | self.proj_convs = nn.ModuleList(proj_convs) 357 | 358 | # two stream feature 359 | # self.stem_conv = Conv1x1(proj_planes*len(in_planes), pred_planes, has_se=has_se) 360 | self.lower_conv = Conv1x1(proj_planes*3, pred_planes, has_se=has_se) 361 | self.higher_conv = Conv1x1(proj_planes*10, pred_planes, has_se=has_se) 362 | # self.lower_conv = ResBlock(proj_planes*3, pred_planes) 363 | # self.higher_conv = ResBlock(proj_planes*10, pred_planes) 364 | self.fr_lower = FRUnit(pred_planes, mu_init=self.mu_init, reweight_mode=self.reweight_mode, normalize=normalize) 365 | self.fr_higher = FRUnit(pred_planes, mu_init=self.mu_init, reweight_mode=self.reweight_mode, normalize=normalize) 366 | # self.fr_lower = FDUnit(pred_planes, pred_planes) 367 | # self.fr_higher = FDUnit(pred_planes, pred_planes) 368 | # self.fr_lower = FRUnit(pred_planes, mu_init=0.7, reweight_mode=self.reweight_mode, normalize=normalize) 369 | # self.fr_higher = FRUnit(pred_planes, mu_init=0.6, reweight_mode=self.reweight_mode, normalize=normalize) 370 | self.mix_conv = Conv1x1(pred_planes*2, pred_planes, has_se=True) 371 | self.noshad_conv = Conv1x1(pred_planes*2, pred_planes, has_se=True) 372 | self.shad_conv = Conv1x1(pred_planes*2, pred_planes, has_se=True) 373 | # self.mix_conv = LEModule(pred_planes, pred_planes, pred_planes) 374 | # self.noshad_conv = LEModule(pred_planes, pred_planes, pred_planes) 375 | # self.shad_conv = LEModule(pred_planes, pred_planes, pred_planes) 376 | # self.fc = nn.Linear(pred_planes, 1) 377 | 378 | # prediction 379 | pred_layers_shadimg = [] 380 | pred_layers_shadmask = [] 381 | pred_layers_maskimg = [] 382 | pred_layers_noshad = [] 383 | if dropout_2d > 1e-6: 384 | pred_layers_shadimg.append(nn.Dropout2d(p=dropout_2d)) 385 | pred_layers_shadmask.append(nn.Dropout2d(p=dropout_2d)) 386 | pred_layers_maskimg.append(nn.Dropout2d(p=dropout_2d)) 387 | pred_layers_noshad.append(nn.Dropout2d(p=dropout_2d)) 388 | pred_layers_shadimg.append(nn.Conv2d(pred_planes, 3, 1)) 389 | # pred_layers_shadimg.append(nn.BatchNorm2d(3)) 390 | pred_layers_shadmask.append(nn.Conv2d(pred_planes, 1, 1)) 391 | # pred_layers_shadmask.append(nn.BatchNorm2d(1)) 392 | pred_layers_maskimg.append(nn.Conv2d(pred_planes, 3, 1)) 393 | # pred_layers_maskimg.append(nn.BatchNorm2d(3)) 394 | pred_layers_noshad.append(nn.Conv2d(pred_planes, 3, 1)) 395 | # pred_layers_noshad.append(nn.BatchNorm2d(3)) 396 | self.pred_conv_shadimg = nn.Sequential(*pred_layers_shadimg) 397 | self.pred_conv_shadmask = nn.Sequential(*pred_layers_shadmask) 398 | self.pred_conv_maskimg = nn.Sequential(*pred_layers_maskimg) 399 | self.pred_conv_noshad = nn.Sequential(*pred_layers_noshad) 400 | 401 | # self.conv1x1 = nn.Conv2d(pred_planes*2, 3, 1) 402 | self.pre_conv1x1 = nn.Conv2d(4, 3, 1) 403 | 404 | for m in self.modules(): 405 | if isinstance(m, nn.ReLU): 406 | m.inplace = True 407 | 408 | # self.sobel_conv = Sobel() 409 | self.shad_filter_low= ShadFilter(pred_planes) 410 | self.shad_filter_high = ShadFilter(pred_planes) 411 | # self.shad_filter_low= ShadFilter2(528) 412 | # self.shad_filter_high = ShadFilter2(528) 413 | 414 | def forward(self, x): 415 | global ml_features 416 | 417 | b, c, h, w = x.size() 418 | ml_features = [] 419 | 420 | if c == 4: 421 | x = self.pre_conv1x1(x) 422 | 423 | _ = self.feat_net.extract_features(self.norm(x)) 424 | 425 | h_f, w_f = ml_features[0].size()[2:] 426 | # h_f, w_f = ml_features[2].size()[2:] 427 | # print(h_f, w_f) 428 | proj_features = [] 429 | 430 | for i in range(3): 431 | cur_proj_feature = self.proj_convs[i](ml_features[i]) 432 | cur_proj_feature_up = F.interpolate(cur_proj_feature, size=(h_f, w_f), mode='bilinear') 433 | proj_features.append(cur_proj_feature_up) 434 | cat_feature_1 = torch.cat(proj_features, dim=1) 435 | stem_feat_low = self.lower_conv(cat_feature_1) 436 | 437 | 438 | proj_features.clear() 439 | for i in range(3, len(ml_features)): 440 | cur_proj_feature = self.proj_convs[i](ml_features[i]) 441 | cur_proj_feature_up = F.interpolate(cur_proj_feature, size=(h_f, w_f), mode='bilinear') 442 | proj_features.append(cur_proj_feature_up) 443 | cat_feature_2 = torch.cat(proj_features, dim=1) 444 | stem_feat_high = self.higher_conv(cat_feature_2) 445 | 446 | # stem_feat = self.stem_conv(cat_feature) 447 | 448 | # factorised feature 449 | # shad_feat, noshad_feat, mix_feat = self.fr(stem_feat) 450 | low_shad, low_noshad, low_feat = self.fr_lower(stem_feat_low) 451 | high_shad, high_noshad, high_feat = self.fr_higher(stem_feat_high) 452 | shad_feat = self.shad_conv(torch.cat([low_shad, high_shad], dim=1)) 453 | noshad_feat = self.noshad_conv(torch.cat([low_noshad, high_noshad], dim=1)) 454 | mix_feat = self.mix_conv(torch.cat([low_feat, high_feat], dim=1)) 455 | # shad_feat = self.shad_conv(low_shad, high_shad) 456 | # noshad_feat = self.noshad_conv(low_noshad, high_noshad) 457 | # mix_feat = self.mix_conv(low_feat, high_feat) 458 | 459 | f_low_shad = self.shad_filter_low(low_shad) 460 | f_low_noshad = self.shad_filter_low(low_noshad) 461 | f_low_feat = self.shad_filter_low(low_feat) 462 | f_high_shad = self.shad_filter_high(high_shad) 463 | f_high_noshad = self.shad_filter_high(high_noshad) 464 | f_high_feat = self.shad_filter_high(high_feat) 465 | 466 | if self.training: 467 | # logits = F.interpolate(self.pred_conv(mix_feat), size=(h, w), mode='bilinear') 468 | # g = self.fc(var_feat.mean([2, 3])) 469 | # return logits, inv_feat, g 470 | # mix_feat = F.interpolate(mix_feat, size=(h, w), mode='bilinear') 471 | # logits_shadimg = self.pred_conv_shadimg(mix_feat) 472 | logits_shadimg = F.interpolate(self.pred_conv_shadimg(mix_feat), size=(h, w), mode='bilinear') 473 | # train_mask = self.pred_conv_shadmask(shad_feat) 474 | # shad_feat = F.interpolate(shad_feat, size=(h, w), mode='bilinear') 475 | # logits_shadmask = self.pred_conv_shadmask(shad_feat) 476 | logits_shadmask = F.interpolate(self.pred_conv_shadmask(shad_feat), size=(h, w), mode='bilinear') 477 | # train_mask = (F.sigmoid(train_mask) > 0.4).type(torch.int64) 478 | # f_noshad_region = train_mask * noshad_feat 479 | # f_mask_region = train_mask * shad_feat 480 | # logits_maskimg = self.pred_conv_shadimg(shad_feat) 481 | logits_maskimg = F.interpolate(self.pred_conv_shadimg(shad_feat), size=(h, w), mode='bilinear') 482 | # logits_maskimg = F.interpolate(self.pred_conv_shadimg(shad_feat), size=(h, w), mode='bilinear') 483 | # noshad_feat = F.interpolate(noshad_feat, size=(h, w), mode='bilinear') 484 | # logits_noshad = self.pred_conv_noshad(noshad_feat) 485 | logits_noshad = F.interpolate(self.pred_conv_noshad(noshad_feat), size=(h, w), mode='bilinear') 486 | # logits_noshad = F.interpolate(self.pred_conv_shadimg(noshad_feat), size=(h, w), mode='bilinear') 487 | # selfsup_noshad = F.interpolate(self.conv1x1(torch.cat([shad_feat, mix_feat], dim=1)), size=(h, w), mode='bilinear') 488 | 489 | # logits_shadimg = F.sigmoid(logits_shadimg) 490 | # logits_shadmask = F.sigmoid(logits_shadmask) 491 | # logits_noshad = F.sigmoid(logits_noshad) 492 | 493 | # sobel_shadmask = self.sobel_conv(logits_shadmask) 494 | # sobel_noshad = self.sobel_conv(logits_noshad) 495 | # sobel_noshad = self.sobel_conv(logits_shadimg) 496 | # print(logits_shadmask) 497 | return logits_shadimg, logits_shadmask, logits_noshad, \ 498 | f_low_shad, f_high_shad, f_low_noshad, f_high_noshad, f_low_feat, f_high_feat, mix_feat,shad_feat,noshad_feat 499 | else: 500 | # if self.reweight_mode != 'learnable': 501 | # mix_feat = self.mu_init * noshad_feat + (1 - self.mu_init) * shad_feat 502 | # # logits = F.interpolate(self.pred_conv_shad(mix_feat), size=(h, w), mode='bilinear') 503 | # logits_shadmask = F.interpolate(self.pred_conv_shadmask(shad_feat), size=(h, w), mode='bilinear') 504 | # # logits_maskimg = F.interpolate(self.pred_conv_maskimg(shad_feat), size=(h, w), mode='bilinear') 505 | # logits_maskimg = F.interpolate(self.pred_conv_shadimg(shad_feat), size=(h, w), mode='bilinear') 506 | # # logits_noshad = F.interpolate(self.pred_conv_noshad(noshad_feat), size=(h, w), mode='bilinear') 507 | # logits_noshad = F.interpolate(self.pred_conv_shadimg(noshad_feat), size=(h, w), mode='bilinear') 508 | # # selfsup_noshad = F.interpolate(self.conv1x1(torch.cat([shad_feat, mix_feat], dim=1)), size=(h, w), mode='bilinear') 509 | # logits_shadimg = F.interpolate(self.pred_conv_shadimg(mix_feat), size=(h, w), mode='bilinear') 510 | # else: 511 | 512 | # seems no need for if, i can strictly use this else branch 513 | 514 | # logits = F.interpolate(self.pred_conv(mix_feat), size=(h, w), mode='bilinear') 515 | logits_shadmask = F.interpolate(self.pred_conv_shadmask(shad_feat), size=(h, w), mode='bilinear') 516 | # logits_maskimg = F.interpolate(self.pred_conv_maskimg(shad_feat), size=(h, w), mode='bilinear') 517 | # logits_maskimg = F.interpolate(self.pred_conv_shadimg(shad_feat), size=(h, w), mode='bilinear') 518 | # logits_noshad = F.interpolate(self.pred_conv_noshad(noshad_feat), size=(h, w), mode='bilinear') 519 | # logits_noshad = F.interpolate(self.pred_conv_shadimg(noshad_feat), size=(h, w), mode='bilinear') 520 | # selfsup_noshad = F.interpolate(self.conv1x1(torch.cat([shad_feat, mix_feat], dim=1)), size=(h, w), mode='bilinear') 521 | # logits_shadimg = F.interpolate(self.pred_conv_shadimg(mix_feat), size=(h, w), mode='bilinear') 522 | 523 | # sobel_shadmask = self.sobel_conv(logits_shadmask) 524 | # sobel_noshad = self.sobel_conv(logits_noshad) 525 | 526 | # logits_shadmask = torch.sigmoid(logits_shadmask) 527 | # logits_shadimg = torch.sigmoid(logits_shadimg) 528 | # logits_noshad = torch.sigmoid(logits_noshad) 529 | # logits_maskimg = torch.sigmoid(logits_maskimg) 530 | 531 | # sobel_noshad = self.sobel_conv(logits_shadimg) 532 | # return {'logit': logits_shadmask, 'noshad': logits_noshad, 'shad':logits_shadimg, 'maskimg': logits_maskimg} 533 | return {'logit': logits_shadmask} 534 | # return {'logit': logits_shadmask, 'noshad': logits_noshad, 'supervised': selfsup_noshad, 'shad':logits_shadimg} 535 | 536 | 537 | -------------------------------------------------------------------------------- /resnext/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'quanlong' 2 | # from resnext101_regular import ResNeXt101 3 | -------------------------------------------------------------------------------- /resnext/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/resnext/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /resnext/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/resnext/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /resnext/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/resnext/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /resnext/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/resnext/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /resnext/__pycache__/resnext101_regular.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/resnext/__pycache__/resnext101_regular.cpython-36.pyc -------------------------------------------------------------------------------- /resnext/__pycache__/resnext101_regular.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/resnext/__pycache__/resnext101_regular.cpython-39.pyc -------------------------------------------------------------------------------- /resnext/__pycache__/resnext_101_32x4d_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/resnext/__pycache__/resnext_101_32x4d_.cpython-36.pyc -------------------------------------------------------------------------------- /resnext/__pycache__/resnext_101_32x4d_.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/resnext/__pycache__/resnext_101_32x4d_.cpython-39.pyc -------------------------------------------------------------------------------- /resnext/config.py: -------------------------------------------------------------------------------- 1 | resnext_101_32_path = 'resnext/resnext_101_32x4d.pth' 2 | -------------------------------------------------------------------------------- /resnext/resnext101_regular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from resnext import resnext_101_32x4d_ 5 | from resnext.config import resnext_101_32_path 6 | 7 | 8 | class ResNeXt101(nn.Module): 9 | def __init__(self): 10 | super(ResNeXt101, self).__init__() 11 | net = resnext_101_32x4d_.resnext_101_32x4d 12 | net.load_state_dict(torch.load(resnext_101_32_path)) 13 | 14 | net = list(net.children()) 15 | self.layer0 = nn.Sequential(*net[:3]) 16 | self.layer1 = nn.Sequential(*net[3: 5]) 17 | self.layer2 = net[5] 18 | self.layer3 = net[6] 19 | self.layer4 = net[7] 20 | 21 | def forward(self, x): 22 | layer0 = self.layer0(x) 23 | layer1 = self.layer1(layer0) 24 | layer2 = self.layer2(layer1) 25 | layer3 = self.layer3(layer2) 26 | layer4 = self.layer4(layer3) 27 | return layer4 28 | -------------------------------------------------------------------------------- /resnext/resnext101_regular_scratch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from resnext import resnext_101_32x4d_ 5 | from resnext.config import resnext_101_32_path 6 | 7 | 8 | class ResNeXt101(nn.Module): 9 | def __init__(self): 10 | super(ResNeXt101, self).__init__() 11 | net = resnext_101_32x4d_.resnext_101_32x4d 12 | # net.load_state_dict(torch.load(resnext_101_32_path)) 13 | 14 | net = list(net.children()) 15 | self.layer0 = nn.Sequential(*net[:3]) 16 | self.layer1 = nn.Sequential(*net[3: 5]) 17 | self.layer2 = net[5] 18 | self.layer3 = net[6] 19 | self.layer4 = net[7] 20 | 21 | def forward(self, x): 22 | layer0 = self.layer0(x) 23 | layer1 = self.layer1(layer0) 24 | layer2 = self.layer2(layer1) 25 | layer3 = self.layer3(layer2) 26 | layer4 = self.layer4(layer3) 27 | return layer4 28 | -------------------------------------------------------------------------------- /resnext/resnext101_regular_sep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from resnext import resnext_101_32x4d_ 5 | from resnext.config import resnext_101_32_path 6 | 7 | 8 | class ResNeXt101(nn.Module): 9 | def __init__(self): 10 | super(ResNeXt101, self).__init__() 11 | net = resnext_101_32x4d_.resnext_101_32x4d 12 | net.load_state_dict(torch.load(resnext_101_32_path)) 13 | 14 | net = list(net.children()) 15 | self.layer0 = nn.Sequential(*net[:3]) 16 | self.layer1 = nn.Sequential(*net[3: 5]) 17 | self.layer2 = net[5] 18 | self.layer3 = net[6] 19 | self.layer4 = net[7] 20 | 21 | def forward(self, x): 22 | layer0 = self.layer0(x) 23 | layer1 = self.layer1(layer0) 24 | layer2 = self.layer2(layer1) 25 | layer3 = self.layer3(layer2) 26 | layer4 = self.layer4(layer3) 27 | return layer4 28 | -------------------------------------------------------------------------------- /resnext/resnext_101_32x4d_.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class LambdaBase(nn.Sequential): 7 | def __init__(self, fn, *args): 8 | super(LambdaBase, self).__init__(*args) 9 | self.lambda_func = fn 10 | 11 | def forward_prepare(self, input): 12 | output = [] 13 | for module in self._modules.values(): 14 | output.append(module(input)) 15 | return output if output else input 16 | 17 | 18 | class Lambda(LambdaBase): 19 | def forward(self, input): 20 | return self.lambda_func(self.forward_prepare(input)) 21 | 22 | 23 | class LambdaMap(LambdaBase): 24 | def forward(self, input): 25 | return list(map(self.lambda_func, self.forward_prepare(input))) 26 | 27 | 28 | class LambdaReduce(LambdaBase): 29 | def forward(self, input): 30 | return reduce(self.lambda_func, self.forward_prepare(input)) 31 | 32 | 33 | resnext_101_32x4d = nn.Sequential( # Sequential, 34 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), 35 | nn.BatchNorm2d(64), 36 | nn.ReLU(), 37 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)), 38 | nn.Sequential( # Sequential, 39 | nn.Sequential( # Sequential, 40 | LambdaMap(lambda x: x, # ConcatTable, 41 | nn.Sequential( # Sequential, 42 | nn.Sequential( # Sequential, 43 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 44 | nn.BatchNorm2d(128), 45 | nn.ReLU(), 46 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 47 | nn.BatchNorm2d(128), 48 | nn.ReLU(), 49 | ), 50 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 51 | nn.BatchNorm2d(256), 52 | ), 53 | nn.Sequential( # Sequential, 54 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 55 | nn.BatchNorm2d(256), 56 | ), 57 | ), 58 | LambdaReduce(lambda x, y: x + y), # CAddTable, 59 | nn.ReLU(), 60 | ), 61 | nn.Sequential( # Sequential, 62 | LambdaMap(lambda x: x, # ConcatTable, 63 | nn.Sequential( # Sequential, 64 | nn.Sequential( # Sequential, 65 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 66 | nn.BatchNorm2d(128), 67 | nn.ReLU(), 68 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 69 | nn.BatchNorm2d(128), 70 | nn.ReLU(), 71 | ), 72 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 73 | nn.BatchNorm2d(256), 74 | ), 75 | Lambda(lambda x: x), # Identity, 76 | ), 77 | LambdaReduce(lambda x, y: x + y), # CAddTable, 78 | nn.ReLU(), 79 | ), 80 | nn.Sequential( # Sequential, 81 | LambdaMap(lambda x: x, # ConcatTable, 82 | nn.Sequential( # Sequential, 83 | nn.Sequential( # Sequential, 84 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 85 | nn.BatchNorm2d(128), 86 | nn.ReLU(), 87 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 88 | nn.BatchNorm2d(128), 89 | nn.ReLU(), 90 | ), 91 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 92 | nn.BatchNorm2d(256), 93 | ), 94 | Lambda(lambda x: x), # Identity, 95 | ), 96 | LambdaReduce(lambda x, y: x + y), # CAddTable, 97 | nn.ReLU(), 98 | ), 99 | ), 100 | nn.Sequential( # Sequential, 101 | nn.Sequential( # Sequential, 102 | LambdaMap(lambda x: x, # ConcatTable, 103 | nn.Sequential( # Sequential, 104 | nn.Sequential( # Sequential, 105 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 106 | nn.BatchNorm2d(256), 107 | nn.ReLU(), 108 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 109 | nn.BatchNorm2d(256), 110 | nn.ReLU(), 111 | ), 112 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 113 | nn.BatchNorm2d(512), 114 | ), 115 | nn.Sequential( # Sequential, 116 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 117 | nn.BatchNorm2d(512), 118 | ), 119 | ), 120 | LambdaReduce(lambda x, y: x + y), # CAddTable, 121 | nn.ReLU(), 122 | ), 123 | nn.Sequential( # Sequential, 124 | LambdaMap(lambda x: x, # ConcatTable, 125 | nn.Sequential( # Sequential, 126 | nn.Sequential( # Sequential, 127 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 128 | nn.BatchNorm2d(256), 129 | nn.ReLU(), 130 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 131 | nn.BatchNorm2d(256), 132 | nn.ReLU(), 133 | ), 134 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 135 | nn.BatchNorm2d(512), 136 | ), 137 | Lambda(lambda x: x), # Identity, 138 | ), 139 | LambdaReduce(lambda x, y: x + y), # CAddTable, 140 | nn.ReLU(), 141 | ), 142 | nn.Sequential( # Sequential, 143 | LambdaMap(lambda x: x, # ConcatTable, 144 | nn.Sequential( # Sequential, 145 | nn.Sequential( # Sequential, 146 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 147 | nn.BatchNorm2d(256), 148 | nn.ReLU(), 149 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 150 | nn.BatchNorm2d(256), 151 | nn.ReLU(), 152 | ), 153 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 154 | nn.BatchNorm2d(512), 155 | ), 156 | Lambda(lambda x: x), # Identity, 157 | ), 158 | LambdaReduce(lambda x, y: x + y), # CAddTable, 159 | nn.ReLU(), 160 | ), 161 | nn.Sequential( # Sequential, 162 | LambdaMap(lambda x: x, # ConcatTable, 163 | nn.Sequential( # Sequential, 164 | nn.Sequential( # Sequential, 165 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 166 | nn.BatchNorm2d(256), 167 | nn.ReLU(), 168 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 169 | nn.BatchNorm2d(256), 170 | nn.ReLU(), 171 | ), 172 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 173 | nn.BatchNorm2d(512), 174 | ), 175 | Lambda(lambda x: x), # Identity, 176 | ), 177 | LambdaReduce(lambda x, y: x + y), # CAddTable, 178 | nn.ReLU(), 179 | ), 180 | ), 181 | nn.Sequential( # Sequential, 182 | nn.Sequential( # Sequential, 183 | LambdaMap(lambda x: x, # ConcatTable, 184 | nn.Sequential( # Sequential, 185 | nn.Sequential( # Sequential, 186 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 187 | nn.BatchNorm2d(512), 188 | nn.ReLU(), 189 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 190 | nn.BatchNorm2d(512), 191 | nn.ReLU(), 192 | ), 193 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 194 | nn.BatchNorm2d(1024), 195 | ), 196 | nn.Sequential( # Sequential, 197 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 198 | nn.BatchNorm2d(1024), 199 | ), 200 | ), 201 | LambdaReduce(lambda x, y: x + y), # CAddTable, 202 | nn.ReLU(), 203 | ), 204 | nn.Sequential( # Sequential, 205 | LambdaMap(lambda x: x, # ConcatTable, 206 | nn.Sequential( # Sequential, 207 | nn.Sequential( # Sequential, 208 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 209 | nn.BatchNorm2d(512), 210 | nn.ReLU(), 211 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 212 | nn.BatchNorm2d(512), 213 | nn.ReLU(), 214 | ), 215 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 216 | nn.BatchNorm2d(1024), 217 | ), 218 | Lambda(lambda x: x), # Identity, 219 | ), 220 | LambdaReduce(lambda x, y: x + y), # CAddTable, 221 | nn.ReLU(), 222 | ), 223 | nn.Sequential( # Sequential, 224 | LambdaMap(lambda x: x, # ConcatTable, 225 | nn.Sequential( # Sequential, 226 | nn.Sequential( # Sequential, 227 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 228 | nn.BatchNorm2d(512), 229 | nn.ReLU(), 230 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 231 | nn.BatchNorm2d(512), 232 | nn.ReLU(), 233 | ), 234 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 235 | nn.BatchNorm2d(1024), 236 | ), 237 | Lambda(lambda x: x), # Identity, 238 | ), 239 | LambdaReduce(lambda x, y: x + y), # CAddTable, 240 | nn.ReLU(), 241 | ), 242 | nn.Sequential( # Sequential, 243 | LambdaMap(lambda x: x, # ConcatTable, 244 | nn.Sequential( # Sequential, 245 | nn.Sequential( # Sequential, 246 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 247 | nn.BatchNorm2d(512), 248 | nn.ReLU(), 249 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 250 | nn.BatchNorm2d(512), 251 | nn.ReLU(), 252 | ), 253 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 254 | nn.BatchNorm2d(1024), 255 | ), 256 | Lambda(lambda x: x), # Identity, 257 | ), 258 | LambdaReduce(lambda x, y: x + y), # CAddTable, 259 | nn.ReLU(), 260 | ), 261 | nn.Sequential( # Sequential, 262 | LambdaMap(lambda x: x, # ConcatTable, 263 | nn.Sequential( # Sequential, 264 | nn.Sequential( # Sequential, 265 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 266 | nn.BatchNorm2d(512), 267 | nn.ReLU(), 268 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 269 | nn.BatchNorm2d(512), 270 | nn.ReLU(), 271 | ), 272 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 273 | nn.BatchNorm2d(1024), 274 | ), 275 | Lambda(lambda x: x), # Identity, 276 | ), 277 | LambdaReduce(lambda x, y: x + y), # CAddTable, 278 | nn.ReLU(), 279 | ), 280 | nn.Sequential( # Sequential, 281 | LambdaMap(lambda x: x, # ConcatTable, 282 | nn.Sequential( # Sequential, 283 | nn.Sequential( # Sequential, 284 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 285 | nn.BatchNorm2d(512), 286 | nn.ReLU(), 287 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 288 | nn.BatchNorm2d(512), 289 | nn.ReLU(), 290 | ), 291 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 292 | nn.BatchNorm2d(1024), 293 | ), 294 | Lambda(lambda x: x), # Identity, 295 | ), 296 | LambdaReduce(lambda x, y: x + y), # CAddTable, 297 | nn.ReLU(), 298 | ), 299 | nn.Sequential( # Sequential, 300 | LambdaMap(lambda x: x, # ConcatTable, 301 | nn.Sequential( # Sequential, 302 | nn.Sequential( # Sequential, 303 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 304 | nn.BatchNorm2d(512), 305 | nn.ReLU(), 306 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 307 | nn.BatchNorm2d(512), 308 | nn.ReLU(), 309 | ), 310 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 311 | nn.BatchNorm2d(1024), 312 | ), 313 | Lambda(lambda x: x), # Identity, 314 | ), 315 | LambdaReduce(lambda x, y: x + y), # CAddTable, 316 | nn.ReLU(), 317 | ), 318 | nn.Sequential( # Sequential, 319 | LambdaMap(lambda x: x, # ConcatTable, 320 | nn.Sequential( # Sequential, 321 | nn.Sequential( # Sequential, 322 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 323 | nn.BatchNorm2d(512), 324 | nn.ReLU(), 325 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 326 | nn.BatchNorm2d(512), 327 | nn.ReLU(), 328 | ), 329 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 330 | nn.BatchNorm2d(1024), 331 | ), 332 | Lambda(lambda x: x), # Identity, 333 | ), 334 | LambdaReduce(lambda x, y: x + y), # CAddTable, 335 | nn.ReLU(), 336 | ), 337 | nn.Sequential( # Sequential, 338 | LambdaMap(lambda x: x, # ConcatTable, 339 | nn.Sequential( # Sequential, 340 | nn.Sequential( # Sequential, 341 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 342 | nn.BatchNorm2d(512), 343 | nn.ReLU(), 344 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 345 | nn.BatchNorm2d(512), 346 | nn.ReLU(), 347 | ), 348 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 349 | nn.BatchNorm2d(1024), 350 | ), 351 | Lambda(lambda x: x), # Identity, 352 | ), 353 | LambdaReduce(lambda x, y: x + y), # CAddTable, 354 | nn.ReLU(), 355 | ), 356 | nn.Sequential( # Sequential, 357 | LambdaMap(lambda x: x, # ConcatTable, 358 | nn.Sequential( # Sequential, 359 | nn.Sequential( # Sequential, 360 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 361 | nn.BatchNorm2d(512), 362 | nn.ReLU(), 363 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 364 | nn.BatchNorm2d(512), 365 | nn.ReLU(), 366 | ), 367 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 368 | nn.BatchNorm2d(1024), 369 | ), 370 | Lambda(lambda x: x), # Identity, 371 | ), 372 | LambdaReduce(lambda x, y: x + y), # CAddTable, 373 | nn.ReLU(), 374 | ), 375 | nn.Sequential( # Sequential, 376 | LambdaMap(lambda x: x, # ConcatTable, 377 | nn.Sequential( # Sequential, 378 | nn.Sequential( # Sequential, 379 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 380 | nn.BatchNorm2d(512), 381 | nn.ReLU(), 382 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 383 | nn.BatchNorm2d(512), 384 | nn.ReLU(), 385 | ), 386 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 387 | nn.BatchNorm2d(1024), 388 | ), 389 | Lambda(lambda x: x), # Identity, 390 | ), 391 | LambdaReduce(lambda x, y: x + y), # CAddTable, 392 | nn.ReLU(), 393 | ), 394 | nn.Sequential( # Sequential, 395 | LambdaMap(lambda x: x, # ConcatTable, 396 | nn.Sequential( # Sequential, 397 | nn.Sequential( # Sequential, 398 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 399 | nn.BatchNorm2d(512), 400 | nn.ReLU(), 401 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 402 | nn.BatchNorm2d(512), 403 | nn.ReLU(), 404 | ), 405 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 406 | nn.BatchNorm2d(1024), 407 | ), 408 | Lambda(lambda x: x), # Identity, 409 | ), 410 | LambdaReduce(lambda x, y: x + y), # CAddTable, 411 | nn.ReLU(), 412 | ), 413 | nn.Sequential( # Sequential, 414 | LambdaMap(lambda x: x, # ConcatTable, 415 | nn.Sequential( # Sequential, 416 | nn.Sequential( # Sequential, 417 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 418 | nn.BatchNorm2d(512), 419 | nn.ReLU(), 420 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 421 | nn.BatchNorm2d(512), 422 | nn.ReLU(), 423 | ), 424 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 425 | nn.BatchNorm2d(1024), 426 | ), 427 | Lambda(lambda x: x), # Identity, 428 | ), 429 | LambdaReduce(lambda x, y: x + y), # CAddTable, 430 | nn.ReLU(), 431 | ), 432 | nn.Sequential( # Sequential, 433 | LambdaMap(lambda x: x, # ConcatTable, 434 | nn.Sequential( # Sequential, 435 | nn.Sequential( # Sequential, 436 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 437 | nn.BatchNorm2d(512), 438 | nn.ReLU(), 439 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 440 | nn.BatchNorm2d(512), 441 | nn.ReLU(), 442 | ), 443 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 444 | nn.BatchNorm2d(1024), 445 | ), 446 | Lambda(lambda x: x), # Identity, 447 | ), 448 | LambdaReduce(lambda x, y: x + y), # CAddTable, 449 | nn.ReLU(), 450 | ), 451 | nn.Sequential( # Sequential, 452 | LambdaMap(lambda x: x, # ConcatTable, 453 | nn.Sequential( # Sequential, 454 | nn.Sequential( # Sequential, 455 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 456 | nn.BatchNorm2d(512), 457 | nn.ReLU(), 458 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 459 | nn.BatchNorm2d(512), 460 | nn.ReLU(), 461 | ), 462 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 463 | nn.BatchNorm2d(1024), 464 | ), 465 | Lambda(lambda x: x), # Identity, 466 | ), 467 | LambdaReduce(lambda x, y: x + y), # CAddTable, 468 | nn.ReLU(), 469 | ), 470 | nn.Sequential( # Sequential, 471 | LambdaMap(lambda x: x, # ConcatTable, 472 | nn.Sequential( # Sequential, 473 | nn.Sequential( # Sequential, 474 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 475 | nn.BatchNorm2d(512), 476 | nn.ReLU(), 477 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 478 | nn.BatchNorm2d(512), 479 | nn.ReLU(), 480 | ), 481 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 482 | nn.BatchNorm2d(1024), 483 | ), 484 | Lambda(lambda x: x), # Identity, 485 | ), 486 | LambdaReduce(lambda x, y: x + y), # CAddTable, 487 | nn.ReLU(), 488 | ), 489 | nn.Sequential( # Sequential, 490 | LambdaMap(lambda x: x, # ConcatTable, 491 | nn.Sequential( # Sequential, 492 | nn.Sequential( # Sequential, 493 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 494 | nn.BatchNorm2d(512), 495 | nn.ReLU(), 496 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 497 | nn.BatchNorm2d(512), 498 | nn.ReLU(), 499 | ), 500 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 501 | nn.BatchNorm2d(1024), 502 | ), 503 | Lambda(lambda x: x), # Identity, 504 | ), 505 | LambdaReduce(lambda x, y: x + y), # CAddTable, 506 | nn.ReLU(), 507 | ), 508 | nn.Sequential( # Sequential, 509 | LambdaMap(lambda x: x, # ConcatTable, 510 | nn.Sequential( # Sequential, 511 | nn.Sequential( # Sequential, 512 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 513 | nn.BatchNorm2d(512), 514 | nn.ReLU(), 515 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 516 | nn.BatchNorm2d(512), 517 | nn.ReLU(), 518 | ), 519 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 520 | nn.BatchNorm2d(1024), 521 | ), 522 | Lambda(lambda x: x), # Identity, 523 | ), 524 | LambdaReduce(lambda x, y: x + y), # CAddTable, 525 | nn.ReLU(), 526 | ), 527 | nn.Sequential( # Sequential, 528 | LambdaMap(lambda x: x, # ConcatTable, 529 | nn.Sequential( # Sequential, 530 | nn.Sequential( # Sequential, 531 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 532 | nn.BatchNorm2d(512), 533 | nn.ReLU(), 534 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 535 | nn.BatchNorm2d(512), 536 | nn.ReLU(), 537 | ), 538 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 539 | nn.BatchNorm2d(1024), 540 | ), 541 | Lambda(lambda x: x), # Identity, 542 | ), 543 | LambdaReduce(lambda x, y: x + y), # CAddTable, 544 | nn.ReLU(), 545 | ), 546 | nn.Sequential( # Sequential, 547 | LambdaMap(lambda x: x, # ConcatTable, 548 | nn.Sequential( # Sequential, 549 | nn.Sequential( # Sequential, 550 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 551 | nn.BatchNorm2d(512), 552 | nn.ReLU(), 553 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 554 | nn.BatchNorm2d(512), 555 | nn.ReLU(), 556 | ), 557 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 558 | nn.BatchNorm2d(1024), 559 | ), 560 | Lambda(lambda x: x), # Identity, 561 | ), 562 | LambdaReduce(lambda x, y: x + y), # CAddTable, 563 | nn.ReLU(), 564 | ), 565 | nn.Sequential( # Sequential, 566 | LambdaMap(lambda x: x, # ConcatTable, 567 | nn.Sequential( # Sequential, 568 | nn.Sequential( # Sequential, 569 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 570 | nn.BatchNorm2d(512), 571 | nn.ReLU(), 572 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 573 | nn.BatchNorm2d(512), 574 | nn.ReLU(), 575 | ), 576 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 577 | nn.BatchNorm2d(1024), 578 | ), 579 | Lambda(lambda x: x), # Identity, 580 | ), 581 | LambdaReduce(lambda x, y: x + y), # CAddTable, 582 | nn.ReLU(), 583 | ), 584 | nn.Sequential( # Sequential, 585 | LambdaMap(lambda x: x, # ConcatTable, 586 | nn.Sequential( # Sequential, 587 | nn.Sequential( # Sequential, 588 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 589 | nn.BatchNorm2d(512), 590 | nn.ReLU(), 591 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 592 | nn.BatchNorm2d(512), 593 | nn.ReLU(), 594 | ), 595 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 596 | nn.BatchNorm2d(1024), 597 | ), 598 | Lambda(lambda x: x), # Identity, 599 | ), 600 | LambdaReduce(lambda x, y: x + y), # CAddTable, 601 | nn.ReLU(), 602 | ), 603 | nn.Sequential( # Sequential, 604 | LambdaMap(lambda x: x, # ConcatTable, 605 | nn.Sequential( # Sequential, 606 | nn.Sequential( # Sequential, 607 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 608 | nn.BatchNorm2d(512), 609 | nn.ReLU(), 610 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 611 | nn.BatchNorm2d(512), 612 | nn.ReLU(), 613 | ), 614 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 615 | nn.BatchNorm2d(1024), 616 | ), 617 | Lambda(lambda x: x), # Identity, 618 | ), 619 | LambdaReduce(lambda x, y: x + y), # CAddTable, 620 | nn.ReLU(), 621 | ), 622 | ), 623 | nn.Sequential( # Sequential, 624 | nn.Sequential( # Sequential, 625 | LambdaMap(lambda x: x, # ConcatTable, 626 | nn.Sequential( # Sequential, 627 | nn.Sequential( # Sequential, 628 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 629 | nn.BatchNorm2d(1024), 630 | nn.ReLU(), 631 | nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 632 | nn.BatchNorm2d(1024), 633 | nn.ReLU(), 634 | ), 635 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 636 | nn.BatchNorm2d(2048), 637 | ), 638 | nn.Sequential( # Sequential, 639 | nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 640 | nn.BatchNorm2d(2048), 641 | ), 642 | ), 643 | LambdaReduce(lambda x, y: x + y), # CAddTable, 644 | nn.ReLU(), 645 | ), 646 | nn.Sequential( # Sequential, 647 | LambdaMap(lambda x: x, # ConcatTable, 648 | nn.Sequential( # Sequential, 649 | nn.Sequential( # Sequential, 650 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 651 | nn.BatchNorm2d(1024), 652 | nn.ReLU(), 653 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 654 | nn.BatchNorm2d(1024), 655 | nn.ReLU(), 656 | ), 657 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 658 | nn.BatchNorm2d(2048), 659 | ), 660 | Lambda(lambda x: x), # Identity, 661 | ), 662 | LambdaReduce(lambda x, y: x + y), # CAddTable, 663 | nn.ReLU(), 664 | ), 665 | nn.Sequential( # Sequential, 666 | LambdaMap(lambda x: x, # ConcatTable, 667 | nn.Sequential( # Sequential, 668 | nn.Sequential( # Sequential, 669 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 670 | nn.BatchNorm2d(1024), 671 | nn.ReLU(), 672 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 673 | nn.BatchNorm2d(1024), 674 | nn.ReLU(), 675 | ), 676 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 677 | nn.BatchNorm2d(2048), 678 | ), 679 | Lambda(lambda x: x), # Identity, 680 | ), 681 | LambdaReduce(lambda x, y: x + y), # CAddTable, 682 | nn.ReLU(), 683 | ), 684 | ), 685 | nn.AvgPool2d((7, 7), (1, 1)), 686 | Lambda(lambda x: x.view(x.size(0), -1)), # View, 687 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear, 688 | ) 689 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from networks.sddnet import SDDNet 3 | from datasets.sbu_dataset_new import SBUDataset 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | import os 8 | import torch.nn.functional as F 9 | 10 | from PIL import Image 11 | import numpy as np 12 | from torchvision import transforms 13 | import cv2 14 | import time 15 | cv2.setNumThreads(0) 16 | 17 | # ckpt_path = 'ckpt/istd_epoch_010.ckpt' 18 | # ckpt_path = './ckpt/ucf.ckpt' 19 | ckpt_path = './ckpt/sbu.ckpt' 20 | # ckpt_path = './ckpt/istd.ckpt' 21 | # ckpt_path = '/data/gyc/new_codes/logs/ckpt_7/ep_019.ckpt' 22 | # data_root = '/data/gyc/ISTD_Dataset/test' 23 | data_root = '/data/gyc/SBU-shadow/SBU-Test_rename' 24 | # data_root = '/data/gyc/UCF' 25 | # save_dir = 'test/raw_modify3_new' 26 | save_dir = 'test/demo' 27 | torch.cuda.set_device(7) 28 | 29 | os.makedirs(save_dir, exist_ok=True) 30 | 31 | model = SDDNet(backbone='efficientnet-b3', 32 | proj_planes=16, 33 | pred_planes=32, 34 | use_pretrained=True, 35 | fix_backbone=False, 36 | has_se=False, 37 | dropout_2d=0, 38 | normalize=True, 39 | mu_init=0.4, 40 | reweight_mode='manual') 41 | # ckpt = torch.load(ckpt_path) 42 | ckpt = torch.load(ckpt_path, map_location={'cuda:0': 'cuda:7'}) 43 | for i in ckpt: 44 | print(i) 45 | model.load_state_dict(ckpt['model']) 46 | # model.fr.set_mu(0.4) 47 | model.cuda() 48 | model.eval() 49 | 50 | 51 | test_dataset = SBUDataset(data_root=data_root, 52 | augmentation=True, 53 | phase='test', 54 | normalize=False, 55 | im_size=512) 56 | # print(test_dataset[0]) 57 | test_loader = DataLoader(test_dataset, batch_size=1, num_workers=4) 58 | 59 | with torch.no_grad(): 60 | # img_list = [img_name for img_name in os.listdir(os.path.join(data_root, 'train_A')) if 61 | # img_name.endswith('.png')] 62 | # for idx, im_name in enumerate(img_list): 63 | # image = Image.open(os.path.join(data_root, 'train_A', im_name)) 64 | # w, h = image.size 65 | # img_var = Variable(img_transform(img).unsqueeze(0)).cuda() 66 | # gt = Image.open(os.path.join(data_root, 'train_B', im_name)) 67 | time_all = [] 68 | for data in tqdm(test_loader): 69 | image = data['train_A_input'].cuda() 70 | im_name = data['im_name'][0] 71 | save_path = os.path.join(save_dir, im_name) 72 | gt = data['gt'][0] 73 | start = time.time() 74 | ans = model(image) 75 | end = time.time() 76 | time_all.append(end-start) 77 | 78 | img = Image.open(os.path.join(data_root, 'train_A', im_name)) 79 | w,h = img.size 80 | image = transforms.Resize((h, w))(image) 81 | pred = transforms.Resize((h, w))(torch.sigmoid(ans['logit'].cpu())[0]) 82 | gt = transforms.Resize((h, w))(gt) 83 | imgrid = torchvision.utils.save_image([image.cpu()[0], pred.expand_as(image[0]),gt.expand_as(image[0])], fp=save_path, nrow=3, padding=0) 84 | 85 | 86 | # pred = torch.sigmoid(ans['logit'].cpu())[0] 87 | # pred = (pred > 0.5).type(torch.int64) 88 | 89 | # noshad = ans['noshad'].cpu()[0] 90 | # shad = ans['shad'].cpu()[0] 91 | # maskimg = ans['maskimg'].cpu()[0] 92 | 93 | # imgrid = torchvision.utils.save_image([image.cpu()[0], pred.expand_as(image[0]), 94 | # gt.expand_as(image[0]), noshad, shad, maskimg], fp=save_path, nrow=3, padding=0) 95 | print('average time:', np.mean(time_all) / 1) 96 | print('average fps:',1 / np.mean(time_all)) 97 | 98 | print('fastest time:', min(time_all) / 1) 99 | print('fastest fps:',1 / min(time_all)) 100 | 101 | print('slowest time:', max(time_all) / 1) 102 | print('slowest fps:',1 / max(time_all)) 103 | 104 | 105 | 106 | from utils.evaluation import evaluate 107 | 108 | # im_grid_dir = 'test/raw' 109 | im_grid_dir = save_dir 110 | pos_err, neg_err, ber, acc, df = evaluate(im_grid_dir, pred_id=1, gt_id=2, nimg=3, nrow=3) 111 | print(f'\t BER: {ber:.2f}, pErr: {pos_err:.2f}, nErr: {neg_err:.2f}, acc:{acc:.4f}') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | from sys import prefix 3 | import torch 4 | from torch.optim import lr_scheduler, optimizer 5 | import torch.nn as nn 6 | import torchvision 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | import torchvision.utils as vutils 10 | from networks.sddnet import SDDNet 11 | from networks.loss import BBCEWithLogitLoss, DiceLoss, EdgeLoss, OrthoLoss, \ 12 | DiffLoss, StyleLoss, DiffLoss_2, ZeroLoss 13 | from utils.evaluation import MyConfuseMatrixMeter 14 | import numpy as np 15 | import random 16 | 17 | from torch import Tensor 18 | import torch.nn.functional as F 19 | 20 | from datasets.sbu_dataset_new import SBUDataset 21 | from datasets.transforms import Denormalize 22 | 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | import configargparse 26 | import os 27 | import logging 28 | 29 | from utils.visualization import colorize_classid_array 30 | 31 | from ptflops import get_model_complexity_info 32 | import cv2 33 | from modelsize_estimate import modelsize 34 | cv2.setNumThreads(0) 35 | 36 | logger = logging.getLogger('ShadowDet') 37 | logger.setLevel(logging.DEBUG) 38 | 39 | 40 | def seed_all(seed=10): 41 | """ 42 | https://discuss.pytorch.org/t/reproducibility-with-all-the-bells-and-whistles/81097 43 | """ 44 | logger.info(f"[ Using Seed : {seed} ]") 45 | 46 | torch.manual_seed(seed) 47 | torch.cuda.manual_seed_all(seed) 48 | torch.cuda.manual_seed(seed) 49 | np.random.seed(seed) 50 | random.seed(seed) 51 | # torch.backends.cudnn.deterministic = True 52 | # torch.backends.cudnn.benchmark = False 53 | 54 | 55 | def create_logdir_and_save_config(args): 56 | paths = {} 57 | paths['sw_dir'] = os.path.join(args.logdir, 'summary') 58 | paths['ckpt_dir'] = os.path.join(args.logdir, 'ckpt_7') 59 | paths['val_dir'] = os.path.join(args.logdir, 'val') 60 | paths['test_dir'] = os.path.join(args.logdir, 'test') 61 | paths['log_file'] = os.path.join(args.logdir, 'train.log') 62 | paths['config_file'] = os.path.join(args.logdir, 'config.txt') 63 | paths['arg_file'] = os.path.join(args.logdir, 'args.txt') 64 | 65 | #### create directories ##### 66 | for k, v in paths.items(): 67 | if k.endswith('dir'): 68 | os.makedirs(v, exist_ok=True) 69 | 70 | #### create summary writer ##### 71 | sw = SummaryWriter(log_dir=paths['sw_dir']) 72 | 73 | #### log to both console and file #### 74 | str2loglevel = {'info': logging.INFO, 'debug': logging.DEBUG, 'error': logging.ERROR} 75 | level = str2loglevel[args.loglevel] 76 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 77 | # create file handler which logs even debug messages 78 | fh = logging.FileHandler(paths['log_file'], 'w+') 79 | fh.setLevel(level) 80 | fh.setFormatter(formatter) 81 | logger.addHandler(fh) 82 | # create console handler with a higher log level 83 | ch = logging.StreamHandler() 84 | ch.setLevel(level) 85 | ch.setFormatter(formatter) 86 | logger.addHandler(ch) 87 | 88 | #### print and save configs ##### 89 | msg = 'Experiment arguments:\n ============begin==================\n' 90 | for arg in sorted(vars(args)): 91 | attr = getattr(args, arg) 92 | msg += '{} = {}\n'.format(arg, attr) 93 | msg += '=============end================' 94 | logger.info(msg) 95 | 96 | with open(paths['arg_file'], 'w') as file: 97 | for arg in sorted(vars(args)): 98 | attr = getattr(args, arg) 99 | file.write('{} = {}\n'.format(arg, attr)) 100 | if args.config is not None: 101 | with open(paths['config_file'], 'w') as file: 102 | file.write(open(args.config, 'r').read()) 103 | 104 | return sw, paths 105 | 106 | 107 | def create_model_and_optimizer(args): 108 | """ 109 | return model, optimizer, lr_schedule, start_epoch 110 | """ 111 | # create model 112 | model = SDDNet(backbone='efficientnet-b3', 113 | proj_planes=16, 114 | pred_planes=32, 115 | use_pretrained=True, 116 | fix_backbone=False, 117 | has_se=False, 118 | dropout_2d=0, 119 | normalize=True, 120 | mu_init=0.5, 121 | reweight_mode='manual') 122 | # macs, params = get_model_complexity_info(model, (3, 512, 512), as_strings=True, 123 | # print_per_layer_stat=True) 124 | # print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 125 | # print('{:<30} {:<8}'.format('Number of parameters: ', params)) 126 | 127 | model.cuda() 128 | 129 | # create optimizer 130 | optimizer = torch.optim.AdamW(model.parameters(), weight_decay=args.wd, lr=args.lr) 131 | # lr schedule 132 | lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) 133 | start_epoch = 0 134 | 135 | logger.info(f'model {args.model} is created!') 136 | 137 | if args.ckpt is not None: 138 | model, optimizer, lr_schedule, start_epoch = load_ckpt(model, optimizer, lr_schedule, args.ckpt) 139 | 140 | return model, optimizer, lr_schedule, start_epoch 141 | 142 | 143 | def create_loss_function(args): 144 | if args.loss == 'bce': 145 | loss_function = nn.BCEWithLogitsLoss() 146 | elif args.loss == 'dice': 147 | loss_function = DiceLoss() 148 | elif args.loss == 'bbce': 149 | loss_function = BBCEWithLogitLoss() 150 | else: 151 | raise ValueError(f'{args.loss} is not supported!') 152 | 153 | return loss_function 154 | 155 | 156 | def create_dataloaders(args): 157 | data_roots = {'SBU_train': '/data/gyc/SBU-shadow/SBUTrain4KRecoveredSmall', 158 | 'SBU_test': '/data/gyc/SBU-shadow/SBU-Test_rename', 159 | 'UCF_test': '/data/gyc/UCF', 160 | 'ISTD_train': '/data/gyc/ISTD_Dataset/train', 161 | 'ISTD_test': '/data/gyc/ISTD_Dataset/test'} 162 | 163 | train_dataset = SBUDataset(data_root=data_roots[args.train_data], 164 | phase='train', augmentation=False, im_size=args.train_size, normalize=False) 165 | 166 | ## set drop_last True to avoid error induced by BatchNormalization 167 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, 168 | shuffle=True, num_workers=args.nworker, 169 | pin_memory=True, drop_last=True) 170 | 171 | # name, split = args.eval_data.split('_') 172 | # val_dataset = get_datasets(name=name, 173 | # root=os.path.join(args.data_root, name), 174 | # split=split, 175 | # transform=val_tf 176 | # ) 177 | eval_data = args.eval_data.split('+') 178 | eval_loaders = {} 179 | for name in eval_data: 180 | dataset = SBUDataset(data_root=data_roots[name], phase='test', augmentation=False, 181 | im_size=args.eval_size, normalize=False) 182 | eval_loaders[name] = torch.utils.data.DataLoader(dataset, batch_size=args.eval_batch, 183 | shuffle=False, num_workers=args.nworker, 184 | pin_memory=True) 185 | 186 | # msg = "Dataloaders are prepared!\n=============================\n" 187 | # msg += f"train_loader: dataset={args.train_data}, num_samples={len(train_loader.dataset)}, batch_size={train_loader.batch_size}\n" 188 | # msg += f"val_loader: dataset={args.eval_data}, num_samples={len(val_loader.dataset)}, batch_size={val_loader.batch_size}\n" 189 | # # msg += f"test_loader: dataset={args.test_data}, num_samples={len(test_loader.dataset)}, batch_size={test_loader.batch_size}\n" 190 | # # msg += "------------------------------\n" 191 | # # msg += f"load_size={args.load_size}\n" 192 | # msg += "=============================" 193 | msg = "Dataloaders are prepared!" 194 | logger.info(msg) 195 | 196 | return train_loader, eval_loaders 197 | 198 | 199 | def save_ckpt(model, optimizer, lr_schedule, epoch, path): 200 | ckpt = {'model': model.state_dict(), 201 | 'optimizer': optimizer.state_dict(), 202 | 'lr_schedule': lr_schedule.state_dict(), 203 | 'epoch': epoch 204 | } 205 | torch.save(ckpt, path) 206 | logger.info(f'checkpoint has been saved to {path}!') 207 | 208 | 209 | def load_ckpt(model, optimizer, lr_schedule, path): 210 | ckpt = torch.load(path) 211 | model.load_state_dict(ckpt['model']) 212 | optimizer.load_state_dict(ckpt['optimizer']) 213 | lr_schedule.load_state_dict(ckpt['lr_schedule']) 214 | start_epoch = ckpt['epoch'] + 1 215 | logger.info(f'model is loaded from {path}!') 216 | return model, optimizer, lr_schedule, start_epoch 217 | 218 | 219 | def visualize_sample(images: Tensor, gt: Tensor, pred: Tensor, bi_th=0.5): 220 | """ 221 | visualize single sample 222 | Args: 223 | images: [2, 3, h, w] tensor 224 | gt: [1, h, w] tensor, int binary mask 225 | pred: [1, h, w] tensor, float soft mask 226 | Return: 227 | grid: visual grid 228 | """ 229 | # mean=[0.485, 0.456, 0.406] 230 | # std=[0.229, 0.224, 0.225] 231 | # # mean=[0.5, 0.5, 0.5] 232 | # # std=[0.5, 0.5, 0.5] 233 | # denorm_fn = Denormalize(mean=mean, std=std) 234 | # images_vis = (denorm_fn(images)*255).type(torch.uint8).cpu() 235 | images_vis = (images*255).type(torch.uint8).cpu() 236 | gt_vis = (torch.cat([gt*255]*3, dim=0)).type(torch.uint8).cpu() 237 | pred_vis = (torch.cat([pred*255]*3, dim=0)).type(torch.uint8).cpu() 238 | pred_bi_vis = (torch.cat([(pred>bi_th).float()*255]*3, dim=0)).type(torch.uint8).cpu() 239 | 240 | # -1: false negative, 0: correct, 1: false positive 241 | diff = (pred > bi_th).type(torch.int8) - gt.type(torch.int8) 242 | logger.debug(f"unique_ids in pred_bi_vis: {torch.unique(pred_bi_vis)}") 243 | logger.debug(f"unique_ids in diff: {torch.unique(diff)}") 244 | diff_vis, _ = colorize_classid_array(diff, alpha=1., image=None, 245 | colors={-1: 'green', 0:'black', 1:'red'}) 246 | diff_vis = diff_vis.cpu() 247 | grid = vutils.make_grid([images_vis, gt_vis, pred_vis, pred_bi_vis, diff_vis], 248 | nrow=3, padding=0) 249 | return grid 250 | 251 | 252 | @torch.no_grad() 253 | def evaluate(model, eval_loader, bi_class_th=0.5, save_dir=None, sw=None, epoch=None, prefix=''): 254 | """ 255 | run inference with given eval_loader; 256 | save_dir: if not None, create the folder to save test results 257 | """ 258 | # logger.info('====start of evaluation====') 259 | if save_dir is not None: 260 | os.makedirs(save_dir, exist_ok=True) 261 | model.eval() 262 | 263 | # cmm = ConfuseMatrixMeter(n_class=2) 264 | cmm = MyConfuseMatrixMeter(n_class=2) 265 | 266 | for i_batch, data in enumerate(tqdm(eval_loader)): 267 | inp = data['train_A_input'].cuda() # (n, 2, c, h, w) 268 | gt = data['gt'].cuda() # (n, 1, h, w) 269 | # requires threshold, TODO: AUROC 270 | pred_soft = torch.sigmoid(F.interpolate(model(inp)['logit'], size=gt.size()[-2:], mode='bilinear')) 271 | # pred_soft = F.interpolate(model(inp)['logit'], size=gt.size()[-2:], mode='bilinear') 272 | pred = (pred_soft > bi_class_th).type(torch.int64) 273 | cmm.update_cm(y_pred=pred.cpu(), y_label=gt.cpu()) 274 | # save_dir = './gyc_eval' 275 | # vutils.save_image(pred_soft, os.path.join(save_dir, str(i_batch)+".png")) 276 | 277 | if (save_dir is not None): 278 | inp = F.interpolate(inp, size=gt.size()[-2:], mode='bilinear') 279 | for i_image, (x, y_gt, y_pred) in enumerate(zip(inp, gt, pred_soft)): 280 | im_grid = visualize_sample(images=x, gt=y_gt, pred=y_pred, bi_th=bi_class_th) 281 | save_name = f'{i_batch*eval_loader.batch_size + i_image:05d}.png' 282 | save_path = os.path.join(save_dir, save_name) 283 | vutils.save_image(im_grid/255., save_path) # save_image() takes float [0, 1.] input 284 | 285 | # score_dict = cmm.get_scores() 286 | score_dict = cmm.get_scores_binary() 287 | msg = 'Scores:\n===============================================' 288 | for k, v in score_dict.items(): 289 | msg += f'\n\t{prefix}.{k}:{v}' 290 | if sw is not None: 291 | sw.add_scalar(f'eval/{prefix}.{k}', v, global_step=epoch) 292 | msg += '\n===============================================' 293 | logger.info(msg) 294 | # logger.info('====end of evaluation====') 295 | return score_dict 296 | 297 | 298 | def train(model, train_loader, loss_fn, optimizer, lr_schedule, epoch, sw, args): 299 | """ 300 | one epoch scan 301 | """ 302 | # logger.info(f'====start of training epoch {epoch+1}/{args.total_ep}====') 303 | global_step = epoch * len(train_loader) 304 | model.train() 305 | # vgg_low = nn.Sequential(*list(torchvision.models.vgg16(pretrained=True).features.children())[:7]) 306 | # vgg_low.cuda() 307 | # loss added 308 | loss_edge = EdgeLoss() 309 | loss_l1 = nn.L1Loss() 310 | loss_l2 = nn.MSELoss() 311 | loss_ortho = OrthoLoss() 312 | loss_diff = DiffLoss_2() 313 | loss_style = StyleLoss() 314 | loss_bce = nn.BCEWithLogitsLoss() 315 | loss_zero = ZeroLoss() 316 | loss_cos = nn.CosineEmbeddingLoss() 317 | 318 | loss_sum = 0 319 | for i_batch, sample in enumerate(train_loader): # mini-batch update 320 | global_step += 1 321 | image, label = sample['train_A_input'].cuda(), sample['gt'].cuda() 322 | # modelsize(model=model, input=image) 323 | logits_shadimg, logits_shadmask, logits_noshad, \ 324 | f_low_shad, f_high_shad, f_low_noshad, f_high_noshad, f_low_feat, f_high_feat,m,s,n = model(image) 325 | 326 | loss1 = loss_l1(logits_shadimg, image) 327 | loss2 = loss_fn(logits_shadmask, label) 328 | # loss3 = loss_shadimg(logits_noshad, noshad_label) 329 | # loss3 = loss_edge(sobel_shadmask, sobel_shadmask) 330 | # loss3 = loss_diff(image, logits_noshad, label) 331 | loss3 = loss_diff(image, logits_noshad, label, logits_shadmask) 332 | # loss4 = loss_style(fea_label, fea_noshad) 333 | loss4 = loss_cos(f_low_shad, f_high_shad, torch.ones([f_low_shad.size()[0]]).cuda()) 334 | loss5 = loss_cos(f_low_shad, f_low_feat, torch.ones([f_low_shad.size()[0]]).cuda()) 335 | loss6 = loss_cos(f_high_shad, f_high_feat, torch.ones([f_low_shad.size()[0]]).cuda()) 336 | # loss7 = loss_cos(f_low_noshad, f_low_shad, torch.ones([f_low_shad.size()[0]]).cuda()) 337 | # loss8 = loss_cos(f_high_noshad, f_high_shad, torch.ones([f_low_shad.size()[0]]).cuda()) 338 | loss7 = loss_ortho(f_low_noshad, f_low_shad) 339 | loss8 = loss_ortho(f_high_noshad, f_high_shad) 340 | loss_mask = loss2 341 | loss_shadimg = 0.2*loss1 342 | loss_noshad = 0.2*loss3 343 | loss_filter = 0.2*(0.5*loss5 + 0.5*loss6 +0.01*loss7+0.01*loss8)# 344 | loss_total = loss_mask+ loss_shadimg + loss_noshad + loss_filter 345 | # sw.add_image('img', torchvision.utils.make_grid(image[0].detach().cpu(), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 346 | # sw.add_image('shadimg', torchvision.utils.make_grid(s[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 347 | # sw.add_image('shadmask', torchvision.utils.make_grid(m[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 348 | # sw.add_image('noshad', torchvision.utils.make_grid(n[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 349 | # sw.add_image('f_noshad_region', torchvision.utils.make_grid(f_noshad_region[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 350 | # sw.add_image('f_mask_region', torchvision.utils.make_grid(f_mask_region[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 351 | # sw.add_image('f_low_shad', torchvision.utils.make_grid(f_low_shad[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 352 | # sw.add_image('f_low_noshad', torchvision.utils.make_grid(f_low_noshad[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 353 | # sw.add_image('f_low_feat', torchvision.utils.make_grid(f_low_feat[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 354 | # sw.add_image('f_high_shad', torchvision.utils.make_grid(f_high_shad[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 355 | # sw.add_image('f_high_noshad', torchvision.utils.make_grid(f_high_noshad[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 356 | # sw.add_image('f_high_feat', torchvision.utils.make_grid(f_high_feat[0].detach().cpu().unsqueeze(dim=1), nrow=5, padding=20, normalize=False, pad_value=1), global_step) 357 | 358 | # 1-step SGD 359 | if args.acc_step == 1: 360 | optimizer.zero_grad() 361 | loss_total.backward() 362 | optimizer.step() 363 | else: 364 | loss_total.backward() 365 | if global_step % args.acc_step == 0: 366 | optimizer.step() 367 | optimizer.zero_grad() 368 | loss_sum += loss_mask.item() 369 | 370 | 371 | # logging 372 | # for name, parms in model.named_parameters(): 373 | # if parms.grad is None: 374 | # print(name) 375 | # for name, parms in model.named_parameters(): 376 | # if parms.grad is not None: 377 | # print('-->name:', name, '-->grad_requirs:', parms.requires_grad, '--weight', torch.mean(parms.data), ' -->grad_value:', torch.mean(parms.grad)) 378 | if global_step % args.i_print == 0: 379 | # info_dict = {'loss_det': loss_det.item(), 'loss_inv': loss_inv.item(), 'loss_var': loss_var.item()} 380 | # info_dict = {'loss_img': loss1.item(), 'loss_mask': loss2.item(), 'loss_noshad': loss3.item()} 381 | info_dict = {'loss_mask': loss_mask.item(), 'loss_noshad': loss_noshad.item(), 'loss_shadimg': loss_shadimg.item(), \ 382 | 'loss_filter': loss_filter.item(), 'loss_total': loss_total} 383 | # info_dict = {'loss_mask': loss2.item()} 384 | msg = f'[batch {i_batch+1}/{len(train_loader)}, epoch {epoch+1}/{args.total_ep}]: ' 385 | for k, v in info_dict.items(): 386 | msg += f'{k}:{v} ' 387 | sw.add_scalar(f'train/{k}', v, global_step=global_step) 388 | logger.info(msg) 389 | # for name, parms in model.named_parameters(): 390 | # if parms.grad is not None: 391 | # sw.add_scalar(f'data/{name}', torch.mean(parms.data), global_step=global_step) 392 | # sw.add_scalar(f'grad/{name}', torch.norm(parms.grad), global_step=global_step) 393 | # TODO: visualization during training 394 | # if global_step % args.i_vis == 0: 395 | # for k, v in vis_dict.items(): 396 | # sw.add_image(f'train/{k}', v, global_step=global_step) 397 | lr_schedule.step() 398 | 399 | # cumulative learning 400 | # model.fr.set_mu(1 - (epoch/args.total_ep)**2) 401 | # logger.info(f'====end of training epoch {epoch+1}/{args.total_ep}====') 402 | return loss_sum 403 | 404 | 405 | 406 | def main(args): 407 | torch.cuda.set_device(7) 408 | seed_all(args.seed) 409 | sw, paths = create_logdir_and_save_config(args) 410 | model, optimizer, lr_schedule, start_epoch = create_model_and_optimizer(args) 411 | loss_fn = create_loss_function(args) 412 | train_loader, val_loader_dict = create_dataloaders(args) 413 | 414 | if args.action == 'test': 415 | for name, val_loader in val_loader_dict.items(): 416 | _ = evaluate(model, val_loader, bi_class_th=args.prob_th, 417 | save_dir=os.path.join(paths['test_dir'], name), 418 | prefix=name) 419 | 420 | elif args.action == 'train': 421 | # best_score = 0 # use it to track best model over training 422 | # best_epoch = -1 423 | for name, val_loader in val_loader_dict.items(): 424 | score_dict = evaluate(model, val_loader, bi_class_th=args.prob_th, save_dir=None, 425 | sw=sw, epoch=-1, prefix=name) 426 | for epoch in range(start_epoch, args.total_ep): 427 | loss = train(model, train_loader=train_loader, loss_fn=loss_fn, 428 | optimizer=optimizer, lr_schedule=lr_schedule, epoch=epoch, 429 | sw=sw, args=args) 430 | # evaluate(model, val_loader, bi_class_th=args.prob_th, save_dir=paths['val_dir'], sw=sw, epoch=epoch) 431 | for name, val_loader in val_loader_dict.items(): 432 | score_dict = evaluate(model, val_loader, bi_class_th=args.prob_th, save_dir=None, 433 | sw=sw, epoch=epoch, prefix=name) 434 | # if score_dict['iou'] > best_score: 435 | # best_score = score_dict['iou'] 436 | # best_epoch = epoch 437 | # save checkpoint 438 | if args.save_ckpt > 0: 439 | ckpt_path = os.path.join(paths['ckpt_dir'], f'ep_{epoch:03d}.ckpt') 440 | save_ckpt(model, optimizer, lr_schedule, epoch, path=ckpt_path) 441 | sw.add_scalar('train/loss', loss, global_step=epoch) 442 | # if epoch == best_epoch: 443 | # ckpt_path = os.path.join(paths['ckpt_dir'], 'best.ckpt') 444 | # save_ckpt(model, optimizer, lr_schedule, epoch, path=ckpt_path) 445 | # after training, record best result and run inference if the save_ckpt 446 | # sw.add_hparams(hparam_dict={'best_epoch': best_epoch}, 447 | # metric_dict={'best_iou': best_score}) 448 | # if args.save_ckpt > 0: 449 | # # run inference 450 | # model, _, _, _ = load_ckpt(model, optimizer, lr_schedule, 451 | # path=os.path.join(paths['ckpt_dir'], 'best.ckpt')) 452 | # evaluate and save visual results 453 | # for name, val_loader in val_loader_dict.items(): 454 | # evaluate(model, val_loader, bi_class_th=args.prob_th, 455 | # save_dir=os.path.join(paths['val_dir'], name), prefix=name) 456 | 457 | else: 458 | raise ValueError(f'invalid action {args.action}') 459 | 460 | # close logger 461 | sw.close() 462 | hdls = logger.handlers[:] 463 | for handler in hdls: 464 | handler.close() 465 | logger.removeHandler(handler) 466 | 467 | 468 | if __name__ == '__main__': 469 | parser = configargparse.ArgumentParser() 470 | parser.add_argument('--config', is_config_file=True, 471 | help='config file path') 472 | 473 | parser.add_argument('--action', type=str, default='train', choices=['train', 'test'], 474 | # parser.add_argument('--action', type=str, default='test', choices=['train', 'test'], 475 | help='action, train or test') 476 | 477 | ## model 478 | parser.add_argument('--model', type=str, default='BANet.efficientnet-b3', 479 | help='architecture to be used') 480 | parser.add_argument('--ckpt', type=str, default=None, help='ckpt to load') 481 | # parser.add_argument('--ckpt', type=str, default='./ckpt/gyc_m3.ckpt', help='ckpt to load') 482 | # /data/gyc/new_codes/logs/ckpt_7/ep_008.ckpt 483 | 484 | ## optimization 485 | parser.add_argument('--seed', type=int, default=4, help='random seed.') 486 | parser.add_argument('--total_ep', type=int, default=20, help='number of epochs for training.') 487 | parser.add_argument('--lr', type=float, default=5e-4, help='initial learning rate.') 488 | parser.add_argument('--lr_step', type=int, default=1, help='learning rate decay frequency (in epochs).') 489 | parser.add_argument('--lr_gamma', type=float, default=0.7, help='learning rate decay factor.') 490 | parser.add_argument('--wd', type=float, default=1e-4, help='weight decay.') 491 | parser.add_argument('--loss', type=str, default='bbce', help='loss function') 492 | parser.add_argument('--save_ckpt', type=int, default=1, help='>0 means save ckpt during training.') 493 | 494 | ## data 495 | parser.add_argument('--train_data', type=str, default='SBU_train', help='training dataset') 496 | parser.add_argument('--eval_data', type=str, default='SBU_test+UCF_test', help='training dataset') 497 | parser.add_argument('--train_batch', type=int, default=4, help='batch_size for train and val dataloader.') 498 | parser.add_argument('--eval_batch', type=int, default=1, help='batch_size for train and val dataloader.') 499 | parser.add_argument('--train_size', type=int, default=512, help='scale images to this size for training') 500 | parser.add_argument('--eval_size', type=int, default=512, help='scale images to this size for evaluation') 501 | parser.add_argument('--nworker', type=int, default=4, help='num_workers for train and val dataloader.') 502 | 503 | ## evaluation 504 | parser.add_argument('--prob_th', type=float, default=0.5, help='threshold for binary classification.') 505 | 506 | ## logging 507 | parser.add_argument('--logdir', type=str, default='logs', help='directory to save logs, args, etc.') 508 | parser.add_argument('--loglevel', type=str, default='info', help='logging level.') 509 | parser.add_argument('--i_print', type=int, default=10, help='training loss display frequency in mini-batchs.') 510 | parser.add_argument('--acc_step', type=int, default=1, help='mini-batch step.') 511 | # parser.add_argument('--i_vis', type=int, default=100, help='training loss display frequency in steps.') 512 | 513 | # ## test 514 | # parser.add_argument('--test_out', type=str, default='local_test_out', help='directory to save test visuals.') 515 | 516 | args = parser.parse() 517 | main(args) 518 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/evaluation.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/visualization.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualization.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmcong/SDDNet_ACMMM23/dddcdd4673dc24a411addcc9d4cbc63a6b30da73/utils/__pycache__/visualization.cpython-39.pyc -------------------------------------------------------------------------------- /utils/cfg_parser.py: -------------------------------------------------------------------------------- 1 | from configparser import * 2 | from collections import OrderedDict 3 | import time 4 | import ast 5 | 6 | # see configparser souce code here 7 | # https://github.com/python/cpython/blob/master/Lib/configparser.py 8 | 9 | class WithTimeInterpolation(ExtendedInterpolation): 10 | def _interpolate_some(self, parser, option, accum, rest, section, map, 11 | depth): 12 | rawval = parser.get(section, option, raw=True, fallback=rest) 13 | if depth > MAX_INTERPOLATION_DEPTH: 14 | raise InterpolationDepthError(option, section, rawval) 15 | while rest: 16 | p = rest.find("$") 17 | if p < 0: 18 | accum.append(rest) 19 | return 20 | if p > 0: 21 | accum.append(rest[:p]) 22 | rest = rest[p:] 23 | # p is no longer used 24 | c = rest[1:2] 25 | if c == "$": 26 | accum.append("$") 27 | rest = rest[2:] 28 | elif c == "{": 29 | m = self._KEYCRE.match(rest) 30 | if m is None: 31 | raise InterpolationSyntaxError(option, section, 32 | "bad interpolation variable reference %r" % rest) 33 | path = m.group(1).split(':') 34 | rest = rest[m.end():] 35 | sect = section 36 | opt = option 37 | try: 38 | if len(path) == 1: 39 | opt = parser.optionxform(path[0]) 40 | v = map[opt] 41 | elif len(path) == 2: 42 | sect = path[0] 43 | opt = parser.optionxform(path[1]) 44 | ###################the part I modified################ 45 | if sect == '_TIME': 46 | if parser.has_section(sect): 47 | raise ValueError("'{:s}' is kept for time interpolation, not allowed \ 48 | to be used as section title".format(sect) ) 49 | else: 50 | v = time.strftime(path[1], time.localtime()) 51 | else: 52 | opt = parser.optionxform(path[1]) 53 | v = parser.get(sect, opt, raw=True) 54 | ####################################################### 55 | else: 56 | raise InterpolationSyntaxError( 57 | option, section, 58 | "More than one ':' found: %r" % (rest,)) 59 | except (KeyError, NoSectionError, NoOptionError): 60 | raise InterpolationMissingOptionError( 61 | option, section, rawval, ":".join(path)) from None 62 | if "$" in v: 63 | self._interpolate_some(parser, opt, accum, v, sect, 64 | dict(parser.items(sect, raw=True)), 65 | depth + 1) 66 | else: 67 | accum.append(v) 68 | else: 69 | raise InterpolationSyntaxError( 70 | option, section, 71 | "'$' must be followed by '$' or '{', " 72 | "found: %r" % (rest,)) 73 | 74 | 75 | 76 | class CfgParser(object): 77 | def __init__(self, path): 78 | # https://docs.python.org/3/library/configparser.html 79 | self.config = ConfigParser(interpolation=WithTimeInterpolation(), 80 | comment_prefixes=('#',';'), 81 | inline_comment_prefixes=(';', '#')) 82 | self.config.read(path) 83 | 84 | def parse(self): 85 | config_dict = OrderedDict() 86 | for section in self.config.sections(): 87 | section_dict = OrderedDict() 88 | for option in self.config[section]: 89 | # https://stackoverflow.com/a/3513475 90 | # print(f'no conversion: {self.config[section][option]}, type {type(self.config[section][option])}\n') 91 | section_dict[option] = ast.literal_eval(self.config[section][option]) 92 | # print(f'converted: {section_dict[option]}, type {type(section_dict[option])}\n') 93 | config_dict[section] = section_dict 94 | return config_dict 95 | 96 | def set(self, section, option, value): 97 | self.config.set(section, option, value) 98 | 99 | 100 | def save(self, path): 101 | with open(path, 'w') as configfile: 102 | self.config.write(configfile) 103 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections import OrderedDict 4 | import pandas as pd 5 | import os 6 | from tqdm import tqdm 7 | import cv2 8 | from utils.misc import split_np_imgrid, get_np_imgrid 9 | import pydensecrf.densecrf as dcrf 10 | 11 | 12 | def cal_ber(tn, tp, fn, fp): 13 | return 0.5*(fp/(tn+fp) + fn/(fn+tp)) 14 | 15 | def cal_acc(tn, tp, fn, fp): 16 | return (tp + tn) / (tp + tn + fp + fn) 17 | 18 | 19 | def get_binary_classification_metrics(pred, gt, threshold=None): 20 | if threshold is not None: 21 | gt = (gt > threshold) 22 | pred = (pred > threshold) 23 | TP = np.logical_and(gt, pred).sum() 24 | TN = np.logical_and(np.logical_not(gt), np.logical_not(pred)).sum() 25 | FN = np.logical_and(gt, np.logical_not(pred)).sum() 26 | FP = np.logical_and(np.logical_not(gt), pred).sum() 27 | BER = cal_ber(TN, TP, FN, FP) 28 | ACC = cal_acc(TN, TP, FN, FP) 29 | return OrderedDict( [('TP', TP), 30 | ('TN', TN), 31 | ('FP', FP), 32 | ('FN', FN), 33 | ('BER', BER), 34 | ('ACC', ACC)] 35 | ) 36 | 37 | 38 | def evaluate(res_root, pred_id, gt_id, nimg, nrow): 39 | img_names = os.listdir(res_root) 40 | score_dict = OrderedDict() 41 | 42 | for img_name in tqdm(img_names, disable=False): 43 | im_grid_path = os.path.join(res_root, img_name) 44 | im_grid = cv2.imread(im_grid_path) 45 | ims = split_np_imgrid(im_grid, nimg, nrow) 46 | pred = ims[pred_id] 47 | gt = ims[gt_id] 48 | score_dict[img_name] = get_binary_classification_metrics(pred, 49 | gt, 50 | 125) 51 | 52 | df = pd.DataFrame(score_dict) 53 | df['ave'] = df.mean(axis=1) 54 | 55 | tn = df['ave']['TN'] 56 | tp = df['ave']['TP'] 57 | fn = df['ave']['FN'] 58 | fp = df['ave']['FP'] 59 | 60 | pos_err = (1 - tp / (tp + fn)) * 100 61 | neg_err = (1 - tn / (tn + fp)) * 100 62 | ber = (pos_err + neg_err) / 2 63 | acc = (tn + tp) / (tn + tp + fn + fp) 64 | 65 | return pos_err, neg_err, ber, acc, df 66 | 67 | 68 | 69 | def _sigmoid(x): 70 | return 1 / (1 + np.exp(-x)) 71 | 72 | 73 | def crf_refine(img, annos): 74 | assert img.dtype == np.uint8 75 | assert annos.dtype == np.uint8 76 | assert img.shape[:2] == annos.shape 77 | 78 | # img and annos should be np array with data type uint8 79 | 80 | EPSILON = 1e-8 81 | 82 | M = 2 # salient or not 83 | tau = 1.05 84 | # Setup the CRF model 85 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 86 | 87 | anno_norm = annos / 255. 88 | 89 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 90 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 91 | 92 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') 93 | U[0, :] = n_energy.flatten() 94 | U[1, :] = p_energy.flatten() 95 | 96 | d.setUnaryEnergy(U) 97 | 98 | d.addPairwiseGaussian(sxy=3, compat=3) 99 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 100 | 101 | # Do the inference 102 | infer = np.array(d.inference(1)).astype('float32') 103 | res = infer[1, :] 104 | 105 | res = res * 255 106 | res = res.reshape(img.shape[:2]) 107 | return res.astype('uint8') 108 | 109 | 110 | 111 | 112 | ############################################### 113 | 114 | class AverageMeter(object): 115 | """Computes and stores the average and current value""" 116 | def __init__(self): 117 | self.sum = 0 118 | self.count = 0 119 | 120 | def update(self, val, weight=1): 121 | self.sum += val * weight 122 | self.count += weight 123 | 124 | def average(self): 125 | if self.count == 0: 126 | return 0 127 | else: 128 | return self.sum / self.count 129 | 130 | def clear(self): 131 | self.sum = 0 132 | self.count = 0 133 | 134 | def compute_cm_torch(y_pred, y_label, n_class): 135 | mask = (y_label >= 0) & (y_label < n_class) 136 | hist = torch.bincount(n_class * y_label[mask] + y_pred[mask], 137 | minlength=n_class**2).reshape(n_class, n_class) 138 | return hist 139 | 140 | class MyConfuseMatrixMeter(AverageMeter): 141 | """More Clear Confusion Matrix Meter""" 142 | def __init__(self, n_class): 143 | super(MyConfuseMatrixMeter, self).__init__() 144 | self.n_class = n_class 145 | 146 | def update_cm(self, y_pred, y_label, weight=1): 147 | y_label = y_label.type(torch.int64) 148 | val = compute_cm_torch(y_pred=y_pred.flatten(), y_label=y_label.flatten(), 149 | n_class=self.n_class) 150 | self.update(val, weight) 151 | 152 | # def get_scores_binary(self): 153 | # assert self.n_class == 2, "this function can only be called for binary calssification problem" 154 | # tn, fp, fn, tp = self.sum.flatten() 155 | # eps = torch.finfo(torch.float32).eps 156 | # precision = tp / (tp + fp + eps) 157 | # recall = tp / (tp + fn + eps) 158 | # f1 = 2*recall*precision / (recall + precision + eps) 159 | # iou = tp / (tp + fn + fp + eps) 160 | # oa = (tp + tn) / (tp + tn + fn + fp + eps) 161 | # score_dict = {} 162 | # score_dict['precision'] = precision.item() 163 | # score_dict['recall'] = recall.item() 164 | # score_dict['f1'] = f1.item() 165 | # score_dict['iou'] = iou.item() 166 | # score_dict['oa'] = oa.item() 167 | # return score_dict 168 | def get_scores_binary(self): 169 | assert self.n_class == 2, "this function can only be called for binary calssification problem" 170 | tn, fp, fn, tp = self.sum.flatten() 171 | eps = torch.finfo(torch.float32).eps 172 | pos_err = (1 - tp / (tp + fn + eps)) * 100 173 | neg_err = (1 - tn / (tn + fp + eps)) * 100 174 | ber = (pos_err + neg_err) / 2 175 | acc = (tn + tp) / (tn + tp + fn + fp + eps) 176 | score_dict = {} 177 | score_dict['pos_err'] = pos_err 178 | score_dict['neg_err'] = neg_err 179 | score_dict['ber'] = ber 180 | score_dict['acc'] = acc 181 | return score_dict 182 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision.utils as vutils 3 | 4 | 5 | def get_np_imgrid(array, nrow=3, padding=0, pad_value=0): 6 | ''' 7 | achieves the same function of torchvision.utils.make_grid for 8 | numpy array 9 | ''' 10 | # assume every image has smae size 11 | n, h, w, c = array.shape 12 | row_num = n // nrow + (n % nrow != 0) 13 | gh, gw = row_num*h + padding*(row_num-1), nrow*w + padding*(nrow - 1) 14 | grid = np.ones((gh, gw, c), dtype=array.dtype) * pad_value 15 | for i in range(n): 16 | grow, gcol = i // nrow, i % nrow 17 | off_y, off_x = grow * (h + padding), gcol * (w + padding) 18 | grid[off_y : off_y + h, off_x : off_x + w] = array[i] 19 | return grid 20 | 21 | 22 | def split_np_imgrid(imgrid, nimg, nrow, padding=0): 23 | ''' 24 | reverse operation of make_grid. 25 | args: 26 | imgrid: HWC image grid 27 | nimg: number of images in the grid 28 | nrow: number of columns in image grid 29 | return: 30 | images: list, contains splitted images 31 | ''' 32 | row_num = nimg // nrow + (nimg % nrow != 0) 33 | gh, gw, _ = imgrid.shape 34 | h, w = (gh - (row_num-1)*padding)//row_num, (gw - (nrow-1)*padding)//nrow 35 | images = [] 36 | for gid in range(nimg): 37 | grow, gcol = gid // nrow, gid % nrow 38 | off_i, off_j = grow * (h + padding), gcol * (w + padding) 39 | images.append(imgrid[off_i:off_i+h, off_j:off_j+w]) 40 | return images 41 | 42 | 43 | class MDTableConvertor: 44 | 45 | def __init__(self, col_num): 46 | self.col_num = col_num 47 | 48 | def _get_table_row(self, items): 49 | row = '' 50 | for item in items: 51 | row += '| {:s} '.format(item) 52 | row += '|\n' 53 | return row 54 | 55 | def convert(self, item_list, title=None): 56 | ''' 57 | args: 58 | item_list: a list of items (str or can be converted to str) 59 | that want to be presented in table. 60 | 61 | title: None, or a list of strings. When set to None, empty title 62 | row is used and column number is determined by col_num; Otherwise, 63 | it will be used as title row, its length will override col_num. 64 | 65 | return: 66 | table: markdown table string. 67 | ''' 68 | table = '' 69 | if title: # not None or not [] both equal to true 70 | col_num = len(title) 71 | table += self._get_table_row(title) 72 | else: 73 | col_num=self.col_num 74 | table += self._get_table_row([' ']*col_num) # empty title row 75 | table += self._get_table_row(['-'] * col_num) # header spliter 76 | for i in range(0, len(item_list), col_num): 77 | table += self._get_table_row(item_list[i:i+col_num]) 78 | return table 79 | 80 | 81 | def visual_dict_to_imgrid(visual_dict, col_num=4, padding=0): 82 | ''' 83 | args: 84 | visual_dict: a dictionary of images of the same size 85 | col_num: number of columns in image grid 86 | padding: number of padding pixels to seperate images 87 | ''' 88 | im_names = [] 89 | im_tensors = [] 90 | for name, visual in visual_dict.items(): 91 | im_names.append(name) 92 | im_tensors.append(visual) 93 | im_grid = vutils.make_grid(im_tensors, 94 | nrow=col_num , 95 | padding=0, 96 | pad_value=1.0) 97 | layout = MDTableConvertor(col_num).convert(im_names) 98 | 99 | return im_grid, layout 100 | 101 | 102 | def count_parameters(model, trainable_only=False): 103 | return sum(p.numel() for p in model.parameters()) 104 | 105 | 106 | 107 | class WarmupExpLRScheduler(object): 108 | def __init__(self, lr_start=1e-4, lr_max=4e-4, lr_min=5e-6, rampup_epochs=4, sustain_epochs=0, exp_decay=0.75): 109 | self.lr_start = lr_start 110 | self.lr_max = lr_max 111 | self.lr_min = lr_min 112 | self.rampup_epochs = rampup_epochs 113 | self.sustain_epochs = sustain_epochs 114 | self.exp_decay = exp_decay 115 | 116 | def __call__(self, epoch): 117 | if epoch < self.rampup_epochs: 118 | lr = (self.lr_max - self.lr_start) / self.rampup_epochs * epoch + self.lr_start 119 | elif epoch < self.rampup_epochs + self.sustain_epochs: 120 | lr = self.lr_max 121 | else: 122 | lr = (self.lr_max - self.lr_min) * self.exp_decay**(epoch - self.rampup_epochs - self.sustain_epochs) + self.lr_min 123 | # print(lr) 124 | return lr -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from torchvision import transforms 4 | import random 5 | 6 | 7 | class ToCV2Image(object): 8 | ''' 9 | convert a CHW, range [0., 1.] tensor to a cv2 image 10 | ''' 11 | def __init__(self, in_color='rgb', 12 | out_color='bgr'): 13 | assert in_color in ['rgb', 'bgr'] 14 | assert out_color in ['rgb', 'bgr'] 15 | self.in_color = in_color 16 | self.out_color = out_color 17 | 18 | def __call__(self, im_tensor): 19 | cv2_img = (im_tensor.cpu().numpy() * 255).astype(np.uint8).transpose(1, 2, 0) 20 | if self.in_color != self.out_color: 21 | cv2_img = cv2.img[:, :, ::-1] 22 | return cv2_img 23 | 24 | 25 | class JointRandHrzFlip(object): 26 | def __init__(self, p=0.5): 27 | self.p = p 28 | 29 | def flip_single(self, image): 30 | return cv2.flip(image, 1) 31 | 32 | def __call__(self, img): 33 | assert isinstance(img, (np.ndarray, list, tuple)) 34 | if random.random() < self.p: 35 | if isinstance(img, np.ndarray): 36 | flipped = self.flip_single(img) 37 | else: 38 | flipped = [] 39 | for each in img: 40 | flipped.append(self.flip_single(each)) 41 | return flipped 42 | else: 43 | return img 44 | 45 | 46 | class JointRandVertFlip(object): 47 | def __init__(self, p=0.5): 48 | self.p = p 49 | 50 | def flip_single(self, image): 51 | return cv2.flip(image, 0) 52 | 53 | def __call__(self, img): 54 | assert isinstance(img, (np.ndarray, list, tuple)) 55 | if random.random() < self.p: 56 | if isinstance(img, np.ndarray): 57 | flipped = self.flip_single(img) 58 | else: 59 | flipped = [] 60 | for each in img: 61 | flipped.append(self.flip_single(each)) 62 | return flipped 63 | else: 64 | return img 65 | 66 | 67 | class JointResize(object): 68 | def __init__(self, size, interpolation='bilinear'): 69 | assert isinstance(size, (tuple, int)) 70 | if isinstance(size, int): 71 | self.size = (size, size) 72 | else: 73 | self.size = size 74 | map_dict = {'bilinear': cv2.INTER_LINEAR, 75 | 'bicubic': cv2.INTER_CUBIC, 76 | 'nearest': cv2.INTER_NEAREST 77 | } 78 | assert interpolation in map_dict.keys() 79 | self.inter_flag = map_dict[interpolation] 80 | 81 | def resize_single(self, image): 82 | return cv2.resize(image, self.size, interpolation=self.inter_flag) 83 | 84 | 85 | def __call__(self, img): 86 | assert isinstance(img, (np.ndarray, list, tuple)) 87 | if isinstance(img, np.ndarray): 88 | resized = self.resize_single(img) 89 | else: 90 | resized = [] 91 | for image in img: 92 | resized.append(self.resize_single(image)) 93 | return resized 94 | 95 | 96 | class JointToTensor(object): 97 | def __init__(self): 98 | self.to_tensor_single = transforms.ToTensor() 99 | 100 | def __call__(self, img): 101 | assert isinstance(img, (np.ndarray, list, tuple)) 102 | if isinstance(img, np.ndarray): 103 | im_tensor = self.to_tensor_single(img) 104 | else: 105 | im_tensor = [] 106 | for image in img: 107 | im_tensor.append(self.to_tensor_single(image)) 108 | return im_tensor 109 | 110 | 111 | class JointNormalize(object): 112 | def __init__(self, mean, std, inplace=False): 113 | self.normalize_single = transforms.Normalize(mean, std, inplace) 114 | 115 | def __call__(self, img): 116 | assert isinstance(img, (np.ndarray, list, tuple)) 117 | if isinstance(img, np.ndarray): 118 | normalized = self.normalize_single(img) 119 | else: 120 | normalized = [] 121 | for image in img: 122 | normalized.append(self.normalize_single(image)) 123 | return normalized 124 | 125 | 126 | class JointRandCrop(object): 127 | def __init__(self, size): 128 | pass 129 | def __call__(self, img): 130 | pass 131 | 132 | 133 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/2 134 | class Denormalize(object): 135 | def __init__(self, mean, std): 136 | self.mean = mean 137 | self.std = std 138 | 139 | def __call__(self, tensor): 140 | """ 141 | Args: 142 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 143 | Returns: 144 | Tensor: Normalized image. 145 | """ 146 | for t, m, s in zip(tensor, self.mean, self.std): 147 | t.mul_(s).add_(m) 148 | # The normalize code -> t.sub_(m).div_(s) 149 | return tensor 150 | 151 | 152 | class Binarize(object): 153 | def __init__(self, threshold=125): 154 | assert isinstance(threshold, (int, float)) 155 | self.threshold = threshold 156 | 157 | def __call__(self, img): 158 | assert isinstance(img, np.ndarray) 159 | return (img > self.threshold).astype(img.dtype) 160 | 161 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.utils import draw_segmentation_masks 3 | 4 | def colorize_classid_array(classid_array, image=None, alpha=0.8, colors=None): 5 | """ 6 | Args: 7 | classidx_array: torch.LongTensor, (H, W) tensor 8 | num_cls: int, number of classes 9 | image: if None, overlay colored label on it, otherwise a pure black image is created 10 | colors: list/dict/array provdes class id to color mapping 11 | """ 12 | if image is None: 13 | image = torch.zeros(size=(3, classid_array.size(-2), classid_array.size(-1)), 14 | dtype=torch.uint8) 15 | # if colors is not None: 16 | # assert len(colors) == num_cls, 'size of colormap should be consistent with num_cls' 17 | # all_class_masks = (classid_array == torch.arange(num_cls)[:, None, None]) 18 | # im_label_overlay = draw_segmentation_masks(image, all_class_masks, alpha=alpha, colors=colors) 19 | unique_idx = torch.unique(classid_array) 20 | colors_use = [colors[idx.item()] for idx in unique_idx] 21 | all_class_masks = (classid_array == unique_idx[:, None, None]) 22 | im_label_overlay = draw_segmentation_masks(image, all_class_masks, alpha=alpha, colors=colors_use) 23 | 24 | return im_label_overlay, unique_idx --------------------------------------------------------------------------------