├── .gitignore
├── LICENSE.md
├── README.md
├── config_seg.py
├── data
├── BaseDataset.py
├── cityscapes.py
├── cityscapes
│ ├── cityscapes_val_fine.txt
│ ├── cityscapes_val_fine_raw.txt
│ └── train_ClsConfSet.lst
├── gta5.py
├── labels.py
├── loader_csg.py
└── visda17.py
├── dataloader_seg.py
├── eval_seg.py
├── model
├── __init__.py
├── csg_builder.py
├── deeplab.py
└── resnet.py
├── requirements.txt
├── tools
├── datasets
│ ├── BaseDataset.py
│ └── cityscapes
│ │ ├── cityscapes.py
│ │ └── cityscapes_val_fine.txt
├── engine
│ ├── evaluator.py
│ ├── logger.py
│ └── tester.py
├── seg_opr
│ └── metric.py
└── utils
│ ├── img_utils.py
│ ├── pyt_utils.py
│ └── visualize.py
├── train.py
├── train.sh
├── train_seg.py
├── train_seg.sh
└── utils
├── __init__.py
├── augmentations.py
├── logger.py
├── sgd.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # vim swp files
2 | *.swp
3 | # caffe/pytorch model files
4 | runs/*
5 | crst_visda/runs/*
6 | *.pth
7 | *.tar
8 | *_softmax.txt
9 | *_seed*.txt
10 | *_soft*.txt
11 | *.json
12 |
13 | # Mkdocs
14 | # /docs/
15 | /mkdocs/docs/temp
16 |
17 | .DS_Store
18 | .idea
19 | .vscode
20 | .pytest_cache
21 | /experiments
22 | node_modules/
23 | history/
24 | ablation/
25 | misc/
26 | prediction/
27 | results/
28 |
29 | # resource temp folder
30 | tests/resources/temp/*
31 | !tests/resources/temp/.gitkeep
32 |
33 | # Byte-compiled / optimized / DLL files
34 | __pycache__/
35 | *.py[cod]
36 | *$py.class
37 |
38 | # C extensions
39 | *.so
40 |
41 | # Distribution / packaging
42 | .Python
43 | build/
44 | develop-eggs/
45 | dist/
46 | downloads/
47 | eggs/
48 | .eggs/
49 | lib/
50 | lib64/
51 | parts/
52 | sdist/
53 | var/
54 | wheels/
55 | *.egg-info/
56 | .installed.cfg
57 | *.egg
58 | MANIFEST
59 |
60 | # PyInstaller
61 | # Usually these files are written by a python script from a template
62 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
63 | *.manifest
64 | *.spec
65 |
66 | # Installer logs
67 | pip-log.txt
68 | pip-delete-this-directory.txt
69 |
70 | # Unit test / coverage reports
71 | htmlcov/
72 | .tox/
73 | .coverage
74 | .coverage.*
75 | .cache
76 | nosetests.xml
77 | coverage.xml
78 | *.cover
79 | .hypothesis/
80 | .pytest_cache/
81 |
82 | # Translations
83 | *.mo
84 | *.pot
85 |
86 | # Django stuff:
87 | *.log
88 | .static_storage/
89 | .media/
90 | local_settings.py
91 | local_settings.py
92 | db.sqlite3
93 |
94 | # Flask stuff:
95 | instance/
96 | .webassets-cache
97 |
98 | # Scrapy stuff:
99 | .scrapy
100 |
101 | # Sphinx documentation
102 | docs/_build/
103 |
104 | # PyBuilder
105 | target/
106 |
107 | # Jupyter Notebook
108 | .ipynb_checkpoints
109 |
110 | # pyenv
111 | .python-version
112 |
113 | # celery beat schedule file
114 | celerybeat-schedule
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | NVIDIA Source Code
2 |
3 | License for Contrastive Syn-to-Real Generalization (CSG)
4 | ---
5 |
6 | 1. Definitions
7 |
8 | "Licensor" means any person or entity that distributes its Work.
9 |
10 | "Software" means the original work of authorship made available under this License.
11 |
12 | "Work" means the Software and any additions to or derivative works of the Software that are made available under this License.
13 |
14 | The terms "reproduce," "reproduction," "derivative works," and "distribution" have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
15 |
16 | Works, including the Software, are "made available" under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
17 |
18 | 2. License Grant
19 |
20 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
21 |
22 | 3. Limitations
23 |
24 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
25 |
26 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work ("Your Terms") only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
27 |
28 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, "non-commercially" means for research or evaluation purposes only.
29 |
30 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately.
31 |
32 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
33 |
34 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately.
35 |
36 | 4. Disclaimer of Warranty.
37 |
38 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
39 |
40 | 5. Limitation of Liability.
41 |
42 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # CSG: Contrastive Syn-to-Real Generalization
4 |
5 |
6 | [Paper](https://arxiv.org/abs/2104.02290)
7 |
8 | Contrastive Syn-to-Real Generalization.
9 | [Wuyang Chen](https://chenwydj.github.io/), [Zhiding Yu](https://chrisding.github.io/), [Shalini De Mello](https://research.nvidia.com/person/shalini-gupta), [Sifei Liu](https://www.sifeiliu.net/), [Jose M. Alvarez](https://rsu.data61.csiro.au/people/jalvarez/), [Zhangyang Wang](https://www.atlaswang.com/), [Anima Anandkumar](http://tensorlab.cms.caltech.edu/users/anima/).
10 | In ICLR 2021.
11 |
12 | * Visda-17 to COCO
13 | - [x] train resnet101 with CSG
14 | - [x] evaluation
15 | * GTA5 to Cityscapes
16 | - [x] train deeplabv2 (resnet50/resnet101) with CSG
17 | - [x] evaluation
18 |
19 | ## Usage
20 |
21 | ### Visda-17
22 | * Download [Visda-17 Dataset](http://ai.bu.edu/visda-2017/#download)
23 |
24 | #### Evaluation
25 | * Download [pretrained ResNet101 on Visda17](https://drive.google.com/file/d/1VdbrwevsYy7I5S3Wo7-S3MwrZZjj09QS/view?usp=sharing)
26 | * Put the checkpoint under `./CSG/pretrained/`
27 | * Put the code below in `train.sh`
28 | ```bash
29 | python train.py \
30 | --epochs 30 \
31 | --batch-size 32 \
32 | --lr 1e-4 \
33 | --rand_seed 0 \
34 | --csg 0.1 \
35 | --apool \
36 | --augment \
37 | --csg-stages 3.4 \
38 | --factor 0.1 \
39 | --resume pretrained/csg_res101_vista17_best.pth.tar \
40 | --evaluate
41 | ```
42 | * Run `CUDA_VISIBLE_DEVICES=0 bash train.sh`
43 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
44 |
45 | #### Train with CSG
46 | * Put the code below in `train.sh`
47 | ```bash
48 | python train.py \
49 | --epochs 30 \
50 | --batch-size 32 \
51 | --lr 1e-4 \
52 | --rand_seed 0 \
53 | --csg 0.1 \
54 | --apool \
55 | --augment \
56 | --csg-stages 3.4 \
57 | --factor 0.1 \
58 | ```
59 | * Run `CUDA_VISIBLE_DEVICES=0 bash train.sh`
60 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
61 |
62 |
63 | ### GTA5 → Cityscapes
64 | * Download [GTA5 dataset](https://download.visinf.tu-darmstadt.de/data/from_games/).
65 | * Download the [leftImg8bit_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=3) and [gtFine_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=1) from the Cityscapes.
66 | * Prepare the annotations by using the [createTrainIdLabelImgs.py](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createTrainIdLabelImgs.py).
67 | * Put the [file of image list](tools/datasets/cityscapes/) into where you save the dataset.
68 | * **Remember to properly set the `C.dataset_path` in the `config_seg.py` to the path where datasets reside.**
69 |
70 | #### Evaluation
71 | * Download pretrained [DeepLabV2-ResNet50](https://drive.google.com/file/d/1E2CosTtGVgIe6BfLBV9vNmyj6l9aYUbk/view?usp=sharing) and [DeepLabV2-ResNet101](https://drive.google.com/file/d/17Pe86m4OCGMFLcxLl_V-1bcG5otOqdvb/view?usp=sharing) on GTA5
72 | * Put the checkpoint under `./CSG/pretrained/`
73 | * Put the code below in `train_seg.sh`
74 | ```bash
75 | python train_seg.py \
76 | --epochs 50 \
77 | --switch-model deeplab50 \
78 | --batch-size 6 \
79 | --lr 1e-3 \
80 | --num-class 19 \
81 | --gpus 0 \
82 | --factor 0.1 \
83 | --csg 75 \
84 | --apool \
85 | --csg-stages 3.4 \
86 | --chunks 8 \
87 | --augment \
88 | --evaluate \
89 | --resume pretrained/csg_res101_segmentation_best.pth.tar \
90 | ```
91 | * Change `--switch-model` (`deeplab50` or `deeplab101`) and `--resume` (path to pretrained checkpoints) accordingly.
92 | * Run `CUDA_VISIBLE_DEVICES=0 bash train_seg.sh`
93 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
94 |
95 | #### Train with CSG
96 | * Put the code below in `train_seg.sh`
97 | ```bash
98 | python train_seg.py \
99 | --epochs 50 \
100 | --switch-model deeplab50 \
101 | --batch-size 6 \
102 | --lr 1e-3 \
103 | --num-class 19 \
104 | --gpus 0 \
105 | --factor 0.1 \
106 | --csg 75 \
107 | --apool \
108 | --csg-stages 3.4 \
109 | --chunks 8 \
110 | --augment
111 | ```
112 | * Change `--switch-model` (`deeplab50` or `deeplab101`) accordingly.
113 | * Run `CUDA_VISIBLE_DEVICES=0 bash train_seg.sh`
114 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
115 |
116 |
117 | ## Citation
118 |
119 | If you use this code for your research, please cite:
120 |
121 | ```BibTeX
122 | @article{chen2021contrastive,
123 | title={Contrastive syn-to-real generalization},
124 | author={Chen, Wuyang and Yu, Zhiding and Mello, SD and Liu, Sifei and Alvarez, Jose M and Wang, Zhangyang and Anandkumar, Anima},
125 | year={2021},
126 | publisher={ICLR}
127 | }
128 | ```
129 |
--------------------------------------------------------------------------------
/config_seg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | # encoding: utf-8
5 |
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import os
11 | import os.path as osp
12 | import sys
13 | import numpy as np
14 | from easydict import EasyDict as edict
15 |
16 | C = edict()
17 | config = C
18 | cfg = C
19 |
20 | """please config ROOT_dir and user when u first using"""
21 | C.repo_name = 'CSG'
22 | C.abs_dir = osp.realpath(".")
23 | C.this_dir = C.abs_dir.split(osp.sep)[-1]
24 | C.root_dir = C.abs_dir[:C.abs_dir.index(C.repo_name) + len(C.repo_name)]
25 |
26 | """Data Dir"""
27 | # C.dataset_path = "/raid/"
28 | C.dataset_path = "/home/chenwy/"
29 |
30 | C.train_img_root = os.path.join(C.dataset_path, "gta5")
31 | C.train_gt_root = os.path.join(C.dataset_path, "gta5")
32 | C.val_img_root = os.path.join(C.dataset_path, "cityscapes")
33 | C.val_gt_root = os.path.join(C.dataset_path, "cityscapes")
34 | C.test_img_root = os.path.join(C.dataset_path, "cityscapes")
35 | C.test_gt_root = os.path.join(C.dataset_path, "cityscapes")
36 |
37 | C.train_source = osp.join(C.train_img_root, "gta5_train.txt")
38 | C.train_target_source = osp.join(C.train_img_root, "cityscapes_train_fine.txt")
39 | C.eval_source = osp.join(C.val_img_root, "cityscapes_val_fine.txt")
40 | C.test_source = osp.join(C.test_img_root, "cityscapes_test.txt")
41 |
42 | """Image Config"""
43 | C.num_classes = 19
44 | C.background = -1
45 | C.image_mean = np.array([0.485, 0.456, 0.406])
46 | C.image_std = np.array([0.229, 0.224, 0.225])
47 | C.down_sampling_train = [1, 1] # first down_sampling then crop
48 | C.down_sampling_val = [1, 1] # first down_sampling then crop
49 | C.gt_down_sampling = 1
50 | C.num_train_imgs = 12403
51 | C.num_eval_imgs = 500
52 |
53 | """ Settings for network, this would be different for each kind of model"""
54 | C.bn_eps = 1e-5
55 | C.bn_momentum = 0.1
56 |
57 | """Train Config"""
58 | C.lr = 0.01
59 | C.momentum = 0.9
60 | C.weight_decay = 5e-4
61 | C.nepochs = 30
62 | C.niters_per_epoch = 2000
63 | C.num_workers = 16
64 | C.train_scale_array = [0.75, 1, 1.25]
65 |
66 | """Eval Config"""
67 | C.eval_stride_rate = 5 / 6
68 | C.eval_scale_array = [1]
69 | C.eval_flip = True
70 | C.eval_base_size = 1024
71 | C.eval_crop_size = 1024
72 | C.eval_height = 1024
73 | C.eval_width = 2048
74 |
75 | # GTA5: 1052x1914
76 | C.image_height = 512
77 | C.image_width = 512
78 | C.is_test = False
79 | C.is_eval = False
80 |
--------------------------------------------------------------------------------
/data/BaseDataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import os
5 | import cv2
6 | import torch
7 | import numpy as np
8 | from random import shuffle
9 | from pdb import set_trace as bp
10 | import torch.utils.data as data
11 | cv2.setNumThreads(0)
12 |
13 |
14 | class BaseDataset(data.Dataset):
15 | def __init__(self, setting, split_name, preprocess=None, file_length=None):
16 | super(BaseDataset, self).__init__()
17 | self._split_name = split_name
18 | if split_name == 'train':
19 | self._img_path = setting['train_img_root']
20 | self._gt_path = setting['train_gt_root']
21 | elif split_name == 'val':
22 | self._img_path = setting['val_img_root']
23 | self._gt_path = setting['val_gt_root']
24 | elif split_name == 'test':
25 | self._img_path = setting['test_img_root']
26 | self._gt_path = setting['test_gt_root']
27 | self._train_source = setting['train_source']
28 | self._eval_source = setting['eval_source']
29 | self._test_source = setting['test_source'] if 'test_source' in setting else setting['eval_source']
30 | self._down_sampling = setting['down_sampling_train'] if split_name == 'train' else setting['down_sampling_val']
31 | print("using downsampling:", self._down_sampling)
32 | self._file_names = self._get_file_names(split_name)
33 | print("Found %d images"%len(self._file_names))
34 | self._file_length = file_length
35 | self.preprocess = preprocess
36 |
37 | def __len__(self):
38 | if self._file_length is not None:
39 | return self._file_length
40 | return len(self._file_names)
41 |
42 | def __getitem__(self, index):
43 | if self._file_length is not None:
44 | names = self._construct_new_file_names(self._file_length)[index]
45 | else:
46 | names = self._file_names[index]
47 | img_path = os.path.join(self._img_path, names[0])
48 | gt_path = os.path.join(self._gt_path, names[1])
49 | item_name = names[1].split("/")[-1].split(".")[0]
50 | img, gt = self._fetch_data(img_path, gt_path)
51 | img = img[:, :, ::-1]
52 | if self.preprocess is not None:
53 | img, gt, extra_dict = self.preprocess(img, gt)
54 |
55 | if self._split_name is 'train':
56 | img = torch.from_numpy(np.ascontiguousarray(img)).float()
57 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
58 | if self.preprocess is not None and extra_dict is not None:
59 | for k, v in extra_dict.items():
60 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v))
61 | if 'label' in k:
62 | extra_dict[k] = extra_dict[k].long()
63 | if 'img' in k:
64 | extra_dict[k] = extra_dict[k].float()
65 |
66 | output_dict = dict(data=img, label=gt, fn=str(item_name), n=len(self._file_names))
67 | if self.preprocess is not None and extra_dict is not None:
68 | output_dict.update(**extra_dict)
69 |
70 | return output_dict
71 |
72 | def _fetch_data(self, img_path, gt_path, dtype=None):
73 | img = self._open_image(img_path, down_sampling=self._down_sampling[0])
74 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype, down_sampling=self._down_sampling[1])
75 |
76 | return img, gt
77 |
78 | def _get_file_names(self, split_name):
79 | assert split_name in ['train', 'val', 'test']
80 | source = self._train_source
81 | if split_name == "val":
82 | source = self._eval_source
83 | elif split_name == 'test':
84 | source = self._test_source
85 |
86 | file_names = []
87 | with open(source) as f:
88 | files = f.readlines()
89 |
90 | for item in files:
91 | img_name, gt_name = self._process_item_names(item)
92 | file_names.append([img_name, gt_name])
93 |
94 | return file_names
95 |
96 | def _construct_new_file_names(self, length):
97 | assert isinstance(length, int)
98 | files_len = len(self._file_names)
99 | new_file_names = self._file_names * (length // files_len)
100 |
101 | rand_indices = torch.randperm(files_len).tolist()
102 | new_indices = rand_indices[:length % files_len]
103 |
104 | new_file_names += [self._file_names[i] for i in new_indices]
105 |
106 | return new_file_names
107 |
108 | @staticmethod
109 | def _process_item_names(item):
110 | item = item.strip()
111 | # item = item.split('\t')
112 | item = item.split(' ')
113 | img_name = item[0]
114 | gt_name = item[1]
115 |
116 | return img_name, gt_name
117 |
118 | def get_length(self):
119 | return self.__len__()
120 |
121 | @staticmethod
122 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None, down_sampling=1):
123 | # cv2: B G R
124 | # h w c
125 | img = np.array(cv2.imread(filepath, mode), dtype=dtype)
126 | if isinstance(down_sampling, int):
127 | try:
128 | H, W = img.shape[:2]
129 | except Exception:
130 | print(img.shape, filepath)
131 | exit(0)
132 | if len(img.shape) == 3:
133 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_LINEAR)
134 | else:
135 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_NEAREST)
136 | assert img.shape[0] == H // down_sampling and img.shape[1] == W // down_sampling
137 | else:
138 | assert (isinstance(down_sampling, tuple) or isinstance(down_sampling, list)) and len(down_sampling) == 2
139 | if len(img.shape) == 3:
140 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_LINEAR)
141 | else:
142 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_NEAREST)
143 | assert img.shape[0] == down_sampling[0] and img.shape[1] == down_sampling[1]
144 |
145 | return img
146 |
147 | @classmethod
148 | def get_class_colors(*args):
149 | raise NotImplementedError
150 |
151 | @classmethod
152 | def get_class_names(*args):
153 | raise NotImplementedError
154 |
155 |
156 | if __name__ == "__main__":
157 | data_setting = {'img_root': '',
158 | 'gt_root': '',
159 | 'train_source': '',
160 | 'eval_source': ''}
161 | bd = BaseDataset(data_setting, 'train', None)
162 | print(bd.get_class_names())
163 |
--------------------------------------------------------------------------------
/data/cityscapes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 | from data.BaseDataset import BaseDataset
6 |
7 | class Cityscapes(BaseDataset):
8 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
9 | 28, 31, 32, 33]
10 |
11 | @classmethod
12 | def get_class_colors(*args):
13 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
14 | [102, 102, 156], [190, 153, 153], [153, 153, 153],
15 | [250, 170, 30], [220, 220, 0], [107, 142, 35],
16 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
17 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
18 | [0, 0, 230], [119, 11, 32]]
19 |
20 | @classmethod
21 | def get_class_names(*args):
22 | # class counting(gtFine)
23 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832
24 | # 359 274 142 513 1646
25 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
26 | 'traffic light', 'traffic sign',
27 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
28 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
29 |
30 | @classmethod
31 | def transform_label(cls, pred, name):
32 | label = np.zeros(pred.shape)
33 | ids = np.unique(pred)
34 | for id in ids:
35 | label[np.where(pred == id)] = cls.trans_labels[id]
36 |
37 | new_name = (name.split('.')[0]).split('_')[:-1]
38 | new_name = '_'.join(new_name) + '.png'
39 |
40 | print('Trans', name, 'to', new_name, ' ',
41 | np.unique(np.array(pred, np.uint8)), ' ---------> ',
42 | np.unique(np.array(label, np.uint8)))
43 | return label, new_name
44 |
--------------------------------------------------------------------------------
/data/gta5.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 | from data.BaseDataset import BaseDataset
6 |
7 | class GTA5(BaseDataset):
8 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
9 | 28, 31, 32, 33]
10 |
11 | @classmethod
12 | def get_class_colors(*args):
13 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
14 | [102, 102, 156], [190, 153, 153], [153, 153, 153],
15 | [250, 170, 30], [220, 220, 0], [107, 142, 35],
16 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
17 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
18 | [0, 0, 230], [119, 11, 32]]
19 |
20 | @classmethod
21 | def get_class_names(*args):
22 | # class counting(gtFine)
23 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832
24 | # 359 274 142 513 1646
25 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
26 | 'traffic light', 'traffic sign',
27 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
28 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
29 |
30 | @classmethod
31 | def transform_label(cls, pred, name):
32 | label = np.zeros(pred.shape)
33 | ids = np.unique(pred)
34 | for id in ids:
35 | label[np.where(pred == id)] = cls.trans_labels[id]
36 |
37 | new_name = (name.split('.')[0]).split('_')[:-1]
38 | new_name = '_'.join(new_name) + '.png'
39 |
40 | print('Trans', name, 'to', new_name, ' ',
41 | np.unique(np.array(pred, np.uint8)), ' ---------> ',
42 | np.unique(np.array(label, np.uint8)))
43 | return label, new_name
44 |
--------------------------------------------------------------------------------
/data/labels.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | #!/usr/bin/python
5 | #
6 | # Cityscapes labels
7 | #
8 |
9 | from collections import namedtuple
10 |
11 |
12 | #--------------------------------------------------------------------------------
13 | # Definitions
14 | #--------------------------------------------------------------------------------
15 |
16 | # a label and all meta information
17 | Label = namedtuple( 'Label' , [
18 |
19 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
20 | # We use them to uniquely name a class
21 |
22 | 'id' , # An integer ID that is associated with this label.
23 | # The IDs are used to represent the label in ground truth images
24 | # An ID of -1 means that this label does not have an ID and thus
25 | # is ignored when creating ground truth images (e.g. license plate).
26 | # Do not modify these IDs, since exactly these IDs are expected by the
27 | # evaluation server.
28 |
29 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
30 | # ground truth images with train IDs, using the tools provided in the
31 | # 'preparation' folder. However, make sure to validate or submit results
32 | # to our evaluation server using the regular IDs above!
33 | # For trainIds, multiple labels might have the same ID. Then, these labels
34 | # are mapped to the same class in the ground truth images. For the inverse
35 | # mapping, we use the label that is defined first in the list below.
36 | # For example, mapping all void-type classes to the same ID in training,
37 | # might make sense for some approaches.
38 | # Max value is 255!
39 |
40 | 'category' , # The name of the category that this label belongs to
41 |
42 | 'categoryId' , # The ID of this category. Used to create ground truth images
43 | # on category level.
44 |
45 | 'hasInstances', # Whether this label distinguishes between single instances or not
46 |
47 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
48 | # during evaluations or not
49 |
50 | 'color' , # The color of this label
51 | ] )
52 |
53 |
54 | #--------------------------------------------------------------------------------
55 | # A list of all labels
56 | #--------------------------------------------------------------------------------
57 |
58 | # Please adapt the train IDs as appropriate for you approach.
59 | # Note that you might want to ignore labels with ID 255 during training.
60 | # Further note that the current train IDs are only a suggestion. You can use whatever you like.
61 | # Make sure to provide your results using the original IDs and not the training IDs.
62 | # Note that many IDs are ignored in evaluation and thus you never need to predict these!
63 |
64 | labels = [
65 | # name id trainId category catId hasInstances ignoreInEval color
66 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
67 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
68 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
69 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
70 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
71 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
72 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
73 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
74 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
75 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
76 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
77 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
78 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
79 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
80 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
81 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
82 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
83 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
84 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
85 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
86 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
87 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
88 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
89 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
90 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
91 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
92 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
93 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
94 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
95 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
96 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
97 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
98 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
99 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
100 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
101 | ]
102 |
103 |
104 | #--------------------------------------------------------------------------------
105 | # Create dictionaries for a fast lookup
106 | #--------------------------------------------------------------------------------
107 |
108 | # Please refer to the main method below for example usages!
109 |
110 | # name to label object
111 | name2label = { label.name : label for label in labels }
112 | # id to label object
113 | id2label = { label.id : label for label in labels }
114 | # trainId to label object
115 | trainId2label = { label.trainId : label for label in reversed(labels) }
116 | # category to list of label objects
117 | category2labels = {}
118 | for label in labels:
119 | category = label.category
120 | if category in category2labels:
121 | category2labels[category].append(label)
122 | else:
123 | category2labels[category] = [label]
124 |
125 | #--------------------------------------------------------------------------------
126 | # Assure single instance name
127 | #--------------------------------------------------------------------------------
128 |
129 | # returns the label name that describes a single instance (if possible)
130 | # e.g. input | output
131 | # ----------------------
132 | # car | car
133 | # cargroup | car
134 | # foo | None
135 | # foogroup | None
136 | # skygroup | None
137 | def assureSingleInstanceName( name ):
138 | # if the name is known, it is not a group
139 | if name in name2label:
140 | return name
141 | # test if the name actually denotes a group
142 | if not name.endswith("group"):
143 | return None
144 | # remove group
145 | name = name[:-len("group")]
146 | # test if the new name exists
147 | if not name in name2label:
148 | return None
149 | # test if the new name denotes a label that actually has instances
150 | if not name2label[name].hasInstances:
151 | return None
152 | # all good then
153 | return name
154 |
155 | #--------------------------------------------------------------------------------
156 | # Main for testing
157 | #--------------------------------------------------------------------------------
158 |
159 | # just a dummy main
160 | if __name__ == "__main__":
161 | # Print all the labels
162 | print("List of cityscapes labels:")
163 | print("")
164 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))
165 | print(" " + ('-' * 98))
166 | for label in labels:
167 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))
168 | print("")
169 |
170 | print("Example usages:")
171 |
172 | # Map from name to label
173 | name = 'car'
174 | id = name2label[name].id
175 | print("ID of label '{name}': {id}".format( name=name, id=id ))
176 |
177 | # Map from ID to label
178 | category = id2label[id].category
179 | print("Category of label with ID '{id}': {category}".format( id=id, category=category ))
180 |
181 | # Map from trainID to label
182 | trainId = 0
183 | name = trainId2label[trainId].name
184 | print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))
185 |
--------------------------------------------------------------------------------
/data/loader_csg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import warnings
5 | from PIL import ImageFilter, Image
6 | import math
7 | import random
8 | import torch
9 | from torchvision.transforms import functional as F
10 |
11 | _pil_interpolation_to_str = {
12 | Image.NEAREST: 'PIL.Image.NEAREST',
13 | Image.BILINEAR: 'PIL.Image.BILINEAR',
14 | Image.BICUBIC: 'PIL.Image.BICUBIC',
15 | Image.LANCZOS: 'PIL.Image.LANCZOS',
16 | Image.HAMMING: 'PIL.Image.HAMMING',
17 | Image.BOX: 'PIL.Image.BOX',
18 | }
19 |
20 |
21 | def _get_image_size(img):
22 | if F._is_pil_image(img):
23 | return img.size
24 | elif isinstance(img, torch.Tensor) and img.dim() > 2:
25 | return img.shape[-2:][::-1]
26 | else:
27 | raise TypeError("Unexpected type {}".format(type(img)))
28 |
29 |
30 | class RandomResizedCrop_two(object):
31 | # generate two closely located patches
32 | """Crop the given PIL Image to random size and aspect ratio.
33 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random
34 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
35 | is finally resized to given size.
36 | This is popularly used to train the Inception networks.
37 | Args:
38 | size: expected output size of each edge
39 | scale: range of size of the origin size cropped
40 | ratio: range of aspect ratio of the origin aspect ratio cropped
41 | interpolation: Default: PIL.Image.BILINEAR
42 | """
43 |
44 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
45 | if isinstance(size, (tuple, list)):
46 | self.size = size
47 | else:
48 | self.size = (size, size)
49 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
50 | warnings.warn("range should be of kind (min, max)")
51 |
52 | self.interpolation = interpolation
53 | self.scale = scale
54 | self.ratio = ratio
55 |
56 | @staticmethod
57 | def get_params(img, scale, ratio, augment=(0.025, 0.075)):
58 | """Get parameters for ``crop`` for a random sized crop.
59 | Args:
60 | img (PIL Image): Image to be cropped.
61 | scale (tuple): range of size of the origin size cropped
62 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
63 | Returns:
64 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
65 | sized crop.
66 | """
67 | width, height = _get_image_size(img)
68 | area = height * width
69 |
70 | for _ in range(10):
71 | target_area = random.uniform(*scale) * area
72 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
73 | aspect_ratio = math.exp(random.uniform(*log_ratio))
74 |
75 | w = int(round(math.sqrt(target_area * aspect_ratio)))
76 | h = int(round(math.sqrt(target_area / aspect_ratio)))
77 |
78 | if 0 < w <= width and 0 < h <= height:
79 | i = random.randint(0, height - h)
80 | j = random.randint(0, width - w)
81 | # return i, j, h, w
82 | ##### augment #####
83 | delta = random.randint(int(h*augment[0]), int(h*augment[1])) * random.choice([-1, 1])
84 | h_a = h + delta; h_a = min(max(1, h_a), height)
85 | delta = random.randint(int(w*augment[0]), int(w*augment[1])) * random.choice([-1, 1])
86 | w_a = w + delta; w_a = min(max(1, w_a), width)
87 | delta = random.randint(int(h*augment[0]), int(h*augment[1])) * random.choice([-1, 1])
88 | i_a = i + delta; i_a = min(max(0, i_a), height - h_a)
89 | delta = random.randint(int(w*augment[0]), int(w*augment[1])) * random.choice([-1, 1])
90 | j_a = j + delta; j_a = min(max(0, j_a), width - w_a)
91 | ###################
92 | return i, j, h, w, i_a, j_a, h_a, w_a
93 |
94 | # Fallback to central crop
95 | in_ratio = float(width) / float(height)
96 | if (in_ratio < min(ratio)):
97 | w = width
98 | h = int(round(w / min(ratio)))
99 | elif (in_ratio > max(ratio)):
100 | h = height
101 | w = int(round(h * max(ratio)))
102 | else: # whole image
103 | w = width
104 | h = height
105 | i = (height - h) // 2
106 | j = (width - w) // 2
107 | ##### augment #####
108 | delta = random.randint(int(h*augment[0]), int(h*augment[1])) * random.choice([-1, 1])
109 | h_a = h + delta; h_a = min(max(1, h_a), height)
110 | delta = random.randint(int(w*augment[0]), int(w*augment[1])) * random.choice([-1, 1])
111 | w_a = w + delta; w_a = min(max(1, w_a), width)
112 | delta = random.randint(int(h*augment[0]), int(h*augment[1])) * random.choice([-1, 1])
113 | i_a = i + delta; i_a = min(max(0, i_a), height - h_a)
114 | delta = random.randint(int(w*augment[0]), int(w*augment[1])) * random.choice([-1, 1])
115 | j_a = j + delta; j_a = min(max(0, j_a), width - w_a)
116 | ###################
117 | return i, j, h, w, i_a, j_a, h_a, w_a
118 |
119 | def __call__(self, img):
120 | """
121 | Args:
122 | img (PIL Image): Image to be cropped and resized.
123 | Returns:
124 | PIL Image: Randomly cropped and resized image.
125 | """
126 | i, j, h, w, i_a, j_a, h_a, w_a = self.get_params(img, self.scale, self.ratio)
127 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), F.resized_crop(img, i_a, j_a, h_a, w_a, self.size, self.interpolation)
128 |
129 | def __repr__(self):
130 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
131 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
132 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
133 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
134 | format_string += ', interpolation={0})'.format(interpolate_str)
135 | return format_string
136 |
137 |
138 | class ImageTransform:
139 | """return both image and tensor"""
140 |
141 | def __init__(self, transform):
142 | self.base_transform = transform[0] # resize, centercrop
143 | self.totensor_norm = transform[1] # totensor, **normalize**
144 |
145 | def __call__(self, x):
146 | image = self.base_transform(x)
147 | tensor = self.totensor_norm(image)
148 | return [tensor, F.to_tensor(image)]
149 |
150 |
151 | class TwoCropsTransform:
152 | """Take two random crops of one image as the query and key."""
153 |
154 | def __init__(self, q_transform, k_transform):
155 | self.q_transform = q_transform
156 | self.k_transform = k_transform
157 |
158 | def __call__(self, x):
159 | q = self.q_transform(x)
160 | k = self.k_transform(x)
161 | return [q, k]
162 |
163 |
164 | class GaussianBlur(object):
165 | def __init__(self, sigma=[.1, 2.]):
166 | self.sigma = sigma
167 |
168 | def __call__(self, x):
169 | sigma = random.uniform(self.sigma[0], self.sigma[1])
170 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
171 | return x
172 |
--------------------------------------------------------------------------------
/data/visda17.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import os
5 | from PIL import Image
6 | import numpy as np
7 | import random
8 | import torch
9 | from torch.utils.data import Dataset
10 | import torchvision.transforms as transforms
11 | from pdb import set_trace as bp
12 |
13 |
14 | class VisDA17(Dataset):
15 |
16 | def __init__(self, txt_file, root_dir, transform=transforms.ToTensor(), label_one_hot=False, portion=1.0):
17 | """
18 | Args:
19 | txt_file (string): Path to the txt file with annotations.
20 | root_dir (string): Directory with all the images.
21 | transform (callable, optional): Optional transform to be applied
22 | on a sample.
23 | """
24 | self.lines = open(txt_file, 'r').readlines()
25 | self.root_dir = root_dir
26 | self.transform = transform
27 | self.label_one_hot = label_one_hot
28 | self.portion = portion
29 | self.number_classes = 12
30 | assert portion != 0
31 | if self.portion > 0:
32 | self.lines = self.lines[:round(self.portion * len(self.lines))]
33 | else:
34 | self.lines = self.lines[round(self.portion * len(self.lines)):]
35 |
36 | def __len__(self):
37 | return len(self.lines)
38 |
39 | def __getitem__(self, idx):
40 | line = str.split(self.lines[idx])
41 | path_img = os.path.join(self.root_dir, line[0])
42 | image = Image.open(path_img)
43 | image = image.convert('RGB')
44 | if self.transform:
45 | image = self.transform(image)
46 | if self.label_one_hot:
47 | label = np.zeros(12, np.float32)
48 | label[np.asarray(line[1], dtype=np.int)] = 1
49 | else:
50 | label = np.asarray(line[1], dtype=np.int)
51 | label = torch.from_numpy(label)
52 | return image, label
53 |
--------------------------------------------------------------------------------
/dataloader_seg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.))
3 |
4 | import cv2
5 | from torch.utils import data
6 | from PIL import Image
7 | import numpy as np
8 | from tools.utils.img_utils import random_scale, random_mirror, normalize, generate_random_crop_pos, random_crop_pad_to_shape
9 | import torchvision.transforms as transforms
10 | cv2.setNumThreads(0)
11 |
12 |
13 | class TrainPre(object):
14 | def __init__(self, config, img_mean, img_std, augment=None):
15 | self.img_mean = img_mean
16 | self.img_std = img_std
17 | self.config = config
18 | self.augment = augment
19 |
20 | # we have func normalize below; return npy
21 | if augment:
22 | self.data_transforms = transforms.Compose([
23 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
24 | ])
25 |
26 | def __call__(self, img, gt):
27 | img, gt = random_mirror(img, gt)
28 | if self.config.train_scale_array is not None:
29 | img, gt, scale = random_scale(img, gt, self.config.train_scale_array)
30 |
31 | crop_size = (self.config.image_height, self.config.image_width)
32 | crop_pos = generate_random_crop_pos(img.shape[:2], crop_size)
33 | if self.augment:
34 | p_img, _ = random_crop_pad_to_shape(normalize(img, self.img_mean, self.img_std), crop_pos, crop_size, 0)
35 | p_img_k, _ = random_crop_pad_to_shape(normalize(np.array(
36 | self.data_transforms(Image.fromarray(img))
37 | ), self.img_mean, self.img_std), crop_pos, crop_size, 0)
38 | p_img = p_img.transpose(2, 0, 1)
39 | p_img_k = p_img_k.transpose(2, 0, 1)
40 | extra_dict = {'img_k': p_img_k}
41 | else:
42 | p_img, _ = random_crop_pad_to_shape(normalize(img, self.img_mean, self.img_std), crop_pos, crop_size, 0)
43 | p_img = p_img.transpose(2, 0, 1)
44 | extra_dict = None
45 | p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 255)
46 | p_gt = cv2.resize(p_gt, (self.config.image_width // self.config.gt_down_sampling, self.config.image_height // self.config.gt_down_sampling), interpolation=cv2.INTER_NEAREST)
47 |
48 | return p_img, p_gt, extra_dict
49 |
50 |
51 | def get_train_loader(config, dataset, worker=None, test=False, augment=None):
52 | data_setting = {
53 | 'train_img_root': config.train_img_root,
54 | 'train_gt_root': config.train_gt_root,
55 | 'val_img_root': config.val_img_root,
56 | 'val_gt_root': config.val_gt_root,
57 | 'train_source': config.train_source,
58 | 'eval_source': config.eval_source,
59 | 'down_sampling_train': config.down_sampling_train
60 | }
61 | if test:
62 | data_setting = {'img_root': config.img_root,
63 | 'gt_root': config.gt_root,
64 | 'train_source': config.train_eval_source,
65 | 'eval_source': config.eval_source}
66 | train_preprocess = TrainPre(config, config.image_mean, config.image_std, augment)
67 |
68 | train_dataset = dataset(data_setting, "train", train_preprocess, config.batch_size * config.niters_per_epoch)
69 |
70 | is_shuffle = True
71 | batch_size = config.batch_size
72 |
73 | train_loader = data.DataLoader(train_dataset,
74 | batch_size=batch_size,
75 | num_workers=config.num_workers if worker is None else worker,
76 | drop_last=True,
77 | shuffle=is_shuffle,
78 | pin_memory=True,
79 | )
80 |
81 | return train_loader
82 |
--------------------------------------------------------------------------------
/eval_seg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.))
3 |
4 | #!/usr/bin/env python3
5 | # encoding: utf-8
6 | import os
7 | import cv2
8 | import numpy as np
9 | from pdb import set_trace as bp
10 | from tools.utils.visualize import print_iou, show_img, show_prediction
11 | from tools.engine.evaluator import Evaluator
12 | from tools.engine.logger import get_logger
13 | from tools.seg_opr.metric import hist_info, compute_score
14 |
15 | cv2.setNumThreads(0)
16 | logger = get_logger()
17 |
18 |
19 | class SegEvaluator(Evaluator):
20 | def func_per_iteration(self, data, device, iter=None):
21 | if self.config is not None: config = self.config
22 | img = data['data']
23 | label = data['label']
24 | name = data['fn']
25 |
26 | if len(config.eval_scale_array) == 1:
27 | pred = self.whole_eval(img, label.shape, resize=config.eval_scale_array[0], device=device)
28 | pred = pred.argmax(2) # since we ignore this step in evaluator.py
29 | elif len(config.eval_scale_array) > 1:
30 | pred = self.whole_eval(img, label.shape, resize=config.eval_scale_array[0], device=device)
31 | for scale in config.eval_scale_array[1:]:
32 | pred += self.whole_eval(img, label.shape, resize=scale, device=device)
33 | pred = pred.argmax(2) # since we ignore this step in evaluator.py
34 | else:
35 | pred = self.sliding_eval(img, config.eval_crop_size, config.eval_stride_rate, device)
36 | hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes, pred, label)
37 | results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp}
38 |
39 | if self.save_path is not None:
40 | fn = name + '.png'
41 | cv2.imwrite(os.path.join(self.save_path, fn), pred)
42 | logger.info('Save the image ' + fn)
43 |
44 | # tensorboard logger does not fit multiprocess
45 | if self.logger is not None and iter is not None:
46 | colors = self.dataset.get_class_colors()
47 | image = img
48 | clean = np.zeros(label.shape)
49 | comp_img = show_img(colors, config.background, image, clean, label, pred)
50 | self.logger.add_image('vis', np.swapaxes(np.swapaxes(comp_img, 0, 2), 1, 2), iter)
51 |
52 | if self.show_image or self.show_prediction:
53 | colors = self.dataset.get_class_colors()
54 | image = img
55 | clean = np.zeros(label.shape)
56 | if self.show_image:
57 | comp_img = show_img(colors, config.background, image, clean, label, pred)
58 | cv2.imwrite(os.path.join(self.save_path, name + ".png"), comp_img[:,:,::-1])
59 | if self.show_prediction:
60 | comp_img = show_prediction(colors, config.background, image, pred)
61 | cv2.imwrite(os.path.join(self.save_path, "viz_"+name+".png"), comp_img[:,:,::-1])
62 |
63 | return results_dict
64 |
65 | def compute_metric(self, results):
66 | hist = np.zeros((self.config.num_classes, self.config.num_classes))
67 | correct = 0
68 | labeled = 0
69 | count = 0
70 | for d in results:
71 | hist += d['hist']
72 | correct += d['correct']
73 | labeled += d['labeled']
74 | count += 1
75 |
76 | iu, mean_IU, mean_IU_no_back, mean_pixel_acc = compute_score(hist, correct, labeled)
77 | result_line = print_iou(iu, mean_pixel_acc, self.dataset.get_class_names(), True)
78 | return result_line, mean_IU
79 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
--------------------------------------------------------------------------------
/model/csg_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import torch
5 | import torch.nn as nn
6 | from pdb import set_trace as bp
7 |
8 |
9 | def chunk_feature(feature, chunk):
10 | if chunk == 1:
11 | return feature
12 | # B x C x H x W => (B*chunk^2) x C x (H//chunk) x (W//chunk)
13 | _f_new = torch.chunk(feature, chunk, dim=2)
14 | _f_new = [torch.chunk(f, chunk, dim=3) for f in _f_new]
15 | f_new = []
16 | for f in _f_new:
17 | f_new += f
18 | f_new = torch.cat(f_new, dim=0)
19 | return f_new
20 |
21 |
22 | class CSG(nn.Module):
23 | def __init__(self, base_encoder, get_head=None, dim=128, K=65536, m=0.999, T=0.07, mlp=True, stages=[4], num_class=12, chunks=[1], task='new',
24 | base_encoder_kwargs={}, apool=True
25 | ):
26 | """
27 | dim: feature dimension (default: 128)
28 | K: queue size; number of negative keys (default: 65536)
29 | m: momentum of updating key encoder (default: 0.999)
30 | T: softmax temperature (default: 0.07)
31 | """
32 | super(CSG, self).__init__()
33 |
34 | self.K = K
35 | self.m = m
36 | self.T = T
37 | self.stages = stages
38 | self.mlp = mlp
39 | self.base_encoder = base_encoder
40 | self.chunks = chunks # chunk feature (segmentation)
41 | self.task = task # new, new-seg
42 | self.attentions = [None for _ in range(len(stages))]
43 | self.apool = apool
44 |
45 | # create the encoders
46 | # num_classes is the output fc dimension
47 | self.encoder_q = base_encoder(num_classes=dim, pretrained=True, **base_encoder_kwargs) # q is for new task
48 | self.encoder_k = base_encoder(num_classes=dim, pretrained=True, **base_encoder_kwargs) # ######
49 | if get_head is not None:
50 | num_ftrs = self.encoder_q.fc.in_features
51 | self.encoder_q.fc_new = get_head(num_ftrs, num_class)
52 | for param in self.encoder_q.fc.parameters():
53 | param.requires_grad = False
54 |
55 | if mlp:
56 | fc_q = {}
57 | fc_k = {}
58 | for stage in stages:
59 | if stage > 0:
60 | try:
61 | # BottleNeck
62 | dim_mlp = getattr(self.encoder_q, "layer%d"%stage)[-1].conv3.weight.size()[0]
63 | except torch.nn.modules.module.ModuleAttributeError:
64 | # BasicBlock
65 | dim_mlp = getattr(self.encoder_q, "layer%d"%stage)[-1].conv2.weight.size()[0]
66 | elif stage == 0:
67 | dim_mlp = self.encoder_q.conv1.weight.size()[0]
68 | fc_q["stage%d"%(stage)] = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, dim))
69 | fc_k["stage%d"%(stage)] = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, dim))
70 | self.encoder_q.fc_csg = nn.ModuleDict(fc_q)
71 | self.encoder_k.fc_csg = nn.ModuleDict(fc_k)
72 | for param_q, param_k in zip(self.encoder_q.fc_csg.parameters(), self.encoder_k.fc_csg.parameters()):
73 | param_k.data.copy_(param_q.data)
74 |
75 | for param_k in self.encoder_k.parameters():
76 | param_k.requires_grad = False # not update by gradient
77 |
78 | if type(self.encoder_q).__name__ == "ResNet":
79 | try:
80 | # BottleNeck
81 | dims = [self.encoder_q.conv1.weight.size()[0]] + [getattr(self.encoder_q, "layer%d"%stage)[-1].conv3.weight.size()[0] for stage in range(1, 5)]
82 | except:
83 | # BasicBlock
84 | dims = [self.encoder_q.conv1.weight.size()[0]] + [getattr(self.encoder_q, "layer%d"%stage)[-1].conv2.weight.size()[0] for stage in range(1, 5)]
85 | elif type(self.encoder_q).__name__ == "DigitNet":
86 | dims = [64 for stage in range(1, 5)]
87 | for stage in stages:
88 | self.register_buffer("queue%d"%(stage), torch.randn(dim, K))
89 | setattr(self, "queue%d"%(stage), nn.functional.normalize(getattr(self, "queue%d"%(stage)), dim=0))
90 | self.register_buffer("queue_ptr%d"%(stage), torch.zeros(1, dtype=torch.long))
91 |
92 | def control_q_backbone_gradient(self, control):
93 | for name, param in self.encoder_q.named_parameters():
94 | if 'fc_new' not in name:
95 | param.requires_grad = control
96 | return
97 |
98 | @torch.no_grad()
99 | def _momentum_update_key_encoder(self):
100 | """
101 | Momentum update of the key encoder
102 | """
103 | for param_q, param_k in zip(self.encoder_q.fc_csg.parameters(), self.encoder_k.fc_csg.parameters()):
104 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
105 |
106 | @torch.no_grad()
107 | def _dequeue_and_enqueue(self, keys, stage):
108 | # gather keys before updating queue
109 | keys = concat_all_gather(keys)
110 |
111 | batch_size = keys.shape[0]
112 |
113 | ptr = int(getattr(self, "queue_ptr%d"%(stage)))
114 |
115 | if ptr + batch_size <= self.K:
116 | getattr(self, "queue%d"%(stage))[:, ptr:ptr + batch_size] = keys.T
117 | else:
118 | getattr(self, "queue%d"%(stage))[:, ptr:] = keys[:(self.K - ptr)].T
119 | getattr(self, "queue%d"%(stage))[:, :ptr + batch_size - self.K] = keys[:(ptr + batch_size - self.K)].T
120 | ptr = (ptr + batch_size) % self.K # move pointer
121 | getattr(self, "queue_ptr%d"%(stage))[0] = ptr
122 |
123 | def adaptive_pool(self, features, attn_from, stage_idx):
124 | # features and attn_from are paired feature maps, of same size
125 | assert features.size() == attn_from.size()
126 | N, C, H, W = features.size()
127 | assert (attn_from >= 0).float().sum() == N*C*H*W
128 | attention = torch.einsum('nchw,nc->nhw', [attn_from, nn.functional.adaptive_avg_pool2d(attn_from, (1, 1)).view(N, C)])
129 | attention = attention / attention.view(N, -1).sum(1).view(N, 1, 1).repeat(1, H, W)
130 | attention = attention.view(N, 1, H, W)
131 | # output size: N, C
132 | return (features * attention).view(N, C, -1).sum(2)
133 |
134 | def forward(self, im_q, im_k):
135 | """
136 | Input:
137 | im_q: a batch of query images
138 | im_k: a batch of key images
139 | Output:
140 | logits, targets
141 | """
142 | if im_k is None:
143 | im_k = im_q
144 |
145 | output, features_new = self.encoder_q(im_q, output_features=["layer%d"%stage for stage in self.stages], task=self.task)
146 | results = {'output': output}
147 |
148 | results['predictions_csg'] = []
149 | results['targets_csg'] = []
150 | # predictions: cosine b/w q and k
151 | # targets: zeros
152 | with torch.no_grad(): # no gradient to keys
153 | if self.mlp:
154 | self._momentum_update_key_encoder() # update the key encoder
155 | if self.apool:
156 | # A-Pool: prepare attention for teacher: get feature of im_k by encoder_q
157 | _, features_new_k = self.encoder_q.forward_backbone(im_k, output_features=["layer%d"%stage for stage in self.stages])
158 | _, features_old = self.encoder_k.forward_backbone(im_k, output_features=["layer%d"%stage for stage in self.stages])
159 | for idx, stage in enumerate(self.stages):
160 | chunk = self.chunks[idx]
161 | # compute query features
162 |
163 | q_feature = chunk_feature(features_new["layer%d"%stage], chunk)
164 | if self.apool:
165 | # A-Pool prepare attention for teacher: get feature of im_k by encoder_q
166 | q_feature_k = chunk_feature(features_new_k["layer%d"%stage], chunk)
167 | if self.mlp:
168 | if self.apool:
169 | q = self.encoder_q.fc_csg["stage%d"%(stage)](self.adaptive_pool(q_feature, q_feature, idx)) # A-Pool
170 | else:
171 | q = self.encoder_q.fc_csg["stage%d"%(stage)](self.encoder_q.avgpool(q_feature).view(features_new["layer%d"%stage].size(0)*chunk**2, -1))
172 | else:
173 | if self.apool != 'none':
174 | q = self.adaptive_pool(q_feature, q_feature, idx) # A-Pool
175 | else:
176 | q = self.encoder_q.avgpool(q_feature).view(features_new["layer%d"%stage].size(0)*chunk**2, -1)
177 | q = nn.functional.normalize(q, dim=1)
178 |
179 | # compute key features
180 | with torch.no_grad(): # no gradient to keys
181 | k_feature = chunk_feature(features_old["layer%d"%stage], chunk)
182 | # A-Pool #############
183 | if self.mlp:
184 | if self.apool:
185 | k = self.encoder_k.fc_csg["stage%d"%(stage)](self.adaptive_pool(k_feature, q_feature_k, idx)) # A-Pool
186 | else:
187 | k = self.encoder_k.fc_csg["stage%d"%(stage)](self.encoder_k.avgpool(k_feature).view(features_old["layer%d"%stage].size(0)*chunk**2, -1))
188 | else:
189 | if self.apool:
190 | k = self.adaptive_pool(k_feature, q_feature_k, idx) # A-Pool
191 | else:
192 | k = self.encoder_k.avgpool(k_feature).view(features_old["layer%d"%stage].size(0)*chunk**2, -1)
193 | # #####################
194 | k = nn.functional.normalize(k, dim=1)
195 |
196 | # compute logits
197 | # positive logits: Nx1
198 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
199 | # negative logits: NxK
200 | l_neg = torch.einsum('nc,ck->nk', [q, getattr(self, "queue%d"%(stage)).clone().detach()])
201 | # logits: Nx(1+K)
202 | logits = torch.cat([l_pos, l_neg], dim=1)
203 | # apply temperature
204 | logits /= self.T
205 | # labels: positive key indicators
206 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
207 | self._dequeue_and_enqueue(k, stage)
208 |
209 | results['predictions_csg'].append(logits)
210 | results['targets_csg'].append(labels)
211 |
212 | return results
213 |
214 |
215 | # utils
216 | @torch.no_grad()
217 | def concat_all_gather(tensor):
218 | """
219 | Performs all_gather operation on the provided tensors.
220 | *** Warning ***: torch.distributed.all_gather has no gradient.
221 | """
222 | try:
223 | tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
224 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
225 | except Exception:
226 | tensors_gather = [tensor]
227 |
228 | output = torch.cat(tensors_gather, dim=0)
229 | return output
230 |
--------------------------------------------------------------------------------
/model/deeplab.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import torch.nn as nn
5 | import torch
6 | import torch.nn.functional as F
7 | import numpy as np
8 | from pdb import set_trace as bp
9 | affine_par = True
10 |
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | "3x3 convolution with padding"
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(inplanes, planes, stride)
25 | self.bn1 = nn.BatchNorm2d(planes, affine = affine_par)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = nn.BatchNorm2d(planes, affine = affine_par)
29 | self.downsample = downsample
30 | self.stride = stride
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | out += residual
46 | out = self.relu(out)
47 |
48 | return out
49 |
50 |
51 | class Bottleneck(nn.Module):
52 | expansion = 4
53 |
54 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
55 | super(Bottleneck, self).__init__()
56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)
57 | self.bn1 = nn.BatchNorm2d(planes,affine = affine_par)
58 | for i in self.bn1.parameters():
59 | i.requires_grad = False
60 |
61 | padding = dilation
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63 | padding=padding, bias=False, dilation = dilation)
64 | self.bn2 = nn.BatchNorm2d(planes,affine = affine_par)
65 | for i in self.bn2.parameters():
66 | i.requires_grad = False
67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68 | self.bn3 = nn.BatchNorm2d(planes * 4, affine = affine_par)
69 | for i in self.bn3.parameters():
70 | i.requires_grad = False
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class ASPPConv(nn.Sequential):
99 | def __init__(self, in_channels, out_channels, dilation):
100 | modules = [
101 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
102 | nn.BatchNorm2d(out_channels),
103 | nn.ReLU()
104 | ]
105 | super(ASPPConv, self).__init__(*modules)
106 |
107 |
108 | class ASPPPooling(nn.Sequential):
109 | def __init__(self, in_channels, out_channels):
110 | super(ASPPPooling, self).__init__(
111 | nn.AdaptiveAvgPool2d(1),
112 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
113 | nn.BatchNorm2d(out_channels),
114 | nn.ReLU())
115 |
116 | def forward(self, x):
117 | size = x.shape[-2:]
118 | for mod in self:
119 | x = mod(x)
120 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
121 |
122 |
123 | class ASPP(nn.Module):
124 | def __init__(self, in_channels, atrous_rates, out_channels=256):
125 | super(ASPP, self).__init__()
126 | modules = []
127 | modules.append(nn.Sequential(
128 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
129 | nn.BatchNorm2d(out_channels),
130 | nn.ReLU()))
131 |
132 | rates = tuple(atrous_rates)
133 | for rate in rates:
134 | modules.append(ASPPConv(in_channels, out_channels, rate))
135 |
136 | modules.append(ASPPPooling(in_channels, out_channels))
137 |
138 | self.convs = nn.ModuleList(modules)
139 |
140 | self.project = nn.Sequential(
141 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
142 | nn.BatchNorm2d(out_channels),
143 | nn.ReLU(),
144 | nn.Dropout(0.5))
145 |
146 | def forward(self, x):
147 | res = []
148 | for conv in self.convs:
149 | res.append(conv(x))
150 | res = torch.cat(res, dim=1)
151 | return self.project(res)
152 |
153 |
154 | class Classifier_Module(nn.Module):
155 |
156 | def __init__(self, dilation_series, padding_series, num_classes):
157 | super(Classifier_Module, self).__init__()
158 | self.conv2d_list = nn.ModuleList()
159 |
160 | self.conv2d_list = nn.ModuleList([nn.Sequential(
161 | ASPP(2048, [12, 24, 36]),
162 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
163 | nn.BatchNorm2d(256),
164 | nn.ReLU(),
165 | nn.Conv2d(256, num_classes, 1)
166 | )])
167 |
168 | def forward(self, x):
169 | out = self.conv2d_list[0](x)
170 | for i in range(len(self.conv2d_list)-1):
171 | out += self.conv2d_list[i+1](x)
172 | return out
173 |
174 |
175 | class ResNet(nn.Module):
176 | def __init__(self, num_classes=1000, num_seg_classes=19, pretrained=False, block=Bottleneck, layers=[3, 4, 23, 3]):
177 | self.inplanes = 64
178 | super(ResNet, self).__init__()
179 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
180 | bias=False)
181 | self.bn1 = nn.BatchNorm2d(64, affine = affine_par)
182 | for i in self.bn1.parameters():
183 | i.requires_grad = False
184 | self.relu = nn.ReLU(inplace=True)
185 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # , ceil_mode=True) # change
186 | self.layer1 = self._make_layer(block, 64, layers[0])
187 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
188 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
189 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
190 | self.fc_new = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], num_seg_classes)
191 | self.fc = nn.Linear(512 * block.expansion, num_classes)
192 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
193 | if pretrained:
194 | model_urls = {
195 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
196 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
197 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
198 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
199 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
200 | }
201 | if layers == [3, 4, 6, 3]:
202 | saved_state_dict = torch.utils.model_zoo.load_url(model_urls['resnet50'])
203 | elif layers == [3, 4, 23, 3]:
204 | saved_state_dict = torch.utils.model_zoo.load_url(model_urls['resnet101'])
205 | new_params = self.state_dict().copy()
206 | for i in saved_state_dict:
207 | i_parts = str(i).split('.')
208 | if not i_parts[0] == 'fc':
209 | assert '.'.join(i_parts[0:]) in new_params
210 | new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
211 | self.load_state_dict(new_params)
212 | else:
213 | for m in self.modules():
214 | if isinstance(m, nn.Conv2d):
215 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
216 | m.weight.data.normal_(0, 0.01)
217 | elif isinstance(m, nn.BatchNorm2d):
218 | m.weight.data.fill_(1)
219 | m.bias.data.zero_()
220 |
221 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
222 | downsample = None
223 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
224 | downsample = nn.Sequential(
225 | nn.Conv2d(self.inplanes, planes * block.expansion,
226 | kernel_size=1, stride=stride, bias=False),
227 | nn.BatchNorm2d(planes * block.expansion,affine = affine_par))
228 | for i in downsample._modules['1'].parameters():
229 | i.requires_grad = False
230 | layers = []
231 | layers.append(block(self.inplanes, planes, stride, dilation=(dilation//2) if dilation > 1 else dilation, downsample=downsample))
232 | self.inplanes = planes * block.expansion
233 | for i in range(1, blocks):
234 | layers.append(block(self.inplanes, planes, dilation=dilation))
235 |
236 | return nn.Sequential(*layers)
237 |
238 | def _make_pred_layer(self, block, dilation_series, padding_series, num_classes):
239 | return block(dilation_series, padding_series, num_classes)
240 |
241 | def forward_fc(self, f4, task='old'):
242 | x = f4
243 | if task in ['old', 'new']:
244 | x = self.avgpool(x)
245 | x = x.reshape(x.size(0), -1)
246 | if task == 'old':
247 | x = self.fc(x)
248 | else:
249 | x = self.fc_new(x)
250 | x = nn.functional.interpolate(x, size=self.input_size, mode='bilinear', align_corners=True)
251 | return x
252 |
253 | def forward_partial(self, feature, stage):
254 | # stage: start forwarding **from** this stage (inclusive)
255 | if stage <= 1:
256 | feature = self.layer1(feature)
257 | if stage <= 2:
258 | feature = self.layer2(feature)
259 | if stage <= 3:
260 | feature = self.layer3(feature)
261 | if stage <= 4:
262 | feature = self.layer4(feature)
263 | return feature
264 |
265 | def forward_backbone(self, x, output_features=['layer4']):
266 | features = {}
267 | x = self.conv1(x)
268 | x = self.bn1(x)
269 | x = self.relu(x)
270 | if 'layer0' in output_features: features['layer0'] = f0
271 | x = self.maxpool(x)
272 | f1 = self.layer1(x)
273 | if 'layer1' in output_features: features['layer1'] = f1
274 | f2 = self.layer2(f1)
275 | if 'layer2' in output_features: features['layer2'] = f2
276 | f3 = self.layer3(f2)
277 | if 'layer3' in output_features: features['layer3'] = f3
278 | f4 = self.layer4(f3)
279 | if 'layer4' in output_features: features['layer4'] = f4
280 | if 'gap' in output_features:
281 | features['gap'] = self.avgpool(f4).view(f4.size(0), -1)
282 | return f4, features
283 |
284 | def forward(self, x, output_features=['layer4'], task='old'):
285 | '''
286 | task: 'old' | 'new' | 'new_seg'
287 | 'old', 'new': classification tasks (ImageNet or Visda)
288 | 'new_seg': segmentation head (convs)
289 | '''
290 | self.input_size = x.size()[2:]
291 | f4, features = self.forward_backbone(x, output_features)
292 | x = self.forward_fc(f4, task=task)
293 | return x, features
294 |
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from pdb import set_trace as bp
9 | from model.csg_builder import chunk_feature
10 |
11 |
12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
14 |
15 |
16 | model_urls = {
17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
22 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
23 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
24 | }
25 |
26 |
27 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
28 | """3x3 convolution with padding"""
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=dilation, groups=groups, bias=False, dilation=dilation)
31 |
32 |
33 | def conv1x1(in_planes, out_planes, stride=1):
34 | """1x1 convolution"""
35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion = 1
40 |
41 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
42 | base_width=64, dilation=1, norm_layer=None):
43 | super(BasicBlock, self).__init__()
44 | if norm_layer is None:
45 | norm_layer = nn.BatchNorm2d
46 | if groups != 1 or base_width != 64:
47 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
48 | if dilation > 1:
49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
51 | self.conv1 = conv3x3(inplanes, planes, stride)
52 | self.bn1 = norm_layer(planes)
53 | self.relu = nn.ReLU(inplace=True)
54 | self.conv2 = conv3x3(planes, planes)
55 | self.bn2 = norm_layer(planes)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | identity = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | identity = self.downsample(x)
71 |
72 | out += identity
73 | out = self.relu(out)
74 |
75 | return out
76 |
77 |
78 | class Bottleneck(nn.Module):
79 | expansion = 4
80 |
81 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
82 | base_width=64, dilation=1, norm_layer=None,
83 | ):
84 | super(Bottleneck, self).__init__()
85 | if norm_layer is None:
86 | norm_layer = nn.BatchNorm2d
87 | width = int(planes * (base_width / 64.)) * groups
88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
89 | self.conv1 = conv1x1(inplanes, width)
90 | self.bn1 = norm_layer(width)
91 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
92 | self.bn2 = norm_layer(width)
93 | self.conv3 = conv1x1(width, planes * self.expansion)
94 | self.bn3 = norm_layer(planes * self.expansion)
95 | self.relu = nn.ReLU(inplace=True)
96 | self.last = last
97 | self.downsample = downsample
98 | self.stride = stride
99 |
100 | def forward(self, x):
101 | identity = x
102 |
103 | out = self.conv1(x)
104 | out = self.bn1(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv2(out)
108 | out = self.bn2(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv3(out)
112 | out = self.bn3(out)
113 |
114 | if self.downsample is not None:
115 | identity = self.downsample(x)
116 |
117 | out += identity
118 | out = self.relu(out)
119 |
120 | return out
121 |
122 |
123 | class ResNet(nn.Module):
124 |
125 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
126 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
127 | norm_layer=None):
128 | super(ResNet, self).__init__()
129 | if norm_layer is None:
130 | norm_layer = nn.BatchNorm2d
131 | self._norm_layer = norm_layer
132 |
133 | self.inplanes = 64
134 | self.dilation = 1
135 | if replace_stride_with_dilation is None:
136 | # each element in the tuple indicates if we should replace
137 | # the 2x2 stride with a dilated convolution instead
138 | replace_stride_with_dilation = [False, False, False]
139 | if len(replace_stride_with_dilation) != 3:
140 | raise ValueError("replace_stride_with_dilation should be None "
141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
142 | self.groups = groups
143 | self.base_width = width_per_group
144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
145 | bias=False)
146 | self.bn1 = norm_layer(self.inplanes)
147 | self.relu = nn.ReLU(inplace=True)
148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
149 | self.layer1 = self._make_layer(block, 64, layers[0])
150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
151 | dilate=replace_stride_with_dilation[0])
152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
153 | dilate=replace_stride_with_dilation[1])
154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
155 | dilate=replace_stride_with_dilation[2])
156 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
157 | self.fc = nn.Linear(512 * block.expansion, num_classes)
158 |
159 | for m in self.modules():
160 | if isinstance(m, nn.Conv2d):
161 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
162 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
163 | nn.init.constant_(m.weight, 1)
164 | nn.init.constant_(m.bias, 0)
165 |
166 | # Zero-initialize the last BN in each residual branch,
167 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
168 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
169 | if zero_init_residual:
170 | for m in self.modules():
171 | if isinstance(m, Bottleneck):
172 | nn.init.constant_(m.bn3.weight, 0)
173 | elif isinstance(m, BasicBlock):
174 | nn.init.constant_(m.bn2.weight, 0)
175 |
176 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
177 | norm_layer = self._norm_layer
178 | downsample = None
179 | previous_dilation = self.dilation
180 | if dilate:
181 | self.dilation *= stride
182 | stride = 1
183 | if stride != 1 or self.inplanes != planes * block.expansion:
184 | downsample = nn.Sequential(
185 | conv1x1(self.inplanes, planes * block.expansion, stride),
186 | norm_layer(planes * block.expansion),
187 | )
188 |
189 | layers = []
190 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
191 | self.base_width, previous_dilation, norm_layer))
192 | self.inplanes = planes * block.expansion
193 | for _idx in range(1, blocks):
194 | layers.append(block(self.inplanes, planes, groups=self.groups,
195 | base_width=self.base_width, dilation=self.dilation,
196 | norm_layer=norm_layer,
197 | ))
198 |
199 | return nn.Sequential(*layers)
200 |
201 | def forward_fc(self, f4, task='old', f3=None, f2=None, return_mid_feature=False):
202 | x = f4
203 | if task in ['old', 'new']:
204 | x = self.avgpool(x)
205 | x = x.reshape(x.size(0), -1)
206 | if task == 'old':
207 | x = self.fc(x)
208 | return x
209 | else:
210 | if return_mid_feature:
211 | mid = self.fc_new[0](x)
212 | x = self.fc_new[1](mid)
213 | x = self.fc_new[2](x)
214 | return x, mid
215 | else:
216 | x = self.fc_new(x)
217 | return x
218 |
219 | def forward_partial(self, feature, stage):
220 | # stage: start forwarding **from** this stage (inclusive)
221 | # assert stage in [1, 2, 3, 4]
222 | if stage <= 1:
223 | feature = self.layer1(feature)
224 | if stage <= 2:
225 | feature = self.layer2(feature)
226 | if stage <= 3:
227 | feature = self.layer3(feature)
228 | if stage <= 4:
229 | feature = self.layer4(feature)
230 | return feature
231 |
232 | def forward_backbone(self, x, output_features=['layer4']):
233 | features = {}
234 | f0 = self.conv1(x)
235 | f0 = self.bn1(f0)
236 | f0 = self.relu(f0)
237 | if 'layer0' in output_features: features['layer0'] = f0
238 | f0 = self.maxpool(f0)
239 | f1 = self.layer1(f0)
240 | if 'layer1' in output_features: features['layer1'] = f1
241 | f2 = self.layer2(f1)
242 | if 'layer2' in output_features: features['layer2'] = f2
243 | f3 = self.layer3(f2)
244 | if 'layer3' in output_features: features['layer3'] = f3
245 | f4 = self.layer4(f3)
246 | if 'layer4' in output_features: features['layer4'] = f4
247 | if 'gap' in output_features:
248 | features['gap'] = self.avgpool(f4).view(f4.size(0), -1)
249 | return f4, features
250 | # return f4, f3, f2, features
251 |
252 | def forward(self, x, output_features=['layer4'], task='old'):
253 | '''
254 | task: 'old' | 'new' | 'new_seg'
255 | 'old', 'new': classification tasks (ImageNet or Visda)
256 | 'new_seg': segmentation head (convs)
257 | '''
258 | f4, features = self.forward_backbone(x, output_features)
259 | if 'fc_mid' in output_features:
260 | x, _mid = self.forward_fc(f4, task=task, return_mid_feature=True)
261 | features['fc_mid'] = _mid
262 | else:
263 | x = self.forward_fc(f4, task=task)
264 | return x, features
265 |
266 |
267 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
268 | model = ResNet(block, layers, **kwargs)
269 | if pretrained:
270 | from torchvision.models.utils import load_state_dict_from_url
271 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
272 | # model.load_state_dict(state_dict)
273 | state = model.state_dict()
274 | pretrained_dict = {k: v for k, v in state_dict.items() if k in state and state[k].size() == v.size()}
275 | state.update(pretrained_dict)
276 | model.load_state_dict(state)
277 | return model
278 |
279 |
280 | def resnet18(pretrained=False, progress=True, **kwargs):
281 | """Constructs a ResNet-18 model.
282 |
283 | Args:
284 | pretrained (bool): If True, returns a model pre-trained on ImageNet
285 | progress (bool): If True, displays a progress bar of the download to stderr
286 | """
287 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
288 | **kwargs)
289 |
290 |
291 | def resnet34(pretrained=False, progress=True, **kwargs):
292 | """Constructs a ResNet-34 model.
293 |
294 | Args:
295 | pretrained (bool): If True, returns a model pre-trained on ImageNet
296 | progress (bool): If True, displays a progress bar of the download to stderr
297 | """
298 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
299 | **kwargs)
300 |
301 |
302 | def resnet50(pretrained=False, progress=True, **kwargs):
303 | """Constructs a ResNet-50 model.
304 |
305 | Args:
306 | pretrained (bool): If True, returns a model pre-trained on ImageNet
307 | progress (bool): If True, displays a progress bar of the download to stderr
308 | """
309 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
310 | **kwargs)
311 |
312 |
313 | def resnet101(pretrained=False, progress=True, **kwargs):
314 | """Constructs a ResNet-101 model.
315 |
316 | Args:
317 | pretrained (bool): If True, returns a model pre-trained on ImageNet
318 | progress (bool): If True, displays a progress bar of the download to stderr
319 | """
320 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
321 | **kwargs)
322 |
323 |
324 | def resnet152(pretrained=False, progress=True, **kwargs):
325 | """Constructs a ResNet-152 model.
326 |
327 | Args:
328 | pretrained (bool): If True, returns a model pre-trained on ImageNet
329 | progress (bool): If True, displays a progress bar of the download to stderr
330 | """
331 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
332 | **kwargs)
333 |
334 |
335 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
336 | """Constructs a ResNeXt-50 32x4d model.
337 |
338 | Args:
339 | pretrained (bool): If True, returns a model pre-trained on ImageNet
340 | progress (bool): If True, displays a progress bar of the download to stderr
341 | """
342 | kwargs['groups'] = 32
343 | kwargs['width_per_group'] = 4
344 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
345 | pretrained, progress, **kwargs)
346 |
347 |
348 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
349 | """Constructs a ResNeXt-101 32x8d model.
350 |
351 | Args:
352 | pretrained (bool): If True, returns a model pre-trained on ImageNet
353 | progress (bool): If True, displays a progress bar of the download to stderr
354 | """
355 | kwargs['groups'] = 32
356 | kwargs['width_per_group'] = 8
357 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
358 | pretrained, progress, **kwargs)
359 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | easydict
2 | matplotlib==3.0.0
3 | numpy==1.16.1
4 | opencv-python==3.4.4.19
5 | Pillow==6.2.0
6 | scipy==1.1.0
7 | tensorflow
8 | tensorboard==1.9.0
9 | tensorboardX==1.6
10 | torch==1.2.0
11 | torchvision==0.3.0
12 | tqdm==4.25.0
--------------------------------------------------------------------------------
/tools/datasets/BaseDataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import os
5 | import cv2
6 | cv2.setNumThreads(0)
7 | import torch
8 | import numpy as np
9 | from random import shuffle
10 |
11 | import torch.utils.data as data
12 |
13 |
14 | class BaseDataset(data.Dataset):
15 | def __init__(self, setting, split_name, preprocess=None, file_length=None):
16 | super(BaseDataset, self).__init__()
17 | self._split_name = split_name
18 | self._img_path = setting['img_root']
19 | self._gt_path = setting['gt_root']
20 | self._portion = setting['portion'] if 'portion' in setting else None
21 | self._train_source = setting['train_source']
22 | self._eval_source = setting['eval_source']
23 | self._test_source = setting['test_source'] if 'test_source' in setting else setting['eval_source']
24 | self._down_sampling = setting['down_sampling']
25 | print("using downsampling:", self._down_sampling)
26 | self._file_names = self._get_file_names(split_name)
27 | print("Found %d images"%len(self._file_names))
28 | self._file_length = file_length
29 | self.preprocess = preprocess
30 |
31 | def __len__(self):
32 | if self._file_length is not None:
33 | return self._file_length
34 | return len(self._file_names)
35 |
36 | def __getitem__(self, index):
37 | if self._file_length is not None:
38 | names = self._construct_new_file_names(self._file_length)[index]
39 | else:
40 | names = self._file_names[index]
41 | img_path = os.path.join(self._img_path, names[0])
42 | gt_path = os.path.join(self._gt_path, names[1])
43 | item_name = names[1].split("/")[-1].split(".")[0]
44 |
45 | img, gt = self._fetch_data(img_path, gt_path)
46 | img = img[:, :, ::-1]
47 | if self.preprocess is not None:
48 | img, gt, extra_dict = self.preprocess(img, gt)
49 |
50 | if self._split_name is 'train':
51 | img = torch.from_numpy(np.ascontiguousarray(img)).float()
52 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
53 | if self.preprocess is not None and extra_dict is not None:
54 | for k, v in extra_dict.items():
55 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v))
56 | if 'label' in k:
57 | extra_dict[k] = extra_dict[k].long()
58 | if 'img' in k:
59 | extra_dict[k] = extra_dict[k].float()
60 |
61 | output_dict = dict(data=img, label=gt, fn=str(item_name),
62 | n=len(self._file_names))
63 | if self.preprocess is not None and extra_dict is not None:
64 | output_dict.update(**extra_dict)
65 |
66 | return output_dict
67 |
68 | def _fetch_data(self, img_path, gt_path, dtype=None):
69 | img = self._open_image(img_path, down_sampling=self._down_sampling)
70 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype, down_sampling=self._down_sampling)
71 |
72 | return img, gt
73 |
74 | def _get_file_names(self, split_name):
75 | assert split_name in ['train', 'val', 'test']
76 | source = self._train_source
77 | if split_name == "val":
78 | source = self._eval_source
79 | elif split_name == 'test':
80 | source = self._test_source
81 |
82 | file_names = []
83 | with open(source) as f:
84 | files = f.readlines()
85 | if self._portion is not None:
86 | shuffle(files)
87 | num_files = len(files)
88 | if self._portion > 0:
89 | split = int(np.floor(self._portion * num_files))
90 | files = files[:split]
91 | elif self._portion < 0:
92 | split = int(np.floor((1 + self._portion) * num_files))
93 | files = files[split:]
94 |
95 | for item in files:
96 | img_name, gt_name = self._process_item_names(item)
97 | file_names.append([img_name, gt_name])
98 |
99 | return file_names
100 |
101 | def _construct_new_file_names(self, length):
102 | assert isinstance(length, int)
103 | files_len = len(self._file_names)
104 | new_file_names = self._file_names * (length // files_len)
105 |
106 | rand_indices = torch.randperm(files_len).tolist()
107 | new_indices = rand_indices[:length % files_len]
108 |
109 | new_file_names += [self._file_names[i] for i in new_indices]
110 |
111 | return new_file_names
112 |
113 | @staticmethod
114 | def _process_item_names(item):
115 | item = item.strip()
116 | # item = item.split('\t')
117 | item = item.split(' ')
118 | img_name = item[0]
119 | gt_name = item[1]
120 |
121 | return img_name, gt_name
122 |
123 | def get_length(self):
124 | return self.__len__()
125 |
126 | @staticmethod
127 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None, down_sampling=1):
128 | # cv2: B G R
129 | # h w c
130 | img = np.array(cv2.imread(filepath, mode), dtype=dtype)
131 |
132 | if isinstance(down_sampling, int):
133 | H, W = img.shape[:2]
134 | if len(img.shape) == 3:
135 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_LINEAR)
136 | else:
137 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_NEAREST)
138 | assert img.shape[0] == H // down_sampling and img.shape[1] == W // down_sampling
139 | else:
140 | assert (isinstance(down_sampling, tuple) or isinstance(down_sampling, list)) and len(down_sampling) == 2
141 | if len(img.shape) == 3:
142 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_LINEAR)
143 | else:
144 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_NEAREST)
145 | assert img.shape[0] == down_sampling[0] and img.shape[1] == down_sampling[1]
146 |
147 | return img
148 |
149 | @classmethod
150 | def get_class_colors(*args):
151 | raise NotImplementedError
152 |
153 | @classmethod
154 | def get_class_names(*args):
155 | raise NotImplementedError
156 |
157 |
158 | if __name__ == "__main__":
159 | data_setting = {'img_root': '',
160 | 'gt_root': '',
161 | 'train_source': '',
162 | 'eval_source': ''}
163 | bd = BaseDataset(data_setting, 'train', None)
164 | print(bd.get_class_names())
165 |
--------------------------------------------------------------------------------
/tools/datasets/cityscapes/cityscapes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.))
3 |
4 | import numpy as np
5 |
6 | from datasets.BaseDataset import BaseDataset
7 |
8 |
9 | class Cityscapes(BaseDataset):
10 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
11 | 28, 31, 32, 33]
12 |
13 | @classmethod
14 | def get_class_colors(*args):
15 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
16 | [102, 102, 156], [190, 153, 153], [153, 153, 153],
17 | [250, 170, 30], [220, 220, 0], [107, 142, 35],
18 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
19 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
20 | [0, 0, 230], [119, 11, 32]]
21 |
22 | @classmethod
23 | def get_class_names(*args):
24 | # class counting(gtFine)
25 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832
26 | # 359 274 142 513 1646
27 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
28 | 'traffic light', 'traffic sign',
29 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
30 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
31 |
32 | @classmethod
33 | def transform_label(cls, pred, name):
34 | label = np.zeros(pred.shape)
35 | ids = np.unique(pred)
36 | for id in ids:
37 | label[np.where(pred == id)] = cls.trans_labels[id]
38 |
39 | new_name = (name.split('.')[0]).split('_')[:-1]
40 | new_name = '_'.join(new_name) + '.png'
41 |
42 | print('Trans', name, 'to', new_name, ' ',
43 | np.unique(np.array(pred, np.uint8)), ' ---------> ',
44 | np.unique(np.array(label, np.uint8)))
45 | return label, new_name
46 |
--------------------------------------------------------------------------------
/tools/engine/evaluator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.))
3 |
4 | import os
5 | from PIL import Image
6 | import cv2
7 | import numpy as np
8 | import time
9 | from tqdm import tqdm
10 |
11 | import torch
12 | import torch.multiprocessing as mp
13 |
14 | from tools.engine.logger import get_logger
15 | from tools.utils.pyt_utils import load_model, link_file, ensure_dir
16 | from tools.utils.img_utils import pad_image_to_shape, normalize
17 |
18 | logger = get_logger()
19 |
20 |
21 | class Evaluator(object):
22 | def __init__(self, dataset, class_num, image_mean, image_std, network,
23 | multi_scales, is_flip, devices=0, threds=3, config=None, logger=None,
24 | verbose=False, save_path=None, show_image=False, show_prediction=False):
25 | self.dataset = dataset
26 | self.ndata = self.dataset.get_length()
27 | self.class_num = class_num
28 | self.image_mean = image_mean
29 | self.image_std = image_std
30 | self.multi_scales = multi_scales
31 | self.is_flip = is_flip
32 | self.network = network
33 | self.devices = devices
34 | if type(self.devices) == int: self.devices = [self.devices]
35 | self.threds = threds
36 | self.config = config
37 | self.logger = logger
38 |
39 | self.context = mp.get_context('spawn')
40 | self.val_func = None
41 | self.results_queue = self.context.Queue(self.ndata)
42 | self.features_queue = self.context.Queue(self.ndata)
43 |
44 | self.verbose = verbose
45 | self.save_path = save_path
46 | if save_path is not None:
47 | ensure_dir(save_path)
48 | self.show_image = show_image
49 | self.show_prediction = show_prediction
50 |
51 | def run(self, model_path, model_indice, log_file, log_file_link):
52 | """There are four evaluation modes:
53 | 1.only eval a .pth model: -e *.pth
54 | 2.only eval a certain epoch: -e epoch
55 | 3.eval all epochs in a given section: -e start_epoch-end_epoch
56 | 4.eval all epochs from a certain started epoch: -e start_epoch-
57 | """
58 | if '.pth' in model_indice:
59 | models = [model_indice, ]
60 | elif "-" in model_indice:
61 | start_epoch = int(model_indice.split("-")[0])
62 | end_epoch = model_indice.split("-")[1]
63 |
64 | models = os.listdir(model_path)
65 | models.remove("epoch-last.pth")
66 | sorted_models = [None] * len(models)
67 | model_idx = [0] * len(models)
68 |
69 | for idx, m in enumerate(models):
70 | num = m.split(".")[0].split("-")[1]
71 | model_idx[idx] = num
72 | sorted_models[idx] = m
73 | model_idx = np.array([int(i) for i in model_idx])
74 |
75 | down_bound = model_idx >= start_epoch
76 | up_bound = [True] * len(sorted_models)
77 | if end_epoch:
78 | end_epoch = int(end_epoch)
79 | assert start_epoch < end_epoch
80 | up_bound = model_idx <= end_epoch
81 | bound = up_bound * down_bound
82 | model_slice = np.array(sorted_models)[bound]
83 | models = [os.path.join(model_path, model) for model in
84 | model_slice]
85 | else:
86 | models = [os.path.join(model_path,
87 | 'epoch-%s.pth' % model_indice), ]
88 |
89 | results = open(log_file, 'a')
90 | link_file(log_file, log_file_link)
91 |
92 | for model in models:
93 | logger.info("Load Model: %s" % model)
94 | self.val_func = load_model(self.network, model)
95 | result_line, mIoU = self.multi_process_evaluation()
96 |
97 | results.write('Model: ' + model + '\n')
98 | results.write(result_line)
99 | results.write('\n')
100 | results.flush()
101 |
102 | results.close()
103 |
104 | def run_online(self):
105 | """
106 | eval during training
107 | """
108 | self.val_func = self.network
109 | result_line, mIoU = self.single_process_evaluation()
110 | return result_line, mIoU
111 |
112 | def single_process_evaluation(self):
113 | all_results = []
114 | from pdb import set_trace as bp
115 | with torch.no_grad():
116 | for idx in tqdm(range(self.ndata)):
117 | dd = self.dataset[idx]
118 | results_dict = self.func_per_iteration(dd, self.devices[0], iter=idx)
119 | all_results.append(results_dict)
120 | _, _mIoU = self.compute_metric([results_dict])
121 | result_line, mIoU = self.compute_metric(all_results)
122 | return result_line, mIoU
123 |
124 | def run_online_multiprocess(self):
125 | """
126 | eval during training
127 | """
128 | self.val_func = self.network
129 | result_line, mIoU = self.multi_process_single_gpu_evaluation()
130 | return result_line, mIoU
131 |
132 | def multi_process_single_gpu_evaluation(self):
133 | # start_eval_time = time.perf_counter()
134 | stride = int(np.ceil(self.ndata / self.threds))
135 |
136 | # start multi-process on single-gpu
137 | procs = []
138 | for d in range(self.threds):
139 | e_record = min((d + 1) * stride, self.ndata)
140 | shred_list = list(range(d * stride, e_record))
141 | device = self.devices[0]
142 | logger.info('Thread %d handle %d data.' % (d, len(shred_list)))
143 | p = self.context.Process(target=self.worker, args=(shred_list, device))
144 | procs.append(p)
145 |
146 | for p in procs:
147 | p.start()
148 |
149 | all_results = []
150 | for _ in tqdm(range(self.ndata)):
151 | t = self.results_queue.get()
152 | all_results.append(t)
153 | if self.verbose:
154 | self.compute_metric(all_results)
155 |
156 | for p in procs:
157 | p.join()
158 |
159 | result_line, mIoU = self.compute_metric(all_results)
160 | return result_line, mIoU
161 |
162 | def multi_process_evaluation(self):
163 | start_eval_time = time.perf_counter()
164 | nr_devices = len(self.devices)
165 | stride = int(np.ceil(self.ndata / nr_devices))
166 |
167 | # start multi-process on multi-gpu
168 | procs = []
169 | for d in range(nr_devices):
170 | e_record = min((d + 1) * stride, self.ndata)
171 | shred_list = list(range(d * stride, e_record))
172 | device = self.devices[d]
173 | logger.info('GPU %s handle %d data.' % (device, len(shred_list)))
174 | p = self.context.Process(target=self.worker, args=(shred_list, device))
175 | procs.append(p)
176 |
177 | for p in procs:
178 | p.start()
179 |
180 | all_results = []
181 | for _ in tqdm(range(self.ndata)):
182 | t = self.results_queue.get()
183 | all_results.append(t)
184 | if self.verbose:
185 | self.compute_metric(all_results)
186 |
187 | for p in procs:
188 | p.join()
189 |
190 | result_line, mIoU = self.compute_metric(all_results)
191 | logger.info('Evaluation Elapsed Time: %.2fs' % (time.perf_counter() - start_eval_time))
192 | return result_line, mIoU
193 |
194 | def worker(self, shred_list, device):
195 | # start_load_time = time.time()
196 | # logger.info('Load Model on Device %d: %.2fs' % (device, time.time() - start_load_time))
197 | for idx in shred_list:
198 | dd = self.dataset[idx]
199 | results_dict = self.func_per_iteration(dd, device, iter=idx)
200 | self.results_queue.put(results_dict)
201 |
202 |
203 | def func_per_iteration(self, data, device, iter=None):
204 | raise NotImplementedError
205 |
206 | def compute_metric(self, results):
207 | raise NotImplementedError
208 |
209 | # evaluate the whole image at once
210 | def whole_eval(self, img, output_size, resize=None, input_size=None, device=None):
211 | if input_size is not None:
212 | img, margin = self.process_image(img, resize=resize, crop_size=input_size)
213 | else:
214 | img = self.process_image(img, resize=resize, crop_size=input_size)
215 |
216 | pred = self.val_func_process(img, device)
217 | if input_size is not None:
218 | pred = pred[:, margin[0]:(pred.shape[1] - margin[1]), margin[2]:(pred.shape[2] - margin[3])]
219 | pred = pred.permute(1, 2, 0)
220 | pred = pred.cpu().numpy()
221 | if output_size is not None:
222 | pred = cv2.resize(pred,
223 | (output_size[1], output_size[0]),
224 | interpolation=cv2.INTER_LINEAR)
225 |
226 | # pred = pred.argmax(2)
227 |
228 | return pred
229 |
230 | # slide the window to evaluate the image
231 | def sliding_eval(self, img, crop_size, stride_rate, device=None):
232 | ori_rows, ori_cols, c = img.shape
233 | processed_pred = np.zeros((ori_rows, ori_cols, self.class_num))
234 |
235 | for s in self.multi_scales:
236 | img_scale = cv2.resize(img, None, fx=s, fy=s,
237 | interpolation=cv2.INTER_LINEAR)
238 | new_rows, new_cols, _ = img_scale.shape
239 | processed_pred += self.scale_process(img_scale,
240 | (ori_rows, ori_cols),
241 | crop_size, stride_rate, device)
242 |
243 | pred = processed_pred.argmax(2)
244 |
245 | return pred
246 |
247 | def scale_process(self, img, ori_shape, crop_size, stride_rate,
248 | device=None):
249 | new_rows, new_cols, c = img.shape
250 | long_size = new_cols if new_cols > new_rows else new_rows
251 |
252 | if long_size <= crop_size:
253 | input_data, margin = self.process_image(img, crop_size=crop_size)
254 | score = self.val_func_process(input_data, device)
255 | score = score[:, margin[0]:(score.shape[1] - margin[1]),
256 | margin[2]:(score.shape[2] - margin[3])]
257 | else:
258 | stride = int(np.ceil(crop_size * stride_rate))
259 | img_pad, margin = pad_image_to_shape(img, crop_size,
260 | cv2.BORDER_CONSTANT, value=0)
261 |
262 | pad_rows = img_pad.shape[0]
263 | pad_cols = img_pad.shape[1]
264 | r_grid = int(np.ceil((pad_rows - crop_size) / stride)) + 1
265 | c_grid = int(np.ceil((pad_cols - crop_size) / stride)) + 1
266 | data_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda(
267 | device)
268 | count_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda(
269 | device)
270 |
271 | for grid_yidx in range(r_grid):
272 | for grid_xidx in range(c_grid):
273 | s_x = grid_xidx * stride
274 | s_y = grid_yidx * stride
275 | e_x = min(s_x + crop_size, pad_cols)
276 | e_y = min(s_y + crop_size, pad_rows)
277 | s_x = e_x - crop_size
278 | s_y = e_y - crop_size
279 | img_sub = img_pad[s_y:e_y, s_x: e_x, :]
280 | count_scale[:, s_y: e_y, s_x: e_x] += 1
281 |
282 | input_data, tmargin = self.process_image(img_sub, crop_size=crop_size)
283 | temp_score = self.val_func_process(input_data, device)
284 | temp_score = temp_score[:,
285 | tmargin[0]:(temp_score.shape[1] - tmargin[1]),
286 | tmargin[2]:(temp_score.shape[2] - tmargin[3])]
287 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score
288 | # score = data_scale / count_scale
289 | score = data_scale
290 | score = score[:, margin[0]:(score.shape[1] - margin[1]),
291 | margin[2]:(score.shape[2] - margin[3])]
292 |
293 | score = score.permute(1, 2, 0)
294 | data_output = cv2.resize(score.cpu().numpy(),
295 | (ori_shape[1], ori_shape[0]),
296 | interpolation=cv2.INTER_LINEAR)
297 |
298 | return data_output
299 |
300 | def val_func_process(self, input_data, device=None):
301 | input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32)
302 | input_data = torch.FloatTensor(input_data).cuda(device)
303 |
304 | with torch.cuda.device(input_data.get_device()):
305 | self.val_func.eval()
306 | self.val_func.to(input_data.get_device())
307 | with torch.no_grad():
308 | score = self.val_func(input_data, output_features=[], task='new_seg')[0]
309 | if (isinstance(score, tuple) or isinstance(score, list)) and len(score) > 1:
310 | score = score[self.out_idx]
311 | score = score[0] # a single image pass, ignore batch dim
312 |
313 | if self.is_flip:
314 | input_data = input_data.flip(-1)
315 | score_flip = self.val_func(input_data)[0]
316 | score_flip = score_flip[0] # a single image pass, ignore batch dim
317 | score += score_flip.flip(-1)
318 | score = torch.exp(score)
319 | # score = score.data
320 |
321 | return score
322 |
323 | def process_image(self, img, resize=None, crop_size=None):
324 | p_img = img
325 |
326 | if img.shape[2] < 3:
327 | im_b = p_img
328 | im_g = p_img
329 | im_r = p_img
330 | p_img = np.concatenate((im_b, im_g, im_r), axis=2)
331 |
332 | if resize is not None:
333 | if isinstance(resize, float):
334 | _size = p_img.shape[:2]
335 | # p_img = np.array(Image.fromarray(p_img).resize((int(_size[0]*resize), int(_size[1]*resize)), Image.BILINEAR))
336 | p_img = np.array(Image.fromarray(p_img).resize((int(_size[1]*resize), int(_size[0]*resize)), Image.BILINEAR))
337 | elif isinstance(resize, tuple) or isinstance(resize, list):
338 | assert len(resize) == 2
339 | p_img = np.array(Image.fromarray(p_img).resize((int(resize[0]), int(resize[1])), Image.BILINEAR))
340 |
341 | p_img = normalize(p_img, self.image_mean, self.image_std)
342 |
343 | if crop_size is not None:
344 | p_img, margin = pad_image_to_shape(p_img, crop_size, cv2.BORDER_CONSTANT, value=0)
345 | p_img = p_img.transpose(2, 0, 1)
346 |
347 | return p_img, margin
348 |
349 | p_img = p_img.transpose(2, 0, 1)
350 |
351 | return p_img
352 |
--------------------------------------------------------------------------------
/tools/engine/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.))
3 |
4 | import os
5 | import sys
6 | import logging
7 |
8 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO')
9 | _default_level = logging.getLevelName(_default_level_name.upper())
10 |
11 |
12 | class LogFormatter(logging.Formatter):
13 | log_fout = None
14 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] '
15 | date = '%(asctime)s '
16 | msg = '%(message)s'
17 |
18 | def format(self, record):
19 | if record.levelno == logging.DEBUG:
20 | mcl, mtxt = self._color_dbg, 'DBG'
21 | elif record.levelno == logging.WARNING:
22 | mcl, mtxt = self._color_warn, 'WRN'
23 | elif record.levelno == logging.ERROR:
24 | mcl, mtxt = self._color_err, 'ERR'
25 | else:
26 | mcl, mtxt = self._color_normal, ''
27 |
28 | if mtxt:
29 | mtxt += ' '
30 |
31 | if self.log_fout:
32 | self.__set_fmt(self.date_full + mtxt + self.msg)
33 | formatted = super(LogFormatter, self).format(record)
34 | # self.log_fout.write(formatted)
35 | # self.log_fout.write('\n')
36 | # self.log_fout.flush()
37 | return formatted
38 |
39 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
40 | formatted = super(LogFormatter, self).format(record)
41 |
42 | return formatted
43 |
44 | if sys.version_info.major < 3:
45 | def __set_fmt(self, fmt):
46 | self._fmt = fmt
47 | else:
48 | def __set_fmt(self, fmt):
49 | self._style._fmt = fmt
50 |
51 | @staticmethod
52 | def _color_dbg(msg):
53 | return '\x1b[36m{}\x1b[0m'.format(msg)
54 |
55 | @staticmethod
56 | def _color_warn(msg):
57 | return '\x1b[1;31m{}\x1b[0m'.format(msg)
58 |
59 | @staticmethod
60 | def _color_err(msg):
61 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg)
62 |
63 | @staticmethod
64 | def _color_omitted(msg):
65 | return '\x1b[35m{}\x1b[0m'.format(msg)
66 |
67 | @staticmethod
68 | def _color_normal(msg):
69 | return msg
70 |
71 | @staticmethod
72 | def _color_date(msg):
73 | return '\x1b[32m{}\x1b[0m'.format(msg)
74 |
75 |
76 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter):
77 | logger = logging.getLogger()
78 | logger.setLevel(_default_level)
79 | del logger.handlers[:]
80 |
81 | if log_dir and log_file:
82 | if not os.path.isdir(log_dir): os.makedirs(log_dir)
83 | LogFormatter.log_fout = True
84 | file_handler = logging.FileHandler(log_file, mode='a')
85 | file_handler.setLevel(logging.INFO)
86 | file_handler.setFormatter(formatter)
87 | logger.addHandler(file_handler)
88 |
89 | stream_handler = logging.StreamHandler()
90 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S'))
91 | stream_handler.setLevel(0)
92 | logger.addHandler(stream_handler)
93 | return logger
94 |
--------------------------------------------------------------------------------
/tools/engine/tester.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.))
3 |
4 | import os
5 | import cv2
6 | import numpy as np
7 | import time
8 | from tqdm import tqdm
9 | from pdb import set_trace as bp
10 | import torch
11 | import torch.multiprocessing as mp
12 |
13 | from engine.logger import get_logger
14 | from tools.utils.pyt_utils import load_model, link_file, ensure_dir
15 | from tools.utils.img_utils import pad_image_to_shape, normalize
16 |
17 | logger = get_logger()
18 |
19 |
20 | class Tester(object):
21 | def __init__(self, dataset, class_num, image_mean, image_std, network,
22 | multi_scales, is_flip, devices=0, out_idx=0, threds=3, config=None, logger=None,
23 | verbose=False, save_path=None, show_image=False):
24 | self.dataset = dataset
25 | self.ndata = self.dataset.get_length()
26 | self.class_num = class_num
27 | self.image_mean = image_mean
28 | self.image_std = image_std
29 | self.multi_scales = multi_scales
30 | self.is_flip = is_flip
31 | self.network = network
32 | self.devices = devices
33 | if type(self.devices) == int: self.devices = [self.devices]
34 | self.out_idx = out_idx
35 | self.threds = threds
36 | self.config = config
37 | self.logger = logger
38 |
39 | self.context = mp.get_context('spawn')
40 | self.val_func = None
41 | self.results_queue = self.context.Queue(self.ndata)
42 |
43 | self.verbose = verbose
44 | self.save_path = save_path
45 | if save_path is not None:
46 | ensure_dir(save_path)
47 | self.show_image = show_image
48 |
49 | def run(self, model_path, model_indice, log_file, log_file_link):
50 | """There are four evaluation modes:
51 | 1.only eval a .pth model: -e *.pth
52 | 2.only eval a certain epoch: -e epoch
53 | 3.eval all epochs in a given section: -e start_epoch-end_epoch
54 | 4.eval all epochs from a certain started epoch: -e start_epoch-
55 | """
56 | if '.pth' in model_indice:
57 | models = [model_indice, ]
58 | elif "-" in model_indice:
59 | start_epoch = int(model_indice.split("-")[0])
60 | end_epoch = model_indice.split("-")[1]
61 |
62 | models = os.listdir(model_path)
63 | models.remove("epoch-last.pth")
64 | sorted_models = [None] * len(models)
65 | model_idx = [0] * len(models)
66 |
67 | for idx, m in enumerate(models):
68 | num = m.split(".")[0].split("-")[1]
69 | model_idx[idx] = num
70 | sorted_models[idx] = m
71 | model_idx = np.array([int(i) for i in model_idx])
72 |
73 | down_bound = model_idx >= start_epoch
74 | up_bound = [True] * len(sorted_models)
75 | if end_epoch:
76 | end_epoch = int(end_epoch)
77 | assert start_epoch < end_epoch
78 | up_bound = model_idx <= end_epoch
79 | bound = up_bound * down_bound
80 | model_slice = np.array(sorted_models)[bound]
81 | models = [os.path.join(model_path, model) for model in
82 | model_slice]
83 | else:
84 | models = [os.path.join(model_path,
85 | 'epoch-%s.pth' % model_indice), ]
86 |
87 | results = open(log_file, 'a')
88 | link_file(log_file, log_file_link)
89 |
90 | for model in models:
91 | logger.info("Load Model: %s" % model)
92 | self.val_func = load_model(self.network, model)
93 | result_line, mIoU = self.multi_process_evaluation()
94 |
95 | results.write('Model: ' + model + '\n')
96 | results.write(result_line)
97 | results.write('\n')
98 | results.flush()
99 |
100 | results.close()
101 |
102 | def run_online(self):
103 | """
104 | eval during training
105 | """
106 | self.val_func = self.network
107 | self.single_process_evaluation()
108 |
109 | def single_process_evaluation(self):
110 | with torch.no_grad():
111 | for idx in tqdm(range(self.ndata)):
112 | dd = self.dataset[idx]
113 | self.func_per_iteration(dd, self.devices[0], iter=idx)
114 |
115 | def run_online_multiprocess(self):
116 | """
117 | eval during training
118 | """
119 | self.val_func = self.network
120 | self.multi_process_single_gpu_evaluation()
121 |
122 | def multi_process_single_gpu_evaluation(self):
123 | # start_eval_time = time.perf_counter()
124 | stride = int(np.ceil(self.ndata / self.threds))
125 |
126 | # start multi-process on single-gpu
127 | procs = []
128 | for d in range(self.threds):
129 | e_record = min((d + 1) * stride, self.ndata)
130 | shred_list = list(range(d * stride, e_record))
131 | device = self.devices[0]
132 | logger.info('Thread %d handle %d data.' % (d, len(shred_list)))
133 | p = self.context.Process(target=self.worker, args=(shred_list, device))
134 | procs.append(p)
135 |
136 | for p in procs:
137 | p.start()
138 |
139 | for p in procs:
140 | p.join()
141 |
142 | def multi_process_evaluation(self):
143 | nr_devices = len(self.devices)
144 | stride = int(np.ceil(self.ndata / nr_devices))
145 |
146 | # start multi-process on multi-gpu
147 | procs = []
148 | for d in range(nr_devices):
149 | e_record = min((d + 1) * stride, self.ndata)
150 | shred_list = list(range(d * stride, e_record))
151 | device = self.devices[d]
152 | logger.info('GPU %s handle %d data.' % (device, len(shred_list)))
153 | p = self.context.Process(target=self.worker, args=(shred_list, device))
154 | procs.append(p)
155 |
156 | for p in procs:
157 | p.start()
158 |
159 | for p in procs:
160 | p.join()
161 |
162 | def worker(self, shred_list, device):
163 | start_load_time = time.time()
164 | # logger.info('Load Model on Device %d: %.2fs' % (device, time.time() - start_load_time))
165 | for idx in shred_list:
166 | dd = self.dataset[idx]
167 | results_dict = self.func_per_iteration(dd, device, iter=idx)
168 | self.results_queue.put(results_dict)
169 |
170 | def func_per_iteration(self, data, device, iter=None):
171 | raise NotImplementedError
172 |
173 | def compute_metric(self, results):
174 | raise NotImplementedError
175 |
176 | # evaluate the whole image at once
177 | def whole_eval(self, img, output_size, input_size=None, device=None):
178 | if input_size is not None:
179 | img, margin = self.process_image(img, input_size)
180 | else:
181 | img = self.process_image(img, input_size)
182 |
183 | pred = self.val_func_process(img, device)
184 | if input_size is not None:
185 | pred = pred[:, margin[0]:(pred.shape[1] - margin[1]), margin[2]:(pred.shape[2] - margin[3])]
186 | pred = pred.permute(1, 2, 0)
187 | pred = pred.cpu().numpy()
188 | if output_size is not None:
189 | pred = cv2.resize(pred,
190 | (output_size[1], output_size[0]),
191 | interpolation=cv2.INTER_LINEAR)
192 |
193 | pred = pred.argmax(2)
194 |
195 | return pred
196 |
197 | # slide the window to evaluate the image
198 | def sliding_eval(self, img, crop_size, stride_rate, device=None):
199 | ori_rows, ori_cols, c = img.shape
200 | processed_pred = np.zeros((ori_rows, ori_cols, self.class_num))
201 |
202 | for s in self.multi_scales:
203 | img_scale = cv2.resize(img, None, fx=s, fy=s,
204 | interpolation=cv2.INTER_LINEAR)
205 | new_rows, new_cols, _ = img_scale.shape
206 | processed_pred += self.scale_process(img_scale,
207 | (ori_rows, ori_cols),
208 | crop_size, stride_rate, device)
209 |
210 | pred = processed_pred.argmax(2)
211 |
212 | return pred
213 |
214 | def scale_process(self, img, ori_shape, crop_size, stride_rate,
215 | device=None):
216 | new_rows, new_cols, c = img.shape
217 | long_size = new_cols if new_cols > new_rows else new_rows
218 |
219 | if long_size <= crop_size:
220 | input_data, margin = self.process_image(img, crop_size)
221 | score = self.val_func_process(input_data, device)
222 | score = score[:, margin[0]:(score.shape[1] - margin[1]),
223 | margin[2]:(score.shape[2] - margin[3])]
224 | else:
225 | stride = int(np.ceil(crop_size * stride_rate))
226 | img_pad, margin = pad_image_to_shape(img, crop_size,
227 | cv2.BORDER_CONSTANT, value=0)
228 |
229 | pad_rows = img_pad.shape[0]
230 | pad_cols = img_pad.shape[1]
231 | r_grid = int(np.ceil((pad_rows - crop_size) / stride)) + 1
232 | c_grid = int(np.ceil((pad_cols - crop_size) / stride)) + 1
233 | data_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda(
234 | device)
235 | count_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda(
236 | device)
237 |
238 | for grid_yidx in range(r_grid):
239 | for grid_xidx in range(c_grid):
240 | s_x = grid_xidx * stride
241 | s_y = grid_yidx * stride
242 | e_x = min(s_x + crop_size, pad_cols)
243 | e_y = min(s_y + crop_size, pad_rows)
244 | s_x = e_x - crop_size
245 | s_y = e_y - crop_size
246 | img_sub = img_pad[s_y:e_y, s_x: e_x, :]
247 | count_scale[:, s_y: e_y, s_x: e_x] += 1
248 |
249 | input_data, tmargin = self.process_image(img_sub, crop_size)
250 | temp_score = self.val_func_process(input_data, device)
251 | temp_score = temp_score[:,
252 | tmargin[0]:(temp_score.shape[1] - tmargin[1]),
253 | tmargin[2]:(temp_score.shape[2] - tmargin[3])]
254 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score
255 | # score = data_scale / count_scale
256 | score = data_scale
257 | score = score[:, margin[0]:(score.shape[1] - margin[1]),
258 | margin[2]:(score.shape[2] - margin[3])]
259 |
260 | score = score.permute(1, 2, 0)
261 | data_output = cv2.resize(score.cpu().numpy(),
262 | (ori_shape[1], ori_shape[0]),
263 | interpolation=cv2.INTER_LINEAR)
264 |
265 | return data_output
266 |
267 | def val_func_process(self, input_data, device=None):
268 | input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32)
269 | input_data = torch.FloatTensor(input_data).cuda(device)
270 |
271 | with torch.cuda.device(input_data.get_device()):
272 | self.val_func.eval()
273 | self.val_func.to(input_data.get_device())
274 | with torch.no_grad():
275 | score = self.val_func(input_data, task='new_seg')[0]
276 | if (isinstance(score, tuple) or isinstance(score, list)) and len(score) > 1:
277 | score = score[self.out_idx]
278 | score = score[0] # a single image pass, ignore batch dim
279 |
280 | if self.is_flip:
281 | input_data = input_data.flip(-1)
282 | score_flip = self.val_func(input_data)
283 | score_flip = score_flip[0]
284 | score += score_flip.flip(-1)
285 | score = torch.exp(score)
286 | # score = score.data
287 |
288 | return score
289 |
290 | def process_image(self, img, crop_size=None):
291 | p_img = img
292 |
293 | if img.shape[2] < 3:
294 | im_b = p_img
295 | im_g = p_img
296 | im_r = p_img
297 | p_img = np.concatenate((im_b, im_g, im_r), axis=2)
298 |
299 | p_img = normalize(p_img, self.image_mean, self.image_std)
300 |
301 | if crop_size is not None:
302 | p_img, margin = pad_image_to_shape(p_img, crop_size, cv2.BORDER_CONSTANT, value=0)
303 | p_img = p_img.transpose(2, 0, 1)
304 |
305 | return p_img, margin
306 |
307 | p_img = p_img.transpose(2, 0, 1)
308 |
309 | return p_img
310 |
--------------------------------------------------------------------------------
/tools/seg_opr/metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 |
6 | np.seterr(divide='ignore', invalid='ignore')
7 |
8 |
9 | # voc cityscapes metric
10 | def hist_info(n_cl, pred, gt):
11 | assert (pred.shape == gt.shape), "pred: " + str(pred.shape) + " v.s. gt: " + str(gt.shape)
12 | k = (gt >= 0) & (gt < n_cl)
13 | labeled = np.sum(k)
14 | correct = np.sum((pred[k] == gt[k]))
15 |
16 | return np.bincount(n_cl * gt[k].astype(int) + pred[k].astype(int),
17 | minlength=n_cl ** 2).reshape(n_cl,
18 | n_cl), labeled, correct
19 |
20 |
21 | def compute_score(hist, correct, labeled):
22 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
23 | mean_IU = np.nanmean(iu)
24 | mean_IU_no_back = np.nanmean(iu[1:])
25 | mean_pixel_acc = correct / labeled
26 |
27 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc
28 |
29 |
30 | # ade metric
31 | def meanIoU(area_intersection, area_union):
32 | iou = 1.0 * np.sum(area_intersection, axis=1) / np.sum(area_union, axis=1)
33 | meaniou = np.nanmean(iou)
34 | meaniou_no_back = np.nanmean(iou[1:])
35 |
36 | return iou, meaniou, meaniou_no_back
37 |
38 |
39 | def intersectionAndUnion(imPred, imLab, numClass):
40 | # Remove classes from unlabeled pixels in gt image.
41 | # We should not penalize detections in unlabeled portions of the image.
42 | imPred = np.asarray(imPred).copy()
43 | imLab = np.asarray(imLab).copy()
44 |
45 | imPred += 1
46 | imLab += 1
47 | # Remove classes from unlabeled pixels in gt image.
48 | # We should not penalize detections in unlabeled portions of the image.
49 | imPred = imPred * (imLab > 0)
50 |
51 | # imPred = imPred * (imLab >= 0)
52 |
53 | # Compute area intersection:
54 | intersection = imPred * (imPred == imLab)
55 | (area_intersection, _) = np.histogram(intersection, bins=numClass,
56 | range=(1, numClass))
57 |
58 | # Compute area union:
59 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
60 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
61 | area_union = area_pred + area_lab - area_intersection
62 |
63 | return area_intersection, area_union
64 |
65 |
66 | def mean_pixel_accuracy(pixel_correct, pixel_labeled):
67 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (
68 | np.spacing(1) + np.sum(pixel_labeled))
69 |
70 | return mean_pixel_accuracy
71 |
72 |
73 | def pixelAccuracy(imPred, imLab):
74 | # Remove classes from unlabeled pixels in gt image.
75 | # We should not penalize detections in unlabeled portions of the image.
76 | pixel_labeled = np.sum(imLab >= 0)
77 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0))
78 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
79 |
80 | return pixel_accuracy, pixel_correct, pixel_labeled
81 |
82 |
83 | def accuracy(preds, label):
84 | valid = (label >= 0)
85 | acc_sum = (valid * (preds == label)).sum()
86 | valid_sum = valid.sum()
87 | acc = float(acc_sum) / (valid_sum + 1e-10)
88 | return acc, valid_sum
89 |
--------------------------------------------------------------------------------
/tools/utils/img_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import cv2
5 | import numpy as np
6 | import numbers
7 | import random
8 | import collections
9 |
10 |
11 | def get_2dshape(shape, *, zero=True):
12 | if not isinstance(shape, collections.Iterable):
13 | shape = int(shape)
14 | shape = (shape, shape)
15 | else:
16 | h, w = map(int, shape)
17 | shape = (h, w)
18 | if zero:
19 | minv = 0
20 | else:
21 | minv = 1
22 |
23 | assert min(shape) >= minv, 'invalid shape: {}'.format(shape)
24 | return shape
25 |
26 |
27 | def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value):
28 | h, w = img.shape[:2]
29 | start_crop_h, start_crop_w = crop_pos
30 | assert ((start_crop_h < h) and (start_crop_h >= 0))
31 | assert ((start_crop_w < w) and (start_crop_w >= 0))
32 |
33 | crop_size = get_2dshape(crop_size)
34 | crop_h, crop_w = crop_size
35 |
36 | img_crop = img[start_crop_h:start_crop_h + crop_h,
37 | start_crop_w:start_crop_w + crop_w, ...]
38 |
39 | img_, margin = pad_image_to_shape(img_crop, crop_size, cv2.BORDER_CONSTANT,
40 | pad_label_value)
41 |
42 | return img_, margin
43 |
44 |
45 | def generate_random_crop_pos(ori_size, crop_size):
46 | ori_size = get_2dshape(ori_size)
47 | h, w = ori_size
48 |
49 | crop_size = get_2dshape(crop_size)
50 | crop_h, crop_w = crop_size
51 |
52 | pos_h, pos_w = 0, 0
53 |
54 | if h > crop_h:
55 | pos_h = random.randint(0, h - crop_h + 1)
56 |
57 | if w > crop_w:
58 | pos_w = random.randint(0, w - crop_w + 1)
59 |
60 | return pos_h, pos_w
61 |
62 |
63 | def pad_image_to_shape(img, shape, border_mode, value):
64 | margin = np.zeros(4, np.uint32)
65 | shape = get_2dshape(shape)
66 | pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0
67 | pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0
68 |
69 | margin[0] = pad_height // 2
70 | margin[1] = pad_height // 2 + pad_height % 2
71 | margin[2] = pad_width // 2
72 | margin[3] = pad_width // 2 + pad_width % 2
73 |
74 | img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3],
75 | border_mode, value=value)
76 |
77 | return img, margin
78 |
79 |
80 | def pad_image_size_to_multiples_of(img, multiple, pad_value):
81 | h, w = img.shape[:2]
82 | d = multiple
83 |
84 | def canonicalize(s):
85 | v = s // d
86 | return (v + (v * d != s)) * d
87 |
88 | th, tw = map(canonicalize, (h, w))
89 |
90 | return pad_image_to_shape(img, (th, tw), cv2.BORDER_CONSTANT, pad_value)
91 |
92 |
93 | def resize_ensure_shortest_edge(img, edge_length,
94 | interpolation_mode=cv2.INTER_LINEAR):
95 | assert isinstance(edge_length, int) and edge_length > 0, edge_length
96 | h, w = img.shape[:2]
97 | if h < w:
98 | ratio = float(edge_length) / h
99 | th, tw = edge_length, max(1, int(ratio * w))
100 | else:
101 | ratio = float(edge_length) / w
102 | th, tw = max(1, int(ratio * h)), edge_length
103 | img = cv2.resize(img, (tw, th), interpolation_mode)
104 |
105 | return img
106 |
107 |
108 | def random_scale(img, gt, scales):
109 | # scale = random.choice(scales)
110 | scale = random.uniform(min(scales), max(scales))
111 | sh = int(img.shape[0] * scale)
112 | sw = int(img.shape[1] * scale)
113 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR)
114 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST)
115 |
116 | return img, gt, scale
117 |
118 |
119 | def random_scale_with_length(img, gt, length):
120 | size = random.choice(length)
121 | sh = size
122 | sw = size
123 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR)
124 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST)
125 |
126 | return img, gt, size
127 |
128 |
129 | def random_mirror(img, gt):
130 | if random.random() >= 0.5:
131 | img = cv2.flip(img, 1)
132 | gt = cv2.flip(gt, 1)
133 |
134 | return img, gt,
135 |
136 |
137 | def random_rotation(img, gt):
138 | angle = random.random() * 20 - 10
139 | h, w = img.shape[:2]
140 | rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
141 | img = cv2.warpAffine(img, rotation_matrix, (w, h), flags=cv2.INTER_LINEAR)
142 | gt = cv2.warpAffine(gt, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST)
143 |
144 | return img, gt
145 |
146 |
147 | def random_gaussian_blur(img):
148 | gauss_size = random.choice([1, 3, 5, 7])
149 | if gauss_size > 1:
150 | # do the gaussian blur
151 | img = cv2.GaussianBlur(img, (gauss_size, gauss_size), 0)
152 |
153 | return img
154 |
155 |
156 | def center_crop(img, shape):
157 | h, w = shape[0], shape[1]
158 | y = (img.shape[0] - h) // 2
159 | x = (img.shape[1] - w) // 2
160 | return img[y:y + h, x:x + w]
161 |
162 |
163 | def random_crop(img, gt, size):
164 | if isinstance(size, numbers.Number):
165 | size = (int(size), int(size))
166 |
167 | h, w = img.shape[:2]
168 | crop_h, crop_w = size[0], size[1]
169 |
170 | if h > crop_h:
171 | x = random.randint(0, h - crop_h + 1)
172 | img = img[x:x + crop_h, :, :]
173 | gt = gt[x:x + crop_h, :]
174 |
175 | if w > crop_w:
176 | x = random.randint(0, w - crop_w + 1)
177 | img = img[:, x:x + crop_w, :]
178 | gt = gt[:, x:x + crop_w]
179 |
180 | return img, gt
181 |
182 |
183 | def normalize(img, mean, std):
184 | # pytorch pretrained model need the input range: 0-1
185 | img = img.astype(np.float32) / 255.0
186 | img = img - mean
187 | img = img / std
188 |
189 | return img
190 |
--------------------------------------------------------------------------------
/tools/utils/pyt_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | # encoding: utf-8
5 | import os
6 | import time
7 | import argparse
8 | from collections import OrderedDict
9 |
10 | import torch
11 |
12 | from tools.engine.logger import get_logger
13 |
14 | logger = get_logger()
15 |
16 | model_urls = {
17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
22 | }
23 |
24 |
25 | def load_model(model, model_file, is_restore=False):
26 | t_start = time.time()
27 | if isinstance(model_file, str):
28 | state_dict = torch.load(model_file)
29 | if 'model' in state_dict.keys():
30 | state_dict = state_dict['model']
31 | else:
32 | state_dict = model_file
33 | t_ioend = time.time()
34 |
35 | if is_restore:
36 | new_state_dict = OrderedDict()
37 | for k, v in state_dict.items():
38 | name = 'module.' + k
39 | new_state_dict[name] = v
40 | state_dict = new_state_dict
41 |
42 | model.load_state_dict(state_dict, strict=False)
43 | ckpt_keys = set(state_dict.keys())
44 | own_keys = set(model.state_dict().keys())
45 | missing_keys = own_keys - ckpt_keys
46 | unexpected_keys = ckpt_keys - own_keys
47 |
48 | if len(missing_keys) > 0:
49 | logger.warning('Missing key(s) in state_dict: {}'.format(
50 | ', '.join('{}'.format(k) for k in missing_keys)))
51 |
52 | if len(unexpected_keys) > 0:
53 | logger.warning('Unexpected key(s) in state_dict: {}'.format(
54 | ', '.join('{}'.format(k) for k in unexpected_keys)))
55 |
56 | del state_dict
57 | t_end = time.time()
58 | logger.info(
59 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format(
60 | t_ioend - t_start, t_end - t_ioend))
61 |
62 | return model
63 |
64 |
65 | def parse_devices(input_devices):
66 | if input_devices.endswith('*'):
67 | devices = list(range(torch.cuda.device_count()))
68 | return devices
69 |
70 | devices = []
71 | for d in input_devices.split(','):
72 | if '-' in d:
73 | start_device, end_device = d.split('-')[0], d.split('-')[1]
74 | assert start_device != ''
75 | assert end_device != ''
76 | start_device, end_device = int(start_device), int(end_device)
77 | assert start_device < end_device
78 | assert end_device < torch.cuda.device_count()
79 | for sd in range(start_device, end_device + 1):
80 | devices.append(sd)
81 | else:
82 | device = int(d)
83 | assert device < torch.cuda.device_count()
84 | devices.append(device)
85 |
86 | logger.info('using devices {}'.format(
87 | ', '.join([str(d) for d in devices])))
88 |
89 | return devices
90 |
91 |
92 | def extant_file(x):
93 | """
94 | 'Type' for argparse - checks that file exists but does not open.
95 | """
96 | if not os.path.exists(x):
97 | # Argparse uses the ArgumentTypeError to give a rejection message like:
98 | # error: argument input: x does not exist
99 | raise argparse.ArgumentTypeError("{0} does not exist".format(x))
100 | return x
101 |
102 |
103 | def link_file(src, target):
104 | if os.path.isdir(target) or os.path.isfile(target):
105 | os.remove(target)
106 | os.system('ln -s {} {}'.format(src, target))
107 |
108 |
109 | def ensure_dir(path):
110 | if not os.path.isdir(path):
111 | os.makedirs(path)
112 |
113 |
114 | def _dbg_interactive(var, value):
115 | from IPython import embed
116 | embed()
117 |
--------------------------------------------------------------------------------
/tools/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import scipy.io as sio
4 |
5 |
6 | def set_img_color(colors, background, img, gt, show255=False, weight_foreground=0.55):
7 | origin = np.array(img)
8 | for i in range(len(colors)):
9 | if i != background:
10 | img[np.where(gt == i)] = colors[i]
11 | if show255:
12 | img[np.where(gt == 255)] = 0
13 | cv2.addWeighted(img, weight_foreground, origin, 1 - weight_foreground, 0, img)
14 | return img
15 |
16 |
17 | def show_prediction(colors, background, img, pred):
18 | im = np.array(img, np.uint8)
19 | set_img_color(colors, background, im, pred, weight_foreground=1)
20 | final = np.array(im)
21 | return final
22 |
23 |
24 | def show_img(colors, background, img, clean, gt, *pds):
25 | im1 = np.array(img, np.uint8)
26 | # set_img_color(colors, background, im1, clean)
27 | final = np.array(im1)
28 | # the pivot black bar
29 | pivot = np.zeros((im1.shape[0], 15, 3), dtype=np.uint8)
30 | for pd in pds:
31 | im = np.array(img, np.uint8)
32 | # pd[np.where(gt == 255)] = 255
33 | set_img_color(colors, background, im, pd)
34 | final = np.column_stack((final, pivot))
35 | final = np.column_stack((final, im))
36 |
37 | im = np.array(img, np.uint8)
38 | set_img_color(colors, background, im, gt, True)
39 | final = np.column_stack((final, pivot))
40 | final = np.column_stack((final, im))
41 | return final
42 |
43 |
44 | def get_colors(class_num):
45 | colors = []
46 | for i in range(class_num):
47 | colors.append((np.random.random((1, 3)) * 255).tolist()[0])
48 |
49 | return colors
50 |
51 |
52 | def get_ade_colors():
53 | colors = sio.loadmat('./color150.mat')['colors']
54 | colors = colors[:, ::-1, ]
55 | colors = np.array(colors).astype(int).tolist()
56 | colors.insert(0, [0, 0, 0])
57 |
58 | return colors
59 |
60 |
61 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False,
62 | no_print=False):
63 | n = iu.size
64 | lines = []
65 | for i in range(n):
66 | if class_names is None:
67 | cls = 'Class %d:' % (i + 1)
68 | else:
69 | cls = '%d %s' % (i + 1, class_names[i])
70 | lines.append('%-8s\t%.3f%%' % (cls, iu[i] * 100))
71 | mean_IU = np.nanmean(iu)
72 | # mean_IU_no_back = np.nanmean(iu[1:])
73 | mean_IU_no_back = np.nanmean(iu[:-1])
74 | if show_no_back:
75 | lines.append(
76 | '---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%\t%-8s\t%.3f%%' % (
77 | 'mean_IU', mean_IU * 100, 'mean_IU_no_back',
78 | mean_IU_no_back * 100,
79 | 'mean_pixel_ACC', mean_pixel_acc * 100))
80 | else:
81 | print(mean_pixel_acc)
82 | lines.append(
83 | '---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%' % (
84 | 'mean_IU', mean_IU * 100, 'mean_pixel_ACC',
85 | mean_pixel_acc * 100))
86 | line = "\n".join(lines)
87 | if not no_print:
88 | print(line)
89 | return line
90 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import argparse
5 | import os
6 | import sys
7 | import logging
8 | import time
9 | from tqdm import tqdm
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.parallel
13 | import torch.optim
14 | from torch.utils.data import DataLoader
15 | import torchvision.transforms as transforms
16 |
17 | from data.visda17 import VisDA17
18 | from data.loader_csg import TwoCropsTransform
19 | from model.resnet import resnet101
20 | from model import csg_builder
21 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, adjust_learning_rate, accuracy
22 | from utils.logger import prepare_logger, prepare_seed
23 | from utils.sgd import SGD
24 | from utils.augmentations import RandAugment, augment_list
25 |
26 | torch.backends.cudnn.enabled = True
27 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean')
28 |
29 | parser = argparse.ArgumentParser(description='PyTorch ResNet Training')
30 | parser.add_argument('--data', default='/home/chenwy/taskcv-2017-public/classification/data', help='path to dataset')
31 | parser.add_argument('--epochs', default=30, type=int, help='number of total epochs to run')
32 | parser.add_argument('--start-epoch', default=0, type=int, help='manual start epoch number (useful on restarts)')
33 | parser.add_argument('--batch-size', default=32, type=int, dest='batch_size', help='mini-batch size (default: 64)')
34 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate')
35 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)')
36 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
37 | parser.add_argument('--csg', default=0.1, type=float, dest='csg', help="weight of CSG loss (default: 0.1).")
38 | parser.add_argument('--factor', default=0.1, type=float, dest='factor', help='scale factor of backbone learning rate (default: 0.1)')
39 | parser.add_argument('--csg-stages', dest='csg_stages', default='4', help='resnet stages to involve in CSG, 0~4, seperated by dot')
40 | parser.add_argument('--chunks', dest='chunks', default='1', help='stage-wise chunk to feature maps, seperated by dot')
41 | parser.add_argument('--no-mlp', dest='mlp', action='store_false', default=True, help='not to use mlp during contrastive learning')
42 | parser.add_argument('--apool', default=False, action='store_true', help='use A-Pool')
43 | parser.add_argument('--augment', action='store_true', default=False, help='use augmentation')
44 | parser.add_argument('--resume', default='none', type=str, help='path to latest checkpoint (default: none)')
45 | parser.add_argument('--num-class', default=12, type=int, dest='num_classes', help='the number of classes')
46 | parser.add_argument('--evaluate', action='store_true', help='only perform evaluation without training')
47 | parser.add_argument('--save_dir', type=str, default="./runs", help='root folder to save checkpoints and log.')
48 | parser.add_argument('--rand_seed', default=0, type=int, help='random seed')
49 | parser.add_argument('--csg-k', default=65536, type=int, help='queue size; number of negative keys (default: 65536)')
50 | parser.add_argument('--timestamp', type=str, default='none', help='timestamp for logging naming')
51 | parser.set_defaults(bottleneck=True)
52 |
53 | best_prec1 = 0
54 |
55 |
56 | def main():
57 | global args, best_prec1
58 | PID = os.getpid()
59 | args = parser.parse_args()
60 | prepare_seed(args.rand_seed)
61 |
62 | if args.timestamp == 'none':
63 | args.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))
64 |
65 | # Log outputs
66 | if args.evaluate:
67 | args.save_dir = args.save_dir + "/Visda17-Res101-evaluate" + \
68 | "%s/%s"%('/'+args.resume.replace('/', '+') if args.resume != 'none' else '', args.timestamp)
69 | else:
70 | args.save_dir = args.save_dir + \
71 | "/VisDA-Res101-CSG.stg{csg_stages}.w{csg_weight}-APool.{apool}-Aug.{augment}-chunk{chunks}-mlp{mlp}.K{csg_k}-LR{lr}.bone{factor}-epoch{epochs}-batch{batch_size}-seed{seed}".format(
72 | csg_stages=args.csg_stages,
73 | mlp=args.mlp,
74 | csg_weight=args.csg,
75 | apool=args.apool,
76 | augment=args.augment,
77 | chunks=args.chunks,
78 | csg_k=args.csg_k,
79 | lr="%.2E"%args.lr,
80 | factor="%.1f"%args.factor,
81 | epochs=args.epochs,
82 | batch_size=args.batch_size,
83 | seed=args.rand_seed
84 | ) + \
85 | "%s/%s"%('/'+args.resume.replace('/', '+') if args.resume != 'none' else '', args.timestamp)
86 | logger = prepare_logger(args)
87 |
88 | data_transforms = {
89 | 'val': transforms.Compose([
90 | transforms.Resize(224),
91 | transforms.CenterCrop(224),
92 | transforms.ToTensor(),
93 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94 | ]),
95 | }
96 | if args.augment:
97 | data_transforms['train'] = transforms.Compose([
98 | RandAugment(1, 6., augment_list),
99 | transforms.Resize(224),
100 | transforms.RandomCrop(224),
101 | transforms.ToTensor(),
102 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
103 | ])
104 | else:
105 | data_transforms['train'] = transforms.Compose([
106 | transforms.Resize(224),
107 | transforms.CenterCrop(224),
108 | transforms.ToTensor(),
109 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
110 | ])
111 |
112 | kwargs = {'num_workers': 20, 'pin_memory': True}
113 | if args.augment:
114 | # two source
115 | trainset = VisDA17(txt_file=os.path.join(args.data, "train/image_list.txt"), root_dir=os.path.join(args.data, "train"),
116 | transform=TwoCropsTransform(data_transforms['train'], data_transforms['train']))
117 | else:
118 | # one source
119 | trainset = VisDA17(txt_file=os.path.join(args.data, "train/image_list.txt"), root_dir=os.path.join(args.data, "train"), transform=data_transforms['train'])
120 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
121 | valset = VisDA17(txt_file=os.path.join(args.data, "validation/image_list.txt"), root_dir=os.path.join(args.data, "validation"), transform=data_transforms['val'], label_one_hot=True)
122 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, **kwargs)
123 |
124 | args.stages = [int(stage) for stage in args.csg_stages.split('.')] if len(args.csg_stages) > 0 else []
125 | chunks = [int(chunk) for chunk in args.chunks.split('.')] if len(args.chunks) > 0 else []
126 | assert len(chunks) == 1 or len(chunks) == len(args.stages)
127 | if len(chunks) < len(args.stages):
128 | chunks = [chunks[0]] * len(args.stages)
129 |
130 | def get_head(num_ftrs, num_classes):
131 | _dim = 512
132 | return nn.Sequential(
133 | nn.Linear(num_ftrs, _dim),
134 | nn.ReLU(inplace=False),
135 | nn.Linear(_dim, num_classes),
136 | )
137 | model = csg_builder.CSG(
138 | resnet101, get_head=get_head, K=args.csg_k, stages=args.stages, chunks=chunks,
139 | apool=args.apool, mlp=args.mlp,
140 | )
141 |
142 | train_blocks = "conv1.bn1.layer1.layer2.layer3.layer4.fc"
143 | train_blocks = train_blocks.split('.')
144 | # Setup optimizer
145 | factor = args.factor
146 | sgd_in = []
147 | for name in train_blocks:
148 | if name != 'fc':
149 | sgd_in.append({'params': get_params(model.encoder_q, [name]), 'lr': factor*args.lr})
150 | else:
151 | # no update to fc but to fc_new
152 | sgd_in.append({'params': get_params(model.encoder_q, ["fc_new"]), 'lr': args.lr})
153 | if model.mlp:
154 | sgd_in.append({'params': get_params(model.encoder_q, ["fc_csg"]), 'lr': args.lr})
155 | base_lrs = [ group['lr'] for group in sgd_in ]
156 | optimizer = SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
157 |
158 | # Optionally resume from a checkpoint
159 | if args.resume != 'none':
160 | if os.path.isfile(args.resume):
161 | print("=> loading checkpoint '{}'".format(args.resume))
162 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
163 | args.start_epoch = checkpoint['epoch']
164 | best_prec1 = checkpoint['best_prec1']
165 | msg = model.load_state_dict(checkpoint['state_dict'], strict=False)
166 | print("resume weights: ", msg)
167 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
168 | else:
169 | print("=ImageClassdata> no checkpoint found at '{}'".format(args.resume))
170 |
171 | model = model.cuda()
172 |
173 | if args.evaluate:
174 | prec1 = validate(val_loader, model, args, 0)
175 | print(prec1)
176 | exit(0)
177 |
178 | # Main training loop
179 | iter_max = args.epochs * len(train_loader)
180 | iter_stat = IterNums(iter_max)
181 | for epoch in range(args.start_epoch, args.epochs):
182 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir))
183 | logger.log("Epoch: %d"%(epoch+1))
184 | train(train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger, args, adjust_lr=epoch best_prec1
192 | best_prec1 = max(prec1, best_prec1)
193 | save_checkpoint(args.save_dir, {
194 | 'epoch': epoch + 1,
195 | 'state_dict': model.state_dict(),
196 | 'best_prec1': best_prec1,
197 | }, is_best, keep_last=1)
198 |
199 | logging.info('Best accuracy: {prec1:.3f}'.format(prec1=best_prec1))
200 |
201 |
202 | def train(train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger, args, adjust_lr=True):
203 | tb_interval = 50
204 |
205 | csg_weight = args.csg
206 |
207 | losses = AverageMeter() # loss on target task
208 | losses_csg = [AverageMeter() for _ in range(len(model.stages))] # [_loss] x #stages
209 | top1_csg = [AverageMeter() for _ in range(len(model.stages))]
210 |
211 | model.eval()
212 |
213 | # train for one epoch
214 | optimizer.zero_grad()
215 | epoch_size = len(train_loader)
216 | train_loader_iter = iter(train_loader)
217 |
218 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
219 | pbar = tqdm(range(epoch_size), file=sys.stdout, bar_format=bar_format, ncols=80)
220 |
221 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9)
222 | logger.writer.add_scalar("lr", lr, epoch)
223 | logger.log("lr %f"%lr)
224 | for idx_iter in pbar:
225 | optimizer.zero_grad()
226 | if adjust_lr:
227 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9)
228 | adjust_learning_rate(base_lrs, optimizer, iter_stat.iter_curr, iter_stat.iter_max, 0.9)
229 |
230 | input, label = next(train_loader_iter)
231 | if args.augment:
232 | input_q = input[0].cuda()
233 | input_k = input[1].cuda()
234 | else:
235 | input_q = input.cuda()
236 | input_k = None
237 | label = label.cuda()
238 |
239 | results = model(input_q, input_k)
240 |
241 | # synthetic task
242 | loss = CrossEntropyLoss(results['output'], label.long())
243 | # measure accuracy and record loss
244 | losses.update(loss, label.size(0))
245 | for idx in range(len(model.stages)):
246 | _loss = 0
247 | acc1 = None
248 | # predictions: cosine b/w q and k
249 | # targets: zeros
250 | _loss = CrossEntropyLoss(results['predictions_csg'][idx], results['targets_csg'][idx])
251 | acc1, acc5 = accuracy_ranking(results['predictions_csg'][idx].data, results['targets_csg'][idx], topk=(1, 5))
252 | loss = loss + _loss * csg_weight
253 | # loss_csg[_type].append(_loss)
254 | if acc1 is not None: top1_csg[idx].update(acc1, label.size(0))
255 | # measure accuracy and record loss
256 | losses_csg[idx].update(_loss, label.size(0))
257 |
258 | loss.backward()
259 |
260 | # compute gradient and do SGD step
261 | optimizer.step()
262 | # increment iter number
263 | iter_stat.update()
264 |
265 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/ce", losses.val, idx_iter + epoch * epoch_size)
266 | description = "[XE %.3f]"%(losses.val)
267 | description += "[CSG "
268 | loss_str = ""
269 | acc_str = ""
270 | for idx, stage in enumerate(model.stages):
271 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/layer%d"%stage, losses_csg[idx].val, idx_iter + epoch * epoch_size)
272 | loss_str += "%.2f|"%losses_csg[idx].val
273 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("prec/layer%d"%stage, top1_csg[idx].val[0], idx_iter + epoch * epoch_size)
274 | acc_str += "%.1f|"%top1_csg[idx].val[0]
275 | description += "loss:%s ranking:%s]"%(loss_str[:-1], acc_str[:-1])
276 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/total", losses.val + sum([_loss.val for _loss in losses_csg]), idx_iter + epoch * epoch_size)
277 | pbar.set_description("[Step %d/%d][%s]"%(idx_iter + 1, epoch_size, str(csg_weight)) + description)
278 |
279 |
280 | def validate(val_loader, model, args, epoch):
281 | """Perform validation on the validation set"""
282 | top1 = AverageMeter()
283 |
284 | # switch to evaluate mode
285 | model.eval()
286 |
287 | val_size = len(val_loader)
288 | val_loader_iter = iter(val_loader)
289 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
290 | pbar = tqdm(range(val_size), file=sys.stdout, bar_format=bar_format, ncols=140)
291 | with torch.no_grad():
292 | for idx_iter in pbar:
293 | input, label = next(val_loader_iter)
294 |
295 | input = input.cuda()
296 | label = label.cuda()
297 |
298 | # compute output
299 | output, _ = model.encoder_q(input, task='new')
300 | output = torch.sigmoid(output)
301 | output = (output + torch.sigmoid(model.encoder_q(torch.flip(input, dims=(3,)), task='new')[0])) / 2
302 |
303 | # accumulate accuracyk
304 | prec1, gt_num = accuracy(output.data, label, args.num_classes, topk=(1,))
305 | top1.update(prec1[0], gt_num[0])
306 |
307 | description = "[Acc@1-mean: %.2f][Acc@1-cls: %s]"%(top1.vec2sca_avg, str(top1.avg.numpy().round(1)))
308 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, val_size) + description)
309 |
310 | logging.info(' * Prec@1 {top1.vec2sca_avg:.3f}'.format(top1=top1))
311 | logging.info(' * Prec@1 {top1.avg}'.format(top1=top1))
312 |
313 | return top1.vec2sca_avg
314 |
315 |
316 | def accuracy_ranking(output, target, topk=(1,)):
317 | """Computes the accuracy over the k top predictions for the specified values of k"""
318 | with torch.no_grad():
319 | maxk = max(topk)
320 | batch_size = target.size(0)
321 |
322 | _, pred = output.topk(maxk, 1, True, True)
323 | pred = pred.t()
324 | correct = pred.eq(target.view(1, -1).expand_as(pred))
325 |
326 | res = []
327 | for k in topk:
328 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
329 | res.append(correct_k.mul_(100.0 / batch_size))
330 | return res
331 |
332 |
333 | if __name__ == '__main__':
334 | main()
335 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | python train.py \
5 | --epochs 30 \
6 | --batch-size 32 \
7 | --lr 1e-4 \
8 | --rand_seed 0 \
9 | --csg 0.1 \
10 | --apool \
11 | --augment \
12 | --csg-stages 3.4 \
13 | --factor 0.1 \
14 | # --resume pretrained/csg_res101_vista17_best.pth.tar \
15 | # --evaluate
16 |
--------------------------------------------------------------------------------
/train_seg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import argparse
5 | import os
6 | import sys
7 | import logging
8 | import time
9 | import numpy as np
10 | from tqdm import tqdm
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | import torch.optim
15 | from pdb import set_trace as bp
16 | from data.gta5 import GTA5
17 | from data.cityscapes import Cityscapes
18 | from model import csg_builder
19 | from model.deeplab import ResNet as deeplab
20 | from dataloader_seg import get_train_loader
21 | from eval_seg import SegEvaluator
22 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, adjust_learning_rate
23 | from utils.logger import prepare_logger, prepare_seed
24 | from utils.sgd import SGD
25 |
26 | torch.backends.cudnn.enabled = True
27 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)
28 | KLDivLoss = nn.KLDivLoss(reduction='batchmean')
29 | best_mIoU = 0
30 |
31 | parser = argparse.ArgumentParser(description='PyTorch ResNet Training')
32 | parser.add_argument('--epochs', default=50, type=int, help='number of total epochs to run')
33 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
34 | parser.add_argument('--batch-size', default=6, type=int, dest='batch_size', help='mini-batch size (default: 64)')
35 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate')
36 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)')
37 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
38 | parser.add_argument('--csg', default=75., type=float, dest='csg', help="weight of LWF los (default: 0). Format: type('_')=>stage(',')")
39 | parser.add_argument('--switch-model', default='deeplab50', choices=["deeplab50", "deeplab101"], help='which model to use')
40 | parser.add_argument('--factor', default=0.1, type=float, dest='factor', help='scale factor of backbone learning rate (default: 0.1)')
41 | parser.add_argument('--csg-stages', dest='csg_stages', default='4', help='resnet stages to involve in LWF, 0~4, seperated by dot')
42 | parser.add_argument('--chunks', dest='chunks', default='8', help='stage-wise chunk to feature maps, seperated by dot')
43 | parser.add_argument('--no-mlp', dest='mlp', action='store_false', default=True, help='not to use mlp during contrastive learning')
44 | parser.add_argument('--apool', default=False, action='store_true', help='use A-Pool')
45 | parser.add_argument('--augment', action='store_true', default=False, help='use augmentation')
46 | parser.add_argument('--resume', default='none', type=str, help='path to latest checkpoint (default: none)')
47 | parser.add_argument('--num-class', default=19, type=int, dest='num_classes', help='the number of classes')
48 | parser.add_argument('--gpus', default=0, type=int, help='gpu to use')
49 | parser.add_argument('--evaluate', action='store_true', help='whether to use learn without forgetting (default: False)')
50 | parser.add_argument('--save_dir', type=str, default="./runs", help='root folder to save checkpoints and log.')
51 | parser.add_argument('--rand_seed', default=0, type=int, help='the number of classes')
52 | parser.add_argument('--csg-k', default=65536, type=int, help='queue size; number of negative keys (default: 65536)')
53 | parser.add_argument('--timestamp', type=str, default='none', help='timestamp for logging naming')
54 | parser.set_defaults(bottleneck=True)
55 |
56 | best_mIoU = 0
57 |
58 |
59 | def main():
60 | global args, best_mIoU
61 | PID = os.getpid()
62 | args = parser.parse_args()
63 | prepare_seed(args.rand_seed)
64 | device = torch.device("cuda:"+str(args.gpus))
65 |
66 | if args.timestamp == 'none':
67 | args.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))
68 |
69 | switch_model = args.switch_model
70 | assert switch_model in ["deeplab50", "deeplab101"]
71 |
72 | # Log outputs
73 | if args.evaluate:
74 | args.save_dir = args.save_dir + "/GTA5-%s-evaluate"%switch_model + \
75 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
76 | else:
77 | args.save_dir = args.save_dir + \
78 | "/GTA5_512x512-{model}-LWF.stg{csg_stages}.w{csg_weight}-APool.{apool}-Aug.{augment}-chunk{chunks}-mlp{mlp}.K{csg_k}-LR{lr}.bone{factor}-epoch{epochs}-batch{batch_size}-seed{seed}".format(
79 | model=switch_model,
80 | csg_stages=args.csg_stages,
81 | mlp=args.mlp,
82 | csg_weight=args.csg,
83 | apool=args.apool,
84 | augment=args.augment,
85 | chunks=args.chunks,
86 | csg_k=args.csg_k,
87 | lr="%.2E"%args.lr,
88 | factor="%.1f"%args.factor,
89 | epochs=args.epochs,
90 | batch_size=args.batch_size,
91 | seed=args.rand_seed
92 | ) + \
93 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
94 | logger = prepare_logger(args)
95 |
96 | from config_seg import config as data_setting
97 | data_setting.batch_size = args.batch_size
98 | train_loader = get_train_loader(data_setting, GTA5, test=False, augment=args.augment)
99 |
100 | args.stages = [int(stage) for stage in args.csg_stages.split('.')] if len(args.csg_stages) > 0 else []
101 | chunks = [int(chunk) for chunk in args.chunks.split('.')] if len(args.chunks) > 0 else []
102 | assert len(chunks) == 1 or len(chunks) == len(args.stages)
103 | if len(chunks) < len(args.stages):
104 | chunks = [chunks[0]] * len(args.stages)
105 |
106 | if switch_model == 'deeplab50':
107 | layers = [3, 4, 6, 3]
108 | elif switch_model == 'deeplab101':
109 | layers = [3, 4, 23, 3]
110 | model = csg_builder.CSG(deeplab, get_head=None, K=args.csg_k, stages=args.stages, chunks=chunks, task='new-seg',
111 | apool=args.apool, mlp=args.mlp,
112 | base_encoder_kwargs={'num_seg_classes': args.num_classes, 'layers': layers})
113 |
114 | threds = 3
115 | evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), args.num_classes, np.array([0.485, 0.456, 0.406]),
116 | np.array([0.229, 0.224, 0.225]), model.encoder_q, [1, ], False, devices=args.gpus, config=data_setting, threds=threds,
117 | verbose=False, save_path=None, show_image=False) # just calculate mIoU, no prediction file is generated
118 | # verbose=False, save_path="./prediction_files", show_image=True, show_prediction=True) # generate prediction files
119 |
120 |
121 | # Setup optimizer
122 | factor = args.factor
123 | sgd_in = [
124 | {'params': get_params(model.encoder_q, ["conv1"]), 'lr': factor*args.lr},
125 | {'params': get_params(model.encoder_q, ["bn1"]), 'lr': factor*args.lr},
126 | {'params': get_params(model.encoder_q, ["layer1"]), 'lr': factor*args.lr},
127 | {'params': get_params(model.encoder_q, ["layer2"]), 'lr': factor*args.lr},
128 | {'params': get_params(model.encoder_q, ["layer3"]), 'lr': factor*args.lr},
129 | {'params': get_params(model.encoder_q, ["layer4"]), 'lr': factor*args.lr},
130 | {'params': get_params(model.encoder_q, ["fc_new"]), 'lr': args.lr},
131 | ]
132 | base_lrs = [ group['lr'] for group in sgd_in ]
133 | optimizer = SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
134 |
135 | # Optionally resume from a checkpoint
136 | if args.resume != 'none':
137 | if os.path.isfile(args.resume):
138 | print("=> loading checkpoint '{}'".format(args.resume))
139 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
140 | args.start_epoch = checkpoint['epoch']
141 | best_mIoU = checkpoint['best_mIoU']
142 | msg = model.load_state_dict(checkpoint['state_dict'])
143 | print("resume weights: ", msg)
144 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
145 | else:
146 | print("=ImageClassdata> no checkpoint found at '{}'".format(args.resume))
147 |
148 | model = model.to(device)
149 |
150 | if args.evaluate:
151 | mIoU = validate(evaluator, model, -1)
152 | print(mIoU)
153 | exit(0)
154 |
155 | # Main training loop
156 | iter_max = args.epochs * len(train_loader)
157 | iter_stat = IterNums(iter_max)
158 | for epoch in range(args.start_epoch, args.epochs):
159 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir))
160 | logger.log("Epoch: %d"%(epoch+1))
161 | # train for one epoch
162 | train(args, train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger, device, adjust_lr=epoch best_mIoU
172 | best_mIoU = max(mIoU, best_mIoU)
173 | save_checkpoint(args.save_dir, {
174 | 'epoch': epoch + 1,
175 | 'state_dict': model.state_dict(),
176 | 'best_mIoU': best_mIoU,
177 | }, is_best)
178 |
179 | logging.info('Best accuracy: {mIoU:.3f}'.format(mIoU=best_mIoU))
180 |
181 |
182 | def train(args, train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger, device, adjust_lr=True):
183 | tb_interval = 50
184 |
185 | csg_weight = args.csg
186 |
187 | """Train for one epoch on the training set"""
188 | losses = AverageMeter()
189 | losses_csg = [AverageMeter() for _ in range(len(model.stages))] # [_loss] x #stages
190 | top1_csg = [AverageMeter() for _ in range(len(model.stages))]
191 |
192 | model.eval()
193 | model.encoder_q.fc_new.train()
194 |
195 | # train for one epoch
196 | optimizer.zero_grad()
197 | epoch_size = len(train_loader)
198 | train_loader_iter = iter(train_loader)
199 |
200 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
201 | pbar = tqdm(range(epoch_size), file=sys.stdout, bar_format=bar_format, ncols=80)
202 |
203 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9)
204 | logger.log("lr %f"%lr)
205 | for idx_iter in pbar:
206 |
207 | optimizer.zero_grad()
208 | if adjust_lr:
209 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9)
210 | adjust_learning_rate(base_lrs, optimizer, iter_stat.iter_curr, iter_stat.iter_max, 0.9)
211 |
212 | sample = next(train_loader_iter)
213 | label = sample['label'].to(device)
214 | input = sample['data']
215 | if args.augment:
216 | input_q = input.to(device)
217 | input_k = sample['img_k'].to(device)
218 | else:
219 | input_q = input.to(device)
220 | input_k = None
221 |
222 | # keys: output, predictions_csg, targets_csg
223 | results = model(input_q, input_k)
224 |
225 | # synthetic task
226 | loss = CrossEntropyLoss(results['output'], label.long())
227 | # measure accuracy and record loss
228 | losses.update(loss, label.size(0))
229 | for idx in range(len(model.stages)):
230 | _loss = 0
231 | acc1 = None
232 | # predictions: cosine b/w q and k
233 | # targets: zeros
234 | _loss = CrossEntropyLoss(results['predictions_csg'][idx], results['targets_csg'][idx])
235 | acc1, acc5 = accuracy_ranking(results['predictions_csg'][idx].data, results['targets_csg'][idx], topk=(1, 5))
236 | loss = loss + _loss * csg_weight
237 | if acc1 is not None: top1_csg[idx].update(acc1, label.size(0))
238 | # measure accuracy and record loss
239 | losses_csg[idx].update(_loss, label.size(0))
240 |
241 | loss.backward()
242 |
243 | # compute gradient and do SGD step
244 | optimizer.step()
245 | # increment iter number
246 | iter_stat.update()
247 |
248 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/ce", losses.val, idx_iter + epoch * epoch_size)
249 | description = "[XE %.3f]"%(losses.val)
250 | description += "[CSG "
251 | loss_str = ""
252 | acc_str = ""
253 | for idx, stage in enumerate(model.stages):
254 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/layer%d"%stage, losses_csg[idx].val, idx_iter + epoch * epoch_size)
255 | loss_str += "%.2f|"%losses_csg[idx].val
256 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("prec/layer%d"%stage, top1_csg[idx].val[0], idx_iter + epoch * epoch_size)
257 | acc_str += "%.1f|"%top1_csg[idx].val[0]
258 | description += "loss:%s ranking:%s]"%(loss_str[:-1], acc_str[:-1])
259 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/total", losses.val + sum([_loss.val for _loss in losses_csg]), idx_iter + epoch * epoch_size)
260 | pbar.set_description("[Step %d/%d][%s]"%(idx_iter + 1, epoch_size, str(csg_weight)) + description)
261 |
262 |
263 | def validate(evaluator, model, epoch):
264 | with torch.no_grad():
265 | model.eval()
266 | # _, mIoU = evaluator.run_online()
267 | _, mIoU = evaluator.run_online_multiprocess()
268 | return mIoU
269 |
270 |
271 | def accuracy_ranking(output, target, topk=(1,)):
272 | """Computes the accuracy over the k top predictions for the specified values of k"""
273 | with torch.no_grad():
274 | maxk = max(topk)
275 | batch_size = target.size(0)
276 |
277 | _, pred = output.topk(maxk, 1, True, True)
278 | pred = pred.t()
279 | correct = pred.eq(target.view(1, -1).expand_as(pred))
280 |
281 | res = []
282 | for k in topk:
283 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
284 | res.append(correct_k.mul_(100.0 / batch_size))
285 | return res
286 |
287 |
288 | if __name__ == '__main__':
289 | main()
290 |
--------------------------------------------------------------------------------
/train_seg.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | python train_seg.py \
5 | --epochs 50 \
6 | --switch-model deeplab101 \
7 | --batch-size 6 \
8 | --lr 1e-3 \
9 | --num-class 19 \
10 | --gpus 0 \
11 | --factor 0.1 \
12 | --csg 75 \
13 | --apool \
14 | --csg-stages 3.4 \
15 | --chunks 8 \
16 | --augment \
17 | --evaluate \
18 | --resume pretrained/csg_res101_segmentation_best.pth.tar \
19 | # --resume pretrained/csg_res50_segmentation_best.pth.tar \
20 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
--------------------------------------------------------------------------------
/utils/augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.))
3 |
4 | import random
5 |
6 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 |
11 |
12 | def ShearX(img, v): # [-0.3, 0.3]
13 | assert -0.3 <= v <= 0.3
14 | if random.random() > 0.5:
15 | v = -v
16 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
17 |
18 |
19 | def ShearY(img, v): # [-0.3, 0.3]
20 | assert -0.3 <= v <= 0.3
21 | if random.random() > 0.5:
22 | v = -v
23 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
24 |
25 |
26 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
27 | assert -0.45 <= v <= 0.45
28 | if random.random() > 0.5:
29 | v = -v
30 | v = v * img.size[0]
31 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
32 |
33 |
34 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
35 | assert 0 <= v
36 | if random.random() > 0.5:
37 | v = -v
38 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
39 |
40 |
41 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
42 | assert -0.45 <= v <= 0.45
43 | if random.random() > 0.5:
44 | v = -v
45 | v = v * img.size[1]
46 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
47 |
48 |
49 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
50 | assert 0 <= v
51 | if random.random() > 0.5:
52 | v = -v
53 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
54 |
55 |
56 | def Rotate(img, v): # [-30, 30]
57 | assert -30 <= v <= 30
58 | if random.random() > 0.5:
59 | v = -v
60 | return img.rotate(v)
61 |
62 |
63 | def AutoContrast(img, v):
64 | if random.random() <= v:
65 | return PIL.ImageOps.autocontrast(img)
66 | else:
67 | return img
68 |
69 |
70 | def Invert(img, v):
71 | if random.random() <= v:
72 | return PIL.ImageOps.invert(img)
73 | else:
74 | return img
75 |
76 |
77 | def Equalize(img, v):
78 | if random.random() <= v:
79 | return PIL.ImageOps.equalize(img)
80 | else:
81 | return img
82 |
83 |
84 | def Flip(img, _): # not from the paper
85 | return PIL.ImageOps.mirror(img)
86 |
87 |
88 | def Solarize(img, v): # [0, 256]
89 | assert 0 <= v <= 256
90 | v = 256 - v
91 | return PIL.ImageOps.solarize(img, v)
92 |
93 |
94 | def SolarizeAdd(img, addition=0, threshold=128):
95 | img_np = np.array(img).astype(np.int)
96 | img_np = img_np + addition
97 | img_np = np.clip(img_np, 0, 255)
98 | img_np = img_np.astype(np.uint8)
99 | img = Image.fromarray(img_np)
100 | return PIL.ImageOps.solarize(img, threshold)
101 |
102 |
103 | def Posterize(img, v): # [4, 8]
104 | assert 0 <= v <= 7
105 | # v = int(v)
106 | v = 8 - int(v)
107 | v = max(1, v)
108 | return PIL.ImageOps.posterize(img, v)
109 |
110 |
111 | def Contrast(img, v): # [0.,0.9]
112 | # A factor of 1.0 gives the original image.
113 | assert 0. <= v <= 0.9
114 | if random.random() > 0.5:
115 | v = -v
116 | return PIL.ImageEnhance.Contrast(img).enhance(v+1) # 0.1 to 1.9
117 |
118 |
119 | def Color(img, v): # [0.,0.9]
120 | # A factor of 1.0 gives the original image.
121 | assert 0. <= v <= 0.9
122 | if random.random() > 0.5:
123 | v = -v
124 | return PIL.ImageEnhance.Color(img).enhance(v+1) # 0.1 to 1.9
125 |
126 |
127 | def Brightness(img, v): # [0.,0.9]
128 | # A factor of 1.0 gives the original image.
129 | assert 0. <= v <= 0.9
130 | if random.random() > 0.5:
131 | v = -v
132 | return PIL.ImageEnhance.Brightness(img).enhance(v+1) # 0.1 to 1.9
133 |
134 |
135 | def Sharpness(img, v): # [0.,.9]
136 | assert 0. <= v <= 0.9
137 | if random.random() > 0.5:
138 | v = -v
139 | return PIL.ImageEnhance.Sharpness(img).enhance(v+1) # 0.1 to 1.9
140 |
141 |
142 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
143 | assert 0.0 <= v <= 0.2
144 | if v <= 0.:
145 | return img
146 |
147 | v = v * img.size[0]
148 | return CutoutAbs(img, v)
149 |
150 |
151 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
152 | # assert 0 <= v <= 20
153 | if v < 0:
154 | return img
155 | w, h = img.size
156 | x0 = np.random.uniform(w)
157 | y0 = np.random.uniform(h)
158 |
159 | x0 = int(max(0, x0 - v / 2.))
160 | y0 = int(max(0, y0 - v / 2.))
161 | x1 = min(w, x0 + v)
162 | y1 = min(h, y0 + v)
163 |
164 | xy = (x0, y0, x1, y1)
165 | color = (125, 123, 114)
166 | # color = (0, 0, 0)
167 | img = img.copy()
168 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
169 | return img
170 |
171 |
172 | def SamplePairing(imgs): # [0, 0.4]
173 | def f(img1, v):
174 | i = np.random.choice(len(imgs))
175 | img2 = PIL.Image.fromarray(imgs[i])
176 | return PIL.Image.blend(img1, img2, v)
177 |
178 | return f
179 |
180 |
181 | def Identity(img, v):
182 | return img
183 |
184 |
185 | augment_list = [
186 | (Identity, 0., 1.0),
187 | (AutoContrast, 0, 1),
188 | (Equalize, 0, 1),
189 | (Rotate, 0, 30),
190 | (Posterize, 0, 7),
191 | (Solarize, 0, 256),
192 | (Color, 0., 0.9),
193 | (Contrast, 0., 0.9),
194 | (Brightness, 0., 0.9),
195 | (Sharpness, 0., 0.9),
196 | (ShearX, 0., 0.3),
197 | (ShearY, 0., 0.3),
198 | (TranslateXabs, 0., 100),
199 | (TranslateYabs, 0., 100),
200 | ]
201 |
202 |
203 | class Lighting(object):
204 | """Lighting noise(AlexNet - style PCA - based noise)"""
205 |
206 | def __init__(self, alphastd, eigval, eigvec):
207 | self.alphastd = alphastd
208 | self.eigval = torch.Tensor(eigval)
209 | self.eigvec = torch.Tensor(eigvec)
210 |
211 | def __call__(self, img):
212 | if self.alphastd == 0:
213 | return img
214 |
215 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
216 | rgb = self.eigvec.type_as(img).clone() \
217 | .mul(alpha.view(1, 3).expand(3, 3)) \
218 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
219 | .sum(1).squeeze()
220 |
221 | return img.add(rgb.view(3, 1, 1).expand_as(img))
222 |
223 |
224 | class CutoutDefault(object):
225 | def __init__(self, length):
226 | self.length = length
227 |
228 | def __call__(self, img):
229 | h, w = img.size(1), img.size(2)
230 | mask = np.ones((h, w), np.float32)
231 | y = np.random.randint(h)
232 | x = np.random.randint(w)
233 |
234 | y1 = np.clip(y - self.length // 2, 0, h)
235 | y2 = np.clip(y + self.length // 2, 0, h)
236 | x1 = np.clip(x - self.length // 2, 0, w)
237 | x2 = np.clip(x + self.length // 2, 0, w)
238 |
239 | mask[y1: y2, x1: x2] = 0.
240 | mask = torch.from_numpy(mask)
241 | mask = mask.expand_as(img)
242 | img *= mask
243 | return img
244 |
245 |
246 | class RandAugment:
247 | def __init__(self, n, m, augment_list):
248 | self.n = n
249 | self.m = m # [0, 30]
250 | assert 0 <= m <= 30
251 | self.augment_list = augment_list
252 |
253 | def __call__(self, img):
254 | ops = random.choices(self.augment_list, k=self.n)
255 | for op, minval, maxval in ops:
256 | val = (float(self.m) / 30) * float(maxval - minval) + minval
257 | img = op(img, val)
258 |
259 | return img
260 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | from pathlib import Path
5 | import importlib, warnings
6 | import os, sys, numpy as np
7 | import torch, random, PIL, copy
8 | import glob
9 | import shutil
10 | if sys.version_info.major == 2: # Python 2.x
11 | from StringIO import StringIO as BIO
12 | else: # Python 3.x
13 | from io import BytesIO as BIO
14 | from torch.utils.tensorboard import SummaryWriter
15 | if importlib.util.find_spec('tensorflow'):
16 | import tensorflow as tf
17 |
18 |
19 | def prepare_seed(rand_seed):
20 | random.seed(rand_seed)
21 | np.random.seed(rand_seed)
22 | torch.manual_seed(rand_seed)
23 | torch.cuda.manual_seed(rand_seed)
24 | torch.cuda.manual_seed_all(rand_seed)
25 |
26 |
27 | def prepare_logger(xargs):
28 | args = copy.deepcopy( xargs )
29 | logger = Logger(args.save_dir, args.rand_seed)
30 | logger.log('Main Function with logger : {:}'.format(logger))
31 | logger.log('Arguments : -------------------------------')
32 | for name, value in args._get_kwargs():
33 | logger.log('{:16} : {:}'.format(name, value))
34 | logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
35 | logger.log("Pillow Version : {:}".format(PIL.__version__))
36 | logger.log("PyTorch Version : {:}".format(torch.__version__))
37 | logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
38 | logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
39 | logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
40 | logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
41 | return logger
42 |
43 |
44 | class PrintLogger(object):
45 |
46 | def __init__(self):
47 | """Create a summary writer logging to log_dir."""
48 | self.name = 'PrintLogger'
49 |
50 | def log(self, string):
51 | print (string)
52 |
53 | def close(self):
54 | print ('-'*30 + ' close printer ' + '-'*30)
55 |
56 |
57 | class Logger(object):
58 |
59 | def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
60 | """Create a summary writer logging to log_dir."""
61 | self.seed = int(seed)
62 | self.log_dir = Path(log_dir)
63 | self.model_dir = Path(log_dir) / 'model'
64 | self.log_dir.mkdir (parents=True, exist_ok=True)
65 |
66 | self.use_tf = bool(use_tf)
67 | self.tensorboard_dir = self.log_dir
68 | self.logger_path = self.log_dir / 'seed-{:}.log'.format(self.seed)
69 | self.logger_file = open(self.logger_path, 'w')
70 |
71 | self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
72 | self.writer = SummaryWriter(str(self.tensorboard_dir))
73 |
74 | scripts_to_save=glob.glob('*.py')+glob.glob('*.sh')
75 | os.mkdir(os.path.join(log_dir, 'scripts'))
76 | for script in scripts_to_save:
77 | dst_file = os.path.join(log_dir, 'scripts', os.path.basename(script))
78 | shutil.copyfile(script, dst_file)
79 | shutil.make_archive(os.path.join(log_dir, "scripts"), 'zip', log_dir, "scripts")
80 | shutil.rmtree(os.path.join(log_dir, 'scripts'))
81 |
82 | def __repr__(self):
83 | return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__))
84 |
85 | def path(self, mode):
86 | valids = ('model', 'best', 'info', 'log')
87 | if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed)
88 | elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed)
89 | elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed)
90 | elif mode == 'log' : return self.log_dir
91 | else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids))
92 |
93 | def extract_log(self):
94 | return self.logger_file
95 |
96 | def close(self):
97 | self.logger_file.close()
98 | if self.writer is not None:
99 | self.writer.close()
100 |
101 | def log(self, string, save=True, stdout=False):
102 | if stdout:
103 | sys.stdout.write(string); sys.stdout.flush()
104 | else:
105 | print (string)
106 | if save:
107 | self.logger_file.write('{:}\n'.format(string))
108 | self.logger_file.flush()
109 |
110 | def scalar_summary(self, tags, values, step):
111 | """Log a scalar variable."""
112 | if not self.use_tf:
113 | warnings.warn('Do set use-tensorflow installed but call scalar_summary')
114 | else:
115 | assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values))
116 | if not isinstance(tags, list):
117 | tags, values = [tags], [values]
118 | for tag, value in zip(tags, values):
119 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
120 | self.writer.add_summary(summary, step)
121 | self.writer.flush()
122 |
123 | def image_summary(self, tag, images, step):
124 | """Log a list of images."""
125 | import scipy
126 | if not self.use_tf:
127 | warnings.warn('Do set use-tensorflow installed but call scalar_summary')
128 | return
129 |
130 | img_summaries = []
131 | for i, img in enumerate(images):
132 | # Write the image to a string
133 | try:
134 | s = StringIO()
135 | except:
136 | s = BytesIO()
137 | scipy.misc.toimage(img).save(s, format="png")
138 |
139 | # Create an Image object
140 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
141 | height=img.shape[0],
142 | width=img.shape[1])
143 | # Create a Summary value
144 | img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum))
145 |
146 | # Create and write Summary
147 | summary = tf.Summary(value=img_summaries)
148 | self.writer.add_summary(summary, step)
149 | self.writer.flush()
150 |
151 | def histo_summary(self, tag, values, step, bins=1000):
152 | """Log a histogram of the tensor of values."""
153 | if not self.use_tf: raise ValueError('Do not have tensorflow')
154 | import tensorflow as tf
155 |
156 | # Create a histogram using numpy
157 | counts, bin_edges = np.histogram(values, bins=bins)
158 |
159 | # Fill the fields of the histogram proto
160 | hist = tf.HistogramProto()
161 | hist.min = float(np.min(values))
162 | hist.max = float(np.max(values))
163 | hist.num = int(np.prod(values.shape))
164 | hist.sum = float(np.sum(values))
165 | hist.sum_squares = float(np.sum(values**2))
166 |
167 | # Drop the start of the first bin
168 | bin_edges = bin_edges[1:]
169 |
170 | # Add bin edges and counts
171 | for edge in bin_edges:
172 | hist.bucket_limit.append(edge)
173 | for c in counts:
174 | hist.bucket.append(c)
175 |
176 | # Create and write Summary
177 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
178 | self.writer.add_summary(summary, step)
179 | self.writer.flush()
180 |
--------------------------------------------------------------------------------
/utils/sgd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer, required
3 |
4 |
5 | # fixed SGD
6 | # See Note here: https://pytorch.org/docs/stable/optim.html#torch.optim.SGD
7 | class SGD(Optimizer):
8 | r"""Implements stochastic gradient descent (optionally with momentum).
9 |
10 | Nesterov momentum is based on the formula from
11 | `On the importance of initialization and momentum in deep learning`__.
12 |
13 | Args:
14 | params (iterable): iterable of parameters to optimize or dicts defining
15 | parameter groups
16 | lr (float): learning rate
17 | momentum (float, optional): momentum factor (default: 0)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 | dampening (float, optional): dampening for momentum (default: 0)
20 | nesterov (bool, optional): enables Nesterov momentum (default: False)
21 |
22 | Example:
23 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
24 | >>> optimizer.zero_grad()
25 | >>> loss_fn(model(input), target).backward()
26 | >>> optimizer.step()
27 |
28 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
29 |
30 | .. note::
31 | The implementation of SGD with Momentum/Nesterov subtly differs from
32 | Sutskever et. al. and implementations in some other frameworks.
33 |
34 | Considering the specific case of Momentum, the update can be written as
35 |
36 | .. math::
37 | v = \rho * v + g \\
38 | p = p - lr * v
39 |
40 | where p, g, v and :math:`\rho` denote the parameters, gradient,
41 | velocity, and momentum respectively.
42 |
43 | This is in contrast to Sutskever et. al. and
44 | other frameworks which employ an update of the form
45 |
46 | .. math::
47 | v = \rho * v + lr * g \\
48 | p = p - v
49 |
50 | The Nesterov version is analogously modified.
51 | """
52 |
53 | def __init__(self, params, lr=required, momentum=0, dampening=0,
54 | weight_decay=0, nesterov=False):
55 | if lr is not required and lr < 0.0:
56 | raise ValueError("Invalid learning rate: {}".format(lr))
57 | if momentum < 0.0:
58 | raise ValueError("Invalid momentum value: {}".format(momentum))
59 | if weight_decay < 0.0:
60 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
61 |
62 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
63 | weight_decay=weight_decay, nesterov=nesterov)
64 | if nesterov and (momentum <= 0 or dampening != 0):
65 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
66 | super(SGD, self).__init__(params, defaults)
67 |
68 | def __setstate__(self, state):
69 | super(SGD, self).__setstate__(state)
70 | for group in self.param_groups:
71 | group.setdefault('nesterov', False)
72 |
73 | def step(self, closure=None):
74 | """Performs a single optimization step.
75 |
76 | Arguments:
77 | closure (callable, optional): A closure that reevaluates the model
78 | and returns the loss.
79 | """
80 | loss = None
81 | if closure is not None:
82 | loss = closure()
83 |
84 | for group in self.param_groups:
85 | weight_decay = group['weight_decay']
86 | momentum = group['momentum']
87 | dampening = group['dampening']
88 | nesterov = group['nesterov']
89 |
90 | for p in group['params']:
91 | if p.grad is None:
92 | continue
93 | d_p = p.grad.data
94 | if weight_decay != 0:
95 | d_p.add_(weight_decay, p.data)
96 | if momentum != 0:
97 | param_state = self.state[p]
98 | if 'momentum_buffer' not in param_state:
99 | # buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
100 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach().mul_(group['lr'])
101 | else:
102 | buf = param_state['momentum_buffer']
103 | # buf.mul_(momentum).add_(1 - dampening, d_p)
104 | buf.mul_(momentum).add_(1 - dampening, d_p.mul_(group['lr']))
105 | if nesterov:
106 | d_p = d_p.add(momentum, buf)
107 | else:
108 | d_p = buf
109 |
110 | # p.data.add_(-group['lr'], d_p)
111 | p.data.add_(-1, d_p)
112 |
113 | return loss
114 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import glob
5 | import os
6 | import shutil
7 | import numpy as np
8 | import torch
9 | from pdb import set_trace as bp
10 | from PIL import Image
11 | import matplotlib.cm as mpl_color_map
12 | import copy
13 |
14 |
15 | def get_params(model, layers=["layer4"]):
16 | """
17 | This generator returns all the parameters of the net except for
18 | the last classification layer. Note that for each batchnorm layer,
19 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
20 | any batchnorm parameter
21 | """
22 | if isinstance(layers, str):
23 | layers = [layers]
24 | b = []
25 | for layer in layers:
26 | b.append(getattr(model, layer))
27 |
28 | for i in range(len(b)):
29 | for k, v in b[i].named_parameters():
30 | if v.requires_grad:
31 | yield v
32 |
33 |
34 | def adjust_learning_rate_exp(optimizer, power=0.746):
35 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
36 | num_groups = len(optimizer.param_groups)
37 | for g in range(num_groups):
38 | optimizer.param_groups[g]['lr'] *= power
39 |
40 |
41 | def adjust_learning_rate(base_lrs, optimizer, iter_curr, iter_max, power):
42 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
43 | num_groups = len(optimizer.param_groups)
44 | for g in range(num_groups):
45 | optimizer.param_groups[g]['lr'] = lr_poly(base_lrs[g], iter_curr, iter_max, power)
46 |
47 |
48 | def lr_poly(base_lr, iter, max_iter, power):
49 | return min(0.01+0.99*(float(iter)/100)**2.0, 1.0) * base_lr * ((1-float(iter)/max_iter)**power) # This is with warm up
50 | # return min(0.01+0.99*(float(iter)/100)**2.0, 1.0) * base_lr * ((1-min(float(iter)/max_iter, 0.8))**power) # This is with warm up & no smaller than last 20% LR
51 | # return base_lr * ((1-float(iter)/max_iter)**power)
52 |
53 |
54 | def save_checkpoint(name, state, is_best, filename='checkpoint.pth.tar', keep_last=1):
55 | """Saves checkpoint to disk"""
56 | directory = name
57 | if not os.path.exists(directory):
58 | os.makedirs(directory)
59 | models_paths = list(filter(os.path.isfile, glob.glob(directory + "/epoch*.pth.tar")))
60 | models_paths.sort(key=os.path.getmtime, reverse=False)
61 | if len(models_paths) == keep_last:
62 | for i in range(len(models_paths) + 1 - keep_last):
63 | os.remove(models_paths[i])
64 | torch.save(state, directory + '/epoch_'+str(state['epoch']) + '_' + filename)
65 | filename = directory + '/latest_' + filename
66 | torch.save(state, filename)
67 | if is_best:
68 | shutil.copyfile(filename, '%s/'%(name) + 'model_best.pth.tar')
69 |
70 |
71 | class IterNums(object):
72 | def __init__(self, iter_max):
73 | self.iter_max = iter_max
74 | self.iter_curr = 0
75 |
76 | def reset(self):
77 | self.iter_curr = 0
78 |
79 | def update(self):
80 | self.iter_curr += 1
81 |
82 |
83 | class AverageMeter(object):
84 | """Computes and stores the average and current value"""
85 | def __init__(self):
86 | self.reset()
87 |
88 | def reset(self):
89 | self.val = 0
90 | self.avg = 0
91 | self.sum = 0
92 | self.count = 0
93 | self.vec2sca_avg = 0
94 | self.vec2sca_val = 0
95 |
96 | def update(self, val, n=1):
97 | self.val = val
98 | self.sum += val * n
99 | self.count += n
100 | self.avg = self.sum / self.count
101 | if torch.is_tensor(self.val) and torch.numel(self.val) != 1:
102 | self.avg[self.count == 0] = 0
103 | self.vec2sca_avg = self.avg.sum() / len(self.avg)
104 | self.vec2sca_val = self.val.sum() / len(self.val)
105 |
106 |
107 | def accuracy(output, label, num_class, topk=(1,)):
108 | """Computes the precision@k for the specified values of k, currently only k=1 is supported"""
109 | maxk = max(topk)
110 |
111 | _, pred = output.topk(maxk, 1, True, True)
112 | if len(label.size()) == 2:
113 | # one_hot label
114 | _, gt = label.topk(maxk, 1, True, True)
115 | else:
116 | gt = label
117 | pred = pred.t()
118 | pred_class_idx_list = [pred == class_idx for class_idx in range(num_class)]
119 | gt = gt.t()
120 | gt_class_number_list = [(gt == class_idx).sum() for class_idx in range(num_class)]
121 | correct = pred.eq(gt)
122 |
123 | res = []
124 | gt_num = []
125 | for k in topk:
126 | correct_k = correct[:k].float()
127 | per_class_correct_list = [correct_k[pred_class_idx].sum(0) for pred_class_idx in pred_class_idx_list]
128 | per_class_correct_array = torch.tensor(per_class_correct_list)
129 | gt_class_number_tensor = torch.tensor(gt_class_number_list).float()
130 | gt_class_zeronumber_tensor = gt_class_number_tensor == 0
131 | gt_class_number_matrix = torch.tensor(gt_class_number_list).float()
132 | gt_class_acc = per_class_correct_array.mul_(100.0 / gt_class_number_matrix)
133 | gt_class_acc[gt_class_zeronumber_tensor] = 0
134 | res.append(gt_class_acc)
135 | gt_num.append(gt_class_number_matrix)
136 | return res, gt_num
137 |
138 |
139 | def apply_colormap_on_image(org_im, activation, colormap_name='hsv'):
140 | """
141 | Apply heatmap on image
142 | Args:
143 | org_img (PIL img): Original image
144 | activation_map (numpy arr): Activation map (grayscale) 0-255
145 | colormap_name (str): Name of the colormap
146 | """
147 | # Get colormap
148 | color_map = mpl_color_map.get_cmap(colormap_name)
149 | no_trans_heatmap = color_map(activation)
150 | # Change alpha channel in colormap to make sure original image is displayed
151 | heatmap = copy.copy(no_trans_heatmap)
152 | heatmap[:, :, 3] = 0.5
153 | heatmap = Image.fromarray((heatmap*255).astype(np.uint8))
154 | no_trans_heatmap = Image.fromarray((no_trans_heatmap*255).astype(np.uint8))
155 |
156 | # Apply heatmap on iamge
157 | heatmap_on_image = Image.new("RGBA", org_im.size)
158 | heatmap_on_image = Image.alpha_composite(heatmap_on_image, org_im.convert('RGBA'))
159 | heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap)
160 | return no_trans_heatmap, heatmap_on_image
161 |
162 |
163 | class UnNormalize(object):
164 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
165 | self.mean = mean
166 | self.std = std
167 |
168 | def __call__(self, tensor):
169 | """
170 | Args:
171 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
172 | Returns:
173 | Tensor: Normalized image.
174 | """
175 | for t, m, s in zip(tensor, self.mean, self.std):
176 | t.mul_(s).add_(m)
177 | # The normalize code -> t.sub_(m).div_(s)
178 | return tensor
179 |
180 |
181 | class AvgrageMeter(object):
182 |
183 | def __init__(self):
184 | self.reset()
185 |
186 | def reset(self):
187 | self.avg = 0
188 | self.sum = 0
189 | self.cnt = 0
190 |
191 | def update(self, val, n=1):
192 | self.sum += val * n
193 | self.cnt += n
194 | self.avg = self.sum / self.cnt
195 |
196 |
197 | class Cutout(object):
198 | def __init__(self, length):
199 | self.length = length
200 |
201 | def __call__(self, img):
202 | h, w = img.size(1), img.size(2)
203 | mask = np.ones((h, w), np.float32)
204 | y = np.random.randint(h)
205 | x = np.random.randint(w)
206 |
207 | y1 = np.clip(y - self.length // 2, 0, h)
208 | y2 = np.clip(y + self.length // 2, 0, h)
209 | x1 = np.clip(x - self.length // 2, 0, w)
210 | x2 = np.clip(x + self.length // 2, 0, w)
211 |
212 | mask[y1: y2, x1: x2] = 0.
213 | mask = torch.from_numpy(mask)
214 | mask = mask.expand_as(img)
215 | img *= mask
216 | return img
217 |
218 |
219 | def count_parameters_in_MB(model):
220 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
221 |
222 |
223 | def save(model, model_path):
224 | torch.save(model.state_dict(), model_path)
225 |
226 |
227 | def load(model, model_path):
228 | model.load_state_dict(torch.load(model_path))
229 |
230 |
231 | def create_exp_dir(path, scripts_to_save=None):
232 | if not os.path.exists(path):
233 | os.makedirs(path)
234 | print('Experiment dir : {}'.format(path))
235 |
236 | if scripts_to_save is not None:
237 | os.mkdir(os.path.join(path, 'scripts'))
238 | for script in scripts_to_save:
239 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
240 | shutil.copyfile(script, dst_file)
241 |
242 |
--------------------------------------------------------------------------------