├── .gitignore ├── 6.jpg ├── FaceParsing.py ├── LICENSE ├── README.md ├── __init__.py ├── evaluate.py ├── face_dataset.py ├── hair.png ├── logger.py ├── loss.py ├── makeup.py ├── makeup ├── 116_1.png ├── 116_3.png ├── 116_lip_ori.png └── 116_ori.png ├── model.py ├── modules ├── __init__.py ├── bn.py ├── deeplab.py ├── dense.py ├── functions.py ├── misc.py ├── residual.py └── src │ ├── checks.h │ ├── inplace_abn.cpp │ ├── inplace_abn.h │ ├── inplace_abn_cpu.cpp │ ├── inplace_abn_cuda.cu │ ├── inplace_abn_cuda_half.cu │ └── utils │ ├── checks.h │ ├── common.h │ └── cuda.cuh ├── optimizer.py ├── prepropess_data.py ├── resnet.py ├── test.py ├── train.py └── transform.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # res 107 | res/ 108 | 109 | .idea/ 110 | 111 | -------------------------------------------------------------------------------- /6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/6.jpg -------------------------------------------------------------------------------- /FaceParsing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import collections 6 | import torchvision.transforms as transforms 7 | 8 | from face_parsing.model import BiSeNet 9 | 10 | class FaceParsing(object): 11 | def __init__(self, model_path=None): 12 | if model_path is None: 13 | # REVIEW this is a terrible default 14 | model_path = '../../../external/data/models/face_parsing/face_parsing_79999_iter.pth' 15 | 16 | self.net = BiSeNet(n_classes=19) 17 | self.net.load_state_dict(torch.load(model_path, map_location='cpu')) 18 | self.net.eval() 19 | 20 | self.transform = transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 23 | ]) 24 | 25 | self.label_by_idx = collections.OrderedDict( 26 | [(-1, 'unlabeled'), (0, 'background'), (1, 'skin'), 27 | (2, 'l_brow'), (3, 'r_brow'), (4, 'l_eye'), (5, 'r_eye'), 28 | (6, 'eye_g (eye glasses)'), (7, 'l_ear'), (8, 'r_ear'), (9, 'ear_r (ear ring)'), 29 | (10, 'nose'), (11, 'mouth'), (12, 'u_lip'), (13, 'l_lip'), 30 | (14, 'neck'), (15, 'neck_l (necklace)'), (16, 'cloth'), 31 | (17, 'hair'), (18, 'hat')]) 32 | self.idx_by_label = {v: k for k,v in self.label_by_idx.items()} 33 | 34 | def to_tensor(self, images): 35 | # images : N,H,W,C numpy.array 36 | return self.transform(images) 37 | 38 | def label_for_idx(idx): 39 | return self.label_by_idx[idx] 40 | 41 | def idx_for_label(label): 42 | return self.idx_by_label[label] 43 | 44 | # TODO rename to label_image to contrast with label, which takes a tensor 45 | def parse_face(self, images, device=0): 46 | # images : list of square PIL Images 47 | # device : which CUDA device to run on 48 | # 49 | # returns parsings : list of PIL Images 50 | # 51 | # 0 - background, 1 - skin, 2 - l_brow, 3 - r_brow, 4 - l_eye, 5 - r_eye, 6 - eye_g (eye glasses), 52 | # 7 - l_ear, 8 - r_ear, 9 - ear_r (ear ring), 10 - nose, 11 - mouth, 12 - u_lip, 13 - l_lip, 53 | # 14 - neck, 15 - neck_l (necklace), 16 - cloth, 17 - hair, 18 - hat 54 | 55 | # move the network to the correct device 56 | # REVIEW this makes it imposible to run on the cpu... 57 | self.net.to('cuda:{}'.format(device)) 58 | 59 | assert all(im.size[0] == im.size[1] for im in images) 60 | in_sizes = [im.size[0] for im in images] # im is square 61 | 62 | pt_images = [] 63 | for img in images: 64 | # seems to work best with images around 512 65 | img = img.resize((512, 512), Image.BILINEAR) 66 | img = self.to_tensor(img) 67 | pt_images.append(img) 68 | pt_images = torch.stack(pt_images, dim=0) 69 | 70 | # move the data to the device 71 | pt_images = pt_images.to('cuda:{}'.format(device)) 72 | 73 | out = self.label(pt_images) 74 | parsings = out.cpu().numpy().argmax(axis=1).astype(np.uint8) 75 | 76 | parsings = [Image.fromarray(parsing).resize((in_size, in_size), Image.NEAREST) 77 | for parsing, in_size in zip(parsings, in_sizes)] 78 | 79 | return parsings # list of PIL Images 80 | 81 | def label(self, pt_images): 82 | # N,H,W,C torch.tensor 83 | with torch.no_grad(): 84 | return self.net(pt_images)[0] 85 | 86 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 zll 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # face-parsing.PyTorch 2 | 3 |

4 | 5 | 6 | 7 |

8 | 9 | ### Contents 10 | - [Training](#training) 11 | - [Demo](#Demo) 12 | - [References](#references) 13 | 14 | ## Training 15 | 16 | 1. Prepare training data: 17 | -- download [CelebAMask-HQ dataset](https://github.com/switchablenorms/CelebAMask-HQ) 18 | 19 | -- change file path in the `prepropess_data.py` and run 20 | ```Shell 21 | python prepropess_data.py 22 | ``` 23 | 24 | 2. Train the model using CelebAMask-HQ dataset: 25 | Just run the train script: 26 | ``` 27 | $ CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py 28 | ``` 29 | 30 | If you do not wish to train the model, you can download [our pre-trained model](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812) and save it in `res/cp`. 31 | 32 | 33 | ## Demo 34 | 1. Evaluate the trained model using: 35 | ```Shell 36 | # evaluate using GPU 37 | python test.py 38 | ``` 39 | 40 | ## Face makeup using parsing maps 41 | [**face-makeup.PyTorch**](https://github.com/zllrunning/face-makeup.PyTorch) 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
 HairLip
