├── PAM
├── models
│ ├── __init__py
│ └── classifier.py
├── run_PAM.sh
├── utils
│ ├── avgMeter.py
│ ├── my_optim.py
│ ├── LoadData.py
│ ├── Metrics.py
│ └── transforms
│ │ └── functional.py
├── README.md
├── point_extraction.py
└── train.py
├── figures
├── overview.png
├── result_coco.png
├── result_voc.png
├── PAM_comparison.png
├── PAM_architecture.png
└── qualitavie_result.png
├── models
├── imagenet
│ └── README.md
├── hrnet_config
│ ├── __init__.py
│ ├── models.py
│ ├── w32_384x288_adam_lr1e-3.yaml
│ ├── w48_384x288_adam_lr1e-3.yaml
│ └── default.py
├── __init__.py
├── resnet.py
├── panoptic_deeplab.py
└── hrnet.py
├── scripts
├── run_point_labels.sh
└── run_image_labels.sh
├── LICENSE
├── utils
├── crf.py
├── my_optim.py
├── loss.py
├── decode.py
├── LoadData.py
└── transforms
│ └── transforms.py
├── .gitignore
├── NOTICE
├── README.md
└── main.py
/PAM/models/__init__py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figures/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/overview.png
--------------------------------------------------------------------------------
/figures/result_coco.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/result_coco.png
--------------------------------------------------------------------------------
/figures/result_voc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/result_voc.png
--------------------------------------------------------------------------------
/figures/PAM_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/PAM_comparison.png
--------------------------------------------------------------------------------
/figures/PAM_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/PAM_architecture.png
--------------------------------------------------------------------------------
/figures/qualitavie_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/BESTIE/HEAD/figures/qualitavie_result.png
--------------------------------------------------------------------------------
/models/imagenet/README.md:
--------------------------------------------------------------------------------
1 | Download HRNet ImageNet pretrained weight files to this path ([GoogleDrive](https://drive.google.com/drive/folders/1E6j6W7RqGhW1o7UHgiQ9X4g8fVJRU9TX)) ([OneDrive](https://1drv.ms/f/s!AhIXJn_J-blW231MH2krnmLq5kkQ)).
2 | The weight file is from https://github.com/HRNet/HRNet-Human-Pose-Estimation/tree/master .
3 |
4 | - hrnet_w32-36af842e.pth
5 | - hrnet_w48-8ef0771d.pth
--------------------------------------------------------------------------------
/models/hrnet_config/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # ------------------------------------------------------------------------------
6 |
7 | from .default import _C as cfg
8 | from .default import update_config
9 | from .models import MODEL_EXTRAS
10 |
--------------------------------------------------------------------------------
/PAM/run_PAM.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1 python train.py \
2 | --root_dir=your_dataset_root_path \
3 | --lr=0.001 \
4 | --epoch=15 \
5 | --decay_points='5,10' \
6 | --alpha=0.7 \
7 | --save_folder=checkpoints/PAM \
8 | --show_interval=50
9 |
10 | CUDA_VISIBLE_DEVICES=0 python point_extraction.py \
11 | --root_dir=your_dataset_root_path \
12 | --alpha=0.7 \
13 | --checkpoint=checkpoints/PAM/ckpt_15.pth \
14 | --save_dir Peak_points
15 |
--------------------------------------------------------------------------------
/PAM/utils/avgMeter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 | def __init__(self):
4 | self.reset()
5 |
6 | def reset(self):
7 | self.val = 0
8 | self.avg = 0
9 | self.sum = 0
10 | self.count = 0
11 |
12 | def update(self, val, n=1):
13 | self.val = val
14 | self.sum += val * n
15 | self.count += n
16 | self.avg = self.sum / self.count
17 |
18 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .panoptic_deeplab import PanopticDeepLab
2 | from .hrnet import hrnet32, hrnet48
3 | from .resnet import resnet50, resnet101
4 |
5 | def model_factory(args):
6 |
7 | if args.backbone == 'hrnet32':
8 | model = hrnet32(args)
9 | elif args.backbone == 'hrnet48':
10 | model = hrnet48(args)
11 |
12 | elif args.backbone == 'resnet50':
13 | backbone = resnet50(pretrained=True)
14 | model = PanopticDeepLab(backbone, args)
15 |
16 | elif args.backbone == 'resnet101':
17 | backbone = resnet101(pretrained=True)
18 | model = PanopticDeepLab(backbone, args)
19 |
20 | return model
21 |
--------------------------------------------------------------------------------
/scripts/run_point_labels.sh:
--------------------------------------------------------------------------------
1 | # Training BESTIE with point labels.
2 |
3 | ROOT=your_dataset_root_path
4 | SUP=point
5 | REFINE_WARMUP=0
6 | SIZE=416
7 | BATCH=16
8 | WORKERS=4
9 | TRAIN_ITERS=50000
10 | BACKBONE=hrnet48 # [resnet50, resnet101, hrnet32, hrnet48]
11 | VAL_IGNORE=False
12 |
13 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nnodes=1 --nproc_per_node=4 main.py \
14 | --root_dir ${ROOT} --sup ${SUP} --batch_size ${BATCH} --num_workers ${WORKERS} --crop_size ${SIZE} \
15 | --train_iter ${TRAIN_ITERS} --refine True --refine_iter ${REFINE_WARMUP} \
16 | --val_freq 1000 --val_thresh 0.1 --val_ignore ${VAL_IGNORE} --val_clean False --val_flip True \
17 | --seg_weight 1.0 --center_weight 200.0 --offset_weight 0.01 \
18 | --lr 5e-5 --backbone ${BACKBONE} --random_seed 3407
19 |
--------------------------------------------------------------------------------
/scripts/run_image_labels.sh:
--------------------------------------------------------------------------------
1 | # Training BESTIE with image-level labels.
2 |
3 | ROOT=your_dataset_root_path
4 | SUP=cls
5 | PSEUDO_THRESH=0.7
6 | REFINE_THRESH=0.3
7 | REFINE_WARMUP=0
8 | SIZE=416
9 | BATCH=16
10 | WORKERS=4
11 | TRAIN_ITERS=50000
12 | BACKBONE=hrnet48 # [resnet50, resnet101, hrnet32, hrnet48]
13 | VAL_IGNORE=False
14 |
15 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nnodes=1 --nproc_per_node=4 main.py \
16 | --root_dir ${ROOT} --sup ${SUP} --batch_size ${BATCH} --num_workers ${WORKERS} --crop_size ${SIZE} --train_iter ${TRAIN_ITERS} \
17 | --refine True --refine_iter ${REFINE_WARMUP} --pseudo_thresh ${PSEUDO_THRESH} --refine_thresh ${REFINE_THRESH} \
18 | --val_freq 1000 --val_ignore ${VAL_IGNORE} --val_clean False --val_flip False \
19 | --seg_weight 1.0 --center_weight 200.0 --offset_weight 0.01 \
20 | --lr 5e-5 --backbone ${BACKBONE} --random_seed 3407
21 |
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BESTIE
2 | Copyright (c) 2022-present NAVER Corp.
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy
5 | of this software and associated documentation files (the "Software"), to deal
6 | in the Software without restriction, including without limitation the rights
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is
9 | furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in
12 | all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 | THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/PAM/README.md:
--------------------------------------------------------------------------------
1 | # Peak Attention Module (PAM)
2 |
3 | ## Abtract
4 |
5 | CAMs have a limitation in obtaining the accurate instance cue because several instance cues might be extracted in a single instance due to noisy activation regions as illustrated in the bellow Figure.
6 | It disturbs the generation of pseudo instance labels.
7 | To address this limitation, we propose a peak attention module (PAM) to extract one appropriate instance cue per instance.
8 | PAM aims to strengthen the attention on peak regions, while weakening the attention on noisy activation regions.
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## How to Run?
16 |
17 | ```
18 | # change the data ROOT in the shell script
19 | bash run_PAM.sh
20 | ```
21 |
22 | * Note that extracted peak points are used in the image-level supervised BESTIE.
23 | * We provide the weight for the pretrained classfier with PAM module [[download]](https://drive.google.com/file/d/1I5DocPV2Lkc59DtDrr4XoQuVlKdRi4km/view?usp=sharing)
24 |
25 | ## Acknowledgement
26 |
27 | Our implementation is based on these repositories:
28 | - (DRS) https://github.com/qjadud1994/DRS
29 |
--------------------------------------------------------------------------------
/utils/crf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pydensecrf.densecrf as dcrf
3 | from pydensecrf.utils import unary_from_labels, unary_from_softmax
4 |
5 | def crf_inference_label(img, labels, t=10, n_labels=21, gt_prob=0.7):
6 |
7 | h, w = img.shape[:2]
8 |
9 | d = dcrf.DenseCRF2D(w, h, n_labels)
10 |
11 | unary = unary_from_labels(labels, n_labels, gt_prob=gt_prob, zero_unsure=False)
12 |
13 | d.setUnaryEnergy(unary)
14 | d.addPairwiseGaussian(sxy=3, compat=3)
15 | d.addPairwiseBilateral(sxy=50, srgb=5, rgbim=np.ascontiguousarray(np.copy(img)), compat=10)
16 |
17 | q = d.inference(t)
18 |
19 | return np.argmax(np.array(q).reshape((n_labels, h, w)), axis=0)
20 |
21 |
22 | class DenseCRF(object):
23 | def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std):
24 | self.iter_max = iter_max
25 | self.pos_w = pos_w
26 | self.pos_xy_std = pos_xy_std
27 | self.bi_w = bi_w
28 | self.bi_xy_std = bi_xy_std
29 | self.bi_rgb_std = bi_rgb_std
30 |
31 | def __call__(self, image, probmap):
32 | C, H, W = probmap.shape
33 |
34 | U = unary_from_softmax(probmap)
35 | U = np.ascontiguousarray(U)
36 |
37 | image = np.ascontiguousarray(image)
38 |
39 | d = dcrf.DenseCRF2D(W, H, C)
40 | d.setUnaryEnergy(U)
41 | d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w)
42 | d.addPairwiseBilateral(
43 | sxy=self.bi_xy_std, srgb=self.bi_rgb_std, rgbim=image, compat=self.bi_w
44 | )
45 |
46 | Q = d.inference(self.iter_max)
47 | Q = np.array(Q).reshape((C, H, W))
48 |
49 | output = np.argmax(Q, axis=0).astype(np.uint8)
50 |
51 | return output
--------------------------------------------------------------------------------
/PAM/utils/my_optim.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/utils/my_optim.py
3 | # ------------------------------------------------------------------------------
4 |
5 | import torch.optim as optim
6 | from torch.optim.lr_scheduler import LambdaLR
7 | import numpy as np
8 |
9 |
10 | class PolyOptimizer(optim.SGD):
11 |
12 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9):
13 | super().__init__(params, lr, weight_decay)
14 | self.param_groups = params
15 | self.global_step = 0
16 | self.max_step = max_step
17 | self.momentum = momentum
18 |
19 | self.__initial_lr = [group['lr'] for group in self.param_groups]
20 |
21 |
22 | def step(self, closure=None):
23 |
24 | if self.global_step < self.max_step:
25 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
26 |
27 | for i in range(len(self.param_groups)):
28 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
29 | super().step(closure)
30 |
31 | self.global_step += 1
32 |
33 |
34 | def lr_poly(base_lr, iter,max_iter,power=0.9):
35 | return base_lr*((1-float(iter)/max_iter)**(power))
36 |
37 | def reduce_lr_poly(args, optimizer, global_iter, max_iter):
38 | base_lr = args.lr
39 | for g in optimizer.param_groups:
40 | g['lr'] = lr_poly(base_lr=base_lr, iter=global_iter, max_iter=max_iter, power=0.9)
41 |
42 |
43 | def reduce_lr(args, optimizer, epoch, factor=0.1):
44 | values = args.decay_points.strip().split(',')
45 | try:
46 | change_points = map(lambda x: int(x.strip()), values)
47 | except ValueError:
48 | change_points = None
49 |
50 | if change_points is not None and epoch in change_points:
51 | for g in optimizer.param_groups:
52 | g['lr'] = g['lr']*factor
53 | print("Reduce Learning Rate : ", epoch, g['lr'])
54 | return True
--------------------------------------------------------------------------------
/models/hrnet_config/models.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | from yacs.config import CfgNode as CN
12 |
13 |
14 | # pose_resnet related params
15 | POSE_RESNET = CN()
16 | POSE_RESNET.NUM_LAYERS = 50
17 | POSE_RESNET.DECONV_WITH_BIAS = False
18 | POSE_RESNET.NUM_DECONV_LAYERS = 3
19 | POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
20 | POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
21 | POSE_RESNET.FINAL_CONV_KERNEL = 1
22 | POSE_RESNET.PRETRAINED_LAYERS = ['*']
23 |
24 | # pose_multi_resoluton_net related params
25 | POSE_HIGH_RESOLUTION_NET = CN()
26 | POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
27 | POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64
28 | POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
29 |
30 | POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()
31 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
32 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
33 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
34 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
35 | POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
36 | POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'
37 |
38 | POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()
39 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
40 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
41 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
42 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
43 | POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
44 | POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'
45 |
46 | POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()
47 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
48 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
49 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
50 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
51 | POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
52 | POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'
53 |
54 |
55 | MODEL_EXTRAS = {
56 | 'pose_resnet': POSE_RESNET,
57 | 'pose_high_resolution_net': POSE_HIGH_RESOLUTION_NET,
58 | }
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/models/hrnet_config/w32_384x288_adam_lr1e-3.yaml:
--------------------------------------------------------------------------------
1 | AUTO_RESUME: true
2 | CUDNN:
3 | BENCHMARK: true
4 | DETERMINISTIC: false
5 | ENABLED: true
6 | DATA_DIR: ''
7 | GPUS: (0,1,2,3)
8 | OUTPUT_DIR: 'output'
9 | LOG_DIR: 'log'
10 | WORKERS: 24
11 | PRINT_FREQ: 100
12 |
13 | DATASET:
14 | COLOR_RGB: true
15 | DATASET: 'coco'
16 | DATA_FORMAT: jpg
17 | FLIP: true
18 | NUM_JOINTS_HALF_BODY: 8
19 | PROB_HALF_BODY: 0.3
20 | ROOT: 'data/coco/'
21 | ROT_FACTOR: 45
22 | SCALE_FACTOR: 0.35
23 | TEST_SET: 'val2017'
24 | TRAIN_SET: 'train2017'
25 | MODEL:
26 | INIT_WEIGHTS: true
27 | NAME: pose_hrnet
28 | NUM_JOINTS: 1
29 | PRETRAINED: 'models/imagenet/hrnet_w32-36af842e.pth'
30 | TARGET_TYPE: gaussian
31 | IMAGE_SIZE:
32 | - 288
33 | - 384
34 | HEATMAP_SIZE:
35 | - 72
36 | - 96
37 | SIGMA: 3
38 | EXTRA:
39 | PRETRAINED_LAYERS:
40 | - 'conv1'
41 | - 'bn1'
42 | - 'conv2'
43 | - 'bn2'
44 | - 'layer1'
45 | - 'transition1'
46 | - 'stage2'
47 | - 'transition2'
48 | - 'stage3'
49 | - 'transition3'
50 | - 'stage4'
51 | FINAL_CONV_KERNEL: 1
52 | STAGE2:
53 | NUM_MODULES: 1
54 | NUM_BRANCHES: 2
55 | BLOCK: BASIC
56 | NUM_BLOCKS:
57 | - 4
58 | - 4
59 | NUM_CHANNELS:
60 | - 32
61 | - 64
62 | FUSE_METHOD: SUM
63 | STAGE3:
64 | NUM_MODULES: 4
65 | NUM_BRANCHES: 3
66 | BLOCK: BASIC
67 | NUM_BLOCKS:
68 | - 4
69 | - 4
70 | - 4
71 | NUM_CHANNELS:
72 | - 32
73 | - 64
74 | - 128
75 | FUSE_METHOD: SUM
76 | STAGE4:
77 | NUM_MODULES: 3
78 | NUM_BRANCHES: 4
79 | BLOCK: BASIC
80 | NUM_BLOCKS:
81 | - 4
82 | - 4
83 | - 4
84 | - 4
85 | NUM_CHANNELS:
86 | - 32
87 | - 64
88 | - 128
89 | - 256
90 | FUSE_METHOD: SUM
91 | LOSS:
92 | USE_TARGET_WEIGHT: true
93 | TRAIN:
94 | BATCH_SIZE_PER_GPU: 32
95 | SHUFFLE: true
96 | BEGIN_EPOCH: 0
97 | END_EPOCH: 210
98 | OPTIMIZER: adam
99 | LR: 0.001
100 | LR_FACTOR: 0.1
101 | LR_STEP:
102 | - 170
103 | - 200
104 | WD: 0.0001
105 | GAMMA1: 0.99
106 | GAMMA2: 0.0
107 | MOMENTUM: 0.9
108 | NESTEROV: false
109 | TEST:
110 | BATCH_SIZE_PER_GPU: 32
111 | COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
112 | BBOX_THRE: 1.0
113 | IMAGE_THRE: 0.0
114 | IN_VIS_THRE: 0.2
115 | MODEL_FILE: ''
116 | NMS_THRE: 1.0
117 | OKS_THRE: 0.9
118 | USE_GT_BBOX: true
119 | FLIP_TEST: true
120 | POST_PROCESS: true
121 | SHIFT_HEATMAP: true
122 | DEBUG:
123 | DEBUG: true
124 | SAVE_BATCH_IMAGES_GT: true
125 | SAVE_BATCH_IMAGES_PRED: true
126 | SAVE_HEATMAPS_GT: true
127 | SAVE_HEATMAPS_PRED: true
128 |
--------------------------------------------------------------------------------
/models/hrnet_config/w48_384x288_adam_lr1e-3.yaml:
--------------------------------------------------------------------------------
1 | AUTO_RESUME: true
2 | CUDNN:
3 | BENCHMARK: true
4 | DETERMINISTIC: false
5 | ENABLED: true
6 | DATA_DIR: ''
7 | GPUS: (0,1,2,3)
8 | OUTPUT_DIR: 'output'
9 | LOG_DIR: 'log'
10 | WORKERS: 24
11 | PRINT_FREQ: 100
12 |
13 | DATASET:
14 | COLOR_RGB: true
15 | DATASET: 'coco'
16 | DATA_FORMAT: jpg
17 | FLIP: true
18 | NUM_JOINTS_HALF_BODY: 8
19 | PROB_HALF_BODY: 0.3
20 | ROOT: 'data/coco/'
21 | ROT_FACTOR: 45
22 | SCALE_FACTOR: 0.35
23 | TEST_SET: 'val2017'
24 | TRAIN_SET: 'train2017'
25 | MODEL:
26 | INIT_WEIGHTS: true
27 | NAME: pose_hrnet
28 | NUM_JOINTS: 1
29 | PRETRAINED: 'models/imagenet/hrnet_w48-8ef0771d.pth'
30 | TARGET_TYPE: gaussian
31 | IMAGE_SIZE:
32 | - 288
33 | - 384
34 | HEATMAP_SIZE:
35 | - 72
36 | - 96
37 | SIGMA: 3
38 | EXTRA:
39 | PRETRAINED_LAYERS:
40 | - 'conv1'
41 | - 'bn1'
42 | - 'conv2'
43 | - 'bn2'
44 | - 'layer1'
45 | - 'transition1'
46 | - 'stage2'
47 | - 'transition2'
48 | - 'stage3'
49 | - 'transition3'
50 | - 'stage4'
51 | FINAL_CONV_KERNEL: 1
52 | STAGE2:
53 | NUM_MODULES: 1
54 | NUM_BRANCHES: 2
55 | BLOCK: BASIC
56 | NUM_BLOCKS:
57 | - 4
58 | - 4
59 | NUM_CHANNELS:
60 | - 48
61 | - 96
62 | FUSE_METHOD: SUM
63 | STAGE3:
64 | NUM_MODULES: 4
65 | NUM_BRANCHES: 3
66 | BLOCK: BASIC
67 | NUM_BLOCKS:
68 | - 4
69 | - 4
70 | - 4
71 | NUM_CHANNELS:
72 | - 48
73 | - 96
74 | - 192
75 | FUSE_METHOD: SUM
76 | STAGE4:
77 | NUM_MODULES: 3
78 | NUM_BRANCHES: 4
79 | BLOCK: BASIC
80 | NUM_BLOCKS:
81 | - 4
82 | - 4
83 | - 4
84 | - 4
85 | NUM_CHANNELS:
86 | - 48
87 | - 96
88 | - 192
89 | - 384
90 | FUSE_METHOD: SUM
91 | LOSS:
92 | USE_TARGET_WEIGHT: true
93 | TRAIN:
94 | BATCH_SIZE_PER_GPU: 24
95 | SHUFFLE: true
96 | BEGIN_EPOCH: 0
97 | END_EPOCH: 210
98 | OPTIMIZER: adam
99 | LR: 0.001
100 | LR_FACTOR: 0.1
101 | LR_STEP:
102 | - 170
103 | - 200
104 | WD: 0.0001
105 | GAMMA1: 0.99
106 | GAMMA2: 0.0
107 | MOMENTUM: 0.9
108 | NESTEROV: false
109 | TEST:
110 | BATCH_SIZE_PER_GPU: 24
111 | COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
112 | BBOX_THRE: 1.0
113 | IMAGE_THRE: 0.0
114 | IN_VIS_THRE: 0.2
115 | MODEL_FILE: ''
116 | NMS_THRE: 1.0
117 | OKS_THRE: 0.9
118 | USE_GT_BBOX: true
119 | FLIP_TEST: true
120 | POST_PROCESS: true
121 | SHIFT_HEATMAP: true
122 | DEBUG:
123 | DEBUG: true
124 | SAVE_BATCH_IMAGES_GT: true
125 | SAVE_BATCH_IMAGES_PRED: true
126 | SAVE_HEATMAPS_GT: true
127 | SAVE_HEATMAPS_PRED: true
128 |
--------------------------------------------------------------------------------
/PAM/utils/LoadData.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/utils/LoadData.py
3 | # ------------------------------------------------------------------------------
4 |
5 | from .transforms import transforms
6 | from torch.utils.data import DataLoader
7 | import torch
8 | import numpy as np
9 | from torch.utils.data import Dataset
10 | import os
11 | from PIL import Image
12 |
13 | def train_data_loader(args):
14 | mean_vals = [0.485, 0.456, 0.406]
15 | std_vals = [0.229, 0.224, 0.225]
16 |
17 | input_size = int(args.input_size)
18 | crop_size = int(args.crop_size)
19 | tsfm_train = transforms.Compose(
20 | [
21 | transforms.Resize(input_size),
22 | transforms.RandomHorizontalFlip(),
23 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
24 | transforms.RandomCrop(crop_size),
25 | transforms.ToTensor(),
26 | transforms.Normalize(mean_vals, std_vals),
27 | ]
28 | )
29 |
30 | train_list = os.path.join(args.root_dir, "ImageSets/Segmentation/train_cls.txt")
31 |
32 | img_train = VOCDataset(train_list, crop_size, root_dir=args.root_dir, num_classes=args.num_classes, transform=tsfm_train, mode='train')
33 |
34 | train_loader = DataLoader(img_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
35 |
36 | return train_loader
37 |
38 |
39 | def test_data_loader(args):
40 | mean_vals = [0.485, 0.456, 0.406]
41 | std_vals = [0.229, 0.224, 0.225]
42 |
43 | input_size = int(args.input_size)
44 | crop_size = int(args.crop_size)
45 |
46 | tsfm_test = transforms.Compose(
47 | [
48 | transforms.Resize(crop_size),
49 | transforms.ToTensor(),
50 | transforms.Normalize(mean_vals, std_vals),
51 | ]
52 | )
53 |
54 | test_list = os.path.join(args.root_dir, "ImageSets/Segmentation/train_cls.txt")
55 |
56 | img_test = VOCDataset(test_list, crop_size, root_dir=args.root_dir, num_classes=args.num_classes, transform=tsfm_test, mode='test')
57 |
58 | test_loader = DataLoader(img_test, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
59 |
60 | return test_loader
61 |
62 |
63 |
64 | class VOCDataset(Dataset):
65 | def __init__(self, datalist_file, input_size, root_dir, num_classes=20, transform=None, mode='train'):
66 | self.root_dir = root_dir
67 | self.mode = mode
68 | self.datalist_file = datalist_file
69 | self.transform = transform
70 | self.num_classes = num_classes
71 |
72 | self.image_list, self.label_list = self.read_labeled_image_list(self.root_dir, self.datalist_file)
73 |
74 |
75 | def __len__(self):
76 | return len(self.image_list)
77 |
78 |
79 | def __getitem__(self, idx):
80 | img_name = self.image_list[idx]
81 | image = Image.open(img_name).convert('RGB')
82 |
83 | meta = {"img_name": img_name, "ori_size": image.size}
84 |
85 | if self.transform is not None:
86 | image = self.transform(image)
87 |
88 | return image, self.label_list[idx], meta
89 |
90 |
91 | def read_labeled_image_list(self, data_dir, data_list):
92 | img_dir = os.path.join(data_dir, "JPEGImages")
93 |
94 | with open(data_list, 'r') as f:
95 | lines = f.readlines()
96 |
97 | img_name_list = []
98 | img_labels = []
99 |
100 | for line in lines:
101 | fields = line.strip().split()
102 | image = fields[0] + '.jpg'
103 |
104 | labels = np.zeros((self.num_classes,), dtype=np.float32)
105 | for i in range(len(fields)-1):
106 | index = int(fields[i+1])
107 | labels[index] = 1.
108 |
109 | img_name_list.append(os.path.join(img_dir, image))
110 | img_labels.append(labels)
111 |
112 | return img_name_list, img_labels
113 |
--------------------------------------------------------------------------------
/PAM/point_extraction.py:
--------------------------------------------------------------------------------
1 | """
2 | BESTIE
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT License
5 | """
6 |
7 | import os
8 | import numpy as np
9 | import torch
10 | import argparse
11 | import torch.nn as nn
12 |
13 | from models.classifier import vgg16_pam
14 | from utils.LoadData import test_data_loader
15 |
16 | def get_arguments():
17 | parser = argparse.ArgumentParser(description='PAM pytorch implement')
18 | parser.add_argument("--root_dir", type=str, default='', help='Directory of training images')
19 | parser.add_argument("--dataset", type=str, default='voc')
20 | parser.add_argument("--batch_size", type=int, default=1)
21 | parser.add_argument("--input_size", type=int, default=384)
22 | parser.add_argument("--crop_size", type=int, default=321)
23 | parser.add_argument("--num_classes", type=int, default=20)
24 | parser.add_argument("--num_workers", type=int, default=2)
25 | parser.add_argument('--checkpoint', default='checkpoints/PAM/ckpt_15.pth', help='Location to save checkpoint file')
26 | parser.add_argument('--save_dir', default='Peak_points', help='save dir for peak points')
27 | parser.add_argument("--alpha", type=float, default=0.7, help='hyperparameter for PAM (controller)')
28 | parser.add_argument("--conf_thresh", type=float, default=0.1, help='peak threshold')
29 |
30 | return parser.parse_args()
31 |
32 |
33 | def peak_extract(heat, kernel=5, K=25):
34 | B, C, H, W = heat.size()
35 |
36 | pad = (kernel - 1) // 2
37 |
38 | hmax = torch.nn.functional.max_pool2d(
39 | heat, (kernel, kernel), stride=1, padding=pad)
40 |
41 | keep = (hmax == heat).float()
42 |
43 | peak = heat * keep
44 |
45 | topk_scores, topk_inds = torch.topk(peak.view(B, C, -1), K)
46 |
47 | topk_inds = topk_inds % (H * W)
48 | topk_ys = (topk_inds / W).int().float()
49 | topk_xs = (topk_inds % W).int().float()
50 |
51 | topk_scores = topk_scores[0].float().detach().cpu().numpy()
52 | topk_ys = topk_ys[0].int().detach().cpu().numpy()
53 | topk_xs = topk_xs[0].int().detach().cpu().numpy()
54 |
55 | return topk_scores, topk_ys, topk_xs
56 |
57 |
58 | def smoothing(heat, kernel=3):
59 | pad = (kernel - 1) // 2
60 | heat = torch.nn.functional.avg_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
61 |
62 | return heat
63 |
64 |
65 | if __name__ == '__main__':
66 | args = get_arguments()
67 | os.makedirs(args.save_dir, exist_ok=True)
68 |
69 | model = vgg16_pam(alpha=args.alpha)
70 |
71 | state = torch.load(args.checkpoint, map_location='cpu')
72 | model.load_state_dict(state['model'], strict=True)
73 |
74 | model.eval()
75 | model.cuda()
76 |
77 | data_loader = test_data_loader(args)
78 |
79 | with torch.no_grad():
80 |
81 | for idx, (img, label, meta) in enumerate(data_loader):
82 | img_name = meta['img_name'][0]
83 | ori_W, ori_H = int(meta['ori_size'][0]), int(meta['ori_size'][1])
84 | print("[%03d/%03d] %s" % (idx, len(data_loader), img_name), end='\r')
85 |
86 | label = label.to('cuda', non_blocking=True)
87 | img = img.to('cuda', non_blocking=True)
88 |
89 | # flip TTA
90 | _img = torch.cat( [img, img.flip(-1)] , dim=0)
91 | _label = torch.cat( [label, label] , dim=0)
92 |
93 | _, cam = model(_img, _label, (ori_H, ori_W))
94 |
95 | cam = (cam[0:1] + cam[1:2].flip(-1)) / 2.
96 | cam = smoothing(cam)
97 |
98 | peak_conf, peak_y, peak_x = peak_extract(cam, kernel=15)
99 |
100 | #####################################################################
101 |
102 | img_name = img_name.split("/")[-1][:-4]
103 | label = label[0].cpu().detach().numpy()
104 | valid_label = np.nonzero(label)[0]
105 |
106 | with open(os.path.join(args.save_dir, "%s.txt" % img_name), 'w') as peak_txt:
107 | for l in valid_label:
108 | for conf, x, y in zip(peak_conf[l], peak_x[l], peak_y[l]):
109 | if conf < args.conf_thresh:
110 | break
111 |
112 | peak_txt.write("%d %d %d %.3f\n" % (x, y, l, conf))
113 |
114 |
--------------------------------------------------------------------------------
/models/hrnet_config/default.py:
--------------------------------------------------------------------------------
1 |
2 | # ------------------------------------------------------------------------------
3 | # Copyright (c) Microsoft
4 | # Licensed under the MIT License.
5 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import os
13 |
14 | from yacs.config import CfgNode as CN
15 |
16 |
17 | _C = CN()
18 |
19 | _C.OUTPUT_DIR = ''
20 | _C.LOG_DIR = ''
21 | _C.DATA_DIR = ''
22 | _C.GPUS = (0,)
23 | _C.WORKERS = 4
24 | _C.PRINT_FREQ = 20
25 | _C.AUTO_RESUME = False
26 | _C.PIN_MEMORY = True
27 | _C.RANK = 0
28 |
29 | # Cudnn related params
30 | _C.CUDNN = CN()
31 | _C.CUDNN.BENCHMARK = True
32 | _C.CUDNN.DETERMINISTIC = False
33 | _C.CUDNN.ENABLED = True
34 |
35 | # common params for NETWORK
36 | _C.MODEL = CN()
37 | _C.MODEL.NAME = 'pose_hrnet'
38 | _C.MODEL.INIT_WEIGHTS = True
39 | _C.MODEL.PRETRAINED = ''
40 | _C.MODEL.NUM_JOINTS = 1
41 | _C.MODEL.TAG_PER_JOINT = True
42 | _C.MODEL.TARGET_TYPE = 'gaussian'
43 | _C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
44 | _C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
45 | _C.MODEL.SIGMA = 2
46 | _C.MODEL.EXTRA = CN(new_allowed=True)
47 |
48 | _C.LOSS = CN()
49 | _C.LOSS.USE_OHKM = False
50 | _C.LOSS.TOPK = 8
51 | _C.LOSS.USE_TARGET_WEIGHT = True
52 | _C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False
53 |
54 | # DATASET related params
55 | _C.DATASET = CN()
56 | _C.DATASET.ROOT = ''
57 | _C.DATASET.DATASET = 'mpii'
58 | _C.DATASET.TRAIN_SET = 'train'
59 | _C.DATASET.TEST_SET = 'valid'
60 | _C.DATASET.DATA_FORMAT = 'jpg'
61 | _C.DATASET.HYBRID_JOINTS_TYPE = ''
62 | _C.DATASET.SELECT_DATA = False
63 |
64 | # training data augmentation
65 | _C.DATASET.FLIP = True
66 | _C.DATASET.SCALE_FACTOR = 0.25
67 | _C.DATASET.ROT_FACTOR = 30
68 | _C.DATASET.PROB_HALF_BODY = 0.0
69 | _C.DATASET.NUM_JOINTS_HALF_BODY = 8
70 | _C.DATASET.COLOR_RGB = False
71 |
72 | # train
73 | _C.TRAIN = CN()
74 |
75 | _C.TRAIN.LR_FACTOR = 0.1
76 | _C.TRAIN.LR_STEP = [90, 110]
77 | _C.TRAIN.LR = 0.001
78 |
79 | _C.TRAIN.OPTIMIZER = 'adam'
80 | _C.TRAIN.MOMENTUM = 0.9
81 | _C.TRAIN.WD = 0.0001
82 | _C.TRAIN.NESTEROV = False
83 | _C.TRAIN.GAMMA1 = 0.99
84 | _C.TRAIN.GAMMA2 = 0.0
85 |
86 | _C.TRAIN.BEGIN_EPOCH = 0
87 | _C.TRAIN.END_EPOCH = 140
88 |
89 | _C.TRAIN.RESUME = False
90 | _C.TRAIN.CHECKPOINT = ''
91 |
92 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32
93 | _C.TRAIN.SHUFFLE = True
94 |
95 | # testing
96 | _C.TEST = CN()
97 |
98 | # size of images for each device
99 | _C.TEST.BATCH_SIZE_PER_GPU = 32
100 | # Test Model Epoch
101 | _C.TEST.FLIP_TEST = False
102 | _C.TEST.POST_PROCESS = False
103 | _C.TEST.SHIFT_HEATMAP = False
104 |
105 | _C.TEST.USE_GT_BBOX = False
106 |
107 | # nms
108 | _C.TEST.IMAGE_THRE = 0.1
109 | _C.TEST.NMS_THRE = 0.6
110 | _C.TEST.SOFT_NMS = False
111 | _C.TEST.OKS_THRE = 0.5
112 | _C.TEST.IN_VIS_THRE = 0.0
113 | _C.TEST.COCO_BBOX_FILE = ''
114 | _C.TEST.BBOX_THRE = 1.0
115 | _C.TEST.MODEL_FILE = ''
116 |
117 | # debug
118 | _C.DEBUG = CN()
119 | _C.DEBUG.DEBUG = False
120 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False
121 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
122 | _C.DEBUG.SAVE_HEATMAPS_GT = False
123 | _C.DEBUG.SAVE_HEATMAPS_PRED = False
124 |
125 | def update_config(cfg, yaml):
126 | cfg.defrost()
127 | cfg.merge_from_file(yaml)
128 |
129 |
130 | # def update_config(cfg, args):
131 | # cfg.defrost()
132 | # cfg.merge_from_file(args.cfg)
133 | # cfg.merge_from_list(args.opts)
134 |
135 | # if args.modelDir:
136 | # cfg.OUTPUT_DIR = args.modelDir
137 |
138 | # if args.logDir:
139 | # cfg.LOG_DIR = args.logDir
140 |
141 | # if args.dataDir:
142 | # cfg.DATA_DIR = args.dataDir
143 |
144 | # cfg.DATASET.ROOT = os.path.join(
145 | # cfg.DATA_DIR, cfg.DATASET.ROOT
146 | # )
147 |
148 | # cfg.MODEL.PRETRAINED = os.path.join(
149 | # cfg.DATA_DIR, cfg.MODEL.PRETRAINED
150 | # )
151 |
152 | # if cfg.TEST.MODEL_FILE:
153 | # cfg.TEST.MODEL_FILE = os.path.join(
154 | # cfg.DATA_DIR, cfg.TEST.MODEL_FILE
155 | # )
156 |
157 |
158 | if __name__ == '__main__':
159 | import sys
160 | with open(sys.argv[1], 'w') as f:
161 | print(_C, file=f)
162 |
163 |
--------------------------------------------------------------------------------
/PAM/train.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/scripts/train_cls.py
3 | # ------------------------------------------------------------------------------
4 | """
5 | BESTIE
6 | Copyright (c) 2022-present NAVER Corp.
7 | MIT License
8 | """
9 |
10 | import os
11 | import numpy as np
12 | import torch
13 | import argparse
14 |
15 | import torch.optim as optim
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 |
19 | from models.classifier import vgg16_pam
20 | from utils.my_optim import reduce_lr
21 | from utils.avgMeter import AverageMeter
22 | from utils.LoadData import train_data_loader
23 | from utils.Metrics import Cls_Accuracy
24 |
25 |
26 | def get_arguments():
27 | parser = argparse.ArgumentParser(description='PAM pytorch implement')
28 | parser.add_argument("--root_dir", type=str, default='', help='Directory of training images')
29 | parser.add_argument("--dataset", type=str, default='voc')
30 | parser.add_argument("--batch_size", type=int, default=5)
31 | parser.add_argument("--input_size", type=int, default=384)
32 | parser.add_argument("--crop_size", type=int, default=321)
33 | parser.add_argument("--num_classes", type=int, default=20)
34 | parser.add_argument("--lr", type=float, default=0.001)
35 | parser.add_argument("--weight_decay", type=float, default=0.0005)
36 | parser.add_argument("--decay_points", type=str, default='5,10')
37 | parser.add_argument("--epoch", type=int, default=15)
38 | parser.add_argument("--num_workers", type=int, default=2)
39 | parser.add_argument('--show_interval', default=50, type=int, help='interval of showing training conditions')
40 | parser.add_argument('--save_interval', default=5, type=int, help='interval of save checkpoint models')
41 | parser.add_argument('--save_folder', default='checkpoints/test', help='Location to save checkpoint models')
42 | parser.add_argument("--alpha", type=float, default=0.7, help='hyperparameter for PAM (controller)')
43 |
44 | return parser.parse_args()
45 |
46 |
47 | def get_model(args):
48 | model = vgg16_pam(pretrained=True, alpha=args.alpha)
49 |
50 | model = model.cuda()
51 | model = torch.nn.DataParallel(model).cuda()
52 | param_groups = model.module.get_parameter_groups()
53 |
54 | optimizer = optim.SGD(
55 | [
56 | {'params': param_groups[0], 'lr': args.lr},
57 | {'params': param_groups[1], 'lr': 2*args.lr},
58 | {'params': param_groups[2], 'lr': 10*args.lr},
59 | {'params': param_groups[3], 'lr': 20*args.lr}
60 | ],
61 | momentum=0.9,
62 | weight_decay=args.weight_decay,
63 | nesterov=True
64 | )
65 |
66 | return model, optimizer
67 |
68 |
69 | def train(current_epoch):
70 | global curr_iter
71 | losses = AverageMeter()
72 | cls_acc_metric = Cls_Accuracy()
73 |
74 | model.train()
75 |
76 | """ learning rate decay """
77 | res = reduce_lr(args, optimizer, current_epoch)
78 |
79 | for img, label, _ in train_loader:
80 | label = label.to('cuda', non_blocking=True)
81 | img = img.to('cuda', non_blocking=True)
82 |
83 | logit = model(img)
84 |
85 | """ classification loss """
86 | loss = F.multilabel_soft_margin_loss(logit, label)
87 |
88 | """ backprop """
89 | optimizer.zero_grad()
90 | loss.backward()
91 | optimizer.step()
92 |
93 | """ average meter """
94 | cls_acc_metric.update(logit, label)
95 | losses.update(loss.item(), img.size()[0])
96 |
97 | curr_iter += 1
98 |
99 | """ training log """
100 | if curr_iter % args.show_interval == 0:
101 | cls_acc = cls_acc_metric.compute_avg_acc()
102 |
103 | print('Epoch: [{}][{}/{}] '
104 | 'LR: {:.5f} '
105 | 'ACC: {:.5f} '
106 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) '.format(
107 | current_epoch, curr_iter%len(train_loader), len(train_loader),
108 | optimizer.param_groups[0]['lr'], cls_acc, loss=losses))
109 |
110 |
111 | if __name__ == '__main__':
112 | args = get_arguments()
113 |
114 | n_gpu = torch.cuda.device_count()
115 |
116 | args.batch_size *= n_gpu
117 | args.num_workers *= n_gpu
118 |
119 | print('Running parameters:\n', args)
120 |
121 | if not os.path.exists(args.save_folder):
122 | os.makedirs(args.save_folder)
123 |
124 | train_loader = train_data_loader(args)
125 | print('# of train dataset:', len(train_loader) * args.batch_size)
126 |
127 | model, optimizer = get_model(args)
128 |
129 | curr_iter = 0
130 | for current_epoch in range(1, args.epoch+1):
131 | train(current_epoch)
132 |
133 | """ save checkpoint """
134 | if current_epoch % args.save_interval == 0 and current_epoch > 0:
135 | print('\nSaving state, epoch : %d \n' % current_epoch)
136 | state = {
137 | 'model': model.module.state_dict(),
138 | #"optimizer": optimizer.state_dict(),
139 | 'epoch': current_epoch,
140 | 'iter': curr_iter,
141 | }
142 | model_file = args.save_folder + '/ckpt_' + repr(current_epoch) + '.pth'
143 | torch.save(state, model_file)
144 |
--------------------------------------------------------------------------------
/PAM/utils/Metrics.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/utils/Metrics.py
3 | # ------------------------------------------------------------------------------
4 |
5 | import numpy as np
6 | from sklearn.metrics import confusion_matrix
7 | import torch
8 |
9 | class Cls_Accuracy():
10 | def __init__(self, ):
11 | self.total = 0
12 | self.correct = 0
13 |
14 |
15 | def update(self, logit, label):
16 |
17 | logit = logit.sigmoid_()
18 | logit = (logit >= 0.5)
19 | all_correct = torch.all(logit == label.byte(), dim=1).float().sum().item()
20 |
21 | self.total += logit.size(0)
22 | self.correct += all_correct
23 |
24 | def compute_avg_acc(self):
25 | return self.correct / self.total
26 |
27 |
28 |
29 | class RunningConfusionMatrix():
30 | """Running Confusion Matrix class that enables computation of confusion matrix
31 | on the go and has methods to compute such accuracy metrics as Mean Intersection over
32 | Union MIOU.
33 |
34 | Attributes
35 | ----------
36 | labels : list[int]
37 | List that contains int values that represent classes.
38 | overall_confusion_matrix : sklean.confusion_matrix object
39 | Container of the sum of all confusion matrices. Used to compute MIOU at the end.
40 | ignore_label : int
41 | A label representing parts that should be ignored during
42 | computation of metrics
43 |
44 | """
45 |
46 | def __init__(self, labels, ignore_label=255):
47 |
48 | self.labels = labels
49 | self.ignore_label = ignore_label
50 | self.overall_confusion_matrix = None
51 |
52 | def update_matrix(self, ground_truth, prediction):
53 | """Updates overall confusion matrix statistics.
54 | If you are working with 2D data, just .flatten() it before running this
55 | function.
56 | Parameters
57 | ----------
58 | groundtruth : array, shape = [n_samples]
59 | An array with groundtruth values
60 | prediction : array, shape = [n_samples]
61 | An array with predictions
62 | """
63 |
64 | # Mask-out value is ignored by default in the sklearn
65 | # read sources to see how that was handled
66 | # But sometimes all the elements in the groundtruth can
67 | # be equal to ignore value which will cause the crush
68 | # of scikit_learn.confusion_matrix(), this is why we check it here
69 | if (ground_truth == self.ignore_label).all():
70 |
71 | return
72 |
73 | current_confusion_matrix = confusion_matrix(y_true=ground_truth,
74 | y_pred=prediction,
75 | labels=self.labels)
76 |
77 | if self.overall_confusion_matrix is not None:
78 |
79 | self.overall_confusion_matrix += current_confusion_matrix
80 | else:
81 |
82 | self.overall_confusion_matrix = current_confusion_matrix
83 |
84 | def compute_current_mean_intersection_over_union(self):
85 |
86 | intersection = np.diag(self.overall_confusion_matrix)
87 | ground_truth_set = self.overall_confusion_matrix.sum(axis=1)
88 | predicted_set = self.overall_confusion_matrix.sum(axis=0)
89 | union = ground_truth_set + predicted_set - intersection
90 |
91 | #intersection_over_union = intersection / (union.astype(np.float32) + 1e-4)
92 | intersection_over_union = intersection / union.astype(np.float32)
93 |
94 | mean_intersection_over_union = np.mean(intersection_over_union)
95 |
96 | return mean_intersection_over_union
97 |
98 |
99 | class IOUMetric:
100 | """
101 | Class to calculate mean-iou using fast_hist method
102 | """
103 |
104 | def __init__(self, num_classes):
105 | self.num_classes = num_classes
106 | self.hist = np.zeros((num_classes, num_classes))
107 |
108 | def _fast_hist(self, label_pred, label_true):
109 | mask = (label_true >= 0) & (label_true < self.num_classes)
110 |
111 | hist = np.bincount(
112 | self.num_classes*label_true[mask] + label_pred[mask],
113 | minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
114 |
115 | return hist
116 |
117 | def add_batch(self, predictions, gts):
118 | for lp, lt in zip(predictions, gts):
119 | self.hist += self._fast_hist(lp.flatten(), lt.flatten())
120 |
121 | def evaluate(self):
122 | acc = np.diag(self.hist).sum() / self.hist.sum()
123 | acc_cls = np.diag(self.hist) / self.hist.sum(axis=1)
124 | acc_cls = np.nanmean(acc_cls)
125 | iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
126 | mean_iu = np.nanmean(iu)
127 | freq = self.hist.sum(axis=1) / self.hist.sum()
128 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
129 | cls_iu = dict(zip(range(self.num_classes), iu))
130 |
131 | return {
132 | "Pixel_Accuracy": acc,
133 | "Mean_Accuracy": acc_cls,
134 | "Frequency_Weighted_IoU": fwavacc,
135 | "Mean_IoU": mean_iu,
136 | "Class_IoU": cls_iu,
137 | }
138 |
--------------------------------------------------------------------------------
/utils/my_optim.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.optim as optim
4 | from bisect import bisect_right
5 | from typing import List
6 |
7 |
8 | class PolyOptimizer(optim.SGD):
9 |
10 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9):
11 | super().__init__(params, lr, weight_decay)
12 | self.param_groups = params
13 | self.global_step = 0
14 | self.max_step = max_step
15 | self.momentum = momentum
16 |
17 | self.__initial_lr = [group['lr'] for group in self.param_groups]
18 |
19 |
20 | def step(self, closure=None):
21 |
22 | if self.global_step < self.max_step:
23 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
24 |
25 | for i in range(len(self.param_groups)):
26 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
27 | super().step(closure)
28 |
29 | self.global_step += 1
30 |
31 |
32 | def lr_poly(base_lr, iter,max_iter,power=0.9):
33 | return base_lr*((1-float(iter)/max_iter)**(power))
34 |
35 | def reduce_lr_poly(args, optimizer, global_iter, max_iter):
36 | base_lr = args.lr
37 | for g in optimizer.param_groups:
38 | g['lr'] = lr_poly(base_lr=base_lr, iter=global_iter, max_iter=max_iter, power=0.9)
39 |
40 |
41 | def reduce_lr(args, optimizer, epoch, factor=0.1):
42 | values = args.decay_points.strip().split(',')
43 | try:
44 | change_points = map(lambda x: int(x.strip()), values)
45 | except ValueError:
46 | change_points = None
47 |
48 | if change_points is not None and epoch in change_points:
49 | for g in optimizer.param_groups:
50 | g['lr'] = g['lr']*factor
51 | print("Reduce Learning Rate : ", epoch, g['lr'])
52 | return True
53 |
54 |
55 |
56 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
57 | def __init__(
58 | self,
59 | optimizer: torch.optim.Optimizer,
60 | max_iters: int,
61 | warmup_factor: float = 0.001,
62 | warmup_iters: int = 1000,
63 | warmup_method: str = "linear",
64 | last_epoch: int = -1,
65 | power: float = 0.9,
66 | constant_ending: float = 0.
67 | ):
68 | self.max_iters = max_iters
69 | self.warmup_factor = warmup_factor
70 | self.warmup_iters = warmup_iters
71 | self.warmup_method = warmup_method
72 | self.power = power
73 | self.constant_ending = constant_ending
74 | super().__init__(optimizer, last_epoch)
75 |
76 | def get_lr(self) -> List[float]:
77 | warmup_factor = _get_warmup_factor_at_iter(
78 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
79 | )
80 | if self.constant_ending > 0 and warmup_factor == 1.:
81 | # Constant ending lr.
82 | if math.pow((1.0 - self.last_epoch / self.max_iters), self.power) < self.constant_ending:
83 | return [
84 | base_lr
85 | * self.constant_ending
86 | for base_lr in self.base_lrs
87 | ]
88 | return [
89 | base_lr
90 | * warmup_factor
91 | * math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
92 | for base_lr in self.base_lrs
93 | ]
94 |
95 | def _compute_values(self) -> List[float]:
96 | # The new interface
97 | return self.get_lr()
98 |
99 |
100 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
101 | def __init__(
102 | self,
103 | optimizer: torch.optim.Optimizer,
104 | milestones: List[int],
105 | gamma: float = 0.1,
106 | warmup_factor: float = 0.001,
107 | warmup_iters: int = 1000,
108 | warmup_method: str = "linear",
109 | last_epoch: int = -1,
110 | ):
111 | if not list(milestones) == sorted(milestones):
112 | raise ValueError(
113 | "Milestones should be a list of" " increasing integers. Got {}", milestones
114 | )
115 | self.milestones = milestones
116 | self.gamma = gamma
117 | self.warmup_factor = warmup_factor
118 | self.warmup_iters = warmup_iters
119 | self.warmup_method = warmup_method
120 | super().__init__(optimizer, last_epoch)
121 |
122 | def get_lr(self) -> List[float]:
123 | warmup_factor = _get_warmup_factor_at_iter(
124 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
125 | )
126 | return [
127 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
128 | for base_lr in self.base_lrs
129 | ]
130 |
131 | def _compute_values(self) -> List[float]:
132 | # The new interface
133 | return self.get_lr()
134 |
135 |
136 | def _get_warmup_factor_at_iter(
137 | method: str, iter: int, warmup_iters: int, warmup_factor: float
138 | ) -> float:
139 | """
140 | Return the learning rate warmup factor at a specific iteration.
141 | See https://arxiv.org/abs/1706.02677 for more details.
142 | Args:
143 | method (str): warmup method; either "constant" or "linear".
144 | iter (int): iteration at which to calculate the warmup factor.
145 | warmup_iters (int): the number of warmup iterations.
146 | warmup_factor (float): the base warmup factor (the meaning changes according
147 | to the method used).
148 | Returns:
149 | float: the effective warmup factor at the given iteration.
150 | """
151 | if iter >= warmup_iters:
152 | return 1.0
153 |
154 | if method == "constant":
155 | return warmup_factor
156 | elif method == "linear":
157 | alpha = iter / warmup_iters
158 | return warmup_factor * (1 - alpha) + alpha
159 | else:
160 | raise ValueError("Unknown warmup method: {}".format(method))
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | BESTIE
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT License
5 | """
6 |
7 | import torch
8 |
9 |
10 | class L1_Loss(torch.nn.Module):
11 | ''' L1 loss for Offset map (without Instance-aware Guidance)'''
12 | def __init__(self):
13 | super(L1_Loss, self).__init__()
14 | self.l1_loss = torch.nn.L1Loss(reduction='mean')
15 |
16 | def forward(self, out, target, weight):
17 | loss = self.l1_loss(out, target)
18 |
19 | return loss
20 |
21 |
22 | class Weighted_L1_Loss(torch.nn.Module):
23 | ''' Weighted L1 loss for Offset map (with Instance-aware Guidance)'''
24 | def __init__(self):
25 | super(Weighted_L1_Loss, self).__init__()
26 | self.l1_loss = torch.nn.L1Loss(reduction='none')
27 |
28 | def forward(self, out, target, weight):
29 | loss = self.l1_loss(out, target) * weight
30 |
31 | if weight.sum() > 0:
32 | loss = loss.sum() / (weight > 0).float().sum()
33 | else:
34 | loss = loss.sum() * 0
35 |
36 | return loss
37 |
38 |
39 | class MSELoss(torch.nn.Module):
40 | ''' MSE loss for center map (without Instance-aware Guidance)'''
41 | def __init__(self):
42 | super(MSELoss, self).__init__()
43 | self.mse_loss = torch.nn.MSELoss(reduction='mean')
44 |
45 | def forward(self, out, target, weight):
46 |
47 | loss = self.mse_loss(out, target)
48 |
49 | return loss
50 |
51 |
52 | class Weighted_MSELoss(torch.nn.Module):
53 | ''' MSE loss for center map (with Instance-aware Guidance)'''
54 | def __init__(self):
55 | super(Weighted_MSELoss, self).__init__()
56 | self.mse_loss = torch.nn.MSELoss(reduction='none')
57 |
58 | def forward(self, out, target, weight):
59 |
60 | loss = self.mse_loss(out, target) * weight
61 |
62 | if weight.sum() > 0:
63 | loss = loss.sum() / (weight > 0).float().sum()
64 | else:
65 | loss = loss.sum() * 0
66 |
67 | return loss
68 |
69 |
70 | class DeepLabCE(torch.nn.Module):
71 | """
72 | Hard pixel mining mining with cross entropy loss, for semantic segmentation.
73 | This is used in TensorFlow DeepLab frameworks.
74 | Reference: https://github.com/tensorflow/models/blob/bd488858d610e44df69da6f89277e9de8a03722c/research/deeplab/utils/train_utils.py#L33
75 | Arguments:
76 | ignore_label: Integer, label to ignore.
77 | top_k_percent_pixels: Float, the value lies in [0.0, 1.0]. When its value < 1.0, only compute the loss for
78 | the top k percent pixels (e.g., the top 20% pixels). This is useful for hard pixel mining.
79 | weight: Tensor, a manual rescaling weight given to each class.
80 | """
81 | def __init__(self, ignore_label=255, top_k_percent_pixels=0.2, weight=None):
82 | super(DeepLabCE, self).__init__()
83 | self.top_k_percent_pixels = top_k_percent_pixels
84 | self.ignore_label = ignore_label
85 | self.criterion = torch.nn.CrossEntropyLoss(weight=weight,
86 | ignore_index=ignore_label,
87 | reduction='none')
88 |
89 | def forward(self, logits, labels):
90 |
91 | pixel_losses = self.criterion(logits, labels).contiguous().view(-1)
92 |
93 | if self.top_k_percent_pixels == 1.0:
94 | return pixel_losses.mean()
95 |
96 | top_k_pixels = int(self.top_k_percent_pixels * pixel_losses.numel())
97 | pixel_losses, _ = torch.topk(pixel_losses, top_k_pixels)
98 |
99 | return pixel_losses.mean()
100 |
101 |
102 | class RegularCE(torch.nn.Module):
103 | """
104 | Regular cross entropy loss for semantic segmentation, support pixel-wise loss weight.
105 | Arguments:
106 | ignore_label: Integer, label to ignore.
107 | weight: Tensor, a manual rescaling weight given to each class.
108 | """
109 | def __init__(self, ignore_label=255, weight=None):
110 | super(RegularCE, self).__init__()
111 | self.ignore_label = ignore_label
112 | self.criterion = torch.nn.CrossEntropyLoss(weight=weight,
113 | ignore_index=ignore_label,
114 | reduction='none')
115 |
116 | def forward(self, logits, labels):
117 | pixel_losses = self.criterion(logits, labels)
118 |
119 | mask = (labels != self.ignore_label)
120 |
121 | if mask.sum() > 0:
122 | pixel_losses = pixel_losses.sum() / mask.sum()
123 | else:
124 | pixel_losses = pixel_losses.sum() * 0
125 |
126 | return pixel_losses
127 |
128 |
129 |
130 | def _neg_loss(pred, gt, weight):
131 | ''' Modified focal loss. Exactly the same as CornerNet.
132 | Runs faster and costs a little bit more memory
133 | Arguments:
134 | pred (batch x c x h x w)
135 | gt_regr (batch x c x h x w)
136 | '''
137 | pos_inds = gt.eq(1).float()
138 | neg_inds = gt.lt(1).float()
139 |
140 | neg_weights = torch.pow(1 - gt, 4)
141 |
142 | loss = 0
143 |
144 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds * weight
145 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds * weight
146 |
147 | num_pos = pos_inds.float().sum()
148 | pos_loss = pos_loss.sum()
149 | neg_loss = neg_loss.sum()
150 |
151 | if num_pos == 0:
152 | loss = loss - neg_loss
153 | else:
154 | loss = loss - (pos_loss + neg_loss) / num_pos
155 | return loss
156 |
157 |
158 | class FocalLoss(torch.nn.Module):
159 | '''nn.Module warpper for focal loss'''
160 | def __init__(self):
161 | super(FocalLoss, self).__init__()
162 | self.neg_loss = _neg_loss
163 |
164 | def forward(self, out, target, weight):
165 | return self.neg_loss(out, target, weight)
166 |
167 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | BESTIE
2 | Copyright (c) 2022-present NAVER Corp.
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy
5 | of this software and associated documentation files (the "Software"), to deal
6 | in the Software without restriction, including without limitation the rights
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is
9 | furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in
12 | all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 | THE SOFTWARE.
21 |
22 | --------------------------------------------------------------------------------------
23 |
24 | This project contains subcomponents with separate copyright notices and license terms.
25 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
26 |
27 | =====
28 |
29 | HRNet/HRNet-Human-Pose-Estimation
30 | https://github.com/HRNet/HRNet-Human-Pose-Estimation
31 |
32 |
33 | MIT License
34 |
35 | Copyright (c) 2019 Leo Xiao
36 |
37 | Permission is hereby granted, free of charge, to any person obtaining a copy
38 | of this software and associated documentation files (the "Software"), to deal
39 | in the Software without restriction, including without limitation the rights
40 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
41 | copies of the Software, and to permit persons to whom the Software is
42 | furnished to do so, subject to the following conditions:
43 |
44 | The above copyright notice and this permission notice shall be included in all
45 | copies or substantial portions of the Software.
46 |
47 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
48 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
49 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
50 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
51 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
52 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
53 | SOFTWARE.
54 |
55 | =====
56 |
57 | bowenc0221/panoptic-deeplab
58 | https://github.com/bowenc0221/panoptic-deeplab
59 |
60 |
61 | Licensed under the Apache License, Version 2.0 (the "License");
62 | you may not use this file except in compliance with the License.
63 | You may obtain a copy of the License at
64 |
65 | http://www.apache.org/licenses/LICENSE-2.0
66 |
67 | Unless required by applicable law or agreed to in writing, software
68 | distributed under the License is distributed on an "AS IS" BASIS,
69 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70 | See the License for the specific language governing permissions and
71 | limitations under the License.
72 |
73 | =====
74 |
75 | pytorch/vision
76 | https://github.com/pytorch/vision
77 |
78 |
79 | BSD 3-Clause License
80 |
81 | Copyright (c) Soumith Chintala 2016,
82 | All rights reserved.
83 |
84 | Redistribution and use in source and binary forms, with or without
85 | modification, are permitted provided that the following conditions are met:
86 |
87 | * Redistributions of source code must retain the above copyright notice, this
88 | list of conditions and the following disclaimer.
89 |
90 | * Redistributions in binary form must reproduce the above copyright notice,
91 | this list of conditions and the following disclaimer in the documentation
92 | and/or other materials provided with the distribution.
93 |
94 | * Neither the name of the copyright holder nor the names of its
95 | contributors may be used to endorse or promote products derived from
96 | this software without specific prior written permission.
97 |
98 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
99 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
100 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
101 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
102 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
103 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
104 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
105 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
106 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
107 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
108 |
109 | =====
110 |
111 | tensorflow/models
112 | https://github.com/tensorflow/models/
113 |
114 |
115 | Copyright 2016, The Authors.
116 |
117 | Licensed under the Apache License, Version 2.0 (the "License");
118 | you may not use this file except in compliance with the License.
119 | You may obtain a copy of the License at
120 |
121 | http://www.apache.org/licenses/LICENSE-2.0
122 |
123 | Unless required by applicable law or agreed to in writing, software
124 | distributed under the License is distributed on an "AS IS" BASIS,
125 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
126 | See the License for the specific language governing permissions and
127 | limitations under the License.
128 |
129 | =====
130 |
131 | qjadud1994/DRS
132 | https://github.com/qjadud1994/DRS
133 |
134 |
135 | MIT License
136 |
137 | Copyright (c) 2021 Beom
138 |
139 | Permission is hereby granted, free of charge, to any person obtaining a copy
140 | of this software and associated documentation files (the "Software"), to deal
141 | in the Software without restriction, including without limitation the rights
142 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
143 | copies of the Software, and to permit persons to whom the Software is
144 | furnished to do so, subject to the following conditions:
145 |
146 | The above copyright notice and this permission notice shall be included in all
147 | copies or substantial portions of the Software.
148 |
149 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
150 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
151 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
152 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
153 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
154 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
155 | SOFTWARE.
156 |
157 | =====
158 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BESTIE - Official Pytorch Implementation (CVPR 2022)
2 |
3 | **Beyond Semantic to Instance Segmentation: Weakly-Supervised Instance Segmentation via Semantic Knowledge Transfer and Self-Refinement (CVPR 2022)**
4 | Beomyoung Kim1, YoungJoon Yoo1,2, Chaeeun Rhee3, Junmo Kim4
5 |
6 | 1 NAVER CLOVA
7 | 2 NAVER AI Lab
8 | 3 Inha University
9 | 4 KAIST
10 |
11 | [](https://arxiv.org/abs/2109.09477)
12 |
13 | [](https://paperswithcode.com/sota/image-level-supervised-instance-segmentation?p=beyond-semantic-to-instance-segmentation)
14 | [](https://paperswithcode.com/sota/image-level-supervised-instance-segmentation-2?p=beyond-semantic-to-instance-segmentation)
15 | [](https://paperswithcode.com/sota/image-level-supervised-instance-segmentation-1?p=beyond-semantic-to-instance-segmentation)
16 |
17 | [](https://paperswithcode.com/sota/point-supervised-instance-segmentation-on?p=beyond-semantic-to-instance-segmentation)
18 | [](https://paperswithcode.com/sota/point-supervised-instance-segmentation-on-2?p=beyond-semantic-to-instance-segmentation)
19 | [](https://paperswithcode.com/sota/point-supervised-instance-segmentation-on-1?p=beyond-semantic-to-instance-segmentation)
20 |
21 |
22 |
23 |
24 | # Abtract
25 |
26 | Weakly-supervised instance segmentation (WSIS) has been considered as a more challenging task than weakly-supervised semantic segmentation (WSSS). Compared to WSSS, WSIS requires instance-wise localization, which is difficult to extract from image-level labels. To tackle the problem, most WSIS approaches use off-the-shelf proposal techniques that require pre-training with instance or object level labels, deviating the fundamental definition of the fully-image-level supervised setting.
27 | In this paper, we propose a novel approach including two innovative components. First, we propose a *semantic knowledge transfer* to obtain pseudo instance labels by transferring the knowledge of WSSS to WSIS while eliminating the need for the off-the-shelf proposals. Second, we propose a *self-refinement* method to refine the pseudo instance labels in a self-supervised scheme and to use the refined labels for training in an online manner. Here, we discover an erroneous phenomenon, *semantic drift*, that occurred by the missing instances in pseudo instance labels categorized as background class. This *semantic drift* occurs confusion between background and instance in training and consequently degrades the segmentation performance. We term this problem as *semantic drift problem* and show that our proposed *self-refinement* method eliminates the semantic drift problem.
28 | The extensive experiments on PASCAL VOC 2012 and MS COCO demonstrate the effectiveness of our approach, and we achieve a considerable performance without off-the-shelf proposal techniques. The code is available at https://github.com/clovaai/BESTIE.
29 |
30 | # Experimental Results (VOC 2012, COCO)
31 |
32 |
33 |
34 |
35 | * BESTIE (HRNet48, Image-label) : 42.6 mAP50 on VOC2012 [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/BESTIE_HRNet48_image_label.pt)
36 | * BESTIE (HRNet48, point-label) : 46.7 mAP50 on VOC2012 [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/BESTIE_HRNet48_point_label.pt)
37 |
38 | Extra Sources
39 | * PAM: [[pretrained_weight]](https://github.com/clovaai/BESTIE/releases/download/asset/PAM.pth)
40 | * HRNet-W32: [[imagenet_pretrained_weight]](https://github.com/clovaai/BESTIE/releases/download/asset/hrnet_w32-36af842e.pth)
41 | * HRNet-W48: [[imagenet_pretrained_weight]](https://github.com/clovaai/BESTIE/releases/download/asset/hrnet_w48-8ef0771d.pth)
42 |
43 | # Qualitative Results
44 |
45 |
46 |
47 |
48 | # News
49 |
50 | - [x] official pytorch code release
51 | - [x] release the code for the classifier with PAM module
52 | - [ ] update training code and dataset for COCO
53 |
54 | # How To Run
55 |
56 | ### Requirements
57 | - torch>=1.10.1
58 | - torchvision>=0.11.2
59 | - chainercv>=0.13.1
60 | - numpy
61 | - pillow
62 | - scikit-learn
63 | - tqdm
64 |
65 | ### Datasets
66 |
67 | - Download Pascal VOC2012 dataset from the [official dataset homepage](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/).
68 | - `Center_points/` (ground-trugh point labels) [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/Center_points.zip)
69 | - `Peak_points/` (point labels extracted by PAM module and image-level labels) [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/Peak_points.zip)
70 | - `WSSS_maps/` (weakly-supervised semantic segmentation outputs) [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/WSSS_maps.zip)
71 | - `SegmentationObject/` (ground-truth mask labels) [[download]](https://github.com/clovaai/BESTIE/releases/download/asset/SegmentationObject.zip)
72 |
73 | ```
74 | data_root/
75 | --- VOC2012/
76 | --- Annotations/
77 | --- ImageSet/
78 | --- JPEGImages/
79 | --- SegmentationObject/
80 | --- Center_points/
81 | --- Peak_points/
82 | --- WSSS_maps/
83 | ```
84 |
85 | ### Image-level Supervised Instance Segmentation on VOC2012
86 | ```
87 | # change the data ROOT in the shell script
88 | bash scrips/run_image_labels.sh
89 | ```
90 |
91 | ### Point Supervised Instance Segmentation on VOC2012
92 | ```
93 | # change the data ROOT in the shell script
94 | bash scrips/run_point_labels.sh
95 | ```
96 |
97 | ### Mask R-CNN Refinement
98 |
99 | 1. Generate COCO-style pseudo labels using the BESTIE model.
100 | 2. Train the Mask R-CNN using the pseudo-labels: https://github.com/facebookresearch/maskrcnn-benchmark .
101 |
102 |
103 | # Acknowledgement
104 |
105 | Our implementation is based on these repositories:
106 | - (Panoptic-DeepLab) https://github.com/bowenc0221/panoptic-deeplab
107 | - (HRNet) https://github.com/HRNet/HRNet-Human-Pose-Estimation
108 | - (DRS) https://github.com/qjadud1994/DRS
109 |
110 | # License
111 |
112 | ```
113 | Copyright (c) 2022-present NAVER Corp.
114 |
115 | Permission is hereby granted, free of charge, to any person obtaining a copy
116 | of this software and associated documentation files (the "Software"), to deal
117 | in the Software without restriction, including without limitation the rights
118 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
119 | copies of the Software, and to permit persons to whom the Software is
120 | furnished to do so, subject to the following conditions:
121 |
122 | The above copyright notice and this permission notice shall be included in
123 | all copies or substantial portions of the Software.
124 |
125 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
126 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
127 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
128 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
129 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
130 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
131 | THE SOFTWARE.
132 | ```
133 |
--------------------------------------------------------------------------------
/PAM/models/classifier.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/models/vgg.py
3 | # ------------------------------------------------------------------------------
4 | """
5 | BESTIE
6 | Copyright (c) 2022-present NAVER Corp.
7 | MIT License
8 | """
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.utils.model_zoo as model_zoo
13 | import torch.nn.functional as F
14 | import math
15 | import cv2
16 | import numpy as np
17 | import os
18 |
19 | model_urls = {'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth'}
20 |
21 |
22 | class PAM(nn.Module):
23 | def __init__(self, alpha):
24 | super(PAM, self).__init__()
25 | self.relu = nn.ReLU(inplace=True)
26 | self.selector = nn.AdaptiveMaxPool2d(1)
27 | self.alpha = alpha
28 |
29 | def forward(self, x):
30 | b, c, _, _ = x.size()
31 | x = self.relu(x)
32 |
33 | """ 1: selector """
34 | peak_region = self.selector(x).view(b, c, 1, 1)
35 | peak_region = peak_region.expand_as(x)
36 |
37 | """ 2: Controller -> self.alpha"""
38 | boundary = (x < peak_region * self.alpha)
39 |
40 | """ 3: Peak Stimulator"""
41 | x = torch.where(boundary, torch.zeros_like(x), x)
42 |
43 | return x
44 |
45 |
46 | class VGG(nn.Module):
47 |
48 | def __init__(self, features, num_classes=20,
49 | alpha=0.7, init_weights=True):
50 |
51 | super(VGG, self).__init__()
52 |
53 | self.features = features
54 |
55 | self.layer1_conv1 = features[0]
56 | self.layer1_conv2 = features[2]
57 | self.layer1_maxpool = features[4]
58 |
59 | self.layer2_conv1 = features[5]
60 | self.layer2_conv2 = features[7]
61 | self.layer2_maxpool = features[9]
62 |
63 | self.layer3_conv1 = features[10]
64 | self.layer3_conv2 = features[12]
65 | self.layer3_conv3 = features[14]
66 | self.layer3_maxpool = features[16]
67 |
68 | self.layer4_conv1 = features[17]
69 | self.layer4_conv2 = features[19]
70 | self.layer4_conv3 = features[21]
71 | self.layer4_maxpool = features[23]
72 |
73 | self.layer5_conv1 = features[24]
74 | self.layer5_conv2 = features[26]
75 | self.layer5_conv3 = features[28]
76 |
77 | self.extra_conv1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
78 | self.extra_conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
79 | self.extra_conv3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
80 | self.extra_conv4 = nn.Conv2d(512, 20, kernel_size=1)
81 |
82 | self.pam = PAM(alpha)
83 | self.relu = nn.ReLU(inplace=True)
84 |
85 | if init_weights:
86 | self._initialize_weights(self.extra_conv1)
87 | self._initialize_weights(self.extra_conv2)
88 | self._initialize_weights(self.extra_conv3)
89 | self._initialize_weights(self.extra_conv4)
90 |
91 | def forward(self, x, label=None, size=None):
92 | if size is None:
93 | size = x.size()[2:]
94 |
95 | # layer1
96 | x = self.layer1_conv1(x)
97 | x = self.relu(x)
98 | x = self.layer1_conv2(x)
99 | x = self.relu(x)
100 | x = self.layer1_maxpool(x)
101 |
102 | # layer2
103 | x = self.layer2_conv1(x)
104 | x = self.relu(x)
105 | x = self.layer2_conv2(x)
106 | x = self.relu(x)
107 | x = self.layer2_maxpool(x)
108 |
109 | # layer3
110 | x = self.layer3_conv1(x)
111 | x = self.relu(x)
112 | x = self.layer3_conv2(x)
113 | x = self.relu(x)
114 | x = self.layer3_conv3(x)
115 | x = self.relu(x)
116 | x = self.layer3_maxpool(x)
117 |
118 | # layer4
119 | x = self.layer4_conv1(x)
120 | x = self.relu(x)
121 | x = self.layer4_conv2(x)
122 | x = self.relu(x)
123 | x = self.layer4_conv3(x)
124 | x = self.relu(x)
125 | x = self.layer4_maxpool(x)
126 |
127 | # layer5
128 | x = self.layer5_conv1(x)
129 | x = self.relu(x)
130 | x = self.layer5_conv2(x)
131 | x = self.relu(x)
132 | x = self.layer5_conv3(x)
133 | x = self.relu(x)
134 | # ==============================
135 |
136 | x = self.extra_conv1(x)
137 | x = self.pam(x)
138 | x = self.extra_conv2(x)
139 | x = self.pam(x)
140 | x = self.extra_conv3(x)
141 | x = self.pam(x)
142 | x = self.extra_conv4(x)
143 | # ==============================
144 |
145 | logit = self.fc(x)
146 |
147 | if self.training:
148 | return logit
149 |
150 | else:
151 | cam = self.cam_normalize(x.detach(), size, label)
152 | return logit, cam
153 |
154 |
155 | def fc(self, x):
156 | x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0)
157 | x = x.view(-1, 20)
158 | return x
159 |
160 |
161 | def cam_normalize(self, cam, size, label):
162 | B, C, H, W = cam.size()
163 |
164 | cam = F.relu(cam)
165 | cam = cam * label[:, :, None, None]
166 |
167 | cam = F.interpolate(cam, size=size, mode='bilinear', align_corners=False)
168 | cam /= F.adaptive_max_pool2d(cam, 1) + 1e-5
169 |
170 | return cam
171 |
172 |
173 | def _initialize_weights(self, layer):
174 | for m in layer.modules():
175 | if isinstance(m, nn.Conv2d):
176 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
177 | m.weight.data.normal_(0, math.sqrt(2. / n))
178 | if m.bias is not None:
179 | m.bias.data.zero_()
180 | elif isinstance(m, nn.BatchNorm2d):
181 | m.weight.data.fill_(1)
182 | m.bias.data.zero_()
183 | elif isinstance(m, nn.Linear):
184 | m.weight.data.normal_(0, 0.01)
185 | m.bias.data.zero_()
186 |
187 | def get_parameter_groups(self):
188 | groups = ([], [], [], [])
189 |
190 | for name, value in self.named_parameters():
191 |
192 | if 'extra' in name:
193 | if 'weight' in name:
194 | groups[2].append(value)
195 | else:
196 | groups[3].append(value)
197 | else:
198 | if 'weight' in name:
199 | groups[0].append(value)
200 | else:
201 | groups[1].append(value)
202 | return groups
203 |
204 |
205 |
206 |
207 | #######################################################################################################
208 |
209 |
210 | def make_layers(cfg, batch_norm=False):
211 | layers = []
212 | in_channels = 3
213 | for i, v in enumerate(cfg):
214 | if v == 'M':
215 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
216 | elif v == 'N':
217 | layers += [nn.MaxPool2d(kernel_size=3, stride=1, padding=1)]
218 | else:
219 | if i > 13:
220 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, dilation=2, padding=2)
221 | else:
222 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
223 | if batch_norm:
224 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
225 | else:
226 | layers += [conv2d, nn.ReLU(inplace=True)]
227 | in_channels = v
228 | return nn.Sequential(*layers)
229 |
230 |
231 | cfg = {
232 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
233 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
234 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
235 | 'D1': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'N', 512, 512, 512],
236 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
237 | }
238 |
239 |
240 | def vgg16_pam(pretrained=True, alpha=0.7):
241 | model = VGG(make_layers(cfg['D1']), alpha=alpha)
242 |
243 | if pretrained:
244 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16']), strict=False)
245 | return model
246 |
247 |
248 | if __name__ == '__main__':
249 | import copy
250 |
251 | model = vgg16(pretrained=True)
252 | print()
253 |
254 | print(model)
255 |
256 | input = torch.randn(2, 3, 321, 321)
257 | label = np.array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])
258 | label = torch.from_numpy(label)
259 |
260 | out = model(input, label)
261 |
262 | print(out[1].shape)
263 |
264 |
--------------------------------------------------------------------------------
/utils/decode.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def colorize_offset(offset_map, offset_weight, seg_map=None, pred=True):
5 |
6 | import matplotlib.colors
7 | import math
8 |
9 | a = (np.arctan2(-offset_map[0], -offset_map[1]) / math.pi + 1) / 2
10 |
11 | r = np.sqrt(offset_map[0] ** 2 + offset_map[1] ** 2)
12 | s = r / (np.max(r) + 1e-5)
13 |
14 | hsv_color = np.stack((a, s, np.ones_like(a)), axis=-1)
15 | rgb_color = matplotlib.colors.hsv_to_rgb(hsv_color)
16 | rgb_color = np.uint8(rgb_color * 255)
17 |
18 | if seg_map is not None:
19 | rgb_color[np.where(seg_map == 0)] = [0, 0, 0] # background
20 |
21 | if not pred:
22 | rgb_color[np.where(offset_weight == 0)] = [255, 255, 255] # ignore
23 |
24 | return rgb_color
25 |
26 |
27 | def voc_names():
28 | return [
29 | "background", "aeroplane", "bicycle", "bird",
30 | "boat", "bottle", "bus", "car", "cat", "chair",
31 | "cow", "diningtable", "dog", "horse", "motorbike",
32 | "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor",
33 | ]
34 |
35 |
36 | def get_palette():
37 | palette = []
38 | for i in range(256):
39 | palette.extend((i,i,i))
40 | palette[:3*21] = np.array([[0, 0, 0],
41 | [128, 0, 0],
42 | [0, 128, 0],
43 | [128, 128, 0],
44 | [0, 0, 128],
45 | [128, 0, 128],
46 | [0, 128, 128],
47 | [128, 128, 128],
48 | [64, 0, 0],
49 | [192, 0, 0],
50 | [64, 128, 0],
51 | [192, 128, 0],
52 | [64, 0, 128],
53 | [192, 0, 128],
54 | [64, 128, 128],
55 | [192, 128, 128],
56 | [0, 64, 0],
57 | [128, 64, 0],
58 | [0, 192, 0],
59 | [128, 192, 0],
60 | [0, 64, 128]], dtype='uint8').flatten()
61 |
62 | return palette
63 |
64 | def voc_colors():
65 | colors = np.array([[0, 0, 0],
66 | [128, 0, 0],
67 | [0, 128, 0],
68 | [128, 128, 0],
69 | [0, 0, 128],
70 | [128, 0, 128],
71 | [0, 128, 128],
72 | [128, 128, 128],
73 | [64, 0, 0],
74 | [192, 0, 0],
75 | [64, 128, 0],
76 | [192, 128, 0],
77 | [64, 0, 128],
78 | [192, 0, 128],
79 | [64, 128, 128],
80 | [192, 128, 128],
81 | [0, 64, 0],
82 | [128, 64, 0],
83 | [0, 192, 0],
84 | [128, 192, 0],
85 | [0, 64, 128],
86 | [255, 255, 255],
87 | [200, 200, 200]], dtype='uint8')
88 | return colors
89 |
90 |
91 | def cam_to_seg(CAM, sal_map, palette, alpha=0.2, ignore=255):
92 | colors = voc_colors()
93 | C, H, W = CAM.shape
94 |
95 | CAM[CAM < alpha] = 0 # object cue
96 |
97 | bg = np.zeros((1, H, W), dtype=np.float32)
98 | pred_map = np.concatenate([bg, CAM], axis=0) # [21, H, W]
99 |
100 | pred_map[0, :, :] = (1. - sal_map) # backgroudn cue
101 |
102 | # conflict pixels with multiple confidence values
103 | bg = np.array(pred_map > 0.99, dtype=np.uint8)
104 | bg = np.sum(bg, axis=0)
105 | pred_map = pred_map.argmax(0).astype(np.uint8)
106 | pred_map[bg > 2] = ignore
107 |
108 | # pixels regarded as background but confidence saliency values
109 | bg = (sal_map == 1).astype(np.uint8) * (pred_map == 0).astype(np.uint8) # and operator
110 | pred_map[bg > 0] = ignore
111 |
112 | pred_map = np.uint8(pred_map)
113 |
114 | palette = get_palette()
115 | pred_map = Image.fromarray(pred_map)
116 | pred_map.putpalette(palette)
117 |
118 | return pred_map
119 |
120 |
121 | def cam_with_crf(cam, img, keys, fg_thresh=0.1, bg_thresh=0.001):
122 | cam = np.float32(cam) / 255.
123 |
124 | cam = cam[keys] # valid category selection
125 |
126 | valid_cat = np.pad(keys + 1, (1, 0), mode='constant') # valid category : [background, val_cat+1]
127 |
128 | fg_conf_cam = np.pad(cam, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=fg_thresh) # [c+1, H, W]
129 | fg_conf_cam = np.argmax(fg_conf_cam, axis=0)
130 | pred = crf_inference_label(img, fg_conf_cam, n_labels=valid_cat.shape[0])
131 | fg_conf = valid_cat[pred] # convert to whole index (0, 1, 2) -> (0 ~ 20)
132 |
133 | bg_conf_cam = np.pad(cam, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=bg_thresh) # [c+1, H, W]
134 | bg_conf_cam = np.argmax(bg_conf_cam, axis=0)
135 | pred = crf_inference_label(img, bg_conf_cam, n_labels=valid_cat.shape[0])
136 | bg_conf = valid_cat[pred] # convert to whole index (0, 1, 2) -> (0 ~ 20)
137 |
138 | conf = fg_conf.copy()
139 | conf[fg_conf == 0] = 21
140 | conf[bg_conf + fg_conf == 0] = 0 # both zero
141 |
142 | conf_color = colors[conf]
143 |
144 | return conf, conf_color
145 |
146 |
147 | def heatmap_colorize(score_map, exclude_zero=True, normalize=True):
148 | import matplotlib.colors
149 |
150 | VOC_color = np.array([(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
151 | (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0),
152 | (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0),
153 | (0, 192, 0), (128, 192, 0), (0, 64, 128), (255, 255, 255)], np.float32)
154 |
155 | if exclude_zero:
156 | VOC_color = VOC_color[1:]
157 |
158 | test = VOC_color[np.argmax(score_map, axis=0)%22]
159 | test = np.expand_dims(np.max(score_map, axis=0), axis=-1) * test
160 | if normalize:
161 | test /= np.max(test) + 1e-5
162 |
163 | return test
164 |
165 | def decode_seg_map_sequence(label_masks):
166 | if label_masks.ndim == 2:
167 | label_masks = label_masks[None, :, :]
168 |
169 | rgb_masks = []
170 | for label_mask in label_masks:
171 | rgb_mask = decode_segmap(label_mask)
172 | rgb_masks.append(rgb_mask)
173 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
174 | return rgb_masks
175 |
176 |
177 | def decode_segmap(label_mask, plot=False):
178 | """Decode segmentation class labels into a color image
179 | Args:
180 | label_mask (np.ndarray): an (M,N) array of integer values denoting
181 | the class label at each spatial location.
182 | plot (bool, optional): whether to show the resulting color image
183 | in a figure.
184 | Returns:
185 | (np.ndarray, optional): the resulting decoded color image.
186 | """
187 |
188 | n_classes = 21
189 | label_colours = get_pascal_labels()
190 |
191 | r = label_mask.copy()
192 | g = label_mask.copy()
193 | b = label_mask.copy()
194 | for ll in range(0, n_classes):
195 | r[label_mask == ll] = label_colours[ll, 0]
196 | g[label_mask == ll] = label_colours[ll, 1]
197 | b[label_mask == ll] = label_colours[ll, 2]
198 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
199 | rgb[:, :, 0] = r / 255.0
200 | rgb[:, :, 1] = g / 255.0
201 | rgb[:, :, 2] = b / 255.0
202 | if plot:
203 | plt.imshow(rgb)
204 | plt.show()
205 | else:
206 | return rgb
207 |
208 |
209 | def encode_segmap(mask):
210 | """Encode segmentation label images as pascal classes
211 | Args:
212 | mask (np.ndarray): raw segmentation label image of dimension
213 | (M, N, 3), in which the Pascal classes are encoded as colours.
214 | Returns:
215 | (np.ndarray): class map with dimensions (M,N), where the value at
216 | a given location is the integer denoting the class index.
217 | """
218 | mask = mask.astype(int)
219 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
220 | for ii, label in enumerate(get_pascal_labels()):
221 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
222 | label_mask = label_mask.astype(int)
223 | return label_mask
224 |
225 |
226 | def get_pascal_labels():
227 | """Load the mapping that associates pascal classes with label colors
228 | Returns:
229 | np.ndarray with dimensions (21, 3)
230 | """
231 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
232 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
233 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
234 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
235 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
236 | [0, 64, 128]])
237 |
238 |
239 | def get_ins_colors():
240 | ins_colors = np.random.random((2000, 3))
241 | ins_colors = np.uint8(ins_colors*255)
242 | ins_colors[0] = [0, 0, 0]
243 | ins_colors[1] = [192, 128, 0]
244 | ins_colors[2] = [64, 0, 128]
245 | ins_colors[3] = [192, 0, 128]
246 | ins_colors[4] = [64, 128, 128]
247 | ins_colors[5] = [192, 128, 128]
248 | ins_colors[6] = [0, 64, 0]
249 | ins_colors[7] = [128, 64, 0]
250 | ins_colors[8] = [0, 192, 0]
251 | ins_colors[9] = [128, 192, 0]
252 | ins_colors[10] = [0, 64, 128]
253 | ins_colors[11] = [128, 0, 0]
254 | ins_colors[12] = [0, 128, 0]
255 | ins_colors[13] = [128, 128, 0]
256 | ins_colors[14] = [0, 0, 128]
257 | ins_colors[15] = [128, 0, 128]
258 | ins_colors[16] = [0, 128, 128]
259 | ins_colors[17] = [128, 128, 128]
260 | ins_colors[18] = [64, 0, 0]
261 | ins_colors[19] = [192, 0, 0]
262 | ins_colors[20] = [64, 128, 0]
263 | return ins_colors
--------------------------------------------------------------------------------
/utils/LoadData.py:
--------------------------------------------------------------------------------
1 | """
2 | BESTIE
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT License
5 | """
6 |
7 | import os
8 | from PIL import Image
9 | import numpy as np
10 |
11 | import torch
12 | import torchvision.transforms.functional as TF
13 | from .transforms import transforms as T
14 |
15 |
16 | from torch.utils.data import Dataset
17 | from .utils import gaussian, pseudo_label_generation
18 |
19 | def get_dataset(args, mode):
20 | if 'coco' in args.dataset:
21 | mean_vals = [0.471, 0.448, 0.408]
22 | std_vals = [0.234, 0.239, 0.242]
23 | else:
24 | mean_vals = [0.485, 0.456, 0.406]
25 | std_vals = [0.229, 0.224, 0.225]
26 |
27 | if mode == 'train':
28 | data_list = "data/train_cls.txt" # 10,582 images
29 |
30 | crop_size = int(args.crop_size)
31 | input_size = crop_size + 64
32 |
33 | min_resize_value = input_size
34 | max_resize_value = input_size
35 | resize_factor = 32
36 |
37 | min_scale = 0.7
38 | max_scale = 1.3
39 | scale_step_size=0.1
40 | crop_h, crop_w = crop_size, crop_size
41 |
42 | pad_value = tuple([int(v * 255) for v in mean_vals])
43 | ignore_label = (0, 0, 0)
44 |
45 | transform = T.Compose(
46 | [
47 | T.PhotometricDistort(),
48 | T.Resize(min_resize_value,
49 | max_resize_value,
50 | resize_factor),
51 | T.RandomScale(
52 | min_scale,
53 | max_scale,
54 | scale_step_size
55 | ),
56 | T.RandomCrop(
57 | crop_h,
58 | crop_w,
59 | pad_value,
60 | ignore_label,
61 | random_pad=True
62 | ),
63 | T.ToTensor(),
64 | T.RandomHorizontalFlip(),
65 | T.Normalize(
66 | mean_vals,
67 | std_vals
68 | )
69 | ]
70 | )
71 |
72 | dataset = VOCDataset(data_list,
73 | root_dir=args.root_dir,
74 | num_classes=args.num_classes,
75 | transform=transform,
76 | sup=args.sup,
77 | sigma=args.sigma,
78 | point_thresh=args.pseudo_thresh)
79 |
80 | else:
81 | #data_list = "data/train_labeled_cls.txt"
82 | data_list = "data/val_cls.txt"
83 |
84 | dataset = VOCTestDataset(data_list,
85 | root_dir=args.root_dir,
86 | num_classes=args.num_classes)
87 |
88 | return dataset
89 |
90 |
91 | class VOCDataset(Dataset):
92 | def __init__(self, datalist_file, root_dir, num_classes=20,
93 | transform=None, sup='cls', sigma=8, point_thresh=0.5):
94 |
95 | self.num_classes = num_classes
96 | self.transform = transform
97 | self.sigma = sigma
98 | self.sup = sup
99 | self.point_thresh = point_thresh
100 |
101 | self.g = gaussian(sigma)
102 |
103 | self.dat_list = self.read_labeled_image_list(root_dir, datalist_file)
104 |
105 |
106 | def __getitem__(self, idx):
107 | img_path = self.dat_list["img"][idx]
108 | seg_map_path = self.dat_list["seg_map"][idx]
109 | cls_label = self.dat_list["cls_label"][idx]
110 | points = self.dat_list["point"][idx]
111 |
112 | img = np.uint8(Image.open(img_path).convert("RGB"))
113 | seg_map = np.uint8(Image.open(seg_map_path))
114 |
115 | if self.transform is not None:
116 | img, seg_map, points = self.transform(img, seg_map, points)
117 |
118 | center_map, offset_map, weight = pseudo_label_generation(self.sup,
119 | seg_map.numpy(),
120 | points,
121 | cls_label,
122 | self.num_classes,
123 | self.sigma,
124 | self.g)
125 |
126 | point_list = self.make_class_wise_point_list(points)
127 |
128 | seg_map = seg_map.long()
129 | center_map = torch.from_numpy(center_map)
130 | offset_map = torch.from_numpy(offset_map)
131 | weight = torch.from_numpy(weight)
132 | point_list = torch.from_numpy(point_list)
133 |
134 | return img, cls_label, seg_map, center_map, offset_map, weight, point_list
135 |
136 |
137 | def make_class_wise_point_list(self, points):
138 |
139 | MAX_NUM_POINTS = 128
140 |
141 | point_list = np.zeros((self.num_classes, MAX_NUM_POINTS, 2), dtype=np.int32)
142 | point_count = [0 for _ in range(self.num_classes)]
143 |
144 | for (x, y, cls, _ ) in points:
145 | point_list[cls][point_count[cls]] = [y, x]
146 | point_count[cls] += 1
147 |
148 | return point_list
149 |
150 |
151 | def read_labeled_image_list(self, root_dir, data_list):
152 | img_dir = os.path.join(root_dir, "JPEGImages")
153 | seg_map_dir = os.path.join(root_dir, "WSSS_maps")
154 |
155 | if self.sup == 'point':
156 | point_dir = os.path.join(root_dir, "Center_points")
157 | else:
158 | point_dir = os.path.join(root_dir, "Peak_points")
159 |
160 | with open(data_list, 'r') as f:
161 | lines = f.read().splitlines()
162 |
163 | img_list = []
164 | label_list = []
165 | seg_map_list = []
166 | point_list = []
167 |
168 | np.random.shuffle(lines)
169 |
170 | for line in lines:
171 | fields = line.strip().split(" ")
172 | # fields[0] : file_name
173 | # fields[1:] : cls labels
174 |
175 | image_path = os.path.join(img_dir, fields[0] + '.jpg')
176 | seg_map_path = os.path.join(seg_map_dir, fields[0] + '.png')
177 | point_txt = os.path.join(point_dir, fields[0] + '.txt')
178 |
179 | # one-hot cls label
180 | labels = np.zeros((self.num_classes,), dtype=np.float32)
181 | for i in range(len(fields)-1):
182 | index = int(fields[i+1])
183 | labels[index] = 1.
184 |
185 | # get points
186 | with open(point_txt, 'r') as pf:
187 | points = pf.read().splitlines()
188 | points = [p.strip().split(" ") for p in points]
189 | points = [ [float(p[0]), float(p[1]), int(p[2]), float(p[3])] for p in points if float(p[3]) > self.point_thresh]
190 | # point (x_coord, y_coord, class-idx, conf)
191 |
192 | img_list.append(image_path)
193 | seg_map_list.append(seg_map_path)
194 | point_list.append(points)
195 | label_list.append(labels)
196 |
197 | return {"img": img_list,
198 | "cls_label": label_list,
199 | "seg_map": seg_map_list,
200 | "point": point_list}
201 |
202 | def __len__(self):
203 | return len(self.dat_list["img"])
204 |
205 |
206 |
207 |
208 | class VOCTestDataset(Dataset):
209 | def __init__(self, datalist_file, root_dir, num_classes=20):
210 |
211 | self.num_classes = num_classes
212 | self.dat_list = self.read_labeled_image_list(root_dir, datalist_file)
213 |
214 |
215 | def __getitem__(self, idx):
216 | fname = self.dat_list["fname"][idx]
217 | img_path = self.dat_list["img"][idx]
218 | cls_label = self.dat_list["cls_label"][idx]
219 | points = self.dat_list["point"][idx]
220 |
221 | img = Image.open(img_path).convert("RGB")
222 |
223 | ori_w, ori_h = img.size
224 |
225 | new_h = (ori_h + 31) // 32 * 32
226 | new_w = (ori_w + 31) // 32 * 32
227 |
228 | img = img.resize((new_w, new_h), Image.BILINEAR)
229 | img = TF.to_tensor(img)
230 | img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
231 |
232 | return img, cls_label, points, fname, (ori_h, ori_w)
233 |
234 |
235 | def read_labeled_image_list(self, root_dir, data_list):
236 | img_dir = os.path.join(root_dir, "JPEGImages")
237 | mask_dir = os.path.join(root_dir, "SegmentationObjectAug")
238 | point_dir = os.path.join(root_dir, "Center_points")
239 |
240 | with open(data_list, 'r') as f:
241 | lines = f.read().splitlines()
242 |
243 | fname_list = []
244 | img_list = []
245 | mask_list = []
246 | label_list = []
247 | point_list = []
248 |
249 | for line in lines:
250 | fields = line.strip().split(" ")
251 | # fields[0] : file_name
252 | # fields[1:] : cls labels
253 |
254 | image_path = os.path.join(img_dir, fields[0] + '.jpg')
255 | mask_path = os.path.join(mask_dir, fields[0] + '.png')
256 | point_txt = os.path.join(point_dir, fields[0] + '.txt')
257 |
258 | # one-hot cls label
259 | labels = np.zeros((self.num_classes,), dtype=np.float32)
260 | for i in range(len(fields)-1):
261 | index = int(fields[i+1])
262 | labels[index] = 1.
263 |
264 | # get points
265 | with open(point_txt, 'r') as pf:
266 | points = pf.read().splitlines()
267 | points = [p.strip().split(" ") for p in points]
268 | points = [ [float(p[0]), float(p[1]), int(p[2]), float(p[3])] for p in points]
269 | # point (x_coord, y_coord, class-idx, conf)
270 |
271 | points_cls = [[] for _ in range(self.num_classes)]
272 | for (x, y, cls, _ ) in points:
273 | points_cls[cls].append((y, x))
274 |
275 | fname_list.append(fields[0])
276 | img_list.append(image_path)
277 | mask_list.append(mask_path)
278 | point_list.append(points_cls)
279 | label_list.append(labels)
280 |
281 | return {"img": img_list,
282 | "mask": mask_list,
283 | "cls_label": label_list,
284 | "point": point_list,
285 | "fname": fname_list}
286 |
287 | def __len__(self):
288 | return len(self.dat_list["img"])
289 |
290 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Reference: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
3 | # Modified by Bowen Cheng (bcheng9@illinois.edu)
4 | # ------------------------------------------------------------------------------
5 |
6 | import torch.nn as nn
7 | try:
8 | from torchvision.models.utils import load_state_dict_from_url
9 | except:
10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
11 |
12 |
13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
14 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
15 | 'wide_resnet50_2', 'wide_resnet101_2']
16 |
17 |
18 | model_urls = {
19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
24 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
25 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
26 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
27 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
28 | }
29 |
30 |
31 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
32 | """3x3 convolution with padding"""
33 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
34 | padding=dilation, groups=groups, bias=False, dilation=dilation)
35 |
36 |
37 | def conv1x1(in_planes, out_planes, stride=1):
38 | """1x1 convolution"""
39 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
40 |
41 |
42 | class BasicBlock(nn.Module):
43 | expansion = 1
44 |
45 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
46 | base_width=64, dilation=1, norm_layer=None):
47 | super(BasicBlock, self).__init__()
48 | if norm_layer is None:
49 | norm_layer = nn.BatchNorm2d
50 | if groups != 1 or base_width != 64:
51 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
52 | if dilation > 1:
53 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
54 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
55 | self.conv1 = conv3x3(inplanes, planes, stride)
56 | self.bn1 = norm_layer(planes)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.conv2 = conv3x3(planes, planes)
59 | self.bn2 = norm_layer(planes)
60 | self.downsample = downsample
61 | self.stride = stride
62 |
63 | def forward(self, x):
64 | identity = x
65 |
66 | out = self.conv1(x)
67 | out = self.bn1(out)
68 | out = self.relu(out)
69 |
70 | out = self.conv2(out)
71 | out = self.bn2(out)
72 |
73 | if self.downsample is not None:
74 | identity = self.downsample(x)
75 |
76 | out += identity
77 | out = self.relu(out)
78 |
79 | return out
80 |
81 |
82 | class Bottleneck(nn.Module):
83 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
84 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
85 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
86 | # This variant is also known as ResNet V1.5 and improves accuracy according to
87 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
88 |
89 | expansion = 4
90 |
91 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
92 | base_width=64, dilation=1, norm_layer=None):
93 | super(Bottleneck, self).__init__()
94 | if norm_layer is None:
95 | norm_layer = nn.BatchNorm2d
96 | width = int(planes * (base_width / 64.)) * groups
97 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
98 | self.conv1 = conv1x1(inplanes, width)
99 | self.bn1 = norm_layer(width)
100 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
101 | self.bn2 = norm_layer(width)
102 | self.conv3 = conv1x1(width, planes * self.expansion)
103 | self.bn3 = norm_layer(planes * self.expansion)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.downsample = downsample
106 | self.stride = stride
107 |
108 | def forward(self, x):
109 | identity = x
110 |
111 | out = self.conv1(x)
112 | out = self.bn1(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv2(out)
116 | out = self.bn2(out)
117 | out = self.relu(out)
118 |
119 | out = self.conv3(out)
120 | out = self.bn3(out)
121 |
122 | if self.downsample is not None:
123 | identity = self.downsample(x)
124 |
125 | out += identity
126 | out = self.relu(out)
127 |
128 | return out
129 |
130 |
131 | class ResNet(nn.Module):
132 |
133 | def __init__(self, block, layers, zero_init_residual=False,
134 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
135 | norm_layer=None):
136 | super(ResNet, self).__init__()
137 | if norm_layer is None:
138 | norm_layer = nn.BatchNorm2d
139 | self._norm_layer = norm_layer
140 |
141 | self.inplanes = 64
142 | self.dilation = 1
143 | if replace_stride_with_dilation is None:
144 | # each element in the tuple indicates if we should replace
145 | # the 2x2 stride with a dilated convolution instead
146 | replace_stride_with_dilation = [False, False, False]
147 | if len(replace_stride_with_dilation) != 3:
148 | raise ValueError("replace_stride_with_dilation should be None "
149 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
150 | self.groups = groups
151 | self.base_width = width_per_group
152 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
153 | bias=False)
154 | self.bn1 = norm_layer(self.inplanes)
155 | self.relu = nn.ReLU(inplace=True)
156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
157 | self.layer1 = self._make_layer(block, 64, layers[0])
158 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
159 | dilate=replace_stride_with_dilation[0])
160 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
161 | dilate=replace_stride_with_dilation[1])
162 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
163 | dilate=replace_stride_with_dilation[2])
164 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
165 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
166 |
167 | for m in self.modules():
168 | if isinstance(m, nn.Conv2d):
169 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
170 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
171 | nn.init.constant_(m.weight, 1)
172 | nn.init.constant_(m.bias, 0)
173 |
174 | # Zero-initialize the last BN in each residual branch,
175 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
176 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
177 | if zero_init_residual:
178 | for m in self.modules():
179 | if isinstance(m, Bottleneck):
180 | nn.init.constant_(m.bn3.weight, 0)
181 | elif isinstance(m, BasicBlock):
182 | nn.init.constant_(m.bn2.weight, 0)
183 |
184 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
185 | norm_layer = self._norm_layer
186 | downsample = None
187 | previous_dilation = self.dilation
188 | if dilate:
189 | self.dilation *= stride
190 | stride = 1
191 | if stride != 1 or self.inplanes != planes * block.expansion:
192 | downsample = nn.Sequential(
193 | conv1x1(self.inplanes, planes * block.expansion, stride),
194 | norm_layer(planes * block.expansion),
195 | )
196 |
197 | layers = []
198 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
199 | self.base_width, previous_dilation, norm_layer))
200 | self.inplanes = planes * block.expansion
201 | for _ in range(1, blocks):
202 | layers.append(block(self.inplanes, planes, groups=self.groups,
203 | base_width=self.base_width, dilation=self.dilation,
204 | norm_layer=norm_layer))
205 |
206 | return nn.Sequential(*layers)
207 |
208 | def _forward_impl(self, x):
209 | outputs = {}
210 | # See note [TorchScript super()]
211 | x = self.conv1(x)
212 | x = self.bn1(x)
213 | x = self.relu(x)
214 | x = self.maxpool(x)
215 | outputs['stem'] = x
216 |
217 | x = self.layer1(x) # 1/4
218 | outputs['res2'] = x
219 |
220 | x = self.layer2(x) # 1/8
221 | outputs['res3'] = x
222 |
223 | x = self.layer3(x) # 1/16
224 | outputs['res4'] = x
225 |
226 | x = self.layer4(x) # 1/32
227 | outputs['res5'] = x
228 |
229 | return outputs
230 |
231 | def forward(self, x):
232 | return self._forward_impl(x)
233 |
234 |
235 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
236 | model = ResNet(block, layers, **kwargs)
237 | if pretrained:
238 | state_dict = load_state_dict_from_url(model_urls[arch],
239 | progress=progress)
240 | model.load_state_dict(state_dict, strict=False)
241 | return model
242 |
243 |
244 | def resnet18(pretrained=False, progress=True, **kwargs):
245 | r"""ResNet-18 model from
246 | `"Deep Residual Learning for Image Recognition" `_
247 | Args:
248 | pretrained (bool): If True, returns a model pre-trained on ImageNet
249 | progress (bool): If True, displays a progress bar of the download to stderr
250 | """
251 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
252 | **kwargs)
253 |
254 |
255 | def resnet34(pretrained=False, progress=True, **kwargs):
256 | r"""ResNet-34 model from
257 | `"Deep Residual Learning for Image Recognition" `_
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | progress (bool): If True, displays a progress bar of the download to stderr
261 | """
262 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
263 | **kwargs)
264 |
265 |
266 | def resnet50(pretrained=False, progress=True, **kwargs):
267 | r"""ResNet-50 model from
268 | `"Deep Residual Learning for Image Recognition" `_
269 | Args:
270 | pretrained (bool): If True, returns a model pre-trained on ImageNet
271 | progress (bool): If True, displays a progress bar of the download to stderr
272 | """
273 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
274 | **kwargs)
275 |
276 |
277 | def resnet101(pretrained=False, progress=True, **kwargs):
278 | r"""ResNet-101 model from
279 | `"Deep Residual Learning for Image Recognition" `_
280 | Args:
281 | pretrained (bool): If True, returns a model pre-trained on ImageNet
282 | progress (bool): If True, displays a progress bar of the download to stderr
283 | """
284 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
285 | **kwargs)
286 |
287 |
288 | def resnet152(pretrained=False, progress=True, **kwargs):
289 | r"""ResNet-152 model from
290 | `"Deep Residual Learning for Image Recognition" `_
291 | Args:
292 | pretrained (bool): If True, returns a model pre-trained on ImageNet
293 | progress (bool): If True, displays a progress bar of the download to stderr
294 | """
295 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
296 | **kwargs)
297 |
298 |
299 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
300 | r"""ResNeXt-50 32x4d model from
301 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
302 | Args:
303 | pretrained (bool): If True, returns a model pre-trained on ImageNet
304 | progress (bool): If True, displays a progress bar of the download to stderr
305 | """
306 | kwargs['groups'] = 32
307 | kwargs['width_per_group'] = 4
308 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
309 | pretrained, progress, **kwargs)
310 |
311 |
312 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
313 | r"""ResNeXt-101 32x8d model from
314 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
315 | Args:
316 | pretrained (bool): If True, returns a model pre-trained on ImageNet
317 | progress (bool): If True, displays a progress bar of the download to stderr
318 | """
319 | kwargs['groups'] = 32
320 | kwargs['width_per_group'] = 8
321 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
322 | pretrained, progress, **kwargs)
323 |
324 |
325 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
326 | r"""Wide ResNet-50-2 model from
327 | `"Wide Residual Networks" `_
328 | The model is the same as ResNet except for the bottleneck number of channels
329 | which is twice larger in every block. The number of channels in outer 1x1
330 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
331 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
332 | Args:
333 | pretrained (bool): If True, returns a model pre-trained on ImageNet
334 | progress (bool): If True, displays a progress bar of the download to stderr
335 | """
336 | kwargs['width_per_group'] = 64 * 2
337 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
338 | pretrained, progress, **kwargs)
339 |
340 |
341 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
342 | r"""Wide ResNet-101-2 model from
343 | `"Wide Residual Networks" `_
344 | The model is the same as ResNet except for the bottleneck number of channels
345 | which is twice larger in every block. The number of channels in outer 1x1
346 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
347 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
348 | Args:
349 | pretrained (bool): If True, returns a model pre-trained on ImageNet
350 | progress (bool): If True, displays a progress bar of the download to stderr
351 | """
352 | kwargs['width_per_group'] = 64 * 2
353 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
354 | pretrained, progress, **kwargs)
355 |
--------------------------------------------------------------------------------
/models/panoptic_deeplab.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Panoptic-DeepLab decoder.
3 | # Written by Bowen Cheng (bcheng9@illinois.edu)
4 | # ------------------------------------------------------------------------------
5 |
6 | from collections import OrderedDict
7 | from functools import partial
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from utils.utils import refine_label_generation, refine_label_generation_with_point
14 |
15 | ################################################################################################
16 |
17 | def basic_conv(in_planes, out_planes, kernel_size, stride=1, padding=1, groups=1,
18 | with_bn=True, with_relu=True):
19 | """convolution with bn and relu"""
20 | module = []
21 | has_bias = not with_bn
22 | module.append(
23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
24 | bias=has_bias)
25 | )
26 | if with_bn:
27 | module.append(nn.BatchNorm2d(out_planes))
28 | if with_relu:
29 | module.append(nn.ReLU())
30 | return nn.Sequential(*module)
31 |
32 |
33 | def depthwise_separable_conv(in_planes, out_planes, kernel_size, stride=1, padding=1, groups=1,
34 | with_bn=True, with_relu=True):
35 | """depthwise separable convolution with bn and relu"""
36 | del groups
37 |
38 | module = []
39 | module.extend([
40 | basic_conv(in_planes, in_planes, kernel_size, stride, padding, groups=in_planes,
41 | with_bn=True, with_relu=True),
42 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
43 | ])
44 | if with_bn:
45 | module.append(nn.BatchNorm2d(out_planes))
46 | if with_relu:
47 | module.append(nn.ReLU())
48 | return nn.Sequential(*module)
49 |
50 |
51 | def stacked_conv(in_planes, out_planes, kernel_size, num_stack, stride=1, padding=1, groups=1,
52 | with_bn=True, with_relu=True, conv_type='basic_conv'):
53 | """stacked convolution with bn and relu"""
54 | if num_stack < 1:
55 | assert ValueError('`num_stack` has to be a positive integer.')
56 | if conv_type == 'basic_conv':
57 | conv = partial(basic_conv, out_planes=out_planes, kernel_size=kernel_size, stride=stride,
58 | padding=padding, groups=groups, with_bn=with_bn, with_relu=with_relu)
59 | elif conv_type == 'depthwise_separable_conv':
60 | conv = partial(depthwise_separable_conv, out_planes=out_planes, kernel_size=kernel_size, stride=stride,
61 | padding=padding, groups=1, with_bn=with_bn, with_relu=with_relu)
62 | else:
63 | raise ValueError('Unknown conv_type: {}'.format(conv_type))
64 | module = []
65 | module.append(conv(in_planes=in_planes))
66 | for n in range(1, num_stack):
67 | module.append(conv(in_planes=out_planes))
68 | return nn.Sequential(*module)
69 |
70 | ################################################################################################
71 |
72 | class ASPPConv(nn.Sequential):
73 | def __init__(self, in_channels, out_channels, dilation):
74 | modules = [
75 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
76 | nn.BatchNorm2d(out_channels),
77 | nn.ReLU()
78 | ]
79 | super(ASPPConv, self).__init__(*modules)
80 |
81 |
82 | class ASPPPooling(nn.Module):
83 | def __init__(self, in_channels, out_channels):
84 | super(ASPPPooling, self).__init__()
85 | self.aspp_pooling = nn.Sequential(
86 | nn.AdaptiveAvgPool2d(1),
87 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
88 | nn.ReLU()
89 | )
90 |
91 | def set_image_pooling(self, pool_size=None):
92 | if pool_size is None:
93 | self.aspp_pooling[0] = nn.AdaptiveAvgPool2d(1)
94 | else:
95 | self.aspp_pooling[0] = nn.AvgPool2d(kernel_size=pool_size, stride=1)
96 |
97 | def forward(self, x):
98 | size = x.shape[-2:]
99 | x = self.aspp_pooling(x)
100 | return F.interpolate(x, size=size, mode='bilinear', align_corners=True)
101 |
102 |
103 | class ASPP(nn.Module):
104 | def __init__(self, in_channels, out_channels, atrous_rates):
105 | super(ASPP, self).__init__()
106 | # out_channels = 256
107 | modules = []
108 | modules.append(nn.Sequential(
109 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
110 | nn.BatchNorm2d(out_channels),
111 | nn.ReLU()))
112 |
113 | rate1, rate2, rate3 = tuple(atrous_rates)
114 | modules.append(ASPPConv(in_channels, out_channels, rate1))
115 | modules.append(ASPPConv(in_channels, out_channels, rate2))
116 | modules.append(ASPPConv(in_channels, out_channels, rate3))
117 | modules.append(ASPPPooling(in_channels, out_channels))
118 |
119 | self.convs = nn.ModuleList(modules)
120 |
121 | self.project = nn.Sequential(
122 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
123 | nn.BatchNorm2d(out_channels),
124 | nn.ReLU(),
125 | nn.Dropout(0.5))
126 |
127 | def set_image_pooling(self, pool_size):
128 | self.convs[-1].set_image_pooling(pool_size)
129 |
130 | def forward(self, x):
131 | res = []
132 | for conv in self.convs:
133 | res.append(conv(x))
134 | res = torch.cat(res, dim=1)
135 | return self.project(res)
136 |
137 |
138 | ################################################################################################
139 |
140 | class SinglePanopticDeepLabDecoder(nn.Module):
141 | def __init__(self, in_channels, feature_key, low_level_channels, low_level_key, low_level_channels_project,
142 | decoder_channels, atrous_rates, aspp_channels=None):
143 | super(SinglePanopticDeepLabDecoder, self).__init__()
144 | if aspp_channels is None:
145 | aspp_channels = decoder_channels
146 | self.aspp = ASPP(in_channels, out_channels=aspp_channels, atrous_rates=atrous_rates)
147 | self.feature_key = feature_key
148 | self.decoder_stage = len(low_level_channels)
149 | assert self.decoder_stage == len(low_level_key)
150 | assert self.decoder_stage == len(low_level_channels_project)
151 | self.low_level_key = low_level_key
152 | fuse_conv = partial(stacked_conv, kernel_size=5, num_stack=1, padding=2,
153 | conv_type='depthwise_separable_conv')
154 |
155 | # Transform low-level feature
156 | project = []
157 | # Fuse
158 | fuse = []
159 | # Top-down direction, i.e. starting from largest stride
160 | for i in range(self.decoder_stage):
161 | project.append(
162 | nn.Sequential(
163 | nn.Conv2d(low_level_channels[i], low_level_channels_project[i], 1, bias=False),
164 | nn.BatchNorm2d(low_level_channels_project[i]),
165 | nn.ReLU()
166 | )
167 | )
168 | if i == 0:
169 | fuse_in_channels = aspp_channels + low_level_channels_project[i]
170 | else:
171 | fuse_in_channels = decoder_channels + low_level_channels_project[i]
172 | fuse.append(
173 | fuse_conv(
174 | fuse_in_channels,
175 | decoder_channels,
176 | )
177 | )
178 | self.project = nn.ModuleList(project)
179 | self.fuse = nn.ModuleList(fuse)
180 |
181 | def set_image_pooling(self, pool_size):
182 | self.aspp.set_image_pooling(pool_size)
183 |
184 | def forward(self, features):
185 | x = features[self.feature_key]
186 | x = self.aspp(x)
187 |
188 | # build decoder
189 | for i in range(self.decoder_stage):
190 | l = features[self.low_level_key[i]]
191 | l = self.project[i](l)
192 | x = F.interpolate(x, size=l.size()[2:], mode='bilinear', align_corners=True)
193 | x = torch.cat((x, l), dim=1)
194 | x = self.fuse[i](x)
195 |
196 | return x
197 |
198 |
199 | class SinglePanopticDeepLabHead(nn.Module):
200 | def __init__(self, decoder_channels, head_channels, num_classes, class_key):
201 | super(SinglePanopticDeepLabHead, self).__init__()
202 | fuse_conv = partial(stacked_conv, kernel_size=5, num_stack=1, padding=2,
203 | conv_type='depthwise_separable_conv')
204 |
205 | self.num_head = len(num_classes)
206 | assert self.num_head == len(class_key)
207 |
208 | classifier = {}
209 | for i in range(self.num_head):
210 | classifier[class_key[i]] = nn.Sequential(
211 | fuse_conv(
212 | decoder_channels,
213 | head_channels[i],
214 | ),
215 | nn.Conv2d(head_channels[i], num_classes[i], 1)
216 | )
217 | self.classifier = nn.ModuleDict(classifier)
218 | self.class_key = class_key
219 |
220 | def forward(self, x):
221 | pred = OrderedDict()
222 | # build classifier
223 | for key in self.class_key:
224 | pred[key] = self.classifier[key](x)
225 |
226 | return pred
227 |
228 |
229 | class PanopticDeepLabDecoder(nn.Module):
230 | def __init__(self, in_channels, feature_key, low_level_channels, low_level_key, low_level_channels_project,
231 | decoder_channels, atrous_rates, num_classes, instance_head_kwargs, **kwargs):
232 | super(PanopticDeepLabDecoder, self).__init__()
233 |
234 | self.semantic_decoder = SinglePanopticDeepLabDecoder(in_channels, feature_key, low_level_channels,
235 | low_level_key, low_level_channels_project,
236 | decoder_channels, atrous_rates)
237 | self.semantic_head = SinglePanopticDeepLabHead(decoder_channels, [decoder_channels], [num_classes], ['seg'])
238 |
239 |
240 | # Build instance decoder
241 | instance_decoder_kwargs = dict(
242 | in_channels=in_channels,
243 | feature_key=feature_key,
244 | low_level_channels=low_level_channels,
245 | low_level_key=low_level_key,
246 | low_level_channels_project=(64, 32, 16),
247 | decoder_channels=128,
248 | atrous_rates=atrous_rates,
249 | aspp_channels=256
250 | )
251 | self.instance_decoder = SinglePanopticDeepLabDecoder(**instance_decoder_kwargs)
252 |
253 | self.instance_head = SinglePanopticDeepLabHead(**instance_head_kwargs)
254 |
255 | def set_image_pooling(self, pool_size):
256 | self.semantic_decoder.set_image_pooling(pool_size)
257 | self.instance_decoder.set_image_pooling(pool_size)
258 |
259 | def forward(self, features):
260 | pred = OrderedDict()
261 |
262 | # Semantic branch
263 | semantic = self.semantic_decoder(features)
264 | semantic = self.semantic_head(semantic)
265 | for key in semantic.keys():
266 | pred[key] = semantic[key]
267 |
268 | # Instance branch
269 | instance = self.instance_decoder(features)
270 | instance = self.instance_head(instance)
271 | for key in instance.keys():
272 | pred[key] = instance[key]
273 |
274 | return pred
275 |
276 | ################################################################################################
277 |
278 | class BaseSegmentationModel(nn.Module):
279 | """
280 | Base class for segmentation models.
281 | Arguments:
282 | backbone: A nn.Module of backbone model.
283 | decoder: A nn.Module of decoder.
284 | """
285 | def __init__(self, backbone, decoder, args):
286 | super(BaseSegmentationModel, self).__init__()
287 | self.backbone = backbone
288 | self.decoder = decoder
289 | self.args = args
290 |
291 | def _init_params(self, ):
292 | # Backbone is already initialized (either from pre-trained checkpoint or random init).
293 | for m in self.decoder.modules():
294 | if isinstance(m, nn.Conv2d):
295 | nn.init.normal_(m.weight, mean=0.0, std=0.001)
296 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
297 | nn.init.constant_(m.weight, 1)
298 | nn.init.constant_(m.bias, 0)
299 |
300 | def set_image_pooling(self, pool_size):
301 | self.decoder.set_image_pooling(pool_size)
302 |
303 | def _upsample_predictions(self, pred, input_shape):
304 | """Upsamples final prediction.
305 | Args:
306 | pred (dict): stores all output of the segmentation model.
307 | input_shape (tuple): spatial resolution of the desired shape.
308 | Returns:
309 | result (OrderedDict): upsampled dictionary.
310 | """
311 | result = OrderedDict()
312 | for key in pred.keys():
313 | out = F.interpolate(pred[key], size=input_shape, mode='bilinear', align_corners=True)
314 | result[key] = out
315 | return result
316 |
317 | def forward(self, x, seg_map=None, label=None, point_list=None, target_shape=None):
318 | if target_shape is None:
319 | target_shape = x.shape[-2:]
320 |
321 | # contract: features is a dict of tensors
322 | features = self.backbone(x)
323 | pred = self.decoder(features)
324 | results = self._upsample_predictions(pred, target_shape)
325 |
326 | if label is not None: # refined label generation
327 |
328 | if self.args.sup == 'point': # point supervision setting
329 | pseudo_label = refine_label_generation_with_point(
330 | results['seg'].clone().detach(),
331 | point_list.cpu().numpy(),
332 | results['offset'].clone().detach(),
333 | label.clone().detach(),
334 | seg_map.clone().detach(),
335 | self.args,
336 | )
337 |
338 | else: # image-level supervision setting
339 | pseudo_label = refine_label_generation(
340 | results['seg'].clone().detach(),
341 | results['center'].clone().detach(),
342 | results['offset'].clone().detach(),
343 | label.clone().detach(),
344 | seg_map.clone().detach(),
345 | self.args,
346 | )
347 |
348 | return results, pseudo_label
349 |
350 | return results
351 |
352 |
353 | def PanopticDeepLab(backbone, args):
354 |
355 | instance_head_kwargs = dict(
356 | decoder_channels=128,
357 | head_channels=(128, 32),
358 | num_classes=(args.num_classes, 2),
359 | class_key=["center", "offset"],
360 | )
361 |
362 | decoder = PanopticDeepLabDecoder(in_channels=2048,
363 | feature_key="res5",
364 | low_level_channels=(1024, 512, 256),
365 | low_level_key=["res4", "res3", "res2"],
366 | low_level_channels_project=(128, 64, 32),
367 | decoder_channels=256,
368 | atrous_rates=(3, 6, 9),
369 | num_classes=args.num_classes+1,
370 | instance_head_kwargs=instance_head_kwargs
371 | )
372 |
373 | model = BaseSegmentationModel(backbone=backbone, decoder=decoder, args=args)
374 |
375 | model._init_params()
376 |
377 | # set batchnorm momentum
378 | for module in model.modules():
379 | if isinstance(module, torch.nn.BatchNorm2d):
380 | module.momentum = args.bn_momentum
381 |
382 | return model
383 |
384 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | BESTIE
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT License
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | import argparse
10 | import os
11 | import cv2
12 | import time
13 | import random
14 | import pickle
15 | from tqdm import tqdm
16 |
17 | import torch.nn as nn
18 | import torch.optim as optim
19 | from torch.utils.data import DataLoader
20 | from torch.utils.data.distributed import DistributedSampler
21 | from torch.nn.parallel import DistributedDataParallel as DDP
22 |
23 | from models import model_factory
24 | from utils.LoadData import get_dataset
25 | from utils.my_optim import WarmupPolyLR
26 | from utils.loss import Weighted_L1_Loss, Weighted_MSELoss, DeepLabCE
27 | from utils.utils import AverageMeter, get_ins_map, get_ins_map_with_point
28 |
29 | import chainercv
30 | from chainercv.datasets import VOCInstanceSegmentationDataset
31 | from chainercv.evaluations import eval_instance_segmentation_voc
32 |
33 | def str2bool(v):
34 | return v.lower() in ("yes", "y", "true", "t", "1")
35 |
36 | def parse():
37 |
38 | parser = argparse.ArgumentParser(description='BESTIE pytorch implementation')
39 | parser.add_argument("--root_dir", type=str, default='', help='Root dir for the project')
40 | parser.add_argument('--sup', type=str, help='supervision source', choices=["cls", "point"])
41 | parser.add_argument("--dataset", type=str, default='voc', choices=["voc", "coco"])
42 | parser.add_argument("--backbone", type=str, default='resnet50', choices=["resnet50", "resnet101", "hrnet34", "hrnet48"])
43 | parser.add_argument("--batch_size", type=int, default=16)
44 | parser.add_argument("--crop_size", type=int, default=416)
45 | parser.add_argument("--num_classes", type=int, default=20)
46 | parser.add_argument("--lr", type=float, default=5e-5)
47 | parser.add_argument("--weight_decay", type=float, default=0)
48 | parser.add_argument("--train_iter", type=int, default=50000)
49 | parser.add_argument("--warm_iter", type=int, default=2000, help='warm-up iterations')
50 | parser.add_argument("--train_epoch", type=int, default=0)
51 | parser.add_argument("--num_workers", type=int, default=4)
52 | parser.add_argument('--resume', default=None, type=str, help='weight restore')
53 |
54 | parser.add_argument('--save_folder', default='checkpoints/test1', help='Location to save checkpoint models')
55 | parser.add_argument('--print_freq', default=200, type=int, help='interval of showing training conditions')
56 | parser.add_argument('--save_freq', default=10000, type=int, help='interval of save checkpoint models')
57 | parser.add_argument("--cur_iter", type=int, default=0, help='current training interations')
58 |
59 | parser.add_argument("--gamma", type=float, default=0.9, help='learning rate decay power')
60 | parser.add_argument("--pseudo_thresh", type=float, default=0.7, help='threshold for pseudo-label generation')
61 | parser.add_argument("--refine_thresh", type=float, default=0.3, help='threshold for refined-label generation')
62 | parser.add_argument("--kernel", type=int, default=41, help='kernel size for point extraction')
63 | parser.add_argument("--sigma", type=int, default=6, help='sigma of 2D gaussian kernel')
64 | parser.add_argument("--beta", type=float, default=3.0, help='parameter for center-clustering')
65 | parser.add_argument("--bn_momentum", type=float, default=0.01)
66 | parser.add_argument('--refine', type=str2bool, default=True, help='enable self-refinement.')
67 | parser.add_argument("--refine_iter", type=int, default=0, help='self-refinement running iteration')
68 | parser.add_argument("--seg_weight", type=float, default=1.0, help='loss weight for segmantic segmentation map')
69 | parser.add_argument("--center_weight", type=float, default=200.0, help='loss weight for center map')
70 | parser.add_argument("--offset_weight", type=float, default=0.01, help='loss weight for offset map')
71 |
72 | parser.add_argument('--val_freq', default=1000, type=int, help='interval of model validation')
73 | parser.add_argument("--val_thresh", type=float, default=0.1, help='threhsold for instance-groupping in validation phase')
74 | parser.add_argument("--val_kernel", type=int, default=41, help='kernsl size for point extraction in validation phase')
75 | parser.add_argument('--val_flip', type=str2bool, default=True, help='enable flip test-time augmentation in vadliation phase')
76 | parser.add_argument('--val_clean', type=str2bool, default=False, help='cleaning pseudo-labels using image-level labels')
77 | parser.add_argument('--val_ignore', type=str2bool, default=False, help='ignore')
78 |
79 | parser.add_argument("--random_seed", type=int, default=1)
80 | parser.add_argument("--gpu", type=int, default=0)
81 | parser.add_argument("--world_size", type=int, default=1)
82 | parser.add_argument("--local_rank", type=int, default=int(os.environ["LOCAL_RANK"]))
83 |
84 | return parser.parse_args()
85 |
86 | def print_func(string):
87 | if torch.distributed.get_rank() == 0:
88 | print(string)
89 |
90 | def save_checkpoint(save_path, model):
91 | if torch.distributed.get_rank() == 0:
92 | print('\nSaving state: %s\n' % save_path)
93 | state = {
94 | 'model': model.module.state_dict(),
95 | }
96 | torch.save(state, save_path)
97 |
98 |
99 | def train():
100 |
101 | batch_time = AverageMeter()
102 | avg_total_loss = AverageMeter()
103 | avg_seg_loss = AverageMeter()
104 | avg_pseudo_center_loss = AverageMeter()
105 | avg_refine_center_loss = AverageMeter()
106 | avg_pseudo_offset_loss = AverageMeter()
107 | avg_refine_offset_loss = AverageMeter()
108 |
109 | best_AP = -1
110 |
111 | model.train()
112 | start = time.time()
113 | end = time.time()
114 | epoch = 0
115 |
116 | for cur_iter in range(1, args.train_iter+1):
117 |
118 | try:
119 | img, label, seg_map, center_map, offset_map, weight, point_list = next(data_iter)
120 | except Exception as e:
121 | print_func(" [LOADER ERROR] " + str(e))
122 |
123 | epoch += 1
124 | data_iter = iter(train_loader)
125 | img, label, seg_map, center_map, offset_map, weight, point_list = next(data_iter)
126 |
127 | end = time.time()
128 | batch_time.reset()
129 | avg_total_loss.reset()
130 | avg_seg_loss.reset()
131 | avg_pseudo_center_loss.reset()
132 | avg_refine_center_loss.reset()
133 | avg_pseudo_offset_loss.reset()
134 | avg_refine_offset_loss.reset()
135 |
136 | img = img.to(device, non_blocking=True)
137 | label = label.to(device, non_blocking=True)
138 | seg_map = seg_map.to(device, non_blocking=True)
139 | center_map = center_map.to(device, non_blocking=True)
140 | offset_map = offset_map.to(device, non_blocking=True)
141 | weight = weight.to(device, non_blocking=True)
142 |
143 | run_refine = args.refine and (cur_iter > args.refine_iter)
144 |
145 | if run_refine:
146 | out, c_label = model(img, seg_map, label, point_list)
147 | else:
148 | out = model(img)
149 |
150 | seg_loss = criterion['seg'](out['seg'], seg_map) * args.seg_weight
151 | center_loss_1 = criterion['center'](out['center'], center_map, weight) * args.center_weight
152 | offset_loss_1 = criterion['offset'](out['offset'], offset_map, weight) * args.offset_weight
153 |
154 | center_loss_2 = center_loss_1
155 | offset_loss_2 = offset_loss_1
156 |
157 | if run_refine and args.sup == 'cls':
158 | center_loss_2 = criterion['center'](out['center'], c_label['center'],
159 | c_label['weight']) * args.center_weight
160 |
161 | if run_refine:
162 | offset_loss_2 = criterion['offset'](out['offset'], c_label['offset'],
163 | c_label['weight']) * args.offset_weight
164 |
165 | loss = seg_loss + (center_loss_1 + center_loss_2)*0.5 + (offset_loss_1 + offset_loss_2)*0.5
166 |
167 | # compute gradient and backward
168 | optimizer.zero_grad()
169 | loss.backward()
170 | optimizer.step()
171 | lr_scheduler.step()
172 |
173 | batch_time.update((time.time() - end))
174 | end = time.time()
175 |
176 | avg_total_loss.update(loss.item(), img.size(0))
177 | avg_seg_loss.update(seg_loss.item(), img.size(0))
178 | avg_pseudo_center_loss.update(center_loss_1.item(), img.size(0))
179 | avg_refine_center_loss.update(center_loss_2.item(), img.size(0))
180 | avg_pseudo_offset_loss.update(offset_loss_1.item(), img.size(0))
181 | avg_refine_offset_loss.update(offset_loss_2.item(), img.size(0))
182 |
183 | if cur_iter % args.print_freq == 0:
184 | batch_time.synch(device)
185 | avg_total_loss.synch(device)
186 | avg_seg_loss.synch(device)
187 | avg_pseudo_center_loss.synch(device)
188 | avg_refine_center_loss.synch(device)
189 | avg_pseudo_offset_loss.synch(device)
190 | avg_refine_offset_loss.synch(device)
191 |
192 | if args.local_rank == 0:
193 | print('Progress: [{0}][{1}/{2}] ({3:.1f}%, {4:.1f} min) | '
194 | 'Time: {5:.1f} ms | '
195 | 'Left: {6:.1f} min | '
196 | 'TotalLoss: {7:.4f} | '
197 | 'SegLoss: {8:.4f} | '
198 | 'centerLoss: {9:.4f} ({10:.4f} + {11:.4f}) | '
199 | 'OffsetLoss: {12:.4f} ({13:.4f} + {14:.4f}) '.format(
200 | epoch, cur_iter, args.train_iter,
201 | cur_iter/args.train_iter*100, (end-start) / 60,
202 | batch_time.avg * 1000, (args.train_iter - cur_iter) * batch_time.avg / 60,
203 | avg_total_loss.avg, avg_seg_loss.avg,
204 | avg_pseudo_center_loss.avg + avg_refine_center_loss.avg,
205 | avg_pseudo_center_loss.avg, avg_refine_center_loss.avg,
206 | avg_pseudo_offset_loss.avg + avg_refine_offset_loss.avg,
207 | avg_pseudo_offset_loss.avg, avg_refine_offset_loss.avg,
208 | )
209 | )
210 |
211 | if args.local_rank == 0 and cur_iter % args.save_freq == 0:
212 | save_path = os.path.join(args.save_folder, 'last.pt')
213 | save_checkpoint(save_path, model)
214 |
215 | if cur_iter % args.val_freq == 0:
216 | val_score = validate()
217 |
218 | if args.local_rank == 0 and val_score['map'] > best_AP:
219 | best_AP = val_score['map']
220 | print('\n Best mAP50, iteration : %d, mAP50 : %.2f \n' % (cur_iter, best_AP))
221 |
222 | save_path = os.path.join(args.save_folder, 'best.pt')
223 | save_checkpoint(save_path, model)
224 |
225 | end = time.time()
226 |
227 | if args.local_rank == 0:
228 | print('\n training done')
229 | save_path = os.path.join(args.save_folder, 'last.pt')
230 | save_checkpoint(save_path, model)
231 |
232 | val_score = validate()
233 |
234 | if args.local_rank == 0 and val_score['map'] > best_AP:
235 | best_AP = val_score['map']
236 | print('\n Best mAP50, iteration : %d, mAP50 : %.2f \n' % (cur_iter, best_AP))
237 |
238 | model_file = os.path.join(args.save_folder, 'best.pt')
239 | save_checkpoint(save_path, model)
240 |
241 |
242 | def validate():
243 | model.eval()
244 |
245 | pred_seg_maps, pred_labels, pred_masks, pred_scores = [], [], [], []
246 | val_dir = "val_temp_dir"
247 | if args.local_rank == 0:
248 | os.makedirs(val_dir, exist_ok=True)
249 |
250 | torch.distributed.barrier()
251 | for img, cls_label, points, fname, tsize in tqdm(valid_loader):
252 | target_size = int(tsize[0]), int(tsize[1])
253 |
254 | if args.val_flip:
255 | img = torch.cat( [img, img.flip(-1)] , dim=0)
256 |
257 | out = model(img.to(device), target_shape=target_size)
258 |
259 | if args.sup == 'point' and args.val_clean:
260 | pred_seg, pred_label, pred_mask, pred_score = get_ins_map_with_point(out,
261 | cls_label,
262 | points,
263 | target_size,
264 | device,
265 | args)
266 | else:
267 | pred_seg, pred_label, pred_mask, pred_score = get_ins_map(out,
268 | cls_label,
269 | target_size,
270 | device,
271 | args)
272 |
273 | with open(f'{val_dir}/{fname[0]}.pickle', 'wb') as f:
274 | pickle.dump({
275 | 'pred_label': pred_label,
276 | 'pred_mask': pred_mask,
277 | 'pred_score': pred_score,
278 | }, f)
279 |
280 | torch.distributed.barrier()
281 |
282 | ap_result = {"ap": None, "map": None}
283 |
284 | if args.local_rank == 0:
285 | pred_masks, pred_labels, pred_scores = [], [], []
286 |
287 | for fname in ins_gt_ids:
288 | with open(f'{val_dir}/{fname}.pickle', 'rb') as f:
289 | dat = pickle.load(f)
290 | pred_masks.append(dat['pred_mask'])
291 | pred_labels.append(dat['pred_label'])
292 | pred_scores.append(dat['pred_score'])
293 |
294 | ap_result = eval_instance_segmentation_voc(pred_masks,
295 | pred_labels,
296 | pred_scores,
297 | ins_gt_masks,
298 | ins_gt_labels,
299 | iou_thresh=0.5)
300 |
301 | #print(ap_result)
302 | os.system(f"rm -rf {val_dir}")
303 |
304 | torch.distributed.barrier()
305 |
306 | model.train()
307 |
308 | return ap_result
309 |
310 |
311 | if __name__ == '__main__':
312 |
313 | args = parse()
314 |
315 | torch.manual_seed(args.random_seed)
316 | np.random.seed(args.random_seed)
317 | random.seed(args.random_seed)
318 |
319 | torch.backends.cudnn.benchmark = True
320 |
321 | args.gpu = args.local_rank
322 | torch.cuda.set_device(args.gpu)
323 |
324 | # Init dirstributed system
325 | torch.distributed.init_process_group(
326 | backend="nccl", rank=args.local_rank, world_size=torch.cuda.device_count()
327 | )
328 | args.world_size = torch.distributed.get_world_size()
329 | device = torch.device(f"cuda:{args.gpu}")
330 |
331 | if args.local_rank == 0:
332 | os.makedirs(args.save_folder, exist_ok=True)
333 |
334 | """ load model """
335 | model = model_factory(args)
336 | model = model.to(device)
337 |
338 | optimizer = optim.Adam(model.parameters(),
339 | lr=args.lr,
340 | weight_decay=args.weight_decay)
341 |
342 | # define loss function (criterion) and optimizer
343 | criterion = {"center" : Weighted_MSELoss(),
344 | "offset" : Weighted_L1_Loss(),
345 | "seg" : DeepLabCE()
346 | }
347 |
348 | # Optionally resume from a checkpoint
349 | if args.resume:
350 | if os.path.isfile(args.resume):
351 | print_func("=> loading checkpoint '{}'".format(args.resume))
352 | ckpt = torch.load(args.resume, map_location='cpu')['model']
353 | model.load_state_dict(new_dict, strict=True)
354 | else:
355 | print_func("=> no checkpoint found at '{}'".format(args.resume))
356 |
357 |
358 | """ Get data loader """
359 | train_dataset = get_dataset(args, mode='train')
360 | valid_dataset = get_dataset(args, mode='val')
361 | print_func("number of train set = %d | valid set = %d" % (len(train_dataset), len(valid_dataset)))
362 |
363 | n_gpus = torch.cuda.device_count()
364 | batch_per_gpu = args.batch_size // n_gpus
365 |
366 | train_sampler = DistributedSampler(train_dataset, num_replicas=n_gpus, rank=args.local_rank)
367 | train_loader = DataLoader(train_dataset,
368 | batch_size=batch_per_gpu,
369 | num_workers=args.num_workers,
370 | pin_memory=True,
371 | sampler=train_sampler,
372 | drop_last=True)
373 |
374 | valid_sampler = DistributedSampler(valid_dataset, num_replicas=n_gpus, rank=args.local_rank)
375 | valid_loader = DataLoader(valid_dataset,
376 | batch_size=1,
377 | num_workers=4,
378 | pin_memory=True,
379 | sampler=valid_sampler,
380 | drop_last=False)
381 |
382 | if args.train_epoch != 0:
383 | args.train_iter = args.train_epoch * len(train_loader)
384 |
385 | lr_scheduler = WarmupPolyLR(
386 | optimizer,
387 | args.train_iter,
388 | warmup_iters=args.warm_iter,
389 | power=args.gamma,
390 | )
391 |
392 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
393 | model = DDP(model, device_ids=[args.gpu])
394 |
395 | if args.val_freq != 0 and args.local_rank == 0:
396 | print("...Preparing GT dataset for evaluation")
397 | ins_dataset = VOCInstanceSegmentationDataset(split='val', data_dir=args.root_dir)
398 |
399 | ins_gt_ids = ins_dataset.ids
400 | ins_gt_masks = [ins_dataset.get_example_by_keys(i, (1,))[0] for i in range(len(ins_dataset))]
401 | ins_gt_labels = [ins_dataset.get_example_by_keys(i, (2,))[0] for i in range(len(ins_dataset))]
402 |
403 | torch.distributed.barrier()
404 | print_func("...Training Start \n")
405 | print_func(args)
406 |
407 | #validate()
408 | train()
409 |
--------------------------------------------------------------------------------
/utils/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import warnings
3 | import cv2
4 | import math
5 | import numpy as np
6 | from torchvision.transforms import functional as F
7 | import torch
8 |
9 | class Compose(object):
10 | """
11 | Composes a sequence of transforms.
12 | Arguments:
13 | transforms: A list of transforms.
14 | """
15 | def __init__(self, transforms):
16 | self.transforms = transforms
17 |
18 | def __call__(self, image, seg_map, peak):
19 | for t in self.transforms:
20 | image, seg_map, peak = t(image, seg_map, peak)
21 | return image, seg_map, peak
22 |
23 | def __repr__(self):
24 | format_string = self.__class__.__name__ + "("
25 | for t in self.transforms:
26 | format_string += "\n"
27 | format_string += " {0}".format(t)
28 | format_string += "\n)"
29 | return format_string
30 |
31 |
32 | class ToTensor(object):
33 | """
34 | Converts image to torch Tensor.
35 | """
36 | def __call__(self, image, seg_map, peak):
37 |
38 | image = F.to_tensor(image)
39 |
40 | seg_map = np.array(seg_map, dtype=np.uint8, copy=True)
41 |
42 | seg_map = torch.from_numpy(seg_map)
43 |
44 | return image, seg_map, peak
45 |
46 |
47 | class Normalize(object):
48 | """Normalize an tensor image with mean and standard deviation.
49 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
50 | will normalize each channel of the input ``torch.*Tensor`` i.e.
51 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
52 |
53 | Args:
54 | mean (sequence): Sequence of means for each channel.
55 | std (sequence): Sequence of standard deviations for each channel.
56 | """
57 |
58 | def __init__(self, mean, std):
59 | self.mean = mean
60 | self.std = std
61 |
62 | def __call__(self, image, seg_map, peak):
63 | """
64 | Args:
65 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
66 |
67 | Returns:
68 | Tensor: Normalized Tensor image.
69 | """
70 | image = F.normalize(image, self.mean, self.std)
71 |
72 | peak = [ [int(x), int(y), cls, conf] for x, y, cls, conf in peak]
73 |
74 | return image, seg_map, peak
75 |
76 |
77 | class RandomScale(object):
78 | """
79 | Applies random scale augmentation.
80 | Arguments:
81 | min_scale: Minimum scale value.
82 | max_scale: Maximum scale value.
83 | scale_step_size: The step size from minimum to maximum value.
84 | """
85 | def __init__(self, min_scale, max_scale, scale_step_size):
86 | self.min_scale = min_scale
87 | self.max_scale = max_scale
88 | self.scale_step_size = scale_step_size
89 |
90 | @staticmethod
91 | def get_random_scale(min_scale_factor, max_scale_factor, step_size):
92 | """Gets a random scale value.
93 | Args:
94 | min_scale_factor: Minimum scale value.
95 | max_scale_factor: Maximum scale value.
96 | step_size: The step size from minimum to maximum value.
97 | Returns:
98 | A random scale value selected between minimum and maximum value.
99 | Raises:
100 | ValueError: min_scale_factor has unexpected value.
101 | """
102 | if min_scale_factor < 0 or min_scale_factor > max_scale_factor:
103 | raise ValueError('Unexpected value of min_scale_factor.')
104 |
105 | if min_scale_factor == max_scale_factor:
106 | return min_scale_factor
107 |
108 | # When step_size = 0, we sample the value uniformly from [min, max).
109 | if step_size == 0:
110 | return random.uniform(min_scale_factor, max_scale_factor)
111 |
112 | # When step_size != 0, we randomly select one discrete value from [min, max].
113 | num_steps = int((max_scale_factor - min_scale_factor) / step_size + 1)
114 | scale_factors = np.linspace(min_scale_factor, max_scale_factor, num_steps)
115 | np.random.shuffle(scale_factors)
116 | return scale_factors[0]
117 |
118 | #def __call__(self, image, label):
119 | def __call__(self, image, seg_map, peak):
120 | f_scale = self.get_random_scale(self.min_scale, self.max_scale, self.scale_step_size)
121 | # TODO: cv2 uses align_corner=False
122 | # TODO: use fvcore (https://github.com/facebookresearch/fvcore/blob/master/fvcore/transforms/transform.py#L377)
123 | image_dtype = image.dtype
124 | seg_map_dtype = seg_map.dtype
125 |
126 | image = cv2.resize(image.astype(np.float), None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_LINEAR)
127 |
128 | seg_map = cv2.resize(seg_map.astype(np.float), None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_NEAREST)
129 |
130 | peak = [ [p[0]*f_scale, p[1]*f_scale, p[2], p[3]] for p in peak ]
131 |
132 | return image.astype(image_dtype), seg_map.astype(seg_map_dtype), peak
133 |
134 |
135 | class RandomCrop(object):
136 | """
137 | Applies random crop augmentation.
138 | Arguments:
139 | crop_h: Integer, crop height size.
140 | crop_w: Integer, crop width size.
141 | pad_value: Tuple, pad value for image, length 3.
142 | ignore_label: Tuple, pad value for label, length could be 1 (semantic) or 3 (panoptic).
143 | random_pad: Bool, when crop size larger than image size, whether to randomly pad four boundaries,
144 | or put image to top-left and only pad bottom and right boundaries.
145 | """
146 | def __init__(self, crop_h, crop_w, pad_value, ignore_label, random_pad):
147 | self.crop_h = crop_h
148 | self.crop_w = crop_w
149 | self.pad_value = pad_value
150 | self.ignore_label = ignore_label
151 | self.random_pad = random_pad
152 |
153 | def __call__(self, image, seg_map, peak):
154 | img_h, img_w = image.shape[0], image.shape[1]
155 | # save dtype
156 | image_dtype = image.dtype
157 | seg_map_dtype = seg_map.dtype
158 |
159 | # padding
160 | pad_h = max(self.crop_h - img_h, 0)
161 | pad_w = max(self.crop_w - img_w, 0)
162 | if pad_h > 0 or pad_w > 0:
163 | if self.random_pad:
164 | pad_top = random.randint(0, pad_h)
165 | pad_bottom = pad_h - pad_top
166 | pad_left = random.randint(0, pad_w)
167 | pad_right = pad_w - pad_left
168 | else:
169 | pad_top, pad_bottom, pad_left, pad_right = 0, pad_h, 0, pad_w
170 |
171 | img_pad = cv2.copyMakeBorder(image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT,
172 | value=self.pad_value)
173 | seg_map_pad = cv2.copyMakeBorder(seg_map, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT,
174 | value=self.ignore_label)
175 |
176 | peak = [ [px+pad_left, py+pad_top, cls, conf] for px, py, cls, conf in peak ]
177 |
178 | else:
179 | img_pad, seg_map_pad = image, seg_map
180 |
181 | img_h, img_w = img_pad.shape[0], img_pad.shape[1]
182 | h_off = random.randint(0, img_h - self.crop_h)
183 | w_off = random.randint(0, img_w - self.crop_w)
184 |
185 | image = np.asarray(img_pad[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w], np.float32)
186 | seg_map = np.asarray(seg_map_pad[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w], np.float32)
187 |
188 | peak_crop = []
189 | for px, py, cls, conf in peak:
190 | if (h_off <= py < h_off + self.crop_h) and (w_off <= px < w_off + self.crop_w):
191 | px_crop, py_crop = px-w_off, py-h_off
192 |
193 | #if (0 <= px_crop < self.crop_w) and (0 <= py_crop < self.crop_h):
194 | peak_crop.append( [px_crop, py_crop, cls, conf] )
195 |
196 | return image.astype(image_dtype), seg_map.astype(seg_map_dtype), peak_crop
197 |
198 |
199 | class RandomHorizontalFlip(object):
200 | """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
201 |
202 | def __call__(self, image, seg_map, peak):
203 |
204 | if random.random() < 0.5:
205 | image = torch.flip(image, [2])
206 | seg_map = torch.flip(seg_map, [1])
207 |
208 | _, H, W = image.shape
209 |
210 | peak = [ [W-p[0]-1, p[1], p[2], p[3]] for p in peak ]
211 |
212 | return image, seg_map, peak
213 |
214 |
215 | class Resize(object):
216 | """
217 | Applies random scale augmentation.
218 | Reference: https://github.com/tensorflow/models/blob/master/research/deeplab/input_preprocess.py#L28
219 | Arguments:
220 | min_resize_value: Desired size of the smaller image side, no resize if set to None
221 | max_resize_value: Maximum allowed size of the larger image side, no limit if set to None
222 | resize_factor: Resized dimensions are multiple of factor plus one.
223 | keep_aspect_ratio: Boolean, keep aspect ratio or not. If True, the input
224 | will be resized while keeping the original aspect ratio. If False, the
225 | input will be resized to [max_resize_value, max_resize_value] without
226 | keeping the original aspect ratio.
227 | align_corners: If True, exactly align all 4 corners of input and output.
228 | """
229 | def __init__(self, min_resize_value=None, max_resize_value=None, resize_factor=None,
230 | keep_aspect_ratio=True, align_corners=False):
231 | if min_resize_value is not None and min_resize_value < 0:
232 | min_resize_value = None
233 | if max_resize_value is not None and max_resize_value < 0:
234 | max_resize_value = None
235 | if resize_factor is not None and resize_factor < 0:
236 | resize_factor = None
237 | self.min_resize_value = min_resize_value
238 | self.max_resize_value = max_resize_value
239 | self.resize_factor = resize_factor
240 | self.keep_aspect_ratio = keep_aspect_ratio
241 | self.align_corners = align_corners
242 |
243 | if self.align_corners:
244 | warnings.warn('`align_corners = True` is not supported by opencv.')
245 |
246 | if self.max_resize_value is not None:
247 | # Modify the max_size to be a multiple of factor plus 1 and make sure the max dimension after resizing
248 | # is no larger than max_size.
249 | if self.resize_factor is not None:
250 | self.max_resize_value = (self.max_resize_value - (self.max_resize_value - 1) % self.resize_factor)
251 |
252 | def __call__(self, image, seg_map, peak):
253 | if self.min_resize_value is None:
254 | return image, label
255 | [orig_height, orig_width, _] = image.shape
256 | orig_min_size = np.minimum(orig_height, orig_width)
257 |
258 | # Calculate the larger of the possible sizes
259 | large_scale_factor = self.min_resize_value / orig_min_size
260 | large_height = int(math.floor(orig_height * large_scale_factor))
261 | large_width = int(math.floor(orig_width * large_scale_factor))
262 | large_size = np.array([large_height, large_width])
263 |
264 | new_size = large_size
265 | if self.max_resize_value is not None:
266 | # Calculate the smaller of the possible sizes, use that if the larger is too big.
267 | orig_max_size = np.maximum(orig_height, orig_width)
268 | small_scale_factor = self.max_resize_value / orig_max_size
269 | small_height = int(math.floor(orig_height * small_scale_factor))
270 | small_width = int(math.floor(orig_width * small_scale_factor))
271 | small_size = np.array([small_height, small_width])
272 |
273 | if np.max(large_size) > self.max_resize_value:
274 | new_size = small_size
275 |
276 | # Ensure that both output sides are multiples of factor plus one.
277 | if self.resize_factor is not None:
278 | new_size += (self.resize_factor - (new_size - 1) % self.resize_factor) % self.resize_factor
279 | # If new_size exceeds largest allowed size
280 | new_size[new_size > self.max_resize_value] -= self.resize_factor
281 |
282 | if not self.keep_aspect_ratio:
283 | # If not keep the aspect ratio, we resize everything to max_size, allowing
284 | # us to do pre-processing without extra padding.
285 | new_size = [np.max(new_size), np.max(new_size)]
286 |
287 | # TODO: cv2 uses align_corner=False
288 | # TODO: use fvcore (https://github.com/facebookresearch/fvcore/blob/master/fvcore/transforms/transform.py#L377)
289 | image_dtype = image.dtype
290 | seg_map_dtype = seg_map.dtype
291 |
292 | image = cv2.resize(image.astype(np.float), (new_size[1], new_size[0]), interpolation=cv2.INTER_LINEAR)
293 | seg_map = cv2.resize(seg_map.astype(np.float), (new_size[1], new_size[0]), interpolation=cv2.INTER_NEAREST)
294 |
295 | peak = [ [int(p[0]/orig_width*new_size[1]), int(p[1]/orig_height*new_size[0]), p[2], p[3]] for p in peak ]
296 |
297 | return image.astype(image_dtype), seg_map.astype(seg_map_dtype), peak
298 |
299 |
300 |
301 | class RandomContrast(object):
302 | def __init__(self, lower=0.5, upper=1.5):
303 | self.lower = lower
304 | self.upper = upper
305 | assert self.upper >= self.lower, "contrast upper must be >= lower."
306 | assert self.lower >= 0, "contrast lower must be non-negative."
307 |
308 | # expects float image
309 | def __call__(self, image, seg_map, peak):
310 | if random.random() < 0.5:
311 | alpha = random.uniform(self.lower, self.upper)
312 | image *= alpha
313 | image[image>255] = 255
314 | image[image<0] = 0
315 |
316 | return image, seg_map, peak
317 |
318 | class RandomBrightness(object):
319 | def __init__(self, delta=16):
320 | assert delta >= 0.0
321 | assert delta <= 255.0
322 | self.delta = delta
323 |
324 | def __call__(self, image, seg_map, peak):
325 | if random.random() < 0.5:
326 | delta = random.uniform(-self.delta, self.delta)
327 | image += delta
328 | image[image>255] = 255
329 | image[image<0] = 0
330 |
331 | return image, seg_map, peak
332 |
333 | class RandomHue(object):
334 | def __init__(self, delta=36.0):
335 | assert delta >= 0.0 and delta <= 360.0
336 | self.delta = delta
337 |
338 | def __call__(self, image, seg_map, peak):
339 | if random.random() < 0.5:
340 | image[:, :, 0] += random.uniform(-self.delta, self.delta)
341 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
342 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
343 | return image, seg_map, peak
344 |
345 |
346 | class RandomLightingNoise(object):
347 | def __init__(self):
348 | self.perms = ((0, 1, 2), (0, 2, 1),
349 | (1, 0, 2), (1, 2, 0),
350 | (2, 0, 1), (2, 1, 0))
351 |
352 | def __call__(self, image, seg_map, peak):
353 | if random.random() < 0.5:
354 | swap = self.perms[random.randint(len(self.perms))]
355 | shuffle = SwapChannels(swap) # shuffle channels
356 | image = shuffle(image)
357 | return image, seg_map, peak
358 |
359 |
360 | class ConvertColor(object):
361 | def __init__(self, current='BGR', transform='HSV'):
362 | self.transform = transform
363 | self.current = current
364 |
365 | def __call__(self, image, seg_map, peak):
366 | if self.current == 'RGB' and self.transform == 'HSV':
367 | image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
368 | elif self.current == 'HSV' and self.transform == 'RGB':
369 | image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB)
370 | else:
371 | raise NotImplementedError
372 | return image, seg_map, peak
373 |
374 | class PhotometricDistort(object):
375 | def __init__(self):
376 | self.pd = [
377 | RandomContrast(),
378 | ConvertColor(current='RGB', transform='HSV'),
379 | RandomHue(),
380 | ConvertColor(current='HSV', transform='RGB'),
381 | RandomContrast()
382 | ]
383 | self.rand_brightness = RandomBrightness()
384 | self.rand_light_noise = RandomLightingNoise()
385 |
386 | def __call__(self, image, seg_map, peak):
387 | #m = image.copy()
388 | im = image.copy().astype(np.float32)
389 |
390 | im, seg_map, peak = self.rand_brightness(im, seg_map, peak)
391 |
392 | if random.random() < 0.5:
393 | distort = Compose(self.pd[:-1])
394 | else:
395 | distort = Compose(self.pd[1:])
396 |
397 | im, seg_map, peak = distort(im, seg_map, peak)
398 | # im, boxes, labels = self.rand_light_noise(im, boxes, labels)
399 |
400 | im = np.clip(im, 0, 255).astype(np.uint8)
401 |
402 | return im, seg_map, peak
403 |
404 |
405 | class Keepsize(object):
406 | """Resize the input PIL Image to the given size.
407 |
408 | Args:
409 | size (sequence or int): Desired output size. If size is a sequence like
410 | (h, w), output size will be matched to this. If size is an int,
411 | smaller edge of the image will be matched to this number.
412 | i.e, if height > width, then image will be rescaled to
413 | (size * height / width, size)
414 | interpolation (int, optional): Desired interpolation. Default is
415 | ``PIL.Image.BILINEAR``
416 | """
417 |
418 | def __init__(self, shape=None):
419 | self.shape = shape
420 |
421 | pass
422 |
423 | def __call__(self, img, seg_map, peak):
424 | """
425 | Args:
426 | img (PIL Image): Image to be scaled.
427 |
428 | Returns:
429 | PIL Image: Rescaled image.
430 | """
431 | ori_h, ori_w, _ = img.shape
432 |
433 | if self.shape is None:
434 | new_h = (ori_h + 31) // 32 * 32
435 | new_w = (ori_w + 31) // 32 * 32
436 | else:
437 | new_h = self.shape[1]
438 | new_w = self.shape[0]
439 |
440 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
441 | seg_map = cv2.resize(seg_map, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
442 |
443 | peak = [ [int(p[0]/ori_w*new_w), int(p[1]/ori_h*new_h), p[2], p[3]] for p in peak ]
444 |
445 | return img, seg_map, peak
446 |
--------------------------------------------------------------------------------
/PAM/utils/transforms/functional.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Reference: https://github.com/qjadud1994/DRS/blob/main/utils/transforms/functional.py
3 | # ------------------------------------------------------------------------------
4 |
5 | from __future__ import division
6 | import torch
7 | import math
8 | import random
9 | from PIL import Image, ImageOps, ImageEnhance
10 | try:
11 | import accimage
12 | except ImportError:
13 | accimage = None
14 | import numpy as np
15 | import numbers
16 | import types
17 | import collections
18 | import warnings
19 |
20 |
21 | def _is_pil_image(img):
22 | if accimage is not None:
23 | return isinstance(img, (Image.Image, accimage.Image))
24 | else:
25 | return isinstance(img, Image.Image)
26 |
27 |
28 | def _is_tensor_image(img):
29 | return torch.is_tensor(img) and img.ndimension() == 3
30 |
31 |
32 | def _is_numpy_image(img):
33 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
34 |
35 |
36 | def to_tensor(pic):
37 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
38 |
39 | See ``ToTensor`` for more details.
40 |
41 | Args:
42 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
43 |
44 | Returns:
45 | Tensor: Converted image.
46 | """
47 | if not(_is_pil_image(pic) or _is_numpy_image(pic)):
48 | raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
49 |
50 | if isinstance(pic, np.ndarray):
51 | # handle numpy array
52 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
53 | # backward compatibility
54 | return img.float().div(255)
55 | # return img.float()
56 |
57 | if accimage is not None and isinstance(pic, accimage.Image):
58 | nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
59 | pic.copyto(nppic)
60 | return torch.from_numpy(nppic)
61 |
62 | # handle PIL Image
63 | if pic.mode == 'I':
64 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
65 | elif pic.mode == 'I;16':
66 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
67 | else:
68 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
69 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
70 | if pic.mode == 'YCbCr':
71 | nchannel = 3
72 | elif pic.mode == 'I;16':
73 | nchannel = 1
74 | else:
75 | nchannel = len(pic.mode)
76 | img = img.view(pic.size[1], pic.size[0], nchannel)
77 | # put it from HWC to CHW format
78 | # yikes, this transpose takes 80% of the loading time/CPU
79 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
80 | if isinstance(img, torch.ByteTensor):
81 | return img.float().div(255)
82 | # return img.float()
83 | else:
84 | return img
85 |
86 |
87 | def to_pil_image(pic, mode=None):
88 | """Convert a tensor or an ndarray to PIL Image.
89 |
90 | See :class:`~torchvision.transforms.ToPIlImage` for more details.
91 |
92 | Args:
93 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
94 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
95 |
96 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
97 |
98 | Returns:
99 | PIL Image: Image converted to PIL Image.
100 | """
101 | if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
102 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
103 |
104 | npimg = pic
105 | if isinstance(pic, torch.FloatTensor):
106 | pic = pic.mul(255).byte()
107 | if torch.is_tensor(pic):
108 | npimg = np.transpose(pic.numpy(), (1, 2, 0))
109 |
110 | if not isinstance(npimg, np.ndarray):
111 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
112 | 'not {}'.format(type(npimg)))
113 |
114 | if npimg.shape[2] == 1:
115 | expected_mode = None
116 | npimg = npimg[:, :, 0]
117 | if npimg.dtype == np.uint8:
118 | expected_mode = 'L'
119 | if npimg.dtype == np.int16:
120 | expected_mode = 'I;16'
121 | if npimg.dtype == np.int32:
122 | expected_mode = 'I'
123 | elif npimg.dtype == np.float32:
124 | expected_mode = 'F'
125 | if mode is not None and mode != expected_mode:
126 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
127 | .format(mode, np.dtype, expected_mode))
128 | mode = expected_mode
129 |
130 | elif npimg.shape[2] == 4:
131 | permitted_4_channel_modes = ['RGBA', 'CMYK']
132 | if mode is not None and mode not in permitted_4_channel_modes:
133 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
134 |
135 | if mode is None and npimg.dtype == np.uint8:
136 | mode = 'RGBA'
137 | else:
138 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
139 | if mode is not None and mode not in permitted_3_channel_modes:
140 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
141 | if mode is None and npimg.dtype == np.uint8:
142 | mode = 'RGB'
143 |
144 | if mode is None:
145 | raise TypeError('Input type {} is not supported'.format(npimg.dtype))
146 |
147 | return Image.fromarray(npimg, mode=mode)
148 |
149 |
150 | def normalize(tensor, mean, std):
151 | """Normalize a tensor image with mean and standard deviation.
152 |
153 | See ``Normalize`` for more details.
154 |
155 | Args:
156 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
157 | mean (sequence): Sequence of means for each channel.
158 | std (sequence): Sequence of standard deviations for each channely.
159 |
160 | Returns:
161 | Tensor: Normalized Tensor image.
162 | """
163 | if not _is_tensor_image(tensor):
164 | raise TypeError('tensor is not a torch image.')
165 | # TODO: make efficient
166 | for t, m, s in zip(tensor, mean, std):
167 | t.sub_(m).div_(s)
168 | return tensor
169 |
170 |
171 | def resize(img, size, interpolation=Image.BILINEAR):
172 | """Resize the input PIL Image to the given size.
173 |
174 | Args:
175 | img (PIL Image): Image to be resized.
176 | size (sequence or int): Desired output size. If size is a sequence like
177 | (h, w), the output size will be matched to this. If size is an int,
178 | the smaller edge of the image will be matched to this number maintaing
179 | the aspect ratio. i.e, if height > width, then image will be rescaled to
180 | (size * height / width, size)
181 | interpolation (int, optional): Desired interpolation. Default is
182 | ``PIL.Image.BILINEAR``
183 |
184 | Returns:
185 | PIL Image: Resized image.
186 | """
187 | if not _is_pil_image(img):
188 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
189 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
190 | raise TypeError('Got inappropriate size arg: {}'.format(size))
191 |
192 | if isinstance(size, int):
193 | w, h = img.size
194 | if (w <= h and w == size) or (h <= w and h == size):
195 | return img
196 | if w < h:
197 | ow = size
198 | oh = int(size * h / w)
199 | return img.resize((ow, oh), interpolation)
200 | else:
201 | oh = size
202 | ow = int(size * w / h)
203 | return img.resize((ow, oh), interpolation)
204 | else:
205 | return img.resize(size[::-1], interpolation)
206 |
207 |
208 | def scale(*args, **kwargs):
209 | warnings.warn("The use of the transforms.Scale transform is deprecated, " +
210 | "please use transforms.Resize instead.")
211 | return resize(*args, **kwargs)
212 |
213 |
214 | def pad(img, padding, fill=0):
215 | """Pad the given PIL Image on all sides with the given "pad" value.
216 |
217 | Args:
218 | img (PIL Image): Image to be padded.
219 | padding (int or tuple): Padding on each border. If a single int is provided this
220 | is used to pad all borders. If tuple of length 2 is provided this is the padding
221 | on left/right and top/bottom respectively. If a tuple of length 4 is provided
222 | this is the padding for the left, top, right and bottom borders
223 | respectively.
224 | fill: Pixel fill value. Default is 0. If a tuple of
225 | length 3, it is used to fill R, G, B channels respectively.
226 |
227 | Returns:
228 | PIL Image: Padded image.
229 | """
230 | if not _is_pil_image(img):
231 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
232 |
233 | if not isinstance(padding, (numbers.Number, tuple)):
234 | raise TypeError('Got inappropriate padding arg')
235 | if not isinstance(fill, (numbers.Number, str, tuple)):
236 | raise TypeError('Got inappropriate fill arg')
237 |
238 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
239 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
240 | "{} element tuple".format(len(padding)))
241 |
242 | return ImageOps.expand(img, border=padding, fill=fill)
243 |
244 |
245 | def crop(img, i, j, h, w):
246 | """Crop the given PIL Image.
247 |
248 | Args:
249 | img (PIL Image): Image to be cropped.
250 | i: Upper pixel coordinate.
251 | j: Left pixel coordinate.
252 | h: Height of the cropped image.
253 | w: Width of the cropped image.
254 |
255 | Returns:
256 | PIL Image: Cropped image.
257 | """
258 | if not _is_pil_image(img):
259 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
260 |
261 | return img.crop((j, i, j + w, i + h))
262 |
263 |
264 | def center_crop(img, output_size):
265 | if isinstance(output_size, numbers.Number):
266 | output_size = (int(output_size), int(output_size))
267 | w, h = img.size
268 | th, tw = output_size
269 | i = int(round((h - th) / 2.))
270 | j = int(round((w - tw) / 2.))
271 | return crop(img, i, j, th, tw)
272 |
273 |
274 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
275 | """Crop the given PIL Image and resize it to desired size.
276 |
277 | Notably used in RandomResizedCrop.
278 |
279 | Args:
280 | img (PIL Image): Image to be cropped.
281 | i: Upper pixel coordinate.
282 | j: Left pixel coordinate.
283 | h: Height of the cropped image.
284 | w: Width of the cropped image.
285 | size (sequence or int): Desired output size. Same semantics as ``scale``.
286 | interpolation (int, optional): Desired interpolation. Default is
287 | ``PIL.Image.BILINEAR``.
288 | Returns:
289 | PIL Image: Cropped image.
290 | """
291 | assert _is_pil_image(img), 'img should be PIL Image'
292 | img = crop(img, i, j, h, w)
293 | img = resize(img, size, interpolation)
294 | return img
295 |
296 |
297 | def hflip(img):
298 | """Horizontally flip the given PIL Image.
299 |
300 | Args:
301 | img (PIL Image): Image to be flipped.
302 |
303 | Returns:
304 | PIL Image: Horizontall flipped image.
305 | """
306 | if not _is_pil_image(img):
307 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
308 |
309 | return img.transpose(Image.FLIP_LEFT_RIGHT)
310 |
311 |
312 | def vflip(img):
313 | """Vertically flip the given PIL Image.
314 |
315 | Args:
316 | img (PIL Image): Image to be flipped.
317 |
318 | Returns:
319 | PIL Image: Vertically flipped image.
320 | """
321 | if not _is_pil_image(img):
322 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
323 |
324 | return img.transpose(Image.FLIP_TOP_BOTTOM)
325 |
326 |
327 | def five_crop(img, size):
328 | """Crop the given PIL Image into four corners and the central crop.
329 |
330 | .. Note::
331 | This transform returns a tuple of images and there may be a
332 | mismatch in the number of inputs and targets your ``Dataset`` returns.
333 |
334 | Args:
335 | size (sequence or int): Desired output size of the crop. If size is an
336 | int instead of sequence like (h, w), a square crop (size, size) is
337 | made.
338 | Returns:
339 | tuple: tuple (tl, tr, bl, br, center) corresponding top left,
340 | top right, bottom left, bottom right and center crop.
341 | """
342 | if isinstance(size, numbers.Number):
343 | size = (int(size), int(size))
344 | else:
345 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
346 |
347 | w, h = img.size
348 | crop_h, crop_w = size
349 | if crop_w > w or crop_h > h:
350 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
351 | (h, w)))
352 | tl = img.crop((0, 0, crop_w, crop_h))
353 | tr = img.crop((w - crop_w, 0, w, crop_h))
354 | bl = img.crop((0, h - crop_h, crop_w, h))
355 | br = img.crop((w - crop_w, h - crop_h, w, h))
356 | center = center_crop(img, (crop_h, crop_w))
357 | return (tl, tr, bl, br, center)
358 |
359 |
360 | def ten_crop(img, size, vertical_flip=False):
361 | """Crop the given PIL Image into four corners and the central crop plus the
362 | flipped version of these (horizontal flipping is used by default).
363 |
364 | .. Note::
365 | This transform returns a tuple of images and there may be a
366 | mismatch in the number of inputs and targets your ``Dataset`` returns.
367 |
368 | Args:
369 | size (sequence or int): Desired output size of the crop. If size is an
370 | int instead of sequence like (h, w), a square crop (size, size) is
371 | made.
372 | vertical_flip (bool): Use vertical flipping instead of horizontal
373 |
374 | Returns:
375 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
376 | br_flip, center_flip) corresponding top left, top right,
377 | bottom left, bottom right and center crop and same for the
378 | flipped image.
379 | """
380 | if isinstance(size, numbers.Number):
381 | size = (int(size), int(size))
382 | else:
383 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
384 |
385 | first_five = five_crop(img, size)
386 |
387 | if vertical_flip:
388 | img = vflip(img)
389 | else:
390 | img = hflip(img)
391 |
392 | second_five = five_crop(img, size)
393 | return first_five + second_five
394 |
395 |
396 | def adjust_brightness(img, brightness_factor):
397 | """Adjust brightness of an Image.
398 |
399 | Args:
400 | img (PIL Image): PIL Image to be adjusted.
401 | brightness_factor (float): How much to adjust the brightness. Can be
402 | any non negative number. 0 gives a black image, 1 gives the
403 | original image while 2 increases the brightness by a factor of 2.
404 |
405 | Returns:
406 | PIL Image: Brightness adjusted image.
407 | """
408 | if not _is_pil_image(img):
409 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
410 |
411 | enhancer = ImageEnhance.Brightness(img)
412 | img = enhancer.enhance(brightness_factor)
413 | return img
414 |
415 |
416 | def adjust_contrast(img, contrast_factor):
417 | """Adjust contrast of an Image.
418 |
419 | Args:
420 | img (PIL Image): PIL Image to be adjusted.
421 | contrast_factor (float): How much to adjust the contrast. Can be any
422 | non negative number. 0 gives a solid gray image, 1 gives the
423 | original image while 2 increases the contrast by a factor of 2.
424 |
425 | Returns:
426 | PIL Image: Contrast adjusted image.
427 | """
428 | if not _is_pil_image(img):
429 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
430 |
431 | enhancer = ImageEnhance.Contrast(img)
432 | img = enhancer.enhance(contrast_factor)
433 | return img
434 |
435 |
436 | def adjust_saturation(img, saturation_factor):
437 | """Adjust color saturation of an image.
438 |
439 | Args:
440 | img (PIL Image): PIL Image to be adjusted.
441 | saturation_factor (float): How much to adjust the saturation. 0 will
442 | give a black and white image, 1 will give the original image while
443 | 2 will enhance the saturation by a factor of 2.
444 |
445 | Returns:
446 | PIL Image: Saturation adjusted image.
447 | """
448 | if not _is_pil_image(img):
449 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
450 |
451 | enhancer = ImageEnhance.Color(img)
452 | img = enhancer.enhance(saturation_factor)
453 | return img
454 |
455 |
456 | def adjust_hue(img, hue_factor):
457 | """Adjust hue of an image.
458 |
459 | The image hue is adjusted by converting the image to HSV and
460 | cyclically shifting the intensities in the hue channel (H).
461 | The image is then converted back to original image mode.
462 |
463 | `hue_factor` is the amount of shift in H channel and must be in the
464 | interval `[-0.5, 0.5]`.
465 |
466 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
467 |
468 | Args:
469 | img (PIL Image): PIL Image to be adjusted.
470 | hue_factor (float): How much to shift the hue channel. Should be in
471 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
472 | HSV space in positive and negative direction respectively.
473 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
474 | with complementary colors while 0 gives the original image.
475 |
476 | Returns:
477 | PIL Image: Hue adjusted image.
478 | """
479 | if not(-0.5 <= hue_factor <= 0.5):
480 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
481 |
482 | if not _is_pil_image(img):
483 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
484 |
485 | input_mode = img.mode
486 | if input_mode in {'L', '1', 'I', 'F'}:
487 | return img
488 |
489 | h, s, v = img.convert('HSV').split()
490 |
491 | np_h = np.array(h, dtype=np.uint8)
492 | # uint8 addition take cares of rotation across boundaries
493 | with np.errstate(over='ignore'):
494 | np_h += np.uint8(hue_factor * 255)
495 | h = Image.fromarray(np_h, 'L')
496 |
497 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
498 | return img
499 |
500 |
501 | def adjust_gamma(img, gamma, gain=1):
502 | """Perform gamma correction on an image.
503 |
504 | Also known as Power Law Transform. Intensities in RGB mode are adjusted
505 | based on the following equation:
506 |
507 | I_out = 255 * gain * ((I_in / 255) ** gamma)
508 |
509 | See https://en.wikipedia.org/wiki/Gamma_correction for more details.
510 |
511 | Args:
512 | img (PIL Image): PIL Image to be adjusted.
513 | gamma (float): Non negative real number. gamma larger than 1 make the
514 | shadows darker, while gamma smaller than 1 make dark regions
515 | lighter.
516 | gain (float): The constant multiplier.
517 | """
518 | if not _is_pil_image(img):
519 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
520 |
521 | if gamma < 0:
522 | raise ValueError('Gamma should be a non-negative real number')
523 |
524 | input_mode = img.mode
525 | img = img.convert('RGB')
526 |
527 | np_img = np.array(img, dtype=np.float32)
528 | np_img = 255 * gain * ((np_img / 255) ** gamma)
529 | np_img = np.uint8(np.clip(np_img, 0, 255))
530 |
531 | img = Image.fromarray(np_img, 'RGB').convert(input_mode)
532 | return img
533 |
534 |
535 | def rotate(img, angle, resample=False, expand=False, center=None):
536 | """Rotate the image by angle and then (optionally) translate it by (n_columns, n_rows)
537 |
538 |
539 | Args:
540 | img (PIL Image): PIL Image to be rotated.
541 | angle ({float, int}): In degrees degrees counter clockwise order.
542 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
543 | An optional resampling filter.
544 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
545 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
546 | expand (bool, optional): Optional expansion flag.
547 | If true, expands the output image to make it large enough to hold the entire rotated image.
548 | If false or omitted, make the output image the same size as the input image.
549 | Note that the expand flag assumes rotation around the center and no translation.
550 | center (2-tuple, optional): Optional center of rotation.
551 | Origin is the upper left corner.
552 | Default is the center of the image.
553 | """
554 |
555 | if not _is_pil_image(img):
556 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
557 |
558 | return img.rotate(angle, resample, expand, center)
559 |
560 |
561 | def to_grayscale(img, num_output_channels=1):
562 | """Convert image to grayscale version of image.
563 |
564 | Args:
565 | img (PIL Image): Image to be converted to grayscale.
566 |
567 | Returns:
568 | PIL Image: Grayscale version of the image.
569 | if num_output_channels == 1 : returned image is single channel
570 | if num_output_channels == 3 : returned image is 3 channel with r == g == b
571 | """
572 | if not _is_pil_image(img):
573 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
574 |
575 | if num_output_channels == 1:
576 | img = img.convert('L')
577 | elif num_output_channels == 3:
578 | img = img.convert('L')
579 | np_img = np.array(img, dtype=np.uint8)
580 | np_img = np.dstack([np_img, np_img, np_img])
581 | img = Image.fromarray(np_img, 'RGB')
582 | else:
583 | raise ValueError('num_output_channels should be either 1 or 3')
584 |
585 | return img
--------------------------------------------------------------------------------
/models/hrnet.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import os
12 | import logging
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 | from models.hrnet_config import cfg
18 | from models.hrnet_config import update_config
19 |
20 | from utils.utils import refine_label_generation, refine_label_generation_with_point
21 |
22 | BN_MOMENTUM = 0.1
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | def conv3x3(in_planes, out_planes, stride=1):
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=1, bias=False)
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | expansion = 1
34 |
35 | def __init__(self, inplanes, planes, stride=1, downsample=None):
36 | super(BasicBlock, self).__init__()
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.conv2 = conv3x3(planes, planes)
41 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | residual = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 |
55 | if self.downsample is not None:
56 | residual = self.downsample(x)
57 |
58 | out += residual
59 | out = self.relu(out)
60 |
61 | return out
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | expansion = 4
66 |
67 | def __init__(self, inplanes, planes, stride=1, downsample=None):
68 | super(Bottleneck, self).__init__()
69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
70 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
72 | padding=1, bias=False)
73 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
74 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
75 | bias=False)
76 | self.bn3 = nn.BatchNorm2d(planes * self.expansion,
77 | momentum=BN_MOMENTUM)
78 | self.relu = nn.ReLU(inplace=True)
79 | self.downsample = downsample
80 | self.stride = stride
81 |
82 | def forward(self, x):
83 | residual = x
84 |
85 | out = self.conv1(x)
86 | out = self.bn1(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv2(out)
90 | out = self.bn2(out)
91 | out = self.relu(out)
92 |
93 | out = self.conv3(out)
94 | out = self.bn3(out)
95 |
96 | if self.downsample is not None:
97 | residual = self.downsample(x)
98 |
99 | out += residual
100 | out = self.relu(out)
101 |
102 | return out
103 |
104 |
105 | class HighResolutionModule(nn.Module):
106 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
107 | num_channels, fuse_method, multi_scale_output=True):
108 | super(HighResolutionModule, self).__init__()
109 | self._check_branches(
110 | num_branches, blocks, num_blocks, num_inchannels, num_channels)
111 |
112 | self.num_inchannels = num_inchannels
113 | self.fuse_method = fuse_method
114 | self.num_branches = num_branches
115 |
116 | self.multi_scale_output = multi_scale_output
117 |
118 | self.branches = self._make_branches(
119 | num_branches, blocks, num_blocks, num_channels)
120 | self.fuse_layers = self._make_fuse_layers()
121 | self.relu = nn.ReLU(True)
122 |
123 | def _check_branches(self, num_branches, blocks, num_blocks,
124 | num_inchannels, num_channels):
125 | if num_branches != len(num_blocks):
126 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
127 | num_branches, len(num_blocks))
128 | logger.error(error_msg)
129 | raise ValueError(error_msg)
130 |
131 | if num_branches != len(num_channels):
132 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
133 | num_branches, len(num_channels))
134 | logger.error(error_msg)
135 | raise ValueError(error_msg)
136 |
137 | if num_branches != len(num_inchannels):
138 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
139 | num_branches, len(num_inchannels))
140 | logger.error(error_msg)
141 | raise ValueError(error_msg)
142 |
143 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
144 | stride=1):
145 | downsample = None
146 | if stride != 1 or \
147 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
148 | downsample = nn.Sequential(
149 | nn.Conv2d(
150 | self.num_inchannels[branch_index],
151 | num_channels[branch_index] * block.expansion,
152 | kernel_size=1, stride=stride, bias=False
153 | ),
154 | nn.BatchNorm2d(
155 | num_channels[branch_index] * block.expansion,
156 | momentum=BN_MOMENTUM
157 | ),
158 | )
159 |
160 | layers = []
161 | layers.append(
162 | block(
163 | self.num_inchannels[branch_index],
164 | num_channels[branch_index],
165 | stride,
166 | downsample
167 | )
168 | )
169 | self.num_inchannels[branch_index] = \
170 | num_channels[branch_index] * block.expansion
171 | for i in range(1, num_blocks[branch_index]):
172 | layers.append(
173 | block(
174 | self.num_inchannels[branch_index],
175 | num_channels[branch_index]
176 | )
177 | )
178 |
179 | return nn.Sequential(*layers)
180 |
181 | def _make_branches(self, num_branches, block, num_blocks, num_channels):
182 | branches = []
183 |
184 | for i in range(num_branches):
185 | branches.append(
186 | self._make_one_branch(i, block, num_blocks, num_channels)
187 | )
188 |
189 | return nn.ModuleList(branches)
190 |
191 | def _make_fuse_layers(self):
192 | if self.num_branches == 1:
193 | return None
194 |
195 | num_branches = self.num_branches
196 | num_inchannels = self.num_inchannels
197 | fuse_layers = []
198 | for i in range(num_branches if self.multi_scale_output else 1):
199 | fuse_layer = []
200 | for j in range(num_branches):
201 | if j > i:
202 | fuse_layer.append(
203 | nn.Sequential(
204 | nn.Conv2d(
205 | num_inchannels[j],
206 | num_inchannels[i],
207 | 1, 1, 0, bias=False
208 | ),
209 | nn.BatchNorm2d(num_inchannels[i]),
210 | nn.Upsample(scale_factor=2**(j-i), mode='nearest')
211 | )
212 | )
213 | elif j == i:
214 | fuse_layer.append(None)
215 | else:
216 | conv3x3s = []
217 | for k in range(i-j):
218 | if k == i - j - 1:
219 | num_outchannels_conv3x3 = num_inchannels[i]
220 | conv3x3s.append(
221 | nn.Sequential(
222 | nn.Conv2d(
223 | num_inchannels[j],
224 | num_outchannels_conv3x3,
225 | 3, 2, 1, bias=False
226 | ),
227 | nn.BatchNorm2d(num_outchannels_conv3x3)
228 | )
229 | )
230 | else:
231 | num_outchannels_conv3x3 = num_inchannels[j]
232 | conv3x3s.append(
233 | nn.Sequential(
234 | nn.Conv2d(
235 | num_inchannels[j],
236 | num_outchannels_conv3x3,
237 | 3, 2, 1, bias=False
238 | ),
239 | nn.BatchNorm2d(num_outchannels_conv3x3),
240 | nn.ReLU(True)
241 | )
242 | )
243 | fuse_layer.append(nn.Sequential(*conv3x3s))
244 | fuse_layers.append(nn.ModuleList(fuse_layer))
245 |
246 | return nn.ModuleList(fuse_layers)
247 |
248 | def get_num_inchannels(self):
249 | return self.num_inchannels
250 |
251 | def forward(self, x):
252 | if self.num_branches == 1:
253 | return [self.branches[0](x[0])]
254 |
255 | for i in range(self.num_branches):
256 | x[i] = self.branches[i](x[i])
257 |
258 | x_fuse = []
259 |
260 | for i in range(len(self.fuse_layers)):
261 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
262 | for j in range(1, self.num_branches):
263 | if i == j:
264 | y = y + x[j]
265 | else:
266 | y = y + self.fuse_layers[i][j](x[j])
267 | x_fuse.append(self.relu(y))
268 |
269 | return x_fuse
270 |
271 |
272 | blocks_dict = {
273 | 'BASIC': BasicBlock,
274 | 'BOTTLENECK': Bottleneck
275 | }
276 |
277 | class PoseHighResolutionNet(nn.Module):
278 |
279 | def __init__(self, cfg, heads, args, **kwargs):
280 | self.inplanes = 64
281 | extra = cfg['MODEL']['EXTRA']
282 | super(PoseHighResolutionNet, self).__init__()
283 |
284 | self.heads = heads
285 | self.args = args
286 |
287 | # stem net
288 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
289 | bias=False)
290 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
291 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
292 | bias=False)
293 | self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
294 | self.relu = nn.ReLU(inplace=True)
295 | self.layer1 = self._make_layer(Bottleneck, 64, 4)
296 |
297 | self.stage2_cfg = extra['STAGE2']
298 | num_channels = self.stage2_cfg['NUM_CHANNELS']
299 | block = blocks_dict[self.stage2_cfg['BLOCK']]
300 | num_channels = [
301 | num_channels[i] * block.expansion for i in range(len(num_channels))
302 | ]
303 | self.transition1 = self._make_transition_layer([256], num_channels)
304 | self.stage2, pre_stage_channels = self._make_stage(
305 | self.stage2_cfg, num_channels)
306 |
307 | self.stage3_cfg = extra['STAGE3']
308 | num_channels = self.stage3_cfg['NUM_CHANNELS']
309 | block = blocks_dict[self.stage3_cfg['BLOCK']]
310 | num_channels = [
311 | num_channels[i] * block.expansion for i in range(len(num_channels))
312 | ]
313 | self.transition2 = self._make_transition_layer(
314 | pre_stage_channels, num_channels)
315 | self.stage3, pre_stage_channels = self._make_stage(
316 | self.stage3_cfg, num_channels)
317 |
318 | self.stage4_cfg = extra['STAGE4']
319 | num_channels = self.stage4_cfg['NUM_CHANNELS']
320 | block = blocks_dict[self.stage4_cfg['BLOCK']]
321 | num_channels = [
322 | num_channels[i] * block.expansion for i in range(len(num_channels))
323 | ]
324 | self.transition3 = self._make_transition_layer(
325 | pre_stage_channels, num_channels)
326 | self.stage4, pre_stage_channels = self._make_stage(
327 | self.stage4_cfg, num_channels, multi_scale_output=False)
328 |
329 | for head in self.heads:
330 | classes = self.heads[head]
331 | head_conv = 256
332 |
333 | fc = nn.Sequential(
334 | nn.Conv2d(pre_stage_channels[0], head_conv, kernel_size=3, padding=1, bias=True),
335 | nn.ReLU(inplace=True),
336 | nn.Conv2d(head_conv, classes, kernel_size=1, stride=1, padding=0, bias=True),
337 | )
338 |
339 | self.__setattr__(head, fc)
340 |
341 | self.pretrained_layers = extra['PRETRAINED_LAYERS']
342 |
343 | def _make_transition_layer(
344 | self, num_channels_pre_layer, num_channels_cur_layer):
345 | num_branches_cur = len(num_channels_cur_layer)
346 | num_branches_pre = len(num_channels_pre_layer)
347 |
348 | transition_layers = []
349 | for i in range(num_branches_cur):
350 | if i < num_branches_pre:
351 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
352 | transition_layers.append(
353 | nn.Sequential(
354 | nn.Conv2d(
355 | num_channels_pre_layer[i],
356 | num_channels_cur_layer[i],
357 | 3, 1, 1, bias=False
358 | ),
359 | nn.BatchNorm2d(num_channels_cur_layer[i]),
360 | nn.ReLU(inplace=True)
361 | )
362 | )
363 | else:
364 | transition_layers.append(None)
365 | else:
366 | conv3x3s = []
367 | for j in range(i+1-num_branches_pre):
368 | inchannels = num_channels_pre_layer[-1]
369 | outchannels = num_channels_cur_layer[i] \
370 | if j == i-num_branches_pre else inchannels
371 | conv3x3s.append(
372 | nn.Sequential(
373 | nn.Conv2d(
374 | inchannels, outchannels, 3, 2, 1, bias=False
375 | ),
376 | nn.BatchNorm2d(outchannels),
377 | nn.ReLU(inplace=True)
378 | )
379 | )
380 | transition_layers.append(nn.Sequential(*conv3x3s))
381 |
382 | return nn.ModuleList(transition_layers)
383 |
384 | def _make_layer(self, block, planes, blocks, stride=1):
385 | downsample = None
386 | if stride != 1 or self.inplanes != planes * block.expansion:
387 | downsample = nn.Sequential(
388 | nn.Conv2d(
389 | self.inplanes, planes * block.expansion,
390 | kernel_size=1, stride=stride, bias=False
391 | ),
392 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
393 | )
394 |
395 | layers = []
396 | layers.append(block(self.inplanes, planes, stride, downsample))
397 | self.inplanes = planes * block.expansion
398 | for i in range(1, blocks):
399 | layers.append(block(self.inplanes, planes))
400 |
401 | return nn.Sequential(*layers)
402 |
403 | def _make_stage(self, layer_config, num_inchannels,
404 | multi_scale_output=True):
405 | num_modules = layer_config['NUM_MODULES']
406 | num_branches = layer_config['NUM_BRANCHES']
407 | num_blocks = layer_config['NUM_BLOCKS']
408 | num_channels = layer_config['NUM_CHANNELS']
409 | block = blocks_dict[layer_config['BLOCK']]
410 | fuse_method = layer_config['FUSE_METHOD']
411 |
412 | modules = []
413 | for i in range(num_modules):
414 | # multi_scale_output is only used last module
415 | if not multi_scale_output and i == num_modules - 1:
416 | reset_multi_scale_output = False
417 | else:
418 | reset_multi_scale_output = True
419 |
420 | modules.append(
421 | HighResolutionModule(
422 | num_branches,
423 | block,
424 | num_blocks,
425 | num_inchannels,
426 | num_channels,
427 | fuse_method,
428 | reset_multi_scale_output
429 | )
430 | )
431 | num_inchannels = modules[-1].get_num_inchannels()
432 |
433 | return nn.Sequential(*modules), num_inchannels
434 |
435 | def forward(self, x, seg_map=None, label=None, point_list=None, target_shape=None):
436 | if target_shape is None:
437 | target_shape = x.shape[-2:]
438 |
439 | x = self.conv1(x)
440 | x = self.bn1(x)
441 | x = self.relu(x)
442 | x = self.conv2(x)
443 | x = self.bn2(x)
444 | x = self.relu(x)
445 | x = self.layer1(x)
446 |
447 | x_list = []
448 | for i in range(self.stage2_cfg['NUM_BRANCHES']):
449 | if self.transition1[i] is not None:
450 | x_list.append(self.transition1[i](x))
451 | else:
452 | x_list.append(x)
453 | y_list = self.stage2(x_list)
454 |
455 | x_list = []
456 | for i in range(self.stage3_cfg['NUM_BRANCHES']):
457 | if self.transition2[i] is not None:
458 | x_list.append(self.transition2[i](y_list[-1]))
459 | else:
460 | x_list.append(y_list[i])
461 | y_list = self.stage3(x_list)
462 |
463 | x_list = []
464 | for i in range(self.stage4_cfg['NUM_BRANCHES']):
465 | if self.transition3[i] is not None:
466 | x_list.append(self.transition3[i](y_list[-1]))
467 | else:
468 | x_list.append(y_list[i])
469 | y_list = self.stage4(x_list)
470 |
471 | results = {}
472 | for head in self.heads:
473 | results[head] = self.__getattr__(head)(y_list[0])
474 | results[head] = torch.nn.functional.interpolate(results[head],
475 | size=target_shape,
476 | mode='bilinear',
477 | align_corners=False)
478 |
479 | if label is not None: # refined label generation
480 |
481 | if self.args.sup == 'point': # point supervision setting
482 | pseudo_label = refine_label_generation_with_point(
483 | results['seg'].clone().detach(),
484 | point_list.cpu().numpy(),
485 | results['offset'].clone().detach(),
486 | label.clone().detach(),
487 | seg_map.clone().detach(),
488 | self.args,
489 | )
490 |
491 | else: # image-level supervision setting
492 | pseudo_label = refine_label_generation(
493 | results['seg'].clone().detach(),
494 | results['center'].clone().detach(),
495 | results['offset'].clone().detach(),
496 | label.clone().detach(),
497 | seg_map.clone().detach(),
498 | self.args,
499 | )
500 |
501 | return results, pseudo_label
502 |
503 | return results
504 |
505 |
506 | def init_weights(self, pretrained=''):
507 | logger.info('=> init weights from normal distribution')
508 | for n, m in self.named_modules():
509 | if isinstance(m, nn.Conv2d):
510 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
511 | nn.init.normal_(m.weight, std=0.001)
512 | for name, _ in m.named_parameters():
513 | if name in ['bias']:
514 | nn.init.constant_(m.bias, 0)
515 | elif isinstance(m, nn.BatchNorm2d):
516 | nn.init.constant_(m.weight, 1)
517 | nn.init.constant_(m.bias, 0)
518 | elif isinstance(m, nn.ConvTranspose2d):
519 | nn.init.normal_(m.weight, std=0.001)
520 | for name, _ in m.named_parameters():
521 | if name in ['bias']:
522 | nn.init.constant_(m.bias, 0)
523 |
524 | if os.path.isfile(pretrained):
525 | pretrained_state_dict = torch.load(pretrained)
526 | logger.info('=> loading pretrained model {}'.format(pretrained))
527 | print('=> loading pretrained model {}'.format(pretrained))
528 |
529 | need_init_state_dict = {}
530 | for name, m in pretrained_state_dict.items():
531 | if name.split('.')[0] in self.pretrained_layers \
532 | or self.pretrained_layers[0] is '*':
533 | need_init_state_dict[name] = m
534 | self.load_state_dict(need_init_state_dict, strict=False)
535 | else:
536 | logger.info("=> without pre-trained models")
537 | print("=> without pre-trained models")
538 |
539 |
540 | def HRNet(layer, heads, args, **kwargs):
541 | update_config(cfg, f"./models/hrnet_config/w{layer}_384x288_adam_lr1e-3.yaml")
542 | cfg.MODEL.NUM_JOINTS = 1 # num classes
543 |
544 | model = PoseHighResolutionNet(cfg, heads, args, **kwargs)
545 |
546 | model.init_weights(cfg.MODEL.PRETRAINED)
547 |
548 | return model
549 |
550 |
551 | def hrnet32(args):
552 | heads = {
553 | 'seg': args.num_classes+1,
554 | 'center': args.num_classes,
555 | 'offset': 2
556 | }
557 | model = HRNet(32, heads, args)
558 | return model
559 |
560 | def hrnet48(args):
561 | heads = {
562 | 'seg': args.num_classes+1,
563 | 'center': args.num_classes,
564 | 'offset': 2
565 | }
566 | model = HRNet(48, heads, args)
567 | return model
--------------------------------------------------------------------------------