├── README.md ├── aff_infer.py ├── aff_prepare.py ├── aff_train.py ├── contrast_infer.py ├── contrast_train.py ├── eval.py ├── network ├── resnet38_SEAM.py ├── resnet38_aff.py ├── resnet38_contrast.py └── resnet38d.py ├── script ├── script_cls.sh └── script_contrast.sh ├── segmentation ├── .DS_Store ├── experiment │ ├── EPS_deeplabv1_resnet101 │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config.py │ │ ├── test.py │ │ └── train.py │ ├── EPS_deeplabv2_resnet101 │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config.py │ │ ├── test.py │ │ └── train.py │ └── SEAM_deeplabv1_resnet38 │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config.py │ │ ├── test.py │ │ └── train.py └── lib │ ├── .DS_Store │ ├── datasets │ ├── ADE20KDataset.py │ ├── BaseDataset.py │ ├── COCODataset.py │ ├── CityscapesDataset.py │ ├── ContextDataset.py │ ├── VOCDataset.py │ ├── __init__.py │ ├── generateData.py │ ├── metric.py │ └── transform.py │ ├── net │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── resnet.py │ │ ├── resnet38d.py │ │ └── xception.py │ ├── deeplabv1.py │ ├── deeplabv2.py │ ├── deeplabv3.py │ ├── deeplabv3plus.py │ ├── generateNet.py │ ├── operators │ │ ├── ASPP.py │ │ ├── PPM.py │ │ └── __init__.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── batchnorm_reimpl.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ │ ├── tests │ │ ├── test_numeric_batchnorm.py │ │ └── test_sync_batchnorm.py │ │ └── unittest.py │ └── utils │ ├── DenseCRF.py │ ├── __init__.py │ ├── configuration.py │ ├── finalprocess.py │ ├── imutils.py │ ├── registry.py │ ├── test_utils.py │ └── visualization.py ├── tool ├── imutils.py ├── pyutils.py ├── torchutils.py └── visualization.py ├── utils ├── __init__.py └── util.py └── voc12 ├── __init__.py ├── cls_labels.npy ├── data.py ├── make_cls_labels.py ├── test.txt ├── train.txt ├── train_aug.txt ├── trainaug_val.txt ├── val.txt └── voc_saliency.py /aff_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from tool import imutils 4 | import argparse 5 | import importlib 6 | import numpy as np 7 | import voc12.data 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import os.path 11 | import imageio 12 | from tqdm import tqdm 13 | 14 | def get_indices_in_radius(height, width, radius): 15 | 16 | search_dist = [] 17 | for x in range(1, radius): 18 | search_dist.append((0, x)) 19 | 20 | for y in range(1, radius): 21 | for x in range(-radius+1, radius): 22 | if x*x + y*y < radius*radius: 23 | search_dist.append((y, x)) 24 | 25 | full_indices = np.reshape(np.arange(0, height * width, dtype=np.int64), 26 | (height, width)) 27 | radius_floor = radius-1 28 | cropped_height = height - radius_floor 29 | cropped_width = width - 2 * radius_floor 30 | 31 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], [-1]) 32 | 33 | indices_from_to_list = [] 34 | 35 | for dy, dx in search_dist: 36 | 37 | indices_to = full_indices[dy:dy + cropped_height, radius_floor + dx:radius_floor + dx + cropped_width] 38 | indices_to = np.reshape(indices_to, [-1]) 39 | 40 | indices_from_to = np.stack((indices_from, indices_to), axis=1) 41 | 42 | indices_from_to_list.append(indices_from_to) 43 | 44 | concat_indices_from_to = np.concatenate(indices_from_to_list, axis=0) 45 | 46 | return concat_indices_from_to 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--weights", required=True, type=str) 53 | parser.add_argument("--network", default="network.resnet38_aff", type=str) 54 | parser.add_argument("--infer_list", default="voc12/val.txt", type=str) 55 | parser.add_argument("--num_workers", default=8, type=int) 56 | parser.add_argument("--cam_dir", required=True, type=str) 57 | parser.add_argument("--voc12_root", default='VOC2012', type=str) 58 | parser.add_argument("--alpha", default=6, type=float) 59 | parser.add_argument("--out_rw", default='out_rw', type=str) 60 | parser.add_argument("--beta", default=8, type=int) 61 | parser.add_argument("--logt", default=6, type=int) 62 | parser.add_argument("--crf", default=False, type=bool) 63 | 64 | args = parser.parse_args() 65 | 66 | if not os.path.exists(args.out_rw): 67 | os.makedirs(args.out_rw) 68 | 69 | model = getattr(importlib.import_module(args.network), 'Net')() 70 | 71 | model.load_state_dict(torch.load(args.weights), strict=False) 72 | 73 | model.eval() 74 | model.cuda() 75 | 76 | infer_dataset = voc12.data.VOC12ImageDataset(args.infer_list, voc12_root=args.voc12_root, 77 | transform=torchvision.transforms.Compose([np.asarray, 78 | model.normalize, 79 | imutils.HWC_to_CHW])) 80 | infer_data_loader = DataLoader(infer_dataset, shuffle=False, num_workers=args.num_workers, pin_memory=True) 81 | 82 | for iter, (name, img) in tqdm(enumerate(infer_data_loader), total=len(infer_data_loader)): 83 | 84 | name = name[0] 85 | # print(iter) 86 | 87 | orig_shape = img.shape 88 | padded_size = (int(np.ceil(img.shape[2]/8)*8), int(np.ceil(img.shape[3]/8)*8)) 89 | 90 | p2d = (0, padded_size[1] - img.shape[3], 0, padded_size[0] - img.shape[2]) 91 | img = F.pad(img, p2d) 92 | 93 | dheight = int(np.ceil(img.shape[2]/8)) 94 | dwidth = int(np.ceil(img.shape[3]/8)) 95 | 96 | cam = np.load(os.path.join(args.cam_dir, name + '.npy'), allow_pickle=True).item() 97 | 98 | cam_full_arr = np.zeros((21, orig_shape[2], orig_shape[3]), np.float32) 99 | for k, v in cam.items(): 100 | cam_full_arr[k+1] = v 101 | 102 | cam_full_arr[0] = 0.27 103 | cam_full_arr = np.pad(cam_full_arr, ((0, 0), (0, p2d[3]), (0, p2d[1])), mode='constant') 104 | 105 | with torch.no_grad(): 106 | aff_mat = torch.pow(model.forward(img.cuda(), True), args.beta) 107 | 108 | trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True) 109 | for _ in range(args.logt): 110 | trans_mat = torch.matmul(trans_mat, trans_mat) 111 | 112 | cam_full_arr = torch.from_numpy(cam_full_arr) 113 | cam_full_arr = F.avg_pool2d(cam_full_arr, 8, 8) 114 | 115 | cam_vec = cam_full_arr.view(21, -1) 116 | cam_rw = torch.matmul(cam_vec.cuda(), trans_mat) 117 | cam_rw = cam_rw.view(1, 21, dheight, dwidth) 118 | 119 | cam_rw = torch.nn.Upsample((img.shape[2], img.shape[3]), mode='bilinear')(cam_rw) 120 | 121 | # if args.crf: 122 | # img_8 = img[0].numpy().transpose((1,2,0)) #F.interpolate(img, (dheight,dwidth), mode='bilinear')[0].numpy().transpose((1,2,0)) 123 | # img_8 = np.ascontiguousarray(img_8) 124 | # mean = (0.485, 0.456, 0.406) 125 | # std = (0.229, 0.224, 0.225) 126 | # img_8[:,:,0] = (img_8[:,:,0]*std[0] + mean[0])*255 127 | # img_8[:,:,1] = (img_8[:,:,1]*std[1] + mean[1])*255 128 | # img_8[:,:,2] = (img_8[:,:,2]*std[2] + mean[2])*255 129 | # img_8[img_8 > 255] = 255 130 | # img_8[img_8 < 0] = 0 131 | # img_8 = img_8.astype(np.uint8) 132 | # cam_rw = cam_rw[0].cpu().numpy() 133 | # cam_rw = imutils.crf_inference(img_8, cam_rw, t=1) 134 | # cam_rw = torch.from_numpy(cam_rw).view(1, 21, img.shape[2], img.shape[3]).cuda() 135 | 136 | _, cam_rw_pred = torch.max(cam_rw, 1) 137 | 138 | res = np.uint8(cam_rw_pred.cpu().data[0])[:orig_shape[2], :orig_shape[3]] 139 | 140 | # scipy.misc.imsave(os.path.join(args.out_rw, name + '.png'), res) 141 | imageio.imwrite(os.path.join(args.out_rw, name + '.png'), res) 142 | -------------------------------------------------------------------------------- /aff_prepare.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | from PIL import Image 5 | import pandas as pd 6 | import multiprocessing 7 | import pydensecrf.densecrf as dcrf 8 | from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--infer_list", default="./VOC2012/ImageSets/Segmentation/trainaug.txt", type=str) 14 | parser.add_argument("--num_workers", default=8, type=int) 15 | parser.add_argument("--voc12_root", default='VOC2012', type=str) 16 | parser.add_argument("--cam_dir", default=None, type=str) 17 | parser.add_argument("--out_crf", default=None, type=str) 18 | parser.add_argument("--crf_iters", default=10, type=float) 19 | parser.add_argument("--alpha", default=4, type=float) 20 | 21 | args = parser.parse_args() 22 | 23 | assert args.cam_dir is not None 24 | 25 | if args.out_crf: 26 | if not os.path.exists(args.out_crf): 27 | os.makedirs(args.out_crf) 28 | 29 | df = pd.read_csv(args.infer_list, names=['filename']) 30 | name_list = df['filename'].values 31 | 32 | 33 | # https://github.com/pigcv/AdvCAM/blob/fa08f0ad4c1f764f3ccaf36883c0ae43342d34c5/misc/imutils.py#L156 34 | def _crf_inference(img, labels, t=10, n_labels=21, gt_prob=0.7): 35 | h, w = img.shape[:2] 36 | d = dcrf.DenseCRF2D(w, h, n_labels) 37 | U = unary_from_labels(labels, 21, gt_prob=gt_prob, zero_unsure=False) 38 | d.setUnaryEnergy(U) 39 | feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2]) 40 | d.addPairwiseEnergy(feats, compat=3, 41 | kernel=dcrf.DIAG_KERNEL, 42 | normalization=dcrf.NORMALIZE_SYMMETRIC) 43 | feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13), 44 | img=img, chdim=2) 45 | d.addPairwiseEnergy(feats, compat=10, 46 | kernel=dcrf.DIAG_KERNEL, 47 | normalization=dcrf.NORMALIZE_SYMMETRIC) 48 | Q = d.inference(t) 49 | 50 | return np.array(Q).reshape((n_labels, h, w)) 51 | 52 | 53 | def _infer_crf_with_alpha(start, step, alpha): 54 | for idx in range(start, len(name_list), step): 55 | name = name_list[idx] 56 | cam_file = os.path.join(args.cam_dir, '%s.npy' % name) 57 | cam_dict = np.load(cam_file, allow_pickle=True).item() 58 | h, w = list(cam_dict.values())[0].shape 59 | tensor = np.zeros((21, h, w), np.float32) 60 | for key in cam_dict.keys(): 61 | tensor[key + 1] = cam_dict[key] 62 | tensor[0, :, :] = np.power(1 - np.max(tensor, axis=0, keepdims=True), alpha) 63 | 64 | predict = np.argmax(tensor, axis=0).astype(np.uint8) 65 | img = Image.open(os.path.join('./VOC2012/JPEGImages', name + '.jpg')).convert("RGB") 66 | img = np.array(img) 67 | crf_array = _crf_inference(img, predict) 68 | 69 | crf_folder = args.out_crf + ('/%.2f' % alpha) 70 | if not os.path.exists(crf_folder): 71 | os.makedirs(crf_folder) 72 | 73 | np.save(os.path.join(crf_folder, name + '.npy'), crf_array) 74 | 75 | 76 | alpha_list = [4, 8, 16, 24, 32] 77 | 78 | for alpha in alpha_list: 79 | p_list = [] 80 | for i in range(8): 81 | p = multiprocessing.Process(target=_infer_crf_with_alpha, args=(i, 8, args.alpha)) 82 | p.start() 83 | p_list.append(p) 84 | for p in p_list: 85 | p.join() 86 | print(f'Info: Alpha {alpha} done!') 87 | -------------------------------------------------------------------------------- /aff_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms 6 | import voc12.data 7 | from tool import pyutils, imutils, torchutils 8 | import argparse 9 | import importlib 10 | 11 | if __name__ == '__main__': 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--batch_size", default=8, type=int) 15 | parser.add_argument("--max_epoches", default=8, type=int) 16 | parser.add_argument("--network", default="network.resnet38_aff", type=str) 17 | parser.add_argument("--lr", default=0.01, type=float) 18 | parser.add_argument("--num_workers", default=8, type=int) 19 | parser.add_argument("--wt_dec", default=5e-4, type=float) 20 | parser.add_argument("--train_list", default="voc12/train_aug.txt", type=str) 21 | parser.add_argument("--val_list", default="voc12/val.txt", type=str) 22 | parser.add_argument("--session_name", default="resnet38_aff", type=str) 23 | parser.add_argument("--crop_size", default=448, type=int) 24 | parser.add_argument("--weights", required=True, type=str) 25 | parser.add_argument("--voc12_root", default='VOC2012', type=str) 26 | parser.add_argument("--la_crf_dir", required=True, type=str) 27 | parser.add_argument("--ha_crf_dir", required=True, type=str) 28 | args = parser.parse_args() 29 | 30 | pyutils.Logger(args.session_name + '.log') 31 | 32 | print(vars(args)) 33 | 34 | model = getattr(importlib.import_module(args.network), 'Net')() 35 | 36 | print(model) 37 | 38 | train_dataset = voc12.data.VOC12AffDataset(args.train_list, label_la_dir=args.la_crf_dir, 39 | label_ha_dir=args.ha_crf_dir, 40 | voc12_root=args.voc12_root, cropsize=args.crop_size, radius=5, 41 | joint_transform_list=[ 42 | None, 43 | None, 44 | imutils.RandomCrop(args.crop_size), 45 | imutils.RandomHorizontalFlip() 46 | ], 47 | img_transform_list=[ 48 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, 49 | hue=0.1), 50 | np.asarray, 51 | model.normalize, 52 | imutils.HWC_to_CHW 53 | ], 54 | label_transform_list=[ 55 | None, 56 | None, 57 | None, 58 | imutils.AvgPool2d(8) 59 | ]) 60 | 61 | 62 | def worker_init_fn(worker_id): 63 | np.random.seed(1 + worker_id) 64 | 65 | 66 | train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 67 | num_workers=args.num_workers, 68 | pin_memory=True, drop_last=True, worker_init_fn=worker_init_fn) 69 | max_step = len(train_dataset) // args.batch_size * args.max_epoches 70 | 71 | param_groups = model.get_parameter_groups() 72 | optimizer = torchutils.PolyOptimizer([ 73 | {'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec}, 74 | {'params': param_groups[1], 'lr': 2 * args.lr, 'weight_decay': 0}, 75 | {'params': param_groups[2], 'lr': 10 * args.lr, 'weight_decay': args.wt_dec}, 76 | {'params': param_groups[3], 'lr': 20 * args.lr, 'weight_decay': 0} 77 | ], lr=args.lr, weight_decay=args.wt_dec, max_step=max_step) 78 | 79 | if args.weights[-7:] == '.params': 80 | import network.resnet38d 81 | 82 | assert args.network == "network.resnet38_aff" 83 | weights_dict = network.resnet38d.convert_mxnet_to_torch(args.weights) 84 | else: 85 | weights_dict = torch.load(args.weights) 86 | 87 | model.load_state_dict(weights_dict, strict=False) 88 | model = torch.nn.DataParallel(model).cuda() 89 | model.train() 90 | 91 | avg_meter = pyutils.AverageMeter('loss', 'bg_loss', 'fg_loss', 'neg_loss', 'bg_cnt', 92 | 'fg_cnt', 'neg_cnt') 93 | 94 | timer = pyutils.Timer("Session started: ") 95 | 96 | for ep in range(args.max_epoches): 97 | 98 | for iter, pack in enumerate(train_data_loader): 99 | 100 | aff = model.forward(pack[0]) 101 | 102 | bg_label = pack[1][0].cuda(non_blocking=True) 103 | fg_label = pack[1][1].cuda(non_blocking=True) 104 | neg_label = pack[1][2].cuda(non_blocking=True) 105 | 106 | bg_count = torch.sum(bg_label) + 1e-5 107 | fg_count = torch.sum(fg_label) + 1e-5 108 | neg_count = torch.sum(neg_label) + 1e-5 109 | 110 | bg_loss = torch.sum(- bg_label * torch.log(aff + 1e-5)) / bg_count 111 | fg_loss = torch.sum(- fg_label * torch.log(aff + 1e-5)) / fg_count 112 | neg_loss = torch.sum(- neg_label * torch.log(1. + 1e-5 - aff)) / neg_count 113 | 114 | loss = bg_loss / 4 + fg_loss / 4 + neg_loss / 2 115 | 116 | optimizer.zero_grad() 117 | loss.backward() 118 | optimizer.step() 119 | 120 | avg_meter.add({ 121 | 'loss': loss.item(), 122 | 'bg_loss': bg_loss.item(), 'fg_loss': fg_loss.item(), 'neg_loss': neg_loss.item(), 123 | 'bg_cnt': bg_count.item(), 'fg_cnt': fg_count.item(), 'neg_cnt': neg_count.item() 124 | }) 125 | 126 | if (optimizer.global_step - 1) % 50 == 0: 127 | timer.update_progress(optimizer.global_step / max_step) 128 | 129 | print('Iter:%5d/%5d' % (optimizer.global_step - 1, max_step), 130 | 'loss:%.4f %.4f %.4f %.4f' % avg_meter.get('loss', 'bg_loss', 'fg_loss', 'neg_loss'), 131 | 'cnt:%.0f %.0f %.0f' % avg_meter.get('bg_cnt', 'fg_cnt', 'neg_cnt'), 132 | 'imps:%.1f' % ((iter + 1) * args.batch_size / timer.get_stage_elapsed()), 133 | 'Fin:%s' % (timer.str_est_finish()), 134 | 'lr: %.4f' % (optimizer.param_groups[0]['lr']), flush=True) 135 | 136 | avg_meter.pop() 137 | 138 | 139 | else: 140 | print('') 141 | timer.reset_stage() 142 | 143 | torch.save(model.module.state_dict(), args.session_name + '.pth') 144 | -------------------------------------------------------------------------------- /contrast_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import voc12.data 6 | import importlib 7 | import imageio 8 | import torchvision 9 | import torch.nn.functional as F 10 | import pydensecrf.densecrf as dcrf 11 | from PIL import Image 12 | from torch.utils.data import DataLoader 13 | from tool import imutils, pyutils 14 | from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian 15 | from tqdm import tqdm 16 | 17 | if __name__ == '__main__': 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--weights", required=True, type=str) 21 | parser.add_argument("--network", default="network.resnet38_contrast", type=str) 22 | parser.add_argument("--infer_list", default="voc12/train.txt", type=str) 23 | parser.add_argument("--num_workers", default=8, type=int) 24 | parser.add_argument("--voc12_root", default='VOC2012', type=str) 25 | parser.add_argument("--out_cam", default=None, type=str) # cam_npy 26 | parser.add_argument("--out_crf", default=None, type=str) # crf_png 27 | parser.add_argument("--out_cam_pred", default=None, type=str) # cam_png 28 | parser.add_argument("--out_cam_pred_alpha", default=0.26, type=float) # cam_png_bg_score 29 | parser.add_argument("--crf_iters", default=10, type=float) 30 | 31 | args = parser.parse_args() 32 | model = getattr(importlib.import_module(args.network), 'Net')() 33 | model.load_state_dict(torch.load(args.weights)) 34 | 35 | model.eval() 36 | model.cuda() 37 | 38 | infer_dataset = voc12.data.VOC12ClsDatasetMSF(args.infer_list, voc12_root=args.voc12_root, 39 | scales=[0.5, 1.0, 1.5, 2.0], 40 | inter_transform=torchvision.transforms.Compose( 41 | [np.asarray, 42 | model.normalize, 43 | imutils.HWC_to_CHW])) 44 | infer_data_loader = DataLoader(infer_dataset, shuffle=False, num_workers=args.num_workers, pin_memory=True) 45 | 46 | n_gpus = torch.cuda.device_count() 47 | model_replicas = torch.nn.parallel.replicate(model, list(range(n_gpus))) 48 | 49 | for iter, (img_name, img_list, label) in tqdm(enumerate(infer_data_loader), total=len(infer_data_loader)): 50 | img_name = img_name[0] 51 | label = label[0] 52 | 53 | img_path = voc12.data.get_img_path(img_name, args.voc12_root) 54 | orig_img = np.asarray(Image.open(img_path)) 55 | orig_img_size = orig_img.shape[:2] 56 | 57 | 58 | def _work(i, img): 59 | with torch.no_grad(): 60 | with torch.cuda.device(i % n_gpus): 61 | _, cam, f, ff = model_replicas[i % n_gpus](img.cuda()) 62 | cam = F.upsample(cam[:, 1:, :, :], orig_img_size, mode='bilinear', align_corners=False)[0] 63 | cam = cam.cpu().numpy() * label.clone().view(20, 1, 1).numpy() 64 | if i % 2 == 1: 65 | cam = np.flip(cam, axis=-1) 66 | return cam 67 | 68 | 69 | thread_pool = pyutils.BatchThreader(_work, list(enumerate(img_list)), 70 | batch_size=12, prefetch_size=0, 71 | processes=args.num_workers) 72 | 73 | cam_list = thread_pool.pop_results() 74 | 75 | sum_cam = np.sum(cam_list, axis=0) 76 | sum_cam[sum_cam < 0] = 0 77 | cam_max = np.max(sum_cam, (1, 2), keepdims=True) 78 | cam_min = np.min(sum_cam, (1, 2), keepdims=True) 79 | sum_cam[sum_cam < cam_min + 1e-5] = 0 80 | norm_cam = (sum_cam - cam_min - 1e-5) / (cam_max - cam_min + 1e-5) 81 | 82 | cam_dict = {} 83 | for i in range(20): 84 | if label[i] > 1e-5: 85 | cam_dict[i] = norm_cam[i] 86 | 87 | if args.out_cam is not None: 88 | if not os.path.exists(args.out_cam): 89 | os.makedirs(args.out_cam) 90 | np.save(os.path.join(args.out_cam, img_name + '.npy'), cam_dict) 91 | 92 | if args.out_cam_pred is not None: 93 | 94 | if not os.path.exists(args.out_cam_pred): 95 | os.makedirs(args.out_cam_pred) 96 | 97 | bg_score = [np.ones_like(norm_cam[0]) * args.out_cam_pred_alpha] 98 | pred = np.argmax(np.concatenate((bg_score, norm_cam)), 0) 99 | imageio.imsave(os.path.join(args.out_cam_pred, img_name + '.png'), pred.astype(np.uint8)) 100 | 101 | 102 | def _crf(cam_dict, bg_score=0.26): 103 | 104 | h, w = list(cam_dict.values())[0].shape 105 | tensor = np.zeros((21, h, w), np.float32) 106 | for key in cam_dict.keys(): 107 | tensor[key + 1] = cam_dict[key] 108 | tensor[0, :, :] = bg_score 109 | predict = np.argmax(tensor, axis=0).astype(np.uint8) 110 | img = Image.open(os.path.join('./VOC2012/JPEGImages', img_name + '.jpg')).convert("RGB") 111 | img = np.array(img) 112 | crf_score = _crf_inference(img, predict) 113 | return np.argmax(crf_score, axis=0).astype(np.uint8) 114 | 115 | def _crf_inference(img, labels, t=10, n_labels=21, gt_prob=0.7): 116 | 117 | h, w = img.shape[:2] 118 | d = dcrf.DenseCRF2D(w, h, n_labels) 119 | unary = unary_from_labels(labels, n_labels, gt_prob=gt_prob, zero_unsure=False) 120 | d.setUnaryEnergy(unary) 121 | d.addPairwiseGaussian(sxy=3, compat=3) 122 | d.addPairwiseBilateral(sxy=50, srgb=5, rgbim=np.ascontiguousarray(np.copy(img)), compat=10) 123 | 124 | q = d.inference(t) 125 | 126 | return np.array(q).reshape((n_labels, h, w)) 127 | 128 | 129 | if args.out_crf is not None: 130 | crf_pred = _crf(cam_dict) 131 | folder = args.out_crf 132 | if not os.path.exists(folder): 133 | os.makedirs(folder) 134 | imageio.imsave(os.path.join(folder, img_name + '.png'), crf_pred.astype(np.uint8)) 135 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from PIL import Image 5 | import multiprocessing 6 | import argparse 7 | 8 | categories = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 9 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 10 | 'tvmonitor'] 11 | 12 | 13 | def do_python_eval(predict_folder, gt_folder, name_list, num_cls=21, input_type='png', threshold=1.0, printlog=False): 14 | TP = [] 15 | P = [] 16 | T = [] 17 | for i in range(num_cls): 18 | TP.append(multiprocessing.Value('i', 0, lock=True)) 19 | P.append(multiprocessing.Value('i', 0, lock=True)) 20 | T.append(multiprocessing.Value('i', 0, lock=True)) 21 | 22 | def compare(start, step, TP, P, T, input_type, threshold): 23 | for idx in range(start, len(name_list), step): 24 | name = name_list[idx] 25 | if input_type == 'png': 26 | predict_file = os.path.join(predict_folder, '%s.png' % name) 27 | predict = np.array(Image.open(predict_file)) # cv2.imread(predict_file) 28 | elif input_type == 'npy': 29 | predict_file = os.path.join(predict_folder, '%s.npy' % name) 30 | predict_dict = np.load(predict_file, allow_pickle=True).item() 31 | h, w = list(predict_dict.values())[0].shape 32 | tensor = np.zeros((21, h, w), np.float32) 33 | for key in predict_dict.keys(): 34 | tensor[key + 1] = predict_dict[key] 35 | tensor[0, :, :] = threshold 36 | predict = np.argmax(tensor, axis=0).astype(np.uint8) 37 | 38 | gt_file = os.path.join(gt_folder, '%s.png' % name) 39 | gt = np.array(Image.open(gt_file)) 40 | cal = gt < 255 41 | mask = (predict == gt) * cal 42 | 43 | for i in range(num_cls): 44 | P[i].acquire() 45 | P[i].value += np.sum((predict == i) * cal) 46 | P[i].release() 47 | T[i].acquire() 48 | T[i].value += np.sum((gt == i) * cal) 49 | T[i].release() 50 | TP[i].acquire() 51 | TP[i].value += np.sum((gt == i) * mask) 52 | TP[i].release() 53 | 54 | p_list = [] 55 | for i in range(8): 56 | p = multiprocessing.Process(target=compare, args=(i, 8, TP, P, T, input_type, threshold)) 57 | p.start() 58 | p_list.append(p) 59 | for p in p_list: 60 | p.join() 61 | IoU = [] 62 | T_TP = [] 63 | P_TP = [] 64 | FP_ALL = [] 65 | FN_ALL = [] 66 | for i in range(num_cls): 67 | IoU.append(TP[i].value / (T[i].value + P[i].value - TP[i].value + 1e-10)) 68 | T_TP.append(T[i].value / (TP[i].value + 1e-10)) 69 | P_TP.append(P[i].value / (TP[i].value + 1e-10)) 70 | FP_ALL.append((P[i].value - TP[i].value) / (T[i].value + P[i].value - TP[i].value + 1e-10)) 71 | FN_ALL.append((T[i].value - TP[i].value) / (T[i].value + P[i].value - TP[i].value + 1e-10)) 72 | loglist = {} 73 | for i in range(num_cls): 74 | loglist[categories[i]] = IoU[i] * 100 75 | 76 | miou = np.mean(np.array(IoU)) 77 | loglist['mIoU'] = miou * 100 78 | if printlog: 79 | for i in range(num_cls): 80 | if i % 2 != 1: 81 | print('%11s:%7.3f%%' % (categories[i], IoU[i] * 100), end='\t') 82 | else: 83 | print('%11s:%7.3f%%' % (categories[i], IoU[i] * 100)) 84 | print('\n======================================================') 85 | print('%11s:%7.3f%%' % ('mIoU', miou * 100)) 86 | return loglist 87 | 88 | 89 | def writedict(file, dictionary): 90 | s = '' 91 | for key in dictionary.keys(): 92 | sub = '%s:%s ' % (key, dictionary[key]) 93 | s += sub 94 | s += '\n' 95 | file.write(s) 96 | 97 | 98 | def writelog(filepath, metric, comment): 99 | filepath = filepath 100 | logfile = open(filepath, 'a') 101 | import time 102 | logfile.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 103 | logfile.write('\t%s\n' % comment) 104 | writedict(logfile, metric) 105 | logfile.write('=====================================\n') 106 | logfile.close() 107 | 108 | 109 | if __name__ == '__main__': 110 | 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("--list", default='./VOC2012/ImageSets/Segmentation/train.txt', type=str) 113 | parser.add_argument("--predict_dir", default='./out_rw', type=str) 114 | parser.add_argument("--gt_dir", default='./VOC2012/SegmentationClass', type=str) 115 | parser.add_argument('--logfile', default='./evallog.txt', type=str) 116 | parser.add_argument('--comment', required=True, type=str) 117 | parser.add_argument('--type', default='png', choices=['npy', 'png'], type=str) 118 | parser.add_argument('--t', default=None, type=float) 119 | parser.add_argument('--curve', default=False, type=bool) 120 | args = parser.parse_args() 121 | 122 | if args.type == 'npy': 123 | assert args.t is not None or args.curve 124 | df = pd.read_csv(args.list, names=['filename']) 125 | name_list = df['filename'].values 126 | if not args.curve: 127 | loglist = do_python_eval(args.predict_dir, args.gt_dir, name_list, 21, args.type, args.t, printlog=True) 128 | writelog(args.logfile, loglist, args.comment) 129 | else: 130 | l = [] 131 | for i in range(60): 132 | t = i / 100.0 133 | loglist = do_python_eval(args.predict_dir, args.gt_dir, name_list, 21, args.type, t) 134 | l.append(loglist['mIoU']) 135 | print('%d/60 background score: %.3f\tmIoU: %.3f%%' % (i, t, loglist['mIoU'])) 136 | writelog(args.logfile, {'mIoU': l}, args.comment) 137 | -------------------------------------------------------------------------------- /network/resnet38_SEAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | np.set_printoptions(threshold=np.inf) 8 | 9 | import network.resnet38d 10 | from tool import pyutils 11 | 12 | 13 | class Net(network.resnet38d.Net): 14 | def __init__(self): 15 | super(Net, self).__init__() 16 | self.dropout7 = torch.nn.Dropout2d(0.5) 17 | self.fc8 = nn.Conv2d(4096, 21, 1, bias=False) 18 | 19 | self.f8_3 = torch.nn.Conv2d(512, 64, 1, bias=False) 20 | self.f8_4 = torch.nn.Conv2d(1024, 128, 1, bias=False) 21 | self.f9 = torch.nn.Conv2d(192 + 3, 192, 1, bias=False) 22 | 23 | torch.nn.init.xavier_uniform_(self.fc8.weight) 24 | torch.nn.init.kaiming_normal_(self.f8_3.weight) 25 | torch.nn.init.kaiming_normal_(self.f8_4.weight) 26 | torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) 27 | self.from_scratch_layers = [self.f8_3, self.f8_4, self.f9, self.fc8] 28 | self.not_training = [self.conv1a, self.b2, self.b2_1, self.b2_2] 29 | 30 | def forward(self, x): 31 | N, C, H, W = x.size() 32 | d = super().forward_as_dict(x) 33 | cam = self.fc8(self.dropout7(d['conv6'])) 34 | n, c, h, w = cam.size() 35 | 36 | with torch.no_grad(): 37 | cam_d = F.relu(cam.detach()) 38 | cam_d_max = torch.max(cam_d.view(n, c, -1), dim=-1)[0].view(n, c, 1, 1) + 1e-5 39 | # max norm 40 | cam_d_norm = F.relu(cam_d - 1e-5) / cam_d_max 41 | cam_d_norm[:, 0, :, :] = 1 - torch.max(cam_d_norm[:, 1:, :, :], dim=1)[0] 42 | cam_max = torch.max(cam_d_norm[:, 1:, :, :], dim=1, keepdim=True)[0] 43 | cam_d_norm[:, 1:, :, :][cam_d_norm[:, 1:, :, :] < cam_max] = 0 44 | 45 | f8_3 = F.relu(self.f8_3(d['conv4'].detach()), inplace=True) 46 | f8_4 = F.relu(self.f8_4(d['conv5'].detach()), inplace=True) 47 | x_s = F.interpolate(x, (h, w), mode='bilinear', align_corners=True) 48 | f = torch.cat([x_s, f8_3, f8_4], dim=1) 49 | n, c, h, w = f.size() 50 | 51 | cam_rv = F.interpolate(self.PCM(cam_d_norm, f), (H, W), mode='bilinear', align_corners=True) 52 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=True) 53 | return cam, cam_rv 54 | 55 | def PCM(self, cam, f): 56 | 57 | n, c, h, w = f.size() 58 | cam = F.interpolate(cam, (h, w), mode='bilinear', align_corners=True).view(n, -1, h * w) 59 | f = self.f9(f) 60 | f = f.view(n, -1, h * w) 61 | # norm 62 | f = f / (torch.norm(f, dim=1, keepdim=True) + 1e-5) 63 | aff = F.relu(torch.matmul(f.transpose(1, 2), f), inplace=True) 64 | aff = aff / (torch.sum(aff, dim=1, keepdim=True) + 1e-5) 65 | cam_rv = torch.matmul(cam, aff).view(n, -1, h, w) 66 | 67 | return cam_rv 68 | 69 | def get_parameter_groups(self): 70 | groups = ([], [], [], []) 71 | print('======================================================') 72 | for m in self.modules(): 73 | 74 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): 75 | 76 | if m.weight.requires_grad: 77 | if m in self.from_scratch_layers: 78 | groups[2].append(m.weight) 79 | else: 80 | groups[0].append(m.weight) 81 | 82 | if m.bias is not None and m.bias.requires_grad: 83 | if m in self.from_scratch_layers: 84 | groups[3].append(m.bias) 85 | else: 86 | groups[1].append(m.bias) 87 | 88 | return groups 89 | -------------------------------------------------------------------------------- /network/resnet38_aff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | import torch.nn.functional as F 5 | 6 | import wSEAM.network.resnet38d 7 | from wSEAM.tool import pyutils 8 | 9 | 10 | class Net(wSEAM.network.resnet38d.Net): 11 | def __init__(self): 12 | super(Net, self).__init__() 13 | 14 | self.f8_3 = torch.nn.Conv2d(512, 64, 1, bias=False) 15 | self.f8_4 = torch.nn.Conv2d(1024, 128, 1, bias=False) 16 | self.f8_5 = torch.nn.Conv2d(4096, 256, 1, bias=False) 17 | 18 | self.f9 = torch.nn.Conv2d(448, 448, 1, bias=False) 19 | 20 | torch.nn.init.kaiming_normal_(self.f8_3.weight) 21 | torch.nn.init.kaiming_normal_(self.f8_4.weight) 22 | torch.nn.init.kaiming_normal_(self.f8_5.weight) 23 | torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) 24 | 25 | self.not_training = [self.conv1a, self.b2, self.b2_1, self.b2_2] 26 | 27 | self.from_scratch_layers = [self.f8_3, self.f8_4, self.f8_5, self.f9] 28 | 29 | self.predefined_featuresize = int(448//8) 30 | self.radius = 5 31 | self.ind_from, self.ind_to = pyutils.get_indices_of_pairs(radius=self.radius, size=(self.predefined_featuresize, self.predefined_featuresize)) 32 | self.ind_from = torch.from_numpy(self.ind_from); self.ind_to = torch.from_numpy(self.ind_to) 33 | return 34 | 35 | def forward(self, x, to_dense=False): 36 | 37 | d = super().forward_as_dict(x) 38 | 39 | f8_3 = F.elu(self.f8_3(d['conv4'])) 40 | f8_4 = F.elu(self.f8_4(d['conv5'])) 41 | f8_5 = F.elu(self.f8_5(d['conv6'])) 42 | x = F.elu(self.f9(torch.cat([f8_3, f8_4, f8_5], dim=1))) 43 | 44 | if x.size(2) == self.predefined_featuresize and x.size(3) == self.predefined_featuresize: 45 | ind_from = self.ind_from 46 | ind_to = self.ind_to 47 | else: 48 | min_edge = min(x.size(2), x.size(3)) 49 | radius = (min_edge-1)//2 if min_edge < self.radius*2+1 else self.radius 50 | ind_from, ind_to = pyutils.get_indices_of_pairs(radius, (x.size(2), x.size(3))) 51 | ind_from = torch.from_numpy(ind_from); ind_to = torch.from_numpy(ind_to) 52 | 53 | x = x.view(x.size(0), x.size(1), -1).contiguous() 54 | ind_from = ind_from.contiguous() 55 | ind_to = ind_to.contiguous() 56 | 57 | ff = torch.index_select(x, dim=2, index=ind_from.cuda(non_blocking=True)) 58 | ft = torch.index_select(x, dim=2, index=ind_to.cuda(non_blocking=True)) 59 | 60 | ff = torch.unsqueeze(ff, dim=2) 61 | ft = ft.view(ft.size(0), ft.size(1), -1, ff.size(3)) 62 | 63 | aff = torch.exp(-torch.mean(torch.abs(ft-ff), dim=1)) 64 | 65 | if to_dense: 66 | aff = aff.view(-1).cpu() 67 | 68 | ind_from_exp = torch.unsqueeze(ind_from, dim=0).expand(ft.size(2), -1).contiguous().view(-1) 69 | indices = torch.stack([ind_from_exp, ind_to]) 70 | indices_tp = torch.stack([ind_to, ind_from_exp]) 71 | 72 | area = x.size(2) 73 | indices_id = torch.stack([torch.arange(0, area).long(), torch.arange(0, area).long()]) 74 | 75 | aff_mat = sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), 76 | torch.cat([aff, torch.ones([area]), aff])).to_dense().cuda() 77 | 78 | return aff_mat 79 | 80 | else: 81 | return aff 82 | 83 | 84 | def get_parameter_groups(self): 85 | groups = ([], [], [], []) 86 | 87 | for m in self.modules(): 88 | 89 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): 90 | 91 | if m.weight.requires_grad: 92 | if m in self.from_scratch_layers: 93 | groups[2].append(m.weight) 94 | else: 95 | groups[0].append(m.weight) 96 | 97 | if m.bias is not None and m.bias.requires_grad: 98 | 99 | if m in self.from_scratch_layers: 100 | groups[3].append(m.bias) 101 | else: 102 | groups[1].append(m.bias) 103 | 104 | return groups 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /network/resnet38_contrast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.sparse as sparse 4 | import torch.nn.functional as F 5 | import numpy as np 6 | np.set_printoptions(threshold=np.inf) 7 | 8 | import network.resnet38d 9 | from tool import pyutils 10 | 11 | class Net(network.resnet38d.Net): 12 | def __init__(self): 13 | super(Net, self).__init__() 14 | self.dropout7 = torch.nn.Dropout2d(0.5) 15 | self.fc8 = nn.Conv2d(4096, 21, 1, bias=False) 16 | self.fc_proj = torch.nn.Conv2d(4096, 128, 1, bias=False) 17 | 18 | self.f8_3 = torch.nn.Conv2d(512, 64, 1, bias=False) 19 | self.f8_4 = torch.nn.Conv2d(1024, 128, 1, bias=False) 20 | self.f9 = torch.nn.Conv2d(192+3, 192, 1, bias=False) 21 | 22 | torch.nn.init.xavier_uniform_(self.fc8.weight) 23 | torch.nn.init.kaiming_normal_(self.f8_3.weight) 24 | torch.nn.init.kaiming_normal_(self.f8_4.weight) 25 | torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) 26 | torch.nn.init.xavier_uniform_(self.fc_proj.weight) 27 | 28 | self.from_scratch_layers = [self.f8_3, self.f8_4, self.f9, self.fc8, self.fc_proj] 29 | self.not_training = [self.conv1a, self.b2, self.b2_1, self.b2_2] 30 | 31 | def forward(self, x): 32 | N, C, H, W = x.size() 33 | d = super().forward_as_dict(x) 34 | fea = self.dropout7(d['conv6']) 35 | 36 | f_proj = F.relu(self.fc_proj(fea), inplace=True) 37 | 38 | cam = self.fc8(fea) 39 | n,c,h,w = cam.size() 40 | 41 | with torch.no_grad(): 42 | cam_d = F.relu(cam.detach()) 43 | cam_d_max = torch.max(cam_d.view(n, c, -1), dim=-1)[0].view(n, c, 1, 1)+1e-5 44 | # max norm 45 | cam_d_norm = F.relu(cam_d - 1e-5) / cam_d_max 46 | cam_d_norm[:, 0, :, :] = 1-torch.max(cam_d_norm[:, 1:, :, :], dim=1)[0] 47 | cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0] 48 | cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0 49 | 50 | f8_3 = F.relu(self.f8_3(d['conv4'].detach()), inplace=True) 51 | f8_4 = F.relu(self.f8_4(d['conv5'].detach()), inplace=True) 52 | x_s = F.interpolate(x, (h, w), mode='bilinear', align_corners=True) 53 | f = torch.cat([x_s, f8_3, f8_4], dim=1) 54 | n, c, h, w = f.size() 55 | 56 | cam_rv_down = self.PCM(cam_d_norm, f) 57 | cam_rv = F.interpolate(cam_rv_down, (H,W), 58 | mode='bilinear', align_corners=True) 59 | cam = F.interpolate(cam, (H, W), mode='bilinear', align_corners=True) 60 | 61 | return cam, cam_rv, f_proj, cam_rv_down 62 | 63 | def PCM(self, cam, f): 64 | 65 | n,c,h,w = f.size() 66 | cam = F.interpolate(cam, (h,w), mode='bilinear', align_corners=True).view(n,-1,h*w) 67 | f = self.f9(f) 68 | f = f.view(n, -1, h*w) 69 | # norm 70 | f = f / (torch.norm(f, dim=1, keepdim=True) + 1e-5) 71 | aff = F.relu(torch.matmul(f.transpose(1, 2), f), inplace=True) 72 | aff = aff/(torch.sum(aff, dim=1, keepdim=True) + 1e-5) 73 | cam_rv = torch.matmul(cam, aff).view(n, -1, h, w) 74 | 75 | return cam_rv 76 | 77 | def get_parameter_groups(self): 78 | groups = ([], [], [], []) 79 | print('======================================================') 80 | for m in self.modules(): 81 | 82 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): 83 | 84 | if m.weight.requires_grad: 85 | if m in self.from_scratch_layers: 86 | groups[2].append(m.weight) 87 | else: 88 | groups[0].append(m.weight) 89 | 90 | if m.bias is not None and m.bias.requires_grad: 91 | if m in self.from_scratch_layers: 92 | groups[3].append(m.bias) 93 | else: 94 | groups[1].append(m.bias) 95 | 96 | return groups 97 | 98 | -------------------------------------------------------------------------------- /network/resnet38d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | import torch.nn.functional as F 6 | class ResBlock(nn.Module): 7 | def __init__(self, in_channels, mid_channels, out_channels, stride=1, 8 | first_dilation=None, dilation=1): 9 | super(ResBlock, self).__init__() 10 | 11 | self.same_shape = (in_channels == out_channels and stride == 1) 12 | 13 | if first_dilation == None: first_dilation = dilation 14 | 15 | self.bn_branch2a = nn.BatchNorm2d(in_channels) 16 | 17 | self.conv_branch2a = nn.Conv2d(in_channels, mid_channels, 3, stride, 18 | padding=first_dilation, dilation=first_dilation, bias=False) 19 | 20 | self.bn_branch2b1 = nn.BatchNorm2d(mid_channels) 21 | 22 | self.conv_branch2b1 = nn.Conv2d(mid_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False) 23 | 24 | if not self.same_shape: 25 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 26 | 27 | def forward(self, x, get_x_bn_relu=False): 28 | 29 | branch2 = self.bn_branch2a(x) 30 | branch2 = F.relu(branch2) 31 | 32 | x_bn_relu = branch2 33 | 34 | if not self.same_shape: 35 | branch1 = self.conv_branch1(branch2) 36 | else: 37 | branch1 = x 38 | 39 | branch2 = self.conv_branch2a(branch2) 40 | branch2 = self.bn_branch2b1(branch2) 41 | branch2 = F.relu(branch2) 42 | branch2 = self.conv_branch2b1(branch2) 43 | 44 | x = branch1 + branch2 45 | 46 | if get_x_bn_relu: 47 | return x, x_bn_relu 48 | 49 | return x 50 | 51 | def __call__(self, x, get_x_bn_relu=False): 52 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 53 | 54 | class ResBlock_bot(nn.Module): 55 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, dropout=0.): 56 | super(ResBlock_bot, self).__init__() 57 | 58 | self.same_shape = (in_channels == out_channels and stride == 1) 59 | 60 | self.bn_branch2a = nn.BatchNorm2d(in_channels) 61 | self.conv_branch2a = nn.Conv2d(in_channels, out_channels//4, 1, stride, bias=False) 62 | 63 | self.bn_branch2b1 = nn.BatchNorm2d(out_channels//4) 64 | self.dropout_2b1 = torch.nn.Dropout2d(dropout) 65 | self.conv_branch2b1 = nn.Conv2d(out_channels//4, out_channels//2, 3, padding=dilation, dilation=dilation, bias=False) 66 | 67 | self.bn_branch2b2 = nn.BatchNorm2d(out_channels//2) 68 | self.dropout_2b2 = torch.nn.Dropout2d(dropout) 69 | self.conv_branch2b2 = nn.Conv2d(out_channels//2, out_channels, 1, bias=False) 70 | 71 | if not self.same_shape: 72 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 73 | 74 | def forward(self, x, get_x_bn_relu=False): 75 | 76 | branch2 = self.bn_branch2a(x) 77 | branch2 = F.relu(branch2) 78 | x_bn_relu = branch2 79 | 80 | branch1 = self.conv_branch1(branch2) 81 | 82 | branch2 = self.conv_branch2a(branch2) 83 | 84 | branch2 = self.bn_branch2b1(branch2) 85 | branch2 = F.relu(branch2) 86 | branch2 = self.dropout_2b1(branch2) 87 | branch2 = self.conv_branch2b1(branch2) 88 | 89 | branch2 = self.bn_branch2b2(branch2) 90 | branch2 = F.relu(branch2) 91 | branch2 = self.dropout_2b2(branch2) 92 | branch2 = self.conv_branch2b2(branch2) 93 | 94 | x = branch1 + branch2 95 | 96 | if get_x_bn_relu: 97 | return x, x_bn_relu 98 | 99 | return x 100 | 101 | def __call__(self, x, get_x_bn_relu=False): 102 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 103 | 104 | class Normalize(): 105 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)): 106 | 107 | self.mean = mean 108 | self.std = std 109 | 110 | def __call__(self, img): 111 | imgarr = np.asarray(img) 112 | proc_img = np.empty_like(imgarr, np.float32) 113 | 114 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0] 115 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1] 116 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2] 117 | 118 | return proc_img 119 | 120 | class Net(nn.Module): 121 | def __init__(self): 122 | super(Net, self).__init__() 123 | 124 | self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False) 125 | 126 | self.b2 = ResBlock(64, 128, 128, stride=2) 127 | self.b2_1 = ResBlock(128, 128, 128) 128 | self.b2_2 = ResBlock(128, 128, 128) 129 | 130 | self.b3 = ResBlock(128, 256, 256, stride=2) 131 | self.b3_1 = ResBlock(256, 256, 256) 132 | self.b3_2 = ResBlock(256, 256, 256) 133 | 134 | self.b4 = ResBlock(256, 512, 512, stride=2) 135 | self.b4_1 = ResBlock(512, 512, 512) 136 | self.b4_2 = ResBlock(512, 512, 512) 137 | self.b4_3 = ResBlock(512, 512, 512) 138 | self.b4_4 = ResBlock(512, 512, 512) 139 | self.b4_5 = ResBlock(512, 512, 512) 140 | 141 | self.b5 = ResBlock(512, 512, 1024, stride=1, first_dilation=1, dilation=2) 142 | self.b5_1 = ResBlock(1024, 512, 1024, dilation=2) 143 | self.b5_2 = ResBlock(1024, 512, 1024, dilation=2) 144 | 145 | self.b6 = ResBlock_bot(1024, 2048, stride=1, dilation=4, dropout=0.3) 146 | 147 | self.b7 = ResBlock_bot(2048, 4096, dilation=4, dropout=0.5) 148 | 149 | self.bn7 = nn.BatchNorm2d(4096) 150 | 151 | self.not_training = [self.conv1a] 152 | 153 | self.normalize = Normalize() 154 | 155 | return 156 | 157 | def forward(self, x): 158 | return self.forward_as_dict(x)['conv6'] 159 | 160 | def forward_as_dict(self, x): 161 | 162 | x = self.conv1a(x) 163 | 164 | x = self.b2(x) 165 | x = self.b2_1(x) 166 | x = self.b2_2(x) 167 | 168 | x = self.b3(x) 169 | x = self.b3_1(x) 170 | x = self.b3_2(x) 171 | 172 | #x = self.b4(x) 173 | x, conv3 = self.b4(x, get_x_bn_relu=True) 174 | x = self.b4_1(x) 175 | x = self.b4_2(x) 176 | x = self.b4_3(x) 177 | x = self.b4_4(x) 178 | x = self.b4_5(x) 179 | 180 | x, conv4 = self.b5(x, get_x_bn_relu=True) 181 | x = self.b5_1(x) 182 | x = self.b5_2(x) 183 | 184 | x, conv5 = self.b6(x, get_x_bn_relu=True) 185 | 186 | x = self.b7(x) 187 | conv6 = F.relu(self.bn7(x)) 188 | 189 | return dict({'conv3': conv3, 'conv4': conv4, 'conv5': conv5, 'conv6': conv6}) 190 | 191 | 192 | def train(self, mode=True): 193 | 194 | super().train(mode) 195 | 196 | for layer in self.not_training: 197 | 198 | if isinstance(layer, torch.nn.Conv2d): 199 | layer.weight.requires_grad = False 200 | 201 | elif isinstance(layer, torch.nn.Module): 202 | for c in layer.children(): 203 | c.weight.requires_grad = False 204 | if c.bias is not None: 205 | c.bias.requires_grad = False 206 | 207 | for layer in self.modules(): 208 | 209 | if isinstance(layer, torch.nn.BatchNorm2d): 210 | layer.eval() 211 | layer.bias.requires_grad = False 212 | layer.weight.requires_grad = False 213 | 214 | return 215 | 216 | def convert_mxnet_to_torch(filename): 217 | import mxnet 218 | 219 | save_dict = mxnet.nd.load(filename) 220 | 221 | renamed_dict = dict() 222 | 223 | bn_param_mx_pt = {'beta': 'bias', 'gamma': 'weight', 'mean': 'running_mean', 'var': 'running_var'} 224 | 225 | for k, v in save_dict.items(): 226 | 227 | v = torch.from_numpy(v.asnumpy()) 228 | toks = k.split('_') 229 | 230 | if 'conv1a' in toks[0]: 231 | renamed_dict['conv1a.weight'] = v 232 | 233 | elif 'linear1000' in toks[0]: 234 | pass 235 | 236 | elif 'branch' in toks[1]: 237 | 238 | pt_name = [] 239 | 240 | if toks[0][-1] != 'a': 241 | pt_name.append('b' + toks[0][-3] + '_' + toks[0][-1]) 242 | else: 243 | pt_name.append('b' + toks[0][-2]) 244 | 245 | if 'res' in toks[0]: 246 | layer_type = 'conv' 247 | last_name = 'weight' 248 | 249 | else: # 'bn' in toks[0]: 250 | layer_type = 'bn' 251 | last_name = bn_param_mx_pt[toks[-1]] 252 | 253 | pt_name.append(layer_type + '_' + toks[1]) 254 | 255 | pt_name.append(last_name) 256 | 257 | torch_name = '.'.join(pt_name) 258 | renamed_dict[torch_name] = v 259 | 260 | else: 261 | last_name = bn_param_mx_pt[toks[-1]] 262 | renamed_dict['bn7.' + last_name] = v 263 | 264 | return renamed_dict 265 | 266 | -------------------------------------------------------------------------------- /script/script_cls.sh: -------------------------------------------------------------------------------- 1 | # NEED TO SET 2 | DATASET_ROOT=PATH/TO/DATASET 3 | WEIGHT_ROOT=PATH/TO/WEIGHT 4 | GPU=0,1 5 | 6 | 7 | # Default setting 8 | IMG_ROOT=${DATASET_ROOT}/JPEGImages 9 | BACKBONE=resnet38_cls 10 | SESSION=resnet38_cls 11 | BASE_WEIGHT=${WEIGHT}/ilsvrc-cls_rna-a1_cls1000_ep-0001.params 12 | 13 | 14 | # 1. train classification network 15 | CUDA_VISIBLE_DEVICES=${GPU} python3 train.py \ 16 | --session ${SESSION} \ 17 | --network network.${BACKBONE} \ 18 | --data_root ${IMG_ROOT} \ 19 | --weights ${BASE_WEIGHT} \ 20 | --crop_size 448 \ 21 | --max_iters 10000 \ 22 | --iter_size 2 \ 23 | --batch_size 8 24 | 25 | 26 | # 2. inference CAM 27 | DATA=train # train / train_aug 28 | TRAINED_WEIGHT=train_log/${SESSION}/checkpoint_cls.pth 29 | 30 | CUDA_VISIBLE_DEVICES=${GPU} python3 infer.py \ 31 | --infer_list data/voc12/${DATA}_id.txt \ 32 | --img_root ${IMG_ROOT} \ 33 | --network network.${BACKBONE} \ 34 | --weights ${TRAINED_WEIGHT} \ 35 | --thr 0.20 \ 36 | --n_gpus 2 \ 37 | --n_processes_per_gpu 1 1 \ 38 | --cam_png train_log/${SESSION}/result/cam_png 39 | 40 | # 3. evaluate CAM 41 | GT_ROOT=${DATASET_ROOT}/SegmentationClassAug/ 42 | 43 | CUDA_VISIBLE_DEVICES=${GPU} python3 evaluate_png.py \ 44 | --datalist data/voc12/${DATA}.txt \ 45 | --gt_dir ${GT_ROOT} \ 46 | --save_path train_log/${SESSION}/result/${DATA}.txt \ 47 | --pred_dir train_log/${SESSION}/result/cam_png -------------------------------------------------------------------------------- /script/script_contrast.sh: -------------------------------------------------------------------------------- 1 | # NEED TO SET 2 | DATASET_ROOT=./VOC2012 3 | WEIGHT_ROOT=./weights 4 | SESSION=resnet38_contrast 5 | 6 | GPU=0,1,2 7 | 8 | 9 | 10 | BASE_WEIGHT=${WEIGHT_ROOT}/ilsvrc-cls_rna-a1_cls1000_ep-0001.params 11 | 12 | 13 | # train classification network with Contrastive Learning 14 | CUDA_VISIBLE_DEVICES=${GPU} python contrast_train.py \ 15 | --voc12_root ${DATASET_ROOT} \ 16 | --weights ${BASE_WEIGHT} \ 17 | --session_name ${SESSION} \ 18 | --batch_size 9 19 | 20 | 21 | # 2. inference CAM 22 | DATA=trainaug # train / trainaug / val 23 | TRAINED_WEIGHT=train_log/${SESSION}/checkpoint_contrast.pth 24 | CAM_NPY_DIR=store/cam_npy/${DATA} 25 | CAM_PNG_DIR=store/cam_png/${DATA} 26 | CRF_PNG_DIR=store/crf_png/${DATA} 27 | 28 | CUDA_VISIBLE_DEVICES=${GPU} python contrast_infer.py \ 29 | --weights ${TRAINED_WEIGHT} \ 30 | --infer_list voc12/${DATA}.txt \ 31 | --out_cam ${CAM_NPY_DIR} \ 32 | --out_cam_pred ${CAM_PNG_DIR} \ 33 | --out_crf ${CRF_PNG_DIR} 34 | 35 | 36 | # 3. evaluate CAM 37 | DATA=train # train / val 38 | LIST=VOC2012/ImageSets/Segmentation/${DATA}.txt 39 | RESULT_DIR=${CAM_PNG_DIR} 40 | COMMENT=YOURCOMMENT 41 | GT_ROOT=${DATASET_ROOT}/SegmentationClass/ 42 | 43 | 44 | python eval.py \ 45 | --list ${LIST} \ 46 | --predict_dir ${RESULT_DIR} \ 47 | --gt_dir ${GT_ROOT} \ 48 | --comment ${COMMENT} \ 49 | --type png \ 50 | --curve True 51 | -------------------------------------------------------------------------------- /segmentation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usr922/wseg/e96b961038e4171c5a49a0378111b47374dc5219/segmentation/.DS_Store -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv1_resnet101/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv1_resnet101/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usr922/wseg/e96b961038e4171c5a49a0378111b47374dc5219/segmentation/experiment/EPS_deeplabv1_resnet101/__init__.py -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv1_resnet101/config.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | import torch 5 | import argparse 6 | import os 7 | import sys 8 | import cv2 9 | import time 10 | 11 | config_dict = { 12 | 'EXP_NAME': 'EPS_deeplabv1_resnet101', 13 | 'GPUS': 1, 14 | 15 | 'DATA_NAME': 'VOCDataset', 16 | 'DATA_YEAR': 2012, 17 | 'DATA_AUG': True, 18 | 'DATA_WORKERS': 4, 19 | 'DATA_MEAN': [0.485, 0.456, 0.406], 20 | 'DATA_STD': [0.229, 0.224, 0.225], 21 | 'DATA_RANDOMCROP': 448, 22 | 'DATA_RANDOMSCALE': [0.5, 1.5], 23 | 'DATA_RANDOM_H': 10, 24 | 'DATA_RANDOM_S': 10, 25 | 'DATA_RANDOM_V': 10, 26 | 'DATA_RANDOMFLIP': 0.5, 27 | 'DATA_PSEUDO_GT': 'your_pseudo_label_dir', 28 | 29 | 'MODEL_NAME': 'deeplabv1', 30 | 'MODEL_BACKBONE': 'resnet101', 31 | 'MODEL_BACKBONE_PRETRAIN': True, 32 | 'MODEL_NUM_CLASSES': 21, 33 | 'MODEL_FREEZEBN': False, 34 | 35 | # 'MODEL_BACKBONE_DILATED': True, 36 | # 'MODEL_BACKBONE_MULTIGRID': False, 37 | # 'MODEL_BACKBONE_DEEPBASE': True, 38 | 39 | 'TRAIN_LR': 0.001, 40 | 'TRAIN_MOMENTUM': 0.9, 41 | 'TRAIN_WEIGHT_DECAY': 0.0005, 42 | 'TRAIN_BN_MOM': 0.0003, 43 | 'TRAIN_POWER': 0.9, 44 | 'TRAIN_BATCHES': 10, 45 | 'TRAIN_SHUFFLE': True, 46 | 'TRAIN_MINEPOCH': 0, 47 | 'TRAIN_ITERATION': 20000, 48 | 'TRAIN_TBLOG': True, 49 | 50 | 'TEST_MULTISCALE': [0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 51 | 'TEST_FLIP': True, 52 | 'TEST_CRF': True, 53 | 'TEST_BATCHES': 1, 54 | } 55 | 56 | config_dict['ROOT_DIR'] = os.path.abspath(os.path.join(os.path.dirname("__file__"),'..','..')) 57 | config_dict['MODEL_SAVE_DIR'] = os.path.join(config_dict['ROOT_DIR'],'model',config_dict['EXP_NAME']) 58 | config_dict['TRAIN_CKPT'] = None 59 | config_dict['LOG_DIR'] = os.path.join(config_dict['ROOT_DIR'],'log',config_dict['EXP_NAME']) 60 | 61 | # for test, must be updated 62 | config_dict['TEST_CKPT'] = os.path.join(config_dict['ROOT_DIR'], 'your_ckpt.pth') 63 | 64 | sys.path.insert(0, os.path.join(config_dict['ROOT_DIR'], 'lib')) 65 | -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv1_resnet101/test.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | torch.manual_seed(1) # cpu 9 | torch.cuda.manual_seed(1) #gpu 10 | np.random.seed(1) #numpy 11 | random.seed(1) #random and transforms 12 | torch.backends.cudnn.deterministic=True # cudnn 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torchvision 16 | import cv2 17 | import time 18 | 19 | from config import config_dict 20 | from datasets.generateData import generate_dataset 21 | from net.generateNet import generate_net 22 | import torch.optim as optim 23 | from net.sync_batchnorm.replicate import patch_replication_callback 24 | from torch.utils.data import DataLoader 25 | from utils.configuration import Configuration 26 | from utils.finalprocess import writelog 27 | from utils.imutils import img_denorm 28 | from utils.DenseCRF import dense_crf, dense_crf_from_deeplabv2 29 | from utils.test_utils import single_gpu_test 30 | from utils.imutils import onehot 31 | 32 | cfg = Configuration(config_dict, False) 33 | 34 | def ClassLogSoftMax(f, category): 35 | exp = torch.exp(f) 36 | exp_norm = exp/torch.sum(exp*category, dim=1, keepdim=True) 37 | softmax = exp_norm*category 38 | logsoftmax = torch.log(exp_norm)*category 39 | return softmax, logsoftmax 40 | 41 | def test_net(): 42 | # period = 'val' 43 | period = 'test' 44 | dataset = generate_dataset(cfg, period=period, transform='none') 45 | def worker_init_fn(worker_id): 46 | np.random.seed(1 + worker_id) 47 | dataloader = DataLoader(dataset, 48 | batch_size=1, 49 | shuffle=False, 50 | num_workers=cfg.DATA_WORKERS, 51 | worker_init_fn = worker_init_fn) 52 | 53 | net = generate_net(cfg, batchnorm=nn.BatchNorm2d) 54 | # dilated=cfg.MODEL_BACKBONE_DILATED, 55 | # multi_grid=cfg.MODEL_BACKBONE_MULTIGRID, 56 | # deep_base=cfg.MODEL_BACKBONE_DEEPBASE) 57 | 58 | print('net initialize') 59 | 60 | if cfg.TEST_CKPT is None: 61 | raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period') 62 | print('start loading model %s'%cfg.TEST_CKPT) 63 | model_dict = torch.load(cfg.TEST_CKPT) 64 | net.load_state_dict(model_dict, strict=False) 65 | 66 | print('Use %d GPU'%cfg.GPUS) 67 | assert torch.cuda.device_count() == cfg.GPUS 68 | device = torch.device('cuda') 69 | net.to(device) 70 | net.eval() 71 | 72 | def prepare_func(sample): 73 | image_msf = [] 74 | for rate in cfg.TEST_MULTISCALE: 75 | inputs_batched = sample['image_%f'%rate] 76 | image_msf.append(inputs_batched) 77 | if cfg.TEST_FLIP: 78 | image_msf.append(torch.flip(inputs_batched,[3])) 79 | return image_msf 80 | 81 | def inference_func(model, img): 82 | seg = model(img) 83 | return seg 84 | 85 | def collect_func(result_list, sample): 86 | [batch, channel, height, width] = sample['image'].size() 87 | for i in range(len(result_list)): 88 | result_seg = F.interpolate(result_list[i], (height, width), mode='bilinear', align_corners=True) 89 | if cfg.TEST_FLIP and i % 2 == 1: 90 | result_seg = torch.flip(result_seg, [3]) 91 | result_list[i] = result_seg 92 | prob_seg = torch.cat(result_list, dim=0) 93 | prob_seg = F.softmax(torch.mean(prob_seg, dim=0, keepdim=True),dim=1)[0] 94 | 95 | 96 | if cfg.TEST_CRF: 97 | prob = prob_seg.cpu().numpy() 98 | img_batched = img_denorm(sample['image'][0].cpu().numpy()).astype(np.uint8) 99 | # TODO: crf更改 100 | # prob = dense_crf(prob, img_batched, n_classes=cfg.MODEL_NUM_CLASSES, n_iters=1) 101 | prob = dense_crf_from_deeplabv2(prob, img_batched) 102 | prob_seg = torch.from_numpy(prob.astype(np.float32)) 103 | 104 | result = torch.argmax(prob_seg, dim=0, keepdim=False).cpu().numpy() 105 | return result 106 | 107 | def save_step_func(result_sample): 108 | dataset.save_result([result_sample], cfg.MODEL_NAME) 109 | 110 | result_list = single_gpu_test(net, dataloader, prepare_func=prepare_func, inference_func=inference_func, collect_func=collect_func, save_step_func=save_step_func) 111 | resultlog = dataset.do_python_eval(cfg.MODEL_NAME) 112 | print('Test finished') 113 | writelog(cfg, period, metric=resultlog) 114 | 115 | if __name__ == '__main__': 116 | test_net() 117 | 118 | 119 | -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv1_resnet101/train.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | torch.manual_seed(1) # cpu 9 | torch.cuda.manual_seed_all(1) #gpu 10 | np.random.seed(1) #numpy 11 | random.seed(1) #random and transforms 12 | torch.backends.cudnn.deterministic=True # cudnn 13 | import torchvision 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torchvision.transforms as transforms 17 | import os 18 | import sys 19 | import time 20 | 21 | from config import config_dict 22 | from datasets.generateData import generate_dataset 23 | from net.generateNet import generate_net 24 | import torch.optim as optim 25 | from PIL import Image 26 | from tensorboardX import SummaryWriter 27 | from torch.utils.data import DataLoader 28 | from net.sync_batchnorm.replicate import patch_replication_callback 29 | from utils.configuration import Configuration 30 | from utils.finalprocess import writelog 31 | from utils.imutils import img_denorm 32 | from net.sync_batchnorm import SynchronizedBatchNorm2d 33 | from utils.visualization import generate_vis, max_norm 34 | from tqdm import tqdm 35 | 36 | cfg = Configuration(config_dict) 37 | 38 | def train_net(): 39 | period = 'train' 40 | transform = 'weak' 41 | dataset = generate_dataset(cfg, period=period, transform=transform) 42 | def worker_init_fn(worker_id): 43 | np.random.seed(1 + worker_id) 44 | dataloader = DataLoader(dataset, 45 | batch_size=cfg.TRAIN_BATCHES, 46 | shuffle=cfg.TRAIN_SHUFFLE, 47 | num_workers=cfg.DATA_WORKERS, 48 | pin_memory=True, 49 | drop_last=True, 50 | worker_init_fn=worker_init_fn) 51 | 52 | net = generate_net(cfg, batchnorm=nn.BatchNorm2d) 53 | if cfg.TRAIN_CKPT: 54 | net.load_state_dict(torch.load(cfg.TRAIN_CKPT),strict=True) 55 | print('load pretrained model') 56 | if cfg.TRAIN_TBLOG: 57 | from tensorboardX import SummaryWriter 58 | # Set the Tensorboard logger 59 | tblogger = SummaryWriter(cfg.LOG_DIR) 60 | 61 | print('Use %d GPU'%cfg.GPUS) 62 | device = torch.device(0) 63 | if cfg.GPUS > 1: 64 | net = nn.DataParallel(net) 65 | patch_replication_callback(net) 66 | parameter_source = net.module 67 | else: 68 | parameter_source = net 69 | net.to(device) 70 | parameter_groups = parameter_source.get_parameter_groups() 71 | optimizer = optim.SGD( 72 | params = [ 73 | {'params': parameter_groups[0], 'lr': cfg.TRAIN_LR, 'weight_decay': cfg.TRAIN_WEIGHT_DECAY}, 74 | {'params': parameter_groups[1], 'lr': 2*cfg.TRAIN_LR, 'weight_decay': 0}, 75 | {'params': parameter_groups[2], 'lr': 10*cfg.TRAIN_LR, 'weight_decay': cfg.TRAIN_WEIGHT_DECAY}, 76 | {'params': parameter_groups[3], 'lr': 20*cfg.TRAIN_LR, 'weight_decay': 0}, 77 | ], 78 | momentum=cfg.TRAIN_MOMENTUM, 79 | weight_decay=cfg.TRAIN_WEIGHT_DECAY 80 | ) 81 | itr = cfg.TRAIN_MINEPOCH * len(dataset)//(cfg.TRAIN_BATCHES) 82 | max_itr = cfg.TRAIN_ITERATION 83 | max_epoch = max_itr*(cfg.TRAIN_BATCHES)//len(dataset)+1 84 | tblogger = SummaryWriter(cfg.LOG_DIR) 85 | criterion = nn.CrossEntropyLoss(ignore_index=255) 86 | with tqdm(total=max_itr) as pbar: 87 | for epoch in range(cfg.TRAIN_MINEPOCH, max_epoch): 88 | for i_batch, sample in enumerate(dataloader): 89 | 90 | now_lr = adjust_lr(optimizer, itr, max_itr, cfg.TRAIN_LR, cfg.TRAIN_POWER) 91 | optimizer.zero_grad() 92 | 93 | inputs, seg_label = sample['image'], sample['segmentation'] 94 | n,c,h,w = inputs.size() 95 | 96 | pred1 = net(inputs.to(0)) 97 | loss = criterion(pred1, seg_label.to(0)) 98 | loss.backward() 99 | optimizer.step() 100 | 101 | pbar.set_description("loss=%g " % (loss.item())) 102 | pbar.update(1) 103 | time.sleep(0.001) 104 | #print('epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g' % 105 | # (epoch, max_epoch, i_batch, len(dataset)//(cfg.TRAIN_BATCHES), 106 | # itr+1, now_lr, loss.item())) 107 | if cfg.TRAIN_TBLOG and itr%100 == 0: 108 | inputs1 = img_denorm(inputs[-1].cpu().numpy()).astype(np.uint8) 109 | label1 = sample['segmentation'][-1].cpu().numpy() 110 | label_color1 = dataset.label2colormap(label1).transpose((2,0,1)) 111 | 112 | n,c,h,w = inputs.size() 113 | seg_vis1 = torch.argmax(pred1[-1], dim=0).detach().cpu().numpy() 114 | seg_color1 = dataset.label2colormap(seg_vis1).transpose((2,0,1)) 115 | 116 | tblogger.add_scalar('loss', loss.item(), itr) 117 | tblogger.add_scalar('lr', now_lr, itr) 118 | tblogger.add_image('Input', inputs1, itr) 119 | tblogger.add_image('Label', label_color1, itr) 120 | tblogger.add_image('SEG1', seg_color1, itr) 121 | itr += 1 122 | if itr>=max_itr: 123 | break 124 | save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch)) 125 | torch.save(parameter_source.state_dict(), save_path) 126 | print('%s has been saved'%save_path) 127 | remove_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch-1)) 128 | if os.path.exists(remove_path): 129 | os.remove(remove_path) 130 | 131 | save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_itr%d_all.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,cfg.TRAIN_ITERATION)) 132 | torch.save(parameter_source.state_dict(),save_path) 133 | if cfg.TRAIN_TBLOG: 134 | tblogger.close() 135 | print('%s has been saved'%save_path) 136 | writelog(cfg, period) 137 | 138 | def adjust_lr(optimizer, itr, max_itr, lr_init, power): 139 | now_lr = lr_init * (1 - itr/(max_itr+1)) ** power 140 | optimizer.param_groups[0]['lr'] = now_lr 141 | optimizer.param_groups[1]['lr'] = 2*now_lr 142 | optimizer.param_groups[2]['lr'] = 10*now_lr 143 | optimizer.param_groups[3]['lr'] = 20*now_lr 144 | return now_lr 145 | 146 | def get_params(model, key): 147 | for m in model.named_modules(): 148 | if key == 'backbone': 149 | if ('backbone' in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 150 | for p in m[1].parameters(): 151 | yield p 152 | elif key == 'cls': 153 | if ('cls_conv' in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 154 | for p in m[1].parameters(): 155 | yield p 156 | elif key == 'others': 157 | if ('backbone' not in m[0] and 'cls_conv' not in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 158 | for p in m[1].parameters(): 159 | yield p 160 | if __name__ == '__main__': 161 | train_net() 162 | 163 | 164 | -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv2_resnet101/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv2_resnet101/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usr922/wseg/e96b961038e4171c5a49a0378111b47374dc5219/segmentation/experiment/EPS_deeplabv2_resnet101/__init__.py -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv2_resnet101/config.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | import torch 5 | import argparse 6 | import os 7 | import sys 8 | import cv2 9 | import time 10 | 11 | config_dict = { 12 | 'EXP_NAME': 'EPS_deeplabv2_resnet101', 13 | 'GPUS': 1, 14 | 15 | 'DATA_NAME': 'VOCDataset', 16 | 'DATA_YEAR': 2012, 17 | 'DATA_AUG': True, 18 | 'DATA_WORKERS': 4, 19 | 'DATA_MEAN': [0.485, 0.456, 0.406], 20 | 'DATA_STD': [0.229, 0.224, 0.225], 21 | 'DATA_RANDOMCROP': 448, 22 | 'DATA_RANDOMSCALE': [0.5, 1.5], 23 | 'DATA_RANDOM_H': 10, 24 | 'DATA_RANDOM_S': 10, 25 | 'DATA_RANDOM_V': 10, 26 | 'DATA_RANDOMFLIP': 0.5, 27 | 'DATA_PSEUDO_GT': 'your_pseudo_label_dir', 28 | 29 | 'MODEL_NAME': 'deeplabv2', 30 | 'MODEL_BACKBONE': 'resnet101', 31 | 'MODEL_BACKBONE_PRETRAIN': True, 32 | 'MODEL_NUM_CLASSES': 21, 33 | 'MODEL_FREEZEBN': False, 34 | 35 | # 'MODEL_BACKBONE_DILATED': True, 36 | # 'MODEL_BACKBONE_MULTIGRID': False, 37 | # 'MODEL_BACKBONE_DEEPBASE': True, 38 | 'MODEL_BACKBONE_DILATED': True, 39 | 'MODEL_BACKBONE_MULTIGRID': False, 40 | 'MODEL_BACKBONE_DEEPBASE': True, 41 | 'MODEL_SHORTCUT_DIM': 48, 42 | 'MODEL_OUTPUT_STRIDE': 8, 43 | 'MODEL_ASPP_OUTDIM': 256, 44 | 'MODEL_ASPP_HASGLOBAL': True, 45 | 46 | 'TRAIN_LR': 0.001, 47 | 'TRAIN_MOMENTUM': 0.9, 48 | 'TRAIN_WEIGHT_DECAY': 0.0005, 49 | 'TRAIN_BN_MOM': 0.0003, 50 | 'TRAIN_POWER': 0.9, 51 | 'TRAIN_BATCHES': 10, 52 | 'TRAIN_SHUFFLE': True, 53 | 'TRAIN_MINEPOCH': 0, 54 | 'TRAIN_ITERATION': 20000, 55 | 'TRAIN_TBLOG': True, 56 | 57 | 'TEST_MULTISCALE': [0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 58 | 'TEST_FLIP': True, 59 | 'TEST_CRF': True, 60 | 'TEST_BATCHES': 1, 61 | } 62 | 63 | config_dict['ROOT_DIR'] = os.path.abspath(os.path.join(os.path.dirname("__file__"),'..','..')) 64 | config_dict['MODEL_SAVE_DIR'] = os.path.join(config_dict['ROOT_DIR'],'model',config_dict['EXP_NAME']) 65 | config_dict['TRAIN_CKPT'] = None 66 | config_dict['LOG_DIR'] = os.path.join(config_dict['ROOT_DIR'],'log',config_dict['EXP_NAME']) 67 | 68 | # for test, must be updated 69 | config_dict['TEST_CKPT'] = os.path.join(config_dict['ROOT_DIR'], 'your_ckpt.pth') 70 | 71 | sys.path.insert(0, os.path.join(config_dict['ROOT_DIR'], 'lib')) 72 | -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv2_resnet101/test.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | torch.manual_seed(1) # cpu 9 | torch.cuda.manual_seed(1) #gpu 10 | np.random.seed(1) #numpy 11 | random.seed(1) #random and transforms 12 | torch.backends.cudnn.deterministic=True # cudnn 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torchvision 16 | import cv2 17 | import time 18 | 19 | from config import config_dict 20 | from datasets.generateData import generate_dataset 21 | from net.generateNet import generate_net 22 | import torch.optim as optim 23 | from net.sync_batchnorm.replicate import patch_replication_callback 24 | from torch.utils.data import DataLoader 25 | from utils.configuration import Configuration 26 | from utils.finalprocess import writelog 27 | from utils.imutils import img_denorm 28 | from utils.DenseCRF import dense_crf, dense_crf_from_deeplabv2 29 | from utils.test_utils import single_gpu_test 30 | from utils.imutils import onehot 31 | 32 | cfg = Configuration(config_dict, False) 33 | 34 | def ClassLogSoftMax(f, category): 35 | exp = torch.exp(f) 36 | exp_norm = exp/torch.sum(exp*category, dim=1, keepdim=True) 37 | softmax = exp_norm*category 38 | logsoftmax = torch.log(exp_norm)*category 39 | return softmax, logsoftmax 40 | 41 | def test_net(): 42 | # period = 'val' 43 | period = 'test' 44 | dataset = generate_dataset(cfg, period=period, transform='none') 45 | def worker_init_fn(worker_id): 46 | np.random.seed(1 + worker_id) 47 | dataloader = DataLoader(dataset, 48 | batch_size=1, 49 | shuffle=False, 50 | num_workers=cfg.DATA_WORKERS, 51 | worker_init_fn = worker_init_fn) 52 | 53 | net = generate_net(cfg, batchnorm=nn.BatchNorm2d, dilated=cfg.MODEL_BACKBONE_DILATED, 54 | multi_grid=cfg.MODEL_BACKBONE_MULTIGRID, 55 | deep_base=cfg.MODEL_BACKBONE_DEEPBASE) 56 | 57 | print('net initialize') 58 | 59 | if cfg.TEST_CKPT is None: 60 | raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period') 61 | print('start loading model %s'%cfg.TEST_CKPT) 62 | model_dict = torch.load(cfg.TEST_CKPT) 63 | net.load_state_dict(model_dict, strict=False) 64 | 65 | print('Use %d GPU'%cfg.GPUS) 66 | assert torch.cuda.device_count() == cfg.GPUS 67 | device = torch.device('cuda') 68 | net.to(device) 69 | net.eval() 70 | 71 | def prepare_func(sample): 72 | image_msf = [] 73 | for rate in cfg.TEST_MULTISCALE: 74 | inputs_batched = sample['image_%f'%rate] 75 | image_msf.append(inputs_batched) 76 | if cfg.TEST_FLIP: 77 | image_msf.append(torch.flip(inputs_batched,[3])) 78 | return image_msf 79 | 80 | def inference_func(model, img): 81 | seg = model(img) 82 | return seg 83 | 84 | def collect_func(result_list, sample): 85 | [batch, channel, height, width] = sample['image'].size() 86 | for i in range(len(result_list)): 87 | result_seg = F.interpolate(result_list[i], (height, width), mode='bilinear', align_corners=True) 88 | if cfg.TEST_FLIP and i % 2 == 1: 89 | result_seg = torch.flip(result_seg, [3]) 90 | result_list[i] = result_seg 91 | prob_seg = torch.cat(result_list, dim=0) 92 | prob_seg = F.softmax(torch.mean(prob_seg, dim=0, keepdim=True),dim=1)[0] 93 | 94 | 95 | if cfg.TEST_CRF: 96 | prob = prob_seg.cpu().numpy() 97 | img_batched = img_denorm(sample['image'][0].cpu().numpy()).astype(np.uint8) 98 | # TODO: crf更改 99 | # prob = dense_crf(prob, img_batched, n_classes=cfg.MODEL_NUM_CLASSES, n_iters=1) 100 | prob = dense_crf_from_deeplabv2(prob, img_batched) 101 | prob_seg = torch.from_numpy(prob.astype(np.float32)) 102 | 103 | result = torch.argmax(prob_seg, dim=0, keepdim=False).cpu().numpy() 104 | return result 105 | 106 | def save_step_func(result_sample): 107 | dataset.save_result([result_sample], cfg.MODEL_NAME) 108 | 109 | result_list = single_gpu_test(net, dataloader, prepare_func=prepare_func, inference_func=inference_func, collect_func=collect_func, save_step_func=save_step_func) 110 | resultlog = dataset.do_python_eval(cfg.MODEL_NAME) 111 | print('Test finished') 112 | writelog(cfg, period, metric=resultlog) 113 | 114 | if __name__ == '__main__': 115 | test_net() 116 | 117 | 118 | -------------------------------------------------------------------------------- /segmentation/experiment/EPS_deeplabv2_resnet101/train.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | torch.manual_seed(1) # cpu 9 | torch.cuda.manual_seed_all(1) #gpu 10 | np.random.seed(1) #numpy 11 | random.seed(1) #random and transforms 12 | torch.backends.cudnn.deterministic=True # cudnn 13 | import torchvision 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torchvision.transforms as transforms 17 | import os 18 | import sys 19 | import time 20 | 21 | from config import config_dict 22 | from datasets.generateData import generate_dataset 23 | from net.generateNet import generate_net 24 | import torch.optim as optim 25 | from PIL import Image 26 | from tensorboardX import SummaryWriter 27 | from torch.utils.data import DataLoader 28 | from net.sync_batchnorm.replicate import patch_replication_callback 29 | from utils.configuration import Configuration 30 | from utils.finalprocess import writelog 31 | from utils.imutils import img_denorm 32 | from net.sync_batchnorm import SynchronizedBatchNorm2d 33 | from utils.visualization import generate_vis, max_norm 34 | from tqdm import tqdm 35 | 36 | cfg = Configuration(config_dict) 37 | 38 | def train_net(): 39 | period = 'train' 40 | transform = 'weak' 41 | dataset = generate_dataset(cfg, period=period, transform=transform) 42 | def worker_init_fn(worker_id): 43 | np.random.seed(1 + worker_id) 44 | dataloader = DataLoader(dataset, 45 | batch_size=cfg.TRAIN_BATCHES, 46 | shuffle=cfg.TRAIN_SHUFFLE, 47 | num_workers=cfg.DATA_WORKERS, 48 | pin_memory=True, 49 | drop_last=True, 50 | worker_init_fn=worker_init_fn) 51 | 52 | net = generate_net(cfg, batchnorm=nn.BatchNorm2d) 53 | if cfg.TRAIN_CKPT: 54 | net.load_state_dict(torch.load(cfg.TRAIN_CKPT),strict=True) 55 | print('load pretrained model') 56 | if cfg.TRAIN_TBLOG: 57 | from tensorboardX import SummaryWriter 58 | # Set the Tensorboard logger 59 | tblogger = SummaryWriter(cfg.LOG_DIR) 60 | 61 | print('Use %d GPU'%cfg.GPUS) 62 | device = torch.device(0) 63 | if cfg.GPUS > 1: 64 | net = nn.DataParallel(net) 65 | patch_replication_callback(net) 66 | parameter_source = net.module 67 | else: 68 | parameter_source = net 69 | net.to(device) 70 | parameter_groups = parameter_source.get_parameter_groups() 71 | optimizer = optim.SGD( 72 | params = [ 73 | {'params': parameter_groups[0], 'lr': cfg.TRAIN_LR, 'weight_decay': cfg.TRAIN_WEIGHT_DECAY}, 74 | {'params': parameter_groups[1], 'lr': 2*cfg.TRAIN_LR, 'weight_decay': 0}, 75 | {'params': parameter_groups[2], 'lr': 10*cfg.TRAIN_LR, 'weight_decay': cfg.TRAIN_WEIGHT_DECAY}, 76 | {'params': parameter_groups[3], 'lr': 20*cfg.TRAIN_LR, 'weight_decay': 0}, 77 | ], 78 | momentum=cfg.TRAIN_MOMENTUM, 79 | weight_decay=cfg.TRAIN_WEIGHT_DECAY 80 | ) 81 | itr = cfg.TRAIN_MINEPOCH * len(dataset)//(cfg.TRAIN_BATCHES) 82 | max_itr = cfg.TRAIN_ITERATION 83 | max_epoch = max_itr*(cfg.TRAIN_BATCHES)//len(dataset)+1 84 | tblogger = SummaryWriter(cfg.LOG_DIR) 85 | criterion = nn.CrossEntropyLoss(ignore_index=255) 86 | with tqdm(total=max_itr) as pbar: 87 | for epoch in range(cfg.TRAIN_MINEPOCH, max_epoch): 88 | for i_batch, sample in enumerate(dataloader): 89 | 90 | now_lr = adjust_lr(optimizer, itr, max_itr, cfg.TRAIN_LR, cfg.TRAIN_POWER) 91 | optimizer.zero_grad() 92 | 93 | inputs, seg_label = sample['image'], sample['segmentation'] 94 | n,c,h,w = inputs.size() 95 | 96 | pred1 = net(inputs.to(0)) 97 | loss = criterion(pred1, seg_label.to(0)) 98 | loss.backward() 99 | optimizer.step() 100 | 101 | pbar.set_description("loss=%g " % (loss.item())) 102 | pbar.update(1) 103 | time.sleep(0.001) 104 | #print('epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g' % 105 | # (epoch, max_epoch, i_batch, len(dataset)//(cfg.TRAIN_BATCHES), 106 | # itr+1, now_lr, loss.item())) 107 | if cfg.TRAIN_TBLOG and itr%100 == 0: 108 | inputs1 = img_denorm(inputs[-1].cpu().numpy()).astype(np.uint8) 109 | label1 = sample['segmentation'][-1].cpu().numpy() 110 | label_color1 = dataset.label2colormap(label1).transpose((2,0,1)) 111 | 112 | n,c,h,w = inputs.size() 113 | seg_vis1 = torch.argmax(pred1[-1], dim=0).detach().cpu().numpy() 114 | seg_color1 = dataset.label2colormap(seg_vis1).transpose((2,0,1)) 115 | 116 | tblogger.add_scalar('loss', loss.item(), itr) 117 | tblogger.add_scalar('lr', now_lr, itr) 118 | tblogger.add_image('Input', inputs1, itr) 119 | tblogger.add_image('Label', label_color1, itr) 120 | tblogger.add_image('SEG1', seg_color1, itr) 121 | itr += 1 122 | if itr>=max_itr: 123 | break 124 | save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch)) 125 | torch.save(parameter_source.state_dict(), save_path) 126 | print('%s has been saved'%save_path) 127 | remove_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch-1)) 128 | if os.path.exists(remove_path): 129 | os.remove(remove_path) 130 | 131 | save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_itr%d_all.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,cfg.TRAIN_ITERATION)) 132 | torch.save(parameter_source.state_dict(),save_path) 133 | if cfg.TRAIN_TBLOG: 134 | tblogger.close() 135 | print('%s has been saved'%save_path) 136 | writelog(cfg, period) 137 | 138 | def adjust_lr(optimizer, itr, max_itr, lr_init, power): 139 | now_lr = lr_init * (1 - itr/(max_itr+1)) ** power 140 | optimizer.param_groups[0]['lr'] = now_lr 141 | optimizer.param_groups[1]['lr'] = 2*now_lr 142 | optimizer.param_groups[2]['lr'] = 10*now_lr 143 | optimizer.param_groups[3]['lr'] = 20*now_lr 144 | return now_lr 145 | 146 | def get_params(model, key): 147 | for m in model.named_modules(): 148 | if key == 'backbone': 149 | if ('backbone' in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 150 | for p in m[1].parameters(): 151 | yield p 152 | elif key == 'cls': 153 | if ('cls_conv' in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 154 | for p in m[1].parameters(): 155 | yield p 156 | elif key == 'others': 157 | if ('backbone' not in m[0] and 'cls_conv' not in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 158 | for p in m[1].parameters(): 159 | yield p 160 | if __name__ == '__main__': 161 | train_net() 162 | 163 | 164 | -------------------------------------------------------------------------------- /segmentation/experiment/SEAM_deeplabv1_resnet38/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /segmentation/experiment/SEAM_deeplabv1_resnet38/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usr922/wseg/e96b961038e4171c5a49a0378111b47374dc5219/segmentation/experiment/SEAM_deeplabv1_resnet38/__init__.py -------------------------------------------------------------------------------- /segmentation/experiment/SEAM_deeplabv1_resnet38/config.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | import torch 5 | import argparse 6 | import os 7 | import sys 8 | import cv2 9 | import time 10 | 11 | config_dict = { 12 | 'EXP_NAME': 'SEAM_deeplabv1_resnet38', 13 | 'GPUS': 1, 14 | 15 | 'DATA_NAME': 'VOCDataset', 16 | 'DATA_YEAR': 2012, 17 | 'DATA_AUG': True, 18 | 'DATA_WORKERS': 4, 19 | 'DATA_MEAN': [0.485, 0.456, 0.406], 20 | 'DATA_STD': [0.229, 0.224, 0.225], 21 | 'DATA_RANDOMCROP': 448, 22 | 'DATA_RANDOMSCALE': [0.5, 1.5], 23 | 'DATA_RANDOM_H': 10, 24 | 'DATA_RANDOM_S': 10, 25 | 'DATA_RANDOM_V': 10, 26 | 'DATA_RANDOMFLIP': 0.5, 27 | 'DATA_PSEUDO_GT': 'your_pseudo_label_dir', 28 | 29 | 'MODEL_NAME': 'deeplabv1', 30 | 'MODEL_BACKBONE': 'resnet38', 31 | 'MODEL_BACKBONE_PRETRAIN': True, 32 | 'MODEL_NUM_CLASSES': 21, 33 | 'MODEL_FREEZEBN': False, 34 | #'MODEL_BACKBONE_DILATED': True, 35 | #'MODEL_BACKBONE_MULTIGRID': False, 36 | #'MODEL_BACKBONE_DEEPBASE': True, 37 | 38 | 'TRAIN_LR': 0.001, 39 | 'TRAIN_MOMENTUM': 0.9, 40 | 'TRAIN_WEIGHT_DECAY': 0.0005, 41 | 'TRAIN_BN_MOM': 0.0003, 42 | 'TRAIN_POWER': 0.9, 43 | 'TRAIN_BATCHES': 10, 44 | 'TRAIN_SHUFFLE': True, 45 | 'TRAIN_MINEPOCH': 0, 46 | 'TRAIN_ITERATION': 20000, 47 | 'TRAIN_TBLOG': True, 48 | 49 | 'TEST_MULTISCALE': [0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 50 | 'TEST_FLIP': True, 51 | 'TEST_CRF': True, 52 | 'TEST_BATCHES': 1, 53 | } 54 | 55 | config_dict['ROOT_DIR'] = os.path.abspath(os.path.join(os.path.dirname("__file__"),'..','..')) 56 | config_dict['MODEL_SAVE_DIR'] = os.path.join(config_dict['ROOT_DIR'],'model',config_dict['EXP_NAME']) 57 | config_dict['TRAIN_CKPT'] = None 58 | config_dict['LOG_DIR'] = os.path.join(config_dict['ROOT_DIR'],'log',config_dict['EXP_NAME']) 59 | # for test, must be updated 60 | config_dict['TEST_CKPT'] = os.path.join(config_dict['ROOT_DIR'], 'your_ckpt.pth') 61 | 62 | sys.path.insert(0, os.path.join(config_dict['ROOT_DIR'], 'lib')) 63 | -------------------------------------------------------------------------------- /segmentation/experiment/SEAM_deeplabv1_resnet38/test.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | torch.manual_seed(1) # cpu 9 | torch.cuda.manual_seed(1) #gpu 10 | np.random.seed(1) #numpy 11 | random.seed(1) #random and transforms 12 | torch.backends.cudnn.deterministic=True # cudnn 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torchvision 16 | import cv2 17 | import time 18 | 19 | from config import config_dict 20 | from datasets.generateData import generate_dataset 21 | from net.generateNet import generate_net 22 | import torch.optim as optim 23 | from net.sync_batchnorm.replicate import patch_replication_callback 24 | from torch.utils.data import DataLoader 25 | from utils.configuration import Configuration 26 | from utils.finalprocess import writelog 27 | from utils.imutils import img_denorm 28 | from utils.DenseCRF import dense_crf, dense_crf_from_deeplabv2 29 | from utils.test_utils import single_gpu_test 30 | from utils.imutils import onehot 31 | 32 | cfg = Configuration(config_dict, False) 33 | 34 | def ClassLogSoftMax(f, category): 35 | exp = torch.exp(f) 36 | exp_norm = exp/torch.sum(exp*category, dim=1, keepdim=True) 37 | softmax = exp_norm*category 38 | logsoftmax = torch.log(exp_norm)*category 39 | return softmax, logsoftmax 40 | 41 | def test_net(): 42 | # period = 'val' 43 | period = 'val' 44 | dataset = generate_dataset(cfg, period=period, transform='none') 45 | def worker_init_fn(worker_id): 46 | np.random.seed(1 + worker_id) 47 | dataloader = DataLoader(dataset, 48 | batch_size=1, 49 | shuffle=False, 50 | num_workers=cfg.DATA_WORKERS, 51 | worker_init_fn = worker_init_fn) 52 | 53 | net = generate_net(cfg, batchnorm=nn.BatchNorm2d) 54 | # dilated=cfg.MODEL_BACKBONE_DILATED, 55 | # multi_grid=cfg.MODEL_BACKBONE_MULTIGRID, 56 | # deep_base=cfg.MODEL_BACKBONE_DEEPBASE) 57 | print('net initialize') 58 | 59 | if cfg.TEST_CKPT is None: 60 | raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period') 61 | print('start loading model %s'%cfg.TEST_CKPT) 62 | model_dict = torch.load(cfg.TEST_CKPT) 63 | net.load_state_dict(model_dict, strict=False) 64 | 65 | print('Use %d GPU'%cfg.GPUS) 66 | assert torch.cuda.device_count() == cfg.GPUS 67 | device = torch.device('cuda') 68 | net.to(device) 69 | net.eval() 70 | 71 | def prepare_func(sample): 72 | image_msf = [] 73 | for rate in cfg.TEST_MULTISCALE: 74 | inputs_batched = sample['image_%f'%rate] 75 | image_msf.append(inputs_batched) 76 | if cfg.TEST_FLIP: 77 | image_msf.append(torch.flip(inputs_batched,[3])) 78 | return image_msf 79 | 80 | def inference_func(model, img): 81 | seg = model(img) 82 | return seg 83 | 84 | def collect_func(result_list, sample): 85 | [batch, channel, height, width] = sample['image'].size() 86 | for i in range(len(result_list)): 87 | result_seg = F.interpolate(result_list[i], (height, width), mode='bilinear', align_corners=True) 88 | if cfg.TEST_FLIP and i % 2 == 1: 89 | result_seg = torch.flip(result_seg, [3]) 90 | result_list[i] = result_seg 91 | prob_seg = torch.cat(result_list, dim=0) 92 | prob_seg = F.softmax(torch.mean(prob_seg, dim=0, keepdim=True),dim=1)[0] 93 | 94 | 95 | if cfg.TEST_CRF: 96 | prob = prob_seg.cpu().numpy() 97 | img_batched = img_denorm(sample['image'][0].cpu().numpy()).astype(np.uint8) 98 | # TODO: crf更改 99 | # prob = dense_crf(prob, img_batched, n_classes=cfg.MODEL_NUM_CLASSES, n_iters=1) 100 | prob = dense_crf_from_deeplabv2(prob, img_batched) 101 | prob_seg = torch.from_numpy(prob.astype(np.float32)) 102 | 103 | result = torch.argmax(prob_seg, dim=0, keepdim=False).cpu().numpy() 104 | return result 105 | 106 | def save_step_func(result_sample): 107 | dataset.save_result([result_sample], cfg.MODEL_NAME) 108 | 109 | result_list = single_gpu_test(net, dataloader, prepare_func=prepare_func, inference_func=inference_func, collect_func=collect_func, save_step_func=save_step_func) 110 | resultlog = dataset.do_python_eval(cfg.MODEL_NAME) 111 | print('Test finished') 112 | writelog(cfg, period, metric=resultlog) 113 | 114 | if __name__ == '__main__': 115 | test_net() 116 | 117 | 118 | -------------------------------------------------------------------------------- /segmentation/experiment/SEAM_deeplabv1_resnet38/train.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | torch.manual_seed(1) # cpu 9 | torch.cuda.manual_seed_all(1) #gpu 10 | np.random.seed(1) #numpy 11 | random.seed(1) #random and transforms 12 | torch.backends.cudnn.deterministic=True # cudnn 13 | import torchvision 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torchvision.transforms as transforms 17 | import os 18 | import sys 19 | import time 20 | 21 | from config import config_dict 22 | from datasets.generateData import generate_dataset 23 | from net.generateNet import generate_net 24 | import torch.optim as optim 25 | from PIL import Image 26 | from tensorboardX import SummaryWriter 27 | from torch.utils.data import DataLoader 28 | from net.sync_batchnorm.replicate import patch_replication_callback 29 | from utils.configuration import Configuration 30 | from utils.finalprocess import writelog 31 | from utils.imutils import img_denorm 32 | from net.sync_batchnorm import SynchronizedBatchNorm2d 33 | from utils.visualization import generate_vis, max_norm 34 | from tqdm import tqdm 35 | 36 | cfg = Configuration(config_dict) 37 | 38 | def train_net(): 39 | period = 'train' 40 | transform = 'weak' 41 | dataset = generate_dataset(cfg, period=period, transform=transform) 42 | def worker_init_fn(worker_id): 43 | np.random.seed(1 + worker_id) 44 | dataloader = DataLoader(dataset, 45 | batch_size=cfg.TRAIN_BATCHES, 46 | shuffle=cfg.TRAIN_SHUFFLE, 47 | num_workers=cfg.DATA_WORKERS, 48 | pin_memory=True, 49 | drop_last=True, 50 | worker_init_fn=worker_init_fn) 51 | 52 | net = generate_net(cfg, batchnorm=nn.BatchNorm2d) 53 | if cfg.TRAIN_CKPT: 54 | net.load_state_dict(torch.load(cfg.TRAIN_CKPT),strict=True) 55 | print('load pretrained model') 56 | if cfg.TRAIN_TBLOG: 57 | from tensorboardX import SummaryWriter 58 | # Set the Tensorboard logger 59 | tblogger = SummaryWriter(cfg.LOG_DIR) 60 | 61 | print('Use %d GPU'%cfg.GPUS) 62 | device = torch.device(0) 63 | if cfg.GPUS > 1: 64 | net = nn.DataParallel(net) 65 | patch_replication_callback(net) 66 | parameter_source = net.module 67 | else: 68 | parameter_source = net 69 | net.to(device) 70 | parameter_groups = parameter_source.get_parameter_groups() 71 | optimizer = optim.SGD( 72 | params = [ 73 | {'params': parameter_groups[0], 'lr': cfg.TRAIN_LR, 'weight_decay': cfg.TRAIN_WEIGHT_DECAY}, 74 | {'params': parameter_groups[1], 'lr': 2*cfg.TRAIN_LR, 'weight_decay': 0}, 75 | {'params': parameter_groups[2], 'lr': 10*cfg.TRAIN_LR, 'weight_decay': cfg.TRAIN_WEIGHT_DECAY}, 76 | {'params': parameter_groups[3], 'lr': 20*cfg.TRAIN_LR, 'weight_decay': 0}, 77 | ], 78 | momentum=cfg.TRAIN_MOMENTUM, 79 | weight_decay=cfg.TRAIN_WEIGHT_DECAY 80 | ) 81 | itr = cfg.TRAIN_MINEPOCH * len(dataset)//(cfg.TRAIN_BATCHES) 82 | max_itr = cfg.TRAIN_ITERATION 83 | max_epoch = max_itr*(cfg.TRAIN_BATCHES)//len(dataset)+1 84 | tblogger = SummaryWriter(cfg.LOG_DIR) 85 | criterion = nn.CrossEntropyLoss(ignore_index=255) 86 | with tqdm(total=max_itr) as pbar: 87 | for epoch in range(cfg.TRAIN_MINEPOCH, max_epoch): 88 | for i_batch, sample in enumerate(dataloader): 89 | 90 | now_lr = adjust_lr(optimizer, itr, max_itr, cfg.TRAIN_LR, cfg.TRAIN_POWER) 91 | optimizer.zero_grad() 92 | 93 | inputs, seg_label = sample['image'], sample['segmentation'] 94 | n,c,h,w = inputs.size() 95 | 96 | pred1 = net(inputs.to(0)) 97 | loss = criterion(pred1, seg_label.to(0)) 98 | loss.backward() 99 | optimizer.step() 100 | 101 | pbar.set_description("loss=%g " % (loss.item())) 102 | pbar.update(1) 103 | time.sleep(0.001) 104 | #print('epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g' % 105 | # (epoch, max_epoch, i_batch, len(dataset)//(cfg.TRAIN_BATCHES), 106 | # itr+1, now_lr, loss.item())) 107 | if cfg.TRAIN_TBLOG and itr%100 == 0: 108 | inputs1 = img_denorm(inputs[-1].cpu().numpy()).astype(np.uint8) 109 | label1 = sample['segmentation'][-1].cpu().numpy() 110 | label_color1 = dataset.label2colormap(label1).transpose((2,0,1)) 111 | 112 | n,c,h,w = inputs.size() 113 | seg_vis1 = torch.argmax(pred1[-1], dim=0).detach().cpu().numpy() 114 | seg_color1 = dataset.label2colormap(seg_vis1).transpose((2,0,1)) 115 | 116 | tblogger.add_scalar('loss', loss.item(), itr) 117 | tblogger.add_scalar('lr', now_lr, itr) 118 | tblogger.add_image('Input', inputs1, itr) 119 | tblogger.add_image('Label', label_color1, itr) 120 | tblogger.add_image('SEG1', seg_color1, itr) 121 | itr += 1 122 | if itr>=max_itr: 123 | break 124 | save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch)) 125 | torch.save(parameter_source.state_dict(), save_path) 126 | print('%s has been saved'%save_path) 127 | remove_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch-1)) 128 | if os.path.exists(remove_path): 129 | os.remove(remove_path) 130 | 131 | save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_itr%d_all.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,cfg.TRAIN_ITERATION)) 132 | torch.save(parameter_source.state_dict(),save_path) 133 | if cfg.TRAIN_TBLOG: 134 | tblogger.close() 135 | print('%s has been saved'%save_path) 136 | writelog(cfg, period) 137 | 138 | def adjust_lr(optimizer, itr, max_itr, lr_init, power): 139 | now_lr = lr_init * (1 - itr/(max_itr+1)) ** power 140 | optimizer.param_groups[0]['lr'] = now_lr 141 | optimizer.param_groups[1]['lr'] = 2*now_lr 142 | optimizer.param_groups[2]['lr'] = 10*now_lr 143 | optimizer.param_groups[3]['lr'] = 20*now_lr 144 | return now_lr 145 | 146 | def get_params(model, key): 147 | for m in model.named_modules(): 148 | if key == 'backbone': 149 | if ('backbone' in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 150 | for p in m[1].parameters(): 151 | yield p 152 | elif key == 'cls': 153 | if ('cls_conv' in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 154 | for p in m[1].parameters(): 155 | yield p 156 | elif key == 'others': 157 | if ('backbone' not in m[0] and 'cls_conv' not in m[0]) and isinstance(m[1], (nn.Conv2d, SynchronizedBatchNorm2d, nn.BatchNorm2d, nn.InstanceNorm2d)): 158 | for p in m[1].parameters(): 159 | yield p 160 | if __name__ == '__main__': 161 | train_net() 162 | 163 | 164 | -------------------------------------------------------------------------------- /segmentation/lib/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usr922/wseg/e96b961038e4171c5a49a0378111b47374dc5219/segmentation/lib/.DS_Store -------------------------------------------------------------------------------- /segmentation/lib/datasets/ADE20KDataset.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | from __future__ import print_function, division 6 | import os 7 | import json 8 | import torch 9 | from torch.utils.data import Dataset 10 | import cv2 11 | #from scipy.misc import imread 12 | import numpy as np 13 | from datasets.transform import * 14 | from datasets.metric import AverageMeter, accuracy, intersectionAndUnion 15 | from utils.registry import DATASETS 16 | from datasets.BaseDataset import BaseDataset 17 | 18 | @DATASETS.register_module 19 | class ADE20KDataset(BaseDataset): 20 | def __init__(self, cfg, period, transform='none'): 21 | super(ADE20KDataset, self).__init__(cfg, period, transform) 22 | assert(self.period != 'test') 23 | self.root_dir = os.path.join(cfg.ROOT_DIR,'data','ADEChallengeData2016') 24 | self.dataset_dir = self.root_dir 25 | self.img_dir = os.path.join(self.dataset_dir, 'images') 26 | self.seg_dir = os.path.join(self.dataset_dir, 'annotations') 27 | self.rst_dir = os.path.join(self.dataset_dir,'result') 28 | if cfg.DATA_PSEUDO_GT: 29 | self.pseudo_gt_dir = cfg.DATA_PSEUDO_GT 30 | else: 31 | self.pseudo_gt_dir = os.path.join(self.root_dir,'pseudo_gt') 32 | self.num_categories = 151 33 | assert(self.num_categories == self.cfg.MODEL_NUM_CLASSES) 34 | if self.period == 'train': 35 | self.name_list = ['ADE_train_%08d'%(i+1) for i in range(20210)] 36 | elif self.period == 'val': 37 | self.name_list = ['ADE_val_%08d'%(i+1) for i in range(2000)] 38 | else: 39 | raise ValueError('self.period is not \'train\' or \'val\'') 40 | 41 | def __len__(self): 42 | return len(self.name_list) 43 | 44 | def load_name(self, idx): 45 | name = self.name_list[idx] 46 | return name 47 | 48 | def load_image(self, idx): 49 | name = self.name_list[idx] 50 | if self.period == 'train': 51 | img_file = os.path.join(self.img_dir,'training',name+'.jpg') 52 | elif self.period == 'val': 53 | img_file = os.path.join(self.img_dir,'validation',name+'.jpg') 54 | else: 55 | raise ValueError('self.period is not \'train\' or \'val\'') 56 | image = cv2.imread(img_file) 57 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 58 | return image_rgb 59 | 60 | def load_segmentation(self, idx): 61 | name = self.name_list[idx] 62 | if self.period == 'train': 63 | seg_file = os.path.join(self.seg_dir,'training',name+'.png') 64 | elif self.period == 'val': 65 | seg_file = os.path.join(self.seg_dir,'validation',name+'.png') 66 | else: 67 | raise ValueError('self.period is not \'train\' or \'val\'') 68 | segmentation = np.array(Image.open(seg_file)) 69 | assert np.min(segmentation)>=0 70 | assert np.max(segmentation)=self.cfg.MODEL_NUM_CLASSES] = 0 72 | #seg += 1 73 | return segmentation 74 | 75 | def load_pseudo_segmentation(self, idx): 76 | name = self.name_list[idx] 77 | seg_file = os.path.join(self.pseudo_gt_dir,name+'.png') 78 | segmentation = np.array(Image.open(seg_file)) 79 | return segmentation 80 | 81 | def save_pseudo_gt(self, result_list): 82 | """Save pseudo gt 83 | 84 | Args: 85 | result_list(list of dict): [{'name':name1, 'predict':predict_seg1},{...},...] 86 | 87 | """ 88 | folder_path = self.pseudo_gt_dir 89 | if not os.path.exists(folder_path): 90 | os.makedirs(folder_path) 91 | for sample in result_list: 92 | file_path = os.path.join(folder_path, '%s.png'%(sample['name'])) 93 | cv2.imwrite(file_path, sample['predict']) 94 | print('%s saved'%(file_path)) 95 | 96 | def label2colormap(self, label): 97 | m = label.astype(np.uint8) 98 | r,c = m.shape 99 | cmap = np.zeros((r,c,3), dtype=np.uint8) 100 | cmap[:,:,0] = (m&1)<<7 | (m&8)<<3 | (m&64)>>1 101 | cmap[:,:,1] = (m&2)<<6 | (m&16)<<2 | (m&128)>>2 102 | cmap[:,:,2] = (m&4)<<5 | (m&32)<<1 103 | return cmap 104 | 105 | def save_result(self, result_list, model_id): 106 | folder_path = os.path.join(self.rst_dir,'%s'%model_id) 107 | if not os.path.exists(folder_path): 108 | os.makedirs(folder_path) 109 | for sample in result_list: 110 | file_path = os.path.join(folder_path,'%s.png'%sample['name']) 111 | ''' 112 | 113 | ATTENTION!!! 114 | 115 | predict label start from 0 or -1 ????? 116 | 117 | DO NOT have operation here!!! 118 | 119 | 120 | ''' 121 | cv2.imwrite(file_path, sample['predict']) 122 | 123 | def do_python_eval(self, model_id): 124 | folder_path = os.path.join(self.rst_dir,'%s'%model_id) 125 | 126 | acc_meter = AverageMeter() 127 | intersection_meter = AverageMeter() 128 | union_meter = AverageMeter() 129 | for name in self.name_list: 130 | predict_path = os.path.join(folder_path,'%s.png'%name) 131 | if 'train' in name: 132 | label_path = os.path.join(self.seg_dir, 'training', name+'.png') 133 | elif 'val' in name: 134 | label_path = os.path.join(self.seg_dir, 'validation', name+'.png') 135 | else: 136 | raise ValueError('self.period is not \'train\' or \'val\'') 137 | 138 | #predict = imread(predict_path) 139 | #label = imread(label_path) 140 | predict = np.array(Image.open(predict_path)) 141 | segmentation = np.array(Image.open(label_pth)) 142 | 143 | acc, pix = accuracy(predict, label) 144 | intersection, union = intersectionAndUnion(predict, label, self.num_categories) 145 | acc_meter.update(acc, pix) 146 | intersection_meter.update(intersection) 147 | union_meter.update(union) 148 | 149 | iou = intersection_meter.sum / (union_meter.sum + 1e-10) 150 | loglist = {} 151 | for i, _iou in enumerate(iou): 152 | print('class [{}], IoU: {}'.format(i, _iou)) 153 | loglist['class[{}]'.format(i)] = _iou 154 | print('[Eval Summary]:') 155 | print('Mean IoU: {:.4}, Accuracy: {:.2f}%'.format(iou.mean(), acc_meter.average()*100)) 156 | loglist['mIoU'] = iou.mean() 157 | loglist['accuracy'] = acc_meter.average()*100 158 | return loglist 159 | -------------------------------------------------------------------------------- /segmentation/lib/datasets/BaseDataset.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | from __future__ import print_function, division 6 | import os 7 | import torch 8 | import pandas as pd 9 | import cv2 10 | import multiprocessing 11 | from skimage import io 12 | from PIL import Image 13 | import numpy as np 14 | from torch.utils.data import Dataset 15 | from datasets.transform import * 16 | from utils.imutils import * 17 | from utils.registry import DATASETS 18 | 19 | #@DATASETS.register_module 20 | class BaseDataset(Dataset): 21 | def __init__(self, cfg, period, transform='none'): 22 | super(BaseDataset, self).__init__() 23 | self.cfg = cfg 24 | self.period = period 25 | self.transform = transform 26 | if 'train' not in self.period: 27 | assert self.transform == 'none' 28 | self.num_categories = None 29 | self.totensor = ToTensor() 30 | self.imagenorm = ImageNorm(cfg.DATA_MEAN, cfg.DATA_STD) 31 | 32 | if self.transform != 'none': 33 | if cfg.DATA_RANDOMCROP > 0: 34 | self.randomcrop = RandomCrop(cfg.DATA_RANDOMCROP) 35 | if cfg.DATA_RANDOMSCALE != 1: 36 | self.randomscale = RandomScale(cfg.DATA_RANDOMSCALE) 37 | if cfg.DATA_RANDOMFLIP > 0: 38 | self.randomflip = RandomFlip(cfg.DATA_RANDOMFLIP) 39 | if cfg.DATA_RANDOM_H > 0 or cfg.DATA_RANDOM_S > 0 or cfg.DATA_RANDOM_V > 0: 40 | self.randomhsv = RandomHSV(cfg.DATA_RANDOM_H, cfg.DATA_RANDOM_S, cfg.DATA_RANDOM_V) 41 | else: 42 | self.multiscale = Multiscale(self.cfg.TEST_MULTISCALE) 43 | 44 | 45 | def __getitem__(self, idx): 46 | sample = self.__sample_generate__(idx) 47 | 48 | if 'segmentation' in sample.keys(): 49 | sample['mask'] = sample['segmentation'] < self.num_categories 50 | t = sample['segmentation'].copy() 51 | t[t >= self.num_categories] = 0 52 | sample['segmentation_onehot']=onehot(t,self.num_categories) 53 | return self.totensor(sample) 54 | 55 | def __sample_generate__(self, idx, split_idx=0): 56 | name = self.load_name(idx) 57 | image = self.load_image(idx) 58 | r,c,_ = image.shape 59 | sample = {'image': image, 'name': name, 'row': r, 'col': c} 60 | 61 | if 'test' in self.period: 62 | return self.__transform__(sample) 63 | elif self.cfg.DATA_PSEUDO_GT and idx>=split_idx and 'train' in self.period: 64 | segmentation = self.load_pseudo_segmentation(idx) 65 | else: 66 | segmentation = self.load_segmentation(idx) 67 | sample['segmentation'] = segmentation 68 | t = sample['segmentation'].copy() 69 | t[t >= self.num_categories] = 0 70 | sample['category'] = seg2cls(t,self.num_categories) 71 | sample['category_copypaste'] = np.zeros(sample['category'].shape) 72 | 73 | #if self.transform == 'none' and self.cfg.DATA_FEATURE_DIR: 74 | # feature = self.load_feature(idx) 75 | # sample['feature'] = feature 76 | return self.__transform__(sample) 77 | 78 | def __transform__(self, sample): 79 | if self.transform == 'weak': 80 | sample = self.__weak_augment__(sample) 81 | elif self.transform == 'strong': 82 | sample = self.__strong_augment__(sample) 83 | else: 84 | sample = self.imagenorm(sample) 85 | sample = self.multiscale(sample) 86 | return sample 87 | 88 | def __weak_augment__(self, sample): 89 | if self.cfg.DATA_RANDOM_H>0 or self.cfg.DATA_RANDOM_S>0 or self.cfg.DATA_RANDOM_V>0: 90 | sample = self.randomhsv(sample) 91 | if self.cfg.DATA_RANDOMFLIP > 0: 92 | sample = self.randomflip(sample) 93 | if self.cfg.DATA_RANDOMSCALE != 1: 94 | sample = self.randomscale(sample) 95 | sample = self.imagenorm(sample) 96 | if self.cfg.DATA_RANDOMCROP > 0: 97 | sample = self.randomcrop(sample) 98 | return sample 99 | 100 | def __strong_augment__(self, sample): 101 | raise NotImplementedError 102 | 103 | def __len__(self): 104 | raise NotImplementedError 105 | 106 | def load_name(self, idx): 107 | raise NotImplementedError 108 | 109 | def load_image(self, idx): 110 | raise NotImplementedError 111 | 112 | def load_segmentation(self, idx): 113 | raise NotImplementedError 114 | 115 | def load_pseudo_segmentation(self, idx): 116 | raise NotImplementedError 117 | 118 | def load_feature(self, idx): 119 | raise NotImplementedError 120 | 121 | def save_result(self, result_list, model_id): 122 | raise NotImplementedError 123 | 124 | def save_pseudo_gt(self, result_list, level=None): 125 | raise NotImplementedError 126 | 127 | def do_python_eval(self, model_id): 128 | raise NotImplementedError 129 | -------------------------------------------------------------------------------- /segmentation/lib/datasets/CityscapesDataset.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | from __future__ import print_function, division 6 | import os, glob 7 | import torch 8 | import pandas as pd 9 | import cv2 10 | import multiprocessing 11 | from skimage import io 12 | from PIL import Image 13 | import numpy as np 14 | from torch.utils.data import Dataset 15 | from datasets.transform import * 16 | from utils.imutils import * 17 | from collections import namedtuple 18 | from utils.registry import DATASETS 19 | from datasets.BaseDataset import BaseDataset 20 | 21 | Label = namedtuple( 'Label', ['name','id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval', 'color',]) 22 | @DATASETS.register_module 23 | class CityscapesDataset(BaseDataset): 24 | def __init__(self, cfg, period, transform=False): 25 | super(CityscapesDataset, self).__init__(cfg, period, transform) 26 | self.root_dir = os.path.join(cfg.ROOT_DIR,'data','cityscapes') 27 | self.dataset_dir = self.root_dir 28 | self.rst_dir = os.path.join(self.dataset_dir,'results') 29 | self.img_dir = os.path.join(self.dataset_dir, 'leftImg8bit',self.period) 30 | self.seg_dir = os.path.join(self.dataset_dir, 'gtFine',self.period) 31 | self.img_extra_dir = os.path.join(self.dataset_dir, 'leftImg8bit', 'train_extra') 32 | self.seg_extra_dir = os.path.join(self.dataset_dir, 'gtCoarse', 'train_extra') 33 | if cfg.DATA_PSEUDO_GT: 34 | self.pseudo_gt_dir = cfg.DATA_PSEUDO_GT 35 | else: 36 | self.pseudo_gt_dir = os.path.join(self.root_dir,'pseudo_gt') 37 | 38 | searchFine = os.path.join(self.img_dir,'*', '*_*_*_leftImg8bit.png' ) 39 | filesFine = glob.glob(searchFine) 40 | filesFine.sort() 41 | self.name_list = [] 42 | for file in filesFine: 43 | name = file.replace('%s/'%self.img_dir,'') 44 | name = name.replace('_leftImg8bit.png','') 45 | self.name_list.append(name) 46 | 47 | if cfg.DATA_AUG: 48 | searchCoarse = os.path.join(self.img_extra_dir, '*', '*_*_*_leftImg8bit.png') 49 | filesCoarse = glob.glob(searchCoarse) 50 | filesCoarse.sort() 51 | for file in filesCoarse: 52 | name = file.replace('%s/'%self.img_extra_dir,'') 53 | name = name.replace('_leftImg8bit.png','') 54 | self.name_list.append(name) 55 | 56 | self.categories= [ 57 | # name id trainId category catId hasInstances ignoreInEval color 58 | Label( 'unlabeled', 0 , 255 , 'void', 0, False, True, ( 0, 0, 0) ), 59 | Label( 'ego vehicle', 1 , 255 , 'void', 0, False, True, ( 0, 0, 0) ), 60 | Label( 'rectification border', 2 , 255 , 'void', 0, False, True, ( 0, 0, 0) ), 61 | Label( 'out of roi', 3 , 255 , 'void', 0, False, True, ( 0, 0, 0) ), 62 | Label( 'static', 4 , 255 , 'void', 0, False, True, ( 0, 0, 0) ), 63 | Label( 'dynamic', 5 , 255 , 'void', 0, False, True, (111, 74, 0) ), 64 | Label( 'ground', 6 , 255 , 'void', 0, False, True, ( 81, 0, 81) ), 65 | Label( 'road', 7 , 0 , 'flat', 1, False, False, (128, 64,128) ), 66 | Label( 'sidewalk', 8 , 1 , 'flat', 1, False, False, (244, 35,232) ), 67 | Label( 'parking', 9 , 255 , 'flat', 1, False, True, (250,170,160) ), 68 | Label( 'rail track', 10 , 255 , 'flat', 1, False, True, (230,150,140) ), 69 | Label( 'building', 11 , 2 , 'construction',2, False, False, ( 70, 70, 70) ), 70 | Label( 'wall', 12 , 3 , 'construction',2, False, False, (102,102,156) ), 71 | Label( 'fence', 13 , 4 , 'construction',2, False, False, (190,153,153) ), 72 | Label( 'guard rail', 14 , 255 , 'construction',2, False, True, (180,165,180) ), 73 | Label( 'bridge', 15 , 255 , 'construction',2, False, True, (150,100,100) ), 74 | Label( 'tunnel', 16 , 255 , 'construction',2, False, True, (150,120, 90) ), 75 | Label( 'pole', 17 , 5 , 'object', 3, False, False, (153,153,153) ), 76 | Label( 'polegroup', 18 , 255 , 'object', 3, False, True, (153,153,153) ), 77 | Label( 'traffic light', 19 , 6 , 'object', 3, False, False, (250,170, 30) ), 78 | Label( 'traffic sign', 20 , 7 , 'object', 3, False, False, (220,220, 0) ), 79 | Label( 'vegetation', 21 , 8 , 'nature', 4, False, False, (107,142, 35) ), 80 | Label( 'terrain', 22 , 9 , 'nature', 4, False, False, (152,251,152) ), 81 | Label( 'sky', 23 , 10 , 'sky', 5, False, False, ( 70,130,180) ), 82 | Label( 'person', 24 , 11 , 'human', 6, True , False, (220, 20, 60) ), 83 | Label( 'rider', 25 , 12 , 'human', 6, True , False, (255, 0, 0) ), 84 | Label( 'car', 26 , 13 , 'vehicle', 7, True , False, ( 0, 0,142) ), 85 | Label( 'truck', 27 , 14 , 'vehicle', 7, True , False, ( 0, 0, 70) ), 86 | Label( 'bus', 28 , 15 , 'vehicle', 7, True , False, ( 0, 60,100) ), 87 | Label( 'caravan', 29 , 255 , 'vehicle', 7, True , True, ( 0, 0, 90) ), 88 | Label( 'trailer', 30 , 255 , 'vehicle', 7, True , True, ( 0, 0,110) ), 89 | Label( 'train', 31 , 16 , 'vehicle', 7, True , False, ( 0, 80,100) ), 90 | Label( 'motorcycle', 32 , 17 , 'vehicle', 7, True , False, ( 0, 0,230) ), 91 | Label( 'bicycle', 33 , 18 , 'vehicle', 7, True , False, (119, 11, 32) ), 92 | Label( 'license plate', -1 , -1 , 'vehicle', 7, False, True, ( 0, 0,142) ), 93 | ] 94 | self.id2label = {label.id: label for label in self.categories} 95 | self.trainId2label = {label.trainId : label for label in reversed(self.categories)} 96 | self.num_categories = 19 97 | assert self.num_categories == self.cfg.MODEL_NUM_CLASSES 98 | 99 | def __len__(self): 100 | return len(self.name_list) 101 | 102 | def load_name(self, idx): 103 | return self.name_list[idx] 104 | 105 | def load_image(self, idx): 106 | name = self.name_list[idx] 107 | img_file = os.path.join(self.img_dir, name + '_leftImg8bit.png') 108 | if not os.path.exists(img_file): 109 | img_file = os.path.join(self.img_extra_dir, name + '_leftImg8bit.png') 110 | image = cv2.imread(img_file) 111 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 112 | return image 113 | 114 | def __id2trainid__(self, seg): 115 | for label in self.categories: 116 | seg[seg == label.id] = label.trainId 117 | #seg[seg==-1] = 34 118 | return seg 119 | 120 | def load_segmentation(self, idx): 121 | name = self.name_list[idx] 122 | seg_file = os.path.join(self.seg_dir, name + '_gtFine_labelIds.png') 123 | if not os.path.exists(seg_file): 124 | seg_file = os.path.join(self.seg_extra_dir, name + '_gtCoarse_labelIds.png') 125 | segmentation = np.array(Image.open(seg_file)) 126 | return self.__id2trainid__(segmentation) 127 | 128 | def load_pseudo_segmentation(self, idx): 129 | name = self.name_list[idx].split('/')[1] 130 | seg_file = os.path.join(self.seg_dir, name + '.png') 131 | segmentation = np.array(Image.open(seg_file)) 132 | return self.__id2trainid__(segmentation) 133 | 134 | def save_pseudo_gt(self, result_list, level=None): 135 | folder_path = self.pseudo_gt_dir 136 | if not os.path.exists(folder_path): 137 | os.makedirs(folder_path) 138 | for sample in result_list: 139 | name = sample['name'].split('/') 140 | file_path = os.path.join(folder_path, '%s.png'%name[1]) 141 | cv2.imwrite(file_path, sample['predict']) 142 | print('%s saved'%(file_path)) 143 | 144 | def do_python_eval(self, model_id): 145 | raise NotImplementedError 146 | 147 | 148 | def label2colormap(self, label, id_version='trainid'): 149 | m = label.astype(np.uint8) 150 | r,c = m.shape 151 | cmap = np.zeros((r,c,3), dtype=np.uint8) 152 | if id_version == 'id': 153 | for k in self.id2label.keys(): 154 | cmap[m == k] = self.id2label[k].color 155 | elif id_version == 'trainid': 156 | for k in self.trainId2label.keys(): 157 | cmap[m == k] = self.trainId2label[k].color 158 | return cmap 159 | 160 | def trainid2id(self, label): 161 | label_id = label.copy() 162 | for k in self.trainId2label.keys(): 163 | label_id[label == k] = self.trainId2label[k].id 164 | return label_id 165 | 166 | def save_result(self, result_list, model_id): 167 | """Save test results 168 | 169 | Args: 170 | result_list(list of dict): [{'name':name1, 'predict':predict_seg1},{...},...] 171 | 172 | """ 173 | i = 1 174 | folder_path = self.rst_dir 175 | if not os.path.exists(folder_path): 176 | os.makedirs(folder_path) 177 | for sample in result_list: 178 | name = sample['name'].split('/') 179 | file_path = os.path.join(folder_path, '%s.png'%name[1]) 180 | # predict_color = self.label2colormap(sample['predict']) 181 | # p = self.__coco2voc(sample['predict']) 182 | cv2.imwrite(file_path, sample['predict']) 183 | print('[%d/%d] %s saved'%(i,len(result_list),file_path)) 184 | i+=1 185 | 186 | def do_cityscapesscripts_eval(self): 187 | import subprocess 188 | path = self.root_dir 189 | cmd = 'cd {} && '.format(path) 190 | cmd += 'python cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py' 191 | 192 | print('start subprocess for cityscapesscripts evaluation...') 193 | print(cmd) 194 | subprocess.call(cmd, shell=True) 195 | 196 | 197 | -------------------------------------------------------------------------------- /segmentation/lib/datasets/VOCDataset.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | from __future__ import print_function, division 6 | import os 7 | import torch 8 | import pandas as pd 9 | import cv2 10 | import multiprocessing 11 | from skimage import io 12 | from PIL import Image 13 | import numpy as np 14 | from torch.utils.data import Dataset 15 | from datasets.transform import * 16 | from utils.imutils import * 17 | from utils.registry import DATASETS 18 | from datasets.BaseDataset import BaseDataset 19 | 20 | @DATASETS.register_module 21 | class VOCDataset(BaseDataset): 22 | def __init__(self, cfg, period, transform='none'): 23 | super(VOCDataset, self).__init__(cfg, period, transform) 24 | self.dataset_name = 'VOC%d'%cfg.DATA_YEAR 25 | self.root_dir = os.path.join(cfg.ROOT_DIR,'data','VOCdevkit') 26 | self.dataset_dir = os.path.join(self.root_dir,self.dataset_name) 27 | self.rst_dir = os.path.join(self.root_dir,'results',self.dataset_name,'Segmentation') 28 | self.eval_dir = os.path.join(self.root_dir,'eval_result',self.dataset_name,'Segmentation') 29 | self.img_dir = os.path.join(self.dataset_dir, 'JPEGImages') 30 | self.ann_dir = os.path.join(self.dataset_dir, 'Annotations') 31 | self.seg_dir = os.path.join(self.dataset_dir, 'SegmentationClass') 32 | self.set_dir = os.path.join(self.dataset_dir, 'ImageSets', 'Segmentation') 33 | if cfg.DATA_PSEUDO_GT: 34 | self.pseudo_gt_dir = cfg.DATA_PSEUDO_GT 35 | else: 36 | self.pseudo_gt_dir = os.path.join(self.root_dir,'pseudo_gt',self.dataset_name,'Segmentation') 37 | 38 | file_name = None 39 | if cfg.DATA_AUG and 'train' in self.period: 40 | file_name = self.set_dir+'/'+period+'aug.txt' 41 | else: 42 | file_name = self.set_dir+'/'+period+'.txt' 43 | df = pd.read_csv(file_name, names=['filename']) 44 | self.name_list = df['filename'].values 45 | if self.dataset_name == 'VOC2012': 46 | self.categories = ['aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow', 47 | 'diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] 48 | self.coco2voc = [[0],[5],[2],[16],[9],[44],[6],[3],[17],[62], 49 | [21],[67],[18],[19],[4],[1],[64],[20],[63],[7],[72]] 50 | 51 | self.num_categories = len(self.categories)+1 52 | self.cmap = self.__colormap(len(self.categories)+1) 53 | 54 | def __len__(self): 55 | return len(self.name_list) 56 | 57 | def load_name(self, idx): 58 | name = self.name_list[idx] 59 | return name 60 | 61 | def load_image(self, idx): 62 | name = self.name_list[idx] 63 | img_file = self.img_dir + '/' + name + '.jpg' 64 | image = cv2.imread(img_file) 65 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 66 | return image_rgb 67 | 68 | def load_segmentation(self, idx): 69 | name = self.name_list[idx] 70 | seg_file = self.seg_dir + '/' + name + '.png' 71 | segmentation = np.array(Image.open(seg_file)) 72 | return segmentation 73 | 74 | def load_pseudo_segmentation(self, idx): 75 | name = self.name_list[idx] 76 | seg_file = self.pseudo_gt_dir + '/' + name + '.png' 77 | segmentation = np.array(Image.open(seg_file)) 78 | return segmentation 79 | 80 | def __colormap(self, N): 81 | """Get the map from label index to color 82 | 83 | Args: 84 | N: number of class 85 | 86 | return: a Nx3 matrix 87 | 88 | """ 89 | cmap = np.zeros((N, 3), dtype = np.uint8) 90 | 91 | def uint82bin(n, count=8): 92 | """returns the binary of integer n, count refers to amount of bits""" 93 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 94 | 95 | for i in range(N): 96 | r = 0 97 | g = 0 98 | b = 0 99 | idx = i 100 | for j in range(7): 101 | str_id = uint82bin(idx) 102 | r = r ^ ( np.uint8(str_id[-1]) << (7-j)) 103 | g = g ^ ( np.uint8(str_id[-2]) << (7-j)) 104 | b = b ^ ( np.uint8(str_id[-3]) << (7-j)) 105 | idx = idx >> 3 106 | cmap[i, 0] = r 107 | cmap[i, 1] = g 108 | cmap[i, 2] = b 109 | return cmap 110 | 111 | def load_ranked_namelist(self): 112 | df = self.read_rank_result() 113 | self.name_list = df['filename'].values 114 | 115 | def label2colormap(self, label): 116 | m = label.astype(np.uint8) 117 | r,c = m.shape 118 | cmap = np.zeros((r,c,3), dtype=np.uint8) 119 | cmap[:,:,0] = (m&1)<<7 | (m&8)<<3 120 | cmap[:,:,1] = (m&2)<<6 | (m&16)<<2 121 | cmap[:,:,2] = (m&4)<<5 122 | cmap[m==255] = [255,255,255] 123 | return cmap 124 | 125 | def save_result(self, result_list, model_id): 126 | """Save test results 127 | 128 | Args: 129 | result_list(list of dict): [{'name':name1, 'predict':predict_seg1},{...},...] 130 | 131 | """ 132 | folder_path = os.path.join(self.rst_dir,'%s_%s'%(model_id,self.period)) 133 | if not os.path.exists(folder_path): 134 | os.makedirs(folder_path) 135 | 136 | for sample in result_list: 137 | file_path = os.path.join(folder_path, '%s.png'%sample['name']) 138 | cv2.imwrite(file_path, sample['predict']) 139 | 140 | def save_pseudo_gt(self, result_list, folder_path=None): 141 | """Save pseudo gt 142 | 143 | Args: 144 | result_list(list of dict): [{'name':name1, 'predict':predict_seg1},{...},...] 145 | 146 | """ 147 | i = 1 148 | folder_path = self.pseudo_gt_dir if folder_path is None else folder_path 149 | if not os.path.exists(folder_path): 150 | os.makedirs(folder_path) 151 | for sample in result_list: 152 | file_path = os.path.join(folder_path, '%s.png'%(sample['name'])) 153 | cv2.imwrite(file_path, sample['predict']) 154 | i+=1 155 | 156 | def do_matlab_eval(self, model_id): 157 | import subprocess 158 | path = os.path.join(self.root_dir, 'VOCcode') 159 | eval_filename = os.path.join(self.eval_dir,'%s_result.mat'%model_id) 160 | cmd = 'cd {} && '.format(path) 161 | cmd += 'matlab -nodisplay -nodesktop ' 162 | cmd += '-r "dbstop if error; VOCinit; ' 163 | cmd += 'VOCevalseg(VOCopts,\'{:s}\');'.format(model_id) 164 | cmd += 'accuracies,avacc,conf,rawcounts = VOCevalseg(VOCopts,\'{:s}\'); '.format(model_id) 165 | cmd += 'save(\'{:s}\',\'accuracies\',\'avacc\',\'conf\',\'rawcounts\'); '.format(eval_filename) 166 | cmd += 'quit;"' 167 | 168 | print('start subprocess for matlab evaluation...') 169 | print(cmd) 170 | subprocess.call(cmd, shell=True) 171 | 172 | def do_python_eval(self, model_id): 173 | predict_folder = os.path.join(self.rst_dir,'%s_%s'%(model_id,self.period)) 174 | gt_folder = self.seg_dir 175 | TP = [] 176 | P = [] 177 | T = [] 178 | for i in range(self.num_categories): 179 | TP.append(multiprocessing.Value('i', 0, lock=True)) 180 | P.append(multiprocessing.Value('i', 0, lock=True)) 181 | T.append(multiprocessing.Value('i', 0, lock=True)) 182 | 183 | def compare(start,step,TP,P,T): 184 | for idx in range(start,len(self.name_list),step): 185 | #print('%d/%d'%(idx,len(self.name_list))) 186 | name = self.name_list[idx] 187 | predict_file = os.path.join(predict_folder,'%s.png'%name) 188 | gt_file = os.path.join(gt_folder,'%s.png'%name) 189 | predict = np.array(Image.open(predict_file)) #cv2.imread(predict_file) 190 | gt = np.array(Image.open(gt_file)) 191 | cal = gt<255 192 | mask = (predict==gt) * cal 193 | 194 | for i in range(self.num_categories): 195 | P[i].acquire() 196 | P[i].value += np.sum((predict==i)*cal) 197 | P[i].release() 198 | T[i].acquire() 199 | T[i].value += np.sum((gt==i)*cal) 200 | T[i].release() 201 | TP[i].acquire() 202 | TP[i].value += np.sum((gt==i)*mask) 203 | TP[i].release() 204 | p_list = [] 205 | for i in range(8): 206 | p = multiprocessing.Process(target=compare, args=(i,8,TP,P,T)) 207 | p.start() 208 | p_list.append(p) 209 | for p in p_list: 210 | p.join() 211 | IoU = [] 212 | for i in range(self.num_categories): 213 | IoU.append(TP[i].value/(T[i].value+P[i].value-TP[i].value+1e-10)) 214 | loglist = {} 215 | for i in range(self.num_categories): 216 | if i == 0: 217 | print('%11s:%7.3f%%'%('background',IoU[i]*100),end='\t') 218 | loglist['background'] = IoU[i] * 100 219 | else: 220 | if i%2 != 1: 221 | print('%11s:%7.3f%%'%(self.categories[i-1],IoU[i]*100),end='\t') 222 | else: 223 | print('%11s:%7.3f%%'%(self.categories[i-1],IoU[i]*100)) 224 | loglist[self.categories[i-1]] = IoU[i] * 100 225 | 226 | miou = np.mean(np.array(IoU)) 227 | print('\n======================================================') 228 | print('%11s:%7.3f%%'%('mIoU',miou*100)) 229 | loglist['mIoU'] = miou * 100 230 | return loglist 231 | 232 | def __coco2voc(self, m): 233 | r,c = m.shape 234 | result = np.zeros((r,c),dtype=np.uint8) 235 | for i in range(0,21): 236 | for j in self.coco2voc[i]: 237 | result[m==j] = i 238 | return result 239 | 240 | 241 | -------------------------------------------------------------------------------- /segmentation/lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .VOCDataset import * 2 | #from .COCODataset import * 3 | #from .CityscapesDataset import * 4 | #from .ADE20KDataset import * 5 | #from .ContextDataset import * 6 | -------------------------------------------------------------------------------- /segmentation/lib/datasets/generateData.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | #import torch 6 | #import torch.nn as nn 7 | #from datasets.VOCDataset import VOCDataset, Semi_VOCDataset, VOCSuperPixelDataset 8 | #from datasets.COCODataset import COCOSmtDataset, COCOInsDataset 9 | #from datasets.ADE20KDataset import ADE20KDataset 10 | #from datasets.ContextDataset import ContextDataset 11 | #from datasets.CityscapesDataset import CityscapesDataset 12 | #from datasets.CityscapesDemoDataset import CityscapesDemoDataset 13 | from utils.registry import DATASETS 14 | 15 | #def generate_dataset(dataset_name, cfg, data_period, data_aug=False, aug_period=None): 16 | # if dataset_name == 'voc2012' or dataset_name == 'VOC2012': 17 | # return VOCDataset('VOC2012', cfg, data_period, data_aug=data_aug, aug_period=aug_period) 18 | # elif dataset_name == 'semi-voc2012' or dataset_name == 'Semi-VOC2012': 19 | # return Semi_VOCDataset('VOC2012', cfg, data_period, data_aug=data_aug, aug_period=aug_period) 20 | # elif dataset_name == 'vocsp2012' or dataset_name == 'VOCSuperPixel2012': 21 | # return VOCSuperPixelDataset('VOC2012', cfg, data_period, data_aug=data_aug, aug_period=aug_period) 22 | # elif dataset_name == 'coco2017smt' or dataset_name == 'COCO2017Smt': 23 | # return COCODataset('COCO2017', cfg, data_period) 24 | # elif dataset_name == 'ade20k' or dataset_name == 'ADE20K': 25 | # return ADE20KDataset('ADE20K', cfg, data_period) 26 | # elif dataset_name == 'context' or dataset_name == 'Context': 27 | # return ContextDataset('Context', cfg, data_period) 28 | # elif dataset_name == 'cityscapes' or dataset_name == 'Cityscapes': 29 | # return CityscapesDataset('Cityscapes', cfg, data_period, data_aug=data_aug, aug_period=aug_period) 30 | # elif dataset_name == 'cityscapesdemo' or dataset_name == 'CityscapesDemo': 31 | # return CityscapesDemoDataset('CityscapesDemo', cfg, data_period) 32 | # else: 33 | # raise ValueError('generateData.py: dataset %s is not support yet'%dataset_name) 34 | 35 | def generate_dataset(cfg, **kwargs): 36 | dataset = DATASETS.get(cfg.DATA_NAME)(cfg, **kwargs) 37 | return dataset 38 | -------------------------------------------------------------------------------- /segmentation/lib/datasets/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import functools 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.initialized = False 9 | self.val = None 10 | self.avg = None 11 | self.sum = None 12 | self.count = None 13 | 14 | def initialize(self, val, weight): 15 | self.val = val 16 | self.avg = val 17 | self.sum = val * weight 18 | self.count = weight 19 | self.initialized = True 20 | 21 | def update(self, val, weight=1): 22 | if not self.initialized: 23 | self.initialize(val, weight) 24 | else: 25 | self.add(val, weight) 26 | 27 | def add(self, val, weight): 28 | self.val = val 29 | self.sum += val * weight 30 | self.count += weight 31 | self.avg = self.sum / self.count 32 | 33 | def value(self): 34 | return self.val 35 | 36 | def average(self): 37 | return self.avg 38 | 39 | 40 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 41 | ar = np.asanyarray(ar).flatten() 42 | 43 | optional_indices = return_index or return_inverse 44 | optional_returns = optional_indices or return_counts 45 | 46 | if ar.size == 0: 47 | if not optional_returns: 48 | ret = ar 49 | else: 50 | ret = (ar,) 51 | if return_index: 52 | ret += (np.empty(0, np.bool),) 53 | if return_inverse: 54 | ret += (np.empty(0, np.bool),) 55 | if return_counts: 56 | ret += (np.empty(0, np.intp),) 57 | return ret 58 | if optional_indices: 59 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 60 | aux = ar[perm] 61 | else: 62 | ar.sort() 63 | aux = ar 64 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 65 | 66 | if not optional_returns: 67 | ret = aux[flag] 68 | else: 69 | ret = (aux[flag],) 70 | if return_index: 71 | ret += (perm[flag],) 72 | if return_inverse: 73 | iflag = np.cumsum(flag) - 1 74 | inv_idx = np.empty(ar.shape, dtype=np.intp) 75 | inv_idx[perm] = iflag 76 | ret += (inv_idx,) 77 | if return_counts: 78 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 79 | ret += (np.diff(idx),) 80 | return ret 81 | 82 | 83 | def colorEncode(labelmap, colors, mode='BGR'): 84 | labelmap = labelmap.astype('int') 85 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 86 | dtype=np.uint8) 87 | for label in unique(labelmap): 88 | if label < 0: 89 | continue 90 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 91 | np.tile(colors[label], 92 | (labelmap.shape[0], labelmap.shape[1], 1)) 93 | 94 | if mode == 'BGR': 95 | return labelmap_rgb[:, :, ::-1] 96 | else: 97 | return labelmap_rgb 98 | 99 | 100 | def accuracy(preds, label): 101 | valid = (label >= 0) 102 | acc_sum = (valid * (preds == label)).sum() 103 | valid_sum = valid.sum() 104 | acc = float(acc_sum) / (valid_sum + 1e-10) 105 | return acc, valid_sum 106 | 107 | 108 | def intersectionAndUnion(imPred, imLab, numClass): 109 | imPred = np.asarray(imPred).copy() 110 | imLab = np.asarray(imLab).copy() 111 | 112 | imPred += 1 113 | imLab += 1 114 | # Remove classes from unlabeled pixels in gt image. 115 | # We should not penalize detections in unlabeled portions of the image. 116 | imPred = imPred * (imLab > 0) 117 | 118 | # Compute area intersection: 119 | intersection = imPred * (imPred == imLab) 120 | (area_intersection, _) = np.histogram( 121 | intersection, bins=numClass, range=(1, numClass)) 122 | 123 | # Compute area union: 124 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 125 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 126 | area_union = area_pred + area_lab - area_intersection 127 | 128 | return (area_intersection, area_union) 129 | 130 | 131 | class NotSupportedCliException(Exception): 132 | pass 133 | 134 | 135 | def process_range(xpu, inp): 136 | start, end = map(int, inp) 137 | if start > end: 138 | end, start = start, end 139 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) 140 | 141 | 142 | REGEX = [ 143 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), 144 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), 145 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), 146 | functools.partial(process_range, 'gpu')), 147 | (re.compile(r'^(\d+)-(\d+)$'), 148 | functools.partial(process_range, 'gpu')), 149 | ] 150 | 151 | 152 | def parse_devices(input_devices): 153 | 154 | """Parse user's devices input str to standard format. 155 | e.g. [gpu0, gpu1, ...] 156 | 157 | """ 158 | ret = [] 159 | for d in input_devices.split(','): 160 | for regex, func in REGEX: 161 | m = regex.match(d.lower().strip()) 162 | if m: 163 | tmp = func(m.groups()) 164 | # prevent duplicate 165 | for x in tmp: 166 | if x not in ret: 167 | ret.append(x) 168 | break 169 | else: 170 | raise NotSupportedCliException( 171 | 'Can not recognize device: "%s"' % d) 172 | return ret 173 | -------------------------------------------------------------------------------- /segmentation/lib/datasets/transform.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import random 9 | import PIL 10 | from PIL import Image, ImageOps, ImageFilter 11 | 12 | class RandomCrop(object): 13 | """Crop randomly the image in a sample. 14 | 15 | Args: 16 | output_size (tuple or int): Desired output size. If int, square crop 17 | is made. 18 | """ 19 | 20 | def __init__(self, output_size): 21 | assert isinstance(output_size, (int, tuple)) 22 | if isinstance(output_size, int): 23 | self.output_size = (output_size, output_size) 24 | else: 25 | assert len(output_size) == 2 26 | self.output_size = output_size 27 | 28 | def __call__(self, sample): 29 | 30 | h, w = sample['image'].shape[:2] 31 | ch = min(h, self.output_size[0]) 32 | cw = min(w, self.output_size[1]) 33 | 34 | h_space = h - self.output_size[0] 35 | w_space = w - self.output_size[1] 36 | 37 | if w_space > 0: 38 | cont_left = 0 39 | img_left = random.randrange(w_space+1) 40 | else: 41 | cont_left = random.randrange(-w_space+1) 42 | img_left = 0 43 | 44 | if h_space > 0: 45 | cont_top = 0 46 | img_top = random.randrange(h_space+1) 47 | else: 48 | cont_top = random.randrange(-h_space+1) 49 | img_top = 0 50 | 51 | key_list = sample.keys() 52 | for key in key_list: 53 | if 'image' in key: 54 | img = sample[key] 55 | img_crop = np.zeros((self.output_size[0], self.output_size[1], 3), np.float32) 56 | img_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 57 | img[img_top:img_top+ch, img_left:img_left+cw] 58 | #img_crop = img[img_top:img_top+ch, img_left:img_left+cw] 59 | sample[key] = img_crop 60 | elif 'segmentation' == key: 61 | seg = sample[key] 62 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255 63 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 64 | seg[img_top:img_top+ch, img_left:img_left+cw] 65 | #seg_crop = seg[img_top:img_top+ch, img_left:img_left+cw] 66 | sample[key] = seg_crop 67 | elif 'segmentation_pseudo' in key: 68 | seg_pseudo = sample[key] 69 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255 70 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 71 | seg_pseudo[img_top:img_top+ch, img_left:img_left+cw] 72 | #seg_crop = seg_pseudo[img_top:img_top+ch, img_left:img_left+cw] 73 | sample[key] = seg_crop 74 | return sample 75 | 76 | class RandomHSV(object): 77 | """Generate randomly the image in hsv space.""" 78 | def __init__(self, h_r, s_r, v_r): 79 | self.h_r = h_r 80 | self.s_r = s_r 81 | self.v_r = v_r 82 | 83 | def __call__(self, sample): 84 | image = sample['image'] 85 | hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) 86 | h = hsv[:,:,0].astype(np.int32) 87 | s = hsv[:,:,1].astype(np.int32) 88 | v = hsv[:,:,2].astype(np.int32) 89 | delta_h = random.randint(-self.h_r,self.h_r) 90 | delta_s = random.randint(-self.s_r,self.s_r) 91 | delta_v = random.randint(-self.v_r,self.v_r) 92 | h = (h + delta_h)%180 93 | s = s + delta_s 94 | s[s>255] = 255 95 | s[s<0] = 0 96 | v = v + delta_v 97 | v[v>255] = 255 98 | v[v<0] = 0 99 | hsv = np.stack([h,s,v], axis=-1).astype(np.uint8) 100 | image = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.uint8) 101 | sample['image'] = image 102 | return sample 103 | 104 | class RandomFlip(object): 105 | """Randomly flip image""" 106 | def __init__(self, threshold): 107 | self.flip_t = threshold 108 | def __call__(self, sample): 109 | if random.random() < self.flip_t: 110 | key_list = sample.keys() 111 | for key in key_list: 112 | if 'image' in key: 113 | img = sample[key] 114 | img = np.flip(img, axis=1) 115 | sample[key] = img 116 | elif 'segmentation' == key: 117 | seg = sample[key] 118 | seg = np.flip(seg, axis=1) 119 | sample[key] = seg 120 | elif 'segmentation_pseudo' in key: 121 | seg_pseudo = sample[key] 122 | seg_pseudo = np.flip(seg_pseudo, axis=1) 123 | sample[key] = seg_pseudo 124 | return sample 125 | 126 | class RandomScale(object): 127 | """Randomly scale image""" 128 | def __init__(self, scale_r, is_continuous=False): 129 | self.scale_r = scale_r 130 | self.seg_interpolation = cv2.INTER_CUBIC if is_continuous else cv2.INTER_NEAREST 131 | 132 | def __call__(self, sample): 133 | row, col, _ = sample['image'].shape 134 | rand_scale = random.random()*(self.scale_r[1] - self.scale_r[0]) + self.scale_r[0] 135 | key_list = sample.keys() 136 | for key in key_list: 137 | if 'image' in key: 138 | img = sample[key] 139 | img = cv2.resize(img, None, fx=rand_scale, fy=rand_scale, interpolation=cv2.INTER_CUBIC) 140 | sample[key] = img 141 | elif 'segmentation' == key: 142 | seg = sample[key] 143 | seg = cv2.resize(seg, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation) 144 | sample[key] = seg 145 | elif 'segmentation_pseudo' in key: 146 | seg_pseudo = sample[key] 147 | seg_pseudo = cv2.resize(seg_pseudo, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation) 148 | sample[key] = seg_pseudo 149 | return sample 150 | 151 | class ImageNorm(object): 152 | """Randomly scale image""" 153 | def __init__(self, mean=None, std=None): 154 | self.mean = mean 155 | self.std = std 156 | def __call__(self, sample): 157 | key_list = sample.keys() 158 | for key in key_list: 159 | if 'image' in key: 160 | image = sample[key].astype(np.float32) 161 | if self.mean is not None and self.std is not None: 162 | image[...,0] = (image[...,0]/255 - self.mean[0]) / self.std[0] 163 | image[...,1] = (image[...,1]/255 - self.mean[1]) / self.std[1] 164 | image[...,2] = (image[...,2]/255 - self.mean[2]) / self.std[2] 165 | else: 166 | image /= 255.0 167 | sample[key] = image 168 | return sample 169 | 170 | class Multiscale(object): 171 | def __init__(self, rate_list): 172 | self.rate_list = rate_list 173 | 174 | def __call__(self, sample): 175 | image = sample['image'] 176 | row, col, _ = image.shape 177 | image_multiscale = [] 178 | for rate in self.rate_list: 179 | rescaled_image = cv2.resize(image, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC) 180 | sample['image_%f'%rate] = rescaled_image 181 | return sample 182 | 183 | 184 | class ToTensor(object): 185 | """Convert ndarrays in sample to Tensors.""" 186 | 187 | def __call__(self, sample): 188 | key_list = sample.keys() 189 | for key in key_list: 190 | if 'image' in key: 191 | image = sample[key].astype(np.float32) 192 | # swap color axis because 193 | # numpy image: H x W x C 194 | # torch image: C X H X W 195 | image = image.transpose((2,0,1)) 196 | sample[key] = torch.from_numpy(image) 197 | #sample[key] = torch.from_numpy(image.astype(np.float32)/128.0-1.0) 198 | elif 'edge' == key: 199 | edge = sample['edge'] 200 | sample['edge'] = torch.from_numpy(edge.astype(np.float32)) 201 | sample['edge'] = torch.unsqueeze(sample['edge'],0) 202 | elif 'segmentation' == key: 203 | segmentation = sample['segmentation'] 204 | sample['segmentation'] = torch.from_numpy(segmentation.astype(np.long)) 205 | elif 'segmentation_pseudo' in key: 206 | segmentation_pseudo = sample[key] 207 | sample[key] = torch.from_numpy(segmentation_pseudo.astype(np.float32)) 208 | elif 'segmentation_onehot' == key: 209 | onehot = sample['segmentation_onehot'].transpose((2,0,1)) 210 | sample['segmentation_onehot'] = torch.from_numpy(onehot.astype(np.float32)) 211 | elif 'category' in key: 212 | sample[key] = torch.from_numpy(sample[key].astype(np.float32)) 213 | elif 'mask' == key: 214 | mask = sample['mask'] 215 | sample['mask'] = torch.from_numpy(mask.astype(np.float32)) 216 | elif 'feature' == key: 217 | feature = sample['feature'] 218 | sample['feature'] = torch.from_numpy(feature.astype(np.float32)) 219 | return sample 220 | 221 | -------------------------------------------------------------------------------- /segmentation/lib/net/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplabv1 import * 2 | from .deeplabv2 import * 3 | from .deeplabv3 import * 4 | from .deeplabv3plus import * 5 | from .backbone import * 6 | -------------------------------------------------------------------------------- /segmentation/lib/net/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_backbone 2 | from .resnet38d import * 3 | from .resnet import * 4 | from .xception import * 5 | 6 | __all__ = ['build_backbone'] 7 | -------------------------------------------------------------------------------- /segmentation/lib/net/backbone/builder.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | from utils.registry import BACKBONES 6 | 7 | def build_backbone(backbone_name, pretrained=True, **kwargs): 8 | net = BACKBONES.get(backbone_name)(pretrained=pretrained, **kwargs) 9 | return net 10 | -------------------------------------------------------------------------------- /segmentation/lib/net/deeplabv1.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from net.backbone import build_backbone 8 | from utils.registry import NETS 9 | 10 | @NETS.register_module 11 | class deeplabv1(nn.Module): 12 | def __init__(self, cfg, batchnorm=nn.BatchNorm2d, **kwargs): 13 | super(deeplabv1, self).__init__() 14 | self.cfg = cfg 15 | self.batchnorm = batchnorm 16 | #self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, os=self.cfg.MODEL_OUTPUT_STRIDE) 17 | self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, pretrained=cfg.MODEL_BACKBONE_PRETRAIN, norm_layer=self.batchnorm, **kwargs) 18 | self.conv_fov = nn.Conv2d(self.backbone.OUTPUT_DIM, 512, 3, 1, padding=12, dilation=12, bias=False) 19 | self.bn_fov = batchnorm(512, momentum=cfg.TRAIN_BN_MOM, affine=True) 20 | self.conv_fov2 = nn.Conv2d(512, 512, 1, 1, padding=0, bias=False) 21 | self.bn_fov2 = batchnorm(512, momentum=cfg.TRAIN_BN_MOM, affine=True) 22 | self.dropout1 = nn.Dropout(0.5) 23 | self.cls_conv = nn.Conv2d(512, cfg.MODEL_NUM_CLASSES, 1, 1, padding=0) 24 | self.__initial__() 25 | self.not_training = []#[self.backbone.conv1a, self.backbone.b2, self.backbone.b2_1, self.backbone.b2_2] 26 | #self.from_scratch_layers = [self.cls_conv] 27 | self.from_scratch_layers = [self.conv_fov, self.conv_fov2, self.cls_conv] 28 | 29 | def __initial__(self): 30 | for m in self.modules(): 31 | if m not in self.backbone.modules(): 32 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 33 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 34 | elif isinstance(m, self.batchnorm): 35 | nn.init.constant_(m.weight, 1) 36 | nn.init.constant_(m.bias, 0) 37 | #self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, pretrained=self.cfg.MODEL_BACKBONE_PRETRAIN) 38 | 39 | def forward(self, x): 40 | n,c,h,w = x.size() 41 | x_bottom = self.backbone(x)[-1] 42 | feature = self.conv_fov(x_bottom) 43 | feature = self.bn_fov(feature) 44 | feature = F.relu(feature, inplace=True) 45 | feature = self.conv_fov2(feature) 46 | feature = self.bn_fov2(feature) 47 | feature = F.relu(feature, inplace=True) 48 | feature = self.dropout1(feature) 49 | result = self.cls_conv(feature) 50 | result = F.interpolate(result,(h,w),mode='bilinear', align_corners=True) 51 | return result 52 | 53 | def get_parameter_groups(self): 54 | groups = ([], [], [], []) 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | if m.weight.requires_grad: 58 | if m in self.from_scratch_layers: 59 | groups[2].append(m.weight) 60 | else: 61 | groups[0].append(m.weight) 62 | 63 | if m.bias is not None and m.bias.requires_grad: 64 | 65 | if m in self.from_scratch_layers: 66 | groups[3].append(m.bias) 67 | else: 68 | groups[1].append(m.bias) 69 | return groups 70 | 71 | @NETS.register_module 72 | class deeplabv1_caffe(nn.Module): 73 | def __init__(self, cfg, batchnorm=nn.BatchNorm2d, **kwargs): 74 | super(deeplabv1_caffe, self).__init__() 75 | self.cfg = cfg 76 | self.batchnorm = batchnorm 77 | self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, os=self.cfg.MODEL_OUTPUT_STRIDE) 78 | outdim = 4096 79 | self.maxpool = nn.MaxPool2d(3, stride=1, padding=1) 80 | #self.avgpool = nn.AvgPool2d(3, stride=1, padding=1) 81 | self.conv_fov = nn.Conv2d(self.backbone.OUTPUT_DIM, outdim, 3, 1, padding=12, dilation=12) 82 | self.dropout1 = nn.Dropout(0.5) 83 | self.conv_fov2 = nn.Conv2d(outdim, outdim, 1, 1, padding=0) 84 | self.dropout2 = nn.Dropout(0.5) 85 | self.cls_conv = nn.Conv2d(outdim, cfg.MODEL_NUM_CLASSES, 1, 1, padding=0) 86 | self.__initial__() 87 | self.not_training = [] 88 | self.from_scratch_layers = [self.cls_conv] 89 | 90 | def __initial__(self): 91 | for m in self.modules(): 92 | if m not in self.backbone.modules(): 93 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 94 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 95 | if m.bias is not None: 96 | nn.init.constant_(m.bias, 0) 97 | elif isinstance(m, self.batchnorm): 98 | nn.init.constant_(m.weight, 1) 99 | nn.init.constant_(m.bias, 0) 100 | 101 | def forward(self, x): 102 | n,c,h,w = x.size() 103 | x_bottom = self.backbone(x)[-1] 104 | feature = self.maxpool(x_bottom) 105 | #feature = self.avgpool(feature) 106 | feature = F.relu(self.conv_fov(feature), inplace=True) 107 | feature = self.dropout1(feature) 108 | feature = F.relu(self.conv_fov2(feature), inplace=True) 109 | feature = self.dropout2(feature) 110 | result = self.cls_conv(feature) 111 | result = F.interpolate(result,(h,w),mode='bilinear', align_corners=True) 112 | return result 113 | 114 | def get_parameter_groups(self): 115 | groups = ([], [], [], []) 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | if m.weight.requires_grad: 119 | if m in self.from_scratch_layers: 120 | groups[2].append(m.weight) 121 | else: 122 | groups[0].append(m.weight) 123 | 124 | if m.bias is not None and m.bias.requires_grad: 125 | 126 | if m in self.from_scratch_layers: 127 | groups[3].append(m.bias) 128 | else: 129 | groups[1].append(m.bias) 130 | return groups 131 | -------------------------------------------------------------------------------- /segmentation/lib/net/deeplabv2.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import init 10 | from net.backbone import build_backbone 11 | from net.operators import ASPP 12 | from utils.registry import NETS 13 | class _deeplabv2(nn.Module): 14 | def __init__(self, cfg, batchnorm=nn.BatchNorm2d): 15 | super(_deeplabv2, self).__init__() 16 | self.batchnorm = batchnorm 17 | self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, os=self.cfg.MODEL_OUTPUT_STRIDE) 18 | input_channel = self.backbone.OUTPUT_DIM 19 | self.aspp = ASPP(dim_in=input_channel, 20 | dim_out=cfg.MODEL_ASPP_OUTDIM, 21 | rate=[6,12,18,24], 22 | bn_mom = cfg.TRAIN_BN_MOM, 23 | has_global=cfg.MODEL_ASPP_HASGLOBAL, 24 | batchnorm=batchnorm 25 | ) 26 | self.cfg = cfg 27 | def __initial__(self): 28 | for m in self.modules(): 29 | if m not in self.backbone.modules(): 30 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 31 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 32 | elif isinstance(m, self.batchnorm): 33 | nn.init.constant_(m.weight, 1) 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | raise NotImplementedError 38 | 39 | @NETS.register_module 40 | class deeplabv2(_deeplabv2): 41 | def __init__(self, cfg, batchnorm=nn.BatchNorm2d, **kwargs): 42 | super(deeplabv2, self).__init__(cfg, batchnorm) 43 | self.dropout1 = nn.Dropout(0.5) 44 | self.cls_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM, cfg.MODEL_NUM_CLASSES, 1, 1, padding=0) 45 | self.__initial__() 46 | self.from_scratch_layers = [self.cls_conv] 47 | for m in self.aspp.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | self.from_scratch_layers.append(m) 50 | 51 | def forward(self, x): 52 | n,c,h,w = x.size() 53 | x_bottom = self.backbone(x)[-1] 54 | feature = self.aspp(x_bottom) 55 | feature = self.dropout1(feature) 56 | result = self.cls_conv(feature) 57 | result = F.interpolate(result,(h,w),mode='bilinear', align_corners=True) 58 | 59 | return result 60 | 61 | def get_parameter_groups(self): 62 | groups = ([], [], [], []) 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | if m.weight.requires_grad: 66 | if m in self.from_scratch_layers: 67 | groups[2].append(m.weight) 68 | print(m) 69 | else: 70 | groups[0].append(m.weight) 71 | 72 | if m.bias is not None and m.bias.requires_grad: 73 | if m in self.from_scratch_layers: 74 | groups[3].append(m.bias) 75 | else: 76 | groups[1].append(m.bias) 77 | return groups 78 | 79 | -------------------------------------------------------------------------------- /segmentation/lib/net/deeplabv3.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import init 10 | from net.backbone import build_backbone 11 | from net.operators import ASPP 12 | from utils.registry import NETS 13 | 14 | class _deeplabv3(nn.Module): 15 | def __init__(self, cfg, batchnorm=nn.BatchNorm2d, **kwargs): 16 | super(_deeplabv3, self).__init__() 17 | self.cfg = cfg 18 | self.backbone = build_backbone(cfg.MODEL_BACKBONE, pretrained=cfg.MODEL_BACKBONE_PRETRAIN, **kwargs) 19 | self.batchnorm = batchnorm 20 | input_channel = self.backbone.OUTPUT_DIM 21 | self.aspp = ASPP(dim_in=input_channel, 22 | dim_out=cfg.MODEL_ASPP_OUTDIM, 23 | rate=[0, 6, 12, 18], 24 | bn_mom = cfg.TRAIN_BN_MOM, 25 | has_global = cfg.MODEL_ASPP_HASGLOBAL, 26 | batchnorm = self.batchnorm) 27 | def __initial__(self): 28 | for m in self.modules(): 29 | if m not in self.backbone.modules(): 30 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 31 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 32 | elif isinstance(m, self.batchnorm): 33 | nn.init.constant_(m.weight, 1) 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | raise NotImplementedError 38 | 39 | @NETS.register_module 40 | class deeplabv3(_deeplabv3): 41 | def __init__(self, cfg, **kwargs): 42 | super(deeplabv3, self).__init__(cfg, **kwargs) 43 | self.cls_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM, cfg.MODEL_NUM_CLASSES, 1, 1, padding=0) 44 | self.__initial__() 45 | 46 | def forward(self, x): 47 | n,c,h,w = x.size() 48 | x_bottom = self.backbone(x)[-1] 49 | feature = self.aspp(x_bottom) 50 | result = self.cls_conv(feature) 51 | result = F.interpolate(result,(h,w),mode='bilinear', align_corners=True) 52 | return result 53 | 54 | -------------------------------------------------------------------------------- /segmentation/lib/net/generateNet.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | #from net.deeplabv3plus import deeplabv3plus 6 | #from net.deeplabv3 import deeplabv3, deeplabv3_noise, deeplabv3_feature, deeplabv3_glore 7 | #from net.deeplabv2 import deeplabv2, deeplabv2_caffe 8 | #from net.deeplabv1 import deeplabv1, deeplabv1_caffe 9 | #from net.clsnet import ClsNet 10 | #from net.fcn import FCN 11 | #from net.DFANet import DFANet 12 | from utils.registry import NETS 13 | 14 | def generate_net(cfg, **kwargs): 15 | net = NETS.get(cfg.MODEL_NAME)(cfg, **kwargs) 16 | return net 17 | #def generate_net(cfg): 18 | # if cfg.MODEL_NAME == 'deeplabv3plus' or cfg.MODEL_NAME == 'deeplabv3+': 19 | # return deeplabv3plus(cfg) 20 | # elif cfg.MODEL_NAME == 'deeplabv3': 21 | # return deeplabv3(cfg) 22 | # elif cfg.MODEL_NAME == 'deeplabv2': 23 | # return deeplabv2(cfg) 24 | # elif cfg.MODEL_NAME == 'deeplabv1': 25 | # return deeplabv1(cfg) 26 | # elif cfg.MODEL_NAME == 'deeplabv1_caffe': 27 | # return deeplabv1_caffe(cfg) 28 | # elif cfg.MODEL_NAME == 'deeplabv2_caffe': 29 | # return deeplabv2_caffe(cfg) 30 | # elif cfg.MODEL_NAME == 'clsnet' or cfg.MODEL_NAME == 'ClsNet': 31 | # return ClsNet(cfg) 32 | # elif cfg.MODEL_NAME == 'fcn' or cfg.MODEL_NAME == 'FCN': 33 | # return FCN(cfg) 34 | # elif cfg.MODEL_NAME == 'DFANet' or cfg.MODEL_NAME == 'dfanet': 35 | # return DFANet(cfg) 36 | # else: 37 | # raise ValueError('generateNet.py: network %s is not support yet'%cfg.MODEL_NAME) 38 | -------------------------------------------------------------------------------- /segmentation/lib/net/operators/ASPP.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from net.sync_batchnorm import SynchronizedBatchNorm2d 10 | 11 | class ASPP(nn.Module): 12 | 13 | def __init__(self, dim_in, dim_out, rate=[1,6,12,18], bn_mom=0.1, has_global=True, batchnorm=SynchronizedBatchNorm2d): 14 | super(ASPP, self).__init__() 15 | self.dim_in = dim_in 16 | self.dim_out = dim_out 17 | self.has_global = has_global 18 | if rate[0] == 0: 19 | self.branch1 = nn.Sequential( 20 | nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=1,bias=False), 21 | batchnorm(dim_out, momentum=bn_mom, affine=True), 22 | nn.ReLU(inplace=True), 23 | ) 24 | else: 25 | self.branch1 = nn.Sequential( 26 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[0], dilation=rate[0],bias=False), 27 | batchnorm(dim_out, momentum=bn_mom, affine=True), 28 | nn.ReLU(inplace=True), 29 | ) 30 | self.branch2 = nn.Sequential( 31 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[1], dilation=rate[1],bias=False), 32 | batchnorm(dim_out, momentum=bn_mom, affine=True), 33 | nn.ReLU(inplace=True), 34 | ) 35 | self.branch3 = nn.Sequential( 36 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[2], dilation=rate[2],bias=False), 37 | batchnorm(dim_out, momentum=bn_mom, affine=True), 38 | nn.ReLU(inplace=True), 39 | ) 40 | self.branch4 = nn.Sequential( 41 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[3], dilation=rate[3],bias=False), 42 | batchnorm(dim_out, momentum=bn_mom, affine=True), 43 | nn.ReLU(inplace=True), 44 | ) 45 | if self.has_global: 46 | self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=False) 47 | self.branch5_bn = batchnorm(dim_out, momentum=bn_mom, affine=True) 48 | self.branch5_relu = nn.ReLU(inplace=True) 49 | self.conv_cat = nn.Sequential( 50 | nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=False), 51 | batchnorm(dim_out, momentum=bn_mom, affine=True), 52 | nn.ReLU(inplace=True), 53 | nn.Dropout(0.5) 54 | ) 55 | else: 56 | self.conv_cat = nn.Sequential( 57 | nn.Conv2d(dim_out*4, dim_out, 1, 1, padding=0), 58 | batchnorm(dim_out, momentum=bn_mom, affine=True), 59 | nn.ReLU(inplace=True), 60 | nn.Dropout(0.5) 61 | ) 62 | def forward(self, x): 63 | result = None 64 | [b,c,row,col] = x.size() 65 | conv1x1 = self.branch1(x) 66 | conv3x3_1 = self.branch2(x) 67 | conv3x3_2 = self.branch3(x) 68 | conv3x3_3 = self.branch4(x) 69 | if self.has_global: 70 | global_feature = F.adaptive_avg_pool2d(x, (1,1)) 71 | global_feature = self.branch5_conv(global_feature) 72 | global_feature = self.branch5_bn(global_feature) 73 | global_feature = self.branch5_relu(global_feature) 74 | global_feature = F.interpolate(global_feature, (row,col), None, 'bilinear', align_corners=True) 75 | 76 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) 77 | else: 78 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3], dim=1) 79 | result = self.conv_cat(feature_cat) 80 | 81 | return result 82 | -------------------------------------------------------------------------------- /segmentation/lib/net/operators/PPM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class PPM(nn.Module): 6 | """ 7 | Reference: 8 | Zhao, Hengshuang, et al. *"Pyramid scene parsing network."* 9 | """ 10 | def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6), norm_layer=nn.BatchNorm2d): 11 | super(PPM, self).__init__() 12 | 13 | self.stages = [] 14 | self.stages = nn.ModuleList([self._make_stage(features, out_features, size, norm_layer) for size in sizes]) 15 | self.bottleneck = nn.Sequential( 16 | nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=1, padding=0, dilation=1, bias=False), 17 | norm_layer(out_features), 18 | nn.ReLU(), 19 | nn.Dropout2d(0.1) 20 | ) 21 | 22 | def _make_stage(self, features, out_features, size, norm_layer): 23 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 24 | conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False) 25 | bn = norm_layer(out_features) 26 | return nn.Sequential(prior, conv, bn) 27 | 28 | def forward(self, feats): 29 | h, w = feats.size(2), feats.size(3) 30 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats] 31 | bottle = self.bottleneck(torch.cat(priors, 1)) 32 | return bottle 33 | -------------------------------------------------------------------------------- /segmentation/lib/net/operators/__init__.py: -------------------------------------------------------------------------------- 1 | from .ASPP import * 2 | from .PPM import * 3 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /segmentation/lib/net/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /segmentation/lib/utils/DenseCRF.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | from pydensecrf.utils import unary_from_softmax 4 | 5 | def dense_crf(probs, img=None, n_classes=21, n_iters=1, scale_factor=1): 6 | #probs = np.transpose(probs,(1,2,0)).copy(order='C') 7 | c,h,w = probs.shape 8 | 9 | if img is not None: 10 | assert(img.shape[1:3] == (h, w)) 11 | img = np.transpose(img,(1,2,0)).copy(order='C') 12 | 13 | #probs = probs.transpose(2, 0, 1).copy(order='C') # Need a contiguous array. 14 | 15 | d = dcrf.DenseCRF2D(w, h, n_classes) # Define DenseCRF model. 16 | 17 | unary = unary_from_softmax(probs) 18 | unary = np.ascontiguousarray(unary) 19 | d.setUnaryEnergy(unary) 20 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 21 | #d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 22 | d.addPairwiseBilateral(sxy=32/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 23 | Q = d.inference(n_iters) 24 | 25 | # U = -np.log(probs) # Unary potential. 26 | # U = U.reshape((n_classes, -1)) # Needs to be flat. 27 | # d.setUnaryEnergy(U) 28 | # d.addPairwiseGaussian(sxy=sxy_gaussian, compat=compat_gaussian, 29 | # kernel=kernel_gaussian, normalization=normalisation_gaussian) 30 | # if img is not None: 31 | # assert(img.shape[1:3] == (h, w)) 32 | # img = np.transpose(img,(1,2,0)).copy(order='C') 33 | # d.addPairwiseBilateral(sxy=sxy_bilateral, compat=compat_bilateral, 34 | # kernel=kernel_bilateral, normalization=normalisation_bilateral, 35 | # srgb=srgb_bilateral, rgbim=img) 36 | # Q = d.inference(n_iters) 37 | preds = np.array(Q, dtype=np.float32).reshape((n_classes, h, w)) 38 | #return np.expand_dims(preds, 0) 39 | return preds 40 | 41 | def pro_crf(p, img, itr): 42 | C, H, W = p.shape 43 | p_bg = 1-p 44 | for i in range(C): 45 | cat = np.concatenate([p[i,:,:], p_bg[i,:,:]], axis=0) 46 | crf_pro = dense_crf(cat, img.astype(np.uint8), n_classes=C, n_iters=itr) 47 | p[i,:,:] = crf_pro[0] 48 | return p 49 | -------------------------------------------------------------------------------- /segmentation/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import DATASETS, BACKBONES, NETS 2 | 3 | __all__ = ['DATASETS', 'BACKBONES', 'NETS'] 4 | -------------------------------------------------------------------------------- /segmentation/lib/utils/configuration.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | import torch 5 | import os 6 | import sys 7 | import shutil 8 | 9 | class Configuration(): 10 | def __init__(self, config_dict, clear=True): 11 | self.__dict__ = config_dict 12 | self.clear = clear 13 | self.__check() 14 | 15 | def __check(self): 16 | if not torch.cuda.is_available(): 17 | raise ValueError('config.py: cuda is not avalable') 18 | if self.GPUS == 0: 19 | raise ValueError('config.py: the number of GPU is 0') 20 | if self.GPUS != torch.cuda.device_count(): 21 | raise ValueError('config.py: GPU number is not matched') 22 | if not os.path.isdir(self.LOG_DIR): 23 | os.makedirs(self.LOG_DIR) 24 | elif self.clear: 25 | shutil.rmtree(self.LOG_DIR) 26 | os.mkdir(self.LOG_DIR) 27 | if not os.path.isdir(self.MODEL_SAVE_DIR): 28 | os.makedirs(self.MODEL_SAVE_DIR) 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /segmentation/lib/utils/finalprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def writelog(cfg, period, metric=None, commit=''): 4 | filepath = os.path.join(cfg.ROOT_DIR,'log','logfile.txt') 5 | logfile = open(filepath,'a') 6 | import time 7 | logfile.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 8 | logfile.write('\t%s\n'%period) 9 | para_data_dict = {} 10 | para_model_dict = {} 11 | para_train_dict = {} 12 | para_test_dict = {} 13 | para_name = dir(cfg) 14 | for name in para_name: 15 | if 'DATA_' in name: 16 | v = getattr(cfg,name) 17 | para_data_dict[name] = v 18 | elif 'MODEL_' in name: 19 | v = getattr(cfg,name) 20 | para_model_dict[name] = v 21 | elif 'TRAIN_' in name: 22 | v = getattr(cfg,name) 23 | para_train_dict[name] = v 24 | elif 'TEST_' in name: 25 | v = getattr(cfg,name) 26 | para_test_dict[name] = v 27 | writedict(logfile, {'EXP_NAME': cfg.EXP_NAME}) 28 | writedict(logfile, para_data_dict) 29 | writedict(logfile, para_model_dict) 30 | if 'train' in period: 31 | writedict(logfile, para_train_dict) 32 | else: 33 | writedict(logfile, para_test_dict) 34 | writedict(logfile, metric) 35 | 36 | logfile.write(commit) 37 | logfile.write('=====================================\n') 38 | logfile.close() 39 | 40 | def writedict(file, dictionary): 41 | s = '' 42 | for key in dictionary.keys(): 43 | sub = '%s:%s '%(key, dictionary[key]) 44 | s += sub 45 | s += '\n' 46 | file.write(s) 47 | 48 | -------------------------------------------------------------------------------- /segmentation/lib/utils/imutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def pseudo_erode(label, num, t=1): 5 | label_onehot = onehot(label, num) 6 | k = np.ones((15,15),np.uint8) 7 | e = cv2.erode(label_onehot, k, t) 8 | m = (e != label_onehot) 9 | m = np.max(m, axis=2) 10 | label[m] = 255 11 | return label 12 | 13 | 14 | def onehot(label, num): 15 | num = int(num) 16 | m = label.astype(np.int32) 17 | one_hot = np.eye(num)[m] 18 | return one_hot 19 | 20 | def seg2cls(label, num): 21 | cls = np.zeros(num) 22 | index = np.unique(label) 23 | cls[index] = 1 24 | #cls[0] = 0 25 | cls = cls.reshape((num,1,1)) 26 | return cls 27 | 28 | def gamma_correction(img): 29 | gamma = np.mean(img)/128.0 30 | lookUpTable = np.empty((1,256), np.uint8) 31 | for i in range(256): 32 | lookUpTable[0,i] = np.clip(pow(i / 255.0, gamma) * 255.0, 0, 255) 33 | res_img = cv2.LUT(img, lookUpTable) 34 | return res_img 35 | 36 | def img_denorm(inputs, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), mul=True): 37 | inputs = np.ascontiguousarray(inputs) 38 | if inputs.ndim == 3: 39 | inputs[0,:,:] = (inputs[0,:,:]*std[0] + mean[0]) 40 | inputs[1,:,:] = (inputs[1,:,:]*std[1] + mean[1]) 41 | inputs[2,:,:] = (inputs[2,:,:]*std[2] + mean[2]) 42 | elif inputs.ndim == 4: 43 | n = inputs.shape[0] 44 | for i in range(n): 45 | inputs[i,0,:,:] = (inputs[i,0,:,:]*std[0] + mean[0]) 46 | inputs[i,1,:,:] = (inputs[i,1,:,:]*std[1] + mean[1]) 47 | inputs[i,2,:,:] = (inputs[i,2,:,:]*std[2] + mean[2]) 48 | 49 | if mul: 50 | inputs = inputs*255 51 | inputs[inputs > 255] = 255 52 | inputs[inputs < 0] = 0 53 | inputs = inputs.astype(np.uint8) 54 | else: 55 | inputs[inputs > 1] = 1 56 | inputs[inputs < 0] = 0 57 | return inputs 58 | -------------------------------------------------------------------------------- /segmentation/lib/utils/registry.py: -------------------------------------------------------------------------------- 1 | 2 | class Registry(object): 3 | def __init__(self, name): 4 | super(Registry, self).__init__() 5 | self._name = name 6 | self._module_dict = dict() 7 | 8 | @property 9 | def name(self): 10 | return self._name 11 | 12 | @property 13 | def module_dict(self): 14 | return self._module_dict 15 | 16 | def __len__(self): 17 | return len(self.module_dict) 18 | 19 | def get(self, key): 20 | return self._module_dict[key] 21 | 22 | def register_module(self, module=None): 23 | if module is None: 24 | raise TypeError('fail to register None in Registry {}'.format(self.name)) 25 | module_name = module.__name__ 26 | if module_name in self._module_dict: 27 | raise KeyError('{} is already registry in Registry {}'.format(module_name, self.name)) 28 | self._module_dict[module_name] = module 29 | return module 30 | 31 | DATASETS = Registry('dataset') 32 | BACKBONES = Registry('backbone') 33 | NETS = Registry('nets') 34 | -------------------------------------------------------------------------------- /segmentation/lib/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from tqdm import tqdm 4 | 5 | def single_gpu_test(model, dataloader, prepare_func, inference_func, collect_func, save_step_func=None): 6 | model.eval() 7 | n_gpus = torch.cuda.device_count() 8 | #assert n_gpus == 1 9 | collect_list = [] 10 | total_num = len(dataloader) 11 | with tqdm(total=total_num) as pbar: 12 | with torch.no_grad(): 13 | for i_batch, sample in enumerate(dataloader): 14 | name = sample['name'] 15 | image_msf = prepare_func(sample) 16 | result_list = [] 17 | for img in image_msf: 18 | result = inference_func(model, img.cuda()) 19 | result_list.append(result) 20 | result_item = collect_func(result_list, sample) 21 | result_sample = {'predict': result_item, 'name':name[0]} 22 | #print('%d/%d'%(i_batch,len(dataloader))) 23 | pbar.set_description('Processing') 24 | pbar.update(1) 25 | time.sleep(0.001) 26 | 27 | if save_step_func is not None: 28 | save_step_func(result_sample) 29 | else: 30 | collect_list.append(result_sample) 31 | return collect_list 32 | -------------------------------------------------------------------------------- /segmentation/lib/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import cv2 5 | from utils.DenseCRF import * 6 | #from cv2.ximgproc import l0Smooth 7 | 8 | def color_pro(pro, img=None, mode='hwc'): 9 | H, W = pro.shape 10 | pro_255 = (pro*255).astype(np.uint8) 11 | pro_255 = np.expand_dims(pro_255,axis=2) 12 | color = cv2.applyColorMap(pro_255,cv2.COLORMAP_JET) 13 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) 14 | if img is not None: 15 | rate = 0.5 16 | if mode == 'hwc': 17 | assert img.shape[0] == H and img.shape[1] == W 18 | color = cv2.addWeighted(img,rate,color,1-rate,0) 19 | elif mode == 'chw': 20 | assert img.shape[1] == H and img.shape[2] == W 21 | img = np.transpose(img,(1,2,0)) 22 | color = cv2.addWeighted(img,rate,color,1-rate,0) 23 | color = np.transpose(color,(2,0,1)) 24 | else: 25 | if mode == 'chw': 26 | color = np.transpose(color,(2,0,1)) 27 | return color 28 | 29 | def generate_vis(p, gt, img, func_label2color, threshold=0.1, norm=True, crf=False): 30 | # All the input should be numpy.array 31 | # img should be 0-255 uint8 32 | C, H, W = p.shape 33 | 34 | if norm: 35 | prob = max_norm(p, 'numpy') 36 | else: 37 | prob = p 38 | if gt is not None: 39 | prob = prob * gt 40 | prob[prob<=0] = 1e-5 41 | if threshold is not None: 42 | prob[0,:,:] = np.power(1-np.max(prob[1:,:,:],axis=0,keepdims=True), 4) 43 | 44 | CLS = ColorCLS(prob, func_label2color) 45 | CAM = ColorCAM(prob, img) 46 | if crf: 47 | prob_crf = dense_crf(prob, img, n_classes=C, n_iters=1) 48 | CLS_crf = ColorCLS(prob_crf, func_label2color) 49 | CAM_crf = ColorCAM(prob_crf, img) 50 | return CLS, CAM, CLS_crf, CAM_crf 51 | else: 52 | return CLS, CAM 53 | 54 | def max_norm(p, version='torch', e=1e-5): 55 | if version is 'torch': 56 | if p.dim() == 3: 57 | C, H, W = p.size() 58 | p = F.relu(p, inplace=True) 59 | max_v = torch.max(p.view(C,-1),dim=-1)[0].view(C,1,1) 60 | min_v = torch.min(p.view(C,-1),dim=-1)[0].view(C,1,1) 61 | p = F.relu(p-min_v-e, inplace=True)/(max_v-min_v+e) 62 | elif p.dim() == 4: 63 | N, C, H, W = p.size() 64 | p = F.relu(p, inplace=True) 65 | max_v = torch.max(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 66 | min_v = torch.min(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 67 | p = F.relu(p-min_v-e, inplace=True)/(max_v-min_v+e) 68 | elif version is 'numpy' or version is 'np': 69 | if p.ndim == 3: 70 | C, H, W = p.shape 71 | p[p 0: 45 | cont_left = 0 46 | img_left = random.randrange(w_space+1) 47 | else: 48 | cont_left = random.randrange(-w_space+1) 49 | img_left = 0 50 | 51 | if h_space > 0: 52 | cont_top = 0 53 | img_top = random.randrange(h_space+1) 54 | else: 55 | cont_top = random.randrange(-h_space+1) 56 | img_top = 0 57 | 58 | container = np.zeros((self.cropsize, self.cropsize, imgarr.shape[-1]), np.float32) 59 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 60 | imgarr[img_top:img_top+ch, img_left:img_left+cw] 61 | if sal is not None: 62 | container_sal = np.zeros((self.cropsize, self.cropsize,1), np.float32) 63 | container_sal[cont_top:cont_top+ch, cont_left:cont_left+cw,0] = \ 64 | sal[img_top:img_top+ch, img_left:img_left+cw] 65 | return container, container_sal 66 | 67 | return container 68 | 69 | def get_random_crop_box(imgsize, cropsize): 70 | h, w = imgsize 71 | 72 | ch = min(cropsize, h) 73 | cw = min(cropsize, w) 74 | 75 | w_space = w - cropsize 76 | h_space = h - cropsize 77 | 78 | if w_space > 0: 79 | cont_left = 0 80 | img_left = random.randrange(w_space + 1) 81 | else: 82 | cont_left = random.randrange(-w_space + 1) 83 | img_left = 0 84 | 85 | if h_space > 0: 86 | cont_top = 0 87 | img_top = random.randrange(h_space + 1) 88 | else: 89 | cont_top = random.randrange(-h_space + 1) 90 | img_top = 0 91 | 92 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw 93 | 94 | def crop_with_box(img, box): 95 | if len(img.shape) == 3: 96 | img_cont = np.zeros((max(box[1]-box[0], box[4]-box[5]), max(box[3]-box[2], box[7]-box[6]), img.shape[-1]), dtype=img.dtype) 97 | else: 98 | img_cont = np.zeros((max(box[1] - box[0], box[4] - box[5]), max(box[3] - box[2], box[7] - box[6])), dtype=img.dtype) 99 | img_cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 100 | return img_cont 101 | 102 | 103 | def random_crop(images, cropsize, fills): 104 | if isinstance(images[0], PIL.Image.Image): 105 | imgsize = images[0].size[::-1] 106 | else: 107 | imgsize = images[0].shape[:2] 108 | box = get_random_crop_box(imgsize, cropsize) 109 | 110 | new_images = [] 111 | for img, f in zip(images, fills): 112 | 113 | if isinstance(img, PIL.Image.Image): 114 | img = img.crop((box[6], box[4], box[7], box[5])) 115 | cont = PIL.Image.new(img.mode, (cropsize, cropsize)) 116 | cont.paste(img, (box[2], box[0])) 117 | new_images.append(cont) 118 | 119 | else: 120 | if len(img.shape) == 3: 121 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f 122 | else: 123 | cont = np.ones((cropsize, cropsize), img.dtype)*f 124 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]] 125 | new_images.append(cont) 126 | 127 | return new_images 128 | 129 | 130 | class AvgPool2d(): 131 | 132 | def __init__(self, ksize): 133 | self.ksize = ksize 134 | 135 | def __call__(self, img): 136 | import skimage.measure 137 | 138 | return skimage.measure.block_reduce(img, (self.ksize, self.ksize, 1), np.mean) 139 | 140 | 141 | class RandomHorizontalFlip(): 142 | def __init__(self): 143 | return 144 | 145 | def __call__(self, img, sal=None): 146 | if bool(random.getrandbits(1)): 147 | #img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT) 148 | img = np.fliplr(img).copy() 149 | if sal: 150 | #sal = sal.transpose(PIL.Image.FLIP_LEFT_RIGHT) 151 | sal = np.fliplr(sal).copy() 152 | return img, sal 153 | return img 154 | else: 155 | if sal: 156 | return img, sal 157 | return img 158 | 159 | 160 | class CenterCrop(): 161 | 162 | def __init__(self, cropsize, default_value=0): 163 | self.cropsize = cropsize 164 | self.default_value = default_value 165 | 166 | def __call__(self, npimg): 167 | 168 | h, w = npimg.shape[:2] 169 | 170 | ch = min(self.cropsize, h) 171 | cw = min(self.cropsize, w) 172 | 173 | sh = h - self.cropsize 174 | sw = w - self.cropsize 175 | 176 | if sw > 0: 177 | cont_left = 0 178 | img_left = int(round(sw / 2)) 179 | else: 180 | cont_left = int(round(-sw / 2)) 181 | img_left = 0 182 | 183 | if sh > 0: 184 | cont_top = 0 185 | img_top = int(round(sh / 2)) 186 | else: 187 | cont_top = int(round(-sh / 2)) 188 | img_top = 0 189 | 190 | if len(npimg.shape) == 2: 191 | container = np.ones((self.cropsize, self.cropsize), npimg.dtype)*self.default_value 192 | else: 193 | container = np.ones((self.cropsize, self.cropsize, npimg.shape[2]), npimg.dtype)*self.default_value 194 | 195 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 196 | npimg[img_top:img_top+ch, img_left:img_left+cw] 197 | 198 | return container 199 | 200 | 201 | def HWC_to_CHW(tensor, sal=False): 202 | if sal: 203 | tensor = np.expand_dims(tensor, axis=0) 204 | else: 205 | tensor = np.transpose(tensor, (2, 0, 1)) 206 | return tensor 207 | 208 | 209 | class RescaleNearest(): 210 | def __init__(self, scale): 211 | self.scale = scale 212 | 213 | def __call__(self, npimg): 214 | import cv2 215 | return cv2.resize(npimg, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_NEAREST) 216 | 217 | 218 | 219 | 220 | def crf_inference(img, probs, t=10, scale_factor=1, labels=21): 221 | import pydensecrf.densecrf as dcrf 222 | from pydensecrf.utils import unary_from_softmax 223 | 224 | h, w = img.shape[:2] 225 | n_labels = labels 226 | 227 | d = dcrf.DenseCRF2D(w, h, n_labels) 228 | 229 | unary = unary_from_softmax(probs) 230 | unary = np.ascontiguousarray(unary) 231 | 232 | d.setUnaryEnergy(unary) 233 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 234 | d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 235 | Q = d.inference(t) 236 | 237 | return np.array(Q).reshape((n_labels, h, w)) 238 | -------------------------------------------------------------------------------- /tool/pyutils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | class Logger(object): 7 | def __init__(self, outfile): 8 | self.terminal = sys.stdout 9 | self.log = open(outfile, "w") 10 | sys.stdout = self 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.terminal.flush() 18 | 19 | 20 | class AverageMeter: 21 | def __init__(self, *keys): 22 | self.__data = dict() 23 | for k in keys: 24 | self.__data[k] = [0.0, 0] 25 | 26 | def add(self, dict): 27 | for k, v in dict.items(): 28 | self.__data[k][0] += v 29 | self.__data[k][1] += 1 30 | 31 | def get(self, *keys): 32 | if len(keys) == 1: 33 | return self.__data[keys[0]][0] / self.__data[keys[0]][1] 34 | else: 35 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys] 36 | return tuple(v_list) 37 | 38 | def pop(self, key=None): 39 | if key is None: 40 | for k in self.__data.keys(): 41 | self.__data[k] = [0.0, 0] 42 | else: 43 | v = self.get(key) 44 | self.__data[key] = [0.0, 0] 45 | return v 46 | 47 | 48 | class Timer: 49 | def __init__(self, starting_msg = None): 50 | self.start = time.time() 51 | self.stage_start = self.start 52 | 53 | if starting_msg is not None: 54 | print(starting_msg, time.ctime(time.time())) 55 | 56 | 57 | def update_progress(self, progress): 58 | self.elapsed = time.time() - self.start 59 | self.est_total = self.elapsed / progress 60 | self.est_remaining = self.est_total - self.elapsed 61 | self.est_finish = int(self.start + self.est_total) 62 | 63 | 64 | def str_est_finish(self): 65 | return str(time.ctime(self.est_finish)) 66 | 67 | def get_stage_elapsed(self): 68 | return time.time() - self.stage_start 69 | 70 | def reset_stage(self): 71 | self.stage_start = time.time() 72 | 73 | 74 | from multiprocessing.pool import ThreadPool 75 | 76 | class BatchThreader: 77 | 78 | def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12): 79 | self.batch_size = batch_size 80 | self.prefetch_size = prefetch_size 81 | 82 | self.pool = ThreadPool(processes=processes) 83 | self.async_result = [] 84 | 85 | self.func = func 86 | self.left_args_list = args_list 87 | self.n_tasks = len(args_list) 88 | 89 | # initial work 90 | self.__start_works(self.__get_n_pending_works()) 91 | 92 | 93 | def __start_works(self, times): 94 | for _ in range(times): 95 | args = self.left_args_list.pop(0) 96 | self.async_result.append( 97 | self.pool.apply_async(self.func, args)) 98 | 99 | 100 | def __get_n_pending_works(self): 101 | return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result) 102 | , len(self.left_args_list)) 103 | 104 | 105 | 106 | def pop_results(self): 107 | 108 | n_inwork = len(self.async_result) 109 | 110 | n_fetch = min(n_inwork, self.batch_size) 111 | rtn = [self.async_result.pop(0).get() 112 | for _ in range(n_fetch)] 113 | 114 | to_fill = self.__get_n_pending_works() 115 | if to_fill == 0: 116 | self.pool.close() 117 | else: 118 | self.__start_works(to_fill) 119 | 120 | return rtn 121 | 122 | 123 | 124 | 125 | def get_indices_of_pairs(radius, size): 126 | 127 | search_dist = [] 128 | 129 | for x in range(1, radius): 130 | search_dist.append((0, x)) 131 | 132 | for y in range(1, radius): 133 | for x in range(-radius + 1, radius): 134 | if x * x + y * y < radius * radius: 135 | search_dist.append((y, x)) 136 | 137 | radius_floor = radius - 1 138 | 139 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 140 | (size[0], size[1])) 141 | 142 | cropped_height = size[0] - radius_floor 143 | cropped_width = size[1] - 2 * radius_floor 144 | 145 | indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor], 146 | [-1]) 147 | 148 | indices_to_list = [] 149 | 150 | for dy, dx in search_dist: 151 | indices_to = full_indices[dy:dy + cropped_height, 152 | radius_floor + dx:radius_floor + dx + cropped_width] 153 | indices_to = np.reshape(indices_to, [-1]) 154 | 155 | indices_to_list.append(indices_to) 156 | 157 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 158 | 159 | return indices_from, concat_indices_to 160 | 161 | def get_indices_of_pairs_circle(radius, size): 162 | 163 | search_dist = [] 164 | 165 | for y in range(-radius + 1, radius): 166 | for x in range(-radius + 1, radius): 167 | if x * x + y * y < radius * radius and x*x+y*y!=0: 168 | search_dist.append((y, x)) 169 | 170 | radius_floor = radius - 1 171 | 172 | full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64), 173 | (size[0], size[1])) 174 | 175 | cropped_height = size[0] - 2 * radius_floor 176 | cropped_width = size[1] - 2 * radius_floor 177 | 178 | indices_from = np.reshape(full_indices[radius_floor:-radius_floor, radius_floor:-radius_floor], 179 | [-1]) 180 | 181 | indices_to_list = [] 182 | 183 | for dy, dx in search_dist: 184 | indices_to = full_indices[radius_floor + dy : radius_floor + dy + cropped_height, 185 | radius_floor + dx : radius_floor + dx + cropped_width] 186 | indices_to = np.reshape(indices_to, [-1]) 187 | 188 | indices_to_list.append(indices_to) 189 | 190 | concat_indices_to = np.concatenate(indices_to_list, axis=0) 191 | 192 | return indices_from, concat_indices_to 193 | -------------------------------------------------------------------------------- /tool/torchutils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import os.path 6 | import random 7 | import numpy as np 8 | from tool import imutils 9 | import torch.nn.functional as F 10 | 11 | class PolyOptimizer(torch.optim.SGD): 12 | 13 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): 14 | super().__init__(params, lr, weight_decay) 15 | 16 | self.global_step = 0 17 | self.max_step = max_step 18 | self.momentum = momentum 19 | 20 | self.__initial_lr = [group['lr'] for group in self.param_groups] 21 | 22 | 23 | def step(self, closure=None): 24 | 25 | if self.global_step < self.max_step: 26 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 27 | 28 | for i in range(len(self.param_groups)): 29 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 30 | 31 | super().step(closure) 32 | 33 | self.global_step += 1 34 | 35 | 36 | class PolyAdam(torch.optim.Adam): 37 | 38 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9, betas=(0.9, 0.999)): 39 | super().__init__(params, lr, weight_decay=weight_decay, betas=betas) 40 | 41 | self.global_step = 0 42 | self.max_step = max_step 43 | self.momentum = momentum 44 | 45 | self.__initial_lr = [group['lr'] for group in self.param_groups] 46 | 47 | 48 | def step(self, closure=None): 49 | 50 | if self.global_step < self.max_step: 51 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 52 | 53 | for i in range(len(self.param_groups)): 54 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 55 | 56 | super().step(closure) 57 | 58 | self.global_step += 1 59 | 60 | 61 | 62 | class BatchNorm2dFixed(torch.nn.Module): 63 | 64 | def __init__(self, num_features, eps=1e-5): 65 | super(BatchNorm2dFixed, self).__init__() 66 | self.num_features = num_features 67 | self.eps = eps 68 | self.weight = torch.nn.Parameter(torch.Tensor(num_features)) 69 | self.bias = torch.nn.Parameter(torch.Tensor(num_features)) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | 73 | 74 | def forward(self, input): 75 | 76 | return F.batch_norm( 77 | input, self.running_mean, self.running_var, self.weight, self.bias, 78 | False, eps=self.eps) 79 | 80 | def __call__(self, x): 81 | return self.forward(x) 82 | 83 | 84 | class SegmentationDataset(Dataset): 85 | def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, 86 | img_transform=None, mask_transform=None): 87 | self.img_name_list_path = img_name_list_path 88 | self.img_dir = img_dir 89 | self.label_dir = label_dir 90 | 91 | self.img_transform = img_transform 92 | self.mask_transform = mask_transform 93 | 94 | self.img_name_list = open(self.img_name_list_path).read().splitlines() 95 | 96 | self.rescale = rescale 97 | self.flip = flip 98 | self.cropsize = cropsize 99 | 100 | def __len__(self): 101 | return len(self.img_name_list) 102 | 103 | def __getitem__(self, idx): 104 | 105 | name = self.img_name_list[idx] 106 | 107 | img = Image.open(os.path.join(self.img_dir, name + '.jpg')).convert("RGB") 108 | mask = Image.open(os.path.join(self.label_dir, name + '.png')) 109 | 110 | if self.rescale is not None: 111 | s = self.rescale[0] + random.random() * (self.rescale[1] - self.rescale[0]) 112 | adj_size = (round(img.size[0]*s/8)*8, round(img.size[1]*s/8)*8) 113 | img = img.resize(adj_size, resample=Image.CUBIC) 114 | mask = img.resize(adj_size, resample=Image.NEAREST) 115 | 116 | if self.img_transform is not None: 117 | img = self.img_transform(img) 118 | if self.mask_transform is not None: 119 | mask = self.mask_transform(mask) 120 | 121 | if self.cropsize is not None: 122 | img, mask = imutils.random_crop([img, mask], self.cropsize, (0, 255)) 123 | 124 | mask = imutils.RescaleNearest(0.125)(mask) 125 | 126 | if self.flip is True and bool(random.getrandbits(1)): 127 | img = np.flip(img, 1).copy() 128 | mask = np.flip(mask, 1).copy() 129 | 130 | img = np.transpose(img, (2, 0, 1)) 131 | 132 | return name, img, mask 133 | 134 | 135 | class ExtractAffinityLabelInRadius(): 136 | 137 | def __init__(self, cropsize, radius=5): 138 | self.radius = radius 139 | 140 | self.search_dist = [] 141 | 142 | for x in range(1, radius): 143 | self.search_dist.append((0, x)) 144 | 145 | for y in range(1, radius): 146 | for x in range(-radius+1, radius): 147 | if x*x + y*y < radius*radius: 148 | self.search_dist.append((y, x)) 149 | 150 | self.radius_floor = radius-1 151 | 152 | self.crop_height = cropsize - self.radius_floor 153 | self.crop_width = cropsize - 2 * self.radius_floor 154 | return 155 | 156 | def __call__(self, label): 157 | 158 | labels_from = label[:-self.radius_floor, self.radius_floor:-self.radius_floor] 159 | labels_from = np.reshape(labels_from, [-1]) 160 | 161 | labels_to_list = [] 162 | valid_pair_list = [] 163 | 164 | for dy, dx in self.search_dist: 165 | labels_to = label[dy:dy+self.crop_height, self.radius_floor+dx:self.radius_floor+dx+self.crop_width] 166 | labels_to = np.reshape(labels_to, [-1]) 167 | 168 | valid_pair = np.logical_and(np.less(labels_to, 255), np.less(labels_from, 255)) 169 | 170 | labels_to_list.append(labels_to) 171 | valid_pair_list.append(valid_pair) 172 | 173 | bc_labels_from = np.expand_dims(labels_from, 0) 174 | concat_labels_to = np.stack(labels_to_list) 175 | concat_valid_pair = np.stack(valid_pair_list) 176 | 177 | pos_affinity_label = np.equal(bc_labels_from, concat_labels_to) 178 | 179 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(bc_labels_from, 0)).astype(np.float32) 180 | 181 | fg_pos_affinity_label = np.logical_and(np.logical_and(pos_affinity_label, np.not_equal(bc_labels_from, 0)), concat_valid_pair).astype(np.float32) 182 | 183 | neg_affinity_label = np.logical_and(np.logical_not(pos_affinity_label), concat_valid_pair).astype(np.float32) 184 | 185 | return bg_pos_affinity_label, fg_pos_affinity_label, neg_affinity_label 186 | 187 | class AffinityFromMaskDataset(SegmentationDataset): 188 | def __init__(self, img_name_list_path, img_dir, label_dir, rescale=None, flip=False, cropsize=None, 189 | img_transform=None, mask_transform=None, radius=5): 190 | super().__init__(img_name_list_path, img_dir, label_dir, rescale, flip, cropsize, img_transform, mask_transform) 191 | 192 | self.radius = radius 193 | 194 | self.extract_aff_lab_func = ExtractAffinityLabelInRadius(cropsize=cropsize//8, radius=radius) 195 | 196 | def __getitem__(self, idx): 197 | name, img, mask = super().__getitem__(idx) 198 | 199 | aff_label = self.extract_aff_lab_func(mask) 200 | 201 | return name, img, aff_label 202 | -------------------------------------------------------------------------------- /tool/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import cv2 5 | import pydensecrf.densecrf as dcrf 6 | from pydensecrf.utils import unary_from_softmax 7 | 8 | def color_pro(pro, img=None, mode='hwc'): 9 | H, W = pro.shape 10 | pro_255 = (pro*255).astype(np.uint8) 11 | pro_255 = np.expand_dims(pro_255,axis=2) 12 | color = cv2.applyColorMap(pro_255,cv2.COLORMAP_JET) 13 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) 14 | if img is not None: 15 | rate = 0.5 16 | if mode == 'hwc': 17 | assert img.shape[0] == H and img.shape[1] == W 18 | color = cv2.addWeighted(img,rate,color,1-rate,0) 19 | elif mode == 'chw': 20 | assert img.shape[1] == H and img.shape[2] == W 21 | img = np.transpose(img,(1,2,0)) 22 | color = cv2.addWeighted(img,rate,color,1-rate,0) 23 | color = np.transpose(color,(2,0,1)) 24 | else: 25 | if mode == 'chw': 26 | color = np.transpose(color,(2,0,1)) 27 | return color 28 | 29 | def generate_vis(p, gt, img, func_label2color, threshold=0.1, norm=True): 30 | # All the input should be numpy.array 31 | # img should be 0-255 uint8 32 | C, H, W = p.shape 33 | 34 | if norm: 35 | prob = max_norm(p, 'numpy') 36 | else: 37 | prob = p 38 | if gt is not None: 39 | prob = prob * gt 40 | prob[prob<=0] = 1e-7 41 | if threshold is not None: 42 | prob[0,:,:] = np.power(1-np.max(prob[1:,:,:],axis=0,keepdims=True), 4) 43 | 44 | CLS = ColorCLS(prob, func_label2color) 45 | CAM = ColorCAM(prob, img) 46 | 47 | prob_crf = dense_crf(prob, img, n_classes=C, n_iters=1) 48 | 49 | CLS_crf = ColorCLS(prob_crf, func_label2color) 50 | CAM_crf = ColorCAM(prob_crf, img) 51 | 52 | return CLS, CAM, CLS_crf, CAM_crf 53 | 54 | def max_norm(p, version='torch', e=1e-5): 55 | if version is 'torch': 56 | if p.dim() == 3: 57 | C, H, W = p.size() 58 | p = F.relu(p) 59 | max_v = torch.max(p.view(C,-1),dim=-1)[0].view(C,1,1) 60 | min_v = torch.min(p.view(C,-1),dim=-1)[0].view(C,1,1) 61 | p = F.relu(p-min_v-e)/(max_v-min_v+e) 62 | elif p.dim() == 4: 63 | N, C, H, W = p.size() 64 | p = F.relu(p) 65 | max_v = torch.max(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 66 | min_v = torch.min(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 67 | p = F.relu(p-min_v-e)/(max_v-min_v+e) 68 | elif version is 'numpy' or version is 'np': 69 | if p.ndim == 3: 70 | C, H, W = p.shape 71 | p[p<0] = 0 72 | max_v = np.max(p,(1,2),keepdims=True) 73 | min_v = np.min(p,(1,2),keepdims=True) 74 | p[p