├── .gitignore
├── .idea
├── DTQ.iml
├── deployment.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── vcs.xml
├── webServers.xml
└── workspace.xml
├── LICENSE
├── README.md
├── data
├── __init__.py
├── data_split.py
└── dataloader.py
├── framework.png
├── framework.py
├── hocon_config
├── mobilenetv2_caltech-256-30.hocon
└── resnet50_caltech-256-30.hocon
├── main.py
├── models
├── __init__.py
├── get_model.py
├── loss_function.py
└── regularizer.py
├── option.py
├── quantization
├── google_quantization.py
├── qmobilenet.py
└── qresnet.py
├── requirements.txt
└── utils
├── __init__.py
├── checkpoint.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | <<<<<<< HEAD
2 | /quantized_transfer
3 | =======
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 | >>>>>>> f77b2de624357f2d3822d2805d99884da4f82597
134 |
--------------------------------------------------------------------------------
/.idea/DTQ.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | BackgroundGenerator
45 | get_AuxClassifier_list
46 | dorefa
47 | save_checkpoint
48 | save_model
49 | bits_activations
50 | reg_channel_att_fea_map_learn
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 | true
85 | DEFINITION_ORDER
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 | 1595335755982
165 |
166 |
167 | 1595335755982
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) xiezheng 2020,
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Transferring Quantization
2 |
3 | We provide PyTorch implementation for ["Deep Transferring Quantization"](https://tanmingkui.github.io/files/publications/Deep_Transferring.pdf) (ECCV2020).
4 |
5 |
6 |
7 |
8 |
9 | ## Paper
10 | * Deep Transferring Quantization
11 | * Zheng Xie *, Zhiquan Wen *, Jing Liu *, Zhiqiang Liu, Xixian Wu, and Mingkui Tan *
12 | * European Conference on Computer Vision (ECCV), 2020
13 |
14 |
15 |
16 |
17 | ## Dependencies
18 |
19 | * Python 3.6
20 | * PyTorch 1.1.0
21 | * dependencies in requirements.txt
22 |
23 |
24 | ## Getting Started
25 |
26 | ### Installation
27 |
28 | 1. Clone this repository:
29 |
30 | git clone https://github.com/xiezheng-cs/DTQ.git
31 | cd DTQ
32 |
33 | 2. Install PyTorch and other dependencies:
34 |
35 | pip install -r requirements.txt
36 |
37 |
38 | ### Training
39 |
40 | To quantize the pre-trained MobileNetV2 on Caltech 256-30 to 4-bit:
41 |
42 | python main.py hocon_config/mobilenetv2_caltech-256-30.hocon
43 |
44 | To quantize the pre-trained ResNet-50 on Caltech 256-30 to 4-bit:
45 |
46 | python main.py hocon_config/resnet50_caltech-256-30.hocon
47 |
48 |
49 |
50 |
51 | ## Experimental Results
52 |
53 | | Target Data Set | Model | W / A | DELTA-Q Top1 Acc(%) | DTQ(Ours) Top1 Acc(%) |
54 | | :-: | :-: | :-: | :-: | :-: |
55 | | Caltech 256-30 | MobileNetV2 | 4 / 4 | 74.0±0.7 | **75.9±1.2** |
56 | | Caltech 256-30 | ResNet-50 | 4 / 4 | 82.8±0.5 | **83.5±0.6** |
57 |
58 |
59 |
60 |
61 | ## Citation
62 | If this work is useful for your research, please cite our paper:
63 |
64 | @InProceedings{xie2020deep,
65 | title = {Deep Transferring Quantization},
66 | author = {Zheng, Xie and Zhiquan, Wen and Jing, Liu and Zhiqiang, Liu and Xixian, Wu and Mingkui, Tan},
67 | booktitle = {European Conference on Computer Vision (ECCV)},
68 | year = {2020}
69 | }
70 |
71 |
72 |
73 |
74 | ## Acknowledgments
75 | This work was partially supported by the Key-Area Research and Development Program of Guangdong Province 2019B010155002, National Natural Science Foundation of China (NSFC) 61836003 (key project), Program for Guangdong Introducing Innovative and Entrepreneurial Teams 2017ZT07X183, Fundamental Research Funds for the Central Universities D2191240.
76 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/9/23 19:42
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : __init__.py.py
7 |
8 |
--------------------------------------------------------------------------------
/data/data_split.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/10/13 16:57
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : data_split.py
7 |
8 |
9 | import os
10 | from scipy.io import loadmat
11 | import shutil
12 |
13 |
14 | def get_path_str(line):
15 | line = str(line)
16 | _, path, _ = line.split('\'')
17 | # print('line={}, path={}'.format(line, path))
18 | return path
19 |
20 |
21 | def path_replace(line):
22 | return line.replace('/', '\\')
23 |
24 |
25 | def copy_img(root, list, save_path):
26 | for i in range(list.shape[0]):
27 | print('i={}'.format(i))
28 | path = get_path_str(list[i][0])
29 | source_img_path = path_replace(os.path.join(root, 'Images', path))
30 |
31 | dir_, name = path.split('/')
32 | target_img_dir = path_replace(os.path.join(save_path, dir_))
33 | if not os.path.exists(target_img_dir):
34 | os.makedirs(target_img_dir)
35 |
36 | target_img_path = path_replace(os.path.join(target_img_dir, name))
37 | print('source_img_path={}, target_img_path={}'.format(source_img_path, target_img_path))
38 | shutil.copy(source_img_path, target_img_path)
39 |
40 |
41 | if __name__ == '__main__':
42 | print()
43 | root = '\Stanford Dogs 120'
44 |
45 | train_list = loadmat(os.path.join(root, 'train_list.mat'))['file_list']
46 | save_train_path = '\Stanford Dogs 120\\train'
47 | copy_img(root, train_list, save_train_path)
48 |
49 |
50 | # test_list = loadmat(os.path.join(root, 'test_list.mat'))['file_list']
51 | # save_test_path = '\Stanford Dogs 120\\test'
52 | # copy_img(root, test_list, save_test_path)
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/9/23 19:45
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : dataloader.py
7 |
8 |
9 | import os
10 | import sys
11 | import numpy as np
12 |
13 | import torch
14 | import torchvision
15 | from torchvision import transforms
16 | from PIL import Image
17 | from torch.utils.data import Dataset, DataLoader
18 | from scipy.io import loadmat
19 | import torchvision.datasets as datasets
20 | from PIL import ImageFile
21 |
22 | ImageFile.LOAD_TRUNCATED_IMAGES = True
23 | # torchvision.set_image_backend('accimage')
24 |
25 |
26 | def get_target_dataloader(dataset, batch_size, n_threads, data_path='', image_size=224,
27 | data_aug='default', logger=None):
28 | """
29 | Get dataloader for target_dataset
30 | :param dataset: the name of the dataset
31 | :param batch_size: how many samples per batch to load
32 | :param n_threads: how many subprocesses to use for data loading.
33 | :param data_path: the path of dataset
34 | :param logger: logger for logging
35 | """
36 |
37 | logger.info("|===>Get datalaoder for " + dataset)
38 |
39 | # setting
40 | crop_size = {299: 320, 224: 256}
41 | resize = crop_size[image_size]
42 | logger.info("image_size={}, resize={}".format(image_size, resize))
43 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
44 |
45 | if data_aug == 'default':
46 | logger.info('data_aug = {} !!!'.format(data_aug))
47 | train_transform = transforms.Compose([
48 | transforms.Resize((resize, resize)),
49 | transforms.RandomHorizontalFlip(),
50 | transforms.RandomCrop(image_size),
51 | transforms.ToTensor(),
52 | normalize])
53 | val_transform = transforms.Compose([
54 | transforms.Resize((resize, resize)),
55 | transforms.CenterCrop(image_size),
56 | transforms.ToTensor(),
57 | normalize])
58 | else:
59 | assert False, logger.info("invalid data_aug={}".format(data_aug))
60 |
61 | # data root
62 | if dataset in ['MIT_Indoors_67', 'Stanford_Dogs', 'Caltech_256-10', 'Caltech_256-20',
63 | 'Caltech_256-30', 'Caltech_256-40', 'Caltech_256-60', 'CUB-200-2011', 'Food-101']:
64 | data_root = os.path.join(data_path, dataset)
65 | else:
66 | assert False, logger.info("invalid dataset={}".format(dataset))
67 | logger.info('{} path = {}'.format(dataset, data_root))
68 |
69 | # datset
70 | train_dataset = datasets.ImageFolder(root=os.path.join(data_root, 'train'), transform=train_transform)
71 | val_dataset = datasets.ImageFolder(root=os.path.join(data_root, 'test'), transform=val_transform)
72 | class_num = len(train_dataset.classes)
73 | train_dataset_sizes = len(train_dataset)
74 | val_dataset_sizes = len(val_dataset)
75 |
76 | # dataloader
77 | train_loader = DataLoader(dataset=train_dataset,
78 | batch_size=batch_size,
79 | shuffle=True,
80 | pin_memory=True,
81 | num_workers=n_threads)
82 |
83 | val_loader = DataLoader(dataset=val_dataset,
84 | batch_size=batch_size,
85 | shuffle=False,
86 | pin_memory=True,
87 | num_workers=n_threads)
88 |
89 | logger.info("train and val loader are ready! class_num={}".format(class_num))
90 | logger.info("train_dataset_sizes={}, val_dataset_sizes={}".format(train_dataset_sizes, val_dataset_sizes))
91 | return train_loader, val_loader, class_num, train_dataset_sizes
92 |
93 |
94 |
95 |
96 | if __name__ == '__main__':
97 | print()
98 |
--------------------------------------------------------------------------------
/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiezheng-cs/DTQ/e633b197c1b3343c00d5741229ca97f0eeaa793f/framework.png
--------------------------------------------------------------------------------
/framework.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.util import AverageMeter
3 | from utils.util import get_learning_rate, accuracy
4 | from models.regularizer import reg_channel_att_fea_map_learn
5 | from models.loss_function import loss_kl
6 |
7 |
8 | class TransferFramework:
9 | def __init__(self, args, train_loader, val_loader, target_class_num, data_aug, base_model_name,
10 | model_source, model_feature, model_source_classifier, model_target_classifier, feature_criterions,
11 | loss_fn, num_epochs, optimizer, lr_scheduler, writer, logger, print_freq=10):
12 |
13 | self.setting = args
14 | self.train_loader = train_loader
15 | self.val_loader = val_loader
16 | self.target_class_num = target_class_num
17 | self.data_aug = data_aug
18 | self.reg_type = args.reg_type
19 | self.feature_criterions = feature_criterions
20 |
21 | self.base_model_name = base_model_name
22 | self.model_source = model_source
23 |
24 | # target model
25 | self.model_feature = model_feature
26 | self.model_source_classifier = model_source_classifier
27 | self.model_target_classifier = model_target_classifier
28 |
29 | # self.criterion_mse = nn.MSELoss().cuda()
30 | self.loss_fn = loss_fn
31 | self.num_epochs = num_epochs
32 | self.optimizer = optimizer
33 |
34 | self.lambada = args.lambada
35 | self.theta = args.theta
36 |
37 | self.lr = 0.0
38 | self.lr_scheduler = lr_scheduler
39 | self.writer = writer
40 | self.logger = logger
41 | self.print_freq = print_freq
42 |
43 | # framework init
44 | self.hook_layers = []
45 | if len(self.setting.gpu_id) <= 1:
46 | self.layer_outputs_source = []
47 | self.layer_outputs_target = []
48 | else:
49 | self.layer_outputs_source = {}
50 | self.layer_outputs_target = {}
51 | self.logger.info("hook output save to type: {}".format(type(self.layer_outputs_source)))
52 | self.framework_init()
53 |
54 |
55 | def framework_init(self):
56 | self.hook_setting()
57 |
58 | # hook
59 | def _for_hook_source(self, module, input, output):
60 | if len(self.setting.gpu_id) > 1:
61 | gpu_id = str(output.get_device())
62 | if gpu_id not in self.layer_outputs_source:
63 | self.layer_outputs_source[gpu_id] = []
64 | self.layer_outputs_source[gpu_id].append(output)
65 | else:
66 | self.layer_outputs_source.append(output)
67 |
68 | def _for_hook_target(self, module, input, output):
69 | if len(self.setting.gpu_id) > 1:
70 | gpu_id = str(output.get_device())
71 | if gpu_id not in self.layer_outputs_target:
72 | self.layer_outputs_target[gpu_id] = []
73 | self.layer_outputs_target[gpu_id].append(output)
74 | else:
75 | self.layer_outputs_target.append(output)
76 |
77 | def register_hook(self, model, func):
78 | for name, layer in model.named_modules():
79 | if name in self.hook_layers:
80 | layer.register_forward_hook(func)
81 |
82 |
83 | def get_hook_layers(self):
84 | if self.setting.base_model_name in ['resnet50']:
85 | if len(self.setting.gpu_id) > 1:
86 | self.hook_layers = ['module.layer1.2.conv3', 'module.layer2.3.conv3', 'module.layer3.5.conv3', 'module.layer4.2.conv3']
87 | else:
88 | self.hook_layers = ['layer1.2.conv3', 'layer2.3.conv3', 'layer3.5.conv3', 'layer4.2.conv3']
89 |
90 | elif self.base_model_name == 'mobilenet_v2':
91 | if len(self.setting.gpu_id) > 1:
92 | self.hook_layers = ['module.features.5.conv3', 'module.features.9.conv.3', 'module.features.13.conv.3', 'module.features.17.conv.3']
93 | else:
94 | self.hook_layers = ['features.5.conv.3', 'features.9.conv.3', 'features.13.conv.3', 'features.17.conv.3']
95 |
96 | else:
97 | assert False, self.logger.info("invalid base_model_name={}".format(self.base_model_name))
98 |
99 |
100 | def hook_setting(self):
101 | # hook
102 | self.get_hook_layers()
103 | self.register_hook(self.model_source, self._for_hook_source)
104 | self.register_hook(self.model_feature, self._for_hook_target)
105 | self.logger.info("self.hook_layers={}".format(self.hook_layers))
106 |
107 |
108 | def train(self, epoch):
109 | # train mode
110 | # target model
111 | self.model_feature.train()
112 | self.model_target_classifier.train()
113 | self.model_source_classifier.eval()
114 |
115 | # source model
116 | self.model_source.eval()
117 |
118 | clc_losses = AverageMeter()
119 | kl_losses = AverageMeter()
120 | fea_losses = AverageMeter()
121 |
122 | total_losses = AverageMeter()
123 | train_top1_accs = AverageMeter()
124 |
125 | self.lr_scheduler.step(epoch)
126 | self.lr = get_learning_rate(self.optimizer)
127 | self.logger.info('self.optimizer={}'.format(self.optimizer))
128 |
129 | self.logger.info('kl_loss weight lambada={}'.format(self.lambada))
130 | self.logger.info('fea_loss weight theta={}'.format(self.theta))
131 | self.logger.info('T={}'.format(self.setting.T))
132 | self.logger.info("reg_type: {}".format(self.reg_type))
133 |
134 | for i, (imgs, labels) in enumerate(self.train_loader):
135 |
136 | if torch.cuda.is_available():
137 | imgs = imgs.cuda()
138 | labels = labels.cuda()
139 |
140 | # target forward and loss
141 | target_outputs = self.model_feature(imgs)
142 |
143 | target_model_source_classifier_outputs = self.model_source_classifier(target_outputs)
144 | target_model_target_classifier_outputs = self.model_target_classifier(target_outputs)
145 |
146 | # source_model forward for hook
147 | with torch.no_grad():
148 | source_outputs = self.model_source(imgs)
149 |
150 | # loss
151 | clc_loss = self.loss_fn(target_model_target_classifier_outputs, labels)
152 | kl_loss = loss_kl(target_model_source_classifier_outputs, source_outputs, self.setting.T)
153 |
154 | if self.reg_type == 'channel_att_fea_map_learn':
155 | if self.theta == 0.0:
156 | fea_loss = 0.0
157 | else:
158 | fea_loss = reg_channel_att_fea_map_learn(self.layer_outputs_source, self.layer_outputs_target,
159 | self.feature_criterions, self.setting.bits_activations, self.logger)
160 | else:
161 | assert False, "Wrong reg type!!!"
162 |
163 | total_loss = clc_loss + self.lambada * kl_loss + self.theta * fea_loss
164 |
165 | self.optimizer.zero_grad()
166 | total_loss.backward()
167 | self.optimizer.step()
168 |
169 | self.layer_outputs_source.clear()
170 | self.layer_outputs_target.clear()
171 |
172 | clc_losses.update(clc_loss.item(), imgs.size(0))
173 | kl_losses.update(kl_loss.item(), imgs.size(0))
174 |
175 | if fea_loss == 0.0:
176 | fea_losses.update(fea_loss, imgs.size(0))
177 | else:
178 | fea_losses.update(fea_loss.item(), imgs.size(0))
179 | total_losses.update(total_loss.item(), imgs.size(0))
180 |
181 | # compute accuracy
182 | top1_accuracy = accuracy(target_model_target_classifier_outputs, labels, 1)
183 | train_top1_accs.update(top1_accuracy, imgs.size(0))
184 |
185 | if i % self.print_freq == 0:
186 | self.logger.info(
187 | 'Train Epoch: [{:d}/{:d}][{:d}/{:d}]\tlr={:.6f}\tclc_loss={:.4f}\t\tkl_loss={:.4f}'
188 | '\t\tfea_loss={:.4f}\t\ttotal_loss={:.4f}\ttop1_Accuracy={:.4f}'
189 | .format(epoch, self.num_epochs, i, len(self.train_loader), self.lr, clc_losses.avg,
190 | kl_losses.avg, fea_losses.avg, total_losses.avg, train_top1_accs.avg))
191 |
192 | # save tensorboard
193 | self.writer.add_scalar('lr', self.lr, epoch)
194 | self.writer.add_scalar('Train_classification_loss', clc_losses.avg, epoch)
195 | self.writer.add_scalar('Train_kl_loss', kl_losses.avg, epoch)
196 | self.writer.add_scalar('Train_fea_loss', fea_losses.avg, epoch)
197 | self.writer.add_scalar('Train_total_loss', total_losses.avg, epoch)
198 | self.writer.add_scalar('Train_top1_accuracy', train_top1_accs.avg, epoch)
199 |
200 | self.logger.info(
201 | '||==> Train Epoch: [{:d}/{:d}]\tTrain: lr={:.6f}\tclc_loss={:.4f}\t\tkl_loss={:.4f}'
202 | '\t\tfea_loss={:.4f}\ttotal_loss={:.4f}\ttop1_Accuracy={:.4f}'
203 | .format(epoch, self.num_epochs, self.lr, clc_losses.avg, kl_losses.avg,
204 | fea_losses.avg, total_losses.avg, train_top1_accs.avg))
205 |
206 | return clc_losses.avg, kl_losses.avg, fea_losses.avg, total_losses.avg, train_top1_accs.avg
207 |
208 |
209 | def val(self, epoch):
210 | # test mode
211 | self.model_feature.eval()
212 | self.model_target_classifier.eval()
213 |
214 | val_losses = AverageMeter()
215 | val_top1_accs = AverageMeter()
216 |
217 | # Batches
218 | for i, (imgs, labels) in enumerate(self.val_loader):
219 | # Move to GPU, if available
220 | if torch.cuda.is_available():
221 | imgs = imgs.cuda()
222 | labels = labels.cuda()
223 |
224 | if self.data_aug == 'improved':
225 | bs, ncrops, c, h, w = imgs.size()
226 | imgs = imgs.view(-1, c, h, w)
227 |
228 | # forward and loss
229 | with torch.no_grad():
230 | outputs = self.model_feature(imgs)
231 | outputs = self.model_target_classifier(outputs)
232 |
233 | if self.data_aug == 'improved':
234 | outputs = outputs.view(bs, ncrops, -1).mean(1)
235 |
236 | val_loss = self.loss_fn(outputs, labels)
237 |
238 | val_losses.update(val_loss.item(), imgs.size(0))
239 | # compute accuracy
240 | top1_accuracy = accuracy(outputs, labels, 1)
241 | val_top1_accs.update(top1_accuracy, imgs.size(0))
242 |
243 | # batch update
244 | self.layer_outputs_source.clear()
245 | self.layer_outputs_target.clear()
246 |
247 | # Print status
248 | if i % self.print_freq == 0:
249 | self.logger.info('Val Epoch: [{:d}/{:d}][{:d}/{:d}]\tval_loss={:.4f}\t\ttop1_accuracy={:.4f}\t'
250 | .format(epoch, self.num_epochs, i, len(self.val_loader), val_losses.avg, val_top1_accs.avg))
251 | # save tensorboard
252 | self.writer.add_scalar('Val_loss', val_losses.avg, epoch)
253 | self.writer.add_scalar('Val_top1_accuracy', val_top1_accs.avg, epoch)
254 |
255 | self.logger.info('||==> Val Epoch: [{:d}/{:d}]\tval_loss={:.4f}\t\ttop1_accuracy={:.4f}'
256 | .format(epoch, self.num_epochs, val_losses.avg, val_top1_accs.avg))
257 |
258 | return val_losses.avg, val_top1_accs.avg
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
--------------------------------------------------------------------------------
/hocon_config/mobilenetv2_caltech-256-30.hocon:
--------------------------------------------------------------------------------
1 | # ------------- data options ----------------------------------------
2 | target_dataset = "Caltech_256-30" # Stanford_Dogs, Caltech_256-30, Caltech_256-60, CUB-200-2011, Food-101
3 | target_data_dir = "/mnt/ssd/Datasets/Fine-Grained_Recognition/"
4 |
5 |
6 | # ------------- general options -------------------------------------------
7 | outpath = "./exp_log/DTQ_4bits/"
8 | gpu_id = "4" # single-gpu
9 | seed = 0
10 | print_freq = 10
11 | batch_size = 64
12 | num_workers = 4
13 | exp_id = "20200730"
14 |
15 |
16 | # ------------- common optimization options ----------------------------
17 | repeat = 5
18 | lr = 0.01 # 0.01 for resnet50 and mobilenet_v2
19 | max_iter = 9000 # 9000, 6000 decay
20 | momentum = 0.9
21 | weight_decay = 1e-4 # for classifier
22 | gamma = 0.1
23 |
24 |
25 | # ------------- model options ------------------------------------------
26 | base_task = "imagenet" # imagenet
27 | base_model_name = "mobilenet_v2" # resnet50, mobilenet_v2
28 | image_size = 224 # 224 for resnet101, mobilenet_v2
29 | data_aug = "default" # default
30 | bits_weights = 4 # 32, 5, 4 bit
31 | bits_activations = 4 # 32, 5, 4 bit
32 |
33 |
34 | # ------------- training options ------------------------------------------
35 | loss_type = "CrossEntropyLoss" # CrossEntropyLoss
36 | lr_scheduler = "steplr" # steplr
37 | reg_type = "channel_att_fea_map_learn" # channel_att_fea_map_learn
38 | lambada = 0.5 # kl_loss weight
39 | theta = 0.01 # AFA_loss weight
40 | T = 30.0 # parameter of the KL loss
41 |
42 |
43 | # ------------- resume or retrain options ------------------------------
44 | pretrain_path = ""
45 | resume = ""
--------------------------------------------------------------------------------
/hocon_config/resnet50_caltech-256-30.hocon:
--------------------------------------------------------------------------------
1 | # ------------- data options ----------------------------------------
2 | target_dataset = "Caltech_256-30" # Stanford_Dogs, Caltech_256-30, Caltech_256-60, CUB-200-2011, Food-101
3 | target_data_dir = "/mnt/ssd/datasets/Fine-Grained_Recognition/"
4 |
5 |
6 | # ------------- general options -------------------------------------------
7 | outpath = "./exp_log/DTQ_4bits/"
8 | gpu_id = "3,5"
9 | seed = 4
10 | print_freq = 10
11 | batch_size = 64
12 | num_workers = 4
13 | exp_id = "20200730"
14 |
15 |
16 | # ------------- common optimization options ----------------------------
17 | repeat = 5
18 | lr = 0.01 # 0.01 for resnet50 and mobilenet_v2
19 | max_iter = 9000 # 9000, 6000 decay
20 | momentum = 0.9
21 | weight_decay = 1e-4 # for classifier
22 | gamma = 0.1
23 |
24 |
25 | # ------------- model options ------------------------------------------
26 | base_task = "imagenet" # imagenet
27 | base_model_name = "resnet50" # resnet50, mobilenet_v2
28 | image_size = 224 # 224 for resnet50, mobilenet_v2
29 | data_aug = "default" # default
30 | bits_weights = 4 # 32, 5, 4 bit
31 | bits_activations = 4 # 32, 5, 4 bit
32 |
33 |
34 | # ------------- training options ------------------------------------------
35 | loss_type = "CrossEntropyLoss" # CrossEntropyLoss
36 | lr_scheduler = "steplr" # steplr
37 | reg_type = "channel_att_fea_map_learn" # channel_att_fea_map_learn
38 | lambada = 0.5 # kl_loss weight
39 | theta = 0.01 # AFA_loss weight
40 | T = 60.0 # parameter of the KL loss
41 |
42 |
43 | # ------------- resume or retrain options ------------------------------
44 | pretrain_path = ""
45 | resume = ""
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import random
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | import torch.backends.cudnn as cudnn
9 | from tensorboardX import SummaryWriter
10 |
11 | from option import Option
12 | from framework import TransferFramework
13 | from models.loss_function import get_loss_type
14 | from data.dataloader import get_target_dataloader
15 | from models.get_model import get_model, model_split
16 | from utils.checkpoint import save_checkpoint
17 | from models.regularizer import get_reg_criterions
18 | from utils.util import get_logger, output_process, ours_record_epoch_data, \
19 | write_settings, get_optimier_and_scheduler
20 |
21 |
22 | def train_net(args, logger, seed):
23 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
24 | logger.info('seed={}'.format(seed))
25 |
26 | # init seed
27 | torch.manual_seed(seed)
28 | torch.cuda.manual_seed(seed)
29 | random.seed(seed)
30 | np.random.seed(seed)
31 | # cudnn.benchmark = True
32 | cudnn.benchmark = False
33 | cudnn.deterministic = True # cudnn
34 |
35 | writer = SummaryWriter(args.outpath)
36 | start_epoch = 0
37 | val_best_acc = 0
38 | val_best_acc_index = 0
39 |
40 | # data_loader
41 | train_loader, val_loader, target_class_num, dataset_sizes = \
42 | get_target_dataloader(args.target_dataset, args.batch_size, args.num_workers, args.target_data_dir,
43 | image_size=args.image_size, data_aug=args.data_aug, logger=logger)
44 |
45 | # model setting
46 | model_source, model_target = get_model(args.base_model_name, args.base_task, logger, args)
47 |
48 | # target_model split: (feature, classifier)
49 | model_feature, model_source_classifier, model_target_classifier = \
50 | model_split(args.base_model_name, model_target, target_class_num, logger, args)
51 |
52 | if len(args.gpu_id) > 1:
53 | model_source = nn.DataParallel(model_source)
54 | model_feature = nn.DataParallel(model_feature)
55 | model_source_classifier = nn.DataParallel(model_source_classifier)
56 | model_target_classifier = nn.DataParallel(model_target_classifier)
57 | model_source = model_source.cuda()
58 | model_feature = model_feature.cuda()
59 | model_target_classifier = model_target_classifier.cuda()
60 | model_source_classifier = model_source_classifier.cuda()
61 | logger.info("push all model to dataparallel and then gpu")
62 | else:
63 | model_source = model_source.cuda()
64 | model_feature = model_feature.cuda()
65 | model_target_classifier = model_target_classifier.cuda()
66 | model_source_classifier = model_source_classifier.cuda()
67 | logger.info("push all model to gpu")
68 |
69 | # iterations -> epochs
70 | num_epochs = int(np.round(args.max_iter * args.batch_size / dataset_sizes))
71 | step = [int(0.67 * num_epochs)]
72 | logger.info('num_epochs={}, step={}'.format(num_epochs, step))
73 |
74 | # loss
75 | loss_fn = get_loss_type(loss_type=args.loss_type, logger=logger)
76 |
77 | # get feature_criterions
78 | if args.reg_type in ['channel_att_fea_map_learn', 'fea_loss']:
79 | feature_criterions = get_reg_criterions(args, logger)
80 |
81 | # optimizer and lr_scheduler
82 | optimizer, lr_scheduler = get_optimier_and_scheduler(args, model_feature, model_target_classifier, feature_criterions,
83 | step, logger)
84 |
85 | # init framework
86 | framework = TransferFramework(args, train_loader, val_loader, target_class_num, args.data_aug, args.base_model_name,
87 | model_source, model_feature, model_source_classifier, model_target_classifier,
88 | feature_criterions, loss_fn, num_epochs, optimizer, lr_scheduler,
89 | writer, logger, print_freq=args.print_freq)
90 |
91 | # epochs
92 | for epoch in range(start_epoch, num_epochs):
93 | # train epoch
94 | clc_loss, kl_loss, fea_loss, train_total_loss, train_top1_acc = framework.train(epoch)
95 | # val epoch
96 | val_loss, val_top1_acc = framework.val(epoch)
97 | # record into txt
98 | ours_record_epoch_data(args.outpath, epoch, clc_loss, kl_loss, fea_loss, train_total_loss, train_top1_acc, val_loss, val_top1_acc)
99 |
100 | if val_top1_acc >= val_best_acc:
101 | val_best_acc = val_top1_acc
102 | val_best_acc_index = epoch
103 | # save_checkpoint
104 | save_checkpoint(args.outpath, epoch, model_feature, model_source_classifier, model_target_classifier,
105 | optimizer, lr_scheduler, val_best_acc)
106 |
107 | logger.info('||==>Val Epoch: Val_best_acc_index={}\tVal_best_acc={:.4f}\n'.format(val_best_acc_index, val_best_acc))
108 | # break
109 | return val_best_acc
110 |
111 |
112 | if __name__ == '__main__':
113 | parser = argparse.ArgumentParser(description='Transfer')
114 | parser.add_argument('conf_path', type=str, metavar='conf_path',
115 | help='the path of config file for training (default: 64)')
116 | argparses = parser.parse_args()
117 | args = Option(argparses.conf_path)
118 | args.set_save_path()
119 |
120 | # args = parse_args()
121 | best_val_acc_list = []
122 | logger = None
123 | temp = args.outpath
124 | for i in range(1, args.repeat + 1):
125 | if args.repeat != 1:
126 | args.outpath = temp + "_{:02d}".format(i)
127 |
128 | output_process(args.outpath)
129 | write_settings(args)
130 | logger = get_logger(args.outpath, 'attention_transfer_{:02d}'.format(i))
131 | if i == 1:
132 | args.copy_code(logger, dst=os.path.join(args.outpath, 'code'))
133 |
134 | val_acc = train_net(args, logger, seed=(args.seed+i))
135 | best_val_acc_list.append(val_acc)
136 |
137 | acc_mean = np.mean(best_val_acc_list)
138 | acc_std = np.std(best_val_acc_list)
139 | for i in range(len(best_val_acc_list)):
140 | print_str = 'repeat={}\tbest_val_acc={}'.format(i, best_val_acc_list[i])
141 | logger.info(print_str)
142 | logger.info('All repeat val_acc_mean={}\tval_acc_std={})'.format(acc_mean, acc_std))
143 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/9/23 19:40
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : __init__.py.py
--------------------------------------------------------------------------------
/models/get_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/9/23 20:51
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : get_model.py
7 |
8 |
9 | import pickle
10 | import torch
11 | import torch.nn as nn
12 | from torchvision.models import resnet50, mobilenet_v2
13 | from quantization.qmobilenet import Qmobilenet_v2
14 | from quantization.qresnet import Qresnet50
15 |
16 |
17 | def pretrained_model_imagenet(base_model): # load pre-trained model
18 | return eval(base_model)(pretrained=True)
19 |
20 |
21 | def get_base_model(base_model, model_type, logger, args): # interface for obtaining full precision or low precision model
22 | if model_type == 'source':
23 | return pretrained_model_imagenet(base_model)
24 |
25 | elif model_type == 'target':
26 |
27 | if args.bits_weights == 32 and args.bits_activations == 32: # full precision model
28 | model_target = pretrained_model_imagenet(base_model)
29 | logger.info('bits_weights and bits_activations == {}, '
30 | 'target model is full-precision!'.format(args.bits_weights))
31 | return model_target
32 |
33 | else:
34 | if base_model == "mobilenet_v2":
35 | model_target = Qmobilenet_v2(pretrained=True, bits_weights=args.bits_weights,
36 | bits_activations=args.bits_activations) # load low-precision mobilenet_v2 model
37 | elif base_model == "resnet50":
38 | model_target = Qresnet50(pretrained=True,bits_weights=args.bits_weights,
39 | bits_activations=args.bits_activations) # load low-precision ResNet-101 model
40 | else:
41 | assert False, "The model {} not allowed".format(base_model)
42 |
43 | logger.info('bits_weights and bits_activations == {}, '
44 | 'target model is low-precision!'.format(args.bits_weights))
45 | return model_target
46 | else:
47 | assert False, "Not exist this model_type {}".format(model_type)
48 |
49 |
50 | def get_model(base_model_name, base_task, logger, args): # obtain source and target model
51 | model_source = get_base_model(base_model_name, "source", logger, args)
52 | model_target = get_base_model(base_model_name, "target", logger, args)
53 |
54 | logger.info("model_source: {}".format(model_source))
55 | logger.info("model_target: {}".format(model_target))
56 |
57 | for param in model_source.parameters():
58 | param.requires_grad = False
59 |
60 | logger.info('base_task = {}, get model_source = {} and model_target ={}'
61 | .format(base_task, base_model_name, base_model_name))
62 | return model_source, model_target
63 |
64 |
65 | def model_split(base_model_name, model, target_class_num, logger, args): # split the target model into feature extractor and classifier
66 | if 'resnet' in base_model_name:
67 | model_source_classifier = model.fc
68 | logger.info('model_source_classifier:\n{}'.format(model_source_classifier))
69 |
70 | model_target_classifier = nn.Linear(model.fc.in_features, target_class_num)
71 | logger.info('model_target_classifier:\n{}'.format(model_target_classifier))
72 |
73 | model_feature = model
74 | model_feature.fc = nn.Identity()
75 | logger.info('model_feature:\n{}'.format(model_feature))
76 |
77 | elif 'mobilenet' in base_model_name:
78 | model_source_classifier = model.classifier[1]
79 | logger.info('model_source_classifier:\n{}'.format(model_source_classifier))
80 |
81 | model_target_classifier = nn.Linear(list(model.classifier.children())[1].in_features,target_class_num)
82 |
83 | logger.info('model_target_classifier:\n{}'.format(model_target_classifier))
84 |
85 | model_feature = model
86 | model_feature.classifier[1] = nn.Identity()
87 | logger.info('model_feature:\n{}'.format(model_feature))
88 |
89 | else:
90 | logger.info('unknown base_model_name={}'.format(base_model_name))
91 |
92 | return model_feature, model_source_classifier, model_target_classifier
93 |
94 |
95 | if __name__ == '__main__':
96 | model = resnet50(pretrained=False)
97 |
--------------------------------------------------------------------------------
/models/loss_function.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/10/6 10:30
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : loss_function.py
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | def loss_kl(outputs, teacher_outputs, T=1.0):
15 | kl_loss = (T * T) * nn.KLDivLoss(size_average=False)(F.log_softmax(outputs / T),
16 | F.softmax(teacher_outputs / T)) / outputs.shape[0]
17 | return kl_loss
18 |
19 |
20 | def get_loss_type(loss_type, logger=None):
21 |
22 | if loss_type == 'CrossEntropyLoss':
23 | loss_fn = nn.CrossEntropyLoss().cuda()
24 | else:
25 | assert False, logger.info("invalid loss_type={}".format(loss_type))
26 |
27 | if logger is not None:
28 | logger.info("loss_type={}, {}".format(loss_type, loss_fn))
29 | return loss_fn
30 |
31 |
32 |
33 | if __name__ == '__main__':
34 | print()
--------------------------------------------------------------------------------
/models/regularizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/10/15 11:26
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : regularizer.py
7 |
8 |
9 | import numpy as np
10 | import math
11 |
12 | import torch
13 | from torch import nn
14 | from utils.util import get_conv_num, get_fc_name, concat_gpu_data
15 | from quantization.google_quantization import quantization_on_input
16 |
17 |
18 | def flatten_outputs(fea):
19 | return torch.reshape(fea, (fea.shape[0], fea.shape[1], fea.shape[2] * fea.shape[3]))
20 |
21 |
22 | def get_reg_criterions(args, logger):
23 | if args.base_model_name in ['resnet50', 'resnet101']:
24 | in_channels_list = [256, 512, 1024, 2048]
25 | feature_size = [56, 28, 14, 7]
26 | elif args.base_model_name in ['inception_v3']:
27 | in_channels_list = [192, 288, 768, 2048]
28 | feature_size = [71, 35, 17, 8]
29 | elif args.base_model_name in ['mobilenet_v2']:
30 | in_channels_list = [32, 64, 96, 320]
31 | feature_size = [28, 14, 14, 7]
32 | else:
33 | assert False, logger.info('invalid base_model_name={}'.format(args.base_model_name))
34 |
35 | logger.info('in_channels_list={}'.format(in_channels_list))
36 | logger.info('feature_size={}'.format(feature_size))
37 |
38 | feature_criterions = get_feature_criterions(args, in_channels_list, feature_size, logger) # obtain channel attentive module
39 | return feature_criterions
40 |
41 |
42 | class channel_attention(nn.Module): # channel attentive module
43 | def __init__(self, in_channels, feature_size):
44 | super(channel_attention, self).__init__()
45 |
46 | # channel-wise attention
47 | self.fc1 = nn.Linear(feature_size * feature_size, feature_size, bias=False)
48 | self.relu1 = nn.ReLU(inplace=True)
49 | self.fc2 = nn.Linear(feature_size, 1, bias=False)
50 | self.bias = nn.Parameter(torch.zeros(in_channels))
51 | self.softmax = nn.Softmax()
52 |
53 | def forward(self, target_feature):
54 | b, c, h, w = target_feature.shape
55 | target_feature_resize = target_feature.view(b, c, h * w)
56 |
57 | # channel-wise attention
58 | c_f = self.fc1(target_feature_resize)
59 | c_f = self.relu1(c_f)
60 | c_f = self.fc2(c_f)
61 | c_f = c_f.view(b, c)
62 |
63 | # softmax
64 | channel_attention_weight = self.softmax(c_f + self.bias) # b*in_channels
65 | return channel_attention_weight
66 |
67 |
68 | # obtain channel feature alignment loss
69 | def reg_channel_att_fea_map_learn(layer_outputs_source, layer_outputs_target,
70 | feature_criterions, bits_activations, logger):
71 | if isinstance(feature_criterions, nn.DataParallel):
72 | feature_criterions_module = feature_criterions.module
73 | layer_outputs_source_processed = concat_gpu_data(layer_outputs_source)
74 | layer_outputs_target_processed = concat_gpu_data(layer_outputs_target)
75 | else:
76 | feature_criterions_module = feature_criterions
77 | layer_outputs_source_processed = layer_outputs_source
78 | layer_outputs_target_processed = layer_outputs_target
79 |
80 | fea_loss = torch.tensor(0.).cuda()
81 | for i, (fm_src, fm_tgt, feature_criterion) in \
82 | enumerate(zip(layer_outputs_source_processed, layer_outputs_target_processed, feature_criterions_module)):
83 | channel_attention_weight = feature_criterion(fm_src) # b, c
84 | b, c, h, w = fm_src.shape
85 |
86 | fm_src = flatten_outputs(fm_src) # b * c * (hw)
87 | fm_tgt = flatten_outputs(fm_tgt)
88 |
89 | diff = fm_tgt - fm_src.detach()
90 | distance = torch.norm(diff, 2, 2) # b * c
91 |
92 | distance = torch.mul(channel_attention_weight, distance ** 2) * c
93 | fea_loss += 0.5 * torch.sum(distance) / b
94 |
95 | return fea_loss
96 |
97 |
98 | def get_feature_criterions(args, in_channels_list, feature_size, logger):
99 | feature_criterions = nn.ModuleList()
100 | for i in range(len(in_channels_list)):
101 |
102 | if args.reg_type == 'channel_att_fea_map_learn':
103 | feature_criterions.append(channel_attention(in_channels_list[i], feature_size[i]))
104 | else:
105 | assert False, logger.info('invalid reg_type={}'.format(args.reg_type))
106 |
107 | if len(args.gpu_id) <= 1:
108 | feature_criterions = feature_criterions.cuda()
109 | else:
110 | feature_criterions = nn.DataParallel(feature_criterions)
111 | feature_criterions = feature_criterions.cuda()
112 | logger.info('feature_criterions={}'.format(feature_criterions))
113 | return feature_criterions
114 |
115 |
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | from pyhocon import ConfigFactory
5 | from utils.util import is_main_process
6 |
7 |
8 | class Option(object):
9 | def __init__(self, conf_path):
10 | super(Option, self).__init__()
11 | self.conf = ConfigFactory.parse_file(conf_path)
12 |
13 |
14 | # ------------- data options -------------------------------------------
15 | # target_dataset
16 | self.target_dataset = self.conf['target_dataset'] # target dataset name
17 | self.target_data_dir = self.conf['target_data_dir'] # path for loading data set
18 |
19 |
20 | # ------------- general options ----------------------------------------
21 | self.outpath = self.conf['outpath'] # log path
22 | self.gpu_id = self.conf['gpu_id'] # GPU id to use, e.g. "0,1,2,3"
23 | self.seed = self.conf['seed'] # manually set RNG seed
24 | self.print_freq = self.conf['print_freq'] # print frequency (default: 10)
25 | self.batch_size = self.conf['batch_size'] # mini-batch size
26 | self.num_workers = self.conf['num_workers'] # num_workers
27 | self.exp_id = self.conf['exp_id'] # identifier for experiment
28 |
29 |
30 | # ------------- common optimization options ----------------------------
31 | self.repeat = self.conf['repeat']
32 | self.lr = float(self.conf['lr']) # initial learning rate
33 | self.max_iter = self.conf['max_iter'] # number of total epochs
34 | self.momentum = self.conf['momentum'] # momentum
35 | self.weight_decay = float(self.conf['weight_decay']) # weight decay
36 | self.gamma = self.conf['gamma'] # the times for drop lr
37 | self.bits_weights = self.conf['bits_weights']
38 | self.bits_activations = self.conf['bits_activations']
39 | self.lambada = self.conf['lambada'] # kl loss weight
40 | self.theta = self.conf['theta'] # AFA loss weight
41 | self.T = self.conf['T'] # parameter of the KL loss
42 |
43 | # ------------- model options ------------------------------------------
44 | self.base_task = self.conf['base_task']
45 | self.base_model_name = self.conf['base_model_name']
46 | self.image_size = self.conf['image_size']
47 | self.data_aug = self.conf['data_aug']
48 | self.reg_type = self.conf['reg_type']
49 |
50 | self.loss_type = self.conf['loss_type']
51 | self.lr_scheduler = self.conf['lr_scheduler']
52 |
53 |
54 | # ---------- resume or pretrained options ---------------------------------
55 | # path to pretrained model
56 | self.pretrain_path = None if len(self.conf['pretrain_path']) == 0 else self.conf['pretrain_path']
57 | # path to directory containing checkpoint
58 | self.resume = None if len(self.conf['resume']) == 0 else self.conf['resume']
59 |
60 |
61 | def set_save_path(self):
62 | exp_id = 'log_{}_{}_img{}_da-{}_{}_iter{}_bs{}_{}_lr{}_wd{}_W{}A{}_lambada{}_theta{}_T{}_{}' \
63 | .format(self.base_task, self.target_dataset, self.image_size, self.data_aug, self.base_model_name,
64 | self.max_iter, self.batch_size, self.lr_scheduler, self.lr, self.weight_decay,
65 | self.bits_weights, self.bits_activations, self.lambada, self.theta, self.T, self.exp_id)
66 |
67 | path = '{}_{}_da-{}_{}'.format('quantized_transfer', self.target_dataset, self.data_aug, self.base_model_name)
68 | self.outpath = os.path.join(self.outpath, path, exp_id)
69 | # self.outpath = os.path.join(self.outpath, exp_id)
70 |
71 |
72 | def copy_code(self, logger, src=os.path.abspath("./"), dst="./code/"):
73 | """
74 | copy code in current path to a folder
75 | """
76 | if is_main_process():
77 | for f in os.listdir(src):
78 | if "specific_experiments" in f or "log" in f:
79 | continue
80 | src_file = os.path.join(src, f)
81 | file_split = f.split(".")
82 | if len(file_split) >= 2 and file_split[1] == "py":
83 | if not os.path.isdir(dst):
84 | os.makedirs(dst)
85 | dst_file = os.path.join(dst, f)
86 | try:
87 | shutil.copyfile(src=src_file, dst=dst_file)
88 | except:
89 | logger.errro("copy file error! src: {}, dst: {}".format(src_file, dst_file))
90 | elif os.path.isdir(src_file):
91 | deeper_dst = os.path.join(dst, f)
92 | self.copy_code(logger, src=src_file, dst=deeper_dst)
93 |
--------------------------------------------------------------------------------
/quantization/google_quantization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | from torch.nn import functional as F
5 | from torch.nn import Parameter
6 |
7 | import numpy as np
8 | import os
9 | import matplotlib.pyplot as plt
10 |
11 | # global params
12 | # QUANTIZE_NUM = 127.0
13 |
14 |
15 | def quantization_on_weights(x, k):
16 | n = 2 ** k
17 | a = torch.min(x)
18 | b = torch.max(x)
19 | s = (b - a) / (n - 1)
20 |
21 | x = torch.clamp(x, float(a), float(b))
22 | x = (x - a) / s
23 | x = RoundFunction.apply(x)
24 | x = x * s + a
25 | return x
26 |
27 |
28 | class RoundFunction(Function):
29 | @staticmethod
30 | def forward(ctx, x):
31 | return torch.round(x)
32 | @staticmethod
33 | def backward(ctx, grad_output):
34 | return grad_output
35 |
36 |
37 | def quantization_on_input(x, k):
38 | n = 2 ** k
39 | a = torch.min(x)
40 | b = torch.max(x)
41 | s = (b - a) / (n - 1)
42 |
43 | x = torch.clamp(x, float(a), float(b))
44 | x = (x - a) / s
45 | x = RoundFunction.apply(x)
46 | x = x * s + a
47 | return x
48 |
49 |
50 | class QConv2d(nn.Conv2d):
51 | """
52 | custom convolutional layers for quantization
53 | """
54 | def __init__(self, in_channels, out_channels, kernel_size,
55 | stride=1, padding=0, dilation=1, groups=1, bias=True,
56 | bits_weights=32, bits_activations=32):
57 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
58 | stride, padding, dilation, groups, bias)
59 |
60 | self.bits_weights = bits_weights
61 | self.bits_activations = bits_activations
62 |
63 | def forward(self, input):
64 | if self.bits_activations == 32:
65 | quantized_input = input
66 | else:
67 | quantized_input = quantization_on_input(input, self.bits_activations)
68 | quantized_weight = quantization_on_weights(self.weight, self.bits_weights)
69 | return F.conv2d(quantized_input, quantized_weight, self.bias, self.stride,
70 | self.padding, self.dilation, self.groups)
71 |
72 | def extra_repr(self):
73 | s = super().extra_repr()
74 | s += ", bits_weights={}".format(self.bits_weights)
75 | s += ", bits_activations={}".format(self.bits_activations)
76 | s += ", method={}".format("google")
77 | return s
78 |
79 |
80 | class QLinear(nn.Linear):
81 | """
82 | custom linear layers for quantization
83 | """
84 | def __init__(self, in_features, out_features, bias=False, bits_weights=32, bits_activations=32):
85 | super(QLinear, self).__init__(in_features, out_features, bias=bias)
86 | self.bits_weights = bits_weights
87 | self.bits_activations = bits_activations
88 |
89 | def forward(self, input):
90 | quantized_input = quantization_on_input(input, self.bits_activations)
91 | quantized_weight = quantization_on_weights(self.weight, self.bits_weights)
92 | return F.linear(quantized_input, quantized_weight, self.bias)
93 |
94 | def extra_repr(self):
95 | s = super().extra_repr()
96 | s += ", bits_weights={}".format(self.bits_weights)
97 | s += ", bits_activations={}".format(self.bits_activations)
98 | s += ", method={}".format("google")
99 | return s
100 |
101 |
102 | if __name__ == "__main__":
103 | x = torch.rand(2,2,2,2)
104 | print('x={}'.format(x))
105 | k = 8
106 | q_x = quantization_on_weights(x, k)
107 | print('q_x={}'.format(q_x))
--------------------------------------------------------------------------------
/quantization/qmobilenet.py:
--------------------------------------------------------------------------------
1 | # copy from pytorch-torchvision-models-mobilenet_v2
2 | from torch import nn
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 |
6 | from torch.hub import load_state_dict_from_url
7 | from quantization.google_quantization import QConv2d
8 |
9 |
10 | __all__ = ['MobileNetV2', 'mobilenet_v2', 'QMobileNetV2']
11 |
12 |
13 | model_urls = {
14 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
15 | }
16 |
17 |
18 | def _make_divisible(v, divisor, min_value=None):
19 | """
20 | This function is taken from the original tf repo.
21 | It ensures that all layers have a channel number that is divisible by 8
22 | It can be seen here:
23 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
24 | :param v:
25 | :param divisor:
26 | :param min_value:
27 | :return:
28 | """
29 | if min_value is None:
30 | min_value = divisor
31 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
32 | # Make sure that round down does not go down by more than 10%.
33 | if new_v < 0.9 * v:
34 | new_v += divisor
35 | return new_v
36 |
37 |
38 | class ConvBNReLU(nn.Sequential):
39 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
40 | padding = (kernel_size - 1) // 2
41 | super(ConvBNReLU, self).__init__(
42 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
43 | nn.BatchNorm2d(out_planes),
44 | nn.ReLU6(inplace=True)
45 | )
46 |
47 |
48 | class QConvBNReLU(nn.Sequential):
49 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, bits_weights=32, bits_activations=32):
50 | padding = (kernel_size - 1)//2
51 | super(QConvBNReLU, self).__init__(
52 | QConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
53 | bias=False, bits_weights=bits_weights, bits_activations=bits_activations),
54 | nn.BatchNorm2d(out_planes),
55 | nn.ReLU6(inplace=True)
56 | )
57 |
58 |
59 | class InvertedResidual(nn.Module):
60 | def __init__(self, inp, oup, stride, expand_ratio):
61 | super(InvertedResidual, self).__init__()
62 | self.stride = stride
63 | assert stride in [1, 2]
64 |
65 | hidden_dim = int(round(inp * expand_ratio))
66 | self.use_res_connect = self.stride == 1 and inp == oup
67 |
68 | layers = []
69 | if expand_ratio != 1:
70 | # pw
71 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
72 | layers.extend([
73 | # dw
74 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
75 | # pw-linear
76 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
77 | nn.BatchNorm2d(oup),
78 | ])
79 | self.conv = nn.Sequential(*layers)
80 |
81 | def forward(self, x):
82 | if self.use_res_connect:
83 | return x + self.conv(x)
84 | else:
85 | return self.conv(x)
86 |
87 |
88 | class QInvertedResidual(nn.Module):
89 | def __init__(self, inp, oup, stride, expand_ratio, bits_weights=32, bits_activations=32):
90 | super(QInvertedResidual, self).__init__()
91 | self.stride = stride
92 | assert stride in [1, 2]
93 |
94 | hidden_dim = int(round(inp * expand_ratio))
95 | self.use_res_connect = self.stride == 1 and inp == oup
96 |
97 | layers = []
98 | if expand_ratio != 1:
99 | # pw
100 | layers.append(QConvBNReLU(inp, hidden_dim, kernel_size=1, bits_weights=bits_weights, bits_activations=bits_activations))
101 | layers.extend([
102 | # dw
103 | QConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, bits_weights=bits_weights, bits_activations=bits_activations),
104 | # pw-linear
105 | QConv2d(hidden_dim, oup, 1, 1, 0, bias=False, bits_weights=bits_weights, bits_activations=bits_activations),
106 | nn.BatchNorm2d(oup),
107 | ])
108 | self.conv = nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | if self.use_res_connect:
112 | return x + self.conv(x)
113 | else:
114 | return self.conv(x)
115 |
116 |
117 | class MobileNetV2(nn.Module):
118 | def __init__(self,
119 | num_classes=1000,
120 | width_mult=1.0,
121 | inverted_residual_setting=None,
122 | round_nearest=8,
123 | block=None):
124 | """
125 | MobileNet V2 main class
126 | Args:
127 | num_classes (int): Number of classes
128 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
129 | inverted_residual_setting: Network structure
130 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
131 | Set to 1 to turn off rounding
132 | block: Module specifying inverted residual building block for mobilenet
133 | """
134 | super(MobileNetV2, self).__init__()
135 |
136 | if block is None:
137 | block = InvertedResidual
138 | input_channel = 32
139 | last_channel = 1280
140 |
141 | if inverted_residual_setting is None:
142 | inverted_residual_setting = [
143 | # t, c, n, s
144 | [1, 16, 1, 1],
145 | [6, 24, 2, 2],
146 | [6, 32, 3, 2],
147 | [6, 64, 4, 2],
148 | [6, 96, 3, 1],
149 | [6, 160, 3, 2],
150 | [6, 320, 1, 1],
151 | ]
152 |
153 | # only check the first element, assuming user knows t,c,n,s are required
154 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
155 | raise ValueError("inverted_residual_setting should be non-empty "
156 | "or a 4-element list, got {}".format(inverted_residual_setting))
157 |
158 | # building first layer
159 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
160 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
161 | features = [ConvBNReLU(3, input_channel, stride=2)]
162 | # building inverted residual blocks
163 | for t, c, n, s in inverted_residual_setting:
164 | output_channel = _make_divisible(c * width_mult, round_nearest)
165 | for i in range(n):
166 | stride = s if i == 0 else 1
167 | features.append(block(input_channel, output_channel, stride, expand_ratio=t))
168 | input_channel = output_channel
169 | # building last several layers
170 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
171 | # make it nn.Sequential
172 | self.features = nn.Sequential(*features)
173 |
174 | # building classifier
175 | self.classifier = nn.Sequential(
176 | nn.Dropout(0.2),
177 | nn.Linear(self.last_channel, num_classes),
178 | )
179 |
180 | # weight initialization
181 | for m in self.modules():
182 | if isinstance(m, nn.Conv2d):
183 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
184 | if m.bias is not None:
185 | nn.init.zeros_(m.bias)
186 | elif isinstance(m, nn.BatchNorm2d):
187 | nn.init.ones_(m.weight)
188 | nn.init.zeros_(m.bias)
189 | elif isinstance(m, nn.Linear):
190 | nn.init.normal_(m.weight, 0, 0.01)
191 | nn.init.zeros_(m.bias)
192 |
193 | def _forward_impl(self, x):
194 | # This exists since TorchScript doesn't support inheritance, so the superclass method
195 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
196 | x = self.features(x)
197 | x = x.mean([2, 3])
198 | x = self.classifier(x)
199 | return x
200 |
201 | def forward(self, x):
202 | return self._forward_impl(x)
203 |
204 |
205 | class QMobileNetV2(nn.Module):
206 | def __init__(self,
207 | num_classes=1000,
208 | width_mult=1.0,
209 | inverted_residual_setting=None,
210 | round_nearest=8,
211 | block=None,
212 | bits_weights=32,
213 | bits_activations=32):
214 | """
215 | MobileNet V2 main class
216 | Args:
217 | num_classes (int): Number of classes
218 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
219 | inverted_residual_setting: Network structure
220 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
221 | Set to 1 to turn off rounding
222 | block: Module specifying inverted residual building block for mobilenet
223 | bits_weights: quantization bits of weights
224 | bits_activations: quantization bits of activations
225 | """
226 | super(QMobileNetV2, self).__init__()
227 |
228 | if block is None:
229 | block = QInvertedResidual
230 | input_channel = 32
231 | last_channel = 1280
232 |
233 | if inverted_residual_setting is None:
234 | inverted_residual_setting = [
235 | # t, c, n, s
236 | [1, 16, 1, 1],
237 | [6, 24, 2, 2],
238 | [6, 32, 3, 2],
239 | [6, 64, 4, 2],
240 | [6, 96, 3, 1],
241 | [6, 160, 3, 2],
242 | [6, 320, 1, 1],
243 | ]
244 |
245 | # only check the first element, assuming user knows t,c,n,s are required
246 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
247 | raise ValueError("inverted_residual_setting should be non-empty "
248 | "or a 4-element list, got {}".format(inverted_residual_setting))
249 |
250 | # building first layer
251 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
252 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
253 | features = [ConvBNReLU(3, input_channel, stride=2)]
254 | # building inverted residual blocks
255 | for t, c, n, s in inverted_residual_setting:
256 | output_channel = _make_divisible(c * width_mult, round_nearest)
257 | for i in range(n):
258 | stride = s if i == 0 else 1
259 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, bits_weights=bits_weights, bits_activations=bits_activations))
260 | input_channel = output_channel
261 | # building last several layers
262 | features.append(QConvBNReLU(input_channel, self.last_channel, kernel_size=1, bits_weights=bits_weights, bits_activations=bits_activations))
263 | # make it nn.Sequential
264 | self.features = nn.Sequential(*features)
265 |
266 | # building classifier
267 | self.classifier = nn.Sequential(
268 | nn.Dropout(0.2),
269 | nn.Linear(self.last_channel, num_classes),
270 | )
271 |
272 | # weight initialization
273 | for m in self.modules():
274 | if isinstance(m, QConv2d):
275 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
276 | if m.bias is not None:
277 | nn.init.zeros_(m.bias)
278 | elif isinstance(m, nn.BatchNorm2d):
279 | nn.init.ones_(m.weight)
280 | nn.init.zeros_(m.bias)
281 | elif isinstance(m, nn.Linear):
282 | nn.init.normal_(m.weight, 0, 0.01)
283 | nn.init.zeros_(m.bias)
284 |
285 | def _forward_impl(self, x):
286 | # This exists since TorchScript doesn't support inheritance, so the superclass method
287 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
288 | x = self.features(x)
289 | x = x.mean([2, 3])
290 | x = self.classifier(x)
291 | return x
292 |
293 | def forward(self, x):
294 | return self._forward_impl(x)
295 |
296 |
297 | def mobilenet_v2(pretrained=False, progress=True, **kwargs):
298 | """
299 | Constructs a MobileNetV2 architecture from
300 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
301 | Args:
302 | pretrained (bool): If True, returns a model pre-trained on ImageNet
303 | progress (bool): If True, displays a progress bar of the download to stderr
304 | """
305 | model = MobileNetV2(**kwargs)
306 | if pretrained:
307 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
308 | progress=progress)
309 | model.load_state_dict(state_dict)
310 | return model
311 |
312 |
313 | def Qmobilenet_v2(pretrained=False, progress=True, **kwargs):
314 | """
315 | Constructs a MobileNetV2 architecture from
316 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
317 | Args:
318 | pretrained (bool): If True, returns a model pre-trained on ImageNet
319 | progress (bool): If True, displays a progress bar of the download to stderr
320 | """
321 | model = QMobileNetV2(**kwargs)
322 | if pretrained:
323 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
324 | progress=progress)
325 | model.load_state_dict(state_dict)
326 | return model
327 |
328 |
329 |
330 | def main():
331 | qmobilenet = Qmobilenet_v2(pretrained=True, bits_weights=8, bits_activations=8)
332 | print(qmobilenet)
333 |
334 |
335 | if __name__ == "__main__":
336 | main()
337 |
--------------------------------------------------------------------------------
/quantization/qresnet.py:
--------------------------------------------------------------------------------
1 | # copy from pytorch-torchvision-models-resnet
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | from quantization.google_quantization import QConv2d
6 |
7 |
8 | __all__ = [
9 | "QResNet",
10 | "QBasicBlock",
11 | "QBottleneck",
12 | "Qresnet101"
13 | ]
14 |
15 | model_urls = {
16 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
17 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
18 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
19 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
20 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1):
25 | "3x3 convolution with padding"
26 | return nn.Conv2d(
27 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
28 | )
29 |
30 |
31 | def qconv3x3(in_planes, out_planes, stride=1, bits_weights=32, bits_activations=32):
32 | "3x3 convolution with padding"
33 | return QConv2d(
34 | in_planes,
35 | out_planes,
36 | kernel_size=3,
37 | stride=stride,
38 | padding=1,
39 | bias=False,
40 | bits_weights=bits_weights,
41 | bits_activations=bits_activations,
42 | )
43 |
44 |
45 | class QBasicBlock(nn.Module):
46 | expansion = 1
47 |
48 | def __init__(
49 | self,
50 | inplanes,
51 | planes,
52 | stride=1,
53 | downsample=None,
54 | bits_weights=32,
55 | bits_activations=32,
56 | ):
57 | super(QBasicBlock, self).__init__()
58 | self.name = "resnet-basic"
59 | self.conv1 = qconv3x3(
60 | inplanes,
61 | planes,
62 | stride,
63 | bits_weights=bits_weights,
64 | bits_activations=bits_activations,
65 | )
66 | self.bn1 = nn.BatchNorm2d(planes)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.conv2 = qconv3x3(
69 | planes, planes, bits_weights=bits_weights, bits_activations=bits_activations
70 | )
71 | self.bn2 = nn.BatchNorm2d(planes)
72 | self.downsample = downsample
73 | self.stride = stride
74 | self.block_index = 0
75 |
76 | def forward(self, x):
77 | residual = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class QBottleneck(nn.Module):
96 | expansion = 4
97 |
98 | def __init__(
99 | self,
100 | inplanes,
101 | planes,
102 | stride=1,
103 | downsample=None,
104 | bits_weights=32,
105 | bits_activations=32,
106 | ):
107 | super(QBottleneck, self).__init__()
108 | self.name = "resnet-bottleneck"
109 | self.conv1 = QConv2d(
110 | inplanes,
111 | planes,
112 | kernel_size=1,
113 | bias=False,
114 | bits_weights=bits_weights,
115 | bits_activations=bits_activations,
116 | )
117 | self.bn1 = nn.BatchNorm2d(planes)
118 | self.conv2 = QConv2d(
119 | planes,
120 | planes,
121 | kernel_size=3,
122 | stride=stride,
123 | padding=1,
124 | bias=False,
125 | bits_weights=bits_weights,
126 | bits_activations=bits_activations,
127 | )
128 | self.bn2 = nn.BatchNorm2d(planes)
129 | self.conv3 = QConv2d(
130 | planes,
131 | planes * 4,
132 | kernel_size=1,
133 | bias=False,
134 | bits_weights=bits_weights,
135 | bits_activations=bits_activations,
136 | )
137 | self.bn3 = nn.BatchNorm2d(planes * 4)
138 | self.relu = nn.ReLU(inplace=True)
139 | self.downsample = downsample
140 | self.stride = stride
141 | self.block_index = 0
142 |
143 | def forward(self, x):
144 | residual = x
145 |
146 | out = self.conv1(x)
147 | out = self.bn1(out)
148 | out = self.relu(out)
149 |
150 | out = self.conv2(out)
151 | out = self.bn2(out)
152 | out = self.relu(out)
153 |
154 | out = self.conv3(out)
155 | out = self.bn3(out)
156 |
157 | if self.downsample is not None:
158 | residual = self.downsample(x)
159 |
160 | out += residual
161 |
162 | out = self.relu(out)
163 |
164 | return out
165 |
166 |
167 | class QResNet(nn.Module):
168 | def __init__(self, depth, num_classes=1000, bits_weights=32, bits_activations=32):
169 | self.inplanes = 64
170 | super(QResNet, self).__init__()
171 | if depth < 50:
172 | block = QBasicBlock
173 | else:
174 | block = QBottleneck
175 |
176 | if depth == 18:
177 | layers = [2, 2, 2, 2]
178 | elif depth == 34:
179 | layers = [3, 4, 6, 3]
180 | elif depth == 50:
181 | layers = [3, 4, 6, 3]
182 | elif depth == 101:
183 | layers = [3, 4, 23, 3]
184 | elif depth == 152:
185 | layers = [3, 8, 36, 3]
186 |
187 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
188 | self.bn1 = nn.BatchNorm2d(64)
189 | self.relu = nn.ReLU(inplace=True)
190 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
191 | self.layer1 = self._make_layer(
192 | block,
193 | 64,
194 | layers[0],
195 | bits_weights=bits_weights,
196 | bits_activations=bits_activations,
197 | )
198 | self.layer2 = self._make_layer(
199 | block,
200 | 128,
201 | layers[1],
202 | stride=2,
203 | bits_weights=bits_weights,
204 | bits_activations=bits_activations,
205 | )
206 | self.layer3 = self._make_layer(
207 | block,
208 | 256,
209 | layers[2],
210 | stride=2,
211 | bits_weights=bits_weights,
212 | bits_activations=bits_activations,
213 | )
214 | self.layer4 = self._make_layer(
215 | block,
216 | 512,
217 | layers[3],
218 | stride=2,
219 | bits_weights=bits_weights,
220 | bits_activations=bits_activations,
221 | )
222 | self.avgpool = nn.AvgPool2d(7, stride=1)
223 | self.fc = nn.Linear(512 * block.expansion, num_classes)
224 |
225 | for m in self.modules():
226 | if isinstance(m, nn.Conv2d):
227 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
228 | elif isinstance(m, nn.BatchNorm2d):
229 | nn.init.constant_(m.weight, 1)
230 | nn.init.constant_(m.bias, 0)
231 |
232 | def _make_layer(
233 | self, block, planes, blocks, stride=1, bits_weights=32, bits_activations=32
234 | ):
235 | downsample = None
236 | if stride != 1 or self.inplanes != planes * block.expansion:
237 | downsample = nn.Sequential(
238 | QConv2d(
239 | self.inplanes,
240 | planes * block.expansion,
241 | kernel_size=1,
242 | stride=stride,
243 | bias=False,
244 | bits_weights=bits_weights,
245 | bits_activations=bits_activations,
246 | ),
247 | nn.BatchNorm2d(planes * block.expansion),
248 | )
249 |
250 | layers = []
251 | layers.append(
252 | block(
253 | self.inplanes,
254 | planes,
255 | stride,
256 | downsample,
257 | bits_weights=bits_weights,
258 | bits_activations=bits_activations,
259 | )
260 | )
261 | self.inplanes = planes * block.expansion
262 | for i in range(1, blocks):
263 | layers.append(
264 | block(
265 | self.inplanes,
266 | planes,
267 | bits_weights=bits_weights,
268 | bits_activations=bits_activations,
269 | )
270 | )
271 |
272 | return nn.Sequential(*layers)
273 |
274 | def forward(self, x):
275 | x = self.conv1(x)
276 | x = self.bn1(x)
277 | x = self.relu(x)
278 | x = self.maxpool(x)
279 |
280 | x = self.layer1(x)
281 | x = self.layer2(x)
282 | x = self.layer3(x)
283 | x = self.layer4(x)
284 |
285 | x = self.avgpool(x)
286 | x = x.view(x.size(0), -1)
287 | x = self.fc(x)
288 |
289 | return x
290 |
291 |
292 | def Qresnet101(pretrained=False, **kwargs): # bits_weights=2, bits_activations=2
293 | """Constructs a ResNet-101 model.
294 | Args:
295 | pretrained (bool): If True, returns a model pre-trained on ImageNet
296 | """
297 | model = QResNet(101, num_classes=1000, **kwargs)
298 | if pretrained:
299 | model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
300 | return model
301 |
302 |
303 | def Qresnet50(pretrained=False, **kwargs): # bits_weights=2, bits_activations=2
304 | """Constructs a ResNet-101 model.
305 | Args:
306 | pretrained (bool): If True, returns a model pre-trained on ImageNet
307 | """
308 | model = QResNet(50, num_classes=1000, **kwargs)
309 | if pretrained:
310 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
311 | return model
312 |
313 |
314 | if __name__ == '__main__':
315 | model = Qresnet50(pretrained=True)
316 | print(model)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 |
2 | numpy==1.14.3
3 | pyhocon==0.3.51
4 | opencv-python==4.1.0
5 | tensorboardX==2.0
6 | torchvision==0.3.0
7 | torch==1.1.0
8 |
9 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/9/23 19:41
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : __init__.py.py
--------------------------------------------------------------------------------
/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/9/23 22:25
4 | # @Author : xiezheng
5 | # @Site :
6 | # @File : checkpoint.py
7 |
8 | import os
9 |
10 | import torch
11 | from torch import nn
12 | from utils.util import ensure_folder
13 |
14 |
15 | def save_checkpoint(outpath, epoch, model_feature, model_source_classifier, model_target_classifier,
16 | optimizer, lr_scheduler, val_best_acc):
17 |
18 | check_point_params = {}
19 | if isinstance(model_feature, nn.DataParallel):
20 | check_point_params["model_feature"] = model_feature.module.state_dict()
21 | else:
22 | check_point_params["model_feature"] = model_feature.state_dict()
23 |
24 | if isinstance(model_source_classifier, nn.DataParallel):
25 | check_point_params["model_source_classifier"] = model_source_classifier.module.state_dict()
26 | else:
27 | check_point_params["model_source_classifier"] = model_source_classifier.state_dict()
28 |
29 | if isinstance(model_target_classifier, nn.DataParallel):
30 | check_point_params["model_target_classifier"] = model_target_classifier.module.state_dict()
31 | else:
32 | check_point_params["model_target_classifier"] = model_target_classifier.state_dict()
33 |
34 | check_point_params["val_best_acc"] = val_best_acc
35 | check_point_params["optimizer"] = optimizer
36 | check_point_params["lr_scheduler"] = lr_scheduler
37 | check_point_params['epoch'] = epoch
38 |
39 |
40 | output_path = os.path.join(outpath, "check_point")
41 | ensure_folder(output_path)
42 | filename = 'checkpoint.pth'
43 | torch.save(check_point_params, os.path.join(output_path, filename))
44 |
45 |
46 | def save_model(outpath, epoch, model_feature, model_target_classifier, val_best_acc, logger):
47 | check_point_params = {}
48 |
49 | if isinstance(model_feature, nn.DataParallel):
50 | check_point_params["model"] = model_feature.module.state_dict()
51 | else:
52 | check_point_params["model"] = model_feature.state_dict()
53 |
54 | if isinstance(model_target_classifier, nn.DataParallel):
55 | check_point_params["fc"] = model_target_classifier.module.state_dict()
56 | else:
57 | check_point_params["fc"] = model_target_classifier.state_dict()
58 |
59 | output_path = os.path.join(outpath, "check_point")
60 | ensure_folder(output_path)
61 | filename = 'model_{:03d}_acc{:.4f}.pth'.format(epoch, val_best_acc)
62 | torch.save(check_point_params, os.path.join(output_path, filename))
63 |
64 |
65 |
66 | def load_checkpoint(checkpoint_path, model, optimizer, lr_scheduler, logger):
67 | check_point_params = torch.load(checkpoint_path)
68 | model_state = check_point_params["model"]
69 | start_epoch = check_point_params['epoch']
70 | optimizer.load_state_dict(check_point_params["optimizer"])
71 | lr_scheduler.load_state_dict(check_point_params["lr_scheduler"])
72 | val_best_acc = check_point_params["val_best_acc"]
73 |
74 | model = load_state(model, model_state, logger)
75 | return model, start_epoch, optimizer, lr_scheduler, val_best_acc
76 |
77 |
78 | def load_state(model, state_dict, logger):
79 | """
80 | load state_dict to model
81 | :params model:
82 | :params state_dict:
83 | :return: model
84 | """
85 |
86 | if isinstance(model, nn.DataParallel):
87 | model.module.load_state_dict(state_dict)
88 | else:
89 | model.load_state_dict(state_dict)
90 | logger.info("load model state finished !!!")
91 | return model
92 |
93 |
94 | def load_pretrain_model(pretrain_path, model, logger):
95 | if pretrain_path is not None:
96 | check_point_params = torch.load(pretrain_path)
97 |
98 | model_state = check_point_params['model']
99 | # model_state = check_point_params
100 |
101 | model = model.load_state(model_state)
102 | logger.info("|===>load restrain file: {}".format(pretrain_path))
103 | else:
104 | logger.info('pretrain_path is None')
105 |
106 | return model
107 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from datetime import datetime
4 | import numpy as np
5 | import shutil
6 | import json
7 |
8 | import torch
9 | import logging
10 | import torch.nn.functional as F
11 | from torch import nn
12 | from torch import optim
13 | from torch.optim.lr_scheduler import MultiStepLR
14 | import torch.distributed as dist
15 |
16 |
17 | def get_logger(save_path, logger_name):
18 | """
19 | Initialize logger
20 | """
21 |
22 | logger = logging.getLogger(logger_name)
23 | file_formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
24 | console_formatter = logging.Formatter('%(message)s')
25 | # file log
26 | file_handler = logging.FileHandler(os.path.join(save_path, "experiment.log"))
27 | file_handler.setFormatter(file_formatter)
28 |
29 | # console log
30 | console_handler = logging.StreamHandler(sys.stdout)
31 | console_handler.setFormatter(console_formatter)
32 |
33 | logger.addHandler(file_handler)
34 | logger.addHandler(console_handler)
35 |
36 | logger.setLevel(logging.INFO)
37 | return logger
38 |
39 |
40 | def output_process(output_path):
41 | if os.path.exists(output_path):
42 | print("{} file exist!".format(output_path))
43 | action = input("Select Action: d (delete) / q (quit):").lower().strip()
44 | act = action
45 | if act == 'd':
46 | shutil.rmtree(output_path)
47 | else:
48 | raise OSError("Directory {} exits!".format(output_path))
49 |
50 | if not os.path.exists(output_path):
51 | os.makedirs(output_path)
52 |
53 | def is_dist_avail_and_initialized():
54 | if not dist.is_available():
55 | return False
56 | if not dist.is_initialized():
57 | return False
58 | return True
59 |
60 | def get_rank():
61 | if not is_dist_avail_and_initialized():
62 | return 0
63 | return dist.get_rank()
64 |
65 | def is_main_process():
66 | return get_rank() == 0
67 |
68 |
69 | class AverageMeter(object):
70 | """
71 | Keeps track of most recent, average, sum, and count of a metric.
72 | """
73 |
74 | def __init__(self):
75 | self.reset()
76 |
77 | def reset(self):
78 | self.val = 0
79 | self.avg = 0
80 | self.sum = 0
81 | self.count = 0
82 |
83 | def update(self, val, n=1):
84 | self.val = val
85 | self.sum += val * n
86 | self.count += n
87 | self.avg = self.sum / self.count
88 |
89 |
90 | def get_learning_rate(optimizer):
91 | lr = []
92 | for param_group in optimizer.param_groups:
93 | lr += [param_group['lr']]
94 | return lr[0]
95 |
96 |
97 | def accuracy(scores, targets, k=1):
98 | batch_size = targets.size(0)
99 | _, ind = scores.topk(k, 1, True, True)
100 | correct = ind.eq(targets.long().view(-1, 1).expand_as(ind))
101 | correct_total = correct.view(-1).float().sum() # 0D tensor
102 | return correct_total.item() * (1.0 / batch_size)
103 |
104 |
105 | def ensure_folder(folder):
106 | import os
107 | if not os.path.isdir(folder):
108 | os.makedirs(folder)
109 |
110 |
111 | def record_epoch_data(outpath, epoch, clc_loss, mid_clc_loss, train_total_loss, train_top1_acc, val_loss, val_top1_acc):
112 | txt_path = os.path.join(outpath, "log.txt")
113 | f = open(txt_path, 'a+')
114 |
115 | record_txt = '{}\t{}\t{}\t{}\t{}\t{}\t{}\n' \
116 | .format(epoch, clc_loss, mid_clc_loss, train_total_loss, train_top1_acc, val_loss, val_top1_acc)
117 |
118 | if epoch == 0:
119 | record_head = "epoch\tclc_loss\tmid_clc_loss\t" \
120 | "train_total_loss\ttrain_top1_acc\tval_loss\tval_top1_acc\n"
121 | f.write(record_head)
122 |
123 | f.write(record_txt)
124 | f.close()
125 |
126 |
127 | def ours_record_epoch_data(outpath, epoch, clc_loss, kl_loss, fm_mse_loss,
128 | train_total_loss, train_top1_acc, val_loss, val_top1_acc):
129 | txt_path = os.path.join(outpath, "log.txt")
130 | f = open(txt_path, 'a+')
131 |
132 | record_txt = '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n' \
133 | .format(epoch, clc_loss, kl_loss, fm_mse_loss, train_total_loss, train_top1_acc, val_loss, val_top1_acc)
134 |
135 | if epoch == 0:
136 | record_head = "epoch\tclc_loss\tkl_loss\tfm_mse_loss\t" \
137 | "train_total_loss\ttrain_top1_acc\tval_loss\tval_top1_acc\n"
138 | f.write(record_head)
139 |
140 | f.write(record_txt)
141 | f.close()
142 |
143 |
144 | def record_epoch_learn_alpha(outpath, alpha, epoch, logger):
145 | txt_path = os.path.join(outpath, "alpha.txt")
146 | f = open(txt_path, 'a+')
147 | alpha = alpha.data.cpu().numpy()
148 | # print(alpha.shape)
149 | np.savetxt(f, alpha.reshape(1, alpha.shape[0]), fmt='%.6e')
150 | f.close()
151 | # assert False
152 | logger.info("epoch={}, alpha save!".format(epoch))
153 |
154 |
155 | def write_settings(settings):
156 | """
157 | Save expriment settings to a file
158 | :param settings: the instance of option
159 | """
160 |
161 | with open(os.path.join(settings.outpath, "settings.log"), "w") as f:
162 | for k, v in settings.__dict__.items():
163 | f.write(str(k) + ": " + str(v) + "\n")
164 |
165 |
166 | def get_optimier_and_scheduler(args, model_feature, model_target_classifier, feature_criterions, step, logger):
167 | # model_source_classifier fixed
168 | if len(args.gpu_id) > 1:
169 | if feature_criterions:
170 | optimizer = optim.SGD([{'params': model_feature.module.parameters()},
171 | {'params': model_target_classifier.module.parameters(), 'weight_decay':args.weight_decay},
172 | {'params': feature_criterions.module.parameters()}],
173 | lr=args.lr, momentum=args.momentum)
174 | else:
175 | optimizer = optim.SGD([{'params': model_feature.module.parameters()},
176 | {'params': model_target_classifier.module.parameters(),
177 | 'weight_decay':args.weight_decay}],
178 | lr=args.lr, momentum=args.momentum)
179 | else:
180 | if feature_criterions:
181 | optimizer = optim.SGD([{'params': model_feature.parameters()},
182 | {'params': model_target_classifier.parameters(), 'weight_decay':args.weight_decay},
183 | {'params': feature_criterions.parameters()}],
184 | lr=args.lr, momentum=args.momentum)
185 | else:
186 | optimizer = optim.SGD([{'params': model_feature.parameters()},
187 | {'params': model_target_classifier.parameters(),
188 | 'weight_decay':args.weight_decay}],
189 | lr=args.lr, momentum=args.momentum)
190 |
191 | logger.info('optimizer={}'.format(optimizer))
192 |
193 | # lr_scheduler
194 | if args.lr_scheduler == 'steplr':
195 | lr_scheduler = MultiStepLR(optimizer, milestones=step, gamma=args.gamma)
196 | logger.info('lr_scheduler: SGD MultiStepLR !!!')
197 | else:
198 | assert False, logger.info("invalid lr_scheduler={}".format(args.lr_scheduler))
199 |
200 | logger.info('lr_scheduler={}'.format(lr_scheduler))
201 | return optimizer, lr_scheduler
202 |
203 |
204 | def get_channel_weight(channel_weight_path, logger=None):
205 | channel_weights = []
206 | if channel_weight_path:
207 | for js in json.load(open(channel_weight_path)):
208 | js = np.array(js)
209 | js = (js - np.mean(js)) / np.std(js) # normalization
210 | cw = torch.from_numpy(js).float().cuda()
211 | cw = F.softmax(cw / 5.0).detach()
212 | channel_weights.append(cw)
213 | else:
214 | logger.info("channel_weight_path is None")
215 | return None
216 |
217 | return channel_weights
218 |
219 |
220 | # when data distribute to different GPU,
221 | # we concat the feature in different GPU but in the same position
222 | def concat_gpu_data(data):
223 | """
224 | Concat gpu data from different gpu.
225 | """
226 | gpu_id = list(data.keys())
227 | gpu_id.sort()
228 | main_gpu_id = gpu_id[0]
229 | data_features = []
230 | for j, i in enumerate(gpu_id):
231 | data_Cat = data[i]
232 | for k, fea in enumerate(data_Cat):
233 | if j == 0:
234 | data_features.append(fea)
235 | else:
236 | data_features[k] = torch.cat((data_features[k], fea.cuda(int(main_gpu_id))))
237 |
238 | return data_features
239 |
240 |
241 | def get_conv_num(base_model_name, model_source, fc_name, logger):
242 | model_source_weights = {}
243 | if 'resnet' in base_model_name:
244 | for name, param in model_source.named_parameters():
245 | # print('name={}'.format(name))
246 | if not name.startswith(fc_name) and ('conv' in name or 'downsample.0' in name):
247 | model_source_weights[name] = param.detach()
248 | logger.info('name={}'.format(name))
249 | # to do
250 | # elif 'inception' in base_model_name:
251 | else:
252 | assert False, logger.info("invalid base_model_name={}, "
253 | "do not know fc_name ".format(base_model_name))
254 |
255 | layer_length = len(model_source_weights)
256 | return layer_length
257 |
258 |
259 | def get_fc_name(base_model_name, logger):
260 | if 'resnet' in base_model_name:
261 | fc_name = 'fc.'
262 | elif 'inception' in base_model_name:
263 | fc_name = 'fc.'
264 | else:
265 | assert False, logger.info("invalid base_model_name={}, "
266 | "do not know fc_name ".format(base_model_name))
267 | return fc_name
268 |
269 |
270 | if __name__ == '__main__':
271 | # 241
272 | channel_weights_path = './json_result/channel_wei.Stanford_Dogs.json'
273 | channel_weights = get_channel_weight(channel_weights_path)
274 |
275 | print(len(channel_weights))
276 | print(channel_weights[0].shape)
277 |
278 |
--------------------------------------------------------------------------------