├── .gitignore
├── LICENSE.md
├── README.md
├── config_seg.py
├── data
├── BaseDataset.py
├── cityscapes.py
├── cityscapes
│ ├── cityscapes_test.txt
│ ├── cityscapes_val_fine.txt
│ └── cityscapes_val_fine_raw.txt
├── gta5.py
├── gta5
│ ├── gta5_train.txt
│ └── gta5_train_raw.txt
└── visda17.py
├── dataloader_seg.py
├── eval_seg.py
├── l2o_train.py
├── l2o_train.sh
├── l2o_train_seg.py
├── l2o_train_seg.sh
├── model
├── __init__.py
├── fcn8s_vgg.py
├── resnet.py
└── vgg.py
├── reinforce
├── __init__.py
├── algo
│ ├── __init__.py
│ └── reinforce.py
├── arguments.py
├── distributions.py
├── models
│ ├── policy.py
│ └── rnn_state_encoder.py
├── storage.py
└── utils.py
├── tools
├── datasets
│ ├── BaseDataset.py
│ └── cityscapes
│ │ ├── cityscapes.py
│ │ ├── cityscapes_test.txt
│ │ ├── cityscapes_train_fine.txt
│ │ └── cityscapes_val_fine.txt
├── engine
│ ├── evaluator.py
│ ├── logger.py
│ └── tester.py
├── seg_opr
│ └── metric.py
└── utils
│ ├── img_utils.py
│ ├── pyt_utils.py
│ └── visualize.py
├── train.py
├── train.sh
├── train_seg.py
├── train_seg.sh
└── utils
├── __init__.py
├── logger.py
├── sgd.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # vim swp files
2 | *.swp
3 | # caffe/pytorch model files
4 | runs/*
5 | pretrained/*
6 | crst_visda/runs/*
7 | *.pth
8 | *.tar
9 | *_softmax.txt
10 | *_seed*.txt
11 | *_soft*.txt
12 | *.json
13 |
14 | # Mkdocs
15 | # /docs/
16 | /mkdocs/docs/temp
17 |
18 | .DS_Store
19 | .idea
20 | .vscode
21 | .pytest_cache
22 | /experiments
23 | node_modules/
24 | history/
25 | ablation/
26 | misc/
27 | prediction/
28 | results/
29 |
30 | # resource temp folder
31 | tests/resources/temp/*
32 | !tests/resources/temp/.gitkeep
33 |
34 | # Byte-compiled / optimized / DLL files
35 | __pycache__/
36 | *.py[cod]
37 | *$py.class
38 |
39 | # C extensions
40 | *.so
41 |
42 | # Distribution / packaging
43 | .Python
44 | build/
45 | develop-eggs/
46 | dist/
47 | downloads/
48 | eggs/
49 | .eggs/
50 | lib/
51 | lib64/
52 | parts/
53 | sdist/
54 | var/
55 | wheels/
56 | *.egg-info/
57 | .installed.cfg
58 | *.egg
59 | MANIFEST
60 |
61 | # PyInstaller
62 | # Usually these files are written by a python script from a template
63 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
64 | *.manifest
65 | *.spec
66 |
67 | # Installer logs
68 | pip-log.txt
69 | pip-delete-this-directory.txt
70 |
71 | # Unit test / coverage reports
72 | htmlcov/
73 | .tox/
74 | .coverage
75 | .coverage.*
76 | .cache
77 | nosetests.xml
78 | coverage.xml
79 | *.cover
80 | .hypothesis/
81 | .pytest_cache/
82 |
83 | # Translations
84 | *.mo
85 | *.pot
86 |
87 | # Django stuff:
88 | *.log
89 | .static_storage/
90 | .media/
91 | local_settings.py
92 | local_settings.py
93 | db.sqlite3
94 |
95 | # Flask stuff:
96 | instance/
97 | .webassets-cache
98 |
99 | # Scrapy stuff:
100 | .scrapy
101 |
102 | # Sphinx documentation
103 | docs/_build/
104 |
105 | # PyBuilder
106 | target/
107 |
108 | # Jupyter Notebook
109 | .ipynb_checkpoints
110 |
111 | # pyenv
112 | .python-version
113 |
114 | # celery beat schedule file
115 | celerybeat-schedule
116 |
117 | # SageMath parsed files
118 | *.sage.py
119 |
120 | # Environments
121 | .env
122 | .venv
123 | env/
124 | venv/
125 | ENV/
126 | env.bak/
127 | venv.bak/
128 |
129 | # Spyder project settings
130 | .spyderproject
131 | .spyproject
132 |
133 | # Rope project settings
134 | .ropeproject
135 |
136 | # mkdocs documentation
137 | /site
138 |
139 | # mypy
140 | .mypy_cache/
141 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | NVIDIA Source Code
2 |
3 | License for Automated Synthetic-to-real Generalization (ASG)
4 | ---
5 |
6 | 1. Definitions
7 |
8 | "Licensor" means any person or entity that distributes its Work.
9 |
10 | "Software" means the original work of authorship made available under this License.
11 |
12 | "Work" means the Software and any additions to or derivative works of the Software that are made available under this License.
13 |
14 | The terms "reproduce," "reproduction," "derivative works," and "distribution" have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
15 |
16 | Works, including the Software, are "made available" under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
17 |
18 | 2. License Grant
19 |
20 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
21 |
22 | 3. Limitations
23 |
24 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
25 |
26 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work ("Your Terms") only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
27 |
28 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, "non-commercially" means for research or evaluation purposes only.
29 |
30 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately.
31 |
32 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
33 |
34 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately.
35 |
36 | 4. Disclaimer of Warranty.
37 |
38 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
39 |
40 | 5. Limitation of Liability.
41 |
42 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # ASG: Automated Synthetic-to-Real Generalization
4 |
5 |
6 | [Paper](https://arxiv.org/abs/2007.06965)
7 |
8 | Automated Synthetic-to-Real Generalization.
9 | [Wuyang Chen](https://chenwydj.github.io/), [Zhiding Yu](https://chrisding.github.io/), [Zhangyang Wang](https://www.atlaswang.com/), [Anima Anandkumar](http://tensorlab.cms.caltech.edu/users/anima/).
10 | In ICML 2020.
11 |
12 | * Visda-17 to COCO
13 | - [x] train resnet101 with only proxy guidance
14 | - [x] train resnet101 with both proxy guidance and L2O policy
15 | - [x] evaluation
16 | * GTA5 to Cityscapes
17 | - [x] train vgg16 with only proxy guidance
18 | - [x] train vgg16 with both proxy guidance and L2O policy
19 | - [x] evaluation
20 |
21 | ## Usage
22 |
23 | ### Visda-17
24 | * Download [Visda-17 Dataset](http://ai.bu.edu/visda-2017/#download)
25 |
26 | #### Evaluation
27 | * Download [pretrained ResNet101 on Visda17](https://drive.google.com/file/d/1jjihDIxU1HIRtJEZyd7eTpYfO21OrY36/view?usp=sharing)
28 | * Put the checkpoint under `./ASG/pretrained/`
29 | * Put the code below in `train.sh`
30 | ```bash
31 | python train.py \
32 | --epochs 30 \
33 | --batch-size 32 \
34 | --lr 1e-4 \
35 | --lwf 0.1 \
36 | --resume pretrained/res101_vista17_best.pth.tar \
37 | --evaluate
38 | ```
39 | * Run `CUDA_VISIBLE_DEVICES=0 bash train.sh`
40 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
41 |
42 | #### Train with SGD
43 | * Put the code below in `train.sh`
44 | ```bash
45 | python train.py \
46 | --epochs 30 \
47 | --batch-size 32 \
48 | --lr 1e-4 \
49 | --lwf 0.1
50 | ```
51 | * Run `CUDA_VISIBLE_DEVICES=0 bash train.sh`
52 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
53 |
54 | #### Train with L2O
55 | * Download [pretrained L2O Policy on Visda17](https://drive.google.com/file/d/1Rc2Ey-FspUagFPTjnEozeSEIdA4ir7b1/view?usp=sharing)
56 | * Put the checkpoint under `./ASG/pretrained/`
57 | * Put the code below in `l2o_train.sh`
58 | ```bash
59 | python l2o_train.py \
60 | --epochs 30 \
61 | --batch-size 32 \
62 | --lr 1e-4 \
63 | --lwf 0.1 \
64 | --agent_load_dir ./ASG/pretrained/policy_res101_vista17.pth
65 | ```
66 | * Run `CUDA_VISIBLE_DEVICES=0 bash l2o_train.sh`
67 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
68 |
69 | ### GTA5 → Cityscapes
70 | * Download [GTA5 dataset](https://download.visinf.tu-darmstadt.de/data/from_games/).
71 | * Download the [leftImg8bit_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=3) and [gtFine_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=1) from the Cityscapes.
72 | * Prepare the annotations by using the [createTrainIdLabelImgs.py](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createTrainIdLabelImgs.py).
73 | * Put the [file of image list](tools/datasets/cityscapes/) into where you save the dataset.
74 | * **Remember to properly set the `C.dataset_path` in the `config_seg.py` to the path where datasets reside.**
75 |
76 | #### Evaluation
77 | * Download [pretrained Vgg16 on GTA5](https://drive.google.com/file/d/13HcsiyL-o1A9057ezJ4qCnGztnY5deQ6/view?usp=sharing)
78 | * Put the checkpoint under `./ASG/pretrained/`
79 | * Put the code below in `train_seg.sh`
80 | ```bash
81 | python train_seg.py \
82 | --epochs 50 \
83 | --batch-size 6 \
84 | --lr 1e-3 \
85 | --num-class 19 \
86 | --gpus 0 \
87 | --factor 0.1 \
88 | --lwf 75. \
89 | --evaluate \
90 | --resume ./pretrained/vgg16_segmentation_best.pth.tar
91 | ```
92 | * Run `CUDA_VISIBLE_DEVICES=0 bash train_seg.sh`
93 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
94 |
95 | #### Train with SGD
96 | * Put the code below in `train_seg.sh`
97 | ```bash
98 | python train_seg.py \
99 | --epochs 50 \
100 | --batch-size 6 \
101 | --lr 1e-3 \
102 | --num-class 19 \
103 | --gpus 0 \
104 | --factor 0.1 \
105 | --lwf 75. \
106 | ```
107 | * Run `CUDA_VISIBLE_DEVICES=0 bash train_seg.sh`
108 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
109 |
110 | #### Train with L2O
111 | * Download [pretrained L2O Policy on GTA5](https://drive.google.com/file/d/1RVQE0VxrtPCyUpsvNulpKKBQhYlOi1ag/view?usp=sharing)
112 | * Put the checkpoint under `./ASG/pretrained/`
113 | * Put the code below in `l2o_train_seg.sh`
114 | ```bash
115 | python l2o_train_seg.py \
116 | --epochs 50 \
117 | --batch-size 6 \
118 | --lr 1e-3 \
119 | --num-class 19 \
120 | --gpus 0 \
121 | --gamma 0 \
122 | --early-stop 2 \
123 | --lwf 75. \
124 | --algo reinforce \
125 | --agent_load_dir ./ASG/pretrained/policy_vgg16_segmentation.pth
126 | ```
127 | * Run `CUDA_VISIBLE_DEVICES=0 bash l2o_train_seg.sh`
128 | - Please update the GPU index via `CUDA_VISIBLE_DEVICES` based on your need.
129 |
130 | ## Citation
131 |
132 | If you use this code for your research, please cite:
133 |
134 | ```BibTeX
135 | @inproceedings{chen2020automated,
136 | author = {Chen, Wuyang and Yu, Zhiding and Wang, Zhangyang and Anandkumar, Anima},
137 | booktitle = {Proceedings of Machine Learning and Systems 2020},
138 | pages = {8272--8282},
139 | title = {Automated Synthetic-to-Real Generalization},
140 | year = {2020}
141 | }
142 | ```
143 |
--------------------------------------------------------------------------------
/config_seg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | # encoding: utf-8
5 |
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import os
11 | import os.path as osp
12 | import sys
13 | import numpy as np
14 | from easydict import EasyDict as edict
15 |
16 | C = edict()
17 | config = C
18 | cfg = C
19 |
20 | """please config ROOT_dir and user when u first using"""
21 | C.repo_name = 'ASG'
22 | C.abs_dir = osp.realpath(".")
23 | C.this_dir = C.abs_dir.split(osp.sep)[-1]
24 | C.root_dir = C.abs_dir[:C.abs_dir.index(C.repo_name) + len(C.repo_name)]
25 |
26 |
27 | """Data Dir"""
28 | C.dataset_path = "/home/chenwy/"
29 |
30 | C.train_img_root = os.path.join(C.dataset_path, "gta5")
31 | C.train_gt_root = os.path.join(C.dataset_path, "gta5")
32 | C.val_img_root = os.path.join(C.dataset_path, "cityscapes")
33 | C.val_gt_root = os.path.join(C.dataset_path, "cityscapes")
34 | C.test_img_root = os.path.join(C.dataset_path, "cityscapes")
35 | C.test_gt_root = os.path.join(C.dataset_path, "cityscapes")
36 |
37 | C.train_source = osp.join(C.train_img_root, "gta5_train.txt")
38 | C.train_target_source = osp.join(C.train_img_root, "cityscapes_train_fine.txt")
39 | C.eval_source = osp.join(C.val_img_root, "cityscapes_val_fine.txt")
40 | C.test_source = osp.join(C.test_img_root, "cityscapes_test.txt")
41 |
42 | """Image Config"""
43 | C.num_classes = 19
44 | C.background = -1
45 | C.image_mean = np.array([0.485, 0.456, 0.406])
46 | C.image_std = np.array([0.229, 0.224, 0.225])
47 | C.down_sampling_train = [1, 1] # first down_sampling then crop
48 | C.down_sampling_val = [1, 1] # first down_sampling then crop
49 | C.gt_down_sampling = 1
50 | C.num_train_imgs = 12403
51 | C.num_eval_imgs = 500
52 |
53 | """ Settings for network, this would be different for each kind of model"""
54 | C.bn_eps = 1e-5
55 | C.bn_momentum = 0.1
56 |
57 | """Train Config"""
58 | C.lr = 0.01
59 | C.momentum = 0.9
60 | C.weight_decay = 5e-4
61 | C.nepochs = 30
62 | C.niters_per_epoch = 2000
63 | C.num_workers = 16
64 | C.train_scale_array = [0.75, 1, 1.25]
65 | # C.train_scale_array = [1]
66 |
67 | """Eval Config"""
68 | C.eval_stride_rate = 5 / 6
69 | C.eval_scale_array = [1]
70 | C.eval_flip = True
71 | C.eval_base_size = 1024
72 | C.eval_crop_size = 1024
73 | C.eval_height = 1024
74 | C.eval_width = 2048
75 |
76 | # GTA5: 1052x1914
77 | C.image_height = 512
78 | C.image_width = 512
79 | C.is_test = False # if True, prediction files for the test set will be generated
80 | C.is_eval = False # if True, the train.py will only do evaluation for once
81 |
--------------------------------------------------------------------------------
/data/BaseDataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import os
5 | import cv2
6 | import torch
7 | import numpy as np
8 | from pdb import set_trace as bp
9 | import torch.utils.data as data
10 | cv2.setNumThreads(0)
11 |
12 |
13 | class BaseDataset(data.Dataset):
14 | def __init__(self, setting, split_name, preprocess=None, file_length=None):
15 | super(BaseDataset, self).__init__()
16 | self._split_name = split_name
17 | if split_name == 'train':
18 | self._img_path = setting['train_img_root']
19 | self._gt_path = setting['train_gt_root']
20 | elif split_name == 'val':
21 | self._img_path = setting['val_img_root']
22 | self._gt_path = setting['val_gt_root']
23 | elif split_name == 'test':
24 | self._img_path = setting['test_img_root']
25 | self._gt_path = setting['test_gt_root']
26 | self._train_source = setting['train_source']
27 | self._eval_source = setting['eval_source']
28 | self._test_source = setting['test_source'] if 'test_source' in setting else setting['eval_source']
29 | self._down_sampling = setting['down_sampling_train'] if split_name == 'train' else setting['down_sampling_val']
30 | print("using downsampling:", self._down_sampling)
31 | self._file_names = self._get_file_names(split_name)
32 | print("Found %d images"%len(self._file_names))
33 | self._file_length = file_length
34 | self.preprocess = preprocess
35 |
36 | def __len__(self):
37 | if self._file_length is not None:
38 | return self._file_length
39 | return len(self._file_names)
40 |
41 | def __getitem__(self, index):
42 | if self._file_length is not None:
43 | names = self._construct_new_file_names(self._file_length)[index]
44 | else:
45 | names = self._file_names[index]
46 | img_path = os.path.join(self._img_path, names[0])
47 | gt_path = os.path.join(self._gt_path, names[1])
48 | item_name = names[1].split("/")[-1].split(".")[0]
49 | img, gt = self._fetch_data(img_path, gt_path)
50 | img = img[:, :, ::-1]
51 | if self.preprocess is not None:
52 | img, gt, extra_dict = self.preprocess(img, gt)
53 |
54 | if self._split_name == 'train':
55 | img = torch.from_numpy(np.ascontiguousarray(img)).float()
56 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
57 | if self.preprocess is not None and extra_dict is not None:
58 | for k, v in extra_dict.items():
59 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v))
60 | if 'label' in k:
61 | extra_dict[k] = extra_dict[k].long()
62 | if 'img' in k:
63 | extra_dict[k] = extra_dict[k].float()
64 |
65 | output_dict = dict(data=img, label=gt, fn=str(item_name), n=len(self._file_names))
66 | if self.preprocess is not None and extra_dict is not None:
67 | output_dict.update(**extra_dict)
68 |
69 | return output_dict
70 |
71 | def _fetch_data(self, img_path, gt_path, dtype=None):
72 | img = self._open_image(img_path, down_sampling=self._down_sampling[0])
73 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype, down_sampling=self._down_sampling[1])
74 |
75 | return img, gt
76 |
77 | def _get_file_names(self, split_name):
78 | assert split_name in ['train', 'val', 'test']
79 | source = self._train_source
80 | if split_name == "val":
81 | source = self._eval_source
82 | elif split_name == 'test':
83 | source = self._test_source
84 |
85 | file_names = []
86 | with open(source) as f:
87 | files = f.readlines()
88 |
89 | for item in files:
90 | img_name, gt_name = self._process_item_names(item)
91 | file_names.append([img_name, gt_name])
92 |
93 | return file_names
94 |
95 | def _construct_new_file_names(self, length):
96 | assert isinstance(length, int)
97 | files_len = len(self._file_names)
98 | new_file_names = self._file_names * (length // files_len)
99 |
100 | rand_indices = torch.randperm(files_len).tolist()
101 | new_indices = rand_indices[:length % files_len]
102 |
103 | new_file_names += [self._file_names[i] for i in new_indices]
104 |
105 | return new_file_names
106 |
107 | @staticmethod
108 | def _process_item_names(item):
109 | item = item.strip()
110 | # item = item.split('\t')
111 | item = item.split(' ')
112 | img_name = item[0]
113 | gt_name = item[1]
114 |
115 | return img_name, gt_name
116 |
117 | def get_length(self):
118 | return self.__len__()
119 |
120 | @staticmethod
121 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None, down_sampling=1):
122 | # cv2: B G R
123 | # h w c
124 | img = np.array(cv2.imread(filepath, mode), dtype=dtype)
125 | if isinstance(down_sampling, int):
126 | try:
127 | H, W = img.shape[:2]
128 | except:
129 | print(img.shape, filepath)
130 | exit(0)
131 | if len(img.shape) == 3:
132 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_LINEAR)
133 | else:
134 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_NEAREST)
135 | assert img.shape[0] == H // down_sampling and img.shape[1] == W // down_sampling
136 | else:
137 | assert (isinstance(down_sampling, tuple) or isinstance(down_sampling, list)) and len(down_sampling) == 2
138 | if len(img.shape) == 3:
139 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_LINEAR)
140 | else:
141 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_NEAREST)
142 | assert img.shape[0] == down_sampling[0] and img.shape[1] == down_sampling[1]
143 |
144 | return img
145 |
146 | @classmethod
147 | def get_class_colors(*args):
148 | raise NotImplementedError
149 |
150 | @classmethod
151 | def get_class_names(*args):
152 | raise NotImplementedError
153 |
154 |
155 | if __name__ == "__main__":
156 | data_setting = {'img_root': '',
157 | 'gt_root': '',
158 | 'train_source': '',
159 | 'eval_source': ''}
160 | bd = BaseDataset(data_setting, 'train', None)
161 | print(bd.get_class_names())
162 |
--------------------------------------------------------------------------------
/data/cityscapes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 | from data.BaseDataset import BaseDataset
6 |
7 |
8 | class Cityscapes(BaseDataset):
9 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
10 | 28, 31, 32, 33]
11 |
12 | @classmethod
13 | def get_class_colors(*args):
14 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
15 | [102, 102, 156], [190, 153, 153], [153, 153, 153],
16 | [250, 170, 30], [220, 220, 0], [107, 142, 35],
17 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
18 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
19 | [0, 0, 230], [119, 11, 32]]
20 |
21 | @classmethod
22 | def get_class_names(*args):
23 | # class counting(gtFine)
24 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832
25 | # 359 274 142 513 1646
26 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
27 | 'traffic light', 'traffic sign',
28 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
29 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
30 |
31 | @classmethod
32 | def transform_label(cls, pred, name):
33 | label = np.zeros(pred.shape)
34 | ids = np.unique(pred)
35 | for id in ids:
36 | label[np.where(pred == id)] = cls.trans_labels[id]
37 |
38 | new_name = (name.split('.')[0]).split('_')[:-1]
39 | new_name = '_'.join(new_name) + '.png'
40 |
41 | print('Trans', name, 'to', new_name, ' ',
42 | np.unique(np.array(pred, np.uint8)), ' ---------> ',
43 | np.unique(np.array(label, np.uint8)))
44 | return label, new_name
45 |
--------------------------------------------------------------------------------
/data/gta5.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 | from data.BaseDataset import BaseDataset
6 |
7 |
8 | class GTA5(BaseDataset):
9 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
10 | 28, 31, 32, 33]
11 |
12 | @classmethod
13 | def get_class_colors(*args):
14 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
15 | [102, 102, 156], [190, 153, 153], [153, 153, 153],
16 | [250, 170, 30], [220, 220, 0], [107, 142, 35],
17 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
18 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
19 | [0, 0, 230], [119, 11, 32]]
20 |
21 | @classmethod
22 | def get_class_names(*args):
23 | # class counting(gtFine)
24 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832
25 | # 359 274 142 513 1646
26 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
27 | 'traffic light', 'traffic sign',
28 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
29 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
30 |
31 | @classmethod
32 | def transform_label(cls, pred, name):
33 | label = np.zeros(pred.shape)
34 | ids = np.unique(pred)
35 | for id in ids:
36 | label[np.where(pred == id)] = cls.trans_labels[id]
37 |
38 | new_name = (name.split('.')[0]).split('_')[:-1]
39 | new_name = '_'.join(new_name) + '.png'
40 |
41 | print('Trans', name, 'to', new_name, ' ',
42 | np.unique(np.array(pred, np.uint8)), ' ---------> ',
43 | np.unique(np.array(label, np.uint8)))
44 | return label, new_name
45 |
--------------------------------------------------------------------------------
/data/visda17.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import os
5 | from PIL import Image
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import Dataset
9 | import torchvision.transforms as transforms
10 |
11 | class VisDA17(Dataset):
12 |
13 | def __init__(self, txt_file, root_dir, transform=transforms.ToTensor(), label_one_hot=False, portion=1.0):
14 | """
15 | Args:
16 | txt_file (string): Path to the txt file with annotations.
17 | root_dir (string): Directory with all the images.
18 | transform (callable, optional): Optional transform to be applied
19 | on a sample.
20 | """
21 | self.lines = open(txt_file, 'r').readlines()
22 | self.root_dir = root_dir
23 | self.transform = transform
24 | self.label_one_hot = label_one_hot
25 | self.portion = portion
26 | self.number_classes = 12
27 | assert portion != 0
28 | if self.portion > 0:
29 | self.lines = self.lines[:round(self.portion * len(self.lines))]
30 | else:
31 | self.lines = self.lines[round(self.portion * len(self.lines)):]
32 |
33 | def __len__(self):
34 | return len(self.lines)
35 |
36 | def __getitem__(self, idx):
37 | line = str.split(self.lines[idx])
38 | path_img = os.path.join(self.root_dir, line[0])
39 | image = Image.open(path_img)
40 | image = image.convert('RGB')
41 | if self.label_one_hot:
42 | label = np.zeros(12, np.float32)
43 | label[np.asarray(line[1], dtype=np.int)] = 1
44 | else:
45 | label = np.asarray(line[1], dtype=np.int)
46 | label = torch.from_numpy(label)
47 | if self.transform:
48 | image = self.transform(image)
49 | return image, label
50 |
--------------------------------------------------------------------------------
/dataloader_seg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import cv2
5 | from torch.utils import data
6 | from tools.utils.img_utils import random_scale, random_mirror, normalize, generate_random_crop_pos, random_crop_pad_to_shape
7 | cv2.setNumThreads(0)
8 |
9 |
10 | class TrainPre(object):
11 | def __init__(self, config, img_mean, img_std):
12 | self.img_mean = img_mean
13 | self.img_std = img_std
14 | self.config = config
15 |
16 | def __call__(self, img, gt):
17 | img, gt = random_mirror(img, gt)
18 | if self.config.train_scale_array is not None:
19 | img, gt, scale = random_scale(img, gt, self.config.train_scale_array)
20 |
21 | crop_size = (self.config.image_height, self.config.image_width)
22 | crop_pos = generate_random_crop_pos(img.shape[:2], crop_size)
23 | p_img, _ = random_crop_pad_to_shape(normalize(img, self.img_mean, self.img_std), crop_pos, crop_size, 0)
24 | p_img = p_img.transpose(2, 0, 1)
25 | extra_dict = None
26 | p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 255)
27 | p_gt = cv2.resize(p_gt, (self.config.image_width // self.config.gt_down_sampling, self.config.image_height // self.config.gt_down_sampling), interpolation=cv2.INTER_NEAREST)
28 |
29 | return p_img, p_gt, extra_dict
30 |
31 |
32 | def get_train_loader(config, dataset, worker=None, test=False):
33 | data_setting = {
34 | 'train_img_root': config.train_img_root,
35 | 'train_gt_root': config.train_gt_root,
36 | 'val_img_root': config.val_img_root,
37 | 'val_gt_root': config.val_gt_root,
38 | 'train_source': config.train_source,
39 | 'eval_source': config.eval_source,
40 | 'down_sampling_train': config.down_sampling_train
41 | }
42 | if test:
43 | data_setting = {'img_root': config.img_root,
44 | 'gt_root': config.gt_root,
45 | 'train_source': config.train_eval_source,
46 | 'eval_source': config.eval_source}
47 | train_preprocess = TrainPre(config, config.image_mean, config.image_std)
48 |
49 | train_dataset = dataset(data_setting, "train", train_preprocess, config.batch_size * config.niters_per_epoch)
50 |
51 | is_shuffle = True
52 | batch_size = config.batch_size
53 |
54 | train_loader = data.DataLoader(train_dataset,
55 | batch_size=batch_size,
56 | num_workers=config.num_workers if worker is None else worker,
57 | drop_last=True,
58 | shuffle=is_shuffle,
59 | pin_memory=True,
60 | )
61 |
62 | return train_loader
63 |
--------------------------------------------------------------------------------
/eval_seg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | #!/usr/bin/env python3
5 | # encoding: utf-8
6 | import os
7 | import cv2
8 | import numpy as np
9 | from pdb import set_trace as bp
10 | from tools.utils.visualize import print_iou, show_img, show_prediction
11 | from tools.engine.evaluator import Evaluator
12 | from tools.engine.logger import get_logger
13 | from tools.seg_opr.metric import hist_info, compute_score
14 | cv2.setNumThreads(0)
15 |
16 | logger = get_logger()
17 |
18 |
19 | class SegEvaluator(Evaluator):
20 | def func_per_iteration(self, data, device, iter=None):
21 | if self.config is not None: config = self.config
22 | img = data['data']
23 | label = data['label']
24 | name = data['fn']
25 |
26 | if len(config.eval_scale_array) == 1:
27 | pred = self.whole_eval(img, label.shape, resize=config.eval_scale_array[0], device=device)
28 | pred = pred.argmax(2) # since we ignore this step in evaluator.py
29 | elif len(config.eval_scale_array) > 1:
30 | pred = self.whole_eval(img, label.shape, resize=config.eval_scale_array[0], device=device)
31 | for scale in config.eval_scale_array[1:]:
32 | pred += self.whole_eval(img, label.shape, resize=scale, device=device)
33 | pred = pred.argmax(2) # since we ignore this step in evaluator.py
34 | else:
35 | pred = self.sliding_eval(img, config.eval_crop_size, config.eval_stride_rate, device)
36 | hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes, pred, label)
37 | results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp}
38 |
39 | if self.save_path is not None:
40 | fn = name + '.png'
41 | cv2.imwrite(os.path.join(self.save_path, fn), pred)
42 | logger.info('Save the image ' + fn)
43 |
44 | # tensorboard logger does not fit multiprocess
45 | if self.logger is not None and iter is not None:
46 | colors = self.dataset.get_class_colors()
47 | image = img
48 | clean = np.zeros(label.shape)
49 | comp_img = show_img(colors, config.background, image, clean, label, pred)
50 | self.logger.add_image('vis', np.swapaxes(np.swapaxes(comp_img, 0, 2), 1, 2), iter)
51 |
52 | if self.show_image or self.show_prediction:
53 | colors = self.dataset.get_class_colors()
54 | image = img
55 | clean = np.zeros(label.shape)
56 | if self.show_image:
57 | comp_img = show_img(colors, config.background, image, clean, label, pred)
58 | cv2.imwrite(os.path.join(self.save_path, name + ".png"), comp_img[:,:,::-1])
59 | if self.show_prediction:
60 | comp_img = show_prediction(colors, config.background, image, pred)
61 | cv2.imwrite(os.path.join(self.save_path, "viz_"+name+".png"), comp_img[:,:,::-1])
62 |
63 | return results_dict
64 |
65 | def compute_metric(self, results):
66 | hist = np.zeros((self.config.num_classes, self.config.num_classes))
67 | correct = 0
68 | labeled = 0
69 | count = 0
70 | for d in results:
71 | hist += d['hist']
72 | correct += d['correct']
73 | labeled += d['labeled']
74 | count += 1
75 |
76 | iu, mean_IU, mean_IU_no_back, mean_pixel_acc = compute_score(hist, correct, labeled)
77 | result_line = print_iou(iu, mean_pixel_acc, self.dataset.get_class_names(), True)
78 | return result_line, mean_IU
79 |
--------------------------------------------------------------------------------
/l2o_train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | # [x] train resnet101 with both proxy guidance and L2O policy on visda17
5 |
6 | import os
7 | import sys
8 | import time
9 | from collections import deque
10 | import logging
11 | from random import choice
12 | from tqdm import tqdm
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from torch.utils.data import DataLoader
18 | import torchvision.transforms as transforms
19 |
20 | from data.visda17 import VisDA17
21 | from model.resnet import resnet101
22 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, accuracy
23 | from utils.logger import prepare_logger, prepare_seed
24 | from utils.sgd import SGD
25 |
26 | from reinforce.arguments import get_args
27 | from reinforce.models.policy import Policy
28 |
29 | from pdb import set_trace as bp
30 |
31 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean')
32 | KLDivLoss = nn.KLDivLoss(reduction='batchmean')
33 |
34 |
35 | def adjust_learning_rate(lr, optimizer):
36 | for param_group in optimizer.param_groups:
37 | param_group['lr'] = lr
38 |
39 |
40 | def get_window_sample(train_loader_iter, train_loader, window_size=1):
41 | samples = []
42 | while len(samples) < window_size:
43 | try:
44 | sample = next(train_loader_iter)
45 | except:
46 | train_loader_iter = iter(train_loader)
47 | sample = next(train_loader_iter)
48 | samples.append(sample)
49 | return samples, train_loader_iter
50 |
51 |
52 | def train_step(args, _window_size, train_loader_iter, train_loader, model, optimizer, obs_avg, base_lr, pbar, step, total_steps, model_old=None):
53 | # if obs_avg: average the observation in the window
54 | losses = []
55 | losses_kl = []
56 | fc_mean = []; fc_std = []
57 | optimizee_step = []
58 | for idx in range(_window_size):
59 | optimizer.zero_grad()
60 | """Train for one sample on the training set"""
61 | samples, train_loader_iter = get_window_sample(train_loader_iter, train_loader)
62 | input, label = samples[0]
63 | label = label.cuda()
64 | input = input.cuda()
65 | # compute output
66 | output, features_new = model(input, output_features=['layer4'], task='new')
67 | # compute gradient
68 | loss = CrossEntropyLoss(output, label.long())
69 | # LWF KL div
70 | loss_kl = 0
71 | if model_old is not None:
72 | output_new = model.forward_fc(features_new['layer4'], task='old')
73 | output_old, _ = model_old(input, output_features=[], task='old')
74 | loss_kl = KLDivLoss(F.log_softmax(output_new, dim=1), F.softmax(output_old, dim=1)).sum(-1)
75 | (loss + args.lwf * loss_kl).backward()
76 | # compute gradient and do SGD step
77 | optimizer.step()
78 | fc_mean.append(model.fc_new[2].weight.mean().detach())
79 | fc_std.append(model.fc_new[2].weight.std().detach())
80 | description = "[step: %.5f][loss: %.1f][loss_kl: %.1f][fc_mean: %.3f][fc_std: %.3f]"%(1. * (step + idx) / total_steps, loss, loss_kl, fc_mean[-1]*1000, fc_std[-1]*1000)
81 | pbar.set_description("[Step %d/%d]"%(step + idx, total_steps) + description)
82 | losses.append(loss.detach())
83 | losses_kl.append(loss_kl.detach())
84 | optimizee_step.append(1. * (step + idx) / total_steps)
85 | if obs_avg:
86 | losses = [sum(losses) / len(losses)]
87 | losses_kl = [sum(losses_kl) / len(losses_kl)]
88 | fc_mean = [sum(fc_mean) / len(fc_mean)]
89 | fc_std = [sum(fc_std) / len(fc_std)]
90 | optimizee_step = [sum(optimizee_step) / len(optimizee_step)]
91 | losses = [loss for loss in losses]
92 | losses_kl = [loss_kl for loss_kl in losses_kl]
93 | optimizee_step = [torch.tensor(step).cuda() for step in optimizee_step]
94 | observation = torch.stack(losses + losses_kl + optimizee_step + fc_mean + fc_std, dim=0)
95 | LRs = torch.Tensor([ group['lr'] / base_lr for group in optimizer.param_groups ]).cuda()
96 | observation = torch.cat([observation, LRs], dim=0).unsqueeze(0) # (batch=1, feature_size=window_size)
97 | return train_loader_iter, observation, torch.mean(torch.stack(losses, dim=0)), torch.mean(torch.stack(losses_kl, dim=0)), torch.mean(torch.stack(fc_mean, dim=0)), torch.mean(torch.stack(fc_std, dim=0))
98 |
99 |
100 | def prepare_optimizee(args, sgd_in_names, obs_shape, hidden_size, actor_critic, current_optimizee_step, prev_optimizee_step):
101 | prev_optimizee_step += current_optimizee_step
102 | current_optimizee_step = 0
103 |
104 | model = resnet101(pretrained=True)
105 | num_ftrs = model.fc.in_features
106 | fc_layers = nn.Sequential(
107 | nn.Linear(num_ftrs, 512),
108 | nn.ReLU(inplace=True),
109 | nn.Linear(512, args.num_class),
110 | )
111 | model.fc_new = fc_layers
112 |
113 | train_blocks = args.train_blocks.split('.')
114 | # default turn-off fc, turn-on fc_new
115 | for param in model.fc.parameters():
116 | param.requires_grad = False
117 | ##### Freeze several bottom layers (Optional) #####
118 | non_train_blocks = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
119 | for name in train_blocks:
120 | try:
121 | non_train_blocks.remove(name)
122 | except Exception:
123 | print("cannot find block name %s\nAvailable blocks are: conv1, bn1, layer1, layer2, layer3, layer4, fc"%name)
124 | for name in non_train_blocks:
125 | for param in getattr(model, name).parameters():
126 | param.requires_grad = False
127 |
128 | # Setup optimizer
129 | sgd_in = []
130 | for name in train_blocks:
131 | if name != 'fc':
132 | sgd_in.append({'params': get_params(model, [name]), 'lr': args.lr})
133 | else:
134 | sgd_in.append({'params': get_params(model, ["fc_new"]), 'lr': args.lr})
135 | base_lrs = [ group['lr'] for group in sgd_in ]
136 | optimizer = SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
137 |
138 | model = model.cuda()
139 | model.eval()
140 | return model, optimizer, current_optimizee_step, prev_optimizee_step
141 |
142 |
143 | def main():
144 | args = get_args()
145 | PID = os.getpid()
146 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir))
147 | prepare_seed(args.rand_seed)
148 |
149 | if args.timestamp == 'none':
150 | args.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))
151 |
152 | torch.set_num_threads(1)
153 |
154 | # Log outputs
155 | args.save_dir = args.save_dir + \
156 | "/Visda17-L2O.train.Res101-%s-train.%s-LR%.2E-epoch%d-batch%d-seed%d"%(
157 | "LWF" if args.lwf > 0 else "XE", args.train_blocks, args.lr, args.epochs, args.batch_size, args.rand_seed) + \
158 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
159 | logger = prepare_logger(args)
160 |
161 | best_prec1 = 0
162 |
163 | #### preparation ###########################################
164 | data_transforms = {
165 | 'train': transforms.Compose([
166 | transforms.Resize(224),
167 | transforms.CenterCrop(224),
168 | transforms.ToTensor(),
169 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
170 | ]),
171 | 'val': transforms.Compose([
172 | transforms.Resize(224),
173 | transforms.CenterCrop(224),
174 | transforms.ToTensor(),
175 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
176 | ]),
177 | }
178 |
179 | kwargs = {'num_workers': 20, 'pin_memory': True}
180 | trainset = VisDA17(txt_file=os.path.join(args.data, "train/image_list.txt"), root_dir=os.path.join(args.data, "train"), transform=data_transforms['train'])
181 | valset = VisDA17(txt_file=os.path.join(args.data, "validation/image_list.txt"), root_dir=os.path.join(args.data, "validation"), transform=data_transforms['val'], label_one_hot=True)
182 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
183 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, **kwargs)
184 | train_loader_iter = iter(train_loader)
185 | current_optimizee_step, prev_optimizee_step = 0, 0
186 |
187 | model_old = None
188 | if args.lwf > 0:
189 | # create a fixed model copy for Life-long learning
190 | model_old = resnet101(pretrained=True)
191 | for param in model_old.parameters():
192 | param.requires_grad = False
193 | model_old.eval()
194 | model_old.cuda()
195 | ############################################################
196 |
197 | ### Agent Settings ########################################
198 | RANDOM = False # False | True | 'init'
199 | action_space = np.arange(0, 1.1, 0.1)
200 | obs_avg = True
201 | _window_size = 1
202 | window_size = 1 if obs_avg else _window_size
203 | window_shrink_size = 20 # larger: controller will be updated more frequently
204 | sgd_in_names = ["conv1", "bn1", "layer1", "layer2", "layer3", "layer4", "fc_new"]
205 | coord_size = len(sgd_in_names)
206 | ob_name_lstm = ["loss", "loss_kl", "step", "fc_mean", "fc_std"]
207 | ob_name_scalar = []
208 | obs_shape = (len(ob_name_lstm) * window_size + len(ob_name_scalar) + coord_size, )
209 | _hidden_size = 20
210 | hidden_size = _hidden_size * len(ob_name_lstm)
211 | actor_critic = Policy(coord_size, input_size=(len(ob_name_lstm), len(ob_name_scalar)), action_space=len(action_space), hidden_size=_hidden_size, window_size=window_size)
212 | actor_critic.cuda()
213 | actor_critic.eval()
214 |
215 | partial = torch.load(args.agent_load_dir, map_location=lambda storage, loc: storage)
216 | state = actor_critic.state_dict()
217 | pretrained_dict = {k: v for k, v in partial.items()}
218 | state.update(pretrained_dict)
219 | actor_critic.load_state_dict(state)
220 |
221 | ################################################################
222 |
223 | _min_iter = 10
224 | # reset optmizee
225 | model, optimizer, current_optimizee_step, prev_optimizee_step = prepare_optimizee(args, sgd_in_names, obs_shape, hidden_size, actor_critic, current_optimizee_step, prev_optimizee_step)
226 | epoch_size = len(train_loader)
227 | total_steps = epoch_size*args.epochs
228 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
229 | pbar = tqdm(range(int(epoch_size*args.epochs)), file=sys.stdout, bar_format=bar_format, ncols=100)
230 | _window_size = max(_min_iter, current_optimizee_step + prev_optimizee_step // window_shrink_size)
231 | train_loader_iter, obs, loss, loss_kl, fc_mean, fc_std = train_step(args, _window_size, train_loader_iter, train_loader, model, optimizer, obs_avg, args.lr, pbar, current_optimizee_step + prev_optimizee_step, total_steps, model_old=model_old)
232 | logger.writer.add_scalar("loss/ce", loss, current_optimizee_step + prev_optimizee_step)
233 | logger.writer.add_scalar("loss/kl", loss_kl, current_optimizee_step + prev_optimizee_step)
234 | logger.writer.add_scalar("loss/total", loss + loss_kl, current_optimizee_step + prev_optimizee_step)
235 | logger.writer.add_scalar("fc/mean", fc_mean, current_optimizee_step + prev_optimizee_step)
236 | logger.writer.add_scalar("fc/std", fc_std, current_optimizee_step + prev_optimizee_step)
237 | current_optimizee_step += _window_size
238 | pbar.update(_window_size)
239 | prev_obs = obs.unsqueeze(0)
240 | prev_hidden = torch.zeros(actor_critic.net.num_recurrent_layers, 1, hidden_size).cuda()
241 | for epoch in range(args.epochs):
242 | print("\n===== Epoch %d / %d ====="%(epoch+1, args.epochs))
243 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir))
244 | while current_optimizee_step < epoch_size:
245 | # Sample actions
246 | with torch.no_grad():
247 | if not RANDOM:
248 | value, action, action_log_prob, recurrent_hidden_states, distribution = actor_critic.act(prev_obs, prev_hidden, deterministic=False)
249 | action = action.squeeze()
250 | action_log_prob = action_log_prob.squeeze()
251 | value = value.squeeze()
252 | for idx in range(len(action)):
253 | logger.writer.add_scalar("action/%s"%sgd_in_names[idx], action[idx], current_optimizee_step + prev_optimizee_step)
254 | logger.writer.add_scalar("entropy/%s"%sgd_in_names[idx], distribution.distributions[idx].entropy(), current_optimizee_step + prev_optimizee_step)
255 | optimizer.param_groups[idx]['lr'] = float(action_space[action[idx]]) * args.lr
256 | logger.writer.add_scalar("LR/%s"%sgd_in_names[idx], optimizer.param_groups[idx]['lr'], current_optimizee_step + prev_optimizee_step)
257 | else:
258 | if RANDOM is True or RANDOM == 'init':
259 | for idx in range(coord_size):
260 | optimizer.param_groups[idx]['lr'] = float(choice(action_space)) * args.lr
261 | if RANDOM == 'init':
262 | RANDOM = 'done'
263 | for idx in range(coord_size):
264 | logger.writer.add_scalar("LR/%s"%sgd_in_names[idx], optimizer.param_groups[idx]['lr'], current_optimizee_step + prev_optimizee_step)
265 |
266 | # Obser reward and next obs
267 | _window_size = max(_min_iter, current_optimizee_step + prev_optimizee_step // window_shrink_size)
268 | _window_size = min(_window_size, epoch_size - current_optimizee_step)
269 | train_loader_iter, obs, loss, loss_kl, fc_mean, fc_std = train_step(args, _window_size, train_loader_iter, train_loader, model, optimizer, obs_avg, args.lr, pbar, current_optimizee_step + prev_optimizee_step, total_steps, model_old=model_old)
270 | logger.writer.add_scalar("loss/ce", loss, current_optimizee_step + prev_optimizee_step)
271 | logger.writer.add_scalar("loss/kl", loss_kl, current_optimizee_step + prev_optimizee_step)
272 | logger.writer.add_scalar("loss/total", loss + loss_kl, current_optimizee_step + prev_optimizee_step)
273 | logger.writer.add_scalar("fc/mean", fc_mean, current_optimizee_step + prev_optimizee_step)
274 | logger.writer.add_scalar("fc/std", fc_std, current_optimizee_step + prev_optimizee_step)
275 | current_optimizee_step += _window_size
276 | pbar.update(_window_size)
277 | prev_obs = obs.unsqueeze(0)
278 | if not RANDOM: prev_hidden = recurrent_hidden_states
279 | prev_optimizee_step += current_optimizee_step
280 | current_optimizee_step = 0
281 |
282 | # evaluate on validation set
283 | prec1 = validate(val_loader, model, args)
284 | logger.writer.add_scalar("prec", prec1, epoch)
285 |
286 | # remember best prec@1 and save checkpoint
287 | is_best = prec1 > best_prec1
288 | best_prec1 = max(prec1, best_prec1)
289 | save_checkpoint(args.save_dir, {
290 | 'epoch': epoch + 1,
291 | 'state_dict': model.state_dict(),
292 | 'best_prec1': best_prec1,
293 | }, is_best)
294 |
295 | logging.info('Best accuracy: {prec1:.3f}'.format(prec1=best_prec1))
296 |
297 |
298 | def validate(val_loader, model, args):
299 | """Perform validation on the validation set"""
300 | batch_time = AverageMeter()
301 | top1 = AverageMeter()
302 |
303 | model.eval()
304 |
305 | end = time.time()
306 | val_size = len(val_loader)
307 | val_loader_iter = iter(val_loader)
308 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
309 | pbar = tqdm(range(val_size), file=sys.stdout, bar_format=bar_format, ncols=140)
310 | with torch.no_grad():
311 | for idx_iter in pbar:
312 | input, label = next(val_loader_iter)
313 |
314 | input = input.cuda()
315 | label = label.cuda()
316 |
317 | # compute output
318 | output = torch.sigmoid(model(input, task='new')[0])
319 | output = (output + torch.sigmoid(model(torch.flip(input, dims=(3,)), task='new')[0])) / 2
320 |
321 | # accumulate accuracyk
322 | prec1, gt_num = accuracy(output.data, label, args.num_class, topk=(1,))
323 | top1.update(prec1[0], gt_num[0])
324 |
325 | # measure elapsed time
326 | batch_time.update(time.time() - end)
327 | end = time.time()
328 |
329 | description = "[Acc@1-mean: %.2f][Acc@1-cls: %s]"%(top1.vec2sca_avg, str(top1.avg.numpy().round(1)))
330 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, val_size) + description)
331 |
332 | logging.info(' * Prec@1 {top1.vec2sca_avg:.3f}'.format(top1=top1))
333 | logging.info(' * Prec@1 {top1.avg}'.format(top1=top1))
334 |
335 | return top1.vec2sca_avg
336 |
337 |
338 | if __name__ == "__main__":
339 | main()
340 |
--------------------------------------------------------------------------------
/l2o_train.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | python l2o_train.py \
5 | --epochs 30 \
6 | --batch-size 32 \
7 | --lr 1e-4 \
8 | --num-class 12 \
9 | --lwf 0.1 \
10 | --agent_load_dir /raid/ASG/pretrained/policy_res101_vista17.pth
11 |
--------------------------------------------------------------------------------
/l2o_train_seg.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | python l2o_train_seg.py \
5 | --epochs 50 \
6 | --batch-size 6 \
7 | --lr 1e-3 \
8 | --num-class 19 \
9 | --gpus 0 \
10 | --gamma 0 \
11 | --early-stop 2 \
12 | --lwf 75. \
13 | --algo reinforce
14 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
--------------------------------------------------------------------------------
/model/fcn8s_vgg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import os.path as osp
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from pdb import set_trace as bp
9 |
10 |
11 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
12 | """Make a 2D bilinear kernel suitable for upsampling"""
13 | factor = (kernel_size + 1) // 2
14 | if kernel_size % 2 == 1:
15 | center = factor - 1
16 | else:
17 | center = factor - 0.5
18 | og = np.ogrid[:kernel_size, :kernel_size]
19 | filt = (1 - abs(og[0] - center) / factor) * \
20 | (1 - abs(og[1] - center) / factor)
21 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
22 | dtype=np.float64)
23 | weight[range(in_channels), range(out_channels), :, :] = filt
24 | return torch.from_numpy(weight).float()
25 |
26 |
27 | class FCN8s(nn.Module):
28 |
29 | def __init__(self, n_class=21):
30 | super(FCN8s, self).__init__()
31 | # conv1
32 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
33 | self.relu1_1 = nn.ReLU(inplace=True)
34 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
35 | self.relu1_2 = nn.ReLU(inplace=True)
36 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2
37 |
38 | # conv2
39 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
40 | self.relu2_1 = nn.ReLU(inplace=True)
41 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
42 | self.relu2_2 = nn.ReLU(inplace=True)
43 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4
44 |
45 | # conv3
46 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
47 | self.relu3_1 = nn.ReLU(inplace=True)
48 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
49 | self.relu3_2 = nn.ReLU(inplace=True)
50 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
51 | self.relu3_3 = nn.ReLU(inplace=True)
52 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8
53 |
54 | # conv4
55 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
56 | self.relu4_1 = nn.ReLU(inplace=True)
57 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
58 | self.relu4_2 = nn.ReLU(inplace=True)
59 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
60 | self.relu4_3 = nn.ReLU(inplace=True)
61 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16
62 |
63 | # conv5
64 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
65 | self.relu5_1 = nn.ReLU(inplace=True)
66 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
67 | self.relu5_2 = nn.ReLU(inplace=True)
68 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
69 | self.relu5_3 = nn.ReLU(inplace=True)
70 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32
71 |
72 | # fc6
73 | self.fc6 = nn.Conv2d(512, 4096, 7)
74 | self.relu6 = nn.ReLU(inplace=True)
75 | self.drop6 = nn.Dropout2d()
76 |
77 | # fc7
78 | self.fc7 = nn.Conv2d(4096, 4096, 1)
79 | self.relu7 = nn.ReLU(inplace=True)
80 | self.drop7 = nn.Dropout2d()
81 |
82 | self.score_fr = nn.Conv2d(4096, n_class, 1)
83 | self.score_pool3 = nn.Conv2d(256, n_class, 1)
84 | self.score_pool4 = nn.Conv2d(512, n_class, 1)
85 |
86 | self.upscore2 = nn.ConvTranspose2d(
87 | n_class, n_class, 4, stride=2, bias=False)
88 | self.upscore8 = nn.ConvTranspose2d(
89 | n_class, n_class, 16, stride=8, bias=False)
90 | self.upscore_pool4 = nn.ConvTranspose2d(
91 | n_class, n_class, 4, stride=2, bias=False)
92 |
93 | self._initialize_weights()
94 |
95 | def _initialize_weights(self):
96 | for m in self.modules():
97 | if isinstance(m, nn.Conv2d):
98 | m.weight.data.zero_()
99 | if m.bias is not None:
100 | m.bias.data.zero_()
101 | if isinstance(m, nn.ConvTranspose2d):
102 | assert m.kernel_size[0] == m.kernel_size[1]
103 | initial_weight = get_upsampling_weight(
104 | m.in_channels, m.out_channels, m.kernel_size[0])
105 | m.weight.data.copy_(initial_weight)
106 |
107 | def forward_fc(self, features, task='new_seg'):
108 | h = features['layer5']
109 | h = self.score_fr(h)
110 | h = self.upscore2(h)
111 | upscore2 = h # 1/16
112 |
113 | h = self.score_pool4(features['layer3'])
114 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
115 | score_pool4c = h # 1/16
116 |
117 | h = upscore2 + score_pool4c # 1/16
118 | h = self.upscore_pool4(h)
119 | upscore_pool4 = h # 1/8
120 |
121 | h = self.score_pool3(features['layer2'])
122 | h = h[:, :,
123 | 9:9 + upscore_pool4.size()[2],
124 | 9:9 + upscore_pool4.size()[3]]
125 | score_pool3c = h # 1/8
126 |
127 | h = upscore_pool4 + score_pool3c # 1/8
128 |
129 | h = self.upscore8(h)
130 | h = h[:, :, 31:31 + self.input_size[2], 31:31 + self.input_size[3]].contiguous()
131 | return h
132 |
133 | def forward_backbone(self, x, output_features=['layer4']):
134 | features = {}
135 | h = x
136 | h = self.relu1_1(self.conv1_1(h))
137 | h = self.relu1_2(self.conv1_2(h))
138 | h = self.pool1(h)
139 |
140 | h = self.relu2_1(self.conv2_1(h))
141 | h = self.relu2_2(self.conv2_2(h))
142 | h = self.pool2(h)
143 |
144 | h = self.relu3_1(self.conv3_1(h))
145 | h = self.relu3_2(self.conv3_2(h))
146 | h = self.relu3_3(self.conv3_3(h))
147 | h = self.pool3(h)
148 | pool3 = h # 1/8
149 | features['layer2'] = pool3
150 |
151 | h = self.relu4_1(self.conv4_1(h))
152 | h = self.relu4_2(self.conv4_2(h))
153 | h = self.relu4_3(self.conv4_3(h))
154 | h = self.pool4(h)
155 | pool4 = h # 1/16
156 | features['layer3'] = pool4
157 |
158 | h = self.relu5_1(self.conv5_1(h))
159 | h = self.relu5_2(self.conv5_2(h))
160 | h = self.relu5_3(self.conv5_3(h))
161 | h = self.pool5(h)
162 | pool5 = h # 1/32
163 | features['layer4'] = pool5
164 |
165 | h = self.relu6(self.fc6(h))
166 | h = self.drop6(h)
167 |
168 | h = self.relu7(self.fc7(h))
169 | h = self.drop7(h)
170 | features['layer5'] = h
171 | return features
172 |
173 | def forward(self, x, output_features=['layer4'], task='new_seg'):
174 | '''
175 | task: 'old' | 'new' | 'new_seg'
176 | 'old', 'new': classification tasks (ImageNet or Visda)
177 | 'new_seg': segmentation head (convs)
178 | '''
179 | self.input_size = x.size()
180 | ###### standard FCN ##################
181 | features = self.forward_backbone(x, output_features)
182 | x = self.forward_fc(features, task=task)
183 | ######################################
184 | return x, features
185 |
186 | def copy_params_from_fcn16s(self, fcn16s):
187 | for name, l1 in fcn16s.named_children():
188 | try:
189 | l2 = getattr(self, name)
190 | l2.weight # skip ReLU / Dropout
191 | except Exception:
192 | continue
193 | assert l1.weight.size() == l2.weight.size()
194 | l2.weight.data.copy_(l1.weight.data)
195 | if l1.bias is not None:
196 | assert l1.bias.size() == l2.bias.size()
197 | l2.bias.data.copy_(l1.bias.data)
198 |
199 |
200 | class FCN8sAtOnce(FCN8s):
201 |
202 | def forward_fc(self, features, task='new_seg'):
203 | h = features['layer5']
204 | h = self.score_fr(h)
205 | h = self.upscore2(h)
206 | upscore2 = h # 1/16
207 |
208 | h = self.score_pool4(features['layer3'] * 0.01) # scaling to train at once
209 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
210 | score_pool4c = h # 1/16
211 |
212 | h = upscore2 + score_pool4c # 1/16
213 | h = self.upscore_pool4(h)
214 | upscore_pool4 = h # 1/8
215 |
216 | h = self.score_pool3(features['layer2'] * 0.0001) # scaling to train at once
217 | h = h[:, :,
218 | 9:9 + upscore_pool4.size()[2],
219 | 9:9 + upscore_pool4.size()[3]]
220 | score_pool3c = h # 1/8
221 |
222 | h = upscore_pool4 + score_pool3c # 1/8
223 |
224 | h = self.upscore8(h)
225 | h = h[:, :, 31:31 + self.input_size[2], 31:31 + self.input_size[3]].contiguous()
226 | return h
227 |
228 | def forward_backbone(self, x, output_features=['layer4']):
229 | features = {}
230 | h = x
231 | h = self.relu1_1(self.conv1_1(h))
232 | h = self.relu1_2(self.conv1_2(h))
233 | h = self.pool1(h)
234 |
235 | h = self.relu2_1(self.conv2_1(h))
236 | h = self.relu2_2(self.conv2_2(h))
237 | h = self.pool2(h)
238 |
239 | h = self.relu3_1(self.conv3_1(h))
240 | h = self.relu3_2(self.conv3_2(h))
241 | h = self.relu3_3(self.conv3_3(h))
242 | h = self.pool3(h)
243 | pool3 = h # 1/8
244 | features['layer2'] = pool3
245 |
246 | h = self.relu4_1(self.conv4_1(h))
247 | h = self.relu4_2(self.conv4_2(h))
248 | h = self.relu4_3(self.conv4_3(h))
249 | h = self.pool4(h)
250 | pool4 = h # 1/16
251 | features['layer3'] = pool4
252 |
253 | h = self.relu5_1(self.conv5_1(h))
254 | h = self.relu5_2(self.conv5_2(h))
255 | h = self.relu5_3(self.conv5_3(h))
256 | h = self.pool5(h)
257 | pool5 = h # 1/32
258 | features['layer4'] = pool5
259 |
260 | h = self.relu6(self.fc6(h))
261 | h = self.drop6(h)
262 |
263 | h = self.relu7(self.fc7(h))
264 | h = self.drop7(h)
265 | features['layer5'] = h
266 | return features
267 |
268 | def forward(self, x, output_features=['layer4'], task='new_seg'):
269 | '''
270 | task: 'old' | 'new' | 'new_seg'
271 | 'old', 'new': classification tasks (ImageNet or Visda)
272 | 'new_seg': segmentation head (convs)
273 | '''
274 | self.input_size = x.size()
275 | ###### standard FCN ##################
276 | features = self.forward_backbone(x, output_features)
277 | x = self.forward_fc(features, task=task)
278 | ######################################
279 | return x, features
280 |
281 | def copy_params_from_vgg16(self, vgg16):
282 | features = [
283 | self.conv1_1, self.relu1_1,
284 | self.conv1_2, self.relu1_2,
285 | self.pool1,
286 | self.conv2_1, self.relu2_1,
287 | self.conv2_2, self.relu2_2,
288 | self.pool2,
289 | self.conv3_1, self.relu3_1,
290 | self.conv3_2, self.relu3_2,
291 | self.conv3_3, self.relu3_3,
292 | self.pool3,
293 | self.conv4_1, self.relu4_1,
294 | self.conv4_2, self.relu4_2,
295 | self.conv4_3, self.relu4_3,
296 | self.pool4,
297 | self.conv5_1, self.relu5_1,
298 | self.conv5_2, self.relu5_2,
299 | self.conv5_3, self.relu5_3,
300 | self.pool5,
301 | ]
302 | for l1, l2 in zip(vgg16.features, features):
303 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
304 | assert l1.weight.size() == l2.weight.size()
305 | assert l1.bias.size() == l2.bias.size()
306 | l2.weight.data.copy_(l1.weight.data)
307 | l2.bias.data.copy_(l1.bias.data)
308 | for i, name in zip([0, 3], ['fc6', 'fc7']):
309 | l1 = vgg16.classifier[i]
310 | l2 = getattr(self, name)
311 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size()))
312 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size()))
313 |
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from pdb import set_trace as bp
9 |
10 |
11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
13 |
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
23 | }
24 |
25 |
26 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=dilation, groups=groups, bias=False, dilation=dilation)
30 |
31 |
32 | def conv1x1(in_planes, out_planes, stride=1):
33 | """1x1 convolution"""
34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
35 |
36 |
37 | class BasicBlock(nn.Module):
38 | expansion = 1
39 |
40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
41 | base_width=64, dilation=1, norm_layer=None):
42 | super(BasicBlock, self).__init__()
43 | if norm_layer is None:
44 | norm_layer = nn.BatchNorm2d
45 | if groups != 1 or base_width != 64:
46 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
47 | if dilation > 1:
48 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
49 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
50 | self.conv1 = conv3x3(inplanes, planes, stride)
51 | self.bn1 = norm_layer(planes)
52 | self.relu = nn.ReLU(inplace=True)
53 | self.conv2 = conv3x3(planes, planes)
54 | self.bn2 = norm_layer(planes)
55 | self.downsample = downsample
56 | self.stride = stride
57 |
58 | def forward(self, x):
59 | identity = x
60 |
61 | out = self.conv1(x)
62 | out = self.bn1(out)
63 | out = self.relu(out)
64 |
65 | out = self.conv2(out)
66 | out = self.bn2(out)
67 |
68 | if self.downsample is not None:
69 | identity = self.downsample(x)
70 |
71 | out += identity
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 |
77 | class Bottleneck(nn.Module):
78 | expansion = 4
79 |
80 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
81 | base_width=64, dilation=1, norm_layer=None):
82 | super(Bottleneck, self).__init__()
83 | if norm_layer is None:
84 | norm_layer = nn.BatchNorm2d
85 | width = int(planes * (base_width / 64.)) * groups
86 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
87 | self.conv1 = conv1x1(inplanes, width)
88 | self.bn1 = norm_layer(width)
89 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
90 | self.bn2 = norm_layer(width)
91 | self.conv3 = conv1x1(width, planes * self.expansion)
92 | self.bn3 = norm_layer(planes * self.expansion)
93 | self.relu = nn.ReLU(inplace=True)
94 | self.downsample = downsample
95 | self.stride = stride
96 |
97 | def forward(self, x):
98 | identity = x
99 |
100 | out = self.conv1(x)
101 | out = self.bn1(out)
102 | out = self.relu(out)
103 |
104 | out = self.conv2(out)
105 | out = self.bn2(out)
106 | out = self.relu(out)
107 |
108 | out = self.conv3(out)
109 | out = self.bn3(out)
110 |
111 | if self.downsample is not None:
112 | identity = self.downsample(x)
113 |
114 | out += identity
115 | out = self.relu(out)
116 |
117 | return out
118 |
119 |
120 | class ResNet(nn.Module):
121 |
122 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
123 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
124 | norm_layer=None):
125 | super(ResNet, self).__init__()
126 | if norm_layer is None:
127 | norm_layer = nn.BatchNorm2d
128 | self._norm_layer = norm_layer
129 |
130 | self.inplanes = 64
131 | self.dilation = 1
132 | if replace_stride_with_dilation is None:
133 | # each element in the tuple indicates if we should replace
134 | # the 2x2 stride with a dilated convolution instead
135 | replace_stride_with_dilation = [False, False, False]
136 | if len(replace_stride_with_dilation) != 3:
137 | raise ValueError("replace_stride_with_dilation should be None "
138 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
139 | self.groups = groups
140 | self.base_width = width_per_group
141 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
142 | bias=False)
143 | self.bn1 = norm_layer(self.inplanes)
144 | self.relu = nn.ReLU(inplace=True)
145 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
146 | self.layer1 = self._make_layer(block, 64, layers[0])
147 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
148 | dilate=replace_stride_with_dilation[0])
149 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
150 | dilate=replace_stride_with_dilation[1])
151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
152 | dilate=replace_stride_with_dilation[2])
153 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
154 | self.fc = nn.Linear(512 * block.expansion, num_classes)
155 |
156 | for m in self.modules():
157 | if isinstance(m, nn.Conv2d):
158 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
159 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
160 | nn.init.constant_(m.weight, 1)
161 | nn.init.constant_(m.bias, 0)
162 |
163 | # Zero-initialize the last BN in each residual branch,
164 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
165 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
166 | if zero_init_residual:
167 | for m in self.modules():
168 | if isinstance(m, Bottleneck):
169 | nn.init.constant_(m.bn3.weight, 0)
170 | elif isinstance(m, BasicBlock):
171 | nn.init.constant_(m.bn2.weight, 0)
172 |
173 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
174 | norm_layer = self._norm_layer
175 | downsample = None
176 | previous_dilation = self.dilation
177 | if dilate:
178 | self.dilation *= stride
179 | stride = 1
180 | if stride != 1 or self.inplanes != planes * block.expansion:
181 | downsample = nn.Sequential(
182 | conv1x1(self.inplanes, planes * block.expansion, stride),
183 | norm_layer(planes * block.expansion),
184 | )
185 |
186 | layers = []
187 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
188 | self.base_width, previous_dilation, norm_layer))
189 | self.inplanes = planes * block.expansion
190 | for _ in range(1, blocks):
191 | layers.append(block(self.inplanes, planes, groups=self.groups,
192 | base_width=self.base_width, dilation=self.dilation,
193 | norm_layer=norm_layer))
194 |
195 | return nn.Sequential(*layers)
196 |
197 | def forward_fc(self, f4, task='old', f3=None, f2=None):
198 | x = f4
199 | if task in ['old', 'new']:
200 | x = self.avgpool(x)
201 | x = x.reshape(x.size(0), -1)
202 | if task == 'old':
203 | x = self.fc(x)
204 | else:
205 | x = self.fc_new(x)
206 | return x
207 |
208 | def forward_backbone(self, x, output_features=['layer4']):
209 | features = {}
210 | f0 = self.conv1(x)
211 | f0 = self.bn1(f0)
212 | f0 = self.relu(f0)
213 | f0 = self.maxpool(f0)
214 | if 'layer0' in output_features: features['layer0'] = f0
215 | f1 = self.layer1(f0)
216 | if 'layer1' in output_features: features['layer1'] = f1
217 | f2 = self.layer2(f1)
218 | if 'layer2' in output_features: features['layer2'] = f2
219 | f3 = self.layer3(f2)
220 | if 'layer3' in output_features: features['layer3'] = f3
221 | f4 = self.layer4(f3)
222 | if 'layer4' in output_features: features['layer4'] = f4
223 | return f4, features
224 | # return f4, f3, f2, features
225 |
226 | def forward(self, x, output_features=['layer4'], task='old'):
227 | '''
228 | task: 'old' | 'new' | 'new_seg'
229 | 'old', 'new': classification tasks (ImageNet or Visda)
230 | 'new_seg': segmentation head (convs)
231 | '''
232 | ###### standard FCN ##################
233 | f4, features = self.forward_backbone(x, output_features)
234 | x = self.forward_fc(f4, task=task)
235 | ######################################
236 | return x, features
237 |
238 |
239 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
240 | model = ResNet(block, layers, **kwargs)
241 | if pretrained:
242 | from torchvision.models.utils import load_state_dict_from_url
243 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
244 | # model.load_state_dict(state_dict)
245 | state = model.state_dict()
246 | pretrained_dict = {k: v for k, v in state_dict.items() if k in state and state[k].size() == v.size()}
247 | state.update(pretrained_dict)
248 | model.load_state_dict(state)
249 | return model
250 |
251 |
252 | def resnet18(pretrained=False, progress=True, **kwargs):
253 | """Constructs a ResNet-18 model.
254 |
255 | Args:
256 | pretrained (bool): If True, returns a model pre-trained on ImageNet
257 | progress (bool): If True, displays a progress bar of the download to stderr
258 | """
259 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
260 | **kwargs)
261 |
262 |
263 | def resnet34(pretrained=False, progress=True, **kwargs):
264 | """Constructs a ResNet-34 model.
265 |
266 | Args:
267 | pretrained (bool): If True, returns a model pre-trained on ImageNet
268 | progress (bool): If True, displays a progress bar of the download to stderr
269 | """
270 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
271 | **kwargs)
272 |
273 |
274 | def resnet50(pretrained=False, progress=True, **kwargs):
275 | """Constructs a ResNet-50 model.
276 |
277 | Args:
278 | pretrained (bool): If True, returns a model pre-trained on ImageNet
279 | progress (bool): If True, displays a progress bar of the download to stderr
280 | """
281 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
282 | **kwargs)
283 |
284 |
285 | def resnet101(pretrained=False, progress=True, **kwargs):
286 | """Constructs a ResNet-101 model.
287 |
288 | Args:
289 | pretrained (bool): If True, returns a model pre-trained on ImageNet
290 | progress (bool): If True, displays a progress bar of the download to stderr
291 | """
292 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
293 | **kwargs)
294 |
295 |
296 | def resnet152(pretrained=False, progress=True, **kwargs):
297 | """Constructs a ResNet-152 model.
298 |
299 | Args:
300 | pretrained (bool): If True, returns a model pre-trained on ImageNet
301 | progress (bool): If True, displays a progress bar of the download to stderr
302 | """
303 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
304 | **kwargs)
305 |
306 |
307 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
308 | """Constructs a ResNeXt-50 32x4d model.
309 |
310 | Args:
311 | pretrained (bool): If True, returns a model pre-trained on ImageNet
312 | progress (bool): If True, displays a progress bar of the download to stderr
313 | """
314 | kwargs['groups'] = 32
315 | kwargs['width_per_group'] = 4
316 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
317 | pretrained, progress, **kwargs)
318 |
319 |
320 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
321 | """Constructs a ResNeXt-101 32x8d model.
322 |
323 | Args:
324 | pretrained (bool): If True, returns a model pre-trained on ImageNet
325 | progress (bool): If True, displays a progress bar of the download to stderr
326 | """
327 | kwargs['groups'] = 32
328 | kwargs['width_per_group'] = 8
329 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
330 | pretrained, progress, **kwargs)
331 |
--------------------------------------------------------------------------------
/model/vgg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import torch
5 | import torch.nn as nn
6 | from pdb import set_trace as bp
7 |
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | model_urls = {
16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
20 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
21 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
22 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
23 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
24 | }
25 |
26 |
27 | class VGG(nn.Module):
28 |
29 | def __init__(self, features, num_classes=1000, init_weights=True):
30 | super(VGG, self).__init__()
31 | self.features = features
32 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
33 | self.classifier = nn.Sequential(
34 | nn.Linear(512 * 7 * 7, 4096),
35 | nn.ReLU(True),
36 | nn.Dropout(),
37 | nn.Linear(4096, 4096),
38 | nn.ReLU(True),
39 | nn.Dropout(),
40 | nn.Linear(4096, num_classes),
41 | )
42 | if init_weights:
43 | self._initialize_weights()
44 |
45 | def forward_fc(self, f4, task='old'):
46 | x = f4
47 | if task in ['old', 'new']:
48 | x = self.avgpool(x)
49 | x = torch.flatten(x, 1)
50 | if task == 'old':
51 | x = self.classifier(x)
52 | else:
53 | x = self.fc_new(x)
54 | return x
55 |
56 | def forward_backbone(self, x, output_features=['layer4']):
57 | features = {}
58 | f4 = self.features(x)
59 | if 'layer4' in output_features: features['layer4'] = f4
60 | return f4, features
61 |
62 | def forward(self, x, output_features=['layer4'], task='old'):
63 | '''
64 | task: 'old' | 'new' | 'new_seg'
65 | 'old', 'new': classification tasks (ImageNet or Visda)
66 | 'new_seg': segmentation head (convs)
67 | '''
68 | ###### standard FCN ##################
69 | f4, features = self.forward_backbone(x, output_features)
70 | x = self.forward_fc(f4, task=task)
71 | ######################################
72 | return x, features
73 |
74 | def _initialize_weights(self):
75 | for m in self.modules():
76 | if isinstance(m, nn.Conv2d):
77 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
78 | if m.bias is not None:
79 | nn.init.constant_(m.bias, 0)
80 | elif isinstance(m, nn.BatchNorm2d):
81 | nn.init.constant_(m.weight, 1)
82 | nn.init.constant_(m.bias, 0)
83 | elif isinstance(m, nn.Linear):
84 | nn.init.normal_(m.weight, 0, 0.01)
85 | nn.init.constant_(m.bias, 0)
86 |
87 |
88 | def make_layers(cfg, batch_norm=False):
89 | layers = []
90 | in_channels = 3
91 | for idx, v in enumerate(cfg):
92 | if v == 'M':
93 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
94 | else:
95 | if idx == 0:
96 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=100)
97 | else:
98 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
99 | if batch_norm:
100 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
101 | else:
102 | layers += [conv2d, nn.ReLU(inplace=True)]
103 | in_channels = v
104 | return nn.Sequential(*layers)
105 |
106 |
107 | cfgs = {
108 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
109 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
110 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
111 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
112 | }
113 |
114 |
115 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
116 | if pretrained:
117 | kwargs['init_weights'] = False
118 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
119 | if pretrained:
120 | from torchvision.models.utils import load_state_dict_from_url
121 | state_dict = load_state_dict_from_url(model_urls[arch],
122 | progress=progress)
123 | model.load_state_dict(state_dict)
124 | return model
125 |
126 |
127 | def vgg11(pretrained=False, progress=True, **kwargs):
128 | r"""VGG 11-layer model (configuration "A") from
129 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
130 |
131 | Args:
132 | pretrained (bool): If True, returns a model pre-trained on ImageNet
133 | progress (bool): If True, displays a progress bar of the download to stderr
134 | """
135 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
136 |
137 |
138 | def vgg11_bn(pretrained=False, progress=True, **kwargs):
139 | r"""VGG 11-layer model (configuration "A") with batch normalization
140 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
141 |
142 | Args:
143 | pretrained (bool): If True, returns a model pre-trained on ImageNet
144 | progress (bool): If True, displays a progress bar of the download to stderr
145 | """
146 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
147 |
148 |
149 | def vgg13(pretrained=False, progress=True, **kwargs):
150 | r"""VGG 13-layer model (configuration "B")
151 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
152 |
153 | Args:
154 | pretrained (bool): If True, returns a model pre-trained on ImageNet
155 | progress (bool): If True, displays a progress bar of the download to stderr
156 | """
157 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
158 |
159 |
160 | def vgg13_bn(pretrained=False, progress=True, **kwargs):
161 | r"""VGG 13-layer model (configuration "B") with batch normalization
162 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
163 |
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | progress (bool): If True, displays a progress bar of the download to stderr
167 | """
168 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
169 |
170 |
171 | def vgg16(pretrained=False, progress=True, **kwargs):
172 | r"""VGG 16-layer model (configuration "D")
173 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
174 |
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | progress (bool): If True, displays a progress bar of the download to stderr
178 | """
179 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
180 |
181 |
182 | def vgg16_bn(pretrained=False, progress=True, **kwargs):
183 | r"""VGG 16-layer model (configuration "D") with batch normalization
184 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
185 |
186 | Args:
187 | pretrained (bool): If True, returns a model pre-trained on ImageNet
188 | progress (bool): If True, displays a progress bar of the download to stderr
189 | """
190 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
191 |
192 |
193 | def vgg19(pretrained=False, progress=True, **kwargs):
194 | r"""VGG 19-layer model (configuration "E")
195 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
196 |
197 | Args:
198 | pretrained (bool): If True, returns a model pre-trained on ImageNet
199 | progress (bool): If True, displays a progress bar of the download to stderr
200 | """
201 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
202 |
203 |
204 | def vgg19_bn(pretrained=False, progress=True, **kwargs):
205 | r"""VGG 19-layer model (configuration 'E') with batch normalization
206 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
207 |
208 | Args:
209 | pretrained (bool): If True, returns a model pre-trained on ImageNet
210 | progress (bool): If True, displays a progress bar of the download to stderr
211 | """
212 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
213 |
--------------------------------------------------------------------------------
/reinforce/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
--------------------------------------------------------------------------------
/reinforce/algo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | from .reinforce import REINFORCE
5 | # from .a2c_acktr import A2C_ACKTR
6 | # from .ppo import PPO
7 |
--------------------------------------------------------------------------------
/reinforce/algo/reinforce.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from pdb import set_trace as bp
8 |
9 |
10 | class REINFORCE():
11 | def __init__(self,
12 | actor_critic,
13 | entropy_coef,
14 | lr=None,
15 | eps=None,
16 | alpha=None,
17 | max_grad_norm=None,
18 | acktr=False):
19 |
20 | self.actor_critic = actor_critic
21 | self.acktr = acktr
22 | self.entropy_coef = entropy_coef
23 | self.max_grad_norm = max_grad_norm
24 | # self.optimizer = optim.Adam(actor_critic.parameters(), lr)#, eps=eps)
25 | self.optimizer = optim.SGD(actor_critic.parameters(), lr, momentum=0.9)#, eps=eps)
26 |
27 | def update(self, rollouts):
28 | obs_shape = rollouts.obs.size()[2:]
29 | action_shape = rollouts.actions.size()[-1]
30 | num_steps, num_processes, _ = rollouts.rewards.size()
31 |
32 | values, action_log_probs, dist_entropy, _, distribution = self.actor_critic.evaluate_actions(
33 | # rollouts.obs[:-1].view(-1, *obs_shape),
34 | # rollouts.recurrent_hidden_states[0].view(-1, self.actor_critic.recurrent_hidden_state_size),
35 | # rollouts.actions.view(-1, action_shape)
36 | rollouts.obs[:-1],
37 | rollouts.recurrent_hidden_states[0],
38 | rollouts.actions
39 | )
40 |
41 | values = values.view(num_steps, num_processes, 1)
42 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1)
43 |
44 | # advantages = rollouts.returns[:-1] - values
45 | advantages = rollouts.returns[:-1]
46 |
47 | action_loss = -(advantages.detach() * action_log_probs).mean()
48 |
49 | self.optimizer.zero_grad()
50 | # (action_loss - dist_entropy * self.entropy_coef).backward()
51 | action_loss.backward()
52 |
53 | # nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
54 |
55 | self.optimizer.step()
56 |
57 | return 0, action_loss.item(), dist_entropy.item()
58 |
--------------------------------------------------------------------------------
/reinforce/arguments.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import argparse
5 | import torch
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser(description='RL')
10 | parser.add_argument('--data', default='/raid/taskcv-2017-public/classification/data', help='path to dataset')
11 | parser.add_argument('--epochs', default=300, type=int, help='number of total epochs to run')
12 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
13 | parser.add_argument('--batch-size', default=64, type=int, dest='batch_size', help='mini-batch size (default: 64)')
14 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate')
15 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)')
16 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
17 | parser.add_argument('--early-stop', default=1, type=int, dest='early_stop', help='limit the optimizer only sees partial optimizee epoch')
18 | parser.add_argument('--lwf', default=0., type=float, dest='lwf', help='weight of KL loss for LwF (default: 0)')
19 | parser.add_argument('--resume', default='none', type=str, help='path to latest checkpoint (default: none)')
20 | parser.add_argument('--timestamp', type=str, default='none', help='timestamp for logging naming')
21 | parser.add_argument('--save_dir', type=str, default="./runs", help='root folder to save checkpoints and log.')
22 | parser.add_argument('--agent_load_dir', type=str, default="/raid/ASG/pretrained/policy_res101_vista17.pth", help='path to pretrained L2O policy model.')
23 | parser.add_argument('--train_blocks', type=str, default="conv1.bn1.layer1.layer2.layer3.layer4.fc", help='blocks to train, seperated by dot.')
24 | parser.add_argument('--num-class', default=12, type=int, dest='num_class', help='the number of classes')
25 | parser.add_argument('--rand_seed', default=0, type=int, help='the number of classes')
26 | parser.add_argument('--algo', default='a2c', help='algorithm to use: a2c | ppo | acktr')
27 | parser.add_argument('--gail', action='store_true', default=False, help='do imitation learning with gail')
28 | parser.add_argument('--gail-experts-dir', default='./gail_experts', help='directory that contains expert demonstrations for gail')
29 | parser.add_argument('--gail-batch-size', type=int, default=128, help='gail batch size (default: 128)')
30 | parser.add_argument('--gail-epoch', type=int, default=5, help='gail epochs (default: 5)')
31 | parser.add_argument('--lr-meta', type=float, default=7e-4, help='learning rate (default: 7e-4)')
32 | parser.add_argument('--eps', type=float, default=1e-5, help='RMSprop optimizer epsilon (default: 1e-5)')
33 | parser.add_argument('--alpha', type=float, default=0.99, help='RMSprop optimizer apha (default: 0.99)')
34 | parser.add_argument('--gamma', type=float, default=0.99, help='discount factor for rewards (default: 0.99)')
35 | parser.add_argument('--use-gae', action='store_true', default=False, help='use generalized advantage estimation')
36 | parser.add_argument('--gae-lambda', type=float, default=0.95, help='gae lambda parameter (default: 0.95)')
37 | parser.add_argument('--entropy-coef', type=float, default=0.01, help='entropy term coefficient (default: 0.01)')
38 | parser.add_argument('--value-loss-coef', type=float, default=0.5, help='value loss coefficient (default: 0.5)')
39 | parser.add_argument('--max-grad-norm', type=float, default=0.5, help='max norm of gradients (default: 0.5)')
40 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
41 | parser.add_argument('--cuda-deterministic', action='store_true', default=False, help="sets flags for determinism when using CUDA (potentially slow!)")
42 | parser.add_argument('--num-steps', type=int, default=5, help='number of forward steps in A2C (default: 5)')
43 | parser.add_argument('--ppo-epoch', type=int, default=4, help='number of ppo epochs (default: 4)')
44 | parser.add_argument('--num-mini-batch', type=int, default=32, help='number of batches for ppo (default: 32)')
45 | parser.add_argument('--clip-param', type=float, default=0.2, help='ppo clip parameter (default: 0.2)')
46 | parser.add_argument('--log-interval', type=int, default=10, help='log interval, one log per n updates (default: 10)')
47 | parser.add_argument('--save-interval', type=int, default=100, help='save interval, one save per n updates (default: 100)')
48 | parser.add_argument('--eval-interval', type=int, default=None, help='eval interval, one eval per n updates (default: None)')
49 | parser.add_argument('--num-env-steps', type=int, default=10e6, help='number of environment steps to train (default: 10e6)')
50 | parser.add_argument('--use-proper-time-limits', action='store_true', default=False, help='compute returns taking into account time limits')
51 | parser.add_argument('--no-recurrent-policy', action='store_false', default=True, help='do not use a recurrent policy')
52 | parser.add_argument('--use-linear-lr-decay', action='store_true', default=False, help='use a linear schedule on the learning rate')
53 | parser.add_argument('--gpus', default=0, type=int, help='use gpu with cuda number')
54 | args = parser.parse_args()
55 |
56 | args.cuda = torch.cuda.is_available()
57 |
58 | # assert args.algo in ['a2c', 'ppo', 'acktr']
59 | if not args.no_recurrent_policy:
60 | assert args.algo in ['a2c', 'ppo'], \
61 | 'Recurrent policy is not implemented for ACKTR'
62 |
63 | return args
64 |
--------------------------------------------------------------------------------
/reinforce/distributions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | from pdb import set_trace as bp
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from reinforce.utils import init
10 |
11 |
12 | class FixedCategorical(torch.distributions.Categorical):
13 | def sample(self):
14 | return super().sample().unsqueeze(-1)
15 |
16 | def log_probs(self, actions):
17 | return (
18 | super()
19 | .log_prob(actions.squeeze(-1))
20 | .view(actions.size(0), -1)
21 | .sum(-1)
22 | .unsqueeze(-1)
23 | )
24 |
25 | def mode(self):
26 | return self.probs.argmax(dim=-1, keepdim=True)
27 |
28 |
29 | class FixedNormal(torch.distributions.Normal):
30 | def log_probs(self, actions):
31 | return super().log_prob(actions).sum(-1, keepdim=True)
32 |
33 | def entrop(self):
34 | return super.entropy().sum(-1)
35 |
36 | def mode(self):
37 | return self.mean
38 |
39 |
40 | class Categorical(nn.Module):
41 | def __init__(self, num_inputs, num_outputs, coord_size=1):
42 | # num_inputs: #features for each coord
43 | # num_outputs: action_space
44 | super(Categorical, self).__init__()
45 | self.num_inputs = num_inputs
46 | self.num_outputs = num_outputs
47 | self.coord_size = coord_size
48 |
49 | init_ = lambda m: init(
50 | m,
51 | nn.init.orthogonal_,
52 | lambda x: nn.init.constant_(x, 0),
53 | gain=0.01)
54 |
55 | self.linear = nn.ModuleList([
56 | init_(nn.Linear(num_inputs, num_outputs))
57 | for _ in range(coord_size)
58 | ])
59 |
60 | def forward(self, x):
61 | # x: (coord, batch, *features)
62 | # will coordinate-wisely return distributions
63 | distributions = []
64 | for coord in range(self.coord_size):
65 | dist = FixedCategorical(logits=self.linear[coord](x[coord]))
66 | distributions.append(dist)
67 | return MultiCategorical(distributions)
68 |
69 |
70 | class MultiCategorical(nn.Module):
71 | def __init__(self, distributions):
72 | super(MultiCategorical, self).__init__()
73 | # coordinate-wise distributions
74 | self.distributions = distributions
75 |
76 | def sample(self):
77 | actions = []
78 | for dist in self.distributions:
79 | actions.append(dist.sample())
80 | return torch.cat(actions, dim=1)
81 |
82 | def log_probs(self, actions, is_sum=True):
83 | # actions: (batch, coord)
84 | log_probs = []
85 | for coord in range(len(self.distributions)):
86 | try:
87 | log_probs.append(self.distributions[coord].log_probs(actions[:, coord:coord+1]))
88 | except:
89 | bp()
90 | log_probs.append(self.distributions[coord].log_probs(actions[:, coord:coord+1]))
91 | log_probs = torch.cat(log_probs, dim=1)
92 | if is_sum:
93 | return log_probs.sum(-1).unsqueeze(-1)
94 | else:
95 | return log_probs
96 |
97 | def entropy(self):
98 | # actions: (batch, coord)
99 | entropies = []
100 | for coord in range(len(self.distributions)):
101 | entropies.append(self.distributions[coord].entropy())
102 | entropies = torch.cat(entropies, dim=0)
103 | return entropies.unsqueeze(-1)
104 |
105 | def mode(self):
106 | actions = []
107 | for dist in self.distributions:
108 | actions.append(dist.probs.argmax(dim=-1, keepdim=True))
109 | return torch.cat(actions, dim=1)
110 |
111 |
112 | class Gaussian(nn.Module):
113 | def __init__(self, num_inputs, num_outputs=1, mean_range=[0, 1], std_epsilon=0.001):
114 | super(Gaussian, self).__init__()
115 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0))
116 | self._num_inputs = num_inputs
117 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
118 | self.fc_std = init_(nn.Linear(num_inputs, num_outputs))
119 | assert len(mean_range) == 2 and mean_range[0] < mean_range[1]
120 | self.mean_min = mean_range[0]
121 | self.mean_max = mean_range[1]
122 | self.std_epsilon = std_epsilon
123 |
124 | def forward(self, x):
125 | # x = x.view(1, self._num_inputs)
126 | action_mean = self.fc_mean(x)
127 | action_mean = F.sigmoid(action_mean) * (self.mean_max - self.mean_min) + self.mean_min
128 | action_std = F.softplus(self.fc_std(x)) + self.std_epsilon
129 | return FixedNormal(action_mean, action_std)
130 |
--------------------------------------------------------------------------------
/reinforce/models/policy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import abc
5 | from pdb import set_trace as bp
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 |
10 | from reinforce.distributions import Categorical
11 | from reinforce.models.rnn_state_encoder import RNNStateEncoder
12 |
13 |
14 | class Policy(nn.Module):
15 | def __init__(self, coord_size, input_size=(1, 1), action_space=1, hidden_size=1, window_size=1):
16 | # input_size: (#lstm_input, #mlp_input)
17 | super().__init__()
18 | self.net = BasicNet(coord_size, input_size=input_size, hidden_size=hidden_size, window_size=window_size)
19 | # will coordinate-wisely return distributions
20 | self.action_distribution = Categorical(input_size[0]*hidden_size+input_size[1]+1, action_space, coord_size=coord_size)
21 | self.critic = CriticHead(coord_size * (input_size[0]*hidden_size+input_size[1]+1))
22 | self.recurrent_hidden_state_size = hidden_size
23 | self.coord_size = coord_size
24 | self.input_size = input_size
25 | self.action_space = action_space
26 | self.hidden_size = hidden_size
27 | self.window_size = window_size
28 |
29 | def forward(self, *x):
30 | raise NotImplementedError
31 |
32 | def act(self, observations, rnn_hidden_states, deterministic=False):
33 | features, rnn_hidden_states = self.net(observations, rnn_hidden_states)
34 | distribution = self.action_distribution(features)
35 | # (coord, seq_len*batch, feature) ==> (seq_len*batch, coord, feature)
36 | value = self.critic(features.permute(1, 0, 2).view(features.size(1), -1))
37 |
38 | if deterministic:
39 | action = distribution.mode()
40 | else:
41 | action = distribution.sample()
42 |
43 | action_log_probs = distribution.log_probs(action)
44 | return value, action, action_log_probs, rnn_hidden_states, distribution
45 |
46 | def get_value(self, observations, rnn_hidden_states):
47 | features, _ = self.net(observations, rnn_hidden_states)
48 | # features = features.view(-1, self.batch_size * self.recurrent_hidden_state_size)
49 | return self.critic(features.permute(1, 0, 2).view(features.size(1), -1))
50 |
51 | def evaluate_actions(self, observations, rnn_hidden_states, action):
52 | features, rnn_hidden_states = self.net(observations, rnn_hidden_states)
53 | # features = features.view(-1, self.batch_size * self.recurrent_hidden_state_size)
54 | distribution = self.action_distribution(features)
55 | value = self.critic(features.permute(1, 0, 2).contiguous().view(features.size(1), -1))
56 |
57 | action_log_probs = distribution.log_probs(action)
58 | distribution_entropy = distribution.entropy().mean()
59 |
60 | return value, action_log_probs, distribution_entropy, rnn_hidden_states, distribution
61 |
62 |
63 |
64 | class CriticHead(nn.Module):
65 | def __init__(self, input_size):
66 | super().__init__()
67 | self.fc = nn.Linear(input_size, 1)
68 | nn.init.orthogonal_(self.fc.weight)
69 | nn.init.constant_(self.fc.bias, 0)
70 |
71 | def forward(self, x):
72 | return self.fc(x)
73 |
74 |
75 | class Net(nn.Module, metaclass=abc.ABCMeta):
76 | @abc.abstractmethod
77 | def forward(self, observations, rnn_hidden_states, prev_actions):
78 | pass
79 |
80 | @property
81 | @abc.abstractmethod
82 | def output_size(self):
83 | pass
84 |
85 | @property
86 | @abc.abstractmethod
87 | def num_recurrent_layers(self):
88 | pass
89 |
90 |
91 | class BasicNet(Net):
92 | def __init__(self, coord_size, input_size=(1, 1), hidden_size=1, window_size=1):
93 | super().__init__()
94 | self._coord_size = coord_size
95 | # input_size: (#lstm_input, #mlp_input)
96 | self._input_size = input_size
97 | self._hidden_size = hidden_size
98 | self._window_size = window_size
99 | self.state_encoder = nn.ModuleList([
100 | RNNStateEncoder(input_size=window_size, hidden_size=self._hidden_size)
101 | for _ in range(input_size[0])
102 | ])
103 | self.train()
104 |
105 | @property
106 | def output_size(self):
107 | return self._hidden_size
108 |
109 | @property
110 | def num_recurrent_layers(self):
111 | return self.state_encoder[0].num_recurrent_layers
112 |
113 | def forward(self, observations, rnn_hidden_states):
114 | # observation: (seq_len, batch_size, #lstm_input * window + #scalar_input + #actions * 1(LR))
115 | # rnn_hidden_states: (#lstm_input * hidden_size)
116 | outputs = []
117 | rnn_hidden_states_new = []
118 | # coordinate-wise
119 | for i in range(self._input_size[0]):
120 | # output: (seq_len, batch(1), hidden_size)
121 | output, rnn_hidden_state = self.state_encoder[i](observations[:, :, i*self._window_size:(i+1)*self._window_size], rnn_hidden_states[:, :, i*self._hidden_size:(i+1)*self._hidden_size])
122 | outputs.append(output)
123 | rnn_hidden_states_new.append(rnn_hidden_state)
124 | # outputs: (seq_len, batch(1), hidden_size * #lstm_input + #scalar_input)
125 | outputs = torch.cat(outputs + [observations[:, :, self._input_size[0]*self._window_size:self._input_size[0]*self._window_size+self._input_size[1]]], dim=2)
126 | # add LR feature for each coord
127 | outputs_LR = []
128 | for coord in range(-self._coord_size, 0):
129 | outputs_LR.append(torch.cat([outputs, observations[:, :, observations.size(2)+coord:observations.size(2)+coord+1]], dim=2))
130 | outputs_LR = torch.stack(outputs_LR, dim=0) # (coord, seq_len, 1, hidden_size * #lstm_input + #scalar_input + 1)
131 | outputs_LR = outputs_LR.view(self._coord_size, -1, outputs_LR.size(-1)) # (coord, seq_len * 1, hidden_size * #lstm_input + #scalar_input + 1)
132 | return outputs_LR, rnn_hidden_states
133 |
--------------------------------------------------------------------------------
/reinforce/models/rnn_state_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import torch
5 | import torch.nn as nn
6 | from pdb import set_trace as bp
7 |
8 |
9 | class RNNStateEncoder(nn.Module):
10 | def __init__(self, input_size: int = 1, hidden_size: int = 1, num_layers: int = 1, rnn_type: str = "LSTM"):
11 | r"""An RNN for encoding the state in RL.
12 |
13 | Supports masking the hidden state during various timesteps in the forward lass
14 |
15 | Args:
16 | input_size: The input size of the RNN
17 | hidden_size: The hidden size
18 | num_layers: The number of recurrent layers
19 | rnn_type: The RNN cell type. Must be GRU or LSTM
20 | """
21 |
22 | super().__init__()
23 | self._num_recurrent_layers = num_layers
24 | self._rnn_type = rnn_type
25 |
26 | self.rnn = getattr(nn, rnn_type)(
27 | input_size=input_size,
28 | hidden_size=hidden_size,
29 | num_layers=num_layers,
30 | )
31 |
32 | self.layer_init()
33 |
34 | def layer_init(self):
35 | for name, param in self.rnn.named_parameters():
36 | if "weight" in name:
37 | nn.init.orthogonal_(param)
38 | elif "bias" in name:
39 | nn.init.constant_(param, 0)
40 |
41 | @property
42 | def num_recurrent_layers(self):
43 | return self._num_recurrent_layers * (
44 | 2 if "LSTM" in self._rnn_type else 1
45 | )
46 |
47 | def _pack_hidden(self, hidden_states):
48 | if "LSTM" in self._rnn_type:
49 | hidden_states = torch.cat(
50 | [hidden_states[0], hidden_states[1]], dim=0
51 | )
52 | return hidden_states
53 |
54 | def _unpack_hidden(self, hidden_states):
55 | if "LSTM" in self._rnn_type:
56 | hidden_states = (
57 | hidden_states[0 : self._num_recurrent_layers],
58 | hidden_states[self._num_recurrent_layers :],
59 | )
60 | return hidden_states
61 |
62 | def single_forward(self, x, hidden_states):
63 | r"""Forward for a non-sequence input
64 | """
65 | if len(x.size()) == 2:
66 | x = x.unsqueeze(0)
67 | # input: (seq_len, batch, input_size)
68 | x, hidden_states = self.rnn(x, hidden_states)
69 | return x, hidden_states
70 |
71 | def forward(self, x, hidden_states):
72 | hidden_states = self._unpack_hidden(hidden_states)
73 | x, hidden_states = self.single_forward(x, hidden_states)
74 | hidden_states = self._pack_hidden(hidden_states)
75 | return x, hidden_states
76 |
--------------------------------------------------------------------------------
/reinforce/storage.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import torch
5 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
6 | from pdb import set_trace as bp
7 |
8 |
9 | def _flatten_helper(T, N, _tensor):
10 | return _tensor.view(T * N, *_tensor.size()[2:])
11 |
12 |
13 | class RolloutStorage(object):
14 | def __init__(self, num_steps, obs_shape, action_shape=1, hidden_size=1, num_recurrent_layers=1):
15 | # observation: (seq_len, batch_size, #lstm_input * window + #scalar_input + #actions * 1(LR))
16 | self.obs = torch.zeros(num_steps + 1, 1, *obs_shape)
17 | self.recurrent_hidden_states = torch.zeros(num_steps + 1, num_recurrent_layers, 1, hidden_size)
18 | self.rewards = torch.zeros(num_steps, 1, 1)
19 | self.value_preds = torch.zeros(num_steps + 1, 1)
20 | self.returns = torch.zeros(num_steps + 1, 1)
21 | self.action_log_probs = torch.zeros(num_steps, 1)
22 | self.actions = torch.zeros(num_steps, action_shape)
23 | self.num_steps = num_steps
24 | self.step = 0
25 |
26 | def to(self, device):
27 | self.obs = self.obs.to(device)
28 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device)
29 | self.rewards = self.rewards.to(device)
30 | self.value_preds = self.value_preds.to(device)
31 | self.returns = self.returns.to(device)
32 | self.action_log_probs = self.action_log_probs.to(device)
33 | self.actions = self.actions.to(device)
34 |
35 | def insert(self, obs, recurrent_hidden_states, actions, action_log_probs, value_preds, rewards):
36 | self.obs[self.step + 1].copy_(obs)
37 | self.recurrent_hidden_states[self.step + 1].copy_(recurrent_hidden_states)
38 | self.actions[self.step].copy_(actions)
39 | self.action_log_probs[self.step].copy_(action_log_probs)
40 | self.value_preds[self.step].copy_(value_preds)
41 | self.rewards[self.step].copy_(rewards)
42 | self.step = (self.step + 1) % self.num_steps
43 |
44 | def after_update(self):
45 | self.obs[0].copy_(self.obs[-1])
46 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1])
47 |
48 | def compute_returns(self, next_value, use_gae, gamma, gae_lambda):
49 | if use_gae:
50 | self.value_preds[-1] = next_value
51 | gae = 0
52 | for step in reversed(range(self.rewards.size(0))):
53 | delta = self.rewards[step] + gamma * self.value_preds[step + 1] - self.value_preds[step]
54 | gae = delta + gamma * gae_lambda * gae
55 | self.returns[step] = gae + self.value_preds[step]
56 | else:
57 | self.returns[-1] = next_value
58 | for step in reversed(range(self.rewards.size(0))):
59 | self.returns[step] = self.returns[step + 1] * gamma + self.rewards[step]
60 |
61 | def feed_forward_generator(self, advantages, num_mini_batch=None, mini_batch_size=None):
62 | num_steps, num_processes = self.rewards.size()[0:2]
63 | batch_size = num_processes * num_steps
64 |
65 | if mini_batch_size is None:
66 | assert batch_size >= num_mini_batch, (
67 | "PPO requires the number of processes ({}) "
68 | "* number of steps ({}) = {} "
69 | "to be greater than or equal to the number of PPO mini batches ({})."
70 | "".format(num_processes, num_steps, num_processes * num_steps,
71 | num_mini_batch))
72 | mini_batch_size = batch_size // num_mini_batch
73 | sampler = BatchSampler(
74 | SubsetRandomSampler(range(batch_size)),
75 | mini_batch_size,
76 | drop_last=True)
77 | for indices in sampler:
78 | obs_batch = self.obs[:-1].view(-1, *self.obs.size()[1:])[indices]
79 | recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(-1, *self.recurrent_hidden_states.size()[1:])[indices]
80 | actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
81 | value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
82 | return_batch = self.returns[:-1].view(-1, 1)[indices]
83 | old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
84 | if advantages is None:
85 | adv_targ = None
86 | else:
87 | adv_targ = advantages.view(-1, 1)[indices]
88 |
89 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch, old_action_log_probs_batch, adv_targ
90 |
91 | def recurrent_generator(self, advantages, num_mini_batch):
92 | num_processes = self.rewards.size(1)
93 | assert num_processes >= num_mini_batch, (
94 | "PPO requires the number of processes ({}) "
95 | "to be greater than or equal to the number of "
96 | "PPO mini batches ({}).".format(num_processes, num_mini_batch))
97 | num_envs_per_batch = num_processes // num_mini_batch
98 | perm = torch.randperm(num_processes)
99 | for start_ind in range(0, num_processes, num_envs_per_batch):
100 | obs_batch = []
101 | recurrent_hidden_states_batch = []
102 | actions_batch = []
103 | value_preds_batch = []
104 | return_batch = []
105 | old_action_log_probs_batch = []
106 | adv_targ = []
107 |
108 | for offset in range(num_envs_per_batch):
109 | ind = perm[start_ind + offset]
110 | obs_batch.append(self.obs[:-1, ind])
111 | recurrent_hidden_states_batch.append(self.recurrent_hidden_states[0:1, ind])
112 | actions_batch.append(self.actions[:, ind])
113 | value_preds_batch.append(self.value_preds[:-1, ind])
114 | return_batch.append(self.returns[:-1, ind])
115 | old_action_log_probs_batch.append(
116 | self.action_log_probs[:, ind])
117 | adv_targ.append(advantages[:, ind])
118 |
119 | T, N = self.num_steps, num_envs_per_batch
120 | # These are all tensors of size (T, N, -1)
121 | obs_batch = torch.stack(obs_batch, 1)
122 | actions_batch = torch.stack(actions_batch, 1)
123 | value_preds_batch = torch.stack(value_preds_batch, 1)
124 | return_batch = torch.stack(return_batch, 1)
125 | old_action_log_probs_batch = torch.stack(
126 | old_action_log_probs_batch, 1)
127 | adv_targ = torch.stack(adv_targ, 1)
128 |
129 | # States is just a (N, -1) tensor
130 | recurrent_hidden_states_batch = torch.stack(recurrent_hidden_states_batch, 1).view(N, -1)
131 |
132 | # Flatten the (T, N, ...) tensors to (T * N, ...)
133 | obs_batch = _flatten_helper(T, N, obs_batch)
134 | actions_batch = _flatten_helper(T, N, actions_batch)
135 | value_preds_batch = _flatten_helper(T, N, value_preds_batch)
136 | return_batch = _flatten_helper(T, N, return_batch)
137 | old_action_log_probs_batch = _flatten_helper(T, N, \
138 | old_action_log_probs_batch)
139 | adv_targ = _flatten_helper(T, N, adv_targ)
140 |
141 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
142 | value_preds_batch, return_batch, old_action_log_probs_batch, adv_targ
143 |
--------------------------------------------------------------------------------
/reinforce/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import glob
5 | import os
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | # Get a render function
12 | def get_render_func(venv):
13 | if hasattr(venv, 'envs'):
14 | return venv.envs[0].render
15 | elif hasattr(venv, 'venv'):
16 | return get_render_func(venv.venv)
17 | elif hasattr(venv, 'env'):
18 | return get_render_func(venv.env)
19 |
20 | return None
21 |
22 |
23 | def get_vec_normalize(venv):
24 | if isinstance(venv, VecNormalize):
25 | return venv
26 | elif hasattr(venv, 'venv'):
27 | return get_vec_normalize(venv.venv)
28 |
29 | return None
30 |
31 |
32 | # Necessary for my KFAC implementation.
33 | class AddBias(nn.Module):
34 | def __init__(self, bias):
35 | super(AddBias, self).__init__()
36 | self._bias = nn.Parameter(bias.unsqueeze(1))
37 |
38 | def forward(self, x):
39 | if x.dim() == 2:
40 | bias = self._bias.t().view(1, -1)
41 | else:
42 | bias = self._bias.t().view(1, -1, 1, 1)
43 |
44 | return x + bias
45 |
46 |
47 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
48 | """Decreases the learning rate linearly"""
49 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
50 | for param_group in optimizer.param_groups:
51 | param_group['lr'] = lr
52 |
53 |
54 | def init(module, weight_init, bias_init, gain=1):
55 | weight_init(module.weight.data, gain=gain)
56 | bias_init(module.bias.data)
57 | return module
58 |
59 |
60 | def cleanup_log_dir(log_dir):
61 | try:
62 | os.makedirs(log_dir)
63 | except OSError:
64 | files = glob.glob(os.path.join(log_dir, '*.monitor.csv'))
65 | for f in files:
66 | os.remove(f)
67 |
--------------------------------------------------------------------------------
/tools/datasets/BaseDataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import os
5 | import cv2
6 | cv2.setNumThreads(0)
7 | import torch
8 | import numpy as np
9 | from random import shuffle
10 |
11 | import torch.utils.data as data
12 |
13 |
14 | class BaseDataset(data.Dataset):
15 | def __init__(self, setting, split_name, preprocess=None, file_length=None):
16 | super(BaseDataset, self).__init__()
17 | self._split_name = split_name
18 | self._img_path = setting['img_root']
19 | self._gt_path = setting['gt_root']
20 | self._portion = setting['portion'] if 'portion' in setting else None
21 | self._train_source = setting['train_source']
22 | self._eval_source = setting['eval_source']
23 | self._test_source = setting['test_source'] if 'test_source' in setting else setting['eval_source']
24 | self._down_sampling = setting['down_sampling']
25 | print("using downsampling:", self._down_sampling)
26 | self._file_names = self._get_file_names(split_name)
27 | print("Found %d images"%len(self._file_names))
28 | self._file_length = file_length
29 | self.preprocess = preprocess
30 |
31 | def __len__(self):
32 | if self._file_length is not None:
33 | return self._file_length
34 | return len(self._file_names)
35 |
36 | def __getitem__(self, index):
37 | if self._file_length is not None:
38 | names = self._construct_new_file_names(self._file_length)[index]
39 | else:
40 | names = self._file_names[index]
41 | img_path = os.path.join(self._img_path, names[0])
42 | gt_path = os.path.join(self._gt_path, names[1])
43 | item_name = names[1].split("/")[-1].split(".")[0]
44 |
45 | img, gt = self._fetch_data(img_path, gt_path)
46 | img = img[:, :, ::-1]
47 | if self.preprocess is not None:
48 | img, gt, extra_dict = self.preprocess(img, gt)
49 |
50 | if self._split_name is 'train':
51 | img = torch.from_numpy(np.ascontiguousarray(img)).float()
52 | gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
53 | if self.preprocess is not None and extra_dict is not None:
54 | for k, v in extra_dict.items():
55 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v))
56 | if 'label' in k:
57 | extra_dict[k] = extra_dict[k].long()
58 | if 'img' in k:
59 | extra_dict[k] = extra_dict[k].float()
60 |
61 | output_dict = dict(data=img, label=gt, fn=str(item_name),
62 | n=len(self._file_names))
63 | if self.preprocess is not None and extra_dict is not None:
64 | output_dict.update(**extra_dict)
65 |
66 | return output_dict
67 |
68 | def _fetch_data(self, img_path, gt_path, dtype=None):
69 | img = self._open_image(img_path, down_sampling=self._down_sampling)
70 | gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype, down_sampling=self._down_sampling)
71 |
72 | return img, gt
73 |
74 | def _get_file_names(self, split_name):
75 | assert split_name in ['train', 'val', 'test']
76 | source = self._train_source
77 | if split_name == "val":
78 | source = self._eval_source
79 | elif split_name == 'test':
80 | source = self._test_source
81 |
82 | file_names = []
83 | with open(source) as f:
84 | files = f.readlines()
85 | if self._portion is not None:
86 | shuffle(files)
87 | num_files = len(files)
88 | if self._portion > 0:
89 | split = int(np.floor(self._portion * num_files))
90 | files = files[:split]
91 | elif self._portion < 0:
92 | split = int(np.floor((1 + self._portion) * num_files))
93 | files = files[split:]
94 |
95 | for item in files:
96 | img_name, gt_name = self._process_item_names(item)
97 | file_names.append([img_name, gt_name])
98 |
99 | return file_names
100 |
101 | def _construct_new_file_names(self, length):
102 | assert isinstance(length, int)
103 | files_len = len(self._file_names)
104 | new_file_names = self._file_names * (length // files_len)
105 |
106 | rand_indices = torch.randperm(files_len).tolist()
107 | new_indices = rand_indices[:length % files_len]
108 |
109 | new_file_names += [self._file_names[i] for i in new_indices]
110 |
111 | return new_file_names
112 |
113 | @staticmethod
114 | def _process_item_names(item):
115 | item = item.strip()
116 | # item = item.split('\t')
117 | item = item.split(' ')
118 | img_name = item[0]
119 | gt_name = item[1]
120 |
121 | return img_name, gt_name
122 |
123 | def get_length(self):
124 | return self.__len__()
125 |
126 | @staticmethod
127 | def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None, down_sampling=1):
128 | # cv2: B G R
129 | # h w c
130 | img = np.array(cv2.imread(filepath, mode), dtype=dtype)
131 |
132 | if isinstance(down_sampling, int):
133 | H, W = img.shape[:2]
134 | if len(img.shape) == 3:
135 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_LINEAR)
136 | else:
137 | img = cv2.resize(img, (W // down_sampling, H // down_sampling), interpolation=cv2.INTER_NEAREST)
138 | assert img.shape[0] == H // down_sampling and img.shape[1] == W // down_sampling
139 | else:
140 | assert (isinstance(down_sampling, tuple) or isinstance(down_sampling, list)) and len(down_sampling) == 2
141 | if len(img.shape) == 3:
142 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_LINEAR)
143 | else:
144 | img = cv2.resize(img, (down_sampling[1], down_sampling[0]), interpolation=cv2.INTER_NEAREST)
145 | assert img.shape[0] == down_sampling[0] and img.shape[1] == down_sampling[1]
146 |
147 | return img
148 |
149 | @classmethod
150 | def get_class_colors(*args):
151 | raise NotImplementedError
152 |
153 | @classmethod
154 | def get_class_names(*args):
155 | raise NotImplementedError
156 |
157 |
158 | if __name__ == "__main__":
159 | data_setting = {'img_root': '',
160 | 'gt_root': '',
161 | 'train_source': '',
162 | 'eval_source': ''}
163 | bd = BaseDataset(data_setting, 'train', None)
164 | print(bd.get_class_names())
165 |
--------------------------------------------------------------------------------
/tools/datasets/cityscapes/cityscapes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 |
6 | from datasets.BaseDataset import BaseDataset
7 |
8 |
9 | class Cityscapes(BaseDataset):
10 | trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
11 | 28, 31, 32, 33]
12 |
13 | @classmethod
14 | def get_class_colors(*args):
15 | return [[128, 64, 128], [244, 35, 232], [70, 70, 70],
16 | [102, 102, 156], [190, 153, 153], [153, 153, 153],
17 | [250, 170, 30], [220, 220, 0], [107, 142, 35],
18 | [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
19 | [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
20 | [0, 0, 230], [119, 11, 32]]
21 |
22 | @classmethod
23 | def get_class_names(*args):
24 | # class counting(gtFine)
25 | # 2953 2811 2934 970 1296 2949 1658 2808 2891 1654 2686 2343 1023 2832
26 | # 359 274 142 513 1646
27 | return ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
28 | 'traffic light', 'traffic sign',
29 | 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
30 | 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
31 |
32 | @classmethod
33 | def transform_label(cls, pred, name):
34 | label = np.zeros(pred.shape)
35 | ids = np.unique(pred)
36 | for id in ids:
37 | label[np.where(pred == id)] = cls.trans_labels[id]
38 |
39 | new_name = (name.split('.')[0]).split('_')[:-1]
40 | new_name = '_'.join(new_name) + '.png'
41 |
42 | print('Trans', name, 'to', new_name, ' ',
43 | np.unique(np.array(pred, np.uint8)), ' ---------> ',
44 | np.unique(np.array(label, np.uint8)))
45 | return label, new_name
46 |
--------------------------------------------------------------------------------
/tools/engine/evaluator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import os
5 | from PIL import Image
6 | import cv2
7 | import numpy as np
8 | import time
9 | from tqdm import tqdm
10 |
11 | import torch
12 | import torch.multiprocessing as mp
13 |
14 | from tools.engine.logger import get_logger
15 | from tools.utils.pyt_utils import load_model, link_file, ensure_dir
16 | from tools.utils.img_utils import pad_image_to_shape, normalize
17 |
18 | logger = get_logger()
19 |
20 |
21 | class Evaluator(object):
22 | def __init__(self, dataset, class_num, image_mean, image_std, network,
23 | multi_scales, is_flip, devices=0, threds=3, config=None, logger=None,
24 | verbose=False, save_path=None, show_image=False, show_prediction=False):
25 | self.dataset = dataset
26 | self.ndata = self.dataset.get_length()
27 | self.class_num = class_num
28 | self.image_mean = image_mean
29 | self.image_std = image_std
30 | self.multi_scales = multi_scales
31 | self.is_flip = is_flip
32 | self.network = network
33 | self.devices = devices
34 | if type(self.devices) == int: self.devices = [self.devices]
35 | self.threds = threds
36 | self.config = config
37 | self.logger = logger
38 |
39 | self.context = mp.get_context('spawn')
40 | self.val_func = None
41 | self.results_queue = self.context.Queue(self.ndata)
42 | self.features_queue = self.context.Queue(self.ndata)
43 |
44 | self.verbose = verbose
45 | self.save_path = save_path
46 | if save_path is not None:
47 | ensure_dir(save_path)
48 | self.show_image = show_image
49 | self.show_prediction = show_prediction
50 |
51 | def run(self, model_path, model_indice, log_file, log_file_link):
52 | """There are four evaluation modes:
53 | 1.only eval a .pth model: -e *.pth
54 | 2.only eval a certain epoch: -e epoch
55 | 3.eval all epochs in a given section: -e start_epoch-end_epoch
56 | 4.eval all epochs from a certain started epoch: -e start_epoch-
57 | """
58 | if '.pth' in model_indice:
59 | models = [model_indice, ]
60 | elif "-" in model_indice:
61 | start_epoch = int(model_indice.split("-")[0])
62 | end_epoch = model_indice.split("-")[1]
63 |
64 | models = os.listdir(model_path)
65 | models.remove("epoch-last.pth")
66 | sorted_models = [None] * len(models)
67 | model_idx = [0] * len(models)
68 |
69 | for idx, m in enumerate(models):
70 | num = m.split(".")[0].split("-")[1]
71 | model_idx[idx] = num
72 | sorted_models[idx] = m
73 | model_idx = np.array([int(i) for i in model_idx])
74 |
75 | down_bound = model_idx >= start_epoch
76 | up_bound = [True] * len(sorted_models)
77 | if end_epoch:
78 | end_epoch = int(end_epoch)
79 | assert start_epoch < end_epoch
80 | up_bound = model_idx <= end_epoch
81 | bound = up_bound * down_bound
82 | model_slice = np.array(sorted_models)[bound]
83 | models = [os.path.join(model_path, model) for model in
84 | model_slice]
85 | else:
86 | models = [os.path.join(model_path,
87 | 'epoch-%s.pth' % model_indice), ]
88 |
89 | results = open(log_file, 'a')
90 | link_file(log_file, log_file_link)
91 |
92 | for model in models:
93 | logger.info("Load Model: %s" % model)
94 | self.val_func = load_model(self.network, model)
95 | result_line, mIoU = self.multi_process_evaluation()
96 |
97 | results.write('Model: ' + model + '\n')
98 | results.write(result_line)
99 | results.write('\n')
100 | results.flush()
101 |
102 | results.close()
103 |
104 | def run_online(self):
105 | """
106 | eval during training
107 | """
108 | self.val_func = self.network
109 | result_line, mIoU = self.single_process_evaluation()
110 | return result_line, mIoU
111 |
112 | def single_process_evaluation(self):
113 | all_results = []
114 | from pdb import set_trace as bp
115 | with torch.no_grad():
116 | for idx in tqdm(range(self.ndata)):
117 | dd = self.dataset[idx]
118 | results_dict = self.func_per_iteration(dd, self.devices[0], iter=idx)
119 | all_results.append(results_dict)
120 | _, _mIoU = self.compute_metric([results_dict])
121 | result_line, mIoU = self.compute_metric(all_results)
122 | return result_line, mIoU
123 |
124 | def run_online_multiprocess(self):
125 | """
126 | eval during training
127 | """
128 | self.val_func = self.network
129 | result_line, mIoU = self.multi_process_single_gpu_evaluation()
130 | return result_line, mIoU
131 |
132 | def multi_process_single_gpu_evaluation(self):
133 | # start_eval_time = time.perf_counter()
134 | stride = int(np.ceil(self.ndata / self.threds))
135 |
136 | # start multi-process on single-gpu
137 | procs = []
138 | for d in range(self.threds):
139 | e_record = min((d + 1) * stride, self.ndata)
140 | shred_list = list(range(d * stride, e_record))
141 | device = self.devices[0]
142 | logger.info('Thread %d handle %d data.' % (d, len(shred_list)))
143 | p = self.context.Process(target=self.worker, args=(shred_list, device))
144 | procs.append(p)
145 |
146 | for p in procs:
147 | p.start()
148 |
149 | all_results = []
150 | for _ in tqdm(range(self.ndata)):
151 | t = self.results_queue.get()
152 | all_results.append(t)
153 | if self.verbose:
154 | self.compute_metric(all_results)
155 |
156 | for p in procs:
157 | p.join()
158 |
159 | result_line, mIoU = self.compute_metric(all_results)
160 | return result_line, mIoU
161 |
162 | def multi_process_evaluation(self):
163 | start_eval_time = time.perf_counter()
164 | nr_devices = len(self.devices)
165 | stride = int(np.ceil(self.ndata / nr_devices))
166 |
167 | # start multi-process on multi-gpu
168 | procs = []
169 | for d in range(nr_devices):
170 | e_record = min((d + 1) * stride, self.ndata)
171 | shred_list = list(range(d * stride, e_record))
172 | device = self.devices[d]
173 | logger.info('GPU %s handle %d data.' % (device, len(shred_list)))
174 | p = self.context.Process(target=self.worker, args=(shred_list, device))
175 | procs.append(p)
176 |
177 | for p in procs:
178 | p.start()
179 |
180 | all_results = []
181 | for _ in tqdm(range(self.ndata)):
182 | t = self.results_queue.get()
183 | all_results.append(t)
184 | if self.verbose:
185 | self.compute_metric(all_results)
186 |
187 | for p in procs:
188 | p.join()
189 |
190 | result_line, mIoU = self.compute_metric(all_results)
191 | logger.info('Evaluation Elapsed Time: %.2fs' % (time.perf_counter() - start_eval_time))
192 | return result_line, mIoU
193 |
194 | def worker(self, shred_list, device):
195 | # start_load_time = time.time()
196 | # logger.info('Load Model on Device %d: %.2fs' % (device, time.time() - start_load_time))
197 | for idx in shred_list:
198 | dd = self.dataset[idx]
199 | results_dict = self.func_per_iteration(dd, device, iter=idx)
200 | self.results_queue.put(results_dict)
201 |
202 |
203 | def func_per_iteration(self, data, device, iter=None):
204 | raise NotImplementedError
205 |
206 | def compute_metric(self, results):
207 | raise NotImplementedError
208 |
209 | # evaluate the whole image at once
210 | def whole_eval(self, img, output_size, resize=None, input_size=None, device=None):
211 | if input_size is not None:
212 | img, margin = self.process_image(img, resize=resize, crop_size=input_size)
213 | else:
214 | img = self.process_image(img, resize=resize, crop_size=input_size)
215 |
216 | pred = self.val_func_process(img, device)
217 | if input_size is not None:
218 | pred = pred[:, margin[0]:(pred.shape[1] - margin[1]), margin[2]:(pred.shape[2] - margin[3])]
219 | pred = pred.permute(1, 2, 0)
220 | pred = pred.cpu().numpy()
221 | if output_size is not None:
222 | pred = cv2.resize(pred,
223 | (output_size[1], output_size[0]),
224 | interpolation=cv2.INTER_LINEAR)
225 |
226 | return pred
227 |
228 | # slide the window to evaluate the image
229 | def sliding_eval(self, img, crop_size, stride_rate, device=None):
230 | ori_rows, ori_cols, c = img.shape
231 | processed_pred = np.zeros((ori_rows, ori_cols, self.class_num))
232 |
233 | for s in self.multi_scales:
234 | img_scale = cv2.resize(img, None, fx=s, fy=s,
235 | interpolation=cv2.INTER_LINEAR)
236 | new_rows, new_cols, _ = img_scale.shape
237 | processed_pred += self.scale_process(img_scale,
238 | (ori_rows, ori_cols),
239 | crop_size, stride_rate, device)
240 |
241 | pred = processed_pred.argmax(2)
242 |
243 | return pred
244 |
245 | def scale_process(self, img, ori_shape, crop_size, stride_rate,
246 | device=None):
247 | new_rows, new_cols, c = img.shape
248 | long_size = new_cols if new_cols > new_rows else new_rows
249 |
250 | if long_size <= crop_size:
251 | input_data, margin = self.process_image(img, crop_size=crop_size)
252 | score = self.val_func_process(input_data, device)
253 | score = score[:, margin[0]:(score.shape[1] - margin[1]),
254 | margin[2]:(score.shape[2] - margin[3])]
255 | else:
256 | stride = int(np.ceil(crop_size * stride_rate))
257 | img_pad, margin = pad_image_to_shape(img, crop_size,
258 | cv2.BORDER_CONSTANT, value=0)
259 |
260 | pad_rows = img_pad.shape[0]
261 | pad_cols = img_pad.shape[1]
262 | r_grid = int(np.ceil((pad_rows - crop_size) / stride)) + 1
263 | c_grid = int(np.ceil((pad_cols - crop_size) / stride)) + 1
264 | data_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda(
265 | device)
266 | count_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda(
267 | device)
268 |
269 | for grid_yidx in range(r_grid):
270 | for grid_xidx in range(c_grid):
271 | s_x = grid_xidx * stride
272 | s_y = grid_yidx * stride
273 | e_x = min(s_x + crop_size, pad_cols)
274 | e_y = min(s_y + crop_size, pad_rows)
275 | s_x = e_x - crop_size
276 | s_y = e_y - crop_size
277 | img_sub = img_pad[s_y:e_y, s_x: e_x, :]
278 | count_scale[:, s_y: e_y, s_x: e_x] += 1
279 |
280 | input_data, tmargin = self.process_image(img_sub, crop_size=crop_size)
281 | temp_score = self.val_func_process(input_data, device)
282 | temp_score = temp_score[:,
283 | tmargin[0]:(temp_score.shape[1] - tmargin[1]),
284 | tmargin[2]:(temp_score.shape[2] - tmargin[3])]
285 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score
286 | # score = data_scale / count_scale
287 | score = data_scale
288 | score = score[:, margin[0]:(score.shape[1] - margin[1]),
289 | margin[2]:(score.shape[2] - margin[3])]
290 |
291 | score = score.permute(1, 2, 0)
292 | data_output = cv2.resize(score.cpu().numpy(),
293 | (ori_shape[1], ori_shape[0]),
294 | interpolation=cv2.INTER_LINEAR)
295 |
296 | return data_output
297 |
298 | def val_func_process(self, input_data, device=None):
299 | input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32)
300 | input_data = torch.FloatTensor(input_data).cuda(device)
301 |
302 | with torch.cuda.device(input_data.get_device()):
303 | self.val_func.eval()
304 | self.val_func.to(input_data.get_device())
305 | with torch.no_grad():
306 | score = self.val_func(input_data, output_features=[], task='new_seg')[0]
307 | if (isinstance(score, tuple) or isinstance(score, list)) and len(score) > 1:
308 | score = score[self.out_idx]
309 | score = score[0] # a single image pass, ignore batch dim
310 |
311 | if self.is_flip:
312 | input_data = input_data.flip(-1)
313 | score_flip = self.val_func(input_data)[0]
314 | score_flip = score_flip[0] # a single image pass, ignore batch dim
315 | score += score_flip.flip(-1)
316 | score = torch.exp(score)
317 | # score = score.data
318 |
319 | return score
320 |
321 | def process_image(self, img, resize=None, crop_size=None):
322 | p_img = img
323 |
324 | if img.shape[2] < 3:
325 | im_b = p_img
326 | im_g = p_img
327 | im_r = p_img
328 | p_img = np.concatenate((im_b, im_g, im_r), axis=2)
329 |
330 | if resize is not None:
331 | if isinstance(resize, float):
332 | _size = p_img.shape[:2]
333 | # p_img = np.array(Image.fromarray(p_img).resize((int(_size[0]*resize), int(_size[1]*resize)), Image.BILINEAR))
334 | p_img = np.array(Image.fromarray(p_img).resize((int(_size[1]*resize), int(_size[0]*resize)), Image.BILINEAR))
335 | elif isinstance(resize, tuple) or isinstance(resize, list):
336 | assert len(resize) == 2
337 | p_img = np.array(Image.fromarray(p_img).resize((int(resize[0]), int(resize[1])), Image.BILINEAR))
338 |
339 | p_img = normalize(p_img, self.image_mean, self.image_std)
340 |
341 | if crop_size is not None:
342 | p_img, margin = pad_image_to_shape(p_img, crop_size, cv2.BORDER_CONSTANT, value=0)
343 | p_img = p_img.transpose(2, 0, 1)
344 |
345 | return p_img, margin
346 |
347 | p_img = p_img.transpose(2, 0, 1)
348 |
349 | return p_img
350 |
--------------------------------------------------------------------------------
/tools/engine/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import os
5 | import sys
6 | import logging
7 |
8 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO')
9 | _default_level = logging.getLevelName(_default_level_name.upper())
10 |
11 |
12 | class LogFormatter(logging.Formatter):
13 | log_fout = None
14 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] '
15 | date = '%(asctime)s '
16 | msg = '%(message)s'
17 |
18 | def format(self, record):
19 | if record.levelno == logging.DEBUG:
20 | mcl, mtxt = self._color_dbg, 'DBG'
21 | elif record.levelno == logging.WARNING:
22 | mcl, mtxt = self._color_warn, 'WRN'
23 | elif record.levelno == logging.ERROR:
24 | mcl, mtxt = self._color_err, 'ERR'
25 | else:
26 | mcl, mtxt = self._color_normal, ''
27 |
28 | if mtxt:
29 | mtxt += ' '
30 |
31 | if self.log_fout:
32 | self.__set_fmt(self.date_full + mtxt + self.msg)
33 | formatted = super(LogFormatter, self).format(record)
34 | # self.log_fout.write(formatted)
35 | # self.log_fout.write('\n')
36 | # self.log_fout.flush()
37 | return formatted
38 |
39 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
40 | formatted = super(LogFormatter, self).format(record)
41 |
42 | return formatted
43 |
44 | if sys.version_info.major < 3:
45 | def __set_fmt(self, fmt):
46 | self._fmt = fmt
47 | else:
48 | def __set_fmt(self, fmt):
49 | self._style._fmt = fmt
50 |
51 | @staticmethod
52 | def _color_dbg(msg):
53 | return '\x1b[36m{}\x1b[0m'.format(msg)
54 |
55 | @staticmethod
56 | def _color_warn(msg):
57 | return '\x1b[1;31m{}\x1b[0m'.format(msg)
58 |
59 | @staticmethod
60 | def _color_err(msg):
61 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg)
62 |
63 | @staticmethod
64 | def _color_omitted(msg):
65 | return '\x1b[35m{}\x1b[0m'.format(msg)
66 |
67 | @staticmethod
68 | def _color_normal(msg):
69 | return msg
70 |
71 | @staticmethod
72 | def _color_date(msg):
73 | return '\x1b[32m{}\x1b[0m'.format(msg)
74 |
75 |
76 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter):
77 | logger = logging.getLogger()
78 | logger.setLevel(_default_level)
79 | del logger.handlers[:]
80 |
81 | if log_dir and log_file:
82 | if not os.path.isdir(log_dir): os.makedirs(log_dir)
83 | LogFormatter.log_fout = True
84 | file_handler = logging.FileHandler(log_file, mode='a')
85 | file_handler.setLevel(logging.INFO)
86 | file_handler.setFormatter(formatter)
87 | logger.addHandler(file_handler)
88 |
89 | stream_handler = logging.StreamHandler()
90 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S'))
91 | stream_handler.setLevel(0)
92 | logger.addHandler(stream_handler)
93 | return logger
94 |
--------------------------------------------------------------------------------
/tools/engine/tester.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import os
5 | import os.path as osp
6 | import cv2
7 | import numpy as np
8 | import time
9 | from tqdm import tqdm
10 | from pdb import set_trace as bp
11 | import torch
12 | import torch.nn.functional as F
13 | import torch.multiprocessing as mp
14 |
15 | from engine.logger import get_logger
16 | from tools.utils.pyt_utils import load_model, link_file, ensure_dir
17 | from tools.utils.img_utils import pad_image_to_shape, normalize
18 |
19 | logger = get_logger()
20 |
21 |
22 | class Tester(object):
23 | def __init__(self, dataset, class_num, image_mean, image_std, network,
24 | multi_scales, is_flip, devices=0, out_idx=0, threds=3, config=None, logger=None,
25 | verbose=False, save_path=None, show_image=False):
26 | self.dataset = dataset
27 | self.ndata = self.dataset.get_length()
28 | self.class_num = class_num
29 | self.image_mean = image_mean
30 | self.image_std = image_std
31 | self.multi_scales = multi_scales
32 | self.is_flip = is_flip
33 | self.network = network
34 | self.devices = devices
35 | if type(self.devices) == int: self.devices = [self.devices]
36 | self.out_idx = out_idx
37 | self.threds = threds
38 | self.config = config
39 | self.logger = logger
40 |
41 | self.context = mp.get_context('spawn')
42 | self.val_func = None
43 | self.results_queue = self.context.Queue(self.ndata)
44 |
45 | self.verbose = verbose
46 | self.save_path = save_path
47 | if save_path is not None:
48 | ensure_dir(save_path)
49 | self.show_image = show_image
50 |
51 | def run(self, model_path, model_indice, log_file, log_file_link):
52 | """There are four evaluation modes:
53 | 1.only eval a .pth model: -e *.pth
54 | 2.only eval a certain epoch: -e epoch
55 | 3.eval all epochs in a given section: -e start_epoch-end_epoch
56 | 4.eval all epochs from a certain started epoch: -e start_epoch-
57 | """
58 | if '.pth' in model_indice:
59 | models = [model_indice, ]
60 | elif "-" in model_indice:
61 | start_epoch = int(model_indice.split("-")[0])
62 | end_epoch = model_indice.split("-")[1]
63 |
64 | models = os.listdir(model_path)
65 | models.remove("epoch-last.pth")
66 | sorted_models = [None] * len(models)
67 | model_idx = [0] * len(models)
68 |
69 | for idx, m in enumerate(models):
70 | num = m.split(".")[0].split("-")[1]
71 | model_idx[idx] = num
72 | sorted_models[idx] = m
73 | model_idx = np.array([int(i) for i in model_idx])
74 |
75 | down_bound = model_idx >= start_epoch
76 | up_bound = [True] * len(sorted_models)
77 | if end_epoch:
78 | end_epoch = int(end_epoch)
79 | assert start_epoch < end_epoch
80 | up_bound = model_idx <= end_epoch
81 | bound = up_bound * down_bound
82 | model_slice = np.array(sorted_models)[bound]
83 | models = [os.path.join(model_path, model) for model in
84 | model_slice]
85 | else:
86 | models = [os.path.join(model_path,
87 | 'epoch-%s.pth' % model_indice), ]
88 |
89 | results = open(log_file, 'a')
90 | link_file(log_file, log_file_link)
91 |
92 | for model in models:
93 | logger.info("Load Model: %s" % model)
94 | self.val_func = load_model(self.network, model)
95 | result_line, mIoU = self.multi_process_evaluation()
96 |
97 | results.write('Model: ' + model + '\n')
98 | results.write(result_line)
99 | results.write('\n')
100 | results.flush()
101 |
102 | results.close()
103 |
104 | def run_online(self):
105 | """
106 | eval during training
107 | """
108 | self.val_func = self.network
109 | self.single_process_evaluation()
110 |
111 | def single_process_evaluation(self):
112 | with torch.no_grad():
113 | for idx in tqdm(range(self.ndata)):
114 | dd = self.dataset[idx]
115 | self.func_per_iteration(dd, self.devices[0], iter=idx)
116 |
117 | def run_online_multiprocess(self):
118 | """
119 | eval during training
120 | """
121 | self.val_func = self.network
122 | self.multi_process_single_gpu_evaluation()
123 |
124 | def multi_process_single_gpu_evaluation(self):
125 | # start_eval_time = time.perf_counter()
126 | stride = int(np.ceil(self.ndata / self.threds))
127 |
128 | # start multi-process on single-gpu
129 | procs = []
130 | for d in range(self.threds):
131 | e_record = min((d + 1) * stride, self.ndata)
132 | shred_list = list(range(d * stride, e_record))
133 | device = self.devices[0]
134 | logger.info('Thread %d handle %d data.' % (d, len(shred_list)))
135 | p = self.context.Process(target=self.worker, args=(shred_list, device))
136 | procs.append(p)
137 |
138 | for p in procs:
139 | p.start()
140 |
141 | for p in procs:
142 | p.join()
143 |
144 |
145 | def multi_process_evaluation(self):
146 | start_eval_time = time.perf_counter()
147 | nr_devices = len(self.devices)
148 | stride = int(np.ceil(self.ndata / nr_devices))
149 |
150 | # start multi-process on multi-gpu
151 | procs = []
152 | for d in range(nr_devices):
153 | e_record = min((d + 1) * stride, self.ndata)
154 | shred_list = list(range(d * stride, e_record))
155 | device = self.devices[d]
156 | logger.info('GPU %s handle %d data.' % (device, len(shred_list)))
157 | p = self.context.Process(target=self.worker, args=(shred_list, device))
158 | procs.append(p)
159 |
160 | for p in procs:
161 | p.start()
162 |
163 | for p in procs:
164 | p.join()
165 |
166 |
167 | def worker(self, shred_list, device):
168 | start_load_time = time.time()
169 | # logger.info('Load Model on Device %d: %.2fs' % (device, time.time() - start_load_time))
170 | for idx in shred_list:
171 | dd = self.dataset[idx]
172 | results_dict = self.func_per_iteration(dd, device, iter=idx)
173 | self.results_queue.put(results_dict)
174 |
175 | def func_per_iteration(self, data, device, iter=None):
176 | raise NotImplementedError
177 |
178 | def compute_metric(self, results):
179 | raise NotImplementedError
180 |
181 | # evaluate the whole image at once
182 | def whole_eval(self, img, output_size, input_size=None, device=None):
183 | if input_size is not None:
184 | img, margin = self.process_image(img, input_size)
185 | else:
186 | img = self.process_image(img, input_size)
187 |
188 | pred = self.val_func_process(img, device)
189 | if input_size is not None:
190 | pred = pred[:, margin[0]:(pred.shape[1] - margin[1]), margin[2]:(pred.shape[2] - margin[3])]
191 | pred = pred.permute(1, 2, 0)
192 | pred = pred.cpu().numpy()
193 | if output_size is not None:
194 | pred = cv2.resize(pred,
195 | (output_size[1], output_size[0]),
196 | interpolation=cv2.INTER_LINEAR)
197 |
198 | pred = pred.argmax(2)
199 |
200 | return pred
201 |
202 | # slide the window to evaluate the image
203 | def sliding_eval(self, img, crop_size, stride_rate, device=None):
204 | ori_rows, ori_cols, c = img.shape
205 | processed_pred = np.zeros((ori_rows, ori_cols, self.class_num))
206 |
207 | for s in self.multi_scales:
208 | img_scale = cv2.resize(img, None, fx=s, fy=s,
209 | interpolation=cv2.INTER_LINEAR)
210 | new_rows, new_cols, _ = img_scale.shape
211 | processed_pred += self.scale_process(img_scale,
212 | (ori_rows, ori_cols),
213 | crop_size, stride_rate, device)
214 |
215 | pred = processed_pred.argmax(2)
216 |
217 | return pred
218 |
219 | def scale_process(self, img, ori_shape, crop_size, stride_rate,
220 | device=None):
221 | new_rows, new_cols, c = img.shape
222 | long_size = new_cols if new_cols > new_rows else new_rows
223 |
224 | if long_size <= crop_size:
225 | input_data, margin = self.process_image(img, crop_size)
226 | score = self.val_func_process(input_data, device)
227 | score = score[:, margin[0]:(score.shape[1] - margin[1]),
228 | margin[2]:(score.shape[2] - margin[3])]
229 | else:
230 | stride = int(np.ceil(crop_size * stride_rate))
231 | img_pad, margin = pad_image_to_shape(img, crop_size,
232 | cv2.BORDER_CONSTANT, value=0)
233 |
234 | pad_rows = img_pad.shape[0]
235 | pad_cols = img_pad.shape[1]
236 | r_grid = int(np.ceil((pad_rows - crop_size) / stride)) + 1
237 | c_grid = int(np.ceil((pad_cols - crop_size) / stride)) + 1
238 | data_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda(
239 | device)
240 | count_scale = torch.zeros(self.class_num, pad_rows, pad_cols).cuda(
241 | device)
242 |
243 | for grid_yidx in range(r_grid):
244 | for grid_xidx in range(c_grid):
245 | s_x = grid_xidx * stride
246 | s_y = grid_yidx * stride
247 | e_x = min(s_x + crop_size, pad_cols)
248 | e_y = min(s_y + crop_size, pad_rows)
249 | s_x = e_x - crop_size
250 | s_y = e_y - crop_size
251 | img_sub = img_pad[s_y:e_y, s_x: e_x, :]
252 | count_scale[:, s_y: e_y, s_x: e_x] += 1
253 |
254 | input_data, tmargin = self.process_image(img_sub, crop_size)
255 | temp_score = self.val_func_process(input_data, device)
256 | temp_score = temp_score[:,
257 | tmargin[0]:(temp_score.shape[1] - tmargin[1]),
258 | tmargin[2]:(temp_score.shape[2] - tmargin[3])]
259 | data_scale[:, s_y: e_y, s_x: e_x] += temp_score
260 | # score = data_scale / count_scale
261 | score = data_scale
262 | score = score[:, margin[0]:(score.shape[1] - margin[1]),
263 | margin[2]:(score.shape[2] - margin[3])]
264 |
265 | score = score.permute(1, 2, 0)
266 | data_output = cv2.resize(score.cpu().numpy(),
267 | (ori_shape[1], ori_shape[0]),
268 | interpolation=cv2.INTER_LINEAR)
269 |
270 | return data_output
271 |
272 | def val_func_process(self, input_data, device=None):
273 | input_data = np.ascontiguousarray(input_data[None, :, :, :], dtype=np.float32)
274 | input_data = torch.FloatTensor(input_data).cuda(device)
275 |
276 | with torch.cuda.device(input_data.get_device()):
277 | self.val_func.eval()
278 | self.val_func.to(input_data.get_device())
279 | with torch.no_grad():
280 | score = self.val_func(input_data, task='new_seg')[0]
281 | if (isinstance(score, tuple) or isinstance(score, list)) and len(score) > 1:
282 | score = score[self.out_idx]
283 | score = score[0] # a single image pass, ignore batch dim
284 |
285 | if self.is_flip:
286 | input_data = input_data.flip(-1)
287 | score_flip = self.val_func(input_data)
288 | score_flip = score_flip[0]
289 | score += score_flip.flip(-1)
290 | score = torch.exp(score)
291 | # score = score.data
292 |
293 | return score
294 |
295 | def process_image(self, img, crop_size=None):
296 | p_img = img
297 |
298 | if img.shape[2] < 3:
299 | im_b = p_img
300 | im_g = p_img
301 | im_r = p_img
302 | p_img = np.concatenate((im_b, im_g, im_r), axis=2)
303 |
304 | p_img = normalize(p_img, self.image_mean, self.image_std)
305 |
306 | if crop_size is not None:
307 | p_img, margin = pad_image_to_shape(p_img, crop_size, cv2.BORDER_CONSTANT, value=0)
308 | p_img = p_img.transpose(2, 0, 1)
309 |
310 | return p_img, margin
311 |
312 | p_img = p_img.transpose(2, 0, 1)
313 |
314 | return p_img
315 |
--------------------------------------------------------------------------------
/tools/seg_opr/metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 |
6 | np.seterr(divide='ignore', invalid='ignore')
7 |
8 |
9 | # voc cityscapes metric
10 | def hist_info(n_cl, pred, gt):
11 | assert (pred.shape == gt.shape), "pred: " + str(pred.shape) + " v.s. gt: " + str(gt.shape)
12 | k = (gt >= 0) & (gt < n_cl)
13 | labeled = np.sum(k)
14 | correct = np.sum((pred[k] == gt[k]))
15 |
16 | return np.bincount(n_cl * gt[k].astype(int) + pred[k].astype(int),
17 | minlength=n_cl ** 2).reshape(n_cl,
18 | n_cl), labeled, correct
19 |
20 |
21 | def compute_score(hist, correct, labeled):
22 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
23 | mean_IU = np.nanmean(iu)
24 | mean_IU_no_back = np.nanmean(iu[1:])
25 | freq = hist.sum(1) / hist.sum()
26 | # freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
27 | mean_pixel_acc = correct / labeled
28 |
29 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc
30 |
31 |
32 | # ade metric
33 | def meanIoU(area_intersection, area_union):
34 | iou = 1.0 * np.sum(area_intersection, axis=1) / np.sum(area_union, axis=1)
35 | meaniou = np.nanmean(iou)
36 | meaniou_no_back = np.nanmean(iou[1:])
37 |
38 | return iou, meaniou, meaniou_no_back
39 |
40 |
41 | def intersectionAndUnion(imPred, imLab, numClass):
42 | # Remove classes from unlabeled pixels in gt image.
43 | # We should not penalize detections in unlabeled portions of the image.
44 | imPred = np.asarray(imPred).copy()
45 | imLab = np.asarray(imLab).copy()
46 |
47 | imPred += 1
48 | imLab += 1
49 | # Remove classes from unlabeled pixels in gt image.
50 | # We should not penalize detections in unlabeled portions of the image.
51 | imPred = imPred * (imLab > 0)
52 |
53 | # imPred = imPred * (imLab >= 0)
54 |
55 | # Compute area intersection:
56 | intersection = imPred * (imPred == imLab)
57 | (area_intersection, _) = np.histogram(intersection, bins=numClass,
58 | range=(1, numClass))
59 |
60 | # Compute area union:
61 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
62 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
63 | area_union = area_pred + area_lab - area_intersection
64 |
65 | return area_intersection, area_union
66 |
67 |
68 | def mean_pixel_accuracy(pixel_correct, pixel_labeled):
69 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (
70 | np.spacing(1) + np.sum(pixel_labeled))
71 |
72 | return mean_pixel_accuracy
73 |
74 |
75 | def pixelAccuracy(imPred, imLab):
76 | # Remove classes from unlabeled pixels in gt image.
77 | # We should not penalize detections in unlabeled portions of the image.
78 | pixel_labeled = np.sum(imLab >= 0)
79 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0))
80 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
81 |
82 | return pixel_accuracy, pixel_correct, pixel_labeled
83 |
84 |
85 | def accuracy(preds, label):
86 | valid = (label >= 0)
87 | acc_sum = (valid * (preds == label)).sum()
88 | valid_sum = valid.sum()
89 | acc = float(acc_sum) / (valid_sum + 1e-10)
90 | return acc, valid_sum
91 |
--------------------------------------------------------------------------------
/tools/utils/img_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import cv2
5 | import numpy as np
6 | import numbers
7 | import random
8 | import collections
9 |
10 |
11 | def get_2dshape(shape, *, zero=True):
12 | if not isinstance(shape, collections.Iterable):
13 | shape = int(shape)
14 | shape = (shape, shape)
15 | else:
16 | h, w = map(int, shape)
17 | shape = (h, w)
18 | if zero:
19 | minv = 0
20 | else:
21 | minv = 1
22 |
23 | assert min(shape) >= minv, 'invalid shape: {}'.format(shape)
24 | return shape
25 |
26 |
27 | def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value):
28 | h, w = img.shape[:2]
29 | start_crop_h, start_crop_w = crop_pos
30 | assert ((start_crop_h < h) and (start_crop_h >= 0))
31 | assert ((start_crop_w < w) and (start_crop_w >= 0))
32 |
33 | crop_size = get_2dshape(crop_size)
34 | crop_h, crop_w = crop_size
35 |
36 | img_crop = img[start_crop_h:start_crop_h + crop_h,
37 | start_crop_w:start_crop_w + crop_w, ...]
38 |
39 | img_, margin = pad_image_to_shape(img_crop, crop_size, cv2.BORDER_CONSTANT,
40 | pad_label_value)
41 |
42 | return img_, margin
43 |
44 |
45 | def generate_random_crop_pos(ori_size, crop_size):
46 | ori_size = get_2dshape(ori_size)
47 | h, w = ori_size
48 |
49 | crop_size = get_2dshape(crop_size)
50 | crop_h, crop_w = crop_size
51 |
52 | pos_h, pos_w = 0, 0
53 |
54 | if h > crop_h:
55 | pos_h = random.randint(0, h - crop_h + 1)
56 |
57 | if w > crop_w:
58 | pos_w = random.randint(0, w - crop_w + 1)
59 |
60 | return pos_h, pos_w
61 |
62 |
63 | def pad_image_to_shape(img, shape, border_mode, value):
64 | margin = np.zeros(4, np.uint32)
65 | shape = get_2dshape(shape)
66 | pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0
67 | pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0
68 |
69 | margin[0] = pad_height // 2
70 | margin[1] = pad_height // 2 + pad_height % 2
71 | margin[2] = pad_width // 2
72 | margin[3] = pad_width // 2 + pad_width % 2
73 |
74 | img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3],
75 | border_mode, value=value)
76 |
77 | return img, margin
78 |
79 |
80 | def pad_image_size_to_multiples_of(img, multiple, pad_value):
81 | h, w = img.shape[:2]
82 | d = multiple
83 |
84 | def canonicalize(s):
85 | v = s // d
86 | return (v + (v * d != s)) * d
87 |
88 | th, tw = map(canonicalize, (h, w))
89 |
90 | return pad_image_to_shape(img, (th, tw), cv2.BORDER_CONSTANT, pad_value)
91 |
92 |
93 | def resize_ensure_shortest_edge(img, edge_length,
94 | interpolation_mode=cv2.INTER_LINEAR):
95 | assert isinstance(edge_length, int) and edge_length > 0, edge_length
96 | h, w = img.shape[:2]
97 | if h < w:
98 | ratio = float(edge_length) / h
99 | th, tw = edge_length, max(1, int(ratio * w))
100 | else:
101 | ratio = float(edge_length) / w
102 | th, tw = max(1, int(ratio * h)), edge_length
103 | img = cv2.resize(img, (tw, th), interpolation_mode)
104 |
105 | return img
106 |
107 |
108 | def random_scale(img, gt, scales):
109 | # scale = random.choice(scales)
110 | scale = random.uniform(min(scales), max(scales))
111 | sh = int(img.shape[0] * scale)
112 | sw = int(img.shape[1] * scale)
113 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR)
114 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST)
115 |
116 | return img, gt, scale
117 |
118 |
119 | def random_scale_with_length(img, gt, length):
120 | size = random.choice(length)
121 | sh = size
122 | sw = size
123 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR)
124 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST)
125 |
126 | return img, gt, size
127 |
128 |
129 | def random_mirror(img, gt):
130 | if random.random() >= 0.5:
131 | img = cv2.flip(img, 1)
132 | gt = cv2.flip(gt, 1)
133 |
134 | return img, gt,
135 |
136 |
137 | def random_rotation(img, gt):
138 | angle = random.random() * 20 - 10
139 | h, w = img.shape[:2]
140 | rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
141 | img = cv2.warpAffine(img, rotation_matrix, (w, h), flags=cv2.INTER_LINEAR)
142 | gt = cv2.warpAffine(gt, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST)
143 |
144 | return img, gt
145 |
146 |
147 | def random_gaussian_blur(img):
148 | gauss_size = random.choice([1, 3, 5, 7])
149 | if gauss_size > 1:
150 | # do the gaussian blur
151 | img = cv2.GaussianBlur(img, (gauss_size, gauss_size), 0)
152 |
153 | return img
154 |
155 |
156 | def center_crop(img, shape):
157 | h, w = shape[0], shape[1]
158 | y = (img.shape[0] - h) // 2
159 | x = (img.shape[1] - w) // 2
160 | return img[y:y + h, x:x + w]
161 |
162 |
163 | def random_crop(img, gt, size):
164 | if isinstance(size, numbers.Number):
165 | size = (int(size), int(size))
166 |
167 | h, w = img.shape[:2]
168 | crop_h, crop_w = size[0], size[1]
169 |
170 | if h > crop_h:
171 | x = random.randint(0, h - crop_h + 1)
172 | img = img[x:x + crop_h, :, :]
173 | gt = gt[x:x + crop_h, :]
174 |
175 | if w > crop_w:
176 | x = random.randint(0, w - crop_w + 1)
177 | img = img[:, x:x + crop_w, :]
178 | gt = gt[:, x:x + crop_w]
179 |
180 | return img, gt
181 |
182 |
183 | def normalize(img, mean, std):
184 | # pytorch pretrained model need the input range: 0-1
185 | img = img.astype(np.float32) / 255.0
186 | img = img - mean
187 | img = img / std
188 |
189 | return img
190 |
--------------------------------------------------------------------------------
/tools/utils/pyt_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | # encoding: utf-8
5 | import os
6 | import time
7 | import argparse
8 | from collections import OrderedDict
9 |
10 | import torch
11 | import torch.distributed as dist
12 |
13 | from tools.engine.logger import get_logger
14 |
15 | logger = get_logger()
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | }
24 |
25 |
26 | # def reduce_tensor(tensor, dst=0, op=dist.ReduceOp.SUM, world_size=1):
27 | # tensor = tensor.clone()
28 | # dist.reduce(tensor, dst, op)
29 | # if dist.get_rank() == dst:
30 | # tensor.div_(world_size)
31 | #
32 | # return tensor
33 |
34 |
35 | # def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM, world_size=1):
36 | # tensor = tensor.clone()
37 | # dist.all_reduce(tensor, op)
38 | # tensor.div_(world_size)
39 | #
40 | # return tensor
41 |
42 |
43 | def load_model(model, model_file, is_restore=False):
44 | t_start = time.time()
45 | if isinstance(model_file, str):
46 | state_dict = torch.load(model_file)
47 | if 'model' in state_dict.keys():
48 | state_dict = state_dict['model']
49 | else:
50 | state_dict = model_file
51 | t_ioend = time.time()
52 |
53 | if is_restore:
54 | new_state_dict = OrderedDict()
55 | for k, v in state_dict.items():
56 | name = 'module.' + k
57 | new_state_dict[name] = v
58 | state_dict = new_state_dict
59 |
60 | model.load_state_dict(state_dict, strict=False)
61 | ckpt_keys = set(state_dict.keys())
62 | own_keys = set(model.state_dict().keys())
63 | missing_keys = own_keys - ckpt_keys
64 | unexpected_keys = ckpt_keys - own_keys
65 |
66 | if len(missing_keys) > 0:
67 | logger.warning('Missing key(s) in state_dict: {}'.format(
68 | ', '.join('{}'.format(k) for k in missing_keys)))
69 |
70 | if len(unexpected_keys) > 0:
71 | logger.warning('Unexpected key(s) in state_dict: {}'.format(
72 | ', '.join('{}'.format(k) for k in unexpected_keys)))
73 |
74 | del state_dict
75 | t_end = time.time()
76 | logger.info(
77 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format(
78 | t_ioend - t_start, t_end - t_ioend))
79 |
80 | return model
81 |
82 |
83 | def parse_devices(input_devices):
84 | if input_devices.endswith('*'):
85 | devices = list(range(torch.cuda.device_count()))
86 | return devices
87 |
88 | devices = []
89 | for d in input_devices.split(','):
90 | if '-' in d:
91 | start_device, end_device = d.split('-')[0], d.split('-')[1]
92 | assert start_device != ''
93 | assert end_device != ''
94 | start_device, end_device = int(start_device), int(end_device)
95 | assert start_device < end_device
96 | assert end_device < torch.cuda.device_count()
97 | for sd in range(start_device, end_device + 1):
98 | devices.append(sd)
99 | else:
100 | device = int(d)
101 | assert device < torch.cuda.device_count()
102 | devices.append(device)
103 |
104 | logger.info('using devices {}'.format(
105 | ', '.join([str(d) for d in devices])))
106 |
107 | return devices
108 |
109 |
110 | def extant_file(x):
111 | """
112 | 'Type' for argparse - checks that file exists but does not open.
113 | """
114 | if not os.path.exists(x):
115 | # Argparse uses the ArgumentTypeError to give a rejection message like:
116 | # error: argument input: x does not exist
117 | raise argparse.ArgumentTypeError("{0} does not exist".format(x))
118 | return x
119 |
120 |
121 | def link_file(src, target):
122 | if os.path.isdir(target) or os.path.isfile(target):
123 | os.remove(target)
124 | os.system('ln -s {} {}'.format(src, target))
125 |
126 |
127 | def ensure_dir(path):
128 | if not os.path.isdir(path):
129 | os.makedirs(path)
130 |
131 |
132 | def _dbg_interactive(var, value):
133 | from IPython import embed
134 | embed()
135 |
--------------------------------------------------------------------------------
/tools/utils/visualize.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import numpy as np
5 | import cv2
6 | import scipy.io as sio
7 |
8 |
9 | def set_img_color(colors, background, img, gt, show255=False, weight_foreground=0.55):
10 | origin = np.array(img)
11 | for i in range(len(colors)):
12 | if i != background:
13 | img[np.where(gt == i)] = colors[i]
14 | if show255:
15 | img[np.where(gt == 255)] = 0
16 | cv2.addWeighted(img, weight_foreground, origin, 1 - weight_foreground, 0, img)
17 | return img
18 |
19 |
20 | def show_prediction(colors, background, img, pred):
21 | im = np.array(img, np.uint8)
22 | set_img_color(colors, background, im, pred, weight_foreground=1)
23 | final = np.array(im)
24 | return final
25 |
26 |
27 | def show_img(colors, background, img, clean, gt, *pds):
28 | im1 = np.array(img, np.uint8)
29 | # set_img_color(colors, background, im1, clean)
30 | final = np.array(im1)
31 | # the pivot black bar
32 | pivot = np.zeros((im1.shape[0], 15, 3), dtype=np.uint8)
33 | for pd in pds:
34 | im = np.array(img, np.uint8)
35 | # pd[np.where(gt == 255)] = 255
36 | set_img_color(colors, background, im, pd)
37 | final = np.column_stack((final, pivot))
38 | final = np.column_stack((final, im))
39 |
40 | im = np.array(img, np.uint8)
41 | set_img_color(colors, background, im, gt, True)
42 | final = np.column_stack((final, pivot))
43 | final = np.column_stack((final, im))
44 | return final
45 |
46 |
47 | def get_colors(class_num):
48 | colors = []
49 | for i in range(class_num):
50 | colors.append((np.random.random((1, 3)) * 255).tolist()[0])
51 |
52 | return colors
53 |
54 |
55 | def get_ade_colors():
56 | colors = sio.loadmat('./color150.mat')['colors']
57 | colors = colors[:, ::-1, ]
58 | colors = np.array(colors).astype(int).tolist()
59 | colors.insert(0, [0, 0, 0])
60 |
61 | return colors
62 |
63 |
64 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False,
65 | no_print=False):
66 | n = iu.size
67 | lines = []
68 | for i in range(n):
69 | if class_names is None:
70 | cls = 'Class %d:' % (i + 1)
71 | else:
72 | cls = '%d %s' % (i + 1, class_names[i])
73 | lines.append('%-8s\t%.3f%%' % (cls, iu[i] * 100))
74 | mean_IU = np.nanmean(iu)
75 | # mean_IU_no_back = np.nanmean(iu[1:])
76 | mean_IU_no_back = np.nanmean(iu[:-1])
77 | if show_no_back:
78 | lines.append(
79 | '---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%\t%-8s\t%.3f%%' % (
80 | 'mean_IU', mean_IU * 100, 'mean_IU_no_back',
81 | mean_IU_no_back * 100,
82 | 'mean_pixel_ACC', mean_pixel_acc * 100))
83 | else:
84 | print(mean_pixel_acc)
85 | lines.append(
86 | '---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%' % (
87 | 'mean_IU', mean_IU * 100, 'mean_pixel_ACC',
88 | mean_pixel_acc * 100))
89 | line = "\n".join(lines)
90 | if not no_print:
91 | print(line)
92 | return line
93 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | # [x] train resnet101 with proxy guidance on visda17
5 | # [x] evaluation on visda17
6 |
7 | import argparse
8 | import os
9 | import sys
10 | import logging
11 | import time
12 | from tqdm import tqdm
13 | import torch
14 | import torch.nn as nn
15 | from torch.nn import functional as F
16 | from torch.utils.data import DataLoader
17 | import torchvision.transforms as transforms
18 | from pdb import set_trace as bp
19 |
20 | from data.visda17 import VisDA17
21 | from model.resnet import resnet101
22 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, adjust_learning_rate, accuracy
23 | from utils.logger import prepare_logger, prepare_seed
24 | from utils.sgd import SGD
25 |
26 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean')
27 | KLDivLoss = nn.KLDivLoss(reduction='batchmean')
28 |
29 | parser = argparse.ArgumentParser(description='ASG Training')
30 | parser.add_argument('--data', default='/raid/taskcv-2017-public/classification/data', help='path to dataset')
31 | parser.add_argument('--epochs', default=30, type=int, help='number of total epochs to run')
32 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
33 | parser.add_argument('--batch-size', default=32, type=int, dest='batch_size', help='mini-batch size (default: 32)')
34 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate')
35 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)')
36 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
37 | parser.add_argument('--lwf', default=0., type=float, dest='lwf', help='weight of KL loss for LwF (default: 0)')
38 | parser.add_argument('--resume', default='none', type=str, help='path to latest checkpoint (default: none)')
39 | parser.add_argument('--evaluate', action='store_true', help='whether to use learn without forgetting (default: False)')
40 | parser.add_argument('--timestamp', type=str, default='none', help='timestamp for logging naming')
41 | parser.add_argument('--save_dir', type=str, default="./runs", help='root folder to save checkpoints and log.')
42 | parser.add_argument('--train_blocks', type=str, default="conv1.bn1.layer1.layer2.layer3.layer4.fc", help='blocks to train, seperated by dot.')
43 | parser.add_argument('--num-class', default=12, type=int, dest='num_class', help='the number of classes')
44 | parser.add_argument('--rand_seed', default=0, type=int, help='the number of classes')
45 |
46 | best_prec1 = 0
47 |
48 | def main():
49 | global args, best_prec1
50 | PID = os.getpid()
51 | args = parser.parse_args()
52 | prepare_seed(args.rand_seed)
53 |
54 | if args.timestamp == 'none':
55 | args.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))
56 |
57 | # Log outputs
58 | if args.evaluate:
59 | args.save_dir = args.save_dir + "/Visda17-Res101-evaluate" + \
60 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
61 | else:
62 | args.save_dir = args.save_dir + \
63 | "/Visda17-Res101-%s-train.%s-LR%.2E-epoch%d-batch%d-seed%d"%(
64 | "LWF%.2f"%args.lwf if args.lwf > 0 else "XE", args.train_blocks, args.lr, args.epochs, args.batch_size, args.rand_seed) + \
65 | "%s/%s"%('/'+args.resume if args.resume != 'none' else '', args.timestamp)
66 | logger = prepare_logger(args)
67 |
68 | data_transforms = {
69 | 'train': transforms.Compose([
70 | transforms.Resize(224),
71 | transforms.CenterCrop(224),
72 | transforms.ToTensor(),
73 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
74 | ]),
75 | 'val': transforms.Compose([
76 | transforms.Resize(224),
77 | transforms.CenterCrop(224),
78 | transforms.ToTensor(),
79 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
80 | ]),
81 | }
82 |
83 | kwargs = {'num_workers': 20, 'pin_memory': True}
84 | trainset = VisDA17(txt_file=os.path.join(args.data, "train/image_list.txt"), root_dir=os.path.join(args.data, "train"), transform=data_transforms['train'])
85 | valset = VisDA17(txt_file=os.path.join(args.data, "validation/image_list.txt"), root_dir=os.path.join(args.data, "validation"), transform=data_transforms['val'], label_one_hot=True)
86 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
87 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, **kwargs)
88 |
89 | model = resnet101(pretrained=True)
90 | num_ftrs = model.fc.in_features
91 | fc_layers = nn.Sequential(
92 | nn.Linear(num_ftrs, 512),
93 | nn.ReLU(inplace=True),
94 | nn.Linear(512, args.num_class),
95 | )
96 | model.fc_new = fc_layers
97 |
98 | train_blocks = args.train_blocks.split('.')
99 | # default turn-off fc, turn-on fc_new
100 | for param in model.fc.parameters():
101 | param.requires_grad = False
102 | ##### Freeze several bottom layers (Optional) #####
103 | non_train_blocks = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
104 | for name in train_blocks:
105 | try:
106 | non_train_blocks.remove(name)
107 | except Exception:
108 | print("cannot find block name %s\nAvailable blocks are: conv1, bn1, layer1, layer2, layer3, layer4, fc"%name)
109 | for name in non_train_blocks:
110 | for param in getattr(model, name).parameters():
111 | param.requires_grad = False
112 |
113 | # Setup optimizer
114 | factor = 0.1
115 | sgd_in = []
116 | for name in train_blocks:
117 | if name != 'fc':
118 | sgd_in.append({'params': get_params(model, [name]), 'lr': factor*args.lr})
119 | else:
120 | sgd_in.append({'params': get_params(model, ["fc_new"]), 'lr': args.lr})
121 | base_lrs = [ group['lr'] for group in sgd_in ]
122 | optimizer = SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
123 |
124 | # Optionally resume from a checkpoint
125 | if args.resume != 'none':
126 | if os.path.isfile(args.resume):
127 | print("=> loading checkpoint '{}'".format(args.resume))
128 | checkpoint = torch.load(args.resume)
129 | args.start_epoch = checkpoint['epoch']
130 | best_prec1 = checkpoint['best_prec1']
131 | model.load_state_dict(checkpoint['state_dict'])
132 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
133 | else:
134 | print("=ImageClassdata> no checkpoint found at '{}'".format(args.resume))
135 |
136 | model = model.cuda()
137 |
138 | model_old = None
139 | if args.lwf > 0:
140 | # create a fixed model copy for Life-long learning
141 | model_old = resnet101(pretrained=True)
142 | for param in model_old.parameters():
143 | param.requires_grad = False
144 | model_old.eval()
145 | model_old.cuda()
146 |
147 | if args.evaluate:
148 | prec1 = validate(val_loader, model)
149 | print(prec1)
150 | exit(0)
151 |
152 | # Main training loop
153 | iter_max = args.epochs * len(train_loader)
154 | iter_stat = IterNums(iter_max)
155 | for epoch in range(args.start_epoch, args.epochs):
156 | print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, args.save_dir))
157 | logger.log("Epoch: %d"%(epoch+1))
158 | # train for one epoch
159 | train(train_loader, model, optimizer, base_lrs, iter_stat, epoch, logger.writer, model_old=model_old, adjust_lr=True)
160 |
161 | # evaluate on validation set
162 | prec1 = validate(val_loader, model)
163 | logger.writer.add_scalar("prec", prec1, epoch)
164 |
165 | # remember best prec@1 and save checkpoint
166 | is_best = prec1 > best_prec1
167 | best_prec1 = max(prec1, best_prec1)
168 | save_checkpoint(args.save_dir, {
169 | 'epoch': epoch + 1,
170 | 'state_dict': model.state_dict(),
171 | 'best_prec1': best_prec1,
172 | }, is_best)
173 |
174 | logging.info('Best accuracy: {prec1:.3f}'.format(prec1=best_prec1))
175 |
176 |
177 | def train(train_loader, model, optimizer, base_lrs, iter_stat, epoch, writer, model_old=None, adjust_lr=True):
178 | kl_weight = args.lwf
179 | """Train for one epoch on the training set"""
180 | batch_time = AverageMeter()
181 | losses = AverageMeter()
182 | losses_kl = AverageMeter()
183 |
184 | model.eval()
185 |
186 | # start timer
187 | end = time.time()
188 |
189 | # train for one epoch
190 | optimizer.zero_grad()
191 | epoch_size = len(train_loader)
192 | train_loader_iter = iter(train_loader)
193 |
194 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
195 | pbar = tqdm(range(epoch_size), file=sys.stdout, bar_format=bar_format, ncols=80)
196 |
197 | for idx_iter in pbar:
198 |
199 | optimizer.zero_grad()
200 | if adjust_lr:
201 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9)
202 | writer.add_scalar("lr", lr, idx_iter + epoch * epoch_size)
203 | adjust_learning_rate(base_lrs, optimizer, iter_stat.iter_curr, iter_stat.iter_max, 0.9)
204 |
205 | input, label = next(train_loader_iter)
206 | label = label.cuda()
207 | input = input.cuda()
208 |
209 | # compute output
210 | output, features_new = model(input, output_features=['layer1', 'layer4'], task='new')
211 |
212 | # compute gradient
213 | loss = CrossEntropyLoss(output, label.long())
214 |
215 | # LWF KL div
216 | if model_old is None:
217 | loss_kl = 0
218 | else:
219 | output_new = model.forward_fc(features_new['layer4'], task='old')
220 | output_old, features_old = model_old(input, output_features=['layer1', 'layer4'], task='old')
221 | loss_kl = KLDivLoss(F.log_softmax(output_new, dim=1), F.softmax(output_old, dim=1)).sum(-1)
222 |
223 | (loss + kl_weight * loss_kl).backward()
224 |
225 | # measure accuracy and record loss
226 | losses.update(loss, input.size(0))
227 | losses_kl.update(loss_kl, input.size(0))
228 |
229 | # compute gradient and do SGD step
230 | optimizer.step()
231 | # increment iter number
232 | iter_stat.update()
233 | # measure elapsed time
234 | batch_time.update(time.time() - end)
235 | end = time.time()
236 |
237 | writer.add_scalar("loss/ce", losses.val, idx_iter + epoch * epoch_size)
238 | writer.add_scalar("loss/kl", losses_kl.val, idx_iter + epoch * epoch_size)
239 | writer.add_scalar("loss/total", losses.val + losses_kl.val, idx_iter + epoch * epoch_size)
240 | description = "[loss: %.3f][loss_kl: %.3f]"%(losses.val, losses_kl.val)
241 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, epoch_size) + description)
242 |
243 |
244 | def validate(val_loader, model):
245 | """Perform validation on the validation set"""
246 | batch_time = AverageMeter()
247 | top1 = AverageMeter()
248 |
249 | model.eval()
250 |
251 | end = time.time()
252 | val_size = len(val_loader)
253 | val_loader_iter = iter(val_loader)
254 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
255 | pbar = tqdm(range(val_size), file=sys.stdout, bar_format=bar_format, ncols=140)
256 | with torch.no_grad():
257 | for idx_iter in pbar:
258 | input, label = next(val_loader_iter)
259 |
260 | input = input.cuda()
261 | label = label.cuda()
262 |
263 | # compute output
264 | output = torch.sigmoid(model(input, task='new')[0])
265 | output = (output + torch.sigmoid(model(torch.flip(input, dims=(3,)), task='new')[0])) / 2
266 |
267 | # accumulate accuracyk
268 | prec1, gt_num = accuracy(output.data, label, args.num_class, topk=(1,))
269 | top1.update(prec1[0], gt_num[0])
270 |
271 | # measure elapsed time
272 | batch_time.update(time.time() - end)
273 | end = time.time()
274 |
275 | description = "[Acc@1-mean: %.2f][Acc@1-cls: %s]"%(top1.vec2sca_avg, str(top1.avg.numpy().round(1)))
276 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, val_size) + description)
277 |
278 | logging.info(' * Prec@1 {top1.vec2sca_avg:.3f}'.format(top1=top1))
279 | logging.info(' * Prec@1 {top1.avg}'.format(top1=top1))
280 |
281 | return top1.vec2sca_avg
282 |
283 |
284 | if __name__ == "__main__":
285 | main()
286 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | # [x] train resnet101 with proxy guidance on visda17
5 | # [x] evaluation on visda17
6 |
7 | python train.py \
8 | --epochs 30 \
9 | --batch-size 32 \
10 | --lr 1e-4 \
11 | --lwf 0.1 \
12 | # --resume pretrained/res101_vista17_best.pth.tar \
13 | # --evaluate
14 |
--------------------------------------------------------------------------------
/train_seg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import argparse
5 | import os
6 | import sys
7 | import logging
8 | import time
9 | import numpy as np
10 | import math
11 | from tqdm import tqdm
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.parallel
15 | import torch.optim
16 | from torch.nn import functional as F
17 | from torch.utils.tensorboard import SummaryWriter
18 | from data.gta5 import GTA5
19 | from data.cityscapes import Cityscapes
20 | from model.vgg import vgg16
21 | from model.fcn8s_vgg import FCN8sAtOnce as FCN_Vgg
22 | from dataloader_seg import get_train_loader
23 | from eval_seg import SegEvaluator
24 | from utils.utils import get_params, IterNums, save_checkpoint, AverageMeter, lr_poly, adjust_learning_rate
25 | from pdb import set_trace as bp
26 | torch.backends.cudnn.enabled = True
27 |
28 | CrossEntropyLoss = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)
29 | KLDivLoss = nn.KLDivLoss(reduction='batchmean')
30 | best_mIoU = 0
31 |
32 | parser = argparse.ArgumentParser(description='PyTorch ResNet Training')
33 | parser.add_argument('--epochs', default=300, type=int, help='number of total epochs to run')
34 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
35 | parser.add_argument('--batch-size', default=6, type=int, dest='batch_size', help='mini-batch size (default: 6)')
36 | parser.add_argument('--iter-size', default=1, type=int, dest='iter_size', help='iteration size (default: 1)')
37 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, help='initial learning rate')
38 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, help='weight decay (default: 5e-4)')
39 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
40 | parser.add_argument('--lwf', default=0., type=float, dest='lwf', help='weight of KL loss for LwF (default: 0)')
41 | parser.add_argument('--factor', default=0.1, type=float, dest='factor', help='scale factor of backbone learning rate (default: 0.1)')
42 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)')
43 | parser.add_argument('--name', default='Vgg16_GTA5', type=str, help='name of experiment')
44 | parser.add_argument('--tensorboard', help='Log progress to TensorBoard', action='store_true')
45 | parser.add_argument('--num-class', default=19, type=int, dest='num_class', help='the number of classes')
46 | parser.add_argument('--gpus', default=0, type=int, help='use gpu with cuda number')
47 | parser.add_argument('--evaluate', action='store_true', help='whether to use learn without forgetting (default: False)')
48 | parser.set_defaults(bottleneck=True)
49 | parser.set_defaults(augment=True)
50 |
51 |
52 | def main():
53 | global args, best_mIoU
54 | args = parser.parse_args()
55 | pid = os.getpid()
56 |
57 | # Log outputs
58 | args.name = "GTA5_Vgg16_batch%d_512x512_Poly_LR%.1e_1to%.1f_all_lwf.%d_epoch%d"%(args.batch_size, args.lr, args.factor, args.lwf, args.epochs)
59 | if args.resume:
60 | args.name += "_resumed"
61 | directory = "runs/%s/"%(args.name)
62 | if not os.path.exists(directory):
63 | os.makedirs(directory)
64 | filename = directory + 'train.log'
65 | for handler in logging.root.handlers[:]:
66 | logging.root.removeHandler(handler)
67 | rootLogger = logging.getLogger()
68 | logFormatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
69 | fileHandler = logging.FileHandler(filename)
70 | fileHandler.setFormatter(logFormatter)
71 | rootLogger.addHandler(fileHandler)
72 |
73 | consoleHandler = logging.StreamHandler()
74 | consoleHandler.setFormatter(logFormatter)
75 | rootLogger.addHandler(consoleHandler)
76 | rootLogger.setLevel(logging.INFO)
77 |
78 | writer = SummaryWriter(directory)
79 |
80 | from config_seg import config as data_setting
81 | data_setting.batch_size = args.batch_size
82 | train_loader = get_train_loader(data_setting, GTA5, test=False)
83 |
84 | ##### Vgg16 #####
85 | vgg = vgg16(pretrained=True)
86 | model = FCN_Vgg(n_class=args.num_class)
87 | model.copy_params_from_vgg16(vgg)
88 | ###################
89 | threds = 1
90 | evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), args.num_class, np.array([0.485, 0.456, 0.406]),
91 | np.array([0.229, 0.224, 0.225]), model, [1, ], False, devices=args.gpus, config=data_setting, threds=threds,
92 | verbose=False, save_path=None, show_image=False)
93 |
94 | # Setup optimizer
95 | ##### Vgg16 #####
96 | sgd_in = [
97 | {'params': get_params(model, ["conv1_1", "conv1_2"]), 'lr': args.factor*args.lr},
98 | {'params': get_params(model, ["conv2_1", "conv2_2"]), 'lr': args.factor*args.lr},
99 | {'params': get_params(model, ["conv3_1", "conv3_2", "conv3_3"]), 'lr': args.factor*args.lr},
100 | {'params': get_params(model, ["conv4_1", "conv4_2", "conv4_3"]), 'lr': args.factor*args.lr},
101 | {'params': get_params(model, ["conv5_1", "conv5_2", "conv5_3"]), 'lr': args.factor*args.lr},
102 | {'params': get_params(model, ["fc6", "fc7"]), 'lr': args.factor*args.lr},
103 | {'params': get_params(model, ["score_fr", "score_pool3", "score_pool4", "upscore2", "upscore8", "upscore_pool4"]), 'lr': args.lr},
104 | ]
105 | base_lrs = [ group['lr'] for group in sgd_in ]
106 | optimizer = torch.optim.SGD(sgd_in, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
107 |
108 | # Optionally resume from a checkpoint
109 | if args.resume:
110 | if os.path.isfile(args.resume):
111 | print("=> loading checkpoint '{}'".format(args.resume))
112 | checkpoint = torch.load(args.resume)
113 | args.start_epoch = checkpoint['epoch']
114 | best_mIoU = checkpoint['best_mIoU']
115 | model.load_state_dict(checkpoint['state_dict'])
116 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
117 | else:
118 | print("=ImageClassdata> no checkpoint found at '{}'".format(args.resume))
119 |
120 | model = model.cuda()
121 | model_old = None
122 | if args.lwf > 0:
123 | # create a fixed model copy for Life-long learning
124 | model_old = vgg16(pretrained=True)
125 | ###################
126 | for param in model_old.parameters():
127 | param.requires_grad = False
128 | model_old.eval()
129 | model_old.cuda()
130 |
131 | if args.evaluate:
132 | mIoU = validate(evaluator, model)
133 | print(mIoU)
134 |
135 | # Main training loop
136 | iter_max = args.epochs * math.ceil(len(train_loader)/args.iter_size)
137 | iter_stat = IterNums(iter_max)
138 | for epoch in range(args.start_epoch, args.epochs):
139 | logging.info("============= " + args.name + " ================")
140 | logging.info("============= PID: " + str(pid) + " ================")
141 | logging.info("Epoch: %d"%(epoch+1))
142 | # train for one epoch
143 | train(args, train_loader, model, optimizer, base_lrs, iter_stat, epoch, writer, model_old=model_old, adjust_lr=epoch best_mIoU
150 | best_mIoU = max(mIoU, best_mIoU)
151 | save_checkpoint(directory, {
152 | 'epoch': epoch + 1,
153 | 'state_dict': model.state_dict(),
154 | 'best_mIoU': best_mIoU,
155 | }, is_best)
156 |
157 | logging.info('Best accuracy: {mIoU:.3f}'.format(mIoU=best_mIoU))
158 |
159 |
160 | def train(args, train_loader, model, optimizer, base_lrs, iter_stat, epoch, writer, model_old=None, adjust_lr=True):
161 | """Train for one epoch on the training set"""
162 | losses = AverageMeter()
163 | losses_kl = AverageMeter()
164 |
165 | model.eval()
166 |
167 | # train for one epoch
168 | optimizer.zero_grad()
169 | epoch_size = len(train_loader)
170 | train_loader_iter = iter(train_loader)
171 |
172 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
173 | pbar = tqdm(range(epoch_size), file=sys.stdout, bar_format=bar_format, ncols=80)
174 |
175 | for idx_iter in pbar:
176 | loss_print = 0
177 | loss_kl_print = 0
178 | avg_size = 0
179 |
180 | optimizer.zero_grad()
181 | if adjust_lr:
182 | lr = lr_poly(base_lrs[-1], iter_stat.iter_curr, iter_stat.iter_max, 0.9)
183 | writer.add_scalar("lr", lr, idx_iter + epoch * epoch_size)
184 | adjust_learning_rate(base_lrs, optimizer, iter_stat.iter_curr, iter_stat.iter_max, 0.9)
185 |
186 | sample = next(train_loader_iter)
187 | label = sample['label'].cuda()
188 | input = sample['data'].cuda()
189 |
190 | # compute output
191 | output, features_new = model(input, output_features=['layer4'], task='new_seg')
192 |
193 | # compute gradient
194 | loss = CrossEntropyLoss(output, label.long())
195 | loss_print += loss
196 |
197 | # LWF KL div
198 | if model_old is None:
199 | loss_kl = 0
200 | else:
201 | output_new = model_old.forward_fc(features_new['layer4'], task='old')
202 | output_old, features_old = model_old(input, output_features=[], task='old')
203 | loss_kl = KLDivLoss(F.log_softmax(output_new, dim=1), F.softmax(output_old, dim=1)).sum(-1)
204 | loss_kl_print += loss_kl
205 |
206 | (loss + args.lwf * loss_kl).backward()
207 |
208 | # update size
209 | avg_size += input.size(0)
210 |
211 | # measure accuracy and record loss
212 | losses.update(loss_print, avg_size)
213 | losses_kl.update(loss_kl_print, avg_size)
214 |
215 | # compute gradient and do SGD step
216 | optimizer.step()
217 | # increment iter number
218 | iter_stat.update()
219 |
220 | writer.add_scalar("loss/ce", losses.val, idx_iter + epoch * epoch_size)
221 | writer.add_scalar("loss/kl", losses_kl.val, idx_iter + epoch * epoch_size)
222 | writer.add_scalar("loss/total", losses.val + losses_kl.val, idx_iter + epoch * epoch_size)
223 | description = "[loss: %.3f][loss_kl: %.3f]"%(losses.val, losses_kl.val)
224 | pbar.set_description("[Step %d/%d]"%(idx_iter + 1, epoch_size) + description)
225 |
226 |
227 | def validate(evaluator, model):
228 | with torch.no_grad():
229 | model.eval()
230 | # _, mIoU = evaluator.run_online()
231 | _, mIoU = evaluator.run_online_multiprocess()
232 | return mIoU
233 |
234 |
235 | if __name__ == '__main__':
236 | main()
237 |
--------------------------------------------------------------------------------
/train_seg.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | # [x] train vgg16 with proxy guidance on GTA5
5 | # [x] evaluation on Cityscapes
6 |
7 | python train_seg.py \
8 | --epochs 50 \
9 | --batch-size 6 \
10 | --lr 1e-3 \
11 | --num-class 19 \
12 | --gpus 0 \
13 | --factor 0.1 \
14 | --lwf 75. \
15 | # --evaluate \
16 | # --resume ./pretrained/vgg16_segmentation_best.pth.tar
17 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | from pathlib import Path
5 | import importlib, warnings
6 | import os, sys, time, numpy as np
7 | import torch, random, PIL, copy
8 | from os import path as osp
9 | from shutil import copyfile
10 | if sys.version_info.major == 2: # Python 2.x
11 | from StringIO import StringIO as BIO
12 | else: # Python 3.x
13 | from io import BytesIO as BIO
14 | from torch.utils.tensorboard import SummaryWriter
15 | if importlib.util.find_spec('tensorflow'):
16 | import tensorflow as tf
17 |
18 |
19 | def prepare_seed(rand_seed):
20 | random.seed(rand_seed)
21 | np.random.seed(rand_seed)
22 | torch.manual_seed(rand_seed)
23 | torch.cuda.manual_seed(rand_seed)
24 | torch.cuda.manual_seed_all(rand_seed)
25 |
26 |
27 | def prepare_logger(xargs):
28 | args = copy.deepcopy( xargs )
29 | logger = Logger(args.save_dir, args.rand_seed)
30 | logger.log('Main Function with logger : {:}'.format(logger))
31 | logger.log('Arguments : -------------------------------')
32 | for name, value in args._get_kwargs():
33 | logger.log('{:16} : {:}'.format(name, value))
34 | logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
35 | logger.log("Pillow Version : {:}".format(PIL.__version__))
36 | logger.log("PyTorch Version : {:}".format(torch.__version__))
37 | logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
38 | logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
39 | logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
40 | logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
41 | return logger
42 |
43 |
44 | class PrintLogger(object):
45 |
46 | def __init__(self):
47 | """Create a summary writer logging to log_dir."""
48 | self.name = 'PrintLogger'
49 |
50 | def log(self, string):
51 | print (string)
52 |
53 | def close(self):
54 | print ('-'*30 + ' close printer ' + '-'*30)
55 |
56 |
57 | class Logger(object):
58 |
59 | def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
60 | """Create a summary writer logging to log_dir."""
61 | self.seed = int(seed)
62 | self.log_dir = Path(log_dir)
63 | self.model_dir = Path(log_dir) / 'model'
64 | self.log_dir.mkdir (parents=True, exist_ok=True)
65 | # if create_model_dir:
66 | # self.model_dir.mkdir(parents=True, exist_ok=True)
67 | #self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
68 |
69 | self.use_tf = bool(use_tf)
70 | self.tensorboard_dir = self.log_dir
71 | #self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
72 | # self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time())))
73 | self.logger_path = self.log_dir / 'seed-{:}.log'.format(self.seed)
74 | self.logger_file = open(self.logger_path, 'w')
75 |
76 | self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
77 | self.writer = SummaryWriter(str(self.tensorboard_dir))
78 |
79 | def __repr__(self):
80 | return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__))
81 |
82 | def path(self, mode):
83 | valids = ('model', 'best', 'info', 'log')
84 | if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed)
85 | elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed)
86 | elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed)
87 | elif mode == 'log' : return self.log_dir
88 | else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids))
89 |
90 | def extract_log(self):
91 | return self.logger_file
92 |
93 | def close(self):
94 | self.logger_file.close()
95 | if self.writer is not None:
96 | self.writer.close()
97 |
98 | def log(self, string, save=True, stdout=False):
99 | if stdout:
100 | sys.stdout.write(string); sys.stdout.flush()
101 | else:
102 | print (string)
103 | if save:
104 | self.logger_file.write('{:}\n'.format(string))
105 | self.logger_file.flush()
106 |
107 | def scalar_summary(self, tags, values, step):
108 | """Log a scalar variable."""
109 | if not self.use_tf:
110 | warnings.warn('Do set use-tensorflow installed but call scalar_summary')
111 | else:
112 | assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values))
113 | if not isinstance(tags, list):
114 | tags, values = [tags], [values]
115 | for tag, value in zip(tags, values):
116 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
117 | self.writer.add_summary(summary, step)
118 | self.writer.flush()
119 |
120 | def image_summary(self, tag, images, step):
121 | """Log a list of images."""
122 | import scipy
123 | if not self.use_tf:
124 | warnings.warn('Do set use-tensorflow installed but call scalar_summary')
125 | return
126 |
127 | img_summaries = []
128 | for i, img in enumerate(images):
129 | # Write the image to a string
130 | try:
131 | s = StringIO()
132 | except:
133 | s = BytesIO()
134 | scipy.misc.toimage(img).save(s, format="png")
135 |
136 | # Create an Image object
137 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
138 | height=img.shape[0],
139 | width=img.shape[1])
140 | # Create a Summary value
141 | img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum))
142 |
143 | # Create and write Summary
144 | summary = tf.Summary(value=img_summaries)
145 | self.writer.add_summary(summary, step)
146 | self.writer.flush()
147 |
148 | def histo_summary(self, tag, values, step, bins=1000):
149 | """Log a histogram of the tensor of values."""
150 | if not self.use_tf: raise ValueError('Do not have tensorflow')
151 | import tensorflow as tf
152 |
153 | # Create a histogram using numpy
154 | counts, bin_edges = np.histogram(values, bins=bins)
155 |
156 | # Fill the fields of the histogram proto
157 | hist = tf.HistogramProto()
158 | hist.min = float(np.min(values))
159 | hist.max = float(np.max(values))
160 | hist.num = int(np.prod(values.shape))
161 | hist.sum = float(np.sum(values))
162 | hist.sum_squares = float(np.sum(values**2))
163 |
164 | # Drop the start of the first bin
165 | bin_edges = bin_edges[1:]
166 |
167 | # Add bin edges and counts
168 | for edge in bin_edges:
169 | hist.bucket_limit.append(edge)
170 | for c in counts:
171 | hist.bucket.append(c)
172 |
173 | # Create and write Summary
174 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
175 | self.writer.add_summary(summary, step)
176 | self.writer.flush()
177 |
--------------------------------------------------------------------------------
/utils/sgd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.)
3 |
4 | import torch
5 | from torch.optim.optimizer import Optimizer, required
6 |
7 |
8 | # fixed SGD
9 | # See Note here: https://pytorch.org/docs/stable/optim.html#torch.optim.SGD
10 | class SGD(Optimizer):
11 | r"""Implements stochastic gradient descent (optionally with momentum).
12 |
13 | Nesterov momentum is based on the formula from
14 | `On the importance of initialization and momentum in deep learning`__.
15 |
16 | Args:
17 | params (iterable): iterable of parameters to optimize or dicts defining
18 | parameter groups
19 | lr (float): learning rate
20 | momentum (float, optional): momentum factor (default: 0)
21 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
22 | dampening (float, optional): dampening for momentum (default: 0)
23 | nesterov (bool, optional): enables Nesterov momentum (default: False)
24 |
25 | Example:
26 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
27 | >>> optimizer.zero_grad()
28 | >>> loss_fn(model(input), target).backward()
29 | >>> optimizer.step()
30 |
31 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
32 |
33 | .. note::
34 | The implementation of SGD with Momentum/Nesterov subtly differs from
35 | Sutskever et. al. and implementations in some other frameworks.
36 |
37 | Considering the specific case of Momentum, the update can be written as
38 |
39 | .. math::
40 | v = \rho * v + g \\
41 | p = p - lr * v
42 |
43 | where p, g, v and :math:`\rho` denote the parameters, gradient,
44 | velocity, and momentum respectively.
45 |
46 | This is in contrast to Sutskever et. al. and
47 | other frameworks which employ an update of the form
48 |
49 | .. math::
50 | v = \rho * v + lr * g \\
51 | p = p - v
52 |
53 | The Nesterov version is analogously modified.
54 | """
55 |
56 | def __init__(self, params, lr=required, momentum=0, dampening=0,
57 | weight_decay=0, nesterov=False):
58 | if lr is not required and lr < 0.0:
59 | raise ValueError("Invalid learning rate: {}".format(lr))
60 | if momentum < 0.0:
61 | raise ValueError("Invalid momentum value: {}".format(momentum))
62 | if weight_decay < 0.0:
63 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
64 |
65 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
66 | weight_decay=weight_decay, nesterov=nesterov)
67 | if nesterov and (momentum <= 0 or dampening != 0):
68 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
69 | super(SGD, self).__init__(params, defaults)
70 |
71 | def __setstate__(self, state):
72 | super(SGD, self).__setstate__(state)
73 | for group in self.param_groups:
74 | group.setdefault('nesterov', False)
75 |
76 | def step(self, closure=None):
77 | """Performs a single optimization step.
78 |
79 | Arguments:
80 | closure (callable, optional): A closure that reevaluates the model
81 | and returns the loss.
82 | """
83 | loss = None
84 | if closure is not None:
85 | loss = closure()
86 |
87 | for group in self.param_groups:
88 | weight_decay = group['weight_decay']
89 | momentum = group['momentum']
90 | dampening = group['dampening']
91 | nesterov = group['nesterov']
92 |
93 | for p in group['params']:
94 | if p.grad is None:
95 | continue
96 | d_p = p.grad.data
97 | if weight_decay != 0:
98 | d_p.add_(weight_decay, p.data)
99 | if momentum != 0:
100 | param_state = self.state[p]
101 | if 'momentum_buffer' not in param_state:
102 | # buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
103 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach().mul_(group['lr'])
104 | else:
105 | buf = param_state['momentum_buffer']
106 | # buf.mul_(momentum).add_(1 - dampening, d_p)
107 | buf.mul_(momentum).add_(1 - dampening, d_p.mul_(group['lr']))
108 | if nesterov:
109 | d_p = d_p.add(momentum, buf)
110 | else:
111 | d_p = buf
112 |
113 | # p.data.add_(-group['lr'], d_p)
114 | p.data.add_(-1, d_p)
115 |
116 | return loss
117 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 NVIDIA Corporation. All rights reserved.
2 | # This work is licensed under a NVIDIA Open Source Non-commercial license.
3 |
4 | import glob
5 | import os
6 | import shutil
7 | import numpy as np
8 | import torch
9 | from pdb import set_trace as bp
10 |
11 |
12 | def get_params(model, layers=["layer4"]):
13 | """
14 | This generator returns all the parameters of the net except for
15 | the last classification layer. Note that for each batchnorm layer,
16 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
17 | any batchnorm parameter
18 | """
19 | if isinstance(layers, str):
20 | layers = [layers]
21 | b = []
22 | for layer in layers:
23 | b.append(getattr(model, layer))
24 |
25 | for i in range(len(b)):
26 | for k, v in b[i].named_parameters():
27 | if v.requires_grad:
28 | yield v
29 |
30 |
31 | def adjust_learning_rate(base_lrs, optimizer, iter_curr, iter_max, power):
32 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
33 | num_groups = len(optimizer.param_groups)
34 | for g in range(num_groups):
35 | optimizer.param_groups[g]['lr'] = lr_poly(base_lrs[g], iter_curr, iter_max, power)
36 |
37 |
38 | def lr_poly(base_lr, iter, max_iter, power):
39 | # return min(0.01+0.99*(float(iter)/100)**2.0, 1.0) * base_lr * ((1-float(iter)/max_iter)**power) # This is with warm up
40 | return min(0.01+0.99*(float(iter)/100)**2.0, 1.0) * base_lr * ((1-min(float(iter)/max_iter, 0.8))**power) # This is with warm up & no smaller than last 20% LR
41 | # return base_lr * ((1-float(iter)/max_iter)**power)
42 |
43 |
44 | def save_checkpoint(name, state, is_best, filename='checkpoint.pth.tar', keep_last=1):
45 | """Saves checkpoint to disk"""
46 | directory = name
47 | if not os.path.exists(directory):
48 | os.makedirs(directory)
49 | models_paths = list(filter(os.path.isfile, glob.glob(directory + "/epoch*.pth.tar")))
50 | models_paths.sort(key=os.path.getmtime, reverse=False)
51 | if len(models_paths) == keep_last:
52 | for i in range(len(models_paths) + 1 - keep_last):
53 | os.remove(models_paths[i])
54 | # filename = directory + '/epoch_'+str(state['epoch']) + '_' + filename
55 | filename = directory + '/latest_' + filename
56 | torch.save(state, filename)
57 | if is_best:
58 | shutil.copyfile(filename, '%s/'%(name) + 'model_best.pth.tar')
59 |
60 |
61 | class IterNums(object):
62 | def __init__(self, iter_max):
63 | self.iter_max = iter_max
64 | self.iter_curr = 0
65 |
66 | def reset(self):
67 | self.iter_curr = 0
68 |
69 | def update(self):
70 | self.iter_curr += 1
71 |
72 |
73 | class AverageMeter(object):
74 | """Computes and stores the average and current value"""
75 | def __init__(self):
76 | self.reset()
77 |
78 | def reset(self):
79 | self.val = 0
80 | self.avg = 0
81 | self.sum = 0
82 | self.count = 0
83 | self.vec2sca_avg = 0
84 | self.vec2sca_val = 0
85 |
86 | def update(self, val, n=1):
87 | self.val = val
88 | self.sum += val * n
89 | self.count += n
90 | self.avg = self.sum / self.count
91 | if torch.is_tensor(self.val) and torch.numel(self.val) != 1:
92 | self.avg[self.count == 0] = 0
93 | self.vec2sca_avg = self.avg.sum() / len(self.avg)
94 | self.vec2sca_val = self.val.sum() / len(self.val)
95 |
96 |
97 | class ROC(object):
98 | def __init__(self, num_class):
99 | self.num_class = num_class
100 | self.pred = []
101 | self.label = []
102 |
103 | def update(self, pred, label):
104 | assert (self.num_class == pred.shape[0]), "num_class mismatch on input predictions!"
105 | assert (self.num_class == label.shape[0]), "num_class mismatch on input labels!"
106 | self.pred.append(pred)
107 | self.label.append(label)
108 |
109 | def roc_curve(self):
110 | pred = np.hstack(self.pred)
111 | label = np.hstack(self.label)
112 | p = label == 1
113 | n = ~p
114 | num_p = np.sum(p, axis=1)
115 | num_n = np.sum(n, axis=1)
116 | tpr = np.zeros((self.num_class, 101), np.float32)
117 | fpr = np.zeros((self.num_class, 101), np.float32)
118 | for idx in range(101):
119 | thre = 1 - idx/100.0
120 | pp = pred > thre
121 | tp = pp & p
122 | fp = pp & n
123 | num_tp = np.sum(tp, axis=1)
124 | num_fp = np.sum(fp, axis=1)
125 | tpr[:, idx] = num_tp/(num_p + (num_p == 0))
126 | fpr[:, idx] = num_fp/(num_n + (num_n == 0))
127 | return tpr, fpr
128 |
129 | def auc(self, tpr, fpr):
130 | assert(tpr.shape[0] == fpr.shape[0])
131 | auc = np.zeros(tpr.shape[0], np.float32)
132 | for idx in range(tpr.shape[0]):
133 | auc[idx] = metrics.auc(fpr[idx, :], tpr[idx, :])
134 | return auc
135 |
136 |
137 | def accuracy(output, label, num_class, topk=(1,)):
138 | """Computes the precision@k for the specified values of k, currently only k=1 is supported"""
139 | maxk = max(topk)
140 |
141 | _, pred = output.topk(maxk, 1, True, True)
142 | if len(label.size()) == 2:
143 | # one_hot label
144 | _, gt = label.topk(maxk, 1, True, True)
145 | else:
146 | gt = label
147 | pred = pred.t()
148 | pred_class_idx_list = [pred == class_idx for class_idx in range(num_class)]
149 | gt = gt.t()
150 | gt_class_number_list = [(gt == class_idx).sum() for class_idx in range(num_class)]
151 | correct = pred.eq(gt)
152 |
153 | res = []
154 | gt_num = []
155 | for k in topk:
156 | correct_k = correct[:k].float()
157 | per_class_correct_list = [correct_k[pred_class_idx].sum(0) for pred_class_idx in pred_class_idx_list]
158 | per_class_correct_array = torch.tensor(per_class_correct_list)
159 | gt_class_number_tensor = torch.tensor(gt_class_number_list).float()
160 | gt_class_zeronumber_tensor = gt_class_number_tensor == 0
161 | gt_class_number_matrix = torch.tensor(gt_class_number_list).float()
162 | gt_class_acc = per_class_correct_array.mul_(100.0 / gt_class_number_matrix)
163 | gt_class_acc[gt_class_zeronumber_tensor] = 0
164 | res.append(gt_class_acc)
165 | gt_num.append(gt_class_number_matrix)
166 | return res, gt_num
167 |
--------------------------------------------------------------------------------