├── README.md ├── Requirements ├── __pycache__ ├── engine.cpython-36.pyc ├── evaluate.cpython-36.pyc ├── evaluate.cpython-37.pyc ├── transform_LIP.cpython-36.pyc └── transform_LIP.cpython-37.pyc ├── dataset ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── data_augmentation.cpython-36.pyc │ ├── data_augmentation.cpython-37.pyc │ ├── dataset_LIP.cpython-36.pyc │ ├── datasets.cpython-36.pyc │ ├── datasets.cpython-37.pyc │ ├── joint_transformation.cpython-36.pyc │ ├── joint_transformation.cpython-37.pyc │ ├── target_generation.cpython-36.pyc │ ├── target_generation.cpython-37.pyc │ └── voc.cpython-36.pyc ├── datasets.py ├── list │ └── .DS_Store └── target_generation.py ├── engine.py ├── evaluate.py ├── evaluate_multi.py ├── networks ├── .DS_Store ├── CDGNet.py └── __pycache__ │ ├── CDGNet.cpython-36.pyc │ ├── CE2P.cpython-36.pyc │ ├── CE2P.cpython-37.pyc │ ├── CE2P.cpython-38.pyc │ ├── CE2PHybrid.cpython-36.pyc │ └── CE2P_test.cpython-36.pyc ├── run.sh ├── run_evaluate.sh ├── run_evaluate_multiScale.sh ├── train.py └── utils ├── .DS_Store ├── ImgTransforms.py ├── __init__.py ├── __pycache__ ├── ImgTransforms.cpython-36.pyc ├── ImgTransforms.cpython-37.pyc ├── OCRAttention.cpython-36.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── attention.cpython-36.pyc ├── attention.cpython-37.pyc ├── attention.cpython-38.pyc ├── criterion.cpython-36.pyc ├── criterion.cpython-37.pyc ├── distributed.cpython-36.pyc ├── distributed.cpython-37.pyc ├── encoding.cpython-36.pyc ├── encoding.cpython-37.pyc ├── logger.cpython-36.pyc ├── loss.cpython-36.pyc ├── loss.cpython-37.pyc ├── lovasz_losses.cpython-36.pyc ├── lovasz_losses.cpython-37.pyc ├── miou.cpython-36.pyc ├── miou.cpython-37.pyc ├── model_store.cpython-36.pyc ├── pyt_utils.cpython-36.pyc ├── transforms.cpython-36.pyc ├── transforms.cpython-37.pyc ├── utils.cpython-36.pyc └── utils.cpython-37.pyc ├── attention.py ├── criterion.py ├── distributed.py ├── encoding.py ├── logger.py ├── loss.py ├── lovasz_losses.py ├── miou.py ├── pyt_utils.py ├── transforms.py ├── utils.py └── writejson.py /README.md: -------------------------------------------------------------------------------- 1 | This repo is a PyTorch implementation of our paper CDGNet: Class Distribution Guided Network for Human Parsing accepted by CVPR2022(https://openaccess.thecvf.com/content/CVPR2022/html/Liu_CDGNet_Class_Distribution_Guided_Network_for_Human_Parsing_CVPR_2022_paper.html). We accumulate the original human parsing Ground Truth in the horizontal and vertical directions to obtain the class distribution lables that can guide the network to exploit the intrinsic distribution rule of each class. The generated labels can act as additional supervision signal to improve the parsing performance. 2 | 3 | Requirements 4 | Pytorch 1.9.0 5 | 6 | Python 3.7 7 | 8 | Implementation 9 | 10 | Dataset 11 | Please download LIP dataset and make them follow this structure: 12 | ''' 13 | |-- LIP 14 | |-- images_labels 15 | |-- train_images 16 | |-- train_segmentations 17 | |-- val_images 18 | |-- val_segmentations 19 | |-- train_id.txt 20 | |-- val_id.txt 21 | ''' 22 | Please download imagenet pretrained resent-101 from [baidu drive](https://pan.baidu.com/s/1NoxI_JetjSVa7uqgVSKdPw) or [Google drive](https://drive.google.com/open?id=1rzLU-wK6rEorCNJfwrmIu5hY2wRMyKTK), and put it into dataset folder. 23 | 24 | ### Training and Evaluation 25 | ```bash 26 | ./run.sh 27 | ``` 28 | Please download the trained model for LIP dataset from [baidu drive](https://pan.baidu.com/s/1WyifPOOE0SqIzCje-d1kzA?pwd=81la) and put it into snapshots folder. 29 | 30 | ./run_evaluate.sh for single scale evaluation or ./run_evaluate_multiScale.sh for multiple scale evaluation. 31 | ``` 32 | The parsing result of the provided 'LIP_epoch_149.pth' is 60.30 on LIP dataset. 33 | 34 | Citation: 35 | 36 | If you find our work useful for your research, please cite: 37 | 38 | @InProceedings{Liu_2022_CVPR, 39 | author = {Liu, Kunliang and Choi, Ouk and Wang, Jianming and Hwang, Wonjun}, 40 | title = {CDGNet: Class Distribution Guided Network for Human Parsing}, 41 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 42 | month = {June}, 43 | year = {2022}, 44 | pages = {4473-4482} 45 | } 46 | 47 | Acknowledgement: 48 | 49 | We acknowledge Ziwei Zhang and Tao Ruan for sharing their codes. 50 | -------------------------------------------------------------------------------- /Requirements: -------------------------------------------------------------------------------- 1 | Requirements 2 | 3 | Pytorch 1.9.0 4 | torchvision 0.11.0 5 | scipy 1.5.2 6 | cudatoolkit 11.3.1 7 | tensorboardX 2.2 8 | torchvision 0.11.0 9 | Python 3.7 10 | -------------------------------------------------------------------------------- /__pycache__/engine.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/__pycache__/engine.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/evaluate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/__pycache__/evaluate.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/evaluate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/__pycache__/evaluate.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/transform_LIP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/__pycache__/transform_LIP.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/transform_LIP.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/__pycache__/transform_LIP.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/.DS_Store -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/data_augmentation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/data_augmentation.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/data_augmentation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/data_augmentation.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_LIP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/dataset_LIP.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/joint_transformation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/joint_transformation.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/joint_transformation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/joint_transformation.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/target_generation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/target_generation.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/target_generation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/target_generation.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/voc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/__pycache__/voc.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import torch 5 | import cv2 6 | import json 7 | from torch.utils import data 8 | from dataset.target_generation import generate_edge, generate_hw_gt 9 | from utils.transforms import get_affine_transform 10 | from utils.ImgTransforms import AugmentationBlock, autoaug_imagenet_policies 11 | 12 | # statisticSeg=[ 30462,7026,21054,2404,1660,23165,1201,8182,2178,16224, 13 | # 455,518,634,24418,18539,20033,4763,4832,8126,8166] 14 | class LIPDataSet(data.Dataset): 15 | def __init__(self, root, dataset, crop_size=[473, 473], scale_factor=0.25, 16 | rotation_factor=30, ignore_label=255, transform=None): 17 | """ 18 | :rtype: 19 | """ 20 | self.root = root 21 | self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0] 22 | self.crop_size = np.asarray(crop_size) 23 | self.ignore_label = ignore_label 24 | self.scale_factor = scale_factor 25 | self.rotation_factor = rotation_factor 26 | self.flip_prob = 0.5 27 | self.flip_pairs = [[0, 5], [1, 4], [2, 3], [11, 14], [12, 13], [10, 15]] 28 | self.transform = transform 29 | self.dataset = dataset 30 | # self.statSeg = np.array( statisticSeg, dtype ='float') 31 | # self.statSeg = self.statSeg/30462 32 | 33 | list_path = os.path.join(self.root, self.dataset + '_id.txt') 34 | 35 | self.im_list = [i_id.strip() for i_id in open(list_path)] 36 | # if dataset != 'val': 37 | # im_list_2 = [] 38 | # for i in range(len(self.im_list)): 39 | # if i % 5 ==0: 40 | # im_list_2.append(self.im_list[i]) 41 | # self.im_list = im_list_2 42 | self.number_samples = len(self.im_list) 43 | #================================================================================ 44 | self.augBlock = AugmentationBlock( autoaug_imagenet_policies ) 45 | #================================================================================ 46 | def __len__(self): 47 | return self.number_samples 48 | 49 | def _box2cs(self, box): 50 | x, y, w, h = box[:4] 51 | return self._xywh2cs(x, y, w, h) 52 | 53 | def _xywh2cs(self, x, y, w, h): 54 | center = np.zeros((2), dtype=np.float32) 55 | center[0] = x + w * 0.5 56 | center[1] = y + h * 0.5 57 | if w > self.aspect_ratio * h: 58 | h = w * 1.0 / self.aspect_ratio 59 | elif w < self.aspect_ratio * h: 60 | w = h * self.aspect_ratio 61 | scale = np.array([w * 1.0, h * 1.0], dtype=np.float32) 62 | 63 | return center, scale 64 | 65 | def __getitem__(self, index): 66 | # Load training image 67 | im_name = self.im_list[index] 68 | 69 | im_path = os.path.join(self.root, self.dataset + '_images', im_name + '.jpg') 70 | parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', im_name + '.png') 71 | 72 | im = cv2.imread(im_path, cv2.IMREAD_COLOR) 73 | #================================================= 74 | if self.dataset != 'val': 75 | im = self.augBlock( im ) 76 | #================================================= 77 | h, w, _ = im.shape 78 | parsing_anno = np.zeros((h, w), dtype=np.long) 79 | 80 | # Get center and scale 81 | center, s = self._box2cs([0, 0, w - 1, h - 1]) 82 | r = 0 83 | 84 | if self.dataset != 'test': 85 | parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE) 86 | 87 | if self.dataset == 'train' or self.dataset == 'trainval': 88 | 89 | sf = self.scale_factor 90 | rf = self.rotation_factor 91 | s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) 92 | r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \ 93 | if random.random() <= 0.6 else 0 94 | 95 | if random.random() <= self.flip_prob: 96 | im = im[:, ::-1, :] 97 | parsing_anno = parsing_anno[:, ::-1] 98 | 99 | center[0] = im.shape[1] - center[0] - 1 100 | right_idx = [15, 17, 19] 101 | left_idx = [14, 16, 18] 102 | for i in range(0, 3): 103 | right_pos = np.where(parsing_anno == right_idx[i]) 104 | left_pos = np.where(parsing_anno == left_idx[i]) 105 | parsing_anno[right_pos[0], right_pos[1]] = left_idx[i] 106 | parsing_anno[left_pos[0], left_pos[1]] = right_idx[i] 107 | 108 | trans = get_affine_transform(center, s, r, self.crop_size) 109 | input = cv2.warpAffine( 110 | im, 111 | trans, 112 | (int(self.crop_size[1]), int(self.crop_size[0])), 113 | flags=cv2.INTER_LINEAR, 114 | borderMode=cv2.BORDER_CONSTANT, 115 | borderValue=(0, 0, 0)) 116 | 117 | if self.transform: 118 | input = self.transform(input) 119 | 120 | meta = { 121 | 'name': im_name, 122 | 'center': center, 123 | 'height': h, 124 | 'width': w, 125 | 'scale': s, 126 | 'rotation': r 127 | } 128 | 129 | if self.dataset != 'train': 130 | return input, meta 131 | else: 132 | label_parsing = cv2.warpAffine( 133 | parsing_anno, 134 | trans, 135 | (int(self.crop_size[1]), int(self.crop_size[0])), 136 | flags=cv2.INTER_NEAREST, 137 | borderMode=cv2.BORDER_CONSTANT, 138 | borderValue=(255)) 139 | 140 | # label_edge = generate_edge(label_parsing) 141 | hgt, wgt, hwgt = generate_hw_gt(label_parsing) 142 | label_parsing = torch.from_numpy(label_parsing) 143 | # label_edge = torch.from_numpy(label_edge) 144 | 145 | return input, label_parsing, hgt,wgt,hwgt, meta 146 | 147 | class LIPDataValSet(data.Dataset): 148 | def __init__(self, root, dataset='val', crop_size=[384, 384], transform=None, flip=False): 149 | self.root = root 150 | self.crop_size = crop_size 151 | self.transform = transform 152 | self.flip = flip 153 | self.dataset = dataset 154 | self.root = root 155 | self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0] 156 | self.crop_size = np.asarray(crop_size) 157 | 158 | list_path = os.path.join(self.root, self.dataset + '_id.txt') 159 | val_list = [i_id.strip() for i_id in open(list_path)] 160 | 161 | self.val_list = val_list 162 | self.number_samples = len(self.val_list) 163 | 164 | def __len__(self): 165 | return len(self.val_list) 166 | 167 | def _box2cs(self, box): 168 | x, y, w, h = box[:4] 169 | return self._xywh2cs(x, y, w, h) 170 | 171 | def _xywh2cs(self, x, y, w, h): 172 | center = np.zeros((2), dtype=np.float32) 173 | center[0] = x + w * 0.5 174 | center[1] = y + h * 0.5 175 | if w > self.aspect_ratio * h: 176 | h = w * 1.0 / self.aspect_ratio 177 | elif w < self.aspect_ratio * h: 178 | w = h * self.aspect_ratio 179 | scale = np.array([w * 1.0, h * 1.0], dtype=np.float32) 180 | 181 | return center, scale 182 | 183 | def __getitem__(self, index): 184 | val_item = self.val_list[index] 185 | # Load training image 186 | im_path = os.path.join(self.root, self.dataset + '_images', val_item + '.jpg') 187 | im = cv2.imread(im_path, cv2.IMREAD_COLOR) 188 | h, w, _ = im.shape 189 | # Get person center and scale 190 | person_center, s = self._box2cs([0, 0, w - 1, h - 1]) 191 | r = 0 192 | trans = get_affine_transform(person_center, s, r, self.crop_size) 193 | input = cv2.warpAffine( 194 | im, 195 | trans, 196 | (int(self.crop_size[1]), int(self.crop_size[0])), 197 | flags=cv2.INTER_LINEAR, 198 | borderMode=cv2.BORDER_CONSTANT, 199 | borderValue=(0, 0, 0)) 200 | input = self.transform(input) 201 | flip_input = input.flip(dims=[-1]) 202 | if self.flip: 203 | batch_input_im = torch.stack([input, flip_input]) 204 | else: 205 | batch_input_im = input 206 | 207 | meta = { 208 | 'name': val_item, 209 | 'center': person_center, 210 | 'height': h, 211 | 'width': w, 212 | 'scale': s, 213 | 'rotation': r 214 | } 215 | 216 | return batch_input_im, meta 217 | -------------------------------------------------------------------------------- /dataset/list/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/dataset/list/.DS_Store -------------------------------------------------------------------------------- /dataset/target_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import random 5 | import cv2 6 | import torch 7 | from torch.nn import functional as F 8 | def generate_hw_gt( target, class_num = 20 ): 9 | h,w = target.shape 10 | target = torch.from_numpy(target) 11 | target_c = target.clone() 12 | target_c[target_c==255]=0 13 | target_c = target_c.long() 14 | target_c = target_c.view(h*w) 15 | target_c = target_c.unsqueeze(1) 16 | target_onehot = torch.zeros(h*w,class_num) 17 | target_onehot.scatter_( 1, target_c, 1 ) #h*w,class_num 18 | target_onehot = target_onehot.transpose(0,1) 19 | target_onehot = target_onehot.view(class_num,h,w) 20 | # h distribution ground truth 21 | hgt = torch.zeros((class_num,h)) 22 | hgt=( torch.sum( target_onehot, dim=2 ) ).float() 23 | hgt[0,:] = 0 24 | max = torch.max(hgt,dim=1)[0] #c,1 25 | min = torch.min(hgt,dim=1)[0] 26 | max = max.unsqueeze(1) 27 | min = min.unsqueeze(1) 28 | hgt = hgt / ( max + 1e-5 ) 29 | # w distribution gound truth 30 | wgt = torch.zeros((class_num,w)) 31 | wgt=( torch.sum(target_onehot, dim=1 ) ).float() 32 | wgt[0,:]=0 33 | max = torch.max(wgt,dim=1)[0] #c,1 34 | min = torch.min(wgt,dim=1)[0] 35 | max = max.unsqueeze(1) 36 | min = min.unsqueeze(1) 37 | wgt = wgt / ( max + 1e-5 ) 38 | #=========================================================== 39 | hwgt = torch.matmul( hgt.transpose(0,1), wgt ) 40 | max = torch.max( hwgt.view(-1), dim=0 )[0] 41 | # print(max) 42 | hwgt = hwgt / ( max + 1.0e-5 ) 43 | #==================================================================== 44 | return hgt, wgt, hwgt #,cch, ccw gt_hw 45 | 46 | def generate_edge(label, edge_width=3): 47 | label = label.type(torch.cuda.FloatTensor) 48 | if len(label.shape) == 2: 49 | label = label.unsqueeze(0) 50 | n, h, w = label.shape 51 | edge = torch.zeros(label.shape, dtype=torch.float).cuda() 52 | # right 53 | edge_right = edge[:, 1:h, :] 54 | edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255) 55 | & (label[:, :h - 1, :] != 255)] = 1 56 | 57 | # up 58 | edge_up = edge[:, :, :w - 1] 59 | edge_up[(label[:, :, :w - 1] != label[:, :, 1:w]) 60 | & (label[:, :, :w - 1] != 255) 61 | & (label[:, :, 1:w] != 255)] = 1 62 | 63 | # upright 64 | edge_upright = edge[:, :h - 1, :w - 1] 65 | edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w]) 66 | & (label[:, :h - 1, :w - 1] != 255) 67 | & (label[:, 1:h, 1:w] != 255)] = 1 68 | 69 | # bottomright 70 | edge_bottomright = edge[:, :h - 1, 1:w] 71 | edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1]) 72 | & (label[:, :h - 1, 1:w] != 255) 73 | & (label[:, 1:h, :w - 1] != 255)] = 1 74 | 75 | kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda() 76 | with torch.no_grad(): 77 | edge = edge.unsqueeze(1) 78 | edge = F.conv2d(edge, kernel, stride=1, padding=1) 79 | edge[edge!=0] = 1 80 | edge = edge.squeeze() 81 | return edge -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import time 4 | import argparse 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from utils.logger import get_logger 10 | from utils.pyt_utils import parse_devices, all_reduce_tensor, extant_file 11 | ''' 12 | try: 13 | from apex.parallel import DistributedDataParallel, SyncBatchNorm 14 | except ImportError: 15 | raise ImportError( 16 | "Please install apex from https://www.github.com/nvidia/apex .") 17 | ''' 18 | 19 | logger = get_logger() 20 | 21 | 22 | class Engine(object): 23 | def __init__(self, custom_parser=None): 24 | logger.info( 25 | "PyTorch Version {}".format(torch.__version__)) 26 | self.devices = None 27 | self.distributed = False 28 | 29 | if custom_parser is None: 30 | self.parser = argparse.ArgumentParser() 31 | else: 32 | assert isinstance(custom_parser, argparse.ArgumentParser) 33 | self.parser = custom_parser 34 | 35 | self.inject_default_parser() 36 | self.args = self.parser.parse_args() 37 | 38 | self.continue_state_object = self.args.continue_fpath 39 | 40 | # if not self.args.gpu == 'None': 41 | # os.environ["CUDA_VISIBLE_DEVICES"]=self.args.gpu 42 | 43 | if 'WORLD_SIZE' in os.environ: 44 | self.distributed = int(os.environ['WORLD_SIZE']) > 1 45 | 46 | if self.distributed: 47 | self.local_rank = self.args.local_rank 48 | self.world_size = int(os.environ['WORLD_SIZE']) 49 | torch.cuda.set_device(self.local_rank) 50 | dist.init_process_group(backend="nccl", init_method='env://') 51 | self.devices = [i for i in range(self.world_size)] 52 | else: 53 | gpus = os.environ["CUDA_VISIBLE_DEVICES"] 54 | self.devices = [i for i in range(len(gpus.split(',')))] 55 | 56 | def inject_default_parser(self): 57 | p = self.parser 58 | p.add_argument('-d', '--devices', default='', 59 | help='set data parallel training') 60 | p.add_argument('-c', '--continue', type=extant_file, 61 | metavar="FILE", 62 | dest="continue_fpath", 63 | help='continue from one certain checkpoint') 64 | p.add_argument('--local_rank', default=0, type=int, 65 | help='process rank on node') 66 | 67 | def data_parallel(self, model): 68 | if self.distributed: 69 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.local_rank], 70 | output_device=self.local_rank,) 71 | else: 72 | model = torch.nn.DataParallel(model) 73 | return model 74 | 75 | def get_train_loader(self, train_dataset): 76 | train_sampler = None 77 | is_shuffle = True 78 | batch_size = self.args.batch_size 79 | 80 | if self.distributed: 81 | train_sampler = torch.utils.data.distributed.DistributedSampler( 82 | train_dataset) 83 | batch_size = self.args.batch_size // self.world_size 84 | is_shuffle = False 85 | 86 | train_loader = torch.utils.data.DataLoader(train_dataset, 87 | batch_size=batch_size, 88 | num_workers=self.args.num_workers, 89 | drop_last=False, 90 | shuffle=is_shuffle, 91 | pin_memory=True, 92 | sampler=train_sampler) 93 | 94 | return train_loader, train_sampler 95 | 96 | def get_test_loader(self, test_dataset): 97 | test_sampler = None 98 | is_shuffle = False 99 | batch_size = self.args.batch_size 100 | 101 | if self.distributed: 102 | test_sampler = torch.utils.data.distributed.DistributedSampler( 103 | test_dataset) 104 | batch_size = self.args.batch_size // self.world_size 105 | 106 | test_loader = torch.utils.data.DataLoader(test_dataset, 107 | batch_size=batch_size, 108 | num_workers=self.args.num_workers, 109 | drop_last=False, 110 | shuffle=is_shuffle, 111 | pin_memory=True, 112 | sampler=test_sampler) 113 | 114 | return test_loader, test_sampler 115 | 116 | 117 | def all_reduce_tensor(self, tensor, norm=True): 118 | if self.distributed: 119 | return all_reduce_tensor(tensor, world_size=self.world_size, norm=norm) 120 | else: 121 | return torch.mean(tensor) 122 | 123 | 124 | def __enter__(self): 125 | return self 126 | 127 | def __exit__(self, type, value, tb): 128 | torch.cuda.empty_cache() 129 | if type is not None: 130 | logger.warning( 131 | "A exception occurred during Engine initialization, " 132 | "give up running process") 133 | return False 134 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | torch.multiprocessing.set_start_method("spawn", force=True) 5 | from torch.utils import data 6 | from networks.CDGNet import Res_Deeplab 7 | from dataset.datasets import LIPDataSet 8 | import os 9 | import torchvision.transforms as transforms 10 | from utils.miou import compute_mean_ioU 11 | from copy import deepcopy 12 | 13 | from PIL import Image as PILImage 14 | 15 | DATA_DIRECTORY = '/ssd1/liuting14/Dataset/LIP/' 16 | DATA_LIST_PATH = './dataset/list/lip/valList.txt' 17 | IGNORE_LABEL = 255 18 | NUM_CLASSES = 20 19 | SNAPSHOT_DIR = './snapshots/' 20 | INPUT_SIZE = (473,473) 21 | 22 | # colour map 23 | COLORS = [(0,0,0) 24 | # 0=background 25 | ,(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128) 26 | # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle 27 | ,(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0) 28 | # 6=bus, 7=car, 8=cat, 9=chair, 10=cow 29 | ,(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128) 30 | # 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person 31 | ,(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)] 32 | # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor 33 | def get_lip_palette(): 34 | palette = [0,0,0, 35 | 128,0,0, 36 | 255,0,0, 37 | 0,85,0, 38 | 170,0,51, 39 | 255,85,0, 40 | 0,0,85, 41 | 0,119,221, 42 | 85,85,0, 43 | 0,85,85, 44 | 85,51,0, 45 | 52,86,128, 46 | 0,128,0, 47 | 0,0,255, 48 | 51,170,221, 49 | 0,255,255, 50 | 85,255,170, 51 | 170,255,85, 52 | 255,255,0, 53 | 255,170,0] 54 | return palette 55 | def get_palette(num_cls): 56 | """ Returns the color map for visualizing the segmentation mask. 57 | 58 | Inputs: 59 | =num_cls= 60 | Number of classes. 61 | 62 | Returns: 63 | The color map. 64 | """ 65 | n = num_cls 66 | palette = [0] * (n * 3) 67 | for j in range(0, n): 68 | lab = j 69 | palette[j * 3 + 0] = 0 70 | palette[j * 3 + 1] = 0 71 | palette[j * 3 + 2] = 0 72 | i = 0 73 | while lab: 74 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 75 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 76 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 77 | i += 1 78 | lab >>= 3 79 | return palette 80 | 81 | def get_arguments(): 82 | """Parse all the arguments provided from the CLI. 83 | 84 | Returns: 85 | A list of parsed arguments. 86 | """ 87 | parser = argparse.ArgumentParser(description="CE2P Network") 88 | parser.add_argument("--batch-size", type=int, default=1, 89 | help="Number of images sent to the network in one step.") 90 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 91 | help="Path to the directory containing the PASCAL VOC dataset.") 92 | parser.add_argument("--dataset", type=str, default='val', 93 | help="Path to the file listing the images in the dataset.") 94 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 95 | help="The index of the label to ignore during the training.") 96 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 97 | help="Number of classes to predict (including background).") 98 | parser.add_argument("--restore-from", type=str, 99 | help="Where restore model parameters from.") 100 | parser.add_argument("--gpu", type=str, default='0', 101 | help="choose gpu device.") 102 | parser.add_argument("--input-size", type=str, default=INPUT_SIZE, 103 | help="Comma-separated string with height and width of images.") 104 | 105 | return parser.parse_args() 106 | 107 | def valid(model, valloader, input_size, num_samples, gpus): 108 | model.eval() 109 | 110 | parsing_preds = np.zeros((num_samples, input_size[0], input_size[1]), 111 | dtype=np.uint8) 112 | 113 | scales = np.zeros((num_samples, 2), dtype=np.float32) 114 | centers = np.zeros((num_samples, 2), dtype=np.int32) 115 | 116 | idx = 0 117 | interp = torch.nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) 118 | with torch.no_grad(): 119 | for index, batch in enumerate(valloader): 120 | image, meta = batch 121 | num_images = image.size(0) 122 | if index % 10 == 0: 123 | print('%d processd' % (index * num_images)) 124 | 125 | c = meta['center'].numpy() 126 | s = meta['scale'].numpy() 127 | scales[idx:idx + num_images, :] = s[:, :] 128 | centers[idx:idx + num_images, :] = c[:, :] 129 | #==================================================================================== 130 | org_img = image.numpy() 131 | normal_img = org_img 132 | flipped_img = org_img[:,:,:,::-1] 133 | fused_img = np.concatenate( (normal_img,flipped_img), axis=0 ) 134 | outputs = model( torch.from_numpy(fused_img).cuda()) 135 | prediction = interp( outputs[0][-1].cpu()).data.numpy().transpose(0, 2, 3, 1) #N,H,W,C 136 | single_out = prediction[:num_images,:,:,:] 137 | single_out_flip = np.zeros( single_out.shape ) 138 | single_out_tmp = prediction[num_images:, :,:,:] 139 | for c in range(14): 140 | single_out_flip[:,:, :, c] = single_out_tmp[:, :, :, c] 141 | single_out_flip[:, :, :, 14] = single_out_tmp[:, :, :, 15] 142 | single_out_flip[:, :, :, 15] = single_out_tmp[:, :, :, 14] 143 | single_out_flip[:, :, :, 16] = single_out_tmp[:, :, :, 17] 144 | single_out_flip[:, :, :, 17] = single_out_tmp[:, :, :, 16] 145 | single_out_flip[:, :, :, 18] = single_out_tmp[:, :, :, 19] 146 | single_out_flip[:, :, :, 19] = single_out_tmp[:, :, :, 18] 147 | single_out_flip = single_out_flip[:, :, ::-1, :] 148 | # Fuse two outputs 149 | single_out = ( single_out+single_out_flip ) / 2 150 | parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(single_out, axis=3), dtype=np.uint8) 151 | #==================================================================================== 152 | # outputs = model(image.cuda()) 153 | # if gpus > 1: 154 | # for output in outputs: 155 | # parsing = output[0][-1] 156 | # nums = len(parsing) 157 | # parsing = interp(parsing).data.cpu().numpy() 158 | # parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC 159 | # parsing_preds[idx:idx + nums, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8) 160 | 161 | # idx += nums 162 | # else: 163 | # parsing = outputs[0][-1] 164 | # parsing = interp(parsing).data.cpu().numpy() 165 | # parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC 166 | # parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8) 167 | 168 | idx += num_images 169 | 170 | parsing_preds = parsing_preds[:num_samples, :, :] 171 | 172 | 173 | return parsing_preds, scales, centers 174 | 175 | def main(): 176 | """Create the model and start the evaluation process.""" 177 | args = get_arguments() 178 | 179 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 180 | gpus = [int(i) for i in args.gpu.split(',')] 181 | 182 | h, w = map(int, args.input_size.split(',')) 183 | 184 | input_size = (h, w) 185 | 186 | model = Res_Deeplab(num_classes=args.num_classes) 187 | 188 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 189 | std=[0.229, 0.224, 0.225]) 190 | 191 | transform = transforms.Compose([ 192 | transforms.ToTensor(), 193 | normalize, 194 | ]) 195 | 196 | lip_dataset = LIPDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform) 197 | num_samples = len(lip_dataset) 198 | 199 | valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus), 200 | shuffle=False, pin_memory=True) 201 | 202 | restore_from = args.restore_from 203 | 204 | state_dict = model.state_dict().copy() 205 | state_dict_old = torch.load(restore_from) 206 | 207 | for key, nkey in zip(state_dict_old.keys(), state_dict.keys()): 208 | if key != nkey: 209 | # remove the 'module.' in the 'key' 210 | state_dict[key[7:]] = deepcopy(state_dict_old[key]) 211 | else: 212 | state_dict[key] = deepcopy(state_dict_old[key]) 213 | 214 | model.load_state_dict(state_dict) 215 | 216 | model.eval() 217 | model.cuda() 218 | 219 | parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples, len(gpus)) 220 | 221 | #================================================================= 222 | # list_path = os.path.join(args.data_dir, args.dataset + '_id.txt') 223 | # val_id = [i_id.strip() for i_id in open(list_path)] 224 | # pred_root = os.path.join( args.data_dir, 'pred_parsing') 225 | # if not os.path.exists( pred_root ): 226 | # os.makedirs( pred_root ) 227 | # palette = get_lip_palette() 228 | # output_parsing = parsing_preds 229 | # for i in range( num_samples ): 230 | # output_image = PILImage.fromarray( output_parsing[i] ) 231 | # output_image.putpalette( palette ) 232 | # output_image.save( os.path.join( pred_root, str(val_id[i])+'.png')) 233 | #================================================================= 234 | 235 | mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size) 236 | 237 | print(mIoU) 238 | 239 | if __name__ == '__main__': 240 | main() 241 | -------------------------------------------------------------------------------- /evaluate_multi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | torch.multiprocessing.set_start_method("spawn", force=True) 5 | from torch.utils import data 6 | from networks.CDGNet import Res_Deeplab 7 | from dataset.datasets import LIPDataValSet 8 | import os 9 | import torchvision.transforms as transforms 10 | from utils.miou import compute_mean_ioU 11 | from copy import deepcopy 12 | import cv2 13 | 14 | from PIL import Image as PILImage 15 | 16 | DATA_DIRECTORY = '/ssd1/liuting14/Dataset/LIP/' 17 | DATA_LIST_PATH = './dataset/list/lip/valList.txt' 18 | IGNORE_LABEL = 255 19 | NUM_CLASSES = 20 20 | SNAPSHOT_DIR = './snapshots/' 21 | INPUT_SIZE = (473,473) 22 | 23 | # colour map 24 | COLORS = [(0,0,0) 25 | # 0=background 26 | ,(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128) 27 | # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle 28 | ,(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0) 29 | # 6=bus, 7=car, 8=cat, 9=chair, 10=cow 30 | ,(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128) 31 | # 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person 32 | ,(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)] 33 | # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor 34 | def get_lip_palette(): 35 | palette = [0,0,0, 36 | 128,0,0, 37 | 255,0,0, 38 | 0,85,0, 39 | 170,0,51, 40 | 255,85,0, 41 | 0,0,85, 42 | 0,119,221, 43 | 85,85,0, 44 | 0,85,85, 45 | 85,51,0, 46 | 52,86,128, 47 | 0,128,0, 48 | 0,0,255, 49 | 51,170,221, 50 | 0,255,255, 51 | 85,255,170, 52 | 170,255,85, 53 | 255,255,0, 54 | 255,170,0] 55 | return palette 56 | def get_palette(num_cls): 57 | """ Returns the color map for visualizing the segmentation mask. 58 | 59 | Inputs: 60 | =num_cls= 61 | Number of classes. 62 | 63 | Returns: 64 | The color map. 65 | """ 66 | n = num_cls 67 | palette = [0] * (n * 3) 68 | for j in range(0, n): 69 | lab = j 70 | palette[j * 3 + 0] = 0 71 | palette[j * 3 + 1] = 0 72 | palette[j * 3 + 2] = 0 73 | i = 0 74 | while lab: 75 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 76 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 77 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 78 | i += 1 79 | lab >>= 3 80 | return palette 81 | 82 | def get_arguments(): 83 | """Parse all the arguments provided from the CLI. 84 | 85 | Returns: 86 | A list of parsed arguments. 87 | """ 88 | parser = argparse.ArgumentParser(description="CE2P Network") 89 | parser.add_argument("--batch-size", type=int, default=1, 90 | help="Number of images sent to the network in one step.") 91 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 92 | help="Path to the directory containing the PASCAL VOC dataset.") 93 | parser.add_argument("--dataset", type=str, default='val', 94 | help="Path to the file listing the images in the dataset.") 95 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 96 | help="The index of the label to ignore during the training.") 97 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 98 | help="Number of classes to predict (including background).") 99 | parser.add_argument("--restore-from", type=str, 100 | help="Where restore model parameters from.") 101 | parser.add_argument("--gpu", type=str, default='0', 102 | help="choose gpu device.") 103 | parser.add_argument("--input-size", type=str, default=INPUT_SIZE, 104 | help="Comma-separated string with height and width of images.") 105 | 106 | return parser.parse_args() 107 | 108 | # def scale_image(image, scale): 109 | # image = image[0, :, :, :] 110 | # image = image.transpose((1, 2, 0)) 111 | # image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 112 | # image = image.transpose((2, 0, 1)) 113 | # return image 114 | 115 | def valid(model, valloader, input_size, num_samples, gpus): 116 | model.eval() 117 | 118 | parsing_preds = np.zeros((num_samples, input_size[0], input_size[1]), 119 | dtype=np.uint8) 120 | 121 | scales = np.zeros((num_samples, 2), dtype=np.float32) 122 | centers = np.zeros((num_samples, 2), dtype=np.int32) 123 | 124 | hpreds_lst = [] 125 | wpreds_lst = [] 126 | 127 | idx = 0 128 | interp = torch.nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True) 129 | eval_scale=[0.75,1.0,1.25] 130 | # eval_scale=[1.0] 131 | flipped_idx = (15, 14, 17, 16, 19, 18) 132 | with torch.no_grad(): 133 | for index, batch in enumerate(valloader): 134 | image, meta = batch 135 | # num_images = image.size(0) 136 | # print( image.size() ) 137 | image = image.squeeze() 138 | if index % 10 == 0: 139 | print('%d processd' % (index * 1)) 140 | c = meta['center'].numpy()[0] 141 | s = meta['scale'].numpy()[0] 142 | scales[idx, :] = s 143 | centers[idx, :] = c 144 | #==================================================================================== 145 | mul_outputs = [] 146 | for scale in eval_scale: 147 | interp_img = torch.nn.Upsample(scale_factor=scale, mode='bilinear', align_corners=True) 148 | scaled_img = interp_img( image ) 149 | # print( scaled_img.size() ) 150 | outputs = model( scaled_img.cuda() ) 151 | prediction = outputs[0][-1] 152 | #========================================================== 153 | hPreds = outputs[2][0] 154 | wPreds = outputs[2][1] 155 | hpreds_lst.append( hPreds[0].data.cpu().numpy() ) 156 | wpreds_lst.append( wPreds[0].data.cpu().numpy() ) 157 | #========================================================== 158 | single_output = prediction[0] 159 | flipped_output = prediction[1] 160 | flipped_output[14:20,:,:]=flipped_output[flipped_idx,:,:] 161 | single_output += flipped_output.flip(dims=[-1]) 162 | single_output *=0.5 163 | # print( single_output.size() ) 164 | single_output = interp( single_output.unsqueeze(0) ) 165 | mul_outputs.append( single_output[0] ) 166 | fused_prediction = torch.stack( mul_outputs ) 167 | fused_prediction = fused_prediction.mean(0) 168 | fused_prediction = fused_prediction.permute(1, 2, 0) # HWC 169 | fused_prediction = torch.argmax(fused_prediction, dim=2) 170 | fused_prediction = fused_prediction.data.cpu().numpy() 171 | parsing_preds[idx, :, :] = np.asarray(fused_prediction, dtype=np.uint8) 172 | #==================================================================================== 173 | idx += 1 174 | 175 | parsing_preds = parsing_preds[:num_samples, :, :] 176 | 177 | 178 | return parsing_preds, scales, centers, hpreds_lst, wpreds_lst 179 | 180 | def main(): 181 | """Create the model and start the evaluation process.""" 182 | args = get_arguments() 183 | 184 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 185 | gpus = [int(i) for i in args.gpu.split(',')] 186 | 187 | h, w = map(int, args.input_size.split(',')) 188 | 189 | input_size = (h, w) 190 | 191 | model = Res_Deeplab(num_classes=args.num_classes) 192 | 193 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 194 | std=[0.229, 0.224, 0.225]) 195 | 196 | transform = transforms.Compose([ 197 | transforms.ToTensor(), 198 | normalize, 199 | ]) 200 | 201 | lip_dataset = LIPDataValSet(args.data_dir, 'val', crop_size=input_size, transform=transform, flip = True ) 202 | num_samples = len(lip_dataset) 203 | 204 | valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus), 205 | shuffle=False, pin_memory=True) 206 | 207 | restore_from = args.restore_from 208 | 209 | state_dict = model.state_dict().copy() 210 | state_dict_old = torch.load(restore_from) 211 | 212 | for key, nkey in zip(state_dict_old.keys(), state_dict.keys()): 213 | if key != nkey: 214 | # remove the 'module.' in the 'key' 215 | state_dict[key[7:]] = deepcopy(state_dict_old[key]) 216 | else: 217 | state_dict[key] = deepcopy(state_dict_old[key]) 218 | 219 | model.load_state_dict(state_dict) 220 | 221 | model.eval() 222 | model.cuda() 223 | 224 | parsing_preds, scales, centers, hpredLst, wpredLst = valid(model, valloader, input_size, num_samples, len(gpus)) 225 | 226 | #================================================================= 227 | # list_path = os.path.join(args.data_dir, args.dataset + '_id.txt') 228 | # val_id = [i_id.strip() for i_id in open(list_path)] 229 | # # pred_root = os.path.join( args.data_dir, 'pred_parsing') 230 | # pred_root = os.path.join( os.getcwd(), 'pred_parsing') 231 | # print( pred_root ) 232 | # if not os.path.exists( pred_root ): 233 | # os.makedirs( pred_root ) 234 | # palette = get_lip_palette() 235 | # output_parsing = parsing_preds 236 | # for i in range( num_samples ): 237 | # output_image = PILImage.fromarray( output_parsing[i] ) 238 | # output_image.putpalette( palette ) 239 | # output_image.save( os.path.join( pred_root, str(val_id[i])+'.png')) 240 | # i=0 241 | # for i in range(len( hpredLst )): 242 | # filenameh = os.path.join( pred_root, str( val_id[i] ) + "_h" ) 243 | # np.save( filenameh, hpredLst[i] ) 244 | # filenamew = os.path.join( pred_root, str( val_id[i] ) + "_w" ) 245 | # np.save( filenamew, wpredLst[i] ) 246 | #================================================================= 247 | mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size) 248 | 249 | print(mIoU) 250 | 251 | if __name__ == '__main__': 252 | main() 253 | -------------------------------------------------------------------------------- /networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/networks/.DS_Store -------------------------------------------------------------------------------- /networks/CDGNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import numpy as np 7 | from torch.autograd import Variable 8 | affine_par = True 9 | import functools 10 | 11 | import sys, os 12 | from utils.attention import CDGAttention, C2CAttention 13 | from torch.nn import BatchNorm2d as BatchNorm2d 14 | 15 | def InPlaceABNSync(in_channel): 16 | layers = [ 17 | BatchNorm2d(in_channel), 18 | nn.ReLU(), 19 | ] 20 | return nn.Sequential(*layers) 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | "3x3 convolution with padding" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class Bottleneck(nn.Module): 29 | expansion = 4 30 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1): 31 | super(Bottleneck, self).__init__() 32 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 33 | self.bn1 = BatchNorm2d(planes) 34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 35 | padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False) 36 | self.bn2 = BatchNorm2d(planes) 37 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 38 | self.bn3 = BatchNorm2d(planes * 4) 39 | self.relu = nn.ReLU(inplace=False) 40 | self.relu_inplace = nn.ReLU(inplace=True) 41 | self.downsample = downsample 42 | self.dilation = dilation 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out = out + residual 63 | out = self.relu_inplace(out) 64 | 65 | return out 66 | 67 | class ASPPModule(nn.Module): 68 | """ 69 | Reference: 70 | Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."* 71 | """ 72 | def __init__(self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)): 73 | super(ASPPModule, self).__init__() 74 | 75 | self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), 76 | nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False), 77 | InPlaceABNSync(inner_features)) 78 | self.conv2 = nn.Sequential(nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False), 79 | InPlaceABNSync(inner_features)) 80 | self.conv3 = nn.Sequential(nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False), 81 | InPlaceABNSync(inner_features)) 82 | self.conv4 = nn.Sequential(nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False), 83 | InPlaceABNSync(inner_features)) 84 | self.conv5 = nn.Sequential(nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False), 85 | InPlaceABNSync(inner_features)) 86 | 87 | self.bottleneck = nn.Sequential( 88 | nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False), 89 | InPlaceABNSync(out_features), 90 | nn.Dropout2d(0.1) 91 | ) 92 | 93 | def forward(self, x): 94 | 95 | _, _, h, w = x.size() 96 | 97 | feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True) 98 | 99 | feat2 = self.conv2(x) 100 | feat3 = self.conv3(x) 101 | feat4 = self.conv4(x) 102 | feat5 = self.conv5(x) 103 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1) 104 | 105 | bottle = self.bottleneck(out) 106 | return bottle 107 | 108 | class Edge_Module(nn.Module): 109 | 110 | def __init__(self,in_fea=[256,512,1024], mid_fea=256, out_fea=2): 111 | super(Edge_Module, self).__init__() 112 | 113 | self.conv1 = nn.Sequential( 114 | nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), 115 | InPlaceABNSync(mid_fea) 116 | ) 117 | self.conv2 = nn.Sequential( 118 | nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), 119 | InPlaceABNSync(mid_fea) 120 | ) 121 | self.conv3 = nn.Sequential( 122 | nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), 123 | InPlaceABNSync(mid_fea) 124 | ) 125 | self.conv4 = nn.Conv2d(mid_fea,out_fea, kernel_size=3, padding=1, dilation=1, bias=True) 126 | self.conv5 = nn.Conv2d(out_fea*3,out_fea, kernel_size=1, padding=0, dilation=1, bias=True) 127 | 128 | def forward(self, x1, x2, x3): 129 | _, _, h, w = x1.size() 130 | 131 | edge1_fea = self.conv1(x1) 132 | edge1 = self.conv4(edge1_fea) 133 | edge2_fea = self.conv2(x2) 134 | edge2 = self.conv4(edge2_fea) 135 | edge3_fea = self.conv3(x3) 136 | edge3 = self.conv4(edge3_fea) 137 | 138 | edge2_fea = F.interpolate(edge2_fea, size=(h, w), mode='bilinear',align_corners=True) 139 | edge3_fea = F.interpolate(edge3_fea, size=(h, w), mode='bilinear',align_corners=True) 140 | edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear',align_corners=True) 141 | edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear',align_corners=True) 142 | 143 | edge = torch.cat([edge1, edge2, edge3], dim=1) 144 | edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1) 145 | edge = self.conv5(edge) 146 | return edge, edge_fea 147 | 148 | class PSPModule(nn.Module): 149 | """ 150 | Reference: 151 | Zhao, Hengshuang, et al. *"Pyramid scene parsing network."* 152 | """ 153 | def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)): 154 | super(PSPModule, self).__init__() 155 | 156 | self.stages = [] 157 | self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes]) 158 | self.bottleneck = nn.Sequential( 159 | nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=3, padding=1, dilation=1, bias=False), 160 | InPlaceABNSync(out_features), 161 | ) 162 | 163 | def _make_stage(self, features, out_features, size): 164 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 165 | conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False) 166 | bn = InPlaceABNSync(out_features) 167 | return nn.Sequential(prior, conv, bn) 168 | 169 | def forward(self, feats): 170 | h, w = feats.size(2), feats.size(3) 171 | priors = [ F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats] 172 | bottle = self.bottleneck(torch.cat(priors, 1)) 173 | return bottle 174 | 175 | class Decoder_Module(nn.Module): 176 | 177 | def __init__(self, num_classes): 178 | super(Decoder_Module, self).__init__() 179 | self.conv1 = nn.Sequential( 180 | nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False), 181 | InPlaceABNSync(256) 182 | ) 183 | self.conv2 = nn.Sequential( 184 | nn.Conv2d(256, 48, kernel_size=3, stride=1, padding=1, dilation=1, bias=False), 185 | InPlaceABNSync(48) 186 | ) 187 | self.conv3 = nn.Sequential( 188 | nn.Conv2d(304, 256, kernel_size=3, padding=1, dilation=1, bias=False), 189 | InPlaceABNSync(256), 190 | nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False), 191 | InPlaceABNSync(256) 192 | ) 193 | self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True) 194 | #========================================================================================= 195 | self.addCAM = nn.Sequential( 196 | nn.Conv2d(512, 256, kernel_size=3, padding=1, dilation=1, bias=False), 197 | InPlaceABNSync(256), 198 | ) 199 | #======================================================================================= 200 | def PCM(self, cam, f): 201 | n,c,h,w = f.size() 202 | cam = F.interpolate(cam, (h,w), mode='bilinear', align_corners=True).view(n,-1,h*w) 203 | f = f.view(n,-1,h*w) 204 | aff = torch.matmul(f.transpose(1,2), f) 205 | aff = ( c ** -0.5 ) * aff 206 | aff = F.softmax( aff, dim = -1 ) #无tanspose时是57.28,有transpose�?6.64 207 | cam_rv = torch.matmul(cam, aff).view(n,-1,h,w) 208 | return cam_rv 209 | def forward(self, xt, xl, xPCM = None ): 210 | _, _, h, w = xl.size() 211 | xt = F.interpolate(self.conv1(xt), size=(h, w), mode='bilinear', align_corners=True) 212 | xl = self.conv2(xl) 213 | x = torch.cat([xt, xl], dim=1) 214 | x = self.conv3(x) 215 | with torch.no_grad(): 216 | xM = F.relu( x.detach() ) 217 | xPCM = F.interpolate( self.PCM( xM, xPCM ), size=(h, w), mode='bilinear', align_corners=True ) 218 | x = torch.cat( [x, xPCM ], dim = 1 ) 219 | x = self.addCAM( x ) 220 | seg = self.conv4(x) 221 | return seg,x 222 | 223 | class ResNet(nn.Module): 224 | def __init__(self, block, layers, num_classes): 225 | self.inplanes = 128 226 | super(ResNet, self).__init__() 227 | self.conv1 = conv3x3(3, 64, stride=2) 228 | self.bn1 = BatchNorm2d(64) 229 | self.relu1 = nn.ReLU(inplace=False) 230 | self.conv2 = conv3x3(64, 64) 231 | self.bn2 = BatchNorm2d(64) 232 | self.relu2 = nn.ReLU(inplace=False) 233 | self.conv3 = conv3x3(64, 128) 234 | self.bn3 = BatchNorm2d(128) 235 | self.relu3 = nn.ReLU(inplace=False) 236 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 237 | 238 | self.layer1 = self._make_layer(block, 64, layers[0]) 239 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 240 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 241 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, multi_grid=(1,1,1)) 242 | 243 | self.layer5 = PSPModule(2048,512) 244 | 245 | self.edge_layer = Edge_Module() 246 | self.layer6 = Decoder_Module(num_classes) 247 | 248 | self.layer7 = nn.Sequential( 249 | nn.Conv2d(1024, 256, kernel_size=3, padding=1, dilation=1, bias=False), 250 | InPlaceABNSync(256), 251 | nn.Dropout2d(0.1), 252 | nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True) 253 | ) 254 | #=================================================================================== 255 | self.sq4 = nn.Sequential( 256 | nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False), 257 | InPlaceABNSync(256) 258 | ) 259 | self.sq5 = nn.Sequential( 260 | nn.Conv2d( 256, 256, kernel_size=1, padding=0, dilation=1, bias=False ), 261 | InPlaceABNSync(256) 262 | ) 263 | self.f9 = nn.Sequential( 264 | nn.Conv2d(256+256+3, 256, kernel_size=1, padding=0, dilation=1, bias=False), 265 | InPlaceABNSync(256) 266 | ) 267 | #=============================================================== 268 | self.hwAttention = CDGAttention(512, 256, num_classes, [473//4,473//4], 7 ) 269 | self.L = nn.Conv2d(1024, num_classes, kernel_size=1, padding=0, dilation=1, bias=True) 270 | #================================================================ 271 | for m in self.modules(): 272 | if isinstance(m, nn.Conv2d): 273 | nn.init.kaiming_normal_(m.weight.data) 274 | elif isinstance(m, BatchNorm2d): 275 | m.weight.data.fill_(1) 276 | m.bias.data.zero_() 277 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1): 278 | downsample = None 279 | if stride != 1 or self.inplanes != planes * block.expansion: 280 | downsample = nn.Sequential( 281 | nn.Conv2d(self.inplanes, planes * block.expansion, 282 | kernel_size=1, stride=stride, bias=False), 283 | BatchNorm2d(planes * block.expansion)) 284 | 285 | layers = [] 286 | generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1 287 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid))) 288 | self.inplanes = planes * block.expansion 289 | for i in range(1, blocks): 290 | layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid))) 291 | 292 | return nn.Sequential(*layers) 293 | 294 | def forward(self, x): 295 | x_org = x 296 | x = self.relu1(self.bn1(self.conv1(x))) 297 | x = self.relu2(self.bn2(self.conv2(x))) 298 | x = self.relu3(self.bn3(self.conv3(x))) 299 | x = self.maxpool(x) 300 | x2 = self.layer1(x) #1/4 256 301 | x3 = self.layer2(x2) #1/8 512 302 | x4 = self.layer3(x3) #1/16 1024 303 | seg0 = self.L(x4) 304 | x5 = self.layer4(x4) #1/16 2048 305 | x = self.layer5(x5) #1/16 512 306 | #============================================================== 307 | x,fea_h1, fea_w1 = self.hwAttention(x) 308 | #============================================================== 309 | edge,edge_fea = self.edge_layer(x2,x3,x4) 310 | #============================================================== 311 | n, c, h, w = x4.size() 312 | fr1 = self.sq5( x2 ) 313 | fr1 = F.interpolate( fr1, (h,w), mode='bilinear', align_corners= True ) 314 | fr1 = F.relu( fr1, inplace= True ) 315 | fr2 = self.sq4( x4 ) 316 | fr2 = F.interpolate( fr2, (h,w), mode='bilinear', align_corners= True ) 317 | fr2 = F.relu( fr2, inplace= True ) 318 | frOrg = F.interpolate( x_org,(h,w), mode='bilinear', align_corners=True ) 319 | fCat = torch.cat([frOrg, fr1, fr2 ], dim = 1) 320 | fCat = self.f9( fCat ) 321 | #============================================================== 322 | seg1,x = self.layer6( x,x2, fCat ) 323 | #============================================================= 324 | x = torch.cat([x, edge_fea], dim=1) 325 | seg2 = self.layer7(x) 326 | return [[seg0, seg1, seg2], [edge],[fea_h1,fea_w1]] 327 | 328 | def Res_Deeplab(num_classes=21): 329 | model = ResNet(Bottleneck,[3, 4, 23, 3], num_classes) 330 | return model 331 | -------------------------------------------------------------------------------- /networks/__pycache__/CDGNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/networks/__pycache__/CDGNet.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/CE2P.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/networks/__pycache__/CE2P.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/CE2P.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/networks/__pycache__/CE2P.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/CE2P.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/networks/__pycache__/CE2P.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/CE2PHybrid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/networks/__pycache__/CE2PHybrid.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/CE2P_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/networks/__pycache__/CE2P_test.cpython-36.pyc -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | uname -a 3 | #date 4 | #env 5 | date 6 | CS_PATH='/mnt/data/humanparsing/LIP' 7 | LR=3.0e-3 8 | WD=5e-4 9 | BS=8 10 | GPU_IDS=0,1,2,3 11 | RESTORE_FROM='./dataset/resnet101-imagenet.pth' 12 | INPUT_SIZE='473,473' 13 | SNAPSHOT_DIR='./snapshots' 14 | DATASET='train' 15 | NUM_CLASSES=20 16 | EPOCHS=150 17 | 18 | if [[ ! -e ${SNAPSHOT_DIR} ]]; then 19 | mkdir -p ${SNAPSHOT_DIR} 20 | fi 21 | 22 | python -m torch.distributed.launch --nproc_per_node=4 --nnode=1 \ 23 | --node_rank=0 --master_addr=222.32.33.224 --master_port 29500 train.py \ 24 | --data-dir ${CS_PATH} \ 25 | --random-mirror\ 26 | --random-scale\ 27 | --restore-from ${RESTORE_FROM}\ 28 | --gpu ${GPU_IDS}\ 29 | --learning-rate ${LR}\ 30 | --weight-decay ${WD}\ 31 | --batch-size ${BS} \ 32 | --input-size ${INPUT_SIZE}\ 33 | --snapshot-dir ${SNAPSHOT_DIR}\ 34 | --dataset ${DATASET}\ 35 | --num-classes ${NUM_CLASSES} \ 36 | --epochs ${EPOCHS} 37 | 38 | # python evaluate.py 39 | -------------------------------------------------------------------------------- /run_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CS_PATH='./dataset/LIP' 4 | CS_PATH='/mnt/data/humanparsing/LIP' 5 | BS=1 6 | GPU_IDS='1' 7 | INPUT_SIZE='473,473' 8 | SNAPSHOT_FROM='./snapshots/LIP_epoch_149.pth' 9 | DATASET='val' 10 | NUM_CLASSES=20 11 | 12 | CUDA_VISIBLE_DEVICES=1 python evaluate.py --data-dir ${CS_PATH} \ 13 | --gpu ${GPU_IDS} \ 14 | --batch-size ${BS} \ 15 | --input-size ${INPUT_SIZE}\ 16 | --restore-from ${SNAPSHOT_FROM}\ 17 | --dataset ${DATASET}\ 18 | --num-classes ${NUM_CLASSES} 19 | -------------------------------------------------------------------------------- /run_evaluate_multiScale.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CS_PATH='./dataset/LIP' 4 | CS_PATH='/mnt/data/humanparsing/LIP' 5 | # CS_PATH='/mnt/data/humanparsing/CIHP' 6 | BS=1 7 | GPU_IDS='1' 8 | INPUT_SIZE='473,473' 9 | SNAPSHOT_FROM='./snapshots/LIP_epoch_149.pth' 10 | DATASET='val' 11 | NUM_CLASSES=20 12 | 13 | CUDA_VISIBLE_DEVICES=1 python evaluate_multi.py --data-dir ${CS_PATH} \ 14 | --gpu ${GPU_IDS} \ 15 | --batch-size ${BS} \ 16 | --input-size ${INPUT_SIZE}\ 17 | --restore-from ${SNAPSHOT_FROM}\ 18 | --dataset ${DATASET}\ 19 | --num-classes ${NUM_CLASSES} 20 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | torch.multiprocessing.set_start_method("spawn", force=True) 5 | from torch.utils import data 6 | import numpy as np 7 | import torch.optim as optim 8 | import torchvision.utils as vutils 9 | import torch.backends.cudnn as cudnn 10 | import os 11 | import os.path as osp 12 | from networks.CDGNet import Res_Deeplab 13 | from dataset.datasets import LIPDataSet 14 | from dataset.target_generation import generate_edge 15 | import torchvision.transforms as transforms 16 | import timeit 17 | import torch.distributed as dist 18 | from tensorboardX import SummaryWriter 19 | from utils.utils import decode_parsing, inv_preprocess 20 | from utils.criterion import CriterionAll 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | from torch.utils.data.distributed import DistributedSampler 23 | 24 | from utils.miou import compute_mean_ioU 25 | from evaluate import valid 26 | 27 | start = timeit.default_timer() 28 | 29 | BATCH_SIZE = 1 30 | DATA_DIRECTORY = './dataset/LIP' 31 | DATA_LIST_PATH = './dataset/list/cityscapes/train.lst' 32 | IGNORE_LABEL = 255 33 | INPUT_SIZE = '769,769' 34 | LEARNING_RATE = 1e-2 35 | MOMENTUM = 0.9 36 | NUM_CLASSES = 20 37 | POWER = 0.9 38 | RANDOM_SEED = 1234 39 | RESTORE_FROM='./dataset/resnet101-imagenet.pth' 40 | SAVE_NUM_IMAGES = 2 41 | SAVE_PRED_EVERY = 10000 42 | SNAPSHOT_DIR = './snapshots/' 43 | WEIGHT_DECAY = 0.0005 44 | GPU_IDS='0' 45 | 46 | def reduce_loss(tensor, rank, world_size): 47 | with torch.no_grad(): 48 | dist.reduce(tensor, dst=0) 49 | if rank == 0: 50 | tensor /= world_size 51 | 52 | def str2bool(v): 53 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 54 | return True 55 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 56 | return False 57 | else: 58 | raise argparse.ArgumentTypeError('Boolean value expected.') 59 | 60 | 61 | def get_arguments(): 62 | """Parse all the arguments provided from the CLI. 63 | 64 | Returns: 65 | A list of parsed arguments. 66 | """ 67 | parser = argparse.ArgumentParser(description="CE2P Network") 68 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 69 | help="Number of images sent to the network in one step.") 70 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 71 | help="Path to the directory containing the dataset.") 72 | parser.add_argument("--dataset", type=str, default='train', choices=['train', 'val', 'trainval', 'test'], 73 | help="Path to the file listing the images in the dataset.") 74 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 75 | help="The index of the label to ignore during the training.") 76 | parser.add_argument("--input-size", type=str, default=INPUT_SIZE, 77 | help="Comma-separated string with height and width of images.") 78 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 79 | help="Base learning rate for training with polynomial decay.") 80 | parser.add_argument("--momentum", type=float, default=MOMENTUM, 81 | help="Momentum component of the optimiser.") 82 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 83 | help="Number of classes to predict (including background).") 84 | parser.add_argument("--start-iters", type=int, default=0, 85 | help="Number of classes to predict (including background).") 86 | parser.add_argument("--power", type=float, default=POWER, 87 | help="Decay parameter to compute the learning rate.") 88 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, 89 | help="Regularisation parameter for L2-loss.") 90 | parser.add_argument("--random-mirror", action="store_true", 91 | help="Whether to randomly mirror the inputs during the training.") 92 | parser.add_argument("--random-scale", action="store_true", 93 | help="Whether to randomly scale the inputs during the training.") 94 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 95 | help="Random seed to have reproducible results.") 96 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 97 | help="Where restore model parameters from.") 98 | parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, 99 | help="How many images to save.") 100 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, 101 | help="Where to save snapshots of the model.") 102 | parser.add_argument("--gpu", type=str, default=GPU_IDS, 103 | help="choose gpu device.") 104 | parser.add_argument("--start-epoch", type=int, default=0, 105 | help="choose the number of recurrence.") 106 | parser.add_argument("--epochs", type=int, default=150, 107 | help="choose the number of recurrence.") 108 | parser.add_argument('--local_rank', type=int, help="local gpu id") 109 | # os.environ['MASTER_ADDR'] = '202.30.29.226' 110 | # os.environ['MASTER_PORT'] = '8888' 111 | return parser.parse_args() 112 | 113 | 114 | args = get_arguments() 115 | 116 | 117 | def lr_poly(base_lr, iter, max_iter, power): 118 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 119 | 120 | 121 | def adjust_learning_rate(optimizer, i_iter, total_iters): 122 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 123 | lr = lr_poly(args.learning_rate, i_iter, total_iters, args.power) 124 | optimizer.param_groups[0]['lr'] = lr 125 | return lr 126 | 127 | 128 | def adjust_learning_rate_pose(optimizer, epoch): 129 | decay = 0 130 | if epoch + 1 >= 230: 131 | decay = 0.05 132 | elif epoch + 1 >= 200: 133 | decay = 0.1 134 | elif epoch + 1 >= 120: 135 | decay = 0.25 136 | elif epoch + 1 >= 90: 137 | decay = 0.5 138 | else: 139 | decay = 1 140 | 141 | lr = args.learning_rate * decay 142 | for param_group in optimizer.param_groups: 143 | param_group['lr'] = lr 144 | return lr 145 | 146 | 147 | def set_bn_eval(m): 148 | classname = m.__class__.__name__ 149 | if classname.find('BatchNorm') != -1: 150 | m.eval() 151 | 152 | 153 | def set_bn_momentum(m): 154 | classname = m.__class__.__name__ 155 | if classname.find('BatchNorm') != -1 or classname.find('InPlaceABN') != -1: 156 | m.momentum = 0.0003 157 | 158 | 159 | def main(): 160 | """Create the model and start the training.""" 161 | 162 | if not os.path.exists(args.snapshot_dir): 163 | os.makedirs(args.snapshot_dir) 164 | 165 | writer = SummaryWriter(args.snapshot_dir) 166 | gpus = [int(i) for i in args.gpu.split(',')] 167 | if not args.gpu == 'None': 168 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 169 | 170 | h, w = map(int, args.input_size.split(',')) 171 | input_size = [h, w] 172 | 173 | cudnn.enabled = True 174 | # cudnn related setting 175 | cudnn.benchmark = True 176 | torch.backends.cudnn.deterministic = False 177 | torch.backends.cudnn.enabled = True 178 | 179 | dist.init_process_group( backend='nccl', init_method='env://' ) 180 | torch.cuda.set_device( args.local_rank ) 181 | gloabl_rank = dist.get_rank() 182 | world_size = dist.get_world_size() 183 | print( world_size ) 184 | if world_size == 1: 185 | return 186 | dist.barrier() 187 | 188 | deeplab = Res_Deeplab(num_classes=args.num_classes) 189 | 190 | saved_state_dict = torch.load(args.restore_from) 191 | new_params = deeplab.state_dict().copy() 192 | for i in saved_state_dict: 193 | i_parts = i.split('.') 194 | # print(i_parts) 195 | if not i_parts[0] == 'fc': 196 | new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 197 | 198 | deeplab.load_state_dict(new_params) 199 | 200 | deeplab.cuda() 201 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(deeplab) 202 | model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank ) 203 | 204 | criterion = CriterionAll() 205 | # criterion = DataParallelCriterion(criterion) 206 | criterion.cuda() 207 | 208 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 209 | std=[0.229, 0.224, 0.225]) 210 | 211 | transform = transforms.Compose([ 212 | transforms.ToTensor(), 213 | normalize, 214 | ]) 215 | lipDataset = LIPDataSet(args.data_dir, args.dataset, crop_size=input_size, transform=transform) 216 | sampler = DistributedSampler(lipDataset) 217 | trainloader = data.DataLoader(lipDataset, 218 | batch_size=args.batch_size, shuffle=False, 219 | sampler = sampler, 220 | num_workers=4, 221 | pin_memory=True) 222 | #lip_dataset = LIPDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform) 223 | #num_samples = len(lip_dataset) 224 | 225 | #valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus), 226 | # shuffle=False, pin_memory=True) 227 | optimizer = optim.SGD( 228 | model.parameters(), 229 | lr=args.learning_rate, 230 | momentum=args.momentum, 231 | weight_decay=args.weight_decay, 232 | ) 233 | optimizer.zero_grad() 234 | total_iters = args.epochs * len(trainloader) 235 | 236 | # path = osp.join( args.snapshot_dir, 'model_LIP'+'.pth') 237 | # if os.path.exists( path ): 238 | # checkpoint = torch.load(path) 239 | # model.load_state_dict(checkpoint['model']) 240 | # optimizer.load_state_dict(checkpoint['optimizer']) 241 | # epoch = checkpoint['epoch'] 242 | # print( epoch ) 243 | # args.start_epoch = epoch 244 | # print( 'Load model first!') 245 | # else: 246 | # print( 'No model exits from beginning!') 247 | 248 | model.train() 249 | for epoch in range(args.start_epoch, args.epochs): 250 | sampler.set_epoch(epoch) 251 | for i_iter, batch in enumerate(trainloader): 252 | i_iter += len(trainloader) * epoch 253 | lr = adjust_learning_rate(optimizer, i_iter, total_iters) 254 | 255 | images, labels, hgt,wgt,hwgt,_ = batch 256 | labels = labels.cuda(non_blocking=True) 257 | edges = generate_edge(labels) 258 | labels = labels.type(torch.cuda.LongTensor) 259 | edges = edges.type(torch.cuda.LongTensor) 260 | hgt = hgt.float().cuda(non_blocking=True) 261 | wgt = wgt.float().cuda(non_blocking=True) 262 | hwgt = hwgt.float().cuda(non_blocking=True) 263 | preds = model(images) 264 | loss = criterion(preds, [labels, edges],[hgt,wgt,hwgt]) 265 | optimizer.zero_grad() 266 | loss.backward() 267 | optimizer.step() 268 | reduce_loss( loss, gloabl_rank, world_size ) 269 | # if i_iter % 100 == 0: 270 | # writer.add_scalar('learning_rate', lr, i_iter) 271 | # writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter) 272 | 273 | # if i_iter % 500 == 0: 274 | 275 | # images_inv = inv_preprocess(images, args.save_num_images) 276 | # labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False) 277 | # edges_colors = decode_parsing(edges, args.save_num_images, 2, is_pred=False) 278 | 279 | # if isinstance(preds, list): 280 | # preds = preds[0] 281 | # preds_colors = decode_parsing(preds[0][-1], args.save_num_images, args.num_classes, is_pred=True) 282 | # # pred_edges = decode_parsing(preds[1][-1], args.save_num_images, 2, is_pred=True) 283 | 284 | # img = vutils.make_grid(images_inv, normalize=False, scale_each=True) 285 | # lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True) 286 | # pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True) 287 | # edge = vutils.make_grid(edges_colors, normalize=False, scale_each=True) 288 | # # pred_edge = vutils.make_grid(pred_edges, normalize=False, scale_each=True) 289 | 290 | # writer.add_image('Images/', img, i_iter) 291 | # writer.add_image('Labels/', lab, i_iter) 292 | # writer.add_image('Preds/', pred, i_iter) 293 | # writer.add_image('Edges/', edge, i_iter) 294 | # writer.add_image('PredEdges/', pred_edge, i_iter) 295 | if gloabl_rank == 0: 296 | print('Epoch:{} iter = {} of {} completed, loss = {}'.format(epoch, i_iter, total_iters, loss.data.cpu().numpy())) 297 | if epoch > 140 and gloabl_rank == 0: 298 | torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'LIP_epoch_' + str(epoch) + '.pth')) 299 | if gloabl_rank == 0: 300 | path = osp.join( args.snapshot_dir, 'model_LIP'+'.pth') 301 | state = { 'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch } 302 | torch.save(state, path) 303 | #parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples, len(gpus)) 304 | 305 | #mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size) 306 | 307 | #print(mIoU) 308 | #writer.add_scalars('mIoU', mIoU, epoch) 309 | 310 | end = timeit.default_timer() 311 | print(end - start, 'seconds') 312 | 313 | 314 | if __name__ == '__main__': 315 | main() 316 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/.DS_Store -------------------------------------------------------------------------------- /utils/ImgTransforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from PIL import Image, ImageFilter 7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 8 | from torchvision import transforms 9 | # from torchvision import models,datasets 10 | # import matplotlib.pyplot as plt 11 | import random 12 | import cv2 13 | 14 | RESAMPLE_MODE=Image.BICUBIC 15 | 16 | # cat=cv2.imread('d:/testpy/839_482127.jpg') 17 | 18 | random_mirror = True 19 | 20 | def ShearX(img, v): # [-0.3, 0.3] 21 | assert -0.3 <= v <= 0.3 22 | if random_mirror and random.random() > 0.5: 23 | v = -v 24 | return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0), 25 | RESAMPLE_MODE) 26 | 27 | def ShearY(img, v): # [-0.3, 0.3] 28 | assert -0.3 <= v <= 0.3 29 | if random_mirror and random.random() > 0.5: 30 | v = -v 31 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0), 32 | RESAMPLE_MODE) 33 | 34 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 35 | assert -0.45 <= v <= 0.45 36 | if random_mirror and random.random() > 0.5: 37 | v = -v 38 | v = v * img.size[0] 39 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0), 40 | RESAMPLE_MODE) 41 | 42 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 43 | assert -0.45 <= v <= 0.45 44 | if random_mirror and random.random() > 0.5: 45 | v = -v 46 | v = v * img.size[1] 47 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v), 48 | RESAMPLE_MODE) 49 | 50 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 51 | assert 0 <= v 52 | if random.random() > 0.5: 53 | v = -v 54 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0), 55 | RESAMPLE_MODE) 56 | 57 | 58 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 59 | assert 0 <= v 60 | if random.random() > 0.5: 61 | v = -v 62 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v), 63 | RESAMPLE_MODE) 64 | 65 | def Rotate(img, v): # [-30, 30] 66 | assert -30 <= v <= 30 67 | if random_mirror and random.random() > 0.5: 68 | v = -v 69 | return img.rotate(v) 70 | 71 | def AutoContrast(img, _): 72 | return PIL.ImageOps.autocontrast(img,1) 73 | 74 | def Invert(img, _): 75 | return PIL.ImageOps.invert(img) 76 | 77 | def Equalize(img, _): 78 | return PIL.ImageOps.equalize(img) 79 | 80 | def Flip(img, _): # not from the paper 81 | return PIL.ImageOps.mirror(img) 82 | 83 | def Solarize(img, v): # [0, 256] 84 | assert 0 <= v <= 256 85 | return PIL.ImageOps.solarize(img, v) 86 | 87 | def SolarizeAdd(img, addition=0, threshold=128): 88 | img_np = np.array(img).astype(np.int) 89 | img_np = img_np + addition 90 | img_np = np.clip(img_np, 0, 255) 91 | img_np = img_np.astype(np.uint8) 92 | img = Image.fromarray(img_np) 93 | return PIL.ImageOps.solarize(img, threshold) 94 | 95 | def Posterize(img, v): # [4, 8] 96 | #assert 4 <= v <= 8 97 | v = int(v) 98 | return PIL.ImageOps.posterize(img, v) 99 | 100 | def Contrast(img, v): # [0.1,1.9] 101 | assert 0.1 <= v <= 1.9 102 | return PIL.ImageEnhance.Contrast(img).enhance(v) 103 | 104 | def Color(img, v): # [0.1,1.9] 105 | assert 0.1 <= v <= 1.9 106 | return PIL.ImageEnhance.Color(img).enhance(v) 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | def Sharpness(img, v): # [0.1,1.9] 113 | assert 0.1 <= v <= 1.9 114 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 115 | 116 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 117 | # assert 0 <= v <= 20 118 | if v < 0: 119 | return img 120 | w, h = img.size 121 | x0 = np.random.uniform(w) 122 | y0 = np.random.uniform(h) 123 | 124 | x0 = int(max(0, x0 - v / 2.)) 125 | y0 = int(max(0, y0 - v / 2.)) 126 | x1 = min(w, x0 + v) 127 | y1 = min(h, y0 + v) 128 | 129 | xy = (x0, y0, x1, y1) 130 | color = (125, 123, 114) 131 | # color = (0, 0, 0) 132 | img = img.copy() 133 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 134 | return img 135 | 136 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 137 | assert 0.0 <= v <= 0.2 138 | if v <= 0.: 139 | return img 140 | 141 | v = v * img.size[0] 142 | return CutoutAbs(img, v) 143 | 144 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 145 | assert 0 <= v <= 10 146 | if random.random() > 0.5: 147 | v = -v 148 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v), 149 | resample=RESAMPLE_MODE) 150 | 151 | 152 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 153 | assert 0 <= v <= 10 154 | if random.random() > 0.5: 155 | v = -v 156 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0), 157 | resample=RESAMPLE_MODE) 158 | 159 | def Posterize2(img, v): # [0, 4] 160 | assert 0 <= v <= 4 161 | v = int(v) 162 | return PIL.ImageOps.posterize(img, v) 163 | 164 | def SamplePairing(imgs): # [0, 0.4] 165 | def f(img1, v): 166 | i = np.random.choice(len(imgs)) 167 | img2 = Image.fromarray(imgs[i]) 168 | return Image.blend(img1, img2, v) 169 | 170 | return f 171 | 172 | def augment_list(for_autoaug=True): # 16 oeprations and their ranges 173 | l = [ 174 | (ShearX, -0.3, 0.3), # 0 175 | (ShearY, -0.3, 0.3), # 1 176 | (TranslateX, -0.45, 0.45), # 2 177 | (TranslateY, -0.45, 0.45), # 3 178 | (Rotate, -30, 30), # 4 179 | (AutoContrast, 0, 1), # 5 180 | (Invert, 0, 1), # 6 181 | (Equalize, 0, 1), # 7 182 | (Solarize, 0, 256), # 8 183 | (Posterize, 4, 8), # 9 184 | (Contrast, 0.1, 1.9), # 10 185 | (Color, 0.1, 1.9), # 11 186 | (Brightness, 0.1, 1.9), # 12 187 | (Sharpness, 0.1, 1.9), # 13 188 | (Cutout, 0, 0.2), # 14 189 | # (SamplePairing(imgs), 0, 0.4), # 15 190 | ] 191 | if for_autoaug: 192 | l += [ 193 | (CutoutAbs, 0, 20), # compatible with auto-augment 194 | (Posterize2, 0, 4), # 9 195 | (TranslateXAbs, 0, 10), # 9 196 | (TranslateYAbs, 0, 10), # 9 197 | ] 198 | return l 199 | 200 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} 201 | 202 | PARAMETER_MAX = 10 203 | 204 | 205 | def float_parameter(level, maxval): 206 | return float(level) * maxval / PARAMETER_MAX 207 | 208 | 209 | def int_parameter(level, maxval): 210 | return int(float_parameter(level, maxval)) 211 | 212 | def rand_augment_list(): # 16 oeprations and their ranges 213 | l = [ 214 | (AutoContrast, 0, 1), 215 | (Equalize, 0, 1), 216 | (Invert, 0, 1), 217 | (Rotate, 0, 30), 218 | (Posterize, 0, 4), 219 | (Solarize, 0, 256), 220 | (SolarizeAdd, 0, 110), 221 | (Color, 0.1, 1.9), 222 | (Contrast, 0.1, 1.9), 223 | (Brightness, 0.1, 1.9), 224 | (Sharpness, 0.1, 1.9), 225 | (ShearX, 0., 0.3), 226 | (ShearY, 0., 0.3), 227 | (CutoutAbs, 0, 40), 228 | (TranslateXabs, 0., 100), 229 | (TranslateYabs, 0., 100), 230 | ] 231 | 232 | return l 233 | 234 | def autoaug2fastaa(f): 235 | def autoaug(): 236 | mapper = defaultdict(lambda: lambda x: x) 237 | mapper.update({ 238 | 'ShearX': lambda x: float_parameter(x, 0.3), 239 | 'ShearY': lambda x: float_parameter(x, 0.3), 240 | 'TranslateX': lambda x: int_parameter(x, 10), 241 | 'TranslateY': lambda x: int_parameter(x, 10), 242 | 'Rotate': lambda x: int_parameter(x, 30), 243 | 'Solarize': lambda x: 256 - int_parameter(x, 256), 244 | 'Posterize2': lambda x: 4 - int_parameter(x, 4), 245 | 'Contrast': lambda x: float_parameter(x, 1.8) + .1, 246 | 'Color': lambda x: float_parameter(x, 1.8) + .1, 247 | 'Brightness': lambda x: float_parameter(x, 1.8) + .1, 248 | 'Sharpness': lambda x: float_parameter(x, 1.8) + .1, 249 | 'CutoutAbs': lambda x: int_parameter(x, 20) 250 | }) 251 | 252 | def low_high(name, prev_value): 253 | _, low, high = get_augment(name) 254 | return float(prev_value - low) / (high - low) 255 | 256 | policies = f() 257 | new_policies = [] 258 | for policy in policies: 259 | new_policies.append([(name, pr, low_high(name, mapper[name](level))) for name, pr, level in policy]) 260 | return new_policies 261 | 262 | return autoaug 263 | 264 | # @autoaug2fastaa 265 | def autoaug_imagenet_policies(): 266 | return [ 267 | # [('Posterize2', 0.4, 8), ('Rotate', 0.6, 9)], 268 | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], 269 | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], 270 | [('Posterize2', 0.6, 7), ('Posterize2', 0.6, 6)], 271 | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], 272 | # [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], 273 | [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], 274 | [('Posterize2', 0.8, 5), ('Equalize', 1.0, 2)], 275 | # [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], 276 | [('Equalize', 0.6, 8), ('Posterize2', 0.4, 6)], 277 | # [('Rotate', 0.8, 8), ('Color', 0.4, 0)], 278 | # [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], 279 | [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], 280 | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], 281 | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], 282 | # [('Rotate', 0.8, 8), ('Color', 1.0, 0)], 283 | [('Color', 0.8, 8), ('Solarize', 0.8, 7)], 284 | [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], 285 | # [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], 286 | [('Color', 0.4, 0), ('Equalize', 0.6, 3)], 287 | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], 288 | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], 289 | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], 290 | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], 291 | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], 292 | ] 293 | 294 | class ToPIL(object): 295 | """Convert image from ndarray format to PIL 296 | """ 297 | def __call__(self, img): 298 | x = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB)) 299 | return x 300 | 301 | class ToNDArray(object): 302 | def __call__(self, img): 303 | x = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR) 304 | return x 305 | 306 | class RandAugment(object): 307 | def __init__(self, n, m): 308 | self.n = n 309 | self.m = m 310 | self.augment_list = rand_augment_list() 311 | self.topil = ToPIL() 312 | 313 | def __call__(self, img): 314 | img = self.topil(img) 315 | ops = random.choices(self.augment_list, k=self.n) 316 | for op, minval, maxval in ops: 317 | if random.random() > random.uniform(0.2, 0.8): 318 | continue 319 | val = (float(self.m) / 30) * float(maxval - minval) + minval 320 | img = op(img, val) 321 | return img 322 | 323 | def get_augment(name): 324 | return augment_dict[name] 325 | 326 | 327 | def apply_augment(img, name, level): 328 | augment_fn, low, high = get_augment(name) 329 | return augment_fn(img.copy(), level * (high - low) + low) 330 | class PILGaussianBlur(ImageFilter.Filter): 331 | name = "GaussianBlur" 332 | def __init__(self, radius=2, bounds=None): 333 | self.radius = radius 334 | self.bounds = bounds 335 | def filter(self, image): 336 | if self.bounds: 337 | clips = image.crop(self.bounds).gaussian_blur(self.radius) 338 | image.paste(clips, self.bounds) 339 | return image 340 | else: 341 | return image.gaussian_blur(self.radius) 342 | class GaussianBlur(object): 343 | def __init__(self, radius=2 ): 344 | self.GaussianBlur=PILGaussianBlur(radius) 345 | def __call__(self, img): 346 | img = img.filter( self.GaussianBlur ) 347 | return img 348 | class AugmentationBlock(object): 349 | r""" 350 | AutoAugment Block 351 | 352 | Example 353 | ------- 354 | >>> from autogluon.utils.augment import AugmentationBlock, autoaug_imagenet_policies 355 | >>> aa_transform = AugmentationBlock(autoaug_imagenet_policies()) 356 | """ 357 | def __init__(self, policies): 358 | """ 359 | plicies : list of (name, pr, level) 360 | """ 361 | super().__init__() 362 | self.policies = policies() 363 | self.topil = ToPIL() 364 | self.tond = ToNDArray() 365 | self.Gaussian_blue = PILGaussianBlur(2) 366 | self.policy = [GaussianBlur(),transforms.ColorJitter( 0.1026, 0.0935, 0.8386, 0.1592 ), 367 | transforms.Grayscale(num_output_channels=3)] 368 | # self.colorAug = transforms.RandomApply([transforms.ColorJitter( 0.1026, 0.0935, 0.8386, 0.1592 )], p=0.5) 369 | def __call__(self, img): 370 | img = self.topil(img) 371 | trans = random.choice(self.policy) 372 | if random.random() >= 0.5: 373 | img = trans( img ) 374 | img = self.tond(img) 375 | return img 376 | 377 | 378 | # augBlock = AugmentationBlock( autoaug_imagenet_policies ) 379 | # plt.figure() 380 | # for i in range(20): 381 | # catAug = augBlock( cat ) 382 | # plt.subplot(4,5,i+1) 383 | # plt.imshow(catAug) 384 | 385 | # plt.show() 386 | # im_path = os.path.join('D:/testPy/839_482127.jpg') 387 | # img = Image.open( im_path ).convert('RGB') 388 | 389 | # factor = random.uniform(-0.4, 0.4) 390 | # imgb = T.adjust_brightness(img, 1 + factor) 391 | 392 | # imgc = transforms.ColorJitter( 0.4,0.4,0.4,0.4 )(img) 393 | 394 | # imgd = transforms.RandomHorizontalFlip()(img) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/ImgTransforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/ImgTransforms.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ImgTransforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/ImgTransforms.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/OCRAttention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/OCRAttention.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/criterion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/criterion.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/criterion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/criterion.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distributed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/distributed.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distributed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/distributed.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/encoding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/encoding.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/encoding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/encoding.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/lovasz_losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/lovasz_losses.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/lovasz_losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/lovasz_losses.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/miou.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/miou.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/miou.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/miou.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_store.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/model_store.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pyt_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/pyt_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tjpulkl/CDGNet/9daf7ddee6045c151c90a2e300946ea5f5717591/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/attention.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import math 5 | import torch.nn as nn 6 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 7 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 8 | from torch.nn import functional as F 9 | from torch.autograd import Variable 10 | import functools 11 | 12 | from torch.nn import BatchNorm2d as BatchNorm2d 13 | from torch.nn import BatchNorm1d as BatchNorm1d 14 | 15 | def conv2d(in_channel, out_channel, kernel_size): 16 | layers = [ 17 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=kernel_size // 2, bias=False), 18 | BatchNorm2d(out_channel), 19 | nn.ReLU(), 20 | ] 21 | 22 | return nn.Sequential(*layers) 23 | 24 | def conv1d(in_channel, out_channel, kernel_size): 25 | layers = [ 26 | nn.Conv1d(in_channel, out_channel, kernel_size, padding=kernel_size // 2, bias=False), 27 | BatchNorm1d(out_channel), 28 | nn.ReLU(), 29 | ] 30 | 31 | return nn.Sequential(*layers) 32 | 33 | 34 | class CDGAttention(nn.Module): 35 | def __init__(self, feat_in=512, feat_out=256, num_classes=20, size=[384//16,384//16], kernel_size =7 ): 36 | super(CDGAttention, self).__init__() 37 | h,w = size[0],size[1] 38 | kSize = kernel_size 39 | self.gamma = Parameter(torch.ones(1)) 40 | self.beta = Parameter(torch.ones(1)) 41 | self.rowpool = nn.AdaptiveAvgPool2d((h,1)) 42 | self.colpool = nn.AdaptiveAvgPool2d((1,w)) 43 | self.conv_hgt1 =conv1d(feat_in,feat_out,3) 44 | self.conv_hgt2 =conv1d(feat_in,feat_out,3) 45 | self.conv_hwPred1 = nn.Sequential( 46 | nn.Conv1d(feat_out,num_classes,3,stride=1,padding=1,bias=True), 47 | nn.Sigmoid(), 48 | ) 49 | self.conv_hwPred2 = nn.Sequential( 50 | nn.Conv1d(feat_out,num_classes,3,stride=1,padding=1,bias=True), 51 | nn.Sigmoid(), 52 | ) 53 | self.conv_upDim1 = nn.Sequential( 54 | nn.Conv1d(feat_out,feat_in,kSize,stride=1,padding=kSize//2,bias=True), 55 | nn.Sigmoid(), 56 | ) 57 | self.conv_upDim2 = nn.Sequential( 58 | nn.Conv1d(feat_out,feat_in,kSize,stride=1,padding=kSize//2,bias=True), 59 | nn.Sigmoid(), 60 | ) 61 | self.cmbFea = conv2d( feat_in*3,feat_in,3) 62 | def forward(self,fea): 63 | n,c,h,w = fea.size() 64 | fea_h = self.rowpool(fea).squeeze(3) #n,c,h 65 | fea_w = self.colpool(fea).squeeze(2) #n,c,w 66 | fea_h = self.conv_hgt1(fea_h) #n,c,h 67 | fea_w = self.conv_hgt2(fea_w) 68 | #=========================================================== 69 | fea_hp = self.conv_hwPred1(fea_h) #n,class_num,h 70 | fea_wp = self.conv_hwPred2(fea_w) #n,class_num,w 71 | #=========================================================== 72 | fea_h = self.conv_upDim1(fea_h) 73 | fea_w = self.conv_upDim2(fea_w) 74 | fea_hup = fea_h.unsqueeze(3) 75 | fea_wup = fea_w.unsqueeze(2) 76 | fea_hup = F.interpolate( fea_hup, (h,w), mode='bilinear', align_corners= True ) #n,c,h,w 77 | fea_wup = F.interpolate( fea_wup, (h,w), mode='bilinear', align_corners= True ) #n,c,h,w 78 | fea_hw = self.beta*fea_wup + self.gamma*fea_hup 79 | fea_hw_aug = fea * fea_hw 80 | #=============================================================== 81 | fea = torch.cat([fea, fea_hw_aug, fea_hw], dim = 1 ) 82 | fea = self.cmbFea( fea ) 83 | return fea, fea_hp, fea_wp 84 | 85 | class C2CAttention(nn.Module): 86 | def __init__(self, in_fea, out_fea, num_class ): 87 | super(C2CAttention, self).__init__() 88 | self.in_fea = in_fea 89 | self.out_fea = out_fea 90 | self.num_class = num_class 91 | self.gamma = Parameter(torch.ones(1)) 92 | self.beta = Parameter(torch.ones(1)) 93 | self.bias1 = Parameter( torch.FloatTensor( num_class, num_class )) 94 | self.bias2 = Parameter( torch.FloatTensor( num_class, num_class )) 95 | self.convDwn1 = conv2d( in_fea, out_fea, 1 ) 96 | self.convDwn2 = conv2d( in_fea, out_fea, 1 ) 97 | self.convUp1 = nn.Sequential( 98 | nn.AdaptiveAvgPool2d((1,1)), 99 | conv2d( num_class, out_fea, 1 ), 100 | nn.Conv2d(out_fea,in_fea,1,stride=1,padding=0,bias=True), 101 | ) 102 | self.toClass = nn.Sequential( 103 | nn.Conv2d( out_fea, num_class, 1, stride=1, padding = 0, bias = True ), 104 | ) 105 | self.convUp2 = nn.Sequential( 106 | nn.AdaptiveAvgPool2d((1,1)), 107 | conv2d( num_class, out_fea, 1 ), 108 | nn.Conv2d(out_fea,in_fea,1,stride=1,padding=0,bias=True), 109 | ) 110 | self.fea_fuse = conv2d( in_fea*2, in_fea, 1 ) 111 | self.sigmoid = nn.Sigmoid() 112 | self.reset_parameters() 113 | def reset_parameters(self): 114 | torch.nn.init.xavier_uniform_(self.bias1) 115 | torch.nn.init.xavier_uniform_(self.bias2) 116 | def forward(self,input_fea): 117 | n, c, h, w = input_fea.size() 118 | fea_ha = self.convDwn1( input_fea ) 119 | fea_wa = self.convDwn2( input_fea ) 120 | cls_ha = self.toClass( fea_ha ) 121 | cls_ha = F.softmax(cls_ha, dim=1) 122 | cls_wa = self.toClass( fea_wa ) 123 | cls_wa = F.softmax(cls_wa, dim=1) 124 | cls_ha = cls_ha.view( n, self.num_class, h*w ) 125 | cls_wa = cls_wa.view( n, self.num_class, h*w ) 126 | cch = F.relu(torch.matmul( cls_ha, cls_ha.transpose( 1, 2 ) )) #class*class 127 | cch = cch 128 | cch = self.sigmoid( cch ) + self.bias1 129 | ccw = F.relu(torch.matmul( cls_wa, cls_wa. transpose( 1, 2 ) )) #class*class 130 | ccw = ccw 131 | ccw = self.sigmoid( ccw )+ self.bias2 132 | cls_ha = torch.matmul( cls_ha.transpose(1,2), cch.transpose(1,2) ) 133 | cls_ha = cls_ha.transpose( 1,2).contiguous().view( n, self.num_class, h, w ) 134 | cls_wa = torch.matmul( cls_wa.transpose(1,2), ccw.transpose(1,2) ) 135 | cls_wa = cls_wa.transpose(1,2).contiguous().view( n, self.num_class, h, w ) 136 | fea_ha = self.convUp1( cls_ha ) 137 | fea_wa = self.convUp2( cls_wa ) 138 | fea_hwa = self.gamma*fea_ha + self.beta*fea_wa 139 | fea_hwa_aug = input_fea * fea_hwa #* 140 | fea_fuse = torch.cat( [fea_hwa_aug, input_fea], dim = 1 ) 141 | fea_fuse = self.fea_fuse( fea_fuse ) 142 | return fea_fuse, cch, ccw 143 | 144 | class StatisticAttention(nn.Module): 145 | def __init__(self,fea_in, fea_out, num_classes ): 146 | super(StatisticAttention, self).__init__() 147 | # self.gamma = Parameter(torch.ones(1)) 148 | self.conv_1 = conv2d( fea_in, fea_in//2, 1) #kernel size 3 149 | self.conv_2 = conv2d( fea_in//2, num_classes, 3 ) 150 | self.conv_pred = nn.Sequential( 151 | nn.Conv2d( num_classes, 1, 3, stride=1, padding=1, bias=True), #kernel size 1 152 | nn.Sigmoid() 153 | ) 154 | self.conv_fuse = conv2d( fea_in * 2, fea_out, 3 ) 155 | def forward(self,fea): 156 | fea_att = self.conv_1( fea ) 157 | fea_cls = self.conv_2( fea_att ) 158 | fea_stat = self.conv_pred( fea_cls ) 159 | fea_aug = fea * ( 1 - fea_stat ) 160 | fea_fuse = torch.cat( [fea, fea_aug], dim = 1 ) 161 | fea_res = self.conv_fuse( fea_fuse ) 162 | return fea_res, fea_stat 163 | 164 | class PSPModule(nn.Module): 165 | # (1, 2, 3, 6) 166 | def __init__(self, sizes=(1, 3, 7, 11), dimension=2): 167 | super(PSPModule, self).__init__() 168 | self.stages = nn.ModuleList([self._make_stage(size, dimension) for size in sizes]) 169 | 170 | def _make_stage(self, size, dimension=2): 171 | if dimension == 1: 172 | prior = nn.AdaptiveAvgPool1d(output_size=size) 173 | elif dimension == 2: 174 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 175 | elif dimension == 3: 176 | prior = nn.AdaptiveAvgPool3d(output_size=(size, size, size)) 177 | return prior 178 | 179 | def forward(self, feats): 180 | n, c, _, _ = feats.size() 181 | priors = [stage(feats).view(n, c, -1) for stage in self.stages] 182 | center = torch.cat(priors, -1) 183 | return center 184 | 185 | class PCM(Module): 186 | def __init__(self, feat_channels=[256,1024]): 187 | super().__init__() 188 | feat1, feat2 = feat_channels 189 | self.conv_x2 = conv2d( feat1, 256, 1 ) 190 | self.conv_x4 = conv2d( feat2, 256, 1 ) 191 | self.conv_cmb = conv2d( 256+256+3, 256, 1 ) 192 | self.softmax = Softmax(dim=-1) 193 | self.psp = PSPModule() 194 | self.addCAM = conv2d( 512, 256, 1) 195 | def forward(self, xOrg, stg2, stg4, cam ): 196 | n,c,h,w = stg2.size() 197 | stg2 = self.conv_x2( stg2 ) 198 | stg4 = self.conv_x4( stg4 ) 199 | stg4 = F.interpolate( stg4, (h,w), mode='bilinear', align_corners= True) 200 | stg0 = F.interpolate( xOrg, (h,w), mode='bilinear', align_corners= True) 201 | stgSum = torch.cat([stg0,stg2,stg4],dim=1) 202 | stgSum = self.conv_cmb( stgSum ) 203 | stgPool = self.psp( stgSum ) #(N,c,s) 204 | stgSum = stgSum.view( n, -1, h*w ).transpose(1,2) #(N,h*w,c) 205 | stg_aff = torch.matmul( stgSum, stgPool ) #(N,h*w,c)*(N,c,s)=(N,h*w,s) 206 | stg_aff = ( c ** -0.5 ) * stg_aff 207 | stg_aff = F.softmax( stg_aff, dim = -1 ) #(N,h*w,s) 208 | with torch.no_grad(): 209 | cam_d = F.relu( cam.detach() ) 210 | cam_d = F.interpolate( cam_d, (h,w), mode='bilinear', align_corners= True) 211 | cam_pool = self.psp( cam_d ).transpose(1,2) #(N,s,c) 212 | cam_rv = torch.matmul( stg_aff, cam_pool ).transpose(1,2) 213 | cam_rv=cam_rv.view(n, -1, h, w ) 214 | out = torch.cat([cam, cam_rv], dim=1 ) 215 | out = self.addCAM( out ) 216 | return out 217 | 218 | class GCM(Module): 219 | def __init__(self, feat_channels=512): 220 | super().__init__() 221 | 222 | chHig = feat_channels 223 | self.gamma = Parameter(torch.ones(1)) 224 | self.higC = conv2d( chHig, 256, 3 ) 225 | self.coe = nn.Sequential( 226 | conv2d( 256, 256, 3 ), 227 | nn.AdaptiveAvgPool2d((1,1)) 228 | ) 229 | 230 | def forward(self, fea ): 231 | n,_,h, w = fea.size() 232 | stgHig = self.higC( fea ) 233 | coeHig = self.coe( stgHig ) 234 | sim = stgHig - coeHig 235 | # print( sim.size() ) 236 | simDis = torch.norm( sim, 2, 1, keepdim = True ) 237 | # print( simDis.size() ) 238 | simDimMin = simDis.view( n, -1 ) 239 | simDisMin = torch.min( simDimMin, 1, keepdim = True )[0] 240 | # print( simDisMin.size() ) 241 | simDis = simDis.view( n, -1 ) 242 | weightHig = torch.exp( -( simDis - simDisMin ) / 5 ) 243 | weightHig = weightHig.view(n, -1, h, w ) 244 | upFea = F.interpolate( coeHig, (h,w), mode='bilinear', align_corners=True) 245 | upFea = upFea * weightHig 246 | stgHig = stgHig + self.gamma * upFea 247 | 248 | return weightHig, stgHig 249 | 250 | class LCM(Module): 251 | def __init__(self, feat_channels=[256, 256, 512]): 252 | super().__init__() 253 | 254 | chHig, chLow1, chLow2 = feat_channels 255 | self.beta = Parameter(torch.ones(1)) 256 | self.lowC1 = conv2d( chLow1, 48,3) 257 | self.lowC2 = conv2d( chLow2,128,3) 258 | self.cat1 = conv2d( 256+48, 256, 1 ) 259 | self.cat2 = conv2d( 256+128, 256, 1 ) 260 | 261 | def forward(self, feaHig, feaCeo, feaLow1, feaLow2 ): 262 | n,c,h,w = feaLow1.size() 263 | stgHig = F.interpolate( feaHig, (h,w), mode='bilinear', align_corners=True) 264 | weightLow = F.interpolate( feaCeo, (h,w), mode='bilinear', align_corners=True ) 265 | coeLow = 1 - weightLow 266 | stgLow1 = self.lowC1(feaLow1) 267 | stgLow2 = self.lowC2(feaLow2) 268 | stgLow2 = F.interpolate( stgLow2, (h,w), mode='bilinear', align_corners=True ) 269 | 270 | stgLow1 = self.beta * coeLow * stgLow1 271 | stgCat = torch.cat( [stgHig, stgLow1], dim = 1 ) 272 | stgCat = self.cat1( stgCat ) 273 | stgLow2 = self.beta * coeLow * stgLow2 274 | stgCat = torch.cat( [stgCat, stgLow2], dim = 1 ) 275 | stgCat = self.cat2( stgCat ) 276 | return stgCat 277 | -------------------------------------------------------------------------------- /utils/criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import utils.lovasz_losses as L 5 | from torch.nn import functional as F 6 | from torch.nn import Parameter 7 | from .loss import OhemCrossEntropy2d 8 | from dataset.target_generation import generate_edge 9 | 10 | 11 | # class ConsistencyLoss(nn.Module): 12 | # def __init__(self, ignore_index=255): 13 | # super(ConsistencyLoss, self).__init__() 14 | # self.ignore_index=ignore_index 15 | 16 | # def forward(self, parsing, edge, label): 17 | # parsing_pre = torch.argmax(parsing, dim=1) 18 | # parsing_pre[label==self.ignore_index]=self.ignore_index 19 | # generated_edge = generate_edge(parsing_pre) 20 | # edge_pre = torch.argmax(edge, dim=1) 21 | # v_generate_edge = generated_edge[label!=255] 22 | # v_edge_pre = edge_pre[label!=255] 23 | # v_edge_pre = v_edge_pre.type(torch.cuda.FloatTensor) 24 | # positive_union = (v_generate_edge==1)&(v_edge_pre==1) # only the positive values count 25 | # return F.smooth_l1_loss(v_generate_edge[positive_union].squeeze(0), v_edge_pre[positive_union].squeeze(0)) 26 | 27 | class CriterionAll(nn.Module): 28 | def __init__(self, ignore_index=255): 29 | super(CriterionAll, self).__init__() 30 | self.ignore_index = ignore_index 31 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) 32 | # self.ConsEdge = ConsistencyLoss(ignore_index=ignore_index) 33 | self.cos_sim = torch.nn.CosineSimilarity(dim=-1) 34 | # self.l2Loss = torch.nn.MSELoss(reduction='mean') 35 | self.l2loss = torch.nn.MSELoss() 36 | 37 | def parsing_loss(self, preds, target, hwgt ): 38 | h, w = target[0].size(1), target[0].size(2) 39 | 40 | pos_num = torch.sum(target[1] == 1, dtype=torch.float) 41 | neg_num = torch.sum(target[1] == 0, dtype=torch.float) 42 | 43 | weight_pos = neg_num / (pos_num + neg_num) 44 | weight_neg = pos_num / (pos_num + neg_num) 45 | weights = torch.tensor([weight_neg, weight_pos]) 46 | loss = 0 47 | 48 | # loss for parsing 49 | pws = [0.4,1,1,1] 50 | preds_parsing = preds[0] 51 | ind = 0 52 | tmpLoss = 0 53 | if isinstance(preds_parsing, list): 54 | for pred_parsing in preds_parsing: 55 | scale_pred = F.interpolate(input=pred_parsing, size=(h, w), 56 | mode='bilinear', align_corners=True) 57 | tmpLoss = self.criterion(scale_pred, target[0]) 58 | scale_pred = F.softmax( scale_pred, dim = 1 ) 59 | tmpLoss += L.lovasz_softmax( scale_pred, target[0], ignore = self.ignore_index ) 60 | tmpLoss *= pws[ind] 61 | loss += tmpLoss 62 | ind+=1 63 | else: 64 | scale_pred = F.interpolate(input=preds_parsing, size=(h, w), 65 | mode='bilinear', align_corners=True) 66 | loss += self.criterion(scale_pred, target[0]) 67 | # scale_pred = F.softmax( scale_pred, dim = 1 ) 68 | # loss += L.lovasz_softmax( scale_pred, target[0], ignore = self.ignore_index ) 69 | 70 | # loss for edge 71 | tmpLoss = 0 72 | preds_edge = preds[1] 73 | if isinstance(preds_edge, list): 74 | for pred_edge in preds_edge: 75 | scale_pred = F.interpolate(input=pred_edge, size=(h, w), 76 | mode='bilinear', align_corners=True) 77 | tmpLoss += F.cross_entropy(scale_pred, target[1], 78 | weights.cuda(), ignore_index=self.ignore_index) 79 | else: 80 | scale_pred = F.interpolate(input=preds_edge, size=(h, w), 81 | mode='bilinear', align_corners=True) 82 | tmpLoss += F.cross_entropy(scale_pred, target[1], 83 | weights.cuda(), ignore_index=self.ignore_index) 84 | loss += tmpLoss 85 | # loss for height and width attention 86 | #loss for hwattention 87 | hwLoss = 0 88 | hgt = hwgt[0] 89 | wgt = hwgt[1] 90 | n,c,h = hgt.size() 91 | w = wgt.size()[2] 92 | hpred = preds[2][0] #fea_h... 93 | scale_hpred = hpred.unsqueeze(3) #n,c,h,1 94 | scale_hpred = F.interpolate(input=scale_hpred, size=(h,1),mode='bilinear', align_corners=True) 95 | scale_hpred = scale_hpred.squeeze(3) #n,c,h 96 | # hgt = hgt[:,1:,:] 97 | # scale_hpred=scale_hpred[:,1:,:] 98 | hloss = torch.mean( ( hgt - scale_hpred ) * ( hgt - scale_hpred ) ) 99 | wpred = preds[2][1] #fea_w... 100 | scale_wpred = wpred.unsqueeze(2) #n,c,1,w 101 | scale_wpred = F.interpolate(input=scale_wpred, size=(1,w),mode='bilinear', align_corners=True) 102 | scale_wpred = scale_wpred.squeeze(2) #n,c,w 103 | # wgt=wgt[:,1:,:] 104 | # scale_wpred = scale_wpred[:,1:,:] 105 | wloss = torch.mean( ( wgt - scale_wpred ) * ( wgt - scale_wpred ) ) 106 | hwLoss = ( hloss + wloss ) * 45 107 | loss += hwLoss 108 | return loss 109 | 110 | def forward(self, preds, target, hwgt ): 111 | 112 | loss = self.parsing_loss(preds, target, hwgt ) 113 | return loss 114 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def all_gather(data): 45 | world_size = get_world_size() 46 | 47 | if world_size == 1: 48 | return [data] 49 | 50 | buffer = pickle.dumps(data) 51 | storage = torch.ByteStorage.from_buffer(buffer) 52 | tensor = torch.ByteTensor(storage).to('cuda') 53 | 54 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 55 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 56 | dist.all_gather(size_list, local_size) 57 | size_list = [int(size.item()) for size in size_list] 58 | max_size = max(size_list) 59 | 60 | tensor_list = [] 61 | for _ in size_list: 62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 63 | 64 | if local_size != max_size: 65 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 66 | tensor = torch.cat((tensor, padding), 0) 67 | 68 | dist.all_gather(tensor_list, tensor) 69 | 70 | data_list = [] 71 | 72 | for size, tensor in zip(size_list, tensor_list): 73 | buffer = tensor.cpu().numpy().tobytes()[:size] 74 | data_list.append(pickle.loads(buffer)) 75 | 76 | return data_list 77 | 78 | 79 | def reduce_loss_dict(loss_dict): 80 | world_size = get_world_size() 81 | 82 | if world_size < 2: 83 | return loss_dict 84 | 85 | with torch.no_grad(): 86 | keys = [] 87 | losses = [] 88 | 89 | for k in sorted(loss_dict.keys()): 90 | keys.append(k) 91 | losses.append(loss_dict[k]) 92 | 93 | losses = torch.stack(losses, 0) 94 | dist.reduce(losses, dst=0) 95 | 96 | if dist.get_rank() == 0: 97 | losses /= world_size 98 | 99 | reduced_losses = {k: v for k, v in zip(keys, losses)} 100 | 101 | return reduced_losses 102 | 103 | 104 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 105 | # Code is copy-pasted exactly as in torch.utils.data.distributed. 106 | # FIXME remove this once c10d fixes the bug it has 107 | 108 | 109 | class DistributedSampler(Sampler): 110 | """Sampler that restricts data loading to a subset of the dataset. 111 | It is especially useful in conjunction with 112 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 113 | process can pass a DistributedSampler instance as a DataLoader sampler, 114 | and load a subset of the original dataset that is exclusive to it. 115 | .. note:: 116 | Dataset is assumed to be of constant size. 117 | Arguments: 118 | dataset: Dataset used for sampling. 119 | num_replicas (optional): Number of processes participating in 120 | distributed training. 121 | rank (optional): Rank of the current process within num_replicas. 122 | """ 123 | 124 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 125 | if num_replicas is None: 126 | if not dist.is_available(): 127 | raise RuntimeError("Requires distributed package to be available") 128 | num_replicas = dist.get_world_size() 129 | if rank is None: 130 | if not dist.is_available(): 131 | raise RuntimeError("Requires distributed package to be available") 132 | rank = dist.get_rank() 133 | self.dataset = dataset 134 | self.num_replicas = num_replicas 135 | self.rank = rank 136 | self.epoch = 0 137 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 138 | self.total_size = self.num_samples * self.num_replicas 139 | self.shuffle = shuffle 140 | 141 | def __iter__(self): 142 | if self.shuffle: 143 | # deterministically shuffle based on epoch 144 | g = torch.Generator() 145 | g.manual_seed(self.epoch) 146 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 147 | else: 148 | indices = torch.arange(len(self.dataset)).tolist() 149 | 150 | # add extra samples to make it evenly divisible 151 | indices += indices[: (self.total_size - len(indices))] 152 | assert len(indices) == self.total_size 153 | 154 | # subsample 155 | offset = self.num_samples * self.rank 156 | indices = indices[offset : offset + self.num_samples] 157 | assert len(indices) == self.num_samples 158 | 159 | return iter(indices) 160 | 161 | def __len__(self): 162 | return self.num_samples 163 | 164 | def set_epoch(self, epoch): 165 | self.epoch = epoch 166 | -------------------------------------------------------------------------------- /utils/encoding.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """Encoding Data Parallel""" 12 | import threading 13 | import functools 14 | import torch 15 | from torch.autograd import Variable, Function 16 | import torch.cuda.comm as comm 17 | from torch.nn.parallel.data_parallel import DataParallel 18 | from torch.nn.parallel.parallel_apply import get_a_var 19 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 20 | 21 | torch_ver = torch.__version__[:3] 22 | 23 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 24 | 'patch_replication_callback'] 25 | 26 | def allreduce(*inputs): 27 | """Cross GPU all reduce autograd operation for calculate mean and 28 | variance in SyncBN. 29 | """ 30 | return AllReduce.apply(*inputs) 31 | 32 | class AllReduce(Function): 33 | @staticmethod 34 | def forward(ctx, num_inputs, *inputs): 35 | ctx.num_inputs = num_inputs 36 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 37 | inputs = [inputs[i:i + num_inputs] 38 | for i in range(0, len(inputs), num_inputs)] 39 | # sort before reduce sum 40 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 41 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 42 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 43 | return tuple([t for tensors in outputs for t in tensors]) 44 | 45 | @staticmethod 46 | def backward(ctx, *inputs): 47 | inputs = [i.data for i in inputs] 48 | inputs = [inputs[i:i + ctx.num_inputs] 49 | for i in range(0, len(inputs), ctx.num_inputs)] 50 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 51 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 52 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 53 | 54 | 55 | class Reduce(Function): 56 | @staticmethod 57 | def forward(ctx, *inputs): 58 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 59 | inputs = sorted(inputs, key=lambda i: i.get_device()) 60 | return comm.reduce_add(inputs) 61 | 62 | @staticmethod 63 | def backward(ctx, gradOutput): 64 | return Broadcast.apply(ctx.target_gpus, gradOutput) 65 | 66 | 67 | class DataParallelModel(DataParallel): 68 | """Implements data parallelism at the module level. 69 | 70 | This container parallelizes the application of the given module by 71 | splitting the input across the specified devices by chunking in the 72 | batch dimension. 73 | In the forward pass, the module is replicated on each device, 74 | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. 75 | Note that the outputs are not gathered, please use compatible 76 | :class:`encoding.parallel.DataParallelCriterion`. 77 | 78 | The batch size should be larger than the number of GPUs used. It should 79 | also be an integer multiple of the number of GPUs so that each chunk is 80 | the same size (so that each GPU processes the same number of samples). 81 | 82 | Args: 83 | module: module to be parallelized 84 | device_ids: CUDA devices (default: all devices) 85 | 86 | Reference: 87 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 88 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 89 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 90 | 91 | Example:: 92 | 93 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 94 | >>> y = net(x) 95 | """ 96 | def gather(self, outputs, output_device): 97 | return outputs 98 | 99 | def replicate(self, module, device_ids): 100 | modules = super(DataParallelModel, self).replicate(module, device_ids) 101 | execute_replication_callbacks(modules) 102 | return modules 103 | 104 | 105 | class DataParallelCriterion(DataParallel): 106 | """ 107 | Calculate loss in multiple-GPUs, which balance the memory usage for 108 | Semantic Segmentation. 109 | 110 | The targets are splitted across the specified devices by chunking in 111 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 112 | 113 | Reference: 114 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 115 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 116 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 117 | 118 | Example:: 119 | 120 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 121 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 122 | >>> y = net(x) 123 | >>> loss = criterion(y, target) 124 | """ 125 | def forward(self, inputs, *targets, **kwargs): 126 | # input should be already scatterd 127 | # scattering the targets instead 128 | if not self.device_ids: 129 | return self.module(inputs, *targets, **kwargs) 130 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 131 | if len(self.device_ids) == 1: 132 | return self.module(inputs, *targets[0], **kwargs[0]) 133 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 134 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 135 | return Reduce.apply(*outputs) / len(outputs) 136 | #return self.gather(outputs, self.output_device).mean() 137 | 138 | 139 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 140 | assert len(modules) == len(inputs) 141 | assert len(targets) == len(inputs) 142 | if kwargs_tup: 143 | assert len(modules) == len(kwargs_tup) 144 | else: 145 | kwargs_tup = ({},) * len(modules) 146 | if devices is not None: 147 | assert len(modules) == len(devices) 148 | else: 149 | devices = [None] * len(modules) 150 | 151 | lock = threading.Lock() 152 | results = {} 153 | if torch_ver != "0.3": 154 | grad_enabled = torch.is_grad_enabled() 155 | 156 | def _worker(i, module, input, target, kwargs, device=None): 157 | if torch_ver != "0.3": 158 | torch.set_grad_enabled(grad_enabled) 159 | if device is None: 160 | device = get_a_var(input).get_device() 161 | try: 162 | if not isinstance(input, tuple): 163 | input = (input,) 164 | with torch.cuda.device(device): 165 | output = module(*(input + target), **kwargs) 166 | with lock: 167 | results[i] = output 168 | except Exception as e: 169 | with lock: 170 | results[i] = e 171 | 172 | if len(modules) > 1: 173 | threads = [threading.Thread(target=_worker, 174 | args=(i, module, input, target, 175 | kwargs, device),) 176 | for i, (module, input, target, kwargs, device) in 177 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 178 | 179 | for thread in threads: 180 | thread.start() 181 | for thread in threads: 182 | thread.join() 183 | else: 184 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 185 | 186 | outputs = [] 187 | for i in range(len(inputs)): 188 | output = results[i] 189 | if isinstance(output, Exception): 190 | raise output 191 | outputs.append(output) 192 | return outputs 193 | 194 | 195 | ########################################################################### 196 | # Adapted from Synchronized-BatchNorm-PyTorch. 197 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 198 | # 199 | class CallbackContext(object): 200 | pass 201 | 202 | 203 | def execute_replication_callbacks(modules): 204 | """ 205 | Execute an replication callback `__data_parallel_replicate__` on each module created 206 | by original replication. 207 | 208 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 209 | 210 | Note that, as all modules are isomorphism, we assign each sub-module with a context 211 | (shared among multiple copies of this module on different devices). 212 | Through this context, different copies can share some information. 213 | 214 | We guarantee that the callback on the master copy (the first copy) will be called ahead 215 | of calling the callback of any slave copies. 216 | """ 217 | master_copy = modules[0] 218 | nr_modules = len(list(master_copy.modules())) 219 | ctxs = [CallbackContext() for _ in range(nr_modules)] 220 | 221 | for i, module in enumerate(modules): 222 | for j, m in enumerate(module.modules()): 223 | if hasattr(m, '__data_parallel_replicate__'): 224 | m.__data_parallel_replicate__(ctxs[j], i) 225 | 226 | 227 | def patch_replication_callback(data_parallel): 228 | """ 229 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 230 | Useful when you have customized `DataParallel` implementation. 231 | 232 | Examples: 233 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 234 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 235 | > patch_replication_callback(sync_bn) 236 | # this is equivalent to 237 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 238 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 239 | """ 240 | 241 | assert isinstance(data_parallel, DataParallel) 242 | 243 | old_replicate = data_parallel.replicate 244 | 245 | @functools.wraps(old_replicate) 246 | def new_replicate(module, device_ids): 247 | modules = old_replicate(module, device_ids) 248 | execute_replication_callbacks(modules) 249 | return modules 250 | 251 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | 5 | # from . import pyt_utils 6 | # from utils.pyt_utils import ensure_dir 7 | 8 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') 9 | _default_level = logging.getLevelName(_default_level_name.upper()) 10 | 11 | 12 | class LogFormatter(logging.Formatter): 13 | log_fout = None 14 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' 15 | date = '%(asctime)s ' 16 | msg = '%(message)s' 17 | 18 | def format(self, record): 19 | if record.levelno == logging.DEBUG: 20 | mcl, mtxt = self._color_dbg, 'DBG' 21 | elif record.levelno == logging.WARNING: 22 | mcl, mtxt = self._color_warn, 'WRN' 23 | elif record.levelno == logging.ERROR: 24 | mcl, mtxt = self._color_err, 'ERR' 25 | else: 26 | mcl, mtxt = self._color_normal, '' 27 | 28 | if mtxt: 29 | mtxt += ' ' 30 | 31 | if self.log_fout: 32 | self.__set_fmt(self.date_full + mtxt + self.msg) 33 | formatted = super(LogFormatter, self).format(record) 34 | # self.log_fout.write(formatted) 35 | # self.log_fout.write('\n') 36 | # self.log_fout.flush() 37 | return formatted 38 | 39 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) 40 | formatted = super(LogFormatter, self).format(record) 41 | 42 | return formatted 43 | 44 | if sys.version_info.major < 3: 45 | def __set_fmt(self, fmt): 46 | self._fmt = fmt 47 | else: 48 | def __set_fmt(self, fmt): 49 | self._style._fmt = fmt 50 | 51 | @staticmethod 52 | def _color_dbg(msg): 53 | return '\x1b[36m{}\x1b[0m'.format(msg) 54 | 55 | @staticmethod 56 | def _color_warn(msg): 57 | return '\x1b[1;31m{}\x1b[0m'.format(msg) 58 | 59 | @staticmethod 60 | def _color_err(msg): 61 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg) 62 | 63 | @staticmethod 64 | def _color_omitted(msg): 65 | return '\x1b[35m{}\x1b[0m'.format(msg) 66 | 67 | @staticmethod 68 | def _color_normal(msg): 69 | return msg 70 | 71 | @staticmethod 72 | def _color_date(msg): 73 | return '\x1b[32m{}\x1b[0m'.format(msg) 74 | 75 | 76 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter): 77 | logger = logging.getLogger() 78 | logger.setLevel(_default_level) 79 | del logger.handlers[:] 80 | 81 | if log_dir and log_file: 82 | if not os.path.isdir(log_dir): 83 | os.makedirs(log_dir) 84 | LogFormatter.log_fout = True 85 | file_handler = logging.FileHandler(log_file, mode='a') 86 | file_handler.setLevel(logging.INFO) 87 | file_handler.setFormatter(formatter) 88 | logger.addHandler(file_handler) 89 | 90 | stream_handler = logging.StreamHandler() 91 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S')) 92 | stream_handler.setLevel(0) 93 | logger.addHandler(stream_handler) 94 | return logger 95 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import scipy.ndimage as nd 7 | 8 | 9 | class OhemCrossEntropy2d(nn.Module): 10 | 11 | def __init__(self, ignore_label=255, thresh=0.7, min_kept=100000, factor=8): 12 | super(OhemCrossEntropy2d, self).__init__() 13 | self.ignore_label = ignore_label 14 | self.thresh = float(thresh) 15 | # self.min_kept_ratio = float(min_kept_ratio) 16 | self.min_kept = int(min_kept) 17 | self.factor = factor 18 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) 19 | 20 | def find_threshold(self, np_predict, np_target): 21 | # downsample 1/8 22 | factor = self.factor 23 | predict = nd.zoom(np_predict, (1.0, 1.0, 1.0/factor, 1.0/factor), order=1) 24 | target = nd.zoom(np_target, (1.0, 1.0/factor, 1.0/factor), order=0) 25 | 26 | n, c, h, w = predict.shape 27 | min_kept = self.min_kept // (factor*factor) #int(self.min_kept_ratio * n * h * w) 28 | 29 | input_label = target.ravel().astype(np.int32) 30 | input_prob = np.rollaxis(predict, 1).reshape((c, -1)) 31 | 32 | valid_flag = input_label != self.ignore_label 33 | valid_inds = np.where(valid_flag)[0] 34 | label = input_label[valid_flag] 35 | num_valid = valid_flag.sum() 36 | if min_kept >= num_valid: 37 | threshold = 1.0 38 | elif num_valid > 0: 39 | prob = input_prob[:,valid_flag] 40 | pred = prob[label, np.arange(len(label), dtype=np.int32)] 41 | threshold = self.thresh 42 | if min_kept > 0: 43 | k_th = min(len(pred), min_kept)-1 44 | new_array = np.partition(pred, k_th) 45 | new_threshold = new_array[k_th] 46 | if new_threshold > self.thresh: 47 | threshold = new_threshold 48 | return threshold 49 | 50 | 51 | def generate_new_target(self, predict, target): 52 | np_predict = predict.data.cpu().numpy() 53 | np_target = target.data.cpu().numpy() 54 | n, c, h, w = np_predict.shape 55 | 56 | threshold = self.find_threshold(np_predict, np_target) 57 | 58 | input_label = np_target.ravel().astype(np.int32) 59 | input_prob = np.rollaxis(np_predict, 1).reshape((c, -1)) 60 | 61 | valid_flag = input_label != self.ignore_label 62 | valid_inds = np.where(valid_flag)[0] 63 | label = input_label[valid_flag] 64 | num_valid = valid_flag.sum() 65 | 66 | if num_valid > 0: 67 | prob = input_prob[:,valid_flag] 68 | pred = prob[label, np.arange(len(label), dtype=np.int32)] 69 | kept_flag = pred <= threshold 70 | valid_inds = valid_inds[kept_flag] 71 | print('Labels: {} {}'.format(len(valid_inds), threshold)) 72 | 73 | label = input_label[valid_inds].copy() 74 | input_label.fill(self.ignore_label) 75 | input_label[valid_inds] = label 76 | new_target = torch.from_numpy(input_label.reshape(target.size())).long().cuda(target.get_device()) 77 | 78 | return new_target 79 | 80 | 81 | def forward(self, predict, target, weight=None): 82 | """ 83 | Args: 84 | predict:(n, c, h, w) 85 | target:(n, h, w) 86 | weight (Tensor, optional): a manual rescaling weight given to each class. 87 | If given, has to be a Tensor of size "nclasses" 88 | """ 89 | assert not target.requires_grad 90 | 91 | input_prob = F.softmax(predict, 1) 92 | target = self.generate_new_target(input_prob, target) 93 | return self.criterion(predict, target) 94 | -------------------------------------------------------------------------------- /utils/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | try: 13 | from itertools import ifilterfalse 14 | except ImportError: # py3k 15 | from itertools import filterfalse as ifilterfalse 16 | 17 | 18 | def lovasz_grad(gt_sorted): 19 | """ 20 | Computes gradient of the Lovasz extension w.r.t sorted errors 21 | See Alg. 1 in paper 22 | """ 23 | p = len(gt_sorted) 24 | gts = gt_sorted.sum() 25 | intersection = gts - gt_sorted.float().cumsum(0) 26 | union = gts + (1 - gt_sorted).float().cumsum(0) 27 | jaccard = 1. - intersection / union 28 | if p > 1: # cover 1-pixel case 29 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 30 | return jaccard 31 | 32 | 33 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 34 | """ 35 | IoU for foreground class 36 | binary: 1 foreground, 0 background 37 | """ 38 | if not per_image: 39 | preds, labels = (preds,), (labels,) 40 | ious = [] 41 | for pred, label in zip(preds, labels): 42 | intersection = ((label == 1) & (pred == 1)).sum() 43 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 44 | if not union: 45 | iou = EMPTY 46 | else: 47 | iou = float(intersection) / float(union) 48 | ious.append(iou) 49 | iou = mean(ious) # mean accross images if per_image 50 | return 100 * iou 51 | 52 | 53 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 54 | """ 55 | Array of IoU for each (non ignored) class 56 | """ 57 | if not per_image: 58 | preds, labels = (preds,), (labels,) 59 | ious = [] 60 | for pred, label in zip(preds, labels): 61 | iou = [] 62 | for i in range(C): 63 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 64 | intersection = ((label == i) & (pred == i)).sum() 65 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 66 | if not union: 67 | iou.append(EMPTY) 68 | else: 69 | iou.append(float(intersection) / float(union)) 70 | ious.append(iou) 71 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 72 | return 100 * np.array(ious) 73 | 74 | 75 | # --------------------------- BINARY LOSSES --------------------------- 76 | 77 | 78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 79 | """ 80 | Binary Lovasz hinge loss 81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 83 | per_image: compute the loss per image instead of per batch 84 | ignore: void class id 85 | """ 86 | if per_image: 87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 88 | for log, lab in zip(logits, labels)) 89 | else: 90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 91 | return loss 92 | 93 | 94 | def lovasz_hinge_flat(logits, labels): 95 | """ 96 | Binary Lovasz hinge loss 97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 98 | labels: [P] Tensor, binary ground truth labels (0 or 1) 99 | ignore: label to ignore 100 | """ 101 | if len(labels) == 0: 102 | # only void pixels, the gradients should be 0 103 | return logits.sum() * 0. 104 | signs = 2. * labels.float() - 1. 105 | errors = (1. - logits * Variable(signs)) 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | gt_sorted = labels[perm] 109 | grad = lovasz_grad(gt_sorted) 110 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 111 | return loss 112 | 113 | 114 | def flatten_binary_scores(scores, labels, ignore=None): 115 | """ 116 | Flattens predictions in the batch (binary case) 117 | Remove labels equal to 'ignore' 118 | """ 119 | scores = scores.view(-1) 120 | labels = labels.view(-1) 121 | if ignore is None: 122 | return scores, labels 123 | valid = (labels != ignore) 124 | vscores = scores[valid] 125 | vlabels = labels[valid] 126 | return vscores, vlabels 127 | 128 | 129 | class StableBCELoss(torch.nn.modules.Module): 130 | def __init__(self): 131 | super(StableBCELoss, self).__init__() 132 | def forward(self, input, target): 133 | neg_abs = - input.abs() 134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 135 | return loss.mean() 136 | 137 | 138 | def binary_xloss(logits, labels, ignore=None): 139 | """ 140 | Binary Cross entropy loss 141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 143 | ignore: void class id 144 | """ 145 | logits, labels = flatten_binary_scores(logits, labels, ignore) 146 | loss = StableBCELoss()(logits, Variable(labels.float())) 147 | return loss 148 | 149 | 150 | # --------------------------- MULTICLASS LOSSES --------------------------- 151 | 152 | 153 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 154 | """ 155 | Multi-class Lovasz-Softmax loss 156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 157 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 158 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 159 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 160 | per_image: compute the loss per image instead of per batch 161 | ignore: void class labels 162 | """ 163 | if per_image: 164 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 165 | for prob, lab in zip(probas, labels)) 166 | else: 167 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 168 | return loss 169 | 170 | 171 | def lovasz_softmax_flat(probas, labels, classes='present'): 172 | """ 173 | Multi-class Lovasz-Softmax loss 174 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 175 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 176 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 177 | """ 178 | if probas.numel() == 0: 179 | # only void pixels, the gradients should be 0 180 | return probas * 0. 181 | C = probas.size(1) 182 | losses = [] 183 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 184 | for c in class_to_sum: 185 | fg = (labels == c).float() # foreground for class c 186 | if (classes is 'present' and fg.sum() == 0): 187 | continue 188 | if C == 1: 189 | if len(classes) > 1: 190 | raise ValueError('Sigmoid output possible only with 1 class') 191 | class_pred = probas[:, 0] 192 | else: 193 | class_pred = probas[:, c] 194 | errors = (Variable(fg) - class_pred).abs() 195 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 196 | perm = perm.data 197 | fg_sorted = fg[perm] 198 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 199 | return mean(losses) 200 | 201 | 202 | def flatten_probas(probas, labels, ignore=None): 203 | """ 204 | Flattens predictions in the batch 205 | """ 206 | if probas.dim() == 3: 207 | # assumes output of a sigmoid layer 208 | B, H, W = probas.size() 209 | probas = probas.view(B, 1, H, W) 210 | B, C, H, W = probas.size() 211 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 212 | labels = labels.view(-1) 213 | if ignore is None: 214 | return probas, labels 215 | valid = (labels != ignore) 216 | vprobas = probas[valid.nonzero().squeeze()] 217 | vlabels = labels[valid] 218 | return vprobas, vlabels 219 | 220 | def xloss(logits, labels, ignore=None): 221 | """ 222 | Cross entropy loss 223 | """ 224 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 225 | 226 | 227 | # --------------------------- HELPER FUNCTIONS --------------------------- 228 | def isnan(x): 229 | return x != x 230 | 231 | 232 | def mean(l, ignore_nan=False, empty=0): 233 | """ 234 | nanmean compatible with generators. 235 | """ 236 | l = iter(l) 237 | if ignore_nan: 238 | l = ifilterfalse(isnan, l) 239 | try: 240 | n = 1 241 | acc = next(l) 242 | except StopIteration: 243 | if empty == 'raise': 244 | raise ValueError('Empty mean') 245 | return empty 246 | for n, v in enumerate(l, 2): 247 | acc += v 248 | if n == 1: 249 | return acc 250 | return acc / n 251 | -------------------------------------------------------------------------------- /utils/miou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import json 5 | from collections import OrderedDict 6 | import argparse 7 | from PIL import Image as PILImage 8 | from utils.transforms import transform_parsing 9 | 10 | LABELS = ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', \ 11 | 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 12 | 'Right-leg', 'Left-shoe', 'Right-shoe'] 13 | 14 | def get_palette(num_cls): 15 | """ Returns the color map for visualizing the segmentation mask. 16 | Args: 17 | num_cls: Number of classes 18 | Returns: 19 | The color map 20 | """ 21 | 22 | n = num_cls 23 | palette = [0] * (n * 3) 24 | for j in range(0, n): 25 | lab = j 26 | palette[j * 3 + 0] = 0 27 | palette[j * 3 + 1] = 0 28 | palette[j * 3 + 2] = 0 29 | i = 0 30 | while lab: 31 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 32 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 33 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 34 | i += 1 35 | lab >>= 3 36 | return palette 37 | 38 | def get_confusion_matrix(gt_label, pred_label, num_classes): 39 | """ 40 | Calcute the confusion matrix by given label and pred 41 | :param gt_label: the ground truth label 42 | :param pred_label: the pred label 43 | :param num_classes: the nunber of class 44 | :return: the confusion matrix 45 | """ 46 | index = (gt_label * num_classes + pred_label).astype('int32') 47 | label_count = np.bincount(index) 48 | confusion_matrix = np.zeros((num_classes, num_classes)) 49 | 50 | for i_label in range(num_classes): 51 | for i_pred_label in range(num_classes): 52 | cur_index = i_label * num_classes + i_pred_label 53 | if cur_index < len(label_count): 54 | confusion_matrix[i_label, i_pred_label] = label_count[cur_index] 55 | 56 | return confusion_matrix 57 | 58 | 59 | def compute_mean_ioU(preds, scales, centers, num_classes, datadir, input_size=[473, 473], dataset='val'): 60 | list_path = os.path.join(datadir, dataset + '_id.txt') 61 | val_id = [i_id.strip() for i_id in open(list_path)] 62 | 63 | confusion_matrix = np.zeros((num_classes, num_classes)) 64 | 65 | for i, im_name in enumerate(val_id): 66 | gt_path = os.path.join(datadir, dataset + '_segmentations', im_name + '.png') 67 | gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 68 | h, w = gt.shape 69 | pred_out = preds[i] 70 | s = scales[i] 71 | c = centers[i] 72 | pred = transform_parsing(pred_out, c, s, w, h, input_size) 73 | 74 | gt = np.asarray(gt, dtype=np.int32) 75 | pred = np.asarray(pred, dtype=np.int32) 76 | 77 | ignore_index = gt != 255 78 | 79 | gt = gt[ignore_index] 80 | pred = pred[ignore_index] 81 | 82 | confusion_matrix += get_confusion_matrix(gt, pred, num_classes) 83 | 84 | pos = confusion_matrix.sum(1) 85 | res = confusion_matrix.sum(0) 86 | tp = np.diag(confusion_matrix) 87 | 88 | pixel_accuracy = (tp.sum() / pos.sum()) * 100 89 | mean_accuracy = ((tp / np.maximum(1.0, pos)).mean()) * 100 90 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 91 | IoU_array = IoU_array * 100 92 | mean_IoU = IoU_array.mean() 93 | print('Pixel accuracy: %f \n' % pixel_accuracy) 94 | print('Mean accuracy: %f \n' % mean_accuracy) 95 | print('Mean IU: %f \n' % mean_IoU) 96 | name_value = [] 97 | 98 | for i, (label, iou) in enumerate(zip(LABELS, IoU_array)): 99 | name_value.append((label, iou)) 100 | 101 | name_value.append(('Pixel accuracy', pixel_accuracy)) 102 | name_value.append(('Mean accuracy', mean_accuracy)) 103 | name_value.append(('Mean IU', mean_IoU)) 104 | name_value = OrderedDict(name_value) 105 | return name_value 106 | 107 | def compute_mean_ioU_file(preds_dir, num_classes, datadir, dataset='val'): 108 | list_path = os.path.join(datadir, dataset + '_id.txt') 109 | val_id = [i_id.strip() for i_id in open(list_path)] 110 | 111 | confusion_matrix = np.zeros((num_classes, num_classes)) 112 | 113 | for i, im_name in enumerate(val_id): 114 | gt_path = os.path.join(datadir, dataset + '_segmentations', im_name + '.png') 115 | gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 116 | 117 | pred_path = os.path.join(preds_dir, im_name + '.png') 118 | pred = np.asarray(PILImage.open(pred_path)) 119 | 120 | gt = np.asarray(gt, dtype=np.int32) 121 | pred = np.asarray(pred, dtype=np.int32) 122 | 123 | ignore_index = gt != 255 124 | 125 | gt = gt[ignore_index] 126 | pred = pred[ignore_index] 127 | 128 | confusion_matrix += get_confusion_matrix(gt, pred, num_classes) 129 | 130 | pos = confusion_matrix.sum(1) 131 | res = confusion_matrix.sum(0) 132 | tp = np.diag(confusion_matrix) 133 | 134 | pixel_accuracy = (tp.sum() / pos.sum())*100 135 | mean_accuracy = ((tp / np.maximum(1.0, pos)).mean())*100 136 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 137 | IoU_array = IoU_array*100 138 | mean_IoU = IoU_array.mean() 139 | print('Pixel accuracy: %f \n' % pixel_accuracy) 140 | print('Mean accuracy: %f \n' % mean_accuracy) 141 | print('Mean IU: %f \n' % mean_IoU) 142 | name_value = [] 143 | 144 | for i, (label, iou) in enumerate(zip(LABELS, IoU_array)): 145 | name_value.append((label, iou)) 146 | 147 | name_value.append(('Pixel accuracy', pixel_accuracy)) 148 | name_value.append(('Mean accuracy', mean_accuracy)) 149 | name_value.append(('Mean IU', mean_IoU)) 150 | name_value = OrderedDict(name_value) 151 | return name_value 152 | 153 | def write_results(preds, scales, centers, datadir, dataset, result_dir, input_size=[473, 473]): 154 | palette = get_palette(20) 155 | if not os.path.exists(result_dir): 156 | os.makedirs(result_dir) 157 | 158 | json_file = os.path.join(datadir, 'annotations', dataset + '.json') 159 | with open(json_file) as data_file: 160 | data_list = json.load(data_file) 161 | data_list = data_list['root'] 162 | for item, pred_out, s, c in zip(data_list, preds, scales, centers): 163 | im_name = item['im_name'] 164 | w = item['img_width'] 165 | h = item['img_height'] 166 | pred = transform_parsing(pred_out, c, s, w, h, input_size) 167 | #pred = pred_out 168 | save_path = os.path.join(result_dir, im_name[:-4]+'.png') 169 | 170 | output_im = PILImage.fromarray(np.asarray(pred, dtype=np.uint8)) 171 | output_im.putpalette(palette) 172 | output_im.save(save_path) 173 | 174 | def get_arguments(): 175 | """Parse all the arguments provided from the CLI. 176 | 177 | Returns: 178 | A list of parsed arguments. 179 | """ 180 | parser = argparse.ArgumentParser(description="DeepLabLFOV NetworkEv") 181 | parser.add_argument("--pred-path", type=str, default='', 182 | help="Path to predicted segmentation.") 183 | parser.add_argument("--gt-path", type=str, default='', 184 | help="Path to the groundtruth dir.") 185 | 186 | return parser.parse_args() 187 | 188 | 189 | if __name__ == "__main__": 190 | args = get_arguments() 191 | palette = get_palette(20) 192 | # im_path = '/ssd1/liuting14/Dataset/LIP/val_segmentations/100034_483681.png' 193 | # #compute_mean_ioU_file(args.pred_path, 20, args.gt_path, 'val') 194 | # im = cv2.imread(im_path, cv2.IMREAD_GRAYSCALE) 195 | # print(im.shape) 196 | # test = np.asarray( PILImage.open(im_path)) 197 | # print(test.shape) 198 | # if im.all()!=test.all(): 199 | # print('different') 200 | # output_im = PILImage.fromarray(np.zeros((100,100), dtype=np.uint8)) 201 | # output_im.putpalette(palette) 202 | # output_im.save('test.png') 203 | pred_dir = '/ssd1/liuting14/exps/lip/snapshots/results/epoch4/' 204 | num_classes = 20 205 | datadir = '/ssd1/liuting14/Dataset/LIP/' 206 | compute_mean_ioU_file(pred_dir, num_classes, datadir, dataset='val') -------------------------------------------------------------------------------- /utils/pyt_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import os 3 | import sys 4 | import time 5 | import argparse 6 | from collections import OrderedDict, defaultdict 7 | 8 | import torch 9 | import torch.utils.model_zoo as model_zoo 10 | import torch.distributed as dist 11 | 12 | from .logger import get_logger 13 | 14 | logger = get_logger() 15 | 16 | # colour map 17 | label_colours = [(0,0,0) 18 | # 0=background 19 | ,(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128) 20 | # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle 21 | ,(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0) 22 | # 6=bus, 7=car, 8=cat, 9=chair, 10=cow 23 | ,(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128) 24 | # 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person 25 | ,(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)] 26 | # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor 27 | 28 | 29 | def reduce_tensor(tensor, dst=0, op=dist.ReduceOp.SUM, world_size=1): 30 | tensor = tensor.clone() 31 | dist.reduce(tensor, dst, op) 32 | if dist.get_rank() == dst: 33 | tensor.div_(world_size) 34 | 35 | return tensor 36 | 37 | 38 | def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM, world_size=1, norm=True): 39 | tensor = tensor.clone() 40 | dist.all_reduce(tensor, op) 41 | if norm: 42 | tensor.div_(world_size) 43 | 44 | return tensor 45 | 46 | 47 | def load_model(model, model_file, is_restore=False): 48 | t_start = time.time() 49 | if isinstance(model_file, str): 50 | device = torch.device('cpu') 51 | state_dict = torch.load(model_file, map_location=device) 52 | if 'model' in state_dict.keys(): 53 | state_dict = state_dict['model'] 54 | else: 55 | state_dict = model_file 56 | t_ioend = time.time() 57 | 58 | if is_restore: 59 | new_state_dict = OrderedDict() 60 | for k, v in state_dict.items(): 61 | name = 'module.' + k 62 | new_state_dict[name] = v 63 | state_dict = new_state_dict 64 | 65 | model.load_state_dict(state_dict, strict=False) 66 | ckpt_keys = set(state_dict.keys()) 67 | own_keys = set(model.state_dict().keys()) 68 | missing_keys = own_keys - ckpt_keys 69 | unexpected_keys = ckpt_keys - own_keys 70 | 71 | if len(missing_keys) > 0: 72 | logger.warning('Missing key(s) in state_dict: {}'.format( 73 | ', '.join('{}'.format(k) for k in missing_keys))) 74 | 75 | if len(unexpected_keys) > 0: 76 | logger.warning('Unexpected key(s) in state_dict: {}'.format( 77 | ', '.join('{}'.format(k) for k in unexpected_keys))) 78 | 79 | del state_dict 80 | t_end = time.time() 81 | logger.info( 82 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( 83 | t_ioend - t_start, t_end - t_ioend)) 84 | 85 | return model 86 | 87 | 88 | def parse_devices(input_devices): 89 | if input_devices.endswith('*'): 90 | devices = list(range(torch.cuda.device_count())) 91 | return devices 92 | 93 | devices = [] 94 | for d in input_devices.split(','): 95 | if '-' in d: 96 | start_device, end_device = d.split('-')[0], d.split('-')[1] 97 | assert start_device != '' 98 | assert end_device != '' 99 | start_device, end_device = int(start_device), int(end_device) 100 | assert start_device < end_device 101 | assert end_device < torch.cuda.device_count() 102 | for sd in range(start_device, end_device + 1): 103 | devices.append(sd) 104 | else: 105 | device = int(d) 106 | assert device < torch.cuda.device_count() 107 | devices.append(device) 108 | 109 | logger.info('using devices {}'.format( 110 | ', '.join([str(d) for d in devices]))) 111 | 112 | return devices 113 | 114 | 115 | def extant_file(x): 116 | """ 117 | 'Type' for argparse - checks that file exists but does not open. 118 | """ 119 | if not os.path.exists(x): 120 | # Argparse uses the ArgumentTypeError to give a rejection message like: 121 | # error: argument input: x does not exist 122 | raise argparse.ArgumentTypeError("{0} does not exist".format(x)) 123 | return x 124 | 125 | 126 | def link_file(src, target): 127 | if os.path.isdir(target) or os.path.isfile(target): 128 | os.remove(target) 129 | os.system('ln -s {} {}'.format(src, target)) 130 | 131 | 132 | def ensure_dir(path): 133 | if not os.path.isdir(path): 134 | os.makedirs(path) 135 | 136 | 137 | def _dbg_interactive(var, value): 138 | from IPython import embed 139 | embed() 140 | 141 | def decode_labels(mask, num_images=1, num_classes=21): 142 | """Decode batch of segmentation masks. 143 | 144 | Args: 145 | mask: result of inference after taking argmax. 146 | num_images: number of images to decode from the batch. 147 | num_classes: number of classes to predict (including background). 148 | 149 | Returns: 150 | A batch with num_images RGB images of the same size as the input. 151 | """ 152 | mask = mask.data.cpu().numpy() 153 | n, h, w = mask.shape 154 | assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) 155 | outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) 156 | for i in range(num_images): 157 | img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) 158 | pixels = img.load() 159 | for j_, j in enumerate(mask[i, :, :]): 160 | for k_, k in enumerate(j): 161 | if k < num_classes: 162 | pixels[k_,j_] = label_colours[k] 163 | outputs[i] = np.array(img) 164 | return outputs 165 | 166 | def decode_predictions(preds, num_images=1, num_classes=21): 167 | """Decode batch of segmentation masks. 168 | 169 | Args: 170 | mask: result of inference after taking argmax. 171 | num_images: number of images to decode from the batch. 172 | num_classes: number of classes to predict (including background). 173 | 174 | Returns: 175 | A batch with num_images RGB images of the same size as the input. 176 | """ 177 | if isinstance(preds, list): 178 | preds_list = [] 179 | for pred in preds: 180 | preds_list.append(pred[-1].data.cpu().numpy()) 181 | preds = np.concatenate(preds_list, axis=0) 182 | else: 183 | preds = preds.data.cpu().numpy() 184 | 185 | preds = np.argmax(preds, axis=1) 186 | n, h, w = preds.shape 187 | assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) 188 | outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) 189 | for i in range(num_images): 190 | img = Image.new('RGB', (len(preds[i, 0]), len(preds[i]))) 191 | pixels = img.load() 192 | for j_, j in enumerate(preds[i, :, :]): 193 | for k_, k in enumerate(j): 194 | if k < num_classes: 195 | pixels[k_,j_] = label_colours[k] 196 | outputs[i] = np.array(img) 197 | return outputs 198 | 199 | def inv_preprocess(imgs, num_images, img_mean): 200 | """Inverse preprocessing of the batch of images. 201 | Add the mean vector and convert from BGR to RGB. 202 | 203 | Args: 204 | imgs: batch of input images. 205 | num_images: number of images to apply the inverse transformations on. 206 | img_mean: vector of mean colour values. 207 | 208 | Returns: 209 | The batch of the size num_images with the same spatial dimensions as the input. 210 | """ 211 | imgs = imgs.data.cpu().numpy() 212 | n, c, h, w = imgs.shape 213 | assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) 214 | outputs = np.zeros((num_images, h, w, c), dtype=np.uint8) 215 | for i in range(num_images): 216 | outputs[i] = (np.transpose(imgs[i], (1,2,0)) + img_mean).astype(np.uint8) 217 | return outputs 218 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import numpy as np 12 | import cv2 13 | 14 | 15 | def flip_back(output_flipped, matched_parts): 16 | ''' 17 | ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width) 18 | ''' 19 | assert output_flipped.ndim == 4,\ 20 | 'output_flipped should be [batch_size, num_joints, height, width]' 21 | 22 | output_flipped = output_flipped[:, :, :, ::-1] 23 | 24 | for pair in matched_parts: 25 | tmp = output_flipped[:, pair[0], :, :].copy() 26 | output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] 27 | output_flipped[:, pair[1], :, :] = tmp 28 | 29 | return output_flipped 30 | 31 | 32 | def transform_parsing(pred, center, scale, width, height, input_size): 33 | 34 | trans = get_affine_transform(center, scale, 0, input_size, inv=1) 35 | target_pred = cv2.warpAffine( 36 | pred, 37 | trans, 38 | (int(width), int(height)), #(int(width), int(height)), 39 | flags=cv2.INTER_NEAREST, 40 | borderMode=cv2.BORDER_CONSTANT, 41 | borderValue=(0)) 42 | 43 | return target_pred 44 | 45 | 46 | def get_affine_transform(center, 47 | scale, 48 | rot, 49 | output_size, 50 | shift=np.array([0, 0], dtype=np.float32), 51 | inv=0): 52 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 53 | print(scale) 54 | scale = np.array([scale, scale]) 55 | 56 | scale_tmp = scale 57 | 58 | src_w = scale_tmp[0] 59 | dst_w = output_size[1] 60 | dst_h = output_size[0] 61 | 62 | rot_rad = np.pi * rot / 180 63 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 64 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 65 | 66 | src = np.zeros((3, 2), dtype=np.float32) 67 | dst = np.zeros((3, 2), dtype=np.float32) 68 | src[0, :] = center + scale_tmp * shift 69 | src[1, :] = center + src_dir + scale_tmp * shift 70 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 71 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 72 | 73 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 74 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 75 | 76 | if inv: 77 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 78 | else: 79 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 80 | 81 | return trans 82 | 83 | 84 | def affine_transform(pt, t): 85 | new_pt = np.array([pt[0], pt[1], 1.]).T 86 | new_pt = np.dot(t, new_pt) 87 | return new_pt[:2] 88 | 89 | 90 | def get_3rd_point(a, b): 91 | direct = a - b 92 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 93 | 94 | 95 | def get_dir(src_point, rot_rad): 96 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 97 | 98 | src_result = [0, 0] 99 | src_result[0] = src_point[0] * cs - src_point[1] * sn 100 | src_result[1] = src_point[0] * sn + src_point[1] * cs 101 | 102 | return src_result 103 | 104 | 105 | def crop(img, center, scale, output_size, rot=0): 106 | trans = get_affine_transform(center, scale, rot, output_size) 107 | 108 | dst_img = cv2.warpAffine(img, 109 | trans, 110 | (int(output_size[1]), int(output_size[0])), 111 | flags=cv2.INTER_LINEAR) 112 | 113 | return dst_img 114 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torchvision 4 | import torch 5 | 6 | # colour map 7 | COLORS = [(0,0,0) 8 | # 0=background 9 | ,(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128) 10 | # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle 11 | ,(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0) 12 | # 6=bus, 7=car, 8=cat, 9=chair, 10=cow 13 | ,(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128) 14 | # 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person 15 | ,(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)] 16 | # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor 17 | 18 | 19 | def decode_parsing(labels, num_images=1, num_classes=21, is_pred=False): 20 | """Decode batch of segmentation masks. 21 | 22 | Args: 23 | mask: result of inference after taking argmax. 24 | num_images: number of images to decode from the batch. 25 | num_classes: number of classes to predict (including background). 26 | 27 | Returns: 28 | A batch with num_images RGB images of the same size as the input. 29 | """ 30 | pred_labels = labels[:num_images].clone().cpu().data 31 | if is_pred: 32 | pred_labels = torch.argmax(pred_labels, dim=1) 33 | n, h, w = pred_labels.size() 34 | 35 | labels_color = torch.zeros([n, 3, h, w], dtype=torch.uint8) 36 | for i, c in enumerate(COLORS): 37 | c0 = labels_color[:, 0, :, :] 38 | c1 = labels_color[:, 1, :, :] 39 | c2 = labels_color[:, 2, :, :] 40 | 41 | c0[pred_labels == i] = c[0] 42 | c1[pred_labels == i] = c[1] 43 | c2[pred_labels == i] = c[2] 44 | 45 | return labels_color 46 | 47 | def inv_preprocess(imgs, num_images): 48 | """Inverse preprocessing of the batch of images. 49 | Add the mean vector and convert from BGR to RGB. 50 | 51 | Args: 52 | imgs: batch of input images. 53 | num_images: number of images to apply the inverse transformations on. 54 | img_mean: vector of mean colour values. 55 | 56 | Returns: 57 | The batch of the size num_images with the same spatial dimensions as the input. 58 | """ 59 | rev_imgs = imgs[:num_images].clone().cpu().data 60 | rev_normalize = NormalizeInverse(mean=[0.485, 0.456, 0.406], 61 | std=[0.229, 0.224, 0.225]) 62 | for i in range(num_images): 63 | rev_imgs[i] = rev_normalize(rev_imgs[i]) 64 | 65 | return rev_imgs 66 | 67 | class NormalizeInverse(torchvision.transforms.Normalize): 68 | """ 69 | Undoes the normalization and returns the reconstructed images in the input domain. 70 | """ 71 | 72 | def __init__(self, mean, std): 73 | mean = torch.as_tensor(mean) 74 | std = torch.as_tensor(std) 75 | std_inv = 1 / (std + 1e-7) 76 | mean_inv = -mean * std_inv 77 | super().__init__(mean=mean_inv, std=std_inv) 78 | -------------------------------------------------------------------------------- /utils/writejson.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import cv2 4 | 5 | json_file = os.path.join('/ssd1/liuting14/Dataset/LIP', 'annotations', 'test.json') 6 | 7 | with open(json_file) as data_file: 8 | data_json = json.load(data_file) 9 | data_list = data_json['root'] 10 | 11 | for item in data_list: 12 | name = item['im_name'] 13 | im_path = os.path.join('/ssd1/liuting14/Dataset/LIP', 'test_images', name) 14 | im = cv2.imread(im_path, cv2.IMREAD_COLOR) 15 | h, w, c = im.shape 16 | item['img_height'] = h 17 | item['img_width'] = w 18 | item['center'] = [h/2, w/2] 19 | 20 | with open(json_file, "w") as f: 21 | json.dump(data_json, f, indent=2) 22 | --------------------------------------------------------------------------------