├── README.md ├── config └── sparsemat.toml ├── data └── generate_filelist.py ├── datasets ├── __init__.py ├── data_loader.py └── utils.py ├── demo.py ├── figures └── framework.png ├── model ├── __init__.py ├── backbones │ ├── __init__.py │ ├── dilated_resnet_bn.py │ ├── mobilenetv2.py │ ├── mobilenetv3.py │ ├── resnet_bn.py │ ├── sparse_resnet_bn.py │ └── wrapper.py ├── lap_pyramid_loss.py ├── loss.py ├── lpn.py ├── model.py ├── shm.py └── utils.py ├── test.py ├── train.py └── utils ├── __init__.py ├── config.py └── viz_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # SparseMat 2 | Repository for *Ultrahigh Resolution Image/Video Matting with Spatio-Temporal Sparsity*, which has been accepted by CVPR2023. 3 | 4 | 5 | 6 | ### Overview 7 | 8 | Commodity ultrahigh definition (UHD) displays are becoming more affordable which demand imaging in ultrahigh resolution (UHR). This paper proposes SparseMat, a computationally efficient approach for UHR image/video matting. Note that it is infeasible to directly process UHR images at full resolution in one shot using existing matting algorithms without running out of memory on consumer-level computational platforms, e.g., Nvidia 1080Ti with 11G memory, while patch-based approaches can introduce unsightly artifacts due to patch partitioning. Instead, our method resorts to spatial and temporal sparsity for addressing general UHR matting. When processing videos, huge computation redundancy can be reduced by exploiting spatial and temporal sparsity. In this paper, we show how to effectively detect spatio-temporal sparsity, which serves as a gate to activate input pixels for the matting model. Under the guidance of such sparsity, our method with sparse high-resolution module (SHM) can avoid patch-based inference while memory efficient for full-resolution matte refinement. Extensive experiments demonstrate that SparseMat can effectively and efficiently generate high-quality alpha matte for UHR images and videos at the original high resolution in a single pass. 9 | 10 | ### Environment 11 | The recommended pytorch and torchvision version is v1.9.0 and v0.10.0. 12 | 13 | - torch 14 | - torchvision 15 | - easydict 16 | - toml 17 | - pillow 18 | - scikit-image 19 | - scipy 20 | - spconv. Please install sparse conv module refer to [traveller59/spconv](https://github.com/traveller59/spconv/tree/v1.2.1). Note that we use version 1.2.1 instead of the latest version. 21 | 22 | ### Dataset 23 | Existing datasets suffer from limited resolution. Thus, in this paper we contribute the first UHR human matting dataset, composed of HHM50K for training and HHM2K for evaluation. HHM50K and HHM2K consist of respectively 50,000 and 2,000 unique UHR images (with an average resolution of 4K) encompassing a wide range of human poses and matting scenarios. We provide the downloading link below. 24 | - HHM50K: [BaiduDisk](https://pan.baidu.com/s/1txjXk7OH3vIH7yrmpfNThA), password 2tsc 25 | - HHM2K: [BaiduDisk](https://pan.baidu.com/s/1RKu3qJRRMlgfZbIN7P4j4w), password ymyr 26 | 27 | You can download and put them under `data` directory. Then run the following command to generate file lists. 28 | ``` 29 | python3 data/generate_filelist.py 30 | ``` 31 | 32 | ### Code 33 | ###### Training 34 | Run the following command to train the model. To train SparseMat with our self-trained low-resolution prior network, please download [here](https://drive.google.com/file/d/1_zDQbul-lCM-tFEWNcdw0D4jr3WaK1ir/view?usp=sharing) and put it under the `pretrained` directory. 35 | ``` 36 | work_dir=/PATH/TO/SparseMat 37 | cd $work_dir 38 | export PYTHONPATH=$PYTHONPATH:$work_dir 39 | python3 train.py -c configs/sparsemat.toml 40 | ``` 41 | 42 | ###### Testing 43 | Run the following command to evalute the model. You can download our pretrained model [here](https://drive.google.com/file/d/19MX3USM4BK3sYi0o3AHNUxJ8bZEAGXg9/view?usp=sharing) and put it under the `pretrained` directory. 44 | ``` 45 | work_dir=/PATH/TO/SparseMat 46 | cd $work_dir 47 | export PYTHONPATH=$PYTHONPATH:$work_dir 48 | python3 test.py -c configs/sparsemat.toml 49 | ``` 50 | 51 | ###### Inference 52 | You can use the following command to inference the model on images or videos. 53 | ``` 54 | work_dir=/PATH/TO/SparseMat 55 | cd $work_dir 56 | export PYTHONPATH=$PYTHONPATH:$work_dir 57 | python3 demo.py -c configs/sparsemat.toml --input --save_dir 58 | ``` 59 | 60 | ### Reference 61 | ``` 62 | @InProceedings{Sun_2023_CVPR, 63 | author = {Sun, Yanan and Tang, Chi-Keung and Tai, Yu-Wing}, 64 | title = {Ultrahigh Resolution Image/Video Matting With Spatio-Temporal Sparsity}, 65 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 66 | month = {June}, 67 | year = {2023}, 68 | pages = {14112-14121} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /config/sparsemat.toml: -------------------------------------------------------------------------------- 1 | # Refer to utils/config.py for definition and options. 2 | version = "SparseMat" 3 | 4 | [model] 5 | dilation_kernel = 15 6 | max_n_pixel = 4000000 7 | 8 | [loss] 9 | alpha_loss_weights = [0.1, 0.1, 0.1, 1.0] 10 | with_composition_loss = true 11 | composition_loss_weight = 0.5 12 | 13 | [train] 14 | batch_size = 12 15 | epoch = 30 16 | epoch_decay = 10 17 | lr = 0.0001 18 | min_lr = 0.00001 19 | adaptive_lr = true 20 | beta1 = 0.9 21 | beta2 = 0.999 22 | pretrained_model = "pretrained/lpn.pth" 23 | num_workers = 16 24 | 25 | [aug] 26 | rescale_size = 560 27 | crop_size = 512 28 | patch_crop_size = [512, 640, 800] 29 | patch_load_size = 512 30 | 31 | [data] 32 | dataset = "HHM50K" 33 | filelist_train = "data/HHM50K.txt" 34 | filelist_val = "data/HHM2K.txt" 35 | filelist_test = "data/HHM2K.txt" 36 | 37 | [log] 38 | save_frq = 50 39 | 40 | [test] 41 | batch_size = 1 42 | rescale_size = 512 43 | patch_size = 512 44 | max_size = 7680 45 | save = true 46 | cascade = true 47 | checkpoint = "pretrained/SparseMat.pth" 48 | save_dir = "predictions/SparseMatte/HHM2K" 49 | -------------------------------------------------------------------------------- /data/generate_filelist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | 5 | 6 | if __name__ == "__main__": 7 | 8 | root = "data/HHM50K" 9 | writer = open("data/HHM50K.txt", "w") 10 | 11 | images = sorted(glob.glob(os.path.join(root, "images/*.jpg"))) 12 | alphas = sorted(glob.glob(os.path.join(root, "alphas/*.png"))) 13 | fgs = sorted(glob.glob(os.path.join(root, "foregrounds/*.jpg"))) 14 | bgs = sorted(glob.glob(os.path.join(root, "backgrounds/*.jpg"))) 15 | 16 | assert len(images) == len(alphas) 17 | assert len(images) == len(fgs) 18 | assert len(images) == len(bgs) 19 | 20 | for img, pha, fg, bg in zip(images, alphas, fgs, bgs): 21 | img_name = img.split('/')[-1][:-4] 22 | pha_name = pha.split('/')[-1][:-4] 23 | fg_name = fg.split('/')[-1][:-4] 24 | bg_name = bg.split('/')[-1][:-4] 25 | assert img_name == pha_name 26 | assert img_name == fg_name 27 | assert img_name == bg_name 28 | writer.write(f"{img},{pha},{fg},{bg}\n") 29 | 30 | 31 | root = "data/HHM2K" 32 | writer = open("data/HHM2K.txt", "w") 33 | 34 | images = sorted(glob.glob(os.path.join(root, "images/*.jpg"))) 35 | alphas = sorted(glob.glob(os.path.join(root, "alphas/*.png"))) 36 | 37 | assert len(images) == len(alphas) 38 | 39 | for img, pha in zip(images, alphas): 40 | img_name = img.split('/')[-1][:-4] 41 | pha_name = pha.split('/')[-1][:-4] 42 | assert img_name == pha_name 43 | writer.write(f"{img},{pha}\n") 44 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import Rescale 2 | from .data_loader import RescaleT 3 | from .data_loader import RandomFlip 4 | from .data_loader import RandomCrop 5 | from .data_loader import ToTensor 6 | from .data_loader import CustomDataset 7 | -------------------------------------------------------------------------------- /datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | # data loader 2 | from __future__ import print_function, division 3 | import os 4 | import glob 5 | import numpy as np 6 | import random 7 | import math 8 | import cv2 9 | from PIL import Image 10 | from skimage import io, transform, color 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.utils.data import Sampler, Dataset, DataLoader 15 | from torchvision import transforms, utils 16 | 17 | from .utils import convert_color_space, get_random_patch 18 | 19 | 20 | def imread(path): 21 | image = cv2.imread(path) 22 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 23 | return image 24 | 25 | 26 | class RandomFlip(object): 27 | def __init__(self, cfg): 28 | self.cfg = cfg 29 | 30 | def __call__(self, sample): 31 | # randomly flip 32 | if random.random() >= 0.5: 33 | pos = sample['pos'] 34 | x1 = 1. - pos[..., 2] 35 | x2 = 1. - pos[..., 0] 36 | pos[..., 0] = x1 37 | pos[..., 2] = x2 38 | sample['pos'] = pos 39 | sample['hr_image'] = sample['hr_image'][:,::-1].copy() 40 | sample['lr_image'] = sample['lr_image'][:,::-1].copy() 41 | sample['hr_label'] = sample['hr_label'][:,::-1].copy() 42 | sample['hr_unknown'] = sample['hr_unknown'][:,::-1].copy() 43 | if 'hr_fg' in sample: 44 | sample['hr_fg'] = sample['hr_fg'][:,::-1].copy() 45 | sample['hr_bg'] = sample['hr_bg'][:,::-1].copy() 46 | return sample 47 | 48 | 49 | class Rescale(object): 50 | def __init__(self, cfg): 51 | assert isinstance(cfg.aug.rescale_size,(int,tuple)) 52 | self.output_size = cfg.aug.rescale_size 53 | 54 | def __call__(self,sample): 55 | h, w = sample['hr_image'].shape[:2] 56 | sample['origin_h'] = h 57 | sample['origin_w'] = w 58 | if isinstance(self.output_size,int): 59 | ratio = self.output_size / min(h,w) 60 | new_h, new_w = ratio*h, ratio*w 61 | else: 62 | new_h, new_w = self.output_size 63 | new_h, new_w = int(new_h), int(new_w) 64 | sample['lr_image'] = cv2.resize(sample['hr_image'], (new_w, new_h), interpolation=cv2.INTER_LINEAR) 65 | return sample 66 | 67 | 68 | class RescaleT(object): 69 | def __init__(self, cfg): 70 | self.cfg = cfg 71 | self.max_size = cfg.test.max_size 72 | self.output_size = cfg.test.rescale_size 73 | assert isinstance(self.output_size,(int,tuple)) 74 | 75 | def get_dst_size(self, origin_size, output_size=None, stride=32, max_size=1920): 76 | h, w = origin_size 77 | if output_size is None: 78 | ratio = max_size / max(h,w) 79 | if ratio>=1: 80 | new_h, new_w = h, w 81 | else: 82 | new_h, new_w = int(math.ceil(ratio*h)), int(math.ceil(ratio*w)) 83 | elif isinstance(output_size,int): 84 | if output_size>=max_size: 85 | ratio = output_size / max(h,w) 86 | else: 87 | ratio = output_size / min(h,w) 88 | new_h, new_w = int(math.ceil(ratio*h)), int(math.ceil(ratio*w)) 89 | else: 90 | new_h, new_w = output_size 91 | new_h = new_h - new_h % 32 92 | new_w = new_w - new_w % 32 93 | return (new_h, new_w) 94 | 95 | def __call__(self,sample): 96 | h, w = sample['hr_image'].shape[:2] 97 | sample['origin_h'] = h 98 | sample['origin_w'] = w 99 | new_h, new_w = self.get_dst_size((h,w), self.output_size, 32) 100 | sample['lr_image'] = cv2.resize(sample['hr_image'], (new_w, new_h), interpolation=cv2.INTER_LINEAR) 101 | return sample 102 | 103 | 104 | class RandomCrop(object): 105 | 106 | def __init__(self, cfg): 107 | # low-resolution full image 108 | output_size = cfg.aug.crop_size 109 | assert isinstance(output_size, (int, tuple)) 110 | if isinstance(output_size, int): 111 | self.output_size = (output_size, output_size) 112 | else: 113 | assert len(output_size) == 2 114 | self.output_size = output_size 115 | 116 | # full-resolution patch 117 | patch_crop_size = cfg.aug.patch_crop_size 118 | assert isinstance(patch_crop_size, (tuple, list)) 119 | self.patch_crop_size = patch_crop_size 120 | 121 | patch_load_size = cfg.aug.patch_load_size 122 | assert isinstance(patch_load_size, int) 123 | self.patch_load_size = patch_load_size 124 | 125 | self.cfg = cfg 126 | 127 | def random_crop(self, sample): 128 | h, w = sample['lr_image'].shape[:2] 129 | new_h, new_w = self.output_size 130 | ly1 = np.random.randint(0, h - new_h) 131 | lx1 = np.random.randint(0, w - new_w) 132 | ly2 = ly1 + new_h 133 | lx2 = lx1 + new_w 134 | 135 | oh, ow = sample['hr_image'].shape[:2] 136 | ratio_h = oh / float(h) 137 | ratio_w = ow / float(w) 138 | hx1, hy1 = int(lx1*ratio_w), int(ly1*ratio_h) 139 | hx2, hy2 = int(lx2*ratio_w), int(ly2*ratio_h) 140 | return (lx1,ly1,lx2,ly2), (hx1,hy1,hx2,hy2) 141 | 142 | def __call__(self,sample): 143 | (lx1,ly1,lx2,ly2), (hx1,hy1,hx2,hy2) = self.random_crop(sample) 144 | sample['lr_image'] = sample['lr_image'][ly1:ly2, lx1:lx2] 145 | sample['hr_image'] = sample['hr_image'][hy1:hy2, hx1:hx2] 146 | sample['hr_label'] = sample['hr_label'][hy1:hy2, hx1:hx2] 147 | sample['hr_unknown'] = sample['hr_unknown'][hy1:hy2, hx1:hx2] 148 | 149 | # random crop from high-resolution input 150 | h, w = sample['hr_label'].shape[:2] 151 | random_crop_size = random.choice(self.patch_crop_size) 152 | px1,py1,px2,py2 = get_random_patch(sample['hr_label'], random_crop_size) 153 | pos = np.array([px1/w,py1/h,px2/w,py2/h]).astype(np.float32) 154 | pos = np.clip(pos, 0, 1) 155 | sample['pos'] = pos 156 | 157 | load_size = (self.patch_load_size, self.patch_load_size) 158 | sample['hr_image'] = cv2.resize(sample['hr_image'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_LINEAR) 159 | sample['hr_label'] = cv2.resize(sample['hr_label'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_LINEAR) 160 | sample['hr_unknown'] = cv2.resize(sample['hr_unknown'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_NEAREST) 161 | 162 | if 'hr_fg' in sample: 163 | sample['hr_fg'] = sample['hr_fg'][hy1:hy2, hx1:hx2] 164 | sample['hr_fg'] = cv2.resize(sample['hr_fg'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_LINEAR) 165 | sample['hr_bg'] = sample['hr_bg'][hy1:hy2, hx1:hx2] 166 | sample['hr_bg'] = cv2.resize(sample['hr_bg'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_LINEAR) 167 | return sample 168 | 169 | 170 | class ToTensor(object): 171 | """Convert ndarrays in sample to Tensors.""" 172 | def __init__(self, cfg): 173 | self.color_space = cfg.train.color_space 174 | 175 | def __call__(self, sample): 176 | sample['hr_label'] = sample['hr_label'] / 255. 177 | sample['hr_label'] = torch.from_numpy(sample['hr_label'][None].astype(np.float32)) 178 | sample['hr_unknown'] = torch.from_numpy(sample['hr_unknown'][None].astype(np.float32)) 179 | 180 | sample['hr_image'] = convert_color_space(sample['hr_image'], flag=self.color_space) 181 | sample['hr_image'] = torch.from_numpy(sample['hr_image'].transpose((2,0,1)).astype(np.float32)) 182 | sample['lr_image'] = convert_color_space(sample['lr_image'], flag=self.color_space) 183 | sample['lr_image'] = torch.from_numpy(sample['lr_image'].transpose((2,0,1)).astype(np.float32)) 184 | 185 | if 'pos' in sample: 186 | sample['pos'] = torch.from_numpy(sample['pos'].astype(np.float32)) 187 | if 'hr_fg' in sample: 188 | sample['hr_fg'] = convert_color_space(sample['hr_fg'], flag=self.color_space) 189 | sample['hr_fg'] = torch.from_numpy(sample['hr_fg'].transpose((2,0,1)).astype(np.float32)) 190 | sample['hr_bg'] = convert_color_space(sample['hr_bg'], flag=self.color_space) 191 | sample['hr_bg'] = torch.from_numpy(sample['hr_bg'].transpose((2,0,1)).astype(np.float32)) 192 | return sample 193 | 194 | 195 | class CustomDataset(Dataset): 196 | def __init__(self,cfg, is_training, img_name_list, lbl_name_list, 197 | fg_name_list=None, bg_name_list=None, transform=None): 198 | 199 | self.cfg = cfg 200 | self.is_training = is_training 201 | 202 | self.image_name_list = img_name_list 203 | self.label_name_list = lbl_name_list 204 | self.fg_name_list = fg_name_list # for composition loss only!!!!! 205 | self.bg_name_list = bg_name_list # for composition loss only!!!!! 206 | 207 | self.transform = transform 208 | 209 | def __len__(self): 210 | return len(self.image_name_list) 211 | 212 | def __getitem__(self,idx): 213 | 214 | sample = {} 215 | sample['hr_image'] = imread(self.image_name_list[idx]) 216 | sample['hr_label'] = imread(self.label_name_list[idx])[:,:,0] 217 | 218 | unknown = generate_unknown_label(sample['hr_label'], fixed=(not self.is_training)) 219 | mask = (unknown==0) | (unknown==1) 220 | unknown[mask==1] = 0 221 | unknown[mask==0] = 1 222 | sample['hr_unknown'] = unknown 223 | 224 | if self.is_training and len(self.fg_name_list) == len(self.image_name_list): 225 | fg = imread(self.fg_name_list[idx]) 226 | bg = imread(self.bg_name_list[idx]) 227 | sample['hr_fg'] = fg 228 | sample['hr_bg'] = bg 229 | 230 | if self.transform: 231 | sample = self.transform(sample) 232 | 233 | return sample 234 | 235 | 236 | def generate_unknown_label(alpha, ksize=3, iterations=5, fixed=False): 237 | oH, oW = alpha.shape[:2] 238 | if not fixed: 239 | ksize_range=(3, 9) 240 | iter_range=(1, 15) 241 | ksize = random.randint(ksize_range[0], ksize_range[1]) 242 | iterations = random.randint(iter_range[0], iter_range[1]) 243 | else: 244 | ksize = 5 245 | iterations = 5 246 | ratio = 1280. / max(oH,oW) 247 | alpha = cv2.resize(alpha, None, fx=ratio, fy=ratio) 248 | 249 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) 250 | dilated = cv2.dilate(alpha, kernel, iterations=iterations) 251 | eroded = cv2.erode(alpha, kernel, iterations=iterations) 252 | trimap = np.zeros(alpha.shape) + 128 253 | trimap[eroded >= 255] = 255 254 | trimap[dilated <= 0] = 0 255 | trimap = trimap.astype(np.uint8) 256 | if trimap.shape[0] != oH or trimap.shape[1] != oW: 257 | trimap = cv2.resize(trimap, (oW,oH), interpolation=cv2.INTER_NEAREST) 258 | return trimap 259 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import math 5 | import cv2 6 | 7 | import torch.nn.functional as F 8 | 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision import transforms, utils 11 | 12 | from PIL import Image 13 | from skimage import io, transform, color 14 | 15 | 16 | def get_random_patch(mask, crop_size): 17 | new_h, new_w = mask.shape[:2] 18 | crop_size = min(crop_size, min(new_w, new_h)-1) 19 | crop_size_hf = crop_size // 2 20 | maskf = mask / 255. 21 | ys, xs = np.where(np.logical_and(maskf>0.05, maskf<0.95))[:2] 22 | if len(ys)>0: 23 | rand_ind = random.randint(0, len(ys)-1) 24 | cy = min(max(ys[rand_ind], crop_size_hf), new_h-crop_size_hf) 25 | cx = min(max(xs[rand_ind], crop_size_hf), new_w-crop_size_hf) 26 | x1, y1 = cx - crop_size_hf, cy - crop_size_hf 27 | x2, y2 = x1 + crop_size, y1 + crop_size 28 | else: 29 | x1, y1 = new_w // 2 - crop_size_hf, new_h // 2 - crop_size_hf 30 | x2, y2 = x1 + crop_size, y1 + crop_size 31 | return (x1,y1,x2,y2) 32 | 33 | 34 | def convert_color_space(image, flag=3): 35 | if flag == 3: 36 | image = image / 255.0 37 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 38 | if image.shape[2]==1: 39 | tmpImg[:] = 2 * np.tile(image[:,:,None],(1,1,3)) - 1 40 | else: 41 | tmpImg[:] = 2 * image[:] - 1 42 | 43 | elif flag == 2: # with rgb and Lab colors 44 | tmpImg = np.zeros((image.shape[0],image.shape[1],6)) 45 | tmpImgt = np.zeros((image.shape[0],image.shape[1],3)) 46 | if image.shape[2]==1: 47 | tmpImgt[:,:,0] = image[:,:,0] 48 | tmpImgt[:,:,1] = image[:,:,0] 49 | tmpImgt[:,:,2] = image[:,:,0] 50 | else: 51 | tmpImgt = image 52 | tmpImgtl = color.rgb2lab(tmpImgt) 53 | 54 | # nomalize image to range [0,1] 55 | tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0])) 56 | tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1])) 57 | tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2])) 58 | tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0])) 59 | tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1])) 60 | tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2])) 61 | 62 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) 63 | 64 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) 65 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) 66 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) 67 | tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3]) 68 | tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4]) 69 | tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5]) 70 | 71 | elif flag == 1: #with Lab color 72 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 73 | 74 | if image.shape[2]==1: 75 | tmpImg[:,:,0] = image[:,:,0] 76 | tmpImg[:,:,1] = image[:,:,0] 77 | tmpImg[:,:,2] = image[:,:,0] 78 | else: 79 | tmpImg = image 80 | 81 | tmpImg = color.rgb2lab(tmpImg) 82 | 83 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) 84 | 85 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0])) 86 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1])) 87 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2])) 88 | 89 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) 90 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) 91 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) 92 | 93 | else: # with rgb color 94 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 95 | image = image/np.max(image) 96 | if image.shape[2]==1: 97 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 98 | tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 99 | tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 100 | else: 101 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 102 | tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 103 | tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 104 | 105 | return tmpImg 106 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import cv2 5 | import math 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from model import SparseMat 13 | from utils import load_config 14 | 15 | 16 | def load_checkpoint(net, pretrained_model): 17 | net_state_dict = net.state_dict() 18 | state_dict = torch.load(pretrained_model) 19 | if 'state_dict' in state_dict: 20 | state_dict = state_dict['state_dict'] 21 | elif 'model_state_dict' in state_dict: 22 | state_dict = state_dict['model_state_dict'] 23 | 24 | filtered_state_dict = OrderedDict() 25 | for k,v in state_dict.items(): 26 | if k.startswith('module'): 27 | nk = '.'.join(k.split('.')[1:]) 28 | else: 29 | nk = k 30 | filtered_state_dict[nk] = v 31 | net.load_state_dict(filtered_state_dict) 32 | print('load pretrained weight from {} successfully'.format(pretrained_model)) 33 | 34 | 35 | def preprocess(image): 36 | image = (image / 255. - 0.5) / 0.5 37 | image = torch.from_numpy(image[None]).permute(0,3,1,2) 38 | h, w = image.shape[2:] 39 | nh = math.ceil(h / 8) * 8 40 | nw = math.ceil(w / 8) * 8 41 | image = F.interpolate(image, (nh, nw), mode="bilinear") 42 | return image.float().cuda() 43 | 44 | 45 | def run_single_image(net, input_path, save_dir): 46 | filename = input_path.split('/')[-1] 47 | image = cv2.imread(input_path) 48 | origin_h, origin_w = image.shape[:2] 49 | tensor = preprocess(image) 50 | with torch.no_grad(): 51 | pred = net.inference(tensor) 52 | pred = F.interpolate(pred, (origin_h, origin_w), align_corners=False, mode="bilinear") 53 | pred_alpha = (pred * 255).squeeze().data.cpu().numpy().astype(np.uint8) 54 | cv2.imwrite(os.path.join(save_dir, filename), pred_alpha) 55 | return pred 56 | 57 | 58 | def run_multiple_images(net, input_path, save_dir): 59 | for item in os.listdir(input_path): 60 | run_single_image(net, os.path.join(input_path, item), save_dir) 61 | 62 | 63 | def run_video(net, input_path, save_dir): 64 | filename = input_path.split('/')[-1] 65 | cap = cv2.VideoCapture(input_path) 66 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 67 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 68 | fps = cap.get(cv2.CAP_PROP_FPS) 69 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 70 | writer = cv2.VideoWriter(os.path.join(save_dir, filename), fourcc, fps, (width, height)) 71 | 72 | last_frame = None 73 | last_pred = None 74 | while True: 75 | ret, frame = cap.read() 76 | if not ret: 77 | break 78 | tensor = preprocess(frame) 79 | with torch.no_grad(): 80 | pred = net.inference(tensor, last_img=last_frame, last_pred=last_pred) 81 | pred = F.interpolate(pred, (height, width), align_corners=False, mode="bilinear") 82 | pred_alpha = (pred * 255).squeeze().data.cpu().numpy().astype(np.uint8) 83 | writer.write(np.tile(pred_alpha[:,:,None], (1,1,3))) 84 | last_frame = tensor 85 | last_pred = pred 86 | 87 | 88 | def main(): 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument('-c', '--config', type=str, metavar='FILE', help='path to config file') 91 | parser.add_argument('--input', type=str, metavar='PATH', help='path to input path') 92 | parser.add_argument('--save_dir', type=str, metavar='PATH', help='path to save path') 93 | 94 | args = parser.parse_args() 95 | cfg = load_config(args.config) 96 | 97 | os.makedirs(args.save_dir, exist_ok=True) 98 | 99 | net = SparseMat(cfg) 100 | 101 | if torch.cuda.is_available(): 102 | net.cuda() 103 | else: 104 | exit() 105 | 106 | load_checkpoint(net, cfg.test.checkpoint) 107 | 108 | net.eval() 109 | 110 | if args.input.endswith(".mp4"): 111 | run_video(net, args.input, args.save_dir) 112 | elif args.input.endswith(".jpg") or args.input.endswith(".png"): 113 | run_single_image(net, args.input, args.save_dir) 114 | else: 115 | run_multiple_images(net, args.input, args.save_dir) 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowsyn/SparseMat/2678757dfb7db185f91ee54e54d1e68944febded/figures/framework.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SparseMat 2 | from .loss import losses 3 | -------------------------------------------------------------------------------- /model/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrapper import MobileNetV2Backbone 2 | from .wrapper import MobileNetV3LargeBackbone 3 | -------------------------------------------------------------------------------- /model/backbones/dilated_resnet_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | 5 | 6 | class ResnetDilatedBN(nn.Module): 7 | def __init__(self, args, orig_resnet, dilate_scale=8): 8 | super(ResnetDilatedBN, self).__init__() 9 | from functools import partial 10 | 11 | if dilate_scale == 8: 12 | orig_resnet.layer3.apply( 13 | partial(self._nostride_dilate, dilate=2)) 14 | orig_resnet.layer4.apply( 15 | partial(self._nostride_dilate, dilate=4)) 16 | elif dilate_scale == 16: 17 | orig_resnet.layer4.apply( 18 | partial(self._nostride_dilate, dilate=2)) 19 | 20 | # take pretrained resnet, except AvgPool and FC 21 | self.conv1 = orig_resnet.conv1 22 | self.bn1 = orig_resnet.bn1 23 | self.relu1 = orig_resnet.relu1 24 | 25 | self.conv2 = orig_resnet.conv2 26 | self.bn2 = orig_resnet.bn2 27 | self.relu2 = orig_resnet.relu2 28 | self.conv3 = orig_resnet.conv3 29 | self.bn3 = orig_resnet.bn3 30 | self.relu3 = orig_resnet.relu3 31 | 32 | self.maxpool = orig_resnet.maxpool 33 | self.layer1 = orig_resnet.layer1 34 | self.layer2 = orig_resnet.layer2 35 | self.layer3 = orig_resnet.layer3 36 | self.layer4 = orig_resnet.layer4 37 | 38 | self.enc_channels = [128, 256, 512, 1024, 2048] # 2x, 4x, 8x, 8x, 8x 39 | 40 | def _nostride_dilate(self, m, dilate): 41 | classname = m.__class__.__name__ 42 | if classname.find('Conv') != -1: 43 | # the convolution with stride 44 | if m.stride == (2, 2): 45 | m.stride = (1, 1) 46 | if m.kernel_size == (3, 3): 47 | m.dilation = (dilate // 2, dilate // 2) 48 | m.padding = (dilate // 2, dilate // 2) 49 | # other convoluions 50 | else: 51 | if m.kernel_size == (3, 3): 52 | m.dilation = (dilate, dilate) 53 | m.padding = (dilate, dilate) 54 | 55 | def forward(self, x): 56 | conv_out = [] 57 | x = self.relu1(self.bn1(self.conv1(x))) 58 | x = self.relu2(self.bn2(self.conv2(x))) 59 | x = self.relu3(self.bn3(self.conv3(x))) 60 | 61 | conv_out.append(x) # 2x 62 | x, indices = self.maxpool(x) 63 | x = self.layer1(x) 64 | conv_out.append(x) # 4x 65 | x = self.layer2(x) 66 | conv_out.append(x) # 8x 67 | x = self.layer3(x) 68 | conv_out.append(x) # 16x 69 | x = self.layer4(x) 70 | conv_out.append(x) # 32x 71 | return conv_out 72 | -------------------------------------------------------------------------------- /model/backbones/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch""" 2 | 3 | import math 4 | import json 5 | from functools import reduce 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | #------------------------------------------------------------------------------ 12 | # Useful functions 13 | #------------------------------------------------------------------------------ 14 | 15 | def _make_divisible(v, divisor, min_value=None): 16 | if min_value is None: 17 | min_value = divisor 18 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 19 | # Make sure that round down does not go down by more than 10%. 20 | if new_v < 0.9 * v: 21 | new_v += divisor 22 | return new_v 23 | 24 | 25 | def conv_bn(inp, oup, stride, with_norm=True): 26 | if with_norm: 27 | return nn.Sequential( 28 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 29 | nn.BatchNorm2d(oup), 30 | nn.ReLU6(inplace=True) 31 | ) 32 | else: 33 | return nn.Sequential( 34 | nn.Conv2d(inp, oup, 3, stride, 1, bias=True), 35 | nn.ReLU6(inplace=True) 36 | ) 37 | 38 | 39 | def conv_1x1_bn(inp, oup, with_norm=True): 40 | if with_norm: 41 | return nn.Sequential( 42 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 43 | nn.BatchNorm2d(oup), 44 | nn.ReLU6(inplace=True) 45 | ) 46 | else: 47 | return nn.Sequential( 48 | nn.Conv2d(inp, oup, 1, 1, 0, bias=True), 49 | nn.ReLU6(inplace=True) 50 | ) 51 | 52 | 53 | #------------------------------------------------------------------------------ 54 | # Class of Inverted Residual block 55 | #------------------------------------------------------------------------------ 56 | 57 | class InvertedResidual(nn.Module): 58 | def __init__(self, inp, oup, stride, expansion, dilation=1, with_norm=True): 59 | super(InvertedResidual, self).__init__() 60 | self.stride = stride 61 | assert stride in [1, 2] 62 | 63 | hidden_dim = round(inp * expansion) 64 | self.use_res_connect = self.stride == 1 and inp == oup 65 | 66 | if expansion == 1: 67 | if with_norm: 68 | self.conv = nn.Sequential( 69 | # dw 70 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), 71 | nn.BatchNorm2d(hidden_dim), 72 | nn.ReLU6(inplace=True), 73 | # pw-linear 74 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 75 | nn.BatchNorm2d(oup), 76 | ) 77 | else: 78 | self.conv = nn.Sequential( 79 | # dw 80 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=True), 81 | nn.ReLU6(inplace=True), 82 | # pw-linear 83 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=True), 84 | ) 85 | else: 86 | if with_norm: 87 | self.conv = nn.Sequential( 88 | # pw 89 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 90 | nn.BatchNorm2d(hidden_dim), 91 | nn.ReLU6(inplace=True), 92 | # dw 93 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), 94 | nn.BatchNorm2d(hidden_dim), 95 | nn.ReLU6(inplace=True), 96 | # pw-linear 97 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 98 | nn.BatchNorm2d(oup), 99 | ) 100 | else: 101 | self.conv = nn.Sequential( 102 | # pw 103 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=True), 104 | nn.ReLU6(inplace=True), 105 | # dw 106 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=True), 107 | nn.ReLU6(inplace=True), 108 | # pw-linear 109 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=True), 110 | ) 111 | 112 | def forward(self, x): 113 | if self.use_res_connect: 114 | return x + self.conv(x) 115 | else: 116 | return self.conv(x) 117 | 118 | 119 | #------------------------------------------------------------------------------ 120 | # Class of MobileNetV2 121 | #------------------------------------------------------------------------------ 122 | 123 | class MobileNetV2(nn.Module): 124 | def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000, with_norm=True): 125 | super(MobileNetV2, self).__init__() 126 | self.in_channels = in_channels 127 | self.num_classes = num_classes 128 | input_channel = 32 129 | last_channel = 1280 130 | interverted_residual_setting = [ 131 | # t, c, n, s 132 | [1 , 16, 1, 1], 133 | [expansion, 24, 2, 2], 134 | [expansion, 32, 3, 2], 135 | [expansion, 64, 4, 2], 136 | [expansion, 96, 3, 1], 137 | [expansion, 160, 3, 2], 138 | [expansion, 320, 1, 1], 139 | ] 140 | 141 | # building first layer 142 | input_channel = _make_divisible(input_channel*alpha, 8) 143 | self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel 144 | self.features = [conv_bn(self.in_channels, input_channel, 2, with_norm=with_norm)] 145 | 146 | # building inverted residual blocks 147 | idx = 1 # [0, 2, 4, 7, 14] 148 | for t, c, n, s in interverted_residual_setting: 149 | output_channel = _make_divisible(int(c*alpha), 8) 150 | for i in range(n): 151 | if i == 0: 152 | self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t, with_norm=with_norm)) 153 | else: 154 | self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t, with_norm=with_norm)) 155 | idx += 1 156 | input_channel = output_channel 157 | 158 | # building last several layers 159 | self.features.append(conv_1x1_bn(input_channel, self.last_channel, with_norm=with_norm)) 160 | 161 | # make it nn.Sequential 162 | self.features = nn.Sequential(*self.features) 163 | 164 | # building classifier 165 | if self.num_classes is not None: 166 | self.classifier = nn.Sequential( 167 | nn.Dropout(0.2), 168 | nn.Linear(self.last_channel, num_classes), 169 | ) 170 | 171 | # Initialize weights 172 | self._init_weights() 173 | 174 | def forward(self, x, feature_names=None): 175 | # Stage1 176 | x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x) 177 | # Stage2 178 | x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x) 179 | # Stage3 180 | x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x) 181 | # Stage4 182 | x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x) 183 | # Stage5 184 | x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x) 185 | 186 | # Classification 187 | if self.num_classes is not None: 188 | x = x.mean(dim=(2,3)) 189 | x = self.classifier(x) 190 | 191 | # Output 192 | return x 193 | 194 | def _load_pretrained_model(self, pretrained_file): 195 | pretrain_dict = torch.load(pretrained_file, map_location='cpu') 196 | model_dict = {} 197 | state_dict = self.state_dict() 198 | print("[MobileNetV2] Loading pretrained model...") 199 | for k, v in pretrain_dict.items(): 200 | if k in state_dict: 201 | model_dict[k] = v 202 | else: 203 | print(k, "is ignored") 204 | state_dict.update(model_dict) 205 | self.load_state_dict(state_dict) 206 | 207 | def _init_weights(self): 208 | for m in self.modules(): 209 | if isinstance(m, nn.Conv2d): 210 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 211 | m.weight.data.normal_(0, math.sqrt(2. / n)) 212 | if m.bias is not None: 213 | m.bias.data.zero_() 214 | elif isinstance(m, nn.BatchNorm2d): 215 | m.weight.data.fill_(1) 216 | m.bias.data.zero_() 217 | elif isinstance(m, nn.Linear): 218 | n = m.weight.size(1) 219 | m.weight.data.normal_(0, 0.01) 220 | m.bias.data.zero_() 221 | 222 | 223 | 224 | if __name__ == "__main__": 225 | net = MobileNetV2(3) 226 | net.cuda() 227 | inputs = torch.ones((1,3,512,512)).cuda() 228 | outs = net(inputs) 229 | -------------------------------------------------------------------------------- /model/backbones/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a MobileNetV3 Model as defined in: 3 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019). 4 | Searching for MobileNetV3 5 | arXiv preprint arXiv:1905.02244. 6 | """ 7 | 8 | import torch.nn as nn 9 | import math 10 | 11 | 12 | __all__ = ['mobilenetv3_large', 'mobilenetv3_small'] 13 | 14 | 15 | def _make_divisible(v, divisor, min_value=None): 16 | """ 17 | This function is taken from the original tf repo. 18 | It ensures that all layers have a channel number that is divisible by 8 19 | It can be seen here: 20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 21 | :param v: 22 | :param divisor: 23 | :param min_value: 24 | :return: 25 | """ 26 | if min_value is None: 27 | min_value = divisor 28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 29 | # Make sure that round down does not go down by more than 10%. 30 | if new_v < 0.9 * v: 31 | new_v += divisor 32 | return new_v 33 | 34 | 35 | class h_sigmoid(nn.Module): 36 | def __init__(self, inplace=True): 37 | super(h_sigmoid, self).__init__() 38 | self.relu = nn.ReLU6(inplace=inplace) 39 | 40 | def forward(self, x): 41 | return self.relu(x + 3) / 6 42 | 43 | 44 | class h_swish(nn.Module): 45 | def __init__(self, inplace=True): 46 | super(h_swish, self).__init__() 47 | self.sigmoid = h_sigmoid(inplace=inplace) 48 | 49 | def forward(self, x): 50 | return x * self.sigmoid(x) 51 | 52 | 53 | class SELayer(nn.Module): 54 | def __init__(self, channel, reduction=4): 55 | super(SELayer, self).__init__() 56 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 57 | self.fc = nn.Sequential( 58 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 59 | nn.ReLU(inplace=True), 60 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 61 | h_sigmoid() 62 | ) 63 | 64 | def forward(self, x): 65 | b, c, _, _ = x.size() 66 | y = self.avg_pool(x).view(b, c) 67 | y = self.fc(y).view(b, c, 1, 1) 68 | return x * y 69 | 70 | 71 | def conv_3x3_bn(inp, oup, stride): 72 | return nn.Sequential( 73 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 74 | nn.BatchNorm2d(oup), 75 | h_swish() 76 | ) 77 | 78 | 79 | def conv_1x1_bn(inp, oup): 80 | return nn.Sequential( 81 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 82 | nn.BatchNorm2d(oup), 83 | h_swish() 84 | ) 85 | 86 | 87 | class InvertedResidual(nn.Module): 88 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): 89 | super(InvertedResidual, self).__init__() 90 | assert stride in [1, 2] 91 | 92 | self.identity = stride == 1 and inp == oup 93 | 94 | if inp == hidden_dim: 95 | self.conv = nn.Sequential( 96 | # dw 97 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 98 | nn.BatchNorm2d(hidden_dim), 99 | h_swish() if use_hs else nn.ReLU(inplace=True), 100 | # Squeeze-and-Excite 101 | SELayer(hidden_dim) if use_se else nn.Identity(), 102 | # pw-linear 103 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 104 | nn.BatchNorm2d(oup), 105 | ) 106 | else: 107 | self.conv = nn.Sequential( 108 | # pw 109 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 110 | nn.BatchNorm2d(hidden_dim), 111 | h_swish() if use_hs else nn.ReLU(inplace=True), 112 | # dw 113 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 114 | nn.BatchNorm2d(hidden_dim), 115 | # Squeeze-and-Excite 116 | SELayer(hidden_dim) if use_se else nn.Identity(), 117 | h_swish() if use_hs else nn.ReLU(inplace=True), 118 | # pw-linear 119 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 120 | nn.BatchNorm2d(oup), 121 | ) 122 | 123 | def forward(self, x): 124 | if self.identity: 125 | return x + self.conv(x) 126 | else: 127 | return self.conv(x) 128 | 129 | 130 | class MobileNetV3(nn.Module): 131 | def __init__(self, in_channels, mode='large', num_classes=None, width_mult=1., with_norm=True): 132 | super(MobileNetV3, self).__init__() 133 | # setting of inverted residual blocks 134 | cfgs = [ 135 | # k, t, c, SE, HS, s 136 | [3, 1, 16, 0, 0, 1], 137 | [3, 4, 24, 0, 0, 2], 138 | [3, 3, 24, 0, 0, 1], 139 | [5, 3, 40, 1, 0, 2], 140 | [5, 3, 40, 1, 0, 1], 141 | [5, 3, 40, 1, 0, 1], 142 | [3, 6, 80, 0, 1, 2], 143 | [3, 2.5, 80, 0, 1, 1], 144 | [3, 2.3, 80, 0, 1, 1], 145 | [3, 2.3, 80, 0, 1, 1], 146 | [3, 6, 112, 1, 1, 1], 147 | [3, 6, 112, 1, 1, 1], 148 | [5, 6, 160, 1, 1, 2], 149 | [5, 6, 160, 1, 1, 1], 150 | [5, 6, 160, 1, 1, 1] 151 | ] 152 | self.cfgs = cfgs 153 | assert mode in ['large', 'small'] 154 | 155 | # building first layer 156 | input_channel = _make_divisible(16 * width_mult, 8) 157 | layers = [conv_3x3_bn(3, input_channel, 2)] 158 | # self.features = [conv_3x3_bn(in_channels, input_channel, 2)] 159 | # building inverted residual blocks 160 | block = InvertedResidual 161 | for k, t, c, use_se, use_hs, s in self.cfgs: 162 | output_channel = _make_divisible(c * width_mult, 8) 163 | exp_size = _make_divisible(input_channel * t, 8) 164 | layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 165 | input_channel = output_channel 166 | # building last several layers 167 | # self.conv = conv_1x1_bn(input_channel, exp_size) 168 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 169 | output_channel = {'large': 1280, 'small': 1024} 170 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode] 171 | layers.append(conv_1x1_bn(input_channel, output_channel)) 172 | self.features = nn.Sequential(*layers) 173 | 174 | self.num_classes = num_classes 175 | if self.num_classes is not None: 176 | self.classifier = nn.Sequential( 177 | nn.Linear(exp_size, output_channel), 178 | h_swish(), 179 | nn.Dropout(0.2), 180 | nn.Linear(output_channel, num_classes), 181 | ) 182 | 183 | self._initialize_weights() 184 | 185 | def forward(self, x): 186 | # x = self.features(x) 187 | # x = self.conv(x) 188 | # x = self.avgpool(x) 189 | # x = x.view(x.size(0), -1) 190 | # x = self.classifier(x) 191 | 192 | # Stage1 193 | x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x) 194 | # Stage2 195 | x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x) 196 | # Stage3 197 | x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x) 198 | # Stage4 199 | x = reduce(lambda x, n: self.features[n](x), list(range(7,13)), x) 200 | # Stage5 201 | x = reduce(lambda x, n: self.features[n](x), list(range(13,17)), x) 202 | 203 | # Classification 204 | if self.num_classes is not None: 205 | x = x.mean(dim=(2,3)) 206 | x = self.classifier(x) 207 | 208 | # Output 209 | return x 210 | 211 | def _load_pretrained_model(self, pretrained_file): 212 | pretrain_dict = torch.load(pretrained_file, map_location='cpu') 213 | model_dict = {} 214 | state_dict = self.state_dict() 215 | print("[MobileNetV2] Loading pretrained model...") 216 | for k, v in pretrain_dict.items(): 217 | if k in state_dict: 218 | model_dict[k] = v 219 | else: 220 | print(k, "is ignored") 221 | state_dict.update(model_dict) 222 | self.load_state_dict(state_dict) 223 | return x 224 | 225 | def _initialize_weights(self): 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 229 | m.weight.data.normal_(0, math.sqrt(2. / n)) 230 | if m.bias is not None: 231 | m.bias.data.zero_() 232 | elif isinstance(m, nn.BatchNorm2d): 233 | m.weight.data.fill_(1) 234 | m.bias.data.zero_() 235 | elif isinstance(m, nn.Linear): 236 | m.weight.data.normal_(0, 0.01) 237 | m.bias.data.zero_() 238 | 239 | 240 | def mobilenetv3_large(**kwargs): 241 | """ 242 | Constructs a MobileNetV3-Large model 243 | """ 244 | cfgs = [ 245 | # k, t, c, SE, HS, s 246 | [3, 1, 16, 0, 0, 1], 247 | [3, 4, 24, 0, 0, 2], 248 | [3, 3, 24, 0, 0, 1], 249 | [5, 3, 40, 1, 0, 2], 250 | [5, 3, 40, 1, 0, 1], 251 | [5, 3, 40, 1, 0, 1], 252 | [3, 6, 80, 0, 1, 2], 253 | [3, 2.5, 80, 0, 1, 1], 254 | [3, 2.3, 80, 0, 1, 1], 255 | [3, 2.3, 80, 0, 1, 1], 256 | [3, 6, 112, 1, 1, 1], 257 | [3, 6, 112, 1, 1, 1], 258 | [5, 6, 160, 1, 1, 2], 259 | [5, 6, 160, 1, 1, 1], 260 | [5, 6, 160, 1, 1, 1] 261 | ] 262 | return MobileNetV3(cfgs, mode='large', **kwargs) 263 | 264 | 265 | def mobilenetv3_small(**kwargs): 266 | """ 267 | Constructs a MobileNetV3-Small model 268 | """ 269 | cfgs = [ 270 | # k, t, c, SE, HS, s 271 | [3, 1, 16, 1, 0, 2], 272 | [3, 4.5, 24, 0, 0, 2], 273 | [3, 3.67, 24, 0, 0, 1], 274 | [5, 4, 40, 1, 1, 2], 275 | [5, 6, 40, 1, 1, 1], 276 | [5, 6, 40, 1, 1, 1], 277 | [5, 3, 48, 1, 1, 1], 278 | [5, 3, 48, 1, 1, 1], 279 | [5, 6, 96, 1, 1, 2], 280 | [5, 6, 96, 1, 1, 1], 281 | [5, 6, 96, 1, 1, 1], 282 | ] 283 | 284 | return MobileNetV3(cfgs, mode='small', **kwargs) 285 | 286 | 287 | if __name__ == "__main__": 288 | model = MobileNetV3(3) 289 | print(model) 290 | print(len(model.features)) 291 | -------------------------------------------------------------------------------- /model/backbones/resnet_bn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.nn import BatchNorm2d 6 | # from modules.nn import BatchNorm2d 7 | from collections import OrderedDict 8 | 9 | try: 10 | from torch.hub import load_state_dict_from_url 11 | except ImportError: 12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 13 | 14 | __all__ = ['ResNet'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1): 31 | "3x3 convolution with padding" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=1, bias=False) 34 | 35 | def conv7x7(in_planes, out_planes, stride=1): 36 | "3x3 convolution with padding" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, 38 | padding=3, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None): 45 | super(BasicBlock, self).__init__() 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = BatchNorm2d(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = BatchNorm2d(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | residual = self.downsample(x) 66 | 67 | out += residual 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 79 | self.bn1 = BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 81 | padding=1, bias=False) 82 | self.bn2 = BatchNorm2d(planes, momentum=0.01) 83 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 84 | self.bn3 = BatchNorm2d(planes * 4) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | residual = self.downsample(x) 105 | 106 | out += residual 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, block, layers, num_classes=1000, inplanes=128, conv7x7=False): 115 | self.inplanes = inplanes 116 | super(ResNet, self).__init__() 117 | self.conv7x7 = conv7x7 118 | if self.conv7x7: 119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 120 | self.bn1 = BatchNorm2d(64) 121 | self.relu1 = nn.ReLU(inplace=True) 122 | else: 123 | self.conv1 = conv3x3(3, 64, stride=2) 124 | self.bn1 = BatchNorm2d(64) 125 | self.relu1 = nn.ReLU(inplace=True) 126 | self.conv2 = conv3x3(64, 64) 127 | self.bn2 = BatchNorm2d(64) 128 | self.relu2 = nn.ReLU(inplace=True) 129 | self.conv3 = conv3x3(64, 128) 130 | self.bn3 = BatchNorm2d(128) 131 | self.relu3 = nn.ReLU(inplace=True) 132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 133 | 134 | self.layer1 = self._make_layer(block, 64, layers[0]) 135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 137 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 138 | self.avgpool = nn.AvgPool2d(7, stride=1) 139 | self.fc = nn.Linear(512 * block.expansion, num_classes) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 144 | m.weight.data.normal_(0, math.sqrt(2. / n)) 145 | elif isinstance(m, BatchNorm2d): 146 | m.weight.data.fill_(1) 147 | m.bias.data.zero_() 148 | 149 | def _make_layer(self, block, planes, blocks, stride=1): 150 | downsample = None 151 | if stride != 1 or self.inplanes != planes * block.expansion: 152 | downsample = nn.Sequential( 153 | nn.Conv2d(self.inplanes, planes * block.expansion, 154 | kernel_size=1, stride=stride, bias=False), 155 | BatchNorm2d(planes * block.expansion), 156 | ) 157 | 158 | layers = [] 159 | layers.append(block(self.inplanes, planes, stride, downsample)) 160 | self.inplanes = planes * block.expansion 161 | for i in range(1, blocks): 162 | layers.append(block(self.inplanes, planes)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x): 167 | if self.conv7x7: 168 | x = self.relu1(self.bn1(self.conv1(x))) 169 | else: 170 | x = self.relu1(self.bn1(self.conv1(x))) 171 | x = self.relu2(self.bn2(self.conv2(x))) 172 | x = self.relu3(self.bn3(self.conv3(x))) 173 | x, indices = self.maxpool(x) 174 | 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | 180 | x = self.avgpool(x) 181 | x = x.view(x.size(0), -1) 182 | x = self.fc(x) 183 | return x 184 | 185 | 186 | def l_resnet50(pretrained=False): 187 | """Constructs a ResNet-50 model. 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = ResNet(Bottleneck, [3, 4, 6, 3], inplanes=128, conv7x7=False) 192 | if pretrained: 193 | state_dict = torch.load('pretrained_model/resnet50_v1c.pth') 194 | model.load_state_dict(state_dict, strict=True) 195 | return model 196 | 197 | 198 | if __name__ == "__main__": 199 | model = l_resnet50(pretrained=True) 200 | -------------------------------------------------------------------------------- /model/backbones/sparse_resnet_bn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import spconv 5 | 6 | from torch.nn import BatchNorm1d 7 | from collections import OrderedDict 8 | 9 | try: 10 | from torch.hub import load_state_dict_from_url 11 | except ImportError: 12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 13 | 14 | __all__ = ['ResNet'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, indice_key=None): 31 | "3x3 convolution with padding" 32 | return spconv.SubMConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, indice_key=indice_key) 33 | 34 | def conv7x7(in_planes, out_planes, stride=1, indice_key=None): 35 | "3x3 convolution with padding" 36 | return spconv.SubMConv2d(in_planes, out_planes, kernel_size=7, stride=stride, padding=3, bias=False, indice_key=indice_key) 37 | 38 | 39 | class BasicBlock(spconv.SparseModule): 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, padding=1, 43 | first_indice_key=None, middle_indice_key=None, last_indice_key=None): 44 | 45 | super(BasicBlock, self).__init__() 46 | if stride == 2: 47 | self.conv1 = spconv.SparseConv2d(inplanes, planes, 3, stride, dilation=dilation, padding=padding, bias=False, indice_key=middle_indice_key) 48 | else: 49 | self.conv1 = spconv.SubMConv2d(inplanes, planes, 3, stride, dilation=dilation, padding=padding, bias=False, indice_key=middle_indice_key) 50 | self.bn1 = nn.BatchNorm1d(planes) 51 | self.relu1 = nn.ReLU(inplace=True) 52 | self.conv2 = spconv.SubMConv2d(planes, planes, 3, 1, padding=1, indice_key=last_indice_key) 53 | self.bn2 = nn.BatchNorm1d(planes) 54 | 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | out = self.conv1(x) 62 | out.features = self.bn1(out.features) 63 | out.features = self.relu1(out.features) 64 | out = self.conv2(out) 65 | out.features = self.bn2(out.features) 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | out.features = out.features + residual.features 69 | out.features = self.relu(out.features) 70 | return out 71 | 72 | 73 | class Bottleneck(spconv.SparseModule): 74 | expansion = 4 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, padding=1, 77 | first_indice_key=None, middle_indice_key=None, last_indice_key=None): 78 | 79 | super(Bottleneck, self).__init__() 80 | self.conv1 = spconv.SubMConv2d(inplanes, planes, kernel_size=1, bias=False, indice_key=first_indice_key) 81 | self.bn1 = nn.BatchNorm1d(planes) 82 | if stride == 2: 83 | self.conv2 = spconv.SparseConv2d(planes, planes, 3, stride=stride, dilation=dilation, padding=padding, bias=False, 84 | indice_key=middle_indice_key) 85 | else: 86 | self.conv2 = spconv.SubMConv2d(planes, planes, 3, stride=stride, dilation=dilation, padding=padding, bias=False, 87 | indice_key=middle_indice_key) 88 | self.bn2 = nn.BatchNorm1d(planes, momentum=0.01) 89 | self.conv3 = spconv.SubMConv2d(planes, planes * 4, kernel_size=1, bias=False, indice_key=last_indice_key) 90 | self.bn3 = nn.BatchNorm1d(planes * 4) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | residual = x 97 | 98 | out = self.conv1(x) 99 | out.features = self.bn1(out.features) 100 | out.features = self.relu(out.features) 101 | 102 | out = self.conv2(out) 103 | out.features = self.bn2(out.features) 104 | out.features = self.relu(out.features) 105 | 106 | out = self.conv3(out) 107 | out.features = self.bn3(out.features) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out.features = out.features + residual.features 113 | out.features = self.relu(out.features) 114 | return out 115 | 116 | 117 | class SparseResNet18(spconv.SparseModule): 118 | 119 | def __init__(self, inc, stride, block, layers, num_classes=1000, inplanes=128, conv7x7=False): 120 | self.inplanes = inplanes 121 | super(SparseResNet18, self).__init__() 122 | 123 | self.enc_channels = [64, 64, 128, 256, 512] 124 | 125 | self.conv1 = spconv.SubMConv2d(inc, 64, kernel_size=3, padding=1, bias=False, indice_key='subm0s') 126 | self.bn1 = nn.BatchNorm1d(64) 127 | self.relu1 = nn.ReLU(inplace=True) 128 | self.conv2 = spconv.SparseConv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False, indice_key='spconv0') 129 | self.bn2 = nn.BatchNorm1d(64) 130 | self.relu2 = nn.ReLU(inplace=True) 131 | self.conv3 = spconv.SubMConv2d(64, 64, kernel_size=3, padding=1, bias=False, indice_key='subm0e') 132 | self.bn3 = nn.BatchNorm1d(64) 133 | self.relu3 = nn.ReLU(inplace=True) 134 | 135 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2, dilation=1, padding=1, 136 | first_indice_key='subm1s', middle_indice_key='spconv1', last_indice_key='subm1e') 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilation=1, padding=1, 138 | first_indice_key='subm2s', middle_indice_key='spconv2', last_indice_key='subm2e') 139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=int(max(1,stride/8)), dilation=1, padding=1, 140 | first_indice_key='subm3s', middle_indice_key='spconv3', last_indice_key='subm3e') 141 | self.layer4 = self._make_layer(block, 512, layers[3], stride=int(max(1,stride/16)), dilation=2, padding=2, 142 | first_indice_key='subm4s', middle_indice_key='spconv4', last_indice_key='subm4e') 143 | 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 147 | m.weight.data.normal_(0, math.sqrt(2. / n)) 148 | elif isinstance(m, BatchNorm1d): 149 | m.weight.data.fill_(1) 150 | m.bias.data.zero_() 151 | 152 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, padding=1, 153 | first_indice_key=None, middle_indice_key=None, last_indice_key=None): 154 | downsample = None 155 | if stride != 1 or self.inplanes != planes * block.expansion: 156 | if stride == 2: 157 | downsample = spconv.SparseSequential( 158 | spconv.SparseConv2d(self.inplanes, planes * block.expansion, kernel_size=3, stride=stride, padding=1, 159 | bias=False, indice_key=middle_indice_key), 160 | nn.BatchNorm1d(planes * block.expansion), 161 | ) 162 | else: 163 | downsample = spconv.SparseSequential( 164 | spconv.SubMConv2d(self.inplanes, planes * block.expansion, kernel_size=3, stride=stride, padding=1, 165 | bias=False, indice_key=middle_indice_key), 166 | nn.BatchNorm1d(planes * block.expansion), 167 | ) 168 | layers = [] 169 | layers.append(block(self.inplanes, planes, stride, downsample, dilation, padding, 170 | first_indice_key=first_indice_key, middle_indice_key=middle_indice_key, last_indice_key=last_indice_key)) 171 | self.inplanes = planes * block.expansion 172 | for i in range(1, blocks): 173 | layers.append(block(self.inplanes, planes)) 174 | return spconv.SparseSequential(*layers) 175 | 176 | def forward(self, x): 177 | outs = [] 178 | x = self.conv1(x) 179 | x.features = self.relu1(self.bn1(x.features)) 180 | x = self.conv2(x) 181 | x.features = self.relu2(self.bn2(x.features)) 182 | x = self.conv3(x) 183 | x.features = self.relu3(self.bn3(x.features)) 184 | outs.append(x) 185 | 186 | x = self.layer1(x) 187 | outs.append(x) 188 | 189 | x = self.layer2(x) 190 | outs.append(x) 191 | 192 | x = self.layer3(x) 193 | outs.append(x) 194 | 195 | x = self.layer4(x) 196 | outs.append(x) 197 | return outs 198 | 199 | 200 | class SparseResNet(spconv.SparseModule): 201 | 202 | def __init__(self, inc, stride, block, layers, num_classes=1000, inplanes=128, conv7x7=False): 203 | self.inplanes = inplanes 204 | super(SparseResNet, self).__init__() 205 | 206 | self.enc_channels = [128, 256, 512, 1024, 2048] 207 | 208 | self.conv1 = spconv.SubMConv2d(inc, 64, kernel_size=3, padding=1, bias=False, indice_key='subm0s') 209 | self.bn1 = nn.BatchNorm1d(64) 210 | self.relu1 = nn.ReLU(inplace=True) 211 | self.conv2 = spconv.SparseConv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False, indice_key='spconv0') 212 | self.bn2 = nn.BatchNorm1d(64) 213 | self.relu2 = nn.ReLU(inplace=True) 214 | self.conv3 = spconv.SubMConv2d(64, 128, kernel_size=3, padding=1, bias=False, indice_key='subm0e') 215 | self.bn3 = nn.BatchNorm1d(128) 216 | self.relu3 = nn.ReLU(inplace=True) 217 | 218 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2, dilation=1, padding=1, 219 | first_indice_key='subm1s', middle_indice_key='spconv1', last_indice_key='subm1e') 220 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilation=1, padding=1, 221 | first_indice_key='subm2s', middle_indice_key='spconv2', last_indice_key='subm2e') 222 | self.layer3 = self._make_layer(block, 256, layers[2], stride=int(max(1,stride/8)), dilation=1, padding=1, 223 | first_indice_key='subm3s', middle_indice_key='spconv3', last_indice_key='subm3e') 224 | self.layer4 = self._make_layer(block, 512, layers[3], stride=int(max(1,stride/16)), dilation=2, padding=2, 225 | first_indice_key='subm4s', middle_indice_key='spconv4', last_indice_key='subm4e') 226 | 227 | for m in self.modules(): 228 | if isinstance(m, nn.Conv2d): 229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 230 | m.weight.data.normal_(0, math.sqrt(2. / n)) 231 | elif isinstance(m, BatchNorm1d): 232 | m.weight.data.fill_(1) 233 | m.bias.data.zero_() 234 | 235 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, padding=1, 236 | first_indice_key=None, middle_indice_key=None, last_indice_key=None): 237 | downsample = None 238 | if stride != 1 or self.inplanes != planes * block.expansion: 239 | if stride == 2: 240 | downsample = spconv.SparseSequential( 241 | spconv.SparseConv2d(self.inplanes, planes * block.expansion, kernel_size=3, stride=stride, padding=1, 242 | bias=False, indice_key=middle_indice_key), 243 | nn.BatchNorm1d(planes * block.expansion), 244 | ) 245 | else: 246 | downsample = spconv.SparseSequential( 247 | spconv.SubMConv2d(self.inplanes, planes * block.expansion, kernel_size=3, stride=stride, padding=1, 248 | bias=False, indice_key=middle_indice_key), 249 | nn.BatchNorm1d(planes * block.expansion), 250 | ) 251 | layers = [] 252 | layers.append(block(self.inplanes, planes, stride, downsample, dilation, padding, 253 | first_indice_key=first_indice_key, middle_indice_key=middle_indice_key, last_indice_key=last_indice_key)) 254 | self.inplanes = planes * block.expansion 255 | for i in range(1, blocks): 256 | layers.append(block(self.inplanes, planes)) 257 | return spconv.SparseSequential(*layers) 258 | 259 | def forward(self, x): 260 | outs = [] 261 | x = self.conv1(x) 262 | x.features = self.relu1(self.bn1(x.features)) 263 | x = self.conv2(x) 264 | x.features = self.relu2(self.bn2(x.features)) 265 | x = self.conv3(x) 266 | x.features = self.relu3(self.bn3(x.features)) 267 | outs.append(x) 268 | 269 | x = self.layer1(x) 270 | outs.append(x) 271 | 272 | x = self.layer2(x) 273 | outs.append(x) 274 | 275 | x = self.layer3(x) 276 | outs.append(x) 277 | 278 | x = self.layer4(x) 279 | outs.append(x) 280 | return outs 281 | 282 | 283 | def l_sparse_resnet18(inc, stride=8, pretrained=False): 284 | """Constructs a ResNet-50 model. 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | """ 288 | model = SparseResNet18(inc, stride, BasicBlock, [2, 2, 2, 2], inplanes=128, conv7x7=True) 289 | if pretrained: 290 | state_dict = torch.load('pretrained_model/resnet18.pth') 291 | model.load_state_dict(state_dict, strict=True) 292 | return model 293 | 294 | 295 | def l_sparse_resnet50(inc, stride=8, pretrained=False): 296 | """Constructs a ResNet-50 model. 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | """ 300 | model = SparseResNet(inc, stride, Bottleneck, [3, 4, 6, 3], inplanes=128, conv7x7=False) 301 | if pretrained: 302 | state_dict = torch.load('pretrained_model/resnet50_v1c.pth') 303 | model.load_state_dict(state_dict, strict=True) 304 | return model 305 | 306 | 307 | if __name__ == "__main__": 308 | model = ResNet(Bottleneck, [3, 4, 6, 3], inplanes=128, conv7x7=False) 309 | print(model) 310 | -------------------------------------------------------------------------------- /model/backbones/wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import reduce 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from model.utils import load_pretrained_weight 9 | from .mobilenetv2 import MobileNetV2 10 | from .mobilenetv3 import MobileNetV3 11 | 12 | 13 | class BaseBackbone(nn.Module): 14 | """ Superclass of Replaceable Backbone Model for Semantic Estimation 15 | """ 16 | 17 | def __init__(self, in_channels): 18 | super(BaseBackbone, self).__init__() 19 | self.in_channels = in_channels 20 | 21 | self.model = None 22 | self.enc_channels = [] 23 | 24 | def forward(self, x): 25 | raise NotImplementedError 26 | 27 | def load_pretrained_ckpt(self): 28 | raise NotImplementedError 29 | 30 | 31 | class MobileNetV2Backbone(BaseBackbone): 32 | """ MobileNetV2 Backbone 33 | """ 34 | 35 | def __init__(self, in_channels, with_norm=True): 36 | super(MobileNetV2Backbone, self).__init__(in_channels) 37 | 38 | self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None, with_norm=with_norm) 39 | self.enc_channels = [16, 24, 32, 96, 1280] 40 | 41 | def forward(self, x): 42 | x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) 43 | enc2x = x 44 | x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) 45 | enc4x = x 46 | x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) 47 | enc8x = x 48 | x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) 49 | enc16x = x 50 | x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) 51 | enc32x = x 52 | return [enc2x, enc4x, enc8x, enc16x, enc32x] 53 | 54 | def load_pretrained_ckpt(self): 55 | # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch 56 | ckpt_path = './pretrained_model/mobilenetv2_human_seg.ckpt' 57 | self.model = load_pretrained_weight(self.model, ckpt_path) 58 | print('load pretrained weight from {} successfully'.format(ckpt_path)) 59 | 60 | 61 | class MobileNetV3LargeBackbone(BaseBackbone): 62 | """ MobileNetV2 Backbone 63 | """ 64 | 65 | def __init__(self, in_channels, with_norm=True): 66 | super(MobileNetV3LargeBackbone, self).__init__(in_channels) 67 | 68 | self.model = MobileNetV3(self.in_channels, num_classes=None, with_norm=with_norm) 69 | self.enc_channels = [16, 24, 40, 112, 1280] 70 | 71 | def forward(self, x, priors=None): 72 | x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) 73 | enc2x = x 74 | x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) 75 | enc4x = x 76 | x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) 77 | enc8x = x 78 | x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 13)), x) 79 | enc16x = x 80 | x = reduce(lambda x, n: self.model.features[n](x), list(range(13, 17)), x) 81 | enc32x = x 82 | return [enc2x, enc4x, enc8x, enc16x, enc32x] 83 | -------------------------------------------------------------------------------- /model/lap_pyramid_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def gauss_kernel(size=5, device=torch.device('cpu'), channels=3): 4 | kernel = torch.tensor([[1., 4., 6., 4., 1], 5 | [4., 16., 24., 16., 4.], 6 | [6., 24., 36., 24., 6.], 7 | [4., 16., 24., 16., 4.], 8 | [1., 4., 6., 4., 1.]]) 9 | kernel /= 256. 10 | kernel = kernel.repeat(channels, 1, 1, 1) 11 | kernel = kernel.to(device) 12 | return kernel 13 | 14 | def downsample(x): 15 | return x[:, :, ::2, ::2] 16 | 17 | def upsample(x): 18 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3) 19 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) 20 | cc = cc.permute(0,1,3,2) 21 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]*2, device=x.device)], dim=3) 22 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]*2) 23 | x_up = cc.permute(0,1,3,2) 24 | return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1], device=x.device)) 25 | 26 | def conv_gauss(img, kernel): 27 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') 28 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) 29 | return out 30 | 31 | def laplacian_pyramid(img, kernel, max_levels=3): 32 | current = img 33 | pyr = [] 34 | for level in range(max_levels): 35 | filtered = conv_gauss(current, kernel) 36 | down = downsample(filtered) 37 | up = upsample(down) 38 | diff = current-up 39 | pyr.append(diff) 40 | current = down 41 | return pyr 42 | 43 | class LapLoss(torch.nn.Module): 44 | def __init__(self, max_levels=3, channels=3, device=torch.device('cpu')): 45 | super(LapLoss, self).__init__() 46 | self.max_levels = max_levels 47 | self.gauss_kernel = gauss_kernel(channels=channels, device=device) 48 | 49 | def forward(self, input, target): 50 | pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) 51 | pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) 52 | return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) 53 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from model.lap_pyramid_loss import LapLoss 8 | 9 | 10 | def matting_loss(p, d, mask=None, with_lap=False): 11 | assert p.shape == d.shape 12 | 13 | if mask is not None: 14 | loss = torch.sqrt((p - d) ** 2 + 1e-10) * mask 15 | loss = loss.sum() / (mask.sum() + 1) 16 | else: 17 | loss = torch.sqrt((p - d) ** 2 + 1e-10) 18 | loss = loss.mean() 19 | 20 | if with_lap: 21 | lap_loss = LapLoss(5, device=torch.device('cuda')) 22 | loss = loss + lap_loss(p, d) 23 | return loss 24 | 25 | 26 | def composition_loss(alpha, img, fg, bg, mask): 27 | comp = alpha * fg + (1. - alpha) * bg 28 | diff = (comp - img) * mask 29 | loss = torch.sqrt(diff ** 2 + 1e-12) 30 | loss = loss.sum() / (mask.sum() + 1.) / 3. 31 | return loss 32 | 33 | 34 | def losses(pred_list, input_dict, alpha_loss_weights=[1.0, 1.0, 1.0, 1.0], with_composition_loss=False, composition_loss_weight=1.0): 35 | label = input_dict['hr_label'] 36 | mask = input_dict['hr_unknown'] 37 | 38 | loss_dict = {} 39 | 40 | alpha_loss = 0. 41 | for i, pred in enumerate(pred_list): 42 | stride = label.size(2) / pred.size(2) 43 | pred = F.interpolate(pred, scale_factor=stride, mode='bilinear', align_corners=False) 44 | alpha_loss += matting_loss(pred, label, mask, with_lap=True) * alpha_loss_weights[i] 45 | loss_dict['alpha_loss'] = alpha_loss 46 | 47 | if with_composition_loss: 48 | comp_loss = composition_loss(pred_list[-1], input_dict['hr_image'], 49 | input_dict['hr_fg'], input_dict['hr_bg'], mask) * composition_loss_weight 50 | loss_dict['comp_loss'] = comp_loss 51 | 52 | loss = 0. 53 | for k, v in loss_dict.items(): 54 | if k.endswith('loss'): 55 | loss += v 56 | loss_dict['loss'] = loss 57 | return loss_dict 58 | -------------------------------------------------------------------------------- /model/lpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.backbones import MobileNetV2Backbone 6 | from model.utils import upas 7 | 8 | 9 | class IBNorm(nn.Module): 10 | """ Combine Instance Norm and Batch Norm into One Layer 11 | """ 12 | 13 | def __init__(self, in_channels): 14 | super(IBNorm, self).__init__() 15 | in_channels = in_channels 16 | self.bnorm_channels = int(in_channels / 2) 17 | self.inorm_channels = in_channels - self.bnorm_channels 18 | 19 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) 20 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) 21 | 22 | def forward(self, x): 23 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) 24 | n, c, h, w = bn_x.shape 25 | if n==1 and h==1 and w==1: 26 | in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous().expand(n*2, c, h, w).contiguous())[0:1] 27 | else: 28 | in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous()) 29 | return torch.cat((bn_x, in_x), 1) 30 | 31 | 32 | class Conv2dIBNormRelu(nn.Module): 33 | """ Convolution + IBNorm + ReLu 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, kernel_size, 37 | stride=1, padding=0, dilation=1, groups=1, bias=True, 38 | with_ibn=True, with_relu=True): 39 | super(Conv2dIBNormRelu, self).__init__() 40 | 41 | layers = [ 42 | nn.Conv2d(in_channels, out_channels, kernel_size, 43 | stride=stride, padding=padding, dilation=dilation, 44 | groups=groups, bias=bias) 45 | ] 46 | 47 | if with_ibn: 48 | layers.append(IBNorm(out_channels)) 49 | if with_relu: 50 | layers.append(nn.ReLU(inplace=True)) 51 | 52 | self.layers = nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | return self.layers(x) 56 | 57 | 58 | class SEBlock(nn.Module): 59 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 60 | """ 61 | 62 | def __init__(self, in_channels, out_channels, reduction=1): 63 | super(SEBlock, self).__init__() 64 | self.pool = nn.AdaptiveAvgPool2d(1) 65 | self.fc = nn.Sequential( 66 | nn.Linear(in_channels, int(in_channels // reduction), bias=False), 67 | nn.ReLU(inplace=True), 68 | nn.Linear(int(in_channels // reduction), out_channels, bias=False), 69 | nn.Sigmoid() 70 | ) 71 | 72 | def forward(self, x): 73 | b, c, _, _ = x.size() 74 | w = self.pool(x).view(b, c) 75 | w = self.fc(w).view(b, c, 1, 1) 76 | return x * w.expand_as(x) 77 | 78 | 79 | class HLBranch(nn.Module): 80 | """ High Resolution Branch of MODNet 81 | """ 82 | 83 | def __init__(self, hr_channels, enc_channels, with_norm=True): 84 | super(HLBranch, self).__init__() 85 | 86 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) 87 | 88 | self.p32x = Conv2dIBNormRelu(enc_channels[4], 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 89 | 90 | self.conv_dec16x = nn.Sequential( 91 | Conv2dIBNormRelu(enc_channels[4]+enc_channels[3], 2*hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 92 | Conv2dIBNormRelu(2*hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 93 | ) 94 | self.p16x = Conv2dIBNormRelu(hr_channels+1, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 95 | 96 | self.conv_dec8x = nn.Sequential( 97 | Conv2dIBNormRelu(hr_channels + enc_channels[2], 2*hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 98 | Conv2dIBNormRelu(2*hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 99 | ) 100 | self.p8x = Conv2dIBNormRelu(hr_channels+1, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 101 | 102 | self.conv_dec4x = nn.Sequential( 103 | Conv2dIBNormRelu(hr_channels + enc_channels[1], 2*hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 104 | Conv2dIBNormRelu(2*hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 105 | ) 106 | self.p4x = Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 107 | 108 | self.conv_dec2x = nn.Sequential( 109 | Conv2dIBNormRelu(hr_channels+enc_channels[0], 2*hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 110 | Conv2dIBNormRelu(2*hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 111 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 112 | ) 113 | self.p2x = Conv2dIBNormRelu(hr_channels+1, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 114 | 115 | self.conv_dec1x = nn.Sequential( 116 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm), 117 | ) 118 | self.p1x = Conv2dIBNormRelu(hr_channels+1, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 119 | 120 | self.p0x = Conv2dIBNormRelu(2, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 121 | 122 | def forward(self, img, enc2x, enc4x, enc8x, enc16x, enc32x, is_training=True): 123 | enc32x = self.se_block(enc32x) 124 | p32x = self.p32x(enc32x) 125 | p32x = upas(p32x, img) 126 | 127 | dec16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) 128 | dec16x = self.conv_dec16x(torch.cat((dec16x, enc16x), dim=1)) 129 | p16x = self.p16x(torch.cat((dec16x, upas(p32x, dec16x)), dim=1)) 130 | p16x = upas(p16x, img) 131 | 132 | dec8x = F.interpolate(dec16x, scale_factor=2, mode='bilinear', align_corners=False) 133 | dec8x = self.conv_dec8x(torch.cat((dec8x, enc8x), dim=1)) 134 | p8x = self.p8x(torch.cat((dec8x, upas(p16x, dec8x)), dim=1)) 135 | p8x = upas(p8x, img) 136 | 137 | dec4x = F.interpolate(dec8x, scale_factor=2, mode='bilinear', align_corners=False) 138 | dec4x = self.conv_dec4x(torch.cat((dec4x, enc4x), dim=1)) 139 | p4x = self.p4x(dec4x) 140 | p4x = upas(p4x, img) 141 | 142 | dec2x = F.interpolate(dec4x, scale_factor=2, mode='bilinear', align_corners=False) 143 | dec2x = self.conv_dec2x(torch.cat((dec2x, enc2x), dim=1)) 144 | p2x = self.p2x(torch.cat((dec2x, upas(p4x, dec2x)), dim=1)) 145 | p2x = upas(p2x, img) 146 | 147 | dec1x = F.interpolate(dec2x, scale_factor=2, mode='bilinear', align_corners=False) 148 | dec1x = self.conv_dec1x(torch.cat((dec1x, img), dim=1)) 149 | p1x = self.p1x(torch.cat((dec1x, upas(p2x, dec1x)), dim=1)) 150 | 151 | p0x = self.p0x(torch.cat((p1x, upas(p8x, p1x)), dim=1)) 152 | 153 | seg_out = [torch.sigmoid(p) for p in (p8x, p16x, p32x)] 154 | mat_out = [torch.sigmoid(p) for p in (p1x, p2x, p4x)] 155 | fus_out = [torch.sigmoid(p) for p in (p0x,)] 156 | return seg_out, mat_out, fus_out, [dec1x, dec2x, dec4x, dec8x, dec16x] 157 | 158 | 159 | class AuxilaryHead(nn.Module): 160 | def __init__(self, hr_channels, enc_channels): 161 | super().__init__() 162 | 163 | self.s1 = Conv2dIBNormRelu( 164 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 165 | self.s2 = Conv2dIBNormRelu( 166 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 167 | self.s4 = Conv2dIBNormRelu( 168 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 169 | self.s8 = Conv2dIBNormRelu( 170 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 171 | self.s16 = Conv2dIBNormRelu( 172 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False) 173 | 174 | def forward(self, dec1x, dec2x, dec4x, dec8x, dec16x, is_training=True): 175 | p1 = self.s1(dec1x) 176 | 177 | x2 = self.s2(dec2x) 178 | p2 = F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=False) 179 | 180 | x4 = self.s4(dec4x) 181 | p4 = F.interpolate(x4, scale_factor=4, mode='bilinear', align_corners=False) 182 | 183 | x8 = self.s8(dec8x) 184 | p8 = F.interpolate(x8, scale_factor=8, mode='bilinear', align_corners=False) 185 | 186 | x16 = self.s16(dec16x) 187 | p16 = F.interpolate(x16, scale_factor=16, mode='bilinear', align_corners=False) 188 | 189 | return (p1,p2,p4,p8,p16) 190 | 191 | 192 | class LPN(nn.Module): 193 | def __init__(self, in_chn=3, mid_chn=128): 194 | super().__init__() 195 | self.backbone = MobileNetV2Backbone(in_chn) 196 | self.decoder = HLBranch(mid_chn, self.backbone.enc_channels) 197 | self.aux_head = AuxilaryHead(mid_chn, self.backbone.enc_channels) 198 | 199 | def forward(self, images): 200 | enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(images) 201 | seg_outs, mat_outs, fus_outs, decoded_feats = self.decoder(images, enc2x, enc4x, enc8x, enc16x, enc32x) 202 | return fus_outs[0], decoded_feats[-1] 203 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from model.utils import upas, batch_slice 9 | from model.lpn import LPN 10 | from model.shm import SHM 11 | 12 | 13 | class SparseMat(nn.Module): 14 | def __init__(self, cfg): 15 | super(SparseMat, self).__init__() 16 | self.cfg = cfg 17 | in_ch = cfg.model.in_channel 18 | hr_ch = cfg.model.hr_channel 19 | self.lpn = LPN(in_ch, hr_ch) 20 | self.shm = SHM(inc=4) 21 | self.stride = cfg.model.dilation_kernel 22 | self.dilate_op = nn.MaxPool2d(self.stride, stride=1, padding=self.stride//2) 23 | self.max_n_pixel = cfg.model.max_n_pixel 24 | self.cascade = cfg.test.cascade 25 | 26 | @torch.no_grad() 27 | def generate_sparse_inputs(self, img, lr_pred, mask): 28 | lr_pred = (lr_pred - 0.5) / 0.5 29 | x = torch.cat((img, lr_pred), dim=1) 30 | indices = torch.where(mask.squeeze(1)>0) 31 | x = x.permute(0,2,3,1) 32 | x = x[indices] 33 | indices = torch.stack(indices, dim=1) 34 | return x, indices 35 | 36 | def dilate(self, alpha, stride=15): 37 | mask = torch.logical_and(alpha>0.01, alpha<0.99).float() 38 | mask = self.dilate_op(mask) 39 | return mask 40 | 41 | def forward(self, input_dict): 42 | xlr = input_dict['lr_image'] 43 | xhr = input_dict['hr_image'] 44 | 45 | lr_pred, ctx = self.lpn(xlr) 46 | lr_pred = lr_pred.clone().detach() 47 | ctx = ctx.clone().detach() 48 | 49 | lr_pred = batch_slice(lr_pred, input_dict['pos'], xhr.size()[2:]) 50 | lr_pred = upas(lr_pred, xhr) 51 | if 'hr_unknown' in input_dict: 52 | mask = input_dict['hr_unknown'] 53 | else: 54 | mask = self.dilate(lr_pred) 55 | 56 | sparse_inputs, coords = self.generate_sparse_inputs(xhr, lr_pred, mask=mask) 57 | pred_list = self.shm(sparse_inputs, lr_pred, coords, xhr.size(0), mask.size()[2:], ctx=ctx) 58 | return pred_list 59 | 60 | def generate_sparsity_map(self, lr_pred, curr_img, last_img): 61 | mask_s = self.dilate(lr_pred) 62 | if last_img is not None: 63 | diff = (curr_img - last_img).abs().mean(dim=1, keepdim=True) 64 | shared = torch.logical_and( 65 | F.conv2d(diff, torch.ones(1,1,9,9,device=diff.device), padding=4) < 0.05, 66 | F.conv2d(diff, torch.ones(1,1,1,1,device=diff.device), padding=0) < 0.001, 67 | ).float() 68 | mask_t = self.dilate_op(1 - shared) 69 | mask = mask_s * mask_t 70 | mask = self.dilate_op(mask) 71 | else: 72 | shared = torch.zeros_like(mask_s) 73 | mask_t = torch.ones_like(mask_s) 74 | mask = mask_s * mask_t 75 | return mask, mask_s, mask_t, shared 76 | 77 | def inference(self, hr_img, lr_img=None, last_img=None, last_pred=None): 78 | h, w = hr_img.shape[-2:] 79 | 80 | if lr_img is None: 81 | nh = 512. / min(h,w) * h 82 | nh = math.ceil(nh / 32) * 32 83 | nw = 512. / min(h,w) * w 84 | nw = math.ceil(nw / 32) * 32 85 | lr_img = F.interpolate(hr_img, (int(nh), int(nw)), mode="bilinear") 86 | 87 | lr_pred, ctx = self.lpn(lr_img) 88 | lr_pred_us = upas(lr_pred, hr_img) 89 | mask, mask_s, mask_t, shared = self.generate_sparsity_map(lr_pred_us, hr_img, last_img) 90 | n_pixel = mask.sum().item() 91 | 92 | if n_pixel <= self.max_n_pixel: 93 | sparse_inputs, coords = self.generate_sparse_inputs(hr_img, lr_pred_us, mask) 94 | preds = self.shm(sparse_inputs, lr_pred_us, coords, hr_img.size(0), mask.size()[2:], ctx=ctx) 95 | hr_pred_sp = preds[-1] 96 | if last_pred is not None: 97 | hr_pred = hr_pred_sp * mask + lr_pred_us * (1-mask) * (1-shared) + last_pred * (1-mask) * shared 98 | else: 99 | hr_pred = hr_pred_sp * mask + lr_pred_us * (1-mask) 100 | elif self.cascade: 101 | print("Cascading is used.") 102 | for scale in [0.25, 0.5, 1.0]: 103 | hr_img_ds = F.interpolate(hr_img, None, scale_factor=scale, mode="bilinear") 104 | lr_pred_us = upas(lr_pred, hr_img_ds) 105 | mask_s = self.dilate(lr_pred_us) 106 | if mask_s.sum() > self.max_n_pixel: 107 | break 108 | sparse_inputs, coords = self.generate_sparse_inputs(hr_img_ds, lr_pred_us, mask_s) 109 | preds = self.shm(sparse_inputs, lr_pred_us, coords, hr_img_ds.size(0), mask_s.size()[2:], ctx=ctx) 110 | hr_pred_sp = preds[-1] 111 | hr_pred = hr_pred_sp * mask_s + lr_pred_us * (1-mask_s) 112 | lr_pred = hr_pred 113 | else: 114 | print("Rescaling is used.") 115 | scale = math.sqrt(self.max_n_pixel / float(n_pixel)) 116 | nh = int(scale * h) 117 | nw = int(scale * w) 118 | nh = math.ceil(nh / 8) * 8 119 | nw = math.ceil(nw / 8) * 8 120 | 121 | hr_img_ds = F.interpolate(hr_img, (nh, nw), mode="bilinear") 122 | lr_pred_us = upas(lr_pred, hr_img_ds) 123 | mask_s = self.dilate(lr_pred_us) 124 | 125 | sparse_inputs, coords = self.generate_sparse_inputs(hr_img_ds, lr_pred_us, mask_s) 126 | preds = self.shm(sparse_inputs, lr_pred_us, coords, hr_img_ds.size(0), mask_s.size()[2:], ctx=ctx) 127 | hr_pred_sp = preds[-1] 128 | hr_pred = hr_pred_sp * mask_s + lr_pred_us * (1-mask_s) 129 | return hr_pred 130 | -------------------------------------------------------------------------------- /model/shm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import spconv 5 | 6 | from model.backbones.sparse_resnet_bn import l_sparse_resnet18 7 | from model.lpn import Conv2dIBNormRelu 8 | 9 | 10 | class SparseCAM(nn.Module): 11 | def __init__(self, local_inc, global_inc, with_norm=True): 12 | super(SparseCAM, self).__init__() 13 | 14 | self.pool_fg = nn.AdaptiveAvgPool2d(1) 15 | self.pool_bg = nn.AdaptiveAvgPool2d(1) 16 | self.conv_f = Conv2dIBNormRelu(global_inc, global_inc, kernel_size=1, with_ibn=False) 17 | self.conv_b = Conv2dIBNormRelu(global_inc, global_inc, kernel_size=1, with_ibn=False) 18 | self.conv_g = Conv2dIBNormRelu(2*global_inc, local_inc, kernel_size=1, with_relu=False, with_ibn=False) 19 | 20 | def forward(self, idx, x, ctx, mask): 21 | mask_lr = F.interpolate(mask, ctx.size()[2:], align_corners=False, mode='bilinear') 22 | fg_pool = self.pool_fg(ctx * mask_lr) 23 | fg_ctx = self.conv_f(fg_pool) 24 | bg_pool = self.pool_bg(ctx * (1-mask_lr)) 25 | bg_ctx = self.conv_b(bg_pool) 26 | weight = torch.sigmoid(self.conv_g(torch.cat([fg_ctx, bg_ctx], dim=1))).squeeze(3).squeeze(2) 27 | sparse_weight = weight[x.indices[:,0].long()] 28 | x.features = x.features * sparse_weight 29 | return x 30 | 31 | 32 | class SparseDecoder3_18(spconv.SparseModule): 33 | def __init__(self, inc=512): 34 | super(SparseDecoder3_18, self).__init__() 35 | 36 | # upconv modules 37 | self.conv_up1 = spconv.SparseSequential( 38 | spconv.SparseInverseConv2d(inc, 256, kernel_size=3, bias=True, indice_key='spconv2'), 39 | nn.BatchNorm1d(256), 40 | nn.LeakyReLU(), 41 | ) 42 | 43 | self.conv_up2 = spconv.SparseSequential( 44 | spconv.SparseInverseConv2d(256 + 64, 256, kernel_size=3, bias=True, indice_key='spconv1'), 45 | nn.BatchNorm1d(256), 46 | nn.LeakyReLU(), 47 | ) 48 | 49 | self.conv_up3 = spconv.SparseSequential( 50 | spconv.SparseInverseConv2d(256 + 64, 64, kernel_size=3, bias=True, indice_key='spconv0'), 51 | nn.BatchNorm1d(64), 52 | nn.LeakyReLU(), 53 | ) 54 | 55 | chn = 64 + 3 56 | 57 | self.conv_up4_alpha = spconv.SparseSequential( 58 | spconv.SubMConv2d(chn, 32, kernel_size=3, padding=1, bias=True, indice_key='subm0s'), 59 | nn.LeakyReLU(), 60 | spconv.SubMConv2d(32, 16, kernel_size=3, padding=1, bias=True, indice_key='subm0s'), 61 | nn.LeakyReLU(), 62 | spconv.SubMConv2d(16, 1, kernel_size=1, padding=0, bias=False, indice_key='subm0s') 63 | ) 64 | 65 | self.conv_p8x = spconv.SubMConv2d(256, 1, kernel_size=1, padding=0, bias=False, indice_key='spconv2') 66 | self.conv_p4x = spconv.SubMConv2d(256, 1, kernel_size=1, padding=0, bias=False, indice_key='spconv1') 67 | self.conv_p2x = spconv.SubMConv2d(64, 1, kernel_size=1, padding=0, bias=False, indice_key='spconv0') 68 | 69 | def forward(self, img, conv_out, coarse=None, is_training=True): 70 | x1, x2, x3, x4, x5 = conv_out 71 | 72 | dec4x = self.conv_up1(x5) 73 | p4x = self.conv_p8x(dec4x) 74 | 75 | dec4x.features = torch.cat((dec4x.features, x2.features), 1) 76 | dec2x = self.conv_up2(dec4x) 77 | p2x = self.conv_p4x(dec2x) 78 | 79 | dec2x.features = torch.cat((dec2x.features, x1.features), 1) 80 | dec1x = self.conv_up3(dec2x) 81 | p1x = self.conv_p2x(dec1x) 82 | 83 | img.features = img.features[:,:3] * 0.5 + 0.5 84 | dec1x.features = torch.cat((dec1x.features, img.features),1) 85 | p0x = self.conv_up4_alpha(dec1x) 86 | 87 | raws = [p4x.dense(), p2x.dense(), p1x.dense(), p0x.dense()] 88 | p4x.features = torch.sigmoid(p4x.features) 89 | p2x.features = torch.sigmoid(p2x.features) 90 | p1x.features = torch.sigmoid(p1x.features) 91 | p0x.features = torch.sigmoid(p0x.features) 92 | outs = [p4x.dense(), p2x.dense(), p1x.dense(), p0x.dense()] 93 | return outs 94 | 95 | 96 | class SHM(nn.Module): 97 | def __init__(self, inc=4): 98 | super(SHM, self).__init__() 99 | 100 | self.ctx = SparseCAM(512, 32, with_norm=True) 101 | self.backbone = l_sparse_resnet18(inc, stride=8) 102 | self.decoder = SparseDecoder3_18() 103 | 104 | def forward(self, inputs, lr_pred, coords, batch_size, spatial_shape, ctx): 105 | x = spconv.SparseConvTensor(inputs, coords.int(), spatial_shape, batch_size) 106 | encoded_feats = self.backbone(x) 107 | encoded_feats[-1] = self.ctx(coords.int(), encoded_feats[-1], ctx, lr_pred) 108 | outs = self.decoder(x, encoded_feats) 109 | return outs 110 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | from scipy.ndimage import morphology 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def _init_conv(conv): 11 | nn.init.xavier_uniform_(conv.weight) 12 | if conv.bias is not None: 13 | nn.init.constant_(conv.bias, 0) 14 | 15 | 16 | def _init_norm(norm): 17 | if norm.weight is not None: 18 | nn.init.constant_(norm.weight, 1) 19 | nn.init.constant_(norm.bias, 0) 20 | 21 | 22 | def _generate_random_trimap(x, dist=(1,30), is_training=True): 23 | fg = (x>0.999).type(torch.float) 24 | un = (x>=0.001).type(torch.float) - fg 25 | un_np = (un*255).squeeze(1).data.cpu().numpy().astype(np.uint8) 26 | if is_training: 27 | thresh = np.random.randint(dist[0], dist[1]) 28 | else: 29 | thresh = (dist[0] + dist[1]) // 2 30 | un_np = [(morphology.distance_transform_edt(item==0) <= thresh) for item in un_np] 31 | un_np = np.array(un_np) 32 | un = torch.from_numpy(un_np).unsqueeze(1).to(x.device) 33 | trimap = fg 34 | trimap[un>0] = 0.5 35 | return trimap 36 | 37 | 38 | def _make_divisible(v, divisor, min_value=None): 39 | """ 40 | This function is taken from the original tf repo. 41 | It ensures that all layers have a channel number that is divisible by 8 42 | It can be seen here: 43 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 44 | """ 45 | if min_value is None: 46 | min_value = divisor 47 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 48 | # Make sure that round down does not go down by more than 10%. 49 | if new_v < 0.9 * v: 50 | new_v += divisor 51 | return new_v 52 | 53 | 54 | ## upsample tensor 'src' to have the same spatial size with tensor 'tar' 55 | def _upsample_like(src,tar,mode='bilinear'): 56 | src = F.interpolate(src,size=tar.shape[2:],mode=mode,align_corners=False if mode=='bilinear' else None) 57 | return src 58 | upas = _upsample_like 59 | 60 | 61 | def batch_slice(tensor, pos, size, mode='bilinear'): 62 | n, c, h, w = tensor.shape 63 | patchs = [] 64 | for i in range(n): 65 | # x1, y1, x2, y2 = torch.clamp(pos[i], 0, 1) 66 | x1, y1, x2, y2 = pos[i] 67 | x1 = int(x1.item() * w) 68 | y1 = int(y1.item() * h) 69 | x2 = int(x2.item() * w) 70 | y2 = int(y2.item() * h) 71 | patch = tensor[i:i+1, :, y1:y2, x1:x2].contiguous() 72 | patch = F.interpolate(patch, (size[0], size[1]), align_corners=False if mode=='bilinear' else None, mode=mode) 73 | patchs.append(patch) 74 | return torch.cat(patchs, dim=0) 75 | 76 | 77 | def hard_sigmoid(x, inplace: bool = False): 78 | if inplace: 79 | return x.add_(3.).clamp_(0., 6.).div_(6.) 80 | else: 81 | return F.relu6(x + 3.) / 6. 82 | 83 | 84 | ## copy weight from old tensor to new tensor 85 | def copy_weight(ws, wd): 86 | 87 | assert len(ws.shape)==4 or len(ws.shape)==1 88 | 89 | if len(ws.shape) == 4 and ws.shape[2]==ws.shape[3] and ws.shape[3]<=7: 90 | cout1, cin1, kh, kw = ws.shape 91 | cout2, cin2, kh, kw = wd.shape 92 | weight = torch.zeros((cout2, cin2, kh, kw)).float().to(ws.device) 93 | cout3 = min(cout1, cout2) 94 | cin3 = min(cin1, cin2) 95 | weight[:cout3, :cin3] = ws[:cout3, :cin3] 96 | elif len(ws.shape) == 4: 97 | kh, kw, cin1, cout1 = ws.shape # (3,3,4,64) 98 | kh, kw, cin2, cout2 = wd.shape 99 | print(ws.shape, wd.shape) 100 | weight = torch.zeros((kh, kw, cin2, cout2)).float().to(ws.device) 101 | cout3 = min(cout1, cout2) 102 | cin3 = min(cin1, cin2) 103 | weight[:, :, :cin3, :cout3] = ws[:, :, :cin3, :cout3] 104 | else: 105 | cout1, = ws.shape 106 | cout2, = wd.shape 107 | cout3 = min(cout1, cout2) 108 | weight = torch.zeros((cout3,)).float().to(ws.device) 109 | weight[:cout3] = ws[:cout3] 110 | return weight 111 | 112 | 113 | ## only works for models with same architecture 114 | def load_pretrained_weight(model, ckpt_path, copy=True): 115 | ckpt = torch.load(ckpt_path) 116 | filtered_ckpt = OrderedDict() 117 | model_ckpt = model.state_dict() 118 | for k,v in ckpt.items(): 119 | if k in model_ckpt: 120 | if v.shape==model_ckpt[k].shape: 121 | filtered_ckpt[k] = v 122 | elif copy: 123 | filtered_ckpt[k] = copy_weight(v, model_ckpt[k]) 124 | model.load_state_dict(filtered_ckpt, strict=False) 125 | return model 126 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import cv2 5 | from collections import OrderedDict 6 | from torchvision import transforms 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | from model import SparseMat 14 | from utils import load_config, get_logger 15 | from datasets import RescaleT, ToTensor, CustomDataset 16 | 17 | 18 | def get_timestamp(): 19 | from datetime import datetime 20 | now = datetime.now() 21 | dt_string = now.strftime("%Y-%m-%d-%H-%M-%S") 22 | return dt_string 23 | 24 | 25 | def load_checkpoint(net, pretrained_model, logger): 26 | net_state_dict = net.state_dict() 27 | state_dict = torch.load(pretrained_model) 28 | if 'state_dict' in state_dict: 29 | state_dict = state_dict['state_dict'] 30 | elif 'model_state_dict' in state_dict: 31 | state_dict = state_dict['model_state_dict'] 32 | 33 | filtered_state_dict = OrderedDict() 34 | for k,v in state_dict.items(): 35 | if k.startswith('module'): 36 | nk = '.'.join(k.split('.')[1:]) 37 | else: 38 | nk = k 39 | filtered_state_dict[nk] = v 40 | net.load_state_dict(filtered_state_dict) 41 | logger.info('load pretrained weight from {} successfully'.format(pretrained_model)) 42 | 43 | 44 | def load_test_filelist(test_data_path): 45 | test_images = [] 46 | test_labels = [] 47 | for line in open(test_data_path).read().splitlines(): 48 | splits = line.split(',') 49 | img_path, mat_path = splits 50 | test_labels.append(mat_path) 51 | test_images.append(img_path) 52 | return test_images, test_labels 53 | 54 | 55 | def compute_metrics(pred, gt): 56 | assert pred.size(0)==1 and pred.size(1)==1 57 | if pred.shape[2:] != gt.shape[2:]: 58 | pred = F.interpolate(pred, gt.shape[2:], mode='bilinear', align_corners=False) 59 | mad = (pred-gt).abs().mean() 60 | mse = ((pred-gt)**2).mean() 61 | return mse, mad 62 | 63 | 64 | def save_preds(pred, save_dir, filename): 65 | os.makedirs(save_dir, exist_ok=True) 66 | pred = pred.squeeze().data.cpu().numpy() * 255 67 | imgname = filename.split('/')[-1].split('.')[0] + '.png' 68 | cv2.imwrite(os.path.join(save_dir, imgname), pred) 69 | 70 | 71 | def test(cfg, net, dataloader, filenames, logger): 72 | net.eval() 73 | 74 | mse_list = [] 75 | mad_list = [] 76 | 77 | with torch.no_grad(): 78 | for i, data in enumerate(dataloader): 79 | input_dict = {} 80 | for k, v in data.items(): 81 | input_dict[k] = v.cuda() 82 | 83 | pred = net.inference(input_dict['hr_image']) 84 | origin_h = input_dict['origin_h'] 85 | origin_w = input_dict['origin_w'] 86 | pred = F.interpolate(pred, (origin_h, origin_w), align_corners=False, mode="bilinear") 87 | 88 | if cfg.test.save: 89 | save_preds(pred, cfg.test.save_dir, filenames[i]) 90 | 91 | gt = input_dict['hr_label'] 92 | mse, mad = compute_metrics(pred, gt) 93 | mse_list.append(mse.item()) 94 | mad_list.append(mad.item()) 95 | 96 | logger.info('[ith:%d/%d] mad:%.5f mse:%.5f' % (i, len(dataloader), mad.item(), mse.item())) 97 | 98 | avg_mad = np.array(mad_list).mean() 99 | avg_mse = np.array(mse_list).mean() 100 | logger.info('avg_mad:%.5f avg_mse:%.5f' % (avg_mad.item(), avg_mse.item())) 101 | 102 | 103 | def main(): 104 | parser = argparse.ArgumentParser(description='HM') 105 | parser.add_argument('--local_rank', type=int, default=0) 106 | parser.add_argument('--dist', action='store_true', help='use distributed training') 107 | parser.add_argument('-c', '--config', type=str, metavar='FILE', help='path to config file') 108 | parser.add_argument('-p', '--phase', default="train", type=str, metavar='PHASE', help='train or test') 109 | 110 | args = parser.parse_args() 111 | cfg = load_config(args.config) 112 | device_ids = range(torch.cuda.device_count()) 113 | 114 | dataset = cfg.data.dataset 115 | model_name = cfg.model.arch 116 | exp_name = args.config.split('/')[-1].split('.')[0] 117 | timestamp = get_timestamp() 118 | 119 | cfg.log.log_dir = os.path.join(os.getcwd(), 'log', model_name, dataset, exp_name+os.sep) 120 | cfg.log.log_path = os.path.join(cfg.log.log_dir, "log_eval.txt") 121 | os.makedirs(cfg.log.log_dir, exist_ok=True) 122 | 123 | if cfg.test.save_dir is None: 124 | cfg.test.save_dir = os.path.join(cfg.log.log_dir, 'vis') 125 | os.makedirs(cfg.test.save_dir, exist_ok=True) 126 | 127 | logger = get_logger(cfg.log.log_path) 128 | logger.info('[LogPath] {}'.format(cfg.log.log_dir)) 129 | 130 | test_images, test_labels = load_test_filelist(cfg.data.filelist_test) 131 | 132 | test_transform = transforms.Compose([ 133 | RescaleT(cfg), 134 | ToTensor(cfg) 135 | ]) 136 | 137 | test_dataset = CustomDataset( 138 | cfg, 139 | is_training=False, 140 | img_name_list=test_images, 141 | lbl_name_list=test_labels, 142 | transform=test_transform 143 | ) 144 | 145 | test_dataloader = DataLoader( 146 | test_dataset, 147 | batch_size=cfg.test.batch_size, 148 | shuffle=False, 149 | pin_memory=True, 150 | num_workers=cfg.test.num_workers 151 | ) 152 | 153 | net = SparseMat(cfg) 154 | 155 | if torch.cuda.is_available(): 156 | net.cuda() 157 | else: 158 | exit() 159 | 160 | load_checkpoint(net, cfg.test.checkpoint, logger) 161 | test(cfg, net, test_dataloader, test_images, logger) 162 | 163 | 164 | if __name__ == "__main__": 165 | main() 166 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import numpy as np 5 | import cv2 6 | from functools import partial 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torchvision 14 | 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torchvision import transforms 18 | 19 | from datasets import Rescale, RescaleT, RandomFlip, RandomCrop, ToTensor, CustomDataset 20 | from model import SparseMat, losses 21 | from utils import load_config, grid_images, get_logger 22 | 23 | 24 | def get_timestamp(): 25 | from datetime import datetime 26 | now = datetime.now() 27 | dt_string = now.strftime("%Y-%m-%d-%H-%M-%S") 28 | return dt_string 29 | 30 | 31 | def adjust_learning_rate(optimizer, epoch, epoch_decay, init_lr, min_lr=1e-6): 32 | for param_group in optimizer.param_groups: 33 | lr = max(init_lr * (0.1 ** (epoch // epoch_decay)), min_lr) 34 | param_group['lr'] = lr 35 | 36 | 37 | def load_checkpoint(net, pretrained_model, logger): 38 | net_state_dict = net.state_dict() 39 | state_dict = torch.load(pretrained_model) 40 | if 'state_dict' in state_dict: 41 | state_dict = state_dict['state_dict'] 42 | elif 'model_state_dict' in state_dict: 43 | state_dict = state_dict['model_state_dict'] 44 | 45 | filtered_state_dict = OrderedDict() 46 | for k,v in state_dict.items(): 47 | if k.startswith('module'): 48 | nk = '.'.join(k.split('.')[1:]) 49 | else: 50 | nk = k 51 | filtered_state_dict[nk] = v 52 | net.load_state_dict(filtered_state_dict, strict=False) 53 | logger.info('load pretrained weight from {} successfully'.format(pretrained_model)) 54 | 55 | 56 | def save_checkpoint(cfg, net, optimizer, epoch, iterations, running_loss, best_mad, is_best=False): 57 | state_dict = { 58 | 'state_dict': net.state_dict(), 59 | 'optimizer': optimizer.state_dict(), 60 | 'epoch': epoch, 61 | 'iteration': iterations + 1, 62 | 'running_loss': running_loss, 63 | 'best_mad': best_mad, 64 | } 65 | save_path = os.path.join(cfg.log.log_dir, "ckpt_e{}.pth".format(epoch)) 66 | torch.save(state_dict, save_path) 67 | 68 | latest_path = os.path.join(cfg.log.log_dir, "ckpt_latest.pth") 69 | shutil.copy(save_path, latest_path) 70 | 71 | if is_best: 72 | best_path = os.path.join(cfg.log.log_dir, "ckpt_best.pth") 73 | shutil.copy(save_path, best_path) 74 | 75 | 76 | def save_preds(pred, save_dir, filename): 77 | os.makedirs(save_dir, exist_ok=True) 78 | pred = pred.squeeze().data.cpu().numpy() * 255 79 | imgname = filename.split('/')[-1].split('.')[0] + '.png' 80 | cv2.imwrite(os.path.join(save_dir, imgname), pred) 81 | 82 | 83 | def load_filelist(data_path): 84 | images = [] 85 | labels = [] 86 | fgs = [] 87 | bgs = [] 88 | for line in open(data_path).read().splitlines(): 89 | splits = line.split(',') 90 | if len(splits) == 4: 91 | img_path, lbl_path, fg_path, bg_path = splits 92 | images.append(img_path) 93 | labels.append(lbl_path) 94 | fgs.append(fg_path) 95 | bgs.append(bg_path) 96 | else: 97 | img_path, lbl_path = splits 98 | images.append(img_path) 99 | labels.append(lbl_path) 100 | return images, labels, fgs, bgs 101 | 102 | 103 | def compute_metrics(pred, gt): 104 | if pred.shape[2:] != gt.shape[2:]: 105 | pred = F.interpolate(pred, gt.shape[2:], mode='bilinear', align_corners=False) 106 | mad = (pred-gt).abs().mean() 107 | mse = ((pred-gt)**2).mean() 108 | return mad, mse 109 | 110 | 111 | def train(cfg, net, optimizer, criterion, dataloader, writer, logger, epoch, iterations, best_mad): 112 | net.train() 113 | running_loss = 0.0 114 | 115 | for i, data in enumerate(dataloader): 116 | iterations += 1 117 | 118 | input_dict = {} 119 | for k, v in data.items(): 120 | input_dict[k] = v.cuda() 121 | 122 | optimizer.zero_grad() 123 | pred_list = net(input_dict) 124 | loss_dict = criterion(pred_list, input_dict) 125 | loss_dict['loss'].backward() 126 | optimizer.step() 127 | 128 | running_loss += loss_dict['loss'].item() 129 | 130 | cur_lr = optimizer.param_groups[0]['lr'] 131 | 132 | if iterations % cfg.log.print_frq == 0: 133 | for k,v in loss_dict.items(): 134 | writer.add_scalar('loss/'+k, loss_dict[k].item(), iterations) 135 | writer.add_scalar('loss/running_loss', running_loss/(i+1), iterations) 136 | writer.add_image('train/images', torch.cat(torch.unbind(pred_list[-1], dim=0), dim=1), global_step=iterations) 137 | if 'comp_loss' in loss_dict: 138 | logger.info('[epo:%d/%d][iter:%d/%d] lr:%5f loss:%.3f alpha_loss:%.3f comp_Loss:%.3f running_loss:%.3f' % ( 139 | epoch, cfg.train.epoch, (i+1), len(dataloader), cur_lr, loss_dict['loss'], 140 | loss_dict['alpha_loss'], loss_dict['comp_loss'], 141 | running_loss/(i+1))) 142 | else: 143 | logger.info('[epo:%d/%d][iter:%d/%d] lr:%5f loss:%.3f running_loss:%.3f' % ( 144 | epoch, cfg.train.epoch, (i+1), len(dataloader), cur_lr, loss_dict['loss'], running_loss/(i+1))) 145 | 146 | # comment this line if memory is sufficient 147 | torch.cuda.empty_cache() 148 | 149 | return iterations, running_loss 150 | 151 | 152 | def test(cfg, net, dataloader, writer, logger, epoch, filenames): 153 | net.eval() 154 | 155 | mse_list = [] 156 | mad_list = [] 157 | 158 | with torch.no_grad(): 159 | for i, data in enumerate(dataloader): 160 | 161 | input_dict = {} 162 | for k, v in data.items(): 163 | input_dict[k] = v.cuda() 164 | 165 | pred = net.inference(input_dict['hr_image']) 166 | origin_h = input_dict['origin_h'] 167 | origin_w = input_dict['origin_w'] 168 | pred = F.interpolate(pred, (origin_h, origin_w), align_corners=False, mode="bilinear") 169 | 170 | gt = input_dict['hr_label'] 171 | mad, mse = compute_metrics(pred, gt) 172 | mse_list.append(mse.item()) 173 | mad_list.append(mad.item()) 174 | 175 | logger.info('[ith:%d/%d] mad:%.5f mse:%.5f' % (i, len(dataloader), mad.item(), mse.item())) 176 | 177 | avg_mad = np.array(mad_list).mean() 178 | avg_mse = np.array(mse_list).mean() 179 | logger.info('[epo:%d/%d][ith:%d/%d] mad:%.3f mse:%.5f' % (epoch, cfg.train.epoch, i, len(dataloader), mad.item(), mse.item())) 180 | return avg_mad 181 | 182 | 183 | def main(): 184 | parser = argparse.ArgumentParser(description='HM') 185 | parser.add_argument('--local_rank', type=int, default=0) 186 | parser.add_argument('--dist', action='store_true', help='use distributed training') 187 | parser.add_argument('-e', '--evaluate', action='store_true', help='evaluate or not') 188 | parser.add_argument('-c', '--config', type=str, metavar='FILE', help='path to config file') 189 | parser.add_argument('-p', '--phase', default="train", type=str, metavar='PHASE', help='train or test') 190 | 191 | args = parser.parse_args() 192 | cfg = load_config(args.config) 193 | best_mad = 1e12 194 | device_ids = range(torch.cuda.device_count()) 195 | 196 | dataset = cfg.data.dataset 197 | model_name = cfg.model.arch 198 | exp_name = args.config.split('/')[-1].split('.')[0] 199 | timestamp = get_timestamp() 200 | 201 | cfg.log.log_dir = os.path.join(os.getcwd(), 'log', model_name, dataset, exp_name+os.sep) 202 | cfg.log.viz_dir = os.path.join(cfg.log.log_dir, "tensorboardx", timestamp) 203 | cfg.log.log_path = os.path.join(cfg.log.log_dir, "log.txt") 204 | os.makedirs(cfg.log.log_dir, exist_ok=True) 205 | os.makedirs(cfg.log.viz_dir, exist_ok=True) 206 | 207 | if cfg.test.save_dir is None: 208 | cfg.test.save_dir = os.path.join(cfg.log.log_dir, 'vis') 209 | os.makedirs(cfg.test.save_dir, exist_ok=True) 210 | 211 | writer = SummaryWriter(cfg.log.viz_dir) 212 | logger = get_logger(cfg.log.log_path) 213 | 214 | logger.info('[LogPath] {}'.format(cfg.log.log_dir)) 215 | logger.info('[VizPath] {}'.format(cfg.log.viz_dir)) 216 | 217 | train_images, train_labels, train_fgs, train_bgs = load_filelist(cfg.data.filelist_train) 218 | test_images, test_labels, test_fgs, test_bgs = load_filelist(cfg.data.filelist_val) 219 | 220 | train_transform = transforms.Compose([ 221 | Rescale(cfg), 222 | RandomCrop(cfg), 223 | RandomFlip(cfg), 224 | ToTensor(cfg) 225 | ]) 226 | 227 | test_transform = transforms.Compose([ 228 | RescaleT(cfg), 229 | ToTensor(cfg) 230 | ]) 231 | 232 | train_dataset = CustomDataset( 233 | cfg, True, 234 | img_name_list=train_images, 235 | lbl_name_list=train_labels, 236 | fg_name_list=train_fgs, 237 | bg_name_list=train_bgs, 238 | transform=train_transform 239 | ) 240 | test_dataset = CustomDataset( 241 | cfg, False, 242 | img_name_list=test_images, 243 | lbl_name_list=test_labels, 244 | fg_name_list=test_fgs, 245 | bg_name_list=test_bgs, 246 | transform=test_transform 247 | ) 248 | 249 | train_dataloader = DataLoader( 250 | train_dataset, 251 | batch_size=cfg.train.batch_size, 252 | shuffle=True, 253 | pin_memory=True, 254 | drop_last=True, 255 | num_workers=cfg.train.num_workers 256 | ) 257 | test_dataloader = DataLoader( 258 | test_dataset, 259 | batch_size=cfg.test.batch_size, 260 | shuffle=False, 261 | pin_memory=True, 262 | drop_last=True, 263 | num_workers=cfg.test.num_workers 264 | ) 265 | 266 | net = SparseMat(cfg) 267 | criterion = partial( 268 | losses, 269 | alpha_loss_weights=cfg.loss.alpha_loss_weights, 270 | with_composition_loss=cfg.loss.with_composition_loss, 271 | composition_loss_weight=cfg.loss.composition_loss_weight, 272 | ) 273 | 274 | load_checkpoint(net.lpn, cfg.train.pretrained_model, logger) 275 | 276 | if torch.cuda.is_available(): 277 | net.cuda() 278 | else: 279 | exit() 280 | 281 | if len(device_ids)>0: 282 | net = torch.nn.DataParallel(net) 283 | net_without_dp = net.module 284 | else: 285 | net_without_dp = net 286 | 287 | logger.info("---define optimizer...") 288 | optimizer = optim.Adam( 289 | net.parameters(), 290 | lr=cfg.train.lr, 291 | betas=(cfg.train.beta1, cfg.train.beta2), 292 | eps=1e-08, 293 | weight_decay=0, 294 | ) 295 | 296 | logger.info("---start training...") 297 | iterations = 0 298 | running_loss = 0.0 299 | 300 | resume_checkpoint = os.path.join(cfg.log.log_dir, 'ckpt_latest.pth') 301 | if (args.evaluate or cfg.train.resume) and os.path.exists(resume_checkpoint): 302 | state_dict = torch.load(resume_checkpoint) 303 | if state_dict['epoch'] < cfg.train.epoch: 304 | logger.info("Resume checkpoint from {}".format(resume_checkpoint)) 305 | if 'best_mad' in state_dict: 306 | best_mad = state_dict['best_mad'] 307 | if 'epoch' in state_dict: 308 | cfg.train.start_epoch = state_dict['epoch'] 309 | filtered_state_dict = OrderedDict() 310 | for k,v in state_dict['state_dict'].items(): 311 | if k.startswith('module'): 312 | nk = '.'.join(k.split('.')[1:]) 313 | else: 314 | nk = k 315 | filtered_state_dict[nk] = v 316 | net.module.load_state_dict(filtered_state_dict, strict=True) 317 | 318 | if args.evaluate: 319 | test(cfg, net_without_dp, test_dataloader, writer, logger, cfg.train.start_epoch, test_images) 320 | exit() 321 | 322 | for epoch in range(cfg.train.start_epoch, cfg.train.epoch): 323 | iterations, running_loss = train(cfg, net, optimizer, criterion, train_dataloader, writer, logger, epoch+1, iterations, best_mad) 324 | mad = test(cfg, net_without_dp, test_dataloader, writer, logger, epoch+1, test_images) 325 | if mad < best_mad: 326 | best_mad = min(mad, best_mad) 327 | save_checkpoint(cfg, net_without_dp, optimizer, epoch+1, iterations, running_loss, best_mad, is_best=True) 328 | else: 329 | save_checkpoint(cfg, net_without_dp, optimizer, epoch+1, iterations, running_loss, best_mad, is_best=False) 330 | adjust_learning_rate(optimizer, epoch, cfg.train.epoch_decay, cfg.train.lr, min_lr=cfg.train.min_lr) 331 | 332 | 333 | if __name__ == "__main__": 334 | main() 335 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .config import load_config 4 | from .viz_utils import grid_images 5 | 6 | 7 | def get_logger(filename): 8 | logger = logging.getLogger() 9 | logger.setLevel(logging.INFO) 10 | fh = logging.FileHandler(filename, mode='a') 11 | fh.setLevel(logging.INFO) 12 | ch = logging.StreamHandler() 13 | ch.setLevel(logging.INFO) 14 | formatter = logging.Formatter("%(asctime)s - %(message)s") 15 | fh.setFormatter(formatter) 16 | ch.setFormatter(formatter) 17 | logger.addHandler(fh) 18 | logger.addHandler(ch) 19 | return logger 20 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | 3 | CONFIG = EasyDict({}) 4 | CONFIG.is_default = True 5 | CONFIG.version = "baseline" 6 | CONFIG.debug = False 7 | # choices from train,evaluate,inference 8 | CONFIG.phase = "train" 9 | # distributed training 10 | CONFIG.dist = False 11 | # global variables which will be assigned in the runtime 12 | CONFIG.local_rank = 0 13 | CONFIG.gpu = 0 14 | CONFIG.world_size = 1 15 | CONFIG.devices = (0,) 16 | 17 | 18 | # =============================================================================== 19 | # Model config 20 | # =============================================================================== 21 | CONFIG.model = EasyDict({}) 22 | CONFIG.model.arch = 'SparseMat' 23 | 24 | # Model -> Architecture config 25 | CONFIG.model.in_channel = 3 26 | CONFIG.model.hr_channel = 32 27 | # global modules (ppm, aspp) 28 | CONFIG.model.global_module = "ppm" 29 | CONFIG.model.pool_scales = (1,2,3,6) 30 | CONFIG.model.ppm_channel = 256 31 | CONFIG.model.atrous_rates = (12, 24, 36) 32 | CONFIG.model.aspp_channel = 256 33 | CONFIG.model.with_norm = True 34 | CONFIG.model.with_aspp = True 35 | CONFIG.model.dilation_kernel = 15 36 | CONFIG.model.max_n_pixel = 4000000 37 | 38 | # =============================================================================== 39 | # Dataloader config 40 | # =============================================================================== 41 | 42 | CONFIG.aug = EasyDict({}) 43 | CONFIG.aug.rescale_size = 320 44 | CONFIG.aug.crop_size = 288 45 | CONFIG.aug.patch_crop_size = (320,640) 46 | CONFIG.aug.patch_load_size = 320 47 | 48 | CONFIG.data = EasyDict({}) 49 | CONFIG.data.workers = 0 50 | CONFIG.data.dataset = None 51 | CONFIG.data.composite = False 52 | CONFIG.data.filelist = None 53 | CONFIG.data.filelist_train = None 54 | CONFIG.data.filelist_val = None 55 | CONFIG.data.filelist_test = None 56 | 57 | 58 | # =============================================================================== 59 | # Loss config 60 | # =============================================================================== 61 | CONFIG.loss = EasyDict({}) 62 | CONFIG.loss.alpha_loss_weights = [1.0,1.0,1.0,1.0] 63 | CONFIG.loss.with_composition_loss = False 64 | CONFIG.loss.composition_loss_weight = 1.0 65 | 66 | # =============================================================================== 67 | # Training config 68 | # =============================================================================== 69 | CONFIG.train = EasyDict({}) 70 | 71 | CONFIG.train.num_workers = 4 72 | CONFIG.train.batch_size = 8 73 | # epochs 74 | CONFIG.train.start_epoch = 0 75 | CONFIG.train.epoch = 100 76 | CONFIG.train.epoch_decay = 95 77 | # basic learning rate of optimizer 78 | CONFIG.train.lr = 1e-5 79 | CONFIG.train.min_lr = 1e-8 80 | CONFIG.train.reset_lr = False 81 | CONFIG.train.adaptive_lr = False 82 | # beta1 and beta2 for Adam 83 | CONFIG.train.optim = "Adam" 84 | CONFIG.train.eps = 1e-5 85 | CONFIG.train.beta1 = 0.9 86 | CONFIG.train.beta2 = 0.999 87 | CONFIG.train.momentum = 0.9 88 | CONFIG.train.weight_decay = 1e-4 89 | # clip large gradient 90 | CONFIG.train.clip_grad = True 91 | # reset the learning rate (this option will reset the optimizer and learning rate scheduler and ignore warmup) 92 | CONFIG.train.pretrained_model = None 93 | CONFIG.train.resume = False 94 | 95 | CONFIG.train.rescale_size = 320 96 | CONFIG.train.crop_size = 288 97 | CONFIG.train.color_space = 3 98 | 99 | 100 | # =============================================================================== 101 | # Testing config 102 | # =============================================================================== 103 | CONFIG.test = EasyDict({}) 104 | # test image scale to evaluate, "origin" or "resize" or "crop" 105 | CONFIG.test.num_workers = 4 106 | CONFIG.test.batch_size = 1 107 | CONFIG.test.rescale_size = 320 108 | CONFIG.test.max_size = 1920 109 | CONFIG.test.patch_size = 320 110 | CONFIG.test.checkpoint = None 111 | CONFIG.test.save = False 112 | CONFIG.test.save_dir = None 113 | CONFIG.test.cascade = False 114 | # "best_model" or "latest_model" or other base name of the pth file. 115 | 116 | 117 | # =============================================================================== 118 | # Logging config 119 | # =============================================================================== 120 | CONFIG.log = EasyDict({}) 121 | CONFIG.log.log_dir = None 122 | CONFIG.log.viz_dir = None 123 | CONFIG.log.save_frq = 2000 124 | CONFIG.log.print_frq = 20 125 | CONFIG.log.test_frq = 1 126 | CONFIG.log.viz = True 127 | CONFIG.log.show_all = True 128 | 129 | 130 | # =============================================================================== 131 | # util functions 132 | # =============================================================================== 133 | def parse_config(custom_config, default_config=CONFIG, prefix="CONFIG"): 134 | """ 135 | This function will recursively overwrite the default config by a custom config 136 | :param default_config: 137 | :param custom_config: parsed from config/config.toml 138 | :param prefix: prefix for config key 139 | :return: None 140 | """ 141 | if "is_default" in default_config: 142 | default_config.is_default = False 143 | 144 | for key in custom_config.keys(): 145 | full_key = ".".join([prefix, key]) 146 | if key not in default_config: 147 | raise NotImplementedError("Unknown config key: {}".format(full_key)) 148 | elif isinstance(custom_config[key], dict): 149 | if isinstance(default_config[key], dict): 150 | parse_config(default_config=default_config[key], 151 | custom_config=custom_config[key], 152 | prefix=full_key) 153 | else: 154 | raise ValueError("{}: Expected {}, got dict instead.".format(full_key, type(custom_config[key]))) 155 | else: 156 | if isinstance(default_config[key], dict): 157 | raise ValueError("{}: Expected dict, got {} instead.".format(full_key, type(custom_config[key]))) 158 | else: 159 | default_config[key] = custom_config[key] 160 | 161 | 162 | def load_config(config_path): 163 | import toml 164 | with open(config_path) as fp: 165 | custom_config = EasyDict(toml.load(fp)) 166 | parse_config(custom_config=custom_config) 167 | return CONFIG 168 | 169 | 170 | if __name__ == "__main__": 171 | from pprint import pprint 172 | 173 | pprint(CONFIG) 174 | load_config("../config/example.toml") 175 | pprint(CONFIG) 176 | -------------------------------------------------------------------------------- /utils/viz_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | def grid_images(pred_dict, input_dict, show_all=False): 10 | lr_image = input_dict['lr_image'] * 0.5 + 0.5 11 | lr_label = input_dict['lr_label_mat'].expand_as(lr_image) 12 | lr_mask = input_dict['lr_label_unk'].expand_as(lr_image) 13 | lr_pred = pred_dict['coarse'] 14 | if lr_pred.shape[2:] != lr_image.shape[2:]: 15 | lr_pred = F.interpolate(lr_pred, lr_image.shape[2:], mode="bilinear", align_corners=False) 16 | lr_pred = lr_pred.expand_as(lr_image) 17 | 18 | h, w = lr_image.size(2), lr_image.size(3) 19 | 20 | tmps = [] 21 | if show_all: 22 | extra_keys = ['global_seg', 'global_mat', 'errormap', 'classmap'] 23 | for key in extra_keys: 24 | if key in pred_dict: 25 | if key == 'errormap': 26 | tmp = pred_dict[key] 27 | if tmp.size(2) != h or tmp.size(3) != w: 28 | tmp = F.interpolate(tmp, (h,w), mode='nearest') 29 | elif key == 'classmap': 30 | tmp = torch.argmax(pred_dict[key], dim=1, keepdim=True).float() / 2. 31 | # if tmp.min() < 0: 32 | # tmp = (tmp + 1) / 2. 33 | # if tmp.size(2) != h or tmp.size(3) != w: 34 | # tmp = F.interpolate(tmp, (h,w), mode='nearest') 35 | # tmp = tmp.repeat(1,3,1,1).float() 36 | else: 37 | tmp = pred_dict[key][0] 38 | if tmp.size(2) != h or tmp.size(3) != w: 39 | tmp = F.interpolate(tmp, (h,w), mode='bilinear', align_corners=False) 40 | tmp = tmp.expand_as(lr_image) 41 | tmps.append(tmp) 42 | 43 | if 'fine' in pred_dict: 44 | hr_image = input_dict['hr_image'] * 0.5 + 0.5 45 | hr_label = input_dict['hr_label_mat'] 46 | hr_pred = pred_dict['fine'].expand_as(hr_image) 47 | if hr_image.size(2) != h or hr_image.size(3) != w: 48 | hr_image = F.interpolate(hr_image, (h,w), mode='bilinear', align_corners=False) 49 | hr_label = F.interpolate(hr_label, (h,w), mode='bilinear', align_corners=False) 50 | hr_pred = F.interpolate(hr_pred, (h,w), mode='bilinear', align_corners=False) 51 | hr_label = hr_label.expand_as(hr_image) 52 | grid = torch.cat([lr_image, lr_label, lr_mask, lr_pred]+tmps+[hr_image, hr_label, hr_pred], dim=3) 53 | else: 54 | grid = torch.cat([lr_image, lr_label, lr_mask, lr_pred]+tmps, dim=3) 55 | grid = F.interpolate(grid, scale_factor=0.5, mode='bilinear', align_corners=False) 56 | n,c,h,w = grid.size() 57 | grid = grid.permute(1,0,2,3).contiguous().view(c,n*h,w) 58 | # np_img = cv2.cvtColor(np.transpose(grid.data.cpu().numpy(), (1,2,0))*255, cv2.COLOR_RGB2BGR) 59 | # cv2.imwrite('tmp/tmp.png', np_img) 60 | return grid 61 | 62 | 63 | def save_preds(viz_dir, img, lbl, res1, res2): 64 | res1 = F.interpolate(res1, (res2.size(2), res2.size(3))) 65 | img_color = np.transpose(img.data.cpu().numpy(), (0,2,3,1))[:,:,:,::-1]*255 66 | lbl_color = np.tile(lbl.squeeze().data.cpu().numpy()[:,:,:,None], (1,1,1,3)) * 255 67 | res1_color = np.tile(torch.clamp(res1,0,1).squeeze().data.cpu().numpy()[:,:,:,None], (1,1,1,3)) * 255 68 | res2_color = np.tile(torch.clamp(res2,0,1).squeeze().data.cpu().numpy()[:,:,:,None], (1,1,1,3)) * 255 69 | shows = [] 70 | for i in range(img_color.shape[0]): 71 | shows.append(np.concatenate((img_color[i], lbl_color[i], res1_color[i], res2_color[i]), axis=1)) 72 | shows = np.concatenate(shows, axis=0) 73 | ratio = 1200.0 / shows.shape[1] 74 | shows = cv2.resize(shows, None, fx=ratio, fy=ratio) 75 | cv2.imwrite(os.path.join(viz_dir,"viz.png"), shows) 76 | 77 | 78 | def save_labels(labels, save_dir): 79 | n, c, h, w = labels.shape 80 | labels = labels[:,0].data.cpu().numpy() 81 | 82 | template = np.zeros((h*n,w,3)) 83 | 84 | for i in range(n): 85 | label_color = idx_to_colormat(labels[i]) 86 | template[h*i:h*(i+1), w*0:w*1] = label_color 87 | 88 | cv2.imwrite(os.path.join(save_dir, "viz.png"), template) 89 | 90 | 91 | def save_raw_labels(labels, save_dir): 92 | h, w = labels.shape[:2] 93 | template = np.zeros((h,w,3)) 94 | 95 | label_color = idx_to_colormat(labels) 96 | cv2.imwrite(os.path.join(save_dir, "viz.png"), label_color) 97 | --------------------------------------------------------------------------------