├── demo ├── __init__.py ├── demo.py └── eval.py ├── model ├── backbone │ ├── __init__.py │ ├── mobilenet.py │ ├── resnet.py │ └── xception.py ├── sync_batchnorm │ ├── __init__.py │ ├── unittest.py │ ├── replicate.py │ ├── comm.py │ └── batchnorm.py ├── bisenet_up.py └── bisenet.py ├── datasets ├── __init__.py ├── sun │ ├── sun14.csv │ ├── sun37.csv │ ├── mk_dataset.py │ └── SUNRGBD.py ├── build_datasets.py └── transforms.py ├── constants.py ├── onnx ├── cat.jpg ├── cat_super3_torch_onnx.png ├── torch2onnx.md ├── onnx_utils.py └── onnx_demo.py ├── res ├── res101.png └── res18_pi.png ├── .gitignore ├── exps ├── bise.sh └── kd.sh ├── utils ├── calculate_weights.py ├── saver.py ├── lr_scheduler.py ├── metrics.py ├── misc.py ├── trainer.py ├── loss.py ├── trainer_kd.py └── vis.py ├── train.py ├── train_kd.py ├── argument_parser.py └── README.md /demo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.sun.SUNRGBD import SUNRGBD 2 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | BG_INDEX = 255 # 保存成 gray,255即为白点,比较直观反映未标注位置 2 | RUNS = "runs" 3 | -------------------------------------------------------------------------------- /onnx/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuai-Xie/BiSeNet-wali/HEAD/onnx/cat.jpg -------------------------------------------------------------------------------- /res/res101.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuai-Xie/BiSeNet-wali/HEAD/res/res101.png -------------------------------------------------------------------------------- /res/res18_pi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuai-Xie/BiSeNet-wali/HEAD/res/res18_pi.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | */__pycache__ 4 | 5 | img 6 | runs 7 | 8 | */*.onnx 9 | 10 | z.py 11 | -------------------------------------------------------------------------------- /onnx/cat_super3_torch_onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shuai-Xie/BiSeNet-wali/HEAD/onnx/cat_super3_torch_onnx.png -------------------------------------------------------------------------------- /datasets/sun/sun14.csv: -------------------------------------------------------------------------------- 1 | name,r,g,b 2 | bg,0,0,0 3 | wall,148,65,137 4 | floor,255,116,69 5 | cabinet,86,156,137 6 | chair,155,99,235 7 | sofa,161,107,108 8 | table,133,160,103 9 | door,76,152,126 10 | window,84,62,35 11 | bookshelf,44,80,130 12 | blinds,23,197,62 13 | ceiling,155,108,249 14 | tv,100,124,51 15 | box,221,223,147 16 | person,161,66,179 17 | -------------------------------------------------------------------------------- /model/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /datasets/sun/sun37.csv: -------------------------------------------------------------------------------- 1 | name,r,g,b 2 | bg,0,0,0 3 | wall,148,65,137 4 | floor,255,116,69 5 | cabinet,86,156,137 6 | bed,202,179,158 7 | chair,155,99,235 8 | sofa,161,107,108 9 | table,133,160,103 10 | door,76,152,126 11 | window,84,62,35 12 | bookshelf,44,80,130 13 | picture,31,184,157 14 | counter,101,144,77 15 | blinds,23,197,62 16 | desk,141,168,145 17 | shelves,142,151,136 18 | curtain,115,201,77 19 | dresser,100,216,255 20 | pillow,57,156,36 21 | mirror,88,108,129 22 | floor_mat,105,129,112 23 | clothes,42,137,126 24 | ceiling,155,108,249 25 | books,166,148,143 26 | fridge,81,91,87 27 | tv,100,124,51 28 | paper,73,131,121 29 | towel,157,210,220 30 | shower_curtain,134,181,60 31 | box,221,223,147 32 | whiteboard,123,108,131 33 | person,161,66,179 34 | night_stand,163,221,160 35 | toilet,31,146,98 36 | sink,99,121,30 37 | lamp,49,89,240 38 | bathtub,116,108,9 39 | bag,161,176,169 40 | -------------------------------------------------------------------------------- /model/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /exps/bise.sh: -------------------------------------------------------------------------------- 1 | # no class weights 2 | # rand 15min 3 | # 设置 iters_per_epoch=200,每个 epoch 降至 5min -> 20000 iters 4 | # todo: 迭代总次数才是 影响训练效果的主要原因; 5 | # bs=16, 15897MiB, 占太多,别的GPU也无法跑 6 | 7 | # todo: 起始准确率差,因为没有 pretrain 参数 8 | # inp=64 可以使用 pretrain 参数,所以一开始性能就好 9 | 10 | # res18 11 | # rand 12 | python train.py --dataset SUNRGBD --base-size 512 --crop-size 512 --workers 4 \ 13 | --epochs 100 --eval-interval 5 --batch-size=8 --iters_per_epoch 400 \ 14 | --gpu-ids 2 \ 15 | --loss-type mce --use-balanced-weights \ 16 | --lr 0.01 --lr-scheduler poly \ 17 | --context_path resnet18 \ 18 | --in_planes 32 \ 19 | --checkname res18_inp32_deconv 20 | 21 | # res50 22 | # cit 23 | # 24 | python train.py --dataset SUNRGBD --base-size 512 --crop-size 512 --workers 4 \ 25 | --epochs 100 --eval-interval 5 --batch-size=8 --iters_per_epoch 400 \ 26 | --gpu-ids 0 \ 27 | --loss-type mce --use-balanced-weights \ 28 | --lr 0.01 --lr-scheduler poly \ 29 | --context_path resnet50 \ 30 | --in_planes 16 \ 31 | --checkname res50_inp16_deconv 32 | 33 | # res101 34 | python train.py --dataset SUNRGBD --base-size 512 --crop-size 512 --workers 4 \ 35 | --epochs 100 --eval-interval 5 --batch-size=8 --iters_per_epoch 300 \ 36 | --gpu-ids 0 \ 37 | --loss-type mce --use-balanced-weights \ 38 | --lr 0.007 --lr-scheduler poly \ 39 | --context_path resnet101 \ 40 | --in_planes 64 \ 41 | --checkname res101_inp64_deconv -------------------------------------------------------------------------------- /utils/calculate_weights.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import copy 4 | 5 | 6 | def cal_class_freqs(dataset, num_classes): 7 | class_cnts = np.zeros((num_classes,)) 8 | 9 | dataset = copy.deepcopy(dataset) # 替换 transforms,计算 class weights 10 | dataset.transform = None # 不改变原图 target 11 | 12 | # 计算每个像素点数目 13 | for sample in tqdm(dataset): 14 | y = np.asarray(sample['target']) 15 | mask = np.logical_and((y >= 0), (y < num_classes)) # 逻辑或,合理区域 16 | labels = y[mask].astype(np.uint8) 17 | count_l = np.bincount(labels, minlength=num_classes) # 19D vec 18 | class_cnts += count_l # class cnt 19 | 20 | class_freqs = class_cnts / class_cnts.sum() 21 | return class_cnts, class_freqs 22 | 23 | 24 | # freq 和 weight 不同 25 | def cal_class_weights(dataset, num_classes, save_dir=None): 26 | """ 27 | weight = 1/√num, 再归一化 28 | """ 29 | class_cnts, class_freqs = cal_class_freqs(dataset, num_classes) 30 | z = np.nan_to_num(np.sqrt(1 + class_cnts)) # smooth num, 防止下文分母 frequency=0 31 | class_weights = [1 / f for f in z] # frequency 32 | 33 | ret = np.nan_to_num(np.array(class_weights)) 34 | ret[ret > 2 * np.median(ret)] = 2 * np.median(ret) 35 | ret = ret / ret.sum() 36 | print('Class weights: ') 37 | print(ret) 38 | 39 | if save_dir: 40 | np.save(f'{save_dir}/class_weights.npy', ret) 41 | np.save(f'{save_dir}/class_cnts.npy', class_cnts) 42 | np.save(f'{save_dir}/class_freqs.npy', class_freqs) 43 | 44 | return ret 45 | -------------------------------------------------------------------------------- /exps/kd.sh: -------------------------------------------------------------------------------- 1 | # pi 仅采用 pixel-wise loss 蒸馏训练 2 | # 迭代 10000 iters 3 | # 起始 lr 不要太大,不然 seg loss 会很高 4 | 5 | # pi 用在 score map 分类结果 6 | # pa 用在 feature map 比较特征 相似度 7 | 8 | python train_kd.py --dataset SUNRGBD --base-size 512 --crop-size 512 --workers 4 \ 9 | --epochs 100 --eval-interval 5 --batch-size=8 --iters_per_epoch 100 \ 10 | --gpu-ids 0 \ 11 | --loss-type mce --use-balanced-weights \ 12 | --pi \ 13 | --lr-g 1e-4 --lr-scheduler poly \ 14 | --checkname kd_pi 15 | 16 | # rand 17 | python train_kd.py --dataset SUNRGBD --base-size 512 --crop-size 512 --workers 4 \ 18 | --epochs 100 --eval-interval 5 --batch-size=8 --iters_per_epoch 100 \ 19 | --gpu-ids 1 \ 20 | --loss-type mce --use-balanced-weights \ 21 | --pi \ 22 | --lr-g 5e-3 --lr-scheduler poly \ 23 | --checkname kd_pi_lr5e-3 24 | 25 | # loss 在变,但是值太小 26 | # pa=1000, pa loss 激增 4->40 27 | # Acc_pixel: 0.2555, mIoU: 0.0246 第1次 valid 掉太多 28 | # pa=10, 仍然带动 pa loss 和 seg loss 不断增加 29 | python train_kd.py --dataset SUNRGBD --base-size 512 --crop-size 512 --workers 4 \ 30 | --epochs 100 --eval-interval 5 --batch-size=8 --iters_per_epoch 100 \ 31 | --gpu-ids 1 \ 32 | --loss-type mce --use-balanced-weights \ 33 | --pa --lambda-pa 10. \ 34 | --lr-g 1e-3 --lr-scheduler poly \ 35 | --checkname kd_pa10 36 | 37 | # 三种 loss 数量级有无影响 38 | # pa=1000, pi=10, pa loss 激增 4->120, pi loss 激增 58 -> 200+ 39 | # Acc_pixel: 0.3349, mIoU: 0.0224 40 | python train_kd.py --dataset SUNRGBD --base-size 512 --crop-size 512 --workers 4 \ 41 | --epochs 100 --eval-interval 5 --batch-size=8 --iters_per_epoch 100 \ 42 | --gpu-ids 2 \ 43 | --loss-type mce --use-balanced-weights \ 44 | --pi --pa \ 45 | --lambda-pi 10. --lambda-pa 10. \ 46 | --lr-g 1e-3 --lr-scheduler poly \ 47 | --checkname kd_pi10_pa10 -------------------------------------------------------------------------------- /datasets/build_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import * 3 | from utils.vis import get_label_name_colors 4 | 5 | this_dir = os.path.dirname(__file__) 6 | 7 | data_cfg = { 8 | 'SUNRGBD': { 9 | 'root': '/datasets/rgbd_dataset/SUNRGBD', 10 | 'class': SUNRGBD, 11 | 'label_name_colors': get_label_name_colors(os.path.join(this_dir, 'sun/sun37.csv')) 12 | } 13 | } 14 | 15 | 16 | def build_datasets(dataset, base_size, crop_size): 17 | if dataset not in data_cfg: 18 | raise NotImplementedError('no such dataset') 19 | 20 | config = data_cfg[dataset] 21 | cls = config['class'] 22 | 23 | trainset = cls(config['root'], 'train', base_size, crop_size) 24 | validset = cls(config['root'], 'valid', base_size, crop_size) 25 | testset = cls(config['root'], 'test', base_size, crop_size) 26 | 27 | return trainset, validset, testset 28 | 29 | 30 | if __name__ == '__main__': 31 | from utils.vis import plt_img_target 32 | from utils.misc import recover_color_img 33 | import numpy as np 34 | from utils.calculate_weights import cal_class_weights 35 | 36 | dataset = 'SUNRGBD' 37 | trainset, validset, testset = build_datasets(dataset, base_size=512, crop_size=512) 38 | # cal_class_weights(trainset, trainset.num_classes, save_dir='/datasets/rgbd_dataset/SUNRGBD/train') 39 | 40 | for idx, sample in enumerate(validset): 41 | img, target = sample['img'], sample['target'] 42 | img, target = img.squeeze(0), target.squeeze(0) 43 | print(img.shape) 44 | 45 | img, target = img.numpy(), target.numpy() 46 | img = recover_color_img(img) 47 | print(np.unique(target)) 48 | 49 | target = target.astype('uint8') 50 | target = trainset.remap_fn(target) 51 | 52 | plt_img_target(img, target, trainset.label_colors) 53 | 54 | if idx == 4: 55 | exit(0) 56 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import json 5 | import imageio 6 | import numpy as np 7 | import constants 8 | 9 | """ 10 | experiment_dir: tensorboard 11 | save/load checkpoint 12 | save_experiment_config 13 | save_active_selections, mask of select parts 14 | """ 15 | 16 | 17 | class Saver: 18 | 19 | def __init__(self, args, suffix='', timestamp='', 20 | experiment_group=None, remove_existing=False): 21 | 22 | self.args = args 23 | 24 | if experiment_group is None: 25 | experiment_group = args.dataset 26 | 27 | # runs/ tensorboard 28 | self.experiment_dir = os.path.join(constants.RUNS, experiment_group, 29 | f'{args.checkname}_{timestamp}', suffix) 30 | 31 | if remove_existing and os.path.exists(self.experiment_dir): 32 | shutil.rmtree(self.experiment_dir) 33 | 34 | if not os.path.exists(self.experiment_dir): 35 | print(f'Creating dir {self.experiment_dir}') 36 | os.makedirs(self.experiment_dir) 37 | 38 | def save_checkpoint(self, state, filename='checkpoint.pth.tar'): 39 | filename = os.path.join(self.experiment_dir, filename) 40 | torch.save(state, filename) 41 | 42 | def load_checkpoint(self, filename='checkpoint.pth.tar', file_path=None): 43 | filename = os.path.join(self.experiment_dir, filename) if not file_path else file_path 44 | return torch.load(filename) 45 | 46 | def save_experiment_config(self): 47 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 48 | log_file = open(logfile, 'w') 49 | arg_dictionary = vars(self.args) 50 | log_file.write(json.dumps(arg_dictionary, indent=4, sort_keys=True)) # 按 key 排序 51 | log_file.close() 52 | 53 | def save_active_selections(self, paths, regional=False): 54 | if regional: 55 | Saver.save_masks(os.path.join(self.experiment_dir, "selections"), paths) 56 | else: 57 | filename = os.path.join(self.experiment_dir, 'selections.txt') 58 | with open(filename, 'w') as fptr: 59 | for p in paths: 60 | fptr.write(p.decode('utf-8') + '\n') 61 | 62 | @staticmethod 63 | def save_masks(directory, paths): 64 | if not os.path.exists(directory): 65 | os.makedirs(directory) 66 | for p in paths: 67 | imageio.imwrite(os.path.join(directory, p.decode('utf-8') + '.png'), 68 | (paths[p] * 255).astype(np.uint8)) 69 | -------------------------------------------------------------------------------- /demo/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datasets.build_datasets import data_cfg 4 | from datasets.transforms import remap 5 | 6 | from model.bisenet import BiSeNet 7 | from utils.misc import load_state_dict, recover_color_img, mkdir 8 | from utils.vis import color_code_target 9 | import cv2 10 | import numpy as np 11 | import torch.nn.functional as F 12 | from PIL import Image 13 | import matplotlib.pyplot as plt 14 | import torchvision.transforms as transforms 15 | 16 | img_size = (480, 640) 17 | 18 | # trans 19 | test_trans = transforms.Compose([ 20 | transforms.Resize(img_size), # test, 直接 resize 到 crop size 21 | transforms.ToTensor(), # Normalize tensor image 22 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 23 | ]) 24 | remap_fn = remap(0) 25 | label_names, label_colors = data_cfg['SUNRGBD']['label_name_colors'] 26 | 27 | 28 | @torch.no_grad() 29 | def infer_img(img, vis=False): # pillow read img 30 | img = test_trans(img).unsqueeze(0).cuda() 31 | pred = model(img) 32 | pred = F.interpolate(pred, size=img_size, mode='bilinear', align_corners=True) 33 | pred = torch.argmax(pred, dim=1).squeeze().cpu().numpy().astype('uint8') 34 | 35 | pred = remap_fn(pred) 36 | predict = color_code_target(pred, label_colors) 37 | 38 | if vis: 39 | f, ax = plt.subplots(1, 2, figsize=(10, 5)) 40 | ax[0].imshow(recover_color_img(img)) 41 | ax[0].set_title('img') 42 | ax[1].imshow(predict) 43 | ax[1].set_title('predict') 44 | plt.show() 45 | 46 | return predict 47 | 48 | 49 | if __name__ == '__main__': 50 | # arch = 'res18' 51 | arch = 'res101' 52 | 53 | # load model 54 | if arch == 'res18': 55 | model = BiSeNet(37, context_path='resnet18', in_planes=32) 56 | load_state_dict(model, ckpt_path='runs/SUNRGBD/kd_pi_lr1e-3_Jul28_002404/checkpoint.pth.tar') 57 | elif arch == 'res101': 58 | model = BiSeNet(37, context_path='resnet101', in_planes=64) 59 | load_state_dict(model, ckpt_path='runs/SUNRGBD/res101_inp64_deconv_Jul26_205859/checkpoint.pth.tar') 60 | else: 61 | raise NotImplementedError 62 | model.eval().cuda() 63 | 64 | # infer dir 65 | exp = 'sun' 66 | img_dir = f'img/{exp}/rgb' 67 | save_dir = f'img/{exp}/seg_{arch}' 68 | mkdir(save_dir) 69 | 70 | for img in os.listdir(img_dir): 71 | print(img) 72 | img_name = img.split('.')[0] 73 | img = Image.open(f'{img_dir}/{img}').convert('RGB') 74 | predict = infer_img(img, vis=True) 75 | cv2.imwrite(f'{save_dir}/{img_name}.png', predict[:, :, ::-1]) 76 | -------------------------------------------------------------------------------- /onnx/torch2onnx.md: -------------------------------------------------------------------------------- 1 | ## Pytroch 上采样 -> ONNX 2 | 3 | #### ConvTranspose2d 反卷积 4 | ```py 5 | # x8 upsample 6 | self.ffm_upsample = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=8, stride=8, padding=0, output_padding=0) 7 | ``` 8 | ``` 9 | %298 : Float(1, 14, 60, 80) = onnx::Add(%297, %291), scope: BiSeNet/FeatureFusionModule[FFM] 10 | %299 : Float(1, 14, 480, 640) = onnx::ConvTranspose[dilations=[1, 1], group=1, kernel_shape=[8, 8], pads=[0, 0, 0, 0], strides=[8, 8]](%298, %187), scope: BiSeNet/DeconvBlock[ffm_upsample]/ConvTranspose2d[deconv] 11 | ``` 12 | 13 | #### Interpolation 插值 14 | ```py 15 | # x8 upsample 16 | result = F.interpolate(result, size=(480, 640), mode='bilinear') 17 | # or 18 | result = F.interpolate(result, scale_factor=8, mode='bilinear') 19 | ``` 20 | 设置 `size` or `scale_factor` 其实背后对应同一种插值方式,所以转化成 onnx 时,过程是一样的 21 | ``` 22 | %278 : Float(1, 14, 60, 80) = onnx::Add(%277, %271), scope: BiSeNet/FeatureFusionModule[FFM] 23 | %279 : Tensor = onnx::Constant[value= 1 1 8 8 [ CPUFloatType{4} ]](), scope: BiSeNet 24 | %280 : Float(1, 14, 480, 640) = onnx::Upsample[mode="linear"](%278, %279), scope: BiSeNet 25 | ``` 26 | ``` 27 | %278 : Float(1, 14, 60, 80) = onnx::Add(%277, %271), scope: BiSeNet/FeatureFusionModule[FFM] 28 | %279 : Tensor = onnx::Constant[value= 1 1 8 8 [ CPUFloatType{4} ]](), scope: BiSeNet 29 | %280 : Float(1, 14, 480, 640) = onnx::Upsample[mode="linear"](%278, %279), scope: BiSeNet 30 | ``` 31 | --- 32 | 插值方式得到的 onnx 模型在用 tensorRT 构建 engine 时会报错:**Attribute not found: height_scale** 33 | - 原因:`%279 Constant` 定义了放缩因子,而 `%280 Upsample` 并没有得到这个 scale,第一个参数是 height,所以就报错:没有 `height_scale` 这一项 34 | - 修改:重载 onnx 的 upsample 35 | 36 | 37 | 重载最近邻插值:upsample_nearest2d 38 | ```py 39 | import torch.onnx.symbolic 40 | 41 | # Override Upsample's ONNX export until new opset is supported 42 | @torch.onnx.symbolic.parse_args('v', 'is') 43 | def upsample_nearest2d(g, input, output_size): 44 | height_scale = float(output_size[-2]) / input.type().sizes()[-2] 45 | width_scale = float(output_size[-1]) / input.type().sizes()[-1] 46 | return g.op("Upsample", input, 47 | scales_f=(1, 1, height_scale, width_scale), 48 | mode_s="nearest") 49 | 50 | # 点进去原始的函数定义,就知道重载函数怎么写了 51 | torch.onnx.symbolic.upsample_nearest2d = upsample_nearest2d 52 | ``` 53 | 重载双线性插值:upsample_bilinear2d 54 | ```py 55 | @torch.onnx.symbolic.parse_args('v', 'is', 'i') 56 | def upsample_bilinear2d(g, input, output_size, align_corners): 57 | height_scale = float(output_size[-2]) / input.type().sizes()[-2] # 8 58 | width_scale = float(output_size[-1]) / input.type().sizes()[-1] # 8 59 | return g.op("Upsample", input, 60 | scales_f=(1, 1, height_scale, width_scale), 61 | mode_s="linear") 62 | 63 | torch.onnx.symbolic.upsample_bilinear2d = upsample_bilinear2d 64 | ``` 65 | 这样模型转成 onnx 时 upsample 就能拿到 scale 了 66 | ``` 67 | %276 : Float(1, 14, 60, 80) = onnx::Add(%275, %269), scope: BiSeNet/FeatureFusionModule[FFM] 68 | %277 : Float(1, 14, 480, 640) = onnx::Upsample[mode="nearest", scales=[1, 1, 8, 8]](%276), scope: BiSeNet 69 | ``` -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argument_parser 2 | from pprint import pprint 3 | 4 | args = argument_parser.parse_args() 5 | pprint(vars(args)) 6 | 7 | import os 8 | 9 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 10 | os.environ["CUDA_HOME"] = "/nfs/xs/local/cuda-10.2" 11 | 12 | if len(args.gpu_ids) > 1: 13 | args.sync_bn = True 14 | 15 | from torch.utils.tensorboard import SummaryWriter 16 | from datasets.build_datasets import build_datasets 17 | 18 | from model.bisenet import BiSeNet 19 | 20 | from utils.calculate_weights import cal_class_weights 21 | from utils.saver import Saver 22 | from utils.trainer import Trainer 23 | from utils.misc import AccCaches, get_curtime 24 | import numpy as np 25 | 26 | 27 | def main(): 28 | # dataset 29 | trainset, valset, testset = build_datasets(args.dataset, args.base_size, args.crop_size) 30 | 31 | model = BiSeNet(trainset.num_classes, args.context_path, args.in_planes) 32 | 33 | class_weights = None 34 | if args.use_balanced_weights: # default false 35 | class_weights = np.array([ # med_freq 36 | 0.382900, 0.452448, 0.637584, 0.377464, 0.585595, 37 | 0.479574, 0.781544, 0.982534, 1.017466, 0.624581, 38 | 2.589096, 0.980794, 0.920340, 0.667984, 1.172291, # 15 39 | 0.862240, 0.921714, 2.154782, 1.187832, 1.178115, # 20 40 | 1.848545, 1.428922, 2.849658, 0.771605, 1.656668, # 25 41 | 4.483506, 2.209922, 1.120280, 2.790182, 0.706519, # 30 42 | 3.994768, 2.220004, 0.972934, 1.481525, 5.342475, # 35 43 | 0.750738, 4.040773 # 37 44 | ]) 45 | # class_weights = np.load('/datasets/rgbd_dataset/SUNRGBD/train/class_weights.npy') 46 | # class_weights = cal_class_weights(trainset, trainset.num_classes) 47 | 48 | saver = Saver(args, timestamp=get_curtime()) 49 | writer = SummaryWriter(saver.experiment_dir) 50 | trainer = Trainer(args, model, trainset, valset, testset, class_weights, saver, writer) 51 | 52 | start_epoch = 0 53 | 54 | miou_caches = AccCaches(patience=5) # miou 55 | for epoch in range(start_epoch, args.epochs): 56 | trainer.training(epoch) 57 | if epoch % args.eval_interval == (args.eval_interval - 1): 58 | miou, pixelAcc = trainer.validation(epoch) 59 | miou_caches.add(epoch, miou) 60 | if miou_caches.full(): 61 | print('acc caches:', miou_caches.accs) 62 | print('best epoch:', trainer.best_epoch, 'best miou:', trainer.best_mIoU) 63 | _, max_miou = miou_caches.max_cache_acc() 64 | if max_miou < trainer.best_mIoU: 65 | print('end training') 66 | break 67 | 68 | print('valid') 69 | print('best mIoU:', trainer.best_mIoU, 'pixelAcc:', trainer.best_pixelAcc) 70 | 71 | # test 72 | epoch = trainer.load_best_checkpoint() 73 | test_mIoU, test_pixelAcc = trainer.validation(epoch, test=True) 74 | print('test') 75 | print('best mIoU:', test_mIoU, 'pixelAcc:', test_pixelAcc) 76 | 77 | writer.flush() 78 | writer.close() 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /datasets/sun/mk_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from utils.misc import write_list_to_txt, read_txt_as_list 4 | import random 5 | 6 | """ 7 | 10355 8 | train: 5285 0-5284 # 分出 1000 作为 valid 9 | test: 5050 5285-10334 # 分出 1000 作为 valid 10 | 11 | """ 12 | 13 | root = '/datasets/rgbd_dataset/SUNRGBD' 14 | 15 | 16 | def check_list_path(plist): 17 | ERROR = False 18 | for p in plist: 19 | if os.path.exists(p): 20 | continue 21 | else: 22 | ERROR = True 23 | print(p, 'not exist!') 24 | if not ERROR: 25 | print('all pass!') 26 | 27 | 28 | def save_train_test_txt(): # valid 只要随机筛选路径即可 29 | for split in ['train', 'test']: 30 | img_dir = os.path.join(root, split, 'image') 31 | depth_dir = os.path.join(root, split, 'depth') # png, 16bits 32 | target_dir = os.path.join(root, split, 'mask') # npy 33 | 34 | img_names = [p.split('.')[0] for p in os.listdir(img_dir)] 35 | 36 | img_paths = [os.path.join(img_dir, p + '.jpg') for p in img_names] 37 | depth_paths = [os.path.join(depth_dir, p + '.png') for p in img_names] 38 | target_paths = [os.path.join(target_dir, p + '.npy') for p in img_names] # npy 存储占空间很大? 可能格式错了 39 | 40 | print(f'{split}, img:', len(img_paths), 'depth:', len(depth_paths), 'target:', len(target_paths)) 41 | 42 | check_list_path(img_paths) 43 | check_list_path(depth_paths) 44 | check_list_path(target_paths) 45 | 46 | write_list_to_txt(img_paths, txt_path=os.path.join(root, f'{split}_img_paths.txt')) 47 | write_list_to_txt(depth_paths, txt_path=os.path.join(root, f'{split}_depth_paths.txt')) 48 | write_list_to_txt(target_paths, txt_path=os.path.join(root, f'{split}_target_paths.txt')) 49 | 50 | 51 | def generate_valid_txt(valid_num=1000): 52 | random.seed(100) 53 | 54 | valid_img_paths, valid_depth_paths, valid_target_paths = [], [], [] 55 | 56 | for split in ['train', 'test']: 57 | img_paths = read_txt_as_list(os.path.join(root, f'{split}_img_paths.txt')) 58 | depth_paths = read_txt_as_list(os.path.join(root, f'{split}_depth_paths.txt')) 59 | target_paths = read_txt_as_list(os.path.join(root, f'{split}_target_paths.txt')) 60 | 61 | chose_idxs = random.sample(range(len(img_paths)), valid_num) 62 | 63 | chose_img_paths = [img_paths[i] for i in chose_idxs] 64 | chose_depth_paths = [depth_paths[i] for i in chose_idxs] 65 | chose_target_paths = [target_paths[i] for i in chose_idxs] 66 | 67 | valid_img_paths += chose_img_paths 68 | valid_depth_paths += chose_depth_paths 69 | valid_target_paths += chose_target_paths 70 | 71 | print('valid, img:', len(valid_img_paths), 'depth:', len(valid_depth_paths), 'target:', len(valid_target_paths)) 72 | 73 | # 恐怖.... 74 | # random.shuffle(valid_img_paths) 75 | # random.shuffle(valid_depth_paths) 76 | # random.shuffle(valid_target_paths) 77 | 78 | write_list_to_txt(valid_img_paths, os.path.join(root, 'valid_img_paths.txt')) 79 | write_list_to_txt(valid_depth_paths, os.path.join(root, 'valid_depth_paths.txt')) 80 | write_list_to_txt(valid_target_paths, os.path.join(root, 'valid_target_paths.txt')) 81 | 82 | 83 | if __name__ == '__main__': 84 | # save_train_test_txt() 85 | generate_valid_txt() 86 | -------------------------------------------------------------------------------- /model/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | 14 | class LR_Scheduler(object): 15 | """Learning Rate Scheduler 16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 17 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 18 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 19 | Args: 20 | args: 21 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 22 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 23 | :attr:`args.lr_step` 24 | iters_per_epoch: number of iterations per epoch 25 | """ 26 | 27 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 28 | lr_step=0, warmup_epochs=0, warmup_start_lr=1e-5): 29 | self.mode = mode 30 | print('Using {} LR Scheduler!'.format(self.mode)) 31 | self.lr = base_lr 32 | if mode == 'step': 33 | assert lr_step 34 | self.lr_step = lr_step 35 | if isinstance(self.lr_step, list): # 添加首尾 36 | self.lr_step = [0] + self.lr_step + [num_epochs] 37 | 38 | self.iters_per_epoch = iters_per_epoch 39 | self.N = num_epochs * iters_per_epoch 40 | self.epoch = -1 41 | 42 | self.warmup_iters = warmup_epochs * iters_per_epoch 43 | if self.warmup_iters > 0: 44 | self.warmup_start_lr = warmup_start_lr 45 | self.warmup_factor = (self.lr / self.warmup_start_lr) ** (1. / self.warmup_iters) 46 | 47 | def __call__(self, optimizer, i, epoch, best_pred=None): 48 | T = epoch * self.iters_per_epoch + i 49 | 50 | # warm up lr schedule, 从1个小小的 lr 缓慢增加到 base_lr 51 | if self.warmup_iters > 0 and T <= self.warmup_iters: 52 | lr = self.warmup_start_lr * (self.warmup_factor ** T) 53 | else: 54 | if self.mode == 'cos': 55 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 56 | elif self.mode == 'poly': 57 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 58 | elif self.mode == 'step': 59 | if isinstance(self.lr_step, int): 60 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) # lr reduce step 61 | elif isinstance(self.lr_step, list): 62 | for idx in range(len(self.lr_step) - 1): 63 | if self.lr_step[idx] <= epoch < self.lr_step[idx + 1]: 64 | lr = self.lr * (0.1 ** idx) 65 | break 66 | else: 67 | raise TypeError 68 | else: 69 | raise NotImplemented 70 | 71 | self._adjust_learning_rate(optimizer, lr) 72 | 73 | def _adjust_learning_rate(self, optimizer, lr): 74 | if len(optimizer.param_groups) == 1: 75 | optimizer.param_groups[0]['lr'] = lr 76 | else: 77 | # enlarge the lr at the head 78 | optimizer.param_groups[0]['lr'] = lr 79 | for i in range(1, len(optimizer.param_groups)): # 认为后面层为自定义,设置更大 lr 80 | optimizer.param_groups[i]['lr'] = lr * 10 81 | -------------------------------------------------------------------------------- /train_kd.py: -------------------------------------------------------------------------------- 1 | import argument_parser 2 | from pprint import pprint 3 | 4 | args = argument_parser.parse_args() 5 | pprint(vars(args)) 6 | 7 | import os 8 | 9 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 10 | os.environ["CUDA_HOME"] = "/nfs/xs/local/cuda-10.2" 11 | 12 | if len(args.gpu_ids) > 1: 13 | args.sync_bn = True 14 | 15 | from torch.utils.tensorboard import SummaryWriter 16 | from datasets.build_datasets import build_datasets 17 | 18 | from model.bisenet import BiSeNet 19 | 20 | from utils.saver import Saver 21 | from utils.trainer_kd import Trainer 22 | from utils.misc import AccCaches, get_curtime, print_model_parm_nums, load_state_dict 23 | import numpy as np 24 | 25 | 26 | def main(): 27 | # dataset 28 | trainset, valset, testset = build_datasets(args.dataset, args.base_size, args.crop_size) 29 | 30 | # 定义 student/teacher 模型 31 | student = BiSeNet(trainset.num_classes, context_path='resnet18', in_planes=32) 32 | teacher = BiSeNet(trainset.num_classes, context_path='resnet101', in_planes=64) 33 | print_model_parm_nums(student, 'student') # student: Number of params: 5.66 M 34 | print_model_parm_nums(teacher, 'teacher') # teacher: Number of params: 132.92 M 35 | 36 | # 加载 student/teacher 已经训练好的模型 37 | device = f'cuda:{args.gpu_ids}' 38 | load_state_dict(student, 'runs/SUNRGBD/res18_inp32_deconv_Jul27_100319/checkpoint.pth.tar', device) 39 | load_state_dict(teacher, 'runs/SUNRGBD/res101_inp64_deconv_Jul26_205859/checkpoint.pth.tar', device) 40 | 41 | class_weights = None 42 | if args.use_balanced_weights: # default false 43 | class_weights = np.array([ # med_freq 44 | 0.382900, 0.452448, 0.637584, 0.377464, 0.585595, 45 | 0.479574, 0.781544, 0.982534, 1.017466, 0.624581, 46 | 2.589096, 0.980794, 0.920340, 0.667984, 1.172291, # 15 47 | 0.862240, 0.921714, 2.154782, 1.187832, 1.178115, # 20 48 | 1.848545, 1.428922, 2.849658, 0.771605, 1.656668, # 25 49 | 4.483506, 2.209922, 1.120280, 2.790182, 0.706519, # 30 50 | 3.994768, 2.220004, 0.972934, 1.481525, 5.342475, # 35 51 | 0.750738, 4.040773 # 37 52 | ]) 53 | 54 | saver = Saver(args, timestamp=get_curtime()) 55 | writer = SummaryWriter(saver.experiment_dir) 56 | 57 | trainer = Trainer(args, student, teacher, 58 | trainset, valset, testset, class_weights, saver, writer) 59 | 60 | start_epoch = 0 61 | 62 | miou_caches = AccCaches(patience=5) # miou 63 | for epoch in range(start_epoch, args.epochs): 64 | trainer.training(epoch) 65 | if epoch % args.eval_interval == (args.eval_interval - 1): 66 | miou, pixelAcc = trainer.validation(epoch) 67 | miou_caches.add(epoch, miou) 68 | if miou_caches.full(): 69 | print('acc caches:', miou_caches.accs) 70 | print('best epoch:', trainer.best_epoch, 'best miou:', trainer.best_mIoU) 71 | _, max_miou = miou_caches.max_cache_acc() 72 | if max_miou < trainer.best_mIoU: 73 | print('end training') 74 | break 75 | 76 | print('valid') 77 | print('best mIoU:', trainer.best_mIoU, 'pixelAcc:', trainer.best_pixelAcc) 78 | 79 | # test 80 | epoch = trainer.load_best_checkpoint() 81 | test_mIoU, test_pixelAcc = trainer.validation(epoch, test=True) 82 | print('test') 83 | print('best mIoU:', test_mIoU, 'pixelAcc:', test_pixelAcc) 84 | 85 | writer.flush() 86 | writer.close() 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # https://stats.stackexchange.com/questions/179835/how-to-build-a-confusion-matrix-for-a-multiclass-classifier 4 | epsilon = 1e-5 5 | 6 | 7 | def calculate_miou(confusion_matrix): 8 | MIoU = np.divide(np.diag(confusion_matrix), ( 9 | np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) - 10 | np.diag(confusion_matrix))) 11 | MIoU = np.nanmean(MIoU) 12 | return MIoU 13 | 14 | 15 | def fast_hist(a, b, n): 16 | """ 17 | a and b are predict and mask respectively 18 | n is the number of classes 19 | """ 20 | k = (a >= 0) & (a < n) 21 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 22 | 23 | 24 | def per_class_iou(hist): 25 | return (np.diag(hist) + epsilon) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon) 26 | 27 | 28 | class Evaluator(object): 29 | 30 | def __init__(self, num_class): 31 | np.seterr(divide='ignore', invalid='ignore') 32 | self.num_class = num_class 33 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 34 | 35 | def Pixel_Accuracy(self): 36 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() # TP / ALL pixels 37 | return Acc 38 | 39 | def Acc_of_each_class(self): 40 | Accs = np.divide(np.diag(self.confusion_matrix), # TP of each class 41 | self.confusion_matrix.sum(axis=1)) # (TP+FP) of each class 列 42 | return Accs # vector 43 | 44 | def Mean_Pixel_Accuracy(self): 45 | Acc = np.nanmean(self.Acc_of_each_class()) # mean of vector 46 | return Acc 47 | 48 | def IOU_of_each_class(self): 49 | inter = np.diag(self.confusion_matrix) # TP 50 | union = np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - inter # TP+FP+FN 51 | IoUs = np.divide(inter, union) # IoU of each class 52 | return IoUs # vector 53 | 54 | def Mean_Intersection_over_Union(self): 55 | MIoU = np.nanmean(self.IOU_of_each_class()) # mIoU 56 | return MIoU 57 | 58 | def Mean_Intersection_over_Union_20(self): # 20类之后 59 | MIoU = 0 60 | if self.num_class > 20: 61 | subset_20 = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 23, 27, 32, 33, 35, 38]) 62 | confusion_matrix = self.confusion_matrix[subset_20[:, None], subset_20] # 取出子矩阵 63 | MIoU = np.divide(np.diag(confusion_matrix), ( 64 | np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) - 65 | np.diag(confusion_matrix))) 66 | MIoU = np.nanmean(MIoU) 67 | return MIoU 68 | 69 | def Mean_Dice(self): 70 | inter = np.diag(self.confusion_matrix) # vector 71 | dices = np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) 72 | dice = np.divide(2 * inter, dices) 73 | dice = np.nanmean(dice) 74 | return dice 75 | 76 | def Frequency_Weighted_Intersection_over_Union(self): 77 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 78 | iu = np.divide(np.diag(self.confusion_matrix), ( 79 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 80 | np.diag(self.confusion_matrix))) 81 | 82 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 83 | return FWIoU 84 | 85 | def _generate_matrix(self, gt_image, pre_image): 86 | mask = (gt_image >= 0) & (gt_image < self.num_class) # 不会计算背景 87 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] # i:gt, j:pre 88 | count = np.bincount(label, minlength=self.num_class ** 2) # total classes on confusion_matrix 89 | confusion_matrix = count.reshape(self.num_class, self.num_class) 90 | return confusion_matrix 91 | 92 | def add_batch(self, gt_image, pre_image, return_miou=False): 93 | assert gt_image.shape == pre_image.shape # np img, B,H,W 94 | confusion_matrix = self._generate_matrix(gt_image, pre_image) 95 | self.confusion_matrix += confusion_matrix 96 | if return_miou: 97 | return calculate_miou(confusion_matrix) 98 | 99 | def reset(self): 100 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 101 | 102 | def dump_matrix(self, path): 103 | np.save(path, self.confusion_matrix) 104 | -------------------------------------------------------------------------------- /datasets/sun/SUNRGBD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from utils.misc import read_txt_as_list 5 | from utils.vis import get_label_name_colors 6 | import datasets.transforms as tr 7 | from PIL import Image 8 | import constants 9 | 10 | """ 11 | SUN-RGBD 12 | 10335 = 5285(train) + 5050(test) 13 | train 8000, valid 3000, test 2335 14 | """ 15 | this_dir = os.path.dirname(__file__) 16 | 17 | 18 | class SUNRGBD(Dataset): 19 | 20 | def __init__(self, root, split, base_size=None, crop_size=None): 21 | super().__init__() 22 | self.img_paths = read_txt_as_list(os.path.join(root, f'{split}_img_paths.txt')) 23 | # self.depth_paths = read_txt_as_list(os.path.join(root, f'{split}_depth_paths.txt')) 24 | self.target_paths = read_txt_as_list(os.path.join(root, f'{split}_target_paths.txt')) 25 | 26 | # debug 可以用 iters_per_epoch debug 了 27 | # self.img_paths, self.target_paths = self.img_paths[:100], self.target_paths[:100] 28 | 29 | self.base_size = base_size # train 基准 size 30 | self.crop_size = crop_size # train, valid, test 31 | 32 | self.transform = self.get_transform(split) 33 | 34 | self.bg_idx = 0 35 | self.num_classes = 37 36 | self.mapbg_fn = tr.mapbg(self.bg_idx) 37 | self.remap_fn = tr.remap(self.bg_idx) 38 | 39 | self.label_names, self.label_colors = get_label_name_colors(os.path.join(this_dir, 'sun37.csv')) 40 | 41 | def __getitem__(self, index): 42 | img = Image.open(self.img_paths[index]).convert('RGB') 43 | # depth = Image.open(self.depth_paths(index)) 44 | target = np.load(self.target_paths[index]).astype(int) 45 | # target = self.reduce_class14(target) 46 | target = self.mapbg_fn(target) 47 | target = Image.fromarray(target) 48 | 49 | sample = { 50 | 'img': img, 51 | # 'depth': depth, 52 | 'target': target 53 | } 54 | if self.transform: # 只要设置 transform 为 None, 就能方便地得到原始图片 55 | sample = self.transform(sample) 56 | 57 | return sample 58 | 59 | def __len__(self): 60 | return len(self.img_paths) 61 | 62 | def get_transform(self, split): 63 | if split == 'train': 64 | return tr.Compose([ 65 | tr.RandomHorizontalFlip(), 66 | tr.RandomScaleCrop(base_size=self.base_size, 67 | crop_size=self.crop_size, 68 | scales=(0.8, 2.0), 69 | fill=constants.BG_INDEX), 70 | tr.RandomGaussianBlur(), 71 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 72 | tr.ToTensor() 73 | ]) 74 | elif split == 'valid': 75 | return tr.Compose([ 76 | tr.FixScaleCrop(crop_size=self.crop_size), # valid, 固定长宽比 crop 77 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 78 | tr.ToTensor() 79 | ]) 80 | elif split == 'test': 81 | return tr.Compose([ 82 | tr.FixedResize(size=self.crop_size), # test, 直接 resize 到 crop size 83 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 84 | tr.ToTensor() 85 | ]) 86 | else: 87 | return None 88 | 89 | def reduce_class14(self, label): 90 | # 19 91 | abort_classes = [4, 11, 17, 18, 19, 20, 21, 23, 24, 26, 27, 28, 30, 32, 33, 34, 35, 36, 37] 92 | 93 | # reduce classes 37 -> 14 94 | # 0 abort 19 classes to background 95 | for idx in abort_classes: 96 | label[np.where(label == idx)] = 0 97 | 98 | # whether use big class 99 | # ====================== 100 | # # 7 table(counter, desk) 12,14 101 | label[np.where(label == 12)] = 7 102 | label[np.where(label == 14)] = 7 103 | # 10 bookshelf(shelves) 15 104 | label[np.where(label == 15)] = 10 105 | # 13 blinds(curtain) 16 106 | label[np.where(label == 16)] = 13 107 | # total: 19 + 4 + 14 = 37 108 | # ====================== 109 | # use desk, not use table, counter 110 | # for idx in [7, 12]: 111 | # label[np.where(label == idx)] = 0 112 | 113 | # 0,1,2,3 no change 114 | label[np.where(label == 5)] = 4 115 | label[np.where(label == 6)] = 5 116 | label[np.where(label == 7)] = 6 117 | label[np.where(label == 8)] = 7 118 | label[np.where(label == 9)] = 8 119 | label[np.where(label == 10)] = 9 120 | label[np.where(label == 13)] = 10 121 | label[np.where(label == 22)] = 11 122 | label[np.where(label == 25)] = 12 123 | label[np.where(label == 29)] = 13 124 | label[np.where(label == 31)] = 14 125 | 126 | return label 127 | -------------------------------------------------------------------------------- /model/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /onnx/onnx_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.onnx 3 | import onnxruntime as ort # rt: runtime 4 | from model.bisenet import BiSeNet 5 | from utils.misc import load_state_dict 6 | import onnx 7 | 8 | """ 9 | https://pytorch.apachecn.org/docs/1.0/onnx.html 10 | 局限、支持的运算符、 11 | https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html 12 | """ 13 | 14 | """ONNX faults 15 | ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11 16 | ONNX export failed: Couldn't export operator aten::upsample_bilinear2d 17 | ONNX export failed on upsample_bilinear2d because align_corners == True not supported 18 | """ 19 | 20 | 21 | def cvt_onnx(model, onnx_path): 22 | model.eval() 23 | 24 | # Input to the model, 定义 onnx 模型的可变参数 25 | batch_size = 1 26 | input_h, input_w = 512, 512 27 | 28 | # x dummpy input: can be random as long as it is the right type and size. 29 | # Note: the input size will be fixed in the exported ONNX graph for all the input’s dimensions, unless specified as a dynamic axes. 30 | # 使用 dynamic axes: B,H,W, 使 ONNX graph 能接受变化 shape 的输入 31 | x = torch.randn(batch_size, 3, 512, 512, requires_grad=True) 32 | out = model(x) 33 | 34 | torch.onnx.export( 35 | model, 36 | x, # model input, (or a tuple for multiple inputs) 37 | onnx_path, # file path 38 | verbose=True, # 输出转换过程 39 | 40 | export_params=True, # store the trained parameter weights inside the model file, default True 41 | opset_version=9, # the ONNX version to export the model to, default 9 42 | do_constant_folding=True, # whether to execute constant folding for optimization, default False 43 | 44 | input_names=['input'], # the model's input names, 可以是多个 45 | output_names=['output'], # the model's output names 46 | dynamic_axes={ # dynamic dimensions 47 | 'input': { # variable lenght axes 48 | 0: 'batch_size', 49 | # 2: 'input_h', 50 | # 3: 'input_w', 51 | }, 52 | 'output': { 53 | 0: 'batch_size', 54 | # 2: 'output_h', 55 | # 3: 'output_w', 56 | } 57 | } 58 | ) 59 | 60 | check_onnx_model(onnx_path) 61 | 62 | 63 | def check_onnx_model(onnx_path): 64 | # Load the ONNX model 65 | model = onnx.load(onnx_path) # load the saved model and will output a onnx.ModelProto structure 66 | print('load', onnx_path) 67 | 68 | # Check that the IR is well formed 69 | try: 70 | onnx.checker.check_model(model) # verify the model’s structure and confirm that the model has a valid schema. 71 | print('check pass!') 72 | finally: # except 73 | print('check error!') 74 | 75 | # Print a human readable representation of the graph 76 | onnx.helper.printable_graph(model.graph) # 没有输出? 因为 check 失败了? 77 | 78 | 79 | def onnx_infer(onnx_path): 80 | ort_session = ort.InferenceSession(onnx_path) 81 | 82 | 83 | def demo_res18(): 84 | from torchvision.models.resnet import resnet18 85 | 86 | model = resnet18(pretrained=False).eval() 87 | 88 | batch_size = 1 89 | input_h, input_w = 512, 512 90 | x = torch.rand((batch_size, 3, input_h, input_w)) 91 | out = model(x) 92 | # print(out.shape) # 1,1000, 采用 avgpool, 输入 size 无关 93 | 94 | onnx_path = 'onnx/res18.onnx' 95 | 96 | torch.onnx.export( 97 | model, 98 | x, 99 | onnx_path, 100 | verbose=True, 101 | 102 | export_params=True, # store the trained parameter weights inside the model file, default True 103 | opset_version=9, # the ONNX version to export the model to, default 9 104 | do_constant_folding=True, # whether to execute constant folding for optimization, default False 105 | 106 | input_names=['input'], # the model's input names, 可以是多个 107 | output_names=['output'], # the model's output names 108 | dynamic_axes={ # dynamic dimensions 109 | 'input': { # variable lenght axes 110 | 0: 'batch_size', 111 | 2: 'input_h', 112 | 3: 'input_w', 113 | }, 114 | 'output': { 115 | 0: 'batch_size', 116 | } 117 | } 118 | ) 119 | 120 | check_onnx_model(onnx_path) 121 | 122 | 123 | if __name__ == '__main__': 124 | # model = BiSeNet(37, context_path='resnet18', in_planes=32) 125 | # load_state_dict(model, ckpt_path='runs/SUNRGBD/res18_inp32_deconv_Jul27_100319/checkpoint.pth.tar') 126 | # onnx_path = 'onnx/res18_inp32_deconv_Jul27_100319.onnx' 127 | # cvt_onnx(model, onnx_path) 128 | 129 | # todo: demo resnet 130 | demo_res18() 131 | 132 | # model = BiSeNet(37, context_path='resnet101', in_planes=64) 133 | # load_state_dict(model, ckpt_path='runs/SUNRGBD/res101_inp64_deconv_Jul26_205859/checkpoint.pth.tar') 134 | # cvt_onnx(model, onnx_path='onnx/res101_inp64_deconv_Jul26_205859.onnx') 135 | -------------------------------------------------------------------------------- /demo/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | this_dir = os.path.dirname(__file__) 5 | sys.path.insert(0, os.path.join(this_dir, '..')) 6 | 7 | import torch 8 | from utils.metrics import Evaluator 9 | from utils.misc import approx_print 10 | import numpy as np 11 | from tqdm import tqdm 12 | from datasets.build_datasets import data_cfg, build_datasets 13 | from model.bisenet import BiSeNet 14 | from utils.misc import load_state_dict 15 | import matplotlib.pyplot as plt 16 | import torch.nn.functional as F 17 | 18 | 19 | @torch.no_grad() 20 | def evals(arch='res18'): 21 | """ 22 | class IoU & mIoU, Acc & mAcc 23 | """ 24 | trainset, valset, testset = build_datasets(dataset='SUNRGBD', base_size=512, crop_size=512) 25 | 26 | # load model 27 | if arch == 'res18': 28 | model = BiSeNet(37, context_path='resnet18', in_planes=32) 29 | load_state_dict(model, ckpt_path='runs/SUNRGBD/kd_pi_lr1e-3_Jul28_002404/checkpoint.pth.tar') 30 | elif arch == 'res101': 31 | model = BiSeNet(37, context_path='resnet101', in_planes=64) 32 | load_state_dict(model, ckpt_path='runs/SUNRGBD/res101_inp64_deconv_Jul26_205859/checkpoint.pth.tar') 33 | else: 34 | raise NotImplementedError 35 | 36 | model.eval() 37 | model.cuda() 38 | 39 | evaluator = Evaluator(testset.num_classes) 40 | evaluator.reset() 41 | 42 | print('imgs:', len(testset)) 43 | for sample in tqdm(testset): # already transfrom 44 | image, target = sample['img'], sample['target'] 45 | image = image.unsqueeze(0).cuda() 46 | pred = model(image) 47 | pred = F.interpolate(pred, size=(512, 512), mode='bilinear', align_corners=True) 48 | pred = torch.argmax(pred, dim=1).squeeze().cpu().numpy() 49 | target = target.numpy() 50 | evaluator.add_batch(target, pred) 51 | 52 | print('PixelAcc:', evaluator.Pixel_Accuracy()) 53 | 54 | print('mAcc') # 各类的 acc 均值 55 | Accs = evaluator.Acc_of_each_class() 56 | print(np.nanmean(Accs)) # mAcc, mean of non-NaN elements 57 | approx_print(Accs) 58 | 59 | print('mIoU') 60 | IOUs = evaluator.IOU_of_each_class() 61 | print(np.nanmean(IOUs)) # mIoU 62 | approx_print(IOUs) 63 | 64 | 65 | results = { 66 | 'res18(pi)': { 67 | 'acc': [80.85, 87.48, 56.45, 63.27, 78.97, 43.23, 55.04, 43.91, 67.55, 39.17, 59.6, 37.32, 34.79, 22.28, 11.5, 55.31, 45.36, 40.39, 47.57, 0.0, 29.44, 68 | 81.86, 45.62, 28.09, 46.84, 30.37, 14.32, 0.0, 30.08, 62.07, 19.18, 22.79, 78.56, 59.2, 37.63, 52.47, 20.61], 69 | 'iou': [69.11, 80.0, 37.21, 53.74, 51.76, 35.57, 40.81, 27.64, 41.49, 23.95, 38.55, 28.27, 25.08, 16.26, 6.74, 37.46, 32.96, 25.95, 32.54, 0.0, 13.41, 70 | 58.53, 24.09, 22.6, 31.84, 17.78, 11.31, 0.0, 12.1, 46.05, 9.91, 13.28, 51.49, 39.55, 20.73, 38.68, 9.52], 71 | 'Acc': 69.58, 72 | 'mIoU': 30.43 73 | }, 74 | 'res101': { 75 | 'acc': [86.04, 93.08, 69.26, 72.1, 85.25, 58.81, 66.68, 58.51, 72.58, 47.83, 74.31, 43.0, 41.78, 32.6, 20.34, 65.79, 59.12, 56.35, 58.64, 0.02, 45.64, 76 | 82.94, 59.95, 52.79, 75.8, 50.62, 41.36, 1.61, 48.72, 68.63, 63.91, 31.98, 87.75, 70.76, 63.7, 69.97, 44.99 77 | ], 78 | 'iou': [76.74, 87.19, 46.49, 65.72, 66.73, 49.55, 50.68, 41.15, 50.87, 35.88, 50.18, 34.82, 27.59, 23.69, 11.3, 49.93, 41.08, 36.79, 38.42, 0.01, 24.88, 79 | 67.75, 35.37, 45.1, 53.87, 27.59, 30.64, 1.38, 27.39, 58.07, 45.7, 18.68, 68.92, 53.55, 34.65, 50.06, 20.75], 80 | 'Acc': 77.59, 81 | 'mIoU': 41.87 82 | } 83 | } 84 | 85 | 86 | def plt_class_evals(arch): 87 | label_names, label_colors = data_cfg['SUNRGBD']['label_name_colors'] 88 | label_names, label_colors = label_names[1:], label_colors[1:] 89 | 90 | xs = np.arange(len(label_names)) 91 | 92 | accs = results[arch]['acc'] 93 | ious = results[arch]['iou'] 94 | 95 | plt.figure(figsize=(14, 5), dpi=100) 96 | 97 | width = 0.4 98 | fontsize = 8 99 | rotation = 0 100 | 101 | for idx, (x, y) in enumerate(zip(xs, accs)): 102 | plt.bar(x - 0.2, y, width=width, align='center', # 底部 tick 对应位置 103 | linewidth=1, edgecolor=[0.7, 0.7, 0.7], 104 | color=[a / 255.0 for a in label_colors[idx]]) 105 | plt.text(x - 0.2, y + 0.2, 106 | s='%.2f' % y, 107 | rotation=rotation, 108 | ha='center', va='bottom', fontsize=fontsize) 109 | 110 | for idx, (x, y) in enumerate(zip(xs, ious)): 111 | plt.bar(x + 0.2, y, width=width, 112 | linewidth=1, edgecolor=[0.7, 0.7, 0.7], 113 | color=[a / 255.0 for a in label_colors[idx]]) 114 | plt.text(x + 0.2, y + 0.2, 115 | s='%.2f' % y, 116 | rotation=rotation, 117 | ha='center', va='bottom', fontsize=fontsize) 118 | 119 | plt.xticks(xs, label_names, size='small', rotation=60) 120 | plt.ylim([0, 100]) 121 | plt.title(f"{arch}(pi). Acc & IoU of SUNRGBD-37class testset (5050) | Acc: {results[arch]['Acc']}, mIoU: {results[arch]['mIoU']}") 122 | plt.show() 123 | 124 | 125 | if __name__ == '__main__': 126 | arch = 'res18' 127 | # evals(arch) 128 | plt_class_evals(arch) 129 | -------------------------------------------------------------------------------- /onnx/onnx_demo.py: -------------------------------------------------------------------------------- 1 | # Super Resolution model definition in PyTorch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch 5 | import torch.utils.model_zoo as model_zoo 6 | import onnx 7 | import onnxruntime as ort 8 | import numpy as np 9 | 10 | 11 | class SuperResolutionNet(nn.Module): 12 | def __init__(self, upscale_factor, inplace=False): 13 | super(SuperResolutionNet, self).__init__() 14 | 15 | self.relu = nn.ReLU(inplace=inplace) 16 | self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) 17 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 18 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 19 | self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) 20 | self.pixel_shuffle = nn.PixelShuffle(upscale_factor) 21 | 22 | self._initialize_weights() 23 | 24 | def forward(self, x): 25 | x = self.relu(self.conv1(x)) 26 | x = self.relu(self.conv2(x)) 27 | x = self.relu(self.conv3(x)) 28 | x = self.pixel_shuffle(self.conv4(x)) 29 | return x 30 | 31 | def _initialize_weights(self): 32 | init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) 33 | init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) 34 | init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) 35 | init.orthogonal_(self.conv4.weight) 36 | 37 | 38 | def load_torch_model(): 39 | torch_model = SuperResolutionNet(upscale_factor=3) 40 | 41 | # Load pretrained model weights 42 | model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth' 43 | 44 | # Initialize model with the pretrained weights 45 | map_location = lambda storage, loc: storage 46 | if torch.cuda.is_available(): 47 | map_location = None 48 | torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location)) 49 | print('load', model_url) 50 | return torch_model.eval() # set the model to inference mode 51 | 52 | 53 | def cvt_onnx(x, onnx_path): 54 | # Create the super-resolution model by using the above model definition. 55 | torch_model = load_torch_model() 56 | # Input to the model 57 | torch_out = torch_model(x) 58 | 59 | # Export the model 60 | torch.onnx.export( 61 | torch_model, # model being run 62 | x, # model input (or a tuple for multiple inputs) 63 | onnx_path, # where to save the model (can be a file or file-like object) 64 | export_params=True, # store the trained parameter weights inside the model file 65 | opset_version=10, # the ONNX version to export the model to 66 | do_constant_folding=True, # whether to execute constant folding for optimization 67 | input_names=['img'], # the model's input names 68 | output_names=['predict'], # the model's output names 69 | dynamic_axes={ 70 | # 这里的名称,与 input_names/output_names 对应 71 | 'img': { 72 | 0: 'batch_size', 73 | # 2: 'input_h', 3: 'input_w', # 即便输入size改变,也不影响 74 | }, # 对应 x 中自定义的 dim 75 | 'predict': { 76 | 0: 'batch_size', 77 | # 2: 'input_h', 3: 'input_w' 78 | } 79 | }) 80 | 81 | # check 82 | # onnx_model = onnx.load(onnx_path) 83 | # onnx.checker.check_model(onnx_model) # check 仍然会报出 139 84 | print('convert done!') 85 | 86 | return torch_out 87 | 88 | 89 | def to_numpy(tensor): 90 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() 91 | 92 | 93 | def onnx_infer(x, onnx_path): 94 | ort_session = ort.InferenceSession(onnx_path) 95 | 96 | # # compute ONNX Runtime output prediction 97 | ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} 98 | ort_outs = ort_session.run(None, ort_inputs) # 得到 onnx output 99 | 100 | # print(ort_session.get_inputs()[0].name) # img, 对应 onnx.export 中 'input_names' 字段 101 | # print(ort_session.get_outputs()[0].name) # predict, 对应 onnx.export 中 'output_names' 字段 102 | 103 | return ort_outs # list 104 | 105 | 106 | def model_cvt_and_check(onnx_path='onnx/super_resolution.onnx'): 107 | batch_size = 1 108 | input_h, input_w = 224, 224 109 | x = torch.randn(batch_size, 1, input_h, input_w, requires_grad=False) # 可变参数 batch_size 110 | torch_out = cvt_onnx(x, onnx_path) 111 | ort_outs = onnx_infer(x, onnx_path) # list 112 | 113 | # compare ONNX Runtime and PyTorch results 比较 torch 和 onnx 输出的结果 114 | np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) # 断言 output close to each other 115 | print("Exported model has been tested with ONNXRuntime, and the result looks good!") 116 | 117 | 118 | def cvt_img_y_to_rgb(img_y): # y 通道 转 RGB 119 | img_out_y = Image.fromarray(np.uint8((img_y * 255.0).clip(0, 255)), mode='L') 120 | 121 | # get the output image follow post-processing step from PyTorch implementation 122 | final_img = Image.merge( 123 | "YCbCr", [ 124 | img_out_y, 125 | img_cb.resize(img_out_y.size, Image.BICUBIC), # 其他通道 上采样 126 | img_cr.resize(img_out_y.size, Image.BICUBIC), 127 | ]).convert("RGB") 128 | 129 | return final_img 130 | 131 | 132 | if __name__ == '__main__': 133 | from PIL import Image 134 | import torchvision.transforms as transforms 135 | import matplotlib.pyplot as plt 136 | 137 | img = Image.open("img/cat.jpg") 138 | 139 | resize = transforms.Resize([300, 300]) # 测试可变尺寸 140 | img = resize(img) 141 | 142 | img_ycbcr = img.convert('YCbCr') 143 | img_y, img_cb, img_cr = img_ycbcr.split() # 转成单通道图像 144 | 145 | to_tensor = transforms.ToTensor() 146 | img_y = to_tensor(img_y) 147 | img_y.unsqueeze_(0) 148 | 149 | onnx_path = 'onnx/super_resolution.onnx' 150 | torch_out = cvt_onnx(img_y, onnx_path) 151 | onnx_out = onnx_infer(img_y, onnx_path)[0] # 返回 list,元素类型 numpy 152 | 153 | torch_out = cvt_img_y_to_rgb(to_numpy(torch_out).squeeze()) 154 | onnx_out = cvt_img_y_to_rgb(onnx_out.squeeze()) # 颜色略黑 155 | 156 | f, ax = plt.subplots(1, 2, figsize=(10, 5)) 157 | ax[0].imshow(torch_out) 158 | ax[0].set_title('torch_out') 159 | ax[1].imshow(onnx_out) 160 | ax[1].set_title('onnx_out') 161 | plt.show() 162 | -------------------------------------------------------------------------------- /argument_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | 5 | def parse_args(params=None): 6 | parser = argparse.ArgumentParser(description="BiSeNet") 7 | 8 | # model 9 | parser.add_argument('--context_path', type=str, default='resnet50', 10 | choices=['resnet18', 'resnet50', 'resnet101'], 11 | help='backbone name (default: mobilenet)') 12 | parser.add_argument('--in_planes', type=int, default=64, 13 | help='resnet in planes (default: 64)') 14 | parser.add_argument('--dataset', type=str, default='SUNRGBD', 15 | help='dataset name (default: SUNRGBD)') 16 | # default size 17 | parser.add_argument('--base-size', type=int, default=513, 18 | help='base image size') 19 | parser.add_argument('--crop-size', type=int, default=513, 20 | help='crop image size') 21 | 22 | parser.add_argument('--sync-bn', type=bool, default=False, # multi gpu 23 | help='whether to use sync bn (default: False)') 24 | parser.add_argument('--loss-type', type=str, default='ce', 25 | help='loss func type (default: ce)') 26 | parser.add_argument('--workers', type=int, default=4, 27 | metavar='N', help='dataloader threads') 28 | # gpu 29 | parser.add_argument('--gpu-ids', type=str, default='0', 30 | help='use which gpu to train, must be a \ 31 | comma-separated list of integers only (default=0)') 32 | 33 | # training hyper params 34 | parser.add_argument('--epochs', type=int, default=None, metavar='N', 35 | help='number of epochs to train (default: auto)') 36 | parser.add_argument('--iters_per_epoch', type=int, default=None, 37 | help='iterations per epoch') 38 | parser.add_argument('--warmup-epochs', type=int, default=0, metavar='N', 39 | help='number of epochs to train (default: auto)') 40 | parser.add_argument('--batch-size', type=int, default=4, 41 | metavar='N', help='input batch size for training (default: auto)') 42 | parser.add_argument('--use-balanced-weights', action='store_true', default=False, 43 | help='whether to use balanced weights (default: True)') 44 | 45 | # distill params 46 | parser.add_argument("--pi", action='store_true', default=False, help="is pixel wise loss using or not") 47 | parser.add_argument("--pa", action='store_true', default=False, help="is pair wise loss using or not") 48 | parser.add_argument("--ho", action='store_true', default=False, help="is holistic loss using or not") 49 | parser.add_argument("--lr-g", type=float, default=1e-2, help="learning rate for G") 50 | parser.add_argument("--lr-d", type=float, default=4e-4, help="learning rate for D") 51 | parser.add_argument("--lambda-gp", type=float, default=10.0, help="lambda_gp") 52 | parser.add_argument("--lambda-d", type=float, default=0.1, help="lambda_d") 53 | parser.add_argument("--lambda-pi", type=float, default=10.0, help="lambda_pi") 54 | parser.add_argument('--lambda-pa', default=1.0, type=float, help='') 55 | 56 | # optimizer params 57 | parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam']) 58 | parser.add_argument('--lr', type=float, default=2e-3, help='learning rate (default: auto)') 59 | parser.add_argument('--momentum', type=float, default=0.9, 60 | metavar='M', help='momentum (default: 0.9)') 61 | parser.add_argument('--weight-decay', type=float, default=5e-4, # todo: 1e-4 测试集似乎效果不好 62 | metavar='M', help='w-decay (default: 5e-4)') 63 | parser.add_argument('--nesterov', action='store_true', default='False', 64 | help='whether use nesterov (default: False)') 65 | 66 | parser.add_argument('--lr-scheduler', type=str, default='poly', 67 | choices=['step', 'poly', 'cos'], 68 | help='lr scheduler mode: (default: poly)') 69 | parser.add_argument('--lr-step', type=str, default='35', help='step size for lr-step-scheduler') 70 | 71 | # seed 72 | parser.add_argument('--seed', type=int, default=-1, metavar='S', 73 | help='random seed (default: -1)') 74 | # checking point 75 | parser.add_argument('--checkname', type=str, default=None, 76 | help='set the checkpoint name') 77 | parser.add_argument('--resume', action='store_true', default=False, 78 | help='whether to resume training') 79 | 80 | # evaluation option 81 | parser.add_argument('--eval-interval', type=int, default=5, 82 | help='evaluation interval (default: 5) - record metrics every Nth iteration') 83 | 84 | # todo: active 85 | # parser.add_argument('--active-selection-mode', type=str, default='random', 86 | # choices=['random', 87 | # 'entropy', 88 | # 'error_mask', 89 | # 'dropout', 90 | # 'coreset']) 91 | # parser.add_argument('--max-iterations', type=int, default=9, 92 | # help='max active iterations') 93 | # parser.add_argument('--init-percent', type=int, default=None, 94 | # help='init label data percent') 95 | # parser.add_argument('--percent-step', type=int, 96 | # help='incremental label data percent (default: 5)') 97 | # parser.add_argument('--select-num', type=int, 98 | # help='incremental label data percent') 99 | # parser.add_argument('--hard-levels', type=int, default=9, 100 | # help='incremental label data percent') 101 | # parser.add_argument('--strategy', type=str, default='diff_score', 102 | # choices=['diff_score', 'diff_entropy'], 103 | # help='error mask strategy') 104 | 105 | args = parser.parse_args(params) 106 | 107 | # manual seeding 108 | # if args.seed == -1: 109 | # args.seed = int(random.random() * 2000) 110 | # print('Using random seed =', args.seed) 111 | # print('ActiveSelector:', args.active_selection_mode) 112 | 113 | return args 114 | -------------------------------------------------------------------------------- /model/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | import torch.utils.model_zoo as model_zoo 6 | import constants 7 | 8 | 9 | def conv_bn(inp, oup, stride, BatchNorm): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 12 | BatchNorm(oup), 13 | nn.ReLU6(inplace=True) 14 | ) 15 | 16 | 17 | def fixed_padding(inputs, kernel_size, dilation): 18 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 19 | pad_total = kernel_size_effective - 1 20 | pad_beg = pad_total // 2 21 | pad_end = pad_total - pad_beg 22 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 23 | return padded_inputs 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | 28 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = round(inp * expand_ratio) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | self.kernel_size = 3 36 | self.dilation = dilation 37 | 38 | if expand_ratio == 1: 39 | self.conv = nn.Sequential( 40 | # dw 41 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 42 | BatchNorm(hidden_dim), 43 | nn.ReLU6(inplace=True), 44 | # pw-linear 45 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 46 | BatchNorm(oup), 47 | ) 48 | else: 49 | self.conv = nn.Sequential( 50 | # pw 51 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 52 | BatchNorm(hidden_dim), 53 | nn.ReLU6(inplace=True), 54 | # dw 55 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 56 | BatchNorm(hidden_dim), 57 | nn.ReLU6(inplace=True), 58 | # pw-linear 59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 60 | BatchNorm(oup), 61 | ) 62 | 63 | def forward(self, x): 64 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 65 | if self.use_res_connect: 66 | x = x + self.conv(x_pad) 67 | else: 68 | x = self.conv(x_pad) 69 | return x 70 | 71 | 72 | class MobileNetV2(nn.Module): 73 | 74 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True, mc_dropout=False): 75 | super(MobileNetV2, self).__init__() 76 | block = InvertedResidual 77 | input_channel = 32 78 | current_stride = 1 79 | rate = 1 80 | interverted_residual_setting = [ 81 | # t, c, n, s 82 | [1, 16, 1, 1], 83 | [6, 24, 2, 2], 84 | [6, 32, 3, 2], 85 | [6, 64, 4, 2], 86 | [6, 96, 3, 1], 87 | [6, 160, 3, 2], 88 | [6, 320, 1, 1], 89 | ] 90 | 91 | # building first layer 92 | input_channel = int(input_channel * width_mult) 93 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 94 | current_stride *= 2 95 | # building inverted residual blocks 96 | for t, c, n, s in interverted_residual_setting: 97 | if current_stride == output_stride: 98 | stride = 1 99 | dilation = rate 100 | rate *= s 101 | else: 102 | stride = s 103 | dilation = 1 104 | current_stride *= s 105 | output_channel = int(c * width_mult) 106 | for i in range(n): 107 | if i == 0: 108 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 109 | else: 110 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 111 | input_channel = output_channel 112 | 113 | if mc_dropout: # last features, for MC train 114 | self.features.append(nn.Dropout2d(p=constants.MC_DROPOUT_RATE)) 115 | 116 | self.features = nn.Sequential(*self.features) 117 | self._initialize_weights() 118 | 119 | if pretrained: 120 | self._load_pretrained_model() 121 | 122 | # 直接截取到子 Sequential model 123 | self.low_level_features = self.features[0:4] 124 | self.high_level_features = self.features[4:] 125 | self.dropout = nn.Dropout2d(p=constants.MC_DROPOUT_RATE) # for MC test 126 | self.mc_dropout = mc_dropout 127 | 128 | def forward(self, x): 129 | low_level_feat = self.low_level_features(x) 130 | x = self.high_level_features(low_level_feat) 131 | 132 | if self.mc_dropout: 133 | low_level_feat = self.dropout(low_level_feat) 134 | 135 | return x, low_level_feat 136 | 137 | def _load_pretrained_model(self): 138 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 139 | model_dict = {} 140 | state_dict = self.state_dict() 141 | for k, v in pretrain_dict.items(): 142 | if k in state_dict: 143 | model_dict[k] = v 144 | state_dict.update(model_dict) 145 | self.load_state_dict(state_dict) 146 | 147 | def _initialize_weights(self): 148 | for m in self.modules(): 149 | if isinstance(m, nn.Conv2d): 150 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 151 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 152 | torch.nn.init.kaiming_normal_(m.weight) 153 | elif isinstance(m, SynchronizedBatchNorm2d): 154 | m.weight.data.fill_(1) 155 | m.bias.data.zero_() 156 | elif isinstance(m, nn.BatchNorm2d): 157 | m.weight.data.fill_(1) 158 | m.bias.data.zero_() 159 | 160 | 161 | if __name__ == "__main__": 162 | input = torch.rand(1, 3, 360, 480) 163 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 164 | 165 | output, low_level_feat = model(input) 166 | print(output.size()) 167 | print(low_level_feat.size()) 168 | -------------------------------------------------------------------------------- /model/bisenet_up.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model.backbone.resnet import build_contextpath 6 | 7 | """ 8 | F.upsample / cpu_resize 9 | """ 10 | 11 | 12 | class ConvBlock(nn.Module): 13 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1): 14 | super().__init__() 15 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) 16 | self.bn = nn.BatchNorm2d(out_channels) 17 | self.relu = nn.ReLU() 18 | 19 | def forward(self, x): 20 | return self.relu(self.bn(self.conv(x))) # CBR 21 | 22 | 23 | class Spatial_path(nn.Module): 24 | def __init__(self, out_channels): # 1/2 chans 25 | super().__init__() 26 | self.conv1 = ConvBlock(in_channels=3, out_channels=out_channels // 4) 27 | self.conv2 = ConvBlock(in_channels=out_channels // 4, out_channels=out_channels // 2) 28 | self.conv3 = ConvBlock(in_channels=out_channels // 2, out_channels=out_channels) 29 | 30 | def forward(self, x): 31 | return self.conv3(self.conv2(self.conv1(x))) 32 | 33 | 34 | # ARM, channel attention 35 | class AttentionRefinementModule(nn.Module): 36 | def __init__(self, in_channels, out_channels): 37 | super().__init__() 38 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) # GAP 39 | self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) # in = out 40 | self.bn = nn.BatchNorm2d(out_channels) 41 | self.sigmoid = nn.Sigmoid() 42 | 43 | def forward(self, x): 44 | att = self.sigmoid(self.bn(self.conv1x1(self.gap(x)))) # channel attention 45 | out = torch.mul(x, att) # 没有将 input 加入 46 | return out 47 | 48 | 49 | import skimage.transform 50 | import numpy as np 51 | 52 | 53 | def cpu_resize(x, target_size): 54 | assert isinstance(x, torch.Tensor), 'x must be tensor' 55 | 56 | B, C, _, _ = x.size() 57 | y = np.zeros((B, target_size[0], target_size[1], C)) 58 | x = x.detach().permute((0, 2, 3, 1)).cpu().numpy() # B,H,W,C 59 | 60 | for i in range(B): 61 | y[i] = skimage.transform.resize(x[i], target_size, order=1, # bilinear 62 | anti_aliasing=True, preserve_range=True) # 保持原始值范围 63 | 64 | y = torch.from_numpy(y).permute((0, 3, 1, 2)).float().cuda() 65 | 66 | return y 67 | 68 | 69 | # FFM 70 | # todo: feature fusion 采用 self-attention, 直达 loss 71 | class FeatureFusionModule(nn.Module): 72 | def __init__(self, in_channels, num_classes): # out_channels = num_classes 73 | super().__init__() 74 | self.cbr = ConvBlock(in_channels, num_classes, stride=1) # todo: 特征压缩太少了? 75 | 76 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) # GAP 77 | self.conv1x1_1 = nn.Conv2d(num_classes, num_classes, kernel_size=1, bias=False) 78 | self.bn = nn.BatchNorm2d(num_classes) 79 | self.relu = nn.ReLU() 80 | self.conv1x1_2 = nn.Conv2d(num_classes, num_classes, kernel_size=1, bias=False) 81 | self.sigmoid = nn.Sigmoid() 82 | 83 | def forward(self, sp, x1, x2): 84 | _, _, h, w = sp.shape 85 | x1 = F.interpolate(x1, size=(h, w), mode='bilinear', align_corners=True) # 最快 86 | x2 = F.interpolate(x2, size=(h, w), mode='bilinear', align_corners=True) 87 | # x1 = cpu_resize(x1, (h, w)) # 非常慢,速度比 cpu infer 稍高 88 | # x2 = cpu_resize(x2, (h, w)) 89 | 90 | x = torch.cat((sp, x1, x2), dim=1) # fusion feature 91 | x = self.cbr(x) # chans -> num_classes 92 | att = self.sigmoid(self.conv1x1_2(self.relu(self.conv1x1_1(self.gap(x))))) 93 | out = x + torch.mul(x, att) 94 | return out 95 | 96 | 97 | class BiSeNet(nn.Module): 98 | def __init__(self, num_classes, context_path, in_planes=64): 99 | super().__init__() 100 | 101 | # sp_chans = max(in_planes * 2, 64) # 最低要有 64 chan 102 | sp_chans = 128 # 维持为 128 尝试 103 | self.saptial_path = Spatial_path(sp_chans) 104 | self.context_path = build_contextpath(context_path, in_planes, pretrained=True) 105 | 106 | if context_path == 'resnet18': 107 | arm_chans = [in_planes * 4, in_planes * 8] 108 | ffm_chans = sum(arm_chans) + sp_chans 109 | elif context_path == 'resnet50' or context_path == 'resnet101': 110 | arm_chans = [in_planes * 4 * 4, in_planes * 8 * 4] # expansion=4 111 | ffm_chans = sum(arm_chans) + sp_chans 112 | else: 113 | raise NotImplementedError 114 | 115 | # middle features after attention 116 | self.arm1 = AttentionRefinementModule(arm_chans[0], arm_chans[0]) 117 | self.arm2 = AttentionRefinementModule(arm_chans[1], arm_chans[1]) 118 | 119 | # middle supervision 120 | self.mid1 = nn.Conv2d(arm_chans[0], num_classes, kernel_size=1) 121 | self.mid2 = nn.Conv2d(arm_chans[1], num_classes, kernel_size=1) 122 | 123 | self.ffm = FeatureFusionModule(ffm_chans, num_classes) 124 | self.last_conv = nn.Conv2d(num_classes, num_classes, kernel_size=1) 125 | 126 | def forward(self, x): 127 | sp = self.saptial_path(x) 128 | cx1, cx2 = self.context_path(x) 129 | cx1, cx2 = self.arm1(cx1), self.arm2(cx2) # gap 已经在 arm 中做了,没必要再乘 tail 130 | res = self.last_conv(self.ffm(sp, cx1, cx2)) # 1/8 131 | 132 | # 为了适应 teacher 模型输出结果 133 | 134 | # middle sup 应该加在 attention 之后, 使得 auxiliary loss 能辅助训练 arm 模块 135 | # 计算 loss 时,再 *8 upsample 136 | if self.training: # 使用 nn.Module 自带属性判断 training/eval 状态 137 | res1, res2 = self.mid1(cx1), self.mid2(cx2) # 1/16, 1/16 138 | return [res, res1, res2] # 1/8, 1/16, 1/16 139 | else: 140 | return res 141 | 142 | 143 | @torch.no_grad() 144 | def cmp_infer_time(test_num=20): 145 | import time 146 | import itertools 147 | 148 | # 首个 resnet50 预热 GPU 149 | archs = ['resnet50', 'resnet18', 'resnet50', 'resnet101'] 150 | inplanes = [16, 32, 64] 151 | 152 | x = torch.rand(1, 3, 512, 512) 153 | x = x.cuda() 154 | 155 | for arch, inp in itertools.product(archs, inplanes): # 笛卡儿积 156 | model = BiSeNet(37, context_path=arch, in_planes=inp) 157 | model.cuda() 158 | model.eval() 159 | 160 | torch.cuda.synchronize() 161 | t1 = time.time() 162 | for _ in range(test_num): 163 | model(x) 164 | t2 = time.time() 165 | torch.cuda.synchronize() 166 | 167 | t = (t2 - t1) / test_num 168 | fps = 1 / t 169 | 170 | print(f'{arch} - {inp} \t time: {t} \t fps: {fps}') 171 | 172 | 173 | if __name__ == '__main__': 174 | cmp_infer_time() 175 | -------------------------------------------------------------------------------- /model/bisenet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | this_dir = os.path.dirname(__file__) 5 | sys.path.insert(0, os.path.join(this_dir, '..')) # 添加项目目录,python2 不支持中文 6 | 7 | import torch 8 | from torchvision import models 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from model.backbone.resnet import build_contextpath 12 | 13 | 14 | class ConvBlock(nn.Module): 15 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1): 16 | super().__init__() 17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) 18 | self.bn = nn.BatchNorm2d(out_channels) 19 | self.relu = nn.ReLU() 20 | 21 | def forward(self, x): 22 | return self.relu(self.bn(self.conv(x))) # CBR 23 | 24 | 25 | class DeconvBlock(nn.Module): 26 | def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1): 27 | super().__init__() 28 | self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) 29 | self.bn = nn.BatchNorm2d(out_channels) 30 | self.relu = nn.ReLU() 31 | 32 | def forward(self, x): 33 | return self.relu(self.bn(self.deconv(x))) # CBR 34 | 35 | 36 | class Spatial_path(nn.Module): 37 | def __init__(self, out_channels): # 1/2 chans 38 | super().__init__() 39 | self.conv1 = ConvBlock(in_channels=3, out_channels=out_channels // 4) 40 | self.conv2 = ConvBlock(in_channels=out_channels // 4, out_channels=out_channels // 2) 41 | self.conv3 = ConvBlock(in_channels=out_channels // 2, out_channels=out_channels) 42 | 43 | def forward(self, x): 44 | return self.conv3(self.conv2(self.conv1(x))) 45 | 46 | 47 | # ARM, channel attention 48 | class AttentionRefinementModule(nn.Module): 49 | def __init__(self, in_channels, out_channels): 50 | super().__init__() 51 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) # GAP 52 | self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) # in = out 53 | self.bn = nn.BatchNorm2d(out_channels) 54 | self.sigmoid = nn.Sigmoid() 55 | 56 | def forward(self, x): 57 | att = self.sigmoid(self.bn(self.conv1x1(self.gap(x)))) # channel attention 58 | out = torch.mul(x, att) # 没有将 input 加入 59 | return out 60 | 61 | 62 | import skimage.transform 63 | import numpy as np 64 | 65 | 66 | def cpu_resize(x, target_size): 67 | assert isinstance(x, torch.Tensor), 'x must be tensor' 68 | 69 | B, C, _, _ = x.size() 70 | y = np.zeros((B, target_size[0], target_size[1], C)) 71 | x = x.detach().cpu().permute((0, 2, 3, 1)).numpy() # B,H,W,C 72 | 73 | for i in range(B): 74 | y[i] = skimage.transform.resize(x[i], target_size, order=1, # bilinear 75 | anti_aliasing=True, preserve_range=True) # 保持原始值范围 76 | 77 | y = torch.from_numpy(y).permute((0, 3, 1, 2)).float().cuda() 78 | 79 | return y 80 | 81 | 82 | # FFM 83 | # todo: feature fusion 采用 self-attention, 直达 loss 84 | class FeatureFusionModule(nn.Module): 85 | def __init__(self, in_channels, num_classes): # out_channels = num_classes 86 | super().__init__() 87 | self.cbr = ConvBlock(in_channels, num_classes, stride=1) # todo: 特征压缩太少了? 88 | 89 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) # GAP 90 | self.conv1x1_1 = nn.Conv2d(num_classes, num_classes, kernel_size=1, bias=False) 91 | self.bn = nn.BatchNorm2d(num_classes) 92 | self.relu = nn.ReLU() 93 | self.conv1x1_2 = nn.Conv2d(num_classes, num_classes, kernel_size=1, bias=False) 94 | self.sigmoid = nn.Sigmoid() 95 | 96 | def forward(self, sp, x1, x2): 97 | # _, _, h, w = sp.shape 98 | # x1 = F.interpolate(x1, size=(h, w), mode='bilinear', align_corners=True) # 最快 99 | # x2 = F.interpolate(x2, size=(h, w), mode='bilinear', align_corners=True) 100 | # x1 = cpu_resize(x1, (h, w)) # 耗时 101 | # x2 = cpu_resize(x2, (h, w)) 102 | 103 | x = torch.cat((sp, x1, x2), dim=1) # fusion feature 104 | x = self.cbr(x) # chans -> num_classes 105 | att = self.sigmoid(self.conv1x1_2(self.relu(self.conv1x1_1(self.gap(x))))) 106 | out = x + torch.mul(x, att) 107 | return out 108 | 109 | 110 | class BiSeNet(nn.Module): 111 | def __init__(self, num_classes, context_path, in_planes=64): 112 | super().__init__() 113 | 114 | # sp_chans = max(in_planes * 2, 64) # 最低要有 64 chan 115 | sp_chans = 128 # 维持为 128 尝试 116 | self.saptial_path = Spatial_path(sp_chans) 117 | self.context_path = build_contextpath(context_path, in_planes, pretrained=True) 118 | 119 | if context_path == 'resnet18': 120 | arm_chans = [in_planes * 4, in_planes * 8] 121 | ffm_chans = sum(arm_chans) + sp_chans 122 | elif context_path == 'resnet50' or context_path == 'resnet101': 123 | arm_chans = [in_planes * 4 * 4, in_planes * 8 * 4] # expansion=4 124 | ffm_chans = sum(arm_chans) + sp_chans 125 | else: 126 | raise NotImplementedError 127 | 128 | # middle features after attention 129 | self.arm1 = AttentionRefinementModule(arm_chans[0], arm_chans[0]) 130 | self.arm2 = AttentionRefinementModule(arm_chans[1], arm_chans[1]) 131 | 132 | # deconv for ffm 133 | self.deconv1 = DeconvBlock(arm_chans[0], arm_chans[0], kernel_size=4, stride=2, padding=1) # x2 134 | self.deconv2 = DeconvBlock(arm_chans[1], arm_chans[1], kernel_size=4, stride=2, padding=1) # pad 减小 input size 135 | 136 | # middle supervision 137 | self.mid1 = nn.Conv2d(arm_chans[0], num_classes, kernel_size=1) 138 | self.mid2 = nn.Conv2d(arm_chans[1], num_classes, kernel_size=1) 139 | 140 | self.ffm = FeatureFusionModule(ffm_chans, num_classes) 141 | self.last_conv = nn.Conv2d(num_classes, num_classes, kernel_size=1) 142 | 143 | def forward(self, x): 144 | sp = self.saptial_path(x) 145 | cx1, cx2 = self.context_path(x) 146 | cx1, cx2 = self.arm1(cx1), self.arm2(cx2) # gap 已经在 arm 中做了,没必要再乘 tail 147 | 148 | # deconv 上采样 149 | cx1 = self.deconv1(cx1) # tx2, torch 不支持 upsample 150 | cx2 = self.deconv2(cx2) 151 | 152 | res = self.last_conv(self.ffm(sp, cx1, cx2)) # 1/8 153 | 154 | # for onnx output or infer mode 155 | # return res 156 | 157 | # for teacher/student training 158 | res1, res2 = self.mid1(cx1), self.mid2(cx2) # 1/16, 1/16 159 | return [res, res1, res2, cx1, cx2] 160 | 161 | # 单模型可用 self.training 判断状态 162 | # if self.training: # 使用 nn.Module 自带属性判断 training/eval 状态 163 | # res1, res2 = self.mid1(cx1), self.mid2(cx2) # 1/16, 1/16 164 | # return [res, res1, res2] # 1/8, 1/16, 1/16 165 | # else: 166 | # return res 167 | 168 | 169 | @torch.no_grad() 170 | def cmp_infer_time(test_num=20): 171 | import time 172 | import itertools 173 | 174 | # 首个 resnet50 预热 GPU 175 | archs = ['resnet50', 'resnet18', 'resnet50', 'resnet101'] 176 | inplanes = [16, 32, 64] 177 | 178 | x = torch.rand(1, 3, 512, 512) 179 | x = x.cuda() 180 | 181 | for arch, inp in itertools.product(archs, inplanes): # 笛卡儿积 182 | model = BiSeNet(37, context_path=arch, in_planes=inp) 183 | model.cuda() 184 | model.eval() 185 | 186 | torch.cuda.synchronize() # 等待当前设备上所有流中的所有核心完成, CPU 等待 cuda 所有运算执行完才退出 187 | t1 = time.time() 188 | for _ in range(test_num): 189 | model(x) 190 | t2 = time.time() 191 | torch.cuda.synchronize() 192 | 193 | t = (t2 - t1) / test_num 194 | fps = 1 / t 195 | 196 | # print(f'{arch} - {inp} \t time: {t} \t fps: {fps}') 197 | print('{} - {} \t time: {} \t fps: {}'.format(arch, inp, t, fps)) 198 | 199 | 200 | if __name__ == '__main__': 201 | # model = BiSeNet(num_classes=37, context_path='resnet50', in_planes=16) 202 | # model.eval() 203 | # model.cuda() 204 | # x = torch.rand(2, 3, 512, 512).cuda() 205 | # res = model(x) 206 | # print(type(res)) 207 | # for r in res: 208 | # print(r.size()) 209 | 210 | cmp_infer_time() 211 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | import random 6 | import sys 7 | import pickle 8 | import json 9 | import constants 10 | 11 | 12 | def load_state_dict(model, ckpt_path, device): 13 | # 默认在 cuda:0 就容易报错 14 | state_dict = torch.load(ckpt_path, map_location=device)['state_dict'] 15 | model.load_state_dict(state_dict) 16 | print('load', ckpt_path) 17 | 18 | 19 | def print_model_parm_nums(model, string): 20 | b = [] 21 | for param in model.parameters(): 22 | b.append(param.numel()) 23 | print(string + ': Number of params: %.2f M' % (sum(b) / 1e6)) 24 | 25 | 26 | def generate_target_error_mask(pred, target, class_aware=False, num_classes=0): 27 | """ 28 | :param pred: H,W 29 | :param target: H,W 30 | :param class_aware: 31 | :param num_classes: use with class_aware 32 | :return: 33 | """ 34 | 35 | if isinstance(target, torch.Tensor): 36 | pred, target = to_numpy(pred), to_numpy(target) 37 | target_error_mask = (pred != target).astype('uint8') # 0,1 38 | target_error_mask[target == constants.BG_INDEX] = 0 39 | 40 | if class_aware: 41 | # 不受类别数量影响 42 | error_mask = target_error_mask == 1 43 | target_error_mask[~error_mask] = constants.BG_INDEX # bg 44 | 45 | for c in range(num_classes): # C 46 | cls_error = error_mask & (target == c) 47 | target_error_mask[cls_error] = c 48 | 49 | return target_error_mask 50 | 51 | 52 | def to_numpy(var, toint=False): 53 | # Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead. 54 | if isinstance(var, torch.Tensor): 55 | var = var.squeeze().detach().cpu().numpy() 56 | if toint: 57 | var = var.astype('uint8') 58 | return var 59 | 60 | 61 | # pickle io 62 | def dump_pickle(data, out_path): 63 | with open(out_path, 'wb') as f: 64 | pickle.dump(data, f) 65 | print('write data to', out_path) 66 | 67 | 68 | def load_pickle(in_path): 69 | with open(in_path, 'rb') as f: 70 | data = pickle.load(f) # list 71 | return data 72 | 73 | 74 | # json io 75 | def dump_json(adict, out_path): 76 | with open(out_path, 'w', encoding='UTF-8') as json_file: 77 | # 设置缩进,格式化多行保存; ascii False 保存中文 78 | json_str = json.dumps(adict, indent=2, ensure_ascii=False) 79 | json_file.write(json_str) 80 | 81 | 82 | def load_json(in_path): 83 | with open(in_path, 'rb') as f: 84 | adict = json.load(f) 85 | return adict 86 | 87 | 88 | # io: txt <-> list 89 | def write_list_to_txt(a_list, txt_path): 90 | with open(txt_path, 'w') as f: 91 | for p in a_list: 92 | f.write(p + '\n') 93 | 94 | 95 | def read_txt_as_list(f): 96 | with open(f, 'r') as f: 97 | return [p.replace('\n', '') for p in f.readlines()] 98 | 99 | 100 | def approx_print(arr): 101 | arr = np.around(arr * 100, decimals=2) 102 | print(','.join(map(str, arr))) 103 | 104 | 105 | def recover_color_img(img): 106 | """ 107 | cvt tensor image to RGB [note: not BGR] 108 | """ 109 | if isinstance(img, torch.Tensor): 110 | img = img.detach().cpu().numpy().squeeze() 111 | 112 | img = np.transpose(img, axes=[1, 2, 0]) # h,w,c 113 | img = img * (0.229, 0.224, 0.225) + (0.485, 0.456, 0.406) # 直接通道相成? 114 | img = (img * 255).astype('uint8') 115 | return img 116 | 117 | 118 | def colormap(N=256, normalized=False): 119 | """ 120 | return color 121 | """ 122 | 123 | def bitget(byteval, idx): 124 | return (byteval & (1 << idx)) != 0 125 | 126 | dtype = 'float32' if normalized else 'uint8' 127 | cmap = np.zeros((N, 3), dtype=dtype) 128 | for i in range(N): 129 | r = g = b = 0 130 | c = i 131 | for j in range(8): 132 | r = r | (bitget(c, 0) << 7 - j) 133 | g = g | (bitget(c, 1) << 7 - j) 134 | b = b | (bitget(c, 2) << 7 - j) 135 | c = c >> 3 136 | 137 | cmap[i] = np.array([r, g, b]) 138 | 139 | cmap = cmap / 255 if normalized else cmap 140 | return cmap 141 | 142 | 143 | def mkdir(path): 144 | import shutil 145 | if os.path.exists(path): 146 | shutil.rmtree(path) 147 | os.makedirs(path) 148 | 149 | 150 | # dropout 151 | def turn_on_dropout(module): 152 | if type(module) == torch.nn.Dropout: 153 | module.train() 154 | 155 | 156 | def turn_off_dropout(module): 157 | if type(module) == torch.nn.Dropout: 158 | module.eval() 159 | 160 | 161 | # topk 162 | def get_topk_idxs(a, k): 163 | if isinstance(a, list): 164 | a = np.array(a) 165 | return a.argsort()[::-1][:k] 166 | 167 | 168 | def get_group_topk_idxs(scores, groups=5, select_num=10): 169 | total_num = len(scores) 170 | base = total_num // groups 171 | remain = total_num % groups 172 | per_select = select_num // groups 173 | if remain > groups / 2: 174 | base += 1 175 | per_select += 1 # 多组多选 176 | last_select = select_num - per_select * (groups - 1) 177 | 178 | begin_idxs = [0] + [base * (i + 1) for i in range(groups - 1)] + [total_num] 179 | total_idxs = list(range(total_num)) 180 | random.shuffle(total_idxs) 181 | 182 | select_idxs = [] 183 | for i in range(groups): 184 | begin, end = begin_idxs[i], begin_idxs[i + 1] 185 | group_rand_idxs = total_idxs[begin:end] 186 | group_scores = [scores[s] for s in group_rand_idxs] 187 | if i == groups - 1: # 最后一组 188 | per_select = last_select 189 | group_select_idxs = get_topk_idxs(group_scores, k=per_select).tolist() 190 | group_select_idxs = [group_rand_idxs[s] for s in group_select_idxs] # 转成全局 idx 191 | 192 | select_idxs += group_select_idxs 193 | 194 | return select_idxs 195 | 196 | 197 | def get_learning_rate(optimizer): 198 | for param_group in optimizer.param_groups: 199 | return param_group['lr'] 200 | 201 | 202 | def get_curtime(): 203 | current_time = time.strftime('%b%d_%H%M%S', time.localtime()) 204 | return current_time 205 | 206 | 207 | def max_normalize_to1(a): 208 | return a / (np.max(a) + 1e-12) 209 | 210 | 211 | def minmax_normalize(a): # min/max -> [0,1] 212 | min_a, max_a = np.min(a), np.max(a) 213 | return (a - min_a) / (max_a - min_a) 214 | 215 | 216 | def cvt_mask_to_score(mask, pixel_scores): # len(pixel_scores) = num_classes 217 | if isinstance(mask, torch.Tensor): 218 | mask = mask.detach().cpu().numpy() 219 | 220 | valid = mask != constants.BG_INDEX 221 | class_cnts = np.bincount(mask[valid], minlength=len(pixel_scores)) # 0-5 222 | diver_score = np.sum(pixel_scores * class_cnts) / class_cnts.sum() 223 | return diver_score 224 | 225 | 226 | class Logger: 227 | """logger""" 228 | 229 | def __init__(self, filename='default.log', stream=sys.stdout): 230 | self.terminal = stream 231 | self.log = open(filename, 'w', encoding='UTF-8') # 打开时自动清空文件 232 | 233 | def write(self, msg): 234 | self.terminal.write(msg) # 命令行打印 235 | self.log.write(msg) 236 | 237 | def flush(self): # 必有,不然 AttributeError: 'Logger' object has no attribute 'flush' 238 | pass 239 | 240 | def close(self): 241 | self.log.close() 242 | 243 | 244 | class AverageMeter: 245 | """Computes and stores the average and current value""" 246 | 247 | def __init__(self): 248 | self.reset() 249 | 250 | def reset(self): 251 | self.val = 0 252 | self.avg = 0 253 | self.sum = 0 254 | self.count = 0 255 | 256 | def update(self, val, n=1): 257 | self.val = val 258 | self.sum += val * n 259 | self.count += n 260 | self.avg = self.sum / self.count 261 | 262 | 263 | class AccCaches: 264 | """acc cache queue""" 265 | 266 | def __init__(self, patience): 267 | self.accs = [] # [(epoch, acc), ...] 268 | self.patience = patience 269 | 270 | def reset(self): 271 | self.accs = [] 272 | 273 | def add(self, epoch, acc): 274 | if len(self.accs) >= self.patience: # 先满足 = 275 | self.accs = self.accs[1:] # 队头出队列 276 | self.accs.append((epoch, acc)) # 队尾添加 277 | 278 | def full(self): 279 | return len(self.accs) == self.patience 280 | 281 | def max_cache_acc(self): 282 | max_id = int(np.argmax([t[1] for t in self.accs])) # t[1]=acc 283 | max_epoch, max_acc = self.accs[max_id] 284 | return max_epoch, max_acc 285 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import constants 3 | from model.sync_batchnorm.replicate import patch_replication_callback 4 | from utils.misc import get_learning_rate 5 | from utils.loss import SegmentationLosses 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | from utils.metrics import Evaluator 9 | from utils.misc import AverageMeter 10 | from utils.lr_scheduler import LR_Scheduler 11 | from tqdm import tqdm 12 | import torch.nn.functional as F 13 | 14 | 15 | class Trainer: 16 | 17 | def __init__(self, args, model, train_set, val_set, test_set, class_weights, saver, writer): 18 | self.args = args 19 | self.saver = saver 20 | self.saver.save_experiment_config() # save cfgs 21 | self.writer = writer 22 | 23 | self.num_classes = train_set.num_classes 24 | 25 | # dataloaders 26 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 27 | self.train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 28 | self.val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 29 | self.test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 30 | 31 | self.dataset_size = {'train': len(train_set), 'val': len(val_set), 'test': len(test_set)} 32 | print('dataset size:', self.dataset_size) 33 | 34 | # 加快训练,减少每轮迭代次数;不需要从引入样本时就截断数据,这样更好 35 | self.iters_per_epoch = args.iters_per_epoch if args.iters_per_epoch else len(self.train_dataloader) 36 | 37 | if args.optimizer == 'SGD': 38 | print('Using SGD') 39 | self.optimizer = torch.optim.SGD(model.parameters(), 40 | lr=args.lr, 41 | momentum=args.momentum, 42 | weight_decay=args.weight_decay, 43 | nesterov=args.nesterov) 44 | self.lr_scheduler = LR_Scheduler(mode=args.lr_scheduler, base_lr=args.lr, 45 | lr_step=args.lr_step, 46 | num_epochs=args.epochs, 47 | warmup_epochs=args.warmup_epochs, 48 | iters_per_epoch=self.iters_per_epoch) 49 | elif args.optimizer == 'Adam': 50 | print('Using Adam') 51 | self.optimizer = torch.optim.Adam(model.parameters(), 52 | lr=args.lr, 53 | # amsgrad=True, 54 | weight_decay=args.weight_decay) 55 | else: 56 | raise NotImplementedError 57 | 58 | self.device = torch.device(f'cuda:{args.gpu_ids}') 59 | 60 | if len(args.gpu_ids) > 1: 61 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 62 | model = torch.nn.DataParallel(model, device_ids=args.gpu_ids) 63 | patch_replication_callback(model) 64 | print(args.gpu_ids) 65 | 66 | self.model = model.to(self.device) 67 | 68 | # loss 69 | if args.use_balanced_weights: 70 | weight = torch.from_numpy(class_weights.astype(np.float32)).to(self.device) 71 | else: 72 | weight = None 73 | 74 | self.criterion = SegmentationLosses(mode=args.loss_type, weight=weight, ignore_index=constants.BG_INDEX) 75 | 76 | # evaluator 77 | self.evaluator = Evaluator(self.num_classes) 78 | 79 | self.best_epoch = 0 80 | self.best_mIoU = 0.0 81 | self.best_pixelAcc = 0.0 82 | 83 | def training(self, epoch, prefix='Train', evaluation=False): 84 | self.model.train() 85 | if evaluation: 86 | self.evaluator.reset() 87 | 88 | train_losses = AverageMeter() 89 | tbar = tqdm(self.train_dataloader, desc='\r', total=self.iters_per_epoch) # 设置最多迭代次数, 从0开始.. 90 | 91 | if self.writer: 92 | self.writer.add_scalar(f'{prefix}/learning_rate', get_learning_rate(self.optimizer), epoch) 93 | 94 | for i, sample in enumerate(tbar): 95 | image, target = sample['img'], sample['target'] 96 | image, target = image.to(self.device), target.to(self.device) 97 | if self.args.optimizer == 'SGD': 98 | self.lr_scheduler(self.optimizer, i, epoch) # each iteration 99 | 100 | output = self.model(image) 101 | loss = self.criterion(output, target) # multiple output loss 102 | self.optimizer.zero_grad() 103 | loss.backward() 104 | self.optimizer.step() 105 | 106 | train_losses.update(loss.item()) 107 | tbar.set_description('Epoch {}, Train loss: {:.3f}'.format(epoch, train_losses.avg)) 108 | 109 | if evaluation: 110 | output = F.interpolate(output[-1], size=(target.size(1), target.size(2)), mode='bilinear', align_corners=True) 111 | pred = torch.argmax(output, dim=1) 112 | self.evaluator.add_batch(target.cpu().numpy(), pred.cpu().numpy()) # B,H,W 113 | 114 | # 即便 tqdm 有 total,仍然要这样跳出 115 | if i == self.iters_per_epoch - 1: 116 | break 117 | 118 | if self.writer: 119 | self.writer.add_scalar(f'{prefix}/loss', train_losses.val, epoch) 120 | if evaluation: 121 | Acc = self.evaluator.Pixel_Accuracy() 122 | mIoU = self.evaluator.Mean_Intersection_over_Union() 123 | print('Epoch: {}, Acc_pixel:{:.3f}, mIoU:{:.3f}'.format(epoch, Acc, mIoU)) 124 | 125 | self.writer.add_scalars(f'{prefix}/IoU', { 126 | 'mIoU': mIoU, 127 | # 'mDice': mDice, 128 | }, epoch) 129 | self.writer.add_scalars(f'{prefix}/Acc', { 130 | 'acc_pixel': Acc, 131 | # 'acc_class': Acc_class 132 | }, epoch) 133 | 134 | @torch.no_grad() 135 | def validation(self, epoch, test=False): 136 | self.model.eval() 137 | self.evaluator.reset() # reset confusion matrix 138 | 139 | if test: 140 | tbar = tqdm(self.test_dataloader, desc='\r') 141 | prefix = 'Test' 142 | else: 143 | tbar = tqdm(self.val_dataloader, desc='\r') 144 | prefix = 'Valid' 145 | 146 | # loss 147 | segment_losses = AverageMeter() 148 | 149 | for i, sample in enumerate(tbar): 150 | image, target = sample['img'], sample['target'] 151 | image, target = image.to(self.device), target.to(self.device) 152 | 153 | output = self.model(image)[0] # 拿到首个输出 154 | segment_loss = self.criterion(output, target) 155 | segment_losses.update(segment_loss.item()) 156 | tbar.set_description(f'{prefix} loss: %.4f' % segment_losses.avg) 157 | 158 | output = F.interpolate(output, size=(target.size()[1:]), mode='bilinear', align_corners=True) 159 | pred = torch.argmax(output, dim=1) # pred 160 | 161 | # eval: add batch result 162 | self.evaluator.add_batch(target.cpu().numpy(), pred.cpu().numpy()) # B,H,W 163 | 164 | Acc = self.evaluator.Pixel_Accuracy() 165 | # Acc_class = self.evaluator.Pixel_Accuracy_Class() 166 | mIoU = self.evaluator.Mean_Intersection_over_Union() 167 | # mDice = self.evaluator.Mean_Dice() 168 | print('Epoch: {}, Acc_pixel:{:.4f}, mIoU:{:.4f}'.format(epoch, Acc, mIoU)) 169 | 170 | if self.writer: 171 | self.writer.add_scalar(f'{prefix}/loss', segment_losses.avg, epoch) 172 | self.writer.add_scalars(f'{prefix}/IoU', { 173 | 'mIoU': mIoU, 174 | # 'mDice': mDice, 175 | }, epoch) 176 | self.writer.add_scalars(f'{prefix}/Acc', { 177 | 'acc_pixel': Acc, 178 | # 'acc_class': Acc_class 179 | }, epoch) 180 | 181 | if not test: 182 | if mIoU > self.best_mIoU: 183 | print('saving model...') 184 | self.best_mIoU = mIoU 185 | self.best_pixelAcc = Acc 186 | self.best_epoch = epoch 187 | 188 | state = { 189 | 'epoch': self.best_epoch, 190 | 'state_dict': self.model.state_dict(), # 方便 test 保持同样结构? 191 | 'optimizer': self.optimizer.state_dict(), 192 | 'best_mIoU': self.best_mIoU, 193 | 'best_pixelAcc': self.best_pixelAcc 194 | } 195 | self.saver.save_checkpoint(state) 196 | print('save model at epoch', epoch) 197 | 198 | return mIoU, Acc 199 | 200 | def load_best_checkpoint(self): 201 | checkpoint = self.saver.load_checkpoint() 202 | self.model.load_state_dict(checkpoint['state_dict']) 203 | self.optimizer.load_state_dict(checkpoint['optimizer']) 204 | print(f'=> loaded checkpoint - epoch {checkpoint["epoch"]}') 205 | return checkpoint["epoch"] 206 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | import numpy as np 5 | import cv2 6 | import constants 7 | import PIL.ImageEnhance as ImageEnhance 8 | from PIL import Image, ImageOps, ImageFilter 9 | 10 | 11 | def mapbg(bg_idx): 12 | """ 13 | image bg 转成 constants.BG_INDEX, 类别从 [0,..,C-1] 14 | """ 15 | 16 | # bg 在首部,需要调整 实际类别 前移1位 17 | def map_headbg(target): 18 | target = target.astype(int) 19 | target -= 1 # 1->0 20 | target[target == -1] = constants.BG_INDEX # 255 21 | return target.astype('uint8') 22 | 23 | # bg 在尾部,直接替换为 constant 即可 24 | def map_other(target): 25 | target = target.astype(int) 26 | target[target == bg_idx] = constants.BG_INDEX 27 | return target.astype('uint8') 28 | 29 | if bg_idx == 0: 30 | return map_headbg 31 | else: 32 | return map_other 33 | 34 | 35 | def remap(bg_idx): 36 | """ 37 | 分割结果 -> 回归原始 bg idx,方面 vis 38 | """ 39 | 40 | def remap_headbg(target): 41 | target = target.astype(int) 42 | target += 1 43 | target[target == constants.BG_INDEX + 1] = bg_idx 44 | return target.astype('uint8') 45 | 46 | def remap_other(target): 47 | target = target.astype(int) 48 | target[target == constants.BG_INDEX] = bg_idx 49 | return target.astype('uint8') 50 | 51 | if bg_idx == 0: 52 | return remap_headbg 53 | else: 54 | return remap_other 55 | 56 | 57 | class Compose: # 可以采用 默认的 58 | def __init__(self, trans_list): 59 | self.trans_list = trans_list 60 | 61 | def __call__(self, sample): 62 | for t in self.trans_list: 63 | sample = t(sample) 64 | return sample 65 | 66 | def __repr__(self): 67 | format_string = self.__class__.__name__ + '(' 68 | 69 | for t in self.trans_list: 70 | format_string += '\n' 71 | format_string += ' {0}'.format(t) 72 | format_string += '\n)' 73 | 74 | return format_string 75 | 76 | 77 | class RandomHorizontalFlip: 78 | def __call__(self, sample): 79 | img, target = sample['img'], sample['target'] 80 | if random.random() < 0.5: 81 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 82 | target = target.transpose(Image.FLIP_LEFT_RIGHT) 83 | 84 | return {'img': img, 85 | 'target': target} 86 | 87 | 88 | class RandomRotate: 89 | def __init__(self, degree): # 旋角上限 90 | self.degree = degree 91 | 92 | def __call__(self, sample): 93 | img, target = sample['img'], sample['target'] 94 | 95 | rotate_degree = random.uniform(-1 * self.degree, self.degree) 96 | img = img.rotate(rotate_degree, Image.BILINEAR) 97 | target = target.rotate(rotate_degree, Image.NEAREST) 98 | 99 | return {'image': img, 100 | 'target': target} 101 | 102 | 103 | class RandomGaussianBlur: 104 | def __call__(self, sample): 105 | img, target = sample['img'], sample['target'] 106 | 107 | if random.random() < 0.5: 108 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) 109 | 110 | return {'img': img, 111 | 'target': target} 112 | 113 | 114 | class RandomScaleCrop: 115 | def __init__(self, base_size, crop_size, scales=(0.75, 2.0), fill=0): # fill bg_idx 116 | self.base_size = base_size # 1个值,不用给定 h,w 117 | self.crop_size = crop_size 118 | self.scales = scales 119 | self.fill = fill 120 | 121 | def __call__(self, sample): 122 | img, target = sample['img'], sample['target'] 123 | 124 | # 保持原图 aspect ratio,依照 短边 进行缩放 125 | short_size = random.randint(int(self.base_size * self.scales[0]), 126 | int(self.base_size * self.scales[1])) 127 | w, h = img.size 128 | if h > w: 129 | ow = short_size 130 | oh = int(1.0 * h * ow / w) 131 | else: 132 | oh = short_size 133 | ow = int(1.0 * w * oh / h) 134 | 135 | # random scale 136 | img = img.resize((ow, oh), Image.BILINEAR) 137 | target = target.resize((ow, oh), Image.NEAREST) 138 | 139 | # scale 后短边 < 要 crop 尺寸,补图 140 | if short_size < self.crop_size: 141 | padh = self.crop_size - oh if oh < self.crop_size else 0 142 | padw = self.crop_size - ow if ow < self.crop_size else 0 143 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) # img fill 0, 后面还有 normalize 144 | target = ImageOps.expand(target, border=(0, 0, padw, padh), fill=self.fill) # target fill bg_idx 145 | 146 | # random crop 147 | w, h = img.size 148 | x1 = random.randint(0, w - self.crop_size) 149 | y1 = random.randint(0, h - self.crop_size) 150 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 151 | target = target.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 152 | 153 | return {'img': img, 154 | 'target': target} 155 | 156 | 157 | class FixScaleCrop: 158 | def __init__(self, crop_size): # valid, 固定原图 aspect,crop 到 model size 159 | self.crop_size = crop_size 160 | 161 | def __call__(self, sample): 162 | img, target = sample['img'], sample['target'] 163 | 164 | w, h = img.size 165 | if w > h: 166 | oh = self.crop_size 167 | ow = int(1.0 * w * oh / h) # 保证长宽比,以短边为 513,放缩长边 168 | else: 169 | ow = self.crop_size 170 | oh = int(1.0 * h * ow / w) 171 | 172 | img = img.resize((ow, oh), Image.BILINEAR) 173 | target = target.resize((ow, oh), Image.NEAREST) 174 | 175 | w, h = img.size # 放缩后的 size 176 | x1 = int(round((w - self.crop_size) / 2.)) 177 | y1 = int(round((h - self.crop_size) / 2.)) 178 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 179 | target = target.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 180 | 181 | return {'img': img, 182 | 'target': target} 183 | 184 | 185 | class FixedResize: 186 | def __init__(self, size): 187 | self.size = (size, size) # test, 直接 resize 到 crop_size 188 | 189 | def __call__(self, sample): 190 | img, target = sample['img'], sample['target'] 191 | assert img.size == target.size 192 | 193 | img = img.resize(self.size, Image.BILINEAR) 194 | target = target.resize(self.size, Image.NEAREST) 195 | 196 | return {'img': img, 197 | 'target': target} 198 | 199 | 200 | class ColorJitter: 201 | def __init__(self, brightness=None, contrast=None, saturation=None): 202 | if not brightness is None and brightness > 0: 203 | self.brightness = [max(1 - brightness, 0), 1 + brightness] 204 | if not contrast is None and contrast > 0: 205 | self.contrast = [max(1 - contrast, 0), 1 + contrast] 206 | if not saturation is None and saturation > 0: 207 | self.saturation = [max(1 - saturation, 0), 1 + saturation] 208 | 209 | def __call__(self, sample): 210 | img, target = sample['img'], sample['target'] 211 | 212 | r_brightness = random.uniform(self.brightness[0], self.brightness[1]) 213 | r_contrast = random.uniform(self.contrast[0], self.contrast[1]) 214 | r_saturation = random.uniform(self.saturation[0], self.saturation[1]) 215 | 216 | img = Image.fromarray(img) # np->Image [np uint8] 217 | img = ImageEnhance.Brightness(img).enhance(r_brightness) 218 | img = ImageEnhance.Contrast(img).enhance(r_contrast) 219 | img = ImageEnhance.Color(img).enhance(r_saturation) 220 | 221 | return {'img': img, 222 | 'target': target} 223 | 224 | 225 | class Normalize: 226 | """Normalize a tensor image with mean and standard deviation. 227 | Args: 228 | mean (tuple): means for each channel. 229 | std (tuple): standard deviations for each channel. 230 | """ 231 | 232 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 233 | self.mean = mean 234 | self.std = std 235 | 236 | def __call__(self, sample): 237 | img, target = sample['img'], sample['target'] 238 | 239 | img = np.array(img).astype(np.float32) 240 | target = np.array(target).astype(np.float32) 241 | img /= 255.0 242 | img -= self.mean 243 | img /= self.std 244 | 245 | return {'img': img, 246 | 'target': target} 247 | 248 | 249 | class ToTensor: 250 | """Convert ndarrays in sample to Tensors.""" 251 | 252 | def __call__(self, sample): 253 | img, target = sample['img'], sample['target'] 254 | 255 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 256 | target = np.array(target).astype(np.float32) 257 | 258 | img = torch.from_numpy(img).float() 259 | target = torch.from_numpy(target).float() 260 | 261 | return {'img': img, 262 | 'target': target} 263 | 264 | 265 | class ToRGBD: 266 | def __call__(self, sample): 267 | img, depth, target = sample['img'], sample['depth'], sample['target'] 268 | rgbd = torch.cat((img, depth), 0) 269 | return {'rgbd': rgbd, 270 | 'target': target} 271 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BiSeNet-wali 2 | 3 | Compress BiSeNet with Structure Knowledge Distillation for Real-time image segmentation on wali-TX2. 4 | 5 | This repo is developed for RGB image. BiSeNet-RGBD architure can be accessed [here](https://github.com/Shuai-Xie/Wali-turtlebot). 6 | 7 | 8 | Reference: 9 | 10 | - BiSeNet (ECCV2018): https://github.com/ooooverflow/BiSeNet 11 | - Structure Knowledge Distillation (CVPR2019_oral): https://github.com/irfanICMLL/structure_knowledge_distillation 12 | 13 | 14 | ![](res/res18_pi.png) 15 | 16 | ![](res/res101.png) 17 | 18 | ## BiSeNet [arch + CP in_planes] 基本模型在 CPU/GPU/TX2 推理时间 19 | 20 | - SP: Spatial Path, CP: Context Path 21 | - 保持 SP = 128 chans,CP 通过调整 `in_planes` 尝试不同的大小模型 22 | - 速度 23 | - CPU 上的速度,忠诚于 in_planes 变动 24 | - GPU 上负载少时,infer 更快,FPS 更高 25 | - BiSeNet 结构 快的原因 26 | - 去掉了 U 型结构中 太多 深浅层关联的桥接; 浅层细节完全交给 SP 完成 27 | - 明确提出 浅层细节 SP 和 深层语义 CP 的 **双边**分割 28 | 29 | 30 | arch | CP in_planes | valid_IoU | valid_Acc | test_IoU | test_Acc | CPU | GPU | TX2 | TRT 31 | :-:| :-:| :-:|:-:|:-:|:-: | :-: | :-: | :-: | :-: 32 | res18(up) | 64 | 47.23 | 81.77 | 35.07 | 74.36 | 5.3 | 63.5 | / | / 33 | res50 | 64 | 50.26 | 83.41 | 38.64 | 76.60 | 2.1 | 27.2 | / | / 34 | res101 | 64 | 52.00 | 84.73 | 41.04 | 78.02 | 1.8 | 18.0 | / | / 35 | res18(deconv) | 64 | | | | | 5.4 | 54.9 |7.8 36 | res50 | 64 | | | | | 1.8 | 12.1 | 1.8 37 | res101 | 64 | **58.38** | **85.16** | 41.87 | 77.59 | 1.3 | 9.7 | 1.5 38 | res18 | 32 | 38.08 | 73.94 | **28.75** | **67.26** | 9.8 | 130.3 | **24.5** | 39 | res50 | 32 | | | | | 4.5 | 35.0 | 5.9 40 | res101 | 32 | | | | | 3.4 | 26.1 | 4.7 41 | res18 | 16 | | | | | 17.2 | 202.7 | 51.2 42 | res50 | 16 | 34.70 | 71.11 | 26.49 | 64.61 | 8.6 | 80.0 | 15.1 43 | res101 | 16 | | | | | 6.3 | 61.5 | 12.1 44 | 45 | 注:只有前3组实验上采样使用 upsample, 因为 TX2 torch 不支持 upsample,之后实验都为 deconv 46 | - **res18, inp=32, IoU & Speed 双优,作为 student 模型** 47 | - **res101, inp=64, 作为 teacher 模型** 48 | - 训练 iteration = 40000 49 | 50 | 51 | ## 减少 SP=64 chans, 性能大幅下降 [×] 52 | 53 | test_IoU 平均下降 2+ 个点,valid_IoU 平均下降 4+ 个点;**不能再用这种方式压缩模型** 54 | 55 | arch |SP| CP in_planes | valid_IoU | valid_Acc | test_IoU | test_Acc | CPU | GPU | TX2 | 56 | :-:| :-:| :-:| :-:|:-:|:-:|:-: | :-: | :-: | :-: 57 | res18 | 128 | 32 | 38.08 | 73.94 | **28.75** | **67.26** | 9.8 | 130.3 | 24.5 | 58 | res18 | 64 | 32 | 33.30 | 70.63 | 26.00 | 64.35 | 9.8 | 292.8 | **33.0** 59 | res50 | 128 | 16 | 34.70 | 71.11 | **26.49** | 64.61 | 8.6 | 80.0 | 15.1 60 | res50 | 64 | 16 | 30.23 | 69.27 | 24.73 | 63.79 | 8.6 | 160.4 | **18.8** 61 | 62 | ## After Distillation [√] 63 | - pi: pixel-wise; pa: pair-wise; ho: holistic 64 | 65 | arch | in_planes | valid_IoU | valid_Acc | test_IoU | test_Acc | note 66 | :-:| :-: | :-:| :-:|:-:|:-:|- 67 | res101 | 64 | **58.38** | **85.16** | 41.87 | 77.59 | Teacher 68 | res18 | 32 | 38.08 | 73.94 | **28.75** | 67.26 | base 69 | res18 (pi) | 32 | 39.45 | 75.43 | 29.30 | 68.52 | lr=1e-4, pi=10 70 | res18 (pi) | 32 | 39.86 | 75.99 | **30.43** | 69.58 | lr=1e-3, pi=10 71 | res18 (pi) | 32 | 38.18 | 75.27 | 30.02 | 69.48 | lr=5e-3, pi=10 72 | res18 (pa) | 32 | 38.51 | 74.25 | 28.20 | 66.87 | lr=1e-3, pa=10 73 | res18 (pi+pa) | 32 | 39.76 | 75.90 | 30.07 | 69.28 | lr=1e-3, pi=10, pa=10 74 | 75 | - +pi: 1e-4 的 pi_loss 近似保持横线,1e-3 看到 pi_loss 有明显下降,5e-3 性能略有下降 76 | - 不完全是个 finetune 过程,soft_label 的学习,也需要恰当的 lr 77 | - 训练 iteration = 10000 78 | - **pa_loss 为 cos 余弦距离,本身数量级小,强行添加权值到 seg_loss 同等规模,会带来各种 loss 的剧增** 79 | - pa_loss 本质上为 feature 内部相似性度量,类似 self-attention;**实验结果并未带来蒸馏性能提升** 80 | 81 | ## convert to ONNX 82 | 83 | - [Pytroch 上采样 -> ONNX](onnx/torch2onnx.md) 84 | 85 | ## Inference Speed on Various Devices 86 | 87 | CPU infer 88 | 89 | ```py 90 | # arch - inplanes 91 | resnet18 - 16 time: 0.05347979068756103 fps: 18.698652091631914 92 | resnet18 - 32 time: 0.10354204177856445 fps: 9.65791269732352 93 | resnet18 - 64 time: 0.18792753219604490 fps: 5.321200083427934 94 | 95 | resnet50 - 16 time: 0.11204280853271484 fps: 8.925160062441801 96 | resnet50 - 32 time: 0.22045552730560303 fps: 4.53606227170601 97 | resnet50 - 64 time: 0.47310781478881836 fps: 2.113683115647479 98 | 99 | resnet101 - 16 time: 0.14643387794494628 fps: 6.829020811536269 100 | resnet101 - 32 time: 0.25973417758941650 fps: 3.8500901547919635 101 | resnet101 - 64 time: 0.57127020359039300 fps: 1.7504851359568034 102 | 103 | # deconv 差别不大,略有下降 104 | resnet18 - 16 time: 0.05818810462951660 fps: 17.1856431201359 105 | resnet18 - 32 time: 0.10182969570159912 fps: 9.820318062526589 106 | resnet18 - 64 time: 0.18533535003662110 fps: 5.395624740786937 107 | 108 | resnet50 - 16 time: 0.11621274948120117 fps: 8.604907847583128 109 | resnet50 - 32 time: 0.22408280372619630 fps: 4.462636058507579 110 | resnet50 - 64 time: 0.55882019996643060 fps: 1.7894843458773892 111 | 112 | resnet101 - 16 time: 0.15813012123107910 fps: 6.323905858129821 113 | resnet101 - 32 time: 0.29029526710510256 fps: 3.4447685281687557 114 | resnet101 - 64 time: 0.74392204284667970 fps: 1.3442268711025374 115 | ``` 116 | 117 | GPU infer 118 | 119 | ```py 120 | # arch - inplanes 121 | resnet18 - 16 time: 0.004689335823059082 fps: 213.24981569514705 122 | resnet18 - 32 time: 0.007193267345428467 fps: 139.01888418418503 123 | resnet18 - 64 time: 0.015751779079437256 fps: 63.484892402117524 124 | 125 | resnet50 - 16 time: 0.009638953208923340 fps: 103.74570540235032 126 | resnet50 - 32 time: 0.014894998073577880 fps: 67.13663171087562 127 | resnet50 - 64 time: 0.036808860301971430 fps: 27.167371980447907 128 | 129 | resnet101 - 16 time: 0.012390828132629395 fps: 80.70485598671564 130 | resnet101 - 32 time: 0.026305425167083740 fps: 38.014971955340634 131 | resnet101 - 64 time: 0.055841803550720215 fps: 17.90773106194029 132 | 133 | # gpu + deconv 比直接 upsample 稍慢 134 | resnet18 - 16 time: 0.0049323558807373045 fps: 202.74287261091078 135 | resnet18 - 32 time: 0.0076752901077270510 fps: 130.28823483730682 136 | resnet18 - 64 time: 0.0182011842727661140 fps: 54.941479906682225 137 | 138 | resnet50 - 16 time: 0.012502193450927734 fps: 79.98596437697853 139 | resnet50 - 32 time: 0.028601193428039552 fps: 34.96357599608543 140 | resnet50 - 64 time: 0.082448863983154290 fps: 12.128729878004348 141 | 142 | resnet101 - 16 time: 0.016249334812164305 fps: 61.54098069610805 143 | resnet101 - 32 time: 0.038253283500671385 fps: 26.141546776826335 144 | resnet101 - 64 time: 0.102581226825714100 fps: 9.748372396627735 145 | 146 | # cpu upsample, 部分操作转移到 CPU,速度更慢了 147 | resnet18 - 16 time: 0.03870357275009155 fps: 25.837407994786076 148 | resnet18 - 32 time: 0.07058500051498413 fps: 14.167315898619496 149 | resnet18 - 64 time: 0.14476600885391236 fps: 6.907698899187927 150 | 151 | resnet50 - 16 time: 0.13139193058013915 fps: 7.610817464852421 152 | resnet50 - 32 time: 0.27268123626708984 fps: 3.667285705792038 153 | resnet50 - 64 time: 0.62760386466979980 fps: 1.5933617625604144 154 | 155 | resnet101 - 16 time: 0.1334829092025757 fps: 7.491595785362939 156 | resnet101 - 32 time: 0.2804239988327026 fps: 3.5660286001291466 157 | resnet101 - 64 time: 0.6410346508026123 fps: 1.5599780741149365 158 | ``` 159 | 160 | TX2 infer 161 | 162 | ```py 163 | # cpu 164 | resnet18 - 16 time: 0.4392608404159546 fps: 2.276551670422198 165 | resnet18 - 32 time: 1.0016891717910767 fps: 0.9983136766986744 166 | resnet18 - 64 time: 2.234201192855835 fps: 0.4475872643867693 167 | 168 | resnet50 - 16 time: 0.8117639303207398 fps: 1.231885234916616 169 | resnet50 - 32 time: 2.240462875366211 fps: 0.4463363401353154 170 | resnet50 - 64 time: 5.18095223903656 fps: 0.19301471116938113 171 | 172 | resnet101 - 16 time: 1.4539900541305542 fps: 0.6877626137532091 173 | resnet101 - 32 time: 3.2897871494293214 fps: 0.3039710335586513 174 | resnet101 - 64 time: 7.119290184974671 fps: 0.14046344144118603 175 | 176 | # gpu + deconv 177 | resnet18 - 16 time: 0.01954698562622070 fps: 51.158783206888984 178 | resnet18 - 32 time: 0.04074519872665405 fps: 24.5427689949107 # 1 179 | resnet18 - 64 time: 0.1278276562690735 fps: 7.82303320883103 180 | 181 | resnet50 - 16 time: 0.06625185012817383 fps: 15.093918102896065 # 2 182 | resnet50 - 32 time: 0.17058074474334717 fps: 5.862326381002631 183 | resnet50 - 64 time: 0.5573523283004761 fps: 1.7941972235933437 184 | 185 | resnet101 - 16 time: 0.08274534940719605 fps: 12.085271343515938 186 | resnet101 - 32 time: 0.21335372924804688 fps: 4.687051890418992 187 | resnet101 - 64 time: 0.6515600681304932 fps: 1.5347779106065813 188 | 189 | 190 | # gpu + cpu upsample 191 | # sudo nvpmodel -m 0, 已是最大功率情况下 192 | resnet18 - 16 time: 0.10670669078826904 fps: 9.3714835743921 193 | resnet18 - 32 time: 0.23019177913665773 fps: 4.344203792813691 194 | resnet18 - 64 time: 0.4999691009521484 fps: 2.0001236038298877 195 | 196 | resnet50 - 16 time: 0.39669969081878664 fps: 2.5207985363840435 197 | resnet50 - 32 time: 0.8320304155349731 fps: 1.20187913966706 198 | resnet50 - 64 time: 2.0053189754486085 fps: 0.4986737831951601 199 | 200 | resnet101 - 16 time: 0.4652644872665405 fps: 2.1493151258439385 201 | resnet101 - 32 time: 0.8607745170593262 fps: 1.1617444292104644 202 | resnet101 - 64 time: 2.0529680490493774 fps: 0.48709964115761467 203 | ``` 204 | 205 | ## DeepLabv3+ on TX2 206 | 207 | ```py 208 | mobilenet time: 0.2479145646095276 fps: 4.033647646216462 209 | resnet50 time: 0.3126360774040222 fps: 3.1986071738857307 210 | resnet101 time: 0.41015373468399047 fps: 2.438110190000991 211 | ``` -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import constants 5 | 6 | 7 | def focal_loss_sigmoid(y_pred, labels, alpha=0.25, gamma=2): 8 | """ 9 | :param y_pred: binary classification, output after sigmoid 10 | :param labels: gt label 11 | :param alpha: 负例样本权值,值越小,分给正例的权值越大 12 | :param gamma: 13 | :return: 14 | """ 15 | labels = labels.float() 16 | 17 | # loss = label1 + label0 18 | # 难易样本占比,1- y_pred 约束, y_pred 越高,越容易 19 | loss = -labels * (1 - alpha) * ((1 - y_pred) ** gamma) * torch.log(y_pred) - \ 20 | (1 - labels) * alpha * (y_pred ** gamma) * torch.log(1 - y_pred) 21 | 22 | return loss 23 | 24 | 25 | def flatten(tensor): 26 | """Flattens a given tensor such that the channel axis is first. 27 | The shapes are transformed as follows: 28 | (N, C, D, H, W) -> (C, N * D * H * W) 将 C 通道带出 29 | """ 30 | C = tensor.size(1) 31 | # new axis order 32 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) # 可以 >=4d 33 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 34 | transposed = tensor.permute(axis_order) 35 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 36 | return transposed.contiguous().view(C, -1) 37 | 38 | 39 | def onehot(tensor, num_class): 40 | B, H, W = tensor.shape 41 | y = torch.zeros(num_class, B, H, W) 42 | if tensor.is_cuda: 43 | y = y.cuda() 44 | 45 | for i in range(num_class): 46 | y[i][tensor == i] = 1 # 自动过滤掉 bg_idx 47 | return y.permute(1, 0, 2, 3) # B,C,H,W 48 | 49 | 50 | class SegmentationLosses: 51 | 52 | def __init__(self, mode='ce', weight=None, batch_average=False, ignore_index=constants.BG_INDEX, cuda=True): 53 | self.ignore_index = ignore_index 54 | self.weight = weight 55 | self.batch_average = batch_average 56 | self.cuda = cuda 57 | 58 | self.losses = { 59 | 'ce': self.CELoss, 60 | 'bce': self.BCELoss, 61 | 'focal': self.FocalLoss, 62 | 'dice': self.DiceLoss, 63 | 'mce': self.MultiOutput_CELoss, 64 | } 65 | 66 | if mode not in self.losses: 67 | raise NotImplementedError 68 | 69 | self.loss_fn = self.losses[mode] 70 | 71 | def __call__(self, output, target): 72 | return self.loss_fn(output, target) 73 | 74 | def MultiOutput_CELoss(self, outputs, target): 75 | _, h, w = target.shape 76 | if not isinstance(outputs, list): 77 | outputs = [outputs] 78 | 79 | loss = 0. 80 | for out in outputs: 81 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True) 82 | loss += self.CELoss(out, target) 83 | 84 | return loss 85 | 86 | def CELoss(self, output, target): 87 | """ 88 | @param output: [B,C,H,W] 模型的输出;不需要 softmax, CELoss 内部会完成 89 | @param target: [B,H,W] 90 | """ 91 | return F.cross_entropy(output, target.long(), 92 | weight=self.weight, ignore_index=self.ignore_index, reduction='mean') 93 | 94 | def BCELoss(self, output, target): 95 | """ 96 | @param output: [B,C,H,W] 模型的输出;不需要 softmax, binary_cross_entropy_with_logits 会完成 97 | @param target: [B,H,W] 98 | """ 99 | if len(target.shape) == 3: 100 | target = onehot(target, num_class=output.size(1)) 101 | 102 | loss = F.binary_cross_entropy_with_logits(output, target, 103 | weight=self.weight, reduction='mean') 104 | return loss 105 | 106 | def FocalLoss(self, output, target, gamma=2, alpha=None): 107 | """ 108 | @param output: [B,C,H,W] 模型的输出;不需要 softmax, CELoss 内部会完成 109 | @param target: [B,H,W] 110 | @param gamma: hard-easy regulatory factor 调节难易样本的抑制 111 | @param alpha: class imbalance regulatory factor 定义正样本的权值, CE 只用了正样本 112 | """ 113 | logpt = -F.cross_entropy(output, target.long(), # log(pt^) = -CE(pt) 114 | weight=self.weight, ignore_index=self.ignore_index, reduction='none') 115 | pt = torch.exp(logpt) # loss -> pt, loss=0, pt=1 116 | 117 | if alpha is not None: 118 | logpt *= alpha 119 | loss = -((1 - pt) ** gamma) * logpt # element-wise 120 | 121 | # Online Hard Example Mining: top x% losses (pixel-wise). Refer to http://www.robots.ox.ac.uk/~tvg/publications/2017/0026.pdf 122 | # OHEM, _ = loss.topk(k=int(OHEM_percent * [*loss.shape][0])) 123 | loss = loss.mean() 124 | 125 | return loss 126 | 127 | def DiceLoss(self, output, target): 128 | """ 129 | @param output: [B,C,H,W] 模型输出 130 | @param target: [B,C,H,W] one-hot label 131 | """ 132 | if len(target.shape) == 3: 133 | target = onehot(target, num_class=output.size(1)) 134 | assert output.size() == target.size(), "'input' and 'target' must have the same shape" 135 | output = F.softmax(output, dim=1) # 转成 probs 136 | output, target = flatten(output), flatten(target) # C,N 137 | 138 | # element-wise 视角: dice = p/(1+p), 连续化处理,使得 loss 可导 139 | # sum(-1) 把同类 pixel 分子/分母 分别加和了 140 | inter = (output * target).sum(-1) # C, 乘积 target =1 取出了 gt class 的 p 141 | union = (output + target).sum(-1) 142 | dice = inter / union # C, 每一类的 iou, 应该由很多 0 项 143 | dice = torch.mean(dice) # mIoU 144 | 145 | return 1. - dice 146 | 147 | 148 | class OHEM_CrossEntroy_Loss(nn.Module): 149 | def __init__(self, threshold, keep_num): 150 | super(OHEM_CrossEntroy_Loss, self).__init__() 151 | self.threshold = threshold 152 | self.keep_num = keep_num 153 | self.loss_function = nn.CrossEntropyLoss(reduction='none') 154 | 155 | def forward(self, output, target): 156 | loss = self.loss_function(output, target).view(-1) 157 | loss, loss_index = torch.sort(loss, descending=True) 158 | threshold_in_keep_num = loss[self.keep_num] 159 | if threshold_in_keep_num > self.threshold: 160 | loss = loss[loss > self.threshold] 161 | else: 162 | loss = loss[:self.keep_num] # 保存部分 hard example 训练 163 | return torch.mean(loss) 164 | 165 | 166 | class PixelWise_Loss(nn.Module): 167 | def __init__(self, weight=None, ignore_index=255): 168 | super().__init__() # teacher 相当于生成 soft label, 依然可以用 weight 169 | self.logsoftmax = nn.LogSoftmax(dim=0) 170 | # todo: 可为 distill loss 加上 weight 171 | 172 | def forward(self, preds_S, preds_T): 173 | """ 174 | predict + middle supervison predicts 175 | :param preds_S: [res, res1, res2] 176 | :param preds_T: [res, res1, res2] 177 | :return: 178 | """ 179 | losses = 0. 180 | for s, t in zip(preds_S, preds_T): # B,C,[1/8, 1/16, 1/16] 181 | t = t.detach() # teacher infer 结果 detach() 182 | B, C, H, W = t.shape 183 | 184 | # -p(x) * log(q(x)) 185 | softmax_t = F.softmax(flatten(t), dim=0) # p(x), flatten return C,B*H*W 186 | logsoftmax_s = F.log_softmax(flatten(s), dim=0) # log(q(x)) 187 | 188 | loss = torch.sum(-softmax_t * logsoftmax_s) / H / W / B # KL diver of each pixel 189 | losses += loss 190 | 191 | return losses 192 | 193 | 194 | class PairWise_Loss(nn.Module): 195 | def __init__(self): 196 | super().__init__() 197 | self.criterion = sim_dis_compute # MSE loss, 计算 pixel 之间编码 feature 的相似度 198 | 199 | def forward(self, feats_S, feats_T): 200 | """ 201 | context path middle features 202 | :param feats_S: [cx1, cx2] 203 | :param feats_T: [cx1, cx2] 204 | :return: 205 | """ 206 | losses = 0. 207 | for s, t in zip(feats_S, feats_T): # B,C,1/16 208 | t = t.detach() # context path feature 209 | B, C, H, W = t.shape 210 | # patch_h, patch_w = H // 2, W // 2 # max_pool 到 2x2 计算 211 | patch_h, patch_w = H // 4, W // 4 # 控制输出 feature map 的大小 212 | # todo: 可以考虑调小 pool size 213 | maxpool = nn.MaxPool2d(kernel_size=(patch_w, patch_h), stride=(patch_w, patch_h), 214 | padding=0, ceil_mode=True) 215 | loss = self.criterion(maxpool(s), maxpool(t)) # 2x2 216 | losses += loss 217 | return losses 218 | 219 | 220 | def L2(f_): 221 | # 每个 pixel 的 C 维度向量代表的 欧氏距离 √(·)^2 222 | return (((f_ ** 2).sum(dim=1)) ** 0.5).reshape(f_.shape[0], 1, f_.shape[2], f_.shape[3]) + 1e-8 # B,1,H,W 223 | 224 | 225 | def similarity(feat): 226 | feat = feat.float() 227 | 228 | # L2-norm 归一化 (单位向量,计算 cos-dist, 只要 @) 229 | tmp = L2(feat).detach() # B,1,H,W 230 | feat = feat / tmp 231 | feat = feat.reshape(feat.shape[0], feat.shape[1], -1) # B,C,H*W, each pixel 232 | 233 | # mc @ cn -> mn, [B,H*W,H*W] 类似 self-attention, pixel_i 与剩余所有 pixel 的相似度 234 | return torch.einsum('icm,icn->imn', [feat, feat]) # [B,H*W,H*W] 235 | 236 | 237 | def sim_dis_compute(f_S, f_T): 238 | # 求 mean 即为 each pixel, L2-norm 归一化后 cos-dist 太小, 与 pi-loss 量级差太大 239 | sim_dis = torch.mean((similarity(f_T) - similarity(f_S)) ** 2) 240 | return sim_dis 241 | 242 | 243 | if __name__ == "__main__": 244 | a = torch.rand(1, 10, 32, 32) 245 | b = torch.rand(1, 10, 32, 32) 246 | c = torch.rand(1, 10, 64, 64) 247 | 248 | # b = torch.ones(1, 2, 2) 249 | # 250 | # print(SegmentationLosses(mode='ce', cuda=False)(a, b)) 251 | # print(SegmentationLosses(mode='bce', cuda=False)(a, b)) 252 | # print(SegmentationLosses(mode='focal', cuda=False)(a, b)) 253 | # print(SegmentationLosses(mode='dice', cuda=False)(a, b)) 254 | -------------------------------------------------------------------------------- /model/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from torchvision.models.resnet import model_urls 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, # 只有1次 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn1 = BatchNorm(planes) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 18 | dilation=dilation, padding=dilation, bias=False) 19 | self.bn2 = BatchNorm(planes) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | 33 | if self.downsample is not None: 34 | residual = self.downsample(x) 35 | 36 | out += residual 37 | out = self.relu(out) 38 | 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 48 | self.bn1 = BatchNorm(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 50 | dilation=dilation, padding=dilation, bias=False) 51 | self.bn2 = BatchNorm(planes) 52 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 53 | self.bn3 = BatchNorm(planes * 4) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.downsample = downsample 56 | self.stride = stride 57 | self.dilation = dilation 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class ResNet(nn.Module): 83 | 84 | def __init__(self, block, inplanes, layers, output_stride, BatchNorm, pretrained=True): 85 | super(ResNet, self).__init__() # inplanes 默认 =64, 可调节小网络 86 | blocks = [1, 2, 4] # multi grids 87 | self.inplanes = inplanes 88 | 89 | # before layers, out 1/4 90 | # layer3 2/1 for different output stride 91 | if output_stride == 16: 92 | strides = [1, 2, 2, 1] 93 | dilations = [1, 1, 1, 2] # 1/4, 1/8, 1/16, 1/16 94 | elif output_stride == 8: # strides 少1个2; layer3,4, dilation x2 95 | strides = [1, 2, 1, 1] 96 | dilations = [1, 1, 2, 4] 97 | else: 98 | raise NotImplementedError 99 | 100 | # Modules 101 | self.conv1 = nn.Conv2d(3, inplanes, kernel_size=7, stride=2, padding=3, bias=False) 102 | self.bn1 = BatchNorm(inplanes) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 1/4 105 | 106 | self.layer1 = self._make_layer(block, inplanes, layers[0], stride=strides[0], dilation=dilations[0], # 1/4 107 | BatchNorm=BatchNorm) 108 | self.layer2 = self._make_layer(block, inplanes * 2, layers[1], stride=strides[1], dilation=dilations[1], # 1/8 109 | BatchNorm=BatchNorm) 110 | self.layer3 = self._make_layer(block, inplanes * 4, layers[2], stride=strides[2], dilation=dilations[2], # 1/16 111 | BatchNorm=BatchNorm) 112 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], 113 | # BatchNorm=BatchNorm) 114 | self.layer4 = self._make_MG_unit(block, inplanes * 8, blocks=blocks, stride=strides[3], dilation=dilations[3], # 1,2 115 | BatchNorm=BatchNorm) 116 | 117 | self._init_weight() 118 | 119 | if pretrained: 120 | self._load_pretrained_model(layers) 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 123 | """ 124 | :param block: BasicBlock, Bottleneck 125 | :param planes: features num = planes * block.expansion 126 | :param blocks: block repeat times 127 | :param stride: 1st conv's stride of current layer 128 | :param dilation: 129 | :param BatchNorm: 130 | :return: 131 | """ 132 | # layer 连接处,首层残差连接 是否需要 downsample 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | BatchNorm(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | # 首个 block 143 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 144 | # 内部 block 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 152 | """ 153 | 级联 dilation 模块,参考 deeplabv3+,维持 1/16, 但采集更大感受野 feature 154 | blocks: [1, 2, 4] 155 | stride=1, dilation=2 156 | layer dilations: 2,4,8 157 | """ 158 | downsample = None 159 | if stride != 1 or self.inplanes != planes * block.expansion: 160 | downsample = nn.Sequential( 161 | nn.Conv2d(self.inplanes, planes * block.expansion, 162 | kernel_size=1, stride=stride, bias=False), 163 | BatchNorm(planes * block.expansion), 164 | ) 165 | layers = [] 166 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0] * dilation, 167 | downsample=downsample, BatchNorm=BatchNorm)) 168 | self.inplanes = planes * block.expansion 169 | for i in range(1, len(blocks)): 170 | layers.append(block(self.inplanes, planes, stride=1, 171 | dilation=blocks[i] * dilation, BatchNorm=BatchNorm)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | def forward(self, input): 176 | x = self.conv1(input) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | x = self.maxpool(x) 180 | 181 | x = self.layer1(x) # 1/4 182 | x = self.layer2(x) # 1/8 183 | x3 = self.layer3(x) # 1/16 184 | x4 = self.layer4(x3) # 1/16 185 | return x3, x4 186 | 187 | def _init_weight(self): 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 191 | m.weight.data.normal_(0, math.sqrt(2. / n)) 192 | elif isinstance(m, SynchronizedBatchNorm2d): 193 | m.weight.data.fill_(1) 194 | m.bias.data.zero_() 195 | elif isinstance(m, nn.BatchNorm2d): 196 | m.weight.data.fill_(1) 197 | m.bias.data.zero_() 198 | 199 | def _load_pretrained_model(self, layers): 200 | if layers == [3, 4, 23, 3]: 201 | pretrain_dict = model_zoo.load_url(model_urls['resnet101']) 202 | elif layers == [3, 4, 6, 3]: 203 | pretrain_dict = model_zoo.load_url(model_urls['resnet50']) 204 | elif layers == [2, 2, 2, 2]: 205 | pretrain_dict = model_zoo.load_url(model_urls['resnet18']) 206 | else: 207 | raise NotImplementedError 208 | model_dict = {} 209 | state_dict = self.state_dict() 210 | for k, v in pretrain_dict.items(): 211 | # 参数 name & 参数 size 双重判断 212 | if k in state_dict and v.size() == state_dict[k].size(): 213 | model_dict[k] = v 214 | state_dict.update(model_dict) 215 | self.load_state_dict(state_dict) 216 | 217 | 218 | def build_contextpath(model, inplanes=64, output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=True): 219 | if model == 'resnet18': 220 | return ResNet(BasicBlock, inplanes, [2, 2, 2, 2], output_stride, BatchNorm, pretrained=pretrained) 221 | elif model == 'resnet50': 222 | return ResNet(Bottleneck, inplanes, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained) 223 | elif model == 'resnet101': 224 | return ResNet(Bottleneck, inplanes, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 225 | 226 | 227 | if __name__ == "__main__": 228 | import torch 229 | 230 | model = build_contextpath('resnet50', inplanes=32, pretrained=True) 231 | input = torch.rand(1, 3, 512, 512) 232 | feature3, feature4 = model(input) 233 | print(feature3.size()) # 1/8 234 | print(feature4.size()) # 1/4 235 | -------------------------------------------------------------------------------- /utils/trainer_kd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import constants 3 | from model.sync_batchnorm.replicate import patch_replication_callback 4 | from utils.misc import get_learning_rate 5 | from utils.loss import SegmentationLosses, PairWise_Loss, PixelWise_Loss 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | from utils.metrics import Evaluator 9 | from utils.misc import AverageMeter 10 | from utils.lr_scheduler import LR_Scheduler 11 | from tqdm import tqdm 12 | import torch.nn.functional as F 13 | 14 | 15 | class Trainer: 16 | 17 | def __init__(self, args, student, teacher, train_set, val_set, test_set, class_weights, saver, writer): 18 | self.args = args 19 | self.saver = saver 20 | self.saver.save_experiment_config() # save cfgs 21 | self.writer = writer 22 | 23 | self.num_classes = train_set.num_classes 24 | 25 | # dataloaders 26 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 27 | self.train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 28 | self.val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 29 | self.test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 30 | 31 | self.dataset_size = {'train': len(train_set), 'val': len(val_set), 'test': len(test_set)} 32 | print('dataset size:', self.dataset_size) 33 | 34 | # 加快训练,减少每轮迭代次数;不需要从引入样本时就截断数据,这样更好 35 | self.iters_per_epoch = args.iters_per_epoch if args.iters_per_epoch else len(self.train_dataloader) 36 | 37 | self.device = torch.device(f'cuda:{args.gpu_ids}') 38 | 39 | # todo: 二者可以考虑在 2个 device 上? 40 | self.student = student.to(self.device) 41 | self.teacher = teacher.to(self.device).eval() # 用来生成训练 target 42 | 43 | # student is generator 44 | self.G_optimizer = torch.optim.SGD([{ 45 | 'params': filter(lambda p: p.requires_grad, self.student.parameters()), 46 | 'initial_lr': args.lr_g 47 | }], args.lr_g, momentum=args.momentum, weight_decay=args.weight_decay) 48 | self.G_lr_scheduler = LR_Scheduler(mode=args.lr_scheduler, base_lr=args.lr_g, 49 | num_epochs=args.epochs, iters_per_epoch=self.iters_per_epoch) 50 | 51 | # todo: discriminator 52 | # self.D_solver = optim.SGD([{ 53 | # 'params': filter(lambda p: p.requires_grad, D_model.parameters()), 54 | # 'initial_lr': args.lr_d 55 | # }], args.lr_d, momentum=args.momentum, weight_decay=args.weight_decay) 56 | 57 | # loss 58 | if args.use_balanced_weights: 59 | weight = torch.from_numpy(class_weights.astype(np.float32)).to(self.device) 60 | else: 61 | weight = None 62 | 63 | # 原有 loss 64 | self.criterion = SegmentationLosses(mode=args.loss_type, weight=weight, ignore_index=constants.BG_INDEX) 65 | self.criterion_pi = PixelWise_Loss(weight=weight, ignore_index=constants.BG_INDEX) 66 | self.criterion_pa = PairWise_Loss() 67 | 68 | # evaluator 69 | self.evaluator = Evaluator(self.num_classes) 70 | 71 | self.best_epoch = 0 72 | self.best_mIoU = 0.0 73 | self.best_pixelAcc = 0.0 74 | 75 | def training(self, epoch, prefix='Train', evaluation=False): 76 | self.student.train() 77 | if evaluation: 78 | self.evaluator.reset() 79 | 80 | train_losses = AverageMeter() 81 | segment_losses = AverageMeter() 82 | pi_losses, pa_losses = AverageMeter(), AverageMeter() 83 | 84 | tbar = tqdm(self.train_dataloader, desc='\r', total=self.iters_per_epoch) # 设置最多迭代次数, 从0开始.. 85 | 86 | if self.writer: 87 | self.writer.add_scalar(f'{prefix}/learning_rate', get_learning_rate(self.G_optimizer), epoch) 88 | 89 | for i, sample in enumerate(tbar): 90 | image, target = sample['img'], sample['target'] 91 | image, target = image.to(self.device), target.to(self.device) 92 | 93 | # adjust lr 94 | self.G_lr_scheduler(self.G_optimizer, i, epoch) 95 | 96 | # forward 97 | with torch.no_grad(): 98 | preds_T = self.teacher(image) # [res, res1, res2, cx1, cx2] 99 | preds_S = self.student(image) 100 | 101 | # 分割 loss 102 | G_loss = self.criterion(preds_S[:3], target) # multiple output loss 103 | segment_losses.update(G_loss.item()) 104 | 105 | # 蒸馏 loss 106 | if self.args.pi: # pixel wise loss 107 | loss = self.args.lambda_pi * self.criterion_pi(preds_S[:3], preds_T[:3]) 108 | G_loss += loss 109 | pi_losses.update(loss.item()) 110 | 111 | if self.args.pa: # pairwise loss 112 | loss = self.args.lambda_pa * self.criterion_pa(preds_S[3:], preds_T[3:]) 113 | G_loss += loss 114 | pa_losses.update(loss.item()) 115 | 116 | self.G_optimizer.zero_grad() 117 | G_loss.backward() 118 | self.G_optimizer.step() 119 | 120 | train_losses.update(G_loss.item()) 121 | tbar.set_description('Epoch {}, Train loss: {:.3} = seg {:.3f} + pi {:.3f} + pa {:.10f}'.format( 122 | epoch, train_losses.avg, segment_losses.avg, pi_losses.avg, pa_losses.avg)) 123 | 124 | if evaluation: 125 | output = F.interpolate(preds_S[0], size=(target.size(1), target.size(2)), mode='bilinear', align_corners=True) 126 | pred = torch.argmax(output, dim=1) 127 | self.evaluator.add_batch(target.cpu().numpy(), pred.cpu().numpy()) # B,H,W 128 | 129 | # 即便 tqdm 有 total,仍然要这样跳出 130 | if i == self.iters_per_epoch - 1: 131 | break 132 | 133 | if self.writer: 134 | self.writer.add_scalars(f'{prefix}/loss', { 135 | 'train': train_losses.avg, 136 | 'segment': segment_losses.avg, 137 | 'pi': pi_losses.avg, 138 | 'pa': pa_losses.avg 139 | }, epoch) 140 | if evaluation: 141 | Acc = self.evaluator.Pixel_Accuracy() 142 | mIoU = self.evaluator.Mean_Intersection_over_Union() 143 | print('Epoch: {}, Acc_pixel:{:.3f}, mIoU:{:.3f}'.format(epoch, Acc, mIoU)) 144 | 145 | self.writer.add_scalars(f'{prefix}/IoU', { 146 | 'mIoU': mIoU, 147 | # 'mDice': mDice, 148 | }, epoch) 149 | self.writer.add_scalars(f'{prefix}/Acc', { 150 | 'acc_pixel': Acc, 151 | # 'acc_class': Acc_class 152 | }, epoch) 153 | 154 | @torch.no_grad() 155 | def validation(self, epoch, test=False): 156 | self.student.eval() 157 | self.evaluator.reset() # reset confusion matrix 158 | 159 | if test: 160 | tbar = tqdm(self.test_dataloader, desc='\r') 161 | prefix = 'Test' 162 | else: 163 | tbar = tqdm(self.val_dataloader, desc='\r') 164 | prefix = 'Valid' 165 | 166 | # loss 167 | segment_losses = AverageMeter() 168 | 169 | for i, sample in enumerate(tbar): 170 | image, target = sample['img'], sample['target'] 171 | image, target = image.to(self.device), target.to(self.device) 172 | 173 | output = self.student(image)[0] # 拿到首个输出 174 | segment_loss = self.criterion(output, target) 175 | segment_losses.update(segment_loss.item()) 176 | tbar.set_description(f'{prefix} loss: %.4f' % segment_losses.avg) 177 | 178 | output = F.interpolate(output, size=(target.size()[1:]), mode='bilinear', align_corners=True) 179 | pred = torch.argmax(output, dim=1) # pred 180 | 181 | # eval: add batch result 182 | self.evaluator.add_batch(target.cpu().numpy(), pred.cpu().numpy()) # B,H,W 183 | 184 | Acc = self.evaluator.Pixel_Accuracy() 185 | # Acc_class = self.evaluator.Pixel_Accuracy_Class() 186 | mIoU = self.evaluator.Mean_Intersection_over_Union() 187 | # mDice = self.evaluator.Mean_Dice() 188 | print('Epoch: {}, Acc_pixel: {:.4f}, mIoU: {:.4f}'.format(epoch, Acc, mIoU)) 189 | 190 | if self.writer: 191 | self.writer.add_scalar(f'{prefix}/loss', segment_losses.avg, epoch) 192 | self.writer.add_scalars(f'{prefix}/IoU', { 193 | 'mIoU': mIoU, 194 | # 'mDice': mDice, 195 | }, epoch) 196 | self.writer.add_scalars(f'{prefix}/Acc', { 197 | 'acc_pixel': Acc, 198 | # 'acc_class': Acc_class 199 | }, epoch) 200 | 201 | if not test: 202 | if mIoU > self.best_mIoU: 203 | print('saving model...') 204 | self.best_mIoU = mIoU 205 | self.best_pixelAcc = Acc 206 | self.best_epoch = epoch 207 | 208 | state = { 209 | 'epoch': self.best_epoch, 210 | 'state_dict': self.student.state_dict(), # 方便 test 保持同样结构? 211 | 'optimizer': self.G_optimizer.state_dict(), 212 | 'best_mIoU': self.best_mIoU, 213 | 'best_pixelAcc': self.best_pixelAcc 214 | } 215 | self.saver.save_checkpoint(state) 216 | print('save model at epoch', epoch) 217 | 218 | return mIoU, Acc 219 | 220 | def load_best_checkpoint(self): 221 | checkpoint = self.saver.load_checkpoint() 222 | self.student.load_state_dict(checkpoint['state_dict']) 223 | # self.G_optimizer.load_state_dict(checkpoint['optimizer']) 224 | print(f'=> loaded checkpoint - epoch {checkpoint["epoch"]}') 225 | return checkpoint["epoch"] 226 | -------------------------------------------------------------------------------- /model/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | 9 | def fixed_padding(inputs, kernel_size, dilation): 10 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 11 | pad_total = kernel_size_effective - 1 12 | pad_beg = pad_total // 2 13 | pad_end = pad_total - pad_beg 14 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 15 | return padded_inputs 16 | 17 | 18 | class SeparableConv2d(nn.Module): 19 | 20 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 21 | super(SeparableConv2d, self).__init__() 22 | 23 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 24 | groups=inplanes, bias=bias) 25 | self.bn = BatchNorm(inplanes) 26 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 27 | 28 | def forward(self, x): 29 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 30 | x = self.conv1(x) 31 | x = self.bn(x) 32 | x = self.pointwise(x) 33 | return x 34 | 35 | 36 | class Block(nn.Module): 37 | 38 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 39 | start_with_relu=True, grow_first=True, is_last=False): 40 | super(Block, self).__init__() 41 | 42 | if planes != inplanes or stride != 1: 43 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 44 | self.skipbn = BatchNorm(planes) 45 | else: 46 | self.skip = None 47 | 48 | self.relu = nn.ReLU(inplace=True) 49 | rep = [] 50 | 51 | filters = inplanes 52 | if grow_first: 53 | rep.append(self.relu) 54 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 55 | rep.append(BatchNorm(planes)) 56 | filters = planes 57 | 58 | for i in range(reps - 1): 59 | rep.append(self.relu) 60 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 61 | rep.append(BatchNorm(filters)) 62 | 63 | if not grow_first: 64 | rep.append(self.relu) 65 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 66 | rep.append(BatchNorm(planes)) 67 | 68 | if stride != 1: 69 | rep.append(self.relu) 70 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 71 | rep.append(BatchNorm(planes)) 72 | 73 | if stride == 1 and is_last: 74 | rep.append(self.relu) 75 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 76 | rep.append(BatchNorm(planes)) 77 | 78 | if not start_with_relu: 79 | rep = rep[1:] 80 | 81 | self.rep = nn.Sequential(*rep) 82 | 83 | def forward(self, inp): 84 | x = self.rep(inp) 85 | 86 | if self.skip is not None: 87 | skip = self.skip(inp) 88 | skip = self.skipbn(skip) 89 | else: 90 | skip = inp 91 | 92 | x = x + skip 93 | 94 | return x 95 | 96 | 97 | class AlignedXception(nn.Module): 98 | """ 99 | Modified Alighed Xception 100 | """ 101 | 102 | def __init__(self, output_stride, BatchNorm, 103 | pretrained=True): 104 | super(AlignedXception, self).__init__() 105 | 106 | if output_stride == 16: 107 | entry_block3_stride = 2 108 | middle_block_dilation = 1 109 | exit_block_dilations = (1, 2) 110 | elif output_stride == 8: 111 | entry_block3_stride = 1 112 | middle_block_dilation = 2 113 | exit_block_dilations = (2, 4) 114 | else: 115 | raise NotImplementedError 116 | 117 | # Entry flow 118 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 119 | self.bn1 = BatchNorm(32) 120 | self.relu = nn.ReLU(inplace=True) 121 | 122 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 123 | self.bn2 = BatchNorm(64) 124 | 125 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 126 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 127 | grow_first=True) 128 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 129 | start_with_relu=True, grow_first=True, is_last=True) 130 | 131 | # Middle flow 132 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 133 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 134 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 135 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 136 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 137 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 138 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 139 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 140 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 141 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 142 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 143 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 144 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 145 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 146 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 147 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 148 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 149 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 150 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 151 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 152 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 153 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 154 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 155 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 156 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 157 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 158 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 159 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 160 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 161 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 162 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 163 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 164 | 165 | # Exit flow 166 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 167 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 168 | 169 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn3 = BatchNorm(1536) 171 | 172 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn4 = BatchNorm(1536) 174 | 175 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 176 | self.bn5 = BatchNorm(2048) 177 | 178 | # Init weights 179 | self._init_weight() 180 | 181 | # Load pretrained model 182 | if pretrained: 183 | self._load_pretrained_model() 184 | 185 | def forward(self, x): 186 | # Entry flow 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | 191 | x = self.conv2(x) 192 | x = self.bn2(x) 193 | x = self.relu(x) 194 | 195 | x = self.block1(x) 196 | # add relu here 197 | x = self.relu(x) 198 | low_level_feat = x 199 | x = self.block2(x) 200 | x = self.block3(x) 201 | 202 | # Middle flow 203 | x = self.block4(x) 204 | x = self.block5(x) 205 | x = self.block6(x) 206 | x = self.block7(x) 207 | x = self.block8(x) 208 | x = self.block9(x) 209 | x = self.block10(x) 210 | x = self.block11(x) 211 | x = self.block12(x) 212 | x = self.block13(x) 213 | x = self.block14(x) 214 | x = self.block15(x) 215 | x = self.block16(x) 216 | x = self.block17(x) 217 | x = self.block18(x) 218 | x = self.block19(x) 219 | 220 | # Exit flow 221 | x = self.block20(x) 222 | x = self.relu(x) 223 | x = self.conv3(x) 224 | x = self.bn3(x) 225 | x = self.relu(x) 226 | 227 | x = self.conv4(x) 228 | x = self.bn4(x) 229 | x = self.relu(x) 230 | 231 | x = self.conv5(x) 232 | x = self.bn5(x) 233 | x = self.relu(x) 234 | 235 | return x, low_level_feat 236 | 237 | def _init_weight(self): 238 | for m in self.modules(): 239 | if isinstance(m, nn.Conv2d): 240 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 241 | m.weight.data.normal_(0, math.sqrt(2. / n)) 242 | elif isinstance(m, SynchronizedBatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | elif isinstance(m, nn.BatchNorm2d): 246 | m.weight.data.fill_(1) 247 | m.bias.data.zero_() 248 | 249 | def _load_pretrained_model(self): 250 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 251 | model_dict = {} 252 | state_dict = self.state_dict() 253 | 254 | for k, v in pretrain_dict.items(): 255 | if k in model_dict: 256 | if 'pointwise' in k: 257 | v = v.unsqueeze(-1).unsqueeze(-1) 258 | if k.startswith('block11'): 259 | model_dict[k] = v 260 | model_dict[k.replace('block11', 'block12')] = v 261 | model_dict[k.replace('block11', 'block13')] = v 262 | model_dict[k.replace('block11', 'block14')] = v 263 | model_dict[k.replace('block11', 'block15')] = v 264 | model_dict[k.replace('block11', 'block16')] = v 265 | model_dict[k.replace('block11', 'block17')] = v 266 | model_dict[k.replace('block11', 'block18')] = v 267 | model_dict[k.replace('block11', 'block19')] = v 268 | elif k.startswith('block12'): 269 | model_dict[k.replace('block12', 'block20')] = v 270 | elif k.startswith('bn3'): 271 | model_dict[k] = v 272 | model_dict[k.replace('bn3', 'bn4')] = v 273 | elif k.startswith('conv4'): 274 | model_dict[k.replace('conv4', 'conv5')] = v 275 | elif k.startswith('bn4'): 276 | model_dict[k.replace('bn4', 'bn5')] = v 277 | else: 278 | model_dict[k] = v 279 | state_dict.update(model_dict) 280 | self.load_state_dict(state_dict) 281 | 282 | 283 | if __name__ == "__main__": 284 | import torch 285 | 286 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 287 | input = torch.rand(1, 3, 512, 512) 288 | output, low_level_feat = model(input) 289 | print(output.size()) 290 | print(low_level_feat.size()) 291 | -------------------------------------------------------------------------------- /model/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from matplotlib.lines import Line2D 5 | import cv2 6 | import constants 7 | 8 | 9 | def get_label_name_colors(csv_path): 10 | """ 11 | read csv_file and save as label names and colors list 12 | :param csv_path: csv color file path 13 | :return: lable name list, label color list 14 | """ 15 | label_names, label_colors = [], [] 16 | with open(csv_path, 'r') as csv_file: 17 | reader = csv.reader(csv_file) 18 | for i, row in enumerate(reader): 19 | if i > 0: # 跳过第一行 20 | label_names.append(row[0]) 21 | label_colors.append([int(row[1]), int(row[2]), int(row[3])]) 22 | 23 | return label_names, label_colors 24 | 25 | 26 | def color_code_target(target, label_colors): 27 | return np.array(label_colors)[target.astype('int')] 28 | 29 | 30 | def get_legends(class_set, label_names, label_colors): 31 | legend_names, legend_lines = [], [] 32 | for i in class_set: 33 | legend_names.append(label_names[i]) # 图例 34 | legend_lines.append(Line2D([0], [0], color=map_color(label_colors[i]), lw=2)) # 颜色线 35 | return legend_names, legend_lines 36 | 37 | 38 | def map_color(rgb): 39 | return [v / 255 for v in rgb] 40 | 41 | 42 | def plt_img_target_pred(img, target, pred, label_colors, vertial=False): 43 | # target_class_set = set(target.astype('int').flatten().tolist()) 44 | # pred_class_set = set(pred.astype('int').flatten().tolist()) 45 | # target_leg_names, target_leg_lines = get_legends(target_class_set, label_names, label_colors) 46 | # pred_leg_names, pred_leg_lines = get_legends(pred_class_set, label_names, label_colors) 47 | 48 | if vertial: 49 | f, axs = plt.subplots(nrows=3, ncols=1) 50 | f.set_size_inches((4, 9)) 51 | else: 52 | f, axs = plt.subplots(nrows=1, ncols=3, dpi=150) 53 | f.set_size_inches((10, 3)) 54 | 55 | ax1, ax2, ax3 = axs.flat[0], axs.flat[1], axs.flat[2] 56 | 57 | # ax1.axis('off') 58 | ax1.imshow(img) 59 | ax1.set_title('img') 60 | 61 | # ax2.axis('off') 62 | ax2.imshow(color_code_target(target, label_colors)) 63 | ax2.set_title('target') 64 | 65 | # ax3.axis('off') 66 | ax3.imshow(color_code_target(pred, label_colors)) 67 | ax3.set_title('predict') 68 | 69 | plt.show() 70 | 71 | 72 | def plt_img_target_pred_error(img, target, pred, error_mask, label_colors, title=None): 73 | f, axs = plt.subplots(nrows=1, ncols=4, dpi=150) 74 | f.set_size_inches((18, 5)) 75 | 76 | ax1, ax2, ax3, ax4 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3] 77 | 78 | # ax1.axis('off') 79 | ax1.imshow(img) 80 | ax1.set_title('img') 81 | 82 | # ax2.axis('off') 83 | ax2.imshow(color_code_target(target, label_colors)) 84 | ax2.set_title('target') 85 | 86 | # ax3.axis('off') 87 | ax3.imshow(color_code_target(pred, label_colors)) 88 | ax3.set_title('predict') 89 | 90 | ax4.imshow(color_code_target(error_mask, label_colors)) 91 | ax4.set_title('error') 92 | 93 | if title: 94 | plt.title(title) # 做到最后一个图 95 | 96 | plt.show() 97 | 98 | 99 | def plt_img_target(img, target, label_colors, title=None): 100 | f, axs = plt.subplots(nrows=1, ncols=2, dpi=100) 101 | f.set_size_inches((7, 4)) 102 | ax1, ax2 = axs.flat[0], axs.flat[1] 103 | 104 | ax1.axis('off') 105 | ax1.imshow(img) 106 | ax1.set_title('img') 107 | 108 | ax2.axis('off') 109 | ax2.imshow(color_code_target(target, label_colors)) 110 | ax2.set_title('target') 111 | 112 | if title: 113 | plt.suptitle(title) 114 | 115 | plt.show() 116 | 117 | 118 | def plt_img_target_ceal(img, target, ceal, label_colors): 119 | f, axs = plt.subplots(nrows=1, ncols=3) 120 | f.set_size_inches((10, 3)) 121 | ax1, ax2, ax3 = axs.flat[0], axs.flat[1], axs.flat[2] 122 | 123 | ax1.axis('off') 124 | ax1.imshow(img) 125 | ax1.set_title('img') 126 | 127 | ax2.axis('off') 128 | ax2.imshow(color_code_target(target, label_colors)) 129 | ax2.set_title('target') 130 | 131 | ax3.axis('off') 132 | ax3.imshow(color_code_target(ceal, label_colors)) 133 | ax3.set_title('ceal') 134 | 135 | plt.show() 136 | 137 | 138 | def plt_color_label(target, label_colors, title): 139 | plt.figure() 140 | plt.axis('off') 141 | plt.imshow(color_code_target(target, label_colors)) 142 | plt.title(title) 143 | plt.show() 144 | 145 | 146 | def plt_img_target_gt_ceal(img, target, gt, ceal, label_colors): 147 | f, axs = plt.subplots(nrows=2, ncols=2) 148 | f.set_size_inches((8, 6)) # 800, 600 149 | ax1, ax2, ax3, ax4 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3] 150 | 151 | ax1.axis('off') 152 | ax1.imshow(img) 153 | ax1.set_title('img') 154 | 155 | ax2.axis('off') 156 | ax2.imshow(color_code_target(target, label_colors)) 157 | ax2.set_title('target') 158 | 159 | ax3.axis('off') 160 | ax3.imshow(color_code_target(gt, label_colors)) 161 | ax3.set_title('gt') 162 | 163 | ax4.axis('off') 164 | ax4.imshow(color_code_target(ceal, label_colors)) 165 | ax4.set_title('ceal') 166 | 167 | plt.show() 168 | 169 | 170 | def get_plt_img_target_gt_ceal(img, target, gt, ceal, label_colors, figsize=(8, 6), title=None): 171 | fig, axs = plt.subplots(nrows=2, ncols=2) 172 | fig.set_size_inches(figsize) 173 | ax1, ax2, ax3, ax4 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3] 174 | 175 | ax1.axis('off') 176 | ax1.imshow(img) 177 | ax1.set_title('img') 178 | 179 | ax2.axis('off') 180 | ax2.imshow(color_code_target(target, label_colors)) 181 | ax2.set_title('target') 182 | 183 | ax3.axis('off') 184 | ax3.imshow(color_code_target(gt, label_colors)) 185 | ax3.set_title('gt') 186 | 187 | ax4.axis('off') 188 | ax4.imshow(color_code_target(ceal, label_colors)) 189 | ax4.set_title('ceal') 190 | 191 | if title: 192 | plt.suptitle(title) 193 | 194 | # cvt plt result to np img 195 | fig.canvas.draw() 196 | w, h = fig.canvas.get_width_height() 197 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 198 | img = img.reshape((h, w, 3)) # 转成 img 实际大小 199 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 200 | plt.cla() 201 | plt.close("all") 202 | 203 | return img 204 | 205 | 206 | import torch 207 | from utils.misc import to_numpy 208 | 209 | 210 | def save_error_mask(error_mask, save_path): 211 | if isinstance(error_mask, torch.Tensor): 212 | error_mask = to_numpy(error_mask) 213 | plt.axis('off') 214 | plt.imshow(error_mask, cmap='jet') # crop pad,所以不能按照 error_mask 设置 fig 大小 215 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0) 216 | plt.clf() 217 | 218 | 219 | def plt_smooth_error_mask(image, target, label_colors, 220 | target_error_mask, smooth_error_mask, title=None): 221 | f, axs = plt.subplots(nrows=2, ncols=2) 222 | f.set_size_inches((8, 6)) 223 | 224 | ax1, ax2, ax3, ax4 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3] 225 | 226 | # semantic 227 | ax1.axis('off') 228 | ax1.imshow(image) 229 | ax1.set_title('image') 230 | 231 | ax2.axis('off') 232 | ax2.imshow(color_code_target(target, label_colors)) 233 | ax2.set_title('target') 234 | 235 | # mask 236 | ax3.set_xticks([]) 237 | ax3.set_yticks([]) 238 | ax3.imshow(target_error_mask, cmap='gray') 239 | ax3.set_title('target_error_mask') 240 | 241 | ax4.set_xticks([]) 242 | ax4.set_yticks([]) 243 | ax4.imshow(smooth_error_mask, cmap='gray') 244 | ax4.set_title('smooth_error_mask') 245 | 246 | if title: 247 | plt.suptitle(title) 248 | 249 | plt.show() 250 | 251 | 252 | def plt_smooth_error_mask_v2(image, target, predict, label_colors, 253 | target_error_mask, smooth_error_mask, title=None): 254 | f, axs = plt.subplots(nrows=2, ncols=3) 255 | f.set_size_inches((10, 6)) 256 | 257 | ax1, ax2, ax3, ax4, ax5, ax6 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3], axs.flat[4], axs.flat[5] 258 | 259 | # semantic 260 | ax1.axis('off') 261 | ax1.imshow(image) 262 | ax1.set_title('image') 263 | 264 | ax2.axis('off') 265 | ax2.imshow(color_code_target(target, label_colors)) 266 | ax2.set_title('target') 267 | 268 | ax3.axis('off') 269 | ax3.imshow(color_code_target(predict, label_colors)) 270 | ax3.set_title('predict') 271 | 272 | # mask 273 | ax4.set_xticks([]) 274 | ax4.set_yticks([]) 275 | ax4.imshow(target_error_mask, cmap='gray') 276 | ax4.set_title('target_error_mask') 277 | 278 | ax5.set_xticks([]) 279 | ax5.set_yticks([]) 280 | ax5.imshow(smooth_error_mask, cmap='gray') 281 | ax5.set_title('smooth_error_mask') 282 | 283 | if title: 284 | plt.suptitle(title) 285 | 286 | plt.show() 287 | 288 | 289 | def plt_target_pred_masks(target, pred, label_colors, 290 | target_error_mask, pred_error_mask, title=None): 291 | f, axs = plt.subplots(nrows=2, ncols=2) 292 | f.set_size_inches((8, 6)) 293 | 294 | ax1, ax2, ax3, ax4 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3] 295 | 296 | # semantic 297 | ax1.axis('off') 298 | ax1.imshow(color_code_target(target, label_colors)) 299 | ax1.set_title('target') 300 | 301 | ax3.axis('off') 302 | ax3.imshow(color_code_target(pred, label_colors)) 303 | ax3.set_title('predict') 304 | 305 | # mask 306 | ax2.set_xticks([]) 307 | ax2.set_yticks([]) 308 | ax2.imshow(color_code_target(target_error_mask, label_colors)) 309 | ax2.set_title('target_error_mask') 310 | 311 | ax4.axis('off') 312 | ax4.imshow(pred_error_mask, cmap='jet') 313 | ax4.set_title('pred_error_mask') 314 | 315 | if title: 316 | plt.suptitle(title) 317 | 318 | plt.show() 319 | 320 | 321 | def plt_class_superpixels(targets, label_names, label_colors, title=None, save_path=None): 322 | C = targets.shape[0] 323 | 324 | # plt.subplots_adjust(wspace=0, hspace=0) 325 | 326 | rows, cols = 3, 4 327 | f, axs = plt.subplots(nrows=rows, ncols=cols) 328 | f.set_size_inches((10, 6)) 329 | 330 | for i in range(rows * cols): 331 | ax = axs.flat[i] 332 | ax.set_xticks([]) 333 | ax.set_yticks([]) 334 | if i < C: 335 | ax.imshow(color_code_target(targets[i], label_colors)) 336 | ax.set_title(label_names[i], fontsize=10) 337 | else: 338 | ax.axis('off') 339 | 340 | if title: 341 | plt.suptitle(title) 342 | 343 | f.tight_layout() # 调整整体空白 344 | 345 | if save_path: 346 | plt.savefig(save_path) 347 | 348 | # plt.show() 349 | 350 | 351 | def plt_class_errors(errors, label_names, title=None, save_path=None): 352 | # 做出 backbone C features 或 target C features 353 | 354 | C = errors.shape[0] 355 | 356 | # plt.subplots_adjust(wspace=0, hspace=0) 357 | 358 | rows, cols = 3, 4 359 | f, axs = plt.subplots(nrows=rows, ncols=cols) 360 | f.set_size_inches((10, 6)) 361 | 362 | for i in range(rows * cols): 363 | ax = axs.flat[i] 364 | ax.set_xticks([]) 365 | ax.set_yticks([]) 366 | if i < C: 367 | ax.imshow(errors[i], cmap='jet') 368 | ax.set_title(label_names[i], fontsize=10) 369 | else: 370 | ax.axis('off') 371 | 372 | if title: 373 | plt.suptitle(title) 374 | 375 | f.tight_layout() # 调整整体空白 376 | 377 | if save_path: 378 | plt.savefig(save_path) 379 | 380 | # plt.show() 381 | 382 | 383 | def plt_superpixel_scores(targets, all_sps, label_names, label_colors, title=None, save_path=None): 384 | C = targets.shape[0] 385 | 386 | rows, cols = 3, 4 387 | f, axs = plt.subplots(nrows=rows, ncols=cols) 388 | f.set_size_inches((20, 13)) 389 | 390 | for i in range(rows * cols): 391 | ax = axs.flat[i] 392 | ax.set_xticks([]) 393 | ax.set_yticks([]) 394 | if i < C: 395 | ax.imshow(color_code_target(targets[i], label_colors)) # 从1开始,直接取最大 396 | for sps in all_sps[i]['sps']: 397 | ax.annotate('{:.3f}'.format(sps['iou']), 398 | xy=(sps['centroid'][1], sps['centroid'][0]), fontsize=8, 399 | xycoords='data', xytext=(2, -10), textcoords='offset points', 400 | fontweight='bold', 401 | bbox=dict(boxstyle='round, pad=0.3', # linewidth=0 可以不显示边框 402 | alpha=0.5, 403 | facecolor=[c / 255 for c in [255, 255, 255]], lw=0), 404 | color='b') 405 | ax.set_title(label_names[i] + f'({np.max(targets[i])})') 406 | else: 407 | ax.axis('off') 408 | 409 | if title: 410 | plt.suptitle(title) 411 | 412 | f.tight_layout() # 调整整体空白 413 | 414 | if save_path: 415 | plt.savefig(save_path) 416 | 417 | # plt.show() 418 | 419 | 420 | def plt_cmp_top_errors(targets, predicts, errors, label_names, label_colors, top_idxs, 421 | title=None, save_path=None): 422 | rows, cols = 3, len(top_idxs) 423 | f, axs = plt.subplots(nrows=rows, ncols=cols) 424 | f.set_size_inches((10, 6)) 425 | 426 | for i in range(rows): 427 | for j in range(cols): 428 | ax = axs.flat[cols * i + j] 429 | ax.set_xticks([]) 430 | ax.set_yticks([]) 431 | idx = top_idxs[j] 432 | if i == 0: 433 | ax.imshow(color_code_target(targets[idx], label_colors)) 434 | ax.set_title(label_names[idx], fontsize=10) 435 | if i == 1: 436 | ax.imshow(color_code_target(predicts[idx], label_colors)) 437 | ax.set_title(label_names[idx], fontsize=10) 438 | if i == 2: # error score 439 | ax.imshow(errors[idx], cmap='jet') 440 | ax.set_title(label_names[idx], fontsize=10) 441 | 442 | if title: 443 | plt.suptitle(title) 444 | 445 | f.tight_layout() # 调整整体空白 446 | 447 | if save_path: 448 | plt.savefig(save_path) 449 | 450 | plt.show() 451 | 452 | 453 | from utils.misc import minmax_normalize 454 | 455 | 456 | def plt_backbone_features(features, label_names): 457 | # 做出 backbone C features 或 target C features 458 | 459 | C = features.shape[0] 460 | 461 | rows, cols = 3, 4 462 | f, axs = plt.subplots(nrows=rows, ncols=cols) 463 | f.set_size_inches((10, 6)) 464 | 465 | for i in range(rows * cols): 466 | ax = axs.flat[i] 467 | ax.set_xticks([]) 468 | ax.set_yticks([]) 469 | if i < C: 470 | ax.imshow(minmax_normalize(features[i]), cmap='jet') 471 | ax.set_title(label_names[i]) 472 | else: 473 | ax.axis('off') 474 | 475 | plt.show() 476 | 477 | 478 | def plt_all(target, pred, label_colors, 479 | target_error_mask, pred_error_mask, right_mask, error_mask, title=None): 480 | f, axs = plt.subplots(nrows=2, ncols=3) 481 | f.set_size_inches((10, 5)) 482 | 483 | ax1, ax2, ax3, ax4, ax5, ax6 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3], axs.flat[4], axs.flat[5] 484 | 485 | # semantic 486 | ax1.axis('off') 487 | ax1.imshow(color_code_target(target, label_colors)) 488 | ax1.set_title('target') 489 | 490 | ax4.axis('off') 491 | ax4.imshow(color_code_target(pred, label_colors)) 492 | ax4.set_title('predict') 493 | 494 | # mask 495 | ax2.set_xticks([]) 496 | ax2.set_yticks([]) 497 | ax2.imshow(color_code_target(target_error_mask, label_colors)) 498 | ax2.set_title('target_error_mask') 499 | 500 | ax3.axis('off') 501 | ax3.imshow(pred_error_mask, cmap='jet') 502 | ax3.set_title('pred_error_mask') 503 | 504 | ax5.set_xticks([]) 505 | ax5.set_yticks([]) 506 | ax5.imshow(right_mask, cmap='gray') 507 | ax5.set_title('right_mask') 508 | 509 | ax6.set_xticks([]) 510 | ax6.set_yticks([]) 511 | ax6.imshow(error_mask, cmap='gray') 512 | ax6.set_title('error_mask') 513 | 514 | if title: 515 | plt.suptitle(title) 516 | 517 | plt.show() 518 | 519 | 520 | def plt_att(target, pred, label_colors, atten, 521 | target_error_mask, pred_error_mask, title=None): 522 | f, axs = plt.subplots(nrows=2, ncols=3) 523 | f.set_size_inches((10, 6)) 524 | 525 | ax1, ax2, ax3 = axs.flat[0], axs.flat[1], axs.flat[2] 526 | ax4, ax5, ax6 = axs.flat[3], axs.flat[4], axs.flat[5] 527 | 528 | # semantic 529 | ax1.axis('off') 530 | ax1.imshow(color_code_target(target, label_colors)) 531 | ax1.set_title('target') 532 | 533 | # mask 534 | ax2.axis('off') 535 | ax2.imshow(color_code_target(target_error_mask, label_colors)) 536 | ax2.set_title('target_error_mask') 537 | 538 | # att 539 | if atten is not None: 540 | ax3.axis('off') 541 | ax3.imshow(atten, cmap='jet') 542 | ax3.set_title('attention') 543 | 544 | # predict 545 | ax4.axis('off') 546 | ax4.imshow(color_code_target(pred, label_colors)) 547 | ax4.set_title('predict') 548 | 549 | # error mask 550 | ax5.axis('off') 551 | ax5.imshow(pred_error_mask, cmap='jet') 552 | ax5.set_title('pred_error_mask') 553 | 554 | ax6.axis('off') 555 | 556 | if title: 557 | plt.suptitle(title) 558 | 559 | plt.show() 560 | 561 | 562 | def plt_all_atten(target, pred, label_colors, atten, 563 | target_error_mask, pred_error_mask, right_mask, error_mask, title=None): 564 | f, axs = plt.subplots(nrows=2, ncols=4) 565 | f.set_size_inches((12, 6)) 566 | 567 | ax1, ax2, ax3, ax4 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3] 568 | ax5, ax6, ax7, ax8 = axs.flat[4], axs.flat[5], axs.flat[6], axs.flat[7] 569 | 570 | # semantic 571 | ax1.axis('off') 572 | ax1.imshow(color_code_target(target, label_colors)) 573 | ax1.set_title('target') 574 | 575 | # mask 576 | ax2.axis('off') 577 | ax2.imshow(color_code_target(target_error_mask, label_colors)) 578 | ax2.set_title('target_error_mask') 579 | 580 | # ax2.set_xticks([]) 581 | # ax2.set_yticks([]) 582 | # ax2.imshow(target_error_mask, cmap='gray') 583 | # ax2.set_title('target_error_mask') 584 | 585 | ax3.axis('off') 586 | ax3.imshow(pred_error_mask, cmap='jet') 587 | ax3.set_title('pred_error_mask') 588 | 589 | ax4.axis('off') 590 | ax4.imshow(atten, cmap='jet') 591 | ax4.set_title('attention') 592 | 593 | ax5.axis('off') 594 | ax5.imshow(color_code_target(pred, label_colors)) 595 | ax5.set_title('predict') 596 | 597 | ax6.set_xticks([]) 598 | ax6.set_yticks([]) 599 | ax6.imshow(right_mask, cmap='gray') 600 | ax6.set_title('right_mask') 601 | 602 | ax7.set_xticks([]) 603 | ax7.set_yticks([]) 604 | ax7.imshow(error_mask, cmap='gray') 605 | ax7.set_title('error_mask') 606 | 607 | ax8.axis('off') 608 | 609 | if title: 610 | plt.suptitle(title) 611 | 612 | plt.show() 613 | 614 | 615 | def plt_all_v2(target, pred, label_colors, 616 | target_error_mask, pred_error_mask, 617 | thre_right_mask, thre_error_mask, 618 | right_mask, error_mask, 619 | title=None, save_path=None): 620 | f, axs = plt.subplots(nrows=2, ncols=4) 621 | f.set_size_inches((12, 6)) 622 | 623 | ax1, ax2, ax3, ax4 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3] 624 | ax5, ax6, ax7, ax8 = axs.flat[4], axs.flat[5], axs.flat[6], axs.flat[7] 625 | 626 | # semantic 627 | ax1.axis('off') 628 | ax1.imshow(color_code_target(target, label_colors)) 629 | ax1.set_title('target') 630 | 631 | ax5.axis('off') 632 | ax5.imshow(color_code_target(pred, label_colors)) 633 | ax5.set_title('predict') 634 | 635 | # mask 636 | ax2.set_xticks([]) 637 | ax2.set_yticks([]) 638 | ax2.imshow(target_error_mask, cmap='gray') 639 | ax2.set_title('target_error_mask') 640 | 641 | ax6.axis('off') 642 | ax6.imshow(pred_error_mask, cmap='jet') 643 | ax6.set_title('pred_error_mask') 644 | 645 | # thre 646 | ax3.set_xticks([]) 647 | ax3.set_yticks([]) 648 | ax3.imshow(thre_right_mask, cmap='gray') 649 | ax3.set_title('thre_right_mask') 650 | 651 | ax7.set_xticks([]) 652 | ax7.set_yticks([]) 653 | ax7.imshow(thre_error_mask, cmap='gray') 654 | ax7.set_title('thre_error_mask') 655 | 656 | # 0.5 657 | ax4.set_xticks([]) 658 | ax4.set_yticks([]) 659 | ax4.imshow(right_mask, cmap='gray') 660 | ax4.set_title('right_mask') 661 | 662 | ax8.set_xticks([]) 663 | ax8.set_yticks([]) 664 | ax8.imshow(error_mask, cmap='gray') 665 | ax8.set_title('error_mask') 666 | 667 | if title: 668 | plt.suptitle(title) 669 | 670 | f.tight_layout() # 调整整体空白 671 | 672 | if save_path: 673 | plt.savefig(save_path) 674 | 675 | plt.show() 676 | 677 | 678 | def plt_error_mask(error_mask, save_path=None): 679 | plt.figure(figsize=(4, 3), dpi=200) 680 | 681 | plt.axis('off') 682 | plt.imshow(error_mask, cmap='jet') 683 | 684 | if save_path: 685 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.) 686 | 687 | plt.show() 688 | 689 | 690 | def plt_img_target_error(img, target, error_mask, label_colors, save_path=None, title=None): 691 | f, axs = plt.subplots(nrows=1, ncols=3, dpi=100) 692 | f.set_size_inches((8, 3)) 693 | 694 | ax1, ax2, ax3 = axs.flat[0], axs.flat[1], axs.flat[2] 695 | ax1.axis('off') 696 | ax1.imshow(img) 697 | ax1.set_title('img') 698 | 699 | # ax2.axis('off') 700 | ax2.set_xticks([]) 701 | ax2.set_yticks([]) 702 | ax2.imshow(color_code_target(target, label_colors)) 703 | ax2.set_title('target') 704 | 705 | # ax3.axis('off') 706 | ax3.set_xticks([]) 707 | ax3.set_yticks([]) 708 | ax3.imshow(error_mask, cmap='gray') 709 | ax3.set_title('crop score') 710 | 711 | f.tight_layout() # 调整整体空白 712 | plt.subplots_adjust(wspace=0.02) # 调整子图间距(inch),存储时能看到调节了间距 713 | 714 | if title: 715 | plt.suptitle(title) 716 | 717 | if save_path: 718 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.) 719 | 720 | plt.show() 721 | 722 | 723 | def plt_cmp(img, lc, ms, entro, dropout, error_mask, save_path=None): 724 | f, axs = plt.subplots(nrows=1, ncols=6, dpi=200) 725 | f.set_size_inches((20, 3)) 726 | 727 | maps = [img, lc, ms, entro, dropout, error_mask] 728 | titles = ['Image', 'LC', 'MS', 'Entropy', 'Dropout', 'Ours'] 729 | 730 | ax = axs.flat[0] 731 | ax.axis('off') 732 | ax.imshow(maps[0]) 733 | # ax.set_title(titles[0]) 734 | 735 | for i in range(1, 6): 736 | ax = axs.flat[i] 737 | ax.axis('off') 738 | if maps[i] is not None: 739 | ax.imshow(maps[i], cmap='jet') 740 | # ax.set_title(titles[i]) 741 | 742 | f.tight_layout() # 调整整体空白 743 | plt.subplots_adjust(wspace=0.04) # 调整子图间距(inch),存储时能看到调节了间距 744 | 745 | if save_path: 746 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.) 747 | 748 | plt.show() 749 | 750 | 751 | def plt_cmp_v2(img, lc, ms, en, dr, pred_error_mask, save_path=None): 752 | f, axs = plt.subplots(nrows=2, ncols=5, dpi=200) 753 | f.set_size_inches((16, 5)) 754 | 755 | # img 756 | ax = axs.flat[0] 757 | ax.axis('off') 758 | ax.imshow(img) 759 | 760 | # norm uncer map 761 | lc, ms, en = minmax_normalize(lc), minmax_normalize(ms), minmax_normalize(en) 762 | dr = minmax_normalize(dr) 763 | 764 | maps = [lc, ms, en, dr] 765 | 766 | for i in range(len(maps)): 767 | ax = axs.flat[i + 1] 768 | ax.axis('off') 769 | ax.imshow(maps[i], cmap='jet') 770 | 771 | # pred error mask 772 | ax = axs.flat[5] 773 | ax.axis('off') 774 | ax.imshow(pred_error_mask, cmap='jet') 775 | 776 | # semantic attention uncer_map, and normlize 777 | for i in range(len(maps)): 778 | ax = axs.flat[i + 6] 779 | ax.axis('off') 780 | att_uncer_map = minmax_normalize(maps[i] * pred_error_mask) 781 | # att_uncer_map = maps[i] * pred_error_mask 782 | ax.imshow(att_uncer_map, cmap='jet') 783 | 784 | f.tight_layout() # 调整整体空白 785 | plt.subplots_adjust(wspace=0.04) # 调整子图间距(inch),存储时能看到调节了间距 786 | 787 | if save_path: 788 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.) 789 | 790 | plt.show() 791 | 792 | 793 | def plt_compare(target, pred, label_colors, 794 | target_error_mask, pred_error_mask, lc, ms, en, mc_droput=None): 795 | # 不同方法得到的 uncertain map 对比 796 | f, axs = plt.subplots(nrows=2, ncols=4) 797 | f.set_size_inches((12, 6)) 798 | 799 | ax1, ax2, ax3, ax4 = axs.flat[0], axs.flat[1], axs.flat[2], axs.flat[3] 800 | ax5, ax6, ax7, ax8 = axs.flat[4], axs.flat[5], axs.flat[6], axs.flat[7] 801 | 802 | # semantic 803 | ax1.axis('off') 804 | ax1.imshow(color_code_target(target, label_colors)) 805 | ax1.set_title('target') 806 | 807 | ax5.axis('off') 808 | ax5.imshow(color_code_target(pred, label_colors)) 809 | ax5.set_title('predict') 810 | 811 | # mask 812 | ax2.set_xticks([]) 813 | ax2.set_yticks([]) 814 | ax2.imshow(target_error_mask, cmap='gray') 815 | ax2.set_title('target_error_mask') 816 | 817 | ax6.axis('off') 818 | ax6.imshow(pred_error_mask, cmap='jet') 819 | ax6.set_title('pred_error_mask') 820 | 821 | # compare 822 | ax3.axis('off') 823 | ax3.imshow(lc, cmap='jet') 824 | ax3.set_title('least confidence') 825 | 826 | ax4.axis('off') 827 | ax4.imshow(ms, cmap='jet') 828 | ax4.set_title('margin sampling') 829 | 830 | ax7.axis('off') 831 | ax7.imshow(en, cmap='jet') 832 | ax7.set_title('entropy') 833 | 834 | ax8.axis('off') 835 | ax8.set_title('mc droput') 836 | 837 | plt.show() 838 | --------------------------------------------------------------------------------