├── .gitignore ├── LICENSE.md ├── README.md ├── config_seg.py ├── data ├── BaseDataset.py ├── cityscapes.py ├── cityscapes │ ├── cityscapes_test.txt │ ├── cityscapes_val_fine.txt │ └── cityscapes_val_fine_raw.txt ├── gta5.py ├── gta5 │ ├── gta5_train.txt │ └── gta5_train_raw.txt └── visda17.py ├── dataloader_seg.py ├── eval_seg.py ├── l2o_train.py ├── l2o_train.sh ├── l2o_train_seg.py ├── l2o_train_seg.sh ├── model ├── __init__.py ├── fcn8s_vgg.py ├── resnet.py └── vgg.py ├── reinforce ├── __init__.py ├── algo │ ├── __init__.py │ └── reinforce.py ├── arguments.py ├── distributions.py ├── models │ ├── policy.py │ └── rnn_state_encoder.py ├── storage.py └── utils.py ├── tools ├── datasets │ ├── BaseDataset.py │ └── cityscapes │ │ ├── cityscapes.py │ │ ├── cityscapes_test.txt │ │ ├── cityscapes_train_fine.txt │ │ └── cityscapes_val_fine.txt ├── engine │ ├── evaluator.py │ ├── logger.py │ └── tester.py ├── seg_opr │ └── metric.py └── utils │ ├── img_utils.py │ ├── pyt_utils.py │ └── visualize.py ├── train.py ├── train.sh ├── train_seg.py ├── train_seg.sh └── utils ├── __init__.py ├── logger.py ├── sgd.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # vim swp files 2 | *.swp 3 | # caffe/pytorch model files 4 | runs/* 5 | pretrained/* 6 | crst_visda/runs/* 7 | *.pth 8 | *.tar 9 | *_softmax.txt 10 | *_seed*.txt 11 | *_soft*.txt 12 | *.json 13 | 14 | # Mkdocs 15 | # /docs/ 16 | /mkdocs/docs/temp 17 | 18 | .DS_Store 19 | .idea 20 | .vscode 21 | .pytest_cache 22 | /experiments 23 | node_modules/ 24 | history/ 25 | ablation/ 26 | misc/ 27 | prediction/ 28 | results/ 29 | 30 | # resource temp folder 31 | tests/resources/temp/* 32 | !tests/resources/temp/.gitkeep 33 | 34 | # Byte-compiled / optimized / DLL files 35 | __pycache__/ 36 | *.py[cod] 37 | *$py.class 38 | 39 | # C extensions 40 | *.so 41 | 42 | # Distribution / packaging 43 | .Python 44 | build/ 45 | develop-eggs/ 46 | dist/ 47 | downloads/ 48 | eggs/ 49 | .eggs/ 50 | lib/ 51 | lib64/ 52 | parts/ 53 | sdist/ 54 | var/ 55 | wheels/ 56 | *.egg-info/ 57 | .installed.cfg 58 | *.egg 59 | MANIFEST 60 | 61 | # PyInstaller 62 | # Usually these files are written by a python script from a template 63 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 64 | *.manifest 65 | *.spec 66 | 67 | # Installer logs 68 | pip-log.txt 69 | pip-delete-this-directory.txt 70 | 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .coverage 75 | .coverage.* 76 | .cache 77 | nosetests.xml 78 | coverage.xml 79 | *.cover 80 | .hypothesis/ 81 | .pytest_cache/ 82 | 83 | # Translations 84 | *.mo 85 | *.pot 86 | 87 | # Django stuff: 88 | *.log 89 | .static_storage/ 90 | .media/ 91 | local_settings.py 92 | local_settings.py 93 | db.sqlite3 94 | 95 | # Flask stuff: 96 | instance/ 97 | .webassets-cache 98 | 99 | # Scrapy stuff: 100 | .scrapy 101 | 102 | # Sphinx documentation 103 | docs/_build/ 104 | 105 | # PyBuilder 106 | target/ 107 | 108 | # Jupyter Notebook 109 | .ipynb_checkpoints 110 | 111 | # pyenv 112 | .python-version 113 | 114 | # celery beat schedule file 115 | celerybeat-schedule 116 | 117 | # SageMath parsed files 118 | *.sage.py 119 | 120 | # Environments 121 | .env 122 | .venv 123 | env/ 124 | venv/ 125 | ENV/ 126 | env.bak/ 127 | venv.bak/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code 2 | 3 | License for Automated Synthetic-to-real Generalization (ASG) 4 | --- 5 | 6 | 1. Definitions 7 | 8 | "Licensor" means any person or entity that distributes its Work. 9 | 10 | "Software" means the original work of authorship made available under this License. 11 | 12 | "Work" means the Software and any additions to or derivative works of the Software that are made available under this License. 13 | 14 | The terms "reproduce," "reproduction," "derivative works," and "distribution" have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works, including the Software, are "made available" under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 17 | 18 | 2. License Grant 19 | 20 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 21 | 22 | 3. Limitations 23 | 24 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 25 | 26 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work ("Your Terms") only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 27 | 28 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, "non-commercially" means for research or evaluation purposes only. 29 | 30 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 31 | 32 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 33 | 34 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately. 35 | 36 | 4. Disclaimer of Warranty. 37 | 38 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 39 | 40 | 5. Limitation of Liability. 41 | 42 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.7](https://img.shields.io/badge/python-3.7-green.svg) 2 | 3 | # ASG: Automated Synthetic-to-Real Generalization 4 | 5 | 6 | [Paper](https://arxiv.org/abs/2007.06965) 7 | 8 | Automated Synthetic-to-Real Generalization.
9 | [Wuyang Chen](https://chenwydj.github.io/), [Zhiding Yu](https://chrisding.github.io/), [Zhangyang Wang](https://www.atlaswang.com/), [Anima Anandkumar](http://tensorlab.cms.caltech.edu/users/anima/).
10 | In ICML 2020. 11 | 12 | * Visda-17 to COCO 13 | - [x] train resnet101 with only proxy guidance 14 | - [x] train resnet101 with both proxy guidance and L2O policy 15 | - [x] evaluation 16 | * GTA5 to Cityscapes 17 | - [x] train vgg16 with only proxy guidance 18 | - [x] train vgg16 with both proxy guidance and L2O policy 19 | - [x] evaluation 20 | 21 | ## Usage 22 | 23 | ### Visda-17 24 | * Download [Visda-17 Dataset](http://ai.bu.edu/visda-2017/#download) 25 | 26 | #### Evaluation 27 | * Download [pretrained ResNet101 on Visda17](https://drive.google.com/file/d/1jjihDIxU1HIRtJEZyd7eTpYfO21OrY36/view?usp=sharing) 28 | * Put the checkpoint under `./ASG/pretrained/` 29 | * Put the code below in `train.sh` 30 | ```bash 31 | python train.py \ 32 | --epochs 30 \ 33 | --batch-size 32 \ 34 | --lr 1e-4 \ 35 | --lwf 0.1 \ 36 | --resume pretrained/res101_vista17_best.pth.tar \ 37 | --evaluate 38 | ``` 39 | * Run `CUDA_VISIBLE_DEVICES=0 bash train.sh` 40 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 41 | 42 | #### Train with SGD 43 | * Put the code below in `train.sh` 44 | ```bash 45 | python train.py \ 46 | --epochs 30 \ 47 | --batch-size 32 \ 48 | --lr 1e-4 \ 49 | --lwf 0.1 50 | ``` 51 | * Run `CUDA_VISIBLE_DEVICES=0 bash train.sh` 52 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 53 | 54 | #### Train with L2O 55 | * Download [pretrained L2O Policy on Visda17](https://drive.google.com/file/d/1Rc2Ey-FspUagFPTjnEozeSEIdA4ir7b1/view?usp=sharing) 56 | * Put the checkpoint under `./ASG/pretrained/` 57 | * Put the code below in `l2o_train.sh` 58 | ```bash 59 | python l2o_train.py \ 60 | --epochs 30 \ 61 | --batch-size 32 \ 62 | --lr 1e-4 \ 63 | --lwf 0.1 \ 64 | --agent_load_dir ./ASG/pretrained/policy_res101_vista17.pth 65 | ``` 66 | * Run `CUDA_VISIBLE_DEVICES=0 bash l2o_train.sh` 67 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 68 | 69 | ### GTA5 → Cityscapes 70 | * Download [GTA5 dataset](https://download.visinf.tu-darmstadt.de/data/from_games/). 71 | * Download the [leftImg8bit_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=3) and [gtFine_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=1) from the Cityscapes. 72 | * Prepare the annotations by using the [createTrainIdLabelImgs.py](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createTrainIdLabelImgs.py). 73 | * Put the [file of image list](tools/datasets/cityscapes/) into where you save the dataset. 74 | * **Remember to properly set the `C.dataset_path` in the `config_seg.py` to the path where datasets reside.** 75 | 76 | #### Evaluation 77 | * Download [pretrained Vgg16 on GTA5](https://drive.google.com/file/d/13HcsiyL-o1A9057ezJ4qCnGztnY5deQ6/view?usp=sharing) 78 | * Put the checkpoint under `./ASG/pretrained/` 79 | * Put the code below in `train_seg.sh` 80 | ```bash 81 | python train_seg.py \ 82 | --epochs 50 \ 83 | --batch-size 6 \ 84 | --lr 1e-3 \ 85 | --num-class 19 \ 86 | --gpus 0 \ 87 | --factor 0.1 \ 88 | --lwf 75. \ 89 | --evaluate \ 90 | --resume ./pretrained/vgg16_segmentation_best.pth.tar 91 | ``` 92 | * Run `CUDA_VISIBLE_DEVICES=0 bash train_seg.sh` 93 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 94 | 95 | #### Train with SGD 96 | * Put the code below in `train_seg.sh` 97 | ```bash 98 | python train_seg.py \ 99 | --epochs 50 \ 100 | --batch-size 6 \ 101 | --lr 1e-3 \ 102 | --num-class 19 \ 103 | --gpus 0 \ 104 | --factor 0.1 \ 105 | --lwf 75. \ 106 | ``` 107 | * Run `CUDA_VISIBLE_DEVICES=0 bash train_seg.sh` 108 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 109 | 110 | #### Train with L2O 111 | * Download [pretrained L2O Policy on GTA5](https://drive.google.com/file/d/1RVQE0VxrtPCyUpsvNulpKKBQhYlOi1ag/view?usp=sharing) 112 | * Put the checkpoint under `./ASG/pretrained/` 113 | * Put the code below in `l2o_train_seg.sh` 114 | ```bash 115 | python l2o_train_seg.py \ 116 | --epochs 50 \ 117 | --batch-size 6 \ 118 | --lr 1e-3 \ 119 | --num-class 19 \ 120 | --gpus 0 \ 121 | --gamma 0 \ 122 | --early-stop 2 \ 123 | --lwf 75. \ 124 | --algo reinforce \ 125 | --agent_load_dir ./ASG/pretrained/policy_vgg16_segmentation.pth 126 | ``` 127 | * Run `CUDA_VISIBLE_DEVICES=0 bash l2o_train_seg.sh` 128 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need. 129 | 130 | ## Citation 131 | 132 | If you use this code for your research, please cite: 133 | 134 | ```BibTeX 135 | @inproceedings{chen2020automated, 136 | author = {Chen, Wuyang and Yu, Zhiding and Wang, Zhangyang and Anandkumar, Anima}, 137 | booktitle = {Proceedings of Machine Learning and Systems 2020}, 138 | pages = {8272--8282}, 139 | title = {Automated Synthetic-to-Real Generalization}, 140 | year = {2020} 141 | } 142 | ``` 143 | -------------------------------------------------------------------------------- /config_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | # encoding: utf-8 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import os 11 | import os.path as osp 12 | import sys 13 | import numpy as np 14 | from easydict import EasyDict as edict 15 | 16 | C = edict() 17 | config = C 18 | cfg = C 19 | 20 | """please config ROOT_dir and user when u first using""" 21 | C.repo_name = 'ASG' 22 | C.abs_dir = osp.realpath(".") 23 | C.this_dir = C.abs_dir.split(osp.sep)[-1] 24 | C.root_dir = C.abs_dir[:C.abs_dir.index(C.repo_name) + len(C.repo_name)] 25 | 26 | 27 | """Data Dir""" 28 | C.dataset_path = "/home/chenwy/" 29 | 30 | C.train_img_root = os.path.join(C.dataset_path, "gta5") 31 | C.train_gt_root = os.path.join(C.dataset_path, "gta5") 32 | C.val_img_root = os.path.join(C.dataset_path, "cityscapes") 33 | C.val_gt_root = os.path.join(C.dataset_path, "cityscapes") 34 | C.test_img_root = os.path.join(C.dataset_path, "cityscapes") 35 | C.test_gt_root = os.path.join(C.dataset_path, "cityscapes") 36 | 37 | C.train_source = osp.join(C.train_img_root, "gta5_train.txt") 38 | C.train_target_source = osp.join(C.train_img_root, "cityscapes_train_fine.txt") 39 | C.eval_source = osp.join(C.val_img_root, "cityscapes_val_fine.txt") 40 | C.test_source = osp.join(C.test_img_root, "cityscapes_test.txt") 41 | 42 | """Image Config""" 43 | C.num_classes = 19 44 | C.background = -1 45 | C.image_mean = np.array([0.485, 0.456, 0.406]) 46 | C.image_std = np.array([0.229, 0.224, 0.225]) 47 | C.down_sampling_train = [1, 1] # first down_sampling then crop 48 | C.down_sampling_val = [1, 1] # first down_sampling then crop 49 | C.gt_down_sampling = 1 50 | C.num_train_imgs = 12403 51 | C.num_eval_imgs = 500 52 | 53 | """ Settings for network, this would be different for each kind of model""" 54 | C.bn_eps = 1e-5 55 | C.bn_momentum = 0.1 56 | 57 | """Train Config""" 58 | C.lr = 0.01 59 | C.momentum = 0.9 60 | C.weight_decay = 5e-4 61 | C.nepochs = 30 62 | C.niters_per_epoch = 2000 63 | C.num_workers = 16 64 | C.train_scale_array = [0.75, 1, 1.25] 65 | # C.train_scale_array = [1] 66 | 67 | """Eval Config""" 68 | C.eval_stride_rate = 5 / 6 69 | C.eval_scale_array = [1] 70 | C.eval_flip = True 71 | C.eval_base_size = 1024 72 | C.eval_crop_size = 1024 73 | C.eval_height = 1024 74 | C.eval_width = 2048 75 | 76 | # GTA5: 1052x1914 77 | C.image_height = 512 78 | C.image_width = 512 79 | C.is_test = False # if True, prediction files for the test set will be generated 80 | C.is_eval = False # if True, the train.py will only do evaluation for once 81 | -------------------------------------------------------------------------------- /data/BaseDataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import os 5 | import cv2 6 | import torch 7 | import numpy as np 8 | from pdb import set_trace as bp 9 | import torch.utils.data as data 10 | cv2.setNumThreads(0) 11 | 12 | 13 | class BaseDataset(data.Dataset): 14 | def __init__(self, setting, split_name, preprocess=None, file_length=None): 15 | super(BaseDataset, self).__init__() 16 | self._split_name = split_name 17 | if split_name == 'train': 18 | self._img_path = setting['train_img_root'] 19 | self._gt_path = setting['train_gt_root'] 20 | elif split_name == 'val': 21 | self._img_path = setting['val_img_root'] 22 | self._gt_path = setting['val_gt_root'] 23 | elif split_name == 'test': 24 | self._img_path = setting['test_img_root'] 25 | self._gt_path = setting['test_gt_root'] 26 | self._train_source = setting['train_source'] 27 | self._eval_source = setting['eval_source'] 28 | self._test_source = setting['test_source'] if 'test_source' in setting else setting['eval_source'] 29 | self._down_sampling = setting['down_sampling_train'] if split_name == 'train' else setting['down_sampling_val'] 30 | print("using downsampling:", self._down_sampling) 31 | self._file_names = self._get_file_names(split_name) 32 | print("Found %d images"%len(self._file_names)) 33 | self._file_length = file_length 34 | self.preprocess = preprocess 35 | 36 | def __len__(self): 37 | if self._file_length is not None: 38 | return self._file_length 39 | return len(self._file_names) 40 | 41 | def __getitem__(self, index): 42 | if self._file_length is not None: 43 | names = self._construct_new_file_names(self._file_length)[index] 44 | else: 45 | names = self._file_names[index] 46 | img_path = os.path.join(self._img_path, names[0]) 47 | gt_path = os.path.join(self._gt_path, names[1]) 48 | item_name = names[1].split("/")[-1].split(".")[0] 49 | img, gt = self._fetch_data(img_path, gt_path) 50 | img = img[:, :, ::-1] 51 | if self.preprocess is not None: 52 | img, gt, extra_dict = self.preprocess(img, gt) 53 | 54 | if self._split_name == 'train': 55 | img = torch.from_numpy(np.ascontiguousarray(img)).float() 56 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long() 57 | if self.preprocess is not None and extra_dict is not None: 58 | for k, v in extra_dict.items(): 59 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v)) 60 | if 'label' in k: 61 | extra_dict[k] = extra_dict[k].long() 62 | if 'img' in k: 63 | extra_dict[k] = extra_dict[k].float() 64 | 65 | output_dict = dict(data=img, label=gt, fn=str(item_name), n=len(self._file_names)) 66 | if self.preprocess is not None and extra_dict is not None: 67 | output_dict.update(**extra_dict) 68 | 69 | return output_dict 70 | 71 | def _fetch_data(self, img_path, gt_path, dtype=None): 72 | img = self._open_image(img_path, down_sampling=self._down_sampling[0]) 73 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype, down_sampling=self._down_sampling[1]) 74 | 75 | return img, gt 76 | 77 | def _get_file_names(self, split_name): 78 | assert split_name in ['train', 'val', 'test'] 79 | source = self._train_source 80 | if split_name == "val": 81 | source = self._eval_source 82 | elif split_name == 'test': 83 | source = self._test_source 84 | 85 | file_names = [] 86 | with open(source) as f: 87 | files = f.readlines() 88 | 89 | for item in files: 90 | img_name, gt_name = self._process_item_names(item) 91 | file_names.append([img_name, gt_name]) 92 | 93 | return file_names 94 | 95 | def _construct_new_file_names(self, length): 96 | assert isinstance(length, int) 97 | files_len = len(self._file_names) 98 | new_file_names = self._file_names * (length // files_len) 99 | 100 | rand_indices = torch.randperm(files_len).tolist() 101 | new_indices = rand_indices[:length % files_len] 102 | 103 | new_file_names += [self._file_names[i] for i in new_indices] 104 | 105 | return new_file_names 106 | 107 | @staticmethod 108 | def _process_item_names(item): 109 | item = item.strip() 110 | # item = item.split('\t') 111 | item = item.split(' ') 112 | img_name = item[0] 113 | gt_name = item[1] 114 | 115 | return img_name, gt_name 116 | 117 | def get_length(self): 118 | return self.__len__() 119 | 120 | @staticmethod 121 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None, down_sampling=1): 122 | # cv2: B G R 123 | # h w c 124 | img = np.array(cv2.imread(filepath, mode), dtype=dtype) 125 | if isinstance(down_sampling, int): 126 | try: 127 | H, W = img.shape[:2] 128 | except: 129 | print(img.shape, filepath) 130 | exit(0) 131 | if len(img.shape) == 3: 132 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_LINEAR) 133 | else: 134 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_NEAREST) 135 | assert img.shape[0] == H // down_sampling and img.shape[1] == W // down_sampling 136 | else: 137 | assert (isinstance(down_sampling, tuple) or isinstance(down_sampling, list)) and len(down_sampling) == 2 138 | if len(img.shape) == 3: 139 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_LINEAR) 140 | else: 141 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_NEAREST) 142 | assert img.shape[0] == down_sampling[0] and img.shape[1] == down_sampling[1] 143 | 144 | return img 145 | 146 | @classmethod 147 | def get_class_colors(*args): 148 | raise NotImplementedError 149 | 150 | @classmethod 151 | def get_class_names(*args): 152 | raise NotImplementedError 153 | 154 | 155 | if __name__ == "__main__": 156 | data_setting = {'img_root': '', 157 | 'gt_root': '', 158 | 'train_source': '', 159 | 'eval_source': ''} 160 | bd = BaseDataset(data_setting, 'train', None) 161 | print(bd.get_class_names()) 162 | -------------------------------------------------------------------------------- /data/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | from data.BaseDataset import BaseDataset 6 | 7 | 8 | class Cityscapes(BaseDataset): 9 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 10 | 28, 31, 32, 33] 11 | 12 | @classmethod 13 | def get_class_colors(*args): 14 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70], 15 | [102, 102, 156], [190, 153, 153], [153, 153, 153], 16 | [250, 170, 30], [220, 220, 0], [107, 142, 35], 17 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], 18 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], 19 | [0, 0, 230], [119, 11, 32]] 20 | 21 | @classmethod 22 | def get_class_names(*args): 23 | # class counting(gtFine) 24 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832 25 | # 359 274 142 513 1646 26 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 27 | 'traffic light', 'traffic sign', 28 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 29 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] 30 | 31 | @classmethod 32 | def transform_label(cls, pred, name): 33 | label = np.zeros(pred.shape) 34 | ids = np.unique(pred) 35 | for id in ids: 36 | label[np.where(pred == id)] = cls.trans_labels[id] 37 | 38 | new_name = (name.split('.')[0]).split('_')[:-1] 39 | new_name = '_'.join(new_name) + '.png' 40 | 41 | print('Trans', name, 'to', new_name, ' ', 42 | np.unique(np.array(pred, np.uint8)), ' ---------> ', 43 | np.unique(np.array(label, np.uint8))) 44 | return label, new_name 45 | -------------------------------------------------------------------------------- /data/gta5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | from data.BaseDataset import BaseDataset 6 | 7 | 8 | class GTA5(BaseDataset): 9 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 10 | 28, 31, 32, 33] 11 | 12 | @classmethod 13 | def get_class_colors(*args): 14 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70], 15 | [102, 102, 156], [190, 153, 153], [153, 153, 153], 16 | [250, 170, 30], [220, 220, 0], [107, 142, 35], 17 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], 18 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], 19 | [0, 0, 230], [119, 11, 32]] 20 | 21 | @classmethod 22 | def get_class_names(*args): 23 | # class counting(gtFine) 24 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832 25 | # 359 274 142 513 1646 26 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 27 | 'traffic light', 'traffic sign', 28 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 29 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] 30 | 31 | @classmethod 32 | def transform_label(cls, pred, name): 33 | label = np.zeros(pred.shape) 34 | ids = np.unique(pred) 35 | for id in ids: 36 | label[np.where(pred == id)] = cls.trans_labels[id] 37 | 38 | new_name = (name.split('.')[0]).split('_')[:-1] 39 | new_name = '_'.join(new_name) + '.png' 40 | 41 | print('Trans', name, 'to', new_name, ' ', 42 | np.unique(np.array(pred, np.uint8)), ' ---------> ', 43 | np.unique(np.array(label, np.uint8))) 44 | return label, new_name 45 | -------------------------------------------------------------------------------- /data/visda17.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import os 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision.transforms as transforms 10 | 11 | class VisDA17(Dataset): 12 | 13 | def __init__(self, txt_file, root_dir, transform=transforms.ToTensor(), label_one_hot=False, portion=1.0): 14 | """ 15 | Args: 16 | txt_file (string): Path to the txt file with annotations. 17 | root_dir (string): Directory with all the images. 18 | transform (callable, optional): Optional transform to be applied 19 | on a sample. 20 | """ 21 | self.lines = open(txt_file, 'r').readlines() 22 | self.root_dir = root_dir 23 | self.transform = transform 24 | self.label_one_hot = label_one_hot 25 | self.portion = portion 26 | self.number_classes = 12 27 | assert portion != 0 28 | if self.portion > 0: 29 | self.lines = self.lines[:round(self.portion * len(self.lines))] 30 | else: 31 | self.lines = self.lines[round(self.portion * len(self.lines)):] 32 | 33 | def __len__(self): 34 | return len(self.lines) 35 | 36 | def __getitem__(self, idx): 37 | line = str.split(self.lines[idx]) 38 | path_img = os.path.join(self.root_dir, line[0]) 39 | image = Image.open(path_img) 40 | image = image.convert('RGB') 41 | if self.label_one_hot: 42 | label = np.zeros(12, np.float32) 43 | label[np.asarray(line[1], dtype=np.int)] = 1 44 | else: 45 | label = np.asarray(line[1], dtype=np.int) 46 | label = torch.from_numpy(label) 47 | if self.transform: 48 | image = self.transform(image) 49 | return image, label 50 | -------------------------------------------------------------------------------- /dataloader_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import cv2 5 | from torch.utils import data 6 | from tools.utils.img_utils import random_scale, random_mirror, normalize, generate_random_crop_pos, random_crop_pad_to_shape 7 | cv2.setNumThreads(0) 8 | 9 | 10 | class TrainPre(object): 11 | def __init__(self, config, img_mean, img_std): 12 | self.img_mean = img_mean 13 | self.img_std = img_std 14 | self.config = config 15 | 16 | def __call__(self, img, gt): 17 | img, gt = random_mirror(img, gt) 18 | if self.config.train_scale_array is not None: 19 | img, gt, scale = random_scale(img, gt, self.config.train_scale_array) 20 | 21 | crop_size = (self.config.image_height, self.config.image_width) 22 | crop_pos = generate_random_crop_pos(img.shape[:2], crop_size) 23 | p_img, _ = random_crop_pad_to_shape(normalize(img, self.img_mean, self.img_std), crop_pos, crop_size, 0) 24 | p_img = p_img.transpose(2, 0, 1) 25 | extra_dict = None 26 | p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 255) 27 | p_gt = cv2.resize(p_gt, (self.config.image_width // self.config.gt_down_sampling, self.config.image_height // self.config.gt_down_sampling), interpolation=cv2.INTER_NEAREST) 28 | 29 | return p_img, p_gt, extra_dict 30 | 31 | 32 | def get_train_loader(config, dataset, worker=None, test=False): 33 | data_setting = { 34 | 'train_img_root': config.train_img_root, 35 | 'train_gt_root': config.train_gt_root, 36 | 'val_img_root': config.val_img_root, 37 | 'val_gt_root': config.val_gt_root, 38 | 'train_source': config.train_source, 39 | 'eval_source': config.eval_source, 40 | 'down_sampling_train': config.down_sampling_train 41 | } 42 | if test: 43 | data_setting = {'img_root': config.img_root, 44 | 'gt_root': config.gt_root, 45 | 'train_source': config.train_eval_source, 46 | 'eval_source': config.eval_source} 47 | train_preprocess = TrainPre(config, config.image_mean, config.image_std) 48 | 49 | train_dataset = dataset(data_setting, "train", train_preprocess, config.batch_size * config.niters_per_epoch) 50 | 51 | is_shuffle = True 52 | batch_size = config.batch_size 53 | 54 | train_loader = data.DataLoader(train_dataset, 55 | batch_size=batch_size, 56 | num_workers=config.num_workers if worker is None else worker, 57 | drop_last=True, 58 | shuffle=is_shuffle, 59 | pin_memory=True, 60 | ) 61 | 62 | return train_loader 63 | -------------------------------------------------------------------------------- /eval_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | #!/usr/bin/env python3 5 | # encoding: utf-8 6 | import os 7 | import cv2 8 | import numpy as np 9 | from pdb import set_trace as bp 10 | from tools.utils.visualize import print_iou, show_img, show_prediction 11 | from tools.engine.evaluator import Evaluator 12 | from tools.engine.logger import get_logger 13 | from tools.seg_opr.metric import hist_info, compute_score 14 | cv2.setNumThreads(0) 15 | 16 | logger = get_logger() 17 | 18 | 19 | class SegEvaluator(Evaluator): 20 | def func_per_iteration(self, data, device, iter=None): 21 | if self.config is not None: config = self.config 22 | img = data['data'] 23 | label = data['label'] 24 | name = data['fn'] 25 | 26 | if len(config.eval_scale_array) == 1: 27 | pred = self.whole_eval(img, label.shape, resize=config.eval_scale_array[0], device=device) 28 | pred = pred.argmax(2) # since we ignore this step in evaluator.py 29 | elif len(config.eval_scale_array) > 1: 30 | pred = self.whole_eval(img, label.shape, resize=config.eval_scale_array[0], device=device) 31 | for scale in config.eval_scale_array[1:]: 32 | pred += self.whole_eval(img, label.shape, resize=scale, device=device) 33 | pred = pred.argmax(2) # since we ignore this step in evaluator.py 34 | else: 35 | pred = self.sliding_eval(img, config.eval_crop_size, config.eval_stride_rate, device) 36 | hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes, pred, label) 37 | results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp} 38 | 39 | if self.save_path is not None: 40 | fn = name + '.png' 41 | cv2.imwrite(os.path.join(self.save_path, fn), pred) 42 | logger.info('Save the image ' + fn) 43 | 44 | # tensorboard logger does not fit multiprocess 45 | if self.logger is not None and iter is not None: 46 | colors = self.dataset.get_class_colors() 47 | image = img 48 | clean = np.zeros(label.shape) 49 | comp_img = show_img(colors, config.background, image, clean, label, pred) 50 | self.logger.add_image('vis', np.swapaxes(np.swapaxes(comp_img, 0, 2), 1, 2), iter) 51 | 52 | if self.show_image or self.show_prediction: 53 | colors = self.dataset.get_class_colors() 54 | image = img 55 | clean = np.zeros(label.shape) 56 | if self.show_image: 57 | comp_img = show_img(colors, config.background, image, clean, label, pred) 58 | cv2.imwrite(os.path.join(self.save_path, name + ".png"), comp_img[:,:,::-1]) 59 | if self.show_prediction: 60 | comp_img = show_prediction(colors, config.background, image, pred) 61 | cv2.imwrite(os.path.join(self.save_path, "viz_"+name+".png"), comp_img[:,:,::-1]) 62 | 63 | return results_dict 64 | 65 | def compute_metric(self, results): 66 | hist = np.zeros((self.config.num_classes, self.config.num_classes)) 67 | correct = 0 68 | labeled = 0 69 | count = 0 70 | for d in results: 71 | hist += d['hist'] 72 | correct += d['correct'] 73 | labeled += d['labeled'] 74 | count += 1 75 | 76 | iu, mean_IU, mean_IU_no_back, mean_pixel_acc = compute_score(hist, correct, labeled) 77 | result_line = print_iou(iu, mean_pixel_acc, self.dataset.get_class_names(), True) 78 | return result_line, mean_IU 79 | -------------------------------------------------------------------------------- /l2o_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | # [x] train resnet101 with both proxy guidance and L2O policy on visda17 5 | 6 | import os 7 | import sys 8 | import time 9 | from collections import deque 10 | import logging 11 | from random import choice 12 | from tqdm import tqdm 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch.utils.data import DataLoader 18 | import torchvision.transforms as transforms 19 | 20 | from data.visda17 import VisDA17 21 | from model.resnet import resnet101 22 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, accuracy 23 | from utils.logger import prepare_logger, prepare_seed 24 | from utils.sgd import SGD 25 | 26 | from reinforce.arguments import get_args 27 | from reinforce.models.policy import Policy 28 | 29 | from pdb import set_trace as bp 30 | 31 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean') 32 | KLDivLoss = nn.KLDivLoss(reduction='batchmean') 33 | 34 | 35 | def adjust_learning_rate(lr, optimizer): 36 | for param_group in optimizer.param_groups: 37 | param_group['lr'] = lr 38 | 39 | 40 | def get_window_sample(train_loader_iter, train_loader, window_size=1): 41 | samples = [] 42 | while len(samples) < window_size: 43 | try: 44 | sample = next(train_loader_iter) 45 | except: 46 | train_loader_iter = iter(train_loader) 47 | sample = next(train_loader_iter) 48 | samples.append(sample) 49 | return samples, train_loader_iter 50 | 51 | 52 | def train_step(args, _window_size, train_loader_iter, train_loader, model, optimizer, obs_avg, base_lr, pbar, step, total_steps, model_old=None): 53 | # if obs_avg: average the observation in the window 54 | losses = [] 55 | losses_kl = [] 56 | fc_mean = []; fc_std = [] 57 | optimizee_step = [] 58 | for idx in range(_window_size): 59 | optimizer.zero_grad() 60 | """Train for one sample on the training set""" 61 | samples, train_loader_iter = get_window_sample(train_loader_iter, train_loader) 62 | input, label = samples[0] 63 | label = label.cuda() 64 | input = input.cuda() 65 | # compute output 66 | output, features_new = model(input, output_features=['layer4'], task='new') 67 | # compute gradient 68 | loss = CrossEntropyLoss(output, label.long()) 69 | # LWF KL div 70 | loss_kl = 0 71 | if model_old is not None: 72 | output_new = model.forward_fc(features_new['layer4'], task='old') 73 | output_old, _ = model_old(input, output_features=[], task='old') 74 | loss_kl = KLDivLoss(F.log_softmax(output_new, dim=1), F.softmax(output_old, dim=1)).sum(-1) 75 | (loss + args.lwf * loss_kl).backward() 76 | # compute gradient and do SGD step 77 | optimizer.step() 78 | fc_mean.append(model.fc_new[2].weight.mean().detach()) 79 | fc_std.append(model.fc_new[2].weight.std().detach()) 80 | description = "[step: %.5f][loss: %.1f][loss_kl: %.1f][fc_mean: %.3f][fc_std: %.3f]"%(1. * (step + idx) / total_steps, loss, loss_kl, fc_mean[-1]*1000, fc_std[-1]*1000) 81 | pbar.set_description("[Step %d/%d]"%(step + idx, total_steps) + description) 82 | losses.append(loss.detach()) 83 | losses_kl.append(loss_kl.detach()) 84 | optimizee_step.append(1. * (step + idx) / total_steps) 85 | if obs_avg: 86 | losses = [sum(losses) / len(losses)] 87 | losses_kl = [sum(losses_kl) / len(losses_kl)] 88 | fc_mean = [sum(fc_mean) / len(fc_mean)] 89 | fc_std = [sum(fc_std) / len(fc_std)] 90 | optimizee_step = [sum(optimizee_step) / len(optimizee_step)] 91 | losses = [loss for loss in losses] 92 | losses_kl = [loss_kl for loss_kl in losses_kl] 93 | optimizee_step = [torch.tensor(step).cuda() for step in optimizee_step] 94 | observation = torch.stack(losses + losses_kl + optimizee_step + fc_mean + fc_std, dim=0) 95 | LRs = torch.Tensor([ group['lr'] / base_lr for group in optimizer.param_groups ]).cuda() 96 | observation = torch.cat([observation, LRs], dim=0).unsqueeze(0) # (batch=1, feature_size=window_size) 97 | return train_loader_iter, observation, torch.mean(torch.stack(losses, dim=0)), torch.mean(torch.stack(losses_kl, dim=0)), torch.mean(torch.stack(fc_mean, dim=0)), torch.mean(torch.stack(fc_std, dim=0)) 98 | 99 | 100 | def prepare_optimizee(args, sgd_in_names, obs_shape, hidden_size, actor_critic, current_optimizee_step, prev_optimizee_step): 101 | prev_optimizee_step += current_optimizee_step 102 | current_optimizee_step = 0 103 | 104 | model = resnet101(pretrained=True) 105 | num_ftrs = model.fc.in_features 106 | fc_layers = nn.Sequential( 107 | nn.Linear(num_ftrs, 512), 108 | nn.ReLU(inplace=True), 109 | nn.Linear(512, args.num_class), 110 | ) 111 | model.fc_new = fc_layers 112 | 113 | train_blocks = args.train_blocks.split('.') 114 | # default turn-off fc, turn-on fc_new 115 | for param in model.fc.parameters(): 116 | param.requires_grad = False 117 | ##### Freeze several bottom layers (Optional) ##### 118 | non_train_blocks = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc'] 119 | for name in train_blocks: 120 | try: 121 | non_train_blocks.remove(name) 122 | except Exception: 123 | print("cannot find block name %s\nAvailable blocks are: conv1, bn1, layer1, layer2, layer3, layer4, fc"%name) 124 | for name in non_train_blocks: 125 | for param in getattr(model, name).parameters(): 126 | param.requires_grad = False 127 | 128 | # Setup optimizer 129 | sgd_in = [] 130 | for name in train_blocks: 131 | if name != 'fc': 132 | sgd_in.append({'params': get_params(model, [name]), 'lr': args.lr}) 133 | else: 134 | sgd_in.append({'params': get_params(model, ["fc_new"]), 'lr': args.lr}) 135 | base_lrs = [ group['lr'] for group in sgd_in ] 136 | optimizer = SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 137 | 138 | model = model.cuda() 139 | model.eval() 140 | return model, optimizer, current_optimizee_step, prev_optimizee_step 141 | 142 | 143 | def main(): 144 | args = get_args() 145 | PID = os.getpid() 146 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir)) 147 | prepare_seed(args.rand_seed) 148 | 149 | if args.timestamp == 'none': 150 | args.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time()))) 151 | 152 | torch.set_num_threads(1) 153 | 154 | # Log outputs 155 | args.save_dir = args.save_dir + \ 156 | "/Visda17-L2O.train.Res101-%s-train.%s-LR%.2E-epoch%d-batch%d-seed%d"%( 157 | "LWF" if args.lwf > 0 else "XE", args.train_blocks, args.lr, args.epochs, args.batch_size, args.rand_seed) + \ 158 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp) 159 | logger = prepare_logger(args) 160 | 161 | best_prec1 = 0 162 | 163 | #### preparation ########################################### 164 | data_transforms = { 165 | 'train': transforms.Compose([ 166 | transforms.Resize(224), 167 | transforms.CenterCrop(224), 168 | transforms.ToTensor(), 169 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 170 | ]), 171 | 'val': transforms.Compose([ 172 | transforms.Resize(224), 173 | transforms.CenterCrop(224), 174 | transforms.ToTensor(), 175 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 176 | ]), 177 | } 178 | 179 | kwargs = {'num_workers': 20, 'pin_memory': True} 180 | trainset = VisDA17(txt_file=os.path.join(args.data, "train/image_list.txt"), root_dir=os.path.join(args.data, "train"), transform=data_transforms['train']) 181 | valset = VisDA17(txt_file=os.path.join(args.data, "validation/image_list.txt"), root_dir=os.path.join(args.data, "validation"), transform=data_transforms['val'], label_one_hot=True) 182 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) 183 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, **kwargs) 184 | train_loader_iter = iter(train_loader) 185 | current_optimizee_step, prev_optimizee_step = 0, 0 186 | 187 | model_old = None 188 | if args.lwf > 0: 189 | # create a fixed model copy for Life-long learning 190 | model_old = resnet101(pretrained=True) 191 | for param in model_old.parameters(): 192 | param.requires_grad = False 193 | model_old.eval() 194 | model_old.cuda() 195 | ############################################################ 196 | 197 | ### Agent Settings ######################################## 198 | RANDOM = False # False | True | 'init' 199 | action_space = np.arange(0, 1.1, 0.1) 200 | obs_avg = True 201 | _window_size = 1 202 | window_size = 1 if obs_avg else _window_size 203 | window_shrink_size = 20 # larger: controller will be updated more frequently 204 | sgd_in_names = ["conv1", "bn1", "layer1", "layer2", "layer3", "layer4", "fc_new"] 205 | coord_size = len(sgd_in_names) 206 | ob_name_lstm = ["loss", "loss_kl", "step", "fc_mean", "fc_std"] 207 | ob_name_scalar = [] 208 | obs_shape = (len(ob_name_lstm) * window_size + len(ob_name_scalar) + coord_size, ) 209 | _hidden_size = 20 210 | hidden_size = _hidden_size * len(ob_name_lstm) 211 | actor_critic = Policy(coord_size, input_size=(len(ob_name_lstm), len(ob_name_scalar)), action_space=len(action_space), hidden_size=_hidden_size, window_size=window_size) 212 | actor_critic.cuda() 213 | actor_critic.eval() 214 | 215 | partial = torch.load(args.agent_load_dir, map_location=lambda storage, loc: storage) 216 | state = actor_critic.state_dict() 217 | pretrained_dict = {k: v for k, v in partial.items()} 218 | state.update(pretrained_dict) 219 | actor_critic.load_state_dict(state) 220 | 221 | ################################################################ 222 | 223 | _min_iter = 10 224 | # reset optmizee 225 | model, optimizer, current_optimizee_step, prev_optimizee_step = prepare_optimizee(args, sgd_in_names, obs_shape, hidden_size, actor_critic, current_optimizee_step, prev_optimizee_step) 226 | epoch_size = len(train_loader) 227 | total_steps = epoch_size*args.epochs 228 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 229 | pbar = tqdm(range(int(epoch_size*args.epochs)), file=sys.stdout, bar_format=bar_format, ncols=100) 230 | _window_size = max(_min_iter, current_optimizee_step + prev_optimizee_step // window_shrink_size) 231 | train_loader_iter, obs, loss, loss_kl, fc_mean, fc_std = train_step(args, _window_size, train_loader_iter, train_loader, model, optimizer, obs_avg, args.lr, pbar, current_optimizee_step + prev_optimizee_step, total_steps, model_old=model_old) 232 | logger.writer.add_scalar("loss/ce", loss, current_optimizee_step + prev_optimizee_step) 233 | logger.writer.add_scalar("loss/kl", loss_kl, current_optimizee_step + prev_optimizee_step) 234 | logger.writer.add_scalar("loss/total", loss + loss_kl, current_optimizee_step + prev_optimizee_step) 235 | logger.writer.add_scalar("fc/mean", fc_mean, current_optimizee_step + prev_optimizee_step) 236 | logger.writer.add_scalar("fc/std", fc_std, current_optimizee_step + prev_optimizee_step) 237 | current_optimizee_step += _window_size 238 | pbar.update(_window_size) 239 | prev_obs = obs.unsqueeze(0) 240 | prev_hidden = torch.zeros(actor_critic.net.num_recurrent_layers, 1, hidden_size).cuda() 241 | for epoch in range(args.epochs): 242 | print("\n===== Epoch %d / %d ====="%(epoch+1, args.epochs)) 243 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir)) 244 | while current_optimizee_step < epoch_size: 245 | # Sample actions 246 | with torch.no_grad(): 247 | if not RANDOM: 248 | value, action, action_log_prob, recurrent_hidden_states, distribution = actor_critic.act(prev_obs, prev_hidden, deterministic=False) 249 | action = action.squeeze() 250 | action_log_prob = action_log_prob.squeeze() 251 | value = value.squeeze() 252 | for idx in range(len(action)): 253 | logger.writer.add_scalar("action/%s"%sgd_in_names[idx], action[idx], current_optimizee_step + prev_optimizee_step) 254 | logger.writer.add_scalar("entropy/%s"%sgd_in_names[idx], distribution.distributions[idx].entropy(), current_optimizee_step + prev_optimizee_step) 255 | optimizer.param_groups[idx]['lr'] = float(action_space[action[idx]]) * args.lr 256 | logger.writer.add_scalar("LR/%s"%sgd_in_names[idx], optimizer.param_groups[idx]['lr'], current_optimizee_step + prev_optimizee_step) 257 | else: 258 | if RANDOM is True or RANDOM == 'init': 259 | for idx in range(coord_size): 260 | optimizer.param_groups[idx]['lr'] = float(choice(action_space)) * args.lr 261 | if RANDOM == 'init': 262 | RANDOM = 'done' 263 | for idx in range(coord_size): 264 | logger.writer.add_scalar("LR/%s"%sgd_in_names[idx], optimizer.param_groups[idx]['lr'], current_optimizee_step + prev_optimizee_step) 265 | 266 | # Obser reward and next obs 267 | _window_size = max(_min_iter, current_optimizee_step + prev_optimizee_step // window_shrink_size) 268 | _window_size = min(_window_size, epoch_size - current_optimizee_step) 269 | train_loader_iter, obs, loss, loss_kl, fc_mean, fc_std = train_step(args, _window_size, train_loader_iter, train_loader, model, optimizer, obs_avg, args.lr, pbar, current_optimizee_step + prev_optimizee_step, total_steps, model_old=model_old) 270 | logger.writer.add_scalar("loss/ce", loss, current_optimizee_step + prev_optimizee_step) 271 | logger.writer.add_scalar("loss/kl", loss_kl, current_optimizee_step + prev_optimizee_step) 272 | logger.writer.add_scalar("loss/total", loss + loss_kl, current_optimizee_step + prev_optimizee_step) 273 | logger.writer.add_scalar("fc/mean", fc_mean, current_optimizee_step + prev_optimizee_step) 274 | logger.writer.add_scalar("fc/std", fc_std, current_optimizee_step + prev_optimizee_step) 275 | current_optimizee_step += _window_size 276 | pbar.update(_window_size) 277 | prev_obs = obs.unsqueeze(0) 278 | if not RANDOM: prev_hidden = recurrent_hidden_states 279 | prev_optimizee_step += current_optimizee_step 280 | current_optimizee_step = 0 281 | 282 | # evaluate on validation set 283 | prec1 = validate(val_loader, model, args) 284 | logger.writer.add_scalar("prec", prec1, epoch) 285 | 286 | # remember best prec@1 and save checkpoint 287 | is_best = prec1 > best_prec1 288 | best_prec1 = max(prec1, best_prec1) 289 | save_checkpoint(args.save_dir, { 290 | 'epoch': epoch + 1, 291 | 'state_dict': model.state_dict(), 292 | 'best_prec1': best_prec1, 293 | }, is_best) 294 | 295 | logging.info('Best accuracy: {prec1:.3f}'.format(prec1=best_prec1)) 296 | 297 | 298 | def validate(val_loader, model, args): 299 | """Perform validation on the validation set""" 300 | batch_time = AverageMeter() 301 | top1 = AverageMeter() 302 | 303 | model.eval() 304 | 305 | end = time.time() 306 | val_size = len(val_loader) 307 | val_loader_iter = iter(val_loader) 308 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 309 | pbar = tqdm(range(val_size), file=sys.stdout, bar_format=bar_format, ncols=140) 310 | with torch.no_grad(): 311 | for idx_iter in pbar: 312 | input, label = next(val_loader_iter) 313 | 314 | input = input.cuda() 315 | label = label.cuda() 316 | 317 | # compute output 318 | output = torch.sigmoid(model(input, task='new')[0]) 319 | output = (output + torch.sigmoid(model(torch.flip(input, dims=(3,)), task='new')[0])) / 2 320 | 321 | # accumulate accuracyk 322 | prec1, gt_num = accuracy(output.data, label, args.num_class, topk=(1,)) 323 | top1.update(prec1[0], gt_num[0]) 324 | 325 | # measure elapsed time 326 | batch_time.update(time.time() - end) 327 | end = time.time() 328 | 329 | description = "[Acc@1-mean: %.2f][Acc@1-cls: %s]"%(top1.vec2sca_avg, str(top1.avg.numpy().round(1))) 330 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, val_size) + description) 331 | 332 | logging.info(' * Prec@1 {top1.vec2sca_avg:.3f}'.format(top1=top1)) 333 | logging.info(' * Prec@1 {top1.avg}'.format(top1=top1)) 334 | 335 | return top1.vec2sca_avg 336 | 337 | 338 | if __name__ == "__main__": 339 | main() 340 | -------------------------------------------------------------------------------- /l2o_train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | python l2o_train.py \ 5 | --epochs 30 \ 6 | --batch-size 32 \ 7 | --lr 1e-4 \ 8 | --num-class 12 \ 9 | --lwf 0.1 \ 10 | --agent_load_dir /raid/ASG/pretrained/policy_res101_vista17.pth 11 | -------------------------------------------------------------------------------- /l2o_train_seg.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | python l2o_train_seg.py \ 5 | --epochs 50 \ 6 | --batch-size 6 \ 7 | --lr 1e-3 \ 8 | --num-class 19 \ 9 | --gpus 0 \ 10 | --gamma 0 \ 11 | --early-stop 2 \ 12 | --lwf 75. \ 13 | --algo reinforce 14 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | -------------------------------------------------------------------------------- /model/fcn8s_vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import os.path as osp 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from pdb import set_trace as bp 9 | 10 | 11 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 12 | """Make a 2D bilinear kernel suitable for upsampling""" 13 | factor = (kernel_size + 1) // 2 14 | if kernel_size % 2 == 1: 15 | center = factor - 1 16 | else: 17 | center = factor - 0.5 18 | og = np.ogrid[:kernel_size, :kernel_size] 19 | filt = (1 - abs(og[0] - center) / factor) * \ 20 | (1 - abs(og[1] - center) / factor) 21 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 22 | dtype=np.float64) 23 | weight[range(in_channels), range(out_channels), :, :] = filt 24 | return torch.from_numpy(weight).float() 25 | 26 | 27 | class FCN8s(nn.Module): 28 | 29 | def __init__(self, n_class=21): 30 | super(FCN8s, self).__init__() 31 | # conv1 32 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 33 | self.relu1_1 = nn.ReLU(inplace=True) 34 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 35 | self.relu1_2 = nn.ReLU(inplace=True) 36 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 37 | 38 | # conv2 39 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 40 | self.relu2_1 = nn.ReLU(inplace=True) 41 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 42 | self.relu2_2 = nn.ReLU(inplace=True) 43 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 44 | 45 | # conv3 46 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 47 | self.relu3_1 = nn.ReLU(inplace=True) 48 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 49 | self.relu3_2 = nn.ReLU(inplace=True) 50 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 51 | self.relu3_3 = nn.ReLU(inplace=True) 52 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 53 | 54 | # conv4 55 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 56 | self.relu4_1 = nn.ReLU(inplace=True) 57 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 58 | self.relu4_2 = nn.ReLU(inplace=True) 59 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 60 | self.relu4_3 = nn.ReLU(inplace=True) 61 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 62 | 63 | # conv5 64 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 65 | self.relu5_1 = nn.ReLU(inplace=True) 66 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 67 | self.relu5_2 = nn.ReLU(inplace=True) 68 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 69 | self.relu5_3 = nn.ReLU(inplace=True) 70 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 71 | 72 | # fc6 73 | self.fc6 = nn.Conv2d(512, 4096, 7) 74 | self.relu6 = nn.ReLU(inplace=True) 75 | self.drop6 = nn.Dropout2d() 76 | 77 | # fc7 78 | self.fc7 = nn.Conv2d(4096, 4096, 1) 79 | self.relu7 = nn.ReLU(inplace=True) 80 | self.drop7 = nn.Dropout2d() 81 | 82 | self.score_fr = nn.Conv2d(4096, n_class, 1) 83 | self.score_pool3 = nn.Conv2d(256, n_class, 1) 84 | self.score_pool4 = nn.Conv2d(512, n_class, 1) 85 | 86 | self.upscore2 = nn.ConvTranspose2d( 87 | n_class, n_class, 4, stride=2, bias=False) 88 | self.upscore8 = nn.ConvTranspose2d( 89 | n_class, n_class, 16, stride=8, bias=False) 90 | self.upscore_pool4 = nn.ConvTranspose2d( 91 | n_class, n_class, 4, stride=2, bias=False) 92 | 93 | self._initialize_weights() 94 | 95 | def _initialize_weights(self): 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | m.weight.data.zero_() 99 | if m.bias is not None: 100 | m.bias.data.zero_() 101 | if isinstance(m, nn.ConvTranspose2d): 102 | assert m.kernel_size[0] == m.kernel_size[1] 103 | initial_weight = get_upsampling_weight( 104 | m.in_channels, m.out_channels, m.kernel_size[0]) 105 | m.weight.data.copy_(initial_weight) 106 | 107 | def forward_fc(self, features, task='new_seg'): 108 | h = features['layer5'] 109 | h = self.score_fr(h) 110 | h = self.upscore2(h) 111 | upscore2 = h # 1/16 112 | 113 | h = self.score_pool4(features['layer3']) 114 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 115 | score_pool4c = h # 1/16 116 | 117 | h = upscore2 + score_pool4c # 1/16 118 | h = self.upscore_pool4(h) 119 | upscore_pool4 = h # 1/8 120 | 121 | h = self.score_pool3(features['layer2']) 122 | h = h[:, :, 123 | 9:9 + upscore_pool4.size()[2], 124 | 9:9 + upscore_pool4.size()[3]] 125 | score_pool3c = h # 1/8 126 | 127 | h = upscore_pool4 + score_pool3c # 1/8 128 | 129 | h = self.upscore8(h) 130 | h = h[:, :, 31:31 + self.input_size[2], 31:31 + self.input_size[3]].contiguous() 131 | return h 132 | 133 | def forward_backbone(self, x, output_features=['layer4']): 134 | features = {} 135 | h = x 136 | h = self.relu1_1(self.conv1_1(h)) 137 | h = self.relu1_2(self.conv1_2(h)) 138 | h = self.pool1(h) 139 | 140 | h = self.relu2_1(self.conv2_1(h)) 141 | h = self.relu2_2(self.conv2_2(h)) 142 | h = self.pool2(h) 143 | 144 | h = self.relu3_1(self.conv3_1(h)) 145 | h = self.relu3_2(self.conv3_2(h)) 146 | h = self.relu3_3(self.conv3_3(h)) 147 | h = self.pool3(h) 148 | pool3 = h # 1/8 149 | features['layer2'] = pool3 150 | 151 | h = self.relu4_1(self.conv4_1(h)) 152 | h = self.relu4_2(self.conv4_2(h)) 153 | h = self.relu4_3(self.conv4_3(h)) 154 | h = self.pool4(h) 155 | pool4 = h # 1/16 156 | features['layer3'] = pool4 157 | 158 | h = self.relu5_1(self.conv5_1(h)) 159 | h = self.relu5_2(self.conv5_2(h)) 160 | h = self.relu5_3(self.conv5_3(h)) 161 | h = self.pool5(h) 162 | pool5 = h # 1/32 163 | features['layer4'] = pool5 164 | 165 | h = self.relu6(self.fc6(h)) 166 | h = self.drop6(h) 167 | 168 | h = self.relu7(self.fc7(h)) 169 | h = self.drop7(h) 170 | features['layer5'] = h 171 | return features 172 | 173 | def forward(self, x, output_features=['layer4'], task='new_seg'): 174 | ''' 175 | task: 'old' | 'new' | 'new_seg' 176 | 'old', 'new': classification tasks (ImageNet or Visda) 177 | 'new_seg': segmentation head (convs) 178 | ''' 179 | self.input_size = x.size() 180 | ###### standard FCN ################## 181 | features = self.forward_backbone(x, output_features) 182 | x = self.forward_fc(features, task=task) 183 | ###################################### 184 | return x, features 185 | 186 | def copy_params_from_fcn16s(self, fcn16s): 187 | for name, l1 in fcn16s.named_children(): 188 | try: 189 | l2 = getattr(self, name) 190 | l2.weight # skip ReLU / Dropout 191 | except Exception: 192 | continue 193 | assert l1.weight.size() == l2.weight.size() 194 | l2.weight.data.copy_(l1.weight.data) 195 | if l1.bias is not None: 196 | assert l1.bias.size() == l2.bias.size() 197 | l2.bias.data.copy_(l1.bias.data) 198 | 199 | 200 | class FCN8sAtOnce(FCN8s): 201 | 202 | def forward_fc(self, features, task='new_seg'): 203 | h = features['layer5'] 204 | h = self.score_fr(h) 205 | h = self.upscore2(h) 206 | upscore2 = h # 1/16 207 | 208 | h = self.score_pool4(features['layer3'] * 0.01) # scaling to train at once 209 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 210 | score_pool4c = h # 1/16 211 | 212 | h = upscore2 + score_pool4c # 1/16 213 | h = self.upscore_pool4(h) 214 | upscore_pool4 = h # 1/8 215 | 216 | h = self.score_pool3(features['layer2'] * 0.0001) # scaling to train at once 217 | h = h[:, :, 218 | 9:9 + upscore_pool4.size()[2], 219 | 9:9 + upscore_pool4.size()[3]] 220 | score_pool3c = h # 1/8 221 | 222 | h = upscore_pool4 + score_pool3c # 1/8 223 | 224 | h = self.upscore8(h) 225 | h = h[:, :, 31:31 + self.input_size[2], 31:31 + self.input_size[3]].contiguous() 226 | return h 227 | 228 | def forward_backbone(self, x, output_features=['layer4']): 229 | features = {} 230 | h = x 231 | h = self.relu1_1(self.conv1_1(h)) 232 | h = self.relu1_2(self.conv1_2(h)) 233 | h = self.pool1(h) 234 | 235 | h = self.relu2_1(self.conv2_1(h)) 236 | h = self.relu2_2(self.conv2_2(h)) 237 | h = self.pool2(h) 238 | 239 | h = self.relu3_1(self.conv3_1(h)) 240 | h = self.relu3_2(self.conv3_2(h)) 241 | h = self.relu3_3(self.conv3_3(h)) 242 | h = self.pool3(h) 243 | pool3 = h # 1/8 244 | features['layer2'] = pool3 245 | 246 | h = self.relu4_1(self.conv4_1(h)) 247 | h = self.relu4_2(self.conv4_2(h)) 248 | h = self.relu4_3(self.conv4_3(h)) 249 | h = self.pool4(h) 250 | pool4 = h # 1/16 251 | features['layer3'] = pool4 252 | 253 | h = self.relu5_1(self.conv5_1(h)) 254 | h = self.relu5_2(self.conv5_2(h)) 255 | h = self.relu5_3(self.conv5_3(h)) 256 | h = self.pool5(h) 257 | pool5 = h # 1/32 258 | features['layer4'] = pool5 259 | 260 | h = self.relu6(self.fc6(h)) 261 | h = self.drop6(h) 262 | 263 | h = self.relu7(self.fc7(h)) 264 | h = self.drop7(h) 265 | features['layer5'] = h 266 | return features 267 | 268 | def forward(self, x, output_features=['layer4'], task='new_seg'): 269 | ''' 270 | task: 'old' | 'new' | 'new_seg' 271 | 'old', 'new': classification tasks (ImageNet or Visda) 272 | 'new_seg': segmentation head (convs) 273 | ''' 274 | self.input_size = x.size() 275 | ###### standard FCN ################## 276 | features = self.forward_backbone(x, output_features) 277 | x = self.forward_fc(features, task=task) 278 | ###################################### 279 | return x, features 280 | 281 | def copy_params_from_vgg16(self, vgg16): 282 | features = [ 283 | self.conv1_1, self.relu1_1, 284 | self.conv1_2, self.relu1_2, 285 | self.pool1, 286 | self.conv2_1, self.relu2_1, 287 | self.conv2_2, self.relu2_2, 288 | self.pool2, 289 | self.conv3_1, self.relu3_1, 290 | self.conv3_2, self.relu3_2, 291 | self.conv3_3, self.relu3_3, 292 | self.pool3, 293 | self.conv4_1, self.relu4_1, 294 | self.conv4_2, self.relu4_2, 295 | self.conv4_3, self.relu4_3, 296 | self.pool4, 297 | self.conv5_1, self.relu5_1, 298 | self.conv5_2, self.relu5_2, 299 | self.conv5_3, self.relu5_3, 300 | self.pool5, 301 | ] 302 | for l1, l2 in zip(vgg16.features, features): 303 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 304 | assert l1.weight.size() == l2.weight.size() 305 | assert l1.bias.size() == l2.bias.size() 306 | l2.weight.data.copy_(l1.weight.data) 307 | l2.bias.data.copy_(l1.bias.data) 308 | for i, name in zip([0, 3], ['fc6', 'fc7']): 309 | l1 = vgg16.classifier[i] 310 | l2 = getattr(self, name) 311 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 312 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 313 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from pdb import set_trace as bp 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=dilation, groups=groups, bias=False, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes, out_planes, stride=1): 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 41 | base_width=64, dilation=1, norm_layer=None): 42 | super(BasicBlock, self).__init__() 43 | if norm_layer is None: 44 | norm_layer = nn.BatchNorm2d 45 | if groups != 1 or base_width != 64: 46 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 47 | if dilation > 1: 48 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 49 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 50 | self.conv1 = conv3x3(inplanes, planes, stride) 51 | self.bn1 = norm_layer(planes) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.conv2 = conv3x3(planes, planes) 54 | self.bn2 = norm_layer(planes) 55 | self.downsample = downsample 56 | self.stride = stride 57 | 58 | def forward(self, x): 59 | identity = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | identity = self.downsample(x) 70 | 71 | out += identity 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | expansion = 4 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 81 | base_width=64, dilation=1, norm_layer=None): 82 | super(Bottleneck, self).__init__() 83 | if norm_layer is None: 84 | norm_layer = nn.BatchNorm2d 85 | width = int(planes * (base_width / 64.)) * groups 86 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 87 | self.conv1 = conv1x1(inplanes, width) 88 | self.bn1 = norm_layer(width) 89 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 90 | self.bn2 = norm_layer(width) 91 | self.conv3 = conv1x1(width, planes * self.expansion) 92 | self.bn3 = norm_layer(planes * self.expansion) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | identity = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | if self.downsample is not None: 112 | identity = self.downsample(x) 113 | 114 | out += identity 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class ResNet(nn.Module): 121 | 122 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 123 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 124 | norm_layer=None): 125 | super(ResNet, self).__init__() 126 | if norm_layer is None: 127 | norm_layer = nn.BatchNorm2d 128 | self._norm_layer = norm_layer 129 | 130 | self.inplanes = 64 131 | self.dilation = 1 132 | if replace_stride_with_dilation is None: 133 | # each element in the tuple indicates if we should replace 134 | # the 2x2 stride with a dilated convolution instead 135 | replace_stride_with_dilation = [False, False, False] 136 | if len(replace_stride_with_dilation) != 3: 137 | raise ValueError("replace_stride_with_dilation should be None " 138 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 139 | self.groups = groups 140 | self.base_width = width_per_group 141 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 142 | bias=False) 143 | self.bn1 = norm_layer(self.inplanes) 144 | self.relu = nn.ReLU(inplace=True) 145 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 146 | self.layer1 = self._make_layer(block, 64, layers[0]) 147 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 148 | dilate=replace_stride_with_dilation[0]) 149 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 150 | dilate=replace_stride_with_dilation[1]) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 152 | dilate=replace_stride_with_dilation[2]) 153 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 154 | self.fc = nn.Linear(512 * block.expansion, num_classes) 155 | 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 159 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | 163 | # Zero-initialize the last BN in each residual branch, 164 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 165 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 166 | if zero_init_residual: 167 | for m in self.modules(): 168 | if isinstance(m, Bottleneck): 169 | nn.init.constant_(m.bn3.weight, 0) 170 | elif isinstance(m, BasicBlock): 171 | nn.init.constant_(m.bn2.weight, 0) 172 | 173 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 174 | norm_layer = self._norm_layer 175 | downsample = None 176 | previous_dilation = self.dilation 177 | if dilate: 178 | self.dilation *= stride 179 | stride = 1 180 | if stride != 1 or self.inplanes != planes * block.expansion: 181 | downsample = nn.Sequential( 182 | conv1x1(self.inplanes, planes * block.expansion, stride), 183 | norm_layer(planes * block.expansion), 184 | ) 185 | 186 | layers = [] 187 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 188 | self.base_width, previous_dilation, norm_layer)) 189 | self.inplanes = planes * block.expansion 190 | for _ in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, groups=self.groups, 192 | base_width=self.base_width, dilation=self.dilation, 193 | norm_layer=norm_layer)) 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def forward_fc(self, f4, task='old', f3=None, f2=None): 198 | x = f4 199 | if task in ['old', 'new']: 200 | x = self.avgpool(x) 201 | x = x.reshape(x.size(0), -1) 202 | if task == 'old': 203 | x = self.fc(x) 204 | else: 205 | x = self.fc_new(x) 206 | return x 207 | 208 | def forward_backbone(self, x, output_features=['layer4']): 209 | features = {} 210 | f0 = self.conv1(x) 211 | f0 = self.bn1(f0) 212 | f0 = self.relu(f0) 213 | f0 = self.maxpool(f0) 214 | if 'layer0' in output_features: features['layer0'] = f0 215 | f1 = self.layer1(f0) 216 | if 'layer1' in output_features: features['layer1'] = f1 217 | f2 = self.layer2(f1) 218 | if 'layer2' in output_features: features['layer2'] = f2 219 | f3 = self.layer3(f2) 220 | if 'layer3' in output_features: features['layer3'] = f3 221 | f4 = self.layer4(f3) 222 | if 'layer4' in output_features: features['layer4'] = f4 223 | return f4, features 224 | # return f4, f3, f2, features 225 | 226 | def forward(self, x, output_features=['layer4'], task='old'): 227 | ''' 228 | task: 'old' | 'new' | 'new_seg' 229 | 'old', 'new': classification tasks (ImageNet or Visda) 230 | 'new_seg': segmentation head (convs) 231 | ''' 232 | ###### standard FCN ################## 233 | f4, features = self.forward_backbone(x, output_features) 234 | x = self.forward_fc(f4, task=task) 235 | ###################################### 236 | return x, features 237 | 238 | 239 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 240 | model = ResNet(block, layers, **kwargs) 241 | if pretrained: 242 | from torchvision.models.utils import load_state_dict_from_url 243 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 244 | # model.load_state_dict(state_dict) 245 | state = model.state_dict() 246 | pretrained_dict = {k: v for k, v in state_dict.items() if k in state and state[k].size() == v.size()} 247 | state.update(pretrained_dict) 248 | model.load_state_dict(state) 249 | return model 250 | 251 | 252 | def resnet18(pretrained=False, progress=True, **kwargs): 253 | """Constructs a ResNet-18 model. 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | progress (bool): If True, displays a progress bar of the download to stderr 258 | """ 259 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 260 | **kwargs) 261 | 262 | 263 | def resnet34(pretrained=False, progress=True, **kwargs): 264 | """Constructs a ResNet-34 model. 265 | 266 | Args: 267 | pretrained (bool): If True, returns a model pre-trained on ImageNet 268 | progress (bool): If True, displays a progress bar of the download to stderr 269 | """ 270 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 271 | **kwargs) 272 | 273 | 274 | def resnet50(pretrained=False, progress=True, **kwargs): 275 | """Constructs a ResNet-50 model. 276 | 277 | Args: 278 | pretrained (bool): If True, returns a model pre-trained on ImageNet 279 | progress (bool): If True, displays a progress bar of the download to stderr 280 | """ 281 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 282 | **kwargs) 283 | 284 | 285 | def resnet101(pretrained=False, progress=True, **kwargs): 286 | """Constructs a ResNet-101 model. 287 | 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 293 | **kwargs) 294 | 295 | 296 | def resnet152(pretrained=False, progress=True, **kwargs): 297 | """Constructs a ResNet-152 model. 298 | 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | """ 303 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 304 | **kwargs) 305 | 306 | 307 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 308 | """Constructs a ResNeXt-50 32x4d model. 309 | 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | kwargs['groups'] = 32 315 | kwargs['width_per_group'] = 4 316 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 317 | pretrained, progress, **kwargs) 318 | 319 | 320 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 321 | """Constructs a ResNeXt-101 32x8d model. 322 | 323 | Args: 324 | pretrained (bool): If True, returns a model pre-trained on ImageNet 325 | progress (bool): If True, displays a progress bar of the download to stderr 326 | """ 327 | kwargs['groups'] = 32 328 | kwargs['width_per_group'] = 8 329 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 330 | pretrained, progress, **kwargs) 331 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pdb import set_trace as bp 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 21 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 22 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 23 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 24 | } 25 | 26 | 27 | class VGG(nn.Module): 28 | 29 | def __init__(self, features, num_classes=1000, init_weights=True): 30 | super(VGG, self).__init__() 31 | self.features = features 32 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 33 | self.classifier = nn.Sequential( 34 | nn.Linear(512 * 7 * 7, 4096), 35 | nn.ReLU(True), 36 | nn.Dropout(), 37 | nn.Linear(4096, 4096), 38 | nn.ReLU(True), 39 | nn.Dropout(), 40 | nn.Linear(4096, num_classes), 41 | ) 42 | if init_weights: 43 | self._initialize_weights() 44 | 45 | def forward_fc(self, f4, task='old'): 46 | x = f4 47 | if task in ['old', 'new']: 48 | x = self.avgpool(x) 49 | x = torch.flatten(x, 1) 50 | if task == 'old': 51 | x = self.classifier(x) 52 | else: 53 | x = self.fc_new(x) 54 | return x 55 | 56 | def forward_backbone(self, x, output_features=['layer4']): 57 | features = {} 58 | f4 = self.features(x) 59 | if 'layer4' in output_features: features['layer4'] = f4 60 | return f4, features 61 | 62 | def forward(self, x, output_features=['layer4'], task='old'): 63 | ''' 64 | task: 'old' | 'new' | 'new_seg' 65 | 'old', 'new': classification tasks (ImageNet or Visda) 66 | 'new_seg': segmentation head (convs) 67 | ''' 68 | ###### standard FCN ################## 69 | f4, features = self.forward_backbone(x, output_features) 70 | x = self.forward_fc(f4, task=task) 71 | ###################################### 72 | return x, features 73 | 74 | def _initialize_weights(self): 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 78 | if m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | nn.init.constant_(m.weight, 1) 82 | nn.init.constant_(m.bias, 0) 83 | elif isinstance(m, nn.Linear): 84 | nn.init.normal_(m.weight, 0, 0.01) 85 | nn.init.constant_(m.bias, 0) 86 | 87 | 88 | def make_layers(cfg, batch_norm=False): 89 | layers = [] 90 | in_channels = 3 91 | for idx, v in enumerate(cfg): 92 | if v == 'M': 93 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 94 | else: 95 | if idx == 0: 96 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=100) 97 | else: 98 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 99 | if batch_norm: 100 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 101 | else: 102 | layers += [conv2d, nn.ReLU(inplace=True)] 103 | in_channels = v 104 | return nn.Sequential(*layers) 105 | 106 | 107 | cfgs = { 108 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 109 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 110 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 111 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 112 | } 113 | 114 | 115 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 116 | if pretrained: 117 | kwargs['init_weights'] = False 118 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 119 | if pretrained: 120 | from torchvision.models.utils import load_state_dict_from_url 121 | state_dict = load_state_dict_from_url(model_urls[arch], 122 | progress=progress) 123 | model.load_state_dict(state_dict) 124 | return model 125 | 126 | 127 | def vgg11(pretrained=False, progress=True, **kwargs): 128 | r"""VGG 11-layer model (configuration "A") from 129 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 130 | 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | progress (bool): If True, displays a progress bar of the download to stderr 134 | """ 135 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 136 | 137 | 138 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 139 | r"""VGG 11-layer model (configuration "A") with batch normalization 140 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 141 | 142 | Args: 143 | pretrained (bool): If True, returns a model pre-trained on ImageNet 144 | progress (bool): If True, displays a progress bar of the download to stderr 145 | """ 146 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 147 | 148 | 149 | def vgg13(pretrained=False, progress=True, **kwargs): 150 | r"""VGG 13-layer model (configuration "B") 151 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 152 | 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | progress (bool): If True, displays a progress bar of the download to stderr 156 | """ 157 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 158 | 159 | 160 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 161 | r"""VGG 13-layer model (configuration "B") with batch normalization 162 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 163 | 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | progress (bool): If True, displays a progress bar of the download to stderr 167 | """ 168 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 169 | 170 | 171 | def vgg16(pretrained=False, progress=True, **kwargs): 172 | r"""VGG 16-layer model (configuration "D") 173 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | progress (bool): If True, displays a progress bar of the download to stderr 178 | """ 179 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 180 | 181 | 182 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 183 | r"""VGG 16-layer model (configuration "D") with batch normalization 184 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 185 | 186 | Args: 187 | pretrained (bool): If True, returns a model pre-trained on ImageNet 188 | progress (bool): If True, displays a progress bar of the download to stderr 189 | """ 190 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 191 | 192 | 193 | def vgg19(pretrained=False, progress=True, **kwargs): 194 | r"""VGG 19-layer model (configuration "E") 195 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 196 | 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | progress (bool): If True, displays a progress bar of the download to stderr 200 | """ 201 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 202 | 203 | 204 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 205 | r"""VGG 19-layer model (configuration 'E') with batch normalization 206 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | progress (bool): If True, displays a progress bar of the download to stderr 211 | """ 212 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 213 | -------------------------------------------------------------------------------- /reinforce/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | -------------------------------------------------------------------------------- /reinforce/algo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | from .reinforce import REINFORCE 5 | # from .a2c_acktr import A2C_ACKTR 6 | # from .ppo import PPO 7 | -------------------------------------------------------------------------------- /reinforce/algo/reinforce.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from pdb import set_trace as bp 8 | 9 | 10 | class REINFORCE(): 11 | def __init__(self, 12 | actor_critic, 13 | entropy_coef, 14 | lr=None, 15 | eps=None, 16 | alpha=None, 17 | max_grad_norm=None, 18 | acktr=False): 19 | 20 | self.actor_critic = actor_critic 21 | self.acktr = acktr 22 | self.entropy_coef = entropy_coef 23 | self.max_grad_norm = max_grad_norm 24 | # self.optimizer = optim.Adam(actor_critic.parameters(), lr)#, eps=eps) 25 | self.optimizer = optim.SGD(actor_critic.parameters(), lr, momentum=0.9)#, eps=eps) 26 | 27 | def update(self, rollouts): 28 | obs_shape = rollouts.obs.size()[2:] 29 | action_shape = rollouts.actions.size()[-1] 30 | num_steps, num_processes, _ = rollouts.rewards.size() 31 | 32 | values, action_log_probs, dist_entropy, _, distribution = self.actor_critic.evaluate_actions( 33 | # rollouts.obs[:-1].view(-1, *obs_shape), 34 | # rollouts.recurrent_hidden_states[0].view(-1, self.actor_critic.recurrent_hidden_state_size), 35 | # rollouts.actions.view(-1, action_shape) 36 | rollouts.obs[:-1], 37 | rollouts.recurrent_hidden_states[0], 38 | rollouts.actions 39 | ) 40 | 41 | values = values.view(num_steps, num_processes, 1) 42 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1) 43 | 44 | # advantages = rollouts.returns[:-1] - values 45 | advantages = rollouts.returns[:-1] 46 | 47 | action_loss = -(advantages.detach() * action_log_probs).mean() 48 | 49 | self.optimizer.zero_grad() 50 | # (action_loss - dist_entropy * self.entropy_coef).backward() 51 | action_loss.backward() 52 | 53 | # nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm) 54 | 55 | self.optimizer.step() 56 | 57 | return 0, action_loss.item(), dist_entropy.item() 58 | -------------------------------------------------------------------------------- /reinforce/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import argparse 5 | import torch 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser(description='RL') 10 | parser.add_argument('--data', default='/raid/taskcv-2017-public/classification/data', help='path to dataset') 11 | parser.add_argument('--epochs', default=300, type=int, help='number of total epochs to run') 12 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 13 | parser.add_argument('--batch-size', default=64, type=int, dest='batch_size', help='mini-batch size (default: 64)') 14 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate') 15 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)') 16 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') 17 | parser.add_argument('--early-stop', default=1, type=int, dest='early_stop', help='limit the optimizer only sees partial optimizee epoch') 18 | parser.add_argument('--lwf', default=0., type=float, dest='lwf', help='weight of KL loss for LwF (default: 0)') 19 | parser.add_argument('--resume', default='none', type=str, help='path to latest checkpoint (default: none)') 20 | parser.add_argument('--timestamp', type=str, default='none', help='timestamp for logging naming') 21 | parser.add_argument('--save_dir', type=str, default="./runs", help='root folder to save checkpoints and log.') 22 | parser.add_argument('--agent_load_dir', type=str, default="/raid/ASG/pretrained/policy_res101_vista17.pth", help='path to pretrained L2O policy model.') 23 | parser.add_argument('--train_blocks', type=str, default="conv1.bn1.layer1.layer2.layer3.layer4.fc", help='blocks to train, seperated by dot.') 24 | parser.add_argument('--num-class', default=12, type=int, dest='num_class', help='the number of classes') 25 | parser.add_argument('--rand_seed', default=0, type=int, help='the number of classes') 26 | parser.add_argument('--algo', default='a2c', help='algorithm to use: a2c | ppo | acktr') 27 | parser.add_argument('--gail', action='store_true', default=False, help='do imitation learning with gail') 28 | parser.add_argument('--gail-experts-dir', default='./gail_experts', help='directory that contains expert demonstrations for gail') 29 | parser.add_argument('--gail-batch-size', type=int, default=128, help='gail batch size (default: 128)') 30 | parser.add_argument('--gail-epoch', type=int, default=5, help='gail epochs (default: 5)') 31 | parser.add_argument('--lr-meta', type=float, default=7e-4, help='learning rate (default: 7e-4)') 32 | parser.add_argument('--eps', type=float, default=1e-5, help='RMSprop optimizer epsilon (default: 1e-5)') 33 | parser.add_argument('--alpha', type=float, default=0.99, help='RMSprop optimizer apha (default: 0.99)') 34 | parser.add_argument('--gamma', type=float, default=0.99, help='discount factor for rewards (default: 0.99)') 35 | parser.add_argument('--use-gae', action='store_true', default=False, help='use generalized advantage estimation') 36 | parser.add_argument('--gae-lambda', type=float, default=0.95, help='gae lambda parameter (default: 0.95)') 37 | parser.add_argument('--entropy-coef', type=float, default=0.01, help='entropy term coefficient (default: 0.01)') 38 | parser.add_argument('--value-loss-coef', type=float, default=0.5, help='value loss coefficient (default: 0.5)') 39 | parser.add_argument('--max-grad-norm', type=float, default=0.5, help='max norm of gradients (default: 0.5)') 40 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 41 | parser.add_argument('--cuda-deterministic', action='store_true', default=False, help="sets flags for determinism when using CUDA (potentially slow!)") 42 | parser.add_argument('--num-steps', type=int, default=5, help='number of forward steps in A2C (default: 5)') 43 | parser.add_argument('--ppo-epoch', type=int, default=4, help='number of ppo epochs (default: 4)') 44 | parser.add_argument('--num-mini-batch', type=int, default=32, help='number of batches for ppo (default: 32)') 45 | parser.add_argument('--clip-param', type=float, default=0.2, help='ppo clip parameter (default: 0.2)') 46 | parser.add_argument('--log-interval', type=int, default=10, help='log interval, one log per n updates (default: 10)') 47 | parser.add_argument('--save-interval', type=int, default=100, help='save interval, one save per n updates (default: 100)') 48 | parser.add_argument('--eval-interval', type=int, default=None, help='eval interval, one eval per n updates (default: None)') 49 | parser.add_argument('--num-env-steps', type=int, default=10e6, help='number of environment steps to train (default: 10e6)') 50 | parser.add_argument('--use-proper-time-limits', action='store_true', default=False, help='compute returns taking into account time limits') 51 | parser.add_argument('--no-recurrent-policy', action='store_false', default=True, help='do not use a recurrent policy') 52 | parser.add_argument('--use-linear-lr-decay', action='store_true', default=False, help='use a linear schedule on the learning rate') 53 | parser.add_argument('--gpus', default=0, type=int, help='use gpu with cuda number') 54 | args = parser.parse_args() 55 | 56 | args.cuda = torch.cuda.is_available() 57 | 58 | # assert args.algo in ['a2c', 'ppo', 'acktr'] 59 | if not args.no_recurrent_policy: 60 | assert args.algo in ['a2c', 'ppo'], \ 61 | 'Recurrent policy is not implemented for ACKTR' 62 | 63 | return args 64 | -------------------------------------------------------------------------------- /reinforce/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | from pdb import set_trace as bp 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from reinforce.utils import init 10 | 11 | 12 | class FixedCategorical(torch.distributions.Categorical): 13 | def sample(self): 14 | return super().sample().unsqueeze(-1) 15 | 16 | def log_probs(self, actions): 17 | return ( 18 | super() 19 | .log_prob(actions.squeeze(-1)) 20 | .view(actions.size(0), -1) 21 | .sum(-1) 22 | .unsqueeze(-1) 23 | ) 24 | 25 | def mode(self): 26 | return self.probs.argmax(dim=-1, keepdim=True) 27 | 28 | 29 | class FixedNormal(torch.distributions.Normal): 30 | def log_probs(self, actions): 31 | return super().log_prob(actions).sum(-1, keepdim=True) 32 | 33 | def entrop(self): 34 | return super.entropy().sum(-1) 35 | 36 | def mode(self): 37 | return self.mean 38 | 39 | 40 | class Categorical(nn.Module): 41 | def __init__(self, num_inputs, num_outputs, coord_size=1): 42 | # num_inputs: #features for each coord 43 | # num_outputs: action_space 44 | super(Categorical, self).__init__() 45 | self.num_inputs = num_inputs 46 | self.num_outputs = num_outputs 47 | self.coord_size = coord_size 48 | 49 | init_ = lambda m: init( 50 | m, 51 | nn.init.orthogonal_, 52 | lambda x: nn.init.constant_(x, 0), 53 | gain=0.01) 54 | 55 | self.linear = nn.ModuleList([ 56 | init_(nn.Linear(num_inputs, num_outputs)) 57 | for _ in range(coord_size) 58 | ]) 59 | 60 | def forward(self, x): 61 | # x: (coord, batch, *features) 62 | # will coordinate-wisely return distributions 63 | distributions = [] 64 | for coord in range(self.coord_size): 65 | dist = FixedCategorical(logits=self.linear[coord](x[coord])) 66 | distributions.append(dist) 67 | return MultiCategorical(distributions) 68 | 69 | 70 | class MultiCategorical(nn.Module): 71 | def __init__(self, distributions): 72 | super(MultiCategorical, self).__init__() 73 | # coordinate-wise distributions 74 | self.distributions = distributions 75 | 76 | def sample(self): 77 | actions = [] 78 | for dist in self.distributions: 79 | actions.append(dist.sample()) 80 | return torch.cat(actions, dim=1) 81 | 82 | def log_probs(self, actions, is_sum=True): 83 | # actions: (batch, coord) 84 | log_probs = [] 85 | for coord in range(len(self.distributions)): 86 | try: 87 | log_probs.append(self.distributions[coord].log_probs(actions[:, coord:coord+1])) 88 | except: 89 | bp() 90 | log_probs.append(self.distributions[coord].log_probs(actions[:, coord:coord+1])) 91 | log_probs = torch.cat(log_probs, dim=1) 92 | if is_sum: 93 | return log_probs.sum(-1).unsqueeze(-1) 94 | else: 95 | return log_probs 96 | 97 | def entropy(self): 98 | # actions: (batch, coord) 99 | entropies = [] 100 | for coord in range(len(self.distributions)): 101 | entropies.append(self.distributions[coord].entropy()) 102 | entropies = torch.cat(entropies, dim=0) 103 | return entropies.unsqueeze(-1) 104 | 105 | def mode(self): 106 | actions = [] 107 | for dist in self.distributions: 108 | actions.append(dist.probs.argmax(dim=-1, keepdim=True)) 109 | return torch.cat(actions, dim=1) 110 | 111 | 112 | class Gaussian(nn.Module): 113 | def __init__(self, num_inputs, num_outputs=1, mean_range=[0, 1], std_epsilon=0.001): 114 | super(Gaussian, self).__init__() 115 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0)) 116 | self._num_inputs = num_inputs 117 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 118 | self.fc_std = init_(nn.Linear(num_inputs, num_outputs)) 119 | assert len(mean_range) == 2 and mean_range[0] < mean_range[1] 120 | self.mean_min = mean_range[0] 121 | self.mean_max = mean_range[1] 122 | self.std_epsilon = std_epsilon 123 | 124 | def forward(self, x): 125 | # x = x.view(1, self._num_inputs) 126 | action_mean = self.fc_mean(x) 127 | action_mean = F.sigmoid(action_mean) * (self.mean_max - self.mean_min) + self.mean_min 128 | action_std = F.softplus(self.fc_std(x)) + self.std_epsilon 129 | return FixedNormal(action_mean, action_std) 130 | -------------------------------------------------------------------------------- /reinforce/models/policy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import abc 5 | from pdb import set_trace as bp 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | from reinforce.distributions import Categorical 11 | from reinforce.models.rnn_state_encoder import RNNStateEncoder 12 | 13 | 14 | class Policy(nn.Module): 15 | def __init__(self, coord_size, input_size=(1, 1), action_space=1, hidden_size=1, window_size=1): 16 | # input_size: (#lstm_input, #mlp_input) 17 | super().__init__() 18 | self.net = BasicNet(coord_size, input_size=input_size, hidden_size=hidden_size, window_size=window_size) 19 | # will coordinate-wisely return distributions 20 | self.action_distribution = Categorical(input_size[0]*hidden_size+input_size[1]+1, action_space, coord_size=coord_size) 21 | self.critic = CriticHead(coord_size * (input_size[0]*hidden_size+input_size[1]+1)) 22 | self.recurrent_hidden_state_size = hidden_size 23 | self.coord_size = coord_size 24 | self.input_size = input_size 25 | self.action_space = action_space 26 | self.hidden_size = hidden_size 27 | self.window_size = window_size 28 | 29 | def forward(self, *x): 30 | raise NotImplementedError 31 | 32 | def act(self, observations, rnn_hidden_states, deterministic=False): 33 | features, rnn_hidden_states = self.net(observations, rnn_hidden_states) 34 | distribution = self.action_distribution(features) 35 | # (coord, seq_len*batch, feature) ==> (seq_len*batch, coord, feature) 36 | value = self.critic(features.permute(1, 0, 2).view(features.size(1), -1)) 37 | 38 | if deterministic: 39 | action = distribution.mode() 40 | else: 41 | action = distribution.sample() 42 | 43 | action_log_probs = distribution.log_probs(action) 44 | return value, action, action_log_probs, rnn_hidden_states, distribution 45 | 46 | def get_value(self, observations, rnn_hidden_states): 47 | features, _ = self.net(observations, rnn_hidden_states) 48 | # features = features.view(-1, self.batch_size * self.recurrent_hidden_state_size) 49 | return self.critic(features.permute(1, 0, 2).view(features.size(1), -1)) 50 | 51 | def evaluate_actions(self, observations, rnn_hidden_states, action): 52 | features, rnn_hidden_states = self.net(observations, rnn_hidden_states) 53 | # features = features.view(-1, self.batch_size * self.recurrent_hidden_state_size) 54 | distribution = self.action_distribution(features) 55 | value = self.critic(features.permute(1, 0, 2).contiguous().view(features.size(1), -1)) 56 | 57 | action_log_probs = distribution.log_probs(action) 58 | distribution_entropy = distribution.entropy().mean() 59 | 60 | return value, action_log_probs, distribution_entropy, rnn_hidden_states, distribution 61 | 62 | 63 | 64 | class CriticHead(nn.Module): 65 | def __init__(self, input_size): 66 | super().__init__() 67 | self.fc = nn.Linear(input_size, 1) 68 | nn.init.orthogonal_(self.fc.weight) 69 | nn.init.constant_(self.fc.bias, 0) 70 | 71 | def forward(self, x): 72 | return self.fc(x) 73 | 74 | 75 | class Net(nn.Module, metaclass=abc.ABCMeta): 76 | @abc.abstractmethod 77 | def forward(self, observations, rnn_hidden_states, prev_actions): 78 | pass 79 | 80 | @property 81 | @abc.abstractmethod 82 | def output_size(self): 83 | pass 84 | 85 | @property 86 | @abc.abstractmethod 87 | def num_recurrent_layers(self): 88 | pass 89 | 90 | 91 | class BasicNet(Net): 92 | def __init__(self, coord_size, input_size=(1, 1), hidden_size=1, window_size=1): 93 | super().__init__() 94 | self._coord_size = coord_size 95 | # input_size: (#lstm_input, #mlp_input) 96 | self._input_size = input_size 97 | self._hidden_size = hidden_size 98 | self._window_size = window_size 99 | self.state_encoder = nn.ModuleList([ 100 | RNNStateEncoder(input_size=window_size, hidden_size=self._hidden_size) 101 | for _ in range(input_size[0]) 102 | ]) 103 | self.train() 104 | 105 | @property 106 | def output_size(self): 107 | return self._hidden_size 108 | 109 | @property 110 | def num_recurrent_layers(self): 111 | return self.state_encoder[0].num_recurrent_layers 112 | 113 | def forward(self, observations, rnn_hidden_states): 114 | # observation: (seq_len, batch_size, #lstm_input * window + #scalar_input + #actions * 1(LR)) 115 | # rnn_hidden_states: (#lstm_input * hidden_size) 116 | outputs = [] 117 | rnn_hidden_states_new = [] 118 | # coordinate-wise 119 | for i in range(self._input_size[0]): 120 | # output: (seq_len, batch(1), hidden_size) 121 | output, rnn_hidden_state = self.state_encoder[i](observations[:, :, i*self._window_size:(i+1)*self._window_size], rnn_hidden_states[:, :, i*self._hidden_size:(i+1)*self._hidden_size]) 122 | outputs.append(output) 123 | rnn_hidden_states_new.append(rnn_hidden_state) 124 | # outputs: (seq_len, batch(1), hidden_size * #lstm_input + #scalar_input) 125 | outputs = torch.cat(outputs + [observations[:, :, self._input_size[0]*self._window_size:self._input_size[0]*self._window_size+self._input_size[1]]], dim=2) 126 | # add LR feature for each coord 127 | outputs_LR = [] 128 | for coord in range(-self._coord_size, 0): 129 | outputs_LR.append(torch.cat([outputs, observations[:, :, observations.size(2)+coord:observations.size(2)+coord+1]], dim=2)) 130 | outputs_LR = torch.stack(outputs_LR, dim=0) # (coord, seq_len, 1, hidden_size * #lstm_input + #scalar_input + 1) 131 | outputs_LR = outputs_LR.view(self._coord_size, -1, outputs_LR.size(-1)) # (coord, seq_len * 1, hidden_size * #lstm_input + #scalar_input + 1) 132 | return outputs_LR, rnn_hidden_states 133 | -------------------------------------------------------------------------------- /reinforce/models/rnn_state_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pdb import set_trace as bp 7 | 8 | 9 | class RNNStateEncoder(nn.Module): 10 | def __init__(self, input_size: int = 1, hidden_size: int = 1, num_layers: int = 1, rnn_type: str = "LSTM"): 11 | r"""An RNN for encoding the state in RL. 12 | 13 | Supports masking the hidden state during various timesteps in the forward lass 14 | 15 | Args: 16 | input_size: The input size of the RNN 17 | hidden_size: The hidden size 18 | num_layers: The number of recurrent layers 19 | rnn_type: The RNN cell type. Must be GRU or LSTM 20 | """ 21 | 22 | super().__init__() 23 | self._num_recurrent_layers = num_layers 24 | self._rnn_type = rnn_type 25 | 26 | self.rnn = getattr(nn, rnn_type)( 27 | input_size=input_size, 28 | hidden_size=hidden_size, 29 | num_layers=num_layers, 30 | ) 31 | 32 | self.layer_init() 33 | 34 | def layer_init(self): 35 | for name, param in self.rnn.named_parameters(): 36 | if "weight" in name: 37 | nn.init.orthogonal_(param) 38 | elif "bias" in name: 39 | nn.init.constant_(param, 0) 40 | 41 | @property 42 | def num_recurrent_layers(self): 43 | return self._num_recurrent_layers * ( 44 | 2 if "LSTM" in self._rnn_type else 1 45 | ) 46 | 47 | def _pack_hidden(self, hidden_states): 48 | if "LSTM" in self._rnn_type: 49 | hidden_states = torch.cat( 50 | [hidden_states[0], hidden_states[1]], dim=0 51 | ) 52 | return hidden_states 53 | 54 | def _unpack_hidden(self, hidden_states): 55 | if "LSTM" in self._rnn_type: 56 | hidden_states = ( 57 | hidden_states[0 : self._num_recurrent_layers], 58 | hidden_states[self._num_recurrent_layers :], 59 | ) 60 | return hidden_states 61 | 62 | def single_forward(self, x, hidden_states): 63 | r"""Forward for a non-sequence input 64 | """ 65 | if len(x.size()) == 2: 66 | x = x.unsqueeze(0) 67 | # input: (seq_len, batch, input_size) 68 | x, hidden_states = self.rnn(x, hidden_states) 69 | return x, hidden_states 70 | 71 | def forward(self, x, hidden_states): 72 | hidden_states = self._unpack_hidden(hidden_states) 73 | x, hidden_states = self.single_forward(x, hidden_states) 74 | hidden_states = self._pack_hidden(hidden_states) 75 | return x, hidden_states 76 | -------------------------------------------------------------------------------- /reinforce/storage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import torch 5 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 6 | from pdb import set_trace as bp 7 | 8 | 9 | def _flatten_helper(T, N, _tensor): 10 | return _tensor.view(T * N, *_tensor.size()[2:]) 11 | 12 | 13 | class RolloutStorage(object): 14 | def __init__(self, num_steps, obs_shape, action_shape=1, hidden_size=1, num_recurrent_layers=1): 15 | # observation: (seq_len, batch_size, #lstm_input * window + #scalar_input + #actions * 1(LR)) 16 | self.obs = torch.zeros(num_steps + 1, 1, *obs_shape) 17 | self.recurrent_hidden_states = torch.zeros(num_steps + 1, num_recurrent_layers, 1, hidden_size) 18 | self.rewards = torch.zeros(num_steps, 1, 1) 19 | self.value_preds = torch.zeros(num_steps + 1, 1) 20 | self.returns = torch.zeros(num_steps + 1, 1) 21 | self.action_log_probs = torch.zeros(num_steps, 1) 22 | self.actions = torch.zeros(num_steps, action_shape) 23 | self.num_steps = num_steps 24 | self.step = 0 25 | 26 | def to(self, device): 27 | self.obs = self.obs.to(device) 28 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device) 29 | self.rewards = self.rewards.to(device) 30 | self.value_preds = self.value_preds.to(device) 31 | self.returns = self.returns.to(device) 32 | self.action_log_probs = self.action_log_probs.to(device) 33 | self.actions = self.actions.to(device) 34 | 35 | def insert(self, obs, recurrent_hidden_states, actions, action_log_probs, value_preds, rewards): 36 | self.obs[self.step + 1].copy_(obs) 37 | self.recurrent_hidden_states[self.step + 1].copy_(recurrent_hidden_states) 38 | self.actions[self.step].copy_(actions) 39 | self.action_log_probs[self.step].copy_(action_log_probs) 40 | self.value_preds[self.step].copy_(value_preds) 41 | self.rewards[self.step].copy_(rewards) 42 | self.step = (self.step + 1) % self.num_steps 43 | 44 | def after_update(self): 45 | self.obs[0].copy_(self.obs[-1]) 46 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) 47 | 48 | def compute_returns(self, next_value, use_gae, gamma, gae_lambda): 49 | if use_gae: 50 | self.value_preds[-1] = next_value 51 | gae = 0 52 | for step in reversed(range(self.rewards.size(0))): 53 | delta = self.rewards[step] + gamma * self.value_preds[step + 1] - self.value_preds[step] 54 | gae = delta + gamma * gae_lambda * gae 55 | self.returns[step] = gae + self.value_preds[step] 56 | else: 57 | self.returns[-1] = next_value 58 | for step in reversed(range(self.rewards.size(0))): 59 | self.returns[step] = self.returns[step + 1] * gamma + self.rewards[step] 60 | 61 | def feed_forward_generator(self, advantages, num_mini_batch=None, mini_batch_size=None): 62 | num_steps, num_processes = self.rewards.size()[0:2] 63 | batch_size = num_processes * num_steps 64 | 65 | if mini_batch_size is None: 66 | assert batch_size >= num_mini_batch, ( 67 | "PPO requires the number of processes ({}) " 68 | "* number of steps ({}) = {} " 69 | "to be greater than or equal to the number of PPO mini batches ({})." 70 | "".format(num_processes, num_steps, num_processes * num_steps, 71 | num_mini_batch)) 72 | mini_batch_size = batch_size // num_mini_batch 73 | sampler = BatchSampler( 74 | SubsetRandomSampler(range(batch_size)), 75 | mini_batch_size, 76 | drop_last=True) 77 | for indices in sampler: 78 | obs_batch = self.obs[:-1].view(-1, *self.obs.size()[1:])[indices] 79 | recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(-1, *self.recurrent_hidden_states.size()[1:])[indices] 80 | actions_batch = self.actions.view(-1, self.actions.size(-1))[indices] 81 | value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices] 82 | return_batch = self.returns[:-1].view(-1, 1)[indices] 83 | old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices] 84 | if advantages is None: 85 | adv_targ = None 86 | else: 87 | adv_targ = advantages.view(-1, 1)[indices] 88 | 89 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch, old_action_log_probs_batch, adv_targ 90 | 91 | def recurrent_generator(self, advantages, num_mini_batch): 92 | num_processes = self.rewards.size(1) 93 | assert num_processes >= num_mini_batch, ( 94 | "PPO requires the number of processes ({}) " 95 | "to be greater than or equal to the number of " 96 | "PPO mini batches ({}).".format(num_processes, num_mini_batch)) 97 | num_envs_per_batch = num_processes // num_mini_batch 98 | perm = torch.randperm(num_processes) 99 | for start_ind in range(0, num_processes, num_envs_per_batch): 100 | obs_batch = [] 101 | recurrent_hidden_states_batch = [] 102 | actions_batch = [] 103 | value_preds_batch = [] 104 | return_batch = [] 105 | old_action_log_probs_batch = [] 106 | adv_targ = [] 107 | 108 | for offset in range(num_envs_per_batch): 109 | ind = perm[start_ind + offset] 110 | obs_batch.append(self.obs[:-1, ind]) 111 | recurrent_hidden_states_batch.append(self.recurrent_hidden_states[0:1, ind]) 112 | actions_batch.append(self.actions[:, ind]) 113 | value_preds_batch.append(self.value_preds[:-1, ind]) 114 | return_batch.append(self.returns[:-1, ind]) 115 | old_action_log_probs_batch.append( 116 | self.action_log_probs[:, ind]) 117 | adv_targ.append(advantages[:, ind]) 118 | 119 | T, N = self.num_steps, num_envs_per_batch 120 | # These are all tensors of size (T, N, -1) 121 | obs_batch = torch.stack(obs_batch, 1) 122 | actions_batch = torch.stack(actions_batch, 1) 123 | value_preds_batch = torch.stack(value_preds_batch, 1) 124 | return_batch = torch.stack(return_batch, 1) 125 | old_action_log_probs_batch = torch.stack( 126 | old_action_log_probs_batch, 1) 127 | adv_targ = torch.stack(adv_targ, 1) 128 | 129 | # States is just a (N, -1) tensor 130 | recurrent_hidden_states_batch = torch.stack(recurrent_hidden_states_batch, 1).view(N, -1) 131 | 132 | # Flatten the (T, N, ...) tensors to (T * N, ...) 133 | obs_batch = _flatten_helper(T, N, obs_batch) 134 | actions_batch = _flatten_helper(T, N, actions_batch) 135 | value_preds_batch = _flatten_helper(T, N, value_preds_batch) 136 | return_batch = _flatten_helper(T, N, return_batch) 137 | old_action_log_probs_batch = _flatten_helper(T, N, \ 138 | old_action_log_probs_batch) 139 | adv_targ = _flatten_helper(T, N, adv_targ) 140 | 141 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 142 | value_preds_batch, return_batch, old_action_log_probs_batch, adv_targ 143 | -------------------------------------------------------------------------------- /reinforce/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import glob 5 | import os 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | # Get a render function 12 | def get_render_func(venv): 13 | if hasattr(venv, 'envs'): 14 | return venv.envs[0].render 15 | elif hasattr(venv, 'venv'): 16 | return get_render_func(venv.venv) 17 | elif hasattr(venv, 'env'): 18 | return get_render_func(venv.env) 19 | 20 | return None 21 | 22 | 23 | def get_vec_normalize(venv): 24 | if isinstance(venv, VecNormalize): 25 | return venv 26 | elif hasattr(venv, 'venv'): 27 | return get_vec_normalize(venv.venv) 28 | 29 | return None 30 | 31 | 32 | # Necessary for my KFAC implementation. 33 | class AddBias(nn.Module): 34 | def __init__(self, bias): 35 | super(AddBias, self).__init__() 36 | self._bias = nn.Parameter(bias.unsqueeze(1)) 37 | 38 | def forward(self, x): 39 | if x.dim() == 2: 40 | bias = self._bias.t().view(1, -1) 41 | else: 42 | bias = self._bias.t().view(1, -1, 1, 1) 43 | 44 | return x + bias 45 | 46 | 47 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 48 | """Decreases the learning rate linearly""" 49 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 50 | for param_group in optimizer.param_groups: 51 | param_group['lr'] = lr 52 | 53 | 54 | def init(module, weight_init, bias_init, gain=1): 55 | weight_init(module.weight.data, gain=gain) 56 | bias_init(module.bias.data) 57 | return module 58 | 59 | 60 | def cleanup_log_dir(log_dir): 61 | try: 62 | os.makedirs(log_dir) 63 | except OSError: 64 | files = glob.glob(os.path.join(log_dir, '*.monitor.csv')) 65 | for f in files: 66 | os.remove(f) 67 | -------------------------------------------------------------------------------- /tools/datasets/BaseDataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import os 5 | import cv2 6 | cv2.setNumThreads(0) 7 | import torch 8 | import numpy as np 9 | from random import shuffle 10 | 11 | import torch.utils.data as data 12 | 13 | 14 | class BaseDataset(data.Dataset): 15 | def __init__(self, setting, split_name, preprocess=None, file_length=None): 16 | super(BaseDataset, self).__init__() 17 | self._split_name = split_name 18 | self._img_path = setting['img_root'] 19 | self._gt_path = setting['gt_root'] 20 | self._portion = setting['portion'] if 'portion' in setting else None 21 | self._train_source = setting['train_source'] 22 | self._eval_source = setting['eval_source'] 23 | self._test_source = setting['test_source'] if 'test_source' in setting else setting['eval_source'] 24 | self._down_sampling = setting['down_sampling'] 25 | print("using downsampling:", self._down_sampling) 26 | self._file_names = self._get_file_names(split_name) 27 | print("Found %d images"%len(self._file_names)) 28 | self._file_length = file_length 29 | self.preprocess = preprocess 30 | 31 | def __len__(self): 32 | if self._file_length is not None: 33 | return self._file_length 34 | return len(self._file_names) 35 | 36 | def __getitem__(self, index): 37 | if self._file_length is not None: 38 | names = self._construct_new_file_names(self._file_length)[index] 39 | else: 40 | names = self._file_names[index] 41 | img_path = os.path.join(self._img_path, names[0]) 42 | gt_path = os.path.join(self._gt_path, names[1]) 43 | item_name = names[1].split("/")[-1].split(".")[0] 44 | 45 | img, gt = self._fetch_data(img_path, gt_path) 46 | img = img[:, :, ::-1] 47 | if self.preprocess is not None: 48 | img, gt, extra_dict = self.preprocess(img, gt) 49 | 50 | if self._split_name is 'train': 51 | img = torch.from_numpy(np.ascontiguousarray(img)).float() 52 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long() 53 | if self.preprocess is not None and extra_dict is not None: 54 | for k, v in extra_dict.items(): 55 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v)) 56 | if 'label' in k: 57 | extra_dict[k] = extra_dict[k].long() 58 | if 'img' in k: 59 | extra_dict[k] = extra_dict[k].float() 60 | 61 | output_dict = dict(data=img, label=gt, fn=str(item_name), 62 | n=len(self._file_names)) 63 | if self.preprocess is not None and extra_dict is not None: 64 | output_dict.update(**extra_dict) 65 | 66 | return output_dict 67 | 68 | def _fetch_data(self, img_path, gt_path, dtype=None): 69 | img = self._open_image(img_path, down_sampling=self._down_sampling) 70 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype, down_sampling=self._down_sampling) 71 | 72 | return img, gt 73 | 74 | def _get_file_names(self, split_name): 75 | assert split_name in ['train', 'val', 'test'] 76 | source = self._train_source 77 | if split_name == "val": 78 | source = self._eval_source 79 | elif split_name == 'test': 80 | source = self._test_source 81 | 82 | file_names = [] 83 | with open(source) as f: 84 | files = f.readlines() 85 | if self._portion is not None: 86 | shuffle(files) 87 | num_files = len(files) 88 | if self._portion > 0: 89 | split = int(np.floor(self._portion * num_files)) 90 | files = files[:split] 91 | elif self._portion < 0: 92 | split = int(np.floor((1 + self._portion) * num_files)) 93 | files = files[split:] 94 | 95 | for item in files: 96 | img_name, gt_name = self._process_item_names(item) 97 | file_names.append([img_name, gt_name]) 98 | 99 | return file_names 100 | 101 | def _construct_new_file_names(self, length): 102 | assert isinstance(length, int) 103 | files_len = len(self._file_names) 104 | new_file_names = self._file_names * (length // files_len) 105 | 106 | rand_indices = torch.randperm(files_len).tolist() 107 | new_indices = rand_indices[:length % files_len] 108 | 109 | new_file_names += [self._file_names[i] for i in new_indices] 110 | 111 | return new_file_names 112 | 113 | @staticmethod 114 | def _process_item_names(item): 115 | item = item.strip() 116 | # item = item.split('\t') 117 | item = item.split(' ') 118 | img_name = item[0] 119 | gt_name = item[1] 120 | 121 | return img_name, gt_name 122 | 123 | def get_length(self): 124 | return self.__len__() 125 | 126 | @staticmethod 127 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None, down_sampling=1): 128 | # cv2: B G R 129 | # h w c 130 | img = np.array(cv2.imread(filepath, mode), dtype=dtype) 131 | 132 | if isinstance(down_sampling, int): 133 | H, W = img.shape[:2] 134 | if len(img.shape) == 3: 135 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_LINEAR) 136 | else: 137 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_NEAREST) 138 | assert img.shape[0] == H // down_sampling and img.shape[1] == W // down_sampling 139 | else: 140 | assert (isinstance(down_sampling, tuple) or isinstance(down_sampling, list)) and len(down_sampling) == 2 141 | if len(img.shape) == 3: 142 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_LINEAR) 143 | else: 144 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_NEAREST) 145 | assert img.shape[0] == down_sampling[0] and img.shape[1] == down_sampling[1] 146 | 147 | return img 148 | 149 | @classmethod 150 | def get_class_colors(*args): 151 | raise NotImplementedError 152 | 153 | @classmethod 154 | def get_class_names(*args): 155 | raise NotImplementedError 156 | 157 | 158 | if __name__ == "__main__": 159 | data_setting = {'img_root': '', 160 | 'gt_root': '', 161 | 'train_source': '', 162 | 'eval_source': ''} 163 | bd = BaseDataset(data_setting, 'train', None) 164 | print(bd.get_class_names()) 165 | -------------------------------------------------------------------------------- /tools/datasets/cityscapes/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | 6 | from datasets.BaseDataset import BaseDataset 7 | 8 | 9 | class Cityscapes(BaseDataset): 10 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 11 | 28, 31, 32, 33] 12 | 13 | @classmethod 14 | def get_class_colors(*args): 15 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70], 16 | [102, 102, 156], [190, 153, 153], [153, 153, 153], 17 | [250, 170, 30], [220, 220, 0], [107, 142, 35], 18 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], 19 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], 20 | [0, 0, 230], [119, 11, 32]] 21 | 22 | @classmethod 23 | def get_class_names(*args): 24 | # class counting(gtFine) 25 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832 26 | # 359 274 142 513 1646 27 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 28 | 'traffic light', 'traffic sign', 29 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 30 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] 31 | 32 | @classmethod 33 | def transform_label(cls, pred, name): 34 | label = np.zeros(pred.shape) 35 | ids = np.unique(pred) 36 | for id in ids: 37 | label[np.where(pred == id)] = cls.trans_labels[id] 38 | 39 | new_name = (name.split('.')[0]).split('_')[:-1] 40 | new_name = '_'.join(new_name) + '.png' 41 | 42 | print('Trans', name, 'to', new_name, ' ', 43 | np.unique(np.array(pred, np.uint8)), ' ---------> ', 44 | np.unique(np.array(label, np.uint8))) 45 | return label, new_name 46 | -------------------------------------------------------------------------------- /tools/engine/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import os 5 | from PIL import Image 6 | import cv2 7 | import numpy as np 8 | import time 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import torch.multiprocessing as mp 13 | 14 | from tools.engine.logger import get_logger 15 | from tools.utils.pyt_utils import load_model, link_file, ensure_dir 16 | from tools.utils.img_utils import pad_image_to_shape, normalize 17 | 18 | logger = get_logger() 19 | 20 | 21 | class Evaluator(object): 22 | def __init__(self, dataset, class_num, image_mean, image_std, network, 23 | multi_scales, is_flip, devices=0, threds=3, config=None, logger=None, 24 | verbose=False, save_path=None, show_image=False, show_prediction=False): 25 | self.dataset = dataset 26 | self.ndata = self.dataset.get_length() 27 | self.class_num = class_num 28 | self.image_mean = image_mean 29 | self.image_std = image_std 30 | self.multi_scales = multi_scales 31 | self.is_flip = is_flip 32 | self.network = network 33 | self.devices = devices 34 | if type(self.devices) == int: self.devices = [self.devices] 35 | self.threds = threds 36 | self.config = config 37 | self.logger = logger 38 | 39 | self.context = mp.get_context('spawn') 40 | self.val_func = None 41 | self.results_queue = self.context.Queue(self.ndata) 42 | self.features_queue = self.context.Queue(self.ndata) 43 | 44 | self.verbose = verbose 45 | self.save_path = save_path 46 | if save_path is not None: 47 | ensure_dir(save_path) 48 | self.show_image = show_image 49 | self.show_prediction = show_prediction 50 | 51 | def run(self, model_path, model_indice, log_file, log_file_link): 52 | """There are four evaluation modes: 53 | 1.only eval a .pth model: -e *.pth 54 | 2.only eval a certain epoch: -e epoch 55 | 3.eval all epochs in a given section: -e start_epoch-end_epoch 56 | 4.eval all epochs from a certain started epoch: -e start_epoch- 57 | """ 58 | if '.pth' in model_indice: 59 | models = [model_indice, ] 60 | elif "-" in model_indice: 61 | start_epoch = int(model_indice.split("-")[0]) 62 | end_epoch = model_indice.split("-")[1] 63 | 64 | models = os.listdir(model_path) 65 | models.remove("epoch-last.pth") 66 | sorted_models = [None] * len(models) 67 | model_idx = [0] * len(models) 68 | 69 | for idx, m in enumerate(models): 70 | num = m.split(".")[0].split("-")[1] 71 | model_idx[idx] = num 72 | sorted_models[idx] = m 73 | model_idx = np.array([int(i) for i in model_idx]) 74 | 75 | down_bound = model_idx >= start_epoch 76 | up_bound = [True] * len(sorted_models) 77 | if end_epoch: 78 | end_epoch = int(end_epoch) 79 | assert start_epoch < end_epoch 80 | up_bound = model_idx <= end_epoch 81 | bound = up_bound * down_bound 82 | model_slice = np.array(sorted_models)[bound] 83 | models = [os.path.join(model_path, model) for model in 84 | model_slice] 85 | else: 86 | models = [os.path.join(model_path, 87 | 'epoch-%s.pth' % model_indice), ] 88 | 89 | results = open(log_file, 'a') 90 | link_file(log_file, log_file_link) 91 | 92 | for model in models: 93 | logger.info("Load Model: %s" % model) 94 | self.val_func = load_model(self.network, model) 95 | result_line, mIoU = self.multi_process_evaluation() 96 | 97 | results.write('Model: ' + model + '\n') 98 | results.write(result_line) 99 | results.write('\n') 100 | results.flush() 101 | 102 | results.close() 103 | 104 | def run_online(self): 105 | """ 106 | eval during training 107 | """ 108 | self.val_func = self.network 109 | result_line, mIoU = self.single_process_evaluation() 110 | return result_line, mIoU 111 | 112 | def single_process_evaluation(self): 113 | all_results = [] 114 | from pdb import set_trace as bp 115 | with torch.no_grad(): 116 | for idx in tqdm(range(self.ndata)): 117 | dd = self.dataset[idx] 118 | results_dict = self.func_per_iteration(dd, self.devices[0], iter=idx) 119 | all_results.append(results_dict) 120 | _, _mIoU = self.compute_metric([results_dict]) 121 | result_line, mIoU = self.compute_metric(all_results) 122 | return result_line, mIoU 123 | 124 | def run_online_multiprocess(self): 125 | """ 126 | eval during training 127 | """ 128 | self.val_func = self.network 129 | result_line, mIoU = self.multi_process_single_gpu_evaluation() 130 | return result_line, mIoU 131 | 132 | def multi_process_single_gpu_evaluation(self): 133 | # start_eval_time = time.perf_counter() 134 | stride = int(np.ceil(self.ndata / self.threds)) 135 | 136 | # start multi-process on single-gpu 137 | procs = [] 138 | for d in range(self.threds): 139 | e_record = min((d + 1) * stride, self.ndata) 140 | shred_list = list(range(d * stride, e_record)) 141 | device = self.devices[0] 142 | logger.info('Thread %d handle %d data.' % (d, len(shred_list))) 143 | p = self.context.Process(target=self.worker, args=(shred_list, device)) 144 | procs.append(p) 145 | 146 | for p in procs: 147 | p.start() 148 | 149 | all_results = [] 150 | for _ in tqdm(range(self.ndata)): 151 | t = self.results_queue.get() 152 | all_results.append(t) 153 | if self.verbose: 154 | self.compute_metric(all_results) 155 | 156 | for p in procs: 157 | p.join() 158 | 159 | result_line, mIoU = self.compute_metric(all_results) 160 | return result_line, mIoU 161 | 162 | def multi_process_evaluation(self): 163 | start_eval_time = time.perf_counter() 164 | nr_devices = len(self.devices) 165 | stride = int(np.ceil(self.ndata / nr_devices)) 166 | 167 | # start multi-process on multi-gpu 168 | procs = [] 169 | for d in range(nr_devices): 170 | e_record = min((d + 1) * stride, self.ndata) 171 | shred_list = list(range(d * stride, e_record)) 172 | device = self.devices[d] 173 | logger.info('GPU %s handle %d data.' % (device, len(shred_list))) 174 | p = self.context.Process(target=self.worker, args=(shred_list, device)) 175 | procs.append(p) 176 | 177 | for p in procs: 178 | p.start() 179 | 180 | all_results = [] 181 | for _ in tqdm(range(self.ndata)): 182 | t = self.results_queue.get() 183 | all_results.append(t) 184 | if self.verbose: 185 | self.compute_metric(all_results) 186 | 187 | for p in procs: 188 | p.join() 189 | 190 | result_line, mIoU = self.compute_metric(all_results) 191 | logger.info('Evaluation Elapsed Time: %.2fs' % (time.perf_counter() - start_eval_time)) 192 | return result_line, mIoU 193 | 194 | def worker(self, shred_list, device): 195 | # start_load_time = time.time() 196 | # logger.info('Load Model on Device %d: %.2fs' % (device, time.time() - start_load_time)) 197 | for idx in shred_list: 198 | dd = self.dataset[idx] 199 | results_dict = self.func_per_iteration(dd, device, iter=idx) 200 | self.results_queue.put(results_dict) 201 | 202 | 203 | def func_per_iteration(self, data, device, iter=None): 204 | raise NotImplementedError 205 | 206 | def compute_metric(self, results): 207 | raise NotImplementedError 208 | 209 | # evaluate the whole image at once 210 | def whole_eval(self, img, output_size, resize=None, input_size=None, device=None): 211 | if input_size is not None: 212 | img, margin = self.process_image(img, resize=resize, crop_size=input_size) 213 | else: 214 | img = self.process_image(img, resize=resize, crop_size=input_size) 215 | 216 | pred = self.val_func_process(img, device) 217 | if input_size is not None: 218 | pred = pred[:, margin[0]:(pred.shape[1] - margin[1]), margin[2]:(pred.shape[2] - margin[3])] 219 | pred = pred.permute(1, 2, 0) 220 | pred = pred.cpu().numpy() 221 | if output_size is not None: 222 | pred = cv2.resize(pred, 223 | (output_size[1], output_size[0]), 224 | interpolation=cv2.INTER_LINEAR) 225 | 226 | return pred 227 | 228 | # slide the window to evaluate the image 229 | def sliding_eval(self, img, crop_size, stride_rate, device=None): 230 | ori_rows, ori_cols, c = img.shape 231 | processed_pred = np.zeros((ori_rows, ori_cols, self.class_num)) 232 | 233 | for s in self.multi_scales: 234 | img_scale = cv2.resize(img, None, fx=s, fy=s, 235 | interpolation=cv2.INTER_LINEAR) 236 | new_rows, new_cols, _ = img_scale.shape 237 | processed_pred += self.scale_process(img_scale, 238 | (ori_rows, ori_cols), 239 | crop_size, stride_rate, device) 240 | 241 | pred = processed_pred.argmax(2) 242 | 243 | return pred 244 | 245 | def scale_process(self, img, ori_shape, crop_size, stride_rate, 246 | device=None): 247 | new_rows, new_cols, c = img.shape 248 | long_size = new_cols if new_cols > new_rows else new_rows 249 | 250 | if long_size <= crop_size: 251 | input_data, margin = self.process_image(img, crop_size=crop_size) 252 | score = self.val_func_process(input_data, device) 253 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 254 | margin[2]:(score.shape[2] - margin[3])] 255 | else: 256 | stride = int(np.ceil(crop_size * stride_rate)) 257 | img_pad, margin = pad_image_to_shape(img, crop_size, 258 | cv2.BORDER_CONSTANT, value=0) 259 | 260 | pad_rows = img_pad.shape[0] 261 | pad_cols = img_pad.shape[1] 262 | r_grid = int(np.ceil((pad_rows - crop_size) / stride)) + 1 263 | c_grid = int(np.ceil((pad_cols - crop_size) / stride)) + 1 264 | data_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda( 265 | device) 266 | count_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda( 267 | device) 268 | 269 | for grid_yidx in range(r_grid): 270 | for grid_xidx in range(c_grid): 271 | s_x = grid_xidx * stride 272 | s_y = grid_yidx * stride 273 | e_x = min(s_x + crop_size, pad_cols) 274 | e_y = min(s_y + crop_size, pad_rows) 275 | s_x = e_x - crop_size 276 | s_y = e_y - crop_size 277 | img_sub = img_pad[s_y:e_y, s_x: e_x, :] 278 | count_scale[:, s_y: e_y, s_x: e_x] += 1 279 | 280 | input_data, tmargin = self.process_image(img_sub, crop_size=crop_size) 281 | temp_score = self.val_func_process(input_data, device) 282 | temp_score = temp_score[:, 283 | tmargin[0]:(temp_score.shape[1] - tmargin[1]), 284 | tmargin[2]:(temp_score.shape[2] - tmargin[3])] 285 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score 286 | # score = data_scale / count_scale 287 | score = data_scale 288 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 289 | margin[2]:(score.shape[2] - margin[3])] 290 | 291 | score = score.permute(1, 2, 0) 292 | data_output = cv2.resize(score.cpu().numpy(), 293 | (ori_shape[1], ori_shape[0]), 294 | interpolation=cv2.INTER_LINEAR) 295 | 296 | return data_output 297 | 298 | def val_func_process(self, input_data, device=None): 299 | input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32) 300 | input_data = torch.FloatTensor(input_data).cuda(device) 301 | 302 | with torch.cuda.device(input_data.get_device()): 303 | self.val_func.eval() 304 | self.val_func.to(input_data.get_device()) 305 | with torch.no_grad(): 306 | score = self.val_func(input_data, output_features=[], task='new_seg')[0] 307 | if (isinstance(score, tuple) or isinstance(score, list)) and len(score) > 1: 308 | score = score[self.out_idx] 309 | score = score[0] # a single image pass, ignore batch dim 310 | 311 | if self.is_flip: 312 | input_data = input_data.flip(-1) 313 | score_flip = self.val_func(input_data)[0] 314 | score_flip = score_flip[0] # a single image pass, ignore batch dim 315 | score += score_flip.flip(-1) 316 | score = torch.exp(score) 317 | # score = score.data 318 | 319 | return score 320 | 321 | def process_image(self, img, resize=None, crop_size=None): 322 | p_img = img 323 | 324 | if img.shape[2] < 3: 325 | im_b = p_img 326 | im_g = p_img 327 | im_r = p_img 328 | p_img = np.concatenate((im_b, im_g, im_r), axis=2) 329 | 330 | if resize is not None: 331 | if isinstance(resize, float): 332 | _size = p_img.shape[:2] 333 | # p_img = np.array(Image.fromarray(p_img).resize((int(_size[0]*resize), int(_size[1]*resize)), Image.BILINEAR)) 334 | p_img = np.array(Image.fromarray(p_img).resize((int(_size[1]*resize), int(_size[0]*resize)), Image.BILINEAR)) 335 | elif isinstance(resize, tuple) or isinstance(resize, list): 336 | assert len(resize) == 2 337 | p_img = np.array(Image.fromarray(p_img).resize((int(resize[0]), int(resize[1])), Image.BILINEAR)) 338 | 339 | p_img = normalize(p_img, self.image_mean, self.image_std) 340 | 341 | if crop_size is not None: 342 | p_img, margin = pad_image_to_shape(p_img, crop_size, cv2.BORDER_CONSTANT, value=0) 343 | p_img = p_img.transpose(2, 0, 1) 344 | 345 | return p_img, margin 346 | 347 | p_img = p_img.transpose(2, 0, 1) 348 | 349 | return p_img 350 | -------------------------------------------------------------------------------- /tools/engine/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import os 5 | import sys 6 | import logging 7 | 8 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') 9 | _default_level = logging.getLevelName(_default_level_name.upper()) 10 | 11 | 12 | class LogFormatter(logging.Formatter): 13 | log_fout = None 14 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' 15 | date = '%(asctime)s ' 16 | msg = '%(message)s' 17 | 18 | def format(self, record): 19 | if record.levelno == logging.DEBUG: 20 | mcl, mtxt = self._color_dbg, 'DBG' 21 | elif record.levelno == logging.WARNING: 22 | mcl, mtxt = self._color_warn, 'WRN' 23 | elif record.levelno == logging.ERROR: 24 | mcl, mtxt = self._color_err, 'ERR' 25 | else: 26 | mcl, mtxt = self._color_normal, '' 27 | 28 | if mtxt: 29 | mtxt += ' ' 30 | 31 | if self.log_fout: 32 | self.__set_fmt(self.date_full + mtxt + self.msg) 33 | formatted = super(LogFormatter, self).format(record) 34 | # self.log_fout.write(formatted) 35 | # self.log_fout.write('\n') 36 | # self.log_fout.flush() 37 | return formatted 38 | 39 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) 40 | formatted = super(LogFormatter, self).format(record) 41 | 42 | return formatted 43 | 44 | if sys.version_info.major < 3: 45 | def __set_fmt(self, fmt): 46 | self._fmt = fmt 47 | else: 48 | def __set_fmt(self, fmt): 49 | self._style._fmt = fmt 50 | 51 | @staticmethod 52 | def _color_dbg(msg): 53 | return '\x1b[36m{}\x1b[0m'.format(msg) 54 | 55 | @staticmethod 56 | def _color_warn(msg): 57 | return '\x1b[1;31m{}\x1b[0m'.format(msg) 58 | 59 | @staticmethod 60 | def _color_err(msg): 61 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg) 62 | 63 | @staticmethod 64 | def _color_omitted(msg): 65 | return '\x1b[35m{}\x1b[0m'.format(msg) 66 | 67 | @staticmethod 68 | def _color_normal(msg): 69 | return msg 70 | 71 | @staticmethod 72 | def _color_date(msg): 73 | return '\x1b[32m{}\x1b[0m'.format(msg) 74 | 75 | 76 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter): 77 | logger = logging.getLogger() 78 | logger.setLevel(_default_level) 79 | del logger.handlers[:] 80 | 81 | if log_dir and log_file: 82 | if not os.path.isdir(log_dir): os.makedirs(log_dir) 83 | LogFormatter.log_fout = True 84 | file_handler = logging.FileHandler(log_file, mode='a') 85 | file_handler.setLevel(logging.INFO) 86 | file_handler.setFormatter(formatter) 87 | logger.addHandler(file_handler) 88 | 89 | stream_handler = logging.StreamHandler() 90 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S')) 91 | stream_handler.setLevel(0) 92 | logger.addHandler(stream_handler) 93 | return logger 94 | -------------------------------------------------------------------------------- /tools/engine/tester.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import os 5 | import os.path as osp 6 | import cv2 7 | import numpy as np 8 | import time 9 | from tqdm import tqdm 10 | from pdb import set_trace as bp 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.multiprocessing as mp 14 | 15 | from engine.logger import get_logger 16 | from tools.utils.pyt_utils import load_model, link_file, ensure_dir 17 | from tools.utils.img_utils import pad_image_to_shape, normalize 18 | 19 | logger = get_logger() 20 | 21 | 22 | class Tester(object): 23 | def __init__(self, dataset, class_num, image_mean, image_std, network, 24 | multi_scales, is_flip, devices=0, out_idx=0, threds=3, config=None, logger=None, 25 | verbose=False, save_path=None, show_image=False): 26 | self.dataset = dataset 27 | self.ndata = self.dataset.get_length() 28 | self.class_num = class_num 29 | self.image_mean = image_mean 30 | self.image_std = image_std 31 | self.multi_scales = multi_scales 32 | self.is_flip = is_flip 33 | self.network = network 34 | self.devices = devices 35 | if type(self.devices) == int: self.devices = [self.devices] 36 | self.out_idx = out_idx 37 | self.threds = threds 38 | self.config = config 39 | self.logger = logger 40 | 41 | self.context = mp.get_context('spawn') 42 | self.val_func = None 43 | self.results_queue = self.context.Queue(self.ndata) 44 | 45 | self.verbose = verbose 46 | self.save_path = save_path 47 | if save_path is not None: 48 | ensure_dir(save_path) 49 | self.show_image = show_image 50 | 51 | def run(self, model_path, model_indice, log_file, log_file_link): 52 | """There are four evaluation modes: 53 | 1.only eval a .pth model: -e *.pth 54 | 2.only eval a certain epoch: -e epoch 55 | 3.eval all epochs in a given section: -e start_epoch-end_epoch 56 | 4.eval all epochs from a certain started epoch: -e start_epoch- 57 | """ 58 | if '.pth' in model_indice: 59 | models = [model_indice, ] 60 | elif "-" in model_indice: 61 | start_epoch = int(model_indice.split("-")[0]) 62 | end_epoch = model_indice.split("-")[1] 63 | 64 | models = os.listdir(model_path) 65 | models.remove("epoch-last.pth") 66 | sorted_models = [None] * len(models) 67 | model_idx = [0] * len(models) 68 | 69 | for idx, m in enumerate(models): 70 | num = m.split(".")[0].split("-")[1] 71 | model_idx[idx] = num 72 | sorted_models[idx] = m 73 | model_idx = np.array([int(i) for i in model_idx]) 74 | 75 | down_bound = model_idx >= start_epoch 76 | up_bound = [True] * len(sorted_models) 77 | if end_epoch: 78 | end_epoch = int(end_epoch) 79 | assert start_epoch < end_epoch 80 | up_bound = model_idx <= end_epoch 81 | bound = up_bound * down_bound 82 | model_slice = np.array(sorted_models)[bound] 83 | models = [os.path.join(model_path, model) for model in 84 | model_slice] 85 | else: 86 | models = [os.path.join(model_path, 87 | 'epoch-%s.pth' % model_indice), ] 88 | 89 | results = open(log_file, 'a') 90 | link_file(log_file, log_file_link) 91 | 92 | for model in models: 93 | logger.info("Load Model: %s" % model) 94 | self.val_func = load_model(self.network, model) 95 | result_line, mIoU = self.multi_process_evaluation() 96 | 97 | results.write('Model: ' + model + '\n') 98 | results.write(result_line) 99 | results.write('\n') 100 | results.flush() 101 | 102 | results.close() 103 | 104 | def run_online(self): 105 | """ 106 | eval during training 107 | """ 108 | self.val_func = self.network 109 | self.single_process_evaluation() 110 | 111 | def single_process_evaluation(self): 112 | with torch.no_grad(): 113 | for idx in tqdm(range(self.ndata)): 114 | dd = self.dataset[idx] 115 | self.func_per_iteration(dd, self.devices[0], iter=idx) 116 | 117 | def run_online_multiprocess(self): 118 | """ 119 | eval during training 120 | """ 121 | self.val_func = self.network 122 | self.multi_process_single_gpu_evaluation() 123 | 124 | def multi_process_single_gpu_evaluation(self): 125 | # start_eval_time = time.perf_counter() 126 | stride = int(np.ceil(self.ndata / self.threds)) 127 | 128 | # start multi-process on single-gpu 129 | procs = [] 130 | for d in range(self.threds): 131 | e_record = min((d + 1) * stride, self.ndata) 132 | shred_list = list(range(d * stride, e_record)) 133 | device = self.devices[0] 134 | logger.info('Thread %d handle %d data.' % (d, len(shred_list))) 135 | p = self.context.Process(target=self.worker, args=(shred_list, device)) 136 | procs.append(p) 137 | 138 | for p in procs: 139 | p.start() 140 | 141 | for p in procs: 142 | p.join() 143 | 144 | 145 | def multi_process_evaluation(self): 146 | start_eval_time = time.perf_counter() 147 | nr_devices = len(self.devices) 148 | stride = int(np.ceil(self.ndata / nr_devices)) 149 | 150 | # start multi-process on multi-gpu 151 | procs = [] 152 | for d in range(nr_devices): 153 | e_record = min((d + 1) * stride, self.ndata) 154 | shred_list = list(range(d * stride, e_record)) 155 | device = self.devices[d] 156 | logger.info('GPU %s handle %d data.' % (device, len(shred_list))) 157 | p = self.context.Process(target=self.worker, args=(shred_list, device)) 158 | procs.append(p) 159 | 160 | for p in procs: 161 | p.start() 162 | 163 | for p in procs: 164 | p.join() 165 | 166 | 167 | def worker(self, shred_list, device): 168 | start_load_time = time.time() 169 | # logger.info('Load Model on Device %d: %.2fs' % (device, time.time() - start_load_time)) 170 | for idx in shred_list: 171 | dd = self.dataset[idx] 172 | results_dict = self.func_per_iteration(dd, device, iter=idx) 173 | self.results_queue.put(results_dict) 174 | 175 | def func_per_iteration(self, data, device, iter=None): 176 | raise NotImplementedError 177 | 178 | def compute_metric(self, results): 179 | raise NotImplementedError 180 | 181 | # evaluate the whole image at once 182 | def whole_eval(self, img, output_size, input_size=None, device=None): 183 | if input_size is not None: 184 | img, margin = self.process_image(img, input_size) 185 | else: 186 | img = self.process_image(img, input_size) 187 | 188 | pred = self.val_func_process(img, device) 189 | if input_size is not None: 190 | pred = pred[:, margin[0]:(pred.shape[1] - margin[1]), margin[2]:(pred.shape[2] - margin[3])] 191 | pred = pred.permute(1, 2, 0) 192 | pred = pred.cpu().numpy() 193 | if output_size is not None: 194 | pred = cv2.resize(pred, 195 | (output_size[1], output_size[0]), 196 | interpolation=cv2.INTER_LINEAR) 197 | 198 | pred = pred.argmax(2) 199 | 200 | return pred 201 | 202 | # slide the window to evaluate the image 203 | def sliding_eval(self, img, crop_size, stride_rate, device=None): 204 | ori_rows, ori_cols, c = img.shape 205 | processed_pred = np.zeros((ori_rows, ori_cols, self.class_num)) 206 | 207 | for s in self.multi_scales: 208 | img_scale = cv2.resize(img, None, fx=s, fy=s, 209 | interpolation=cv2.INTER_LINEAR) 210 | new_rows, new_cols, _ = img_scale.shape 211 | processed_pred += self.scale_process(img_scale, 212 | (ori_rows, ori_cols), 213 | crop_size, stride_rate, device) 214 | 215 | pred = processed_pred.argmax(2) 216 | 217 | return pred 218 | 219 | def scale_process(self, img, ori_shape, crop_size, stride_rate, 220 | device=None): 221 | new_rows, new_cols, c = img.shape 222 | long_size = new_cols if new_cols > new_rows else new_rows 223 | 224 | if long_size <= crop_size: 225 | input_data, margin = self.process_image(img, crop_size) 226 | score = self.val_func_process(input_data, device) 227 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 228 | margin[2]:(score.shape[2] - margin[3])] 229 | else: 230 | stride = int(np.ceil(crop_size * stride_rate)) 231 | img_pad, margin = pad_image_to_shape(img, crop_size, 232 | cv2.BORDER_CONSTANT, value=0) 233 | 234 | pad_rows = img_pad.shape[0] 235 | pad_cols = img_pad.shape[1] 236 | r_grid = int(np.ceil((pad_rows - crop_size) / stride)) + 1 237 | c_grid = int(np.ceil((pad_cols - crop_size) / stride)) + 1 238 | data_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda( 239 | device) 240 | count_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda( 241 | device) 242 | 243 | for grid_yidx in range(r_grid): 244 | for grid_xidx in range(c_grid): 245 | s_x = grid_xidx * stride 246 | s_y = grid_yidx * stride 247 | e_x = min(s_x + crop_size, pad_cols) 248 | e_y = min(s_y + crop_size, pad_rows) 249 | s_x = e_x - crop_size 250 | s_y = e_y - crop_size 251 | img_sub = img_pad[s_y:e_y, s_x: e_x, :] 252 | count_scale[:, s_y: e_y, s_x: e_x] += 1 253 | 254 | input_data, tmargin = self.process_image(img_sub, crop_size) 255 | temp_score = self.val_func_process(input_data, device) 256 | temp_score = temp_score[:, 257 | tmargin[0]:(temp_score.shape[1] - tmargin[1]), 258 | tmargin[2]:(temp_score.shape[2] - tmargin[3])] 259 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score 260 | # score = data_scale / count_scale 261 | score = data_scale 262 | score = score[:, margin[0]:(score.shape[1] - margin[1]), 263 | margin[2]:(score.shape[2] - margin[3])] 264 | 265 | score = score.permute(1, 2, 0) 266 | data_output = cv2.resize(score.cpu().numpy(), 267 | (ori_shape[1], ori_shape[0]), 268 | interpolation=cv2.INTER_LINEAR) 269 | 270 | return data_output 271 | 272 | def val_func_process(self, input_data, device=None): 273 | input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32) 274 | input_data = torch.FloatTensor(input_data).cuda(device) 275 | 276 | with torch.cuda.device(input_data.get_device()): 277 | self.val_func.eval() 278 | self.val_func.to(input_data.get_device()) 279 | with torch.no_grad(): 280 | score = self.val_func(input_data, task='new_seg')[0] 281 | if (isinstance(score, tuple) or isinstance(score, list)) and len(score) > 1: 282 | score = score[self.out_idx] 283 | score = score[0] # a single image pass, ignore batch dim 284 | 285 | if self.is_flip: 286 | input_data = input_data.flip(-1) 287 | score_flip = self.val_func(input_data) 288 | score_flip = score_flip[0] 289 | score += score_flip.flip(-1) 290 | score = torch.exp(score) 291 | # score = score.data 292 | 293 | return score 294 | 295 | def process_image(self, img, crop_size=None): 296 | p_img = img 297 | 298 | if img.shape[2] < 3: 299 | im_b = p_img 300 | im_g = p_img 301 | im_r = p_img 302 | p_img = np.concatenate((im_b, im_g, im_r), axis=2) 303 | 304 | p_img = normalize(p_img, self.image_mean, self.image_std) 305 | 306 | if crop_size is not None: 307 | p_img, margin = pad_image_to_shape(p_img, crop_size, cv2.BORDER_CONSTANT, value=0) 308 | p_img = p_img.transpose(2, 0, 1) 309 | 310 | return p_img, margin 311 | 312 | p_img = p_img.transpose(2, 0, 1) 313 | 314 | return p_img 315 | -------------------------------------------------------------------------------- /tools/seg_opr/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | 6 | np.seterr(divide='ignore', invalid='ignore') 7 | 8 | 9 | # voc cityscapes metric 10 | def hist_info(n_cl, pred, gt): 11 | assert (pred.shape == gt.shape), "pred: " + str(pred.shape) + " v.s. gt: " + str(gt.shape) 12 | k = (gt >= 0) & (gt < n_cl) 13 | labeled = np.sum(k) 14 | correct = np.sum((pred[k] == gt[k])) 15 | 16 | return np.bincount(n_cl * gt[k].astype(int) + pred[k].astype(int), 17 | minlength=n_cl ** 2).reshape(n_cl, 18 | n_cl), labeled, correct 19 | 20 | 21 | def compute_score(hist, correct, labeled): 22 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 23 | mean_IU = np.nanmean(iu) 24 | mean_IU_no_back = np.nanmean(iu[1:]) 25 | freq = hist.sum(1) / hist.sum() 26 | # freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() 27 | mean_pixel_acc = correct / labeled 28 | 29 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 30 | 31 | 32 | # ade metric 33 | def meanIoU(area_intersection, area_union): 34 | iou = 1.0 * np.sum(area_intersection, axis=1) / np.sum(area_union, axis=1) 35 | meaniou = np.nanmean(iou) 36 | meaniou_no_back = np.nanmean(iou[1:]) 37 | 38 | return iou, meaniou, meaniou_no_back 39 | 40 | 41 | def intersectionAndUnion(imPred, imLab, numClass): 42 | # Remove classes from unlabeled pixels in gt image. 43 | # We should not penalize detections in unlabeled portions of the image. 44 | imPred = np.asarray(imPred).copy() 45 | imLab = np.asarray(imLab).copy() 46 | 47 | imPred += 1 48 | imLab += 1 49 | # Remove classes from unlabeled pixels in gt image. 50 | # We should not penalize detections in unlabeled portions of the image. 51 | imPred = imPred * (imLab > 0) 52 | 53 | # imPred = imPred * (imLab >= 0) 54 | 55 | # Compute area intersection: 56 | intersection = imPred * (imPred == imLab) 57 | (area_intersection, _) = np.histogram(intersection, bins=numClass, 58 | range=(1, numClass)) 59 | 60 | # Compute area union: 61 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 62 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 63 | area_union = area_pred + area_lab - area_intersection 64 | 65 | return area_intersection, area_union 66 | 67 | 68 | def mean_pixel_accuracy(pixel_correct, pixel_labeled): 69 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / ( 70 | np.spacing(1) + np.sum(pixel_labeled)) 71 | 72 | return mean_pixel_accuracy 73 | 74 | 75 | def pixelAccuracy(imPred, imLab): 76 | # Remove classes from unlabeled pixels in gt image. 77 | # We should not penalize detections in unlabeled portions of the image. 78 | pixel_labeled = np.sum(imLab >= 0) 79 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0)) 80 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 81 | 82 | return pixel_accuracy, pixel_correct, pixel_labeled 83 | 84 | 85 | def accuracy(preds, label): 86 | valid = (label >= 0) 87 | acc_sum = (valid * (preds == label)).sum() 88 | valid_sum = valid.sum() 89 | acc = float(acc_sum) / (valid_sum + 1e-10) 90 | return acc, valid_sum 91 | -------------------------------------------------------------------------------- /tools/utils/img_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import cv2 5 | import numpy as np 6 | import numbers 7 | import random 8 | import collections 9 | 10 | 11 | def get_2dshape(shape, *, zero=True): 12 | if not isinstance(shape, collections.Iterable): 13 | shape = int(shape) 14 | shape = (shape, shape) 15 | else: 16 | h, w = map(int, shape) 17 | shape = (h, w) 18 | if zero: 19 | minv = 0 20 | else: 21 | minv = 1 22 | 23 | assert min(shape) >= minv, 'invalid shape: {}'.format(shape) 24 | return shape 25 | 26 | 27 | def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value): 28 | h, w = img.shape[:2] 29 | start_crop_h, start_crop_w = crop_pos 30 | assert ((start_crop_h < h) and (start_crop_h >= 0)) 31 | assert ((start_crop_w < w) and (start_crop_w >= 0)) 32 | 33 | crop_size = get_2dshape(crop_size) 34 | crop_h, crop_w = crop_size 35 | 36 | img_crop = img[start_crop_h:start_crop_h + crop_h, 37 | start_crop_w:start_crop_w + crop_w, ...] 38 | 39 | img_, margin = pad_image_to_shape(img_crop, crop_size, cv2.BORDER_CONSTANT, 40 | pad_label_value) 41 | 42 | return img_, margin 43 | 44 | 45 | def generate_random_crop_pos(ori_size, crop_size): 46 | ori_size = get_2dshape(ori_size) 47 | h, w = ori_size 48 | 49 | crop_size = get_2dshape(crop_size) 50 | crop_h, crop_w = crop_size 51 | 52 | pos_h, pos_w = 0, 0 53 | 54 | if h > crop_h: 55 | pos_h = random.randint(0, h - crop_h + 1) 56 | 57 | if w > crop_w: 58 | pos_w = random.randint(0, w - crop_w + 1) 59 | 60 | return pos_h, pos_w 61 | 62 | 63 | def pad_image_to_shape(img, shape, border_mode, value): 64 | margin = np.zeros(4, np.uint32) 65 | shape = get_2dshape(shape) 66 | pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0 67 | pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0 68 | 69 | margin[0] = pad_height // 2 70 | margin[1] = pad_height // 2 + pad_height % 2 71 | margin[2] = pad_width // 2 72 | margin[3] = pad_width // 2 + pad_width % 2 73 | 74 | img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3], 75 | border_mode, value=value) 76 | 77 | return img, margin 78 | 79 | 80 | def pad_image_size_to_multiples_of(img, multiple, pad_value): 81 | h, w = img.shape[:2] 82 | d = multiple 83 | 84 | def canonicalize(s): 85 | v = s // d 86 | return (v + (v * d != s)) * d 87 | 88 | th, tw = map(canonicalize, (h, w)) 89 | 90 | return pad_image_to_shape(img, (th, tw), cv2.BORDER_CONSTANT, pad_value) 91 | 92 | 93 | def resize_ensure_shortest_edge(img, edge_length, 94 | interpolation_mode=cv2.INTER_LINEAR): 95 | assert isinstance(edge_length, int) and edge_length > 0, edge_length 96 | h, w = img.shape[:2] 97 | if h < w: 98 | ratio = float(edge_length) / h 99 | th, tw = edge_length, max(1, int(ratio * w)) 100 | else: 101 | ratio = float(edge_length) / w 102 | th, tw = max(1, int(ratio * h)), edge_length 103 | img = cv2.resize(img, (tw, th), interpolation_mode) 104 | 105 | return img 106 | 107 | 108 | def random_scale(img, gt, scales): 109 | # scale = random.choice(scales) 110 | scale = random.uniform(min(scales), max(scales)) 111 | sh = int(img.shape[0] * scale) 112 | sw = int(img.shape[1] * scale) 113 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR) 114 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST) 115 | 116 | return img, gt, scale 117 | 118 | 119 | def random_scale_with_length(img, gt, length): 120 | size = random.choice(length) 121 | sh = size 122 | sw = size 123 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR) 124 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST) 125 | 126 | return img, gt, size 127 | 128 | 129 | def random_mirror(img, gt): 130 | if random.random() >= 0.5: 131 | img = cv2.flip(img, 1) 132 | gt = cv2.flip(gt, 1) 133 | 134 | return img, gt, 135 | 136 | 137 | def random_rotation(img, gt): 138 | angle = random.random() * 20 - 10 139 | h, w = img.shape[:2] 140 | rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 141 | img = cv2.warpAffine(img, rotation_matrix, (w, h), flags=cv2.INTER_LINEAR) 142 | gt = cv2.warpAffine(gt, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST) 143 | 144 | return img, gt 145 | 146 | 147 | def random_gaussian_blur(img): 148 | gauss_size = random.choice([1, 3, 5, 7]) 149 | if gauss_size > 1: 150 | # do the gaussian blur 151 | img = cv2.GaussianBlur(img, (gauss_size, gauss_size), 0) 152 | 153 | return img 154 | 155 | 156 | def center_crop(img, shape): 157 | h, w = shape[0], shape[1] 158 | y = (img.shape[0] - h) // 2 159 | x = (img.shape[1] - w) // 2 160 | return img[y:y + h, x:x + w] 161 | 162 | 163 | def random_crop(img, gt, size): 164 | if isinstance(size, numbers.Number): 165 | size = (int(size), int(size)) 166 | 167 | h, w = img.shape[:2] 168 | crop_h, crop_w = size[0], size[1] 169 | 170 | if h > crop_h: 171 | x = random.randint(0, h - crop_h + 1) 172 | img = img[x:x + crop_h, :, :] 173 | gt = gt[x:x + crop_h, :] 174 | 175 | if w > crop_w: 176 | x = random.randint(0, w - crop_w + 1) 177 | img = img[:, x:x + crop_w, :] 178 | gt = gt[:, x:x + crop_w] 179 | 180 | return img, gt 181 | 182 | 183 | def normalize(img, mean, std): 184 | # pytorch pretrained model need the input range: 0-1 185 | img = img.astype(np.float32) / 255.0 186 | img = img - mean 187 | img = img / std 188 | 189 | return img 190 | -------------------------------------------------------------------------------- /tools/utils/pyt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | # encoding: utf-8 5 | import os 6 | import time 7 | import argparse 8 | from collections import OrderedDict 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | from tools.engine.logger import get_logger 14 | 15 | logger = get_logger() 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | } 24 | 25 | 26 | # def reduce_tensor(tensor, dst=0, op=dist.ReduceOp.SUM, world_size=1): 27 | # tensor = tensor.clone() 28 | # dist.reduce(tensor, dst, op) 29 | # if dist.get_rank() == dst: 30 | # tensor.div_(world_size) 31 | # 32 | # return tensor 33 | 34 | 35 | # def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM, world_size=1): 36 | # tensor = tensor.clone() 37 | # dist.all_reduce(tensor, op) 38 | # tensor.div_(world_size) 39 | # 40 | # return tensor 41 | 42 | 43 | def load_model(model, model_file, is_restore=False): 44 | t_start = time.time() 45 | if isinstance(model_file, str): 46 | state_dict = torch.load(model_file) 47 | if 'model' in state_dict.keys(): 48 | state_dict = state_dict['model'] 49 | else: 50 | state_dict = model_file 51 | t_ioend = time.time() 52 | 53 | if is_restore: 54 | new_state_dict = OrderedDict() 55 | for k, v in state_dict.items(): 56 | name = 'module.' + k 57 | new_state_dict[name] = v 58 | state_dict = new_state_dict 59 | 60 | model.load_state_dict(state_dict, strict=False) 61 | ckpt_keys = set(state_dict.keys()) 62 | own_keys = set(model.state_dict().keys()) 63 | missing_keys = own_keys - ckpt_keys 64 | unexpected_keys = ckpt_keys - own_keys 65 | 66 | if len(missing_keys) > 0: 67 | logger.warning('Missing key(s) in state_dict: {}'.format( 68 | ', '.join('{}'.format(k) for k in missing_keys))) 69 | 70 | if len(unexpected_keys) > 0: 71 | logger.warning('Unexpected key(s) in state_dict: {}'.format( 72 | ', '.join('{}'.format(k) for k in unexpected_keys))) 73 | 74 | del state_dict 75 | t_end = time.time() 76 | logger.info( 77 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( 78 | t_ioend - t_start, t_end - t_ioend)) 79 | 80 | return model 81 | 82 | 83 | def parse_devices(input_devices): 84 | if input_devices.endswith('*'): 85 | devices = list(range(torch.cuda.device_count())) 86 | return devices 87 | 88 | devices = [] 89 | for d in input_devices.split(','): 90 | if '-' in d: 91 | start_device, end_device = d.split('-')[0], d.split('-')[1] 92 | assert start_device != '' 93 | assert end_device != '' 94 | start_device, end_device = int(start_device), int(end_device) 95 | assert start_device < end_device 96 | assert end_device < torch.cuda.device_count() 97 | for sd in range(start_device, end_device + 1): 98 | devices.append(sd) 99 | else: 100 | device = int(d) 101 | assert device < torch.cuda.device_count() 102 | devices.append(device) 103 | 104 | logger.info('using devices {}'.format( 105 | ', '.join([str(d) for d in devices]))) 106 | 107 | return devices 108 | 109 | 110 | def extant_file(x): 111 | """ 112 | 'Type' for argparse - checks that file exists but does not open. 113 | """ 114 | if not os.path.exists(x): 115 | # Argparse uses the ArgumentTypeError to give a rejection message like: 116 | # error: argument input: x does not exist 117 | raise argparse.ArgumentTypeError("{0} does not exist".format(x)) 118 | return x 119 | 120 | 121 | def link_file(src, target): 122 | if os.path.isdir(target) or os.path.isfile(target): 123 | os.remove(target) 124 | os.system('ln -s {} {}'.format(src, target)) 125 | 126 | 127 | def ensure_dir(path): 128 | if not os.path.isdir(path): 129 | os.makedirs(path) 130 | 131 | 132 | def _dbg_interactive(var, value): 133 | from IPython import embed 134 | embed() 135 | -------------------------------------------------------------------------------- /tools/utils/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import numpy as np 5 | import cv2 6 | import scipy.io as sio 7 | 8 | 9 | def set_img_color(colors, background, img, gt, show255=False, weight_foreground=0.55): 10 | origin = np.array(img) 11 | for i in range(len(colors)): 12 | if i != background: 13 | img[np.where(gt == i)] = colors[i] 14 | if show255: 15 | img[np.where(gt == 255)] = 0 16 | cv2.addWeighted(img, weight_foreground, origin, 1 - weight_foreground, 0, img) 17 | return img 18 | 19 | 20 | def show_prediction(colors, background, img, pred): 21 | im = np.array(img, np.uint8) 22 | set_img_color(colors, background, im, pred, weight_foreground=1) 23 | final = np.array(im) 24 | return final 25 | 26 | 27 | def show_img(colors, background, img, clean, gt, *pds): 28 | im1 = np.array(img, np.uint8) 29 | # set_img_color(colors, background, im1, clean) 30 | final = np.array(im1) 31 | # the pivot black bar 32 | pivot = np.zeros((im1.shape[0], 15, 3), dtype=np.uint8) 33 | for pd in pds: 34 | im = np.array(img, np.uint8) 35 | # pd[np.where(gt == 255)] = 255 36 | set_img_color(colors, background, im, pd) 37 | final = np.column_stack((final, pivot)) 38 | final = np.column_stack((final, im)) 39 | 40 | im = np.array(img, np.uint8) 41 | set_img_color(colors, background, im, gt, True) 42 | final = np.column_stack((final, pivot)) 43 | final = np.column_stack((final, im)) 44 | return final 45 | 46 | 47 | def get_colors(class_num): 48 | colors = [] 49 | for i in range(class_num): 50 | colors.append((np.random.random((1, 3)) * 255).tolist()[0]) 51 | 52 | return colors 53 | 54 | 55 | def get_ade_colors(): 56 | colors = sio.loadmat('./color150.mat')['colors'] 57 | colors = colors[:, ::-1, ] 58 | colors = np.array(colors).astype(int).tolist() 59 | colors.insert(0, [0, 0, 0]) 60 | 61 | return colors 62 | 63 | 64 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False, 65 | no_print=False): 66 | n = iu.size 67 | lines = [] 68 | for i in range(n): 69 | if class_names is None: 70 | cls = 'Class %d:' % (i + 1) 71 | else: 72 | cls = '%d %s' % (i + 1, class_names[i]) 73 | lines.append('%-8s\t%.3f%%' % (cls, iu[i] * 100)) 74 | mean_IU = np.nanmean(iu) 75 | # mean_IU_no_back = np.nanmean(iu[1:]) 76 | mean_IU_no_back = np.nanmean(iu[:-1]) 77 | if show_no_back: 78 | lines.append( 79 | '---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%\t%-8s\t%.3f%%' % ( 80 | 'mean_IU', mean_IU * 100, 'mean_IU_no_back', 81 | mean_IU_no_back * 100, 82 | 'mean_pixel_ACC', mean_pixel_acc * 100)) 83 | else: 84 | print(mean_pixel_acc) 85 | lines.append( 86 | '---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%' % ( 87 | 'mean_IU', mean_IU * 100, 'mean_pixel_ACC', 88 | mean_pixel_acc * 100)) 89 | line = "\n".join(lines) 90 | if not no_print: 91 | print(line) 92 | return line 93 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | # [x] train resnet101 with proxy guidance on visda17 5 | # [x] evaluation on visda17 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import logging 11 | import time 12 | from tqdm import tqdm 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | from torch.utils.data import DataLoader 17 | import torchvision.transforms as transforms 18 | from pdb import set_trace as bp 19 | 20 | from data.visda17 import VisDA17 21 | from model.resnet import resnet101 22 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, adjust_learning_rate, accuracy 23 | from utils.logger import prepare_logger, prepare_seed 24 | from utils.sgd import SGD 25 | 26 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean') 27 | KLDivLoss = nn.KLDivLoss(reduction='batchmean') 28 | 29 | parser = argparse.ArgumentParser(description='ASG Training') 30 | parser.add_argument('--data', default='/raid/taskcv-2017-public/classification/data', help='path to dataset') 31 | parser.add_argument('--epochs', default=30, type=int, help='number of total epochs to run') 32 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 33 | parser.add_argument('--batch-size', default=32, type=int, dest='batch_size', help='mini-batch size (default: 32)') 34 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate') 35 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)') 36 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') 37 | parser.add_argument('--lwf', default=0., type=float, dest='lwf', help='weight of KL loss for LwF (default: 0)') 38 | parser.add_argument('--resume', default='none', type=str, help='path to latest checkpoint (default: none)') 39 | parser.add_argument('--evaluate', action='store_true', help='whether to use learn without forgetting (default: False)') 40 | parser.add_argument('--timestamp', type=str, default='none', help='timestamp for logging naming') 41 | parser.add_argument('--save_dir', type=str, default="./runs", help='root folder to save checkpoints and log.') 42 | parser.add_argument('--train_blocks', type=str, default="conv1.bn1.layer1.layer2.layer3.layer4.fc", help='blocks to train, seperated by dot.') 43 | parser.add_argument('--num-class', default=12, type=int, dest='num_class', help='the number of classes') 44 | parser.add_argument('--rand_seed', default=0, type=int, help='the number of classes') 45 | 46 | best_prec1 = 0 47 | 48 | def main(): 49 | global args, best_prec1 50 | PID = os.getpid() 51 | args = parser.parse_args() 52 | prepare_seed(args.rand_seed) 53 | 54 | if args.timestamp == 'none': 55 | args.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time()))) 56 | 57 | # Log outputs 58 | if args.evaluate: 59 | args.save_dir = args.save_dir + "/Visda17-Res101-evaluate" + \ 60 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp) 61 | else: 62 | args.save_dir = args.save_dir + \ 63 | "/Visda17-Res101-%s-train.%s-LR%.2E-epoch%d-batch%d-seed%d"%( 64 | "LWF%.2f"%args.lwf if args.lwf > 0 else "XE", args.train_blocks, args.lr, args.epochs, args.batch_size, args.rand_seed) + \ 65 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp) 66 | logger = prepare_logger(args) 67 | 68 | data_transforms = { 69 | 'train': transforms.Compose([ 70 | transforms.Resize(224), 71 | transforms.CenterCrop(224), 72 | transforms.ToTensor(), 73 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 74 | ]), 75 | 'val': transforms.Compose([ 76 | transforms.Resize(224), 77 | transforms.CenterCrop(224), 78 | transforms.ToTensor(), 79 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 80 | ]), 81 | } 82 | 83 | kwargs = {'num_workers': 20, 'pin_memory': True} 84 | trainset = VisDA17(txt_file=os.path.join(args.data, "train/image_list.txt"), root_dir=os.path.join(args.data, "train"), transform=data_transforms['train']) 85 | valset = VisDA17(txt_file=os.path.join(args.data, "validation/image_list.txt"), root_dir=os.path.join(args.data, "validation"), transform=data_transforms['val'], label_one_hot=True) 86 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) 87 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, **kwargs) 88 | 89 | model = resnet101(pretrained=True) 90 | num_ftrs = model.fc.in_features 91 | fc_layers = nn.Sequential( 92 | nn.Linear(num_ftrs, 512), 93 | nn.ReLU(inplace=True), 94 | nn.Linear(512, args.num_class), 95 | ) 96 | model.fc_new = fc_layers 97 | 98 | train_blocks = args.train_blocks.split('.') 99 | # default turn-off fc, turn-on fc_new 100 | for param in model.fc.parameters(): 101 | param.requires_grad = False 102 | ##### Freeze several bottom layers (Optional) ##### 103 | non_train_blocks = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc'] 104 | for name in train_blocks: 105 | try: 106 | non_train_blocks.remove(name) 107 | except Exception: 108 | print("cannot find block name %s\nAvailable blocks are: conv1, bn1, layer1, layer2, layer3, layer4, fc"%name) 109 | for name in non_train_blocks: 110 | for param in getattr(model, name).parameters(): 111 | param.requires_grad = False 112 | 113 | # Setup optimizer 114 | factor = 0.1 115 | sgd_in = [] 116 | for name in train_blocks: 117 | if name != 'fc': 118 | sgd_in.append({'params': get_params(model, [name]), 'lr': factor*args.lr}) 119 | else: 120 | sgd_in.append({'params': get_params(model, ["fc_new"]), 'lr': args.lr}) 121 | base_lrs = [ group['lr'] for group in sgd_in ] 122 | optimizer = SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 123 | 124 | # Optionally resume from a checkpoint 125 | if args.resume != 'none': 126 | if os.path.isfile(args.resume): 127 | print("=> loading checkpoint '{}'".format(args.resume)) 128 | checkpoint = torch.load(args.resume) 129 | args.start_epoch = checkpoint['epoch'] 130 | best_prec1 = checkpoint['best_prec1'] 131 | model.load_state_dict(checkpoint['state_dict']) 132 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 133 | else: 134 | print("=ImageClassdata> no checkpoint found at '{}'".format(args.resume)) 135 | 136 | model = model.cuda() 137 | 138 | model_old = None 139 | if args.lwf > 0: 140 | # create a fixed model copy for Life-long learning 141 | model_old = resnet101(pretrained=True) 142 | for param in model_old.parameters(): 143 | param.requires_grad = False 144 | model_old.eval() 145 | model_old.cuda() 146 | 147 | if args.evaluate: 148 | prec1 = validate(val_loader, model) 149 | print(prec1) 150 | exit(0) 151 | 152 | # Main training loop 153 | iter_max = args.epochs * len(train_loader) 154 | iter_stat = IterNums(iter_max) 155 | for epoch in range(args.start_epoch, args.epochs): 156 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir)) 157 | logger.log("Epoch: %d"%(epoch+1)) 158 | # train for one epoch 159 | train(train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger.writer, model_old=model_old, adjust_lr=True) 160 | 161 | # evaluate on validation set 162 | prec1 = validate(val_loader, model) 163 | logger.writer.add_scalar("prec", prec1, epoch) 164 | 165 | # remember best prec@1 and save checkpoint 166 | is_best = prec1 > best_prec1 167 | best_prec1 = max(prec1, best_prec1) 168 | save_checkpoint(args.save_dir, { 169 | 'epoch': epoch + 1, 170 | 'state_dict': model.state_dict(), 171 | 'best_prec1': best_prec1, 172 | }, is_best) 173 | 174 | logging.info('Best accuracy: {prec1:.3f}'.format(prec1=best_prec1)) 175 | 176 | 177 | def train(train_loader, model, optimizer, base_lrs, iter_stat, epoch, writer, model_old=None, adjust_lr=True): 178 | kl_weight = args.lwf 179 | """Train for one epoch on the training set""" 180 | batch_time = AverageMeter() 181 | losses = AverageMeter() 182 | losses_kl = AverageMeter() 183 | 184 | model.eval() 185 | 186 | # start timer 187 | end = time.time() 188 | 189 | # train for one epoch 190 | optimizer.zero_grad() 191 | epoch_size = len(train_loader) 192 | train_loader_iter = iter(train_loader) 193 | 194 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 195 | pbar = tqdm(range(epoch_size), file=sys.stdout, bar_format=bar_format, ncols=80) 196 | 197 | for idx_iter in pbar: 198 | 199 | optimizer.zero_grad() 200 | if adjust_lr: 201 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9) 202 | writer.add_scalar("lr", lr, idx_iter + epoch * epoch_size) 203 | adjust_learning_rate(base_lrs, optimizer, iter_stat.iter_curr, iter_stat.iter_max, 0.9) 204 | 205 | input, label = next(train_loader_iter) 206 | label = label.cuda() 207 | input = input.cuda() 208 | 209 | # compute output 210 | output, features_new = model(input, output_features=['layer1', 'layer4'], task='new') 211 | 212 | # compute gradient 213 | loss = CrossEntropyLoss(output, label.long()) 214 | 215 | # LWF KL div 216 | if model_old is None: 217 | loss_kl = 0 218 | else: 219 | output_new = model.forward_fc(features_new['layer4'], task='old') 220 | output_old, features_old = model_old(input, output_features=['layer1', 'layer4'], task='old') 221 | loss_kl = KLDivLoss(F.log_softmax(output_new, dim=1), F.softmax(output_old, dim=1)).sum(-1) 222 | 223 | (loss + kl_weight * loss_kl).backward() 224 | 225 | # measure accuracy and record loss 226 | losses.update(loss, input.size(0)) 227 | losses_kl.update(loss_kl, input.size(0)) 228 | 229 | # compute gradient and do SGD step 230 | optimizer.step() 231 | # increment iter number 232 | iter_stat.update() 233 | # measure elapsed time 234 | batch_time.update(time.time() - end) 235 | end = time.time() 236 | 237 | writer.add_scalar("loss/ce", losses.val, idx_iter + epoch * epoch_size) 238 | writer.add_scalar("loss/kl", losses_kl.val, idx_iter + epoch * epoch_size) 239 | writer.add_scalar("loss/total", losses.val + losses_kl.val, idx_iter + epoch * epoch_size) 240 | description = "[loss: %.3f][loss_kl: %.3f]"%(losses.val, losses_kl.val) 241 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, epoch_size) + description) 242 | 243 | 244 | def validate(val_loader, model): 245 | """Perform validation on the validation set""" 246 | batch_time = AverageMeter() 247 | top1 = AverageMeter() 248 | 249 | model.eval() 250 | 251 | end = time.time() 252 | val_size = len(val_loader) 253 | val_loader_iter = iter(val_loader) 254 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 255 | pbar = tqdm(range(val_size), file=sys.stdout, bar_format=bar_format, ncols=140) 256 | with torch.no_grad(): 257 | for idx_iter in pbar: 258 | input, label = next(val_loader_iter) 259 | 260 | input = input.cuda() 261 | label = label.cuda() 262 | 263 | # compute output 264 | output = torch.sigmoid(model(input, task='new')[0]) 265 | output = (output + torch.sigmoid(model(torch.flip(input, dims=(3,)), task='new')[0])) / 2 266 | 267 | # accumulate accuracyk 268 | prec1, gt_num = accuracy(output.data, label, args.num_class, topk=(1,)) 269 | top1.update(prec1[0], gt_num[0]) 270 | 271 | # measure elapsed time 272 | batch_time.update(time.time() - end) 273 | end = time.time() 274 | 275 | description = "[Acc@1-mean: %.2f][Acc@1-cls: %s]"%(top1.vec2sca_avg, str(top1.avg.numpy().round(1))) 276 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, val_size) + description) 277 | 278 | logging.info(' * Prec@1 {top1.vec2sca_avg:.3f}'.format(top1=top1)) 279 | logging.info(' * Prec@1 {top1.avg}'.format(top1=top1)) 280 | 281 | return top1.vec2sca_avg 282 | 283 | 284 | if __name__ == "__main__": 285 | main() 286 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | # [x] train resnet101 with proxy guidance on visda17 5 | # [x] evaluation on visda17 6 | 7 | python train.py \ 8 | --epochs 30 \ 9 | --batch-size 32 \ 10 | --lr 1e-4 \ 11 | --lwf 0.1 \ 12 | # --resume pretrained/res101_vista17_best.pth.tar \ 13 | # --evaluate 14 | -------------------------------------------------------------------------------- /train_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import argparse 5 | import os 6 | import sys 7 | import logging 8 | import time 9 | import numpy as np 10 | import math 11 | from tqdm import tqdm 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.optim 16 | from torch.nn import functional as F 17 | from torch.utils.tensorboard import SummaryWriter 18 | from data.gta5 import GTA5 19 | from data.cityscapes import Cityscapes 20 | from model.vgg import vgg16 21 | from model.fcn8s_vgg import FCN8sAtOnce as FCN_Vgg 22 | from dataloader_seg import get_train_loader 23 | from eval_seg import SegEvaluator 24 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, adjust_learning_rate 25 | from pdb import set_trace as bp 26 | torch.backends.cudnn.enabled = True 27 | 28 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean', ignore_index=255) 29 | KLDivLoss = nn.KLDivLoss(reduction='batchmean') 30 | best_mIoU = 0 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch ResNet Training') 33 | parser.add_argument('--epochs', default=300, type=int, help='number of total epochs to run') 34 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 35 | parser.add_argument('--batch-size', default=6, type=int, dest='batch_size', help='mini-batch size (default: 6)') 36 | parser.add_argument('--iter-size', default=1, type=int, dest='iter_size', help='iteration size (default: 1)') 37 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, help='initial learning rate') 38 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)') 39 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') 40 | parser.add_argument('--lwf', default=0., type=float, dest='lwf', help='weight of KL loss for LwF (default: 0)') 41 | parser.add_argument('--factor', default=0.1, type=float, dest='factor', help='scale factor of backbone learning rate (default: 0.1)') 42 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)') 43 | parser.add_argument('--name', default='Vgg16_GTA5', type=str, help='name of experiment') 44 | parser.add_argument('--tensorboard', help='Log progress to TensorBoard', action='store_true') 45 | parser.add_argument('--num-class', default=19, type=int, dest='num_class', help='the number of classes') 46 | parser.add_argument('--gpus', default=0, type=int, help='use gpu with cuda number') 47 | parser.add_argument('--evaluate', action='store_true', help='whether to use learn without forgetting (default: False)') 48 | parser.set_defaults(bottleneck=True) 49 | parser.set_defaults(augment=True) 50 | 51 | 52 | def main(): 53 | global args, best_mIoU 54 | args = parser.parse_args() 55 | pid = os.getpid() 56 | 57 | # Log outputs 58 | args.name = "GTA5_Vgg16_batch%d_512x512_Poly_LR%.1e_1to%.1f_all_lwf.%d_epoch%d"%(args.batch_size, args.lr, args.factor, args.lwf, args.epochs) 59 | if args.resume: 60 | args.name += "_resumed" 61 | directory = "runs/%s/"%(args.name) 62 | if not os.path.exists(directory): 63 | os.makedirs(directory) 64 | filename = directory + 'train.log' 65 | for handler in logging.root.handlers[:]: 66 | logging.root.removeHandler(handler) 67 | rootLogger = logging.getLogger() 68 | logFormatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") 69 | fileHandler = logging.FileHandler(filename) 70 | fileHandler.setFormatter(logFormatter) 71 | rootLogger.addHandler(fileHandler) 72 | 73 | consoleHandler = logging.StreamHandler() 74 | consoleHandler.setFormatter(logFormatter) 75 | rootLogger.addHandler(consoleHandler) 76 | rootLogger.setLevel(logging.INFO) 77 | 78 | writer = SummaryWriter(directory) 79 | 80 | from config_seg import config as data_setting 81 | data_setting.batch_size = args.batch_size 82 | train_loader = get_train_loader(data_setting, GTA5, test=False) 83 | 84 | ##### Vgg16 ##### 85 | vgg = vgg16(pretrained=True) 86 | model = FCN_Vgg(n_class=args.num_class) 87 | model.copy_params_from_vgg16(vgg) 88 | ################### 89 | threds = 1 90 | evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), args.num_class, np.array([0.485, 0.456, 0.406]), 91 | np.array([0.229, 0.224, 0.225]), model, [1, ], False, devices=args.gpus, config=data_setting, threds=threds, 92 | verbose=False, save_path=None, show_image=False) 93 | 94 | # Setup optimizer 95 | ##### Vgg16 ##### 96 | sgd_in = [ 97 | {'params': get_params(model, ["conv1_1", "conv1_2"]), 'lr': args.factor*args.lr}, 98 | {'params': get_params(model, ["conv2_1", "conv2_2"]), 'lr': args.factor*args.lr}, 99 | {'params': get_params(model, ["conv3_1", "conv3_2", "conv3_3"]), 'lr': args.factor*args.lr}, 100 | {'params': get_params(model, ["conv4_1", "conv4_2", "conv4_3"]), 'lr': args.factor*args.lr}, 101 | {'params': get_params(model, ["conv5_1", "conv5_2", "conv5_3"]), 'lr': args.factor*args.lr}, 102 | {'params': get_params(model, ["fc6", "fc7"]), 'lr': args.factor*args.lr}, 103 | {'params': get_params(model, ["score_fr", "score_pool3", "score_pool4", "upscore2", "upscore8", "upscore_pool4"]), 'lr': args.lr}, 104 | ] 105 | base_lrs = [ group['lr'] for group in sgd_in ] 106 | optimizer = torch.optim.SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 107 | 108 | # Optionally resume from a checkpoint 109 | if args.resume: 110 | if os.path.isfile(args.resume): 111 | print("=> loading checkpoint '{}'".format(args.resume)) 112 | checkpoint = torch.load(args.resume) 113 | args.start_epoch = checkpoint['epoch'] 114 | best_mIoU = checkpoint['best_mIoU'] 115 | model.load_state_dict(checkpoint['state_dict']) 116 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 117 | else: 118 | print("=ImageClassdata> no checkpoint found at '{}'".format(args.resume)) 119 | 120 | model = model.cuda() 121 | model_old = None 122 | if args.lwf > 0: 123 | # create a fixed model copy for Life-long learning 124 | model_old = vgg16(pretrained=True) 125 | ################### 126 | for param in model_old.parameters(): 127 | param.requires_grad = False 128 | model_old.eval() 129 | model_old.cuda() 130 | 131 | if args.evaluate: 132 | mIoU = validate(evaluator, model) 133 | print(mIoU) 134 | 135 | # Main training loop 136 | iter_max = args.epochs * math.ceil(len(train_loader)/args.iter_size) 137 | iter_stat = IterNums(iter_max) 138 | for epoch in range(args.start_epoch, args.epochs): 139 | logging.info("============= " + args.name + " ================") 140 | logging.info("============= PID: " + str(pid) + " ================") 141 | logging.info("Epoch: %d"%(epoch+1)) 142 | # train for one epoch 143 | train(args, train_loader, model, optimizer, base_lrs, iter_stat, epoch, writer, model_old=model_old, adjust_lr=epoch best_mIoU 150 | best_mIoU = max(mIoU, best_mIoU) 151 | save_checkpoint(directory, { 152 | 'epoch': epoch + 1, 153 | 'state_dict': model.state_dict(), 154 | 'best_mIoU': best_mIoU, 155 | }, is_best) 156 | 157 | logging.info('Best accuracy: {mIoU:.3f}'.format(mIoU=best_mIoU)) 158 | 159 | 160 | def train(args, train_loader, model, optimizer, base_lrs, iter_stat, epoch, writer, model_old=None, adjust_lr=True): 161 | """Train for one epoch on the training set""" 162 | losses = AverageMeter() 163 | losses_kl = AverageMeter() 164 | 165 | model.eval() 166 | 167 | # train for one epoch 168 | optimizer.zero_grad() 169 | epoch_size = len(train_loader) 170 | train_loader_iter = iter(train_loader) 171 | 172 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 173 | pbar = tqdm(range(epoch_size), file=sys.stdout, bar_format=bar_format, ncols=80) 174 | 175 | for idx_iter in pbar: 176 | loss_print = 0 177 | loss_kl_print = 0 178 | avg_size = 0 179 | 180 | optimizer.zero_grad() 181 | if adjust_lr: 182 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9) 183 | writer.add_scalar("lr", lr, idx_iter + epoch * epoch_size) 184 | adjust_learning_rate(base_lrs, optimizer, iter_stat.iter_curr, iter_stat.iter_max, 0.9) 185 | 186 | sample = next(train_loader_iter) 187 | label = sample['label'].cuda() 188 | input = sample['data'].cuda() 189 | 190 | # compute output 191 | output, features_new = model(input, output_features=['layer4'], task='new_seg') 192 | 193 | # compute gradient 194 | loss = CrossEntropyLoss(output, label.long()) 195 | loss_print += loss 196 | 197 | # LWF KL div 198 | if model_old is None: 199 | loss_kl = 0 200 | else: 201 | output_new = model_old.forward_fc(features_new['layer4'], task='old') 202 | output_old, features_old = model_old(input, output_features=[], task='old') 203 | loss_kl = KLDivLoss(F.log_softmax(output_new, dim=1), F.softmax(output_old, dim=1)).sum(-1) 204 | loss_kl_print += loss_kl 205 | 206 | (loss + args.lwf * loss_kl).backward() 207 | 208 | # update size 209 | avg_size += input.size(0) 210 | 211 | # measure accuracy and record loss 212 | losses.update(loss_print, avg_size) 213 | losses_kl.update(loss_kl_print, avg_size) 214 | 215 | # compute gradient and do SGD step 216 | optimizer.step() 217 | # increment iter number 218 | iter_stat.update() 219 | 220 | writer.add_scalar("loss/ce", losses.val, idx_iter + epoch * epoch_size) 221 | writer.add_scalar("loss/kl", losses_kl.val, idx_iter + epoch * epoch_size) 222 | writer.add_scalar("loss/total", losses.val + losses_kl.val, idx_iter + epoch * epoch_size) 223 | description = "[loss: %.3f][loss_kl: %.3f]"%(losses.val, losses_kl.val) 224 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, epoch_size) + description) 225 | 226 | 227 | def validate(evaluator, model): 228 | with torch.no_grad(): 229 | model.eval() 230 | # _, mIoU = evaluator.run_online() 231 | _, mIoU = evaluator.run_online_multiprocess() 232 | return mIoU 233 | 234 | 235 | if __name__ == '__main__': 236 | main() 237 | -------------------------------------------------------------------------------- /train_seg.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | # [x] train vgg16 with proxy guidance on GTA5 5 | # [x] evaluation on Cityscapes 6 | 7 | python train_seg.py \ 8 | --epochs 50 \ 9 | --batch-size 6 \ 10 | --lr 1e-3 \ 11 | --num-class 19 \ 12 | --gpus 0 \ 13 | --factor 0.1 \ 14 | --lwf 75. \ 15 | # --evaluate \ 16 | # --resume ./pretrained/vgg16_segmentation_best.pth.tar 17 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | from pathlib import Path 5 | import importlib, warnings 6 | import os, sys, time, numpy as np 7 | import torch, random, PIL, copy 8 | from os import path as osp 9 | from shutil import copyfile 10 | if sys.version_info.major == 2: # Python 2.x 11 | from StringIO import StringIO as BIO 12 | else: # Python 3.x 13 | from io import BytesIO as BIO 14 | from torch.utils.tensorboard import SummaryWriter 15 | if importlib.util.find_spec('tensorflow'): 16 | import tensorflow as tf 17 | 18 | 19 | def prepare_seed(rand_seed): 20 | random.seed(rand_seed) 21 | np.random.seed(rand_seed) 22 | torch.manual_seed(rand_seed) 23 | torch.cuda.manual_seed(rand_seed) 24 | torch.cuda.manual_seed_all(rand_seed) 25 | 26 | 27 | def prepare_logger(xargs): 28 | args = copy.deepcopy( xargs ) 29 | logger = Logger(args.save_dir, args.rand_seed) 30 | logger.log('Main Function with logger : {:}'.format(logger)) 31 | logger.log('Arguments : -------------------------------') 32 | for name, value in args._get_kwargs(): 33 | logger.log('{:16} : {:}'.format(name, value)) 34 | logger.log("Python Version : {:}".format(sys.version.replace('\n', ' '))) 35 | logger.log("Pillow Version : {:}".format(PIL.__version__)) 36 | logger.log("PyTorch Version : {:}".format(torch.__version__)) 37 | logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version())) 38 | logger.log("CUDA available : {:}".format(torch.cuda.is_available())) 39 | logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) 40 | logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None')) 41 | return logger 42 | 43 | 44 | class PrintLogger(object): 45 | 46 | def __init__(self): 47 | """Create a summary writer logging to log_dir.""" 48 | self.name = 'PrintLogger' 49 | 50 | def log(self, string): 51 | print (string) 52 | 53 | def close(self): 54 | print ('-'*30 + ' close printer ' + '-'*30) 55 | 56 | 57 | class Logger(object): 58 | 59 | def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False): 60 | """Create a summary writer logging to log_dir.""" 61 | self.seed = int(seed) 62 | self.log_dir = Path(log_dir) 63 | self.model_dir = Path(log_dir) / 'model' 64 | self.log_dir.mkdir (parents=True, exist_ok=True) 65 | # if create_model_dir: 66 | # self.model_dir.mkdir(parents=True, exist_ok=True) 67 | #self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True) 68 | 69 | self.use_tf = bool(use_tf) 70 | self.tensorboard_dir = self.log_dir 71 | #self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) ))) 72 | # self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time()))) 73 | self.logger_path = self.log_dir / 'seed-{:}.log'.format(self.seed) 74 | self.logger_file = open(self.logger_path, 'w') 75 | 76 | self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True) 77 | self.writer = SummaryWriter(str(self.tensorboard_dir)) 78 | 79 | def __repr__(self): 80 | return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__)) 81 | 82 | def path(self, mode): 83 | valids = ('model', 'best', 'info', 'log') 84 | if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed) 85 | elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed) 86 | elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed) 87 | elif mode == 'log' : return self.log_dir 88 | else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids)) 89 | 90 | def extract_log(self): 91 | return self.logger_file 92 | 93 | def close(self): 94 | self.logger_file.close() 95 | if self.writer is not None: 96 | self.writer.close() 97 | 98 | def log(self, string, save=True, stdout=False): 99 | if stdout: 100 | sys.stdout.write(string); sys.stdout.flush() 101 | else: 102 | print (string) 103 | if save: 104 | self.logger_file.write('{:}\n'.format(string)) 105 | self.logger_file.flush() 106 | 107 | def scalar_summary(self, tags, values, step): 108 | """Log a scalar variable.""" 109 | if not self.use_tf: 110 | warnings.warn('Do set use-tensorflow installed but call scalar_summary') 111 | else: 112 | assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values)) 113 | if not isinstance(tags, list): 114 | tags, values = [tags], [values] 115 | for tag, value in zip(tags, values): 116 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 117 | self.writer.add_summary(summary, step) 118 | self.writer.flush() 119 | 120 | def image_summary(self, tag, images, step): 121 | """Log a list of images.""" 122 | import scipy 123 | if not self.use_tf: 124 | warnings.warn('Do set use-tensorflow installed but call scalar_summary') 125 | return 126 | 127 | img_summaries = [] 128 | for i, img in enumerate(images): 129 | # Write the image to a string 130 | try: 131 | s = StringIO() 132 | except: 133 | s = BytesIO() 134 | scipy.misc.toimage(img).save(s, format="png") 135 | 136 | # Create an Image object 137 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 138 | height=img.shape[0], 139 | width=img.shape[1]) 140 | # Create a Summary value 141 | img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum)) 142 | 143 | # Create and write Summary 144 | summary = tf.Summary(value=img_summaries) 145 | self.writer.add_summary(summary, step) 146 | self.writer.flush() 147 | 148 | def histo_summary(self, tag, values, step, bins=1000): 149 | """Log a histogram of the tensor of values.""" 150 | if not self.use_tf: raise ValueError('Do not have tensorflow') 151 | import tensorflow as tf 152 | 153 | # Create a histogram using numpy 154 | counts, bin_edges = np.histogram(values, bins=bins) 155 | 156 | # Fill the fields of the histogram proto 157 | hist = tf.HistogramProto() 158 | hist.min = float(np.min(values)) 159 | hist.max = float(np.max(values)) 160 | hist.num = int(np.prod(values.shape)) 161 | hist.sum = float(np.sum(values)) 162 | hist.sum_squares = float(np.sum(values**2)) 163 | 164 | # Drop the start of the first bin 165 | bin_edges = bin_edges[1:] 166 | 167 | # Add bin edges and counts 168 | for edge in bin_edges: 169 | hist.bucket_limit.append(edge) 170 | for c in counts: 171 | hist.bucket.append(c) 172 | 173 | # Create and write Summary 174 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 175 | self.writer.add_summary(summary, step) 176 | self.writer.flush() 177 | -------------------------------------------------------------------------------- /utils/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.) 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | 8 | # fixed SGD 9 | # See Note here: https://pytorch.org/docs/stable/optim.html#torch.optim.SGD 10 | class SGD(Optimizer): 11 | r"""Implements stochastic gradient descent (optionally with momentum). 12 | 13 | Nesterov momentum is based on the formula from 14 | `On the importance of initialization and momentum in deep learning`__. 15 | 16 | Args: 17 | params (iterable): iterable of parameters to optimize or dicts defining 18 | parameter groups 19 | lr (float): learning rate 20 | momentum (float, optional): momentum factor (default: 0) 21 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 22 | dampening (float, optional): dampening for momentum (default: 0) 23 | nesterov (bool, optional): enables Nesterov momentum (default: False) 24 | 25 | Example: 26 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 27 | >>> optimizer.zero_grad() 28 | >>> loss_fn(model(input), target).backward() 29 | >>> optimizer.step() 30 | 31 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 32 | 33 | .. note:: 34 | The implementation of SGD with Momentum/Nesterov subtly differs from 35 | Sutskever et. al. and implementations in some other frameworks. 36 | 37 | Considering the specific case of Momentum, the update can be written as 38 | 39 | .. math:: 40 | v = \rho * v + g \\ 41 | p = p - lr * v 42 | 43 | where p, g, v and :math:`\rho` denote the parameters, gradient, 44 | velocity, and momentum respectively. 45 | 46 | This is in contrast to Sutskever et. al. and 47 | other frameworks which employ an update of the form 48 | 49 | .. math:: 50 | v = \rho * v + lr * g \\ 51 | p = p - v 52 | 53 | The Nesterov version is analogously modified. 54 | """ 55 | 56 | def __init__(self, params, lr=required, momentum=0, dampening=0, 57 | weight_decay=0, nesterov=False): 58 | if lr is not required and lr < 0.0: 59 | raise ValueError("Invalid learning rate: {}".format(lr)) 60 | if momentum < 0.0: 61 | raise ValueError("Invalid momentum value: {}".format(momentum)) 62 | if weight_decay < 0.0: 63 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 64 | 65 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 66 | weight_decay=weight_decay, nesterov=nesterov) 67 | if nesterov and (momentum <= 0 or dampening != 0): 68 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 69 | super(SGD, self).__init__(params, defaults) 70 | 71 | def __setstate__(self, state): 72 | super(SGD, self).__setstate__(state) 73 | for group in self.param_groups: 74 | group.setdefault('nesterov', False) 75 | 76 | def step(self, closure=None): 77 | """Performs a single optimization step. 78 | 79 | Arguments: 80 | closure (callable, optional): A closure that reevaluates the model 81 | and returns the loss. 82 | """ 83 | loss = None 84 | if closure is not None: 85 | loss = closure() 86 | 87 | for group in self.param_groups: 88 | weight_decay = group['weight_decay'] 89 | momentum = group['momentum'] 90 | dampening = group['dampening'] 91 | nesterov = group['nesterov'] 92 | 93 | for p in group['params']: 94 | if p.grad is None: 95 | continue 96 | d_p = p.grad.data 97 | if weight_decay != 0: 98 | d_p.add_(weight_decay, p.data) 99 | if momentum != 0: 100 | param_state = self.state[p] 101 | if 'momentum_buffer' not in param_state: 102 | # buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 103 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach().mul_(group['lr']) 104 | else: 105 | buf = param_state['momentum_buffer'] 106 | # buf.mul_(momentum).add_(1 - dampening, d_p) 107 | buf.mul_(momentum).add_(1 - dampening, d_p.mul_(group['lr'])) 108 | if nesterov: 109 | d_p = d_p.add(momentum, buf) 110 | else: 111 | d_p = buf 112 | 113 | # p.data.add_(-group['lr'], d_p) 114 | p.data.add_(-1, d_p) 115 | 116 | return loss 117 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved. 2 | # This work is licensed under a NVIDIA Open Source Non-commercial license. 3 | 4 | import glob 5 | import os 6 | import shutil 7 | import numpy as np 8 | import torch 9 | from pdb import set_trace as bp 10 | 11 | 12 | def get_params(model, layers=["layer4"]): 13 | """ 14 | This generator returns all the parameters of the net except for 15 | the last classification layer. Note that for each batchnorm layer, 16 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 17 | any batchnorm parameter 18 | """ 19 | if isinstance(layers, str): 20 | layers = [layers] 21 | b = [] 22 | for layer in layers: 23 | b.append(getattr(model, layer)) 24 | 25 | for i in range(len(b)): 26 | for k, v in b[i].named_parameters(): 27 | if v.requires_grad: 28 | yield v 29 | 30 | 31 | def adjust_learning_rate(base_lrs, optimizer, iter_curr, iter_max, power): 32 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 33 | num_groups = len(optimizer.param_groups) 34 | for g in range(num_groups): 35 | optimizer.param_groups[g]['lr'] = lr_poly(base_lrs[g], iter_curr, iter_max, power) 36 | 37 | 38 | def lr_poly(base_lr, iter, max_iter, power): 39 | # return min(0.01+0.99*(float(iter)/100)**2.0, 1.0) * base_lr * ((1-float(iter)/max_iter)**power) # This is with warm up 40 | return min(0.01+0.99*(float(iter)/100)**2.0, 1.0) * base_lr * ((1-min(float(iter)/max_iter, 0.8))**power) # This is with warm up & no smaller than last 20% LR 41 | # return base_lr * ((1-float(iter)/max_iter)**power) 42 | 43 | 44 | def save_checkpoint(name, state, is_best, filename='checkpoint.pth.tar', keep_last=1): 45 | """Saves checkpoint to disk""" 46 | directory = name 47 | if not os.path.exists(directory): 48 | os.makedirs(directory) 49 | models_paths = list(filter(os.path.isfile, glob.glob(directory + "/epoch*.pth.tar"))) 50 | models_paths.sort(key=os.path.getmtime, reverse=False) 51 | if len(models_paths) == keep_last: 52 | for i in range(len(models_paths) + 1 - keep_last): 53 | os.remove(models_paths[i]) 54 | # filename = directory + '/epoch_'+str(state['epoch']) + '_' + filename 55 | filename = directory + '/latest_' + filename 56 | torch.save(state, filename) 57 | if is_best: 58 | shutil.copyfile(filename, '%s/'%(name) + 'model_best.pth.tar') 59 | 60 | 61 | class IterNums(object): 62 | def __init__(self, iter_max): 63 | self.iter_max = iter_max 64 | self.iter_curr = 0 65 | 66 | def reset(self): 67 | self.iter_curr = 0 68 | 69 | def update(self): 70 | self.iter_curr += 1 71 | 72 | 73 | class AverageMeter(object): 74 | """Computes and stores the average and current value""" 75 | def __init__(self): 76 | self.reset() 77 | 78 | def reset(self): 79 | self.val = 0 80 | self.avg = 0 81 | self.sum = 0 82 | self.count = 0 83 | self.vec2sca_avg = 0 84 | self.vec2sca_val = 0 85 | 86 | def update(self, val, n=1): 87 | self.val = val 88 | self.sum += val * n 89 | self.count += n 90 | self.avg = self.sum / self.count 91 | if torch.is_tensor(self.val) and torch.numel(self.val) != 1: 92 | self.avg[self.count == 0] = 0 93 | self.vec2sca_avg = self.avg.sum() / len(self.avg) 94 | self.vec2sca_val = self.val.sum() / len(self.val) 95 | 96 | 97 | class ROC(object): 98 | def __init__(self, num_class): 99 | self.num_class = num_class 100 | self.pred = [] 101 | self.label = [] 102 | 103 | def update(self, pred, label): 104 | assert (self.num_class == pred.shape[0]), "num_class mismatch on input predictions!" 105 | assert (self.num_class == label.shape[0]), "num_class mismatch on input labels!" 106 | self.pred.append(pred) 107 | self.label.append(label) 108 | 109 | def roc_curve(self): 110 | pred = np.hstack(self.pred) 111 | label = np.hstack(self.label) 112 | p = label == 1 113 | n = ~p 114 | num_p = np.sum(p, axis=1) 115 | num_n = np.sum(n, axis=1) 116 | tpr = np.zeros((self.num_class, 101), np.float32) 117 | fpr = np.zeros((self.num_class, 101), np.float32) 118 | for idx in range(101): 119 | thre = 1 - idx/100.0 120 | pp = pred > thre 121 | tp = pp & p 122 | fp = pp & n 123 | num_tp = np.sum(tp, axis=1) 124 | num_fp = np.sum(fp, axis=1) 125 | tpr[:, idx] = num_tp/(num_p + (num_p == 0)) 126 | fpr[:, idx] = num_fp/(num_n + (num_n == 0)) 127 | return tpr, fpr 128 | 129 | def auc(self, tpr, fpr): 130 | assert(tpr.shape[0] == fpr.shape[0]) 131 | auc = np.zeros(tpr.shape[0], np.float32) 132 | for idx in range(tpr.shape[0]): 133 | auc[idx] = metrics.auc(fpr[idx, :], tpr[idx, :]) 134 | return auc 135 | 136 | 137 | def accuracy(output, label, num_class, topk=(1,)): 138 | """Computes the precision@k for the specified values of k, currently only k=1 is supported""" 139 | maxk = max(topk) 140 | 141 | _, pred = output.topk(maxk, 1, True, True) 142 | if len(label.size()) == 2: 143 | # one_hot label 144 | _, gt = label.topk(maxk, 1, True, True) 145 | else: 146 | gt = label 147 | pred = pred.t() 148 | pred_class_idx_list = [pred == class_idx for class_idx in range(num_class)] 149 | gt = gt.t() 150 | gt_class_number_list = [(gt == class_idx).sum() for class_idx in range(num_class)] 151 | correct = pred.eq(gt) 152 | 153 | res = [] 154 | gt_num = [] 155 | for k in topk: 156 | correct_k = correct[:k].float() 157 | per_class_correct_list = [correct_k[pred_class_idx].sum(0) for pred_class_idx in pred_class_idx_list] 158 | per_class_correct_array = torch.tensor(per_class_correct_list) 159 | gt_class_number_tensor = torch.tensor(gt_class_number_list).float() 160 | gt_class_zeronumber_tensor = gt_class_number_tensor == 0 161 | gt_class_number_matrix = torch.tensor(gt_class_number_list).float() 162 | gt_class_acc = per_class_correct_array.mul_(100.0 / gt_class_number_matrix) 163 | gt_class_acc[gt_class_zeronumber_tensor] = 0 164 | res.append(gt_class_acc) 165 | gt_num.append(gt_class_number_matrix) 166 | return res, gt_num 167 | --------------------------------------------------------------------------------