├── PAM ├── models │ ├── __init__py │ └── classifier.py ├── run_PAM.sh ├── utils │ ├── avgMeter.py │ ├── my_optim.py │ ├── LoadData.py │ ├── Metrics.py │ └── transforms │ │ └── functional.py ├── README.md ├── point_extraction.py └── train.py ├── figures ├── overview.png ├── result_coco.png ├── result_voc.png ├── PAM_comparison.png ├── PAM_architecture.png └── qualitavie_result.png ├── models ├── imagenet │ └── README.md ├── hrnet_config │ ├── __init__.py │ ├── models.py │ ├── w32_384x288_adam_lr1e-3.yaml │ ├── w48_384x288_adam_lr1e-3.yaml │ └── default.py ├── __init__.py ├── resnet.py ├── panoptic_deeplab.py └── hrnet.py ├── scripts ├── run_point_labels.sh └── run_image_labels.sh ├── LICENSE ├── utils ├── crf.py ├── my_optim.py ├── loss.py ├── decode.py ├── LoadData.py └── transforms │ └── transforms.py ├── .gitignore ├── NOTICE ├── README.md └── main.py /PAM/models/__init__py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/overview.png -------------------------------------------------------------------------------- /figures/result_coco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/result_coco.png -------------------------------------------------------------------------------- /figures/result_voc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/result_voc.png -------------------------------------------------------------------------------- /figures/PAM_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/PAM_comparison.png -------------------------------------------------------------------------------- /figures/PAM_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/PAM_architecture.png -------------------------------------------------------------------------------- /figures/qualitavie_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/qualitavie_result.png -------------------------------------------------------------------------------- /models/imagenet/README.md: -------------------------------------------------------------------------------- 1 | Download HRNet ImageNet pretrained weight files to this path ([GoogleDrive](https://drive.google.com/drive/folders/1E6j6W7RqGhW1o7UHgiQ9X4g8fVJRU9TX)) ([OneDrive](https://1drv.ms/f/s!AhIXJn_J-blW231MH2krnmLq5kkQ)). 2 | The weight file is from https://github.com/HRNet/HRNet-Human-Pose-Estimation/tree/master . 3 | 4 | - hrnet_w32-36af842e.pth 5 | - hrnet_w48-8ef0771d.pth -------------------------------------------------------------------------------- /models/hrnet_config/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from .default import _C as cfg 8 | from .default import update_config 9 | from .models import MODEL_EXTRAS 10 | -------------------------------------------------------------------------------- /PAM/run_PAM.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python train.py \ 2 | --root_dir=your_dataset_root_path \ 3 | --lr=0.001 \ 4 | --epoch=15 \ 5 | --decay_points='5,10' \ 6 | --alpha=0.7 \ 7 | --save_folder=checkpoints/PAM \ 8 | --show_interval=50 9 | 10 | CUDA_VISIBLE_DEVICES=0 python point_extraction.py \ 11 | --root_dir=your_dataset_root_path \ 12 | --alpha=0.7 \ 13 | --checkpoint=checkpoints/PAM/ckpt_15.pth \ 14 | --save_dir Peak_points 15 | -------------------------------------------------------------------------------- /PAM/utils/avgMeter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self): 4 | self.reset() 5 | 6 | def reset(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def update(self, val, n=1): 13 | self.val = val 14 | self.sum += val * n 15 | self.count += n 16 | self.avg = self.sum / self.count 17 | 18 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .panoptic_deeplab import PanopticDeepLab 2 | from .hrnet import hrnet32, hrnet48 3 | from .resnet import resnet50, resnet101 4 | 5 | def model_factory(args): 6 | 7 | if args.backbone == 'hrnet32': 8 | model = hrnet32(args) 9 | elif args.backbone == 'hrnet48': 10 | model = hrnet48(args) 11 | 12 | elif args.backbone == 'resnet50': 13 | backbone = resnet50(pretrained=True) 14 | model = PanopticDeepLab(backbone, args) 15 | 16 | elif args.backbone == 'resnet101': 17 | backbone = resnet101(pretrained=True) 18 | model = PanopticDeepLab(backbone, args) 19 | 20 | return model 21 | -------------------------------------------------------------------------------- /scripts/run_point_labels.sh: -------------------------------------------------------------------------------- 1 | # Training BESTIE with point labels. 2 | 3 | ROOT=your_dataset_root_path 4 | SUP=point 5 | REFINE_WARMUP=0 6 | SIZE=416 7 | BATCH=16 8 | WORKERS=4 9 | TRAIN_ITERS=50000 10 | BACKBONE=hrnet48 # [resnet50, resnet101, hrnet32, hrnet48] 11 | VAL_IGNORE=False 12 | 13 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nnodes=1 --nproc_per_node=4 main.py \ 14 | --root_dir ${ROOT} --sup ${SUP} --batch_size ${BATCH} --num_workers ${WORKERS} --crop_size ${SIZE} \ 15 | --train_iter ${TRAIN_ITERS} --refine True --refine_iter ${REFINE_WARMUP} \ 16 | --val_freq 1000 --val_thresh 0.1 --val_ignore ${VAL_IGNORE} --val_clean False --val_flip True \ 17 | --seg_weight 1.0 --center_weight 200.0 --offset_weight 0.01 \ 18 | --lr 5e-5 --backbone ${BACKBONE} --random_seed 3407 19 | -------------------------------------------------------------------------------- /scripts/run_image_labels.sh: -------------------------------------------------------------------------------- 1 | # Training BESTIE with image-level labels. 2 | 3 | ROOT=your_dataset_root_path 4 | SUP=cls 5 | PSEUDO_THRESH=0.7 6 | REFINE_THRESH=0.3 7 | REFINE_WARMUP=0 8 | SIZE=416 9 | BATCH=16 10 | WORKERS=4 11 | TRAIN_ITERS=50000 12 | BACKBONE=hrnet48 # [resnet50, resnet101, hrnet32, hrnet48] 13 | VAL_IGNORE=False 14 | 15 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nnodes=1 --nproc_per_node=4 main.py \ 16 | --root_dir ${ROOT} --sup ${SUP} --batch_size ${BATCH} --num_workers ${WORKERS} --crop_size ${SIZE} --train_iter ${TRAIN_ITERS} \ 17 | --refine True --refine_iter ${REFINE_WARMUP} --pseudo_thresh ${PSEUDO_THRESH} --refine_thresh ${REFINE_THRESH} \ 18 | --val_freq 1000 --val_ignore ${VAL_IGNORE} --val_clean False --val_flip False \ 19 | --seg_weight 1.0 --center_weight 200.0 --offset_weight 0.01 \ 20 | --lr 5e-5 --backbone ${BACKBONE} --random_seed 3407 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BESTIE 2 | Copyright (c) 2022-present NAVER Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /PAM/README.md: -------------------------------------------------------------------------------- 1 | # Peak Attention Module (PAM) 2 | 3 | ## Abtract 4 | 5 | CAMs have a limitation in obtaining the accurate instance cue because several instance cues might be extracted in a single instance due to noisy activation regions as illustrated in the bellow Figure. 6 | It disturbs the generation of pseudo instance labels. 7 | To address this limitation, we propose a peak attention module (PAM) to extract one appropriate instance cue per instance. 8 | PAM aims to strengthen the attention on peak regions, while weakening the attention on noisy activation regions. 9 | 10 | 11 | 12 | 13 | 14 | 15 | ## How to Run? 16 | 17 | ``` 18 | # change the data ROOT in the shell script 19 | bash run_PAM.sh 20 | ``` 21 | 22 | * Note that extracted peak points are used in the image-level supervised BESTIE. 23 | * We provide the weight for the pretrained classfier with PAM module [[download]](https://drive.google.com/file/d/1I5DocPV2Lkc59DtDrr4XoQuVlKdRi4km/view?usp=sharing) 24 | 25 | ## Acknowledgement 26 | 27 | Our implementation is based on these repositories: 28 | - (DRS) https://github.com/qjadud1994/DRS 29 | -------------------------------------------------------------------------------- /utils/crf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | from pydensecrf.utils import unary_from_labels, unary_from_softmax 4 | 5 | def crf_inference_label(img, labels, t=10, n_labels=21, gt_prob=0.7): 6 | 7 | h, w = img.shape[:2] 8 | 9 | d = dcrf.DenseCRF2D(w, h, n_labels) 10 | 11 | unary = unary_from_labels(labels, n_labels, gt_prob=gt_prob, zero_unsure=False) 12 | 13 | d.setUnaryEnergy(unary) 14 | d.addPairwiseGaussian(sxy=3, compat=3) 15 | d.addPairwiseBilateral(sxy=50, srgb=5, rgbim=np.ascontiguousarray(np.copy(img)), compat=10) 16 | 17 | q = d.inference(t) 18 | 19 | return np.argmax(np.array(q).reshape((n_labels, h, w)), axis=0) 20 | 21 | 22 | class DenseCRF(object): 23 | def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std): 24 | self.iter_max = iter_max 25 | self.pos_w = pos_w 26 | self.pos_xy_std = pos_xy_std 27 | self.bi_w = bi_w 28 | self.bi_xy_std = bi_xy_std 29 | self.bi_rgb_std = bi_rgb_std 30 | 31 | def __call__(self, image, probmap): 32 | C, H, W = probmap.shape 33 | 34 | U = unary_from_softmax(probmap) 35 | U = np.ascontiguousarray(U) 36 | 37 | image = np.ascontiguousarray(image) 38 | 39 | d = dcrf.DenseCRF2D(W, H, C) 40 | d.setUnaryEnergy(U) 41 | d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w) 42 | d.addPairwiseBilateral( 43 | sxy=self.bi_xy_std, srgb=self.bi_rgb_std, rgbim=image, compat=self.bi_w 44 | ) 45 | 46 | Q = d.inference(self.iter_max) 47 | Q = np.array(Q).reshape((C, H, W)) 48 | 49 | output = np.argmax(Q, axis=0).astype(np.uint8) 50 | 51 | return output -------------------------------------------------------------------------------- /PAM/utils/my_optim.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/utils/my_optim.py 3 | # ------------------------------------------------------------------------------ 4 | 5 | import torch.optim as optim 6 | from torch.optim.lr_scheduler import LambdaLR 7 | import numpy as np 8 | 9 | 10 | class PolyOptimizer(optim.SGD): 11 | 12 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): 13 | super().__init__(params, lr, weight_decay) 14 | self.param_groups = params 15 | self.global_step = 0 16 | self.max_step = max_step 17 | self.momentum = momentum 18 | 19 | self.__initial_lr = [group['lr'] for group in self.param_groups] 20 | 21 | 22 | def step(self, closure=None): 23 | 24 | if self.global_step < self.max_step: 25 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 26 | 27 | for i in range(len(self.param_groups)): 28 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 29 | super().step(closure) 30 | 31 | self.global_step += 1 32 | 33 | 34 | def lr_poly(base_lr, iter,max_iter,power=0.9): 35 | return base_lr*((1-float(iter)/max_iter)**(power)) 36 | 37 | def reduce_lr_poly(args, optimizer, global_iter, max_iter): 38 | base_lr = args.lr 39 | for g in optimizer.param_groups: 40 | g['lr'] = lr_poly(base_lr=base_lr, iter=global_iter, max_iter=max_iter, power=0.9) 41 | 42 | 43 | def reduce_lr(args, optimizer, epoch, factor=0.1): 44 | values = args.decay_points.strip().split(',') 45 | try: 46 | change_points = map(lambda x: int(x.strip()), values) 47 | except ValueError: 48 | change_points = None 49 | 50 | if change_points is not None and epoch in change_points: 51 | for g in optimizer.param_groups: 52 | g['lr'] = g['lr']*factor 53 | print("Reduce Learning Rate : ", epoch, g['lr']) 54 | return True -------------------------------------------------------------------------------- /models/hrnet_config/models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from yacs.config import CfgNode as CN 12 | 13 | 14 | # pose_resnet related params 15 | POSE_RESNET = CN() 16 | POSE_RESNET.NUM_LAYERS = 50 17 | POSE_RESNET.DECONV_WITH_BIAS = False 18 | POSE_RESNET.NUM_DECONV_LAYERS = 3 19 | POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256] 20 | POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4] 21 | POSE_RESNET.FINAL_CONV_KERNEL = 1 22 | POSE_RESNET.PRETRAINED_LAYERS = ['*'] 23 | 24 | # pose_multi_resoluton_net related params 25 | POSE_HIGH_RESOLUTION_NET = CN() 26 | POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] 27 | POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64 28 | POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 29 | 30 | POSE_HIGH_RESOLUTION_NET.STAGE2 = CN() 31 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 32 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 33 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] 34 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] 35 | POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' 36 | POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' 37 | 38 | POSE_HIGH_RESOLUTION_NET.STAGE3 = CN() 39 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 40 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 41 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] 42 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] 43 | POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' 44 | POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' 45 | 46 | POSE_HIGH_RESOLUTION_NET.STAGE4 = CN() 47 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 48 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 49 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 50 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 51 | POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' 52 | POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' 53 | 54 | 55 | MODEL_EXTRAS = { 56 | 'pose_resnet': POSE_RESNET, 57 | 'pose_high_resolution_net': POSE_HIGH_RESOLUTION_NET, 58 | } 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /models/hrnet_config/w32_384x288_adam_lr1e-3.yaml: -------------------------------------------------------------------------------- 1 | AUTO_RESUME: true 2 | CUDNN: 3 | BENCHMARK: true 4 | DETERMINISTIC: false 5 | ENABLED: true 6 | DATA_DIR: '' 7 | GPUS: (0,1,2,3) 8 | OUTPUT_DIR: 'output' 9 | LOG_DIR: 'log' 10 | WORKERS: 24 11 | PRINT_FREQ: 100 12 | 13 | DATASET: 14 | COLOR_RGB: true 15 | DATASET: 'coco' 16 | DATA_FORMAT: jpg 17 | FLIP: true 18 | NUM_JOINTS_HALF_BODY: 8 19 | PROB_HALF_BODY: 0.3 20 | ROOT: 'data/coco/' 21 | ROT_FACTOR: 45 22 | SCALE_FACTOR: 0.35 23 | TEST_SET: 'val2017' 24 | TRAIN_SET: 'train2017' 25 | MODEL: 26 | INIT_WEIGHTS: true 27 | NAME: pose_hrnet 28 | NUM_JOINTS: 1 29 | PRETRAINED: 'models/imagenet/hrnet_w32-36af842e.pth' 30 | TARGET_TYPE: gaussian 31 | IMAGE_SIZE: 32 | - 288 33 | - 384 34 | HEATMAP_SIZE: 35 | - 72 36 | - 96 37 | SIGMA: 3 38 | EXTRA: 39 | PRETRAINED_LAYERS: 40 | - 'conv1' 41 | - 'bn1' 42 | - 'conv2' 43 | - 'bn2' 44 | - 'layer1' 45 | - 'transition1' 46 | - 'stage2' 47 | - 'transition2' 48 | - 'stage3' 49 | - 'transition3' 50 | - 'stage4' 51 | FINAL_CONV_KERNEL: 1 52 | STAGE2: 53 | NUM_MODULES: 1 54 | NUM_BRANCHES: 2 55 | BLOCK: BASIC 56 | NUM_BLOCKS: 57 | - 4 58 | - 4 59 | NUM_CHANNELS: 60 | - 32 61 | - 64 62 | FUSE_METHOD: SUM 63 | STAGE3: 64 | NUM_MODULES: 4 65 | NUM_BRANCHES: 3 66 | BLOCK: BASIC 67 | NUM_BLOCKS: 68 | - 4 69 | - 4 70 | - 4 71 | NUM_CHANNELS: 72 | - 32 73 | - 64 74 | - 128 75 | FUSE_METHOD: SUM 76 | STAGE4: 77 | NUM_MODULES: 3 78 | NUM_BRANCHES: 4 79 | BLOCK: BASIC 80 | NUM_BLOCKS: 81 | - 4 82 | - 4 83 | - 4 84 | - 4 85 | NUM_CHANNELS: 86 | - 32 87 | - 64 88 | - 128 89 | - 256 90 | FUSE_METHOD: SUM 91 | LOSS: 92 | USE_TARGET_WEIGHT: true 93 | TRAIN: 94 | BATCH_SIZE_PER_GPU: 32 95 | SHUFFLE: true 96 | BEGIN_EPOCH: 0 97 | END_EPOCH: 210 98 | OPTIMIZER: adam 99 | LR: 0.001 100 | LR_FACTOR: 0.1 101 | LR_STEP: 102 | - 170 103 | - 200 104 | WD: 0.0001 105 | GAMMA1: 0.99 106 | GAMMA2: 0.0 107 | MOMENTUM: 0.9 108 | NESTEROV: false 109 | TEST: 110 | BATCH_SIZE_PER_GPU: 32 111 | COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json' 112 | BBOX_THRE: 1.0 113 | IMAGE_THRE: 0.0 114 | IN_VIS_THRE: 0.2 115 | MODEL_FILE: '' 116 | NMS_THRE: 1.0 117 | OKS_THRE: 0.9 118 | USE_GT_BBOX: true 119 | FLIP_TEST: true 120 | POST_PROCESS: true 121 | SHIFT_HEATMAP: true 122 | DEBUG: 123 | DEBUG: true 124 | SAVE_BATCH_IMAGES_GT: true 125 | SAVE_BATCH_IMAGES_PRED: true 126 | SAVE_HEATMAPS_GT: true 127 | SAVE_HEATMAPS_PRED: true 128 | -------------------------------------------------------------------------------- /models/hrnet_config/w48_384x288_adam_lr1e-3.yaml: -------------------------------------------------------------------------------- 1 | AUTO_RESUME: true 2 | CUDNN: 3 | BENCHMARK: true 4 | DETERMINISTIC: false 5 | ENABLED: true 6 | DATA_DIR: '' 7 | GPUS: (0,1,2,3) 8 | OUTPUT_DIR: 'output' 9 | LOG_DIR: 'log' 10 | WORKERS: 24 11 | PRINT_FREQ: 100 12 | 13 | DATASET: 14 | COLOR_RGB: true 15 | DATASET: 'coco' 16 | DATA_FORMAT: jpg 17 | FLIP: true 18 | NUM_JOINTS_HALF_BODY: 8 19 | PROB_HALF_BODY: 0.3 20 | ROOT: 'data/coco/' 21 | ROT_FACTOR: 45 22 | SCALE_FACTOR: 0.35 23 | TEST_SET: 'val2017' 24 | TRAIN_SET: 'train2017' 25 | MODEL: 26 | INIT_WEIGHTS: true 27 | NAME: pose_hrnet 28 | NUM_JOINTS: 1 29 | PRETRAINED: 'models/imagenet/hrnet_w48-8ef0771d.pth' 30 | TARGET_TYPE: gaussian 31 | IMAGE_SIZE: 32 | - 288 33 | - 384 34 | HEATMAP_SIZE: 35 | - 72 36 | - 96 37 | SIGMA: 3 38 | EXTRA: 39 | PRETRAINED_LAYERS: 40 | - 'conv1' 41 | - 'bn1' 42 | - 'conv2' 43 | - 'bn2' 44 | - 'layer1' 45 | - 'transition1' 46 | - 'stage2' 47 | - 'transition2' 48 | - 'stage3' 49 | - 'transition3' 50 | - 'stage4' 51 | FINAL_CONV_KERNEL: 1 52 | STAGE2: 53 | NUM_MODULES: 1 54 | NUM_BRANCHES: 2 55 | BLOCK: BASIC 56 | NUM_BLOCKS: 57 | - 4 58 | - 4 59 | NUM_CHANNELS: 60 | - 48 61 | - 96 62 | FUSE_METHOD: SUM 63 | STAGE3: 64 | NUM_MODULES: 4 65 | NUM_BRANCHES: 3 66 | BLOCK: BASIC 67 | NUM_BLOCKS: 68 | - 4 69 | - 4 70 | - 4 71 | NUM_CHANNELS: 72 | - 48 73 | - 96 74 | - 192 75 | FUSE_METHOD: SUM 76 | STAGE4: 77 | NUM_MODULES: 3 78 | NUM_BRANCHES: 4 79 | BLOCK: BASIC 80 | NUM_BLOCKS: 81 | - 4 82 | - 4 83 | - 4 84 | - 4 85 | NUM_CHANNELS: 86 | - 48 87 | - 96 88 | - 192 89 | - 384 90 | FUSE_METHOD: SUM 91 | LOSS: 92 | USE_TARGET_WEIGHT: true 93 | TRAIN: 94 | BATCH_SIZE_PER_GPU: 24 95 | SHUFFLE: true 96 | BEGIN_EPOCH: 0 97 | END_EPOCH: 210 98 | OPTIMIZER: adam 99 | LR: 0.001 100 | LR_FACTOR: 0.1 101 | LR_STEP: 102 | - 170 103 | - 200 104 | WD: 0.0001 105 | GAMMA1: 0.99 106 | GAMMA2: 0.0 107 | MOMENTUM: 0.9 108 | NESTEROV: false 109 | TEST: 110 | BATCH_SIZE_PER_GPU: 24 111 | COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json' 112 | BBOX_THRE: 1.0 113 | IMAGE_THRE: 0.0 114 | IN_VIS_THRE: 0.2 115 | MODEL_FILE: '' 116 | NMS_THRE: 1.0 117 | OKS_THRE: 0.9 118 | USE_GT_BBOX: true 119 | FLIP_TEST: true 120 | POST_PROCESS: true 121 | SHIFT_HEATMAP: true 122 | DEBUG: 123 | DEBUG: true 124 | SAVE_BATCH_IMAGES_GT: true 125 | SAVE_BATCH_IMAGES_PRED: true 126 | SAVE_HEATMAPS_GT: true 127 | SAVE_HEATMAPS_PRED: true 128 | -------------------------------------------------------------------------------- /PAM/utils/LoadData.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/utils/LoadData.py 3 | # ------------------------------------------------------------------------------ 4 | 5 | from .transforms import transforms 6 | from torch.utils.data import DataLoader 7 | import torch 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | import os 11 | from PIL import Image 12 | 13 | def train_data_loader(args): 14 | mean_vals = [0.485, 0.456, 0.406] 15 | std_vals = [0.229, 0.224, 0.225] 16 | 17 | input_size = int(args.input_size) 18 | crop_size = int(args.crop_size) 19 | tsfm_train = transforms.Compose( 20 | [ 21 | transforms.Resize(input_size), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), 24 | transforms.RandomCrop(crop_size), 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean_vals, std_vals), 27 | ] 28 | ) 29 | 30 | train_list = os.path.join(args.root_dir, "ImageSets/Segmentation/train_cls.txt") 31 | 32 | img_train = VOCDataset(train_list, crop_size, root_dir=args.root_dir, num_classes=args.num_classes, transform=tsfm_train, mode='train') 33 | 34 | train_loader = DataLoader(img_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 35 | 36 | return train_loader 37 | 38 | 39 | def test_data_loader(args): 40 | mean_vals = [0.485, 0.456, 0.406] 41 | std_vals = [0.229, 0.224, 0.225] 42 | 43 | input_size = int(args.input_size) 44 | crop_size = int(args.crop_size) 45 | 46 | tsfm_test = transforms.Compose( 47 | [ 48 | transforms.Resize(crop_size), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean_vals, std_vals), 51 | ] 52 | ) 53 | 54 | test_list = os.path.join(args.root_dir, "ImageSets/Segmentation/train_cls.txt") 55 | 56 | img_test = VOCDataset(test_list, crop_size, root_dir=args.root_dir, num_classes=args.num_classes, transform=tsfm_test, mode='test') 57 | 58 | test_loader = DataLoader(img_test, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 59 | 60 | return test_loader 61 | 62 | 63 | 64 | class VOCDataset(Dataset): 65 | def __init__(self, datalist_file, input_size, root_dir, num_classes=20, transform=None, mode='train'): 66 | self.root_dir = root_dir 67 | self.mode = mode 68 | self.datalist_file = datalist_file 69 | self.transform = transform 70 | self.num_classes = num_classes 71 | 72 | self.image_list, self.label_list = self.read_labeled_image_list(self.root_dir, self.datalist_file) 73 | 74 | 75 | def __len__(self): 76 | return len(self.image_list) 77 | 78 | 79 | def __getitem__(self, idx): 80 | img_name = self.image_list[idx] 81 | image = Image.open(img_name).convert('RGB') 82 | 83 | meta = {"img_name": img_name, "ori_size": image.size} 84 | 85 | if self.transform is not None: 86 | image = self.transform(image) 87 | 88 | return image, self.label_list[idx], meta 89 | 90 | 91 | def read_labeled_image_list(self, data_dir, data_list): 92 | img_dir = os.path.join(data_dir, "JPEGImages") 93 | 94 | with open(data_list, 'r') as f: 95 | lines = f.readlines() 96 | 97 | img_name_list = [] 98 | img_labels = [] 99 | 100 | for line in lines: 101 | fields = line.strip().split() 102 | image = fields[0] + '.jpg' 103 | 104 | labels = np.zeros((self.num_classes,), dtype=np.float32) 105 | for i in range(len(fields)-1): 106 | index = int(fields[i+1]) 107 | labels[index] = 1. 108 | 109 | img_name_list.append(os.path.join(img_dir, image)) 110 | img_labels.append(labels) 111 | 112 | return img_name_list, img_labels 113 | -------------------------------------------------------------------------------- /PAM/point_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | BESTIE 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | 7 | import os 8 | import numpy as np 9 | import torch 10 | import argparse 11 | import torch.nn as nn 12 | 13 | from models.classifier import vgg16_pam 14 | from utils.LoadData import test_data_loader 15 | 16 | def get_arguments(): 17 | parser = argparse.ArgumentParser(description='PAM pytorch implement') 18 | parser.add_argument("--root_dir", type=str, default='', help='Directory of training images') 19 | parser.add_argument("--dataset", type=str, default='voc') 20 | parser.add_argument("--batch_size", type=int, default=1) 21 | parser.add_argument("--input_size", type=int, default=384) 22 | parser.add_argument("--crop_size", type=int, default=321) 23 | parser.add_argument("--num_classes", type=int, default=20) 24 | parser.add_argument("--num_workers", type=int, default=2) 25 | parser.add_argument('--checkpoint', default='checkpoints/PAM/ckpt_15.pth', help='Location to save checkpoint file') 26 | parser.add_argument('--save_dir', default='Peak_points', help='save dir for peak points') 27 | parser.add_argument("--alpha", type=float, default=0.7, help='hyperparameter for PAM (controller)') 28 | parser.add_argument("--conf_thresh", type=float, default=0.1, help='peak threshold') 29 | 30 | return parser.parse_args() 31 | 32 | 33 | def peak_extract(heat, kernel=5, K=25): 34 | B, C, H, W = heat.size() 35 | 36 | pad = (kernel - 1) // 2 37 | 38 | hmax = torch.nn.functional.max_pool2d( 39 | heat, (kernel, kernel), stride=1, padding=pad) 40 | 41 | keep = (hmax == heat).float() 42 | 43 | peak = heat * keep 44 | 45 | topk_scores, topk_inds = torch.topk(peak.view(B, C, -1), K) 46 | 47 | topk_inds = topk_inds % (H * W) 48 | topk_ys = (topk_inds / W).int().float() 49 | topk_xs = (topk_inds % W).int().float() 50 | 51 | topk_scores = topk_scores[0].float().detach().cpu().numpy() 52 | topk_ys = topk_ys[0].int().detach().cpu().numpy() 53 | topk_xs = topk_xs[0].int().detach().cpu().numpy() 54 | 55 | return topk_scores, topk_ys, topk_xs 56 | 57 | 58 | def smoothing(heat, kernel=3): 59 | pad = (kernel - 1) // 2 60 | heat = torch.nn.functional.avg_pool2d(heat, (kernel, kernel), stride=1, padding=pad) 61 | 62 | return heat 63 | 64 | 65 | if __name__ == '__main__': 66 | args = get_arguments() 67 | os.makedirs(args.save_dir, exist_ok=True) 68 | 69 | model = vgg16_pam(alpha=args.alpha) 70 | 71 | state = torch.load(args.checkpoint, map_location='cpu') 72 | model.load_state_dict(state['model'], strict=True) 73 | 74 | model.eval() 75 | model.cuda() 76 | 77 | data_loader = test_data_loader(args) 78 | 79 | with torch.no_grad(): 80 | 81 | for idx, (img, label, meta) in enumerate(data_loader): 82 | img_name = meta['img_name'][0] 83 | ori_W, ori_H = int(meta['ori_size'][0]), int(meta['ori_size'][1]) 84 | print("[%03d/%03d] %s" % (idx, len(data_loader), img_name), end='\r') 85 | 86 | label = label.to('cuda', non_blocking=True) 87 | img = img.to('cuda', non_blocking=True) 88 | 89 | # flip TTA 90 | _img = torch.cat( [img, img.flip(-1)] , dim=0) 91 | _label = torch.cat( [label, label] , dim=0) 92 | 93 | _, cam = model(_img, _label, (ori_H, ori_W)) 94 | 95 | cam = (cam[0:1] + cam[1:2].flip(-1)) / 2. 96 | cam = smoothing(cam) 97 | 98 | peak_conf, peak_y, peak_x = peak_extract(cam, kernel=15) 99 | 100 | ##################################################################### 101 | 102 | img_name = img_name.split("/")[-1][:-4] 103 | label = label[0].cpu().detach().numpy() 104 | valid_label = np.nonzero(label)[0] 105 | 106 | with open(os.path.join(args.save_dir, "%s.txt" % img_name), 'w') as peak_txt: 107 | for l in valid_label: 108 | for conf, x, y in zip(peak_conf[l], peak_x[l], peak_y[l]): 109 | if conf < args.conf_thresh: 110 | break 111 | 112 | peak_txt.write("%d %d %d %.3f\n" % (x, y, l, conf)) 113 | 114 | -------------------------------------------------------------------------------- /models/hrnet_config/default.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Copyright (c) Microsoft 4 | # Licensed under the MIT License. 5 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | 14 | from yacs.config import CfgNode as CN 15 | 16 | 17 | _C = CN() 18 | 19 | _C.OUTPUT_DIR = '' 20 | _C.LOG_DIR = '' 21 | _C.DATA_DIR = '' 22 | _C.GPUS = (0,) 23 | _C.WORKERS = 4 24 | _C.PRINT_FREQ = 20 25 | _C.AUTO_RESUME = False 26 | _C.PIN_MEMORY = True 27 | _C.RANK = 0 28 | 29 | # Cudnn related params 30 | _C.CUDNN = CN() 31 | _C.CUDNN.BENCHMARK = True 32 | _C.CUDNN.DETERMINISTIC = False 33 | _C.CUDNN.ENABLED = True 34 | 35 | # common params for NETWORK 36 | _C.MODEL = CN() 37 | _C.MODEL.NAME = 'pose_hrnet' 38 | _C.MODEL.INIT_WEIGHTS = True 39 | _C.MODEL.PRETRAINED = '' 40 | _C.MODEL.NUM_JOINTS = 1 41 | _C.MODEL.TAG_PER_JOINT = True 42 | _C.MODEL.TARGET_TYPE = 'gaussian' 43 | _C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256 44 | _C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32 45 | _C.MODEL.SIGMA = 2 46 | _C.MODEL.EXTRA = CN(new_allowed=True) 47 | 48 | _C.LOSS = CN() 49 | _C.LOSS.USE_OHKM = False 50 | _C.LOSS.TOPK = 8 51 | _C.LOSS.USE_TARGET_WEIGHT = True 52 | _C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False 53 | 54 | # DATASET related params 55 | _C.DATASET = CN() 56 | _C.DATASET.ROOT = '' 57 | _C.DATASET.DATASET = 'mpii' 58 | _C.DATASET.TRAIN_SET = 'train' 59 | _C.DATASET.TEST_SET = 'valid' 60 | _C.DATASET.DATA_FORMAT = 'jpg' 61 | _C.DATASET.HYBRID_JOINTS_TYPE = '' 62 | _C.DATASET.SELECT_DATA = False 63 | 64 | # training data augmentation 65 | _C.DATASET.FLIP = True 66 | _C.DATASET.SCALE_FACTOR = 0.25 67 | _C.DATASET.ROT_FACTOR = 30 68 | _C.DATASET.PROB_HALF_BODY = 0.0 69 | _C.DATASET.NUM_JOINTS_HALF_BODY = 8 70 | _C.DATASET.COLOR_RGB = False 71 | 72 | # train 73 | _C.TRAIN = CN() 74 | 75 | _C.TRAIN.LR_FACTOR = 0.1 76 | _C.TRAIN.LR_STEP = [90, 110] 77 | _C.TRAIN.LR = 0.001 78 | 79 | _C.TRAIN.OPTIMIZER = 'adam' 80 | _C.TRAIN.MOMENTUM = 0.9 81 | _C.TRAIN.WD = 0.0001 82 | _C.TRAIN.NESTEROV = False 83 | _C.TRAIN.GAMMA1 = 0.99 84 | _C.TRAIN.GAMMA2 = 0.0 85 | 86 | _C.TRAIN.BEGIN_EPOCH = 0 87 | _C.TRAIN.END_EPOCH = 140 88 | 89 | _C.TRAIN.RESUME = False 90 | _C.TRAIN.CHECKPOINT = '' 91 | 92 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32 93 | _C.TRAIN.SHUFFLE = True 94 | 95 | # testing 96 | _C.TEST = CN() 97 | 98 | # size of images for each device 99 | _C.TEST.BATCH_SIZE_PER_GPU = 32 100 | # Test Model Epoch 101 | _C.TEST.FLIP_TEST = False 102 | _C.TEST.POST_PROCESS = False 103 | _C.TEST.SHIFT_HEATMAP = False 104 | 105 | _C.TEST.USE_GT_BBOX = False 106 | 107 | # nms 108 | _C.TEST.IMAGE_THRE = 0.1 109 | _C.TEST.NMS_THRE = 0.6 110 | _C.TEST.SOFT_NMS = False 111 | _C.TEST.OKS_THRE = 0.5 112 | _C.TEST.IN_VIS_THRE = 0.0 113 | _C.TEST.COCO_BBOX_FILE = '' 114 | _C.TEST.BBOX_THRE = 1.0 115 | _C.TEST.MODEL_FILE = '' 116 | 117 | # debug 118 | _C.DEBUG = CN() 119 | _C.DEBUG.DEBUG = False 120 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False 121 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False 122 | _C.DEBUG.SAVE_HEATMAPS_GT = False 123 | _C.DEBUG.SAVE_HEATMAPS_PRED = False 124 | 125 | def update_config(cfg, yaml): 126 | cfg.defrost() 127 | cfg.merge_from_file(yaml) 128 | 129 | 130 | # def update_config(cfg, args): 131 | # cfg.defrost() 132 | # cfg.merge_from_file(args.cfg) 133 | # cfg.merge_from_list(args.opts) 134 | 135 | # if args.modelDir: 136 | # cfg.OUTPUT_DIR = args.modelDir 137 | 138 | # if args.logDir: 139 | # cfg.LOG_DIR = args.logDir 140 | 141 | # if args.dataDir: 142 | # cfg.DATA_DIR = args.dataDir 143 | 144 | # cfg.DATASET.ROOT = os.path.join( 145 | # cfg.DATA_DIR, cfg.DATASET.ROOT 146 | # ) 147 | 148 | # cfg.MODEL.PRETRAINED = os.path.join( 149 | # cfg.DATA_DIR, cfg.MODEL.PRETRAINED 150 | # ) 151 | 152 | # if cfg.TEST.MODEL_FILE: 153 | # cfg.TEST.MODEL_FILE = os.path.join( 154 | # cfg.DATA_DIR, cfg.TEST.MODEL_FILE 155 | # ) 156 | 157 | 158 | if __name__ == '__main__': 159 | import sys 160 | with open(sys.argv[1], 'w') as f: 161 | print(_C, file=f) 162 | 163 | -------------------------------------------------------------------------------- /PAM/train.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/scripts/train_cls.py 3 | # ------------------------------------------------------------------------------ 4 | """ 5 | BESTIE 6 | Copyright (c) 2022-present NAVER Corp. 7 | MIT License 8 | """ 9 | 10 | import os 11 | import numpy as np 12 | import torch 13 | import argparse 14 | 15 | import torch.optim as optim 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from models.classifier import vgg16_pam 20 | from utils.my_optim import reduce_lr 21 | from utils.avgMeter import AverageMeter 22 | from utils.LoadData import train_data_loader 23 | from utils.Metrics import Cls_Accuracy 24 | 25 | 26 | def get_arguments(): 27 | parser = argparse.ArgumentParser(description='PAM pytorch implement') 28 | parser.add_argument("--root_dir", type=str, default='', help='Directory of training images') 29 | parser.add_argument("--dataset", type=str, default='voc') 30 | parser.add_argument("--batch_size", type=int, default=5) 31 | parser.add_argument("--input_size", type=int, default=384) 32 | parser.add_argument("--crop_size", type=int, default=321) 33 | parser.add_argument("--num_classes", type=int, default=20) 34 | parser.add_argument("--lr", type=float, default=0.001) 35 | parser.add_argument("--weight_decay", type=float, default=0.0005) 36 | parser.add_argument("--decay_points", type=str, default='5,10') 37 | parser.add_argument("--epoch", type=int, default=15) 38 | parser.add_argument("--num_workers", type=int, default=2) 39 | parser.add_argument('--show_interval', default=50, type=int, help='interval of showing training conditions') 40 | parser.add_argument('--save_interval', default=5, type=int, help='interval of save checkpoint models') 41 | parser.add_argument('--save_folder', default='checkpoints/test', help='Location to save checkpoint models') 42 | parser.add_argument("--alpha", type=float, default=0.7, help='hyperparameter for PAM (controller)') 43 | 44 | return parser.parse_args() 45 | 46 | 47 | def get_model(args): 48 | model = vgg16_pam(pretrained=True, alpha=args.alpha) 49 | 50 | model = model.cuda() 51 | model = torch.nn.DataParallel(model).cuda() 52 | param_groups = model.module.get_parameter_groups() 53 | 54 | optimizer = optim.SGD( 55 | [ 56 | {'params': param_groups[0], 'lr': args.lr}, 57 | {'params': param_groups[1], 'lr': 2*args.lr}, 58 | {'params': param_groups[2], 'lr': 10*args.lr}, 59 | {'params': param_groups[3], 'lr': 20*args.lr} 60 | ], 61 | momentum=0.9, 62 | weight_decay=args.weight_decay, 63 | nesterov=True 64 | ) 65 | 66 | return model, optimizer 67 | 68 | 69 | def train(current_epoch): 70 | global curr_iter 71 | losses = AverageMeter() 72 | cls_acc_metric = Cls_Accuracy() 73 | 74 | model.train() 75 | 76 | """ learning rate decay """ 77 | res = reduce_lr(args, optimizer, current_epoch) 78 | 79 | for img, label, _ in train_loader: 80 | label = label.to('cuda', non_blocking=True) 81 | img = img.to('cuda', non_blocking=True) 82 | 83 | logit = model(img) 84 | 85 | """ classification loss """ 86 | loss = F.multilabel_soft_margin_loss(logit, label) 87 | 88 | """ backprop """ 89 | optimizer.zero_grad() 90 | loss.backward() 91 | optimizer.step() 92 | 93 | """ average meter """ 94 | cls_acc_metric.update(logit, label) 95 | losses.update(loss.item(), img.size()[0]) 96 | 97 | curr_iter += 1 98 | 99 | """ training log """ 100 | if curr_iter % args.show_interval == 0: 101 | cls_acc = cls_acc_metric.compute_avg_acc() 102 | 103 | print('Epoch: [{}][{}/{}] ' 104 | 'LR: {:.5f} ' 105 | 'ACC: {:.5f} ' 106 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) '.format( 107 | current_epoch, curr_iter%len(train_loader), len(train_loader), 108 | optimizer.param_groups[0]['lr'], cls_acc, loss=losses)) 109 | 110 | 111 | if __name__ == '__main__': 112 | args = get_arguments() 113 | 114 | n_gpu = torch.cuda.device_count() 115 | 116 | args.batch_size *= n_gpu 117 | args.num_workers *= n_gpu 118 | 119 | print('Running parameters:\n', args) 120 | 121 | if not os.path.exists(args.save_folder): 122 | os.makedirs(args.save_folder) 123 | 124 | train_loader = train_data_loader(args) 125 | print('# of train dataset:', len(train_loader) * args.batch_size) 126 | 127 | model, optimizer = get_model(args) 128 | 129 | curr_iter = 0 130 | for current_epoch in range(1, args.epoch+1): 131 | train(current_epoch) 132 | 133 | """ save checkpoint """ 134 | if current_epoch % args.save_interval == 0 and current_epoch > 0: 135 | print('\nSaving state, epoch : %d \n' % current_epoch) 136 | state = { 137 | 'model': model.module.state_dict(), 138 | #"optimizer": optimizer.state_dict(), 139 | 'epoch': current_epoch, 140 | 'iter': curr_iter, 141 | } 142 | model_file = args.save_folder + '/ckpt_' + repr(current_epoch) + '.pth' 143 | torch.save(state, model_file) 144 | -------------------------------------------------------------------------------- /PAM/utils/Metrics.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/utils/Metrics.py 3 | # ------------------------------------------------------------------------------ 4 | 5 | import numpy as np 6 | from sklearn.metrics import confusion_matrix 7 | import torch 8 | 9 | class Cls_Accuracy(): 10 | def __init__(self, ): 11 | self.total = 0 12 | self.correct = 0 13 | 14 | 15 | def update(self, logit, label): 16 | 17 | logit = logit.sigmoid_() 18 | logit = (logit >= 0.5) 19 | all_correct = torch.all(logit == label.byte(), dim=1).float().sum().item() 20 | 21 | self.total += logit.size(0) 22 | self.correct += all_correct 23 | 24 | def compute_avg_acc(self): 25 | return self.correct / self.total 26 | 27 | 28 | 29 | class RunningConfusionMatrix(): 30 | """Running Confusion Matrix class that enables computation of confusion matrix 31 | on the go and has methods to compute such accuracy metrics as Mean Intersection over 32 | Union MIOU. 33 | 34 | Attributes 35 | ---------- 36 | labels : list[int] 37 | List that contains int values that represent classes. 38 | overall_confusion_matrix : sklean.confusion_matrix object 39 | Container of the sum of all confusion matrices. Used to compute MIOU at the end. 40 | ignore_label : int 41 | A label representing parts that should be ignored during 42 | computation of metrics 43 | 44 | """ 45 | 46 | def __init__(self, labels, ignore_label=255): 47 | 48 | self.labels = labels 49 | self.ignore_label = ignore_label 50 | self.overall_confusion_matrix = None 51 | 52 | def update_matrix(self, ground_truth, prediction): 53 | """Updates overall confusion matrix statistics. 54 | If you are working with 2D data, just .flatten() it before running this 55 | function. 56 | Parameters 57 | ---------- 58 | groundtruth : array, shape = [n_samples] 59 | An array with groundtruth values 60 | prediction : array, shape = [n_samples] 61 | An array with predictions 62 | """ 63 | 64 | # Mask-out value is ignored by default in the sklearn 65 | # read sources to see how that was handled 66 | # But sometimes all the elements in the groundtruth can 67 | # be equal to ignore value which will cause the crush 68 | # of scikit_learn.confusion_matrix(), this is why we check it here 69 | if (ground_truth == self.ignore_label).all(): 70 | 71 | return 72 | 73 | current_confusion_matrix = confusion_matrix(y_true=ground_truth, 74 | y_pred=prediction, 75 | labels=self.labels) 76 | 77 | if self.overall_confusion_matrix is not None: 78 | 79 | self.overall_confusion_matrix += current_confusion_matrix 80 | else: 81 | 82 | self.overall_confusion_matrix = current_confusion_matrix 83 | 84 | def compute_current_mean_intersection_over_union(self): 85 | 86 | intersection = np.diag(self.overall_confusion_matrix) 87 | ground_truth_set = self.overall_confusion_matrix.sum(axis=1) 88 | predicted_set = self.overall_confusion_matrix.sum(axis=0) 89 | union = ground_truth_set + predicted_set - intersection 90 | 91 | #intersection_over_union = intersection / (union.astype(np.float32) + 1e-4) 92 | intersection_over_union = intersection / union.astype(np.float32) 93 | 94 | mean_intersection_over_union = np.mean(intersection_over_union) 95 | 96 | return mean_intersection_over_union 97 | 98 | 99 | class IOUMetric: 100 | """ 101 | Class to calculate mean-iou using fast_hist method 102 | """ 103 | 104 | def __init__(self, num_classes): 105 | self.num_classes = num_classes 106 | self.hist = np.zeros((num_classes, num_classes)) 107 | 108 | def _fast_hist(self, label_pred, label_true): 109 | mask = (label_true >= 0) & (label_true < self.num_classes) 110 | 111 | hist = np.bincount( 112 | self.num_classes*label_true[mask] + label_pred[mask], 113 | minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) 114 | 115 | return hist 116 | 117 | def add_batch(self, predictions, gts): 118 | for lp, lt in zip(predictions, gts): 119 | self.hist += self._fast_hist(lp.flatten(), lt.flatten()) 120 | 121 | def evaluate(self): 122 | acc = np.diag(self.hist).sum() / self.hist.sum() 123 | acc_cls = np.diag(self.hist) / self.hist.sum(axis=1) 124 | acc_cls = np.nanmean(acc_cls) 125 | iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist)) 126 | mean_iu = np.nanmean(iu) 127 | freq = self.hist.sum(axis=1) / self.hist.sum() 128 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 129 | cls_iu = dict(zip(range(self.num_classes), iu)) 130 | 131 | return { 132 | "Pixel_Accuracy": acc, 133 | "Mean_Accuracy": acc_cls, 134 | "Frequency_Weighted_IoU": fwavacc, 135 | "Mean_IoU": mean_iu, 136 | "Class_IoU": cls_iu, 137 | } 138 | -------------------------------------------------------------------------------- /utils/my_optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.optim as optim 4 | from bisect import bisect_right 5 | from typing import List 6 | 7 | 8 | class PolyOptimizer(optim.SGD): 9 | 10 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9): 11 | super().__init__(params, lr, weight_decay) 12 | self.param_groups = params 13 | self.global_step = 0 14 | self.max_step = max_step 15 | self.momentum = momentum 16 | 17 | self.__initial_lr = [group['lr'] for group in self.param_groups] 18 | 19 | 20 | def step(self, closure=None): 21 | 22 | if self.global_step < self.max_step: 23 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 24 | 25 | for i in range(len(self.param_groups)): 26 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 27 | super().step(closure) 28 | 29 | self.global_step += 1 30 | 31 | 32 | def lr_poly(base_lr, iter,max_iter,power=0.9): 33 | return base_lr*((1-float(iter)/max_iter)**(power)) 34 | 35 | def reduce_lr_poly(args, optimizer, global_iter, max_iter): 36 | base_lr = args.lr 37 | for g in optimizer.param_groups: 38 | g['lr'] = lr_poly(base_lr=base_lr, iter=global_iter, max_iter=max_iter, power=0.9) 39 | 40 | 41 | def reduce_lr(args, optimizer, epoch, factor=0.1): 42 | values = args.decay_points.strip().split(',') 43 | try: 44 | change_points = map(lambda x: int(x.strip()), values) 45 | except ValueError: 46 | change_points = None 47 | 48 | if change_points is not None and epoch in change_points: 49 | for g in optimizer.param_groups: 50 | g['lr'] = g['lr']*factor 51 | print("Reduce Learning Rate : ", epoch, g['lr']) 52 | return True 53 | 54 | 55 | 56 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler): 57 | def __init__( 58 | self, 59 | optimizer: torch.optim.Optimizer, 60 | max_iters: int, 61 | warmup_factor: float = 0.001, 62 | warmup_iters: int = 1000, 63 | warmup_method: str = "linear", 64 | last_epoch: int = -1, 65 | power: float = 0.9, 66 | constant_ending: float = 0. 67 | ): 68 | self.max_iters = max_iters 69 | self.warmup_factor = warmup_factor 70 | self.warmup_iters = warmup_iters 71 | self.warmup_method = warmup_method 72 | self.power = power 73 | self.constant_ending = constant_ending 74 | super().__init__(optimizer, last_epoch) 75 | 76 | def get_lr(self) -> List[float]: 77 | warmup_factor = _get_warmup_factor_at_iter( 78 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 79 | ) 80 | if self.constant_ending > 0 and warmup_factor == 1.: 81 | # Constant ending lr. 82 | if math.pow((1.0 - self.last_epoch / self.max_iters), self.power) < self.constant_ending: 83 | return [ 84 | base_lr 85 | * self.constant_ending 86 | for base_lr in self.base_lrs 87 | ] 88 | return [ 89 | base_lr 90 | * warmup_factor 91 | * math.pow((1.0 - self.last_epoch / self.max_iters), self.power) 92 | for base_lr in self.base_lrs 93 | ] 94 | 95 | def _compute_values(self) -> List[float]: 96 | # The new interface 97 | return self.get_lr() 98 | 99 | 100 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 101 | def __init__( 102 | self, 103 | optimizer: torch.optim.Optimizer, 104 | milestones: List[int], 105 | gamma: float = 0.1, 106 | warmup_factor: float = 0.001, 107 | warmup_iters: int = 1000, 108 | warmup_method: str = "linear", 109 | last_epoch: int = -1, 110 | ): 111 | if not list(milestones) == sorted(milestones): 112 | raise ValueError( 113 | "Milestones should be a list of" " increasing integers. Got {}", milestones 114 | ) 115 | self.milestones = milestones 116 | self.gamma = gamma 117 | self.warmup_factor = warmup_factor 118 | self.warmup_iters = warmup_iters 119 | self.warmup_method = warmup_method 120 | super().__init__(optimizer, last_epoch) 121 | 122 | def get_lr(self) -> List[float]: 123 | warmup_factor = _get_warmup_factor_at_iter( 124 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 125 | ) 126 | return [ 127 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 128 | for base_lr in self.base_lrs 129 | ] 130 | 131 | def _compute_values(self) -> List[float]: 132 | # The new interface 133 | return self.get_lr() 134 | 135 | 136 | def _get_warmup_factor_at_iter( 137 | method: str, iter: int, warmup_iters: int, warmup_factor: float 138 | ) -> float: 139 | """ 140 | Return the learning rate warmup factor at a specific iteration. 141 | See https://arxiv.org/abs/1706.02677 for more details. 142 | Args: 143 | method (str): warmup method; either "constant" or "linear". 144 | iter (int): iteration at which to calculate the warmup factor. 145 | warmup_iters (int): the number of warmup iterations. 146 | warmup_factor (float): the base warmup factor (the meaning changes according 147 | to the method used). 148 | Returns: 149 | float: the effective warmup factor at the given iteration. 150 | """ 151 | if iter >= warmup_iters: 152 | return 1.0 153 | 154 | if method == "constant": 155 | return warmup_factor 156 | elif method == "linear": 157 | alpha = iter / warmup_iters 158 | return warmup_factor * (1 - alpha) + alpha 159 | else: 160 | raise ValueError("Unknown warmup method: {}".format(method)) -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | BESTIE 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | 7 | import torch 8 | 9 | 10 | class L1_Loss(torch.nn.Module): 11 | ''' L1 loss for Offset map (without Instance-aware Guidance)''' 12 | def __init__(self): 13 | super(L1_Loss, self).__init__() 14 | self.l1_loss = torch.nn.L1Loss(reduction='mean') 15 | 16 | def forward(self, out, target, weight): 17 | loss = self.l1_loss(out, target) 18 | 19 | return loss 20 | 21 | 22 | class Weighted_L1_Loss(torch.nn.Module): 23 | ''' Weighted L1 loss for Offset map (with Instance-aware Guidance)''' 24 | def __init__(self): 25 | super(Weighted_L1_Loss, self).__init__() 26 | self.l1_loss = torch.nn.L1Loss(reduction='none') 27 | 28 | def forward(self, out, target, weight): 29 | loss = self.l1_loss(out, target) * weight 30 | 31 | if weight.sum() > 0: 32 | loss = loss.sum() / (weight > 0).float().sum() 33 | else: 34 | loss = loss.sum() * 0 35 | 36 | return loss 37 | 38 | 39 | class MSELoss(torch.nn.Module): 40 | ''' MSE loss for center map (without Instance-aware Guidance)''' 41 | def __init__(self): 42 | super(MSELoss, self).__init__() 43 | self.mse_loss = torch.nn.MSELoss(reduction='mean') 44 | 45 | def forward(self, out, target, weight): 46 | 47 | loss = self.mse_loss(out, target) 48 | 49 | return loss 50 | 51 | 52 | class Weighted_MSELoss(torch.nn.Module): 53 | ''' MSE loss for center map (with Instance-aware Guidance)''' 54 | def __init__(self): 55 | super(Weighted_MSELoss, self).__init__() 56 | self.mse_loss = torch.nn.MSELoss(reduction='none') 57 | 58 | def forward(self, out, target, weight): 59 | 60 | loss = self.mse_loss(out, target) * weight 61 | 62 | if weight.sum() > 0: 63 | loss = loss.sum() / (weight > 0).float().sum() 64 | else: 65 | loss = loss.sum() * 0 66 | 67 | return loss 68 | 69 | 70 | class DeepLabCE(torch.nn.Module): 71 | """ 72 | Hard pixel mining mining with cross entropy loss, for semantic segmentation. 73 | This is used in TensorFlow DeepLab frameworks. 74 | Reference: https://github.com/tensorflow/models/blob/bd488858d610e44df69da6f89277e9de8a03722c/research/deeplab/utils/train_utils.py#L33 75 | Arguments: 76 | ignore_label: Integer, label to ignore. 77 | top_k_percent_pixels: Float, the value lies in [0.0, 1.0]. When its value < 1.0, only compute the loss for 78 | the top k percent pixels (e.g., the top 20% pixels). This is useful for hard pixel mining. 79 | weight: Tensor, a manual rescaling weight given to each class. 80 | """ 81 | def __init__(self, ignore_label=255, top_k_percent_pixels=0.2, weight=None): 82 | super(DeepLabCE, self).__init__() 83 | self.top_k_percent_pixels = top_k_percent_pixels 84 | self.ignore_label = ignore_label 85 | self.criterion = torch.nn.CrossEntropyLoss(weight=weight, 86 | ignore_index=ignore_label, 87 | reduction='none') 88 | 89 | def forward(self, logits, labels): 90 | 91 | pixel_losses = self.criterion(logits, labels).contiguous().view(-1) 92 | 93 | if self.top_k_percent_pixels == 1.0: 94 | return pixel_losses.mean() 95 | 96 | top_k_pixels = int(self.top_k_percent_pixels * pixel_losses.numel()) 97 | pixel_losses, _ = torch.topk(pixel_losses, top_k_pixels) 98 | 99 | return pixel_losses.mean() 100 | 101 | 102 | class RegularCE(torch.nn.Module): 103 | """ 104 | Regular cross entropy loss for semantic segmentation, support pixel-wise loss weight. 105 | Arguments: 106 | ignore_label: Integer, label to ignore. 107 | weight: Tensor, a manual rescaling weight given to each class. 108 | """ 109 | def __init__(self, ignore_label=255, weight=None): 110 | super(RegularCE, self).__init__() 111 | self.ignore_label = ignore_label 112 | self.criterion = torch.nn.CrossEntropyLoss(weight=weight, 113 | ignore_index=ignore_label, 114 | reduction='none') 115 | 116 | def forward(self, logits, labels): 117 | pixel_losses = self.criterion(logits, labels) 118 | 119 | mask = (labels != self.ignore_label) 120 | 121 | if mask.sum() > 0: 122 | pixel_losses = pixel_losses.sum() / mask.sum() 123 | else: 124 | pixel_losses = pixel_losses.sum() * 0 125 | 126 | return pixel_losses 127 | 128 | 129 | 130 | def _neg_loss(pred, gt, weight): 131 | ''' Modified focal loss. Exactly the same as CornerNet. 132 | Runs faster and costs a little bit more memory 133 | Arguments: 134 | pred (batch x c x h x w) 135 | gt_regr (batch x c x h x w) 136 | ''' 137 | pos_inds = gt.eq(1).float() 138 | neg_inds = gt.lt(1).float() 139 | 140 | neg_weights = torch.pow(1 - gt, 4) 141 | 142 | loss = 0 143 | 144 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds * weight 145 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds * weight 146 | 147 | num_pos = pos_inds.float().sum() 148 | pos_loss = pos_loss.sum() 149 | neg_loss = neg_loss.sum() 150 | 151 | if num_pos == 0: 152 | loss = loss - neg_loss 153 | else: 154 | loss = loss - (pos_loss + neg_loss) / num_pos 155 | return loss 156 | 157 | 158 | class FocalLoss(torch.nn.Module): 159 | '''nn.Module warpper for focal loss''' 160 | def __init__(self): 161 | super(FocalLoss, self).__init__() 162 | self.neg_loss = _neg_loss 163 | 164 | def forward(self, out, target, weight): 165 | return self.neg_loss(out, target, weight) 166 | 167 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | BESTIE 2 | Copyright (c) 2022-present NAVER Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | 22 | -------------------------------------------------------------------------------------- 23 | 24 | This project contains subcomponents with separate copyright notices and license terms. 25 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 26 | 27 | ===== 28 | 29 | HRNet/HRNet-Human-Pose-Estimation 30 | https://github.com/HRNet/HRNet-Human-Pose-Estimation 31 | 32 | 33 | MIT License 34 | 35 | Copyright (c) 2019 Leo Xiao 36 | 37 | Permission is hereby granted, free of charge, to any person obtaining a copy 38 | of this software and associated documentation files (the "Software"), to deal 39 | in the Software without restriction, including without limitation the rights 40 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 41 | copies of the Software, and to permit persons to whom the Software is 42 | furnished to do so, subject to the following conditions: 43 | 44 | The above copyright notice and this permission notice shall be included in all 45 | copies or substantial portions of the Software. 46 | 47 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 48 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 49 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 50 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 51 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 52 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 53 | SOFTWARE. 54 | 55 | ===== 56 | 57 | bowenc0221/panoptic-deeplab 58 | https://github.com/bowenc0221/panoptic-deeplab 59 | 60 | 61 | Licensed under the Apache License, Version 2.0 (the "License"); 62 | you may not use this file except in compliance with the License. 63 | You may obtain a copy of the License at 64 | 65 | http://www.apache.org/licenses/LICENSE-2.0 66 | 67 | Unless required by applicable law or agreed to in writing, software 68 | distributed under the License is distributed on an "AS IS" BASIS, 69 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 70 | See the License for the specific language governing permissions and 71 | limitations under the License. 72 | 73 | ===== 74 | 75 | pytorch/vision 76 | https://github.com/pytorch/vision 77 | 78 | 79 | BSD 3-Clause License 80 | 81 | Copyright (c) Soumith Chintala 2016, 82 | All rights reserved. 83 | 84 | Redistribution and use in source and binary forms, with or without 85 | modification, are permitted provided that the following conditions are met: 86 | 87 | * Redistributions of source code must retain the above copyright notice, this 88 | list of conditions and the following disclaimer. 89 | 90 | * Redistributions in binary form must reproduce the above copyright notice, 91 | this list of conditions and the following disclaimer in the documentation 92 | and/or other materials provided with the distribution. 93 | 94 | * Neither the name of the copyright holder nor the names of its 95 | contributors may be used to endorse or promote products derived from 96 | this software without specific prior written permission. 97 | 98 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 99 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 100 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 101 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 102 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 103 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 104 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 105 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 106 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 107 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 108 | 109 | ===== 110 | 111 | tensorflow/models 112 | https://github.com/tensorflow/models/ 113 | 114 | 115 | Copyright 2016, The Authors. 116 | 117 | Licensed under the Apache License, Version 2.0 (the "License"); 118 | you may not use this file except in compliance with the License. 119 | You may obtain a copy of the License at 120 | 121 | http://www.apache.org/licenses/LICENSE-2.0 122 | 123 | Unless required by applicable law or agreed to in writing, software 124 | distributed under the License is distributed on an "AS IS" BASIS, 125 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 126 | See the License for the specific language governing permissions and 127 | limitations under the License. 128 | 129 | ===== 130 | 131 | qjadud1994/DRS 132 | https://github.com/qjadud1994/DRS 133 | 134 | 135 | MIT License 136 | 137 | Copyright (c) 2021 Beom 138 | 139 | Permission is hereby granted, free of charge, to any person obtaining a copy 140 | of this software and associated documentation files (the "Software"), to deal 141 | in the Software without restriction, including without limitation the rights 142 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 143 | copies of the Software, and to permit persons to whom the Software is 144 | furnished to do so, subject to the following conditions: 145 | 146 | The above copyright notice and this permission notice shall be included in all 147 | copies or substantial portions of the Software. 148 | 149 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 150 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 151 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 152 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 153 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 154 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 155 | SOFTWARE. 156 | 157 | ===== 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BESTIE - Official Pytorch Implementation (CVPR 2022) 2 | 3 | **Beyond Semantic to Instance Segmentation: Weakly-Supervised Instance Segmentation via Semantic Knowledge Transfer and Self-Refinement (CVPR 2022)**
4 | Beomyoung Kim1, YoungJoon Yoo1,2, Chaeeun Rhee3, Junmo Kim4
5 | 6 | 1 NAVER CLOVA
7 | 2 NAVER AI Lab
8 | 3 Inha University
9 | 4 KAIST
10 | 11 | [![Paper](https://img.shields.io/badge/arXiv-2109.09477-brightgreen)](https://arxiv.org/abs/2109.09477) 12 | 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-semantic-to-instance-segmentation/image-level-supervised-instance-segmentation)](https://paperswithcode.com/sota/image-level-supervised-instance-segmentation?p=beyond-semantic-to-instance-segmentation) 14 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-semantic-to-instance-segmentation/image-level-supervised-instance-segmentation-2)](https://paperswithcode.com/sota/image-level-supervised-instance-segmentation-2?p=beyond-semantic-to-instance-segmentation) 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-semantic-to-instance-segmentation/image-level-supervised-instance-segmentation-1)](https://paperswithcode.com/sota/image-level-supervised-instance-segmentation-1?p=beyond-semantic-to-instance-segmentation) 16 | 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-semantic-to-instance-segmentation/point-supervised-instance-segmentation-on)](https://paperswithcode.com/sota/point-supervised-instance-segmentation-on?p=beyond-semantic-to-instance-segmentation) 18 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-semantic-to-instance-segmentation/point-supervised-instance-segmentation-on-2)](https://paperswithcode.com/sota/point-supervised-instance-segmentation-on-2?p=beyond-semantic-to-instance-segmentation) 19 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-semantic-to-instance-segmentation/point-supervised-instance-segmentation-on-1)](https://paperswithcode.com/sota/point-supervised-instance-segmentation-on-1?p=beyond-semantic-to-instance-segmentation) 20 | 21 | 22 | 23 | 24 | # Abtract 25 | 26 | Weakly-supervised instance segmentation (WSIS) has been considered as a more challenging task than weakly-supervised semantic segmentation (WSSS). Compared to WSSS, WSIS requires instance-wise localization, which is difficult to extract from image-level labels. To tackle the problem, most WSIS approaches use off-the-shelf proposal techniques that require pre-training with instance or object level labels, deviating the fundamental definition of the fully-image-level supervised setting. 27 | In this paper, we propose a novel approach including two innovative components. First, we propose a *semantic knowledge transfer* to obtain pseudo instance labels by transferring the knowledge of WSSS to WSIS while eliminating the need for the off-the-shelf proposals. Second, we propose a *self-refinement* method to refine the pseudo instance labels in a self-supervised scheme and to use the refined labels for training in an online manner. Here, we discover an erroneous phenomenon, *semantic drift*, that occurred by the missing instances in pseudo instance labels categorized as background class. This *semantic drift* occurs confusion between background and instance in training and consequently degrades the segmentation performance. We term this problem as *semantic drift problem* and show that our proposed *self-refinement* method eliminates the semantic drift problem. 28 | The extensive experiments on PASCAL VOC 2012 and MS COCO demonstrate the effectiveness of our approach, and we achieve a considerable performance without off-the-shelf proposal techniques. The code is available at https://github.com/clovaai/BESTIE. 29 | 30 | # Experimental Results (VOC 2012, COCO) 31 | 32 | 33 | 34 | 35 | * BESTIE (HRNet48, Image-label) : 42.6 mAP50 on VOC2012 [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/BESTIE_HRNet48_image_label.pt) 36 | * BESTIE (HRNet48, point-label) : 46.7 mAP50 on VOC2012 [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/BESTIE_HRNet48_point_label.pt) 37 | 38 | Extra Sources 39 | * PAM: [[pretrained_weight]](https://github.com/clovaai/BESTIE/releases/download/asset/PAM.pth) 40 | * HRNet-W32: [[imagenet_pretrained_weight]](https://github.com/clovaai/BESTIE/releases/download/asset/hrnet_w32-36af842e.pth) 41 | * HRNet-W48: [[imagenet_pretrained_weight]](https://github.com/clovaai/BESTIE/releases/download/asset/hrnet_w48-8ef0771d.pth) 42 | 43 | # Qualitative Results 44 | 45 | 46 | 47 | 48 | # News 49 | 50 | - [x] official pytorch code release 51 | - [x] release the code for the classifier with PAM module 52 | - [ ] update training code and dataset for COCO 53 | 54 | # How To Run 55 | 56 | ### Requirements 57 | - torch>=1.10.1 58 | - torchvision>=0.11.2 59 | - chainercv>=0.13.1 60 | - numpy 61 | - pillow 62 | - scikit-learn 63 | - tqdm 64 | 65 | ### Datasets 66 | 67 | - Download Pascal VOC2012 dataset from the [official dataset homepage](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/). 68 | - `Center_points/` (ground-trugh point labels) [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/Center_points.zip) 69 | - `Peak_points/` (point labels extracted by PAM module and image-level labels) [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/Peak_points.zip) 70 | - `WSSS_maps/` (weakly-supervised semantic segmentation outputs) [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/WSSS_maps.zip) 71 | - `SegmentationObject/` (ground-truth mask labels) [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/SegmentationObject.zip) 72 | 73 | ``` 74 | data_root/ 75 | --- VOC2012/ 76 | --- Annotations/ 77 | --- ImageSet/ 78 | --- JPEGImages/ 79 | --- SegmentationObject/ 80 | --- Center_points/ 81 | --- Peak_points/ 82 | --- WSSS_maps/ 83 | ``` 84 | 85 | ### Image-level Supervised Instance Segmentation on VOC2012 86 | ``` 87 | # change the data ROOT in the shell script 88 | bash scrips/run_image_labels.sh 89 | ``` 90 | 91 | ### Point Supervised Instance Segmentation on VOC2012 92 | ``` 93 | # change the data ROOT in the shell script 94 | bash scrips/run_point_labels.sh 95 | ``` 96 | 97 | ### Mask R-CNN Refinement 98 | 99 | 1. Generate COCO-style pseudo labels using the BESTIE model. 100 | 2. Train the Mask R-CNN using the pseudo-labels: https://github.com/facebookresearch/maskrcnn-benchmark . 101 | 102 | 103 | # Acknowledgement 104 | 105 | Our implementation is based on these repositories: 106 | - (Panoptic-DeepLab) https://github.com/bowenc0221/panoptic-deeplab 107 | - (HRNet) https://github.com/HRNet/HRNet-Human-Pose-Estimation 108 | - (DRS) https://github.com/qjadud1994/DRS 109 | 110 | # License 111 | 112 | ``` 113 | Copyright (c) 2022-present NAVER Corp. 114 | 115 | Permission is hereby granted, free of charge, to any person obtaining a copy 116 | of this software and associated documentation files (the "Software"), to deal 117 | in the Software without restriction, including without limitation the rights 118 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 119 | copies of the Software, and to permit persons to whom the Software is 120 | furnished to do so, subject to the following conditions: 121 | 122 | The above copyright notice and this permission notice shall be included in 123 | all copies or substantial portions of the Software. 124 | 125 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 126 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 127 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 128 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 129 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 130 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 131 | THE SOFTWARE. 132 | ``` 133 | -------------------------------------------------------------------------------- /PAM/models/classifier.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/models/vgg.py 3 | # ------------------------------------------------------------------------------ 4 | """ 5 | BESTIE 6 | Copyright (c) 2022-present NAVER Corp. 7 | MIT License 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils.model_zoo as model_zoo 13 | import torch.nn.functional as F 14 | import math 15 | import cv2 16 | import numpy as np 17 | import os 18 | 19 | model_urls = {'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth'} 20 | 21 | 22 | class PAM(nn.Module): 23 | def __init__(self, alpha): 24 | super(PAM, self).__init__() 25 | self.relu = nn.ReLU(inplace=True) 26 | self.selector = nn.AdaptiveMaxPool2d(1) 27 | self.alpha = alpha 28 | 29 | def forward(self, x): 30 | b, c, _, _ = x.size() 31 | x = self.relu(x) 32 | 33 | """ 1: selector """ 34 | peak_region = self.selector(x).view(b, c, 1, 1) 35 | peak_region = peak_region.expand_as(x) 36 | 37 | """ 2: Controller -> self.alpha""" 38 | boundary = (x < peak_region * self.alpha) 39 | 40 | """ 3: Peak Stimulator""" 41 | x = torch.where(boundary, torch.zeros_like(x), x) 42 | 43 | return x 44 | 45 | 46 | class VGG(nn.Module): 47 | 48 | def __init__(self, features, num_classes=20, 49 | alpha=0.7, init_weights=True): 50 | 51 | super(VGG, self).__init__() 52 | 53 | self.features = features 54 | 55 | self.layer1_conv1 = features[0] 56 | self.layer1_conv2 = features[2] 57 | self.layer1_maxpool = features[4] 58 | 59 | self.layer2_conv1 = features[5] 60 | self.layer2_conv2 = features[7] 61 | self.layer2_maxpool = features[9] 62 | 63 | self.layer3_conv1 = features[10] 64 | self.layer3_conv2 = features[12] 65 | self.layer3_conv3 = features[14] 66 | self.layer3_maxpool = features[16] 67 | 68 | self.layer4_conv1 = features[17] 69 | self.layer4_conv2 = features[19] 70 | self.layer4_conv3 = features[21] 71 | self.layer4_maxpool = features[23] 72 | 73 | self.layer5_conv1 = features[24] 74 | self.layer5_conv2 = features[26] 75 | self.layer5_conv3 = features[28] 76 | 77 | self.extra_conv1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 78 | self.extra_conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 79 | self.extra_conv3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 80 | self.extra_conv4 = nn.Conv2d(512, 20, kernel_size=1) 81 | 82 | self.pam = PAM(alpha) 83 | self.relu = nn.ReLU(inplace=True) 84 | 85 | if init_weights: 86 | self._initialize_weights(self.extra_conv1) 87 | self._initialize_weights(self.extra_conv2) 88 | self._initialize_weights(self.extra_conv3) 89 | self._initialize_weights(self.extra_conv4) 90 | 91 | def forward(self, x, label=None, size=None): 92 | if size is None: 93 | size = x.size()[2:] 94 | 95 | # layer1 96 | x = self.layer1_conv1(x) 97 | x = self.relu(x) 98 | x = self.layer1_conv2(x) 99 | x = self.relu(x) 100 | x = self.layer1_maxpool(x) 101 | 102 | # layer2 103 | x = self.layer2_conv1(x) 104 | x = self.relu(x) 105 | x = self.layer2_conv2(x) 106 | x = self.relu(x) 107 | x = self.layer2_maxpool(x) 108 | 109 | # layer3 110 | x = self.layer3_conv1(x) 111 | x = self.relu(x) 112 | x = self.layer3_conv2(x) 113 | x = self.relu(x) 114 | x = self.layer3_conv3(x) 115 | x = self.relu(x) 116 | x = self.layer3_maxpool(x) 117 | 118 | # layer4 119 | x = self.layer4_conv1(x) 120 | x = self.relu(x) 121 | x = self.layer4_conv2(x) 122 | x = self.relu(x) 123 | x = self.layer4_conv3(x) 124 | x = self.relu(x) 125 | x = self.layer4_maxpool(x) 126 | 127 | # layer5 128 | x = self.layer5_conv1(x) 129 | x = self.relu(x) 130 | x = self.layer5_conv2(x) 131 | x = self.relu(x) 132 | x = self.layer5_conv3(x) 133 | x = self.relu(x) 134 | # ============================== 135 | 136 | x = self.extra_conv1(x) 137 | x = self.pam(x) 138 | x = self.extra_conv2(x) 139 | x = self.pam(x) 140 | x = self.extra_conv3(x) 141 | x = self.pam(x) 142 | x = self.extra_conv4(x) 143 | # ============================== 144 | 145 | logit = self.fc(x) 146 | 147 | if self.training: 148 | return logit 149 | 150 | else: 151 | cam = self.cam_normalize(x.detach(), size, label) 152 | return logit, cam 153 | 154 | 155 | def fc(self, x): 156 | x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0) 157 | x = x.view(-1, 20) 158 | return x 159 | 160 | 161 | def cam_normalize(self, cam, size, label): 162 | B, C, H, W = cam.size() 163 | 164 | cam = F.relu(cam) 165 | cam = cam * label[:, :, None, None] 166 | 167 | cam = F.interpolate(cam, size=size, mode='bilinear', align_corners=False) 168 | cam /= F.adaptive_max_pool2d(cam, 1) + 1e-5 169 | 170 | return cam 171 | 172 | 173 | def _initialize_weights(self, layer): 174 | for m in layer.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 177 | m.weight.data.normal_(0, math.sqrt(2. / n)) 178 | if m.bias is not None: 179 | m.bias.data.zero_() 180 | elif isinstance(m, nn.BatchNorm2d): 181 | m.weight.data.fill_(1) 182 | m.bias.data.zero_() 183 | elif isinstance(m, nn.Linear): 184 | m.weight.data.normal_(0, 0.01) 185 | m.bias.data.zero_() 186 | 187 | def get_parameter_groups(self): 188 | groups = ([], [], [], []) 189 | 190 | for name, value in self.named_parameters(): 191 | 192 | if 'extra' in name: 193 | if 'weight' in name: 194 | groups[2].append(value) 195 | else: 196 | groups[3].append(value) 197 | else: 198 | if 'weight' in name: 199 | groups[0].append(value) 200 | else: 201 | groups[1].append(value) 202 | return groups 203 | 204 | 205 | 206 | 207 | ####################################################################################################### 208 | 209 | 210 | def make_layers(cfg, batch_norm=False): 211 | layers = [] 212 | in_channels = 3 213 | for i, v in enumerate(cfg): 214 | if v == 'M': 215 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 216 | elif v == 'N': 217 | layers += [nn.MaxPool2d(kernel_size=3, stride=1, padding=1)] 218 | else: 219 | if i > 13: 220 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, dilation=2, padding=2) 221 | else: 222 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 223 | if batch_norm: 224 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 225 | else: 226 | layers += [conv2d, nn.ReLU(inplace=True)] 227 | in_channels = v 228 | return nn.Sequential(*layers) 229 | 230 | 231 | cfg = { 232 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 233 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 234 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 235 | 'D1': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'N', 512, 512, 512], 236 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 237 | } 238 | 239 | 240 | def vgg16_pam(pretrained=True, alpha=0.7): 241 | model = VGG(make_layers(cfg['D1']), alpha=alpha) 242 | 243 | if pretrained: 244 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16']), strict=False) 245 | return model 246 | 247 | 248 | if __name__ == '__main__': 249 | import copy 250 | 251 | model = vgg16(pretrained=True) 252 | print() 253 | 254 | print(model) 255 | 256 | input = torch.randn(2, 3, 321, 321) 257 | label = np.array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]) 258 | label = torch.from_numpy(label) 259 | 260 | out = model(input, label) 261 | 262 | print(out[1].shape) 263 | 264 | -------------------------------------------------------------------------------- /utils/decode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def colorize_offset(offset_map, offset_weight, seg_map=None, pred=True): 5 | 6 | import matplotlib.colors 7 | import math 8 | 9 | a = (np.arctan2(-offset_map[0], -offset_map[1]) / math.pi + 1) / 2 10 | 11 | r = np.sqrt(offset_map[0] ** 2 + offset_map[1] ** 2) 12 | s = r / (np.max(r) + 1e-5) 13 | 14 | hsv_color = np.stack((a, s, np.ones_like(a)), axis=-1) 15 | rgb_color = matplotlib.colors.hsv_to_rgb(hsv_color) 16 | rgb_color = np.uint8(rgb_color * 255) 17 | 18 | if seg_map is not None: 19 | rgb_color[np.where(seg_map == 0)] = [0, 0, 0] # background 20 | 21 | if not pred: 22 | rgb_color[np.where(offset_weight == 0)] = [255, 255, 255] # ignore 23 | 24 | return rgb_color 25 | 26 | 27 | def voc_names(): 28 | return [ 29 | "background", "aeroplane", "bicycle", "bird", 30 | "boat", "bottle", "bus", "car", "cat", "chair", 31 | "cow", "diningtable", "dog", "horse", "motorbike", 32 | "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", 33 | ] 34 | 35 | 36 | def get_palette(): 37 | palette = [] 38 | for i in range(256): 39 | palette.extend((i,i,i)) 40 | palette[:3*21] = np.array([[0, 0, 0], 41 | [128, 0, 0], 42 | [0, 128, 0], 43 | [128, 128, 0], 44 | [0, 0, 128], 45 | [128, 0, 128], 46 | [0, 128, 128], 47 | [128, 128, 128], 48 | [64, 0, 0], 49 | [192, 0, 0], 50 | [64, 128, 0], 51 | [192, 128, 0], 52 | [64, 0, 128], 53 | [192, 0, 128], 54 | [64, 128, 128], 55 | [192, 128, 128], 56 | [0, 64, 0], 57 | [128, 64, 0], 58 | [0, 192, 0], 59 | [128, 192, 0], 60 | [0, 64, 128]], dtype='uint8').flatten() 61 | 62 | return palette 63 | 64 | def voc_colors(): 65 | colors = np.array([[0, 0, 0], 66 | [128, 0, 0], 67 | [0, 128, 0], 68 | [128, 128, 0], 69 | [0, 0, 128], 70 | [128, 0, 128], 71 | [0, 128, 128], 72 | [128, 128, 128], 73 | [64, 0, 0], 74 | [192, 0, 0], 75 | [64, 128, 0], 76 | [192, 128, 0], 77 | [64, 0, 128], 78 | [192, 0, 128], 79 | [64, 128, 128], 80 | [192, 128, 128], 81 | [0, 64, 0], 82 | [128, 64, 0], 83 | [0, 192, 0], 84 | [128, 192, 0], 85 | [0, 64, 128], 86 | [255, 255, 255], 87 | [200, 200, 200]], dtype='uint8') 88 | return colors 89 | 90 | 91 | def cam_to_seg(CAM, sal_map, palette, alpha=0.2, ignore=255): 92 | colors = voc_colors() 93 | C, H, W = CAM.shape 94 | 95 | CAM[CAM < alpha] = 0 # object cue 96 | 97 | bg = np.zeros((1, H, W), dtype=np.float32) 98 | pred_map = np.concatenate([bg, CAM], axis=0) # [21, H, W] 99 | 100 | pred_map[0, :, :] = (1. - sal_map) # backgroudn cue 101 | 102 | # conflict pixels with multiple confidence values 103 | bg = np.array(pred_map > 0.99, dtype=np.uint8) 104 | bg = np.sum(bg, axis=0) 105 | pred_map = pred_map.argmax(0).astype(np.uint8) 106 | pred_map[bg > 2] = ignore 107 | 108 | # pixels regarded as background but confidence saliency values 109 | bg = (sal_map == 1).astype(np.uint8) * (pred_map == 0).astype(np.uint8) # and operator 110 | pred_map[bg > 0] = ignore 111 | 112 | pred_map = np.uint8(pred_map) 113 | 114 | palette = get_palette() 115 | pred_map = Image.fromarray(pred_map) 116 | pred_map.putpalette(palette) 117 | 118 | return pred_map 119 | 120 | 121 | def cam_with_crf(cam, img, keys, fg_thresh=0.1, bg_thresh=0.001): 122 | cam = np.float32(cam) / 255. 123 | 124 | cam = cam[keys] # valid category selection 125 | 126 | valid_cat = np.pad(keys + 1, (1, 0), mode='constant') # valid category : [background, val_cat+1] 127 | 128 | fg_conf_cam = np.pad(cam, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=fg_thresh) # [c+1, H, W] 129 | fg_conf_cam = np.argmax(fg_conf_cam, axis=0) 130 | pred = crf_inference_label(img, fg_conf_cam, n_labels=valid_cat.shape[0]) 131 | fg_conf = valid_cat[pred] # convert to whole index (0, 1, 2) -> (0 ~ 20) 132 | 133 | bg_conf_cam = np.pad(cam, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=bg_thresh) # [c+1, H, W] 134 | bg_conf_cam = np.argmax(bg_conf_cam, axis=0) 135 | pred = crf_inference_label(img, bg_conf_cam, n_labels=valid_cat.shape[0]) 136 | bg_conf = valid_cat[pred] # convert to whole index (0, 1, 2) -> (0 ~ 20) 137 | 138 | conf = fg_conf.copy() 139 | conf[fg_conf == 0] = 21 140 | conf[bg_conf + fg_conf == 0] = 0 # both zero 141 | 142 | conf_color = colors[conf] 143 | 144 | return conf, conf_color 145 | 146 | 147 | def heatmap_colorize(score_map, exclude_zero=True, normalize=True): 148 | import matplotlib.colors 149 | 150 | VOC_color = np.array([(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), 151 | (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), 152 | (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), 153 | (0, 192, 0), (128, 192, 0), (0, 64, 128), (255, 255, 255)], np.float32) 154 | 155 | if exclude_zero: 156 | VOC_color = VOC_color[1:] 157 | 158 | test = VOC_color[np.argmax(score_map, axis=0)%22] 159 | test = np.expand_dims(np.max(score_map, axis=0), axis=-1) * test 160 | if normalize: 161 | test /= np.max(test) + 1e-5 162 | 163 | return test 164 | 165 | def decode_seg_map_sequence(label_masks): 166 | if label_masks.ndim == 2: 167 | label_masks = label_masks[None, :, :] 168 | 169 | rgb_masks = [] 170 | for label_mask in label_masks: 171 | rgb_mask = decode_segmap(label_mask) 172 | rgb_masks.append(rgb_mask) 173 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 174 | return rgb_masks 175 | 176 | 177 | def decode_segmap(label_mask, plot=False): 178 | """Decode segmentation class labels into a color image 179 | Args: 180 | label_mask (np.ndarray): an (M,N) array of integer values denoting 181 | the class label at each spatial location. 182 | plot (bool, optional): whether to show the resulting color image 183 | in a figure. 184 | Returns: 185 | (np.ndarray, optional): the resulting decoded color image. 186 | """ 187 | 188 | n_classes = 21 189 | label_colours = get_pascal_labels() 190 | 191 | r = label_mask.copy() 192 | g = label_mask.copy() 193 | b = label_mask.copy() 194 | for ll in range(0, n_classes): 195 | r[label_mask == ll] = label_colours[ll, 0] 196 | g[label_mask == ll] = label_colours[ll, 1] 197 | b[label_mask == ll] = label_colours[ll, 2] 198 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 199 | rgb[:, :, 0] = r / 255.0 200 | rgb[:, :, 1] = g / 255.0 201 | rgb[:, :, 2] = b / 255.0 202 | if plot: 203 | plt.imshow(rgb) 204 | plt.show() 205 | else: 206 | return rgb 207 | 208 | 209 | def encode_segmap(mask): 210 | """Encode segmentation label images as pascal classes 211 | Args: 212 | mask (np.ndarray): raw segmentation label image of dimension 213 | (M, N, 3), in which the Pascal classes are encoded as colours. 214 | Returns: 215 | (np.ndarray): class map with dimensions (M,N), where the value at 216 | a given location is the integer denoting the class index. 217 | """ 218 | mask = mask.astype(int) 219 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 220 | for ii, label in enumerate(get_pascal_labels()): 221 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 222 | label_mask = label_mask.astype(int) 223 | return label_mask 224 | 225 | 226 | def get_pascal_labels(): 227 | """Load the mapping that associates pascal classes with label colors 228 | Returns: 229 | np.ndarray with dimensions (21, 3) 230 | """ 231 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 232 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 233 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 234 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 235 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 236 | [0, 64, 128]]) 237 | 238 | 239 | def get_ins_colors(): 240 | ins_colors = np.random.random((2000, 3)) 241 | ins_colors = np.uint8(ins_colors*255) 242 | ins_colors[0] = [0, 0, 0] 243 | ins_colors[1] = [192, 128, 0] 244 | ins_colors[2] = [64, 0, 128] 245 | ins_colors[3] = [192, 0, 128] 246 | ins_colors[4] = [64, 128, 128] 247 | ins_colors[5] = [192, 128, 128] 248 | ins_colors[6] = [0, 64, 0] 249 | ins_colors[7] = [128, 64, 0] 250 | ins_colors[8] = [0, 192, 0] 251 | ins_colors[9] = [128, 192, 0] 252 | ins_colors[10] = [0, 64, 128] 253 | ins_colors[11] = [128, 0, 0] 254 | ins_colors[12] = [0, 128, 0] 255 | ins_colors[13] = [128, 128, 0] 256 | ins_colors[14] = [0, 0, 128] 257 | ins_colors[15] = [128, 0, 128] 258 | ins_colors[16] = [0, 128, 128] 259 | ins_colors[17] = [128, 128, 128] 260 | ins_colors[18] = [64, 0, 0] 261 | ins_colors[19] = [192, 0, 0] 262 | ins_colors[20] = [64, 128, 0] 263 | return ins_colors -------------------------------------------------------------------------------- /utils/LoadData.py: -------------------------------------------------------------------------------- 1 | """ 2 | BESTIE 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | 7 | import os 8 | from PIL import Image 9 | import numpy as np 10 | 11 | import torch 12 | import torchvision.transforms.functional as TF 13 | from .transforms import transforms as T 14 | 15 | 16 | from torch.utils.data import Dataset 17 | from .utils import gaussian, pseudo_label_generation 18 | 19 | def get_dataset(args, mode): 20 | if 'coco' in args.dataset: 21 | mean_vals = [0.471, 0.448, 0.408] 22 | std_vals = [0.234, 0.239, 0.242] 23 | else: 24 | mean_vals = [0.485, 0.456, 0.406] 25 | std_vals = [0.229, 0.224, 0.225] 26 | 27 | if mode == 'train': 28 | data_list = "data/train_cls.txt" # 10,582 images 29 | 30 | crop_size = int(args.crop_size) 31 | input_size = crop_size + 64 32 | 33 | min_resize_value = input_size 34 | max_resize_value = input_size 35 | resize_factor = 32 36 | 37 | min_scale = 0.7 38 | max_scale = 1.3 39 | scale_step_size=0.1 40 | crop_h, crop_w = crop_size, crop_size 41 | 42 | pad_value = tuple([int(v * 255) for v in mean_vals]) 43 | ignore_label = (0, 0, 0) 44 | 45 | transform = T.Compose( 46 | [ 47 | T.PhotometricDistort(), 48 | T.Resize(min_resize_value, 49 | max_resize_value, 50 | resize_factor), 51 | T.RandomScale( 52 | min_scale, 53 | max_scale, 54 | scale_step_size 55 | ), 56 | T.RandomCrop( 57 | crop_h, 58 | crop_w, 59 | pad_value, 60 | ignore_label, 61 | random_pad=True 62 | ), 63 | T.ToTensor(), 64 | T.RandomHorizontalFlip(), 65 | T.Normalize( 66 | mean_vals, 67 | std_vals 68 | ) 69 | ] 70 | ) 71 | 72 | dataset = VOCDataset(data_list, 73 | root_dir=args.root_dir, 74 | num_classes=args.num_classes, 75 | transform=transform, 76 | sup=args.sup, 77 | sigma=args.sigma, 78 | point_thresh=args.pseudo_thresh) 79 | 80 | else: 81 | #data_list = "data/train_labeled_cls.txt" 82 | data_list = "data/val_cls.txt" 83 | 84 | dataset = VOCTestDataset(data_list, 85 | root_dir=args.root_dir, 86 | num_classes=args.num_classes) 87 | 88 | return dataset 89 | 90 | 91 | class VOCDataset(Dataset): 92 | def __init__(self, datalist_file, root_dir, num_classes=20, 93 | transform=None, sup='cls', sigma=8, point_thresh=0.5): 94 | 95 | self.num_classes = num_classes 96 | self.transform = transform 97 | self.sigma = sigma 98 | self.sup = sup 99 | self.point_thresh = point_thresh 100 | 101 | self.g = gaussian(sigma) 102 | 103 | self.dat_list = self.read_labeled_image_list(root_dir, datalist_file) 104 | 105 | 106 | def __getitem__(self, idx): 107 | img_path = self.dat_list["img"][idx] 108 | seg_map_path = self.dat_list["seg_map"][idx] 109 | cls_label = self.dat_list["cls_label"][idx] 110 | points = self.dat_list["point"][idx] 111 | 112 | img = np.uint8(Image.open(img_path).convert("RGB")) 113 | seg_map = np.uint8(Image.open(seg_map_path)) 114 | 115 | if self.transform is not None: 116 | img, seg_map, points = self.transform(img, seg_map, points) 117 | 118 | center_map, offset_map, weight = pseudo_label_generation(self.sup, 119 | seg_map.numpy(), 120 | points, 121 | cls_label, 122 | self.num_classes, 123 | self.sigma, 124 | self.g) 125 | 126 | point_list = self.make_class_wise_point_list(points) 127 | 128 | seg_map = seg_map.long() 129 | center_map = torch.from_numpy(center_map) 130 | offset_map = torch.from_numpy(offset_map) 131 | weight = torch.from_numpy(weight) 132 | point_list = torch.from_numpy(point_list) 133 | 134 | return img, cls_label, seg_map, center_map, offset_map, weight, point_list 135 | 136 | 137 | def make_class_wise_point_list(self, points): 138 | 139 | MAX_NUM_POINTS = 128 140 | 141 | point_list = np.zeros((self.num_classes, MAX_NUM_POINTS, 2), dtype=np.int32) 142 | point_count = [0 for _ in range(self.num_classes)] 143 | 144 | for (x, y, cls, _ ) in points: 145 | point_list[cls][point_count[cls]] = [y, x] 146 | point_count[cls] += 1 147 | 148 | return point_list 149 | 150 | 151 | def read_labeled_image_list(self, root_dir, data_list): 152 | img_dir = os.path.join(root_dir, "JPEGImages") 153 | seg_map_dir = os.path.join(root_dir, "WSSS_maps") 154 | 155 | if self.sup == 'point': 156 | point_dir = os.path.join(root_dir, "Center_points") 157 | else: 158 | point_dir = os.path.join(root_dir, "Peak_points") 159 | 160 | with open(data_list, 'r') as f: 161 | lines = f.read().splitlines() 162 | 163 | img_list = [] 164 | label_list = [] 165 | seg_map_list = [] 166 | point_list = [] 167 | 168 | np.random.shuffle(lines) 169 | 170 | for line in lines: 171 | fields = line.strip().split(" ") 172 | # fields[0] : file_name 173 | # fields[1:] : cls labels 174 | 175 | image_path = os.path.join(img_dir, fields[0] + '.jpg') 176 | seg_map_path = os.path.join(seg_map_dir, fields[0] + '.png') 177 | point_txt = os.path.join(point_dir, fields[0] + '.txt') 178 | 179 | # one-hot cls label 180 | labels = np.zeros((self.num_classes,), dtype=np.float32) 181 | for i in range(len(fields)-1): 182 | index = int(fields[i+1]) 183 | labels[index] = 1. 184 | 185 | # get points 186 | with open(point_txt, 'r') as pf: 187 | points = pf.read().splitlines() 188 | points = [p.strip().split(" ") for p in points] 189 | points = [ [float(p[0]), float(p[1]), int(p[2]), float(p[3])] for p in points if float(p[3]) > self.point_thresh] 190 | # point (x_coord, y_coord, class-idx, conf) 191 | 192 | img_list.append(image_path) 193 | seg_map_list.append(seg_map_path) 194 | point_list.append(points) 195 | label_list.append(labels) 196 | 197 | return {"img": img_list, 198 | "cls_label": label_list, 199 | "seg_map": seg_map_list, 200 | "point": point_list} 201 | 202 | def __len__(self): 203 | return len(self.dat_list["img"]) 204 | 205 | 206 | 207 | 208 | class VOCTestDataset(Dataset): 209 | def __init__(self, datalist_file, root_dir, num_classes=20): 210 | 211 | self.num_classes = num_classes 212 | self.dat_list = self.read_labeled_image_list(root_dir, datalist_file) 213 | 214 | 215 | def __getitem__(self, idx): 216 | fname = self.dat_list["fname"][idx] 217 | img_path = self.dat_list["img"][idx] 218 | cls_label = self.dat_list["cls_label"][idx] 219 | points = self.dat_list["point"][idx] 220 | 221 | img = Image.open(img_path).convert("RGB") 222 | 223 | ori_w, ori_h = img.size 224 | 225 | new_h = (ori_h + 31) // 32 * 32 226 | new_w = (ori_w + 31) // 32 * 32 227 | 228 | img = img.resize((new_w, new_h), Image.BILINEAR) 229 | img = TF.to_tensor(img) 230 | img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 231 | 232 | return img, cls_label, points, fname, (ori_h, ori_w) 233 | 234 | 235 | def read_labeled_image_list(self, root_dir, data_list): 236 | img_dir = os.path.join(root_dir, "JPEGImages") 237 | mask_dir = os.path.join(root_dir, "SegmentationObjectAug") 238 | point_dir = os.path.join(root_dir, "Center_points") 239 | 240 | with open(data_list, 'r') as f: 241 | lines = f.read().splitlines() 242 | 243 | fname_list = [] 244 | img_list = [] 245 | mask_list = [] 246 | label_list = [] 247 | point_list = [] 248 | 249 | for line in lines: 250 | fields = line.strip().split(" ") 251 | # fields[0] : file_name 252 | # fields[1:] : cls labels 253 | 254 | image_path = os.path.join(img_dir, fields[0] + '.jpg') 255 | mask_path = os.path.join(mask_dir, fields[0] + '.png') 256 | point_txt = os.path.join(point_dir, fields[0] + '.txt') 257 | 258 | # one-hot cls label 259 | labels = np.zeros((self.num_classes,), dtype=np.float32) 260 | for i in range(len(fields)-1): 261 | index = int(fields[i+1]) 262 | labels[index] = 1. 263 | 264 | # get points 265 | with open(point_txt, 'r') as pf: 266 | points = pf.read().splitlines() 267 | points = [p.strip().split(" ") for p in points] 268 | points = [ [float(p[0]), float(p[1]), int(p[2]), float(p[3])] for p in points] 269 | # point (x_coord, y_coord, class-idx, conf) 270 | 271 | points_cls = [[] for _ in range(self.num_classes)] 272 | for (x, y, cls, _ ) in points: 273 | points_cls[cls].append((y, x)) 274 | 275 | fname_list.append(fields[0]) 276 | img_list.append(image_path) 277 | mask_list.append(mask_path) 278 | point_list.append(points_cls) 279 | label_list.append(labels) 280 | 281 | return {"img": img_list, 282 | "mask": mask_list, 283 | "cls_label": label_list, 284 | "point": point_list, 285 | "fname": fname_list} 286 | 287 | def __len__(self): 288 | return len(self.dat_list["img"]) 289 | 290 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Reference: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | # Modified by Bowen Cheng (bcheng9@illinois.edu) 4 | # ------------------------------------------------------------------------------ 5 | 6 | import torch.nn as nn 7 | try: 8 | from torchvision.models.utils import load_state_dict_from_url 9 | except: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 15 | 'wide_resnet50_2', 'wide_resnet101_2'] 16 | 17 | 18 | model_urls = { 19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 24 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 25 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 26 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 27 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 28 | } 29 | 30 | 31 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 32 | """3x3 convolution with padding""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 34 | padding=dilation, groups=groups, bias=False, dilation=dilation) 35 | 36 | 37 | def conv1x1(in_planes, out_planes, stride=1): 38 | """1x1 convolution""" 39 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 40 | 41 | 42 | class BasicBlock(nn.Module): 43 | expansion = 1 44 | 45 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 46 | base_width=64, dilation=1, norm_layer=None): 47 | super(BasicBlock, self).__init__() 48 | if norm_layer is None: 49 | norm_layer = nn.BatchNorm2d 50 | if groups != 1 or base_width != 64: 51 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 52 | if dilation > 1: 53 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 54 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 55 | self.conv1 = conv3x3(inplanes, planes, stride) 56 | self.bn1 = norm_layer(planes) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv2 = conv3x3(planes, planes) 59 | self.bn2 = norm_layer(planes) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | identity = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | 73 | if self.downsample is not None: 74 | identity = self.downsample(x) 75 | 76 | out += identity 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class Bottleneck(nn.Module): 83 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 84 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 85 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 86 | # This variant is also known as ResNet V1.5 and improves accuracy according to 87 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 88 | 89 | expansion = 4 90 | 91 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 92 | base_width=64, dilation=1, norm_layer=None): 93 | super(Bottleneck, self).__init__() 94 | if norm_layer is None: 95 | norm_layer = nn.BatchNorm2d 96 | width = int(planes * (base_width / 64.)) * groups 97 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 98 | self.conv1 = conv1x1(inplanes, width) 99 | self.bn1 = norm_layer(width) 100 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 101 | self.bn2 = norm_layer(width) 102 | self.conv3 = conv1x1(width, planes * self.expansion) 103 | self.bn3 = norm_layer(planes * self.expansion) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x): 109 | identity = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | identity = self.downsample(x) 124 | 125 | out += identity 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | 131 | class ResNet(nn.Module): 132 | 133 | def __init__(self, block, layers, zero_init_residual=False, 134 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 135 | norm_layer=None): 136 | super(ResNet, self).__init__() 137 | if norm_layer is None: 138 | norm_layer = nn.BatchNorm2d 139 | self._norm_layer = norm_layer 140 | 141 | self.inplanes = 64 142 | self.dilation = 1 143 | if replace_stride_with_dilation is None: 144 | # each element in the tuple indicates if we should replace 145 | # the 2x2 stride with a dilated convolution instead 146 | replace_stride_with_dilation = [False, False, False] 147 | if len(replace_stride_with_dilation) != 3: 148 | raise ValueError("replace_stride_with_dilation should be None " 149 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 150 | self.groups = groups 151 | self.base_width = width_per_group 152 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 153 | bias=False) 154 | self.bn1 = norm_layer(self.inplanes) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 157 | self.layer1 = self._make_layer(block, 64, layers[0]) 158 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 159 | dilate=replace_stride_with_dilation[0]) 160 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 161 | dilate=replace_stride_with_dilation[1]) 162 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 163 | dilate=replace_stride_with_dilation[2]) 164 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 165 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 166 | 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 170 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 171 | nn.init.constant_(m.weight, 1) 172 | nn.init.constant_(m.bias, 0) 173 | 174 | # Zero-initialize the last BN in each residual branch, 175 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 176 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 177 | if zero_init_residual: 178 | for m in self.modules(): 179 | if isinstance(m, Bottleneck): 180 | nn.init.constant_(m.bn3.weight, 0) 181 | elif isinstance(m, BasicBlock): 182 | nn.init.constant_(m.bn2.weight, 0) 183 | 184 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 185 | norm_layer = self._norm_layer 186 | downsample = None 187 | previous_dilation = self.dilation 188 | if dilate: 189 | self.dilation *= stride 190 | stride = 1 191 | if stride != 1 or self.inplanes != planes * block.expansion: 192 | downsample = nn.Sequential( 193 | conv1x1(self.inplanes, planes * block.expansion, stride), 194 | norm_layer(planes * block.expansion), 195 | ) 196 | 197 | layers = [] 198 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 199 | self.base_width, previous_dilation, norm_layer)) 200 | self.inplanes = planes * block.expansion 201 | for _ in range(1, blocks): 202 | layers.append(block(self.inplanes, planes, groups=self.groups, 203 | base_width=self.base_width, dilation=self.dilation, 204 | norm_layer=norm_layer)) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def _forward_impl(self, x): 209 | outputs = {} 210 | # See note [TorchScript super()] 211 | x = self.conv1(x) 212 | x = self.bn1(x) 213 | x = self.relu(x) 214 | x = self.maxpool(x) 215 | outputs['stem'] = x 216 | 217 | x = self.layer1(x) # 1/4 218 | outputs['res2'] = x 219 | 220 | x = self.layer2(x) # 1/8 221 | outputs['res3'] = x 222 | 223 | x = self.layer3(x) # 1/16 224 | outputs['res4'] = x 225 | 226 | x = self.layer4(x) # 1/32 227 | outputs['res5'] = x 228 | 229 | return outputs 230 | 231 | def forward(self, x): 232 | return self._forward_impl(x) 233 | 234 | 235 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 236 | model = ResNet(block, layers, **kwargs) 237 | if pretrained: 238 | state_dict = load_state_dict_from_url(model_urls[arch], 239 | progress=progress) 240 | model.load_state_dict(state_dict, strict=False) 241 | return model 242 | 243 | 244 | def resnet18(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-18 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | progress (bool): If True, displays a progress bar of the download to stderr 250 | """ 251 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 252 | **kwargs) 253 | 254 | 255 | def resnet34(pretrained=False, progress=True, **kwargs): 256 | r"""ResNet-34 model from 257 | `"Deep Residual Learning for Image Recognition" `_ 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 263 | **kwargs) 264 | 265 | 266 | def resnet50(pretrained=False, progress=True, **kwargs): 267 | r"""ResNet-50 model from 268 | `"Deep Residual Learning for Image Recognition" `_ 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | progress (bool): If True, displays a progress bar of the download to stderr 272 | """ 273 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 274 | **kwargs) 275 | 276 | 277 | def resnet101(pretrained=False, progress=True, **kwargs): 278 | r"""ResNet-101 model from 279 | `"Deep Residual Learning for Image Recognition" `_ 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 285 | **kwargs) 286 | 287 | 288 | def resnet152(pretrained=False, progress=True, **kwargs): 289 | r"""ResNet-152 model from 290 | `"Deep Residual Learning for Image Recognition" `_ 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 296 | **kwargs) 297 | 298 | 299 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 300 | r"""ResNeXt-50 32x4d model from 301 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 302 | Args: 303 | pretrained (bool): If True, returns a model pre-trained on ImageNet 304 | progress (bool): If True, displays a progress bar of the download to stderr 305 | """ 306 | kwargs['groups'] = 32 307 | kwargs['width_per_group'] = 4 308 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 309 | pretrained, progress, **kwargs) 310 | 311 | 312 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 313 | r"""ResNeXt-101 32x8d model from 314 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | progress (bool): If True, displays a progress bar of the download to stderr 318 | """ 319 | kwargs['groups'] = 32 320 | kwargs['width_per_group'] = 8 321 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 322 | pretrained, progress, **kwargs) 323 | 324 | 325 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 326 | r"""Wide ResNet-50-2 model from 327 | `"Wide Residual Networks" `_ 328 | The model is the same as ResNet except for the bottleneck number of channels 329 | which is twice larger in every block. The number of channels in outer 1x1 330 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 331 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | kwargs['width_per_group'] = 64 * 2 337 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 338 | pretrained, progress, **kwargs) 339 | 340 | 341 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 342 | r"""Wide ResNet-101-2 model from 343 | `"Wide Residual Networks" `_ 344 | The model is the same as ResNet except for the bottleneck number of channels 345 | which is twice larger in every block. The number of channels in outer 1x1 346 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 347 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 348 | Args: 349 | pretrained (bool): If True, returns a model pre-trained on ImageNet 350 | progress (bool): If True, displays a progress bar of the download to stderr 351 | """ 352 | kwargs['width_per_group'] = 64 * 2 353 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 354 | pretrained, progress, **kwargs) 355 | -------------------------------------------------------------------------------- /models/panoptic_deeplab.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Panoptic-DeepLab decoder. 3 | # Written by Bowen Cheng (bcheng9@illinois.edu) 4 | # ------------------------------------------------------------------------------ 5 | 6 | from collections import OrderedDict 7 | from functools import partial 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from utils.utils import refine_label_generation, refine_label_generation_with_point 14 | 15 | ################################################################################################ 16 | 17 | def basic_conv(in_planes, out_planes, kernel_size, stride=1, padding=1, groups=1, 18 | with_bn=True, with_relu=True): 19 | """convolution with bn and relu""" 20 | module = [] 21 | has_bias = not with_bn 22 | module.append( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, 24 | bias=has_bias) 25 | ) 26 | if with_bn: 27 | module.append(nn.BatchNorm2d(out_planes)) 28 | if with_relu: 29 | module.append(nn.ReLU()) 30 | return nn.Sequential(*module) 31 | 32 | 33 | def depthwise_separable_conv(in_planes, out_planes, kernel_size, stride=1, padding=1, groups=1, 34 | with_bn=True, with_relu=True): 35 | """depthwise separable convolution with bn and relu""" 36 | del groups 37 | 38 | module = [] 39 | module.extend([ 40 | basic_conv(in_planes, in_planes, kernel_size, stride, padding, groups=in_planes, 41 | with_bn=True, with_relu=True), 42 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 43 | ]) 44 | if with_bn: 45 | module.append(nn.BatchNorm2d(out_planes)) 46 | if with_relu: 47 | module.append(nn.ReLU()) 48 | return nn.Sequential(*module) 49 | 50 | 51 | def stacked_conv(in_planes, out_planes, kernel_size, num_stack, stride=1, padding=1, groups=1, 52 | with_bn=True, with_relu=True, conv_type='basic_conv'): 53 | """stacked convolution with bn and relu""" 54 | if num_stack < 1: 55 | assert ValueError('`num_stack` has to be a positive integer.') 56 | if conv_type == 'basic_conv': 57 | conv = partial(basic_conv, out_planes=out_planes, kernel_size=kernel_size, stride=stride, 58 | padding=padding, groups=groups, with_bn=with_bn, with_relu=with_relu) 59 | elif conv_type == 'depthwise_separable_conv': 60 | conv = partial(depthwise_separable_conv, out_planes=out_planes, kernel_size=kernel_size, stride=stride, 61 | padding=padding, groups=1, with_bn=with_bn, with_relu=with_relu) 62 | else: 63 | raise ValueError('Unknown conv_type: {}'.format(conv_type)) 64 | module = [] 65 | module.append(conv(in_planes=in_planes)) 66 | for n in range(1, num_stack): 67 | module.append(conv(in_planes=out_planes)) 68 | return nn.Sequential(*module) 69 | 70 | ################################################################################################ 71 | 72 | class ASPPConv(nn.Sequential): 73 | def __init__(self, in_channels, out_channels, dilation): 74 | modules = [ 75 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 76 | nn.BatchNorm2d(out_channels), 77 | nn.ReLU() 78 | ] 79 | super(ASPPConv, self).__init__(*modules) 80 | 81 | 82 | class ASPPPooling(nn.Module): 83 | def __init__(self, in_channels, out_channels): 84 | super(ASPPPooling, self).__init__() 85 | self.aspp_pooling = nn.Sequential( 86 | nn.AdaptiveAvgPool2d(1), 87 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 88 | nn.ReLU() 89 | ) 90 | 91 | def set_image_pooling(self, pool_size=None): 92 | if pool_size is None: 93 | self.aspp_pooling[0] = nn.AdaptiveAvgPool2d(1) 94 | else: 95 | self.aspp_pooling[0] = nn.AvgPool2d(kernel_size=pool_size, stride=1) 96 | 97 | def forward(self, x): 98 | size = x.shape[-2:] 99 | x = self.aspp_pooling(x) 100 | return F.interpolate(x, size=size, mode='bilinear', align_corners=True) 101 | 102 | 103 | class ASPP(nn.Module): 104 | def __init__(self, in_channels, out_channels, atrous_rates): 105 | super(ASPP, self).__init__() 106 | # out_channels = 256 107 | modules = [] 108 | modules.append(nn.Sequential( 109 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 110 | nn.BatchNorm2d(out_channels), 111 | nn.ReLU())) 112 | 113 | rate1, rate2, rate3 = tuple(atrous_rates) 114 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 115 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 116 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 117 | modules.append(ASPPPooling(in_channels, out_channels)) 118 | 119 | self.convs = nn.ModuleList(modules) 120 | 121 | self.project = nn.Sequential( 122 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 123 | nn.BatchNorm2d(out_channels), 124 | nn.ReLU(), 125 | nn.Dropout(0.5)) 126 | 127 | def set_image_pooling(self, pool_size): 128 | self.convs[-1].set_image_pooling(pool_size) 129 | 130 | def forward(self, x): 131 | res = [] 132 | for conv in self.convs: 133 | res.append(conv(x)) 134 | res = torch.cat(res, dim=1) 135 | return self.project(res) 136 | 137 | 138 | ################################################################################################ 139 | 140 | class SinglePanopticDeepLabDecoder(nn.Module): 141 | def __init__(self, in_channels, feature_key, low_level_channels, low_level_key, low_level_channels_project, 142 | decoder_channels, atrous_rates, aspp_channels=None): 143 | super(SinglePanopticDeepLabDecoder, self).__init__() 144 | if aspp_channels is None: 145 | aspp_channels = decoder_channels 146 | self.aspp = ASPP(in_channels, out_channels=aspp_channels, atrous_rates=atrous_rates) 147 | self.feature_key = feature_key 148 | self.decoder_stage = len(low_level_channels) 149 | assert self.decoder_stage == len(low_level_key) 150 | assert self.decoder_stage == len(low_level_channels_project) 151 | self.low_level_key = low_level_key 152 | fuse_conv = partial(stacked_conv, kernel_size=5, num_stack=1, padding=2, 153 | conv_type='depthwise_separable_conv') 154 | 155 | # Transform low-level feature 156 | project = [] 157 | # Fuse 158 | fuse = [] 159 | # Top-down direction, i.e. starting from largest stride 160 | for i in range(self.decoder_stage): 161 | project.append( 162 | nn.Sequential( 163 | nn.Conv2d(low_level_channels[i], low_level_channels_project[i], 1, bias=False), 164 | nn.BatchNorm2d(low_level_channels_project[i]), 165 | nn.ReLU() 166 | ) 167 | ) 168 | if i == 0: 169 | fuse_in_channels = aspp_channels + low_level_channels_project[i] 170 | else: 171 | fuse_in_channels = decoder_channels + low_level_channels_project[i] 172 | fuse.append( 173 | fuse_conv( 174 | fuse_in_channels, 175 | decoder_channels, 176 | ) 177 | ) 178 | self.project = nn.ModuleList(project) 179 | self.fuse = nn.ModuleList(fuse) 180 | 181 | def set_image_pooling(self, pool_size): 182 | self.aspp.set_image_pooling(pool_size) 183 | 184 | def forward(self, features): 185 | x = features[self.feature_key] 186 | x = self.aspp(x) 187 | 188 | # build decoder 189 | for i in range(self.decoder_stage): 190 | l = features[self.low_level_key[i]] 191 | l = self.project[i](l) 192 | x = F.interpolate(x, size=l.size()[2:], mode='bilinear', align_corners=True) 193 | x = torch.cat((x, l), dim=1) 194 | x = self.fuse[i](x) 195 | 196 | return x 197 | 198 | 199 | class SinglePanopticDeepLabHead(nn.Module): 200 | def __init__(self, decoder_channels, head_channels, num_classes, class_key): 201 | super(SinglePanopticDeepLabHead, self).__init__() 202 | fuse_conv = partial(stacked_conv, kernel_size=5, num_stack=1, padding=2, 203 | conv_type='depthwise_separable_conv') 204 | 205 | self.num_head = len(num_classes) 206 | assert self.num_head == len(class_key) 207 | 208 | classifier = {} 209 | for i in range(self.num_head): 210 | classifier[class_key[i]] = nn.Sequential( 211 | fuse_conv( 212 | decoder_channels, 213 | head_channels[i], 214 | ), 215 | nn.Conv2d(head_channels[i], num_classes[i], 1) 216 | ) 217 | self.classifier = nn.ModuleDict(classifier) 218 | self.class_key = class_key 219 | 220 | def forward(self, x): 221 | pred = OrderedDict() 222 | # build classifier 223 | for key in self.class_key: 224 | pred[key] = self.classifier[key](x) 225 | 226 | return pred 227 | 228 | 229 | class PanopticDeepLabDecoder(nn.Module): 230 | def __init__(self, in_channels, feature_key, low_level_channels, low_level_key, low_level_channels_project, 231 | decoder_channels, atrous_rates, num_classes, instance_head_kwargs, **kwargs): 232 | super(PanopticDeepLabDecoder, self).__init__() 233 | 234 | self.semantic_decoder = SinglePanopticDeepLabDecoder(in_channels, feature_key, low_level_channels, 235 | low_level_key, low_level_channels_project, 236 | decoder_channels, atrous_rates) 237 | self.semantic_head = SinglePanopticDeepLabHead(decoder_channels, [decoder_channels], [num_classes], ['seg']) 238 | 239 | 240 | # Build instance decoder 241 | instance_decoder_kwargs = dict( 242 | in_channels=in_channels, 243 | feature_key=feature_key, 244 | low_level_channels=low_level_channels, 245 | low_level_key=low_level_key, 246 | low_level_channels_project=(64, 32, 16), 247 | decoder_channels=128, 248 | atrous_rates=atrous_rates, 249 | aspp_channels=256 250 | ) 251 | self.instance_decoder = SinglePanopticDeepLabDecoder(**instance_decoder_kwargs) 252 | 253 | self.instance_head = SinglePanopticDeepLabHead(**instance_head_kwargs) 254 | 255 | def set_image_pooling(self, pool_size): 256 | self.semantic_decoder.set_image_pooling(pool_size) 257 | self.instance_decoder.set_image_pooling(pool_size) 258 | 259 | def forward(self, features): 260 | pred = OrderedDict() 261 | 262 | # Semantic branch 263 | semantic = self.semantic_decoder(features) 264 | semantic = self.semantic_head(semantic) 265 | for key in semantic.keys(): 266 | pred[key] = semantic[key] 267 | 268 | # Instance branch 269 | instance = self.instance_decoder(features) 270 | instance = self.instance_head(instance) 271 | for key in instance.keys(): 272 | pred[key] = instance[key] 273 | 274 | return pred 275 | 276 | ################################################################################################ 277 | 278 | class BaseSegmentationModel(nn.Module): 279 | """ 280 | Base class for segmentation models. 281 | Arguments: 282 | backbone: A nn.Module of backbone model. 283 | decoder: A nn.Module of decoder. 284 | """ 285 | def __init__(self, backbone, decoder, args): 286 | super(BaseSegmentationModel, self).__init__() 287 | self.backbone = backbone 288 | self.decoder = decoder 289 | self.args = args 290 | 291 | def _init_params(self, ): 292 | # Backbone is already initialized (either from pre-trained checkpoint or random init). 293 | for m in self.decoder.modules(): 294 | if isinstance(m, nn.Conv2d): 295 | nn.init.normal_(m.weight, mean=0.0, std=0.001) 296 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 297 | nn.init.constant_(m.weight, 1) 298 | nn.init.constant_(m.bias, 0) 299 | 300 | def set_image_pooling(self, pool_size): 301 | self.decoder.set_image_pooling(pool_size) 302 | 303 | def _upsample_predictions(self, pred, input_shape): 304 | """Upsamples final prediction. 305 | Args: 306 | pred (dict): stores all output of the segmentation model. 307 | input_shape (tuple): spatial resolution of the desired shape. 308 | Returns: 309 | result (OrderedDict): upsampled dictionary. 310 | """ 311 | result = OrderedDict() 312 | for key in pred.keys(): 313 | out = F.interpolate(pred[key], size=input_shape, mode='bilinear', align_corners=True) 314 | result[key] = out 315 | return result 316 | 317 | def forward(self, x, seg_map=None, label=None, point_list=None, target_shape=None): 318 | if target_shape is None: 319 | target_shape = x.shape[-2:] 320 | 321 | # contract: features is a dict of tensors 322 | features = self.backbone(x) 323 | pred = self.decoder(features) 324 | results = self._upsample_predictions(pred, target_shape) 325 | 326 | if label is not None: # refined label generation 327 | 328 | if self.args.sup == 'point': # point supervision setting 329 | pseudo_label = refine_label_generation_with_point( 330 | results['seg'].clone().detach(), 331 | point_list.cpu().numpy(), 332 | results['offset'].clone().detach(), 333 | label.clone().detach(), 334 | seg_map.clone().detach(), 335 | self.args, 336 | ) 337 | 338 | else: # image-level supervision setting 339 | pseudo_label = refine_label_generation( 340 | results['seg'].clone().detach(), 341 | results['center'].clone().detach(), 342 | results['offset'].clone().detach(), 343 | label.clone().detach(), 344 | seg_map.clone().detach(), 345 | self.args, 346 | ) 347 | 348 | return results, pseudo_label 349 | 350 | return results 351 | 352 | 353 | def PanopticDeepLab(backbone, args): 354 | 355 | instance_head_kwargs = dict( 356 | decoder_channels=128, 357 | head_channels=(128, 32), 358 | num_classes=(args.num_classes, 2), 359 | class_key=["center", "offset"], 360 | ) 361 | 362 | decoder = PanopticDeepLabDecoder(in_channels=2048, 363 | feature_key="res5", 364 | low_level_channels=(1024, 512, 256), 365 | low_level_key=["res4", "res3", "res2"], 366 | low_level_channels_project=(128, 64, 32), 367 | decoder_channels=256, 368 | atrous_rates=(3, 6, 9), 369 | num_classes=args.num_classes+1, 370 | instance_head_kwargs=instance_head_kwargs 371 | ) 372 | 373 | model = BaseSegmentationModel(backbone=backbone, decoder=decoder, args=args) 374 | 375 | model._init_params() 376 | 377 | # set batchnorm momentum 378 | for module in model.modules(): 379 | if isinstance(module, torch.nn.BatchNorm2d): 380 | module.momentum = args.bn_momentum 381 | 382 | return model 383 | 384 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | BESTIE 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import argparse 10 | import os 11 | import cv2 12 | import time 13 | import random 14 | import pickle 15 | from tqdm import tqdm 16 | 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | from torch.utils.data import DataLoader 20 | from torch.utils.data.distributed import DistributedSampler 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | 23 | from models import model_factory 24 | from utils.LoadData import get_dataset 25 | from utils.my_optim import WarmupPolyLR 26 | from utils.loss import Weighted_L1_Loss, Weighted_MSELoss, DeepLabCE 27 | from utils.utils import AverageMeter, get_ins_map, get_ins_map_with_point 28 | 29 | import chainercv 30 | from chainercv.datasets import VOCInstanceSegmentationDataset 31 | from chainercv.evaluations import eval_instance_segmentation_voc 32 | 33 | def str2bool(v): 34 | return v.lower() in ("yes", "y", "true", "t", "1") 35 | 36 | def parse(): 37 | 38 | parser = argparse.ArgumentParser(description='BESTIE pytorch implementation') 39 | parser.add_argument("--root_dir", type=str, default='', help='Root dir for the project') 40 | parser.add_argument('--sup', type=str, help='supervision source', choices=["cls", "point"]) 41 | parser.add_argument("--dataset", type=str, default='voc', choices=["voc", "coco"]) 42 | parser.add_argument("--backbone", type=str, default='resnet50', choices=["resnet50", "resnet101", "hrnet34", "hrnet48"]) 43 | parser.add_argument("--batch_size", type=int, default=16) 44 | parser.add_argument("--crop_size", type=int, default=416) 45 | parser.add_argument("--num_classes", type=int, default=20) 46 | parser.add_argument("--lr", type=float, default=5e-5) 47 | parser.add_argument("--weight_decay", type=float, default=0) 48 | parser.add_argument("--train_iter", type=int, default=50000) 49 | parser.add_argument("--warm_iter", type=int, default=2000, help='warm-up iterations') 50 | parser.add_argument("--train_epoch", type=int, default=0) 51 | parser.add_argument("--num_workers", type=int, default=4) 52 | parser.add_argument('--resume', default=None, type=str, help='weight restore') 53 | 54 | parser.add_argument('--save_folder', default='checkpoints/test1', help='Location to save checkpoint models') 55 | parser.add_argument('--print_freq', default=200, type=int, help='interval of showing training conditions') 56 | parser.add_argument('--save_freq', default=10000, type=int, help='interval of save checkpoint models') 57 | parser.add_argument("--cur_iter", type=int, default=0, help='current training interations') 58 | 59 | parser.add_argument("--gamma", type=float, default=0.9, help='learning rate decay power') 60 | parser.add_argument("--pseudo_thresh", type=float, default=0.7, help='threshold for pseudo-label generation') 61 | parser.add_argument("--refine_thresh", type=float, default=0.3, help='threshold for refined-label generation') 62 | parser.add_argument("--kernel", type=int, default=41, help='kernel size for point extraction') 63 | parser.add_argument("--sigma", type=int, default=6, help='sigma of 2D gaussian kernel') 64 | parser.add_argument("--beta", type=float, default=3.0, help='parameter for center-clustering') 65 | parser.add_argument("--bn_momentum", type=float, default=0.01) 66 | parser.add_argument('--refine', type=str2bool, default=True, help='enable self-refinement.') 67 | parser.add_argument("--refine_iter", type=int, default=0, help='self-refinement running iteration') 68 | parser.add_argument("--seg_weight", type=float, default=1.0, help='loss weight for segmantic segmentation map') 69 | parser.add_argument("--center_weight", type=float, default=200.0, help='loss weight for center map') 70 | parser.add_argument("--offset_weight", type=float, default=0.01, help='loss weight for offset map') 71 | 72 | parser.add_argument('--val_freq', default=1000, type=int, help='interval of model validation') 73 | parser.add_argument("--val_thresh", type=float, default=0.1, help='threhsold for instance-groupping in validation phase') 74 | parser.add_argument("--val_kernel", type=int, default=41, help='kernsl size for point extraction in validation phase') 75 | parser.add_argument('--val_flip', type=str2bool, default=True, help='enable flip test-time augmentation in vadliation phase') 76 | parser.add_argument('--val_clean', type=str2bool, default=False, help='cleaning pseudo-labels using image-level labels') 77 | parser.add_argument('--val_ignore', type=str2bool, default=False, help='ignore') 78 | 79 | parser.add_argument("--random_seed", type=int, default=1) 80 | parser.add_argument("--gpu", type=int, default=0) 81 | parser.add_argument("--world_size", type=int, default=1) 82 | parser.add_argument("--local_rank", type=int, default=int(os.environ["LOCAL_RANK"])) 83 | 84 | return parser.parse_args() 85 | 86 | def print_func(string): 87 | if torch.distributed.get_rank() == 0: 88 | print(string) 89 | 90 | def save_checkpoint(save_path, model): 91 | if torch.distributed.get_rank() == 0: 92 | print('\nSaving state: %s\n' % save_path) 93 | state = { 94 | 'model': model.module.state_dict(), 95 | } 96 | torch.save(state, save_path) 97 | 98 | 99 | def train(): 100 | 101 | batch_time = AverageMeter() 102 | avg_total_loss = AverageMeter() 103 | avg_seg_loss = AverageMeter() 104 | avg_pseudo_center_loss = AverageMeter() 105 | avg_refine_center_loss = AverageMeter() 106 | avg_pseudo_offset_loss = AverageMeter() 107 | avg_refine_offset_loss = AverageMeter() 108 | 109 | best_AP = -1 110 | 111 | model.train() 112 | start = time.time() 113 | end = time.time() 114 | epoch = 0 115 | 116 | for cur_iter in range(1, args.train_iter+1): 117 | 118 | try: 119 | img, label, seg_map, center_map, offset_map, weight, point_list = next(data_iter) 120 | except Exception as e: 121 | print_func(" [LOADER ERROR] " + str(e)) 122 | 123 | epoch += 1 124 | data_iter = iter(train_loader) 125 | img, label, seg_map, center_map, offset_map, weight, point_list = next(data_iter) 126 | 127 | end = time.time() 128 | batch_time.reset() 129 | avg_total_loss.reset() 130 | avg_seg_loss.reset() 131 | avg_pseudo_center_loss.reset() 132 | avg_refine_center_loss.reset() 133 | avg_pseudo_offset_loss.reset() 134 | avg_refine_offset_loss.reset() 135 | 136 | img = img.to(device, non_blocking=True) 137 | label = label.to(device, non_blocking=True) 138 | seg_map = seg_map.to(device, non_blocking=True) 139 | center_map = center_map.to(device, non_blocking=True) 140 | offset_map = offset_map.to(device, non_blocking=True) 141 | weight = weight.to(device, non_blocking=True) 142 | 143 | run_refine = args.refine and (cur_iter > args.refine_iter) 144 | 145 | if run_refine: 146 | out, c_label = model(img, seg_map, label, point_list) 147 | else: 148 | out = model(img) 149 | 150 | seg_loss = criterion['seg'](out['seg'], seg_map) * args.seg_weight 151 | center_loss_1 = criterion['center'](out['center'], center_map, weight) * args.center_weight 152 | offset_loss_1 = criterion['offset'](out['offset'], offset_map, weight) * args.offset_weight 153 | 154 | center_loss_2 = center_loss_1 155 | offset_loss_2 = offset_loss_1 156 | 157 | if run_refine and args.sup == 'cls': 158 | center_loss_2 = criterion['center'](out['center'], c_label['center'], 159 | c_label['weight']) * args.center_weight 160 | 161 | if run_refine: 162 | offset_loss_2 = criterion['offset'](out['offset'], c_label['offset'], 163 | c_label['weight']) * args.offset_weight 164 | 165 | loss = seg_loss + (center_loss_1 + center_loss_2)*0.5 + (offset_loss_1 + offset_loss_2)*0.5 166 | 167 | # compute gradient and backward 168 | optimizer.zero_grad() 169 | loss.backward() 170 | optimizer.step() 171 | lr_scheduler.step() 172 | 173 | batch_time.update((time.time() - end)) 174 | end = time.time() 175 | 176 | avg_total_loss.update(loss.item(), img.size(0)) 177 | avg_seg_loss.update(seg_loss.item(), img.size(0)) 178 | avg_pseudo_center_loss.update(center_loss_1.item(), img.size(0)) 179 | avg_refine_center_loss.update(center_loss_2.item(), img.size(0)) 180 | avg_pseudo_offset_loss.update(offset_loss_1.item(), img.size(0)) 181 | avg_refine_offset_loss.update(offset_loss_2.item(), img.size(0)) 182 | 183 | if cur_iter % args.print_freq == 0: 184 | batch_time.synch(device) 185 | avg_total_loss.synch(device) 186 | avg_seg_loss.synch(device) 187 | avg_pseudo_center_loss.synch(device) 188 | avg_refine_center_loss.synch(device) 189 | avg_pseudo_offset_loss.synch(device) 190 | avg_refine_offset_loss.synch(device) 191 | 192 | if args.local_rank == 0: 193 | print('Progress: [{0}][{1}/{2}] ({3:.1f}%, {4:.1f} min) | ' 194 | 'Time: {5:.1f} ms | ' 195 | 'Left: {6:.1f} min | ' 196 | 'TotalLoss: {7:.4f} | ' 197 | 'SegLoss: {8:.4f} | ' 198 | 'centerLoss: {9:.4f} ({10:.4f} + {11:.4f}) | ' 199 | 'OffsetLoss: {12:.4f} ({13:.4f} + {14:.4f}) '.format( 200 | epoch, cur_iter, args.train_iter, 201 | cur_iter/args.train_iter*100, (end-start) / 60, 202 | batch_time.avg * 1000, (args.train_iter - cur_iter) * batch_time.avg / 60, 203 | avg_total_loss.avg, avg_seg_loss.avg, 204 | avg_pseudo_center_loss.avg + avg_refine_center_loss.avg, 205 | avg_pseudo_center_loss.avg, avg_refine_center_loss.avg, 206 | avg_pseudo_offset_loss.avg + avg_refine_offset_loss.avg, 207 | avg_pseudo_offset_loss.avg, avg_refine_offset_loss.avg, 208 | ) 209 | ) 210 | 211 | if args.local_rank == 0 and cur_iter % args.save_freq == 0: 212 | save_path = os.path.join(args.save_folder, 'last.pt') 213 | save_checkpoint(save_path, model) 214 | 215 | if cur_iter % args.val_freq == 0: 216 | val_score = validate() 217 | 218 | if args.local_rank == 0 and val_score['map'] > best_AP: 219 | best_AP = val_score['map'] 220 | print('\n Best mAP50, iteration : %d, mAP50 : %.2f \n' % (cur_iter, best_AP)) 221 | 222 | save_path = os.path.join(args.save_folder, 'best.pt') 223 | save_checkpoint(save_path, model) 224 | 225 | end = time.time() 226 | 227 | if args.local_rank == 0: 228 | print('\n training done') 229 | save_path = os.path.join(args.save_folder, 'last.pt') 230 | save_checkpoint(save_path, model) 231 | 232 | val_score = validate() 233 | 234 | if args.local_rank == 0 and val_score['map'] > best_AP: 235 | best_AP = val_score['map'] 236 | print('\n Best mAP50, iteration : %d, mAP50 : %.2f \n' % (cur_iter, best_AP)) 237 | 238 | model_file = os.path.join(args.save_folder, 'best.pt') 239 | save_checkpoint(save_path, model) 240 | 241 | 242 | def validate(): 243 | model.eval() 244 | 245 | pred_seg_maps, pred_labels, pred_masks, pred_scores = [], [], [], [] 246 | val_dir = "val_temp_dir" 247 | if args.local_rank == 0: 248 | os.makedirs(val_dir, exist_ok=True) 249 | 250 | torch.distributed.barrier() 251 | for img, cls_label, points, fname, tsize in tqdm(valid_loader): 252 | target_size = int(tsize[0]), int(tsize[1]) 253 | 254 | if args.val_flip: 255 | img = torch.cat( [img, img.flip(-1)] , dim=0) 256 | 257 | out = model(img.to(device), target_shape=target_size) 258 | 259 | if args.sup == 'point' and args.val_clean: 260 | pred_seg, pred_label, pred_mask, pred_score = get_ins_map_with_point(out, 261 | cls_label, 262 | points, 263 | target_size, 264 | device, 265 | args) 266 | else: 267 | pred_seg, pred_label, pred_mask, pred_score = get_ins_map(out, 268 | cls_label, 269 | target_size, 270 | device, 271 | args) 272 | 273 | with open(f'{val_dir}/{fname[0]}.pickle', 'wb') as f: 274 | pickle.dump({ 275 | 'pred_label': pred_label, 276 | 'pred_mask': pred_mask, 277 | 'pred_score': pred_score, 278 | }, f) 279 | 280 | torch.distributed.barrier() 281 | 282 | ap_result = {"ap": None, "map": None} 283 | 284 | if args.local_rank == 0: 285 | pred_masks, pred_labels, pred_scores = [], [], [] 286 | 287 | for fname in ins_gt_ids: 288 | with open(f'{val_dir}/{fname}.pickle', 'rb') as f: 289 | dat = pickle.load(f) 290 | pred_masks.append(dat['pred_mask']) 291 | pred_labels.append(dat['pred_label']) 292 | pred_scores.append(dat['pred_score']) 293 | 294 | ap_result = eval_instance_segmentation_voc(pred_masks, 295 | pred_labels, 296 | pred_scores, 297 | ins_gt_masks, 298 | ins_gt_labels, 299 | iou_thresh=0.5) 300 | 301 | #print(ap_result) 302 | os.system(f"rm -rf {val_dir}") 303 | 304 | torch.distributed.barrier() 305 | 306 | model.train() 307 | 308 | return ap_result 309 | 310 | 311 | if __name__ == '__main__': 312 | 313 | args = parse() 314 | 315 | torch.manual_seed(args.random_seed) 316 | np.random.seed(args.random_seed) 317 | random.seed(args.random_seed) 318 | 319 | torch.backends.cudnn.benchmark = True 320 | 321 | args.gpu = args.local_rank 322 | torch.cuda.set_device(args.gpu) 323 | 324 | # Init dirstributed system 325 | torch.distributed.init_process_group( 326 | backend="nccl", rank=args.local_rank, world_size=torch.cuda.device_count() 327 | ) 328 | args.world_size = torch.distributed.get_world_size() 329 | device = torch.device(f"cuda:{args.gpu}") 330 | 331 | if args.local_rank == 0: 332 | os.makedirs(args.save_folder, exist_ok=True) 333 | 334 | """ load model """ 335 | model = model_factory(args) 336 | model = model.to(device) 337 | 338 | optimizer = optim.Adam(model.parameters(), 339 | lr=args.lr, 340 | weight_decay=args.weight_decay) 341 | 342 | # define loss function (criterion) and optimizer 343 | criterion = {"center" : Weighted_MSELoss(), 344 | "offset" : Weighted_L1_Loss(), 345 | "seg" : DeepLabCE() 346 | } 347 | 348 | # Optionally resume from a checkpoint 349 | if args.resume: 350 | if os.path.isfile(args.resume): 351 | print_func("=> loading checkpoint '{}'".format(args.resume)) 352 | ckpt = torch.load(args.resume, map_location='cpu')['model'] 353 | model.load_state_dict(new_dict, strict=True) 354 | else: 355 | print_func("=> no checkpoint found at '{}'".format(args.resume)) 356 | 357 | 358 | """ Get data loader """ 359 | train_dataset = get_dataset(args, mode='train') 360 | valid_dataset = get_dataset(args, mode='val') 361 | print_func("number of train set = %d | valid set = %d" % (len(train_dataset), len(valid_dataset))) 362 | 363 | n_gpus = torch.cuda.device_count() 364 | batch_per_gpu = args.batch_size // n_gpus 365 | 366 | train_sampler = DistributedSampler(train_dataset, num_replicas=n_gpus, rank=args.local_rank) 367 | train_loader = DataLoader(train_dataset, 368 | batch_size=batch_per_gpu, 369 | num_workers=args.num_workers, 370 | pin_memory=True, 371 | sampler=train_sampler, 372 | drop_last=True) 373 | 374 | valid_sampler = DistributedSampler(valid_dataset, num_replicas=n_gpus, rank=args.local_rank) 375 | valid_loader = DataLoader(valid_dataset, 376 | batch_size=1, 377 | num_workers=4, 378 | pin_memory=True, 379 | sampler=valid_sampler, 380 | drop_last=False) 381 | 382 | if args.train_epoch != 0: 383 | args.train_iter = args.train_epoch * len(train_loader) 384 | 385 | lr_scheduler = WarmupPolyLR( 386 | optimizer, 387 | args.train_iter, 388 | warmup_iters=args.warm_iter, 389 | power=args.gamma, 390 | ) 391 | 392 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 393 | model = DDP(model, device_ids=[args.gpu]) 394 | 395 | if args.val_freq != 0 and args.local_rank == 0: 396 | print("...Preparing GT dataset for evaluation") 397 | ins_dataset = VOCInstanceSegmentationDataset(split='val', data_dir=args.root_dir) 398 | 399 | ins_gt_ids = ins_dataset.ids 400 | ins_gt_masks = [ins_dataset.get_example_by_keys(i, (1,))[0] for i in range(len(ins_dataset))] 401 | ins_gt_labels = [ins_dataset.get_example_by_keys(i, (2,))[0] for i in range(len(ins_dataset))] 402 | 403 | torch.distributed.barrier() 404 | print_func("...Training Start \n") 405 | print_func(args) 406 | 407 | #validate() 408 | train() 409 | -------------------------------------------------------------------------------- /utils/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | import cv2 4 | import math 5 | import numpy as np 6 | from torchvision.transforms import functional as F 7 | import torch 8 | 9 | class Compose(object): 10 | """ 11 | Composes a sequence of transforms. 12 | Arguments: 13 | transforms: A list of transforms. 14 | """ 15 | def __init__(self, transforms): 16 | self.transforms = transforms 17 | 18 | def __call__(self, image, seg_map, peak): 19 | for t in self.transforms: 20 | image, seg_map, peak = t(image, seg_map, peak) 21 | return image, seg_map, peak 22 | 23 | def __repr__(self): 24 | format_string = self.__class__.__name__ + "(" 25 | for t in self.transforms: 26 | format_string += "\n" 27 | format_string += " {0}".format(t) 28 | format_string += "\n)" 29 | return format_string 30 | 31 | 32 | class ToTensor(object): 33 | """ 34 | Converts image to torch Tensor. 35 | """ 36 | def __call__(self, image, seg_map, peak): 37 | 38 | image = F.to_tensor(image) 39 | 40 | seg_map = np.array(seg_map, dtype=np.uint8, copy=True) 41 | 42 | seg_map = torch.from_numpy(seg_map) 43 | 44 | return image, seg_map, peak 45 | 46 | 47 | class Normalize(object): 48 | """Normalize an tensor image with mean and standard deviation. 49 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 50 | will normalize each channel of the input ``torch.*Tensor`` i.e. 51 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 52 | 53 | Args: 54 | mean (sequence): Sequence of means for each channel. 55 | std (sequence): Sequence of standard deviations for each channel. 56 | """ 57 | 58 | def __init__(self, mean, std): 59 | self.mean = mean 60 | self.std = std 61 | 62 | def __call__(self, image, seg_map, peak): 63 | """ 64 | Args: 65 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 66 | 67 | Returns: 68 | Tensor: Normalized Tensor image. 69 | """ 70 | image = F.normalize(image, self.mean, self.std) 71 | 72 | peak = [ [int(x), int(y), cls, conf] for x, y, cls, conf in peak] 73 | 74 | return image, seg_map, peak 75 | 76 | 77 | class RandomScale(object): 78 | """ 79 | Applies random scale augmentation. 80 | Arguments: 81 | min_scale: Minimum scale value. 82 | max_scale: Maximum scale value. 83 | scale_step_size: The step size from minimum to maximum value. 84 | """ 85 | def __init__(self, min_scale, max_scale, scale_step_size): 86 | self.min_scale = min_scale 87 | self.max_scale = max_scale 88 | self.scale_step_size = scale_step_size 89 | 90 | @staticmethod 91 | def get_random_scale(min_scale_factor, max_scale_factor, step_size): 92 | """Gets a random scale value. 93 | Args: 94 | min_scale_factor: Minimum scale value. 95 | max_scale_factor: Maximum scale value. 96 | step_size: The step size from minimum to maximum value. 97 | Returns: 98 | A random scale value selected between minimum and maximum value. 99 | Raises: 100 | ValueError: min_scale_factor has unexpected value. 101 | """ 102 | if min_scale_factor < 0 or min_scale_factor > max_scale_factor: 103 | raise ValueError('Unexpected value of min_scale_factor.') 104 | 105 | if min_scale_factor == max_scale_factor: 106 | return min_scale_factor 107 | 108 | # When step_size = 0, we sample the value uniformly from [min, max). 109 | if step_size == 0: 110 | return random.uniform(min_scale_factor, max_scale_factor) 111 | 112 | # When step_size != 0, we randomly select one discrete value from [min, max]. 113 | num_steps = int((max_scale_factor - min_scale_factor) / step_size + 1) 114 | scale_factors = np.linspace(min_scale_factor, max_scale_factor, num_steps) 115 | np.random.shuffle(scale_factors) 116 | return scale_factors[0] 117 | 118 | #def __call__(self, image, label): 119 | def __call__(self, image, seg_map, peak): 120 | f_scale = self.get_random_scale(self.min_scale, self.max_scale, self.scale_step_size) 121 | # TODO: cv2 uses align_corner=False 122 | # TODO: use fvcore (https://github.com/facebookresearch/fvcore/blob/master/fvcore/transforms/transform.py#L377) 123 | image_dtype = image.dtype 124 | seg_map_dtype = seg_map.dtype 125 | 126 | image = cv2.resize(image.astype(np.float), None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_LINEAR) 127 | 128 | seg_map = cv2.resize(seg_map.astype(np.float), None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_NEAREST) 129 | 130 | peak = [ [p[0]*f_scale, p[1]*f_scale, p[2], p[3]] for p in peak ] 131 | 132 | return image.astype(image_dtype), seg_map.astype(seg_map_dtype), peak 133 | 134 | 135 | class RandomCrop(object): 136 | """ 137 | Applies random crop augmentation. 138 | Arguments: 139 | crop_h: Integer, crop height size. 140 | crop_w: Integer, crop width size. 141 | pad_value: Tuple, pad value for image, length 3. 142 | ignore_label: Tuple, pad value for label, length could be 1 (semantic) or 3 (panoptic). 143 | random_pad: Bool, when crop size larger than image size, whether to randomly pad four boundaries, 144 | or put image to top-left and only pad bottom and right boundaries. 145 | """ 146 | def __init__(self, crop_h, crop_w, pad_value, ignore_label, random_pad): 147 | self.crop_h = crop_h 148 | self.crop_w = crop_w 149 | self.pad_value = pad_value 150 | self.ignore_label = ignore_label 151 | self.random_pad = random_pad 152 | 153 | def __call__(self, image, seg_map, peak): 154 | img_h, img_w = image.shape[0], image.shape[1] 155 | # save dtype 156 | image_dtype = image.dtype 157 | seg_map_dtype = seg_map.dtype 158 | 159 | # padding 160 | pad_h = max(self.crop_h - img_h, 0) 161 | pad_w = max(self.crop_w - img_w, 0) 162 | if pad_h > 0 or pad_w > 0: 163 | if self.random_pad: 164 | pad_top = random.randint(0, pad_h) 165 | pad_bottom = pad_h - pad_top 166 | pad_left = random.randint(0, pad_w) 167 | pad_right = pad_w - pad_left 168 | else: 169 | pad_top, pad_bottom, pad_left, pad_right = 0, pad_h, 0, pad_w 170 | 171 | img_pad = cv2.copyMakeBorder(image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, 172 | value=self.pad_value) 173 | seg_map_pad = cv2.copyMakeBorder(seg_map, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, 174 | value=self.ignore_label) 175 | 176 | peak = [ [px+pad_left, py+pad_top, cls, conf] for px, py, cls, conf in peak ] 177 | 178 | else: 179 | img_pad, seg_map_pad = image, seg_map 180 | 181 | img_h, img_w = img_pad.shape[0], img_pad.shape[1] 182 | h_off = random.randint(0, img_h - self.crop_h) 183 | w_off = random.randint(0, img_w - self.crop_w) 184 | 185 | image = np.asarray(img_pad[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w], np.float32) 186 | seg_map = np.asarray(seg_map_pad[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w], np.float32) 187 | 188 | peak_crop = [] 189 | for px, py, cls, conf in peak: 190 | if (h_off <= py < h_off + self.crop_h) and (w_off <= px < w_off + self.crop_w): 191 | px_crop, py_crop = px-w_off, py-h_off 192 | 193 | #if (0 <= px_crop < self.crop_w) and (0 <= py_crop < self.crop_h): 194 | peak_crop.append( [px_crop, py_crop, cls, conf] ) 195 | 196 | return image.astype(image_dtype), seg_map.astype(seg_map_dtype), peak_crop 197 | 198 | 199 | class RandomHorizontalFlip(object): 200 | """Horizontally flip the given PIL Image randomly with a probability of 0.5.""" 201 | 202 | def __call__(self, image, seg_map, peak): 203 | 204 | if random.random() < 0.5: 205 | image = torch.flip(image, [2]) 206 | seg_map = torch.flip(seg_map, [1]) 207 | 208 | _, H, W = image.shape 209 | 210 | peak = [ [W-p[0]-1, p[1], p[2], p[3]] for p in peak ] 211 | 212 | return image, seg_map, peak 213 | 214 | 215 | class Resize(object): 216 | """ 217 | Applies random scale augmentation. 218 | Reference: https://github.com/tensorflow/models/blob/master/research/deeplab/input_preprocess.py#L28 219 | Arguments: 220 | min_resize_value: Desired size of the smaller image side, no resize if set to None 221 | max_resize_value: Maximum allowed size of the larger image side, no limit if set to None 222 | resize_factor: Resized dimensions are multiple of factor plus one. 223 | keep_aspect_ratio: Boolean, keep aspect ratio or not. If True, the input 224 | will be resized while keeping the original aspect ratio. If False, the 225 | input will be resized to [max_resize_value, max_resize_value] without 226 | keeping the original aspect ratio. 227 | align_corners: If True, exactly align all 4 corners of input and output. 228 | """ 229 | def __init__(self, min_resize_value=None, max_resize_value=None, resize_factor=None, 230 | keep_aspect_ratio=True, align_corners=False): 231 | if min_resize_value is not None and min_resize_value < 0: 232 | min_resize_value = None 233 | if max_resize_value is not None and max_resize_value < 0: 234 | max_resize_value = None 235 | if resize_factor is not None and resize_factor < 0: 236 | resize_factor = None 237 | self.min_resize_value = min_resize_value 238 | self.max_resize_value = max_resize_value 239 | self.resize_factor = resize_factor 240 | self.keep_aspect_ratio = keep_aspect_ratio 241 | self.align_corners = align_corners 242 | 243 | if self.align_corners: 244 | warnings.warn('`align_corners = True` is not supported by opencv.') 245 | 246 | if self.max_resize_value is not None: 247 | # Modify the max_size to be a multiple of factor plus 1 and make sure the max dimension after resizing 248 | # is no larger than max_size. 249 | if self.resize_factor is not None: 250 | self.max_resize_value = (self.max_resize_value - (self.max_resize_value - 1) % self.resize_factor) 251 | 252 | def __call__(self, image, seg_map, peak): 253 | if self.min_resize_value is None: 254 | return image, label 255 | [orig_height, orig_width, _] = image.shape 256 | orig_min_size = np.minimum(orig_height, orig_width) 257 | 258 | # Calculate the larger of the possible sizes 259 | large_scale_factor = self.min_resize_value / orig_min_size 260 | large_height = int(math.floor(orig_height * large_scale_factor)) 261 | large_width = int(math.floor(orig_width * large_scale_factor)) 262 | large_size = np.array([large_height, large_width]) 263 | 264 | new_size = large_size 265 | if self.max_resize_value is not None: 266 | # Calculate the smaller of the possible sizes, use that if the larger is too big. 267 | orig_max_size = np.maximum(orig_height, orig_width) 268 | small_scale_factor = self.max_resize_value / orig_max_size 269 | small_height = int(math.floor(orig_height * small_scale_factor)) 270 | small_width = int(math.floor(orig_width * small_scale_factor)) 271 | small_size = np.array([small_height, small_width]) 272 | 273 | if np.max(large_size) > self.max_resize_value: 274 | new_size = small_size 275 | 276 | # Ensure that both output sides are multiples of factor plus one. 277 | if self.resize_factor is not None: 278 | new_size += (self.resize_factor - (new_size - 1) % self.resize_factor) % self.resize_factor 279 | # If new_size exceeds largest allowed size 280 | new_size[new_size > self.max_resize_value] -= self.resize_factor 281 | 282 | if not self.keep_aspect_ratio: 283 | # If not keep the aspect ratio, we resize everything to max_size, allowing 284 | # us to do pre-processing without extra padding. 285 | new_size = [np.max(new_size), np.max(new_size)] 286 | 287 | # TODO: cv2 uses align_corner=False 288 | # TODO: use fvcore (https://github.com/facebookresearch/fvcore/blob/master/fvcore/transforms/transform.py#L377) 289 | image_dtype = image.dtype 290 | seg_map_dtype = seg_map.dtype 291 | 292 | image = cv2.resize(image.astype(np.float), (new_size[1], new_size[0]), interpolation=cv2.INTER_LINEAR) 293 | seg_map = cv2.resize(seg_map.astype(np.float), (new_size[1], new_size[0]), interpolation=cv2.INTER_NEAREST) 294 | 295 | peak = [ [int(p[0]/orig_width*new_size[1]), int(p[1]/orig_height*new_size[0]), p[2], p[3]] for p in peak ] 296 | 297 | return image.astype(image_dtype), seg_map.astype(seg_map_dtype), peak 298 | 299 | 300 | 301 | class RandomContrast(object): 302 | def __init__(self, lower=0.5, upper=1.5): 303 | self.lower = lower 304 | self.upper = upper 305 | assert self.upper >= self.lower, "contrast upper must be >= lower." 306 | assert self.lower >= 0, "contrast lower must be non-negative." 307 | 308 | # expects float image 309 | def __call__(self, image, seg_map, peak): 310 | if random.random() < 0.5: 311 | alpha = random.uniform(self.lower, self.upper) 312 | image *= alpha 313 | image[image>255] = 255 314 | image[image<0] = 0 315 | 316 | return image, seg_map, peak 317 | 318 | class RandomBrightness(object): 319 | def __init__(self, delta=16): 320 | assert delta >= 0.0 321 | assert delta <= 255.0 322 | self.delta = delta 323 | 324 | def __call__(self, image, seg_map, peak): 325 | if random.random() < 0.5: 326 | delta = random.uniform(-self.delta, self.delta) 327 | image += delta 328 | image[image>255] = 255 329 | image[image<0] = 0 330 | 331 | return image, seg_map, peak 332 | 333 | class RandomHue(object): 334 | def __init__(self, delta=36.0): 335 | assert delta >= 0.0 and delta <= 360.0 336 | self.delta = delta 337 | 338 | def __call__(self, image, seg_map, peak): 339 | if random.random() < 0.5: 340 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 341 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 342 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 343 | return image, seg_map, peak 344 | 345 | 346 | class RandomLightingNoise(object): 347 | def __init__(self): 348 | self.perms = ((0, 1, 2), (0, 2, 1), 349 | (1, 0, 2), (1, 2, 0), 350 | (2, 0, 1), (2, 1, 0)) 351 | 352 | def __call__(self, image, seg_map, peak): 353 | if random.random() < 0.5: 354 | swap = self.perms[random.randint(len(self.perms))] 355 | shuffle = SwapChannels(swap) # shuffle channels 356 | image = shuffle(image) 357 | return image, seg_map, peak 358 | 359 | 360 | class ConvertColor(object): 361 | def __init__(self, current='BGR', transform='HSV'): 362 | self.transform = transform 363 | self.current = current 364 | 365 | def __call__(self, image, seg_map, peak): 366 | if self.current == 'RGB' and self.transform == 'HSV': 367 | image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) 368 | elif self.current == 'HSV' and self.transform == 'RGB': 369 | image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB) 370 | else: 371 | raise NotImplementedError 372 | return image, seg_map, peak 373 | 374 | class PhotometricDistort(object): 375 | def __init__(self): 376 | self.pd = [ 377 | RandomContrast(), 378 | ConvertColor(current='RGB', transform='HSV'), 379 | RandomHue(), 380 | ConvertColor(current='HSV', transform='RGB'), 381 | RandomContrast() 382 | ] 383 | self.rand_brightness = RandomBrightness() 384 | self.rand_light_noise = RandomLightingNoise() 385 | 386 | def __call__(self, image, seg_map, peak): 387 | #m = image.copy() 388 | im = image.copy().astype(np.float32) 389 | 390 | im, seg_map, peak = self.rand_brightness(im, seg_map, peak) 391 | 392 | if random.random() < 0.5: 393 | distort = Compose(self.pd[:-1]) 394 | else: 395 | distort = Compose(self.pd[1:]) 396 | 397 | im, seg_map, peak = distort(im, seg_map, peak) 398 | # im, boxes, labels = self.rand_light_noise(im, boxes, labels) 399 | 400 | im = np.clip(im, 0, 255).astype(np.uint8) 401 | 402 | return im, seg_map, peak 403 | 404 | 405 | class Keepsize(object): 406 | """Resize the input PIL Image to the given size. 407 | 408 | Args: 409 | size (sequence or int): Desired output size. If size is a sequence like 410 | (h, w), output size will be matched to this. If size is an int, 411 | smaller edge of the image will be matched to this number. 412 | i.e, if height > width, then image will be rescaled to 413 | (size * height / width, size) 414 | interpolation (int, optional): Desired interpolation. Default is 415 | ``PIL.Image.BILINEAR`` 416 | """ 417 | 418 | def __init__(self, shape=None): 419 | self.shape = shape 420 | 421 | pass 422 | 423 | def __call__(self, img, seg_map, peak): 424 | """ 425 | Args: 426 | img (PIL Image): Image to be scaled. 427 | 428 | Returns: 429 | PIL Image: Rescaled image. 430 | """ 431 | ori_h, ori_w, _ = img.shape 432 | 433 | if self.shape is None: 434 | new_h = (ori_h + 31) // 32 * 32 435 | new_w = (ori_w + 31) // 32 * 32 436 | else: 437 | new_h = self.shape[1] 438 | new_w = self.shape[0] 439 | 440 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 441 | seg_map = cv2.resize(seg_map, (new_w, new_h), interpolation=cv2.INTER_NEAREST) 442 | 443 | peak = [ [int(p[0]/ori_w*new_w), int(p[1]/ori_h*new_h), p[2], p[3]] for p in peak ] 444 | 445 | return img, seg_map, peak 446 | -------------------------------------------------------------------------------- /PAM/utils/transforms/functional.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/utils/transforms/functional.py 3 | # ------------------------------------------------------------------------------ 4 | 5 | from __future__ import division 6 | import torch 7 | import math 8 | import random 9 | from PIL import Image, ImageOps, ImageEnhance 10 | try: 11 | import accimage 12 | except ImportError: 13 | accimage = None 14 | import numpy as np 15 | import numbers 16 | import types 17 | import collections 18 | import warnings 19 | 20 | 21 | def _is_pil_image(img): 22 | if accimage is not None: 23 | return isinstance(img, (Image.Image, accimage.Image)) 24 | else: 25 | return isinstance(img, Image.Image) 26 | 27 | 28 | def _is_tensor_image(img): 29 | return torch.is_tensor(img) and img.ndimension() == 3 30 | 31 | 32 | def _is_numpy_image(img): 33 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 34 | 35 | 36 | def to_tensor(pic): 37 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 38 | 39 | See ``ToTensor`` for more details. 40 | 41 | Args: 42 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 43 | 44 | Returns: 45 | Tensor: Converted image. 46 | """ 47 | if not(_is_pil_image(pic) or _is_numpy_image(pic)): 48 | raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 49 | 50 | if isinstance(pic, np.ndarray): 51 | # handle numpy array 52 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 53 | # backward compatibility 54 | return img.float().div(255) 55 | # return img.float() 56 | 57 | if accimage is not None and isinstance(pic, accimage.Image): 58 | nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) 59 | pic.copyto(nppic) 60 | return torch.from_numpy(nppic) 61 | 62 | # handle PIL Image 63 | if pic.mode == 'I': 64 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 65 | elif pic.mode == 'I;16': 66 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 67 | else: 68 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 69 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 70 | if pic.mode == 'YCbCr': 71 | nchannel = 3 72 | elif pic.mode == 'I;16': 73 | nchannel = 1 74 | else: 75 | nchannel = len(pic.mode) 76 | img = img.view(pic.size[1], pic.size[0], nchannel) 77 | # put it from HWC to CHW format 78 | # yikes, this transpose takes 80% of the loading time/CPU 79 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 80 | if isinstance(img, torch.ByteTensor): 81 | return img.float().div(255) 82 | # return img.float() 83 | else: 84 | return img 85 | 86 | 87 | def to_pil_image(pic, mode=None): 88 | """Convert a tensor or an ndarray to PIL Image. 89 | 90 | See :class:`~torchvision.transforms.ToPIlImage` for more details. 91 | 92 | Args: 93 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 94 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 95 | 96 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 97 | 98 | Returns: 99 | PIL Image: Image converted to PIL Image. 100 | """ 101 | if not(_is_numpy_image(pic) or _is_tensor_image(pic)): 102 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) 103 | 104 | npimg = pic 105 | if isinstance(pic, torch.FloatTensor): 106 | pic = pic.mul(255).byte() 107 | if torch.is_tensor(pic): 108 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 109 | 110 | if not isinstance(npimg, np.ndarray): 111 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 112 | 'not {}'.format(type(npimg))) 113 | 114 | if npimg.shape[2] == 1: 115 | expected_mode = None 116 | npimg = npimg[:, :, 0] 117 | if npimg.dtype == np.uint8: 118 | expected_mode = 'L' 119 | if npimg.dtype == np.int16: 120 | expected_mode = 'I;16' 121 | if npimg.dtype == np.int32: 122 | expected_mode = 'I' 123 | elif npimg.dtype == np.float32: 124 | expected_mode = 'F' 125 | if mode is not None and mode != expected_mode: 126 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" 127 | .format(mode, np.dtype, expected_mode)) 128 | mode = expected_mode 129 | 130 | elif npimg.shape[2] == 4: 131 | permitted_4_channel_modes = ['RGBA', 'CMYK'] 132 | if mode is not None and mode not in permitted_4_channel_modes: 133 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) 134 | 135 | if mode is None and npimg.dtype == np.uint8: 136 | mode = 'RGBA' 137 | else: 138 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 139 | if mode is not None and mode not in permitted_3_channel_modes: 140 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) 141 | if mode is None and npimg.dtype == np.uint8: 142 | mode = 'RGB' 143 | 144 | if mode is None: 145 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 146 | 147 | return Image.fromarray(npimg, mode=mode) 148 | 149 | 150 | def normalize(tensor, mean, std): 151 | """Normalize a tensor image with mean and standard deviation. 152 | 153 | See ``Normalize`` for more details. 154 | 155 | Args: 156 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 157 | mean (sequence): Sequence of means for each channel. 158 | std (sequence): Sequence of standard deviations for each channely. 159 | 160 | Returns: 161 | Tensor: Normalized Tensor image. 162 | """ 163 | if not _is_tensor_image(tensor): 164 | raise TypeError('tensor is not a torch image.') 165 | # TODO: make efficient 166 | for t, m, s in zip(tensor, mean, std): 167 | t.sub_(m).div_(s) 168 | return tensor 169 | 170 | 171 | def resize(img, size, interpolation=Image.BILINEAR): 172 | """Resize the input PIL Image to the given size. 173 | 174 | Args: 175 | img (PIL Image): Image to be resized. 176 | size (sequence or int): Desired output size. If size is a sequence like 177 | (h, w), the output size will be matched to this. If size is an int, 178 | the smaller edge of the image will be matched to this number maintaing 179 | the aspect ratio. i.e, if height > width, then image will be rescaled to 180 | (size * height / width, size) 181 | interpolation (int, optional): Desired interpolation. Default is 182 | ``PIL.Image.BILINEAR`` 183 | 184 | Returns: 185 | PIL Image: Resized image. 186 | """ 187 | if not _is_pil_image(img): 188 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 189 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 190 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 191 | 192 | if isinstance(size, int): 193 | w, h = img.size 194 | if (w <= h and w == size) or (h <= w and h == size): 195 | return img 196 | if w < h: 197 | ow = size 198 | oh = int(size * h / w) 199 | return img.resize((ow, oh), interpolation) 200 | else: 201 | oh = size 202 | ow = int(size * w / h) 203 | return img.resize((ow, oh), interpolation) 204 | else: 205 | return img.resize(size[::-1], interpolation) 206 | 207 | 208 | def scale(*args, **kwargs): 209 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 210 | "please use transforms.Resize instead.") 211 | return resize(*args, **kwargs) 212 | 213 | 214 | def pad(img, padding, fill=0): 215 | """Pad the given PIL Image on all sides with the given "pad" value. 216 | 217 | Args: 218 | img (PIL Image): Image to be padded. 219 | padding (int or tuple): Padding on each border. If a single int is provided this 220 | is used to pad all borders. If tuple of length 2 is provided this is the padding 221 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 222 | this is the padding for the left, top, right and bottom borders 223 | respectively. 224 | fill: Pixel fill value. Default is 0. If a tuple of 225 | length 3, it is used to fill R, G, B channels respectively. 226 | 227 | Returns: 228 | PIL Image: Padded image. 229 | """ 230 | if not _is_pil_image(img): 231 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 232 | 233 | if not isinstance(padding, (numbers.Number, tuple)): 234 | raise TypeError('Got inappropriate padding arg') 235 | if not isinstance(fill, (numbers.Number, str, tuple)): 236 | raise TypeError('Got inappropriate fill arg') 237 | 238 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 239 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 240 | "{} element tuple".format(len(padding))) 241 | 242 | return ImageOps.expand(img, border=padding, fill=fill) 243 | 244 | 245 | def crop(img, i, j, h, w): 246 | """Crop the given PIL Image. 247 | 248 | Args: 249 | img (PIL Image): Image to be cropped. 250 | i: Upper pixel coordinate. 251 | j: Left pixel coordinate. 252 | h: Height of the cropped image. 253 | w: Width of the cropped image. 254 | 255 | Returns: 256 | PIL Image: Cropped image. 257 | """ 258 | if not _is_pil_image(img): 259 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 260 | 261 | return img.crop((j, i, j + w, i + h)) 262 | 263 | 264 | def center_crop(img, output_size): 265 | if isinstance(output_size, numbers.Number): 266 | output_size = (int(output_size), int(output_size)) 267 | w, h = img.size 268 | th, tw = output_size 269 | i = int(round((h - th) / 2.)) 270 | j = int(round((w - tw) / 2.)) 271 | return crop(img, i, j, th, tw) 272 | 273 | 274 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 275 | """Crop the given PIL Image and resize it to desired size. 276 | 277 | Notably used in RandomResizedCrop. 278 | 279 | Args: 280 | img (PIL Image): Image to be cropped. 281 | i: Upper pixel coordinate. 282 | j: Left pixel coordinate. 283 | h: Height of the cropped image. 284 | w: Width of the cropped image. 285 | size (sequence or int): Desired output size. Same semantics as ``scale``. 286 | interpolation (int, optional): Desired interpolation. Default is 287 | ``PIL.Image.BILINEAR``. 288 | Returns: 289 | PIL Image: Cropped image. 290 | """ 291 | assert _is_pil_image(img), 'img should be PIL Image' 292 | img = crop(img, i, j, h, w) 293 | img = resize(img, size, interpolation) 294 | return img 295 | 296 | 297 | def hflip(img): 298 | """Horizontally flip the given PIL Image. 299 | 300 | Args: 301 | img (PIL Image): Image to be flipped. 302 | 303 | Returns: 304 | PIL Image: Horizontall flipped image. 305 | """ 306 | if not _is_pil_image(img): 307 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 308 | 309 | return img.transpose(Image.FLIP_LEFT_RIGHT) 310 | 311 | 312 | def vflip(img): 313 | """Vertically flip the given PIL Image. 314 | 315 | Args: 316 | img (PIL Image): Image to be flipped. 317 | 318 | Returns: 319 | PIL Image: Vertically flipped image. 320 | """ 321 | if not _is_pil_image(img): 322 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 323 | 324 | return img.transpose(Image.FLIP_TOP_BOTTOM) 325 | 326 | 327 | def five_crop(img, size): 328 | """Crop the given PIL Image into four corners and the central crop. 329 | 330 | .. Note:: 331 | This transform returns a tuple of images and there may be a 332 | mismatch in the number of inputs and targets your ``Dataset`` returns. 333 | 334 | Args: 335 | size (sequence or int): Desired output size of the crop. If size is an 336 | int instead of sequence like (h, w), a square crop (size, size) is 337 | made. 338 | Returns: 339 | tuple: tuple (tl, tr, bl, br, center) corresponding top left, 340 | top right, bottom left, bottom right and center crop. 341 | """ 342 | if isinstance(size, numbers.Number): 343 | size = (int(size), int(size)) 344 | else: 345 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 346 | 347 | w, h = img.size 348 | crop_h, crop_w = size 349 | if crop_w > w or crop_h > h: 350 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, 351 | (h, w))) 352 | tl = img.crop((0, 0, crop_w, crop_h)) 353 | tr = img.crop((w - crop_w, 0, w, crop_h)) 354 | bl = img.crop((0, h - crop_h, crop_w, h)) 355 | br = img.crop((w - crop_w, h - crop_h, w, h)) 356 | center = center_crop(img, (crop_h, crop_w)) 357 | return (tl, tr, bl, br, center) 358 | 359 | 360 | def ten_crop(img, size, vertical_flip=False): 361 | """Crop the given PIL Image into four corners and the central crop plus the 362 | flipped version of these (horizontal flipping is used by default). 363 | 364 | .. Note:: 365 | This transform returns a tuple of images and there may be a 366 | mismatch in the number of inputs and targets your ``Dataset`` returns. 367 | 368 | Args: 369 | size (sequence or int): Desired output size of the crop. If size is an 370 | int instead of sequence like (h, w), a square crop (size, size) is 371 | made. 372 | vertical_flip (bool): Use vertical flipping instead of horizontal 373 | 374 | Returns: 375 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, 376 | br_flip, center_flip) corresponding top left, top right, 377 | bottom left, bottom right and center crop and same for the 378 | flipped image. 379 | """ 380 | if isinstance(size, numbers.Number): 381 | size = (int(size), int(size)) 382 | else: 383 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 384 | 385 | first_five = five_crop(img, size) 386 | 387 | if vertical_flip: 388 | img = vflip(img) 389 | else: 390 | img = hflip(img) 391 | 392 | second_five = five_crop(img, size) 393 | return first_five + second_five 394 | 395 | 396 | def adjust_brightness(img, brightness_factor): 397 | """Adjust brightness of an Image. 398 | 399 | Args: 400 | img (PIL Image): PIL Image to be adjusted. 401 | brightness_factor (float): How much to adjust the brightness. Can be 402 | any non negative number. 0 gives a black image, 1 gives the 403 | original image while 2 increases the brightness by a factor of 2. 404 | 405 | Returns: 406 | PIL Image: Brightness adjusted image. 407 | """ 408 | if not _is_pil_image(img): 409 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 410 | 411 | enhancer = ImageEnhance.Brightness(img) 412 | img = enhancer.enhance(brightness_factor) 413 | return img 414 | 415 | 416 | def adjust_contrast(img, contrast_factor): 417 | """Adjust contrast of an Image. 418 | 419 | Args: 420 | img (PIL Image): PIL Image to be adjusted. 421 | contrast_factor (float): How much to adjust the contrast. Can be any 422 | non negative number. 0 gives a solid gray image, 1 gives the 423 | original image while 2 increases the contrast by a factor of 2. 424 | 425 | Returns: 426 | PIL Image: Contrast adjusted image. 427 | """ 428 | if not _is_pil_image(img): 429 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 430 | 431 | enhancer = ImageEnhance.Contrast(img) 432 | img = enhancer.enhance(contrast_factor) 433 | return img 434 | 435 | 436 | def adjust_saturation(img, saturation_factor): 437 | """Adjust color saturation of an image. 438 | 439 | Args: 440 | img (PIL Image): PIL Image to be adjusted. 441 | saturation_factor (float): How much to adjust the saturation. 0 will 442 | give a black and white image, 1 will give the original image while 443 | 2 will enhance the saturation by a factor of 2. 444 | 445 | Returns: 446 | PIL Image: Saturation adjusted image. 447 | """ 448 | if not _is_pil_image(img): 449 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 450 | 451 | enhancer = ImageEnhance.Color(img) 452 | img = enhancer.enhance(saturation_factor) 453 | return img 454 | 455 | 456 | def adjust_hue(img, hue_factor): 457 | """Adjust hue of an image. 458 | 459 | The image hue is adjusted by converting the image to HSV and 460 | cyclically shifting the intensities in the hue channel (H). 461 | The image is then converted back to original image mode. 462 | 463 | `hue_factor` is the amount of shift in H channel and must be in the 464 | interval `[-0.5, 0.5]`. 465 | 466 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 467 | 468 | Args: 469 | img (PIL Image): PIL Image to be adjusted. 470 | hue_factor (float): How much to shift the hue channel. Should be in 471 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 472 | HSV space in positive and negative direction respectively. 473 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 474 | with complementary colors while 0 gives the original image. 475 | 476 | Returns: 477 | PIL Image: Hue adjusted image. 478 | """ 479 | if not(-0.5 <= hue_factor <= 0.5): 480 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 481 | 482 | if not _is_pil_image(img): 483 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 484 | 485 | input_mode = img.mode 486 | if input_mode in {'L', '1', 'I', 'F'}: 487 | return img 488 | 489 | h, s, v = img.convert('HSV').split() 490 | 491 | np_h = np.array(h, dtype=np.uint8) 492 | # uint8 addition take cares of rotation across boundaries 493 | with np.errstate(over='ignore'): 494 | np_h += np.uint8(hue_factor * 255) 495 | h = Image.fromarray(np_h, 'L') 496 | 497 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 498 | return img 499 | 500 | 501 | def adjust_gamma(img, gamma, gain=1): 502 | """Perform gamma correction on an image. 503 | 504 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 505 | based on the following equation: 506 | 507 | I_out = 255 * gain * ((I_in / 255) ** gamma) 508 | 509 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 510 | 511 | Args: 512 | img (PIL Image): PIL Image to be adjusted. 513 | gamma (float): Non negative real number. gamma larger than 1 make the 514 | shadows darker, while gamma smaller than 1 make dark regions 515 | lighter. 516 | gain (float): The constant multiplier. 517 | """ 518 | if not _is_pil_image(img): 519 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 520 | 521 | if gamma < 0: 522 | raise ValueError('Gamma should be a non-negative real number') 523 | 524 | input_mode = img.mode 525 | img = img.convert('RGB') 526 | 527 | np_img = np.array(img, dtype=np.float32) 528 | np_img = 255 * gain * ((np_img / 255) ** gamma) 529 | np_img = np.uint8(np.clip(np_img, 0, 255)) 530 | 531 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 532 | return img 533 | 534 | 535 | def rotate(img, angle, resample=False, expand=False, center=None): 536 | """Rotate the image by angle and then (optionally) translate it by (n_columns, n_rows) 537 | 538 | 539 | Args: 540 | img (PIL Image): PIL Image to be rotated. 541 | angle ({float, int}): In degrees degrees counter clockwise order. 542 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 543 | An optional resampling filter. 544 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 545 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 546 | expand (bool, optional): Optional expansion flag. 547 | If true, expands the output image to make it large enough to hold the entire rotated image. 548 | If false or omitted, make the output image the same size as the input image. 549 | Note that the expand flag assumes rotation around the center and no translation. 550 | center (2-tuple, optional): Optional center of rotation. 551 | Origin is the upper left corner. 552 | Default is the center of the image. 553 | """ 554 | 555 | if not _is_pil_image(img): 556 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 557 | 558 | return img.rotate(angle, resample, expand, center) 559 | 560 | 561 | def to_grayscale(img, num_output_channels=1): 562 | """Convert image to grayscale version of image. 563 | 564 | Args: 565 | img (PIL Image): Image to be converted to grayscale. 566 | 567 | Returns: 568 | PIL Image: Grayscale version of the image. 569 | if num_output_channels == 1 : returned image is single channel 570 | if num_output_channels == 3 : returned image is 3 channel with r == g == b 571 | """ 572 | if not _is_pil_image(img): 573 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 574 | 575 | if num_output_channels == 1: 576 | img = img.convert('L') 577 | elif num_output_channels == 3: 578 | img = img.convert('L') 579 | np_img = np.array(img, dtype=np.uint8) 580 | np_img = np.dstack([np_img, np_img, np_img]) 581 | img = Image.fromarray(np_img, 'RGB') 582 | else: 583 | raise ValueError('num_output_channels should be either 1 or 3') 584 | 585 | return img -------------------------------------------------------------------------------- /models/hrnet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import logging 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from models.hrnet_config import cfg 18 | from models.hrnet_config import update_config 19 | 20 | from utils.utils import refine_label_generation, refine_label_generation_with_point 21 | 22 | BN_MOMENTUM = 0.1 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 74 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 75 | bias=False) 76 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, 77 | momentum=BN_MOMENTUM) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | out += residual 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | class HighResolutionModule(nn.Module): 106 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels, 107 | num_channels, fuse_method, multi_scale_output=True): 108 | super(HighResolutionModule, self).__init__() 109 | self._check_branches( 110 | num_branches, blocks, num_blocks, num_inchannels, num_channels) 111 | 112 | self.num_inchannels = num_inchannels 113 | self.fuse_method = fuse_method 114 | self.num_branches = num_branches 115 | 116 | self.multi_scale_output = multi_scale_output 117 | 118 | self.branches = self._make_branches( 119 | num_branches, blocks, num_blocks, num_channels) 120 | self.fuse_layers = self._make_fuse_layers() 121 | self.relu = nn.ReLU(True) 122 | 123 | def _check_branches(self, num_branches, blocks, num_blocks, 124 | num_inchannels, num_channels): 125 | if num_branches != len(num_blocks): 126 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( 127 | num_branches, len(num_blocks)) 128 | logger.error(error_msg) 129 | raise ValueError(error_msg) 130 | 131 | if num_branches != len(num_channels): 132 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( 133 | num_branches, len(num_channels)) 134 | logger.error(error_msg) 135 | raise ValueError(error_msg) 136 | 137 | if num_branches != len(num_inchannels): 138 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( 139 | num_branches, len(num_inchannels)) 140 | logger.error(error_msg) 141 | raise ValueError(error_msg) 142 | 143 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels, 144 | stride=1): 145 | downsample = None 146 | if stride != 1 or \ 147 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: 148 | downsample = nn.Sequential( 149 | nn.Conv2d( 150 | self.num_inchannels[branch_index], 151 | num_channels[branch_index] * block.expansion, 152 | kernel_size=1, stride=stride, bias=False 153 | ), 154 | nn.BatchNorm2d( 155 | num_channels[branch_index] * block.expansion, 156 | momentum=BN_MOMENTUM 157 | ), 158 | ) 159 | 160 | layers = [] 161 | layers.append( 162 | block( 163 | self.num_inchannels[branch_index], 164 | num_channels[branch_index], 165 | stride, 166 | downsample 167 | ) 168 | ) 169 | self.num_inchannels[branch_index] = \ 170 | num_channels[branch_index] * block.expansion 171 | for i in range(1, num_blocks[branch_index]): 172 | layers.append( 173 | block( 174 | self.num_inchannels[branch_index], 175 | num_channels[branch_index] 176 | ) 177 | ) 178 | 179 | return nn.Sequential(*layers) 180 | 181 | def _make_branches(self, num_branches, block, num_blocks, num_channels): 182 | branches = [] 183 | 184 | for i in range(num_branches): 185 | branches.append( 186 | self._make_one_branch(i, block, num_blocks, num_channels) 187 | ) 188 | 189 | return nn.ModuleList(branches) 190 | 191 | def _make_fuse_layers(self): 192 | if self.num_branches == 1: 193 | return None 194 | 195 | num_branches = self.num_branches 196 | num_inchannels = self.num_inchannels 197 | fuse_layers = [] 198 | for i in range(num_branches if self.multi_scale_output else 1): 199 | fuse_layer = [] 200 | for j in range(num_branches): 201 | if j > i: 202 | fuse_layer.append( 203 | nn.Sequential( 204 | nn.Conv2d( 205 | num_inchannels[j], 206 | num_inchannels[i], 207 | 1, 1, 0, bias=False 208 | ), 209 | nn.BatchNorm2d(num_inchannels[i]), 210 | nn.Upsample(scale_factor=2**(j-i), mode='nearest') 211 | ) 212 | ) 213 | elif j == i: 214 | fuse_layer.append(None) 215 | else: 216 | conv3x3s = [] 217 | for k in range(i-j): 218 | if k == i - j - 1: 219 | num_outchannels_conv3x3 = num_inchannels[i] 220 | conv3x3s.append( 221 | nn.Sequential( 222 | nn.Conv2d( 223 | num_inchannels[j], 224 | num_outchannels_conv3x3, 225 | 3, 2, 1, bias=False 226 | ), 227 | nn.BatchNorm2d(num_outchannels_conv3x3) 228 | ) 229 | ) 230 | else: 231 | num_outchannels_conv3x3 = num_inchannels[j] 232 | conv3x3s.append( 233 | nn.Sequential( 234 | nn.Conv2d( 235 | num_inchannels[j], 236 | num_outchannels_conv3x3, 237 | 3, 2, 1, bias=False 238 | ), 239 | nn.BatchNorm2d(num_outchannels_conv3x3), 240 | nn.ReLU(True) 241 | ) 242 | ) 243 | fuse_layer.append(nn.Sequential(*conv3x3s)) 244 | fuse_layers.append(nn.ModuleList(fuse_layer)) 245 | 246 | return nn.ModuleList(fuse_layers) 247 | 248 | def get_num_inchannels(self): 249 | return self.num_inchannels 250 | 251 | def forward(self, x): 252 | if self.num_branches == 1: 253 | return [self.branches[0](x[0])] 254 | 255 | for i in range(self.num_branches): 256 | x[i] = self.branches[i](x[i]) 257 | 258 | x_fuse = [] 259 | 260 | for i in range(len(self.fuse_layers)): 261 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 262 | for j in range(1, self.num_branches): 263 | if i == j: 264 | y = y + x[j] 265 | else: 266 | y = y + self.fuse_layers[i][j](x[j]) 267 | x_fuse.append(self.relu(y)) 268 | 269 | return x_fuse 270 | 271 | 272 | blocks_dict = { 273 | 'BASIC': BasicBlock, 274 | 'BOTTLENECK': Bottleneck 275 | } 276 | 277 | class PoseHighResolutionNet(nn.Module): 278 | 279 | def __init__(self, cfg, heads, args, **kwargs): 280 | self.inplanes = 64 281 | extra = cfg['MODEL']['EXTRA'] 282 | super(PoseHighResolutionNet, self).__init__() 283 | 284 | self.heads = heads 285 | self.args = args 286 | 287 | # stem net 288 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, 289 | bias=False) 290 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 291 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, 292 | bias=False) 293 | self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 294 | self.relu = nn.ReLU(inplace=True) 295 | self.layer1 = self._make_layer(Bottleneck, 64, 4) 296 | 297 | self.stage2_cfg = extra['STAGE2'] 298 | num_channels = self.stage2_cfg['NUM_CHANNELS'] 299 | block = blocks_dict[self.stage2_cfg['BLOCK']] 300 | num_channels = [ 301 | num_channels[i] * block.expansion for i in range(len(num_channels)) 302 | ] 303 | self.transition1 = self._make_transition_layer([256], num_channels) 304 | self.stage2, pre_stage_channels = self._make_stage( 305 | self.stage2_cfg, num_channels) 306 | 307 | self.stage3_cfg = extra['STAGE3'] 308 | num_channels = self.stage3_cfg['NUM_CHANNELS'] 309 | block = blocks_dict[self.stage3_cfg['BLOCK']] 310 | num_channels = [ 311 | num_channels[i] * block.expansion for i in range(len(num_channels)) 312 | ] 313 | self.transition2 = self._make_transition_layer( 314 | pre_stage_channels, num_channels) 315 | self.stage3, pre_stage_channels = self._make_stage( 316 | self.stage3_cfg, num_channels) 317 | 318 | self.stage4_cfg = extra['STAGE4'] 319 | num_channels = self.stage4_cfg['NUM_CHANNELS'] 320 | block = blocks_dict[self.stage4_cfg['BLOCK']] 321 | num_channels = [ 322 | num_channels[i] * block.expansion for i in range(len(num_channels)) 323 | ] 324 | self.transition3 = self._make_transition_layer( 325 | pre_stage_channels, num_channels) 326 | self.stage4, pre_stage_channels = self._make_stage( 327 | self.stage4_cfg, num_channels, multi_scale_output=False) 328 | 329 | for head in self.heads: 330 | classes = self.heads[head] 331 | head_conv = 256 332 | 333 | fc = nn.Sequential( 334 | nn.Conv2d(pre_stage_channels[0], head_conv, kernel_size=3, padding=1, bias=True), 335 | nn.ReLU(inplace=True), 336 | nn.Conv2d(head_conv, classes, kernel_size=1, stride=1, padding=0, bias=True), 337 | ) 338 | 339 | self.__setattr__(head, fc) 340 | 341 | self.pretrained_layers = extra['PRETRAINED_LAYERS'] 342 | 343 | def _make_transition_layer( 344 | self, num_channels_pre_layer, num_channels_cur_layer): 345 | num_branches_cur = len(num_channels_cur_layer) 346 | num_branches_pre = len(num_channels_pre_layer) 347 | 348 | transition_layers = [] 349 | for i in range(num_branches_cur): 350 | if i < num_branches_pre: 351 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 352 | transition_layers.append( 353 | nn.Sequential( 354 | nn.Conv2d( 355 | num_channels_pre_layer[i], 356 | num_channels_cur_layer[i], 357 | 3, 1, 1, bias=False 358 | ), 359 | nn.BatchNorm2d(num_channels_cur_layer[i]), 360 | nn.ReLU(inplace=True) 361 | ) 362 | ) 363 | else: 364 | transition_layers.append(None) 365 | else: 366 | conv3x3s = [] 367 | for j in range(i+1-num_branches_pre): 368 | inchannels = num_channels_pre_layer[-1] 369 | outchannels = num_channels_cur_layer[i] \ 370 | if j == i-num_branches_pre else inchannels 371 | conv3x3s.append( 372 | nn.Sequential( 373 | nn.Conv2d( 374 | inchannels, outchannels, 3, 2, 1, bias=False 375 | ), 376 | nn.BatchNorm2d(outchannels), 377 | nn.ReLU(inplace=True) 378 | ) 379 | ) 380 | transition_layers.append(nn.Sequential(*conv3x3s)) 381 | 382 | return nn.ModuleList(transition_layers) 383 | 384 | def _make_layer(self, block, planes, blocks, stride=1): 385 | downsample = None 386 | if stride != 1 or self.inplanes != planes * block.expansion: 387 | downsample = nn.Sequential( 388 | nn.Conv2d( 389 | self.inplanes, planes * block.expansion, 390 | kernel_size=1, stride=stride, bias=False 391 | ), 392 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 393 | ) 394 | 395 | layers = [] 396 | layers.append(block(self.inplanes, planes, stride, downsample)) 397 | self.inplanes = planes * block.expansion 398 | for i in range(1, blocks): 399 | layers.append(block(self.inplanes, planes)) 400 | 401 | return nn.Sequential(*layers) 402 | 403 | def _make_stage(self, layer_config, num_inchannels, 404 | multi_scale_output=True): 405 | num_modules = layer_config['NUM_MODULES'] 406 | num_branches = layer_config['NUM_BRANCHES'] 407 | num_blocks = layer_config['NUM_BLOCKS'] 408 | num_channels = layer_config['NUM_CHANNELS'] 409 | block = blocks_dict[layer_config['BLOCK']] 410 | fuse_method = layer_config['FUSE_METHOD'] 411 | 412 | modules = [] 413 | for i in range(num_modules): 414 | # multi_scale_output is only used last module 415 | if not multi_scale_output and i == num_modules - 1: 416 | reset_multi_scale_output = False 417 | else: 418 | reset_multi_scale_output = True 419 | 420 | modules.append( 421 | HighResolutionModule( 422 | num_branches, 423 | block, 424 | num_blocks, 425 | num_inchannels, 426 | num_channels, 427 | fuse_method, 428 | reset_multi_scale_output 429 | ) 430 | ) 431 | num_inchannels = modules[-1].get_num_inchannels() 432 | 433 | return nn.Sequential(*modules), num_inchannels 434 | 435 | def forward(self, x, seg_map=None, label=None, point_list=None, target_shape=None): 436 | if target_shape is None: 437 | target_shape = x.shape[-2:] 438 | 439 | x = self.conv1(x) 440 | x = self.bn1(x) 441 | x = self.relu(x) 442 | x = self.conv2(x) 443 | x = self.bn2(x) 444 | x = self.relu(x) 445 | x = self.layer1(x) 446 | 447 | x_list = [] 448 | for i in range(self.stage2_cfg['NUM_BRANCHES']): 449 | if self.transition1[i] is not None: 450 | x_list.append(self.transition1[i](x)) 451 | else: 452 | x_list.append(x) 453 | y_list = self.stage2(x_list) 454 | 455 | x_list = [] 456 | for i in range(self.stage3_cfg['NUM_BRANCHES']): 457 | if self.transition2[i] is not None: 458 | x_list.append(self.transition2[i](y_list[-1])) 459 | else: 460 | x_list.append(y_list[i]) 461 | y_list = self.stage3(x_list) 462 | 463 | x_list = [] 464 | for i in range(self.stage4_cfg['NUM_BRANCHES']): 465 | if self.transition3[i] is not None: 466 | x_list.append(self.transition3[i](y_list[-1])) 467 | else: 468 | x_list.append(y_list[i]) 469 | y_list = self.stage4(x_list) 470 | 471 | results = {} 472 | for head in self.heads: 473 | results[head] = self.__getattr__(head)(y_list[0]) 474 | results[head] = torch.nn.functional.interpolate(results[head], 475 | size=target_shape, 476 | mode='bilinear', 477 | align_corners=False) 478 | 479 | if label is not None: # refined label generation 480 | 481 | if self.args.sup == 'point': # point supervision setting 482 | pseudo_label = refine_label_generation_with_point( 483 | results['seg'].clone().detach(), 484 | point_list.cpu().numpy(), 485 | results['offset'].clone().detach(), 486 | label.clone().detach(), 487 | seg_map.clone().detach(), 488 | self.args, 489 | ) 490 | 491 | else: # image-level supervision setting 492 | pseudo_label = refine_label_generation( 493 | results['seg'].clone().detach(), 494 | results['center'].clone().detach(), 495 | results['offset'].clone().detach(), 496 | label.clone().detach(), 497 | seg_map.clone().detach(), 498 | self.args, 499 | ) 500 | 501 | return results, pseudo_label 502 | 503 | return results 504 | 505 | 506 | def init_weights(self, pretrained=''): 507 | logger.info('=> init weights from normal distribution') 508 | for n, m in self.named_modules(): 509 | if isinstance(m, nn.Conv2d): 510 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 511 | nn.init.normal_(m.weight, std=0.001) 512 | for name, _ in m.named_parameters(): 513 | if name in ['bias']: 514 | nn.init.constant_(m.bias, 0) 515 | elif isinstance(m, nn.BatchNorm2d): 516 | nn.init.constant_(m.weight, 1) 517 | nn.init.constant_(m.bias, 0) 518 | elif isinstance(m, nn.ConvTranspose2d): 519 | nn.init.normal_(m.weight, std=0.001) 520 | for name, _ in m.named_parameters(): 521 | if name in ['bias']: 522 | nn.init.constant_(m.bias, 0) 523 | 524 | if os.path.isfile(pretrained): 525 | pretrained_state_dict = torch.load(pretrained) 526 | logger.info('=> loading pretrained model {}'.format(pretrained)) 527 | print('=> loading pretrained model {}'.format(pretrained)) 528 | 529 | need_init_state_dict = {} 530 | for name, m in pretrained_state_dict.items(): 531 | if name.split('.')[0] in self.pretrained_layers \ 532 | or self.pretrained_layers[0] is '*': 533 | need_init_state_dict[name] = m 534 | self.load_state_dict(need_init_state_dict, strict=False) 535 | else: 536 | logger.info("=> without pre-trained models") 537 | print("=> without pre-trained models") 538 | 539 | 540 | def HRNet(layer, heads, args, **kwargs): 541 | update_config(cfg, f"./models/hrnet_config/w{layer}_384x288_adam_lr1e-3.yaml") 542 | cfg.MODEL.NUM_JOINTS = 1 # num classes 543 | 544 | model = PoseHighResolutionNet(cfg, heads, args, **kwargs) 545 | 546 | model.init_weights(cfg.MODEL.PRETRAINED) 547 | 548 | return model 549 | 550 | 551 | def hrnet32(args): 552 | heads = { 553 | 'seg': args.num_classes+1, 554 | 'center': args.num_classes, 555 | 'offset': 2 556 | } 557 | model = HRNet(32, heads, args) 558 | return model 559 | 560 | def hrnet48(args): 561 | heads = { 562 | 'seg': args.num_classes+1, 563 | 'center': args.num_classes, 564 | 'offset': 2 565 | } 566 | model = HRNet(48, heads, args) 567 | return model --------------------------------------------------------------------------------