├── .gitignore ├── LICENSE.md ├── README.md ├── config_seg.py ├── data ├── BaseDataset.py ├── cityscapes.py ├── cityscapes │ ├── cityscapes_val_fine.txt │ ├── cityscapes_val_fine_raw.txt │ └── train_ClsConfSet.lst ├── gta5.py ├── labels.py ├── loader_csg.py └── visda17.py ├── dataloader_seg.py ├── eval_seg.py ├── model ├── __init__.py ├── csg_builder.py ├── deeplab.py └── resnet.py ├── requirements.txt ├── tools ├── datasets │ ├── BaseDataset.py │ └── cityscapes │ │ ├── cityscapes.py │ │ └── cityscapes_val_fine.txt ├── engine │ ├── evaluator.py │ ├── logger.py │ └── tester.py ├── seg_opr │ └── metric.py └── utils │ ├── img_utils.py │ ├── pyt_utils.py │ └── visualize.py ├── train.py ├── train.sh ├── train_seg.py ├── train_seg.sh └── utils ├── __init__.py ├── augmentations.py ├── logger.py ├── sgd.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # vim swp files 2 | *.swp 3 | # caffe/pytorch model files 4 | runs/* 5 | crst_visda/runs/* 6 | *.pth 7 | *.tar 8 | *_softmax.txt 9 | *_seed*.txt 10 | *_soft*.txt 11 | *.json 12 | 13 | # Mkdocs 14 | # /docs/ 15 | /mkdocs/docs/temp 16 | 17 | .DS_Store 18 | .idea 19 | .vscode 20 | .pytest_cache 21 | /experiments 22 | node_modules/ 23 | history/ 24 | ablation/ 25 | misc/ 26 | prediction/ 27 | results/ 28 | 29 | # resource temp folder 30 | tests/resources/temp/* 31 | !tests/resources/temp/.gitkeep 32 | 33 | # Byte-compiled / optimized / DLL files 34 | __pycache__/ 35 | *.py[cod] 36 | *$py.class 37 | 38 | # C extensions 39 | *.so 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib/ 50 | lib64/ 51 | parts/ 52 | sdist/ 53 | var/ 54 | wheels/ 55 | *.egg-info/ 56 | .installed.cfg 57 | *.egg 58 | MANIFEST 59 | 60 | # PyInstaller 61 | # Usually these files are written by a python script from a template 62 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 63 | *.manifest 64 | *.spec 65 | 66 | # Installer logs 67 | pip-log.txt 68 | pip-delete-this-directory.txt 69 | 70 | # Unit test / coverage reports 71 | htmlcov/ 72 | .tox/ 73 | .coverage 74 | .coverage.* 75 | .cache 76 | nosetests.xml 77 | coverage.xml 78 | *.cover 79 | .hypothesis/ 80 | .pytest_cache/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | .static_storage/ 89 | .media/ 90 | local_settings.py 91 | local_settings.py 92 | db.sqlite3 93 | 94 | # Flask stuff: 95 | instance/ 96 | .webassets-cache 97 | 98 | # Scrapy stuff: 99 | .scrapy 100 | 101 | # Sphinx documentation 102 | docs/_build/ 103 | 104 | # PyBuilder 105 | target/ 106 | 107 | # Jupyter Notebook 108 | .ipynb_checkpoints 109 | 110 | # pyenv 111 | .python-version 112 | 113 | # celery beat schedule file 114 | celerybeat-schedule 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code 2 | 3 | License for Contrastive Syn-to-Real Generalization (CSG) 4 | --- 5 | 6 | 1. Definitions 7 | 8 | "Licensor" means any person or entity that distributes its Work. 9 | 10 | "Software" means the original work of authorship made available under this License. 11 | 12 | "Work" means the Software and any additions to or derivative works of the Software that are made available under this License. 13 | 14 | The terms "reproduce," "reproduction," "derivative works," and "distribution" have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works, including the Software, are "made available" under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 17 | 18 | 2. License Grant 19 | 20 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 21 | 22 | 3. Limitations 23 | 24 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 25 | 26 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work ("Your Terms") only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 27 | 28 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, "non-commercially" means for research or evaluation purposes only. 29 | 30 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 31 | 32 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 33 | 34 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately. 35 | 36 | 4. Disclaimer of Warranty. 37 | 38 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 39 | 40 | 5. Limitation of Liability. 41 | 42 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.7](https://img.shields.io/badge/python-3.7-green.svg) 2 | 3 | # CSG: Contrastive Syn-to-Real Generalization 4 | 5 | 6 | [Paper](https://arxiv.org/abs/2104.02290) 7 | 8 | Contrastive Syn-to-Real Generalization.
9 | [Wuyang Chen](https://chenwydj.github.io/), [Zhiding Yu](https://chrisding.github.io/), [Shalini De Mello](https://research.nvidia.com/person/shalini-gupta), [Sifei Liu](https://www.sifeiliu.net/), [Jose M. Alvarez](https://rsu.data61.csiro.au/people/jalvarez/), [Zhangyang Wang](https://www.atlaswang.com/), [Anima Anandkumar](http://tensorlab.cms.caltech.edu/users/anima/).
10 | In ICLR 2021. 11 | 12 | * Visda-17 to COCO 13 | - [x] train resnet101 with CSG 14 | - [x] evaluation 15 | * GTA5 to Cityscapes 16 | - [x] train deeplabv2 (resnet50/resnet101) with CSG 17 | - [x] evaluation 18 | 19 | ## Usage 20 | 21 | ### Visda-17 22 | * Download [Visda-17 Dataset](http://ai.bu.edu/visda-2017/#download) 23 | 24 | #### Evaluation 25 | * Download [pretrained ResNet101 on Visda17](https://drive.google.com/file/d/1VdbrwevsYy7I5S3Wo7-S3MwrZZjj09QS/view?usp=sharing) 26 | * Put the checkpoint under `./CSG/pretrained/` 27 | * Put the code below in `train.sh` 28 | ```bash 29 | python train.py \ 30 | --epochs 30 \ 31 | --batch-size 32 \ 32 | --lr 1e-4 \ 33 | --rand_seed 0 \ 34 | --csg 0.1 \ 35 | --apool \ 36 | --augment \ 37 | --csg-stages 3.4 \ 38 | --factor 0.1 \ 39 | --resume pretrained/csg_res101_vista17_best.pth.tar \ 40 | --evaluate 41 | ``` 42 | * Run `CUDA_VISIBLE_DEVICES=0 bash train.sh` 43 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 44 | 45 | #### Train with CSG 46 | * Put the code below in `train.sh` 47 | ```bash 48 | python train.py \ 49 | --epochs 30 \ 50 | --batch-size 32 \ 51 | --lr 1e-4 \ 52 | --rand_seed 0 \ 53 | --csg 0.1 \ 54 | --apool \ 55 | --augment \ 56 | --csg-stages 3.4 \ 57 | --factor 0.1 \ 58 | ``` 59 | * Run `CUDA_VISIBLE_DEVICES=0 bash train.sh` 60 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 61 | 62 | 63 | ### GTA5 → Cityscapes 64 | * Download [GTA5 dataset](https://download.visinf.tu-darmstadt.de/data/from_games/). 65 | * Download the [leftImg8bit_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=3) and [gtFine_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=1) from the Cityscapes. 66 | * Prepare the annotations by using the [createTrainIdLabelImgs.py](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createTrainIdLabelImgs.py). 67 | * Put the [file of image list](tools/datasets/cityscapes/) into where you save the dataset. 68 | * **Remember to properly set the `C.dataset_path` in the `config_seg.py` to the path where datasets reside.** 69 | 70 | #### Evaluation 71 | * Download pretrained [DeepLabV2-ResNet50](https://drive.google.com/file/d/1E2CosTtGVgIe6BfLBV9vNmyj6l9aYUbk/view?usp=sharing) and [DeepLabV2-ResNet101](https://drive.google.com/file/d/17Pe86m4OCGMFLcxLl_V-1bcG5otOqdvb/view?usp=sharing) on GTA5 72 | * Put the checkpoint under `./CSG/pretrained/` 73 | * Put the code below in `train_seg.sh` 74 | ```bash 75 | python train_seg.py \ 76 | --epochs 50 \ 77 | --switch-model deeplab50 \ 78 | --batch-size 6 \ 79 | --lr 1e-3 \ 80 | --num-class 19 \ 81 | --gpus 0 \ 82 | --factor 0.1 \ 83 | --csg 75 \ 84 | --apool \ 85 | --csg-stages 3.4 \ 86 | --chunks 8 \ 87 | --augment \ 88 | --evaluate \ 89 | --resume pretrained/csg_res101_segmentation_best.pth.tar \ 90 | ``` 91 | * Change `--switch-model` (`deeplab50` or `deeplab101`) and `--resume` (path to pretrained checkpoints) accordingly. 92 | * Run `CUDA_VISIBLE_DEVICES=0 bash train_seg.sh` 93 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 94 | 95 | #### Train with CSG 96 | * Put the code below in `train_seg.sh` 97 | ```bash 98 | python train_seg.py \ 99 | --epochs 50 \ 100 | --switch-model deeplab50 \ 101 | --batch-size 6 \ 102 | --lr 1e-3 \ 103 | --num-class 19 \ 104 | --gpus 0 \ 105 | --factor 0.1 \ 106 | --csg 75 \ 107 | --apool \ 108 | --csg-stages 3.4 \ 109 | --chunks 8 \ 110 | --augment 111 | ``` 112 | * Change `--switch-model` (`deeplab50` or `deeplab101`) accordingly. 113 | * Run `CUDA_VISIBLE_DEVICES=0 bash train_seg.sh` 114 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 115 | 116 | 117 | ## Citation 118 | 119 | If you use this code for your research, please cite: 120 | 121 | ```BibTeX 122 | @article{chen2021contrastive, 123 | title={Contrastive syn-to-real generalization}, 124 | author={Chen, Wuyang and Yu, Zhiding and Mello, SD and Liu, Sifei and Alvarez, Jose M and Wang, Zhangyang and Anandkumar, Anima}, 125 | year={2021}, 126 | publisher={ICLR} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /config_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | # encoding: utf-8 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import os 11 | import os.path as osp 12 | import sys 13 | import numpy as np 14 | from easydict import EasyDict as edict 15 | 16 | C = edict() 17 | config = C 18 | cfg = C 19 | 20 | """please config ROOT_dir and user when u first using""" 21 | C.repo_name = 'CSG' 22 | C.abs_dir = osp.realpath(".") 23 | C.this_dir = C.abs_dir.split(osp.sep)[-1] 24 | C.root_dir = C.abs_dir[:C.abs_dir.index(C.repo_name) + len(C.repo_name)] 25 | 26 | """Data Dir""" 27 | # C.dataset_path = "/raid/" 28 | C.dataset_path = "/home/chenwy/" 29 | 30 | C.train_img_root = os.path.join(C.dataset_path, "gta5") 31 | C.train_gt_root = os.path.join(C.dataset_path, "gta5") 32 | C.val_img_root = os.path.join(C.dataset_path, "cityscapes") 33 | C.val_gt_root = os.path.join(C.dataset_path, "cityscapes") 34 | C.test_img_root = os.path.join(C.dataset_path, "cityscapes") 35 | C.test_gt_root = os.path.join(C.dataset_path, "cityscapes") 36 | 37 | C.train_source = osp.join(C.train_img_root, "gta5_train.txt") 38 | C.train_target_source = osp.join(C.train_img_root, "cityscapes_train_fine.txt") 39 | C.eval_source = osp.join(C.val_img_root, "cityscapes_val_fine.txt") 40 | C.test_source = osp.join(C.test_img_root, "cityscapes_test.txt") 41 | 42 | """Image Config""" 43 | C.num_classes = 19 44 | C.background = -1 45 | C.image_mean = np.array([0.485, 0.456, 0.406]) 46 | C.image_std = np.array([0.229, 0.224, 0.225]) 47 | C.down_sampling_train = [1, 1] # first down_sampling then crop 48 | C.down_sampling_val = [1, 1] # first down_sampling then crop 49 | C.gt_down_sampling = 1 50 | C.num_train_imgs = 12403 51 | C.num_eval_imgs = 500 52 | 53 | """ Settings for network, this would be different for each kind of model""" 54 | C.bn_eps = 1e-5 55 | C.bn_momentum = 0.1 56 | 57 | """Train Config""" 58 | C.lr = 0.01 59 | C.momentum = 0.9 60 | C.weight_decay = 5e-4 61 | C.nepochs = 30 62 | C.niters_per_epoch = 2000 63 | C.num_workers = 16 64 | C.train_scale_array = [0.75, 1, 1.25] 65 | 66 | """Eval Config""" 67 | C.eval_stride_rate = 5 / 6 68 | C.eval_scale_array = [1] 69 | C.eval_flip = True 70 | C.eval_base_size = 1024 71 | C.eval_crop_size = 1024 72 | C.eval_height = 1024 73 | C.eval_width = 2048 74 | 75 | # GTA5: 1052x1914 76 | C.image_height = 512 77 | C.image_width = 512 78 | C.is_test = False 79 | C.is_eval = False 80 | -------------------------------------------------------------------------------- /data/BaseDataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import os 5 | import cv2 6 | import torch 7 | import numpy as np 8 | from random import shuffle 9 | from pdb import set_trace as bp 10 | import torch.utils.data as data 11 | cv2.setNumThreads(0) 12 | 13 | 14 | class BaseDataset(data.Dataset): 15 | def __init__(self, setting, split_name, preprocess=None, file_length=None): 16 | super(BaseDataset, self).__init__() 17 | self._split_name = split_name 18 | if split_name == 'train': 19 | self._img_path = setting['train_img_root'] 20 | self._gt_path = setting['train_gt_root'] 21 | elif split_name == 'val': 22 | self._img_path = setting['val_img_root'] 23 | self._gt_path = setting['val_gt_root'] 24 | elif split_name == 'test': 25 | self._img_path = setting['test_img_root'] 26 | self._gt_path = setting['test_gt_root'] 27 | self._train_source = setting['train_source'] 28 | self._eval_source = setting['eval_source'] 29 | self._test_source = setting['test_source'] if 'test_source' in setting else setting['eval_source'] 30 | self._down_sampling = setting['down_sampling_train'] if split_name == 'train' else setting['down_sampling_val'] 31 | print("using downsampling:", self._down_sampling) 32 | self._file_names = self._get_file_names(split_name) 33 | print("Found %d images"%len(self._file_names)) 34 | self._file_length = file_length 35 | self.preprocess = preprocess 36 | 37 | def __len__(self): 38 | if self._file_length is not None: 39 | return self._file_length 40 | return len(self._file_names) 41 | 42 | def __getitem__(self, index): 43 | if self._file_length is not None: 44 | names = self._construct_new_file_names(self._file_length)[index] 45 | else: 46 | names = self._file_names[index] 47 | img_path = os.path.join(self._img_path, names[0]) 48 | gt_path = os.path.join(self._gt_path, names[1]) 49 | item_name = names[1].split("/")[-1].split(".")[0] 50 | img, gt = self._fetch_data(img_path, gt_path) 51 | img = img[:, :, ::-1] 52 | if self.preprocess is not None: 53 | img, gt, extra_dict = self.preprocess(img, gt) 54 | 55 | if self._split_name is 'train': 56 | img = torch.from_numpy(np.ascontiguousarray(img)).float() 57 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long() 58 | if self.preprocess is not None and extra_dict is not None: 59 | for k, v in extra_dict.items(): 60 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v)) 61 | if 'label' in k: 62 | extra_dict[k] = extra_dict[k].long() 63 | if 'img' in k: 64 | extra_dict[k] = extra_dict[k].float() 65 | 66 | output_dict = dict(data=img, label=gt, fn=str(item_name), n=len(self._file_names)) 67 | if self.preprocess is not None and extra_dict is not None: 68 | output_dict.update(**extra_dict) 69 | 70 | return output_dict 71 | 72 | def _fetch_data(self, img_path, gt_path, dtype=None): 73 | img = self._open_image(img_path, down_sampling=self._down_sampling[0]) 74 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype, down_sampling=self._down_sampling[1]) 75 | 76 | return img, gt 77 | 78 | def _get_file_names(self, split_name): 79 | assert split_name in ['train', 'val', 'test'] 80 | source = self._train_source 81 | if split_name == "val": 82 | source = self._eval_source 83 | elif split_name == 'test': 84 | source = self._test_source 85 | 86 | file_names = [] 87 | with open(source) as f: 88 | files = f.readlines() 89 | 90 | for item in files: 91 | img_name, gt_name = self._process_item_names(item) 92 | file_names.append([img_name, gt_name]) 93 | 94 | return file_names 95 | 96 | def _construct_new_file_names(self, length): 97 | assert isinstance(length, int) 98 | files_len = len(self._file_names) 99 | new_file_names = self._file_names * (length // files_len) 100 | 101 | rand_indices = torch.randperm(files_len).tolist() 102 | new_indices = rand_indices[:length % files_len] 103 | 104 | new_file_names += [self._file_names[i] for i in new_indices] 105 | 106 | return new_file_names 107 | 108 | @staticmethod 109 | def _process_item_names(item): 110 | item = item.strip() 111 | # item = item.split('\t') 112 | item = item.split(' ') 113 | img_name = item[0] 114 | gt_name = item[1] 115 | 116 | return img_name, gt_name 117 | 118 | def get_length(self): 119 | return self.__len__() 120 | 121 | @staticmethod 122 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None, down_sampling=1): 123 | # cv2: B G R 124 | # h w c 125 | img = np.array(cv2.imread(filepath, mode), dtype=dtype) 126 | if isinstance(down_sampling, int): 127 | try: 128 | H, W = img.shape[:2] 129 | except Exception: 130 | print(img.shape, filepath) 131 | exit(0) 132 | if len(img.shape) == 3: 133 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_LINEAR) 134 | else: 135 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_NEAREST) 136 | assert img.shape[0] == H // down_sampling and img.shape[1] == W // down_sampling 137 | else: 138 | assert (isinstance(down_sampling, tuple) or isinstance(down_sampling, list)) and len(down_sampling) == 2 139 | if len(img.shape) == 3: 140 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_LINEAR) 141 | else: 142 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_NEAREST) 143 | assert img.shape[0] == down_sampling[0] and img.shape[1] == down_sampling[1] 144 | 145 | return img 146 | 147 | @classmethod 148 | def get_class_colors(*args): 149 | raise NotImplementedError 150 | 151 | @classmethod 152 | def get_class_names(*args): 153 | raise NotImplementedError 154 | 155 | 156 | if __name__ == "__main__": 157 | data_setting = {'img_root': '', 158 | 'gt_root': '', 159 | 'train_source': '', 160 | 'eval_source': ''} 161 | bd = BaseDataset(data_setting, 'train', None) 162 | print(bd.get_class_names()) 163 | -------------------------------------------------------------------------------- /data/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | from data.BaseDataset import BaseDataset 6 | 7 | class Cityscapes(BaseDataset): 8 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 9 | 28, 31, 32, 33] 10 | 11 | @classmethod 12 | def get_class_colors(*args): 13 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70], 14 | [102, 102, 156], [190, 153, 153], [153, 153, 153], 15 | [250, 170, 30], [220, 220, 0], [107, 142, 35], 16 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], 17 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], 18 | [0, 0, 230], [119, 11, 32]] 19 | 20 | @classmethod 21 | def get_class_names(*args): 22 | # class counting(gtFine) 23 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832 24 | # 359 274 142 513 1646 25 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 26 | 'traffic light', 'traffic sign', 27 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 28 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] 29 | 30 | @classmethod 31 | def transform_label(cls, pred, name): 32 | label = np.zeros(pred.shape) 33 | ids = np.unique(pred) 34 | for id in ids: 35 | label[np.where(pred == id)] = cls.trans_labels[id] 36 | 37 | new_name = (name.split('.')[0]).split('_')[:-1] 38 | new_name = '_'.join(new_name) + '.png' 39 | 40 | print('Trans', name, 'to', new_name, ' ', 41 | np.unique(np.array(pred, np.uint8)), ' ---------> ', 42 | np.unique(np.array(label, np.uint8))) 43 | return label, new_name 44 | -------------------------------------------------------------------------------- /data/gta5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | from data.BaseDataset import BaseDataset 6 | 7 | class GTA5(BaseDataset): 8 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 9 | 28, 31, 32, 33] 10 | 11 | @classmethod 12 | def get_class_colors(*args): 13 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70], 14 | [102, 102, 156], [190, 153, 153], [153, 153, 153], 15 | [250, 170, 30], [220, 220, 0], [107, 142, 35], 16 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], 17 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], 18 | [0, 0, 230], [119, 11, 32]] 19 | 20 | @classmethod 21 | def get_class_names(*args): 22 | # class counting(gtFine) 23 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832 24 | # 359 274 142 513 1646 25 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 26 | 'traffic light', 'traffic sign', 27 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 28 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] 29 | 30 | @classmethod 31 | def transform_label(cls, pred, name): 32 | label = np.zeros(pred.shape) 33 | ids = np.unique(pred) 34 | for id in ids: 35 | label[np.where(pred == id)] = cls.trans_labels[id] 36 | 37 | new_name = (name.split('.')[0]).split('_')[:-1] 38 | new_name = '_'.join(new_name) + '.png' 39 | 40 | print('Trans', name, 'to', new_name, ' ', 41 | np.unique(np.array(pred, np.uint8)), ' ---------> ', 42 | np.unique(np.array(label, np.uint8))) 43 | return label, new_name 44 | -------------------------------------------------------------------------------- /data/labels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | #!/usr/bin/python 5 | # 6 | # Cityscapes labels 7 | # 8 | 9 | from collections import namedtuple 10 | 11 | 12 | #-------------------------------------------------------------------------------- 13 | # Definitions 14 | #-------------------------------------------------------------------------------- 15 | 16 | # a label and all meta information 17 | Label = namedtuple( 'Label' , [ 18 | 19 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 20 | # We use them to uniquely name a class 21 | 22 | 'id' , # An integer ID that is associated with this label. 23 | # The IDs are used to represent the label in ground truth images 24 | # An ID of -1 means that this label does not have an ID and thus 25 | # is ignored when creating ground truth images (e.g. license plate). 26 | # Do not modify these IDs, since exactly these IDs are expected by the 27 | # evaluation server. 28 | 29 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 30 | # ground truth images with train IDs, using the tools provided in the 31 | # 'preparation' folder. However, make sure to validate or submit results 32 | # to our evaluation server using the regular IDs above! 33 | # For trainIds, multiple labels might have the same ID. Then, these labels 34 | # are mapped to the same class in the ground truth images. For the inverse 35 | # mapping, we use the label that is defined first in the list below. 36 | # For example, mapping all void-type classes to the same ID in training, 37 | # might make sense for some approaches. 38 | # Max value is 255! 39 | 40 | 'category' , # The name of the category that this label belongs to 41 | 42 | 'categoryId' , # The ID of this category. Used to create ground truth images 43 | # on category level. 44 | 45 | 'hasInstances', # Whether this label distinguishes between single instances or not 46 | 47 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 48 | # during evaluations or not 49 | 50 | 'color' , # The color of this label 51 | ] ) 52 | 53 | 54 | #-------------------------------------------------------------------------------- 55 | # A list of all labels 56 | #-------------------------------------------------------------------------------- 57 | 58 | # Please adapt the train IDs as appropriate for you approach. 59 | # Note that you might want to ignore labels with ID 255 during training. 60 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 61 | # Make sure to provide your results using the original IDs and not the training IDs. 62 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 63 | 64 | labels = [ 65 | # name id trainId category catId hasInstances ignoreInEval color 66 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 67 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 68 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 69 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 70 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 71 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 72 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 73 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 74 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 75 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 76 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 77 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 78 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 79 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 80 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 81 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 82 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 83 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 84 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 85 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 86 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 87 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 88 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 89 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 90 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 91 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 92 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 93 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 94 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 95 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 96 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 97 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 98 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 99 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 100 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 101 | ] 102 | 103 | 104 | #-------------------------------------------------------------------------------- 105 | # Create dictionaries for a fast lookup 106 | #-------------------------------------------------------------------------------- 107 | 108 | # Please refer to the main method below for example usages! 109 | 110 | # name to label object 111 | name2label = { label.name : label for label in labels } 112 | # id to label object 113 | id2label = { label.id : label for label in labels } 114 | # trainId to label object 115 | trainId2label = { label.trainId : label for label in reversed(labels) } 116 | # category to list of label objects 117 | category2labels = {} 118 | for label in labels: 119 | category = label.category 120 | if category in category2labels: 121 | category2labels[category].append(label) 122 | else: 123 | category2labels[category] = [label] 124 | 125 | #-------------------------------------------------------------------------------- 126 | # Assure single instance name 127 | #-------------------------------------------------------------------------------- 128 | 129 | # returns the label name that describes a single instance (if possible) 130 | # e.g. input | output 131 | # ---------------------- 132 | # car | car 133 | # cargroup | car 134 | # foo | None 135 | # foogroup | None 136 | # skygroup | None 137 | def assureSingleInstanceName( name ): 138 | # if the name is known, it is not a group 139 | if name in name2label: 140 | return name 141 | # test if the name actually denotes a group 142 | if not name.endswith("group"): 143 | return None 144 | # remove group 145 | name = name[:-len("group")] 146 | # test if the new name exists 147 | if not name in name2label: 148 | return None 149 | # test if the new name denotes a label that actually has instances 150 | if not name2label[name].hasInstances: 151 | return None 152 | # all good then 153 | return name 154 | 155 | #-------------------------------------------------------------------------------- 156 | # Main for testing 157 | #-------------------------------------------------------------------------------- 158 | 159 | # just a dummy main 160 | if __name__ == "__main__": 161 | # Print all the labels 162 | print("List of cityscapes labels:") 163 | print("") 164 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )) 165 | print(" " + ('-' * 98)) 166 | for label in labels: 167 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )) 168 | print("") 169 | 170 | print("Example usages:") 171 | 172 | # Map from name to label 173 | name = 'car' 174 | id = name2label[name].id 175 | print("ID of label '{name}': {id}".format( name=name, id=id )) 176 | 177 | # Map from ID to label 178 | category = id2label[id].category 179 | print("Category of label with ID '{id}': {category}".format( id=id, category=category )) 180 | 181 | # Map from trainID to label 182 | trainId = 0 183 | name = trainId2label[trainId].name 184 | print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )) 185 | -------------------------------------------------------------------------------- /data/loader_csg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import warnings 5 | from PIL import ImageFilter, Image 6 | import math 7 | import random 8 | import torch 9 | from torchvision.transforms import functional as F 10 | 11 | _pil_interpolation_to_str = { 12 | Image.NEAREST: 'PIL.Image.NEAREST', 13 | Image.BILINEAR: 'PIL.Image.BILINEAR', 14 | Image.BICUBIC: 'PIL.Image.BICUBIC', 15 | Image.LANCZOS: 'PIL.Image.LANCZOS', 16 | Image.HAMMING: 'PIL.Image.HAMMING', 17 | Image.BOX: 'PIL.Image.BOX', 18 | } 19 | 20 | 21 | def _get_image_size(img): 22 | if F._is_pil_image(img): 23 | return img.size 24 | elif isinstance(img, torch.Tensor) and img.dim() > 2: 25 | return img.shape[-2:][::-1] 26 | else: 27 | raise TypeError("Unexpected type {}".format(type(img))) 28 | 29 | 30 | class RandomResizedCrop_two(object): 31 | # generate two closely located patches 32 | """Crop the given PIL Image to random size and aspect ratio. 33 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 34 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 35 | is finally resized to given size. 36 | This is popularly used to train the Inception networks. 37 | Args: 38 | size: expected output size of each edge 39 | scale: range of size of the origin size cropped 40 | ratio: range of aspect ratio of the origin aspect ratio cropped 41 | interpolation: Default: PIL.Image.BILINEAR 42 | """ 43 | 44 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 45 | if isinstance(size, (tuple, list)): 46 | self.size = size 47 | else: 48 | self.size = (size, size) 49 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 50 | warnings.warn("range should be of kind (min, max)") 51 | 52 | self.interpolation = interpolation 53 | self.scale = scale 54 | self.ratio = ratio 55 | 56 | @staticmethod 57 | def get_params(img, scale, ratio, augment=(0.025, 0.075)): 58 | """Get parameters for ``crop`` for a random sized crop. 59 | Args: 60 | img (PIL Image): Image to be cropped. 61 | scale (tuple): range of size of the origin size cropped 62 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 63 | Returns: 64 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 65 | sized crop. 66 | """ 67 | width, height = _get_image_size(img) 68 | area = height * width 69 | 70 | for _ in range(10): 71 | target_area = random.uniform(*scale) * area 72 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 73 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 74 | 75 | w = int(round(math.sqrt(target_area * aspect_ratio))) 76 | h = int(round(math.sqrt(target_area / aspect_ratio))) 77 | 78 | if 0 < w <= width and 0 < h <= height: 79 | i = random.randint(0, height - h) 80 | j = random.randint(0, width - w) 81 | # return i, j, h, w 82 | ##### augment ##### 83 | delta = random.randint(int(h*augment[0]), int(h*augment[1])) * random.choice([-1, 1]) 84 | h_a = h + delta; h_a = min(max(1, h_a), height) 85 | delta = random.randint(int(w*augment[0]), int(w*augment[1])) * random.choice([-1, 1]) 86 | w_a = w + delta; w_a = min(max(1, w_a), width) 87 | delta = random.randint(int(h*augment[0]), int(h*augment[1])) * random.choice([-1, 1]) 88 | i_a = i + delta; i_a = min(max(0, i_a), height - h_a) 89 | delta = random.randint(int(w*augment[0]), int(w*augment[1])) * random.choice([-1, 1]) 90 | j_a = j + delta; j_a = min(max(0, j_a), width - w_a) 91 | ################### 92 | return i, j, h, w, i_a, j_a, h_a, w_a 93 | 94 | # Fallback to central crop 95 | in_ratio = float(width) / float(height) 96 | if (in_ratio < min(ratio)): 97 | w = width 98 | h = int(round(w / min(ratio))) 99 | elif (in_ratio > max(ratio)): 100 | h = height 101 | w = int(round(h * max(ratio))) 102 | else: # whole image 103 | w = width 104 | h = height 105 | i = (height - h) // 2 106 | j = (width - w) // 2 107 | ##### augment ##### 108 | delta = random.randint(int(h*augment[0]), int(h*augment[1])) * random.choice([-1, 1]) 109 | h_a = h + delta; h_a = min(max(1, h_a), height) 110 | delta = random.randint(int(w*augment[0]), int(w*augment[1])) * random.choice([-1, 1]) 111 | w_a = w + delta; w_a = min(max(1, w_a), width) 112 | delta = random.randint(int(h*augment[0]), int(h*augment[1])) * random.choice([-1, 1]) 113 | i_a = i + delta; i_a = min(max(0, i_a), height - h_a) 114 | delta = random.randint(int(w*augment[0]), int(w*augment[1])) * random.choice([-1, 1]) 115 | j_a = j + delta; j_a = min(max(0, j_a), width - w_a) 116 | ################### 117 | return i, j, h, w, i_a, j_a, h_a, w_a 118 | 119 | def __call__(self, img): 120 | """ 121 | Args: 122 | img (PIL Image): Image to be cropped and resized. 123 | Returns: 124 | PIL Image: Randomly cropped and resized image. 125 | """ 126 | i, j, h, w, i_a, j_a, h_a, w_a = self.get_params(img, self.scale, self.ratio) 127 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), F.resized_crop(img, i_a, j_a, h_a, w_a, self.size, self.interpolation) 128 | 129 | def __repr__(self): 130 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 131 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 132 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 133 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 134 | format_string += ', interpolation={0})'.format(interpolate_str) 135 | return format_string 136 | 137 | 138 | class ImageTransform: 139 | """return both image and tensor""" 140 | 141 | def __init__(self, transform): 142 | self.base_transform = transform[0] # resize, centercrop 143 | self.totensor_norm = transform[1] # totensor, **normalize** 144 | 145 | def __call__(self, x): 146 | image = self.base_transform(x) 147 | tensor = self.totensor_norm(image) 148 | return [tensor, F.to_tensor(image)] 149 | 150 | 151 | class TwoCropsTransform: 152 | """Take two random crops of one image as the query and key.""" 153 | 154 | def __init__(self, q_transform, k_transform): 155 | self.q_transform = q_transform 156 | self.k_transform = k_transform 157 | 158 | def __call__(self, x): 159 | q = self.q_transform(x) 160 | k = self.k_transform(x) 161 | return [q, k] 162 | 163 | 164 | class GaussianBlur(object): 165 | def __init__(self, sigma=[.1, 2.]): 166 | self.sigma = sigma 167 | 168 | def __call__(self, x): 169 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 170 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 171 | return x 172 | -------------------------------------------------------------------------------- /data/visda17.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import os 5 | from PIL import Image 6 | import numpy as np 7 | import random 8 | import torch 9 | from torch.utils.data import Dataset 10 | import torchvision.transforms as transforms 11 | from pdb import set_trace as bp 12 | 13 | 14 | class VisDA17(Dataset): 15 | 16 | def __init__(self, txt_file, root_dir, transform=transforms.ToTensor(), label_one_hot=False, portion=1.0): 17 | """ 18 | Args: 19 | txt_file (string): Path to the txt file with annotations. 20 | root_dir (string): Directory with all the images. 21 | transform (callable, optional): Optional transform to be applied 22 | on a sample. 23 | """ 24 | self.lines = open(txt_file, 'r').readlines() 25 | self.root_dir = root_dir 26 | self.transform = transform 27 | self.label_one_hot = label_one_hot 28 | self.portion = portion 29 | self.number_classes = 12 30 | assert portion != 0 31 | if self.portion > 0: 32 | self.lines = self.lines[:round(self.portion * len(self.lines))] 33 | else: 34 | self.lines = self.lines[round(self.portion * len(self.lines)):] 35 | 36 | def __len__(self): 37 | return len(self.lines) 38 | 39 | def __getitem__(self, idx): 40 | line = str.split(self.lines[idx]) 41 | path_img = os.path.join(self.root_dir, line[0]) 42 | image = Image.open(path_img) 43 | image = image.convert('RGB') 44 | if self.transform: 45 | image = self.transform(image) 46 | if self.label_one_hot: 47 | label = np.zeros(12, np.float32) 48 | label[np.asarray(line[1], dtype=np.int)] = 1 49 | else: 50 | label = np.asarray(line[1], dtype=np.int) 51 | label = torch.from_numpy(label) 52 | return image, label 53 | -------------------------------------------------------------------------------- /dataloader_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)) 3 | 4 | import cv2 5 | from torch.utils import data 6 | from PIL import Image 7 | import numpy as np 8 | from tools.utils.img_utils import random_scale, random_mirror, normalize, generate_random_crop_pos, random_crop_pad_to_shape 9 | import torchvision.transforms as transforms 10 | cv2.setNumThreads(0) 11 | 12 | 13 | class TrainPre(object): 14 | def __init__(self, config, img_mean, img_std, augment=None): 15 | self.img_mean = img_mean 16 | self.img_std = img_std 17 | self.config = config 18 | self.augment = augment 19 | 20 | # we have func normalize below; return npy 21 | if augment: 22 | self.data_transforms = transforms.Compose([ 23 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 24 | ]) 25 | 26 | def __call__(self, img, gt): 27 | img, gt = random_mirror(img, gt) 28 | if self.config.train_scale_array is not None: 29 | img, gt, scale = random_scale(img, gt, self.config.train_scale_array) 30 | 31 | crop_size = (self.config.image_height, self.config.image_width) 32 | crop_pos = generate_random_crop_pos(img.shape[:2], crop_size) 33 | if self.augment: 34 | p_img, _ = random_crop_pad_to_shape(normalize(img, self.img_mean, self.img_std), crop_pos, crop_size, 0) 35 | p_img_k, _ = random_crop_pad_to_shape(normalize(np.array( 36 | self.data_transforms(Image.fromarray(img)) 37 | ), self.img_mean, self.img_std), crop_pos, crop_size, 0) 38 | p_img = p_img.transpose(2, 0, 1) 39 | p_img_k = p_img_k.transpose(2, 0, 1) 40 | extra_dict = {'img_k': p_img_k} 41 | else: 42 | p_img, _ = random_crop_pad_to_shape(normalize(img, self.img_mean, self.img_std), crop_pos, crop_size, 0) 43 | p_img = p_img.transpose(2, 0, 1) 44 | extra_dict = None 45 | p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 255) 46 | p_gt = cv2.resize(p_gt, (self.config.image_width // self.config.gt_down_sampling, self.config.image_height // self.config.gt_down_sampling), interpolation=cv2.INTER_NEAREST) 47 | 48 | return p_img, p_gt, extra_dict 49 | 50 | 51 | def get_train_loader(config, dataset, worker=None, test=False, augment=None): 52 | data_setting = { 53 | 'train_img_root': config.train_img_root, 54 | 'train_gt_root': config.train_gt_root, 55 | 'val_img_root': config.val_img_root, 56 | 'val_gt_root': config.val_gt_root, 57 | 'train_source': config.train_source, 58 | 'eval_source': config.eval_source, 59 | 'down_sampling_train': config.down_sampling_train 60 | } 61 | if test: 62 | data_setting = {'img_root': config.img_root, 63 | 'gt_root': config.gt_root, 64 | 'train_source': config.train_eval_source, 65 | 'eval_source': config.eval_source} 66 | train_preprocess = TrainPre(config, config.image_mean, config.image_std, augment) 67 | 68 | train_dataset = dataset(data_setting, "train", train_preprocess, config.batch_size * config.niters_per_epoch) 69 | 70 | is_shuffle = True 71 | batch_size = config.batch_size 72 | 73 | train_loader = data.DataLoader(train_dataset, 74 | batch_size=batch_size, 75 | num_workers=config.num_workers if worker is None else worker, 76 | drop_last=True, 77 | shuffle=is_shuffle, 78 | pin_memory=True, 79 | ) 80 | 81 | return train_loader 82 | -------------------------------------------------------------------------------- /eval_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)) 3 | 4 | #!/usr/bin/env python3 5 | # encoding: utf-8 6 | import os 7 | import cv2 8 | import numpy as np 9 | from pdb import set_trace as bp 10 | from tools.utils.visualize import print_iou, show_img, show_prediction 11 | from tools.engine.evaluator import Evaluator 12 | from tools.engine.logger import get_logger 13 | from tools.seg_opr.metric import hist_info, compute_score 14 | 15 | cv2.setNumThreads(0) 16 | logger = get_logger() 17 | 18 | 19 | class SegEvaluator(Evaluator): 20 | def func_per_iteration(self, data, device, iter=None): 21 | if self.config is not None: config = self.config 22 | img = data['data'] 23 | label = data['label'] 24 | name = data['fn'] 25 | 26 | if len(config.eval_scale_array) == 1: 27 | pred = self.whole_eval(img, label.shape, resize=config.eval_scale_array[0], device=device) 28 | pred = pred.argmax(2) # since we ignore this step in evaluator.py 29 | elif len(config.eval_scale_array) > 1: 30 | pred = self.whole_eval(img, label.shape, resize=config.eval_scale_array[0], device=device) 31 | for scale in config.eval_scale_array[1:]: 32 | pred += self.whole_eval(img, label.shape, resize=scale, device=device) 33 | pred = pred.argmax(2) # since we ignore this step in evaluator.py 34 | else: 35 | pred = self.sliding_eval(img, config.eval_crop_size, config.eval_stride_rate, device) 36 | hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes, pred, label) 37 | results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp} 38 | 39 | if self.save_path is not None: 40 | fn = name + '.png' 41 | cv2.imwrite(os.path.join(self.save_path, fn), pred) 42 | logger.info('Save the image ' + fn) 43 | 44 | # tensorboard logger does not fit multiprocess 45 | if self.logger is not None and iter is not None: 46 | colors = self.dataset.get_class_colors() 47 | image = img 48 | clean = np.zeros(label.shape) 49 | comp_img = show_img(colors, config.background, image, clean, label, pred) 50 | self.logger.add_image('vis', np.swapaxes(np.swapaxes(comp_img, 0, 2), 1, 2), iter) 51 | 52 | if self.show_image or self.show_prediction: 53 | colors = self.dataset.get_class_colors() 54 | image = img 55 | clean = np.zeros(label.shape) 56 | if self.show_image: 57 | comp_img = show_img(colors, config.background, image, clean, label, pred) 58 | cv2.imwrite(os.path.join(self.save_path, name + ".png"), comp_img[:,:,::-1]) 59 | if self.show_prediction: 60 | comp_img = show_prediction(colors, config.background, image, pred) 61 | cv2.imwrite(os.path.join(self.save_path, "viz_"+name+".png"), comp_img[:,:,::-1]) 62 | 63 | return results_dict 64 | 65 | def compute_metric(self, results): 66 | hist = np.zeros((self.config.num_classes, self.config.num_classes)) 67 | correct = 0 68 | labeled = 0 69 | count = 0 70 | for d in results: 71 | hist += d['hist'] 72 | correct += d['correct'] 73 | labeled += d['labeled'] 74 | count += 1 75 | 76 | iu, mean_IU, mean_IU_no_back, mean_pixel_acc = compute_score(hist, correct, labeled) 77 | result_line = print_iou(iu, mean_pixel_acc, self.dataset.get_class_names(), True) 78 | return result_line, mean_IU 79 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | -------------------------------------------------------------------------------- /model/csg_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pdb import set_trace as bp 7 | 8 | 9 | def chunk_feature(feature, chunk): 10 | if chunk == 1: 11 | return feature 12 | # B x C x H x W => (B*chunk^2) x C x (H//chunk) x (W//chunk) 13 | _f_new = torch.chunk(feature, chunk, dim=2) 14 | _f_new = [torch.chunk(f, chunk, dim=3) for f in _f_new] 15 | f_new = [] 16 | for f in _f_new: 17 | f_new += f 18 | f_new = torch.cat(f_new, dim=0) 19 | return f_new 20 | 21 | 22 | class CSG(nn.Module): 23 | def __init__(self, base_encoder, get_head=None, dim=128, K=65536, m=0.999, T=0.07, mlp=True, stages=[4], num_class=12, chunks=[1], task='new', 24 | base_encoder_kwargs={}, apool=True 25 | ): 26 | """ 27 | dim: feature dimension (default: 128) 28 | K: queue size; number of negative keys (default: 65536) 29 | m: momentum of updating key encoder (default: 0.999) 30 | T: softmax temperature (default: 0.07) 31 | """ 32 | super(CSG, self).__init__() 33 | 34 | self.K = K 35 | self.m = m 36 | self.T = T 37 | self.stages = stages 38 | self.mlp = mlp 39 | self.base_encoder = base_encoder 40 | self.chunks = chunks # chunk feature (segmentation) 41 | self.task = task # new, new-seg 42 | self.attentions = [None for _ in range(len(stages))] 43 | self.apool = apool 44 | 45 | # create the encoders 46 | # num_classes is the output fc dimension 47 | self.encoder_q = base_encoder(num_classes=dim, pretrained=True, **base_encoder_kwargs) # q is for new task 48 | self.encoder_k = base_encoder(num_classes=dim, pretrained=True, **base_encoder_kwargs) # ###### 49 | if get_head is not None: 50 | num_ftrs = self.encoder_q.fc.in_features 51 | self.encoder_q.fc_new = get_head(num_ftrs, num_class) 52 | for param in self.encoder_q.fc.parameters(): 53 | param.requires_grad = False 54 | 55 | if mlp: 56 | fc_q = {} 57 | fc_k = {} 58 | for stage in stages: 59 | if stage > 0: 60 | try: 61 | # BottleNeck 62 | dim_mlp = getattr(self.encoder_q, "layer%d"%stage)[-1].conv3.weight.size()[0] 63 | except torch.nn.modules.module.ModuleAttributeError: 64 | # BasicBlock 65 | dim_mlp = getattr(self.encoder_q, "layer%d"%stage)[-1].conv2.weight.size()[0] 66 | elif stage == 0: 67 | dim_mlp = self.encoder_q.conv1.weight.size()[0] 68 | fc_q["stage%d"%(stage)] = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, dim)) 69 | fc_k["stage%d"%(stage)] = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, dim)) 70 | self.encoder_q.fc_csg = nn.ModuleDict(fc_q) 71 | self.encoder_k.fc_csg = nn.ModuleDict(fc_k) 72 | for param_q, param_k in zip(self.encoder_q.fc_csg.parameters(), self.encoder_k.fc_csg.parameters()): 73 | param_k.data.copy_(param_q.data) 74 | 75 | for param_k in self.encoder_k.parameters(): 76 | param_k.requires_grad = False # not update by gradient 77 | 78 | if type(self.encoder_q).__name__ == "ResNet": 79 | try: 80 | # BottleNeck 81 | dims = [self.encoder_q.conv1.weight.size()[0]] + [getattr(self.encoder_q, "layer%d"%stage)[-1].conv3.weight.size()[0] for stage in range(1, 5)] 82 | except: 83 | # BasicBlock 84 | dims = [self.encoder_q.conv1.weight.size()[0]] + [getattr(self.encoder_q, "layer%d"%stage)[-1].conv2.weight.size()[0] for stage in range(1, 5)] 85 | elif type(self.encoder_q).__name__ == "DigitNet": 86 | dims = [64 for stage in range(1, 5)] 87 | for stage in stages: 88 | self.register_buffer("queue%d"%(stage), torch.randn(dim, K)) 89 | setattr(self, "queue%d"%(stage), nn.functional.normalize(getattr(self, "queue%d"%(stage)), dim=0)) 90 | self.register_buffer("queue_ptr%d"%(stage), torch.zeros(1, dtype=torch.long)) 91 | 92 | def control_q_backbone_gradient(self, control): 93 | for name, param in self.encoder_q.named_parameters(): 94 | if 'fc_new' not in name: 95 | param.requires_grad = control 96 | return 97 | 98 | @torch.no_grad() 99 | def _momentum_update_key_encoder(self): 100 | """ 101 | Momentum update of the key encoder 102 | """ 103 | for param_q, param_k in zip(self.encoder_q.fc_csg.parameters(), self.encoder_k.fc_csg.parameters()): 104 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 105 | 106 | @torch.no_grad() 107 | def _dequeue_and_enqueue(self, keys, stage): 108 | # gather keys before updating queue 109 | keys = concat_all_gather(keys) 110 | 111 | batch_size = keys.shape[0] 112 | 113 | ptr = int(getattr(self, "queue_ptr%d"%(stage))) 114 | 115 | if ptr + batch_size <= self.K: 116 | getattr(self, "queue%d"%(stage))[:, ptr:ptr + batch_size] = keys.T 117 | else: 118 | getattr(self, "queue%d"%(stage))[:, ptr:] = keys[:(self.K - ptr)].T 119 | getattr(self, "queue%d"%(stage))[:, :ptr + batch_size - self.K] = keys[:(ptr + batch_size - self.K)].T 120 | ptr = (ptr + batch_size) % self.K # move pointer 121 | getattr(self, "queue_ptr%d"%(stage))[0] = ptr 122 | 123 | def adaptive_pool(self, features, attn_from, stage_idx): 124 | # features and attn_from are paired feature maps, of same size 125 | assert features.size() == attn_from.size() 126 | N, C, H, W = features.size() 127 | assert (attn_from >= 0).float().sum() == N*C*H*W 128 | attention = torch.einsum('nchw,nc->nhw', [attn_from, nn.functional.adaptive_avg_pool2d(attn_from, (1, 1)).view(N, C)]) 129 | attention = attention / attention.view(N, -1).sum(1).view(N, 1, 1).repeat(1, H, W) 130 | attention = attention.view(N, 1, H, W) 131 | # output size: N, C 132 | return (features * attention).view(N, C, -1).sum(2) 133 | 134 | def forward(self, im_q, im_k): 135 | """ 136 | Input: 137 | im_q: a batch of query images 138 | im_k: a batch of key images 139 | Output: 140 | logits, targets 141 | """ 142 | if im_k is None: 143 | im_k = im_q 144 | 145 | output, features_new = self.encoder_q(im_q, output_features=["layer%d"%stage for stage in self.stages], task=self.task) 146 | results = {'output': output} 147 | 148 | results['predictions_csg'] = [] 149 | results['targets_csg'] = [] 150 | # predictions: cosine b/w q and k 151 | # targets: zeros 152 | with torch.no_grad(): # no gradient to keys 153 | if self.mlp: 154 | self._momentum_update_key_encoder() # update the key encoder 155 | if self.apool: 156 | # A-Pool: prepare attention for teacher: get feature of im_k by encoder_q 157 | _, features_new_k = self.encoder_q.forward_backbone(im_k, output_features=["layer%d"%stage for stage in self.stages]) 158 | _, features_old = self.encoder_k.forward_backbone(im_k, output_features=["layer%d"%stage for stage in self.stages]) 159 | for idx, stage in enumerate(self.stages): 160 | chunk = self.chunks[idx] 161 | # compute query features 162 | 163 | q_feature = chunk_feature(features_new["layer%d"%stage], chunk) 164 | if self.apool: 165 | # A-Pool prepare attention for teacher: get feature of im_k by encoder_q 166 | q_feature_k = chunk_feature(features_new_k["layer%d"%stage], chunk) 167 | if self.mlp: 168 | if self.apool: 169 | q = self.encoder_q.fc_csg["stage%d"%(stage)](self.adaptive_pool(q_feature, q_feature, idx)) # A-Pool 170 | else: 171 | q = self.encoder_q.fc_csg["stage%d"%(stage)](self.encoder_q.avgpool(q_feature).view(features_new["layer%d"%stage].size(0)*chunk**2, -1)) 172 | else: 173 | if self.apool != 'none': 174 | q = self.adaptive_pool(q_feature, q_feature, idx) # A-Pool 175 | else: 176 | q = self.encoder_q.avgpool(q_feature).view(features_new["layer%d"%stage].size(0)*chunk**2, -1) 177 | q = nn.functional.normalize(q, dim=1) 178 | 179 | # compute key features 180 | with torch.no_grad(): # no gradient to keys 181 | k_feature = chunk_feature(features_old["layer%d"%stage], chunk) 182 | # A-Pool ############# 183 | if self.mlp: 184 | if self.apool: 185 | k = self.encoder_k.fc_csg["stage%d"%(stage)](self.adaptive_pool(k_feature, q_feature_k, idx)) # A-Pool 186 | else: 187 | k = self.encoder_k.fc_csg["stage%d"%(stage)](self.encoder_k.avgpool(k_feature).view(features_old["layer%d"%stage].size(0)*chunk**2, -1)) 188 | else: 189 | if self.apool: 190 | k = self.adaptive_pool(k_feature, q_feature_k, idx) # A-Pool 191 | else: 192 | k = self.encoder_k.avgpool(k_feature).view(features_old["layer%d"%stage].size(0)*chunk**2, -1) 193 | # ##################### 194 | k = nn.functional.normalize(k, dim=1) 195 | 196 | # compute logits 197 | # positive logits: Nx1 198 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 199 | # negative logits: NxK 200 | l_neg = torch.einsum('nc,ck->nk', [q, getattr(self, "queue%d"%(stage)).clone().detach()]) 201 | # logits: Nx(1+K) 202 | logits = torch.cat([l_pos, l_neg], dim=1) 203 | # apply temperature 204 | logits /= self.T 205 | # labels: positive key indicators 206 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 207 | self._dequeue_and_enqueue(k, stage) 208 | 209 | results['predictions_csg'].append(logits) 210 | results['targets_csg'].append(labels) 211 | 212 | return results 213 | 214 | 215 | # utils 216 | @torch.no_grad() 217 | def concat_all_gather(tensor): 218 | """ 219 | Performs all_gather operation on the provided tensors. 220 | *** Warning ***: torch.distributed.all_gather has no gradient. 221 | """ 222 | try: 223 | tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] 224 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 225 | except Exception: 226 | tensors_gather = [tensor] 227 | 228 | output = torch.cat(tensors_gather, dim=0) 229 | return output 230 | -------------------------------------------------------------------------------- /model/deeplab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from pdb import set_trace as bp 9 | affine_par = True 10 | 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes, affine = affine_par) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes, affine = affine_par) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes,affine = affine_par) 58 | for i in self.bn1.parameters(): 59 | i.requires_grad = False 60 | 61 | padding = dilation 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=padding, bias=False, dilation = dilation) 64 | self.bn2 = nn.BatchNorm2d(planes,affine = affine_par) 65 | for i in self.bn2.parameters(): 66 | i.requires_grad = False 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4, affine = affine_par) 69 | for i in self.bn3.parameters(): 70 | i.requires_grad = False 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class ASPPConv(nn.Sequential): 99 | def __init__(self, in_channels, out_channels, dilation): 100 | modules = [ 101 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 102 | nn.BatchNorm2d(out_channels), 103 | nn.ReLU() 104 | ] 105 | super(ASPPConv, self).__init__(*modules) 106 | 107 | 108 | class ASPPPooling(nn.Sequential): 109 | def __init__(self, in_channels, out_channels): 110 | super(ASPPPooling, self).__init__( 111 | nn.AdaptiveAvgPool2d(1), 112 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 113 | nn.BatchNorm2d(out_channels), 114 | nn.ReLU()) 115 | 116 | def forward(self, x): 117 | size = x.shape[-2:] 118 | for mod in self: 119 | x = mod(x) 120 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 121 | 122 | 123 | class ASPP(nn.Module): 124 | def __init__(self, in_channels, atrous_rates, out_channels=256): 125 | super(ASPP, self).__init__() 126 | modules = [] 127 | modules.append(nn.Sequential( 128 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 129 | nn.BatchNorm2d(out_channels), 130 | nn.ReLU())) 131 | 132 | rates = tuple(atrous_rates) 133 | for rate in rates: 134 | modules.append(ASPPConv(in_channels, out_channels, rate)) 135 | 136 | modules.append(ASPPPooling(in_channels, out_channels)) 137 | 138 | self.convs = nn.ModuleList(modules) 139 | 140 | self.project = nn.Sequential( 141 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 142 | nn.BatchNorm2d(out_channels), 143 | nn.ReLU(), 144 | nn.Dropout(0.5)) 145 | 146 | def forward(self, x): 147 | res = [] 148 | for conv in self.convs: 149 | res.append(conv(x)) 150 | res = torch.cat(res, dim=1) 151 | return self.project(res) 152 | 153 | 154 | class Classifier_Module(nn.Module): 155 | 156 | def __init__(self, dilation_series, padding_series, num_classes): 157 | super(Classifier_Module, self).__init__() 158 | self.conv2d_list = nn.ModuleList() 159 | 160 | self.conv2d_list = nn.ModuleList([nn.Sequential( 161 | ASPP(2048, [12, 24, 36]), 162 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 163 | nn.BatchNorm2d(256), 164 | nn.ReLU(), 165 | nn.Conv2d(256, num_classes, 1) 166 | )]) 167 | 168 | def forward(self, x): 169 | out = self.conv2d_list[0](x) 170 | for i in range(len(self.conv2d_list)-1): 171 | out += self.conv2d_list[i+1](x) 172 | return out 173 | 174 | 175 | class ResNet(nn.Module): 176 | def __init__(self, num_classes=1000, num_seg_classes=19, pretrained=False, block=Bottleneck, layers=[3, 4, 23, 3]): 177 | self.inplanes = 64 178 | super(ResNet, self).__init__() 179 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 180 | bias=False) 181 | self.bn1 = nn.BatchNorm2d(64, affine = affine_par) 182 | for i in self.bn1.parameters(): 183 | i.requires_grad = False 184 | self.relu = nn.ReLU(inplace=True) 185 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # , ceil_mode=True) # change 186 | self.layer1 = self._make_layer(block, 64, layers[0]) 187 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 188 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 189 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 190 | self.fc_new = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], num_seg_classes) 191 | self.fc = nn.Linear(512 * block.expansion, num_classes) 192 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 193 | if pretrained: 194 | model_urls = { 195 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 196 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 197 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 198 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 199 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 200 | } 201 | if layers == [3, 4, 6, 3]: 202 | saved_state_dict = torch.utils.model_zoo.load_url(model_urls['resnet50']) 203 | elif layers == [3, 4, 23, 3]: 204 | saved_state_dict = torch.utils.model_zoo.load_url(model_urls['resnet101']) 205 | new_params = self.state_dict().copy() 206 | for i in saved_state_dict: 207 | i_parts = str(i).split('.') 208 | if not i_parts[0] == 'fc': 209 | assert '.'.join(i_parts[0:]) in new_params 210 | new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 211 | self.load_state_dict(new_params) 212 | else: 213 | for m in self.modules(): 214 | if isinstance(m, nn.Conv2d): 215 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 216 | m.weight.data.normal_(0, 0.01) 217 | elif isinstance(m, nn.BatchNorm2d): 218 | m.weight.data.fill_(1) 219 | m.bias.data.zero_() 220 | 221 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 222 | downsample = None 223 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: 224 | downsample = nn.Sequential( 225 | nn.Conv2d(self.inplanes, planes * block.expansion, 226 | kernel_size=1, stride=stride, bias=False), 227 | nn.BatchNorm2d(planes * block.expansion,affine = affine_par)) 228 | for i in downsample._modules['1'].parameters(): 229 | i.requires_grad = False 230 | layers = [] 231 | layers.append(block(self.inplanes, planes, stride, dilation=(dilation//2) if dilation > 1 else dilation, downsample=downsample)) 232 | self.inplanes = planes * block.expansion 233 | for i in range(1, blocks): 234 | layers.append(block(self.inplanes, planes, dilation=dilation)) 235 | 236 | return nn.Sequential(*layers) 237 | 238 | def _make_pred_layer(self, block, dilation_series, padding_series, num_classes): 239 | return block(dilation_series, padding_series, num_classes) 240 | 241 | def forward_fc(self, f4, task='old'): 242 | x = f4 243 | if task in ['old', 'new']: 244 | x = self.avgpool(x) 245 | x = x.reshape(x.size(0), -1) 246 | if task == 'old': 247 | x = self.fc(x) 248 | else: 249 | x = self.fc_new(x) 250 | x = nn.functional.interpolate(x, size=self.input_size, mode='bilinear', align_corners=True) 251 | return x 252 | 253 | def forward_partial(self, feature, stage): 254 | # stage: start forwarding **from** this stage (inclusive) 255 | if stage <= 1: 256 | feature = self.layer1(feature) 257 | if stage <= 2: 258 | feature = self.layer2(feature) 259 | if stage <= 3: 260 | feature = self.layer3(feature) 261 | if stage <= 4: 262 | feature = self.layer4(feature) 263 | return feature 264 | 265 | def forward_backbone(self, x, output_features=['layer4']): 266 | features = {} 267 | x = self.conv1(x) 268 | x = self.bn1(x) 269 | x = self.relu(x) 270 | if 'layer0' in output_features: features['layer0'] = f0 271 | x = self.maxpool(x) 272 | f1 = self.layer1(x) 273 | if 'layer1' in output_features: features['layer1'] = f1 274 | f2 = self.layer2(f1) 275 | if 'layer2' in output_features: features['layer2'] = f2 276 | f3 = self.layer3(f2) 277 | if 'layer3' in output_features: features['layer3'] = f3 278 | f4 = self.layer4(f3) 279 | if 'layer4' in output_features: features['layer4'] = f4 280 | if 'gap' in output_features: 281 | features['gap'] = self.avgpool(f4).view(f4.size(0), -1) 282 | return f4, features 283 | 284 | def forward(self, x, output_features=['layer4'], task='old'): 285 | ''' 286 | task: 'old' | 'new' | 'new_seg' 287 | 'old', 'new': classification tasks (ImageNet or Visda) 288 | 'new_seg': segmentation head (convs) 289 | ''' 290 | self.input_size = x.size()[2:] 291 | f4, features = self.forward_backbone(x, output_features) 292 | x = self.forward_fc(f4, task=task) 293 | return x, features 294 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from pdb import set_trace as bp 9 | from model.csg_builder import chunk_feature 10 | 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 13 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 14 | 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 22 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 23 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes, out_planes, stride=1): 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 42 | base_width=64, dilation=1, norm_layer=None): 43 | super(BasicBlock, self).__init__() 44 | if norm_layer is None: 45 | norm_layer = nn.BatchNorm2d 46 | if groups != 1 or base_width != 64: 47 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 48 | if dilation > 1: 49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = norm_layer(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = norm_layer(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | identity = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | expansion = 4 80 | 81 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 82 | base_width=64, dilation=1, norm_layer=None, 83 | ): 84 | super(Bottleneck, self).__init__() 85 | if norm_layer is None: 86 | norm_layer = nn.BatchNorm2d 87 | width = int(planes * (base_width / 64.)) * groups 88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 89 | self.conv1 = conv1x1(inplanes, width) 90 | self.bn1 = norm_layer(width) 91 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 92 | self.bn2 = norm_layer(width) 93 | self.conv3 = conv1x1(width, planes * self.expansion) 94 | self.bn3 = norm_layer(planes * self.expansion) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.last = last 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | identity = self.downsample(x) 116 | 117 | out += identity 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | class ResNet(nn.Module): 124 | 125 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 126 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 127 | norm_layer=None): 128 | super(ResNet, self).__init__() 129 | if norm_layer is None: 130 | norm_layer = nn.BatchNorm2d 131 | self._norm_layer = norm_layer 132 | 133 | self.inplanes = 64 134 | self.dilation = 1 135 | if replace_stride_with_dilation is None: 136 | # each element in the tuple indicates if we should replace 137 | # the 2x2 stride with a dilated convolution instead 138 | replace_stride_with_dilation = [False, False, False] 139 | if len(replace_stride_with_dilation) != 3: 140 | raise ValueError("replace_stride_with_dilation should be None " 141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 142 | self.groups = groups 143 | self.base_width = width_per_group 144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | self.bn1 = norm_layer(self.inplanes) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | self.layer1 = self._make_layer(block, 64, layers[0]) 150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 151 | dilate=replace_stride_with_dilation[0]) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 153 | dilate=replace_stride_with_dilation[1]) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 155 | dilate=replace_stride_with_dilation[2]) 156 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 157 | self.fc = nn.Linear(512 * block.expansion, num_classes) 158 | 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 162 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 163 | nn.init.constant_(m.weight, 1) 164 | nn.init.constant_(m.bias, 0) 165 | 166 | # Zero-initialize the last BN in each residual branch, 167 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 168 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 169 | if zero_init_residual: 170 | for m in self.modules(): 171 | if isinstance(m, Bottleneck): 172 | nn.init.constant_(m.bn3.weight, 0) 173 | elif isinstance(m, BasicBlock): 174 | nn.init.constant_(m.bn2.weight, 0) 175 | 176 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 177 | norm_layer = self._norm_layer 178 | downsample = None 179 | previous_dilation = self.dilation 180 | if dilate: 181 | self.dilation *= stride 182 | stride = 1 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | conv1x1(self.inplanes, planes * block.expansion, stride), 186 | norm_layer(planes * block.expansion), 187 | ) 188 | 189 | layers = [] 190 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 191 | self.base_width, previous_dilation, norm_layer)) 192 | self.inplanes = planes * block.expansion 193 | for _idx in range(1, blocks): 194 | layers.append(block(self.inplanes, planes, groups=self.groups, 195 | base_width=self.base_width, dilation=self.dilation, 196 | norm_layer=norm_layer, 197 | )) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward_fc(self, f4, task='old', f3=None, f2=None, return_mid_feature=False): 202 | x = f4 203 | if task in ['old', 'new']: 204 | x = self.avgpool(x) 205 | x = x.reshape(x.size(0), -1) 206 | if task == 'old': 207 | x = self.fc(x) 208 | return x 209 | else: 210 | if return_mid_feature: 211 | mid = self.fc_new[0](x) 212 | x = self.fc_new[1](mid) 213 | x = self.fc_new[2](x) 214 | return x, mid 215 | else: 216 | x = self.fc_new(x) 217 | return x 218 | 219 | def forward_partial(self, feature, stage): 220 | # stage: start forwarding **from** this stage (inclusive) 221 | # assert stage in [1, 2, 3, 4] 222 | if stage <= 1: 223 | feature = self.layer1(feature) 224 | if stage <= 2: 225 | feature = self.layer2(feature) 226 | if stage <= 3: 227 | feature = self.layer3(feature) 228 | if stage <= 4: 229 | feature = self.layer4(feature) 230 | return feature 231 | 232 | def forward_backbone(self, x, output_features=['layer4']): 233 | features = {} 234 | f0 = self.conv1(x) 235 | f0 = self.bn1(f0) 236 | f0 = self.relu(f0) 237 | if 'layer0' in output_features: features['layer0'] = f0 238 | f0 = self.maxpool(f0) 239 | f1 = self.layer1(f0) 240 | if 'layer1' in output_features: features['layer1'] = f1 241 | f2 = self.layer2(f1) 242 | if 'layer2' in output_features: features['layer2'] = f2 243 | f3 = self.layer3(f2) 244 | if 'layer3' in output_features: features['layer3'] = f3 245 | f4 = self.layer4(f3) 246 | if 'layer4' in output_features: features['layer4'] = f4 247 | if 'gap' in output_features: 248 | features['gap'] = self.avgpool(f4).view(f4.size(0), -1) 249 | return f4, features 250 | # return f4, f3, f2, features 251 | 252 | def forward(self, x, output_features=['layer4'], task='old'): 253 | ''' 254 | task: 'old' | 'new' | 'new_seg' 255 | 'old', 'new': classification tasks (ImageNet or Visda) 256 | 'new_seg': segmentation head (convs) 257 | ''' 258 | f4, features = self.forward_backbone(x, output_features) 259 | if 'fc_mid' in output_features: 260 | x, _mid = self.forward_fc(f4, task=task, return_mid_feature=True) 261 | features['fc_mid'] = _mid 262 | else: 263 | x = self.forward_fc(f4, task=task) 264 | return x, features 265 | 266 | 267 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 268 | model = ResNet(block, layers, **kwargs) 269 | if pretrained: 270 | from torchvision.models.utils import load_state_dict_from_url 271 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 272 | # model.load_state_dict(state_dict) 273 | state = model.state_dict() 274 | pretrained_dict = {k: v for k, v in state_dict.items() if k in state and state[k].size() == v.size()} 275 | state.update(pretrained_dict) 276 | model.load_state_dict(state) 277 | return model 278 | 279 | 280 | def resnet18(pretrained=False, progress=True, **kwargs): 281 | """Constructs a ResNet-18 model. 282 | 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | """ 287 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 288 | **kwargs) 289 | 290 | 291 | def resnet34(pretrained=False, progress=True, **kwargs): 292 | """Constructs a ResNet-34 model. 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | progress (bool): If True, displays a progress bar of the download to stderr 297 | """ 298 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 299 | **kwargs) 300 | 301 | 302 | def resnet50(pretrained=False, progress=True, **kwargs): 303 | """Constructs a ResNet-50 model. 304 | 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | progress (bool): If True, displays a progress bar of the download to stderr 308 | """ 309 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 310 | **kwargs) 311 | 312 | 313 | def resnet101(pretrained=False, progress=True, **kwargs): 314 | """Constructs a ResNet-101 model. 315 | 316 | Args: 317 | pretrained (bool): If True, returns a model pre-trained on ImageNet 318 | progress (bool): If True, displays a progress bar of the download to stderr 319 | """ 320 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 321 | **kwargs) 322 | 323 | 324 | def resnet152(pretrained=False, progress=True, **kwargs): 325 | """Constructs a ResNet-152 model. 326 | 327 | Args: 328 | pretrained (bool): If True, returns a model pre-trained on ImageNet 329 | progress (bool): If True, displays a progress bar of the download to stderr 330 | """ 331 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 332 | **kwargs) 333 | 334 | 335 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 336 | """Constructs a ResNeXt-50 32x4d model. 337 | 338 | Args: 339 | pretrained (bool): If True, returns a model pre-trained on ImageNet 340 | progress (bool): If True, displays a progress bar of the download to stderr 341 | """ 342 | kwargs['groups'] = 32 343 | kwargs['width_per_group'] = 4 344 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 345 | pretrained, progress, **kwargs) 346 | 347 | 348 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 349 | """Constructs a ResNeXt-101 32x8d model. 350 | 351 | Args: 352 | pretrained (bool): If True, returns a model pre-trained on ImageNet 353 | progress (bool): If True, displays a progress bar of the download to stderr 354 | """ 355 | kwargs['groups'] = 32 356 | kwargs['width_per_group'] = 8 357 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 358 | pretrained, progress, **kwargs) 359 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict 2 | matplotlib==3.0.0 3 | numpy==1.16.1 4 | opencv-python==3.4.4.19 5 | Pillow==6.2.0 6 | scipy==1.1.0 7 | tensorflow 8 | tensorboard==1.9.0 9 | tensorboardX==1.6 10 | torch==1.2.0 11 | torchvision==0.3.0 12 | tqdm==4.25.0 -------------------------------------------------------------------------------- /tools/datasets/BaseDataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import os 5 | import cv2 6 | cv2.setNumThreads(0) 7 | import torch 8 | import numpy as np 9 | from random import shuffle 10 | 11 | import torch.utils.data as data 12 | 13 | 14 | class BaseDataset(data.Dataset): 15 | def __init__(self, setting, split_name, preprocess=None, file_length=None): 16 | super(BaseDataset, self).__init__() 17 | self._split_name = split_name 18 | self._img_path = setting['img_root'] 19 | self._gt_path = setting['gt_root'] 20 | self._portion = setting['portion'] if 'portion' in setting else None 21 | self._train_source = setting['train_source'] 22 | self._eval_source = setting['eval_source'] 23 | self._test_source = setting['test_source'] if 'test_source' in setting else setting['eval_source'] 24 | self._down_sampling = setting['down_sampling'] 25 | print("using downsampling:", self._down_sampling) 26 | self._file_names = self._get_file_names(split_name) 27 | print("Found %d images"%len(self._file_names)) 28 | self._file_length = file_length 29 | self.preprocess = preprocess 30 | 31 | def __len__(self): 32 | if self._file_length is not None: 33 | return self._file_length 34 | return len(self._file_names) 35 | 36 | def __getitem__(self, index): 37 | if self._file_length is not None: 38 | names = self._construct_new_file_names(self._file_length)[index] 39 | else: 40 | names = self._file_names[index] 41 | img_path = os.path.join(self._img_path, names[0]) 42 | gt_path = os.path.join(self._gt_path, names[1]) 43 | item_name = names[1].split("/")[-1].split(".")[0] 44 | 45 | img, gt = self._fetch_data(img_path, gt_path) 46 | img = img[:, :, ::-1] 47 | if self.preprocess is not None: 48 | img, gt, extra_dict = self.preprocess(img, gt) 49 | 50 | if self._split_name is 'train': 51 | img = torch.from_numpy(np.ascontiguousarray(img)).float() 52 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long() 53 | if self.preprocess is not None and extra_dict is not None: 54 | for k, v in extra_dict.items(): 55 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v)) 56 | if 'label' in k: 57 | extra_dict[k] = extra_dict[k].long() 58 | if 'img' in k: 59 | extra_dict[k] = extra_dict[k].float() 60 | 61 | output_dict = dict(data=img, label=gt, fn=str(item_name), 62 | n=len(self._file_names)) 63 | if self.preprocess is not None and extra_dict is not None: 64 | output_dict.update(**extra_dict) 65 | 66 | return output_dict 67 | 68 | def _fetch_data(self, img_path, gt_path, dtype=None): 69 | img = self._open_image(img_path, down_sampling=self._down_sampling) 70 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype, down_sampling=self._down_sampling) 71 | 72 | return img, gt 73 | 74 | def _get_file_names(self, split_name): 75 | assert split_name in ['train', 'val', 'test'] 76 | source = self._train_source 77 | if split_name == "val": 78 | source = self._eval_source 79 | elif split_name == 'test': 80 | source = self._test_source 81 | 82 | file_names = [] 83 | with open(source) as f: 84 | files = f.readlines() 85 | if self._portion is not None: 86 | shuffle(files) 87 | num_files = len(files) 88 | if self._portion > 0: 89 | split = int(np.floor(self._portion * num_files)) 90 | files = files[:split] 91 | elif self._portion < 0: 92 | split = int(np.floor((1 + self._portion) * num_files)) 93 | files = files[split:] 94 | 95 | for item in files: 96 | img_name, gt_name = self._process_item_names(item) 97 | file_names.append([img_name, gt_name]) 98 | 99 | return file_names 100 | 101 | def _construct_new_file_names(self, length): 102 | assert isinstance(length, int) 103 | files_len = len(self._file_names) 104 | new_file_names = self._file_names * (length // files_len) 105 | 106 | rand_indices = torch.randperm(files_len).tolist() 107 | new_indices = rand_indices[:length % files_len] 108 | 109 | new_file_names += [self._file_names[i] for i in new_indices] 110 | 111 | return new_file_names 112 | 113 | @staticmethod 114 | def _process_item_names(item): 115 | item = item.strip() 116 | # item = item.split('\t') 117 | item = item.split(' ') 118 | img_name = item[0] 119 | gt_name = item[1] 120 | 121 | return img_name, gt_name 122 | 123 | def get_length(self): 124 | return self.__len__() 125 | 126 | @staticmethod 127 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None, down_sampling=1): 128 | # cv2: B G R 129 | # h w c 130 | img = np.array(cv2.imread(filepath, mode), dtype=dtype) 131 | 132 | if isinstance(down_sampling, int): 133 | H, W = img.shape[:2] 134 | if len(img.shape) == 3: 135 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_LINEAR) 136 | else: 137 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_NEAREST) 138 | assert img.shape[0] == H // down_sampling and img.shape[1] == W // down_sampling 139 | else: 140 | assert (isinstance(down_sampling, tuple) or isinstance(down_sampling, list)) and len(down_sampling) == 2 141 | if len(img.shape) == 3: 142 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_LINEAR) 143 | else: 144 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_NEAREST) 145 | assert img.shape[0] == down_sampling[0] and img.shape[1] == down_sampling[1] 146 | 147 | return img 148 | 149 | @classmethod 150 | def get_class_colors(*args): 151 | raise NotImplementedError 152 | 153 | @classmethod 154 | def get_class_names(*args): 155 | raise NotImplementedError 156 | 157 | 158 | if __name__ == "__main__": 159 | data_setting = {'img_root': '', 160 | 'gt_root': '', 161 | 'train_source': '', 162 | 'eval_source': ''} 163 | bd = BaseDataset(data_setting, 'train', None) 164 | print(bd.get_class_names()) 165 | -------------------------------------------------------------------------------- /tools/datasets/cityscapes/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)) 3 | 4 | import numpy as np 5 | 6 | from datasets.BaseDataset import BaseDataset 7 | 8 | 9 | class Cityscapes(BaseDataset): 10 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 11 | 28, 31, 32, 33] 12 | 13 | @classmethod 14 | def get_class_colors(*args): 15 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70], 16 | [102, 102, 156], [190, 153, 153], [153, 153, 153], 17 | [250, 170, 30], [220, 220, 0], [107, 142, 35], 18 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], 19 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], 20 | [0, 0, 230], [119, 11, 32]] 21 | 22 | @classmethod 23 | def get_class_names(*args): 24 | # class counting(gtFine) 25 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832 26 | # 359 274 142 513 1646 27 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 28 | 'traffic light', 'traffic sign', 29 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 30 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] 31 | 32 | @classmethod 33 | def transform_label(cls, pred, name): 34 | label = np.zeros(pred.shape) 35 | ids = np.unique(pred) 36 | for id in ids: 37 | label[np.where(pred == id)] = cls.trans_labels[id] 38 | 39 | new_name = (name.split('.')[0]).split('_')[:-1] 40 | new_name = '_'.join(new_name) + '.png' 41 | 42 | print('Trans', name, 'to', new_name, ' ', 43 | np.unique(np.array(pred, np.uint8)), ' ---------> ', 44 | np.unique(np.array(label, np.uint8))) 45 | return label, new_name 46 | -------------------------------------------------------------------------------- /tools/engine/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)) 3 | 4 | import os 5 | from PIL import Image 6 | import cv2 7 | import numpy as np 8 | import time 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import torch.multiprocessing as mp 13 | 14 | from tools.engine.logger import get_logger 15 | from tools.utils.pyt_utils import load_model, link_file, ensure_dir 16 | from tools.utils.img_utils import pad_image_to_shape, normalize 17 | 18 | logger = get_logger() 19 | 20 | 21 | class Evaluator(object): 22 | def __init__(self, dataset, class_num, image_mean, image_std, network, 23 | multi_scales, is_flip, devices=0, threds=3, config=None, logger=None, 24 | verbose=False, save_path=None, show_image=False, show_prediction=False): 25 | self.dataset = dataset 26 | self.ndata = self.dataset.get_length() 27 | self.class_num = class_num 28 | self.image_mean = image_mean 29 | self.image_std = image_std 30 | self.multi_scales = multi_scales 31 | self.is_flip = is_flip 32 | self.network = network 33 | self.devices = devices 34 | if type(self.devices) == int: self.devices = [self.devices] 35 | self.threds = threds 36 | self.config = config 37 | self.logger = logger 38 | 39 | self.context = mp.get_context('spawn') 40 | self.val_func = None 41 | self.results_queue = self.context.Queue(self.ndata) 42 | self.features_queue = self.context.Queue(self.ndata) 43 | 44 | self.verbose = verbose 45 | self.save_path = save_path 46 | if save_path is not None: 47 | ensure_dir(save_path) 48 | self.show_image = show_image 49 | self.show_prediction = show_prediction 50 | 51 | def run(self, model_path, model_indice, log_file, log_file_link): 52 | """There are four evaluation modes: 53 | 1.only eval a .pth model: -e *.pth 54 | 2.only eval a certain epoch: -e epoch 55 | 3.eval all epochs in a given section: -e start_epoch-end_epoch 56 | 4.eval all epochs from a certain started epoch: -e start_epoch- 57 | """ 58 | if '.pth' in model_indice: 59 | models = [model_indice, ] 60 | elif "-" in model_indice: 61 | start_epoch = int(model_indice.split("-")[0]) 62 | end_epoch = model_indice.split("-")[1] 63 | 64 | models = os.listdir(model_path) 65 | models.remove("epoch-last.pth") 66 | sorted_models = [None] * len(models) 67 | model_idx = [0] * len(models) 68 | 69 | for idx, m in enumerate(models): 70 | num = m.split(".")[0].split("-")[1] 71 | model_idx[idx] = num 72 | sorted_models[idx] = m 73 | model_idx = np.array([int(i) for i in model_idx]) 74 | 75 | down_bound = model_idx >= start_epoch 76 | up_bound = [True] * len(sorted_models) 77 | if end_epoch: 78 | end_epoch = int(end_epoch) 79 | assert start_epoch < end_epoch 80 | up_bound = model_idx <= end_epoch 81 | bound = up_bound * down_bound 82 | model_slice = np.array(sorted_models)[bound] 83 | models = [os.path.join(model_path, model) for model in 84 | model_slice] 85 | else: 86 | models = [os.path.join(model_path, 87 | 'epoch-%s.pth' % model_indice), ] 88 | 89 | results = open(log_file, 'a') 90 | link_file(log_file, log_file_link) 91 | 92 | for model in models: 93 | logger.info("Load Model: %s" % model) 94 | self.val_func = load_model(self.network, model) 95 | result_line, mIoU = self.multi_process_evaluation() 96 | 97 | results.write('Model: ' + model + '\n') 98 | results.write(result_line) 99 | results.write('\n') 100 | results.flush() 101 | 102 | results.close() 103 | 104 | def run_online(self): 105 | """ 106 | eval during training 107 | """ 108 | self.val_func = self.network 109 | result_line, mIoU = self.single_process_evaluation() 110 | return result_line, mIoU 111 | 112 | def single_process_evaluation(self): 113 | all_results = [] 114 | from pdb import set_trace as bp 115 | with torch.no_grad(): 116 | for idx in tqdm(range(self.ndata)): 117 | dd = self.dataset[idx] 118 | results_dict = self.func_per_iteration(dd, self.devices[0], iter=idx) 119 | all_results.append(results_dict) 120 | _, _mIoU = self.compute_metric([results_dict]) 121 | result_line, mIoU = self.compute_metric(all_results) 122 | return result_line, mIoU 123 | 124 | def run_online_multiprocess(self): 125 | """ 126 | eval during training 127 | """ 128 | self.val_func = self.network 129 | result_line, mIoU = self.multi_process_single_gpu_evaluation() 130 | return result_line, mIoU 131 | 132 | def multi_process_single_gpu_evaluation(self): 133 | # start_eval_time = time.perf_counter() 134 | stride = int(np.ceil(self.ndata / self.threds)) 135 | 136 | # start multi-process on single-gpu 137 | procs = [] 138 | for d in range(self.threds): 139 | e_record = min((d + 1) * stride, self.ndata) 140 | shred_list = list(range(d * stride, e_record)) 141 | device = self.devices[0] 142 | logger.info('Thread %d handle %d data.' % (d, len(shred_list))) 143 | p = self.context.Process(target=self.worker, args=(shred_list, device)) 144 | procs.append(p) 145 | 146 | for p in procs: 147 | p.start() 148 | 149 | all_results = [] 150 | for _ in tqdm(range(self.ndata)): 151 | t = self.results_queue.get() 152 | all_results.append(t) 153 | if self.verbose: 154 | self.compute_metric(all_results) 155 | 156 | for p in procs: 157 | p.join() 158 | 159 | result_line, mIoU = self.compute_metric(all_results) 160 | return result_line, mIoU 161 | 162 | def multi_process_evaluation(self): 163 | start_eval_time = time.perf_counter() 164 | nr_devices = len(self.devices) 165 | stride = int(np.ceil(self.ndata / nr_devices)) 166 | 167 | # start multi-process on multi-gpu 168 | procs = [] 169 | for d in range(nr_devices): 170 | e_record = min((d + 1) * stride, self.ndata) 171 | shred_list = list(range(d * stride, e_record)) 172 | device = self.devices[d] 173 | logger.info('GPU %s handle %d data.' % (device, len(shred_list))) 174 | p = self.context.Process(target=self.worker, args=(shred_list, device)) 175 | procs.append(p) 176 | 177 | for p in procs: 178 | p.start() 179 | 180 | all_results = [] 181 | for _ in tqdm(range(self.ndata)): 182 | t = self.results_queue.get() 183 | all_results.append(t) 184 | if self.verbose: 185 | self.compute_metric(all_results) 186 | 187 | for p in procs: 188 | p.join() 189 | 190 | result_line, mIoU = self.compute_metric(all_results) 191 | logger.info('Evaluation Elapsed Time: %.2fs' % (time.perf_counter() - start_eval_time)) 192 | return result_line, mIoU 193 | 194 | def worker(self, shred_list, device): 195 | # start_load_time = time.time() 196 | # logger.info('Load Model on Device %d: %.2fs' % (device, time.time() - start_load_time)) 197 | for idx in shred_list: 198 | dd = self.dataset[idx] 199 | results_dict = self.func_per_iteration(dd, device, iter=idx) 200 | self.results_queue.put(results_dict) 201 | 202 | 203 | def func_per_iteration(self, data, device, iter=None): 204 | raise NotImplementedError 205 | 206 | def compute_metric(self, results): 207 | raise NotImplementedError 208 | 209 | # evaluate the whole image at once 210 | def whole_eval(self, img, output_size, resize=None, input_size=None, device=None): 211 | if input_size is not None: 212 | img, margin = self.process_image(img, resize=resize, crop_size=input_size) 213 | else: 214 | img = self.process_image(img, resize=resize, crop_size=input_size) 215 | 216 | pred = self.val_func_process(img, device) 217 | if input_size is not None: 218 | pred = pred[:, margin[0]:(pred.shape[1] - margin[1]), margin[2]:(pred.shape[2] - margin[3])] 219 | pred = pred.permute(1, 2, 0) 220 | pred = pred.cpu().numpy() 221 | if output_size is not None: 222 | pred = cv2.resize(pred, 223 | (output_size[1], output_size[0]), 224 | interpolation=cv2.INTER_LINEAR) 225 | 226 | # pred = pred.argmax(2) 227 | 228 | return pred 229 | 230 | # slide the window to evaluate the image 231 | def sliding_eval(self, img, crop_size, stride_rate, device=None): 232 | ori_rows, ori_cols, c = img.shape 233 | processed_pred = np.zeros((ori_rows, ori_cols, self.class_num)) 234 | 235 | for s in self.multi_scales: 236 | img_scale = cv2.resize(img, None, fx=s, fy=s, 237 | interpolation=cv2.INTER_LINEAR) 238 | new_rows, new_cols, _ = img_scale.shape 239 | processed_pred += self.scale_process(img_scale, 240 | (ori_rows, ori_cols), 241 | crop_size, stride_rate, device) 242 | 243 | pred = processed_pred.argmax(2) 244 | 245 | return pred 246 | 247 | def scale_process(self, img, ori_shape, crop_size, stride_rate, 248 | device=None): 249 | new_rows, new_cols, c = img.shape 250 | long_size = new_cols if new_cols > new_rows else new_rows 251 | 252 | if long_size <= crop_size: 253 | input_data, margin = self.process_image(img, crop_size=crop_size) 254 | score = self.val_func_process(input_data, device) 255 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 256 | margin[2]:(score.shape[2] - margin[3])] 257 | else: 258 | stride = int(np.ceil(crop_size * stride_rate)) 259 | img_pad, margin = pad_image_to_shape(img, crop_size, 260 | cv2.BORDER_CONSTANT, value=0) 261 | 262 | pad_rows = img_pad.shape[0] 263 | pad_cols = img_pad.shape[1] 264 | r_grid = int(np.ceil((pad_rows - crop_size) / stride)) + 1 265 | c_grid = int(np.ceil((pad_cols - crop_size) / stride)) + 1 266 | data_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda( 267 | device) 268 | count_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda( 269 | device) 270 | 271 | for grid_yidx in range(r_grid): 272 | for grid_xidx in range(c_grid): 273 | s_x = grid_xidx * stride 274 | s_y = grid_yidx * stride 275 | e_x = min(s_x + crop_size, pad_cols) 276 | e_y = min(s_y + crop_size, pad_rows) 277 | s_x = e_x - crop_size 278 | s_y = e_y - crop_size 279 | img_sub = img_pad[s_y:e_y, s_x: e_x, :] 280 | count_scale[:, s_y: e_y, s_x: e_x] += 1 281 | 282 | input_data, tmargin = self.process_image(img_sub, crop_size=crop_size) 283 | temp_score = self.val_func_process(input_data, device) 284 | temp_score = temp_score[:, 285 | tmargin[0]:(temp_score.shape[1] - tmargin[1]), 286 | tmargin[2]:(temp_score.shape[2] - tmargin[3])] 287 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score 288 | # score = data_scale / count_scale 289 | score = data_scale 290 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 291 | margin[2]:(score.shape[2] - margin[3])] 292 | 293 | score = score.permute(1, 2, 0) 294 | data_output = cv2.resize(score.cpu().numpy(), 295 | (ori_shape[1], ori_shape[0]), 296 | interpolation=cv2.INTER_LINEAR) 297 | 298 | return data_output 299 | 300 | def val_func_process(self, input_data, device=None): 301 | input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32) 302 | input_data = torch.FloatTensor(input_data).cuda(device) 303 | 304 | with torch.cuda.device(input_data.get_device()): 305 | self.val_func.eval() 306 | self.val_func.to(input_data.get_device()) 307 | with torch.no_grad(): 308 | score = self.val_func(input_data, output_features=[], task='new_seg')[0] 309 | if (isinstance(score, tuple) or isinstance(score, list)) and len(score) > 1: 310 | score = score[self.out_idx] 311 | score = score[0] # a single image pass, ignore batch dim 312 | 313 | if self.is_flip: 314 | input_data = input_data.flip(-1) 315 | score_flip = self.val_func(input_data)[0] 316 | score_flip = score_flip[0] # a single image pass, ignore batch dim 317 | score += score_flip.flip(-1) 318 | score = torch.exp(score) 319 | # score = score.data 320 | 321 | return score 322 | 323 | def process_image(self, img, resize=None, crop_size=None): 324 | p_img = img 325 | 326 | if img.shape[2] < 3: 327 | im_b = p_img 328 | im_g = p_img 329 | im_r = p_img 330 | p_img = np.concatenate((im_b, im_g, im_r), axis=2) 331 | 332 | if resize is not None: 333 | if isinstance(resize, float): 334 | _size = p_img.shape[:2] 335 | # p_img = np.array(Image.fromarray(p_img).resize((int(_size[0]*resize), int(_size[1]*resize)), Image.BILINEAR)) 336 | p_img = np.array(Image.fromarray(p_img).resize((int(_size[1]*resize), int(_size[0]*resize)), Image.BILINEAR)) 337 | elif isinstance(resize, tuple) or isinstance(resize, list): 338 | assert len(resize) == 2 339 | p_img = np.array(Image.fromarray(p_img).resize((int(resize[0]), int(resize[1])), Image.BILINEAR)) 340 | 341 | p_img = normalize(p_img, self.image_mean, self.image_std) 342 | 343 | if crop_size is not None: 344 | p_img, margin = pad_image_to_shape(p_img, crop_size, cv2.BORDER_CONSTANT, value=0) 345 | p_img = p_img.transpose(2, 0, 1) 346 | 347 | return p_img, margin 348 | 349 | p_img = p_img.transpose(2, 0, 1) 350 | 351 | return p_img 352 | -------------------------------------------------------------------------------- /tools/engine/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)) 3 | 4 | import os 5 | import sys 6 | import logging 7 | 8 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') 9 | _default_level = logging.getLevelName(_default_level_name.upper()) 10 | 11 | 12 | class LogFormatter(logging.Formatter): 13 | log_fout = None 14 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' 15 | date = '%(asctime)s ' 16 | msg = '%(message)s' 17 | 18 | def format(self, record): 19 | if record.levelno == logging.DEBUG: 20 | mcl, mtxt = self._color_dbg, 'DBG' 21 | elif record.levelno == logging.WARNING: 22 | mcl, mtxt = self._color_warn, 'WRN' 23 | elif record.levelno == logging.ERROR: 24 | mcl, mtxt = self._color_err, 'ERR' 25 | else: 26 | mcl, mtxt = self._color_normal, '' 27 | 28 | if mtxt: 29 | mtxt += ' ' 30 | 31 | if self.log_fout: 32 | self.__set_fmt(self.date_full + mtxt + self.msg) 33 | formatted = super(LogFormatter, self).format(record) 34 | # self.log_fout.write(formatted) 35 | # self.log_fout.write('\n') 36 | # self.log_fout.flush() 37 | return formatted 38 | 39 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) 40 | formatted = super(LogFormatter, self).format(record) 41 | 42 | return formatted 43 | 44 | if sys.version_info.major < 3: 45 | def __set_fmt(self, fmt): 46 | self._fmt = fmt 47 | else: 48 | def __set_fmt(self, fmt): 49 | self._style._fmt = fmt 50 | 51 | @staticmethod 52 | def _color_dbg(msg): 53 | return '\x1b[36m{}\x1b[0m'.format(msg) 54 | 55 | @staticmethod 56 | def _color_warn(msg): 57 | return '\x1b[1;31m{}\x1b[0m'.format(msg) 58 | 59 | @staticmethod 60 | def _color_err(msg): 61 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg) 62 | 63 | @staticmethod 64 | def _color_omitted(msg): 65 | return '\x1b[35m{}\x1b[0m'.format(msg) 66 | 67 | @staticmethod 68 | def _color_normal(msg): 69 | return msg 70 | 71 | @staticmethod 72 | def _color_date(msg): 73 | return '\x1b[32m{}\x1b[0m'.format(msg) 74 | 75 | 76 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter): 77 | logger = logging.getLogger() 78 | logger.setLevel(_default_level) 79 | del logger.handlers[:] 80 | 81 | if log_dir and log_file: 82 | if not os.path.isdir(log_dir): os.makedirs(log_dir) 83 | LogFormatter.log_fout = True 84 | file_handler = logging.FileHandler(log_file, mode='a') 85 | file_handler.setLevel(logging.INFO) 86 | file_handler.setFormatter(formatter) 87 | logger.addHandler(file_handler) 88 | 89 | stream_handler = logging.StreamHandler() 90 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S')) 91 | stream_handler.setLevel(0) 92 | logger.addHandler(stream_handler) 93 | return logger 94 | -------------------------------------------------------------------------------- /tools/engine/tester.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)) 3 | 4 | import os 5 | import cv2 6 | import numpy as np 7 | import time 8 | from tqdm import tqdm 9 | from pdb import set_trace as bp 10 | import torch 11 | import torch.multiprocessing as mp 12 | 13 | from engine.logger import get_logger 14 | from tools.utils.pyt_utils import load_model, link_file, ensure_dir 15 | from tools.utils.img_utils import pad_image_to_shape, normalize 16 | 17 | logger = get_logger() 18 | 19 | 20 | class Tester(object): 21 | def __init__(self, dataset, class_num, image_mean, image_std, network, 22 | multi_scales, is_flip, devices=0, out_idx=0, threds=3, config=None, logger=None, 23 | verbose=False, save_path=None, show_image=False): 24 | self.dataset = dataset 25 | self.ndata = self.dataset.get_length() 26 | self.class_num = class_num 27 | self.image_mean = image_mean 28 | self.image_std = image_std 29 | self.multi_scales = multi_scales 30 | self.is_flip = is_flip 31 | self.network = network 32 | self.devices = devices 33 | if type(self.devices) == int: self.devices = [self.devices] 34 | self.out_idx = out_idx 35 | self.threds = threds 36 | self.config = config 37 | self.logger = logger 38 | 39 | self.context = mp.get_context('spawn') 40 | self.val_func = None 41 | self.results_queue = self.context.Queue(self.ndata) 42 | 43 | self.verbose = verbose 44 | self.save_path = save_path 45 | if save_path is not None: 46 | ensure_dir(save_path) 47 | self.show_image = show_image 48 | 49 | def run(self, model_path, model_indice, log_file, log_file_link): 50 | """There are four evaluation modes: 51 | 1.only eval a .pth model: -e *.pth 52 | 2.only eval a certain epoch: -e epoch 53 | 3.eval all epochs in a given section: -e start_epoch-end_epoch 54 | 4.eval all epochs from a certain started epoch: -e start_epoch- 55 | """ 56 | if '.pth' in model_indice: 57 | models = [model_indice, ] 58 | elif "-" in model_indice: 59 | start_epoch = int(model_indice.split("-")[0]) 60 | end_epoch = model_indice.split("-")[1] 61 | 62 | models = os.listdir(model_path) 63 | models.remove("epoch-last.pth") 64 | sorted_models = [None] * len(models) 65 | model_idx = [0] * len(models) 66 | 67 | for idx, m in enumerate(models): 68 | num = m.split(".")[0].split("-")[1] 69 | model_idx[idx] = num 70 | sorted_models[idx] = m 71 | model_idx = np.array([int(i) for i in model_idx]) 72 | 73 | down_bound = model_idx >= start_epoch 74 | up_bound = [True] * len(sorted_models) 75 | if end_epoch: 76 | end_epoch = int(end_epoch) 77 | assert start_epoch < end_epoch 78 | up_bound = model_idx <= end_epoch 79 | bound = up_bound * down_bound 80 | model_slice = np.array(sorted_models)[bound] 81 | models = [os.path.join(model_path, model) for model in 82 | model_slice] 83 | else: 84 | models = [os.path.join(model_path, 85 | 'epoch-%s.pth' % model_indice), ] 86 | 87 | results = open(log_file, 'a') 88 | link_file(log_file, log_file_link) 89 | 90 | for model in models: 91 | logger.info("Load Model: %s" % model) 92 | self.val_func = load_model(self.network, model) 93 | result_line, mIoU = self.multi_process_evaluation() 94 | 95 | results.write('Model: ' + model + '\n') 96 | results.write(result_line) 97 | results.write('\n') 98 | results.flush() 99 | 100 | results.close() 101 | 102 | def run_online(self): 103 | """ 104 | eval during training 105 | """ 106 | self.val_func = self.network 107 | self.single_process_evaluation() 108 | 109 | def single_process_evaluation(self): 110 | with torch.no_grad(): 111 | for idx in tqdm(range(self.ndata)): 112 | dd = self.dataset[idx] 113 | self.func_per_iteration(dd, self.devices[0], iter=idx) 114 | 115 | def run_online_multiprocess(self): 116 | """ 117 | eval during training 118 | """ 119 | self.val_func = self.network 120 | self.multi_process_single_gpu_evaluation() 121 | 122 | def multi_process_single_gpu_evaluation(self): 123 | # start_eval_time = time.perf_counter() 124 | stride = int(np.ceil(self.ndata / self.threds)) 125 | 126 | # start multi-process on single-gpu 127 | procs = [] 128 | for d in range(self.threds): 129 | e_record = min((d + 1) * stride, self.ndata) 130 | shred_list = list(range(d * stride, e_record)) 131 | device = self.devices[0] 132 | logger.info('Thread %d handle %d data.' % (d, len(shred_list))) 133 | p = self.context.Process(target=self.worker, args=(shred_list, device)) 134 | procs.append(p) 135 | 136 | for p in procs: 137 | p.start() 138 | 139 | for p in procs: 140 | p.join() 141 | 142 | def multi_process_evaluation(self): 143 | nr_devices = len(self.devices) 144 | stride = int(np.ceil(self.ndata / nr_devices)) 145 | 146 | # start multi-process on multi-gpu 147 | procs = [] 148 | for d in range(nr_devices): 149 | e_record = min((d + 1) * stride, self.ndata) 150 | shred_list = list(range(d * stride, e_record)) 151 | device = self.devices[d] 152 | logger.info('GPU %s handle %d data.' % (device, len(shred_list))) 153 | p = self.context.Process(target=self.worker, args=(shred_list, device)) 154 | procs.append(p) 155 | 156 | for p in procs: 157 | p.start() 158 | 159 | for p in procs: 160 | p.join() 161 | 162 | def worker(self, shred_list, device): 163 | start_load_time = time.time() 164 | # logger.info('Load Model on Device %d: %.2fs' % (device, time.time() - start_load_time)) 165 | for idx in shred_list: 166 | dd = self.dataset[idx] 167 | results_dict = self.func_per_iteration(dd, device, iter=idx) 168 | self.results_queue.put(results_dict) 169 | 170 | def func_per_iteration(self, data, device, iter=None): 171 | raise NotImplementedError 172 | 173 | def compute_metric(self, results): 174 | raise NotImplementedError 175 | 176 | # evaluate the whole image at once 177 | def whole_eval(self, img, output_size, input_size=None, device=None): 178 | if input_size is not None: 179 | img, margin = self.process_image(img, input_size) 180 | else: 181 | img = self.process_image(img, input_size) 182 | 183 | pred = self.val_func_process(img, device) 184 | if input_size is not None: 185 | pred = pred[:, margin[0]:(pred.shape[1] - margin[1]), margin[2]:(pred.shape[2] - margin[3])] 186 | pred = pred.permute(1, 2, 0) 187 | pred = pred.cpu().numpy() 188 | if output_size is not None: 189 | pred = cv2.resize(pred, 190 | (output_size[1], output_size[0]), 191 | interpolation=cv2.INTER_LINEAR) 192 | 193 | pred = pred.argmax(2) 194 | 195 | return pred 196 | 197 | # slide the window to evaluate the image 198 | def sliding_eval(self, img, crop_size, stride_rate, device=None): 199 | ori_rows, ori_cols, c = img.shape 200 | processed_pred = np.zeros((ori_rows, ori_cols, self.class_num)) 201 | 202 | for s in self.multi_scales: 203 | img_scale = cv2.resize(img, None, fx=s, fy=s, 204 | interpolation=cv2.INTER_LINEAR) 205 | new_rows, new_cols, _ = img_scale.shape 206 | processed_pred += self.scale_process(img_scale, 207 | (ori_rows, ori_cols), 208 | crop_size, stride_rate, device) 209 | 210 | pred = processed_pred.argmax(2) 211 | 212 | return pred 213 | 214 | def scale_process(self, img, ori_shape, crop_size, stride_rate, 215 | device=None): 216 | new_rows, new_cols, c = img.shape 217 | long_size = new_cols if new_cols > new_rows else new_rows 218 | 219 | if long_size <= crop_size: 220 | input_data, margin = self.process_image(img, crop_size) 221 | score = self.val_func_process(input_data, device) 222 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 223 | margin[2]:(score.shape[2] - margin[3])] 224 | else: 225 | stride = int(np.ceil(crop_size * stride_rate)) 226 | img_pad, margin = pad_image_to_shape(img, crop_size, 227 | cv2.BORDER_CONSTANT, value=0) 228 | 229 | pad_rows = img_pad.shape[0] 230 | pad_cols = img_pad.shape[1] 231 | r_grid = int(np.ceil((pad_rows - crop_size) / stride)) + 1 232 | c_grid = int(np.ceil((pad_cols - crop_size) / stride)) + 1 233 | data_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda( 234 | device) 235 | count_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda( 236 | device) 237 | 238 | for grid_yidx in range(r_grid): 239 | for grid_xidx in range(c_grid): 240 | s_x = grid_xidx * stride 241 | s_y = grid_yidx * stride 242 | e_x = min(s_x + crop_size, pad_cols) 243 | e_y = min(s_y + crop_size, pad_rows) 244 | s_x = e_x - crop_size 245 | s_y = e_y - crop_size 246 | img_sub = img_pad[s_y:e_y, s_x: e_x, :] 247 | count_scale[:, s_y: e_y, s_x: e_x] += 1 248 | 249 | input_data, tmargin = self.process_image(img_sub, crop_size) 250 | temp_score = self.val_func_process(input_data, device) 251 | temp_score = temp_score[:, 252 | tmargin[0]:(temp_score.shape[1] - tmargin[1]), 253 | tmargin[2]:(temp_score.shape[2] - tmargin[3])] 254 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score 255 | # score = data_scale / count_scale 256 | score = data_scale 257 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 258 | margin[2]:(score.shape[2] - margin[3])] 259 | 260 | score = score.permute(1, 2, 0) 261 | data_output = cv2.resize(score.cpu().numpy(), 262 | (ori_shape[1], ori_shape[0]), 263 | interpolation=cv2.INTER_LINEAR) 264 | 265 | return data_output 266 | 267 | def val_func_process(self, input_data, device=None): 268 | input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32) 269 | input_data = torch.FloatTensor(input_data).cuda(device) 270 | 271 | with torch.cuda.device(input_data.get_device()): 272 | self.val_func.eval() 273 | self.val_func.to(input_data.get_device()) 274 | with torch.no_grad(): 275 | score = self.val_func(input_data, task='new_seg')[0] 276 | if (isinstance(score, tuple) or isinstance(score, list)) and len(score) > 1: 277 | score = score[self.out_idx] 278 | score = score[0] # a single image pass, ignore batch dim 279 | 280 | if self.is_flip: 281 | input_data = input_data.flip(-1) 282 | score_flip = self.val_func(input_data) 283 | score_flip = score_flip[0] 284 | score += score_flip.flip(-1) 285 | score = torch.exp(score) 286 | # score = score.data 287 | 288 | return score 289 | 290 | def process_image(self, img, crop_size=None): 291 | p_img = img 292 | 293 | if img.shape[2] < 3: 294 | im_b = p_img 295 | im_g = p_img 296 | im_r = p_img 297 | p_img = np.concatenate((im_b, im_g, im_r), axis=2) 298 | 299 | p_img = normalize(p_img, self.image_mean, self.image_std) 300 | 301 | if crop_size is not None: 302 | p_img, margin = pad_image_to_shape(p_img, crop_size, cv2.BORDER_CONSTANT, value=0) 303 | p_img = p_img.transpose(2, 0, 1) 304 | 305 | return p_img, margin 306 | 307 | p_img = p_img.transpose(2, 0, 1) 308 | 309 | return p_img 310 | -------------------------------------------------------------------------------- /tools/seg_opr/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | 6 | np.seterr(divide='ignore', invalid='ignore') 7 | 8 | 9 | # voc cityscapes metric 10 | def hist_info(n_cl, pred, gt): 11 | assert (pred.shape == gt.shape), "pred: " + str(pred.shape) + " v.s. gt: " + str(gt.shape) 12 | k = (gt >= 0) & (gt < n_cl) 13 | labeled = np.sum(k) 14 | correct = np.sum((pred[k] == gt[k])) 15 | 16 | return np.bincount(n_cl * gt[k].astype(int) + pred[k].astype(int), 17 | minlength=n_cl ** 2).reshape(n_cl, 18 | n_cl), labeled, correct 19 | 20 | 21 | def compute_score(hist, correct, labeled): 22 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 23 | mean_IU = np.nanmean(iu) 24 | mean_IU_no_back = np.nanmean(iu[1:]) 25 | mean_pixel_acc = correct / labeled 26 | 27 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 28 | 29 | 30 | # ade metric 31 | def meanIoU(area_intersection, area_union): 32 | iou = 1.0 * np.sum(area_intersection, axis=1) / np.sum(area_union, axis=1) 33 | meaniou = np.nanmean(iou) 34 | meaniou_no_back = np.nanmean(iou[1:]) 35 | 36 | return iou, meaniou, meaniou_no_back 37 | 38 | 39 | def intersectionAndUnion(imPred, imLab, numClass): 40 | # Remove classes from unlabeled pixels in gt image. 41 | # We should not penalize detections in unlabeled portions of the image. 42 | imPred = np.asarray(imPred).copy() 43 | imLab = np.asarray(imLab).copy() 44 | 45 | imPred += 1 46 | imLab += 1 47 | # Remove classes from unlabeled pixels in gt image. 48 | # We should not penalize detections in unlabeled portions of the image. 49 | imPred = imPred * (imLab > 0) 50 | 51 | # imPred = imPred * (imLab >= 0) 52 | 53 | # Compute area intersection: 54 | intersection = imPred * (imPred == imLab) 55 | (area_intersection, _) = np.histogram(intersection, bins=numClass, 56 | range=(1, numClass)) 57 | 58 | # Compute area union: 59 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 60 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 61 | area_union = area_pred + area_lab - area_intersection 62 | 63 | return area_intersection, area_union 64 | 65 | 66 | def mean_pixel_accuracy(pixel_correct, pixel_labeled): 67 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / ( 68 | np.spacing(1) + np.sum(pixel_labeled)) 69 | 70 | return mean_pixel_accuracy 71 | 72 | 73 | def pixelAccuracy(imPred, imLab): 74 | # Remove classes from unlabeled pixels in gt image. 75 | # We should not penalize detections in unlabeled portions of the image. 76 | pixel_labeled = np.sum(imLab >= 0) 77 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0)) 78 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 79 | 80 | return pixel_accuracy, pixel_correct, pixel_labeled 81 | 82 | 83 | def accuracy(preds, label): 84 | valid = (label >= 0) 85 | acc_sum = (valid * (preds == label)).sum() 86 | valid_sum = valid.sum() 87 | acc = float(acc_sum) / (valid_sum + 1e-10) 88 | return acc, valid_sum 89 | -------------------------------------------------------------------------------- /tools/utils/img_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import cv2 5 | import numpy as np 6 | import numbers 7 | import random 8 | import collections 9 | 10 | 11 | def get_2dshape(shape, *, zero=True): 12 | if not isinstance(shape, collections.Iterable): 13 | shape = int(shape) 14 | shape = (shape, shape) 15 | else: 16 | h, w = map(int, shape) 17 | shape = (h, w) 18 | if zero: 19 | minv = 0 20 | else: 21 | minv = 1 22 | 23 | assert min(shape) >= minv, 'invalid shape: {}'.format(shape) 24 | return shape 25 | 26 | 27 | def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value): 28 | h, w = img.shape[:2] 29 | start_crop_h, start_crop_w = crop_pos 30 | assert ((start_crop_h < h) and (start_crop_h >= 0)) 31 | assert ((start_crop_w < w) and (start_crop_w >= 0)) 32 | 33 | crop_size = get_2dshape(crop_size) 34 | crop_h, crop_w = crop_size 35 | 36 | img_crop = img[start_crop_h:start_crop_h + crop_h, 37 | start_crop_w:start_crop_w + crop_w, ...] 38 | 39 | img_, margin = pad_image_to_shape(img_crop, crop_size, cv2.BORDER_CONSTANT, 40 | pad_label_value) 41 | 42 | return img_, margin 43 | 44 | 45 | def generate_random_crop_pos(ori_size, crop_size): 46 | ori_size = get_2dshape(ori_size) 47 | h, w = ori_size 48 | 49 | crop_size = get_2dshape(crop_size) 50 | crop_h, crop_w = crop_size 51 | 52 | pos_h, pos_w = 0, 0 53 | 54 | if h > crop_h: 55 | pos_h = random.randint(0, h - crop_h + 1) 56 | 57 | if w > crop_w: 58 | pos_w = random.randint(0, w - crop_w + 1) 59 | 60 | return pos_h, pos_w 61 | 62 | 63 | def pad_image_to_shape(img, shape, border_mode, value): 64 | margin = np.zeros(4, np.uint32) 65 | shape = get_2dshape(shape) 66 | pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0 67 | pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0 68 | 69 | margin[0] = pad_height // 2 70 | margin[1] = pad_height // 2 + pad_height % 2 71 | margin[2] = pad_width // 2 72 | margin[3] = pad_width // 2 + pad_width % 2 73 | 74 | img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3], 75 | border_mode, value=value) 76 | 77 | return img, margin 78 | 79 | 80 | def pad_image_size_to_multiples_of(img, multiple, pad_value): 81 | h, w = img.shape[:2] 82 | d = multiple 83 | 84 | def canonicalize(s): 85 | v = s // d 86 | return (v + (v * d != s)) * d 87 | 88 | th, tw = map(canonicalize, (h, w)) 89 | 90 | return pad_image_to_shape(img, (th, tw), cv2.BORDER_CONSTANT, pad_value) 91 | 92 | 93 | def resize_ensure_shortest_edge(img, edge_length, 94 | interpolation_mode=cv2.INTER_LINEAR): 95 | assert isinstance(edge_length, int) and edge_length > 0, edge_length 96 | h, w = img.shape[:2] 97 | if h < w: 98 | ratio = float(edge_length) / h 99 | th, tw = edge_length, max(1, int(ratio * w)) 100 | else: 101 | ratio = float(edge_length) / w 102 | th, tw = max(1, int(ratio * h)), edge_length 103 | img = cv2.resize(img, (tw, th), interpolation_mode) 104 | 105 | return img 106 | 107 | 108 | def random_scale(img, gt, scales): 109 | # scale = random.choice(scales) 110 | scale = random.uniform(min(scales), max(scales)) 111 | sh = int(img.shape[0] * scale) 112 | sw = int(img.shape[1] * scale) 113 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR) 114 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST) 115 | 116 | return img, gt, scale 117 | 118 | 119 | def random_scale_with_length(img, gt, length): 120 | size = random.choice(length) 121 | sh = size 122 | sw = size 123 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR) 124 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST) 125 | 126 | return img, gt, size 127 | 128 | 129 | def random_mirror(img, gt): 130 | if random.random() >= 0.5: 131 | img = cv2.flip(img, 1) 132 | gt = cv2.flip(gt, 1) 133 | 134 | return img, gt, 135 | 136 | 137 | def random_rotation(img, gt): 138 | angle = random.random() * 20 - 10 139 | h, w = img.shape[:2] 140 | rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 141 | img = cv2.warpAffine(img, rotation_matrix, (w, h), flags=cv2.INTER_LINEAR) 142 | gt = cv2.warpAffine(gt, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST) 143 | 144 | return img, gt 145 | 146 | 147 | def random_gaussian_blur(img): 148 | gauss_size = random.choice([1, 3, 5, 7]) 149 | if gauss_size > 1: 150 | # do the gaussian blur 151 | img = cv2.GaussianBlur(img, (gauss_size, gauss_size), 0) 152 | 153 | return img 154 | 155 | 156 | def center_crop(img, shape): 157 | h, w = shape[0], shape[1] 158 | y = (img.shape[0] - h) // 2 159 | x = (img.shape[1] - w) // 2 160 | return img[y:y + h, x:x + w] 161 | 162 | 163 | def random_crop(img, gt, size): 164 | if isinstance(size, numbers.Number): 165 | size = (int(size), int(size)) 166 | 167 | h, w = img.shape[:2] 168 | crop_h, crop_w = size[0], size[1] 169 | 170 | if h > crop_h: 171 | x = random.randint(0, h - crop_h + 1) 172 | img = img[x:x + crop_h, :, :] 173 | gt = gt[x:x + crop_h, :] 174 | 175 | if w > crop_w: 176 | x = random.randint(0, w - crop_w + 1) 177 | img = img[:, x:x + crop_w, :] 178 | gt = gt[:, x:x + crop_w] 179 | 180 | return img, gt 181 | 182 | 183 | def normalize(img, mean, std): 184 | # pytorch pretrained model need the input range: 0-1 185 | img = img.astype(np.float32) / 255.0 186 | img = img - mean 187 | img = img / std 188 | 189 | return img 190 | -------------------------------------------------------------------------------- /tools/utils/pyt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | # encoding: utf-8 5 | import os 6 | import time 7 | import argparse 8 | from collections import OrderedDict 9 | 10 | import torch 11 | 12 | from tools.engine.logger import get_logger 13 | 14 | logger = get_logger() 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 22 | } 23 | 24 | 25 | def load_model(model, model_file, is_restore=False): 26 | t_start = time.time() 27 | if isinstance(model_file, str): 28 | state_dict = torch.load(model_file) 29 | if 'model' in state_dict.keys(): 30 | state_dict = state_dict['model'] 31 | else: 32 | state_dict = model_file 33 | t_ioend = time.time() 34 | 35 | if is_restore: 36 | new_state_dict = OrderedDict() 37 | for k, v in state_dict.items(): 38 | name = 'module.' + k 39 | new_state_dict[name] = v 40 | state_dict = new_state_dict 41 | 42 | model.load_state_dict(state_dict, strict=False) 43 | ckpt_keys = set(state_dict.keys()) 44 | own_keys = set(model.state_dict().keys()) 45 | missing_keys = own_keys - ckpt_keys 46 | unexpected_keys = ckpt_keys - own_keys 47 | 48 | if len(missing_keys) > 0: 49 | logger.warning('Missing key(s) in state_dict: {}'.format( 50 | ', '.join('{}'.format(k) for k in missing_keys))) 51 | 52 | if len(unexpected_keys) > 0: 53 | logger.warning('Unexpected key(s) in state_dict: {}'.format( 54 | ', '.join('{}'.format(k) for k in unexpected_keys))) 55 | 56 | del state_dict 57 | t_end = time.time() 58 | logger.info( 59 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( 60 | t_ioend - t_start, t_end - t_ioend)) 61 | 62 | return model 63 | 64 | 65 | def parse_devices(input_devices): 66 | if input_devices.endswith('*'): 67 | devices = list(range(torch.cuda.device_count())) 68 | return devices 69 | 70 | devices = [] 71 | for d in input_devices.split(','): 72 | if '-' in d: 73 | start_device, end_device = d.split('-')[0], d.split('-')[1] 74 | assert start_device != '' 75 | assert end_device != '' 76 | start_device, end_device = int(start_device), int(end_device) 77 | assert start_device < end_device 78 | assert end_device < torch.cuda.device_count() 79 | for sd in range(start_device, end_device + 1): 80 | devices.append(sd) 81 | else: 82 | device = int(d) 83 | assert device < torch.cuda.device_count() 84 | devices.append(device) 85 | 86 | logger.info('using devices {}'.format( 87 | ', '.join([str(d) for d in devices]))) 88 | 89 | return devices 90 | 91 | 92 | def extant_file(x): 93 | """ 94 | 'Type' for argparse - checks that file exists but does not open. 95 | """ 96 | if not os.path.exists(x): 97 | # Argparse uses the ArgumentTypeError to give a rejection message like: 98 | # error: argument input: x does not exist 99 | raise argparse.ArgumentTypeError("{0} does not exist".format(x)) 100 | return x 101 | 102 | 103 | def link_file(src, target): 104 | if os.path.isdir(target) or os.path.isfile(target): 105 | os.remove(target) 106 | os.system('ln -s {} {}'.format(src, target)) 107 | 108 | 109 | def ensure_dir(path): 110 | if not os.path.isdir(path): 111 | os.makedirs(path) 112 | 113 | 114 | def _dbg_interactive(var, value): 115 | from IPython import embed 116 | embed() 117 | -------------------------------------------------------------------------------- /tools/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import scipy.io as sio 4 | 5 | 6 | def set_img_color(colors, background, img, gt, show255=False, weight_foreground=0.55): 7 | origin = np.array(img) 8 | for i in range(len(colors)): 9 | if i != background: 10 | img[np.where(gt == i)] = colors[i] 11 | if show255: 12 | img[np.where(gt == 255)] = 0 13 | cv2.addWeighted(img, weight_foreground, origin, 1 - weight_foreground, 0, img) 14 | return img 15 | 16 | 17 | def show_prediction(colors, background, img, pred): 18 | im = np.array(img, np.uint8) 19 | set_img_color(colors, background, im, pred, weight_foreground=1) 20 | final = np.array(im) 21 | return final 22 | 23 | 24 | def show_img(colors, background, img, clean, gt, *pds): 25 | im1 = np.array(img, np.uint8) 26 | # set_img_color(colors, background, im1, clean) 27 | final = np.array(im1) 28 | # the pivot black bar 29 | pivot = np.zeros((im1.shape[0], 15, 3), dtype=np.uint8) 30 | for pd in pds: 31 | im = np.array(img, np.uint8) 32 | # pd[np.where(gt == 255)] = 255 33 | set_img_color(colors, background, im, pd) 34 | final = np.column_stack((final, pivot)) 35 | final = np.column_stack((final, im)) 36 | 37 | im = np.array(img, np.uint8) 38 | set_img_color(colors, background, im, gt, True) 39 | final = np.column_stack((final, pivot)) 40 | final = np.column_stack((final, im)) 41 | return final 42 | 43 | 44 | def get_colors(class_num): 45 | colors = [] 46 | for i in range(class_num): 47 | colors.append((np.random.random((1, 3)) * 255).tolist()[0]) 48 | 49 | return colors 50 | 51 | 52 | def get_ade_colors(): 53 | colors = sio.loadmat('./color150.mat')['colors'] 54 | colors = colors[:, ::-1, ] 55 | colors = np.array(colors).astype(int).tolist() 56 | colors.insert(0, [0, 0, 0]) 57 | 58 | return colors 59 | 60 | 61 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False, 62 | no_print=False): 63 | n = iu.size 64 | lines = [] 65 | for i in range(n): 66 | if class_names is None: 67 | cls = 'Class %d:' % (i + 1) 68 | else: 69 | cls = '%d %s' % (i + 1, class_names[i]) 70 | lines.append('%-8s\t%.3f%%' % (cls, iu[i] * 100)) 71 | mean_IU = np.nanmean(iu) 72 | # mean_IU_no_back = np.nanmean(iu[1:]) 73 | mean_IU_no_back = np.nanmean(iu[:-1]) 74 | if show_no_back: 75 | lines.append( 76 | '---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%\t%-8s\t%.3f%%' % ( 77 | 'mean_IU', mean_IU * 100, 'mean_IU_no_back', 78 | mean_IU_no_back * 100, 79 | 'mean_pixel_ACC', mean_pixel_acc * 100)) 80 | else: 81 | print(mean_pixel_acc) 82 | lines.append( 83 | '---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%' % ( 84 | 'mean_IU', mean_IU * 100, 'mean_pixel_ACC', 85 | mean_pixel_acc * 100)) 86 | line = "\n".join(lines) 87 | if not no_print: 88 | print(line) 89 | return line 90 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import argparse 5 | import os 6 | import sys 7 | import logging 8 | import time 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.optim 14 | from torch.utils.data import DataLoader 15 | import torchvision.transforms as transforms 16 | 17 | from data.visda17 import VisDA17 18 | from data.loader_csg import TwoCropsTransform 19 | from model.resnet import resnet101 20 | from model import csg_builder 21 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, adjust_learning_rate, accuracy 22 | from utils.logger import prepare_logger, prepare_seed 23 | from utils.sgd import SGD 24 | from utils.augmentations import RandAugment, augment_list 25 | 26 | torch.backends.cudnn.enabled = True 27 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean') 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch ResNet Training') 30 | parser.add_argument('--data', default='/home/chenwy/taskcv-2017-public/classification/data', help='path to dataset') 31 | parser.add_argument('--epochs', default=30, type=int, help='number of total epochs to run') 32 | parser.add_argument('--start-epoch', default=0, type=int, help='manual start epoch number (useful on restarts)') 33 | parser.add_argument('--batch-size', default=32, type=int, dest='batch_size', help='mini-batch size (default: 64)') 34 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate') 35 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)') 36 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') 37 | parser.add_argument('--csg', default=0.1, type=float, dest='csg', help="weight of CSG loss (default: 0.1).") 38 | parser.add_argument('--factor', default=0.1, type=float, dest='factor', help='scale factor of backbone learning rate (default: 0.1)') 39 | parser.add_argument('--csg-stages', dest='csg_stages', default='4', help='resnet stages to involve in CSG, 0~4, seperated by dot') 40 | parser.add_argument('--chunks', dest='chunks', default='1', help='stage-wise chunk to feature maps, seperated by dot') 41 | parser.add_argument('--no-mlp', dest='mlp', action='store_false', default=True, help='not to use mlp during contrastive learning') 42 | parser.add_argument('--apool', default=False, action='store_true', help='use A-Pool') 43 | parser.add_argument('--augment', action='store_true', default=False, help='use augmentation') 44 | parser.add_argument('--resume', default='none', type=str, help='path to latest checkpoint (default: none)') 45 | parser.add_argument('--num-class', default=12, type=int, dest='num_classes', help='the number of classes') 46 | parser.add_argument('--evaluate', action='store_true', help='only perform evaluation without training') 47 | parser.add_argument('--save_dir', type=str, default="./runs", help='root folder to save checkpoints and log.') 48 | parser.add_argument('--rand_seed', default=0, type=int, help='random seed') 49 | parser.add_argument('--csg-k', default=65536, type=int, help='queue size; number of negative keys (default: 65536)') 50 | parser.add_argument('--timestamp', type=str, default='none', help='timestamp for logging naming') 51 | parser.set_defaults(bottleneck=True) 52 | 53 | best_prec1 = 0 54 | 55 | 56 | def main(): 57 | global args, best_prec1 58 | PID = os.getpid() 59 | args = parser.parse_args() 60 | prepare_seed(args.rand_seed) 61 | 62 | if args.timestamp == 'none': 63 | args.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time()))) 64 | 65 | # Log outputs 66 | if args.evaluate: 67 | args.save_dir = args.save_dir + "/Visda17-Res101-evaluate" + \ 68 | "%s/%s"%('/'+args.resume.replace('/', '+') if args.resume != 'none' else '', args.timestamp) 69 | else: 70 | args.save_dir = args.save_dir + \ 71 | "/VisDA-Res101-CSG.stg{csg_stages}.w{csg_weight}-APool.{apool}-Aug.{augment}-chunk{chunks}-mlp{mlp}.K{csg_k}-LR{lr}.bone{factor}-epoch{epochs}-batch{batch_size}-seed{seed}".format( 72 | csg_stages=args.csg_stages, 73 | mlp=args.mlp, 74 | csg_weight=args.csg, 75 | apool=args.apool, 76 | augment=args.augment, 77 | chunks=args.chunks, 78 | csg_k=args.csg_k, 79 | lr="%.2E"%args.lr, 80 | factor="%.1f"%args.factor, 81 | epochs=args.epochs, 82 | batch_size=args.batch_size, 83 | seed=args.rand_seed 84 | ) + \ 85 | "%s/%s"%('/'+args.resume.replace('/', '+') if args.resume != 'none' else '', args.timestamp) 86 | logger = prepare_logger(args) 87 | 88 | data_transforms = { 89 | 'val': transforms.Compose([ 90 | transforms.Resize(224), 91 | transforms.CenterCrop(224), 92 | transforms.ToTensor(), 93 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 94 | ]), 95 | } 96 | if args.augment: 97 | data_transforms['train'] = transforms.Compose([ 98 | RandAugment(1, 6., augment_list), 99 | transforms.Resize(224), 100 | transforms.RandomCrop(224), 101 | transforms.ToTensor(), 102 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 103 | ]) 104 | else: 105 | data_transforms['train'] = transforms.Compose([ 106 | transforms.Resize(224), 107 | transforms.CenterCrop(224), 108 | transforms.ToTensor(), 109 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 110 | ]) 111 | 112 | kwargs = {'num_workers': 20, 'pin_memory': True} 113 | if args.augment: 114 | # two source 115 | trainset = VisDA17(txt_file=os.path.join(args.data, "train/image_list.txt"), root_dir=os.path.join(args.data, "train"), 116 | transform=TwoCropsTransform(data_transforms['train'], data_transforms['train'])) 117 | else: 118 | # one source 119 | trainset = VisDA17(txt_file=os.path.join(args.data, "train/image_list.txt"), root_dir=os.path.join(args.data, "train"), transform=data_transforms['train']) 120 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) 121 | valset = VisDA17(txt_file=os.path.join(args.data, "validation/image_list.txt"), root_dir=os.path.join(args.data, "validation"), transform=data_transforms['val'], label_one_hot=True) 122 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, **kwargs) 123 | 124 | args.stages = [int(stage) for stage in args.csg_stages.split('.')] if len(args.csg_stages) > 0 else [] 125 | chunks = [int(chunk) for chunk in args.chunks.split('.')] if len(args.chunks) > 0 else [] 126 | assert len(chunks) == 1 or len(chunks) == len(args.stages) 127 | if len(chunks) < len(args.stages): 128 | chunks = [chunks[0]] * len(args.stages) 129 | 130 | def get_head(num_ftrs, num_classes): 131 | _dim = 512 132 | return nn.Sequential( 133 | nn.Linear(num_ftrs, _dim), 134 | nn.ReLU(inplace=False), 135 | nn.Linear(_dim, num_classes), 136 | ) 137 | model = csg_builder.CSG( 138 | resnet101, get_head=get_head, K=args.csg_k, stages=args.stages, chunks=chunks, 139 | apool=args.apool, mlp=args.mlp, 140 | ) 141 | 142 | train_blocks = "conv1.bn1.layer1.layer2.layer3.layer4.fc" 143 | train_blocks = train_blocks.split('.') 144 | # Setup optimizer 145 | factor = args.factor 146 | sgd_in = [] 147 | for name in train_blocks: 148 | if name != 'fc': 149 | sgd_in.append({'params': get_params(model.encoder_q, [name]), 'lr': factor*args.lr}) 150 | else: 151 | # no update to fc but to fc_new 152 | sgd_in.append({'params': get_params(model.encoder_q, ["fc_new"]), 'lr': args.lr}) 153 | if model.mlp: 154 | sgd_in.append({'params': get_params(model.encoder_q, ["fc_csg"]), 'lr': args.lr}) 155 | base_lrs = [ group['lr'] for group in sgd_in ] 156 | optimizer = SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 157 | 158 | # Optionally resume from a checkpoint 159 | if args.resume != 'none': 160 | if os.path.isfile(args.resume): 161 | print("=> loading checkpoint '{}'".format(args.resume)) 162 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) 163 | args.start_epoch = checkpoint['epoch'] 164 | best_prec1 = checkpoint['best_prec1'] 165 | msg = model.load_state_dict(checkpoint['state_dict'], strict=False) 166 | print("resume weights: ", msg) 167 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 168 | else: 169 | print("=ImageClassdata> no checkpoint found at '{}'".format(args.resume)) 170 | 171 | model = model.cuda() 172 | 173 | if args.evaluate: 174 | prec1 = validate(val_loader, model, args, 0) 175 | print(prec1) 176 | exit(0) 177 | 178 | # Main training loop 179 | iter_max = args.epochs * len(train_loader) 180 | iter_stat = IterNums(iter_max) 181 | for epoch in range(args.start_epoch, args.epochs): 182 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir)) 183 | logger.log("Epoch: %d"%(epoch+1)) 184 | train(train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger, args, adjust_lr=epoch best_prec1 192 | best_prec1 = max(prec1, best_prec1) 193 | save_checkpoint(args.save_dir, { 194 | 'epoch': epoch + 1, 195 | 'state_dict': model.state_dict(), 196 | 'best_prec1': best_prec1, 197 | }, is_best, keep_last=1) 198 | 199 | logging.info('Best accuracy: {prec1:.3f}'.format(prec1=best_prec1)) 200 | 201 | 202 | def train(train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger, args, adjust_lr=True): 203 | tb_interval = 50 204 | 205 | csg_weight = args.csg 206 | 207 | losses = AverageMeter() # loss on target task 208 | losses_csg = [AverageMeter() for _ in range(len(model.stages))] # [_loss] x #stages 209 | top1_csg = [AverageMeter() for _ in range(len(model.stages))] 210 | 211 | model.eval() 212 | 213 | # train for one epoch 214 | optimizer.zero_grad() 215 | epoch_size = len(train_loader) 216 | train_loader_iter = iter(train_loader) 217 | 218 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 219 | pbar = tqdm(range(epoch_size), file=sys.stdout, bar_format=bar_format, ncols=80) 220 | 221 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9) 222 | logger.writer.add_scalar("lr", lr, epoch) 223 | logger.log("lr %f"%lr) 224 | for idx_iter in pbar: 225 | optimizer.zero_grad() 226 | if adjust_lr: 227 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9) 228 | adjust_learning_rate(base_lrs, optimizer, iter_stat.iter_curr, iter_stat.iter_max, 0.9) 229 | 230 | input, label = next(train_loader_iter) 231 | if args.augment: 232 | input_q = input[0].cuda() 233 | input_k = input[1].cuda() 234 | else: 235 | input_q = input.cuda() 236 | input_k = None 237 | label = label.cuda() 238 | 239 | results = model(input_q, input_k) 240 | 241 | # synthetic task 242 | loss = CrossEntropyLoss(results['output'], label.long()) 243 | # measure accuracy and record loss 244 | losses.update(loss, label.size(0)) 245 | for idx in range(len(model.stages)): 246 | _loss = 0 247 | acc1 = None 248 | # predictions: cosine b/w q and k 249 | # targets: zeros 250 | _loss = CrossEntropyLoss(results['predictions_csg'][idx], results['targets_csg'][idx]) 251 | acc1, acc5 = accuracy_ranking(results['predictions_csg'][idx].data, results['targets_csg'][idx], topk=(1, 5)) 252 | loss = loss + _loss * csg_weight 253 | # loss_csg[_type].append(_loss) 254 | if acc1 is not None: top1_csg[idx].update(acc1, label.size(0)) 255 | # measure accuracy and record loss 256 | losses_csg[idx].update(_loss, label.size(0)) 257 | 258 | loss.backward() 259 | 260 | # compute gradient and do SGD step 261 | optimizer.step() 262 | # increment iter number 263 | iter_stat.update() 264 | 265 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/ce", losses.val, idx_iter + epoch * epoch_size) 266 | description = "[XE %.3f]"%(losses.val) 267 | description += "[CSG " 268 | loss_str = "" 269 | acc_str = "" 270 | for idx, stage in enumerate(model.stages): 271 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/layer%d"%stage, losses_csg[idx].val, idx_iter + epoch * epoch_size) 272 | loss_str += "%.2f|"%losses_csg[idx].val 273 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("prec/layer%d"%stage, top1_csg[idx].val[0], idx_iter + epoch * epoch_size) 274 | acc_str += "%.1f|"%top1_csg[idx].val[0] 275 | description += "loss:%s ranking:%s]"%(loss_str[:-1], acc_str[:-1]) 276 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/total", losses.val + sum([_loss.val for _loss in losses_csg]), idx_iter + epoch * epoch_size) 277 | pbar.set_description("[Step %d/%d][%s]"%(idx_iter + 1, epoch_size, str(csg_weight)) + description) 278 | 279 | 280 | def validate(val_loader, model, args, epoch): 281 | """Perform validation on the validation set""" 282 | top1 = AverageMeter() 283 | 284 | # switch to evaluate mode 285 | model.eval() 286 | 287 | val_size = len(val_loader) 288 | val_loader_iter = iter(val_loader) 289 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 290 | pbar = tqdm(range(val_size), file=sys.stdout, bar_format=bar_format, ncols=140) 291 | with torch.no_grad(): 292 | for idx_iter in pbar: 293 | input, label = next(val_loader_iter) 294 | 295 | input = input.cuda() 296 | label = label.cuda() 297 | 298 | # compute output 299 | output, _ = model.encoder_q(input, task='new') 300 | output = torch.sigmoid(output) 301 | output = (output + torch.sigmoid(model.encoder_q(torch.flip(input, dims=(3,)), task='new')[0])) / 2 302 | 303 | # accumulate accuracyk 304 | prec1, gt_num = accuracy(output.data, label, args.num_classes, topk=(1,)) 305 | top1.update(prec1[0], gt_num[0]) 306 | 307 | description = "[Acc@1-mean: %.2f][Acc@1-cls: %s]"%(top1.vec2sca_avg, str(top1.avg.numpy().round(1))) 308 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, val_size) + description) 309 | 310 | logging.info(' * Prec@1 {top1.vec2sca_avg:.3f}'.format(top1=top1)) 311 | logging.info(' * Prec@1 {top1.avg}'.format(top1=top1)) 312 | 313 | return top1.vec2sca_avg 314 | 315 | 316 | def accuracy_ranking(output, target, topk=(1,)): 317 | """Computes the accuracy over the k top predictions for the specified values of k""" 318 | with torch.no_grad(): 319 | maxk = max(topk) 320 | batch_size = target.size(0) 321 | 322 | _, pred = output.topk(maxk, 1, True, True) 323 | pred = pred.t() 324 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 325 | 326 | res = [] 327 | for k in topk: 328 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 329 | res.append(correct_k.mul_(100.0 / batch_size)) 330 | return res 331 | 332 | 333 | if __name__ == '__main__': 334 | main() 335 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | python train.py \ 5 | --epochs 30 \ 6 | --batch-size 32 \ 7 | --lr 1e-4 \ 8 | --rand_seed 0 \ 9 | --csg 0.1 \ 10 | --apool \ 11 | --augment \ 12 | --csg-stages 3.4 \ 13 | --factor 0.1 \ 14 | # --resume pretrained/csg_res101_vista17_best.pth.tar \ 15 | # --evaluate 16 | -------------------------------------------------------------------------------- /train_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import argparse 5 | import os 6 | import sys 7 | import logging 8 | import time 9 | import numpy as np 10 | from tqdm import tqdm 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.optim 15 | from pdb import set_trace as bp 16 | from data.gta5 import GTA5 17 | from data.cityscapes import Cityscapes 18 | from model import csg_builder 19 | from model.deeplab import ResNet as deeplab 20 | from dataloader_seg import get_train_loader 21 | from eval_seg import SegEvaluator 22 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, adjust_learning_rate 23 | from utils.logger import prepare_logger, prepare_seed 24 | from utils.sgd import SGD 25 | 26 | torch.backends.cudnn.enabled = True 27 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean', ignore_index=255) 28 | KLDivLoss = nn.KLDivLoss(reduction='batchmean') 29 | best_mIoU = 0 30 | 31 | parser = argparse.ArgumentParser(description='PyTorch ResNet Training') 32 | parser.add_argument('--epochs', default=50, type=int, help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 34 | parser.add_argument('--batch-size', default=6, type=int, dest='batch_size', help='mini-batch size (default: 64)') 35 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate') 36 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)') 37 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') 38 | parser.add_argument('--csg', default=75., type=float, dest='csg', help="weight of LWF los (default: 0). Format: type('_')=>stage(',')") 39 | parser.add_argument('--switch-model', default='deeplab50', choices=["deeplab50", "deeplab101"], help='which model to use') 40 | parser.add_argument('--factor', default=0.1, type=float, dest='factor', help='scale factor of backbone learning rate (default: 0.1)') 41 | parser.add_argument('--csg-stages', dest='csg_stages', default='4', help='resnet stages to involve in LWF, 0~4, seperated by dot') 42 | parser.add_argument('--chunks', dest='chunks', default='8', help='stage-wise chunk to feature maps, seperated by dot') 43 | parser.add_argument('--no-mlp', dest='mlp', action='store_false', default=True, help='not to use mlp during contrastive learning') 44 | parser.add_argument('--apool', default=False, action='store_true', help='use A-Pool') 45 | parser.add_argument('--augment', action='store_true', default=False, help='use augmentation') 46 | parser.add_argument('--resume', default='none', type=str, help='path to latest checkpoint (default: none)') 47 | parser.add_argument('--num-class', default=19, type=int, dest='num_classes', help='the number of classes') 48 | parser.add_argument('--gpus', default=0, type=int, help='gpu to use') 49 | parser.add_argument('--evaluate', action='store_true', help='whether to use learn without forgetting (default: False)') 50 | parser.add_argument('--save_dir', type=str, default="./runs", help='root folder to save checkpoints and log.') 51 | parser.add_argument('--rand_seed', default=0, type=int, help='the number of classes') 52 | parser.add_argument('--csg-k', default=65536, type=int, help='queue size; number of negative keys (default: 65536)') 53 | parser.add_argument('--timestamp', type=str, default='none', help='timestamp for logging naming') 54 | parser.set_defaults(bottleneck=True) 55 | 56 | best_mIoU = 0 57 | 58 | 59 | def main(): 60 | global args, best_mIoU 61 | PID = os.getpid() 62 | args = parser.parse_args() 63 | prepare_seed(args.rand_seed) 64 | device = torch.device("cuda:"+str(args.gpus)) 65 | 66 | if args.timestamp == 'none': 67 | args.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time()))) 68 | 69 | switch_model = args.switch_model 70 | assert switch_model in ["deeplab50", "deeplab101"] 71 | 72 | # Log outputs 73 | if args.evaluate: 74 | args.save_dir = args.save_dir + "/GTA5-%s-evaluate"%switch_model + \ 75 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp) 76 | else: 77 | args.save_dir = args.save_dir + \ 78 | "/GTA5_512x512-{model}-LWF.stg{csg_stages}.w{csg_weight}-APool.{apool}-Aug.{augment}-chunk{chunks}-mlp{mlp}.K{csg_k}-LR{lr}.bone{factor}-epoch{epochs}-batch{batch_size}-seed{seed}".format( 79 | model=switch_model, 80 | csg_stages=args.csg_stages, 81 | mlp=args.mlp, 82 | csg_weight=args.csg, 83 | apool=args.apool, 84 | augment=args.augment, 85 | chunks=args.chunks, 86 | csg_k=args.csg_k, 87 | lr="%.2E"%args.lr, 88 | factor="%.1f"%args.factor, 89 | epochs=args.epochs, 90 | batch_size=args.batch_size, 91 | seed=args.rand_seed 92 | ) + \ 93 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp) 94 | logger = prepare_logger(args) 95 | 96 | from config_seg import config as data_setting 97 | data_setting.batch_size = args.batch_size 98 | train_loader = get_train_loader(data_setting, GTA5, test=False, augment=args.augment) 99 | 100 | args.stages = [int(stage) for stage in args.csg_stages.split('.')] if len(args.csg_stages) > 0 else [] 101 | chunks = [int(chunk) for chunk in args.chunks.split('.')] if len(args.chunks) > 0 else [] 102 | assert len(chunks) == 1 or len(chunks) == len(args.stages) 103 | if len(chunks) < len(args.stages): 104 | chunks = [chunks[0]] * len(args.stages) 105 | 106 | if switch_model == 'deeplab50': 107 | layers = [3, 4, 6, 3] 108 | elif switch_model == 'deeplab101': 109 | layers = [3, 4, 23, 3] 110 | model = csg_builder.CSG(deeplab, get_head=None, K=args.csg_k, stages=args.stages, chunks=chunks, task='new-seg', 111 | apool=args.apool, mlp=args.mlp, 112 | base_encoder_kwargs={'num_seg_classes': args.num_classes, 'layers': layers}) 113 | 114 | threds = 3 115 | evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), args.num_classes, np.array([0.485, 0.456, 0.406]), 116 | np.array([0.229, 0.224, 0.225]), model.encoder_q, [1, ], False, devices=args.gpus, config=data_setting, threds=threds, 117 | verbose=False, save_path=None, show_image=False) # just calculate mIoU, no prediction file is generated 118 | # verbose=False, save_path="./prediction_files", show_image=True, show_prediction=True) # generate prediction files 119 | 120 | 121 | # Setup optimizer 122 | factor = args.factor 123 | sgd_in = [ 124 | {'params': get_params(model.encoder_q, ["conv1"]), 'lr': factor*args.lr}, 125 | {'params': get_params(model.encoder_q, ["bn1"]), 'lr': factor*args.lr}, 126 | {'params': get_params(model.encoder_q, ["layer1"]), 'lr': factor*args.lr}, 127 | {'params': get_params(model.encoder_q, ["layer2"]), 'lr': factor*args.lr}, 128 | {'params': get_params(model.encoder_q, ["layer3"]), 'lr': factor*args.lr}, 129 | {'params': get_params(model.encoder_q, ["layer4"]), 'lr': factor*args.lr}, 130 | {'params': get_params(model.encoder_q, ["fc_new"]), 'lr': args.lr}, 131 | ] 132 | base_lrs = [ group['lr'] for group in sgd_in ] 133 | optimizer = SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 134 | 135 | # Optionally resume from a checkpoint 136 | if args.resume != 'none': 137 | if os.path.isfile(args.resume): 138 | print("=> loading checkpoint '{}'".format(args.resume)) 139 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) 140 | args.start_epoch = checkpoint['epoch'] 141 | best_mIoU = checkpoint['best_mIoU'] 142 | msg = model.load_state_dict(checkpoint['state_dict']) 143 | print("resume weights: ", msg) 144 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 145 | else: 146 | print("=ImageClassdata> no checkpoint found at '{}'".format(args.resume)) 147 | 148 | model = model.to(device) 149 | 150 | if args.evaluate: 151 | mIoU = validate(evaluator, model, -1) 152 | print(mIoU) 153 | exit(0) 154 | 155 | # Main training loop 156 | iter_max = args.epochs * len(train_loader) 157 | iter_stat = IterNums(iter_max) 158 | for epoch in range(args.start_epoch, args.epochs): 159 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir)) 160 | logger.log("Epoch: %d"%(epoch+1)) 161 | # train for one epoch 162 | train(args, train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger, device, adjust_lr=epoch best_mIoU 172 | best_mIoU = max(mIoU, best_mIoU) 173 | save_checkpoint(args.save_dir, { 174 | 'epoch': epoch + 1, 175 | 'state_dict': model.state_dict(), 176 | 'best_mIoU': best_mIoU, 177 | }, is_best) 178 | 179 | logging.info('Best accuracy: {mIoU:.3f}'.format(mIoU=best_mIoU)) 180 | 181 | 182 | def train(args, train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger, device, adjust_lr=True): 183 | tb_interval = 50 184 | 185 | csg_weight = args.csg 186 | 187 | """Train for one epoch on the training set""" 188 | losses = AverageMeter() 189 | losses_csg = [AverageMeter() for _ in range(len(model.stages))] # [_loss] x #stages 190 | top1_csg = [AverageMeter() for _ in range(len(model.stages))] 191 | 192 | model.eval() 193 | model.encoder_q.fc_new.train() 194 | 195 | # train for one epoch 196 | optimizer.zero_grad() 197 | epoch_size = len(train_loader) 198 | train_loader_iter = iter(train_loader) 199 | 200 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 201 | pbar = tqdm(range(epoch_size), file=sys.stdout, bar_format=bar_format, ncols=80) 202 | 203 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9) 204 | logger.log("lr %f"%lr) 205 | for idx_iter in pbar: 206 | 207 | optimizer.zero_grad() 208 | if adjust_lr: 209 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9) 210 | adjust_learning_rate(base_lrs, optimizer, iter_stat.iter_curr, iter_stat.iter_max, 0.9) 211 | 212 | sample = next(train_loader_iter) 213 | label = sample['label'].to(device) 214 | input = sample['data'] 215 | if args.augment: 216 | input_q = input.to(device) 217 | input_k = sample['img_k'].to(device) 218 | else: 219 | input_q = input.to(device) 220 | input_k = None 221 | 222 | # keys: output, predictions_csg, targets_csg 223 | results = model(input_q, input_k) 224 | 225 | # synthetic task 226 | loss = CrossEntropyLoss(results['output'], label.long()) 227 | # measure accuracy and record loss 228 | losses.update(loss, label.size(0)) 229 | for idx in range(len(model.stages)): 230 | _loss = 0 231 | acc1 = None 232 | # predictions: cosine b/w q and k 233 | # targets: zeros 234 | _loss = CrossEntropyLoss(results['predictions_csg'][idx], results['targets_csg'][idx]) 235 | acc1, acc5 = accuracy_ranking(results['predictions_csg'][idx].data, results['targets_csg'][idx], topk=(1, 5)) 236 | loss = loss + _loss * csg_weight 237 | if acc1 is not None: top1_csg[idx].update(acc1, label.size(0)) 238 | # measure accuracy and record loss 239 | losses_csg[idx].update(_loss, label.size(0)) 240 | 241 | loss.backward() 242 | 243 | # compute gradient and do SGD step 244 | optimizer.step() 245 | # increment iter number 246 | iter_stat.update() 247 | 248 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/ce", losses.val, idx_iter + epoch * epoch_size) 249 | description = "[XE %.3f]"%(losses.val) 250 | description += "[CSG " 251 | loss_str = "" 252 | acc_str = "" 253 | for idx, stage in enumerate(model.stages): 254 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/layer%d"%stage, losses_csg[idx].val, idx_iter + epoch * epoch_size) 255 | loss_str += "%.2f|"%losses_csg[idx].val 256 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("prec/layer%d"%stage, top1_csg[idx].val[0], idx_iter + epoch * epoch_size) 257 | acc_str += "%.1f|"%top1_csg[idx].val[0] 258 | description += "loss:%s ranking:%s]"%(loss_str[:-1], acc_str[:-1]) 259 | if idx_iter % tb_interval == 0: logger.writer.add_scalar("loss/total", losses.val + sum([_loss.val for _loss in losses_csg]), idx_iter + epoch * epoch_size) 260 | pbar.set_description("[Step %d/%d][%s]"%(idx_iter + 1, epoch_size, str(csg_weight)) + description) 261 | 262 | 263 | def validate(evaluator, model, epoch): 264 | with torch.no_grad(): 265 | model.eval() 266 | # _, mIoU = evaluator.run_online() 267 | _, mIoU = evaluator.run_online_multiprocess() 268 | return mIoU 269 | 270 | 271 | def accuracy_ranking(output, target, topk=(1,)): 272 | """Computes the accuracy over the k top predictions for the specified values of k""" 273 | with torch.no_grad(): 274 | maxk = max(topk) 275 | batch_size = target.size(0) 276 | 277 | _, pred = output.topk(maxk, 1, True, True) 278 | pred = pred.t() 279 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 280 | 281 | res = [] 282 | for k in topk: 283 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 284 | res.append(correct_k.mul_(100.0 / batch_size)) 285 | return res 286 | 287 | 288 | if __name__ == '__main__': 289 | main() 290 | -------------------------------------------------------------------------------- /train_seg.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | python train_seg.py \ 5 | --epochs 50 \ 6 | --switch-model deeplab101 \ 7 | --batch-size 6 \ 8 | --lr 1e-3 \ 9 | --num-class 19 \ 10 | --gpus 0 \ 11 | --factor 0.1 \ 12 | --csg 75 \ 13 | --apool \ 14 | --csg-stages 3.4 \ 15 | --chunks 8 \ 16 | --augment \ 17 | --evaluate \ 18 | --resume pretrained/csg_res101_segmentation_best.pth.tar \ 19 | # --resume pretrained/csg_res50_segmentation_best.pth.tar \ 20 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)) 3 | 4 | import random 5 | 6 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | 11 | 12 | def ShearX(img, v): # [-0.3, 0.3] 13 | assert -0.3 <= v <= 0.3 14 | if random.random() > 0.5: 15 | v = -v 16 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 17 | 18 | 19 | def ShearY(img, v): # [-0.3, 0.3] 20 | assert -0.3 <= v <= 0.3 21 | if random.random() > 0.5: 22 | v = -v 23 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 24 | 25 | 26 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 27 | assert -0.45 <= v <= 0.45 28 | if random.random() > 0.5: 29 | v = -v 30 | v = v * img.size[0] 31 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 32 | 33 | 34 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 35 | assert 0 <= v 36 | if random.random() > 0.5: 37 | v = -v 38 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 39 | 40 | 41 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 42 | assert -0.45 <= v <= 0.45 43 | if random.random() > 0.5: 44 | v = -v 45 | v = v * img.size[1] 46 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 47 | 48 | 49 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 50 | assert 0 <= v 51 | if random.random() > 0.5: 52 | v = -v 53 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 54 | 55 | 56 | def Rotate(img, v): # [-30, 30] 57 | assert -30 <= v <= 30 58 | if random.random() > 0.5: 59 | v = -v 60 | return img.rotate(v) 61 | 62 | 63 | def AutoContrast(img, v): 64 | if random.random() <= v: 65 | return PIL.ImageOps.autocontrast(img) 66 | else: 67 | return img 68 | 69 | 70 | def Invert(img, v): 71 | if random.random() <= v: 72 | return PIL.ImageOps.invert(img) 73 | else: 74 | return img 75 | 76 | 77 | def Equalize(img, v): 78 | if random.random() <= v: 79 | return PIL.ImageOps.equalize(img) 80 | else: 81 | return img 82 | 83 | 84 | def Flip(img, _): # not from the paper 85 | return PIL.ImageOps.mirror(img) 86 | 87 | 88 | def Solarize(img, v): # [0, 256] 89 | assert 0 <= v <= 256 90 | v = 256 - v 91 | return PIL.ImageOps.solarize(img, v) 92 | 93 | 94 | def SolarizeAdd(img, addition=0, threshold=128): 95 | img_np = np.array(img).astype(np.int) 96 | img_np = img_np + addition 97 | img_np = np.clip(img_np, 0, 255) 98 | img_np = img_np.astype(np.uint8) 99 | img = Image.fromarray(img_np) 100 | return PIL.ImageOps.solarize(img, threshold) 101 | 102 | 103 | def Posterize(img, v): # [4, 8] 104 | assert 0 <= v <= 7 105 | # v = int(v) 106 | v = 8 - int(v) 107 | v = max(1, v) 108 | return PIL.ImageOps.posterize(img, v) 109 | 110 | 111 | def Contrast(img, v): # [0.,0.9] 112 | # A factor of 1.0 gives the original image. 113 | assert 0. <= v <= 0.9 114 | if random.random() > 0.5: 115 | v = -v 116 | return PIL.ImageEnhance.Contrast(img).enhance(v+1) # 0.1 to 1.9 117 | 118 | 119 | def Color(img, v): # [0.,0.9] 120 | # A factor of 1.0 gives the original image. 121 | assert 0. <= v <= 0.9 122 | if random.random() > 0.5: 123 | v = -v 124 | return PIL.ImageEnhance.Color(img).enhance(v+1) # 0.1 to 1.9 125 | 126 | 127 | def Brightness(img, v): # [0.,0.9] 128 | # A factor of 1.0 gives the original image. 129 | assert 0. <= v <= 0.9 130 | if random.random() > 0.5: 131 | v = -v 132 | return PIL.ImageEnhance.Brightness(img).enhance(v+1) # 0.1 to 1.9 133 | 134 | 135 | def Sharpness(img, v): # [0.,.9] 136 | assert 0. <= v <= 0.9 137 | if random.random() > 0.5: 138 | v = -v 139 | return PIL.ImageEnhance.Sharpness(img).enhance(v+1) # 0.1 to 1.9 140 | 141 | 142 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 143 | assert 0.0 <= v <= 0.2 144 | if v <= 0.: 145 | return img 146 | 147 | v = v * img.size[0] 148 | return CutoutAbs(img, v) 149 | 150 | 151 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 152 | # assert 0 <= v <= 20 153 | if v < 0: 154 | return img 155 | w, h = img.size 156 | x0 = np.random.uniform(w) 157 | y0 = np.random.uniform(h) 158 | 159 | x0 = int(max(0, x0 - v / 2.)) 160 | y0 = int(max(0, y0 - v / 2.)) 161 | x1 = min(w, x0 + v) 162 | y1 = min(h, y0 + v) 163 | 164 | xy = (x0, y0, x1, y1) 165 | color = (125, 123, 114) 166 | # color = (0, 0, 0) 167 | img = img.copy() 168 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 169 | return img 170 | 171 | 172 | def SamplePairing(imgs): # [0, 0.4] 173 | def f(img1, v): 174 | i = np.random.choice(len(imgs)) 175 | img2 = PIL.Image.fromarray(imgs[i]) 176 | return PIL.Image.blend(img1, img2, v) 177 | 178 | return f 179 | 180 | 181 | def Identity(img, v): 182 | return img 183 | 184 | 185 | augment_list = [ 186 | (Identity, 0., 1.0), 187 | (AutoContrast, 0, 1), 188 | (Equalize, 0, 1), 189 | (Rotate, 0, 30), 190 | (Posterize, 0, 7), 191 | (Solarize, 0, 256), 192 | (Color, 0., 0.9), 193 | (Contrast, 0., 0.9), 194 | (Brightness, 0., 0.9), 195 | (Sharpness, 0., 0.9), 196 | (ShearX, 0., 0.3), 197 | (ShearY, 0., 0.3), 198 | (TranslateXabs, 0., 100), 199 | (TranslateYabs, 0., 100), 200 | ] 201 | 202 | 203 | class Lighting(object): 204 | """Lighting noise(AlexNet - style PCA - based noise)""" 205 | 206 | def __init__(self, alphastd, eigval, eigvec): 207 | self.alphastd = alphastd 208 | self.eigval = torch.Tensor(eigval) 209 | self.eigvec = torch.Tensor(eigvec) 210 | 211 | def __call__(self, img): 212 | if self.alphastd == 0: 213 | return img 214 | 215 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 216 | rgb = self.eigvec.type_as(img).clone() \ 217 | .mul(alpha.view(1, 3).expand(3, 3)) \ 218 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 219 | .sum(1).squeeze() 220 | 221 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 222 | 223 | 224 | class CutoutDefault(object): 225 | def __init__(self, length): 226 | self.length = length 227 | 228 | def __call__(self, img): 229 | h, w = img.size(1), img.size(2) 230 | mask = np.ones((h, w), np.float32) 231 | y = np.random.randint(h) 232 | x = np.random.randint(w) 233 | 234 | y1 = np.clip(y - self.length // 2, 0, h) 235 | y2 = np.clip(y + self.length // 2, 0, h) 236 | x1 = np.clip(x - self.length // 2, 0, w) 237 | x2 = np.clip(x + self.length // 2, 0, w) 238 | 239 | mask[y1: y2, x1: x2] = 0. 240 | mask = torch.from_numpy(mask) 241 | mask = mask.expand_as(img) 242 | img *= mask 243 | return img 244 | 245 | 246 | class RandAugment: 247 | def __init__(self, n, m, augment_list): 248 | self.n = n 249 | self.m = m # [0, 30] 250 | assert 0 <= m <= 30 251 | self.augment_list = augment_list 252 | 253 | def __call__(self, img): 254 | ops = random.choices(self.augment_list, k=self.n) 255 | for op, minval, maxval in ops: 256 | val = (float(self.m) / 30) * float(maxval - minval) + minval 257 | img = op(img, val) 258 | 259 | return img 260 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | from pathlib import Path 5 | import importlib, warnings 6 | import os, sys, numpy as np 7 | import torch, random, PIL, copy 8 | import glob 9 | import shutil 10 | if sys.version_info.major == 2: # Python 2.x 11 | from StringIO import StringIO as BIO 12 | else: # Python 3.x 13 | from io import BytesIO as BIO 14 | from torch.utils.tensorboard import SummaryWriter 15 | if importlib.util.find_spec('tensorflow'): 16 | import tensorflow as tf 17 | 18 | 19 | def prepare_seed(rand_seed): 20 | random.seed(rand_seed) 21 | np.random.seed(rand_seed) 22 | torch.manual_seed(rand_seed) 23 | torch.cuda.manual_seed(rand_seed) 24 | torch.cuda.manual_seed_all(rand_seed) 25 | 26 | 27 | def prepare_logger(xargs): 28 | args = copy.deepcopy( xargs ) 29 | logger = Logger(args.save_dir, args.rand_seed) 30 | logger.log('Main Function with logger : {:}'.format(logger)) 31 | logger.log('Arguments : -------------------------------') 32 | for name, value in args._get_kwargs(): 33 | logger.log('{:16} : {:}'.format(name, value)) 34 | logger.log("Python Version : {:}".format(sys.version.replace('\n', ' '))) 35 | logger.log("Pillow Version : {:}".format(PIL.__version__)) 36 | logger.log("PyTorch Version : {:}".format(torch.__version__)) 37 | logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version())) 38 | logger.log("CUDA available : {:}".format(torch.cuda.is_available())) 39 | logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) 40 | logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None')) 41 | return logger 42 | 43 | 44 | class PrintLogger(object): 45 | 46 | def __init__(self): 47 | """Create a summary writer logging to log_dir.""" 48 | self.name = 'PrintLogger' 49 | 50 | def log(self, string): 51 | print (string) 52 | 53 | def close(self): 54 | print ('-'*30 + ' close printer ' + '-'*30) 55 | 56 | 57 | class Logger(object): 58 | 59 | def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False): 60 | """Create a summary writer logging to log_dir.""" 61 | self.seed = int(seed) 62 | self.log_dir = Path(log_dir) 63 | self.model_dir = Path(log_dir) / 'model' 64 | self.log_dir.mkdir (parents=True, exist_ok=True) 65 | 66 | self.use_tf = bool(use_tf) 67 | self.tensorboard_dir = self.log_dir 68 | self.logger_path = self.log_dir / 'seed-{:}.log'.format(self.seed) 69 | self.logger_file = open(self.logger_path, 'w') 70 | 71 | self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True) 72 | self.writer = SummaryWriter(str(self.tensorboard_dir)) 73 | 74 | scripts_to_save=glob.glob('*.py')+glob.glob('*.sh') 75 | os.mkdir(os.path.join(log_dir, 'scripts')) 76 | for script in scripts_to_save: 77 | dst_file = os.path.join(log_dir, 'scripts', os.path.basename(script)) 78 | shutil.copyfile(script, dst_file) 79 | shutil.make_archive(os.path.join(log_dir, "scripts"), 'zip', log_dir, "scripts") 80 | shutil.rmtree(os.path.join(log_dir, 'scripts')) 81 | 82 | def __repr__(self): 83 | return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__)) 84 | 85 | def path(self, mode): 86 | valids = ('model', 'best', 'info', 'log') 87 | if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed) 88 | elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed) 89 | elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed) 90 | elif mode == 'log' : return self.log_dir 91 | else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids)) 92 | 93 | def extract_log(self): 94 | return self.logger_file 95 | 96 | def close(self): 97 | self.logger_file.close() 98 | if self.writer is not None: 99 | self.writer.close() 100 | 101 | def log(self, string, save=True, stdout=False): 102 | if stdout: 103 | sys.stdout.write(string); sys.stdout.flush() 104 | else: 105 | print (string) 106 | if save: 107 | self.logger_file.write('{:}\n'.format(string)) 108 | self.logger_file.flush() 109 | 110 | def scalar_summary(self, tags, values, step): 111 | """Log a scalar variable.""" 112 | if not self.use_tf: 113 | warnings.warn('Do set use-tensorflow installed but call scalar_summary') 114 | else: 115 | assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values)) 116 | if not isinstance(tags, list): 117 | tags, values = [tags], [values] 118 | for tag, value in zip(tags, values): 119 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 120 | self.writer.add_summary(summary, step) 121 | self.writer.flush() 122 | 123 | def image_summary(self, tag, images, step): 124 | """Log a list of images.""" 125 | import scipy 126 | if not self.use_tf: 127 | warnings.warn('Do set use-tensorflow installed but call scalar_summary') 128 | return 129 | 130 | img_summaries = [] 131 | for i, img in enumerate(images): 132 | # Write the image to a string 133 | try: 134 | s = StringIO() 135 | except: 136 | s = BytesIO() 137 | scipy.misc.toimage(img).save(s, format="png") 138 | 139 | # Create an Image object 140 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 141 | height=img.shape[0], 142 | width=img.shape[1]) 143 | # Create a Summary value 144 | img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum)) 145 | 146 | # Create and write Summary 147 | summary = tf.Summary(value=img_summaries) 148 | self.writer.add_summary(summary, step) 149 | self.writer.flush() 150 | 151 | def histo_summary(self, tag, values, step, bins=1000): 152 | """Log a histogram of the tensor of values.""" 153 | if not self.use_tf: raise ValueError('Do not have tensorflow') 154 | import tensorflow as tf 155 | 156 | # Create a histogram using numpy 157 | counts, bin_edges = np.histogram(values, bins=bins) 158 | 159 | # Fill the fields of the histogram proto 160 | hist = tf.HistogramProto() 161 | hist.min = float(np.min(values)) 162 | hist.max = float(np.max(values)) 163 | hist.num = int(np.prod(values.shape)) 164 | hist.sum = float(np.sum(values)) 165 | hist.sum_squares = float(np.sum(values**2)) 166 | 167 | # Drop the start of the first bin 168 | bin_edges = bin_edges[1:] 169 | 170 | # Add bin edges and counts 171 | for edge in bin_edges: 172 | hist.bucket_limit.append(edge) 173 | for c in counts: 174 | hist.bucket.append(c) 175 | 176 | # Create and write Summary 177 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 178 | self.writer.add_summary(summary, step) 179 | self.writer.flush() 180 | -------------------------------------------------------------------------------- /utils/sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | 5 | # fixed SGD 6 | # See Note here: https://pytorch.org/docs/stable/optim.html#torch.optim.SGD 7 | class SGD(Optimizer): 8 | r"""Implements stochastic gradient descent (optionally with momentum). 9 | 10 | Nesterov momentum is based on the formula from 11 | `On the importance of initialization and momentum in deep learning`__. 12 | 13 | Args: 14 | params (iterable): iterable of parameters to optimize or dicts defining 15 | parameter groups 16 | lr (float): learning rate 17 | momentum (float, optional): momentum factor (default: 0) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | dampening (float, optional): dampening for momentum (default: 0) 20 | nesterov (bool, optional): enables Nesterov momentum (default: False) 21 | 22 | Example: 23 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 24 | >>> optimizer.zero_grad() 25 | >>> loss_fn(model(input), target).backward() 26 | >>> optimizer.step() 27 | 28 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 29 | 30 | .. note:: 31 | The implementation of SGD with Momentum/Nesterov subtly differs from 32 | Sutskever et. al. and implementations in some other frameworks. 33 | 34 | Considering the specific case of Momentum, the update can be written as 35 | 36 | .. math:: 37 | v = \rho * v + g \\ 38 | p = p - lr * v 39 | 40 | where p, g, v and :math:`\rho` denote the parameters, gradient, 41 | velocity, and momentum respectively. 42 | 43 | This is in contrast to Sutskever et. al. and 44 | other frameworks which employ an update of the form 45 | 46 | .. math:: 47 | v = \rho * v + lr * g \\ 48 | p = p - v 49 | 50 | The Nesterov version is analogously modified. 51 | """ 52 | 53 | def __init__(self, params, lr=required, momentum=0, dampening=0, 54 | weight_decay=0, nesterov=False): 55 | if lr is not required and lr < 0.0: 56 | raise ValueError("Invalid learning rate: {}".format(lr)) 57 | if momentum < 0.0: 58 | raise ValueError("Invalid momentum value: {}".format(momentum)) 59 | if weight_decay < 0.0: 60 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 61 | 62 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 63 | weight_decay=weight_decay, nesterov=nesterov) 64 | if nesterov and (momentum <= 0 or dampening != 0): 65 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 66 | super(SGD, self).__init__(params, defaults) 67 | 68 | def __setstate__(self, state): 69 | super(SGD, self).__setstate__(state) 70 | for group in self.param_groups: 71 | group.setdefault('nesterov', False) 72 | 73 | def step(self, closure=None): 74 | """Performs a single optimization step. 75 | 76 | Arguments: 77 | closure (callable, optional): A closure that reevaluates the model 78 | and returns the loss. 79 | """ 80 | loss = None 81 | if closure is not None: 82 | loss = closure() 83 | 84 | for group in self.param_groups: 85 | weight_decay = group['weight_decay'] 86 | momentum = group['momentum'] 87 | dampening = group['dampening'] 88 | nesterov = group['nesterov'] 89 | 90 | for p in group['params']: 91 | if p.grad is None: 92 | continue 93 | d_p = p.grad.data 94 | if weight_decay != 0: 95 | d_p.add_(weight_decay, p.data) 96 | if momentum != 0: 97 | param_state = self.state[p] 98 | if 'momentum_buffer' not in param_state: 99 | # buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 100 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach().mul_(group['lr']) 101 | else: 102 | buf = param_state['momentum_buffer'] 103 | # buf.mul_(momentum).add_(1 - dampening, d_p) 104 | buf.mul_(momentum).add_(1 - dampening, d_p.mul_(group['lr'])) 105 | if nesterov: 106 | d_p = d_p.add(momentum, buf) 107 | else: 108 | d_p = buf 109 | 110 | # p.data.add_(-group['lr'], d_p) 111 | p.data.add_(-1, d_p) 112 | 113 | return loss 114 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import glob 5 | import os 6 | import shutil 7 | import numpy as np 8 | import torch 9 | from pdb import set_trace as bp 10 | from PIL import Image 11 | import matplotlib.cm as mpl_color_map 12 | import copy 13 | 14 | 15 | def get_params(model, layers=["layer4"]): 16 | """ 17 | This generator returns all the parameters of the net except for 18 | the last classification layer. Note that for each batchnorm layer, 19 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 20 | any batchnorm parameter 21 | """ 22 | if isinstance(layers, str): 23 | layers = [layers] 24 | b = [] 25 | for layer in layers: 26 | b.append(getattr(model, layer)) 27 | 28 | for i in range(len(b)): 29 | for k, v in b[i].named_parameters(): 30 | if v.requires_grad: 31 | yield v 32 | 33 | 34 | def adjust_learning_rate_exp(optimizer, power=0.746): 35 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 36 | num_groups = len(optimizer.param_groups) 37 | for g in range(num_groups): 38 | optimizer.param_groups[g]['lr'] *= power 39 | 40 | 41 | def adjust_learning_rate(base_lrs, optimizer, iter_curr, iter_max, power): 42 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 43 | num_groups = len(optimizer.param_groups) 44 | for g in range(num_groups): 45 | optimizer.param_groups[g]['lr'] = lr_poly(base_lrs[g], iter_curr, iter_max, power) 46 | 47 | 48 | def lr_poly(base_lr, iter, max_iter, power): 49 | return min(0.01+0.99*(float(iter)/100)**2.0, 1.0) * base_lr * ((1-float(iter)/max_iter)**power) # This is with warm up 50 | # return min(0.01+0.99*(float(iter)/100)**2.0, 1.0) * base_lr * ((1-min(float(iter)/max_iter, 0.8))**power) # This is with warm up & no smaller than last 20% LR 51 | # return base_lr * ((1-float(iter)/max_iter)**power) 52 | 53 | 54 | def save_checkpoint(name, state, is_best, filename='checkpoint.pth.tar', keep_last=1): 55 | """Saves checkpoint to disk""" 56 | directory = name 57 | if not os.path.exists(directory): 58 | os.makedirs(directory) 59 | models_paths = list(filter(os.path.isfile, glob.glob(directory + "/epoch*.pth.tar"))) 60 | models_paths.sort(key=os.path.getmtime, reverse=False) 61 | if len(models_paths) == keep_last: 62 | for i in range(len(models_paths) + 1 - keep_last): 63 | os.remove(models_paths[i]) 64 | torch.save(state, directory + '/epoch_'+str(state['epoch']) + '_' + filename) 65 | filename = directory + '/latest_' + filename 66 | torch.save(state, filename) 67 | if is_best: 68 | shutil.copyfile(filename, '%s/'%(name) + 'model_best.pth.tar') 69 | 70 | 71 | class IterNums(object): 72 | def __init__(self, iter_max): 73 | self.iter_max = iter_max 74 | self.iter_curr = 0 75 | 76 | def reset(self): 77 | self.iter_curr = 0 78 | 79 | def update(self): 80 | self.iter_curr += 1 81 | 82 | 83 | class AverageMeter(object): 84 | """Computes and stores the average and current value""" 85 | def __init__(self): 86 | self.reset() 87 | 88 | def reset(self): 89 | self.val = 0 90 | self.avg = 0 91 | self.sum = 0 92 | self.count = 0 93 | self.vec2sca_avg = 0 94 | self.vec2sca_val = 0 95 | 96 | def update(self, val, n=1): 97 | self.val = val 98 | self.sum += val * n 99 | self.count += n 100 | self.avg = self.sum / self.count 101 | if torch.is_tensor(self.val) and torch.numel(self.val) != 1: 102 | self.avg[self.count == 0] = 0 103 | self.vec2sca_avg = self.avg.sum() / len(self.avg) 104 | self.vec2sca_val = self.val.sum() / len(self.val) 105 | 106 | 107 | def accuracy(output, label, num_class, topk=(1,)): 108 | """Computes the precision@k for the specified values of k, currently only k=1 is supported""" 109 | maxk = max(topk) 110 | 111 | _, pred = output.topk(maxk, 1, True, True) 112 | if len(label.size()) == 2: 113 | # one_hot label 114 | _, gt = label.topk(maxk, 1, True, True) 115 | else: 116 | gt = label 117 | pred = pred.t() 118 | pred_class_idx_list = [pred == class_idx for class_idx in range(num_class)] 119 | gt = gt.t() 120 | gt_class_number_list = [(gt == class_idx).sum() for class_idx in range(num_class)] 121 | correct = pred.eq(gt) 122 | 123 | res = [] 124 | gt_num = [] 125 | for k in topk: 126 | correct_k = correct[:k].float() 127 | per_class_correct_list = [correct_k[pred_class_idx].sum(0) for pred_class_idx in pred_class_idx_list] 128 | per_class_correct_array = torch.tensor(per_class_correct_list) 129 | gt_class_number_tensor = torch.tensor(gt_class_number_list).float() 130 | gt_class_zeronumber_tensor = gt_class_number_tensor == 0 131 | gt_class_number_matrix = torch.tensor(gt_class_number_list).float() 132 | gt_class_acc = per_class_correct_array.mul_(100.0 / gt_class_number_matrix) 133 | gt_class_acc[gt_class_zeronumber_tensor] = 0 134 | res.append(gt_class_acc) 135 | gt_num.append(gt_class_number_matrix) 136 | return res, gt_num 137 | 138 | 139 | def apply_colormap_on_image(org_im, activation, colormap_name='hsv'): 140 | """ 141 | Apply heatmap on image 142 | Args: 143 | org_img (PIL img): Original image 144 | activation_map (numpy arr): Activation map (grayscale) 0-255 145 | colormap_name (str): Name of the colormap 146 | """ 147 | # Get colormap 148 | color_map = mpl_color_map.get_cmap(colormap_name) 149 | no_trans_heatmap = color_map(activation) 150 | # Change alpha channel in colormap to make sure original image is displayed 151 | heatmap = copy.copy(no_trans_heatmap) 152 | heatmap[:, :, 3] = 0.5 153 | heatmap = Image.fromarray((heatmap*255).astype(np.uint8)) 154 | no_trans_heatmap = Image.fromarray((no_trans_heatmap*255).astype(np.uint8)) 155 | 156 | # Apply heatmap on iamge 157 | heatmap_on_image = Image.new("RGBA", org_im.size) 158 | heatmap_on_image = Image.alpha_composite(heatmap_on_image, org_im.convert('RGBA')) 159 | heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap) 160 | return no_trans_heatmap, heatmap_on_image 161 | 162 | 163 | class UnNormalize(object): 164 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 165 | self.mean = mean 166 | self.std = std 167 | 168 | def __call__(self, tensor): 169 | """ 170 | Args: 171 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 172 | Returns: 173 | Tensor: Normalized image. 174 | """ 175 | for t, m, s in zip(tensor, self.mean, self.std): 176 | t.mul_(s).add_(m) 177 | # The normalize code -> t.sub_(m).div_(s) 178 | return tensor 179 | 180 | 181 | class AvgrageMeter(object): 182 | 183 | def __init__(self): 184 | self.reset() 185 | 186 | def reset(self): 187 | self.avg = 0 188 | self.sum = 0 189 | self.cnt = 0 190 | 191 | def update(self, val, n=1): 192 | self.sum += val * n 193 | self.cnt += n 194 | self.avg = self.sum / self.cnt 195 | 196 | 197 | class Cutout(object): 198 | def __init__(self, length): 199 | self.length = length 200 | 201 | def __call__(self, img): 202 | h, w = img.size(1), img.size(2) 203 | mask = np.ones((h, w), np.float32) 204 | y = np.random.randint(h) 205 | x = np.random.randint(w) 206 | 207 | y1 = np.clip(y - self.length // 2, 0, h) 208 | y2 = np.clip(y + self.length // 2, 0, h) 209 | x1 = np.clip(x - self.length // 2, 0, w) 210 | x2 = np.clip(x + self.length // 2, 0, w) 211 | 212 | mask[y1: y2, x1: x2] = 0. 213 | mask = torch.from_numpy(mask) 214 | mask = mask.expand_as(img) 215 | img *= mask 216 | return img 217 | 218 | 219 | def count_parameters_in_MB(model): 220 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 221 | 222 | 223 | def save(model, model_path): 224 | torch.save(model.state_dict(), model_path) 225 | 226 | 227 | def load(model, model_path): 228 | model.load_state_dict(torch.load(model_path)) 229 | 230 | 231 | def create_exp_dir(path, scripts_to_save=None): 232 | if not os.path.exists(path): 233 | os.makedirs(path) 234 | print('Experiment dir : {}'.format(path)) 235 | 236 | if scripts_to_save is not None: 237 | os.mkdir(os.path.join(path, 'scripts')) 238 | for script in scripts_to_save: 239 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 240 | shutil.copyfile(script, dst_file) 241 | 242 | --------------------------------------------------------------------------------