├── .gitignore
├── 6.jpg
├── FaceParsing.py
├── LICENSE
├── README.md
├── __init__.py
├── evaluate.py
├── face_dataset.py
├── hair.png
├── logger.py
├── loss.py
├── makeup.py
├── makeup
├── 116_1.png
├── 116_3.png
├── 116_lip_ori.png
└── 116_ori.png
├── model.py
├── modules
├── __init__.py
├── bn.py
├── deeplab.py
├── dense.py
├── functions.py
├── misc.py
├── residual.py
└── src
│ ├── checks.h
│ ├── inplace_abn.cpp
│ ├── inplace_abn.h
│ ├── inplace_abn_cpu.cpp
│ ├── inplace_abn_cuda.cu
│ ├── inplace_abn_cuda_half.cu
│ └── utils
│ ├── checks.h
│ ├── common.h
│ └── cuda.cuh
├── optimizer.py
├── prepropess_data.py
├── resnet.py
├── test.py
├── train.py
└── transform.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # res
107 | res/
108 |
109 | .idea/
110 |
111 |
--------------------------------------------------------------------------------
/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/6.jpg
--------------------------------------------------------------------------------
/FaceParsing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | import collections
6 | import torchvision.transforms as transforms
7 |
8 | from face_parsing.model import BiSeNet
9 |
10 | class FaceParsing(object):
11 | def __init__(self, model_path=None):
12 | if model_path is None:
13 | # REVIEW this is a terrible default
14 | model_path = '../../../external/data/models/face_parsing/face_parsing_79999_iter.pth'
15 |
16 | self.net = BiSeNet(n_classes=19)
17 | self.net.load_state_dict(torch.load(model_path, map_location='cpu'))
18 | self.net.eval()
19 |
20 | self.transform = transforms.Compose([
21 | transforms.ToTensor(),
22 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
23 | ])
24 |
25 | self.label_by_idx = collections.OrderedDict(
26 | [(-1, 'unlabeled'), (0, 'background'), (1, 'skin'),
27 | (2, 'l_brow'), (3, 'r_brow'), (4, 'l_eye'), (5, 'r_eye'),
28 | (6, 'eye_g (eye glasses)'), (7, 'l_ear'), (8, 'r_ear'), (9, 'ear_r (ear ring)'),
29 | (10, 'nose'), (11, 'mouth'), (12, 'u_lip'), (13, 'l_lip'),
30 | (14, 'neck'), (15, 'neck_l (necklace)'), (16, 'cloth'),
31 | (17, 'hair'), (18, 'hat')])
32 | self.idx_by_label = {v: k for k,v in self.label_by_idx.items()}
33 |
34 | def to_tensor(self, images):
35 | # images : N,H,W,C numpy.array
36 | return self.transform(images)
37 |
38 | def label_for_idx(idx):
39 | return self.label_by_idx[idx]
40 |
41 | def idx_for_label(label):
42 | return self.idx_by_label[label]
43 |
44 | # TODO rename to label_image to contrast with label, which takes a tensor
45 | def parse_face(self, images, device=0):
46 | # images : list of square PIL Images
47 | # device : which CUDA device to run on
48 | #
49 | # returns parsings : list of PIL Images
50 | #
51 | # 0 - background, 1 - skin, 2 - l_brow, 3 - r_brow, 4 - l_eye, 5 - r_eye, 6 - eye_g (eye glasses),
52 | # 7 - l_ear, 8 - r_ear, 9 - ear_r (ear ring), 10 - nose, 11 - mouth, 12 - u_lip, 13 - l_lip,
53 | # 14 - neck, 15 - neck_l (necklace), 16 - cloth, 17 - hair, 18 - hat
54 |
55 | # move the network to the correct device
56 | # REVIEW this makes it imposible to run on the cpu...
57 | self.net.to('cuda:{}'.format(device))
58 |
59 | assert all(im.size[0] == im.size[1] for im in images)
60 | in_sizes = [im.size[0] for im in images] # im is square
61 |
62 | pt_images = []
63 | for img in images:
64 | # seems to work best with images around 512
65 | img = img.resize((512, 512), Image.BILINEAR)
66 | img = self.to_tensor(img)
67 | pt_images.append(img)
68 | pt_images = torch.stack(pt_images, dim=0)
69 |
70 | # move the data to the device
71 | pt_images = pt_images.to('cuda:{}'.format(device))
72 |
73 | out = self.label(pt_images)
74 | parsings = out.cpu().numpy().argmax(axis=1).astype(np.uint8)
75 |
76 | parsings = [Image.fromarray(parsing).resize((in_size, in_size), Image.NEAREST)
77 | for parsing, in_size in zip(parsings, in_sizes)]
78 |
79 | return parsings # list of PIL Images
80 |
81 | def label(self, pt_images):
82 | # N,H,W,C torch.tensor
83 | with torch.no_grad():
84 | return self.net(pt_images)[0]
85 |
86 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 zll
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # face-parsing.PyTorch
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ### Contents
10 | - [Training](#training)
11 | - [Demo](#Demo)
12 | - [References](#references)
13 |
14 | ## Training
15 |
16 | 1. Prepare training data:
17 | -- download [CelebAMask-HQ dataset](https://github.com/switchablenorms/CelebAMask-HQ)
18 |
19 | -- change file path in the `prepropess_data.py` and run
20 | ```Shell
21 | python prepropess_data.py
22 | ```
23 |
24 | 2. Train the model using CelebAMask-HQ dataset:
25 | Just run the train script:
26 | ```
27 | $ CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
28 | ```
29 |
30 | If you do not wish to train the model, you can download [our pre-trained model](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812) and save it in `res/cp`.
31 |
32 |
33 | ## Demo
34 | 1. Evaluate the trained model using:
35 | ```Shell
36 | # evaluate using GPU
37 | python test.py
38 | ```
39 |
40 | ## Face makeup using parsing maps
41 | [**face-makeup.PyTorch**](https://github.com/zllrunning/face-makeup.PyTorch)
42 |
43 |
44 |
45 | |
46 | Hair |
47 | Lip |
48 |
49 |
50 |
51 |
52 | Original Input |
53 |  |
54 |  |
55 |
56 |
57 |
58 |
59 | Color |
60 |  |
61 |  |
62 |
63 |
64 |
65 |
66 |
67 | ## References
68 | - [BiSeNet](https://github.com/CoinCheung/BiSeNet)
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/__init__.py
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | from .logger import setup_logger
5 | from .model import BiSeNet
6 | from .face_dataset import FaceMask
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | import torch.nn.functional as F
12 | import torch.distributed as dist
13 |
14 | import os
15 | import os.path as osp
16 | import logging
17 | import time
18 | import numpy as np
19 | from tqdm import tqdm
20 | import math
21 | from PIL import Image
22 | import torchvision.transforms as transforms
23 | import cv2
24 |
25 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
26 | # Colors for all 20 parts
27 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
28 | [255, 0, 85], [255, 0, 170],
29 | [0, 255, 0], [85, 255, 0], [170, 255, 0],
30 | [0, 255, 85], [0, 255, 170],
31 | [0, 0, 255], [85, 0, 255], [170, 0, 255],
32 | [0, 85, 255], [0, 170, 255],
33 | [255, 255, 0], [255, 255, 85], [255, 255, 170],
34 | [255, 0, 255], [255, 85, 255], [255, 170, 255],
35 | [0, 255, 255], [85, 255, 255], [170, 255, 255]]
36 |
37 | im = np.array(im)
38 | vis_im = im.copy().astype(np.uint8)
39 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
40 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
41 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
42 |
43 | num_of_class = np.max(vis_parsing_anno)
44 |
45 | for pi in range(1, num_of_class + 1):
46 | index = np.where(vis_parsing_anno == pi)
47 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
48 |
49 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
50 | # print(vis_parsing_anno_color.shape, vis_im.shape)
51 | vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
52 |
53 | # Save result or not
54 | if save_im:
55 | cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
56 |
57 | # return vis_im
58 |
59 | def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
60 |
61 | if not os.path.exists(respth):
62 | os.makedirs(respth)
63 |
64 | n_classes = 19
65 | net = BiSeNet(n_classes=n_classes)
66 | net.cuda()
67 | save_pth = osp.join('res/cp', cp)
68 | net.load_state_dict(torch.load(save_pth))
69 | net.eval()
70 |
71 | to_tensor = transforms.Compose([
72 | transforms.ToTensor(),
73 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
74 | ])
75 | with torch.no_grad():
76 | for image_path in os.listdir(dspth):
77 | img = Image.open(osp.join(dspth, image_path))
78 | image = img.resize((512, 512), Image.BILINEAR)
79 | img = to_tensor(image)
80 | img = torch.unsqueeze(img, 0)
81 | img = img.cuda()
82 | out = net(img)[0]
83 | parsing = out.squeeze(0).cpu().numpy().argmax(0)
84 |
85 | vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 | if __name__ == "__main__":
94 | setup_logger('./res')
95 | evaluate()
96 |
--------------------------------------------------------------------------------
/face_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 | import torchvision.transforms as transforms
7 |
8 | import os.path as osp
9 | import os
10 | from PIL import Image
11 | import numpy as np
12 | import json
13 | import cv2
14 |
15 | from .transform import *
16 |
17 |
18 |
19 | class FaceMask(Dataset):
20 | def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs):
21 | super(FaceMask, self).__init__(*args, **kwargs)
22 | assert mode in ('train', 'val', 'test')
23 | self.mode = mode
24 | self.ignore_lb = 255
25 | self.rootpth = rootpth
26 |
27 | self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img'))
28 |
29 | # pre-processing
30 | self.to_tensor = transforms.Compose([
31 | transforms.ToTensor(),
32 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
33 | ])
34 | self.trans_train = Compose([
35 | ColorJitter(
36 | brightness=0.5,
37 | contrast=0.5,
38 | saturation=0.5),
39 | HorizontalFlip(),
40 | RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
41 | RandomCrop(cropsize)
42 | ])
43 |
44 | def __getitem__(self, idx):
45 | impth = self.imgs[idx]
46 | img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth))
47 | img = img.resize((512, 512), Image.BILINEAR)
48 | label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P')
49 | # print(np.unique(np.array(label)))
50 | if self.mode == 'train':
51 | im_lb = dict(im=img, lb=label)
52 | im_lb = self.trans_train(im_lb)
53 | img, label = im_lb['im'], im_lb['lb']
54 | img = self.to_tensor(img)
55 | label = np.array(label).astype(np.int64)[np.newaxis, :]
56 | return img, label
57 |
58 | def __len__(self):
59 | return len(self.imgs)
60 |
61 |
62 | if __name__ == "__main__":
63 | face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
64 | face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
65 | mask_path = '/home/zll/data/CelebAMask-HQ/mask'
66 | counter = 0
67 | total = 0
68 | for i in range(15):
69 | # files = os.listdir(osp.join(face_sep_mask, str(i)))
70 |
71 | atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
72 | 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
73 |
74 | for j in range(i*2000, (i+1)*2000):
75 |
76 | mask = np.zeros((512, 512))
77 |
78 | for l, att in enumerate(atts, 1):
79 | total += 1
80 | file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
81 | path = osp.join(face_sep_mask, str(i), file_name)
82 |
83 | if os.path.exists(path):
84 | counter += 1
85 | sep_mask = np.array(Image.open(path).convert('P'))
86 | # print(np.unique(sep_mask))
87 |
88 | mask[sep_mask == 225] = l
89 | cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
90 | print(j)
91 |
92 | print(counter, total)
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/hair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/hair.png
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import os.path as osp
6 | import time
7 | import sys
8 | import logging
9 |
10 | import torch.distributed as dist
11 |
12 |
13 | def setup_logger(logpth):
14 | logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
15 | logfile = osp.join(logpth, logfile)
16 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
17 | log_level = logging.INFO
18 | if dist.is_initialized() and not dist.get_rank()==0:
19 | log_level = logging.ERROR
20 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
21 | logging.root.addHandler(logging.StreamHandler())
22 |
23 |
24 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import numpy as np
10 |
11 |
12 | class OhemCELoss(nn.Module):
13 | def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
14 | super(OhemCELoss, self).__init__()
15 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
16 | self.n_min = n_min
17 | self.ignore_lb = ignore_lb
18 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
19 |
20 | def forward(self, logits, labels):
21 | N, C, H, W = logits.size()
22 | loss = self.criteria(logits, labels).view(-1)
23 | loss, _ = torch.sort(loss, descending=True)
24 | if loss[self.n_min] > self.thresh:
25 | loss = loss[loss>self.thresh]
26 | else:
27 | loss = loss[:self.n_min]
28 | return torch.mean(loss)
29 |
30 |
31 | class SoftmaxFocalLoss(nn.Module):
32 | def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
33 | super(SoftmaxFocalLoss, self).__init__()
34 | self.gamma = gamma
35 | self.nll = nn.NLLLoss(ignore_index=ignore_lb)
36 |
37 | def forward(self, logits, labels):
38 | scores = F.softmax(logits, dim=1)
39 | factor = torch.pow(1.-scores, self.gamma)
40 | log_score = F.log_softmax(logits, dim=1)
41 | log_score = factor * log_score
42 | loss = self.nll(log_score, labels)
43 | return loss
44 |
45 |
46 | if __name__ == '__main__':
47 | torch.manual_seed(15)
48 | criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
49 | criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
50 | net1 = nn.Sequential(
51 | nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
52 | )
53 | net1.cuda()
54 | net1.train()
55 | net2 = nn.Sequential(
56 | nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
57 | )
58 | net2.cuda()
59 | net2.train()
60 |
61 | with torch.no_grad():
62 | inten = torch.randn(16, 3, 20, 20).cuda()
63 | lbs = torch.randint(0, 19, [16, 20, 20]).cuda()
64 | lbs[1, :, :] = 255
65 |
66 | logits1 = net1(inten)
67 | logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
68 | logits2 = net2(inten)
69 | logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')
70 |
71 | loss1 = criteria1(logits1, lbs)
72 | loss2 = criteria2(logits2, lbs)
73 | loss = loss1 + loss2
74 | print(loss.detach().cpu())
75 | loss.backward()
76 |
--------------------------------------------------------------------------------
/makeup.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | from skimage.filters import gaussian
5 |
6 |
7 | def sharpen(img):
8 | img = img * 1.0
9 | gauss_out = gaussian(img, sigma=5, multichannel=True)
10 |
11 | alpha = 1.5
12 | img_out = (img - gauss_out) * alpha + img
13 |
14 | img_out = img_out / 255.0
15 |
16 | mask_1 = img_out < 0
17 | mask_2 = img_out > 1
18 |
19 | img_out = img_out * (1 - mask_1)
20 | img_out = img_out * (1 - mask_2) + mask_2
21 | img_out = np.clip(img_out, 0, 1)
22 | img_out = img_out * 255
23 | return np.array(img_out, dtype=np.uint8)
24 |
25 |
26 | def hair(image, parsing, part=17, color=[230, 50, 20]):
27 | b, g, r = color #[10, 50, 250] # [10, 250, 10]
28 | tar_color = np.zeros_like(image)
29 | tar_color[:, :, 0] = b
30 | tar_color[:, :, 1] = g
31 | tar_color[:, :, 2] = r
32 |
33 | image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
34 | tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
35 |
36 | if part == 12 or part == 13:
37 | image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2]
38 | else:
39 | image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
40 |
41 | changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
42 |
43 | if part == 17:
44 | changed = sharpen(changed)
45 |
46 | changed[parsing != part] = image[parsing != part]
47 | # changed = cv2.resize(changed, (512, 512))
48 | return changed
49 |
50 | #
51 | # def lip(image, parsing, part=17, color=[230, 50, 20]):
52 | # b, g, r = color #[10, 50, 250] # [10, 250, 10]
53 | # tar_color = np.zeros_like(image)
54 | # tar_color[:, :, 0] = b
55 | # tar_color[:, :, 1] = g
56 | # tar_color[:, :, 2] = r
57 | #
58 | # image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
59 | # il, ia, ib = cv2.split(image_lab)
60 | #
61 | # tar_lab = cv2.cvtColor(tar_color, cv2.COLOR_BGR2Lab)
62 | # tl, ta, tb = cv2.split(tar_lab)
63 | #
64 | # image_lab[:, :, 0] = np.clip(il - np.mean(il) + tl, 0, 100)
65 | # image_lab[:, :, 1] = np.clip(ia - np.mean(ia) + ta, -127, 128)
66 | # image_lab[:, :, 2] = np.clip(ib - np.mean(ib) + tb, -127, 128)
67 | #
68 | #
69 | # changed = cv2.cvtColor(image_lab, cv2.COLOR_Lab2BGR)
70 | #
71 | # if part == 17:
72 | # changed = sharpen(changed)
73 | #
74 | # changed[parsing != part] = image[parsing != part]
75 | # # changed = cv2.resize(changed, (512, 512))
76 | # return changed
77 |
78 |
79 | if __name__ == '__main__':
80 | # 1 face
81 | # 10 nose
82 | # 11 teeth
83 | # 12 upper lip
84 | # 13 lower lip
85 | # 17 hair
86 | num = 116
87 | table = {
88 | 'hair': 17,
89 | 'upper_lip': 12,
90 | 'lower_lip': 13
91 | }
92 | image_path = '/home/zll/data/CelebAMask-HQ/test-img/{}.jpg'.format(num)
93 | parsing_path = 'res/test_res/{}.png'.format(num)
94 |
95 | image = cv2.imread(image_path)
96 | ori = image.copy()
97 | parsing = np.array(cv2.imread(parsing_path, 0))
98 | parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST)
99 |
100 | parts = [table['hair'], table['upper_lip'], table['lower_lip']]
101 | # colors = [[20, 20, 200], [100, 100, 230], [100, 100, 230]]
102 | colors = [[100, 200, 100]]
103 | for part, color in zip(parts, colors):
104 | image = hair(image, parsing, part, color)
105 | cv2.imwrite('res/makeup/116_ori.png', cv2.resize(ori, (512, 512)))
106 | cv2.imwrite('res/makeup/116_2.png', cv2.resize(image, (512, 512)))
107 |
108 | cv2.imshow('image', cv2.resize(ori, (512, 512)))
109 | cv2.imshow('color', cv2.resize(image, (512, 512)))
110 |
111 | # cv2.imshow('image', ori)
112 | # cv2.imshow('color', image)
113 |
114 | cv2.waitKey(0)
115 | cv2.destroyAllWindows()
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/makeup/116_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/makeup/116_1.png
--------------------------------------------------------------------------------
/makeup/116_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/makeup/116_3.png
--------------------------------------------------------------------------------
/makeup/116_lip_ori.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/makeup/116_lip_ori.png
--------------------------------------------------------------------------------
/makeup/116_ori.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/makeup/116_ori.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision
9 |
10 | from .resnet import Resnet18
11 | # from modules.bn import InPlaceABNSync as BatchNorm2d
12 |
13 |
14 | class ConvBNReLU(nn.Module):
15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16 | super(ConvBNReLU, self).__init__()
17 | self.conv = nn.Conv2d(in_chan,
18 | out_chan,
19 | kernel_size = ks,
20 | stride = stride,
21 | padding = padding,
22 | bias = False)
23 | self.bn = nn.BatchNorm2d(out_chan)
24 | self.init_weight()
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | x = F.relu(self.bn(x))
29 | return x
30 |
31 | def init_weight(self):
32 | for ly in self.children():
33 | if isinstance(ly, nn.Conv2d):
34 | nn.init.kaiming_normal_(ly.weight, a=1)
35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36 |
37 | class BiSeNetOutput(nn.Module):
38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39 | super(BiSeNetOutput, self).__init__()
40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42 | self.init_weight()
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 | x = self.conv_out(x)
47 | return x
48 |
49 | def init_weight(self):
50 | for ly in self.children():
51 | if isinstance(ly, nn.Conv2d):
52 | nn.init.kaiming_normal_(ly.weight, a=1)
53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54 |
55 | def get_params(self):
56 | wd_params, nowd_params = [], []
57 | for name, module in self.named_modules():
58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59 | wd_params.append(module.weight)
60 | if not module.bias is None:
61 | nowd_params.append(module.bias)
62 | elif isinstance(module, nn.BatchNorm2d):
63 | nowd_params += list(module.parameters())
64 | return wd_params, nowd_params
65 |
66 |
67 | class AttentionRefinementModule(nn.Module):
68 | def __init__(self, in_chan, out_chan, *args, **kwargs):
69 | super(AttentionRefinementModule, self).__init__()
70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72 | self.bn_atten = nn.BatchNorm2d(out_chan)
73 | self.sigmoid_atten = nn.Sigmoid()
74 | self.init_weight()
75 |
76 | def forward(self, x):
77 | feat = self.conv(x)
78 | atten = F.avg_pool2d(feat, feat.size()[2:])
79 | atten = self.conv_atten(atten)
80 | atten = self.bn_atten(atten)
81 | atten = self.sigmoid_atten(atten)
82 | out = torch.mul(feat, atten)
83 | return out
84 |
85 | def init_weight(self):
86 | for ly in self.children():
87 | if isinstance(ly, nn.Conv2d):
88 | nn.init.kaiming_normal_(ly.weight, a=1)
89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90 |
91 |
92 | class ContextPath(nn.Module):
93 | def __init__(self, *args, **kwargs):
94 | super(ContextPath, self).__init__()
95 | self.resnet = Resnet18()
96 | self.arm16 = AttentionRefinementModule(256, 128)
97 | self.arm32 = AttentionRefinementModule(512, 128)
98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101 |
102 | self.init_weight()
103 |
104 | def forward(self, x):
105 | H0, W0 = x.size()[2:]
106 | feat8, feat16, feat32 = self.resnet(x)
107 | H8, W8 = feat8.size()[2:]
108 | H16, W16 = feat16.size()[2:]
109 | H32, W32 = feat32.size()[2:]
110 |
111 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
112 | avg = self.conv_avg(avg)
113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114 |
115 | feat32_arm = self.arm32(feat32)
116 | feat32_sum = feat32_arm + avg_up
117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118 | feat32_up = self.conv_head32(feat32_up)
119 |
120 | feat16_arm = self.arm16(feat16)
121 | feat16_sum = feat16_arm + feat32_up
122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123 | feat16_up = self.conv_head16(feat16_up)
124 |
125 | return feat8, feat16_up, feat32_up # x8, x8, x16
126 |
127 | def init_weight(self):
128 | for ly in self.children():
129 | if isinstance(ly, nn.Conv2d):
130 | nn.init.kaiming_normal_(ly.weight, a=1)
131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132 |
133 | def get_params(self):
134 | wd_params, nowd_params = [], []
135 | for name, module in self.named_modules():
136 | if isinstance(module, (nn.Linear, nn.Conv2d)):
137 | wd_params.append(module.weight)
138 | if not module.bias is None:
139 | nowd_params.append(module.bias)
140 | elif isinstance(module, nn.BatchNorm2d):
141 | nowd_params += list(module.parameters())
142 | return wd_params, nowd_params
143 |
144 |
145 | ### This is not used, since I replace this with the resnet feature with the same size
146 | class SpatialPath(nn.Module):
147 | def __init__(self, *args, **kwargs):
148 | super(SpatialPath, self).__init__()
149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153 | self.init_weight()
154 |
155 | def forward(self, x):
156 | feat = self.conv1(x)
157 | feat = self.conv2(feat)
158 | feat = self.conv3(feat)
159 | feat = self.conv_out(feat)
160 | return feat
161 |
162 | def init_weight(self):
163 | for ly in self.children():
164 | if isinstance(ly, nn.Conv2d):
165 | nn.init.kaiming_normal_(ly.weight, a=1)
166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167 |
168 | def get_params(self):
169 | wd_params, nowd_params = [], []
170 | for name, module in self.named_modules():
171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172 | wd_params.append(module.weight)
173 | if not module.bias is None:
174 | nowd_params.append(module.bias)
175 | elif isinstance(module, nn.BatchNorm2d):
176 | nowd_params += list(module.parameters())
177 | return wd_params, nowd_params
178 |
179 |
180 | class FeatureFusionModule(nn.Module):
181 | def __init__(self, in_chan, out_chan, *args, **kwargs):
182 | super(FeatureFusionModule, self).__init__()
183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184 | self.conv1 = nn.Conv2d(out_chan,
185 | out_chan//4,
186 | kernel_size = 1,
187 | stride = 1,
188 | padding = 0,
189 | bias = False)
190 | self.conv2 = nn.Conv2d(out_chan//4,
191 | out_chan,
192 | kernel_size = 1,
193 | stride = 1,
194 | padding = 0,
195 | bias = False)
196 | self.relu = nn.ReLU(inplace=True)
197 | self.sigmoid = nn.Sigmoid()
198 | self.init_weight()
199 |
200 | def forward(self, fsp, fcp):
201 | fcat = torch.cat([fsp, fcp], dim=1)
202 | feat = self.convblk(fcat)
203 | atten = F.avg_pool2d(feat, feat.size()[2:])
204 | atten = self.conv1(atten)
205 | atten = self.relu(atten)
206 | atten = self.conv2(atten)
207 | atten = self.sigmoid(atten)
208 | feat_atten = torch.mul(feat, atten)
209 | feat_out = feat_atten + feat
210 | return feat_out
211 |
212 | def init_weight(self):
213 | for ly in self.children():
214 | if isinstance(ly, nn.Conv2d):
215 | nn.init.kaiming_normal_(ly.weight, a=1)
216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217 |
218 | def get_params(self):
219 | wd_params, nowd_params = [], []
220 | for name, module in self.named_modules():
221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222 | wd_params.append(module.weight)
223 | if not module.bias is None:
224 | nowd_params.append(module.bias)
225 | elif isinstance(module, nn.BatchNorm2d):
226 | nowd_params += list(module.parameters())
227 | return wd_params, nowd_params
228 |
229 |
230 | class BiSeNet(nn.Module):
231 | def __init__(self, n_classes, *args, **kwargs):
232 | super(BiSeNet, self).__init__()
233 | self.cp = ContextPath()
234 | ## here self.sp is deleted
235 | self.ffm = FeatureFusionModule(256, 256)
236 | self.conv_out = BiSeNetOutput(256, 256, n_classes)
237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239 | self.init_weight()
240 |
241 | def forward(self, x):
242 | H, W = x.size()[2:]
243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245 | feat_fuse = self.ffm(feat_sp, feat_cp8)
246 |
247 | feat_out = self.conv_out(feat_fuse)
248 | feat_out16 = self.conv_out16(feat_cp8)
249 | feat_out32 = self.conv_out32(feat_cp16)
250 |
251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254 | return feat_out, feat_out16, feat_out32
255 |
256 | def init_weight(self):
257 | for ly in self.children():
258 | if isinstance(ly, nn.Conv2d):
259 | nn.init.kaiming_normal_(ly.weight, a=1)
260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261 |
262 | def get_params(self):
263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264 | for name, child in self.named_children():
265 | child_wd_params, child_nowd_params = child.get_params()
266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267 | lr_mul_wd_params += child_wd_params
268 | lr_mul_nowd_params += child_nowd_params
269 | else:
270 | wd_params += child_wd_params
271 | nowd_params += child_nowd_params
272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273 |
274 |
275 | if __name__ == "__main__":
276 | net = BiSeNet(19)
277 | net.cuda()
278 | net.eval()
279 | in_ten = torch.randn(16, 3, 640, 480).cuda()
280 | out, out16, out32 = net(in_ten)
281 | print(out.shape)
282 |
283 | net.get_params()
284 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .bn import ABN, InPlaceABN, InPlaceABNSync
2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3 | from .misc import GlobalAvgPool2d, SingleGPU
4 | from .residual import IdentityResidualBlock
5 | from .dense import DenseModule
6 |
--------------------------------------------------------------------------------
/modules/bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as functional
4 |
5 | try:
6 | from queue import Queue
7 | except ImportError:
8 | from Queue import Queue
9 |
10 | from .functions import *
11 |
12 |
13 | class ABN(nn.Module):
14 | """Activated Batch Normalization
15 |
16 | This gathers a `BatchNorm2d` and an activation function in a single module
17 | """
18 |
19 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
20 | """Creates an Activated Batch Normalization module
21 |
22 | Parameters
23 | ----------
24 | num_features : int
25 | Number of feature channels in the input and output.
26 | eps : float
27 | Small constant to prevent numerical issues.
28 | momentum : float
29 | Momentum factor applied to compute running statistics as.
30 | affine : bool
31 | If `True` apply learned scale and shift transformation after normalization.
32 | activation : str
33 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
34 | slope : float
35 | Negative slope for the `leaky_relu` activation.
36 | """
37 | super(ABN, self).__init__()
38 | self.num_features = num_features
39 | self.affine = affine
40 | self.eps = eps
41 | self.momentum = momentum
42 | self.activation = activation
43 | self.slope = slope
44 | if self.affine:
45 | self.weight = nn.Parameter(torch.ones(num_features))
46 | self.bias = nn.Parameter(torch.zeros(num_features))
47 | else:
48 | self.register_parameter('weight', None)
49 | self.register_parameter('bias', None)
50 | self.register_buffer('running_mean', torch.zeros(num_features))
51 | self.register_buffer('running_var', torch.ones(num_features))
52 | self.reset_parameters()
53 |
54 | def reset_parameters(self):
55 | nn.init.constant_(self.running_mean, 0)
56 | nn.init.constant_(self.running_var, 1)
57 | if self.affine:
58 | nn.init.constant_(self.weight, 1)
59 | nn.init.constant_(self.bias, 0)
60 |
61 | def forward(self, x):
62 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
63 | self.training, self.momentum, self.eps)
64 |
65 | if self.activation == ACT_RELU:
66 | return functional.relu(x, inplace=True)
67 | elif self.activation == ACT_LEAKY_RELU:
68 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
69 | elif self.activation == ACT_ELU:
70 | return functional.elu(x, inplace=True)
71 | else:
72 | return x
73 |
74 | def __repr__(self):
75 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
76 | ' affine={affine}, activation={activation}'
77 | if self.activation == "leaky_relu":
78 | rep += ', slope={slope})'
79 | else:
80 | rep += ')'
81 | return rep.format(name=self.__class__.__name__, **self.__dict__)
82 |
83 |
84 | class InPlaceABN(ABN):
85 | """InPlace Activated Batch Normalization"""
86 |
87 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
88 | """Creates an InPlace Activated Batch Normalization module
89 |
90 | Parameters
91 | ----------
92 | num_features : int
93 | Number of feature channels in the input and output.
94 | eps : float
95 | Small constant to prevent numerical issues.
96 | momentum : float
97 | Momentum factor applied to compute running statistics as.
98 | affine : bool
99 | If `True` apply learned scale and shift transformation after normalization.
100 | activation : str
101 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
102 | slope : float
103 | Negative slope for the `leaky_relu` activation.
104 | """
105 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
106 |
107 | def forward(self, x):
108 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
109 | self.training, self.momentum, self.eps, self.activation, self.slope)
110 |
111 |
112 | class InPlaceABNSync(ABN):
113 | """InPlace Activated Batch Normalization with cross-GPU synchronization
114 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
115 | """
116 |
117 | def forward(self, x):
118 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
119 | self.training, self.momentum, self.eps, self.activation, self.slope)
120 |
121 | def __repr__(self):
122 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
123 | ' affine={affine}, activation={activation}'
124 | if self.activation == "leaky_relu":
125 | rep += ', slope={slope})'
126 | else:
127 | rep += ')'
128 | return rep.format(name=self.__class__.__name__, **self.__dict__)
129 |
130 |
131 |
--------------------------------------------------------------------------------
/modules/deeplab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as functional
4 |
5 | from models._util import try_index
6 | from .bn import ABN
7 |
8 |
9 | class DeeplabV3(nn.Module):
10 | def __init__(self,
11 | in_channels,
12 | out_channels,
13 | hidden_channels=256,
14 | dilations=(12, 24, 36),
15 | norm_act=ABN,
16 | pooling_size=None):
17 | super(DeeplabV3, self).__init__()
18 | self.pooling_size = pooling_size
19 |
20 | self.map_convs = nn.ModuleList([
21 | nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
22 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
23 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
24 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
25 | ])
26 | self.map_bn = norm_act(hidden_channels * 4)
27 |
28 | self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
29 | self.global_pooling_bn = norm_act(hidden_channels)
30 |
31 | self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
32 | self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
33 | self.red_bn = norm_act(out_channels)
34 |
35 | self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
36 |
37 | def reset_parameters(self, activation, slope):
38 | gain = nn.init.calculate_gain(activation, slope)
39 | for m in self.modules():
40 | if isinstance(m, nn.Conv2d):
41 | nn.init.xavier_normal_(m.weight.data, gain)
42 | if hasattr(m, "bias") and m.bias is not None:
43 | nn.init.constant_(m.bias, 0)
44 | elif isinstance(m, ABN):
45 | if hasattr(m, "weight") and m.weight is not None:
46 | nn.init.constant_(m.weight, 1)
47 | if hasattr(m, "bias") and m.bias is not None:
48 | nn.init.constant_(m.bias, 0)
49 |
50 | def forward(self, x):
51 | # Map convolutions
52 | out = torch.cat([m(x) for m in self.map_convs], dim=1)
53 | out = self.map_bn(out)
54 | out = self.red_conv(out)
55 |
56 | # Global pooling
57 | pool = self._global_pooling(x)
58 | pool = self.global_pooling_conv(pool)
59 | pool = self.global_pooling_bn(pool)
60 | pool = self.pool_red_conv(pool)
61 | if self.training or self.pooling_size is None:
62 | pool = pool.repeat(1, 1, x.size(2), x.size(3))
63 |
64 | out += pool
65 | out = self.red_bn(out)
66 | return out
67 |
68 | def _global_pooling(self, x):
69 | if self.training or self.pooling_size is None:
70 | pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
71 | pool = pool.view(x.size(0), x.size(1), 1, 1)
72 | else:
73 | pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
74 | min(try_index(self.pooling_size, 1), x.shape[3]))
75 | padding = (
76 | (pooling_size[1] - 1) // 2,
77 | (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
78 | (pooling_size[0] - 1) // 2,
79 | (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
80 | )
81 |
82 | pool = functional.avg_pool2d(x, pooling_size, stride=1)
83 | pool = functional.pad(pool, pad=padding, mode="replicate")
84 | return pool
85 |
--------------------------------------------------------------------------------
/modules/dense.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .bn import ABN
7 |
8 |
9 | class DenseModule(nn.Module):
10 | def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
11 | super(DenseModule, self).__init__()
12 | self.in_channels = in_channels
13 | self.growth = growth
14 | self.layers = layers
15 |
16 | self.convs1 = nn.ModuleList()
17 | self.convs3 = nn.ModuleList()
18 | for i in range(self.layers):
19 | self.convs1.append(nn.Sequential(OrderedDict([
20 | ("bn", norm_act(in_channels)),
21 | ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
22 | ])))
23 | self.convs3.append(nn.Sequential(OrderedDict([
24 | ("bn", norm_act(self.growth * bottleneck_factor)),
25 | ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
26 | dilation=dilation))
27 | ])))
28 | in_channels += self.growth
29 |
30 | @property
31 | def out_channels(self):
32 | return self.in_channels + self.growth * self.layers
33 |
34 | def forward(self, x):
35 | inputs = [x]
36 | for i in range(self.layers):
37 | x = torch.cat(inputs, dim=1)
38 | x = self.convs1[i](x)
39 | x = self.convs3[i](x)
40 | inputs += [x]
41 |
42 | return torch.cat(inputs, dim=1)
43 |
--------------------------------------------------------------------------------
/modules/functions.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | import torch
3 | import torch.distributed as dist
4 | import torch.autograd as autograd
5 | import torch.cuda.comm as comm
6 | from torch.autograd.function import once_differentiable
7 | from torch.utils.cpp_extension import load
8 |
9 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
10 | _backend = load(name="inplace_abn",
11 | extra_cflags=["-O3"],
12 | sources=[path.join(_src_path, f) for f in [
13 | "inplace_abn.cpp",
14 | "inplace_abn_cpu.cpp",
15 | "inplace_abn_cuda.cu",
16 | "inplace_abn_cuda_half.cu"
17 | ]],
18 | extra_cuda_cflags=["--expt-extended-lambda"])
19 |
20 | # Activation names
21 | ACT_RELU = "relu"
22 | ACT_LEAKY_RELU = "leaky_relu"
23 | ACT_ELU = "elu"
24 | ACT_NONE = "none"
25 |
26 |
27 | def _check(fn, *args, **kwargs):
28 | success = fn(*args, **kwargs)
29 | if not success:
30 | raise RuntimeError("CUDA Error encountered in {}".format(fn))
31 |
32 |
33 | def _broadcast_shape(x):
34 | out_size = []
35 | for i, s in enumerate(x.size()):
36 | if i != 1:
37 | out_size.append(1)
38 | else:
39 | out_size.append(s)
40 | return out_size
41 |
42 |
43 | def _reduce(x):
44 | if len(x.size()) == 2:
45 | return x.sum(dim=0)
46 | else:
47 | n, c = x.size()[0:2]
48 | return x.contiguous().view((n, c, -1)).sum(2).sum(0)
49 |
50 |
51 | def _count_samples(x):
52 | count = 1
53 | for i, s in enumerate(x.size()):
54 | if i != 1:
55 | count *= s
56 | return count
57 |
58 |
59 | def _act_forward(ctx, x):
60 | if ctx.activation == ACT_LEAKY_RELU:
61 | _backend.leaky_relu_forward(x, ctx.slope)
62 | elif ctx.activation == ACT_ELU:
63 | _backend.elu_forward(x)
64 | elif ctx.activation == ACT_NONE:
65 | pass
66 |
67 |
68 | def _act_backward(ctx, x, dx):
69 | if ctx.activation == ACT_LEAKY_RELU:
70 | _backend.leaky_relu_backward(x, dx, ctx.slope)
71 | elif ctx.activation == ACT_ELU:
72 | _backend.elu_backward(x, dx)
73 | elif ctx.activation == ACT_NONE:
74 | pass
75 |
76 |
77 | class InPlaceABN(autograd.Function):
78 | @staticmethod
79 | def forward(ctx, x, weight, bias, running_mean, running_var,
80 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
81 | # Save context
82 | ctx.training = training
83 | ctx.momentum = momentum
84 | ctx.eps = eps
85 | ctx.activation = activation
86 | ctx.slope = slope
87 | ctx.affine = weight is not None and bias is not None
88 |
89 | # Prepare inputs
90 | count = _count_samples(x)
91 | x = x.contiguous()
92 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
93 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
94 |
95 | if ctx.training:
96 | mean, var = _backend.mean_var(x)
97 |
98 | # Update running stats
99 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
100 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
101 |
102 | # Mark in-place modified tensors
103 | ctx.mark_dirty(x, running_mean, running_var)
104 | else:
105 | mean, var = running_mean.contiguous(), running_var.contiguous()
106 | ctx.mark_dirty(x)
107 |
108 | # BN forward + activation
109 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
110 | _act_forward(ctx, x)
111 |
112 | # Output
113 | ctx.var = var
114 | ctx.save_for_backward(x, var, weight, bias)
115 | return x
116 |
117 | @staticmethod
118 | @once_differentiable
119 | def backward(ctx, dz):
120 | z, var, weight, bias = ctx.saved_tensors
121 | dz = dz.contiguous()
122 |
123 | # Undo activation
124 | _act_backward(ctx, z, dz)
125 |
126 | if ctx.training:
127 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
128 | else:
129 | # TODO: implement simplified CUDA backward for inference mode
130 | edz = dz.new_zeros(dz.size(1))
131 | eydz = dz.new_zeros(dz.size(1))
132 |
133 | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
134 | dweight = eydz * weight.sign() if ctx.affine else None
135 | dbias = edz if ctx.affine else None
136 |
137 | return dx, dweight, dbias, None, None, None, None, None, None, None
138 |
139 | class InPlaceABNSync(autograd.Function):
140 | @classmethod
141 | def forward(cls, ctx, x, weight, bias, running_mean, running_var,
142 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
143 | # Save context
144 | ctx.training = training
145 | ctx.momentum = momentum
146 | ctx.eps = eps
147 | ctx.activation = activation
148 | ctx.slope = slope
149 | ctx.affine = weight is not None and bias is not None
150 |
151 | # Prepare inputs
152 | ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
153 |
154 | #count = _count_samples(x)
155 | batch_size = x.new_tensor([x.shape[0]],dtype=torch.long)
156 |
157 | x = x.contiguous()
158 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
159 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
160 |
161 | if ctx.training:
162 | mean, var = _backend.mean_var(x)
163 | if ctx.world_size>1:
164 | # get global batch size
165 | if equal_batches:
166 | batch_size *= ctx.world_size
167 | else:
168 | dist.all_reduce(batch_size, dist.ReduceOp.SUM)
169 |
170 | ctx.factor = x.shape[0]/float(batch_size.item())
171 |
172 | mean_all = mean.clone() * ctx.factor
173 | dist.all_reduce(mean_all, dist.ReduceOp.SUM)
174 |
175 | var_all = (var + (mean - mean_all) ** 2) * ctx.factor
176 | dist.all_reduce(var_all, dist.ReduceOp.SUM)
177 |
178 | mean = mean_all
179 | var = var_all
180 |
181 | # Update running stats
182 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
183 | count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1]
184 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
185 |
186 | # Mark in-place modified tensors
187 | ctx.mark_dirty(x, running_mean, running_var)
188 | else:
189 | mean, var = running_mean.contiguous(), running_var.contiguous()
190 | ctx.mark_dirty(x)
191 |
192 | # BN forward + activation
193 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
194 | _act_forward(ctx, x)
195 |
196 | # Output
197 | ctx.var = var
198 | ctx.save_for_backward(x, var, weight, bias)
199 | return x
200 |
201 | @staticmethod
202 | @once_differentiable
203 | def backward(ctx, dz):
204 | z, var, weight, bias = ctx.saved_tensors
205 | dz = dz.contiguous()
206 |
207 | # Undo activation
208 | _act_backward(ctx, z, dz)
209 |
210 | if ctx.training:
211 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
212 | edz_local = edz.clone()
213 | eydz_local = eydz.clone()
214 |
215 | if ctx.world_size>1:
216 | edz *= ctx.factor
217 | dist.all_reduce(edz, dist.ReduceOp.SUM)
218 |
219 | eydz *= ctx.factor
220 | dist.all_reduce(eydz, dist.ReduceOp.SUM)
221 | else:
222 | edz_local = edz = dz.new_zeros(dz.size(1))
223 | eydz_local = eydz = dz.new_zeros(dz.size(1))
224 |
225 | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
226 | dweight = eydz_local * weight.sign() if ctx.affine else None
227 | dbias = edz_local if ctx.affine else None
228 |
229 | return dx, dweight, dbias, None, None, None, None, None, None, None
230 |
231 | inplace_abn = InPlaceABN.apply
232 | inplace_abn_sync = InPlaceABNSync.apply
233 |
234 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
235 |
--------------------------------------------------------------------------------
/modules/misc.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.distributed as dist
4 |
5 | class GlobalAvgPool2d(nn.Module):
6 | def __init__(self):
7 | """Global average pooling over the input's spatial dimensions"""
8 | super(GlobalAvgPool2d, self).__init__()
9 |
10 | def forward(self, inputs):
11 | in_size = inputs.size()
12 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
13 |
14 | class SingleGPU(nn.Module):
15 | def __init__(self, module):
16 | super(SingleGPU, self).__init__()
17 | self.module=module
18 |
19 | def forward(self, input):
20 | return self.module(input.cuda(non_blocking=True))
21 |
22 |
--------------------------------------------------------------------------------
/modules/residual.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch.nn as nn
4 |
5 | from .bn import ABN
6 |
7 |
8 | class IdentityResidualBlock(nn.Module):
9 | def __init__(self,
10 | in_channels,
11 | channels,
12 | stride=1,
13 | dilation=1,
14 | groups=1,
15 | norm_act=ABN,
16 | dropout=None):
17 | """Configurable identity-mapping residual block
18 |
19 | Parameters
20 | ----------
21 | in_channels : int
22 | Number of input channels.
23 | channels : list of int
24 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct
25 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
26 | `3 x 3` then `1 x 1` convolutions.
27 | stride : int
28 | Stride of the first `3 x 3` convolution
29 | dilation : int
30 | Dilation to apply to the `3 x 3` convolutions.
31 | groups : int
32 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
33 | bottleneck blocks.
34 | norm_act : callable
35 | Function to create normalization / activation Module.
36 | dropout: callable
37 | Function to create Dropout Module.
38 | """
39 | super(IdentityResidualBlock, self).__init__()
40 |
41 | # Check parameters for inconsistencies
42 | if len(channels) != 2 and len(channels) != 3:
43 | raise ValueError("channels must contain either two or three values")
44 | if len(channels) == 2 and groups != 1:
45 | raise ValueError("groups > 1 are only valid if len(channels) == 3")
46 |
47 | is_bottleneck = len(channels) == 3
48 | need_proj_conv = stride != 1 or in_channels != channels[-1]
49 |
50 | self.bn1 = norm_act(in_channels)
51 | if not is_bottleneck:
52 | layers = [
53 | ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
54 | dilation=dilation)),
55 | ("bn2", norm_act(channels[0])),
56 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
57 | dilation=dilation))
58 | ]
59 | if dropout is not None:
60 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
61 | else:
62 | layers = [
63 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
64 | ("bn2", norm_act(channels[0])),
65 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
66 | groups=groups, dilation=dilation)),
67 | ("bn3", norm_act(channels[1])),
68 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
69 | ]
70 | if dropout is not None:
71 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
72 | self.convs = nn.Sequential(OrderedDict(layers))
73 |
74 | if need_proj_conv:
75 | self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
76 |
77 | def forward(self, x):
78 | if hasattr(self, "proj_conv"):
79 | bn1 = self.bn1(x)
80 | shortcut = self.proj_conv(bn1)
81 | else:
82 | shortcut = x.clone()
83 | bn1 = self.bn1(x)
84 |
85 | out = self.convs(bn1)
86 | out.add_(shortcut)
87 |
88 | return out
89 |
--------------------------------------------------------------------------------
/modules/src/checks.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6 | #ifndef AT_CHECK
7 | #define AT_CHECK AT_ASSERT
8 | #endif
9 |
10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13 |
14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
--------------------------------------------------------------------------------
/modules/src/inplace_abn.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | std::vector mean_var(at::Tensor x) {
8 | if (x.is_cuda()) {
9 | if (x.type().scalarType() == at::ScalarType::Half) {
10 | return mean_var_cuda_h(x);
11 | } else {
12 | return mean_var_cuda(x);
13 | }
14 | } else {
15 | return mean_var_cpu(x);
16 | }
17 | }
18 |
19 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
20 | bool affine, float eps) {
21 | if (x.is_cuda()) {
22 | if (x.type().scalarType() == at::ScalarType::Half) {
23 | return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
24 | } else {
25 | return forward_cuda(x, mean, var, weight, bias, affine, eps);
26 | }
27 | } else {
28 | return forward_cpu(x, mean, var, weight, bias, affine, eps);
29 | }
30 | }
31 |
32 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
33 | bool affine, float eps) {
34 | if (z.is_cuda()) {
35 | if (z.type().scalarType() == at::ScalarType::Half) {
36 | return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
37 | } else {
38 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
39 | }
40 | } else {
41 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
42 | }
43 | }
44 |
45 | at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
46 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
47 | if (z.is_cuda()) {
48 | if (z.type().scalarType() == at::ScalarType::Half) {
49 | return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
50 | } else {
51 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
52 | }
53 | } else {
54 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
55 | }
56 | }
57 |
58 | void leaky_relu_forward(at::Tensor z, float slope) {
59 | at::leaky_relu_(z, slope);
60 | }
61 |
62 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
63 | if (z.is_cuda()) {
64 | if (z.type().scalarType() == at::ScalarType::Half) {
65 | return leaky_relu_backward_cuda_h(z, dz, slope);
66 | } else {
67 | return leaky_relu_backward_cuda(z, dz, slope);
68 | }
69 | } else {
70 | return leaky_relu_backward_cpu(z, dz, slope);
71 | }
72 | }
73 |
74 | void elu_forward(at::Tensor z) {
75 | at::elu_(z);
76 | }
77 |
78 | void elu_backward(at::Tensor z, at::Tensor dz) {
79 | if (z.is_cuda()) {
80 | return elu_backward_cuda(z, dz);
81 | } else {
82 | return elu_backward_cpu(z, dz);
83 | }
84 | }
85 |
86 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
87 | m.def("mean_var", &mean_var, "Mean and variance computation");
88 | m.def("forward", &forward, "In-place forward computation");
89 | m.def("edz_eydz", &edz_eydz, "First part of backward computation");
90 | m.def("backward", &backward, "Second part of backward computation");
91 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
92 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
93 | m.def("elu_forward", &elu_forward, "Elu forward computation");
94 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
95 | }
96 |
--------------------------------------------------------------------------------
/modules/src/inplace_abn.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #include
6 |
7 | std::vector mean_var_cpu(at::Tensor x);
8 | std::vector mean_var_cuda(at::Tensor x);
9 | std::vector mean_var_cuda_h(at::Tensor x);
10 |
11 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
12 | bool affine, float eps);
13 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
14 | bool affine, float eps);
15 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps);
17 |
18 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
19 | bool affine, float eps);
20 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
21 | bool affine, float eps);
22 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
23 | bool affine, float eps);
24 |
25 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
26 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
27 | at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
28 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
29 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
30 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
31 |
32 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
33 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
34 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);
35 |
36 | void elu_backward_cpu(at::Tensor z, at::Tensor dz);
37 | void elu_backward_cuda(at::Tensor z, at::Tensor dz);
38 |
39 | static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
40 | num = x.size(0);
41 | chn = x.size(1);
42 | sp = 1;
43 | for (int64_t i = 2; i < x.ndimension(); ++i)
44 | sp *= x.size(i);
45 | }
46 |
47 | /*
48 | * Specialized CUDA reduction functions for BN
49 | */
50 | #ifdef __CUDACC__
51 |
52 | #include "utils/cuda.cuh"
53 |
54 | template
55 | __device__ T reduce(Op op, int plane, int N, int S) {
56 | T sum = (T)0;
57 | for (int batch = 0; batch < N; ++batch) {
58 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
59 | sum += op(batch, plane, x);
60 | }
61 | }
62 |
63 | // sum over NumThreads within a warp
64 | sum = warpSum(sum);
65 |
66 | // 'transpose', and reduce within warp again
67 | __shared__ T shared[32];
68 | __syncthreads();
69 | if (threadIdx.x % WARP_SIZE == 0) {
70 | shared[threadIdx.x / WARP_SIZE] = sum;
71 | }
72 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
73 | // zero out the other entries in shared
74 | shared[threadIdx.x] = (T)0;
75 | }
76 | __syncthreads();
77 | if (threadIdx.x / WARP_SIZE == 0) {
78 | sum = warpSum(shared[threadIdx.x]);
79 | if (threadIdx.x == 0) {
80 | shared[0] = sum;
81 | }
82 | }
83 | __syncthreads();
84 |
85 | // Everyone picks it up, should be broadcast into the whole gradInput
86 | return shared[0];
87 | }
88 | #endif
89 |
--------------------------------------------------------------------------------
/modules/src/inplace_abn_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "utils/checks.h"
6 | #include "inplace_abn.h"
7 |
8 | at::Tensor reduce_sum(at::Tensor x) {
9 | if (x.ndimension() == 2) {
10 | return x.sum(0);
11 | } else {
12 | auto x_view = x.view({x.size(0), x.size(1), -1});
13 | return x_view.sum(-1).sum(0);
14 | }
15 | }
16 |
17 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
18 | if (x.ndimension() == 2) {
19 | return v;
20 | } else {
21 | std::vector broadcast_size = {1, -1};
22 | for (int64_t i = 2; i < x.ndimension(); ++i)
23 | broadcast_size.push_back(1);
24 |
25 | return v.view(broadcast_size);
26 | }
27 | }
28 |
29 | int64_t count(at::Tensor x) {
30 | int64_t count = x.size(0);
31 | for (int64_t i = 2; i < x.ndimension(); ++i)
32 | count *= x.size(i);
33 |
34 | return count;
35 | }
36 |
37 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
38 | if (affine) {
39 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
40 | } else {
41 | return z;
42 | }
43 | }
44 |
45 | std::vector mean_var_cpu(at::Tensor x) {
46 | auto num = count(x);
47 | auto mean = reduce_sum(x) / num;
48 | auto diff = x - broadcast_to(mean, x);
49 | auto var = reduce_sum(diff.pow(2)) / num;
50 |
51 | return {mean, var};
52 | }
53 |
54 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
55 | bool affine, float eps) {
56 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
57 | auto mul = at::rsqrt(var + eps) * gamma;
58 |
59 | x.sub_(broadcast_to(mean, x));
60 | x.mul_(broadcast_to(mul, x));
61 | if (affine) x.add_(broadcast_to(bias, x));
62 |
63 | return x;
64 | }
65 |
66 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
67 | bool affine, float eps) {
68 | auto edz = reduce_sum(dz);
69 | auto y = invert_affine(z, weight, bias, affine, eps);
70 | auto eydz = reduce_sum(y * dz);
71 |
72 | return {edz, eydz};
73 | }
74 |
75 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
76 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
77 | auto y = invert_affine(z, weight, bias, affine, eps);
78 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
79 |
80 | auto num = count(z);
81 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
82 | return dx;
83 | }
84 |
85 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
86 | CHECK_CPU_INPUT(z);
87 | CHECK_CPU_INPUT(dz);
88 |
89 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
90 | int64_t count = z.numel();
91 | auto *_z = z.data();
92 | auto *_dz = dz.data();
93 |
94 | for (int64_t i = 0; i < count; ++i) {
95 | if (_z[i] < 0) {
96 | _z[i] *= 1 / slope;
97 | _dz[i] *= slope;
98 | }
99 | }
100 | }));
101 | }
102 |
103 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
104 | CHECK_CPU_INPUT(z);
105 | CHECK_CPU_INPUT(dz);
106 |
107 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
108 | int64_t count = z.numel();
109 | auto *_z = z.data();
110 | auto *_dz = dz.data();
111 |
112 | for (int64_t i = 0; i < count; ++i) {
113 | if (_z[i] < 0) {
114 | _z[i] = log1p(_z[i]);
115 | _dz[i] *= (_z[i] + 1.f);
116 | }
117 | }
118 | }));
119 | }
120 |
--------------------------------------------------------------------------------
/modules/src/inplace_abn_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #include
7 |
8 | #include "utils/checks.h"
9 | #include "utils/cuda.cuh"
10 | #include "inplace_abn.h"
11 |
12 | #include
13 |
14 | // Operations for reduce
15 | template
16 | struct SumOp {
17 | __device__ SumOp(const T *t, int c, int s)
18 | : tensor(t), chn(c), sp(s) {}
19 | __device__ __forceinline__ T operator()(int batch, int plane, int n) {
20 | return tensor[(batch * chn + plane) * sp + n];
21 | }
22 | const T *tensor;
23 | const int chn;
24 | const int sp;
25 | };
26 |
27 | template
28 | struct VarOp {
29 | __device__ VarOp(T m, const T *t, int c, int s)
30 | : mean(m), tensor(t), chn(c), sp(s) {}
31 | __device__ __forceinline__ T operator()(int batch, int plane, int n) {
32 | T val = tensor[(batch * chn + plane) * sp + n];
33 | return (val - mean) * (val - mean);
34 | }
35 | const T mean;
36 | const T *tensor;
37 | const int chn;
38 | const int sp;
39 | };
40 |
41 | template
42 | struct GradOp {
43 | __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
44 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
45 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
46 | T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
47 | T _dz = dz[(batch * chn + plane) * sp + n];
48 | return Pair(_dz, _y * _dz);
49 | }
50 | const T weight;
51 | const T bias;
52 | const T *z;
53 | const T *dz;
54 | const int chn;
55 | const int sp;
56 | };
57 |
58 | /***********
59 | * mean_var
60 | ***********/
61 |
62 | template
63 | __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
64 | int plane = blockIdx.x;
65 | T norm = T(1) / T(num * sp);
66 |
67 | T _mean = reduce>(SumOp(x, chn, sp), plane, num, sp) * norm;
68 | __syncthreads();
69 | T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, sp) * norm;
70 |
71 | if (threadIdx.x == 0) {
72 | mean[plane] = _mean;
73 | var[plane] = _var;
74 | }
75 | }
76 |
77 | std::vector mean_var_cuda(at::Tensor x) {
78 | CHECK_CUDA_INPUT(x);
79 |
80 | // Extract dimensions
81 | int64_t num, chn, sp;
82 | get_dims(x, num, chn, sp);
83 |
84 | // Prepare output tensors
85 | auto mean = at::empty({chn}, x.options());
86 | auto var = at::empty({chn}, x.options());
87 |
88 | // Run kernel
89 | dim3 blocks(chn);
90 | dim3 threads(getNumThreads(sp));
91 | auto stream = at::cuda::getCurrentCUDAStream();
92 | AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
93 | mean_var_kernel<<>>(
94 | x.data(),
95 | mean.data(),
96 | var.data(),
97 | num, chn, sp);
98 | }));
99 |
100 | return {mean, var};
101 | }
102 |
103 | /**********
104 | * forward
105 | **********/
106 |
107 | template
108 | __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
109 | bool affine, float eps, int num, int chn, int sp) {
110 | int plane = blockIdx.x;
111 |
112 | T _mean = mean[plane];
113 | T _var = var[plane];
114 | T _weight = affine ? abs(weight[plane]) + eps : T(1);
115 | T _bias = affine ? bias[plane] : T(0);
116 |
117 | T mul = rsqrt(_var + eps) * _weight;
118 |
119 | for (int batch = 0; batch < num; ++batch) {
120 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
121 | T _x = x[(batch * chn + plane) * sp + n];
122 | T _y = (_x - _mean) * mul + _bias;
123 |
124 | x[(batch * chn + plane) * sp + n] = _y;
125 | }
126 | }
127 | }
128 |
129 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
130 | bool affine, float eps) {
131 | CHECK_CUDA_INPUT(x);
132 | CHECK_CUDA_INPUT(mean);
133 | CHECK_CUDA_INPUT(var);
134 | CHECK_CUDA_INPUT(weight);
135 | CHECK_CUDA_INPUT(bias);
136 |
137 | // Extract dimensions
138 | int64_t num, chn, sp;
139 | get_dims(x, num, chn, sp);
140 |
141 | // Run kernel
142 | dim3 blocks(chn);
143 | dim3 threads(getNumThreads(sp));
144 | auto stream = at::cuda::getCurrentCUDAStream();
145 | AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
146 | forward_kernel<<>>(
147 | x.data(),
148 | mean.data(),
149 | var.data(),
150 | weight.data(),
151 | bias.data(),
152 | affine, eps, num, chn, sp);
153 | }));
154 |
155 | return x;
156 | }
157 |
158 | /***********
159 | * edz_eydz
160 | ***********/
161 |
162 | template
163 | __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
164 | T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
165 | int plane = blockIdx.x;
166 |
167 | T _weight = affine ? abs(weight[plane]) + eps : 1.f;
168 | T _bias = affine ? bias[plane] : 0.f;
169 |
170 | Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, sp);
171 | __syncthreads();
172 |
173 | if (threadIdx.x == 0) {
174 | edz[plane] = res.v1;
175 | eydz[plane] = res.v2;
176 | }
177 | }
178 |
179 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
180 | bool affine, float eps) {
181 | CHECK_CUDA_INPUT(z);
182 | CHECK_CUDA_INPUT(dz);
183 | CHECK_CUDA_INPUT(weight);
184 | CHECK_CUDA_INPUT(bias);
185 |
186 | // Extract dimensions
187 | int64_t num, chn, sp;
188 | get_dims(z, num, chn, sp);
189 |
190 | auto edz = at::empty({chn}, z.options());
191 | auto eydz = at::empty({chn}, z.options());
192 |
193 | // Run kernel
194 | dim3 blocks(chn);
195 | dim3 threads(getNumThreads(sp));
196 | auto stream = at::cuda::getCurrentCUDAStream();
197 | AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
198 | edz_eydz_kernel<<>>(
199 | z.data(),
200 | dz.data(),
201 | weight.data(),
202 | bias.data(),
203 | edz.data(),
204 | eydz.data(),
205 | affine, eps, num, chn, sp);
206 | }));
207 |
208 | return {edz, eydz};
209 | }
210 |
211 | /***********
212 | * backward
213 | ***********/
214 |
215 | template
216 | __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
217 | const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
218 | int plane = blockIdx.x;
219 |
220 | T _weight = affine ? abs(weight[plane]) + eps : 1.f;
221 | T _bias = affine ? bias[plane] : 0.f;
222 | T _var = var[plane];
223 | T _edz = edz[plane];
224 | T _eydz = eydz[plane];
225 |
226 | T _mul = _weight * rsqrt(_var + eps);
227 | T count = T(num * sp);
228 |
229 | for (int batch = 0; batch < num; ++batch) {
230 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
231 | T _dz = dz[(batch * chn + plane) * sp + n];
232 | T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
233 |
234 | dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
235 | }
236 | }
237 | }
238 |
239 | at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
240 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
241 | CHECK_CUDA_INPUT(z);
242 | CHECK_CUDA_INPUT(dz);
243 | CHECK_CUDA_INPUT(var);
244 | CHECK_CUDA_INPUT(weight);
245 | CHECK_CUDA_INPUT(bias);
246 | CHECK_CUDA_INPUT(edz);
247 | CHECK_CUDA_INPUT(eydz);
248 |
249 | // Extract dimensions
250 | int64_t num, chn, sp;
251 | get_dims(z, num, chn, sp);
252 |
253 | auto dx = at::zeros_like(z);
254 |
255 | // Run kernel
256 | dim3 blocks(chn);
257 | dim3 threads(getNumThreads(sp));
258 | auto stream = at::cuda::getCurrentCUDAStream();
259 | AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
260 | backward_kernel<<>>(
261 | z.data(),
262 | dz.data(),
263 | var.data(),
264 | weight.data(),
265 | bias.data(),
266 | edz.data(),
267 | eydz.data(),
268 | dx.data(),
269 | affine, eps, num, chn, sp);
270 | }));
271 |
272 | return dx;
273 | }
274 |
275 | /**************
276 | * activations
277 | **************/
278 |
279 | template
280 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
281 | // Create thrust pointers
282 | thrust::device_ptr th_z = thrust::device_pointer_cast(z);
283 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
284 |
285 | auto stream = at::cuda::getCurrentCUDAStream();
286 | thrust::transform_if(thrust::cuda::par.on(stream),
287 | th_dz, th_dz + count, th_z, th_dz,
288 | [slope] __device__ (const T& dz) { return dz * slope; },
289 | [] __device__ (const T& z) { return z < 0; });
290 | thrust::transform_if(thrust::cuda::par.on(stream),
291 | th_z, th_z + count, th_z,
292 | [slope] __device__ (const T& z) { return z / slope; },
293 | [] __device__ (const T& z) { return z < 0; });
294 | }
295 |
296 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
297 | CHECK_CUDA_INPUT(z);
298 | CHECK_CUDA_INPUT(dz);
299 |
300 | int64_t count = z.numel();
301 |
302 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
303 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count);
304 | }));
305 | }
306 |
307 | template
308 | inline void elu_backward_impl(T *z, T *dz, int64_t count) {
309 | // Create thrust pointers
310 | thrust::device_ptr th_z = thrust::device_pointer_cast(z);
311 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
312 |
313 | auto stream = at::cuda::getCurrentCUDAStream();
314 | thrust::transform_if(thrust::cuda::par.on(stream),
315 | th_dz, th_dz + count, th_z, th_z, th_dz,
316 | [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
317 | [] __device__ (const T& z) { return z < 0; });
318 | thrust::transform_if(thrust::cuda::par.on(stream),
319 | th_z, th_z + count, th_z,
320 | [] __device__ (const T& z) { return log1p(z); },
321 | [] __device__ (const T& z) { return z < 0; });
322 | }
323 |
324 | void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
325 | CHECK_CUDA_INPUT(z);
326 | CHECK_CUDA_INPUT(dz);
327 |
328 | int64_t count = z.numel();
329 |
330 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
331 | elu_backward_impl(z.data(), dz.data(), count);
332 | }));
333 | }
334 |
--------------------------------------------------------------------------------
/modules/src/inplace_abn_cuda_half.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include
6 |
7 | #include "utils/checks.h"
8 | #include "utils/cuda.cuh"
9 | #include "inplace_abn.h"
10 |
11 | #include
12 |
13 | // Operations for reduce
14 | struct SumOpH {
15 | __device__ SumOpH(const half *t, int c, int s)
16 | : tensor(t), chn(c), sp(s) {}
17 | __device__ __forceinline__ float operator()(int batch, int plane, int n) {
18 | return __half2float(tensor[(batch * chn + plane) * sp + n]);
19 | }
20 | const half *tensor;
21 | const int chn;
22 | const int sp;
23 | };
24 |
25 | struct VarOpH {
26 | __device__ VarOpH(float m, const half *t, int c, int s)
27 | : mean(m), tensor(t), chn(c), sp(s) {}
28 | __device__ __forceinline__ float operator()(int batch, int plane, int n) {
29 | const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]);
30 | return (t - mean) * (t - mean);
31 | }
32 | const float mean;
33 | const half *tensor;
34 | const int chn;
35 | const int sp;
36 | };
37 |
38 | struct GradOpH {
39 | __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s)
40 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
41 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
42 | float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight;
43 | float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
44 | return Pair(_dz, _y * _dz);
45 | }
46 | const float weight;
47 | const float bias;
48 | const half *z;
49 | const half *dz;
50 | const int chn;
51 | const int sp;
52 | };
53 |
54 | /***********
55 | * mean_var
56 | ***********/
57 |
58 | __global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) {
59 | int plane = blockIdx.x;
60 | float norm = 1.f / static_cast(num * sp);
61 |
62 | float _mean = reduce(SumOpH(x, chn, sp), plane, num, sp) * norm;
63 | __syncthreads();
64 | float _var = reduce(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm;
65 |
66 | if (threadIdx.x == 0) {
67 | mean[plane] = _mean;
68 | var[plane] = _var;
69 | }
70 | }
71 |
72 | std::vector mean_var_cuda_h(at::Tensor x) {
73 | CHECK_CUDA_INPUT(x);
74 |
75 | // Extract dimensions
76 | int64_t num, chn, sp;
77 | get_dims(x, num, chn, sp);
78 |
79 | // Prepare output tensors
80 | auto mean = at::empty({chn},x.options().dtype(at::kFloat));
81 | auto var = at::empty({chn},x.options().dtype(at::kFloat));
82 |
83 | // Run kernel
84 | dim3 blocks(chn);
85 | dim3 threads(getNumThreads(sp));
86 | auto stream = at::cuda::getCurrentCUDAStream();
87 | mean_var_kernel_h<<>>(
88 | reinterpret_cast(x.data()),
89 | mean.data(),
90 | var.data(),
91 | num, chn, sp);
92 |
93 | return {mean, var};
94 | }
95 |
96 | /**********
97 | * forward
98 | **********/
99 |
100 | __global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias,
101 | bool affine, float eps, int num, int chn, int sp) {
102 | int plane = blockIdx.x;
103 |
104 | const float _mean = mean[plane];
105 | const float _var = var[plane];
106 | const float _weight = affine ? abs(weight[plane]) + eps : 1.f;
107 | const float _bias = affine ? bias[plane] : 0.f;
108 |
109 | const float mul = rsqrt(_var + eps) * _weight;
110 |
111 | for (int batch = 0; batch < num; ++batch) {
112 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
113 | half *x_ptr = x + (batch * chn + plane) * sp + n;
114 | float _x = __half2float(*x_ptr);
115 | float _y = (_x - _mean) * mul + _bias;
116 |
117 | *x_ptr = __float2half(_y);
118 | }
119 | }
120 | }
121 |
122 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
123 | bool affine, float eps) {
124 | CHECK_CUDA_INPUT(x);
125 | CHECK_CUDA_INPUT(mean);
126 | CHECK_CUDA_INPUT(var);
127 | CHECK_CUDA_INPUT(weight);
128 | CHECK_CUDA_INPUT(bias);
129 |
130 | // Extract dimensions
131 | int64_t num, chn, sp;
132 | get_dims(x, num, chn, sp);
133 |
134 | // Run kernel
135 | dim3 blocks(chn);
136 | dim3 threads(getNumThreads(sp));
137 | auto stream = at::cuda::getCurrentCUDAStream();
138 | forward_kernel_h<<>>(
139 | reinterpret_cast(x.data()),
140 | mean.data(),
141 | var.data(),
142 | weight.data(),
143 | bias.data(),
144 | affine, eps, num, chn, sp);
145 |
146 | return x;
147 | }
148 |
149 | __global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias,
150 | float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) {
151 | int plane = blockIdx.x;
152 |
153 | float _weight = affine ? abs(weight[plane]) + eps : 1.f;
154 | float _bias = affine ? bias[plane] : 0.f;
155 |
156 | Pair res = reduce, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp);
157 | __syncthreads();
158 |
159 | if (threadIdx.x == 0) {
160 | edz[plane] = res.v1;
161 | eydz[plane] = res.v2;
162 | }
163 | }
164 |
165 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
166 | bool affine, float eps) {
167 | CHECK_CUDA_INPUT(z);
168 | CHECK_CUDA_INPUT(dz);
169 | CHECK_CUDA_INPUT(weight);
170 | CHECK_CUDA_INPUT(bias);
171 |
172 | // Extract dimensions
173 | int64_t num, chn, sp;
174 | get_dims(z, num, chn, sp);
175 |
176 | auto edz = at::empty({chn},z.options().dtype(at::kFloat));
177 | auto eydz = at::empty({chn},z.options().dtype(at::kFloat));
178 |
179 | // Run kernel
180 | dim3 blocks(chn);
181 | dim3 threads(getNumThreads(sp));
182 | auto stream = at::cuda::getCurrentCUDAStream();
183 | edz_eydz_kernel_h<<>>(
184 | reinterpret_cast(z.data()),
185 | reinterpret_cast(dz.data()),
186 | weight.data(),
187 | bias.data(),
188 | edz.data(),
189 | eydz.data(),
190 | affine, eps, num, chn, sp);
191 |
192 | return {edz, eydz};
193 | }
194 |
195 | __global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz,
196 | const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) {
197 | int plane = blockIdx.x;
198 |
199 | float _weight = affine ? abs(weight[plane]) + eps : 1.f;
200 | float _bias = affine ? bias[plane] : 0.f;
201 | float _var = var[plane];
202 | float _edz = edz[plane];
203 | float _eydz = eydz[plane];
204 |
205 | float _mul = _weight * rsqrt(_var + eps);
206 | float count = float(num * sp);
207 |
208 | for (int batch = 0; batch < num; ++batch) {
209 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
210 | float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
211 | float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight;
212 |
213 | dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul);
214 | }
215 | }
216 | }
217 |
218 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
219 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
220 | CHECK_CUDA_INPUT(z);
221 | CHECK_CUDA_INPUT(dz);
222 | CHECK_CUDA_INPUT(var);
223 | CHECK_CUDA_INPUT(weight);
224 | CHECK_CUDA_INPUT(bias);
225 | CHECK_CUDA_INPUT(edz);
226 | CHECK_CUDA_INPUT(eydz);
227 |
228 | // Extract dimensions
229 | int64_t num, chn, sp;
230 | get_dims(z, num, chn, sp);
231 |
232 | auto dx = at::zeros_like(z);
233 |
234 | // Run kernel
235 | dim3 blocks(chn);
236 | dim3 threads(getNumThreads(sp));
237 | auto stream = at::cuda::getCurrentCUDAStream();
238 | backward_kernel_h<<>>(
239 | reinterpret_cast(z.data()),
240 | reinterpret_cast(dz.data()),
241 | var.data(),
242 | weight.data(),
243 | bias.data(),
244 | edz.data(),
245 | eydz.data(),
246 | reinterpret_cast(dx.data()),
247 | affine, eps, num, chn, sp);
248 |
249 | return dx;
250 | }
251 |
252 | __global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) {
253 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){
254 | float _z = __half2float(z[i]);
255 | if (_z < 0) {
256 | dz[i] = __float2half(__half2float(dz[i]) * slope);
257 | z[i] = __float2half(_z / slope);
258 | }
259 | }
260 | }
261 |
262 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) {
263 | CHECK_CUDA_INPUT(z);
264 | CHECK_CUDA_INPUT(dz);
265 |
266 | int64_t count = z.numel();
267 | dim3 threads(getNumThreads(count));
268 | dim3 blocks = (count + threads.x - 1) / threads.x;
269 | auto stream = at::cuda::getCurrentCUDAStream();
270 | leaky_relu_backward_impl_h<<>>(
271 | reinterpret_cast(z.data()),
272 | reinterpret_cast(dz.data()),
273 | slope, count);
274 | }
275 |
276 |
--------------------------------------------------------------------------------
/modules/src/utils/checks.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6 | #ifndef AT_CHECK
7 | #define AT_CHECK AT_ASSERT
8 | #endif
9 |
10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13 |
14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
--------------------------------------------------------------------------------
/modules/src/utils/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | /*
6 | * Functions to share code between CPU and GPU
7 | */
8 |
9 | #ifdef __CUDACC__
10 | // CUDA versions
11 |
12 | #define HOST_DEVICE __host__ __device__
13 | #define INLINE_HOST_DEVICE __host__ __device__ inline
14 | #define FLOOR(x) floor(x)
15 |
16 | #if __CUDA_ARCH__ >= 600
17 | // Recent compute capabilities have block-level atomicAdd for all data types, so we use that
18 | #define ACCUM(x,y) atomicAdd_block(&(x),(y))
19 | #else
20 | // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float
21 | // and use the known atomicCAS-based implementation for double
22 | template
23 | __device__ inline data_t atomic_add(data_t *address, data_t val) {
24 | return atomicAdd(address, val);
25 | }
26 |
27 | template<>
28 | __device__ inline double atomic_add(double *address, double val) {
29 | unsigned long long int* address_as_ull = (unsigned long long int*)address;
30 | unsigned long long int old = *address_as_ull, assumed;
31 | do {
32 | assumed = old;
33 | old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
34 | } while (assumed != old);
35 | return __longlong_as_double(old);
36 | }
37 |
38 | #define ACCUM(x,y) atomic_add(&(x),(y))
39 | #endif // #if __CUDA_ARCH__ >= 600
40 |
41 | #else
42 | // CPU versions
43 |
44 | #define HOST_DEVICE
45 | #define INLINE_HOST_DEVICE inline
46 | #define FLOOR(x) std::floor(x)
47 | #define ACCUM(x,y) (x) += (y)
48 |
49 | #endif // #ifdef __CUDACC__
--------------------------------------------------------------------------------
/modules/src/utils/cuda.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | /*
4 | * General settings and functions
5 | */
6 | const int WARP_SIZE = 32;
7 | const int MAX_BLOCK_SIZE = 1024;
8 |
9 | static int getNumThreads(int nElem) {
10 | int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE};
11 | for (int i = 0; i < 6; ++i) {
12 | if (nElem <= threadSizes[i]) {
13 | return threadSizes[i];
14 | }
15 | }
16 | return MAX_BLOCK_SIZE;
17 | }
18 |
19 | /*
20 | * Reduction utilities
21 | */
22 | template
23 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
24 | unsigned int mask = 0xffffffff) {
25 | #if CUDART_VERSION >= 9000
26 | return __shfl_xor_sync(mask, value, laneMask, width);
27 | #else
28 | return __shfl_xor(value, laneMask, width);
29 | #endif
30 | }
31 |
32 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
33 |
34 | template
35 | struct Pair {
36 | T v1, v2;
37 | __device__ Pair() {}
38 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
39 | __device__ Pair(T v) : v1(v), v2(v) {}
40 | __device__ Pair(int v) : v1(v), v2(v) {}
41 | __device__ Pair &operator+=(const Pair &a) {
42 | v1 += a.v1;
43 | v2 += a.v2;
44 | return *this;
45 | }
46 | };
47 |
48 | template
49 | static __device__ __forceinline__ T warpSum(T val) {
50 | #if __CUDA_ARCH__ >= 300
51 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
52 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
53 | }
54 | #else
55 | __shared__ T values[MAX_BLOCK_SIZE];
56 | values[threadIdx.x] = val;
57 | __threadfence_block();
58 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
59 | for (int i = 1; i < WARP_SIZE; i++) {
60 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
61 | }
62 | #endif
63 | return val;
64 | }
65 |
66 | template
67 | static __device__ __forceinline__ Pair warpSum(Pair value) {
68 | value.v1 = warpSum(value.v1);
69 | value.v2 = warpSum(value.v2);
70 | return value;
71 | }
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import logging
7 |
8 | logger = logging.getLogger()
9 |
10 | class Optimizer(object):
11 | def __init__(self,
12 | model,
13 | lr0,
14 | momentum,
15 | wd,
16 | warmup_steps,
17 | warmup_start_lr,
18 | max_iter,
19 | power,
20 | *args, **kwargs):
21 | self.warmup_steps = warmup_steps
22 | self.warmup_start_lr = warmup_start_lr
23 | self.lr0 = lr0
24 | self.lr = self.lr0
25 | self.max_iter = float(max_iter)
26 | self.power = power
27 | self.it = 0
28 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params()
29 | param_list = [
30 | {'params': wd_params},
31 | {'params': nowd_params, 'weight_decay': 0},
32 | {'params': lr_mul_wd_params, 'lr_mul': True},
33 | {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}]
34 | self.optim = torch.optim.SGD(
35 | param_list,
36 | lr = lr0,
37 | momentum = momentum,
38 | weight_decay = wd)
39 | self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps)
40 |
41 |
42 | def get_lr(self):
43 | if self.it <= self.warmup_steps:
44 | lr = self.warmup_start_lr*(self.warmup_factor**self.it)
45 | else:
46 | factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power
47 | lr = self.lr0 * factor
48 | return lr
49 |
50 |
51 | def step(self):
52 | self.lr = self.get_lr()
53 | for pg in self.optim.param_groups:
54 | if pg.get('lr_mul', False):
55 | pg['lr'] = self.lr * 10
56 | else:
57 | pg['lr'] = self.lr
58 | if self.optim.defaults.get('lr_mul', False):
59 | self.optim.defaults['lr'] = self.lr * 10
60 | else:
61 | self.optim.defaults['lr'] = self.lr
62 | self.it += 1
63 | self.optim.step()
64 | if self.it == self.warmup_steps+2:
65 | logger.info('==> warmup done, start to implement poly lr strategy')
66 |
67 | def zero_grad(self):
68 | self.optim.zero_grad()
69 |
70 |
--------------------------------------------------------------------------------
/prepropess_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import os.path as osp
5 | import os
6 | import cv2
7 | from .transform import *
8 | from PIL import Image
9 |
10 | face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
11 | face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
12 | mask_path = '/home/zll/data/CelebAMask-HQ/mask'
13 | counter = 0
14 | total = 0
15 | for i in range(15):
16 |
17 | atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
18 | 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
19 |
20 | for j in range(i * 2000, (i + 1) * 2000):
21 |
22 | mask = np.zeros((512, 512))
23 |
24 | for l, att in enumerate(atts, 1):
25 | total += 1
26 | file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
27 | path = osp.join(face_sep_mask, str(i), file_name)
28 |
29 | if os.path.exists(path):
30 | counter += 1
31 | sep_mask = np.array(Image.open(path).convert('P'))
32 | # print(np.unique(sep_mask))
33 |
34 | mask[sep_mask == 225] = l
35 | cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
36 | print(j)
37 |
38 | print(counter, total)
39 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum-1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | #self.init_weight()
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self):
83 | state_dict = modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | from .logger import setup_logger
5 | from .model import BiSeNet
6 |
7 | import torch
8 |
9 | import os
10 | import os.path as osp
11 | import numpy as np
12 | from PIL import Image
13 | import torchvision.transforms as transforms
14 | import cv2
15 |
16 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
17 | # Colors for all 20 parts
18 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
19 | [255, 0, 85], [255, 0, 170],
20 | [0, 255, 0], [85, 255, 0], [170, 255, 0],
21 | [0, 255, 85], [0, 255, 170],
22 | [0, 0, 255], [85, 0, 255], [170, 0, 255],
23 | [0, 85, 255], [0, 170, 255],
24 | [255, 255, 0], [255, 255, 85], [255, 255, 170],
25 | [255, 0, 255], [255, 85, 255], [255, 170, 255],
26 | [0, 255, 255], [85, 255, 255], [170, 255, 255]]
27 |
28 | im = np.array(im)
29 | vis_im = im.copy().astype(np.uint8)
30 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
31 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
32 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
33 |
34 | num_of_class = np.max(vis_parsing_anno)
35 |
36 | for pi in range(1, num_of_class + 1):
37 | index = np.where(vis_parsing_anno == pi)
38 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
39 |
40 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
41 | # print(vis_parsing_anno_color.shape, vis_im.shape)
42 | vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
43 |
44 | # Save result or not
45 | if save_im:
46 | cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
47 | cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
48 |
49 | # return vis_im
50 |
51 | def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
52 |
53 | if not os.path.exists(respth):
54 | os.makedirs(respth)
55 |
56 | n_classes = 19
57 | net = BiSeNet(n_classes=n_classes)
58 | net.cuda()
59 | save_pth = osp.join('res/cp', cp)
60 | net.load_state_dict(torch.load(save_pth))
61 | net.eval()
62 |
63 | to_tensor = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
66 | ])
67 | with torch.no_grad():
68 | for image_path in os.listdir(dspth):
69 | img = Image.open(osp.join(dspth, image_path))
70 | image = img.resize((512, 512), Image.BILINEAR)
71 | img = to_tensor(image)
72 | img = torch.unsqueeze(img, 0)
73 | img = img.cuda()
74 | out = net(img)[0]
75 | parsing = out.squeeze(0).cpu().numpy().argmax(0)
76 | # print(parsing)
77 | print(np.unique(parsing))
78 |
79 | vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 | if __name__ == "__main__":
88 | evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth')
89 |
90 |
91 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | from .logger import setup_logger
5 | from .model import BiSeNet
6 | from .face_dataset import FaceMask
7 | from .loss import OhemCELoss
8 | from .evaluate import evaluate
9 | from .optimizer import Optimizer
10 | import cv2
11 | import numpy as np
12 |
13 | import torch
14 | import torch.nn as nn
15 | from torch.utils.data import DataLoader
16 | import torch.nn.functional as F
17 | import torch.distributed as dist
18 |
19 | import os
20 | import os.path as osp
21 | import logging
22 | import time
23 | import datetime
24 | import argparse
25 |
26 |
27 | respth = './res'
28 | if not osp.exists(respth):
29 | os.makedirs(respth)
30 | logger = logging.getLogger()
31 |
32 |
33 | def parse_args():
34 | parse = argparse.ArgumentParser()
35 | parse.add_argument(
36 | '--local_rank',
37 | dest = 'local_rank',
38 | type = int,
39 | default = -1,
40 | )
41 | return parse.parse_args()
42 |
43 |
44 | def train():
45 | args = parse_args()
46 | torch.cuda.set_device(args.local_rank)
47 | dist.init_process_group(
48 | backend = 'nccl',
49 | init_method = 'tcp://127.0.0.1:33241',
50 | world_size = torch.cuda.device_count(),
51 | rank=args.local_rank
52 | )
53 | setup_logger(respth)
54 |
55 | # dataset
56 | n_classes = 19
57 | n_img_per_gpu = 16
58 | n_workers = 8
59 | cropsize = [448, 448]
60 | data_root = '/home/zll/data/CelebAMask-HQ/'
61 |
62 | ds = FaceMask(data_root, cropsize=cropsize, mode='train')
63 | sampler = torch.utils.data.distributed.DistributedSampler(ds)
64 | dl = DataLoader(ds,
65 | batch_size = n_img_per_gpu,
66 | shuffle = False,
67 | sampler = sampler,
68 | num_workers = n_workers,
69 | pin_memory = True,
70 | drop_last = True)
71 |
72 | # model
73 | ignore_idx = -100
74 | net = BiSeNet(n_classes=n_classes)
75 | net.cuda()
76 | net.train()
77 | net = nn.parallel.DistributedDataParallel(net,
78 | device_ids = [args.local_rank, ],
79 | output_device = args.local_rank
80 | )
81 | score_thres = 0.7
82 | n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16
83 | LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
84 | Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
85 | Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
86 |
87 | ## optimizer
88 | momentum = 0.9
89 | weight_decay = 5e-4
90 | lr_start = 1e-2
91 | max_iter = 80000
92 | power = 0.9
93 | warmup_steps = 1000
94 | warmup_start_lr = 1e-5
95 | optim = Optimizer(
96 | model = net.module,
97 | lr0 = lr_start,
98 | momentum = momentum,
99 | wd = weight_decay,
100 | warmup_steps = warmup_steps,
101 | warmup_start_lr = warmup_start_lr,
102 | max_iter = max_iter,
103 | power = power)
104 |
105 | ## train loop
106 | msg_iter = 50
107 | loss_avg = []
108 | st = glob_st = time.time()
109 | diter = iter(dl)
110 | epoch = 0
111 | for it in range(max_iter):
112 | try:
113 | im, lb = next(diter)
114 | if not im.size()[0] == n_img_per_gpu:
115 | raise StopIteration
116 | except StopIteration:
117 | epoch += 1
118 | sampler.set_epoch(epoch)
119 | diter = iter(dl)
120 | im, lb = next(diter)
121 | im = im.cuda()
122 | lb = lb.cuda()
123 | H, W = im.size()[2:]
124 | lb = torch.squeeze(lb, 1)
125 |
126 | optim.zero_grad()
127 | out, out16, out32 = net(im)
128 | lossp = LossP(out, lb)
129 | loss2 = Loss2(out16, lb)
130 | loss3 = Loss3(out32, lb)
131 | loss = lossp + loss2 + loss3
132 | loss.backward()
133 | optim.step()
134 |
135 | loss_avg.append(loss.item())
136 |
137 | # print training log message
138 | if (it+1) % msg_iter == 0:
139 | loss_avg = sum(loss_avg) / len(loss_avg)
140 | lr = optim.lr
141 | ed = time.time()
142 | t_intv, glob_t_intv = ed - st, ed - glob_st
143 | eta = int((max_iter - it) * (glob_t_intv / it))
144 | eta = str(datetime.timedelta(seconds=eta))
145 | msg = ', '.join([
146 | 'it: {it}/{max_it}',
147 | 'lr: {lr:4f}',
148 | 'loss: {loss:.4f}',
149 | 'eta: {eta}',
150 | 'time: {time:.4f}',
151 | ]).format(
152 | it = it+1,
153 | max_it = max_iter,
154 | lr = lr,
155 | loss = loss_avg,
156 | time = t_intv,
157 | eta = eta
158 | )
159 | logger.info(msg)
160 | loss_avg = []
161 | st = ed
162 | if dist.get_rank() == 0:
163 | if (it+1) % 5000 == 0:
164 | state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
165 | if dist.get_rank() == 0:
166 | torch.save(state, './res/cp/{}_iter.pth'.format(it))
167 | evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it))
168 |
169 | # dump the final model
170 | save_pth = osp.join(respth, 'model_final_diss.pth')
171 | # net.cpu()
172 | state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
173 | if dist.get_rank() == 0:
174 | torch.save(state, save_pth)
175 | logger.info('training done, model saved to: {}'.format(save_pth))
176 |
177 |
178 | if __name__ == "__main__":
179 | train()
180 |
--------------------------------------------------------------------------------
/transform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | from PIL import Image
6 | import PIL.ImageEnhance as ImageEnhance
7 | import random
8 | import numpy as np
9 |
10 | class RandomCrop(object):
11 | def __init__(self, size, *args, **kwargs):
12 | self.size = size
13 |
14 | def __call__(self, im_lb):
15 | im = im_lb['im']
16 | lb = im_lb['lb']
17 | assert im.size == lb.size
18 | W, H = self.size
19 | w, h = im.size
20 |
21 | if (W, H) == (w, h): return dict(im=im, lb=lb)
22 | if w < W or h < H:
23 | scale = float(W) / w if w < h else float(H) / h
24 | w, h = int(scale * w + 1), int(scale * h + 1)
25 | im = im.resize((w, h), Image.BILINEAR)
26 | lb = lb.resize((w, h), Image.NEAREST)
27 | sw, sh = random.random() * (w - W), random.random() * (h - H)
28 | crop = int(sw), int(sh), int(sw) + W, int(sh) + H
29 | return dict(
30 | im = im.crop(crop),
31 | lb = lb.crop(crop)
32 | )
33 |
34 |
35 | class HorizontalFlip(object):
36 | def __init__(self, p=0.5, *args, **kwargs):
37 | self.p = p
38 |
39 | def __call__(self, im_lb):
40 | if random.random() > self.p:
41 | return im_lb
42 | else:
43 | im = im_lb['im']
44 | lb = im_lb['lb']
45 |
46 | # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',
47 | # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
48 |
49 | flip_lb = np.array(lb)
50 | flip_lb[lb == 2] = 3
51 | flip_lb[lb == 3] = 2
52 | flip_lb[lb == 4] = 5
53 | flip_lb[lb == 5] = 4
54 | flip_lb[lb == 7] = 8
55 | flip_lb[lb == 8] = 7
56 | flip_lb = Image.fromarray(flip_lb)
57 | return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT),
58 | lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT),
59 | )
60 |
61 |
62 | class RandomScale(object):
63 | def __init__(self, scales=(1, ), *args, **kwargs):
64 | self.scales = scales
65 |
66 | def __call__(self, im_lb):
67 | im = im_lb['im']
68 | lb = im_lb['lb']
69 | W, H = im.size
70 | scale = random.choice(self.scales)
71 | w, h = int(W * scale), int(H * scale)
72 | return dict(im = im.resize((w, h), Image.BILINEAR),
73 | lb = lb.resize((w, h), Image.NEAREST),
74 | )
75 |
76 |
77 | class ColorJitter(object):
78 | def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs):
79 | if not brightness is None and brightness>0:
80 | self.brightness = [max(1-brightness, 0), 1+brightness]
81 | if not contrast is None and contrast>0:
82 | self.contrast = [max(1-contrast, 0), 1+contrast]
83 | if not saturation is None and saturation>0:
84 | self.saturation = [max(1-saturation, 0), 1+saturation]
85 |
86 | def __call__(self, im_lb):
87 | im = im_lb['im']
88 | lb = im_lb['lb']
89 | r_brightness = random.uniform(self.brightness[0], self.brightness[1])
90 | r_contrast = random.uniform(self.contrast[0], self.contrast[1])
91 | r_saturation = random.uniform(self.saturation[0], self.saturation[1])
92 | im = ImageEnhance.Brightness(im).enhance(r_brightness)
93 | im = ImageEnhance.Contrast(im).enhance(r_contrast)
94 | im = ImageEnhance.Color(im).enhance(r_saturation)
95 | return dict(im = im,
96 | lb = lb,
97 | )
98 |
99 |
100 | class MultiScale(object):
101 | def __init__(self, scales):
102 | self.scales = scales
103 |
104 | def __call__(self, img):
105 | W, H = img.size
106 | sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales]
107 | imgs = []
108 | [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes]
109 | return imgs
110 |
111 |
112 | class Compose(object):
113 | def __init__(self, do_list):
114 | self.do_list = do_list
115 |
116 | def __call__(self, im_lb):
117 | for comp in self.do_list:
118 | im_lb = comp(im_lb)
119 | return im_lb
120 |
121 |
122 |
123 |
124 | if __name__ == '__main__':
125 | flip = HorizontalFlip(p = 1)
126 | crop = RandomCrop((321, 321))
127 | rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0))
128 | img = Image.open('data/img.jpg')
129 | lb = Image.open('data/label.png')
130 |
--------------------------------------------------------------------------------