├── README.md
├── config
└── sparsemat.toml
├── data
└── generate_filelist.py
├── datasets
├── __init__.py
├── data_loader.py
└── utils.py
├── demo.py
├── figures
└── framework.png
├── model
├── __init__.py
├── backbones
│ ├── __init__.py
│ ├── dilated_resnet_bn.py
│ ├── mobilenetv2.py
│ ├── mobilenetv3.py
│ ├── resnet_bn.py
│ ├── sparse_resnet_bn.py
│ └── wrapper.py
├── lap_pyramid_loss.py
├── loss.py
├── lpn.py
├── model.py
├── shm.py
└── utils.py
├── test.py
├── train.py
└── utils
├── __init__.py
├── config.py
└── viz_utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # SparseMat
2 | Repository for *Ultrahigh Resolution Image/Video Matting with Spatio-Temporal Sparsity*, which has been accepted by CVPR2023.
3 |
4 |
5 |
6 | ### Overview
7 |
8 | Commodity ultrahigh definition (UHD) displays are becoming more affordable which demand imaging in ultrahigh resolution (UHR). This paper proposes SparseMat, a computationally efficient approach for UHR image/video matting. Note that it is infeasible to directly process UHR images at full resolution in one shot using existing matting algorithms without running out of memory on consumer-level computational platforms, e.g., Nvidia 1080Ti with 11G memory, while patch-based approaches can introduce unsightly artifacts due to patch partitioning. Instead, our method resorts to spatial and temporal sparsity for addressing general UHR matting. When processing videos, huge computation redundancy can be reduced by exploiting spatial and temporal sparsity. In this paper, we show how to effectively detect spatio-temporal sparsity, which serves as a gate to activate input pixels for the matting model. Under the guidance of such sparsity, our method with sparse high-resolution module (SHM) can avoid patch-based inference while memory efficient for full-resolution matte refinement. Extensive experiments demonstrate that SparseMat can effectively and efficiently generate high-quality alpha matte for UHR images and videos at the original high resolution in a single pass.
9 |
10 | ### Environment
11 | The recommended pytorch and torchvision version is v1.9.0 and v0.10.0.
12 |
13 | - torch
14 | - torchvision
15 | - easydict
16 | - toml
17 | - pillow
18 | - scikit-image
19 | - scipy
20 | - spconv. Please install sparse conv module refer to [traveller59/spconv](https://github.com/traveller59/spconv/tree/v1.2.1). Note that we use version 1.2.1 instead of the latest version.
21 |
22 | ### Dataset
23 | Existing datasets suffer from limited resolution. Thus, in this paper we contribute the first UHR human matting dataset, composed of HHM50K for training and HHM2K for evaluation. HHM50K and HHM2K consist of respectively 50,000 and 2,000 unique UHR images (with an average resolution of 4K) encompassing a wide range of human poses and matting scenarios. We provide the downloading link below.
24 | - HHM50K: [BaiduDisk](https://pan.baidu.com/s/1txjXk7OH3vIH7yrmpfNThA), password 2tsc
25 | - HHM2K: [BaiduDisk](https://pan.baidu.com/s/1RKu3qJRRMlgfZbIN7P4j4w), password ymyr
26 |
27 | You can download and put them under `data` directory. Then run the following command to generate file lists.
28 | ```
29 | python3 data/generate_filelist.py
30 | ```
31 |
32 | ### Code
33 | ###### Training
34 | Run the following command to train the model. To train SparseMat with our self-trained low-resolution prior network, please download [here](https://drive.google.com/file/d/1_zDQbul-lCM-tFEWNcdw0D4jr3WaK1ir/view?usp=sharing) and put it under the `pretrained` directory.
35 | ```
36 | work_dir=/PATH/TO/SparseMat
37 | cd $work_dir
38 | export PYTHONPATH=$PYTHONPATH:$work_dir
39 | python3 train.py -c configs/sparsemat.toml
40 | ```
41 |
42 | ###### Testing
43 | Run the following command to evalute the model. You can download our pretrained model [here](https://drive.google.com/file/d/19MX3USM4BK3sYi0o3AHNUxJ8bZEAGXg9/view?usp=sharing) and put it under the `pretrained` directory.
44 | ```
45 | work_dir=/PATH/TO/SparseMat
46 | cd $work_dir
47 | export PYTHONPATH=$PYTHONPATH:$work_dir
48 | python3 test.py -c configs/sparsemat.toml
49 | ```
50 |
51 | ###### Inference
52 | You can use the following command to inference the model on images or videos.
53 | ```
54 | work_dir=/PATH/TO/SparseMat
55 | cd $work_dir
56 | export PYTHONPATH=$PYTHONPATH:$work_dir
57 | python3 demo.py -c configs/sparsemat.toml --input --save_dir
58 | ```
59 |
60 | ### Reference
61 | ```
62 | @InProceedings{Sun_2023_CVPR,
63 | author = {Sun, Yanan and Tang, Chi-Keung and Tai, Yu-Wing},
64 | title = {Ultrahigh Resolution Image/Video Matting With Spatio-Temporal Sparsity},
65 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
66 | month = {June},
67 | year = {2023},
68 | pages = {14112-14121}
69 | }
70 | ```
71 |
--------------------------------------------------------------------------------
/config/sparsemat.toml:
--------------------------------------------------------------------------------
1 | # Refer to utils/config.py for definition and options.
2 | version = "SparseMat"
3 |
4 | [model]
5 | dilation_kernel = 15
6 | max_n_pixel = 4000000
7 |
8 | [loss]
9 | alpha_loss_weights = [0.1, 0.1, 0.1, 1.0]
10 | with_composition_loss = true
11 | composition_loss_weight = 0.5
12 |
13 | [train]
14 | batch_size = 12
15 | epoch = 30
16 | epoch_decay = 10
17 | lr = 0.0001
18 | min_lr = 0.00001
19 | adaptive_lr = true
20 | beta1 = 0.9
21 | beta2 = 0.999
22 | pretrained_model = "pretrained/lpn.pth"
23 | num_workers = 16
24 |
25 | [aug]
26 | rescale_size = 560
27 | crop_size = 512
28 | patch_crop_size = [512, 640, 800]
29 | patch_load_size = 512
30 |
31 | [data]
32 | dataset = "HHM50K"
33 | filelist_train = "data/HHM50K.txt"
34 | filelist_val = "data/HHM2K.txt"
35 | filelist_test = "data/HHM2K.txt"
36 |
37 | [log]
38 | save_frq = 50
39 |
40 | [test]
41 | batch_size = 1
42 | rescale_size = 512
43 | patch_size = 512
44 | max_size = 7680
45 | save = true
46 | cascade = true
47 | checkpoint = "pretrained/SparseMat.pth"
48 | save_dir = "predictions/SparseMatte/HHM2K"
49 |
--------------------------------------------------------------------------------
/data/generate_filelist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 |
5 |
6 | if __name__ == "__main__":
7 |
8 | root = "data/HHM50K"
9 | writer = open("data/HHM50K.txt", "w")
10 |
11 | images = sorted(glob.glob(os.path.join(root, "images/*.jpg")))
12 | alphas = sorted(glob.glob(os.path.join(root, "alphas/*.png")))
13 | fgs = sorted(glob.glob(os.path.join(root, "foregrounds/*.jpg")))
14 | bgs = sorted(glob.glob(os.path.join(root, "backgrounds/*.jpg")))
15 |
16 | assert len(images) == len(alphas)
17 | assert len(images) == len(fgs)
18 | assert len(images) == len(bgs)
19 |
20 | for img, pha, fg, bg in zip(images, alphas, fgs, bgs):
21 | img_name = img.split('/')[-1][:-4]
22 | pha_name = pha.split('/')[-1][:-4]
23 | fg_name = fg.split('/')[-1][:-4]
24 | bg_name = bg.split('/')[-1][:-4]
25 | assert img_name == pha_name
26 | assert img_name == fg_name
27 | assert img_name == bg_name
28 | writer.write(f"{img},{pha},{fg},{bg}\n")
29 |
30 |
31 | root = "data/HHM2K"
32 | writer = open("data/HHM2K.txt", "w")
33 |
34 | images = sorted(glob.glob(os.path.join(root, "images/*.jpg")))
35 | alphas = sorted(glob.glob(os.path.join(root, "alphas/*.png")))
36 |
37 | assert len(images) == len(alphas)
38 |
39 | for img, pha in zip(images, alphas):
40 | img_name = img.split('/')[-1][:-4]
41 | pha_name = pha.split('/')[-1][:-4]
42 | assert img_name == pha_name
43 | writer.write(f"{img},{pha}\n")
44 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_loader import Rescale
2 | from .data_loader import RescaleT
3 | from .data_loader import RandomFlip
4 | from .data_loader import RandomCrop
5 | from .data_loader import ToTensor
6 | from .data_loader import CustomDataset
7 |
--------------------------------------------------------------------------------
/datasets/data_loader.py:
--------------------------------------------------------------------------------
1 | # data loader
2 | from __future__ import print_function, division
3 | import os
4 | import glob
5 | import numpy as np
6 | import random
7 | import math
8 | import cv2
9 | from PIL import Image
10 | from skimage import io, transform, color
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.utils.data import Sampler, Dataset, DataLoader
15 | from torchvision import transforms, utils
16 |
17 | from .utils import convert_color_space, get_random_patch
18 |
19 |
20 | def imread(path):
21 | image = cv2.imread(path)
22 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
23 | return image
24 |
25 |
26 | class RandomFlip(object):
27 | def __init__(self, cfg):
28 | self.cfg = cfg
29 |
30 | def __call__(self, sample):
31 | # randomly flip
32 | if random.random() >= 0.5:
33 | pos = sample['pos']
34 | x1 = 1. - pos[..., 2]
35 | x2 = 1. - pos[..., 0]
36 | pos[..., 0] = x1
37 | pos[..., 2] = x2
38 | sample['pos'] = pos
39 | sample['hr_image'] = sample['hr_image'][:,::-1].copy()
40 | sample['lr_image'] = sample['lr_image'][:,::-1].copy()
41 | sample['hr_label'] = sample['hr_label'][:,::-1].copy()
42 | sample['hr_unknown'] = sample['hr_unknown'][:,::-1].copy()
43 | if 'hr_fg' in sample:
44 | sample['hr_fg'] = sample['hr_fg'][:,::-1].copy()
45 | sample['hr_bg'] = sample['hr_bg'][:,::-1].copy()
46 | return sample
47 |
48 |
49 | class Rescale(object):
50 | def __init__(self, cfg):
51 | assert isinstance(cfg.aug.rescale_size,(int,tuple))
52 | self.output_size = cfg.aug.rescale_size
53 |
54 | def __call__(self,sample):
55 | h, w = sample['hr_image'].shape[:2]
56 | sample['origin_h'] = h
57 | sample['origin_w'] = w
58 | if isinstance(self.output_size,int):
59 | ratio = self.output_size / min(h,w)
60 | new_h, new_w = ratio*h, ratio*w
61 | else:
62 | new_h, new_w = self.output_size
63 | new_h, new_w = int(new_h), int(new_w)
64 | sample['lr_image'] = cv2.resize(sample['hr_image'], (new_w, new_h), interpolation=cv2.INTER_LINEAR)
65 | return sample
66 |
67 |
68 | class RescaleT(object):
69 | def __init__(self, cfg):
70 | self.cfg = cfg
71 | self.max_size = cfg.test.max_size
72 | self.output_size = cfg.test.rescale_size
73 | assert isinstance(self.output_size,(int,tuple))
74 |
75 | def get_dst_size(self, origin_size, output_size=None, stride=32, max_size=1920):
76 | h, w = origin_size
77 | if output_size is None:
78 | ratio = max_size / max(h,w)
79 | if ratio>=1:
80 | new_h, new_w = h, w
81 | else:
82 | new_h, new_w = int(math.ceil(ratio*h)), int(math.ceil(ratio*w))
83 | elif isinstance(output_size,int):
84 | if output_size>=max_size:
85 | ratio = output_size / max(h,w)
86 | else:
87 | ratio = output_size / min(h,w)
88 | new_h, new_w = int(math.ceil(ratio*h)), int(math.ceil(ratio*w))
89 | else:
90 | new_h, new_w = output_size
91 | new_h = new_h - new_h % 32
92 | new_w = new_w - new_w % 32
93 | return (new_h, new_w)
94 |
95 | def __call__(self,sample):
96 | h, w = sample['hr_image'].shape[:2]
97 | sample['origin_h'] = h
98 | sample['origin_w'] = w
99 | new_h, new_w = self.get_dst_size((h,w), self.output_size, 32)
100 | sample['lr_image'] = cv2.resize(sample['hr_image'], (new_w, new_h), interpolation=cv2.INTER_LINEAR)
101 | return sample
102 |
103 |
104 | class RandomCrop(object):
105 |
106 | def __init__(self, cfg):
107 | # low-resolution full image
108 | output_size = cfg.aug.crop_size
109 | assert isinstance(output_size, (int, tuple))
110 | if isinstance(output_size, int):
111 | self.output_size = (output_size, output_size)
112 | else:
113 | assert len(output_size) == 2
114 | self.output_size = output_size
115 |
116 | # full-resolution patch
117 | patch_crop_size = cfg.aug.patch_crop_size
118 | assert isinstance(patch_crop_size, (tuple, list))
119 | self.patch_crop_size = patch_crop_size
120 |
121 | patch_load_size = cfg.aug.patch_load_size
122 | assert isinstance(patch_load_size, int)
123 | self.patch_load_size = patch_load_size
124 |
125 | self.cfg = cfg
126 |
127 | def random_crop(self, sample):
128 | h, w = sample['lr_image'].shape[:2]
129 | new_h, new_w = self.output_size
130 | ly1 = np.random.randint(0, h - new_h)
131 | lx1 = np.random.randint(0, w - new_w)
132 | ly2 = ly1 + new_h
133 | lx2 = lx1 + new_w
134 |
135 | oh, ow = sample['hr_image'].shape[:2]
136 | ratio_h = oh / float(h)
137 | ratio_w = ow / float(w)
138 | hx1, hy1 = int(lx1*ratio_w), int(ly1*ratio_h)
139 | hx2, hy2 = int(lx2*ratio_w), int(ly2*ratio_h)
140 | return (lx1,ly1,lx2,ly2), (hx1,hy1,hx2,hy2)
141 |
142 | def __call__(self,sample):
143 | (lx1,ly1,lx2,ly2), (hx1,hy1,hx2,hy2) = self.random_crop(sample)
144 | sample['lr_image'] = sample['lr_image'][ly1:ly2, lx1:lx2]
145 | sample['hr_image'] = sample['hr_image'][hy1:hy2, hx1:hx2]
146 | sample['hr_label'] = sample['hr_label'][hy1:hy2, hx1:hx2]
147 | sample['hr_unknown'] = sample['hr_unknown'][hy1:hy2, hx1:hx2]
148 |
149 | # random crop from high-resolution input
150 | h, w = sample['hr_label'].shape[:2]
151 | random_crop_size = random.choice(self.patch_crop_size)
152 | px1,py1,px2,py2 = get_random_patch(sample['hr_label'], random_crop_size)
153 | pos = np.array([px1/w,py1/h,px2/w,py2/h]).astype(np.float32)
154 | pos = np.clip(pos, 0, 1)
155 | sample['pos'] = pos
156 |
157 | load_size = (self.patch_load_size, self.patch_load_size)
158 | sample['hr_image'] = cv2.resize(sample['hr_image'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_LINEAR)
159 | sample['hr_label'] = cv2.resize(sample['hr_label'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_LINEAR)
160 | sample['hr_unknown'] = cv2.resize(sample['hr_unknown'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_NEAREST)
161 |
162 | if 'hr_fg' in sample:
163 | sample['hr_fg'] = sample['hr_fg'][hy1:hy2, hx1:hx2]
164 | sample['hr_fg'] = cv2.resize(sample['hr_fg'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_LINEAR)
165 | sample['hr_bg'] = sample['hr_bg'][hy1:hy2, hx1:hx2]
166 | sample['hr_bg'] = cv2.resize(sample['hr_bg'][py1:py2, px1:px2], load_size, interpolation=cv2.INTER_LINEAR)
167 | return sample
168 |
169 |
170 | class ToTensor(object):
171 | """Convert ndarrays in sample to Tensors."""
172 | def __init__(self, cfg):
173 | self.color_space = cfg.train.color_space
174 |
175 | def __call__(self, sample):
176 | sample['hr_label'] = sample['hr_label'] / 255.
177 | sample['hr_label'] = torch.from_numpy(sample['hr_label'][None].astype(np.float32))
178 | sample['hr_unknown'] = torch.from_numpy(sample['hr_unknown'][None].astype(np.float32))
179 |
180 | sample['hr_image'] = convert_color_space(sample['hr_image'], flag=self.color_space)
181 | sample['hr_image'] = torch.from_numpy(sample['hr_image'].transpose((2,0,1)).astype(np.float32))
182 | sample['lr_image'] = convert_color_space(sample['lr_image'], flag=self.color_space)
183 | sample['lr_image'] = torch.from_numpy(sample['lr_image'].transpose((2,0,1)).astype(np.float32))
184 |
185 | if 'pos' in sample:
186 | sample['pos'] = torch.from_numpy(sample['pos'].astype(np.float32))
187 | if 'hr_fg' in sample:
188 | sample['hr_fg'] = convert_color_space(sample['hr_fg'], flag=self.color_space)
189 | sample['hr_fg'] = torch.from_numpy(sample['hr_fg'].transpose((2,0,1)).astype(np.float32))
190 | sample['hr_bg'] = convert_color_space(sample['hr_bg'], flag=self.color_space)
191 | sample['hr_bg'] = torch.from_numpy(sample['hr_bg'].transpose((2,0,1)).astype(np.float32))
192 | return sample
193 |
194 |
195 | class CustomDataset(Dataset):
196 | def __init__(self,cfg, is_training, img_name_list, lbl_name_list,
197 | fg_name_list=None, bg_name_list=None, transform=None):
198 |
199 | self.cfg = cfg
200 | self.is_training = is_training
201 |
202 | self.image_name_list = img_name_list
203 | self.label_name_list = lbl_name_list
204 | self.fg_name_list = fg_name_list # for composition loss only!!!!!
205 | self.bg_name_list = bg_name_list # for composition loss only!!!!!
206 |
207 | self.transform = transform
208 |
209 | def __len__(self):
210 | return len(self.image_name_list)
211 |
212 | def __getitem__(self,idx):
213 |
214 | sample = {}
215 | sample['hr_image'] = imread(self.image_name_list[idx])
216 | sample['hr_label'] = imread(self.label_name_list[idx])[:,:,0]
217 |
218 | unknown = generate_unknown_label(sample['hr_label'], fixed=(not self.is_training))
219 | mask = (unknown==0) | (unknown==1)
220 | unknown[mask==1] = 0
221 | unknown[mask==0] = 1
222 | sample['hr_unknown'] = unknown
223 |
224 | if self.is_training and len(self.fg_name_list) == len(self.image_name_list):
225 | fg = imread(self.fg_name_list[idx])
226 | bg = imread(self.bg_name_list[idx])
227 | sample['hr_fg'] = fg
228 | sample['hr_bg'] = bg
229 |
230 | if self.transform:
231 | sample = self.transform(sample)
232 |
233 | return sample
234 |
235 |
236 | def generate_unknown_label(alpha, ksize=3, iterations=5, fixed=False):
237 | oH, oW = alpha.shape[:2]
238 | if not fixed:
239 | ksize_range=(3, 9)
240 | iter_range=(1, 15)
241 | ksize = random.randint(ksize_range[0], ksize_range[1])
242 | iterations = random.randint(iter_range[0], iter_range[1])
243 | else:
244 | ksize = 5
245 | iterations = 5
246 | ratio = 1280. / max(oH,oW)
247 | alpha = cv2.resize(alpha, None, fx=ratio, fy=ratio)
248 |
249 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
250 | dilated = cv2.dilate(alpha, kernel, iterations=iterations)
251 | eroded = cv2.erode(alpha, kernel, iterations=iterations)
252 | trimap = np.zeros(alpha.shape) + 128
253 | trimap[eroded >= 255] = 255
254 | trimap[dilated <= 0] = 0
255 | trimap = trimap.astype(np.uint8)
256 | if trimap.shape[0] != oH or trimap.shape[1] != oW:
257 | trimap = cv2.resize(trimap, (oW,oH), interpolation=cv2.INTER_NEAREST)
258 | return trimap
259 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | import math
5 | import cv2
6 |
7 | import torch.nn.functional as F
8 |
9 | from torch.utils.data import Dataset, DataLoader
10 | from torchvision import transforms, utils
11 |
12 | from PIL import Image
13 | from skimage import io, transform, color
14 |
15 |
16 | def get_random_patch(mask, crop_size):
17 | new_h, new_w = mask.shape[:2]
18 | crop_size = min(crop_size, min(new_w, new_h)-1)
19 | crop_size_hf = crop_size // 2
20 | maskf = mask / 255.
21 | ys, xs = np.where(np.logical_and(maskf>0.05, maskf<0.95))[:2]
22 | if len(ys)>0:
23 | rand_ind = random.randint(0, len(ys)-1)
24 | cy = min(max(ys[rand_ind], crop_size_hf), new_h-crop_size_hf)
25 | cx = min(max(xs[rand_ind], crop_size_hf), new_w-crop_size_hf)
26 | x1, y1 = cx - crop_size_hf, cy - crop_size_hf
27 | x2, y2 = x1 + crop_size, y1 + crop_size
28 | else:
29 | x1, y1 = new_w // 2 - crop_size_hf, new_h // 2 - crop_size_hf
30 | x2, y2 = x1 + crop_size, y1 + crop_size
31 | return (x1,y1,x2,y2)
32 |
33 |
34 | def convert_color_space(image, flag=3):
35 | if flag == 3:
36 | image = image / 255.0
37 | tmpImg = np.zeros((image.shape[0],image.shape[1],3))
38 | if image.shape[2]==1:
39 | tmpImg[:] = 2 * np.tile(image[:,:,None],(1,1,3)) - 1
40 | else:
41 | tmpImg[:] = 2 * image[:] - 1
42 |
43 | elif flag == 2: # with rgb and Lab colors
44 | tmpImg = np.zeros((image.shape[0],image.shape[1],6))
45 | tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
46 | if image.shape[2]==1:
47 | tmpImgt[:,:,0] = image[:,:,0]
48 | tmpImgt[:,:,1] = image[:,:,0]
49 | tmpImgt[:,:,2] = image[:,:,0]
50 | else:
51 | tmpImgt = image
52 | tmpImgtl = color.rgb2lab(tmpImgt)
53 |
54 | # nomalize image to range [0,1]
55 | tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
56 | tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
57 | tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
58 | tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
59 | tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
60 | tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
61 |
62 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
63 |
64 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
65 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
66 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
67 | tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
68 | tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
69 | tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
70 |
71 | elif flag == 1: #with Lab color
72 | tmpImg = np.zeros((image.shape[0],image.shape[1],3))
73 |
74 | if image.shape[2]==1:
75 | tmpImg[:,:,0] = image[:,:,0]
76 | tmpImg[:,:,1] = image[:,:,0]
77 | tmpImg[:,:,2] = image[:,:,0]
78 | else:
79 | tmpImg = image
80 |
81 | tmpImg = color.rgb2lab(tmpImg)
82 |
83 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
84 |
85 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
86 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
87 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
88 |
89 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
90 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
91 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
92 |
93 | else: # with rgb color
94 | tmpImg = np.zeros((image.shape[0],image.shape[1],3))
95 | image = image/np.max(image)
96 | if image.shape[2]==1:
97 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
98 | tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
99 | tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
100 | else:
101 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
102 | tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
103 | tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
104 |
105 | return tmpImg
106 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import cv2
5 | import math
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from model import SparseMat
13 | from utils import load_config
14 |
15 |
16 | def load_checkpoint(net, pretrained_model):
17 | net_state_dict = net.state_dict()
18 | state_dict = torch.load(pretrained_model)
19 | if 'state_dict' in state_dict:
20 | state_dict = state_dict['state_dict']
21 | elif 'model_state_dict' in state_dict:
22 | state_dict = state_dict['model_state_dict']
23 |
24 | filtered_state_dict = OrderedDict()
25 | for k,v in state_dict.items():
26 | if k.startswith('module'):
27 | nk = '.'.join(k.split('.')[1:])
28 | else:
29 | nk = k
30 | filtered_state_dict[nk] = v
31 | net.load_state_dict(filtered_state_dict)
32 | print('load pretrained weight from {} successfully'.format(pretrained_model))
33 |
34 |
35 | def preprocess(image):
36 | image = (image / 255. - 0.5) / 0.5
37 | image = torch.from_numpy(image[None]).permute(0,3,1,2)
38 | h, w = image.shape[2:]
39 | nh = math.ceil(h / 8) * 8
40 | nw = math.ceil(w / 8) * 8
41 | image = F.interpolate(image, (nh, nw), mode="bilinear")
42 | return image.float().cuda()
43 |
44 |
45 | def run_single_image(net, input_path, save_dir):
46 | filename = input_path.split('/')[-1]
47 | image = cv2.imread(input_path)
48 | origin_h, origin_w = image.shape[:2]
49 | tensor = preprocess(image)
50 | with torch.no_grad():
51 | pred = net.inference(tensor)
52 | pred = F.interpolate(pred, (origin_h, origin_w), align_corners=False, mode="bilinear")
53 | pred_alpha = (pred * 255).squeeze().data.cpu().numpy().astype(np.uint8)
54 | cv2.imwrite(os.path.join(save_dir, filename), pred_alpha)
55 | return pred
56 |
57 |
58 | def run_multiple_images(net, input_path, save_dir):
59 | for item in os.listdir(input_path):
60 | run_single_image(net, os.path.join(input_path, item), save_dir)
61 |
62 |
63 | def run_video(net, input_path, save_dir):
64 | filename = input_path.split('/')[-1]
65 | cap = cv2.VideoCapture(input_path)
66 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
67 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
68 | fps = cap.get(cv2.CAP_PROP_FPS)
69 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
70 | writer = cv2.VideoWriter(os.path.join(save_dir, filename), fourcc, fps, (width, height))
71 |
72 | last_frame = None
73 | last_pred = None
74 | while True:
75 | ret, frame = cap.read()
76 | if not ret:
77 | break
78 | tensor = preprocess(frame)
79 | with torch.no_grad():
80 | pred = net.inference(tensor, last_img=last_frame, last_pred=last_pred)
81 | pred = F.interpolate(pred, (height, width), align_corners=False, mode="bilinear")
82 | pred_alpha = (pred * 255).squeeze().data.cpu().numpy().astype(np.uint8)
83 | writer.write(np.tile(pred_alpha[:,:,None], (1,1,3)))
84 | last_frame = tensor
85 | last_pred = pred
86 |
87 |
88 | def main():
89 | parser = argparse.ArgumentParser()
90 | parser.add_argument('-c', '--config', type=str, metavar='FILE', help='path to config file')
91 | parser.add_argument('--input', type=str, metavar='PATH', help='path to input path')
92 | parser.add_argument('--save_dir', type=str, metavar='PATH', help='path to save path')
93 |
94 | args = parser.parse_args()
95 | cfg = load_config(args.config)
96 |
97 | os.makedirs(args.save_dir, exist_ok=True)
98 |
99 | net = SparseMat(cfg)
100 |
101 | if torch.cuda.is_available():
102 | net.cuda()
103 | else:
104 | exit()
105 |
106 | load_checkpoint(net, cfg.test.checkpoint)
107 |
108 | net.eval()
109 |
110 | if args.input.endswith(".mp4"):
111 | run_video(net, args.input, args.save_dir)
112 | elif args.input.endswith(".jpg") or args.input.endswith(".png"):
113 | run_single_image(net, args.input, args.save_dir)
114 | else:
115 | run_multiple_images(net, args.input, args.save_dir)
116 |
117 |
118 | if __name__ == "__main__":
119 | main()
120 |
--------------------------------------------------------------------------------
/figures/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nowsyn/SparseMat/2678757dfb7db185f91ee54e54d1e68944febded/figures/framework.png
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import SparseMat
2 | from .loss import losses
3 |
--------------------------------------------------------------------------------
/model/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .wrapper import MobileNetV2Backbone
2 | from .wrapper import MobileNetV3LargeBackbone
3 |
--------------------------------------------------------------------------------
/model/backbones/dilated_resnet_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import functools
4 |
5 |
6 | class ResnetDilatedBN(nn.Module):
7 | def __init__(self, args, orig_resnet, dilate_scale=8):
8 | super(ResnetDilatedBN, self).__init__()
9 | from functools import partial
10 |
11 | if dilate_scale == 8:
12 | orig_resnet.layer3.apply(
13 | partial(self._nostride_dilate, dilate=2))
14 | orig_resnet.layer4.apply(
15 | partial(self._nostride_dilate, dilate=4))
16 | elif dilate_scale == 16:
17 | orig_resnet.layer4.apply(
18 | partial(self._nostride_dilate, dilate=2))
19 |
20 | # take pretrained resnet, except AvgPool and FC
21 | self.conv1 = orig_resnet.conv1
22 | self.bn1 = orig_resnet.bn1
23 | self.relu1 = orig_resnet.relu1
24 |
25 | self.conv2 = orig_resnet.conv2
26 | self.bn2 = orig_resnet.bn2
27 | self.relu2 = orig_resnet.relu2
28 | self.conv3 = orig_resnet.conv3
29 | self.bn3 = orig_resnet.bn3
30 | self.relu3 = orig_resnet.relu3
31 |
32 | self.maxpool = orig_resnet.maxpool
33 | self.layer1 = orig_resnet.layer1
34 | self.layer2 = orig_resnet.layer2
35 | self.layer3 = orig_resnet.layer3
36 | self.layer4 = orig_resnet.layer4
37 |
38 | self.enc_channels = [128, 256, 512, 1024, 2048] # 2x, 4x, 8x, 8x, 8x
39 |
40 | def _nostride_dilate(self, m, dilate):
41 | classname = m.__class__.__name__
42 | if classname.find('Conv') != -1:
43 | # the convolution with stride
44 | if m.stride == (2, 2):
45 | m.stride = (1, 1)
46 | if m.kernel_size == (3, 3):
47 | m.dilation = (dilate // 2, dilate // 2)
48 | m.padding = (dilate // 2, dilate // 2)
49 | # other convoluions
50 | else:
51 | if m.kernel_size == (3, 3):
52 | m.dilation = (dilate, dilate)
53 | m.padding = (dilate, dilate)
54 |
55 | def forward(self, x):
56 | conv_out = []
57 | x = self.relu1(self.bn1(self.conv1(x)))
58 | x = self.relu2(self.bn2(self.conv2(x)))
59 | x = self.relu3(self.bn3(self.conv3(x)))
60 |
61 | conv_out.append(x) # 2x
62 | x, indices = self.maxpool(x)
63 | x = self.layer1(x)
64 | conv_out.append(x) # 4x
65 | x = self.layer2(x)
66 | conv_out.append(x) # 8x
67 | x = self.layer3(x)
68 | conv_out.append(x) # 16x
69 | x = self.layer4(x)
70 | conv_out.append(x) # 32x
71 | return conv_out
72 |
--------------------------------------------------------------------------------
/model/backbones/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | """ This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch"""
2 |
3 | import math
4 | import json
5 | from functools import reduce
6 |
7 | import torch
8 | from torch import nn
9 |
10 |
11 | #------------------------------------------------------------------------------
12 | # Useful functions
13 | #------------------------------------------------------------------------------
14 |
15 | def _make_divisible(v, divisor, min_value=None):
16 | if min_value is None:
17 | min_value = divisor
18 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
19 | # Make sure that round down does not go down by more than 10%.
20 | if new_v < 0.9 * v:
21 | new_v += divisor
22 | return new_v
23 |
24 |
25 | def conv_bn(inp, oup, stride, with_norm=True):
26 | if with_norm:
27 | return nn.Sequential(
28 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
29 | nn.BatchNorm2d(oup),
30 | nn.ReLU6(inplace=True)
31 | )
32 | else:
33 | return nn.Sequential(
34 | nn.Conv2d(inp, oup, 3, stride, 1, bias=True),
35 | nn.ReLU6(inplace=True)
36 | )
37 |
38 |
39 | def conv_1x1_bn(inp, oup, with_norm=True):
40 | if with_norm:
41 | return nn.Sequential(
42 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
43 | nn.BatchNorm2d(oup),
44 | nn.ReLU6(inplace=True)
45 | )
46 | else:
47 | return nn.Sequential(
48 | nn.Conv2d(inp, oup, 1, 1, 0, bias=True),
49 | nn.ReLU6(inplace=True)
50 | )
51 |
52 |
53 | #------------------------------------------------------------------------------
54 | # Class of Inverted Residual block
55 | #------------------------------------------------------------------------------
56 |
57 | class InvertedResidual(nn.Module):
58 | def __init__(self, inp, oup, stride, expansion, dilation=1, with_norm=True):
59 | super(InvertedResidual, self).__init__()
60 | self.stride = stride
61 | assert stride in [1, 2]
62 |
63 | hidden_dim = round(inp * expansion)
64 | self.use_res_connect = self.stride == 1 and inp == oup
65 |
66 | if expansion == 1:
67 | if with_norm:
68 | self.conv = nn.Sequential(
69 | # dw
70 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
71 | nn.BatchNorm2d(hidden_dim),
72 | nn.ReLU6(inplace=True),
73 | # pw-linear
74 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
75 | nn.BatchNorm2d(oup),
76 | )
77 | else:
78 | self.conv = nn.Sequential(
79 | # dw
80 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=True),
81 | nn.ReLU6(inplace=True),
82 | # pw-linear
83 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=True),
84 | )
85 | else:
86 | if with_norm:
87 | self.conv = nn.Sequential(
88 | # pw
89 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
90 | nn.BatchNorm2d(hidden_dim),
91 | nn.ReLU6(inplace=True),
92 | # dw
93 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
94 | nn.BatchNorm2d(hidden_dim),
95 | nn.ReLU6(inplace=True),
96 | # pw-linear
97 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
98 | nn.BatchNorm2d(oup),
99 | )
100 | else:
101 | self.conv = nn.Sequential(
102 | # pw
103 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=True),
104 | nn.ReLU6(inplace=True),
105 | # dw
106 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=True),
107 | nn.ReLU6(inplace=True),
108 | # pw-linear
109 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=True),
110 | )
111 |
112 | def forward(self, x):
113 | if self.use_res_connect:
114 | return x + self.conv(x)
115 | else:
116 | return self.conv(x)
117 |
118 |
119 | #------------------------------------------------------------------------------
120 | # Class of MobileNetV2
121 | #------------------------------------------------------------------------------
122 |
123 | class MobileNetV2(nn.Module):
124 | def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000, with_norm=True):
125 | super(MobileNetV2, self).__init__()
126 | self.in_channels = in_channels
127 | self.num_classes = num_classes
128 | input_channel = 32
129 | last_channel = 1280
130 | interverted_residual_setting = [
131 | # t, c, n, s
132 | [1 , 16, 1, 1],
133 | [expansion, 24, 2, 2],
134 | [expansion, 32, 3, 2],
135 | [expansion, 64, 4, 2],
136 | [expansion, 96, 3, 1],
137 | [expansion, 160, 3, 2],
138 | [expansion, 320, 1, 1],
139 | ]
140 |
141 | # building first layer
142 | input_channel = _make_divisible(input_channel*alpha, 8)
143 | self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel
144 | self.features = [conv_bn(self.in_channels, input_channel, 2, with_norm=with_norm)]
145 |
146 | # building inverted residual blocks
147 | idx = 1 # [0, 2, 4, 7, 14]
148 | for t, c, n, s in interverted_residual_setting:
149 | output_channel = _make_divisible(int(c*alpha), 8)
150 | for i in range(n):
151 | if i == 0:
152 | self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t, with_norm=with_norm))
153 | else:
154 | self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t, with_norm=with_norm))
155 | idx += 1
156 | input_channel = output_channel
157 |
158 | # building last several layers
159 | self.features.append(conv_1x1_bn(input_channel, self.last_channel, with_norm=with_norm))
160 |
161 | # make it nn.Sequential
162 | self.features = nn.Sequential(*self.features)
163 |
164 | # building classifier
165 | if self.num_classes is not None:
166 | self.classifier = nn.Sequential(
167 | nn.Dropout(0.2),
168 | nn.Linear(self.last_channel, num_classes),
169 | )
170 |
171 | # Initialize weights
172 | self._init_weights()
173 |
174 | def forward(self, x, feature_names=None):
175 | # Stage1
176 | x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)
177 | # Stage2
178 | x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)
179 | # Stage3
180 | x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)
181 | # Stage4
182 | x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)
183 | # Stage5
184 | x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)
185 |
186 | # Classification
187 | if self.num_classes is not None:
188 | x = x.mean(dim=(2,3))
189 | x = self.classifier(x)
190 |
191 | # Output
192 | return x
193 |
194 | def _load_pretrained_model(self, pretrained_file):
195 | pretrain_dict = torch.load(pretrained_file, map_location='cpu')
196 | model_dict = {}
197 | state_dict = self.state_dict()
198 | print("[MobileNetV2] Loading pretrained model...")
199 | for k, v in pretrain_dict.items():
200 | if k in state_dict:
201 | model_dict[k] = v
202 | else:
203 | print(k, "is ignored")
204 | state_dict.update(model_dict)
205 | self.load_state_dict(state_dict)
206 |
207 | def _init_weights(self):
208 | for m in self.modules():
209 | if isinstance(m, nn.Conv2d):
210 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
211 | m.weight.data.normal_(0, math.sqrt(2. / n))
212 | if m.bias is not None:
213 | m.bias.data.zero_()
214 | elif isinstance(m, nn.BatchNorm2d):
215 | m.weight.data.fill_(1)
216 | m.bias.data.zero_()
217 | elif isinstance(m, nn.Linear):
218 | n = m.weight.size(1)
219 | m.weight.data.normal_(0, 0.01)
220 | m.bias.data.zero_()
221 |
222 |
223 |
224 | if __name__ == "__main__":
225 | net = MobileNetV2(3)
226 | net.cuda()
227 | inputs = torch.ones((1,3,512,512)).cuda()
228 | outs = net(inputs)
229 |
--------------------------------------------------------------------------------
/model/backbones/mobilenetv3.py:
--------------------------------------------------------------------------------
1 | """
2 | Creates a MobileNetV3 Model as defined in:
3 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019).
4 | Searching for MobileNetV3
5 | arXiv preprint arXiv:1905.02244.
6 | """
7 |
8 | import torch.nn as nn
9 | import math
10 |
11 |
12 | __all__ = ['mobilenetv3_large', 'mobilenetv3_small']
13 |
14 |
15 | def _make_divisible(v, divisor, min_value=None):
16 | """
17 | This function is taken from the original tf repo.
18 | It ensures that all layers have a channel number that is divisible by 8
19 | It can be seen here:
20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
21 | :param v:
22 | :param divisor:
23 | :param min_value:
24 | :return:
25 | """
26 | if min_value is None:
27 | min_value = divisor
28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
29 | # Make sure that round down does not go down by more than 10%.
30 | if new_v < 0.9 * v:
31 | new_v += divisor
32 | return new_v
33 |
34 |
35 | class h_sigmoid(nn.Module):
36 | def __init__(self, inplace=True):
37 | super(h_sigmoid, self).__init__()
38 | self.relu = nn.ReLU6(inplace=inplace)
39 |
40 | def forward(self, x):
41 | return self.relu(x + 3) / 6
42 |
43 |
44 | class h_swish(nn.Module):
45 | def __init__(self, inplace=True):
46 | super(h_swish, self).__init__()
47 | self.sigmoid = h_sigmoid(inplace=inplace)
48 |
49 | def forward(self, x):
50 | return x * self.sigmoid(x)
51 |
52 |
53 | class SELayer(nn.Module):
54 | def __init__(self, channel, reduction=4):
55 | super(SELayer, self).__init__()
56 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
57 | self.fc = nn.Sequential(
58 | nn.Linear(channel, _make_divisible(channel // reduction, 8)),
59 | nn.ReLU(inplace=True),
60 | nn.Linear(_make_divisible(channel // reduction, 8), channel),
61 | h_sigmoid()
62 | )
63 |
64 | def forward(self, x):
65 | b, c, _, _ = x.size()
66 | y = self.avg_pool(x).view(b, c)
67 | y = self.fc(y).view(b, c, 1, 1)
68 | return x * y
69 |
70 |
71 | def conv_3x3_bn(inp, oup, stride):
72 | return nn.Sequential(
73 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
74 | nn.BatchNorm2d(oup),
75 | h_swish()
76 | )
77 |
78 |
79 | def conv_1x1_bn(inp, oup):
80 | return nn.Sequential(
81 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
82 | nn.BatchNorm2d(oup),
83 | h_swish()
84 | )
85 |
86 |
87 | class InvertedResidual(nn.Module):
88 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
89 | super(InvertedResidual, self).__init__()
90 | assert stride in [1, 2]
91 |
92 | self.identity = stride == 1 and inp == oup
93 |
94 | if inp == hidden_dim:
95 | self.conv = nn.Sequential(
96 | # dw
97 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
98 | nn.BatchNorm2d(hidden_dim),
99 | h_swish() if use_hs else nn.ReLU(inplace=True),
100 | # Squeeze-and-Excite
101 | SELayer(hidden_dim) if use_se else nn.Identity(),
102 | # pw-linear
103 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
104 | nn.BatchNorm2d(oup),
105 | )
106 | else:
107 | self.conv = nn.Sequential(
108 | # pw
109 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
110 | nn.BatchNorm2d(hidden_dim),
111 | h_swish() if use_hs else nn.ReLU(inplace=True),
112 | # dw
113 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
114 | nn.BatchNorm2d(hidden_dim),
115 | # Squeeze-and-Excite
116 | SELayer(hidden_dim) if use_se else nn.Identity(),
117 | h_swish() if use_hs else nn.ReLU(inplace=True),
118 | # pw-linear
119 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
120 | nn.BatchNorm2d(oup),
121 | )
122 |
123 | def forward(self, x):
124 | if self.identity:
125 | return x + self.conv(x)
126 | else:
127 | return self.conv(x)
128 |
129 |
130 | class MobileNetV3(nn.Module):
131 | def __init__(self, in_channels, mode='large', num_classes=None, width_mult=1., with_norm=True):
132 | super(MobileNetV3, self).__init__()
133 | # setting of inverted residual blocks
134 | cfgs = [
135 | # k, t, c, SE, HS, s
136 | [3, 1, 16, 0, 0, 1],
137 | [3, 4, 24, 0, 0, 2],
138 | [3, 3, 24, 0, 0, 1],
139 | [5, 3, 40, 1, 0, 2],
140 | [5, 3, 40, 1, 0, 1],
141 | [5, 3, 40, 1, 0, 1],
142 | [3, 6, 80, 0, 1, 2],
143 | [3, 2.5, 80, 0, 1, 1],
144 | [3, 2.3, 80, 0, 1, 1],
145 | [3, 2.3, 80, 0, 1, 1],
146 | [3, 6, 112, 1, 1, 1],
147 | [3, 6, 112, 1, 1, 1],
148 | [5, 6, 160, 1, 1, 2],
149 | [5, 6, 160, 1, 1, 1],
150 | [5, 6, 160, 1, 1, 1]
151 | ]
152 | self.cfgs = cfgs
153 | assert mode in ['large', 'small']
154 |
155 | # building first layer
156 | input_channel = _make_divisible(16 * width_mult, 8)
157 | layers = [conv_3x3_bn(3, input_channel, 2)]
158 | # self.features = [conv_3x3_bn(in_channels, input_channel, 2)]
159 | # building inverted residual blocks
160 | block = InvertedResidual
161 | for k, t, c, use_se, use_hs, s in self.cfgs:
162 | output_channel = _make_divisible(c * width_mult, 8)
163 | exp_size = _make_divisible(input_channel * t, 8)
164 | layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
165 | input_channel = output_channel
166 | # building last several layers
167 | # self.conv = conv_1x1_bn(input_channel, exp_size)
168 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
169 | output_channel = {'large': 1280, 'small': 1024}
170 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode]
171 | layers.append(conv_1x1_bn(input_channel, output_channel))
172 | self.features = nn.Sequential(*layers)
173 |
174 | self.num_classes = num_classes
175 | if self.num_classes is not None:
176 | self.classifier = nn.Sequential(
177 | nn.Linear(exp_size, output_channel),
178 | h_swish(),
179 | nn.Dropout(0.2),
180 | nn.Linear(output_channel, num_classes),
181 | )
182 |
183 | self._initialize_weights()
184 |
185 | def forward(self, x):
186 | # x = self.features(x)
187 | # x = self.conv(x)
188 | # x = self.avgpool(x)
189 | # x = x.view(x.size(0), -1)
190 | # x = self.classifier(x)
191 |
192 | # Stage1
193 | x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)
194 | # Stage2
195 | x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)
196 | # Stage3
197 | x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)
198 | # Stage4
199 | x = reduce(lambda x, n: self.features[n](x), list(range(7,13)), x)
200 | # Stage5
201 | x = reduce(lambda x, n: self.features[n](x), list(range(13,17)), x)
202 |
203 | # Classification
204 | if self.num_classes is not None:
205 | x = x.mean(dim=(2,3))
206 | x = self.classifier(x)
207 |
208 | # Output
209 | return x
210 |
211 | def _load_pretrained_model(self, pretrained_file):
212 | pretrain_dict = torch.load(pretrained_file, map_location='cpu')
213 | model_dict = {}
214 | state_dict = self.state_dict()
215 | print("[MobileNetV2] Loading pretrained model...")
216 | for k, v in pretrain_dict.items():
217 | if k in state_dict:
218 | model_dict[k] = v
219 | else:
220 | print(k, "is ignored")
221 | state_dict.update(model_dict)
222 | self.load_state_dict(state_dict)
223 | return x
224 |
225 | def _initialize_weights(self):
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
229 | m.weight.data.normal_(0, math.sqrt(2. / n))
230 | if m.bias is not None:
231 | m.bias.data.zero_()
232 | elif isinstance(m, nn.BatchNorm2d):
233 | m.weight.data.fill_(1)
234 | m.bias.data.zero_()
235 | elif isinstance(m, nn.Linear):
236 | m.weight.data.normal_(0, 0.01)
237 | m.bias.data.zero_()
238 |
239 |
240 | def mobilenetv3_large(**kwargs):
241 | """
242 | Constructs a MobileNetV3-Large model
243 | """
244 | cfgs = [
245 | # k, t, c, SE, HS, s
246 | [3, 1, 16, 0, 0, 1],
247 | [3, 4, 24, 0, 0, 2],
248 | [3, 3, 24, 0, 0, 1],
249 | [5, 3, 40, 1, 0, 2],
250 | [5, 3, 40, 1, 0, 1],
251 | [5, 3, 40, 1, 0, 1],
252 | [3, 6, 80, 0, 1, 2],
253 | [3, 2.5, 80, 0, 1, 1],
254 | [3, 2.3, 80, 0, 1, 1],
255 | [3, 2.3, 80, 0, 1, 1],
256 | [3, 6, 112, 1, 1, 1],
257 | [3, 6, 112, 1, 1, 1],
258 | [5, 6, 160, 1, 1, 2],
259 | [5, 6, 160, 1, 1, 1],
260 | [5, 6, 160, 1, 1, 1]
261 | ]
262 | return MobileNetV3(cfgs, mode='large', **kwargs)
263 |
264 |
265 | def mobilenetv3_small(**kwargs):
266 | """
267 | Constructs a MobileNetV3-Small model
268 | """
269 | cfgs = [
270 | # k, t, c, SE, HS, s
271 | [3, 1, 16, 1, 0, 2],
272 | [3, 4.5, 24, 0, 0, 2],
273 | [3, 3.67, 24, 0, 0, 1],
274 | [5, 4, 40, 1, 1, 2],
275 | [5, 6, 40, 1, 1, 1],
276 | [5, 6, 40, 1, 1, 1],
277 | [5, 3, 48, 1, 1, 1],
278 | [5, 3, 48, 1, 1, 1],
279 | [5, 6, 96, 1, 1, 2],
280 | [5, 6, 96, 1, 1, 1],
281 | [5, 6, 96, 1, 1, 1],
282 | ]
283 |
284 | return MobileNetV3(cfgs, mode='small', **kwargs)
285 |
286 |
287 | if __name__ == "__main__":
288 | model = MobileNetV3(3)
289 | print(model)
290 | print(len(model.features))
291 |
--------------------------------------------------------------------------------
/model/backbones/resnet_bn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torch.nn import BatchNorm2d
6 | # from modules.nn import BatchNorm2d
7 | from collections import OrderedDict
8 |
9 | try:
10 | from torch.hub import load_state_dict_from_url
11 | except ImportError:
12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
13 |
14 | __all__ = ['ResNet']
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
27 | }
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1):
31 | "3x3 convolution with padding"
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=1, bias=False)
34 |
35 | def conv7x7(in_planes, out_planes, stride=1):
36 | "3x3 convolution with padding"
37 | return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride,
38 | padding=3, bias=False)
39 |
40 |
41 | class BasicBlock(nn.Module):
42 | expansion = 1
43 |
44 | def __init__(self, inplanes, planes, stride=1, downsample=None):
45 | super(BasicBlock, self).__init__()
46 | self.conv1 = conv3x3(inplanes, planes, stride)
47 | self.bn1 = BatchNorm2d(planes)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.conv2 = conv3x3(planes, planes)
50 | self.bn2 = BatchNorm2d(planes)
51 | self.downsample = downsample
52 | self.stride = stride
53 |
54 | def forward(self, x):
55 | residual = x
56 |
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 |
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 |
64 | if self.downsample is not None:
65 | residual = self.downsample(x)
66 |
67 | out += residual
68 | out = self.relu(out)
69 |
70 | return out
71 |
72 |
73 | class Bottleneck(nn.Module):
74 | expansion = 4
75 |
76 | def __init__(self, inplanes, planes, stride=1, downsample=None):
77 | super(Bottleneck, self).__init__()
78 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
79 | self.bn1 = BatchNorm2d(planes)
80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
81 | padding=1, bias=False)
82 | self.bn2 = BatchNorm2d(planes, momentum=0.01)
83 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
84 | self.bn3 = BatchNorm2d(planes * 4)
85 | self.relu = nn.ReLU(inplace=True)
86 | self.downsample = downsample
87 | self.stride = stride
88 |
89 | def forward(self, x):
90 | residual = x
91 |
92 | out = self.conv1(x)
93 | out = self.bn1(out)
94 | out = self.relu(out)
95 |
96 | out = self.conv2(out)
97 | out = self.bn2(out)
98 | out = self.relu(out)
99 |
100 | out = self.conv3(out)
101 | out = self.bn3(out)
102 |
103 | if self.downsample is not None:
104 | residual = self.downsample(x)
105 |
106 | out += residual
107 | out = self.relu(out)
108 |
109 | return out
110 |
111 |
112 | class ResNet(nn.Module):
113 |
114 | def __init__(self, block, layers, num_classes=1000, inplanes=128, conv7x7=False):
115 | self.inplanes = inplanes
116 | super(ResNet, self).__init__()
117 | self.conv7x7 = conv7x7
118 | if self.conv7x7:
119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
120 | self.bn1 = BatchNorm2d(64)
121 | self.relu1 = nn.ReLU(inplace=True)
122 | else:
123 | self.conv1 = conv3x3(3, 64, stride=2)
124 | self.bn1 = BatchNorm2d(64)
125 | self.relu1 = nn.ReLU(inplace=True)
126 | self.conv2 = conv3x3(64, 64)
127 | self.bn2 = BatchNorm2d(64)
128 | self.relu2 = nn.ReLU(inplace=True)
129 | self.conv3 = conv3x3(64, 128)
130 | self.bn3 = BatchNorm2d(128)
131 | self.relu3 = nn.ReLU(inplace=True)
132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
133 |
134 | self.layer1 = self._make_layer(block, 64, layers[0])
135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
137 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
138 | self.avgpool = nn.AvgPool2d(7, stride=1)
139 | self.fc = nn.Linear(512 * block.expansion, num_classes)
140 |
141 | for m in self.modules():
142 | if isinstance(m, nn.Conv2d):
143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
144 | m.weight.data.normal_(0, math.sqrt(2. / n))
145 | elif isinstance(m, BatchNorm2d):
146 | m.weight.data.fill_(1)
147 | m.bias.data.zero_()
148 |
149 | def _make_layer(self, block, planes, blocks, stride=1):
150 | downsample = None
151 | if stride != 1 or self.inplanes != planes * block.expansion:
152 | downsample = nn.Sequential(
153 | nn.Conv2d(self.inplanes, planes * block.expansion,
154 | kernel_size=1, stride=stride, bias=False),
155 | BatchNorm2d(planes * block.expansion),
156 | )
157 |
158 | layers = []
159 | layers.append(block(self.inplanes, planes, stride, downsample))
160 | self.inplanes = planes * block.expansion
161 | for i in range(1, blocks):
162 | layers.append(block(self.inplanes, planes))
163 |
164 | return nn.Sequential(*layers)
165 |
166 | def forward(self, x):
167 | if self.conv7x7:
168 | x = self.relu1(self.bn1(self.conv1(x)))
169 | else:
170 | x = self.relu1(self.bn1(self.conv1(x)))
171 | x = self.relu2(self.bn2(self.conv2(x)))
172 | x = self.relu3(self.bn3(self.conv3(x)))
173 | x, indices = self.maxpool(x)
174 |
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 |
180 | x = self.avgpool(x)
181 | x = x.view(x.size(0), -1)
182 | x = self.fc(x)
183 | return x
184 |
185 |
186 | def l_resnet50(pretrained=False):
187 | """Constructs a ResNet-50 model.
188 | Args:
189 | pretrained (bool): If True, returns a model pre-trained on ImageNet
190 | """
191 | model = ResNet(Bottleneck, [3, 4, 6, 3], inplanes=128, conv7x7=False)
192 | if pretrained:
193 | state_dict = torch.load('pretrained_model/resnet50_v1c.pth')
194 | model.load_state_dict(state_dict, strict=True)
195 | return model
196 |
197 |
198 | if __name__ == "__main__":
199 | model = l_resnet50(pretrained=True)
200 |
--------------------------------------------------------------------------------
/model/backbones/sparse_resnet_bn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import spconv
5 |
6 | from torch.nn import BatchNorm1d
7 | from collections import OrderedDict
8 |
9 | try:
10 | from torch.hub import load_state_dict_from_url
11 | except ImportError:
12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
13 |
14 | __all__ = ['ResNet']
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
27 | }
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
31 | "3x3 convolution with padding"
32 | return spconv.SubMConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, indice_key=indice_key)
33 |
34 | def conv7x7(in_planes, out_planes, stride=1, indice_key=None):
35 | "3x3 convolution with padding"
36 | return spconv.SubMConv2d(in_planes, out_planes, kernel_size=7, stride=stride, padding=3, bias=False, indice_key=indice_key)
37 |
38 |
39 | class BasicBlock(spconv.SparseModule):
40 | expansion = 1
41 |
42 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, padding=1,
43 | first_indice_key=None, middle_indice_key=None, last_indice_key=None):
44 |
45 | super(BasicBlock, self).__init__()
46 | if stride == 2:
47 | self.conv1 = spconv.SparseConv2d(inplanes, planes, 3, stride, dilation=dilation, padding=padding, bias=False, indice_key=middle_indice_key)
48 | else:
49 | self.conv1 = spconv.SubMConv2d(inplanes, planes, 3, stride, dilation=dilation, padding=padding, bias=False, indice_key=middle_indice_key)
50 | self.bn1 = nn.BatchNorm1d(planes)
51 | self.relu1 = nn.ReLU(inplace=True)
52 | self.conv2 = spconv.SubMConv2d(planes, planes, 3, 1, padding=1, indice_key=last_indice_key)
53 | self.bn2 = nn.BatchNorm1d(planes)
54 |
55 | self.relu = nn.ReLU(inplace=True)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | residual = x
61 | out = self.conv1(x)
62 | out.features = self.bn1(out.features)
63 | out.features = self.relu1(out.features)
64 | out = self.conv2(out)
65 | out.features = self.bn2(out.features)
66 | if self.downsample is not None:
67 | residual = self.downsample(x)
68 | out.features = out.features + residual.features
69 | out.features = self.relu(out.features)
70 | return out
71 |
72 |
73 | class Bottleneck(spconv.SparseModule):
74 | expansion = 4
75 |
76 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, padding=1,
77 | first_indice_key=None, middle_indice_key=None, last_indice_key=None):
78 |
79 | super(Bottleneck, self).__init__()
80 | self.conv1 = spconv.SubMConv2d(inplanes, planes, kernel_size=1, bias=False, indice_key=first_indice_key)
81 | self.bn1 = nn.BatchNorm1d(planes)
82 | if stride == 2:
83 | self.conv2 = spconv.SparseConv2d(planes, planes, 3, stride=stride, dilation=dilation, padding=padding, bias=False,
84 | indice_key=middle_indice_key)
85 | else:
86 | self.conv2 = spconv.SubMConv2d(planes, planes, 3, stride=stride, dilation=dilation, padding=padding, bias=False,
87 | indice_key=middle_indice_key)
88 | self.bn2 = nn.BatchNorm1d(planes, momentum=0.01)
89 | self.conv3 = spconv.SubMConv2d(planes, planes * 4, kernel_size=1, bias=False, indice_key=last_indice_key)
90 | self.bn3 = nn.BatchNorm1d(planes * 4)
91 | self.relu = nn.ReLU(inplace=True)
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 | residual = x
97 |
98 | out = self.conv1(x)
99 | out.features = self.bn1(out.features)
100 | out.features = self.relu(out.features)
101 |
102 | out = self.conv2(out)
103 | out.features = self.bn2(out.features)
104 | out.features = self.relu(out.features)
105 |
106 | out = self.conv3(out)
107 | out.features = self.bn3(out.features)
108 |
109 | if self.downsample is not None:
110 | residual = self.downsample(x)
111 |
112 | out.features = out.features + residual.features
113 | out.features = self.relu(out.features)
114 | return out
115 |
116 |
117 | class SparseResNet18(spconv.SparseModule):
118 |
119 | def __init__(self, inc, stride, block, layers, num_classes=1000, inplanes=128, conv7x7=False):
120 | self.inplanes = inplanes
121 | super(SparseResNet18, self).__init__()
122 |
123 | self.enc_channels = [64, 64, 128, 256, 512]
124 |
125 | self.conv1 = spconv.SubMConv2d(inc, 64, kernel_size=3, padding=1, bias=False, indice_key='subm0s')
126 | self.bn1 = nn.BatchNorm1d(64)
127 | self.relu1 = nn.ReLU(inplace=True)
128 | self.conv2 = spconv.SparseConv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False, indice_key='spconv0')
129 | self.bn2 = nn.BatchNorm1d(64)
130 | self.relu2 = nn.ReLU(inplace=True)
131 | self.conv3 = spconv.SubMConv2d(64, 64, kernel_size=3, padding=1, bias=False, indice_key='subm0e')
132 | self.bn3 = nn.BatchNorm1d(64)
133 | self.relu3 = nn.ReLU(inplace=True)
134 |
135 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2, dilation=1, padding=1,
136 | first_indice_key='subm1s', middle_indice_key='spconv1', last_indice_key='subm1e')
137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilation=1, padding=1,
138 | first_indice_key='subm2s', middle_indice_key='spconv2', last_indice_key='subm2e')
139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=int(max(1,stride/8)), dilation=1, padding=1,
140 | first_indice_key='subm3s', middle_indice_key='spconv3', last_indice_key='subm3e')
141 | self.layer4 = self._make_layer(block, 512, layers[3], stride=int(max(1,stride/16)), dilation=2, padding=2,
142 | first_indice_key='subm4s', middle_indice_key='spconv4', last_indice_key='subm4e')
143 |
144 | for m in self.modules():
145 | if isinstance(m, nn.Conv2d):
146 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
147 | m.weight.data.normal_(0, math.sqrt(2. / n))
148 | elif isinstance(m, BatchNorm1d):
149 | m.weight.data.fill_(1)
150 | m.bias.data.zero_()
151 |
152 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, padding=1,
153 | first_indice_key=None, middle_indice_key=None, last_indice_key=None):
154 | downsample = None
155 | if stride != 1 or self.inplanes != planes * block.expansion:
156 | if stride == 2:
157 | downsample = spconv.SparseSequential(
158 | spconv.SparseConv2d(self.inplanes, planes * block.expansion, kernel_size=3, stride=stride, padding=1,
159 | bias=False, indice_key=middle_indice_key),
160 | nn.BatchNorm1d(planes * block.expansion),
161 | )
162 | else:
163 | downsample = spconv.SparseSequential(
164 | spconv.SubMConv2d(self.inplanes, planes * block.expansion, kernel_size=3, stride=stride, padding=1,
165 | bias=False, indice_key=middle_indice_key),
166 | nn.BatchNorm1d(planes * block.expansion),
167 | )
168 | layers = []
169 | layers.append(block(self.inplanes, planes, stride, downsample, dilation, padding,
170 | first_indice_key=first_indice_key, middle_indice_key=middle_indice_key, last_indice_key=last_indice_key))
171 | self.inplanes = planes * block.expansion
172 | for i in range(1, blocks):
173 | layers.append(block(self.inplanes, planes))
174 | return spconv.SparseSequential(*layers)
175 |
176 | def forward(self, x):
177 | outs = []
178 | x = self.conv1(x)
179 | x.features = self.relu1(self.bn1(x.features))
180 | x = self.conv2(x)
181 | x.features = self.relu2(self.bn2(x.features))
182 | x = self.conv3(x)
183 | x.features = self.relu3(self.bn3(x.features))
184 | outs.append(x)
185 |
186 | x = self.layer1(x)
187 | outs.append(x)
188 |
189 | x = self.layer2(x)
190 | outs.append(x)
191 |
192 | x = self.layer3(x)
193 | outs.append(x)
194 |
195 | x = self.layer4(x)
196 | outs.append(x)
197 | return outs
198 |
199 |
200 | class SparseResNet(spconv.SparseModule):
201 |
202 | def __init__(self, inc, stride, block, layers, num_classes=1000, inplanes=128, conv7x7=False):
203 | self.inplanes = inplanes
204 | super(SparseResNet, self).__init__()
205 |
206 | self.enc_channels = [128, 256, 512, 1024, 2048]
207 |
208 | self.conv1 = spconv.SubMConv2d(inc, 64, kernel_size=3, padding=1, bias=False, indice_key='subm0s')
209 | self.bn1 = nn.BatchNorm1d(64)
210 | self.relu1 = nn.ReLU(inplace=True)
211 | self.conv2 = spconv.SparseConv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False, indice_key='spconv0')
212 | self.bn2 = nn.BatchNorm1d(64)
213 | self.relu2 = nn.ReLU(inplace=True)
214 | self.conv3 = spconv.SubMConv2d(64, 128, kernel_size=3, padding=1, bias=False, indice_key='subm0e')
215 | self.bn3 = nn.BatchNorm1d(128)
216 | self.relu3 = nn.ReLU(inplace=True)
217 |
218 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2, dilation=1, padding=1,
219 | first_indice_key='subm1s', middle_indice_key='spconv1', last_indice_key='subm1e')
220 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilation=1, padding=1,
221 | first_indice_key='subm2s', middle_indice_key='spconv2', last_indice_key='subm2e')
222 | self.layer3 = self._make_layer(block, 256, layers[2], stride=int(max(1,stride/8)), dilation=1, padding=1,
223 | first_indice_key='subm3s', middle_indice_key='spconv3', last_indice_key='subm3e')
224 | self.layer4 = self._make_layer(block, 512, layers[3], stride=int(max(1,stride/16)), dilation=2, padding=2,
225 | first_indice_key='subm4s', middle_indice_key='spconv4', last_indice_key='subm4e')
226 |
227 | for m in self.modules():
228 | if isinstance(m, nn.Conv2d):
229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
230 | m.weight.data.normal_(0, math.sqrt(2. / n))
231 | elif isinstance(m, BatchNorm1d):
232 | m.weight.data.fill_(1)
233 | m.bias.data.zero_()
234 |
235 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, padding=1,
236 | first_indice_key=None, middle_indice_key=None, last_indice_key=None):
237 | downsample = None
238 | if stride != 1 or self.inplanes != planes * block.expansion:
239 | if stride == 2:
240 | downsample = spconv.SparseSequential(
241 | spconv.SparseConv2d(self.inplanes, planes * block.expansion, kernel_size=3, stride=stride, padding=1,
242 | bias=False, indice_key=middle_indice_key),
243 | nn.BatchNorm1d(planes * block.expansion),
244 | )
245 | else:
246 | downsample = spconv.SparseSequential(
247 | spconv.SubMConv2d(self.inplanes, planes * block.expansion, kernel_size=3, stride=stride, padding=1,
248 | bias=False, indice_key=middle_indice_key),
249 | nn.BatchNorm1d(planes * block.expansion),
250 | )
251 | layers = []
252 | layers.append(block(self.inplanes, planes, stride, downsample, dilation, padding,
253 | first_indice_key=first_indice_key, middle_indice_key=middle_indice_key, last_indice_key=last_indice_key))
254 | self.inplanes = planes * block.expansion
255 | for i in range(1, blocks):
256 | layers.append(block(self.inplanes, planes))
257 | return spconv.SparseSequential(*layers)
258 |
259 | def forward(self, x):
260 | outs = []
261 | x = self.conv1(x)
262 | x.features = self.relu1(self.bn1(x.features))
263 | x = self.conv2(x)
264 | x.features = self.relu2(self.bn2(x.features))
265 | x = self.conv3(x)
266 | x.features = self.relu3(self.bn3(x.features))
267 | outs.append(x)
268 |
269 | x = self.layer1(x)
270 | outs.append(x)
271 |
272 | x = self.layer2(x)
273 | outs.append(x)
274 |
275 | x = self.layer3(x)
276 | outs.append(x)
277 |
278 | x = self.layer4(x)
279 | outs.append(x)
280 | return outs
281 |
282 |
283 | def l_sparse_resnet18(inc, stride=8, pretrained=False):
284 | """Constructs a ResNet-50 model.
285 | Args:
286 | pretrained (bool): If True, returns a model pre-trained on ImageNet
287 | """
288 | model = SparseResNet18(inc, stride, BasicBlock, [2, 2, 2, 2], inplanes=128, conv7x7=True)
289 | if pretrained:
290 | state_dict = torch.load('pretrained_model/resnet18.pth')
291 | model.load_state_dict(state_dict, strict=True)
292 | return model
293 |
294 |
295 | def l_sparse_resnet50(inc, stride=8, pretrained=False):
296 | """Constructs a ResNet-50 model.
297 | Args:
298 | pretrained (bool): If True, returns a model pre-trained on ImageNet
299 | """
300 | model = SparseResNet(inc, stride, Bottleneck, [3, 4, 6, 3], inplanes=128, conv7x7=False)
301 | if pretrained:
302 | state_dict = torch.load('pretrained_model/resnet50_v1c.pth')
303 | model.load_state_dict(state_dict, strict=True)
304 | return model
305 |
306 |
307 | if __name__ == "__main__":
308 | model = ResNet(Bottleneck, [3, 4, 6, 3], inplanes=128, conv7x7=False)
309 | print(model)
310 |
--------------------------------------------------------------------------------
/model/backbones/wrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import reduce
3 | from collections import OrderedDict
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from model.utils import load_pretrained_weight
9 | from .mobilenetv2 import MobileNetV2
10 | from .mobilenetv3 import MobileNetV3
11 |
12 |
13 | class BaseBackbone(nn.Module):
14 | """ Superclass of Replaceable Backbone Model for Semantic Estimation
15 | """
16 |
17 | def __init__(self, in_channels):
18 | super(BaseBackbone, self).__init__()
19 | self.in_channels = in_channels
20 |
21 | self.model = None
22 | self.enc_channels = []
23 |
24 | def forward(self, x):
25 | raise NotImplementedError
26 |
27 | def load_pretrained_ckpt(self):
28 | raise NotImplementedError
29 |
30 |
31 | class MobileNetV2Backbone(BaseBackbone):
32 | """ MobileNetV2 Backbone
33 | """
34 |
35 | def __init__(self, in_channels, with_norm=True):
36 | super(MobileNetV2Backbone, self).__init__(in_channels)
37 |
38 | self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None, with_norm=with_norm)
39 | self.enc_channels = [16, 24, 32, 96, 1280]
40 |
41 | def forward(self, x):
42 | x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
43 | enc2x = x
44 | x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
45 | enc4x = x
46 | x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
47 | enc8x = x
48 | x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
49 | enc16x = x
50 | x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
51 | enc32x = x
52 | return [enc2x, enc4x, enc8x, enc16x, enc32x]
53 |
54 | def load_pretrained_ckpt(self):
55 | # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
56 | ckpt_path = './pretrained_model/mobilenetv2_human_seg.ckpt'
57 | self.model = load_pretrained_weight(self.model, ckpt_path)
58 | print('load pretrained weight from {} successfully'.format(ckpt_path))
59 |
60 |
61 | class MobileNetV3LargeBackbone(BaseBackbone):
62 | """ MobileNetV2 Backbone
63 | """
64 |
65 | def __init__(self, in_channels, with_norm=True):
66 | super(MobileNetV3LargeBackbone, self).__init__(in_channels)
67 |
68 | self.model = MobileNetV3(self.in_channels, num_classes=None, with_norm=with_norm)
69 | self.enc_channels = [16, 24, 40, 112, 1280]
70 |
71 | def forward(self, x, priors=None):
72 | x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
73 | enc2x = x
74 | x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
75 | enc4x = x
76 | x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
77 | enc8x = x
78 | x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 13)), x)
79 | enc16x = x
80 | x = reduce(lambda x, n: self.model.features[n](x), list(range(13, 17)), x)
81 | enc32x = x
82 | return [enc2x, enc4x, enc8x, enc16x, enc32x]
83 |
--------------------------------------------------------------------------------
/model/lap_pyramid_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def gauss_kernel(size=5, device=torch.device('cpu'), channels=3):
4 | kernel = torch.tensor([[1., 4., 6., 4., 1],
5 | [4., 16., 24., 16., 4.],
6 | [6., 24., 36., 24., 6.],
7 | [4., 16., 24., 16., 4.],
8 | [1., 4., 6., 4., 1.]])
9 | kernel /= 256.
10 | kernel = kernel.repeat(channels, 1, 1, 1)
11 | kernel = kernel.to(device)
12 | return kernel
13 |
14 | def downsample(x):
15 | return x[:, :, ::2, ::2]
16 |
17 | def upsample(x):
18 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
19 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
20 | cc = cc.permute(0,1,3,2)
21 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]*2, device=x.device)], dim=3)
22 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]*2)
23 | x_up = cc.permute(0,1,3,2)
24 | return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1], device=x.device))
25 |
26 | def conv_gauss(img, kernel):
27 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
28 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
29 | return out
30 |
31 | def laplacian_pyramid(img, kernel, max_levels=3):
32 | current = img
33 | pyr = []
34 | for level in range(max_levels):
35 | filtered = conv_gauss(current, kernel)
36 | down = downsample(filtered)
37 | up = upsample(down)
38 | diff = current-up
39 | pyr.append(diff)
40 | current = down
41 | return pyr
42 |
43 | class LapLoss(torch.nn.Module):
44 | def __init__(self, max_levels=3, channels=3, device=torch.device('cpu')):
45 | super(LapLoss, self).__init__()
46 | self.max_levels = max_levels
47 | self.gauss_kernel = gauss_kernel(channels=channels, device=device)
48 |
49 | def forward(self, input, target):
50 | pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
51 | pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
52 | return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
53 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from model.lap_pyramid_loss import LapLoss
8 |
9 |
10 | def matting_loss(p, d, mask=None, with_lap=False):
11 | assert p.shape == d.shape
12 |
13 | if mask is not None:
14 | loss = torch.sqrt((p - d) ** 2 + 1e-10) * mask
15 | loss = loss.sum() / (mask.sum() + 1)
16 | else:
17 | loss = torch.sqrt((p - d) ** 2 + 1e-10)
18 | loss = loss.mean()
19 |
20 | if with_lap:
21 | lap_loss = LapLoss(5, device=torch.device('cuda'))
22 | loss = loss + lap_loss(p, d)
23 | return loss
24 |
25 |
26 | def composition_loss(alpha, img, fg, bg, mask):
27 | comp = alpha * fg + (1. - alpha) * bg
28 | diff = (comp - img) * mask
29 | loss = torch.sqrt(diff ** 2 + 1e-12)
30 | loss = loss.sum() / (mask.sum() + 1.) / 3.
31 | return loss
32 |
33 |
34 | def losses(pred_list, input_dict, alpha_loss_weights=[1.0, 1.0, 1.0, 1.0], with_composition_loss=False, composition_loss_weight=1.0):
35 | label = input_dict['hr_label']
36 | mask = input_dict['hr_unknown']
37 |
38 | loss_dict = {}
39 |
40 | alpha_loss = 0.
41 | for i, pred in enumerate(pred_list):
42 | stride = label.size(2) / pred.size(2)
43 | pred = F.interpolate(pred, scale_factor=stride, mode='bilinear', align_corners=False)
44 | alpha_loss += matting_loss(pred, label, mask, with_lap=True) * alpha_loss_weights[i]
45 | loss_dict['alpha_loss'] = alpha_loss
46 |
47 | if with_composition_loss:
48 | comp_loss = composition_loss(pred_list[-1], input_dict['hr_image'],
49 | input_dict['hr_fg'], input_dict['hr_bg'], mask) * composition_loss_weight
50 | loss_dict['comp_loss'] = comp_loss
51 |
52 | loss = 0.
53 | for k, v in loss_dict.items():
54 | if k.endswith('loss'):
55 | loss += v
56 | loss_dict['loss'] = loss
57 | return loss_dict
58 |
--------------------------------------------------------------------------------
/model/lpn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from model.backbones import MobileNetV2Backbone
6 | from model.utils import upas
7 |
8 |
9 | class IBNorm(nn.Module):
10 | """ Combine Instance Norm and Batch Norm into One Layer
11 | """
12 |
13 | def __init__(self, in_channels):
14 | super(IBNorm, self).__init__()
15 | in_channels = in_channels
16 | self.bnorm_channels = int(in_channels / 2)
17 | self.inorm_channels = in_channels - self.bnorm_channels
18 |
19 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
20 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
21 |
22 | def forward(self, x):
23 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
24 | n, c, h, w = bn_x.shape
25 | if n==1 and h==1 and w==1:
26 | in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous().expand(n*2, c, h, w).contiguous())[0:1]
27 | else:
28 | in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous())
29 | return torch.cat((bn_x, in_x), 1)
30 |
31 |
32 | class Conv2dIBNormRelu(nn.Module):
33 | """ Convolution + IBNorm + ReLu
34 | """
35 |
36 | def __init__(self, in_channels, out_channels, kernel_size,
37 | stride=1, padding=0, dilation=1, groups=1, bias=True,
38 | with_ibn=True, with_relu=True):
39 | super(Conv2dIBNormRelu, self).__init__()
40 |
41 | layers = [
42 | nn.Conv2d(in_channels, out_channels, kernel_size,
43 | stride=stride, padding=padding, dilation=dilation,
44 | groups=groups, bias=bias)
45 | ]
46 |
47 | if with_ibn:
48 | layers.append(IBNorm(out_channels))
49 | if with_relu:
50 | layers.append(nn.ReLU(inplace=True))
51 |
52 | self.layers = nn.Sequential(*layers)
53 |
54 | def forward(self, x):
55 | return self.layers(x)
56 |
57 |
58 | class SEBlock(nn.Module):
59 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
60 | """
61 |
62 | def __init__(self, in_channels, out_channels, reduction=1):
63 | super(SEBlock, self).__init__()
64 | self.pool = nn.AdaptiveAvgPool2d(1)
65 | self.fc = nn.Sequential(
66 | nn.Linear(in_channels, int(in_channels // reduction), bias=False),
67 | nn.ReLU(inplace=True),
68 | nn.Linear(int(in_channels // reduction), out_channels, bias=False),
69 | nn.Sigmoid()
70 | )
71 |
72 | def forward(self, x):
73 | b, c, _, _ = x.size()
74 | w = self.pool(x).view(b, c)
75 | w = self.fc(w).view(b, c, 1, 1)
76 | return x * w.expand_as(x)
77 |
78 |
79 | class HLBranch(nn.Module):
80 | """ High Resolution Branch of MODNet
81 | """
82 |
83 | def __init__(self, hr_channels, enc_channels, with_norm=True):
84 | super(HLBranch, self).__init__()
85 |
86 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
87 |
88 | self.p32x = Conv2dIBNormRelu(enc_channels[4], 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
89 |
90 | self.conv_dec16x = nn.Sequential(
91 | Conv2dIBNormRelu(enc_channels[4]+enc_channels[3], 2*hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
92 | Conv2dIBNormRelu(2*hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
93 | )
94 | self.p16x = Conv2dIBNormRelu(hr_channels+1, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
95 |
96 | self.conv_dec8x = nn.Sequential(
97 | Conv2dIBNormRelu(hr_channels + enc_channels[2], 2*hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
98 | Conv2dIBNormRelu(2*hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
99 | )
100 | self.p8x = Conv2dIBNormRelu(hr_channels+1, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
101 |
102 | self.conv_dec4x = nn.Sequential(
103 | Conv2dIBNormRelu(hr_channels + enc_channels[1], 2*hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
104 | Conv2dIBNormRelu(2*hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
105 | )
106 | self.p4x = Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
107 |
108 | self.conv_dec2x = nn.Sequential(
109 | Conv2dIBNormRelu(hr_channels+enc_channels[0], 2*hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
110 | Conv2dIBNormRelu(2*hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
111 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
112 | )
113 | self.p2x = Conv2dIBNormRelu(hr_channels+1, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
114 |
115 | self.conv_dec1x = nn.Sequential(
116 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1, with_ibn=with_norm),
117 | )
118 | self.p1x = Conv2dIBNormRelu(hr_channels+1, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
119 |
120 | self.p0x = Conv2dIBNormRelu(2, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
121 |
122 | def forward(self, img, enc2x, enc4x, enc8x, enc16x, enc32x, is_training=True):
123 | enc32x = self.se_block(enc32x)
124 | p32x = self.p32x(enc32x)
125 | p32x = upas(p32x, img)
126 |
127 | dec16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
128 | dec16x = self.conv_dec16x(torch.cat((dec16x, enc16x), dim=1))
129 | p16x = self.p16x(torch.cat((dec16x, upas(p32x, dec16x)), dim=1))
130 | p16x = upas(p16x, img)
131 |
132 | dec8x = F.interpolate(dec16x, scale_factor=2, mode='bilinear', align_corners=False)
133 | dec8x = self.conv_dec8x(torch.cat((dec8x, enc8x), dim=1))
134 | p8x = self.p8x(torch.cat((dec8x, upas(p16x, dec8x)), dim=1))
135 | p8x = upas(p8x, img)
136 |
137 | dec4x = F.interpolate(dec8x, scale_factor=2, mode='bilinear', align_corners=False)
138 | dec4x = self.conv_dec4x(torch.cat((dec4x, enc4x), dim=1))
139 | p4x = self.p4x(dec4x)
140 | p4x = upas(p4x, img)
141 |
142 | dec2x = F.interpolate(dec4x, scale_factor=2, mode='bilinear', align_corners=False)
143 | dec2x = self.conv_dec2x(torch.cat((dec2x, enc2x), dim=1))
144 | p2x = self.p2x(torch.cat((dec2x, upas(p4x, dec2x)), dim=1))
145 | p2x = upas(p2x, img)
146 |
147 | dec1x = F.interpolate(dec2x, scale_factor=2, mode='bilinear', align_corners=False)
148 | dec1x = self.conv_dec1x(torch.cat((dec1x, img), dim=1))
149 | p1x = self.p1x(torch.cat((dec1x, upas(p2x, dec1x)), dim=1))
150 |
151 | p0x = self.p0x(torch.cat((p1x, upas(p8x, p1x)), dim=1))
152 |
153 | seg_out = [torch.sigmoid(p) for p in (p8x, p16x, p32x)]
154 | mat_out = [torch.sigmoid(p) for p in (p1x, p2x, p4x)]
155 | fus_out = [torch.sigmoid(p) for p in (p0x,)]
156 | return seg_out, mat_out, fus_out, [dec1x, dec2x, dec4x, dec8x, dec16x]
157 |
158 |
159 | class AuxilaryHead(nn.Module):
160 | def __init__(self, hr_channels, enc_channels):
161 | super().__init__()
162 |
163 | self.s1 = Conv2dIBNormRelu(
164 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
165 | self.s2 = Conv2dIBNormRelu(
166 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
167 | self.s4 = Conv2dIBNormRelu(
168 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
169 | self.s8 = Conv2dIBNormRelu(
170 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
171 | self.s16 = Conv2dIBNormRelu(
172 | hr_channels, 3, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False)
173 |
174 | def forward(self, dec1x, dec2x, dec4x, dec8x, dec16x, is_training=True):
175 | p1 = self.s1(dec1x)
176 |
177 | x2 = self.s2(dec2x)
178 | p2 = F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=False)
179 |
180 | x4 = self.s4(dec4x)
181 | p4 = F.interpolate(x4, scale_factor=4, mode='bilinear', align_corners=False)
182 |
183 | x8 = self.s8(dec8x)
184 | p8 = F.interpolate(x8, scale_factor=8, mode='bilinear', align_corners=False)
185 |
186 | x16 = self.s16(dec16x)
187 | p16 = F.interpolate(x16, scale_factor=16, mode='bilinear', align_corners=False)
188 |
189 | return (p1,p2,p4,p8,p16)
190 |
191 |
192 | class LPN(nn.Module):
193 | def __init__(self, in_chn=3, mid_chn=128):
194 | super().__init__()
195 | self.backbone = MobileNetV2Backbone(in_chn)
196 | self.decoder = HLBranch(mid_chn, self.backbone.enc_channels)
197 | self.aux_head = AuxilaryHead(mid_chn, self.backbone.enc_channels)
198 |
199 | def forward(self, images):
200 | enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(images)
201 | seg_outs, mat_outs, fus_outs, decoded_feats = self.decoder(images, enc2x, enc4x, enc8x, enc16x, enc32x)
202 | return fus_outs[0], decoded_feats[-1]
203 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from model.utils import upas, batch_slice
9 | from model.lpn import LPN
10 | from model.shm import SHM
11 |
12 |
13 | class SparseMat(nn.Module):
14 | def __init__(self, cfg):
15 | super(SparseMat, self).__init__()
16 | self.cfg = cfg
17 | in_ch = cfg.model.in_channel
18 | hr_ch = cfg.model.hr_channel
19 | self.lpn = LPN(in_ch, hr_ch)
20 | self.shm = SHM(inc=4)
21 | self.stride = cfg.model.dilation_kernel
22 | self.dilate_op = nn.MaxPool2d(self.stride, stride=1, padding=self.stride//2)
23 | self.max_n_pixel = cfg.model.max_n_pixel
24 | self.cascade = cfg.test.cascade
25 |
26 | @torch.no_grad()
27 | def generate_sparse_inputs(self, img, lr_pred, mask):
28 | lr_pred = (lr_pred - 0.5) / 0.5
29 | x = torch.cat((img, lr_pred), dim=1)
30 | indices = torch.where(mask.squeeze(1)>0)
31 | x = x.permute(0,2,3,1)
32 | x = x[indices]
33 | indices = torch.stack(indices, dim=1)
34 | return x, indices
35 |
36 | def dilate(self, alpha, stride=15):
37 | mask = torch.logical_and(alpha>0.01, alpha<0.99).float()
38 | mask = self.dilate_op(mask)
39 | return mask
40 |
41 | def forward(self, input_dict):
42 | xlr = input_dict['lr_image']
43 | xhr = input_dict['hr_image']
44 |
45 | lr_pred, ctx = self.lpn(xlr)
46 | lr_pred = lr_pred.clone().detach()
47 | ctx = ctx.clone().detach()
48 |
49 | lr_pred = batch_slice(lr_pred, input_dict['pos'], xhr.size()[2:])
50 | lr_pred = upas(lr_pred, xhr)
51 | if 'hr_unknown' in input_dict:
52 | mask = input_dict['hr_unknown']
53 | else:
54 | mask = self.dilate(lr_pred)
55 |
56 | sparse_inputs, coords = self.generate_sparse_inputs(xhr, lr_pred, mask=mask)
57 | pred_list = self.shm(sparse_inputs, lr_pred, coords, xhr.size(0), mask.size()[2:], ctx=ctx)
58 | return pred_list
59 |
60 | def generate_sparsity_map(self, lr_pred, curr_img, last_img):
61 | mask_s = self.dilate(lr_pred)
62 | if last_img is not None:
63 | diff = (curr_img - last_img).abs().mean(dim=1, keepdim=True)
64 | shared = torch.logical_and(
65 | F.conv2d(diff, torch.ones(1,1,9,9,device=diff.device), padding=4) < 0.05,
66 | F.conv2d(diff, torch.ones(1,1,1,1,device=diff.device), padding=0) < 0.001,
67 | ).float()
68 | mask_t = self.dilate_op(1 - shared)
69 | mask = mask_s * mask_t
70 | mask = self.dilate_op(mask)
71 | else:
72 | shared = torch.zeros_like(mask_s)
73 | mask_t = torch.ones_like(mask_s)
74 | mask = mask_s * mask_t
75 | return mask, mask_s, mask_t, shared
76 |
77 | def inference(self, hr_img, lr_img=None, last_img=None, last_pred=None):
78 | h, w = hr_img.shape[-2:]
79 |
80 | if lr_img is None:
81 | nh = 512. / min(h,w) * h
82 | nh = math.ceil(nh / 32) * 32
83 | nw = 512. / min(h,w) * w
84 | nw = math.ceil(nw / 32) * 32
85 | lr_img = F.interpolate(hr_img, (int(nh), int(nw)), mode="bilinear")
86 |
87 | lr_pred, ctx = self.lpn(lr_img)
88 | lr_pred_us = upas(lr_pred, hr_img)
89 | mask, mask_s, mask_t, shared = self.generate_sparsity_map(lr_pred_us, hr_img, last_img)
90 | n_pixel = mask.sum().item()
91 |
92 | if n_pixel <= self.max_n_pixel:
93 | sparse_inputs, coords = self.generate_sparse_inputs(hr_img, lr_pred_us, mask)
94 | preds = self.shm(sparse_inputs, lr_pred_us, coords, hr_img.size(0), mask.size()[2:], ctx=ctx)
95 | hr_pred_sp = preds[-1]
96 | if last_pred is not None:
97 | hr_pred = hr_pred_sp * mask + lr_pred_us * (1-mask) * (1-shared) + last_pred * (1-mask) * shared
98 | else:
99 | hr_pred = hr_pred_sp * mask + lr_pred_us * (1-mask)
100 | elif self.cascade:
101 | print("Cascading is used.")
102 | for scale in [0.25, 0.5, 1.0]:
103 | hr_img_ds = F.interpolate(hr_img, None, scale_factor=scale, mode="bilinear")
104 | lr_pred_us = upas(lr_pred, hr_img_ds)
105 | mask_s = self.dilate(lr_pred_us)
106 | if mask_s.sum() > self.max_n_pixel:
107 | break
108 | sparse_inputs, coords = self.generate_sparse_inputs(hr_img_ds, lr_pred_us, mask_s)
109 | preds = self.shm(sparse_inputs, lr_pred_us, coords, hr_img_ds.size(0), mask_s.size()[2:], ctx=ctx)
110 | hr_pred_sp = preds[-1]
111 | hr_pred = hr_pred_sp * mask_s + lr_pred_us * (1-mask_s)
112 | lr_pred = hr_pred
113 | else:
114 | print("Rescaling is used.")
115 | scale = math.sqrt(self.max_n_pixel / float(n_pixel))
116 | nh = int(scale * h)
117 | nw = int(scale * w)
118 | nh = math.ceil(nh / 8) * 8
119 | nw = math.ceil(nw / 8) * 8
120 |
121 | hr_img_ds = F.interpolate(hr_img, (nh, nw), mode="bilinear")
122 | lr_pred_us = upas(lr_pred, hr_img_ds)
123 | mask_s = self.dilate(lr_pred_us)
124 |
125 | sparse_inputs, coords = self.generate_sparse_inputs(hr_img_ds, lr_pred_us, mask_s)
126 | preds = self.shm(sparse_inputs, lr_pred_us, coords, hr_img_ds.size(0), mask_s.size()[2:], ctx=ctx)
127 | hr_pred_sp = preds[-1]
128 | hr_pred = hr_pred_sp * mask_s + lr_pred_us * (1-mask_s)
129 | return hr_pred
130 |
--------------------------------------------------------------------------------
/model/shm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import spconv
5 |
6 | from model.backbones.sparse_resnet_bn import l_sparse_resnet18
7 | from model.lpn import Conv2dIBNormRelu
8 |
9 |
10 | class SparseCAM(nn.Module):
11 | def __init__(self, local_inc, global_inc, with_norm=True):
12 | super(SparseCAM, self).__init__()
13 |
14 | self.pool_fg = nn.AdaptiveAvgPool2d(1)
15 | self.pool_bg = nn.AdaptiveAvgPool2d(1)
16 | self.conv_f = Conv2dIBNormRelu(global_inc, global_inc, kernel_size=1, with_ibn=False)
17 | self.conv_b = Conv2dIBNormRelu(global_inc, global_inc, kernel_size=1, with_ibn=False)
18 | self.conv_g = Conv2dIBNormRelu(2*global_inc, local_inc, kernel_size=1, with_relu=False, with_ibn=False)
19 |
20 | def forward(self, idx, x, ctx, mask):
21 | mask_lr = F.interpolate(mask, ctx.size()[2:], align_corners=False, mode='bilinear')
22 | fg_pool = self.pool_fg(ctx * mask_lr)
23 | fg_ctx = self.conv_f(fg_pool)
24 | bg_pool = self.pool_bg(ctx * (1-mask_lr))
25 | bg_ctx = self.conv_b(bg_pool)
26 | weight = torch.sigmoid(self.conv_g(torch.cat([fg_ctx, bg_ctx], dim=1))).squeeze(3).squeeze(2)
27 | sparse_weight = weight[x.indices[:,0].long()]
28 | x.features = x.features * sparse_weight
29 | return x
30 |
31 |
32 | class SparseDecoder3_18(spconv.SparseModule):
33 | def __init__(self, inc=512):
34 | super(SparseDecoder3_18, self).__init__()
35 |
36 | # upconv modules
37 | self.conv_up1 = spconv.SparseSequential(
38 | spconv.SparseInverseConv2d(inc, 256, kernel_size=3, bias=True, indice_key='spconv2'),
39 | nn.BatchNorm1d(256),
40 | nn.LeakyReLU(),
41 | )
42 |
43 | self.conv_up2 = spconv.SparseSequential(
44 | spconv.SparseInverseConv2d(256 + 64, 256, kernel_size=3, bias=True, indice_key='spconv1'),
45 | nn.BatchNorm1d(256),
46 | nn.LeakyReLU(),
47 | )
48 |
49 | self.conv_up3 = spconv.SparseSequential(
50 | spconv.SparseInverseConv2d(256 + 64, 64, kernel_size=3, bias=True, indice_key='spconv0'),
51 | nn.BatchNorm1d(64),
52 | nn.LeakyReLU(),
53 | )
54 |
55 | chn = 64 + 3
56 |
57 | self.conv_up4_alpha = spconv.SparseSequential(
58 | spconv.SubMConv2d(chn, 32, kernel_size=3, padding=1, bias=True, indice_key='subm0s'),
59 | nn.LeakyReLU(),
60 | spconv.SubMConv2d(32, 16, kernel_size=3, padding=1, bias=True, indice_key='subm0s'),
61 | nn.LeakyReLU(),
62 | spconv.SubMConv2d(16, 1, kernel_size=1, padding=0, bias=False, indice_key='subm0s')
63 | )
64 |
65 | self.conv_p8x = spconv.SubMConv2d(256, 1, kernel_size=1, padding=0, bias=False, indice_key='spconv2')
66 | self.conv_p4x = spconv.SubMConv2d(256, 1, kernel_size=1, padding=0, bias=False, indice_key='spconv1')
67 | self.conv_p2x = spconv.SubMConv2d(64, 1, kernel_size=1, padding=0, bias=False, indice_key='spconv0')
68 |
69 | def forward(self, img, conv_out, coarse=None, is_training=True):
70 | x1, x2, x3, x4, x5 = conv_out
71 |
72 | dec4x = self.conv_up1(x5)
73 | p4x = self.conv_p8x(dec4x)
74 |
75 | dec4x.features = torch.cat((dec4x.features, x2.features), 1)
76 | dec2x = self.conv_up2(dec4x)
77 | p2x = self.conv_p4x(dec2x)
78 |
79 | dec2x.features = torch.cat((dec2x.features, x1.features), 1)
80 | dec1x = self.conv_up3(dec2x)
81 | p1x = self.conv_p2x(dec1x)
82 |
83 | img.features = img.features[:,:3] * 0.5 + 0.5
84 | dec1x.features = torch.cat((dec1x.features, img.features),1)
85 | p0x = self.conv_up4_alpha(dec1x)
86 |
87 | raws = [p4x.dense(), p2x.dense(), p1x.dense(), p0x.dense()]
88 | p4x.features = torch.sigmoid(p4x.features)
89 | p2x.features = torch.sigmoid(p2x.features)
90 | p1x.features = torch.sigmoid(p1x.features)
91 | p0x.features = torch.sigmoid(p0x.features)
92 | outs = [p4x.dense(), p2x.dense(), p1x.dense(), p0x.dense()]
93 | return outs
94 |
95 |
96 | class SHM(nn.Module):
97 | def __init__(self, inc=4):
98 | super(SHM, self).__init__()
99 |
100 | self.ctx = SparseCAM(512, 32, with_norm=True)
101 | self.backbone = l_sparse_resnet18(inc, stride=8)
102 | self.decoder = SparseDecoder3_18()
103 |
104 | def forward(self, inputs, lr_pred, coords, batch_size, spatial_shape, ctx):
105 | x = spconv.SparseConvTensor(inputs, coords.int(), spatial_shape, batch_size)
106 | encoded_feats = self.backbone(x)
107 | encoded_feats[-1] = self.ctx(coords.int(), encoded_feats[-1], ctx, lr_pred)
108 | outs = self.decoder(x, encoded_feats)
109 | return outs
110 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import OrderedDict
3 | from scipy.ndimage import morphology
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | def _init_conv(conv):
11 | nn.init.xavier_uniform_(conv.weight)
12 | if conv.bias is not None:
13 | nn.init.constant_(conv.bias, 0)
14 |
15 |
16 | def _init_norm(norm):
17 | if norm.weight is not None:
18 | nn.init.constant_(norm.weight, 1)
19 | nn.init.constant_(norm.bias, 0)
20 |
21 |
22 | def _generate_random_trimap(x, dist=(1,30), is_training=True):
23 | fg = (x>0.999).type(torch.float)
24 | un = (x>=0.001).type(torch.float) - fg
25 | un_np = (un*255).squeeze(1).data.cpu().numpy().astype(np.uint8)
26 | if is_training:
27 | thresh = np.random.randint(dist[0], dist[1])
28 | else:
29 | thresh = (dist[0] + dist[1]) // 2
30 | un_np = [(morphology.distance_transform_edt(item==0) <= thresh) for item in un_np]
31 | un_np = np.array(un_np)
32 | un = torch.from_numpy(un_np).unsqueeze(1).to(x.device)
33 | trimap = fg
34 | trimap[un>0] = 0.5
35 | return trimap
36 |
37 |
38 | def _make_divisible(v, divisor, min_value=None):
39 | """
40 | This function is taken from the original tf repo.
41 | It ensures that all layers have a channel number that is divisible by 8
42 | It can be seen here:
43 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
44 | """
45 | if min_value is None:
46 | min_value = divisor
47 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
48 | # Make sure that round down does not go down by more than 10%.
49 | if new_v < 0.9 * v:
50 | new_v += divisor
51 | return new_v
52 |
53 |
54 | ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
55 | def _upsample_like(src,tar,mode='bilinear'):
56 | src = F.interpolate(src,size=tar.shape[2:],mode=mode,align_corners=False if mode=='bilinear' else None)
57 | return src
58 | upas = _upsample_like
59 |
60 |
61 | def batch_slice(tensor, pos, size, mode='bilinear'):
62 | n, c, h, w = tensor.shape
63 | patchs = []
64 | for i in range(n):
65 | # x1, y1, x2, y2 = torch.clamp(pos[i], 0, 1)
66 | x1, y1, x2, y2 = pos[i]
67 | x1 = int(x1.item() * w)
68 | y1 = int(y1.item() * h)
69 | x2 = int(x2.item() * w)
70 | y2 = int(y2.item() * h)
71 | patch = tensor[i:i+1, :, y1:y2, x1:x2].contiguous()
72 | patch = F.interpolate(patch, (size[0], size[1]), align_corners=False if mode=='bilinear' else None, mode=mode)
73 | patchs.append(patch)
74 | return torch.cat(patchs, dim=0)
75 |
76 |
77 | def hard_sigmoid(x, inplace: bool = False):
78 | if inplace:
79 | return x.add_(3.).clamp_(0., 6.).div_(6.)
80 | else:
81 | return F.relu6(x + 3.) / 6.
82 |
83 |
84 | ## copy weight from old tensor to new tensor
85 | def copy_weight(ws, wd):
86 |
87 | assert len(ws.shape)==4 or len(ws.shape)==1
88 |
89 | if len(ws.shape) == 4 and ws.shape[2]==ws.shape[3] and ws.shape[3]<=7:
90 | cout1, cin1, kh, kw = ws.shape
91 | cout2, cin2, kh, kw = wd.shape
92 | weight = torch.zeros((cout2, cin2, kh, kw)).float().to(ws.device)
93 | cout3 = min(cout1, cout2)
94 | cin3 = min(cin1, cin2)
95 | weight[:cout3, :cin3] = ws[:cout3, :cin3]
96 | elif len(ws.shape) == 4:
97 | kh, kw, cin1, cout1 = ws.shape # (3,3,4,64)
98 | kh, kw, cin2, cout2 = wd.shape
99 | print(ws.shape, wd.shape)
100 | weight = torch.zeros((kh, kw, cin2, cout2)).float().to(ws.device)
101 | cout3 = min(cout1, cout2)
102 | cin3 = min(cin1, cin2)
103 | weight[:, :, :cin3, :cout3] = ws[:, :, :cin3, :cout3]
104 | else:
105 | cout1, = ws.shape
106 | cout2, = wd.shape
107 | cout3 = min(cout1, cout2)
108 | weight = torch.zeros((cout3,)).float().to(ws.device)
109 | weight[:cout3] = ws[:cout3]
110 | return weight
111 |
112 |
113 | ## only works for models with same architecture
114 | def load_pretrained_weight(model, ckpt_path, copy=True):
115 | ckpt = torch.load(ckpt_path)
116 | filtered_ckpt = OrderedDict()
117 | model_ckpt = model.state_dict()
118 | for k,v in ckpt.items():
119 | if k in model_ckpt:
120 | if v.shape==model_ckpt[k].shape:
121 | filtered_ckpt[k] = v
122 | elif copy:
123 | filtered_ckpt[k] = copy_weight(v, model_ckpt[k])
124 | model.load_state_dict(filtered_ckpt, strict=False)
125 | return model
126 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import cv2
5 | from collections import OrderedDict
6 | from torchvision import transforms
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.utils.data import Dataset, DataLoader
12 |
13 | from model import SparseMat
14 | from utils import load_config, get_logger
15 | from datasets import RescaleT, ToTensor, CustomDataset
16 |
17 |
18 | def get_timestamp():
19 | from datetime import datetime
20 | now = datetime.now()
21 | dt_string = now.strftime("%Y-%m-%d-%H-%M-%S")
22 | return dt_string
23 |
24 |
25 | def load_checkpoint(net, pretrained_model, logger):
26 | net_state_dict = net.state_dict()
27 | state_dict = torch.load(pretrained_model)
28 | if 'state_dict' in state_dict:
29 | state_dict = state_dict['state_dict']
30 | elif 'model_state_dict' in state_dict:
31 | state_dict = state_dict['model_state_dict']
32 |
33 | filtered_state_dict = OrderedDict()
34 | for k,v in state_dict.items():
35 | if k.startswith('module'):
36 | nk = '.'.join(k.split('.')[1:])
37 | else:
38 | nk = k
39 | filtered_state_dict[nk] = v
40 | net.load_state_dict(filtered_state_dict)
41 | logger.info('load pretrained weight from {} successfully'.format(pretrained_model))
42 |
43 |
44 | def load_test_filelist(test_data_path):
45 | test_images = []
46 | test_labels = []
47 | for line in open(test_data_path).read().splitlines():
48 | splits = line.split(',')
49 | img_path, mat_path = splits
50 | test_labels.append(mat_path)
51 | test_images.append(img_path)
52 | return test_images, test_labels
53 |
54 |
55 | def compute_metrics(pred, gt):
56 | assert pred.size(0)==1 and pred.size(1)==1
57 | if pred.shape[2:] != gt.shape[2:]:
58 | pred = F.interpolate(pred, gt.shape[2:], mode='bilinear', align_corners=False)
59 | mad = (pred-gt).abs().mean()
60 | mse = ((pred-gt)**2).mean()
61 | return mse, mad
62 |
63 |
64 | def save_preds(pred, save_dir, filename):
65 | os.makedirs(save_dir, exist_ok=True)
66 | pred = pred.squeeze().data.cpu().numpy() * 255
67 | imgname = filename.split('/')[-1].split('.')[0] + '.png'
68 | cv2.imwrite(os.path.join(save_dir, imgname), pred)
69 |
70 |
71 | def test(cfg, net, dataloader, filenames, logger):
72 | net.eval()
73 |
74 | mse_list = []
75 | mad_list = []
76 |
77 | with torch.no_grad():
78 | for i, data in enumerate(dataloader):
79 | input_dict = {}
80 | for k, v in data.items():
81 | input_dict[k] = v.cuda()
82 |
83 | pred = net.inference(input_dict['hr_image'])
84 | origin_h = input_dict['origin_h']
85 | origin_w = input_dict['origin_w']
86 | pred = F.interpolate(pred, (origin_h, origin_w), align_corners=False, mode="bilinear")
87 |
88 | if cfg.test.save:
89 | save_preds(pred, cfg.test.save_dir, filenames[i])
90 |
91 | gt = input_dict['hr_label']
92 | mse, mad = compute_metrics(pred, gt)
93 | mse_list.append(mse.item())
94 | mad_list.append(mad.item())
95 |
96 | logger.info('[ith:%d/%d] mad:%.5f mse:%.5f' % (i, len(dataloader), mad.item(), mse.item()))
97 |
98 | avg_mad = np.array(mad_list).mean()
99 | avg_mse = np.array(mse_list).mean()
100 | logger.info('avg_mad:%.5f avg_mse:%.5f' % (avg_mad.item(), avg_mse.item()))
101 |
102 |
103 | def main():
104 | parser = argparse.ArgumentParser(description='HM')
105 | parser.add_argument('--local_rank', type=int, default=0)
106 | parser.add_argument('--dist', action='store_true', help='use distributed training')
107 | parser.add_argument('-c', '--config', type=str, metavar='FILE', help='path to config file')
108 | parser.add_argument('-p', '--phase', default="train", type=str, metavar='PHASE', help='train or test')
109 |
110 | args = parser.parse_args()
111 | cfg = load_config(args.config)
112 | device_ids = range(torch.cuda.device_count())
113 |
114 | dataset = cfg.data.dataset
115 | model_name = cfg.model.arch
116 | exp_name = args.config.split('/')[-1].split('.')[0]
117 | timestamp = get_timestamp()
118 |
119 | cfg.log.log_dir = os.path.join(os.getcwd(), 'log', model_name, dataset, exp_name+os.sep)
120 | cfg.log.log_path = os.path.join(cfg.log.log_dir, "log_eval.txt")
121 | os.makedirs(cfg.log.log_dir, exist_ok=True)
122 |
123 | if cfg.test.save_dir is None:
124 | cfg.test.save_dir = os.path.join(cfg.log.log_dir, 'vis')
125 | os.makedirs(cfg.test.save_dir, exist_ok=True)
126 |
127 | logger = get_logger(cfg.log.log_path)
128 | logger.info('[LogPath] {}'.format(cfg.log.log_dir))
129 |
130 | test_images, test_labels = load_test_filelist(cfg.data.filelist_test)
131 |
132 | test_transform = transforms.Compose([
133 | RescaleT(cfg),
134 | ToTensor(cfg)
135 | ])
136 |
137 | test_dataset = CustomDataset(
138 | cfg,
139 | is_training=False,
140 | img_name_list=test_images,
141 | lbl_name_list=test_labels,
142 | transform=test_transform
143 | )
144 |
145 | test_dataloader = DataLoader(
146 | test_dataset,
147 | batch_size=cfg.test.batch_size,
148 | shuffle=False,
149 | pin_memory=True,
150 | num_workers=cfg.test.num_workers
151 | )
152 |
153 | net = SparseMat(cfg)
154 |
155 | if torch.cuda.is_available():
156 | net.cuda()
157 | else:
158 | exit()
159 |
160 | load_checkpoint(net, cfg.test.checkpoint, logger)
161 | test(cfg, net, test_dataloader, test_images, logger)
162 |
163 |
164 | if __name__ == "__main__":
165 | main()
166 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import numpy as np
5 | import cv2
6 | from functools import partial
7 | from collections import OrderedDict
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | import torchvision
14 |
15 | from torch.utils.data import Dataset, DataLoader
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torchvision import transforms
18 |
19 | from datasets import Rescale, RescaleT, RandomFlip, RandomCrop, ToTensor, CustomDataset
20 | from model import SparseMat, losses
21 | from utils import load_config, grid_images, get_logger
22 |
23 |
24 | def get_timestamp():
25 | from datetime import datetime
26 | now = datetime.now()
27 | dt_string = now.strftime("%Y-%m-%d-%H-%M-%S")
28 | return dt_string
29 |
30 |
31 | def adjust_learning_rate(optimizer, epoch, epoch_decay, init_lr, min_lr=1e-6):
32 | for param_group in optimizer.param_groups:
33 | lr = max(init_lr * (0.1 ** (epoch // epoch_decay)), min_lr)
34 | param_group['lr'] = lr
35 |
36 |
37 | def load_checkpoint(net, pretrained_model, logger):
38 | net_state_dict = net.state_dict()
39 | state_dict = torch.load(pretrained_model)
40 | if 'state_dict' in state_dict:
41 | state_dict = state_dict['state_dict']
42 | elif 'model_state_dict' in state_dict:
43 | state_dict = state_dict['model_state_dict']
44 |
45 | filtered_state_dict = OrderedDict()
46 | for k,v in state_dict.items():
47 | if k.startswith('module'):
48 | nk = '.'.join(k.split('.')[1:])
49 | else:
50 | nk = k
51 | filtered_state_dict[nk] = v
52 | net.load_state_dict(filtered_state_dict, strict=False)
53 | logger.info('load pretrained weight from {} successfully'.format(pretrained_model))
54 |
55 |
56 | def save_checkpoint(cfg, net, optimizer, epoch, iterations, running_loss, best_mad, is_best=False):
57 | state_dict = {
58 | 'state_dict': net.state_dict(),
59 | 'optimizer': optimizer.state_dict(),
60 | 'epoch': epoch,
61 | 'iteration': iterations + 1,
62 | 'running_loss': running_loss,
63 | 'best_mad': best_mad,
64 | }
65 | save_path = os.path.join(cfg.log.log_dir, "ckpt_e{}.pth".format(epoch))
66 | torch.save(state_dict, save_path)
67 |
68 | latest_path = os.path.join(cfg.log.log_dir, "ckpt_latest.pth")
69 | shutil.copy(save_path, latest_path)
70 |
71 | if is_best:
72 | best_path = os.path.join(cfg.log.log_dir, "ckpt_best.pth")
73 | shutil.copy(save_path, best_path)
74 |
75 |
76 | def save_preds(pred, save_dir, filename):
77 | os.makedirs(save_dir, exist_ok=True)
78 | pred = pred.squeeze().data.cpu().numpy() * 255
79 | imgname = filename.split('/')[-1].split('.')[0] + '.png'
80 | cv2.imwrite(os.path.join(save_dir, imgname), pred)
81 |
82 |
83 | def load_filelist(data_path):
84 | images = []
85 | labels = []
86 | fgs = []
87 | bgs = []
88 | for line in open(data_path).read().splitlines():
89 | splits = line.split(',')
90 | if len(splits) == 4:
91 | img_path, lbl_path, fg_path, bg_path = splits
92 | images.append(img_path)
93 | labels.append(lbl_path)
94 | fgs.append(fg_path)
95 | bgs.append(bg_path)
96 | else:
97 | img_path, lbl_path = splits
98 | images.append(img_path)
99 | labels.append(lbl_path)
100 | return images, labels, fgs, bgs
101 |
102 |
103 | def compute_metrics(pred, gt):
104 | if pred.shape[2:] != gt.shape[2:]:
105 | pred = F.interpolate(pred, gt.shape[2:], mode='bilinear', align_corners=False)
106 | mad = (pred-gt).abs().mean()
107 | mse = ((pred-gt)**2).mean()
108 | return mad, mse
109 |
110 |
111 | def train(cfg, net, optimizer, criterion, dataloader, writer, logger, epoch, iterations, best_mad):
112 | net.train()
113 | running_loss = 0.0
114 |
115 | for i, data in enumerate(dataloader):
116 | iterations += 1
117 |
118 | input_dict = {}
119 | for k, v in data.items():
120 | input_dict[k] = v.cuda()
121 |
122 | optimizer.zero_grad()
123 | pred_list = net(input_dict)
124 | loss_dict = criterion(pred_list, input_dict)
125 | loss_dict['loss'].backward()
126 | optimizer.step()
127 |
128 | running_loss += loss_dict['loss'].item()
129 |
130 | cur_lr = optimizer.param_groups[0]['lr']
131 |
132 | if iterations % cfg.log.print_frq == 0:
133 | for k,v in loss_dict.items():
134 | writer.add_scalar('loss/'+k, loss_dict[k].item(), iterations)
135 | writer.add_scalar('loss/running_loss', running_loss/(i+1), iterations)
136 | writer.add_image('train/images', torch.cat(torch.unbind(pred_list[-1], dim=0), dim=1), global_step=iterations)
137 | if 'comp_loss' in loss_dict:
138 | logger.info('[epo:%d/%d][iter:%d/%d] lr:%5f loss:%.3f alpha_loss:%.3f comp_Loss:%.3f running_loss:%.3f' % (
139 | epoch, cfg.train.epoch, (i+1), len(dataloader), cur_lr, loss_dict['loss'],
140 | loss_dict['alpha_loss'], loss_dict['comp_loss'],
141 | running_loss/(i+1)))
142 | else:
143 | logger.info('[epo:%d/%d][iter:%d/%d] lr:%5f loss:%.3f running_loss:%.3f' % (
144 | epoch, cfg.train.epoch, (i+1), len(dataloader), cur_lr, loss_dict['loss'], running_loss/(i+1)))
145 |
146 | # comment this line if memory is sufficient
147 | torch.cuda.empty_cache()
148 |
149 | return iterations, running_loss
150 |
151 |
152 | def test(cfg, net, dataloader, writer, logger, epoch, filenames):
153 | net.eval()
154 |
155 | mse_list = []
156 | mad_list = []
157 |
158 | with torch.no_grad():
159 | for i, data in enumerate(dataloader):
160 |
161 | input_dict = {}
162 | for k, v in data.items():
163 | input_dict[k] = v.cuda()
164 |
165 | pred = net.inference(input_dict['hr_image'])
166 | origin_h = input_dict['origin_h']
167 | origin_w = input_dict['origin_w']
168 | pred = F.interpolate(pred, (origin_h, origin_w), align_corners=False, mode="bilinear")
169 |
170 | gt = input_dict['hr_label']
171 | mad, mse = compute_metrics(pred, gt)
172 | mse_list.append(mse.item())
173 | mad_list.append(mad.item())
174 |
175 | logger.info('[ith:%d/%d] mad:%.5f mse:%.5f' % (i, len(dataloader), mad.item(), mse.item()))
176 |
177 | avg_mad = np.array(mad_list).mean()
178 | avg_mse = np.array(mse_list).mean()
179 | logger.info('[epo:%d/%d][ith:%d/%d] mad:%.3f mse:%.5f' % (epoch, cfg.train.epoch, i, len(dataloader), mad.item(), mse.item()))
180 | return avg_mad
181 |
182 |
183 | def main():
184 | parser = argparse.ArgumentParser(description='HM')
185 | parser.add_argument('--local_rank', type=int, default=0)
186 | parser.add_argument('--dist', action='store_true', help='use distributed training')
187 | parser.add_argument('-e', '--evaluate', action='store_true', help='evaluate or not')
188 | parser.add_argument('-c', '--config', type=str, metavar='FILE', help='path to config file')
189 | parser.add_argument('-p', '--phase', default="train", type=str, metavar='PHASE', help='train or test')
190 |
191 | args = parser.parse_args()
192 | cfg = load_config(args.config)
193 | best_mad = 1e12
194 | device_ids = range(torch.cuda.device_count())
195 |
196 | dataset = cfg.data.dataset
197 | model_name = cfg.model.arch
198 | exp_name = args.config.split('/')[-1].split('.')[0]
199 | timestamp = get_timestamp()
200 |
201 | cfg.log.log_dir = os.path.join(os.getcwd(), 'log', model_name, dataset, exp_name+os.sep)
202 | cfg.log.viz_dir = os.path.join(cfg.log.log_dir, "tensorboardx", timestamp)
203 | cfg.log.log_path = os.path.join(cfg.log.log_dir, "log.txt")
204 | os.makedirs(cfg.log.log_dir, exist_ok=True)
205 | os.makedirs(cfg.log.viz_dir, exist_ok=True)
206 |
207 | if cfg.test.save_dir is None:
208 | cfg.test.save_dir = os.path.join(cfg.log.log_dir, 'vis')
209 | os.makedirs(cfg.test.save_dir, exist_ok=True)
210 |
211 | writer = SummaryWriter(cfg.log.viz_dir)
212 | logger = get_logger(cfg.log.log_path)
213 |
214 | logger.info('[LogPath] {}'.format(cfg.log.log_dir))
215 | logger.info('[VizPath] {}'.format(cfg.log.viz_dir))
216 |
217 | train_images, train_labels, train_fgs, train_bgs = load_filelist(cfg.data.filelist_train)
218 | test_images, test_labels, test_fgs, test_bgs = load_filelist(cfg.data.filelist_val)
219 |
220 | train_transform = transforms.Compose([
221 | Rescale(cfg),
222 | RandomCrop(cfg),
223 | RandomFlip(cfg),
224 | ToTensor(cfg)
225 | ])
226 |
227 | test_transform = transforms.Compose([
228 | RescaleT(cfg),
229 | ToTensor(cfg)
230 | ])
231 |
232 | train_dataset = CustomDataset(
233 | cfg, True,
234 | img_name_list=train_images,
235 | lbl_name_list=train_labels,
236 | fg_name_list=train_fgs,
237 | bg_name_list=train_bgs,
238 | transform=train_transform
239 | )
240 | test_dataset = CustomDataset(
241 | cfg, False,
242 | img_name_list=test_images,
243 | lbl_name_list=test_labels,
244 | fg_name_list=test_fgs,
245 | bg_name_list=test_bgs,
246 | transform=test_transform
247 | )
248 |
249 | train_dataloader = DataLoader(
250 | train_dataset,
251 | batch_size=cfg.train.batch_size,
252 | shuffle=True,
253 | pin_memory=True,
254 | drop_last=True,
255 | num_workers=cfg.train.num_workers
256 | )
257 | test_dataloader = DataLoader(
258 | test_dataset,
259 | batch_size=cfg.test.batch_size,
260 | shuffle=False,
261 | pin_memory=True,
262 | drop_last=True,
263 | num_workers=cfg.test.num_workers
264 | )
265 |
266 | net = SparseMat(cfg)
267 | criterion = partial(
268 | losses,
269 | alpha_loss_weights=cfg.loss.alpha_loss_weights,
270 | with_composition_loss=cfg.loss.with_composition_loss,
271 | composition_loss_weight=cfg.loss.composition_loss_weight,
272 | )
273 |
274 | load_checkpoint(net.lpn, cfg.train.pretrained_model, logger)
275 |
276 | if torch.cuda.is_available():
277 | net.cuda()
278 | else:
279 | exit()
280 |
281 | if len(device_ids)>0:
282 | net = torch.nn.DataParallel(net)
283 | net_without_dp = net.module
284 | else:
285 | net_without_dp = net
286 |
287 | logger.info("---define optimizer...")
288 | optimizer = optim.Adam(
289 | net.parameters(),
290 | lr=cfg.train.lr,
291 | betas=(cfg.train.beta1, cfg.train.beta2),
292 | eps=1e-08,
293 | weight_decay=0,
294 | )
295 |
296 | logger.info("---start training...")
297 | iterations = 0
298 | running_loss = 0.0
299 |
300 | resume_checkpoint = os.path.join(cfg.log.log_dir, 'ckpt_latest.pth')
301 | if (args.evaluate or cfg.train.resume) and os.path.exists(resume_checkpoint):
302 | state_dict = torch.load(resume_checkpoint)
303 | if state_dict['epoch'] < cfg.train.epoch:
304 | logger.info("Resume checkpoint from {}".format(resume_checkpoint))
305 | if 'best_mad' in state_dict:
306 | best_mad = state_dict['best_mad']
307 | if 'epoch' in state_dict:
308 | cfg.train.start_epoch = state_dict['epoch']
309 | filtered_state_dict = OrderedDict()
310 | for k,v in state_dict['state_dict'].items():
311 | if k.startswith('module'):
312 | nk = '.'.join(k.split('.')[1:])
313 | else:
314 | nk = k
315 | filtered_state_dict[nk] = v
316 | net.module.load_state_dict(filtered_state_dict, strict=True)
317 |
318 | if args.evaluate:
319 | test(cfg, net_without_dp, test_dataloader, writer, logger, cfg.train.start_epoch, test_images)
320 | exit()
321 |
322 | for epoch in range(cfg.train.start_epoch, cfg.train.epoch):
323 | iterations, running_loss = train(cfg, net, optimizer, criterion, train_dataloader, writer, logger, epoch+1, iterations, best_mad)
324 | mad = test(cfg, net_without_dp, test_dataloader, writer, logger, epoch+1, test_images)
325 | if mad < best_mad:
326 | best_mad = min(mad, best_mad)
327 | save_checkpoint(cfg, net_without_dp, optimizer, epoch+1, iterations, running_loss, best_mad, is_best=True)
328 | else:
329 | save_checkpoint(cfg, net_without_dp, optimizer, epoch+1, iterations, running_loss, best_mad, is_best=False)
330 | adjust_learning_rate(optimizer, epoch, cfg.train.epoch_decay, cfg.train.lr, min_lr=cfg.train.min_lr)
331 |
332 |
333 | if __name__ == "__main__":
334 | main()
335 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from .config import load_config
4 | from .viz_utils import grid_images
5 |
6 |
7 | def get_logger(filename):
8 | logger = logging.getLogger()
9 | logger.setLevel(logging.INFO)
10 | fh = logging.FileHandler(filename, mode='a')
11 | fh.setLevel(logging.INFO)
12 | ch = logging.StreamHandler()
13 | ch.setLevel(logging.INFO)
14 | formatter = logging.Formatter("%(asctime)s - %(message)s")
15 | fh.setFormatter(formatter)
16 | ch.setFormatter(formatter)
17 | logger.addHandler(fh)
18 | logger.addHandler(ch)
19 | return logger
20 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict
2 |
3 | CONFIG = EasyDict({})
4 | CONFIG.is_default = True
5 | CONFIG.version = "baseline"
6 | CONFIG.debug = False
7 | # choices from train,evaluate,inference
8 | CONFIG.phase = "train"
9 | # distributed training
10 | CONFIG.dist = False
11 | # global variables which will be assigned in the runtime
12 | CONFIG.local_rank = 0
13 | CONFIG.gpu = 0
14 | CONFIG.world_size = 1
15 | CONFIG.devices = (0,)
16 |
17 |
18 | # ===============================================================================
19 | # Model config
20 | # ===============================================================================
21 | CONFIG.model = EasyDict({})
22 | CONFIG.model.arch = 'SparseMat'
23 |
24 | # Model -> Architecture config
25 | CONFIG.model.in_channel = 3
26 | CONFIG.model.hr_channel = 32
27 | # global modules (ppm, aspp)
28 | CONFIG.model.global_module = "ppm"
29 | CONFIG.model.pool_scales = (1,2,3,6)
30 | CONFIG.model.ppm_channel = 256
31 | CONFIG.model.atrous_rates = (12, 24, 36)
32 | CONFIG.model.aspp_channel = 256
33 | CONFIG.model.with_norm = True
34 | CONFIG.model.with_aspp = True
35 | CONFIG.model.dilation_kernel = 15
36 | CONFIG.model.max_n_pixel = 4000000
37 |
38 | # ===============================================================================
39 | # Dataloader config
40 | # ===============================================================================
41 |
42 | CONFIG.aug = EasyDict({})
43 | CONFIG.aug.rescale_size = 320
44 | CONFIG.aug.crop_size = 288
45 | CONFIG.aug.patch_crop_size = (320,640)
46 | CONFIG.aug.patch_load_size = 320
47 |
48 | CONFIG.data = EasyDict({})
49 | CONFIG.data.workers = 0
50 | CONFIG.data.dataset = None
51 | CONFIG.data.composite = False
52 | CONFIG.data.filelist = None
53 | CONFIG.data.filelist_train = None
54 | CONFIG.data.filelist_val = None
55 | CONFIG.data.filelist_test = None
56 |
57 |
58 | # ===============================================================================
59 | # Loss config
60 | # ===============================================================================
61 | CONFIG.loss = EasyDict({})
62 | CONFIG.loss.alpha_loss_weights = [1.0,1.0,1.0,1.0]
63 | CONFIG.loss.with_composition_loss = False
64 | CONFIG.loss.composition_loss_weight = 1.0
65 |
66 | # ===============================================================================
67 | # Training config
68 | # ===============================================================================
69 | CONFIG.train = EasyDict({})
70 |
71 | CONFIG.train.num_workers = 4
72 | CONFIG.train.batch_size = 8
73 | # epochs
74 | CONFIG.train.start_epoch = 0
75 | CONFIG.train.epoch = 100
76 | CONFIG.train.epoch_decay = 95
77 | # basic learning rate of optimizer
78 | CONFIG.train.lr = 1e-5
79 | CONFIG.train.min_lr = 1e-8
80 | CONFIG.train.reset_lr = False
81 | CONFIG.train.adaptive_lr = False
82 | # beta1 and beta2 for Adam
83 | CONFIG.train.optim = "Adam"
84 | CONFIG.train.eps = 1e-5
85 | CONFIG.train.beta1 = 0.9
86 | CONFIG.train.beta2 = 0.999
87 | CONFIG.train.momentum = 0.9
88 | CONFIG.train.weight_decay = 1e-4
89 | # clip large gradient
90 | CONFIG.train.clip_grad = True
91 | # reset the learning rate (this option will reset the optimizer and learning rate scheduler and ignore warmup)
92 | CONFIG.train.pretrained_model = None
93 | CONFIG.train.resume = False
94 |
95 | CONFIG.train.rescale_size = 320
96 | CONFIG.train.crop_size = 288
97 | CONFIG.train.color_space = 3
98 |
99 |
100 | # ===============================================================================
101 | # Testing config
102 | # ===============================================================================
103 | CONFIG.test = EasyDict({})
104 | # test image scale to evaluate, "origin" or "resize" or "crop"
105 | CONFIG.test.num_workers = 4
106 | CONFIG.test.batch_size = 1
107 | CONFIG.test.rescale_size = 320
108 | CONFIG.test.max_size = 1920
109 | CONFIG.test.patch_size = 320
110 | CONFIG.test.checkpoint = None
111 | CONFIG.test.save = False
112 | CONFIG.test.save_dir = None
113 | CONFIG.test.cascade = False
114 | # "best_model" or "latest_model" or other base name of the pth file.
115 |
116 |
117 | # ===============================================================================
118 | # Logging config
119 | # ===============================================================================
120 | CONFIG.log = EasyDict({})
121 | CONFIG.log.log_dir = None
122 | CONFIG.log.viz_dir = None
123 | CONFIG.log.save_frq = 2000
124 | CONFIG.log.print_frq = 20
125 | CONFIG.log.test_frq = 1
126 | CONFIG.log.viz = True
127 | CONFIG.log.show_all = True
128 |
129 |
130 | # ===============================================================================
131 | # util functions
132 | # ===============================================================================
133 | def parse_config(custom_config, default_config=CONFIG, prefix="CONFIG"):
134 | """
135 | This function will recursively overwrite the default config by a custom config
136 | :param default_config:
137 | :param custom_config: parsed from config/config.toml
138 | :param prefix: prefix for config key
139 | :return: None
140 | """
141 | if "is_default" in default_config:
142 | default_config.is_default = False
143 |
144 | for key in custom_config.keys():
145 | full_key = ".".join([prefix, key])
146 | if key not in default_config:
147 | raise NotImplementedError("Unknown config key: {}".format(full_key))
148 | elif isinstance(custom_config[key], dict):
149 | if isinstance(default_config[key], dict):
150 | parse_config(default_config=default_config[key],
151 | custom_config=custom_config[key],
152 | prefix=full_key)
153 | else:
154 | raise ValueError("{}: Expected {}, got dict instead.".format(full_key, type(custom_config[key])))
155 | else:
156 | if isinstance(default_config[key], dict):
157 | raise ValueError("{}: Expected dict, got {} instead.".format(full_key, type(custom_config[key])))
158 | else:
159 | default_config[key] = custom_config[key]
160 |
161 |
162 | def load_config(config_path):
163 | import toml
164 | with open(config_path) as fp:
165 | custom_config = EasyDict(toml.load(fp))
166 | parse_config(custom_config=custom_config)
167 | return CONFIG
168 |
169 |
170 | if __name__ == "__main__":
171 | from pprint import pprint
172 |
173 | pprint(CONFIG)
174 | load_config("../config/example.toml")
175 | pprint(CONFIG)
176 |
--------------------------------------------------------------------------------
/utils/viz_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import torch
5 | import numpy as np
6 | import torch.nn.functional as F
7 |
8 |
9 | def grid_images(pred_dict, input_dict, show_all=False):
10 | lr_image = input_dict['lr_image'] * 0.5 + 0.5
11 | lr_label = input_dict['lr_label_mat'].expand_as(lr_image)
12 | lr_mask = input_dict['lr_label_unk'].expand_as(lr_image)
13 | lr_pred = pred_dict['coarse']
14 | if lr_pred.shape[2:] != lr_image.shape[2:]:
15 | lr_pred = F.interpolate(lr_pred, lr_image.shape[2:], mode="bilinear", align_corners=False)
16 | lr_pred = lr_pred.expand_as(lr_image)
17 |
18 | h, w = lr_image.size(2), lr_image.size(3)
19 |
20 | tmps = []
21 | if show_all:
22 | extra_keys = ['global_seg', 'global_mat', 'errormap', 'classmap']
23 | for key in extra_keys:
24 | if key in pred_dict:
25 | if key == 'errormap':
26 | tmp = pred_dict[key]
27 | if tmp.size(2) != h or tmp.size(3) != w:
28 | tmp = F.interpolate(tmp, (h,w), mode='nearest')
29 | elif key == 'classmap':
30 | tmp = torch.argmax(pred_dict[key], dim=1, keepdim=True).float() / 2.
31 | # if tmp.min() < 0:
32 | # tmp = (tmp + 1) / 2.
33 | # if tmp.size(2) != h or tmp.size(3) != w:
34 | # tmp = F.interpolate(tmp, (h,w), mode='nearest')
35 | # tmp = tmp.repeat(1,3,1,1).float()
36 | else:
37 | tmp = pred_dict[key][0]
38 | if tmp.size(2) != h or tmp.size(3) != w:
39 | tmp = F.interpolate(tmp, (h,w), mode='bilinear', align_corners=False)
40 | tmp = tmp.expand_as(lr_image)
41 | tmps.append(tmp)
42 |
43 | if 'fine' in pred_dict:
44 | hr_image = input_dict['hr_image'] * 0.5 + 0.5
45 | hr_label = input_dict['hr_label_mat']
46 | hr_pred = pred_dict['fine'].expand_as(hr_image)
47 | if hr_image.size(2) != h or hr_image.size(3) != w:
48 | hr_image = F.interpolate(hr_image, (h,w), mode='bilinear', align_corners=False)
49 | hr_label = F.interpolate(hr_label, (h,w), mode='bilinear', align_corners=False)
50 | hr_pred = F.interpolate(hr_pred, (h,w), mode='bilinear', align_corners=False)
51 | hr_label = hr_label.expand_as(hr_image)
52 | grid = torch.cat([lr_image, lr_label, lr_mask, lr_pred]+tmps+[hr_image, hr_label, hr_pred], dim=3)
53 | else:
54 | grid = torch.cat([lr_image, lr_label, lr_mask, lr_pred]+tmps, dim=3)
55 | grid = F.interpolate(grid, scale_factor=0.5, mode='bilinear', align_corners=False)
56 | n,c,h,w = grid.size()
57 | grid = grid.permute(1,0,2,3).contiguous().view(c,n*h,w)
58 | # np_img = cv2.cvtColor(np.transpose(grid.data.cpu().numpy(), (1,2,0))*255, cv2.COLOR_RGB2BGR)
59 | # cv2.imwrite('tmp/tmp.png', np_img)
60 | return grid
61 |
62 |
63 | def save_preds(viz_dir, img, lbl, res1, res2):
64 | res1 = F.interpolate(res1, (res2.size(2), res2.size(3)))
65 | img_color = np.transpose(img.data.cpu().numpy(), (0,2,3,1))[:,:,:,::-1]*255
66 | lbl_color = np.tile(lbl.squeeze().data.cpu().numpy()[:,:,:,None], (1,1,1,3)) * 255
67 | res1_color = np.tile(torch.clamp(res1,0,1).squeeze().data.cpu().numpy()[:,:,:,None], (1,1,1,3)) * 255
68 | res2_color = np.tile(torch.clamp(res2,0,1).squeeze().data.cpu().numpy()[:,:,:,None], (1,1,1,3)) * 255
69 | shows = []
70 | for i in range(img_color.shape[0]):
71 | shows.append(np.concatenate((img_color[i], lbl_color[i], res1_color[i], res2_color[i]), axis=1))
72 | shows = np.concatenate(shows, axis=0)
73 | ratio = 1200.0 / shows.shape[1]
74 | shows = cv2.resize(shows, None, fx=ratio, fy=ratio)
75 | cv2.imwrite(os.path.join(viz_dir,"viz.png"), shows)
76 |
77 |
78 | def save_labels(labels, save_dir):
79 | n, c, h, w = labels.shape
80 | labels = labels[:,0].data.cpu().numpy()
81 |
82 | template = np.zeros((h*n,w,3))
83 |
84 | for i in range(n):
85 | label_color = idx_to_colormat(labels[i])
86 | template[h*i:h*(i+1), w*0:w*1] = label_color
87 |
88 | cv2.imwrite(os.path.join(save_dir, "viz.png"), template)
89 |
90 |
91 | def save_raw_labels(labels, save_dir):
92 | h, w = labels.shape[:2]
93 | template = np.zeros((h,w,3))
94 |
95 | label_color = idx_to_colormat(labels)
96 | cv2.imwrite(os.path.join(save_dir, "viz.png"), label_color)
97 |
--------------------------------------------------------------------------------