Original InputOriginal InputOriginal Input
ColorColorColor
65 | 66 | 67 | ## References 68 | - [BiSeNet](https://github.com/CoinCheung/BiSeNet) -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/__init__.py -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from .logger import setup_logger 5 | from .model import BiSeNet 6 | from .face_dataset import FaceMask 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | import torch.distributed as dist 13 | 14 | import os 15 | import os.path as osp 16 | import logging 17 | import time 18 | import numpy as np 19 | from tqdm import tqdm 20 | import math 21 | from PIL import Image 22 | import torchvision.transforms as transforms 23 | import cv2 24 | 25 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): 26 | # Colors for all 20 parts 27 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], 28 | [255, 0, 85], [255, 0, 170], 29 | [0, 255, 0], [85, 255, 0], [170, 255, 0], 30 | [0, 255, 85], [0, 255, 170], 31 | [0, 0, 255], [85, 0, 255], [170, 0, 255], 32 | [0, 85, 255], [0, 170, 255], 33 | [255, 255, 0], [255, 255, 85], [255, 255, 170], 34 | [255, 0, 255], [255, 85, 255], [255, 170, 255], 35 | [0, 255, 255], [85, 255, 255], [170, 255, 255]] 36 | 37 | im = np.array(im) 38 | vis_im = im.copy().astype(np.uint8) 39 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 40 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 41 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 42 | 43 | num_of_class = np.max(vis_parsing_anno) 44 | 45 | for pi in range(1, num_of_class + 1): 46 | index = np.where(vis_parsing_anno == pi) 47 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 48 | 49 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 50 | # print(vis_parsing_anno_color.shape, vis_im.shape) 51 | vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) 52 | 53 | # Save result or not 54 | if save_im: 55 | cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 56 | 57 | # return vis_im 58 | 59 | def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): 60 | 61 | if not os.path.exists(respth): 62 | os.makedirs(respth) 63 | 64 | n_classes = 19 65 | net = BiSeNet(n_classes=n_classes) 66 | net.cuda() 67 | save_pth = osp.join('res/cp', cp) 68 | net.load_state_dict(torch.load(save_pth)) 69 | net.eval() 70 | 71 | to_tensor = transforms.Compose([ 72 | transforms.ToTensor(), 73 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 74 | ]) 75 | with torch.no_grad(): 76 | for image_path in os.listdir(dspth): 77 | img = Image.open(osp.join(dspth, image_path)) 78 | image = img.resize((512, 512), Image.BILINEAR) 79 | img = to_tensor(image) 80 | img = torch.unsqueeze(img, 0) 81 | img = img.cuda() 82 | out = net(img)[0] 83 | parsing = out.squeeze(0).cpu().numpy().argmax(0) 84 | 85 | vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path)) 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | if __name__ == "__main__": 94 | setup_logger('./res') 95 | evaluate() 96 | -------------------------------------------------------------------------------- /face_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms as transforms 7 | 8 | import os.path as osp 9 | import os 10 | from PIL import Image 11 | import numpy as np 12 | import json 13 | import cv2 14 | 15 | from .transform import * 16 | 17 | 18 | 19 | class FaceMask(Dataset): 20 | def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs): 21 | super(FaceMask, self).__init__(*args, **kwargs) 22 | assert mode in ('train', 'val', 'test') 23 | self.mode = mode 24 | self.ignore_lb = 255 25 | self.rootpth = rootpth 26 | 27 | self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img')) 28 | 29 | # pre-processing 30 | self.to_tensor = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 33 | ]) 34 | self.trans_train = Compose([ 35 | ColorJitter( 36 | brightness=0.5, 37 | contrast=0.5, 38 | saturation=0.5), 39 | HorizontalFlip(), 40 | RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)), 41 | RandomCrop(cropsize) 42 | ]) 43 | 44 | def __getitem__(self, idx): 45 | impth = self.imgs[idx] 46 | img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth)) 47 | img = img.resize((512, 512), Image.BILINEAR) 48 | label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P') 49 | # print(np.unique(np.array(label))) 50 | if self.mode == 'train': 51 | im_lb = dict(im=img, lb=label) 52 | im_lb = self.trans_train(im_lb) 53 | img, label = im_lb['im'], im_lb['lb'] 54 | img = self.to_tensor(img) 55 | label = np.array(label).astype(np.int64)[np.newaxis, :] 56 | return img, label 57 | 58 | def __len__(self): 59 | return len(self.imgs) 60 | 61 | 62 | if __name__ == "__main__": 63 | face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img' 64 | face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno' 65 | mask_path = '/home/zll/data/CelebAMask-HQ/mask' 66 | counter = 0 67 | total = 0 68 | for i in range(15): 69 | # files = os.listdir(osp.join(face_sep_mask, str(i))) 70 | 71 | atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 72 | 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 73 | 74 | for j in range(i*2000, (i+1)*2000): 75 | 76 | mask = np.zeros((512, 512)) 77 | 78 | for l, att in enumerate(atts, 1): 79 | total += 1 80 | file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png']) 81 | path = osp.join(face_sep_mask, str(i), file_name) 82 | 83 | if os.path.exists(path): 84 | counter += 1 85 | sep_mask = np.array(Image.open(path).convert('P')) 86 | # print(np.unique(sep_mask)) 87 | 88 | mask[sep_mask == 225] = l 89 | cv2.imwrite('{}/{}.png'.format(mask_path, j), mask) 90 | print(j) 91 | 92 | print(counter, total) 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /hair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/hair.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os.path as osp 6 | import time 7 | import sys 8 | import logging 9 | 10 | import torch.distributed as dist 11 | 12 | 13 | def setup_logger(logpth): 14 | logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) 15 | logfile = osp.join(logpth, logfile) 16 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 17 | log_level = logging.INFO 18 | if dist.is_initialized() and not dist.get_rank()==0: 19 | log_level = logging.ERROR 20 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) 21 | logging.root.addHandler(logging.StreamHandler()) 22 | 23 | 24 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import numpy as np 10 | 11 | 12 | class OhemCELoss(nn.Module): 13 | def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs): 14 | super(OhemCELoss, self).__init__() 15 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() 16 | self.n_min = n_min 17 | self.ignore_lb = ignore_lb 18 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') 19 | 20 | def forward(self, logits, labels): 21 | N, C, H, W = logits.size() 22 | loss = self.criteria(logits, labels).view(-1) 23 | loss, _ = torch.sort(loss, descending=True) 24 | if loss[self.n_min] > self.thresh: 25 | loss = loss[loss>self.thresh] 26 | else: 27 | loss = loss[:self.n_min] 28 | return torch.mean(loss) 29 | 30 | 31 | class SoftmaxFocalLoss(nn.Module): 32 | def __init__(self, gamma, ignore_lb=255, *args, **kwargs): 33 | super(SoftmaxFocalLoss, self).__init__() 34 | self.gamma = gamma 35 | self.nll = nn.NLLLoss(ignore_index=ignore_lb) 36 | 37 | def forward(self, logits, labels): 38 | scores = F.softmax(logits, dim=1) 39 | factor = torch.pow(1.-scores, self.gamma) 40 | log_score = F.log_softmax(logits, dim=1) 41 | log_score = factor * log_score 42 | loss = self.nll(log_score, labels) 43 | return loss 44 | 45 | 46 | if __name__ == '__main__': 47 | torch.manual_seed(15) 48 | criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() 49 | criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() 50 | net1 = nn.Sequential( 51 | nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), 52 | ) 53 | net1.cuda() 54 | net1.train() 55 | net2 = nn.Sequential( 56 | nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), 57 | ) 58 | net2.cuda() 59 | net2.train() 60 | 61 | with torch.no_grad(): 62 | inten = torch.randn(16, 3, 20, 20).cuda() 63 | lbs = torch.randint(0, 19, [16, 20, 20]).cuda() 64 | lbs[1, :, :] = 255 65 | 66 | logits1 = net1(inten) 67 | logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear') 68 | logits2 = net2(inten) 69 | logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear') 70 | 71 | loss1 = criteria1(logits1, lbs) 72 | loss2 = criteria2(logits2, lbs) 73 | loss = loss1 + loss2 74 | print(loss.detach().cpu()) 75 | loss.backward() 76 | -------------------------------------------------------------------------------- /makeup.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | from skimage.filters import gaussian 5 | 6 | 7 | def sharpen(img): 8 | img = img * 1.0 9 | gauss_out = gaussian(img, sigma=5, multichannel=True) 10 | 11 | alpha = 1.5 12 | img_out = (img - gauss_out) * alpha + img 13 | 14 | img_out = img_out / 255.0 15 | 16 | mask_1 = img_out < 0 17 | mask_2 = img_out > 1 18 | 19 | img_out = img_out * (1 - mask_1) 20 | img_out = img_out * (1 - mask_2) + mask_2 21 | img_out = np.clip(img_out, 0, 1) 22 | img_out = img_out * 255 23 | return np.array(img_out, dtype=np.uint8) 24 | 25 | 26 | def hair(image, parsing, part=17, color=[230, 50, 20]): 27 | b, g, r = color #[10, 50, 250] # [10, 250, 10] 28 | tar_color = np.zeros_like(image) 29 | tar_color[:, :, 0] = b 30 | tar_color[:, :, 1] = g 31 | tar_color[:, :, 2] = r 32 | 33 | image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 34 | tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV) 35 | 36 | if part == 12 or part == 13: 37 | image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2] 38 | else: 39 | image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1] 40 | 41 | changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR) 42 | 43 | if part == 17: 44 | changed = sharpen(changed) 45 | 46 | changed[parsing != part] = image[parsing != part] 47 | # changed = cv2.resize(changed, (512, 512)) 48 | return changed 49 | 50 | # 51 | # def lip(image, parsing, part=17, color=[230, 50, 20]): 52 | # b, g, r = color #[10, 50, 250] # [10, 250, 10] 53 | # tar_color = np.zeros_like(image) 54 | # tar_color[:, :, 0] = b 55 | # tar_color[:, :, 1] = g 56 | # tar_color[:, :, 2] = r 57 | # 58 | # image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) 59 | # il, ia, ib = cv2.split(image_lab) 60 | # 61 | # tar_lab = cv2.cvtColor(tar_color, cv2.COLOR_BGR2Lab) 62 | # tl, ta, tb = cv2.split(tar_lab) 63 | # 64 | # image_lab[:, :, 0] = np.clip(il - np.mean(il) + tl, 0, 100) 65 | # image_lab[:, :, 1] = np.clip(ia - np.mean(ia) + ta, -127, 128) 66 | # image_lab[:, :, 2] = np.clip(ib - np.mean(ib) + tb, -127, 128) 67 | # 68 | # 69 | # changed = cv2.cvtColor(image_lab, cv2.COLOR_Lab2BGR) 70 | # 71 | # if part == 17: 72 | # changed = sharpen(changed) 73 | # 74 | # changed[parsing != part] = image[parsing != part] 75 | # # changed = cv2.resize(changed, (512, 512)) 76 | # return changed 77 | 78 | 79 | if __name__ == '__main__': 80 | # 1 face 81 | # 10 nose 82 | # 11 teeth 83 | # 12 upper lip 84 | # 13 lower lip 85 | # 17 hair 86 | num = 116 87 | table = { 88 | 'hair': 17, 89 | 'upper_lip': 12, 90 | 'lower_lip': 13 91 | } 92 | image_path = '/home/zll/data/CelebAMask-HQ/test-img/{}.jpg'.format(num) 93 | parsing_path = 'res/test_res/{}.png'.format(num) 94 | 95 | image = cv2.imread(image_path) 96 | ori = image.copy() 97 | parsing = np.array(cv2.imread(parsing_path, 0)) 98 | parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST) 99 | 100 | parts = [table['hair'], table['upper_lip'], table['lower_lip']] 101 | # colors = [[20, 20, 200], [100, 100, 230], [100, 100, 230]] 102 | colors = [[100, 200, 100]] 103 | for part, color in zip(parts, colors): 104 | image = hair(image, parsing, part, color) 105 | cv2.imwrite('res/makeup/116_ori.png', cv2.resize(ori, (512, 512))) 106 | cv2.imwrite('res/makeup/116_2.png', cv2.resize(image, (512, 512))) 107 | 108 | cv2.imshow('image', cv2.resize(ori, (512, 512))) 109 | cv2.imshow('color', cv2.resize(image, (512, 512))) 110 | 111 | # cv2.imshow('image', ori) 112 | # cv2.imshow('color', image) 113 | 114 | cv2.waitKey(0) 115 | cv2.destroyAllWindows() 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /makeup/116_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/makeup/116_1.png -------------------------------------------------------------------------------- /makeup/116_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/makeup/116_3.png -------------------------------------------------------------------------------- /makeup/116_lip_ori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/makeup/116_lip_ori.png -------------------------------------------------------------------------------- /makeup/116_ori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionSystemsInc/face-parsing.PyTorch/041d38aa906007adb0d35ea2fcaeba0b9068af27/makeup/116_ori.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from .resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size = ks, 20 | stride = stride, 21 | padding = padding, 22 | bias = False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 36 | 37 | class BiSeNetOutput(nn.Module): 38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 39 | super(BiSeNetOutput, self).__init__() 40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 42 | self.init_weight() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.conv_out(x) 47 | return x 48 | 49 | def init_weight(self): 50 | for ly in self.children(): 51 | if isinstance(ly, nn.Conv2d): 52 | nn.init.kaiming_normal_(ly.weight, a=1) 53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 54 | 55 | def get_params(self): 56 | wd_params, nowd_params = [], [] 57 | for name, module in self.named_modules(): 58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 59 | wd_params.append(module.weight) 60 | if not module.bias is None: 61 | nowd_params.append(module.bias) 62 | elif isinstance(module, nn.BatchNorm2d): 63 | nowd_params += list(module.parameters()) 64 | return wd_params, nowd_params 65 | 66 | 67 | class AttentionRefinementModule(nn.Module): 68 | def __init__(self, in_chan, out_chan, *args, **kwargs): 69 | super(AttentionRefinementModule, self).__init__() 70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 72 | self.bn_atten = nn.BatchNorm2d(out_chan) 73 | self.sigmoid_atten = nn.Sigmoid() 74 | self.init_weight() 75 | 76 | def forward(self, x): 77 | feat = self.conv(x) 78 | atten = F.avg_pool2d(feat, feat.size()[2:]) 79 | atten = self.conv_atten(atten) 80 | atten = self.bn_atten(atten) 81 | atten = self.sigmoid_atten(atten) 82 | out = torch.mul(feat, atten) 83 | return out 84 | 85 | def init_weight(self): 86 | for ly in self.children(): 87 | if isinstance(ly, nn.Conv2d): 88 | nn.init.kaiming_normal_(ly.weight, a=1) 89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 90 | 91 | 92 | class ContextPath(nn.Module): 93 | def __init__(self, *args, **kwargs): 94 | super(ContextPath, self).__init__() 95 | self.resnet = Resnet18() 96 | self.arm16 = AttentionRefinementModule(256, 128) 97 | self.arm32 = AttentionRefinementModule(512, 128) 98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 101 | 102 | self.init_weight() 103 | 104 | def forward(self, x): 105 | H0, W0 = x.size()[2:] 106 | feat8, feat16, feat32 = self.resnet(x) 107 | H8, W8 = feat8.size()[2:] 108 | H16, W16 = feat16.size()[2:] 109 | H32, W32 = feat32.size()[2:] 110 | 111 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 112 | avg = self.conv_avg(avg) 113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 114 | 115 | feat32_arm = self.arm32(feat32) 116 | feat32_sum = feat32_arm + avg_up 117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 118 | feat32_up = self.conv_head32(feat32_up) 119 | 120 | feat16_arm = self.arm16(feat16) 121 | feat16_sum = feat16_arm + feat32_up 122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 123 | feat16_up = self.conv_head16(feat16_up) 124 | 125 | return feat8, feat16_up, feat32_up # x8, x8, x16 126 | 127 | def init_weight(self): 128 | for ly in self.children(): 129 | if isinstance(ly, nn.Conv2d): 130 | nn.init.kaiming_normal_(ly.weight, a=1) 131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 132 | 133 | def get_params(self): 134 | wd_params, nowd_params = [], [] 135 | for name, module in self.named_modules(): 136 | if isinstance(module, (nn.Linear, nn.Conv2d)): 137 | wd_params.append(module.weight) 138 | if not module.bias is None: 139 | nowd_params.append(module.bias) 140 | elif isinstance(module, nn.BatchNorm2d): 141 | nowd_params += list(module.parameters()) 142 | return wd_params, nowd_params 143 | 144 | 145 | ### This is not used, since I replace this with the resnet feature with the same size 146 | class SpatialPath(nn.Module): 147 | def __init__(self, *args, **kwargs): 148 | super(SpatialPath, self).__init__() 149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 153 | self.init_weight() 154 | 155 | def forward(self, x): 156 | feat = self.conv1(x) 157 | feat = self.conv2(feat) 158 | feat = self.conv3(feat) 159 | feat = self.conv_out(feat) 160 | return feat 161 | 162 | def init_weight(self): 163 | for ly in self.children(): 164 | if isinstance(ly, nn.Conv2d): 165 | nn.init.kaiming_normal_(ly.weight, a=1) 166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 167 | 168 | def get_params(self): 169 | wd_params, nowd_params = [], [] 170 | for name, module in self.named_modules(): 171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 172 | wd_params.append(module.weight) 173 | if not module.bias is None: 174 | nowd_params.append(module.bias) 175 | elif isinstance(module, nn.BatchNorm2d): 176 | nowd_params += list(module.parameters()) 177 | return wd_params, nowd_params 178 | 179 | 180 | class FeatureFusionModule(nn.Module): 181 | def __init__(self, in_chan, out_chan, *args, **kwargs): 182 | super(FeatureFusionModule, self).__init__() 183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 184 | self.conv1 = nn.Conv2d(out_chan, 185 | out_chan//4, 186 | kernel_size = 1, 187 | stride = 1, 188 | padding = 0, 189 | bias = False) 190 | self.conv2 = nn.Conv2d(out_chan//4, 191 | out_chan, 192 | kernel_size = 1, 193 | stride = 1, 194 | padding = 0, 195 | bias = False) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.sigmoid = nn.Sigmoid() 198 | self.init_weight() 199 | 200 | def forward(self, fsp, fcp): 201 | fcat = torch.cat([fsp, fcp], dim=1) 202 | feat = self.convblk(fcat) 203 | atten = F.avg_pool2d(feat, feat.size()[2:]) 204 | atten = self.conv1(atten) 205 | atten = self.relu(atten) 206 | atten = self.conv2(atten) 207 | atten = self.sigmoid(atten) 208 | feat_atten = torch.mul(feat, atten) 209 | feat_out = feat_atten + feat 210 | return feat_out 211 | 212 | def init_weight(self): 213 | for ly in self.children(): 214 | if isinstance(ly, nn.Conv2d): 215 | nn.init.kaiming_normal_(ly.weight, a=1) 216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 217 | 218 | def get_params(self): 219 | wd_params, nowd_params = [], [] 220 | for name, module in self.named_modules(): 221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 222 | wd_params.append(module.weight) 223 | if not module.bias is None: 224 | nowd_params.append(module.bias) 225 | elif isinstance(module, nn.BatchNorm2d): 226 | nowd_params += list(module.parameters()) 227 | return wd_params, nowd_params 228 | 229 | 230 | class BiSeNet(nn.Module): 231 | def __init__(self, n_classes, *args, **kwargs): 232 | super(BiSeNet, self).__init__() 233 | self.cp = ContextPath() 234 | ## here self.sp is deleted 235 | self.ffm = FeatureFusionModule(256, 256) 236 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | H, W = x.size()[2:] 243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 245 | feat_fuse = self.ffm(feat_sp, feat_cp8) 246 | 247 | feat_out = self.conv_out(feat_fuse) 248 | feat_out16 = self.conv_out16(feat_cp8) 249 | feat_out32 = self.conv_out32(feat_cp16) 250 | 251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 254 | return feat_out, feat_out16, feat_out32 255 | 256 | def init_weight(self): 257 | for ly in self.children(): 258 | if isinstance(ly, nn.Conv2d): 259 | nn.init.kaiming_normal_(ly.weight, a=1) 260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 261 | 262 | def get_params(self): 263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 264 | for name, child in self.named_children(): 265 | child_wd_params, child_nowd_params = child.get_params() 266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 267 | lr_mul_wd_params += child_wd_params 268 | lr_mul_nowd_params += child_nowd_params 269 | else: 270 | wd_params += child_wd_params 271 | nowd_params += child_nowd_params 272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 273 | 274 | 275 | if __name__ == "__main__": 276 | net = BiSeNet(19) 277 | net.cuda() 278 | net.eval() 279 | in_ten = torch.randn(16, 3, 640, 480).cuda() 280 | out, out16, out32 = net(in_ten) 281 | print(out.shape) 282 | 283 | net.get_params() 284 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNSync 2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE 3 | from .misc import GlobalAvgPool2d, SingleGPU 4 | from .residual import IdentityResidualBlock 5 | from .dense import DenseModule 6 | -------------------------------------------------------------------------------- /modules/bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | try: 6 | from queue import Queue 7 | except ImportError: 8 | from Queue import Queue 9 | 10 | from .functions import * 11 | 12 | 13 | class ABN(nn.Module): 14 | """Activated Batch Normalization 15 | 16 | This gathers a `BatchNorm2d` and an activation function in a single module 17 | """ 18 | 19 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 20 | """Creates an Activated Batch Normalization module 21 | 22 | Parameters 23 | ---------- 24 | num_features : int 25 | Number of feature channels in the input and output. 26 | eps : float 27 | Small constant to prevent numerical issues. 28 | momentum : float 29 | Momentum factor applied to compute running statistics as. 30 | affine : bool 31 | If `True` apply learned scale and shift transformation after normalization. 32 | activation : str 33 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 34 | slope : float 35 | Negative slope for the `leaky_relu` activation. 36 | """ 37 | super(ABN, self).__init__() 38 | self.num_features = num_features 39 | self.affine = affine 40 | self.eps = eps 41 | self.momentum = momentum 42 | self.activation = activation 43 | self.slope = slope 44 | if self.affine: 45 | self.weight = nn.Parameter(torch.ones(num_features)) 46 | self.bias = nn.Parameter(torch.zeros(num_features)) 47 | else: 48 | self.register_parameter('weight', None) 49 | self.register_parameter('bias', None) 50 | self.register_buffer('running_mean', torch.zeros(num_features)) 51 | self.register_buffer('running_var', torch.ones(num_features)) 52 | self.reset_parameters() 53 | 54 | def reset_parameters(self): 55 | nn.init.constant_(self.running_mean, 0) 56 | nn.init.constant_(self.running_var, 1) 57 | if self.affine: 58 | nn.init.constant_(self.weight, 1) 59 | nn.init.constant_(self.bias, 0) 60 | 61 | def forward(self, x): 62 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 63 | self.training, self.momentum, self.eps) 64 | 65 | if self.activation == ACT_RELU: 66 | return functional.relu(x, inplace=True) 67 | elif self.activation == ACT_LEAKY_RELU: 68 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) 69 | elif self.activation == ACT_ELU: 70 | return functional.elu(x, inplace=True) 71 | else: 72 | return x 73 | 74 | def __repr__(self): 75 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 76 | ' affine={affine}, activation={activation}' 77 | if self.activation == "leaky_relu": 78 | rep += ', slope={slope})' 79 | else: 80 | rep += ')' 81 | return rep.format(name=self.__class__.__name__, **self.__dict__) 82 | 83 | 84 | class InPlaceABN(ABN): 85 | """InPlace Activated Batch Normalization""" 86 | 87 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 88 | """Creates an InPlace Activated Batch Normalization module 89 | 90 | Parameters 91 | ---------- 92 | num_features : int 93 | Number of feature channels in the input and output. 94 | eps : float 95 | Small constant to prevent numerical issues. 96 | momentum : float 97 | Momentum factor applied to compute running statistics as. 98 | affine : bool 99 | If `True` apply learned scale and shift transformation after normalization. 100 | activation : str 101 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 102 | slope : float 103 | Negative slope for the `leaky_relu` activation. 104 | """ 105 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) 106 | 107 | def forward(self, x): 108 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, 109 | self.training, self.momentum, self.eps, self.activation, self.slope) 110 | 111 | 112 | class InPlaceABNSync(ABN): 113 | """InPlace Activated Batch Normalization with cross-GPU synchronization 114 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`. 115 | """ 116 | 117 | def forward(self, x): 118 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, 119 | self.training, self.momentum, self.eps, self.activation, self.slope) 120 | 121 | def __repr__(self): 122 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 123 | ' affine={affine}, activation={activation}' 124 | if self.activation == "leaky_relu": 125 | rep += ', slope={slope})' 126 | else: 127 | rep += ')' 128 | return rep.format(name=self.__class__.__name__, **self.__dict__) 129 | 130 | 131 | -------------------------------------------------------------------------------- /modules/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | from models._util import try_index 6 | from .bn import ABN 7 | 8 | 9 | class DeeplabV3(nn.Module): 10 | def __init__(self, 11 | in_channels, 12 | out_channels, 13 | hidden_channels=256, 14 | dilations=(12, 24, 36), 15 | norm_act=ABN, 16 | pooling_size=None): 17 | super(DeeplabV3, self).__init__() 18 | self.pooling_size = pooling_size 19 | 20 | self.map_convs = nn.ModuleList([ 21 | nn.Conv2d(in_channels, hidden_channels, 1, bias=False), 22 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]), 23 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]), 24 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2]) 25 | ]) 26 | self.map_bn = norm_act(hidden_channels * 4) 27 | 28 | self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) 29 | self.global_pooling_bn = norm_act(hidden_channels) 30 | 31 | self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) 32 | self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) 33 | self.red_bn = norm_act(out_channels) 34 | 35 | self.reset_parameters(self.map_bn.activation, self.map_bn.slope) 36 | 37 | def reset_parameters(self, activation, slope): 38 | gain = nn.init.calculate_gain(activation, slope) 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | nn.init.xavier_normal_(m.weight.data, gain) 42 | if hasattr(m, "bias") and m.bias is not None: 43 | nn.init.constant_(m.bias, 0) 44 | elif isinstance(m, ABN): 45 | if hasattr(m, "weight") and m.weight is not None: 46 | nn.init.constant_(m.weight, 1) 47 | if hasattr(m, "bias") and m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | 50 | def forward(self, x): 51 | # Map convolutions 52 | out = torch.cat([m(x) for m in self.map_convs], dim=1) 53 | out = self.map_bn(out) 54 | out = self.red_conv(out) 55 | 56 | # Global pooling 57 | pool = self._global_pooling(x) 58 | pool = self.global_pooling_conv(pool) 59 | pool = self.global_pooling_bn(pool) 60 | pool = self.pool_red_conv(pool) 61 | if self.training or self.pooling_size is None: 62 | pool = pool.repeat(1, 1, x.size(2), x.size(3)) 63 | 64 | out += pool 65 | out = self.red_bn(out) 66 | return out 67 | 68 | def _global_pooling(self, x): 69 | if self.training or self.pooling_size is None: 70 | pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) 71 | pool = pool.view(x.size(0), x.size(1), 1, 1) 72 | else: 73 | pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]), 74 | min(try_index(self.pooling_size, 1), x.shape[3])) 75 | padding = ( 76 | (pooling_size[1] - 1) // 2, 77 | (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1, 78 | (pooling_size[0] - 1) // 2, 79 | (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1 80 | ) 81 | 82 | pool = functional.avg_pool2d(x, pooling_size, stride=1) 83 | pool = functional.pad(pool, pad=padding, mode="replicate") 84 | return pool 85 | -------------------------------------------------------------------------------- /modules/dense.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .bn import ABN 7 | 8 | 9 | class DenseModule(nn.Module): 10 | def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1): 11 | super(DenseModule, self).__init__() 12 | self.in_channels = in_channels 13 | self.growth = growth 14 | self.layers = layers 15 | 16 | self.convs1 = nn.ModuleList() 17 | self.convs3 = nn.ModuleList() 18 | for i in range(self.layers): 19 | self.convs1.append(nn.Sequential(OrderedDict([ 20 | ("bn", norm_act(in_channels)), 21 | ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False)) 22 | ]))) 23 | self.convs3.append(nn.Sequential(OrderedDict([ 24 | ("bn", norm_act(self.growth * bottleneck_factor)), 25 | ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False, 26 | dilation=dilation)) 27 | ]))) 28 | in_channels += self.growth 29 | 30 | @property 31 | def out_channels(self): 32 | return self.in_channels + self.growth * self.layers 33 | 34 | def forward(self, x): 35 | inputs = [x] 36 | for i in range(self.layers): 37 | x = torch.cat(inputs, dim=1) 38 | x = self.convs1[i](x) 39 | x = self.convs3[i](x) 40 | inputs += [x] 41 | 42 | return torch.cat(inputs, dim=1) 43 | -------------------------------------------------------------------------------- /modules/functions.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import torch 3 | import torch.distributed as dist 4 | import torch.autograd as autograd 5 | import torch.cuda.comm as comm 6 | from torch.autograd.function import once_differentiable 7 | from torch.utils.cpp_extension import load 8 | 9 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src") 10 | _backend = load(name="inplace_abn", 11 | extra_cflags=["-O3"], 12 | sources=[path.join(_src_path, f) for f in [ 13 | "inplace_abn.cpp", 14 | "inplace_abn_cpu.cpp", 15 | "inplace_abn_cuda.cu", 16 | "inplace_abn_cuda_half.cu" 17 | ]], 18 | extra_cuda_cflags=["--expt-extended-lambda"]) 19 | 20 | # Activation names 21 | ACT_RELU = "relu" 22 | ACT_LEAKY_RELU = "leaky_relu" 23 | ACT_ELU = "elu" 24 | ACT_NONE = "none" 25 | 26 | 27 | def _check(fn, *args, **kwargs): 28 | success = fn(*args, **kwargs) 29 | if not success: 30 | raise RuntimeError("CUDA Error encountered in {}".format(fn)) 31 | 32 | 33 | def _broadcast_shape(x): 34 | out_size = [] 35 | for i, s in enumerate(x.size()): 36 | if i != 1: 37 | out_size.append(1) 38 | else: 39 | out_size.append(s) 40 | return out_size 41 | 42 | 43 | def _reduce(x): 44 | if len(x.size()) == 2: 45 | return x.sum(dim=0) 46 | else: 47 | n, c = x.size()[0:2] 48 | return x.contiguous().view((n, c, -1)).sum(2).sum(0) 49 | 50 | 51 | def _count_samples(x): 52 | count = 1 53 | for i, s in enumerate(x.size()): 54 | if i != 1: 55 | count *= s 56 | return count 57 | 58 | 59 | def _act_forward(ctx, x): 60 | if ctx.activation == ACT_LEAKY_RELU: 61 | _backend.leaky_relu_forward(x, ctx.slope) 62 | elif ctx.activation == ACT_ELU: 63 | _backend.elu_forward(x) 64 | elif ctx.activation == ACT_NONE: 65 | pass 66 | 67 | 68 | def _act_backward(ctx, x, dx): 69 | if ctx.activation == ACT_LEAKY_RELU: 70 | _backend.leaky_relu_backward(x, dx, ctx.slope) 71 | elif ctx.activation == ACT_ELU: 72 | _backend.elu_backward(x, dx) 73 | elif ctx.activation == ACT_NONE: 74 | pass 75 | 76 | 77 | class InPlaceABN(autograd.Function): 78 | @staticmethod 79 | def forward(ctx, x, weight, bias, running_mean, running_var, 80 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 81 | # Save context 82 | ctx.training = training 83 | ctx.momentum = momentum 84 | ctx.eps = eps 85 | ctx.activation = activation 86 | ctx.slope = slope 87 | ctx.affine = weight is not None and bias is not None 88 | 89 | # Prepare inputs 90 | count = _count_samples(x) 91 | x = x.contiguous() 92 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 93 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 94 | 95 | if ctx.training: 96 | mean, var = _backend.mean_var(x) 97 | 98 | # Update running stats 99 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 100 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 101 | 102 | # Mark in-place modified tensors 103 | ctx.mark_dirty(x, running_mean, running_var) 104 | else: 105 | mean, var = running_mean.contiguous(), running_var.contiguous() 106 | ctx.mark_dirty(x) 107 | 108 | # BN forward + activation 109 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 110 | _act_forward(ctx, x) 111 | 112 | # Output 113 | ctx.var = var 114 | ctx.save_for_backward(x, var, weight, bias) 115 | return x 116 | 117 | @staticmethod 118 | @once_differentiable 119 | def backward(ctx, dz): 120 | z, var, weight, bias = ctx.saved_tensors 121 | dz = dz.contiguous() 122 | 123 | # Undo activation 124 | _act_backward(ctx, z, dz) 125 | 126 | if ctx.training: 127 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 128 | else: 129 | # TODO: implement simplified CUDA backward for inference mode 130 | edz = dz.new_zeros(dz.size(1)) 131 | eydz = dz.new_zeros(dz.size(1)) 132 | 133 | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 134 | dweight = eydz * weight.sign() if ctx.affine else None 135 | dbias = edz if ctx.affine else None 136 | 137 | return dx, dweight, dbias, None, None, None, None, None, None, None 138 | 139 | class InPlaceABNSync(autograd.Function): 140 | @classmethod 141 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 142 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True): 143 | # Save context 144 | ctx.training = training 145 | ctx.momentum = momentum 146 | ctx.eps = eps 147 | ctx.activation = activation 148 | ctx.slope = slope 149 | ctx.affine = weight is not None and bias is not None 150 | 151 | # Prepare inputs 152 | ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1 153 | 154 | #count = _count_samples(x) 155 | batch_size = x.new_tensor([x.shape[0]],dtype=torch.long) 156 | 157 | x = x.contiguous() 158 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 159 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 160 | 161 | if ctx.training: 162 | mean, var = _backend.mean_var(x) 163 | if ctx.world_size>1: 164 | # get global batch size 165 | if equal_batches: 166 | batch_size *= ctx.world_size 167 | else: 168 | dist.all_reduce(batch_size, dist.ReduceOp.SUM) 169 | 170 | ctx.factor = x.shape[0]/float(batch_size.item()) 171 | 172 | mean_all = mean.clone() * ctx.factor 173 | dist.all_reduce(mean_all, dist.ReduceOp.SUM) 174 | 175 | var_all = (var + (mean - mean_all) ** 2) * ctx.factor 176 | dist.all_reduce(var_all, dist.ReduceOp.SUM) 177 | 178 | mean = mean_all 179 | var = var_all 180 | 181 | # Update running stats 182 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 183 | count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1] 184 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1))) 185 | 186 | # Mark in-place modified tensors 187 | ctx.mark_dirty(x, running_mean, running_var) 188 | else: 189 | mean, var = running_mean.contiguous(), running_var.contiguous() 190 | ctx.mark_dirty(x) 191 | 192 | # BN forward + activation 193 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 194 | _act_forward(ctx, x) 195 | 196 | # Output 197 | ctx.var = var 198 | ctx.save_for_backward(x, var, weight, bias) 199 | return x 200 | 201 | @staticmethod 202 | @once_differentiable 203 | def backward(ctx, dz): 204 | z, var, weight, bias = ctx.saved_tensors 205 | dz = dz.contiguous() 206 | 207 | # Undo activation 208 | _act_backward(ctx, z, dz) 209 | 210 | if ctx.training: 211 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 212 | edz_local = edz.clone() 213 | eydz_local = eydz.clone() 214 | 215 | if ctx.world_size>1: 216 | edz *= ctx.factor 217 | dist.all_reduce(edz, dist.ReduceOp.SUM) 218 | 219 | eydz *= ctx.factor 220 | dist.all_reduce(eydz, dist.ReduceOp.SUM) 221 | else: 222 | edz_local = edz = dz.new_zeros(dz.size(1)) 223 | eydz_local = eydz = dz.new_zeros(dz.size(1)) 224 | 225 | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 226 | dweight = eydz_local * weight.sign() if ctx.affine else None 227 | dbias = edz_local if ctx.affine else None 228 | 229 | return dx, dweight, dbias, None, None, None, None, None, None, None 230 | 231 | inplace_abn = InPlaceABN.apply 232 | inplace_abn_sync = InPlaceABNSync.apply 233 | 234 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] 235 | -------------------------------------------------------------------------------- /modules/misc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class GlobalAvgPool2d(nn.Module): 6 | def __init__(self): 7 | """Global average pooling over the input's spatial dimensions""" 8 | super(GlobalAvgPool2d, self).__init__() 9 | 10 | def forward(self, inputs): 11 | in_size = inputs.size() 12 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 13 | 14 | class SingleGPU(nn.Module): 15 | def __init__(self, module): 16 | super(SingleGPU, self).__init__() 17 | self.module=module 18 | 19 | def forward(self, input): 20 | return self.module(input.cuda(non_blocking=True)) 21 | 22 | -------------------------------------------------------------------------------- /modules/residual.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | 5 | from .bn import ABN 6 | 7 | 8 | class IdentityResidualBlock(nn.Module): 9 | def __init__(self, 10 | in_channels, 11 | channels, 12 | stride=1, 13 | dilation=1, 14 | groups=1, 15 | norm_act=ABN, 16 | dropout=None): 17 | """Configurable identity-mapping residual block 18 | 19 | Parameters 20 | ---------- 21 | in_channels : int 22 | Number of input channels. 23 | channels : list of int 24 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 25 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 26 | `3 x 3` then `1 x 1` convolutions. 27 | stride : int 28 | Stride of the first `3 x 3` convolution 29 | dilation : int 30 | Dilation to apply to the `3 x 3` convolutions. 31 | groups : int 32 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 33 | bottleneck blocks. 34 | norm_act : callable 35 | Function to create normalization / activation Module. 36 | dropout: callable 37 | Function to create Dropout Module. 38 | """ 39 | super(IdentityResidualBlock, self).__init__() 40 | 41 | # Check parameters for inconsistencies 42 | if len(channels) != 2 and len(channels) != 3: 43 | raise ValueError("channels must contain either two or three values") 44 | if len(channels) == 2 and groups != 1: 45 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 46 | 47 | is_bottleneck = len(channels) == 3 48 | need_proj_conv = stride != 1 or in_channels != channels[-1] 49 | 50 | self.bn1 = norm_act(in_channels) 51 | if not is_bottleneck: 52 | layers = [ 53 | ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, 54 | dilation=dilation)), 55 | ("bn2", norm_act(channels[0])), 56 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 57 | dilation=dilation)) 58 | ] 59 | if dropout is not None: 60 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 61 | else: 62 | layers = [ 63 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)), 64 | ("bn2", norm_act(channels[0])), 65 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 66 | groups=groups, dilation=dilation)), 67 | ("bn3", norm_act(channels[1])), 68 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)) 69 | ] 70 | if dropout is not None: 71 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 72 | self.convs = nn.Sequential(OrderedDict(layers)) 73 | 74 | if need_proj_conv: 75 | self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 76 | 77 | def forward(self, x): 78 | if hasattr(self, "proj_conv"): 79 | bn1 = self.bn1(x) 80 | shortcut = self.proj_conv(bn1) 81 | else: 82 | shortcut = x.clone() 83 | bn1 = self.bn1(x) 84 | 85 | out = self.convs(bn1) 86 | out.add_(shortcut) 87 | 88 | return out 89 | -------------------------------------------------------------------------------- /modules/src/checks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT 6 | #ifndef AT_CHECK 7 | #define AT_CHECK AT_ASSERT 8 | #endif 9 | 10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") 12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") 13 | 14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) -------------------------------------------------------------------------------- /modules/src/inplace_abn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | std::vector mean_var(at::Tensor x) { 8 | if (x.is_cuda()) { 9 | if (x.type().scalarType() == at::ScalarType::Half) { 10 | return mean_var_cuda_h(x); 11 | } else { 12 | return mean_var_cuda(x); 13 | } 14 | } else { 15 | return mean_var_cpu(x); 16 | } 17 | } 18 | 19 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 20 | bool affine, float eps) { 21 | if (x.is_cuda()) { 22 | if (x.type().scalarType() == at::ScalarType::Half) { 23 | return forward_cuda_h(x, mean, var, weight, bias, affine, eps); 24 | } else { 25 | return forward_cuda(x, mean, var, weight, bias, affine, eps); 26 | } 27 | } else { 28 | return forward_cpu(x, mean, var, weight, bias, affine, eps); 29 | } 30 | } 31 | 32 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 33 | bool affine, float eps) { 34 | if (z.is_cuda()) { 35 | if (z.type().scalarType() == at::ScalarType::Half) { 36 | return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps); 37 | } else { 38 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps); 39 | } 40 | } else { 41 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps); 42 | } 43 | } 44 | 45 | at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 46 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 47 | if (z.is_cuda()) { 48 | if (z.type().scalarType() == at::ScalarType::Half) { 49 | return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps); 50 | } else { 51 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); 52 | } 53 | } else { 54 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); 55 | } 56 | } 57 | 58 | void leaky_relu_forward(at::Tensor z, float slope) { 59 | at::leaky_relu_(z, slope); 60 | } 61 | 62 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { 63 | if (z.is_cuda()) { 64 | if (z.type().scalarType() == at::ScalarType::Half) { 65 | return leaky_relu_backward_cuda_h(z, dz, slope); 66 | } else { 67 | return leaky_relu_backward_cuda(z, dz, slope); 68 | } 69 | } else { 70 | return leaky_relu_backward_cpu(z, dz, slope); 71 | } 72 | } 73 | 74 | void elu_forward(at::Tensor z) { 75 | at::elu_(z); 76 | } 77 | 78 | void elu_backward(at::Tensor z, at::Tensor dz) { 79 | if (z.is_cuda()) { 80 | return elu_backward_cuda(z, dz); 81 | } else { 82 | return elu_backward_cpu(z, dz); 83 | } 84 | } 85 | 86 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 87 | m.def("mean_var", &mean_var, "Mean and variance computation"); 88 | m.def("forward", &forward, "In-place forward computation"); 89 | m.def("edz_eydz", &edz_eydz, "First part of backward computation"); 90 | m.def("backward", &backward, "Second part of backward computation"); 91 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); 92 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); 93 | m.def("elu_forward", &elu_forward, "Elu forward computation"); 94 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); 95 | } 96 | -------------------------------------------------------------------------------- /modules/src/inplace_abn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | std::vector mean_var_cpu(at::Tensor x); 8 | std::vector mean_var_cuda(at::Tensor x); 9 | std::vector mean_var_cuda_h(at::Tensor x); 10 | 11 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 12 | bool affine, float eps); 13 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 14 | bool affine, float eps); 15 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps); 17 | 18 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 19 | bool affine, float eps); 20 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 21 | bool affine, float eps); 22 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 23 | bool affine, float eps); 24 | 25 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 26 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 27 | at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 28 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 29 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 30 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 31 | 32 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); 33 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); 34 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope); 35 | 36 | void elu_backward_cpu(at::Tensor z, at::Tensor dz); 37 | void elu_backward_cuda(at::Tensor z, at::Tensor dz); 38 | 39 | static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) { 40 | num = x.size(0); 41 | chn = x.size(1); 42 | sp = 1; 43 | for (int64_t i = 2; i < x.ndimension(); ++i) 44 | sp *= x.size(i); 45 | } 46 | 47 | /* 48 | * Specialized CUDA reduction functions for BN 49 | */ 50 | #ifdef __CUDACC__ 51 | 52 | #include "utils/cuda.cuh" 53 | 54 | template 55 | __device__ T reduce(Op op, int plane, int N, int S) { 56 | T sum = (T)0; 57 | for (int batch = 0; batch < N; ++batch) { 58 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 59 | sum += op(batch, plane, x); 60 | } 61 | } 62 | 63 | // sum over NumThreads within a warp 64 | sum = warpSum(sum); 65 | 66 | // 'transpose', and reduce within warp again 67 | __shared__ T shared[32]; 68 | __syncthreads(); 69 | if (threadIdx.x % WARP_SIZE == 0) { 70 | shared[threadIdx.x / WARP_SIZE] = sum; 71 | } 72 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 73 | // zero out the other entries in shared 74 | shared[threadIdx.x] = (T)0; 75 | } 76 | __syncthreads(); 77 | if (threadIdx.x / WARP_SIZE == 0) { 78 | sum = warpSum(shared[threadIdx.x]); 79 | if (threadIdx.x == 0) { 80 | shared[0] = sum; 81 | } 82 | } 83 | __syncthreads(); 84 | 85 | // Everyone picks it up, should be broadcast into the whole gradInput 86 | return shared[0]; 87 | } 88 | #endif 89 | -------------------------------------------------------------------------------- /modules/src/inplace_abn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "utils/checks.h" 6 | #include "inplace_abn.h" 7 | 8 | at::Tensor reduce_sum(at::Tensor x) { 9 | if (x.ndimension() == 2) { 10 | return x.sum(0); 11 | } else { 12 | auto x_view = x.view({x.size(0), x.size(1), -1}); 13 | return x_view.sum(-1).sum(0); 14 | } 15 | } 16 | 17 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 18 | if (x.ndimension() == 2) { 19 | return v; 20 | } else { 21 | std::vector broadcast_size = {1, -1}; 22 | for (int64_t i = 2; i < x.ndimension(); ++i) 23 | broadcast_size.push_back(1); 24 | 25 | return v.view(broadcast_size); 26 | } 27 | } 28 | 29 | int64_t count(at::Tensor x) { 30 | int64_t count = x.size(0); 31 | for (int64_t i = 2; i < x.ndimension(); ++i) 32 | count *= x.size(i); 33 | 34 | return count; 35 | } 36 | 37 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { 38 | if (affine) { 39 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); 40 | } else { 41 | return z; 42 | } 43 | } 44 | 45 | std::vector mean_var_cpu(at::Tensor x) { 46 | auto num = count(x); 47 | auto mean = reduce_sum(x) / num; 48 | auto diff = x - broadcast_to(mean, x); 49 | auto var = reduce_sum(diff.pow(2)) / num; 50 | 51 | return {mean, var}; 52 | } 53 | 54 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 55 | bool affine, float eps) { 56 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); 57 | auto mul = at::rsqrt(var + eps) * gamma; 58 | 59 | x.sub_(broadcast_to(mean, x)); 60 | x.mul_(broadcast_to(mul, x)); 61 | if (affine) x.add_(broadcast_to(bias, x)); 62 | 63 | return x; 64 | } 65 | 66 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 67 | bool affine, float eps) { 68 | auto edz = reduce_sum(dz); 69 | auto y = invert_affine(z, weight, bias, affine, eps); 70 | auto eydz = reduce_sum(y * dz); 71 | 72 | return {edz, eydz}; 73 | } 74 | 75 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 76 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 77 | auto y = invert_affine(z, weight, bias, affine, eps); 78 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); 79 | 80 | auto num = count(z); 81 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); 82 | return dx; 83 | } 84 | 85 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { 86 | CHECK_CPU_INPUT(z); 87 | CHECK_CPU_INPUT(dz); 88 | 89 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { 90 | int64_t count = z.numel(); 91 | auto *_z = z.data(); 92 | auto *_dz = dz.data(); 93 | 94 | for (int64_t i = 0; i < count; ++i) { 95 | if (_z[i] < 0) { 96 | _z[i] *= 1 / slope; 97 | _dz[i] *= slope; 98 | } 99 | } 100 | })); 101 | } 102 | 103 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) { 104 | CHECK_CPU_INPUT(z); 105 | CHECK_CPU_INPUT(dz); 106 | 107 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { 108 | int64_t count = z.numel(); 109 | auto *_z = z.data(); 110 | auto *_dz = dz.data(); 111 | 112 | for (int64_t i = 0; i < count; ++i) { 113 | if (_z[i] < 0) { 114 | _z[i] = log1p(_z[i]); 115 | _dz[i] *= (_z[i] + 1.f); 116 | } 117 | } 118 | })); 119 | } 120 | -------------------------------------------------------------------------------- /modules/src/inplace_abn_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "utils/checks.h" 9 | #include "utils/cuda.cuh" 10 | #include "inplace_abn.h" 11 | 12 | #include 13 | 14 | // Operations for reduce 15 | template 16 | struct SumOp { 17 | __device__ SumOp(const T *t, int c, int s) 18 | : tensor(t), chn(c), sp(s) {} 19 | __device__ __forceinline__ T operator()(int batch, int plane, int n) { 20 | return tensor[(batch * chn + plane) * sp + n]; 21 | } 22 | const T *tensor; 23 | const int chn; 24 | const int sp; 25 | }; 26 | 27 | template 28 | struct VarOp { 29 | __device__ VarOp(T m, const T *t, int c, int s) 30 | : mean(m), tensor(t), chn(c), sp(s) {} 31 | __device__ __forceinline__ T operator()(int batch, int plane, int n) { 32 | T val = tensor[(batch * chn + plane) * sp + n]; 33 | return (val - mean) * (val - mean); 34 | } 35 | const T mean; 36 | const T *tensor; 37 | const int chn; 38 | const int sp; 39 | }; 40 | 41 | template 42 | struct GradOp { 43 | __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) 44 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} 45 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 46 | T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; 47 | T _dz = dz[(batch * chn + plane) * sp + n]; 48 | return Pair(_dz, _y * _dz); 49 | } 50 | const T weight; 51 | const T bias; 52 | const T *z; 53 | const T *dz; 54 | const int chn; 55 | const int sp; 56 | }; 57 | 58 | /*********** 59 | * mean_var 60 | ***********/ 61 | 62 | template 63 | __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { 64 | int plane = blockIdx.x; 65 | T norm = T(1) / T(num * sp); 66 | 67 | T _mean = reduce>(SumOp(x, chn, sp), plane, num, sp) * norm; 68 | __syncthreads(); 69 | T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, sp) * norm; 70 | 71 | if (threadIdx.x == 0) { 72 | mean[plane] = _mean; 73 | var[plane] = _var; 74 | } 75 | } 76 | 77 | std::vector mean_var_cuda(at::Tensor x) { 78 | CHECK_CUDA_INPUT(x); 79 | 80 | // Extract dimensions 81 | int64_t num, chn, sp; 82 | get_dims(x, num, chn, sp); 83 | 84 | // Prepare output tensors 85 | auto mean = at::empty({chn}, x.options()); 86 | auto var = at::empty({chn}, x.options()); 87 | 88 | // Run kernel 89 | dim3 blocks(chn); 90 | dim3 threads(getNumThreads(sp)); 91 | auto stream = at::cuda::getCurrentCUDAStream(); 92 | AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { 93 | mean_var_kernel<<>>( 94 | x.data(), 95 | mean.data(), 96 | var.data(), 97 | num, chn, sp); 98 | })); 99 | 100 | return {mean, var}; 101 | } 102 | 103 | /********** 104 | * forward 105 | **********/ 106 | 107 | template 108 | __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, 109 | bool affine, float eps, int num, int chn, int sp) { 110 | int plane = blockIdx.x; 111 | 112 | T _mean = mean[plane]; 113 | T _var = var[plane]; 114 | T _weight = affine ? abs(weight[plane]) + eps : T(1); 115 | T _bias = affine ? bias[plane] : T(0); 116 | 117 | T mul = rsqrt(_var + eps) * _weight; 118 | 119 | for (int batch = 0; batch < num; ++batch) { 120 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 121 | T _x = x[(batch * chn + plane) * sp + n]; 122 | T _y = (_x - _mean) * mul + _bias; 123 | 124 | x[(batch * chn + plane) * sp + n] = _y; 125 | } 126 | } 127 | } 128 | 129 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 130 | bool affine, float eps) { 131 | CHECK_CUDA_INPUT(x); 132 | CHECK_CUDA_INPUT(mean); 133 | CHECK_CUDA_INPUT(var); 134 | CHECK_CUDA_INPUT(weight); 135 | CHECK_CUDA_INPUT(bias); 136 | 137 | // Extract dimensions 138 | int64_t num, chn, sp; 139 | get_dims(x, num, chn, sp); 140 | 141 | // Run kernel 142 | dim3 blocks(chn); 143 | dim3 threads(getNumThreads(sp)); 144 | auto stream = at::cuda::getCurrentCUDAStream(); 145 | AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { 146 | forward_kernel<<>>( 147 | x.data(), 148 | mean.data(), 149 | var.data(), 150 | weight.data(), 151 | bias.data(), 152 | affine, eps, num, chn, sp); 153 | })); 154 | 155 | return x; 156 | } 157 | 158 | /*********** 159 | * edz_eydz 160 | ***********/ 161 | 162 | template 163 | __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, 164 | T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { 165 | int plane = blockIdx.x; 166 | 167 | T _weight = affine ? abs(weight[plane]) + eps : 1.f; 168 | T _bias = affine ? bias[plane] : 0.f; 169 | 170 | Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, sp); 171 | __syncthreads(); 172 | 173 | if (threadIdx.x == 0) { 174 | edz[plane] = res.v1; 175 | eydz[plane] = res.v2; 176 | } 177 | } 178 | 179 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 180 | bool affine, float eps) { 181 | CHECK_CUDA_INPUT(z); 182 | CHECK_CUDA_INPUT(dz); 183 | CHECK_CUDA_INPUT(weight); 184 | CHECK_CUDA_INPUT(bias); 185 | 186 | // Extract dimensions 187 | int64_t num, chn, sp; 188 | get_dims(z, num, chn, sp); 189 | 190 | auto edz = at::empty({chn}, z.options()); 191 | auto eydz = at::empty({chn}, z.options()); 192 | 193 | // Run kernel 194 | dim3 blocks(chn); 195 | dim3 threads(getNumThreads(sp)); 196 | auto stream = at::cuda::getCurrentCUDAStream(); 197 | AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { 198 | edz_eydz_kernel<<>>( 199 | z.data(), 200 | dz.data(), 201 | weight.data(), 202 | bias.data(), 203 | edz.data(), 204 | eydz.data(), 205 | affine, eps, num, chn, sp); 206 | })); 207 | 208 | return {edz, eydz}; 209 | } 210 | 211 | /*********** 212 | * backward 213 | ***********/ 214 | 215 | template 216 | __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, 217 | const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) { 218 | int plane = blockIdx.x; 219 | 220 | T _weight = affine ? abs(weight[plane]) + eps : 1.f; 221 | T _bias = affine ? bias[plane] : 0.f; 222 | T _var = var[plane]; 223 | T _edz = edz[plane]; 224 | T _eydz = eydz[plane]; 225 | 226 | T _mul = _weight * rsqrt(_var + eps); 227 | T count = T(num * sp); 228 | 229 | for (int batch = 0; batch < num; ++batch) { 230 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 231 | T _dz = dz[(batch * chn + plane) * sp + n]; 232 | T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; 233 | 234 | dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; 235 | } 236 | } 237 | } 238 | 239 | at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 240 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 241 | CHECK_CUDA_INPUT(z); 242 | CHECK_CUDA_INPUT(dz); 243 | CHECK_CUDA_INPUT(var); 244 | CHECK_CUDA_INPUT(weight); 245 | CHECK_CUDA_INPUT(bias); 246 | CHECK_CUDA_INPUT(edz); 247 | CHECK_CUDA_INPUT(eydz); 248 | 249 | // Extract dimensions 250 | int64_t num, chn, sp; 251 | get_dims(z, num, chn, sp); 252 | 253 | auto dx = at::zeros_like(z); 254 | 255 | // Run kernel 256 | dim3 blocks(chn); 257 | dim3 threads(getNumThreads(sp)); 258 | auto stream = at::cuda::getCurrentCUDAStream(); 259 | AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { 260 | backward_kernel<<>>( 261 | z.data(), 262 | dz.data(), 263 | var.data(), 264 | weight.data(), 265 | bias.data(), 266 | edz.data(), 267 | eydz.data(), 268 | dx.data(), 269 | affine, eps, num, chn, sp); 270 | })); 271 | 272 | return dx; 273 | } 274 | 275 | /************** 276 | * activations 277 | **************/ 278 | 279 | template 280 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { 281 | // Create thrust pointers 282 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 283 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 284 | 285 | auto stream = at::cuda::getCurrentCUDAStream(); 286 | thrust::transform_if(thrust::cuda::par.on(stream), 287 | th_dz, th_dz + count, th_z, th_dz, 288 | [slope] __device__ (const T& dz) { return dz * slope; }, 289 | [] __device__ (const T& z) { return z < 0; }); 290 | thrust::transform_if(thrust::cuda::par.on(stream), 291 | th_z, th_z + count, th_z, 292 | [slope] __device__ (const T& z) { return z / slope; }, 293 | [] __device__ (const T& z) { return z < 0; }); 294 | } 295 | 296 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { 297 | CHECK_CUDA_INPUT(z); 298 | CHECK_CUDA_INPUT(dz); 299 | 300 | int64_t count = z.numel(); 301 | 302 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { 303 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count); 304 | })); 305 | } 306 | 307 | template 308 | inline void elu_backward_impl(T *z, T *dz, int64_t count) { 309 | // Create thrust pointers 310 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 311 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 312 | 313 | auto stream = at::cuda::getCurrentCUDAStream(); 314 | thrust::transform_if(thrust::cuda::par.on(stream), 315 | th_dz, th_dz + count, th_z, th_z, th_dz, 316 | [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, 317 | [] __device__ (const T& z) { return z < 0; }); 318 | thrust::transform_if(thrust::cuda::par.on(stream), 319 | th_z, th_z + count, th_z, 320 | [] __device__ (const T& z) { return log1p(z); }, 321 | [] __device__ (const T& z) { return z < 0; }); 322 | } 323 | 324 | void elu_backward_cuda(at::Tensor z, at::Tensor dz) { 325 | CHECK_CUDA_INPUT(z); 326 | CHECK_CUDA_INPUT(dz); 327 | 328 | int64_t count = z.numel(); 329 | 330 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { 331 | elu_backward_impl(z.data(), dz.data(), count); 332 | })); 333 | } 334 | -------------------------------------------------------------------------------- /modules/src/inplace_abn_cuda_half.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "utils/checks.h" 8 | #include "utils/cuda.cuh" 9 | #include "inplace_abn.h" 10 | 11 | #include 12 | 13 | // Operations for reduce 14 | struct SumOpH { 15 | __device__ SumOpH(const half *t, int c, int s) 16 | : tensor(t), chn(c), sp(s) {} 17 | __device__ __forceinline__ float operator()(int batch, int plane, int n) { 18 | return __half2float(tensor[(batch * chn + plane) * sp + n]); 19 | } 20 | const half *tensor; 21 | const int chn; 22 | const int sp; 23 | }; 24 | 25 | struct VarOpH { 26 | __device__ VarOpH(float m, const half *t, int c, int s) 27 | : mean(m), tensor(t), chn(c), sp(s) {} 28 | __device__ __forceinline__ float operator()(int batch, int plane, int n) { 29 | const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]); 30 | return (t - mean) * (t - mean); 31 | } 32 | const float mean; 33 | const half *tensor; 34 | const int chn; 35 | const int sp; 36 | }; 37 | 38 | struct GradOpH { 39 | __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s) 40 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} 41 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 42 | float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight; 43 | float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); 44 | return Pair(_dz, _y * _dz); 45 | } 46 | const float weight; 47 | const float bias; 48 | const half *z; 49 | const half *dz; 50 | const int chn; 51 | const int sp; 52 | }; 53 | 54 | /*********** 55 | * mean_var 56 | ***********/ 57 | 58 | __global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) { 59 | int plane = blockIdx.x; 60 | float norm = 1.f / static_cast(num * sp); 61 | 62 | float _mean = reduce(SumOpH(x, chn, sp), plane, num, sp) * norm; 63 | __syncthreads(); 64 | float _var = reduce(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm; 65 | 66 | if (threadIdx.x == 0) { 67 | mean[plane] = _mean; 68 | var[plane] = _var; 69 | } 70 | } 71 | 72 | std::vector mean_var_cuda_h(at::Tensor x) { 73 | CHECK_CUDA_INPUT(x); 74 | 75 | // Extract dimensions 76 | int64_t num, chn, sp; 77 | get_dims(x, num, chn, sp); 78 | 79 | // Prepare output tensors 80 | auto mean = at::empty({chn},x.options().dtype(at::kFloat)); 81 | auto var = at::empty({chn},x.options().dtype(at::kFloat)); 82 | 83 | // Run kernel 84 | dim3 blocks(chn); 85 | dim3 threads(getNumThreads(sp)); 86 | auto stream = at::cuda::getCurrentCUDAStream(); 87 | mean_var_kernel_h<<>>( 88 | reinterpret_cast(x.data()), 89 | mean.data(), 90 | var.data(), 91 | num, chn, sp); 92 | 93 | return {mean, var}; 94 | } 95 | 96 | /********** 97 | * forward 98 | **********/ 99 | 100 | __global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias, 101 | bool affine, float eps, int num, int chn, int sp) { 102 | int plane = blockIdx.x; 103 | 104 | const float _mean = mean[plane]; 105 | const float _var = var[plane]; 106 | const float _weight = affine ? abs(weight[plane]) + eps : 1.f; 107 | const float _bias = affine ? bias[plane] : 0.f; 108 | 109 | const float mul = rsqrt(_var + eps) * _weight; 110 | 111 | for (int batch = 0; batch < num; ++batch) { 112 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 113 | half *x_ptr = x + (batch * chn + plane) * sp + n; 114 | float _x = __half2float(*x_ptr); 115 | float _y = (_x - _mean) * mul + _bias; 116 | 117 | *x_ptr = __float2half(_y); 118 | } 119 | } 120 | } 121 | 122 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 123 | bool affine, float eps) { 124 | CHECK_CUDA_INPUT(x); 125 | CHECK_CUDA_INPUT(mean); 126 | CHECK_CUDA_INPUT(var); 127 | CHECK_CUDA_INPUT(weight); 128 | CHECK_CUDA_INPUT(bias); 129 | 130 | // Extract dimensions 131 | int64_t num, chn, sp; 132 | get_dims(x, num, chn, sp); 133 | 134 | // Run kernel 135 | dim3 blocks(chn); 136 | dim3 threads(getNumThreads(sp)); 137 | auto stream = at::cuda::getCurrentCUDAStream(); 138 | forward_kernel_h<<>>( 139 | reinterpret_cast(x.data()), 140 | mean.data(), 141 | var.data(), 142 | weight.data(), 143 | bias.data(), 144 | affine, eps, num, chn, sp); 145 | 146 | return x; 147 | } 148 | 149 | __global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias, 150 | float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) { 151 | int plane = blockIdx.x; 152 | 153 | float _weight = affine ? abs(weight[plane]) + eps : 1.f; 154 | float _bias = affine ? bias[plane] : 0.f; 155 | 156 | Pair res = reduce, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp); 157 | __syncthreads(); 158 | 159 | if (threadIdx.x == 0) { 160 | edz[plane] = res.v1; 161 | eydz[plane] = res.v2; 162 | } 163 | } 164 | 165 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 166 | bool affine, float eps) { 167 | CHECK_CUDA_INPUT(z); 168 | CHECK_CUDA_INPUT(dz); 169 | CHECK_CUDA_INPUT(weight); 170 | CHECK_CUDA_INPUT(bias); 171 | 172 | // Extract dimensions 173 | int64_t num, chn, sp; 174 | get_dims(z, num, chn, sp); 175 | 176 | auto edz = at::empty({chn},z.options().dtype(at::kFloat)); 177 | auto eydz = at::empty({chn},z.options().dtype(at::kFloat)); 178 | 179 | // Run kernel 180 | dim3 blocks(chn); 181 | dim3 threads(getNumThreads(sp)); 182 | auto stream = at::cuda::getCurrentCUDAStream(); 183 | edz_eydz_kernel_h<<>>( 184 | reinterpret_cast(z.data()), 185 | reinterpret_cast(dz.data()), 186 | weight.data(), 187 | bias.data(), 188 | edz.data(), 189 | eydz.data(), 190 | affine, eps, num, chn, sp); 191 | 192 | return {edz, eydz}; 193 | } 194 | 195 | __global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz, 196 | const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) { 197 | int plane = blockIdx.x; 198 | 199 | float _weight = affine ? abs(weight[plane]) + eps : 1.f; 200 | float _bias = affine ? bias[plane] : 0.f; 201 | float _var = var[plane]; 202 | float _edz = edz[plane]; 203 | float _eydz = eydz[plane]; 204 | 205 | float _mul = _weight * rsqrt(_var + eps); 206 | float count = float(num * sp); 207 | 208 | for (int batch = 0; batch < num; ++batch) { 209 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 210 | float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); 211 | float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight; 212 | 213 | dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul); 214 | } 215 | } 216 | } 217 | 218 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 219 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 220 | CHECK_CUDA_INPUT(z); 221 | CHECK_CUDA_INPUT(dz); 222 | CHECK_CUDA_INPUT(var); 223 | CHECK_CUDA_INPUT(weight); 224 | CHECK_CUDA_INPUT(bias); 225 | CHECK_CUDA_INPUT(edz); 226 | CHECK_CUDA_INPUT(eydz); 227 | 228 | // Extract dimensions 229 | int64_t num, chn, sp; 230 | get_dims(z, num, chn, sp); 231 | 232 | auto dx = at::zeros_like(z); 233 | 234 | // Run kernel 235 | dim3 blocks(chn); 236 | dim3 threads(getNumThreads(sp)); 237 | auto stream = at::cuda::getCurrentCUDAStream(); 238 | backward_kernel_h<<>>( 239 | reinterpret_cast(z.data()), 240 | reinterpret_cast(dz.data()), 241 | var.data(), 242 | weight.data(), 243 | bias.data(), 244 | edz.data(), 245 | eydz.data(), 246 | reinterpret_cast(dx.data()), 247 | affine, eps, num, chn, sp); 248 | 249 | return dx; 250 | } 251 | 252 | __global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) { 253 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){ 254 | float _z = __half2float(z[i]); 255 | if (_z < 0) { 256 | dz[i] = __float2half(__half2float(dz[i]) * slope); 257 | z[i] = __float2half(_z / slope); 258 | } 259 | } 260 | } 261 | 262 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) { 263 | CHECK_CUDA_INPUT(z); 264 | CHECK_CUDA_INPUT(dz); 265 | 266 | int64_t count = z.numel(); 267 | dim3 threads(getNumThreads(count)); 268 | dim3 blocks = (count + threads.x - 1) / threads.x; 269 | auto stream = at::cuda::getCurrentCUDAStream(); 270 | leaky_relu_backward_impl_h<<>>( 271 | reinterpret_cast(z.data()), 272 | reinterpret_cast(dz.data()), 273 | slope, count); 274 | } 275 | 276 | -------------------------------------------------------------------------------- /modules/src/utils/checks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT 6 | #ifndef AT_CHECK 7 | #define AT_CHECK AT_ASSERT 8 | #endif 9 | 10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") 12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") 13 | 14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) -------------------------------------------------------------------------------- /modules/src/utils/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* 6 | * Functions to share code between CPU and GPU 7 | */ 8 | 9 | #ifdef __CUDACC__ 10 | // CUDA versions 11 | 12 | #define HOST_DEVICE __host__ __device__ 13 | #define INLINE_HOST_DEVICE __host__ __device__ inline 14 | #define FLOOR(x) floor(x) 15 | 16 | #if __CUDA_ARCH__ >= 600 17 | // Recent compute capabilities have block-level atomicAdd for all data types, so we use that 18 | #define ACCUM(x,y) atomicAdd_block(&(x),(y)) 19 | #else 20 | // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float 21 | // and use the known atomicCAS-based implementation for double 22 | template 23 | __device__ inline data_t atomic_add(data_t *address, data_t val) { 24 | return atomicAdd(address, val); 25 | } 26 | 27 | template<> 28 | __device__ inline double atomic_add(double *address, double val) { 29 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 30 | unsigned long long int old = *address_as_ull, assumed; 31 | do { 32 | assumed = old; 33 | old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); 34 | } while (assumed != old); 35 | return __longlong_as_double(old); 36 | } 37 | 38 | #define ACCUM(x,y) atomic_add(&(x),(y)) 39 | #endif // #if __CUDA_ARCH__ >= 600 40 | 41 | #else 42 | // CPU versions 43 | 44 | #define HOST_DEVICE 45 | #define INLINE_HOST_DEVICE inline 46 | #define FLOOR(x) std::floor(x) 47 | #define ACCUM(x,y) (x) += (y) 48 | 49 | #endif // #ifdef __CUDACC__ -------------------------------------------------------------------------------- /modules/src/utils/cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * General settings and functions 5 | */ 6 | const int WARP_SIZE = 32; 7 | const int MAX_BLOCK_SIZE = 1024; 8 | 9 | static int getNumThreads(int nElem) { 10 | int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE}; 11 | for (int i = 0; i < 6; ++i) { 12 | if (nElem <= threadSizes[i]) { 13 | return threadSizes[i]; 14 | } 15 | } 16 | return MAX_BLOCK_SIZE; 17 | } 18 | 19 | /* 20 | * Reduction utilities 21 | */ 22 | template 23 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 24 | unsigned int mask = 0xffffffff) { 25 | #if CUDART_VERSION >= 9000 26 | return __shfl_xor_sync(mask, value, laneMask, width); 27 | #else 28 | return __shfl_xor(value, laneMask, width); 29 | #endif 30 | } 31 | 32 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 33 | 34 | template 35 | struct Pair { 36 | T v1, v2; 37 | __device__ Pair() {} 38 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 39 | __device__ Pair(T v) : v1(v), v2(v) {} 40 | __device__ Pair(int v) : v1(v), v2(v) {} 41 | __device__ Pair &operator+=(const Pair &a) { 42 | v1 += a.v1; 43 | v2 += a.v2; 44 | return *this; 45 | } 46 | }; 47 | 48 | template 49 | static __device__ __forceinline__ T warpSum(T val) { 50 | #if __CUDA_ARCH__ >= 300 51 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 52 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 53 | } 54 | #else 55 | __shared__ T values[MAX_BLOCK_SIZE]; 56 | values[threadIdx.x] = val; 57 | __threadfence_block(); 58 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 59 | for (int i = 1; i < WARP_SIZE; i++) { 60 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 61 | } 62 | #endif 63 | return val; 64 | } 65 | 66 | template 67 | static __device__ __forceinline__ Pair warpSum(Pair value) { 68 | value.v1 = warpSum(value.v1); 69 | value.v2 = warpSum(value.v2); 70 | return value; 71 | } -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import logging 7 | 8 | logger = logging.getLogger() 9 | 10 | class Optimizer(object): 11 | def __init__(self, 12 | model, 13 | lr0, 14 | momentum, 15 | wd, 16 | warmup_steps, 17 | warmup_start_lr, 18 | max_iter, 19 | power, 20 | *args, **kwargs): 21 | self.warmup_steps = warmup_steps 22 | self.warmup_start_lr = warmup_start_lr 23 | self.lr0 = lr0 24 | self.lr = self.lr0 25 | self.max_iter = float(max_iter) 26 | self.power = power 27 | self.it = 0 28 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params() 29 | param_list = [ 30 | {'params': wd_params}, 31 | {'params': nowd_params, 'weight_decay': 0}, 32 | {'params': lr_mul_wd_params, 'lr_mul': True}, 33 | {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}] 34 | self.optim = torch.optim.SGD( 35 | param_list, 36 | lr = lr0, 37 | momentum = momentum, 38 | weight_decay = wd) 39 | self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps) 40 | 41 | 42 | def get_lr(self): 43 | if self.it <= self.warmup_steps: 44 | lr = self.warmup_start_lr*(self.warmup_factor**self.it) 45 | else: 46 | factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power 47 | lr = self.lr0 * factor 48 | return lr 49 | 50 | 51 | def step(self): 52 | self.lr = self.get_lr() 53 | for pg in self.optim.param_groups: 54 | if pg.get('lr_mul', False): 55 | pg['lr'] = self.lr * 10 56 | else: 57 | pg['lr'] = self.lr 58 | if self.optim.defaults.get('lr_mul', False): 59 | self.optim.defaults['lr'] = self.lr * 10 60 | else: 61 | self.optim.defaults['lr'] = self.lr 62 | self.it += 1 63 | self.optim.step() 64 | if self.it == self.warmup_steps+2: 65 | logger.info('==> warmup done, start to implement poly lr strategy') 66 | 67 | def zero_grad(self): 68 | self.optim.zero_grad() 69 | 70 | -------------------------------------------------------------------------------- /prepropess_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import os.path as osp 5 | import os 6 | import cv2 7 | from .transform import * 8 | from PIL import Image 9 | 10 | face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img' 11 | face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno' 12 | mask_path = '/home/zll/data/CelebAMask-HQ/mask' 13 | counter = 0 14 | total = 0 15 | for i in range(15): 16 | 17 | atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 18 | 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 19 | 20 | for j in range(i * 2000, (i + 1) * 2000): 21 | 22 | mask = np.zeros((512, 512)) 23 | 24 | for l, att in enumerate(atts, 1): 25 | total += 1 26 | file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png']) 27 | path = osp.join(face_sep_mask, str(i), file_name) 28 | 29 | if os.path.exists(path): 30 | counter += 1 31 | sep_mask = np.array(Image.open(path).convert('P')) 32 | # print(np.unique(sep_mask)) 33 | 34 | mask[sep_mask == 225] = l 35 | cv2.imwrite('{}/{}.png'.format(mask_path, j), mask) 36 | print(j) 37 | 38 | print(counter, total) 39 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | #self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from .logger import setup_logger 5 | from .model import BiSeNet 6 | 7 | import torch 8 | 9 | import os 10 | import os.path as osp 11 | import numpy as np 12 | from PIL import Image 13 | import torchvision.transforms as transforms 14 | import cv2 15 | 16 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): 17 | # Colors for all 20 parts 18 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], 19 | [255, 0, 85], [255, 0, 170], 20 | [0, 255, 0], [85, 255, 0], [170, 255, 0], 21 | [0, 255, 85], [0, 255, 170], 22 | [0, 0, 255], [85, 0, 255], [170, 0, 255], 23 | [0, 85, 255], [0, 170, 255], 24 | [255, 255, 0], [255, 255, 85], [255, 255, 170], 25 | [255, 0, 255], [255, 85, 255], [255, 170, 255], 26 | [0, 255, 255], [85, 255, 255], [170, 255, 255]] 27 | 28 | im = np.array(im) 29 | vis_im = im.copy().astype(np.uint8) 30 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 31 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 32 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 33 | 34 | num_of_class = np.max(vis_parsing_anno) 35 | 36 | for pi in range(1, num_of_class + 1): 37 | index = np.where(vis_parsing_anno == pi) 38 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 39 | 40 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 41 | # print(vis_parsing_anno_color.shape, vis_im.shape) 42 | vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) 43 | 44 | # Save result or not 45 | if save_im: 46 | cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno) 47 | cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 48 | 49 | # return vis_im 50 | 51 | def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): 52 | 53 | if not os.path.exists(respth): 54 | os.makedirs(respth) 55 | 56 | n_classes = 19 57 | net = BiSeNet(n_classes=n_classes) 58 | net.cuda() 59 | save_pth = osp.join('res/cp', cp) 60 | net.load_state_dict(torch.load(save_pth)) 61 | net.eval() 62 | 63 | to_tensor = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 66 | ]) 67 | with torch.no_grad(): 68 | for image_path in os.listdir(dspth): 69 | img = Image.open(osp.join(dspth, image_path)) 70 | image = img.resize((512, 512), Image.BILINEAR) 71 | img = to_tensor(image) 72 | img = torch.unsqueeze(img, 0) 73 | img = img.cuda() 74 | out = net(img)[0] 75 | parsing = out.squeeze(0).cpu().numpy().argmax(0) 76 | # print(parsing) 77 | print(np.unique(parsing)) 78 | 79 | vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path)) 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | if __name__ == "__main__": 88 | evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth') 89 | 90 | 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from .logger import setup_logger 5 | from .model import BiSeNet 6 | from .face_dataset import FaceMask 7 | from .loss import OhemCELoss 8 | from .evaluate import evaluate 9 | from .optimizer import Optimizer 10 | import cv2 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.utils.data import DataLoader 16 | import torch.nn.functional as F 17 | import torch.distributed as dist 18 | 19 | import os 20 | import os.path as osp 21 | import logging 22 | import time 23 | import datetime 24 | import argparse 25 | 26 | 27 | respth = './res' 28 | if not osp.exists(respth): 29 | os.makedirs(respth) 30 | logger = logging.getLogger() 31 | 32 | 33 | def parse_args(): 34 | parse = argparse.ArgumentParser() 35 | parse.add_argument( 36 | '--local_rank', 37 | dest = 'local_rank', 38 | type = int, 39 | default = -1, 40 | ) 41 | return parse.parse_args() 42 | 43 | 44 | def train(): 45 | args = parse_args() 46 | torch.cuda.set_device(args.local_rank) 47 | dist.init_process_group( 48 | backend = 'nccl', 49 | init_method = 'tcp://127.0.0.1:33241', 50 | world_size = torch.cuda.device_count(), 51 | rank=args.local_rank 52 | ) 53 | setup_logger(respth) 54 | 55 | # dataset 56 | n_classes = 19 57 | n_img_per_gpu = 16 58 | n_workers = 8 59 | cropsize = [448, 448] 60 | data_root = '/home/zll/data/CelebAMask-HQ/' 61 | 62 | ds = FaceMask(data_root, cropsize=cropsize, mode='train') 63 | sampler = torch.utils.data.distributed.DistributedSampler(ds) 64 | dl = DataLoader(ds, 65 | batch_size = n_img_per_gpu, 66 | shuffle = False, 67 | sampler = sampler, 68 | num_workers = n_workers, 69 | pin_memory = True, 70 | drop_last = True) 71 | 72 | # model 73 | ignore_idx = -100 74 | net = BiSeNet(n_classes=n_classes) 75 | net.cuda() 76 | net.train() 77 | net = nn.parallel.DistributedDataParallel(net, 78 | device_ids = [args.local_rank, ], 79 | output_device = args.local_rank 80 | ) 81 | score_thres = 0.7 82 | n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16 83 | LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 84 | Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 85 | Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 86 | 87 | ## optimizer 88 | momentum = 0.9 89 | weight_decay = 5e-4 90 | lr_start = 1e-2 91 | max_iter = 80000 92 | power = 0.9 93 | warmup_steps = 1000 94 | warmup_start_lr = 1e-5 95 | optim = Optimizer( 96 | model = net.module, 97 | lr0 = lr_start, 98 | momentum = momentum, 99 | wd = weight_decay, 100 | warmup_steps = warmup_steps, 101 | warmup_start_lr = warmup_start_lr, 102 | max_iter = max_iter, 103 | power = power) 104 | 105 | ## train loop 106 | msg_iter = 50 107 | loss_avg = [] 108 | st = glob_st = time.time() 109 | diter = iter(dl) 110 | epoch = 0 111 | for it in range(max_iter): 112 | try: 113 | im, lb = next(diter) 114 | if not im.size()[0] == n_img_per_gpu: 115 | raise StopIteration 116 | except StopIteration: 117 | epoch += 1 118 | sampler.set_epoch(epoch) 119 | diter = iter(dl) 120 | im, lb = next(diter) 121 | im = im.cuda() 122 | lb = lb.cuda() 123 | H, W = im.size()[2:] 124 | lb = torch.squeeze(lb, 1) 125 | 126 | optim.zero_grad() 127 | out, out16, out32 = net(im) 128 | lossp = LossP(out, lb) 129 | loss2 = Loss2(out16, lb) 130 | loss3 = Loss3(out32, lb) 131 | loss = lossp + loss2 + loss3 132 | loss.backward() 133 | optim.step() 134 | 135 | loss_avg.append(loss.item()) 136 | 137 | # print training log message 138 | if (it+1) % msg_iter == 0: 139 | loss_avg = sum(loss_avg) / len(loss_avg) 140 | lr = optim.lr 141 | ed = time.time() 142 | t_intv, glob_t_intv = ed - st, ed - glob_st 143 | eta = int((max_iter - it) * (glob_t_intv / it)) 144 | eta = str(datetime.timedelta(seconds=eta)) 145 | msg = ', '.join([ 146 | 'it: {it}/{max_it}', 147 | 'lr: {lr:4f}', 148 | 'loss: {loss:.4f}', 149 | 'eta: {eta}', 150 | 'time: {time:.4f}', 151 | ]).format( 152 | it = it+1, 153 | max_it = max_iter, 154 | lr = lr, 155 | loss = loss_avg, 156 | time = t_intv, 157 | eta = eta 158 | ) 159 | logger.info(msg) 160 | loss_avg = [] 161 | st = ed 162 | if dist.get_rank() == 0: 163 | if (it+1) % 5000 == 0: 164 | state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() 165 | if dist.get_rank() == 0: 166 | torch.save(state, './res/cp/{}_iter.pth'.format(it)) 167 | evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it)) 168 | 169 | # dump the final model 170 | save_pth = osp.join(respth, 'model_final_diss.pth') 171 | # net.cpu() 172 | state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() 173 | if dist.get_rank() == 0: 174 | torch.save(state, save_pth) 175 | logger.info('training done, model saved to: {}'.format(save_pth)) 176 | 177 | 178 | if __name__ == "__main__": 179 | train() 180 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from PIL import Image 6 | import PIL.ImageEnhance as ImageEnhance 7 | import random 8 | import numpy as np 9 | 10 | class RandomCrop(object): 11 | def __init__(self, size, *args, **kwargs): 12 | self.size = size 13 | 14 | def __call__(self, im_lb): 15 | im = im_lb['im'] 16 | lb = im_lb['lb'] 17 | assert im.size == lb.size 18 | W, H = self.size 19 | w, h = im.size 20 | 21 | if (W, H) == (w, h): return dict(im=im, lb=lb) 22 | if w < W or h < H: 23 | scale = float(W) / w if w < h else float(H) / h 24 | w, h = int(scale * w + 1), int(scale * h + 1) 25 | im = im.resize((w, h), Image.BILINEAR) 26 | lb = lb.resize((w, h), Image.NEAREST) 27 | sw, sh = random.random() * (w - W), random.random() * (h - H) 28 | crop = int(sw), int(sh), int(sw) + W, int(sh) + H 29 | return dict( 30 | im = im.crop(crop), 31 | lb = lb.crop(crop) 32 | ) 33 | 34 | 35 | class HorizontalFlip(object): 36 | def __init__(self, p=0.5, *args, **kwargs): 37 | self.p = p 38 | 39 | def __call__(self, im_lb): 40 | if random.random() > self.p: 41 | return im_lb 42 | else: 43 | im = im_lb['im'] 44 | lb = im_lb['lb'] 45 | 46 | # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 47 | # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] 48 | 49 | flip_lb = np.array(lb) 50 | flip_lb[lb == 2] = 3 51 | flip_lb[lb == 3] = 2 52 | flip_lb[lb == 4] = 5 53 | flip_lb[lb == 5] = 4 54 | flip_lb[lb == 7] = 8 55 | flip_lb[lb == 8] = 7 56 | flip_lb = Image.fromarray(flip_lb) 57 | return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT), 58 | lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT), 59 | ) 60 | 61 | 62 | class RandomScale(object): 63 | def __init__(self, scales=(1, ), *args, **kwargs): 64 | self.scales = scales 65 | 66 | def __call__(self, im_lb): 67 | im = im_lb['im'] 68 | lb = im_lb['lb'] 69 | W, H = im.size 70 | scale = random.choice(self.scales) 71 | w, h = int(W * scale), int(H * scale) 72 | return dict(im = im.resize((w, h), Image.BILINEAR), 73 | lb = lb.resize((w, h), Image.NEAREST), 74 | ) 75 | 76 | 77 | class ColorJitter(object): 78 | def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs): 79 | if not brightness is None and brightness>0: 80 | self.brightness = [max(1-brightness, 0), 1+brightness] 81 | if not contrast is None and contrast>0: 82 | self.contrast = [max(1-contrast, 0), 1+contrast] 83 | if not saturation is None and saturation>0: 84 | self.saturation = [max(1-saturation, 0), 1+saturation] 85 | 86 | def __call__(self, im_lb): 87 | im = im_lb['im'] 88 | lb = im_lb['lb'] 89 | r_brightness = random.uniform(self.brightness[0], self.brightness[1]) 90 | r_contrast = random.uniform(self.contrast[0], self.contrast[1]) 91 | r_saturation = random.uniform(self.saturation[0], self.saturation[1]) 92 | im = ImageEnhance.Brightness(im).enhance(r_brightness) 93 | im = ImageEnhance.Contrast(im).enhance(r_contrast) 94 | im = ImageEnhance.Color(im).enhance(r_saturation) 95 | return dict(im = im, 96 | lb = lb, 97 | ) 98 | 99 | 100 | class MultiScale(object): 101 | def __init__(self, scales): 102 | self.scales = scales 103 | 104 | def __call__(self, img): 105 | W, H = img.size 106 | sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales] 107 | imgs = [] 108 | [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes] 109 | return imgs 110 | 111 | 112 | class Compose(object): 113 | def __init__(self, do_list): 114 | self.do_list = do_list 115 | 116 | def __call__(self, im_lb): 117 | for comp in self.do_list: 118 | im_lb = comp(im_lb) 119 | return im_lb 120 | 121 | 122 | 123 | 124 | if __name__ == '__main__': 125 | flip = HorizontalFlip(p = 1) 126 | crop = RandomCrop((321, 321)) 127 | rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0)) 128 | img = Image.open('data/img.jpg') 129 | lb = Image.open('data/label.png') 130 | --------------------------------------------------------------------------------