├── .gitignore
├── LICENSE
├── README.md
├── collect_env.py
├── convs
├── __init__.py
├── cifar_resnet.py
├── linears.py
├── resnet.py
└── vits.py
├── evaluator.py
├── exps
├── slca
│ ├── slca_cars.json
│ ├── slca_cars_mocov3.json
│ ├── slca_cifar.json
│ ├── slca_cifar_mocov3.json
│ ├── slca_cub.json
│ ├── slca_cub_mocov3.json
│ ├── slca_imgnetr.json
│ └── slca_imgnetr_mocov3.json
└── slcapp
│ ├── slcapp_cars_lora.json
│ ├── slcapp_cars_lora_mocov3.json
│ ├── slcapp_cifar_lora.json
│ ├── slcapp_cifar_lora_mocov3.json
│ ├── slcapp_cub_lora.json
│ ├── slcapp_cub_lora_mocov3.json
│ ├── slcapp_imgnetr_lora.json
│ └── slcapp_imgnetr_lora_mocov3.json
├── main.py
├── models
├── __init__.py
├── base.py
├── slca.py
└── slca_pp.py
├── slca_performance.jpg
├── split_car.py
├── split_cub.py
├── train_all.sh
├── trainer.py
└── utils
├── __init__.py
├── buffer.py
├── cutmix.py
├── data.py
├── data_manager.py
├── factory.py
├── inc_net.py
├── net_linear_wapper.py
└── toolkit.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | logs/
3 | __pycache__/
4 | *.pyc
5 | .DS_Store
6 |
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Gengwei (David) Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
SLCA++: Unleash the Power of Sequential Fine-tuning for Continual Learning with Pre-training
6 |
7 |
8 |
9 | Gengwei Zhang* Liyuan Wang* Guoliang Kang Ling Chen Yunchao Wei
10 |
11 |
12 |
13 |
14 |
15 |
16 | PyTorch code for paper "[SLCA++: Unleash the Power of Sequential Fine-tuning for Continual Learning with Pre-training](https://arxiv.org/abs/2408.08295)", together with the code for our ICCV 2023 paper "[SLCA: Slow Learner with Classifier Alignment for Continual Learning on a Pre-trained Model](https://arxiv.org/abs/2303.05118)".
17 |
18 | ## What's new?
19 | [2024.08] We release SLCA++, a parameter-efficient version of SLCA with even better continual performance on fine-grained benchmarks!
20 |
21 | ## Introduction
22 | In our paper, we present an in-depth analysis of the progressive overfitting problem from the lens of Seq FT. Considering that the overly fast representation learning and the biased classification layer constitute this particular problem, we introduce the advanced Slow Learner with Classifier Alignment (SLCA++) framework to unleash the power of Seq FT, serving as a strong baseline approach for Continual Learning with Pre-training (CLPT). Our approach involves a Slow Learner (SL) to selectively reduce the learning rate of backbone parameters, and a Classifier Alignment (CA) to align the disjoint classification layers in a post-hoc fashion. We further enhance the efficacy of SL with a symmetric cross-entropy loss (SCE), as well as employ a parameter-efficient strategy to implement Seq FT with SLCA++. Across a variety of continual learning scenarios, including class-incremental learning on general datasets like CIFAR-100 and ImageNet-R, fine-grained datasets like CUB-200 and Cars-196, and domain-incremental learning on DomainNet, our approach provides substantial improvements and outperforms state-of-the-art methods by a large margin.
23 |
24 |
25 |
26 | ## Requirement
27 | 1. torch==1.12.0
28 | 2. torchvision==0.13.0
29 | 3. timm==0.5.4
30 | 4. tqdm
31 | 5. numpy
32 | 6. scipy
33 | 7. quadprog
34 | 8. POT
35 |
36 | ## Pre-trained Models
37 | Please download pre-trained ViT-Base models from [MoCo v3](https://drive.google.com/file/d/1bshDu4jEKztZZvwpTVXSAuCsDoXwCkfy/view?usp=share_link) and [ImaegNet-21K](https://drive.google.com/file/d/1PcAOf0tJYs1FVDpj-7lrkSuwXTJXVmuk/view?usp=share_link) and then put or link the pre-trained models to ```SLCA/pretrained```
38 |
39 | ## Acknowledgement
40 | This repo is heavily based on [PyCIL](https://github.com/G-U-N/PyCIL), many thanks.
41 |
42 | ## Citation
43 |
44 | If you find our codes or paper useful, please consider giving us a star or cite with:
45 |
46 | ```
47 | @misc{zhang2024slcaunleashpowersequential,
48 | title={SLCA++: Unleash the Power of Sequential Fine-tuning for Continual Learning with Pre-training},
49 | author={Zhang, Gengwei and Wang, Liyuan and Kang, Guoliang and Chen, Ling and Wei, Yunchao},
50 | year={2024},
51 | eprint={2408.08295},
52 | archivePrefix={arXiv},
53 | url={https://arxiv.org/abs/2408.08295},
54 | }
55 | ```
56 |
57 | ```
58 | @inproceedings{zhang2023slca,
59 | title={SLCA: Slow Learner with Classifier Alignment for Continual Learning on a Pre-trained Model},
60 | author={Zhang, Gengwei and Wang, Liyuan and Kang, Guoliang and Chen, Ling and Wei, Yunchao},
61 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
62 | year={2023}
63 | }
64 | ```
65 |
--------------------------------------------------------------------------------
/collect_env.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | # Unlike the rest of the PyTorch this file must be python2 compliant.
4 | # This script outputs relevant system environment info
5 | # Run it with `python collect_env.py`.
6 | import datetime
7 | import locale
8 | import re
9 | import subprocess
10 | import sys
11 | import os
12 | from collections import namedtuple
13 |
14 |
15 | try:
16 | import torch
17 | TORCH_AVAILABLE = True
18 | except (ImportError, NameError, AttributeError, OSError):
19 | TORCH_AVAILABLE = False
20 |
21 | # System Environment Information
22 | SystemEnv = namedtuple('SystemEnv', [
23 | 'torch_version',
24 | 'is_debug_build',
25 | 'cuda_compiled_version',
26 | 'gcc_version',
27 | 'clang_version',
28 | 'cmake_version',
29 | 'os',
30 | 'libc_version',
31 | 'python_version',
32 | 'python_platform',
33 | 'is_cuda_available',
34 | 'cuda_runtime_version',
35 | 'nvidia_driver_version',
36 | 'nvidia_gpu_models',
37 | 'cudnn_version',
38 | 'pip_version', # 'pip' or 'pip3'
39 | 'pip_packages',
40 | 'conda_packages',
41 | 'hip_compiled_version',
42 | 'hip_runtime_version',
43 | 'miopen_runtime_version',
44 | 'caching_allocator_config',
45 | 'is_xnnpack_available',
46 | ])
47 |
48 |
49 | def run(command):
50 | """Returns (return-code, stdout, stderr)"""
51 | p = subprocess.Popen(command, stdout=subprocess.PIPE,
52 | stderr=subprocess.PIPE, shell=True)
53 | raw_output, raw_err = p.communicate()
54 | rc = p.returncode
55 | if get_platform() == 'win32':
56 | enc = 'oem'
57 | else:
58 | enc = locale.getpreferredencoding()
59 | output = raw_output.decode(enc)
60 | err = raw_err.decode(enc)
61 | return rc, output.strip(), err.strip()
62 |
63 |
64 | def run_and_read_all(run_lambda, command):
65 | """Runs command using run_lambda; reads and returns entire output if rc is 0"""
66 | rc, out, _ = run_lambda(command)
67 | if rc != 0:
68 | return None
69 | return out
70 |
71 |
72 | def run_and_parse_first_match(run_lambda, command, regex):
73 | """Runs command using run_lambda, returns the first regex match if it exists"""
74 | rc, out, _ = run_lambda(command)
75 | if rc != 0:
76 | return None
77 | match = re.search(regex, out)
78 | if match is None:
79 | return None
80 | return match.group(1)
81 |
82 | def run_and_return_first_line(run_lambda, command):
83 | """Runs command using run_lambda and returns first line if output is not empty"""
84 | rc, out, _ = run_lambda(command)
85 | if rc != 0:
86 | return None
87 | return out.split('\n')[0]
88 |
89 |
90 | def get_conda_packages(run_lambda):
91 | conda = os.environ.get('CONDA_EXE', 'conda')
92 | out = run_and_read_all(run_lambda, "{} list".format(conda))
93 | if out is None:
94 | return out
95 |
96 | return "\n".join(
97 | line
98 | for line in out.splitlines()
99 | if not line.startswith("#")
100 | and any(
101 | name in line
102 | for name in {
103 | "torch",
104 | "numpy",
105 | "cudatoolkit",
106 | "soumith",
107 | "mkl",
108 | "magma",
109 | "mkl",
110 | }
111 | )
112 | )
113 |
114 | def get_gcc_version(run_lambda):
115 | return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)')
116 |
117 | def get_clang_version(run_lambda):
118 | return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)')
119 |
120 |
121 | def get_cmake_version(run_lambda):
122 | return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)')
123 |
124 |
125 | def get_nvidia_driver_version(run_lambda):
126 | if get_platform() == 'darwin':
127 | cmd = 'kextstat | grep -i cuda'
128 | return run_and_parse_first_match(run_lambda, cmd,
129 | r'com[.]nvidia[.]CUDA [(](.*?)[)]')
130 | smi = get_nvidia_smi()
131 | return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ')
132 |
133 |
134 | def get_gpu_info(run_lambda):
135 | if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None):
136 | if TORCH_AVAILABLE and torch.cuda.is_available():
137 | return torch.cuda.get_device_name(None)
138 | return None
139 | smi = get_nvidia_smi()
140 | uuid_regex = re.compile(r' \(UUID: .+?\)')
141 | rc, out, _ = run_lambda(smi + ' -L')
142 | if rc != 0:
143 | return None
144 | # Anonymize GPUs by removing their UUID
145 | return re.sub(uuid_regex, '', out)
146 |
147 |
148 | def get_running_cuda_version(run_lambda):
149 | return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)')
150 |
151 |
152 | def get_cudnn_version(run_lambda):
153 | """This will return a list of libcudnn.so; it's hard to tell which one is being used"""
154 | if get_platform() == 'win32':
155 | system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
156 | cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%")
157 | where_cmd = os.path.join(system_root, 'System32', 'where')
158 | cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
159 | elif get_platform() == 'darwin':
160 | # CUDA libraries and drivers can be found in /usr/local/cuda/. See
161 | # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
162 | # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
163 | # Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
164 | cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*'
165 | else:
166 | cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
167 | rc, out, _ = run_lambda(cudnn_cmd)
168 | # find will return 1 if there are permission errors or if not found
169 | if len(out) == 0 or (rc != 1 and rc != 0):
170 | l = os.environ.get('CUDNN_LIBRARY')
171 | if l is not None and os.path.isfile(l):
172 | return os.path.realpath(l)
173 | return None
174 | files_set = set()
175 | for fn in out.split('\n'):
176 | fn = os.path.realpath(fn) # eliminate symbolic links
177 | if os.path.isfile(fn):
178 | files_set.add(fn)
179 | if not files_set:
180 | return None
181 | # Alphabetize the result because the order is non-deterministic otherwise
182 | files = list(sorted(files_set))
183 | if len(files) == 1:
184 | return files[0]
185 | result = '\n'.join(files)
186 | return 'Probably one of the following:\n{}'.format(result)
187 |
188 |
189 | def get_nvidia_smi():
190 | # Note: nvidia-smi is currently available only on Windows and Linux
191 | smi = 'nvidia-smi'
192 | if get_platform() == 'win32':
193 | system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
194 | program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files')
195 | legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi)
196 | new_path = os.path.join(system_root, 'System32', smi)
197 | smis = [new_path, legacy_path]
198 | for candidate_smi in smis:
199 | if os.path.exists(candidate_smi):
200 | smi = '"{}"'.format(candidate_smi)
201 | break
202 | return smi
203 |
204 |
205 | def get_platform():
206 | if sys.platform.startswith('linux'):
207 | return 'linux'
208 | elif sys.platform.startswith('win32'):
209 | return 'win32'
210 | elif sys.platform.startswith('cygwin'):
211 | return 'cygwin'
212 | elif sys.platform.startswith('darwin'):
213 | return 'darwin'
214 | else:
215 | return sys.platform
216 |
217 |
218 | def get_mac_version(run_lambda):
219 | return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)')
220 |
221 |
222 | def get_windows_version(run_lambda):
223 | system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
224 | wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic')
225 | findstr_cmd = os.path.join(system_root, 'System32', 'findstr')
226 | return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd))
227 |
228 |
229 | def get_lsb_version(run_lambda):
230 | return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)')
231 |
232 |
233 | def check_release_file(run_lambda):
234 | return run_and_parse_first_match(run_lambda, 'cat /etc/*-release',
235 | r'PRETTY_NAME="(.*)"')
236 |
237 |
238 | def get_os(run_lambda):
239 | from platform import machine
240 | platform = get_platform()
241 |
242 | if platform == 'win32' or platform == 'cygwin':
243 | return get_windows_version(run_lambda)
244 |
245 | if platform == 'darwin':
246 | version = get_mac_version(run_lambda)
247 | if version is None:
248 | return None
249 | return 'macOS {} ({})'.format(version, machine())
250 |
251 | if platform == 'linux':
252 | # Ubuntu/Debian based
253 | desc = get_lsb_version(run_lambda)
254 | if desc is not None:
255 | return '{} ({})'.format(desc, machine())
256 |
257 | # Try reading /etc/*-release
258 | desc = check_release_file(run_lambda)
259 | if desc is not None:
260 | return '{} ({})'.format(desc, machine())
261 |
262 | return '{} ({})'.format(platform, machine())
263 |
264 | # Unknown platform
265 | return platform
266 |
267 |
268 | def get_python_platform():
269 | import platform
270 | return platform.platform()
271 |
272 |
273 | def get_libc_version():
274 | import platform
275 | if get_platform() != 'linux':
276 | return 'N/A'
277 | return '-'.join(platform.libc_ver())
278 |
279 |
280 | def get_pip_packages(run_lambda):
281 | """Returns `pip list` output. Note: will also find conda-installed pytorch
282 | and numpy packages."""
283 | # People generally have `pip` as `pip` or `pip3`
284 | # But here it is incoved as `python -mpip`
285 | def run_with_pip(pip):
286 | out = run_and_read_all(run_lambda, "{} list --format=freeze".format(pip))
287 | return "\n".join(
288 | line
289 | for line in out.splitlines()
290 | if any(
291 | name in line
292 | for name in {
293 | "torch",
294 | "numpy",
295 | "mypy",
296 | }
297 | )
298 | )
299 |
300 | pip_version = 'pip3' if sys.version[0] == '3' else 'pip'
301 | out = run_with_pip(sys.executable + ' -mpip')
302 |
303 | return pip_version, out
304 |
305 |
306 | def get_cachingallocator_config():
307 | ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
308 | return ca_config
309 |
310 | def is_xnnpack_available():
311 | if TORCH_AVAILABLE:
312 | import torch.backends.xnnpack
313 | return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
314 | else:
315 | return "N/A"
316 |
317 | def get_env_info():
318 | run_lambda = run
319 | pip_version, pip_list_output = get_pip_packages(run_lambda)
320 |
321 | if TORCH_AVAILABLE:
322 | version_str = torch.__version__
323 | debug_mode_str = str(torch.version.debug)
324 | cuda_available_str = str(torch.cuda.is_available())
325 | cuda_version_str = torch.version.cuda
326 | if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version
327 | hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
328 | else: # HIP version
329 | cfg = torch._C._show_config().split('\n')
330 | hip_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'HIP Runtime' in s][0]
331 | miopen_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'MIOpen' in s][0]
332 | cuda_version_str = 'N/A'
333 | hip_compiled_version = torch.version.hip
334 | else:
335 | version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A'
336 | hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
337 |
338 | sys_version = sys.version.replace("\n", " ")
339 |
340 | return SystemEnv(
341 | torch_version=version_str,
342 | is_debug_build=debug_mode_str,
343 | python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1),
344 | python_platform=get_python_platform(),
345 | is_cuda_available=cuda_available_str,
346 | cuda_compiled_version=cuda_version_str,
347 | cuda_runtime_version=get_running_cuda_version(run_lambda),
348 | nvidia_gpu_models=get_gpu_info(run_lambda),
349 | nvidia_driver_version=get_nvidia_driver_version(run_lambda),
350 | cudnn_version=get_cudnn_version(run_lambda),
351 | hip_compiled_version=hip_compiled_version,
352 | hip_runtime_version=hip_runtime_version,
353 | miopen_runtime_version=miopen_runtime_version,
354 | pip_version=pip_version,
355 | pip_packages=pip_list_output,
356 | conda_packages=get_conda_packages(run_lambda),
357 | os=get_os(run_lambda),
358 | libc_version=get_libc_version(),
359 | gcc_version=get_gcc_version(run_lambda),
360 | clang_version=get_clang_version(run_lambda),
361 | cmake_version=get_cmake_version(run_lambda),
362 | caching_allocator_config=get_cachingallocator_config(),
363 | is_xnnpack_available=is_xnnpack_available(),
364 | )
365 |
366 | env_info_fmt = """
367 | PyTorch version: {torch_version}
368 | Is debug build: {is_debug_build}
369 | CUDA used to build PyTorch: {cuda_compiled_version}
370 | ROCM used to build PyTorch: {hip_compiled_version}
371 |
372 | OS: {os}
373 | GCC version: {gcc_version}
374 | Clang version: {clang_version}
375 | CMake version: {cmake_version}
376 | Libc version: {libc_version}
377 |
378 | Python version: {python_version}
379 | Python platform: {python_platform}
380 | Is CUDA available: {is_cuda_available}
381 | CUDA runtime version: {cuda_runtime_version}
382 | GPU models and configuration: {nvidia_gpu_models}
383 | Nvidia driver version: {nvidia_driver_version}
384 | cuDNN version: {cudnn_version}
385 | HIP runtime version: {hip_runtime_version}
386 | MIOpen runtime version: {miopen_runtime_version}
387 | Is XNNPACK available: {is_xnnpack_available}
388 |
389 | Versions of relevant libraries:
390 | {pip_packages}
391 | {conda_packages}
392 | """.strip()
393 |
394 |
395 | def pretty_str(envinfo):
396 | def replace_nones(dct, replacement='Could not collect'):
397 | for key in dct.keys():
398 | if dct[key] is not None:
399 | continue
400 | dct[key] = replacement
401 | return dct
402 |
403 | def replace_bools(dct, true='Yes', false='No'):
404 | for key in dct.keys():
405 | if dct[key] is True:
406 | dct[key] = true
407 | elif dct[key] is False:
408 | dct[key] = false
409 | return dct
410 |
411 | def prepend(text, tag='[prepend]'):
412 | lines = text.split('\n')
413 | updated_lines = [tag + line for line in lines]
414 | return '\n'.join(updated_lines)
415 |
416 | def replace_if_empty(text, replacement='No relevant packages'):
417 | if text is not None and len(text) == 0:
418 | return replacement
419 | return text
420 |
421 | def maybe_start_on_next_line(string):
422 | # If `string` is multiline, prepend a \n to it.
423 | if string is not None and len(string.split('\n')) > 1:
424 | return '\n{}\n'.format(string)
425 | return string
426 |
427 | mutable_dict = envinfo._asdict()
428 |
429 | # If nvidia_gpu_models is multiline, start on the next line
430 | mutable_dict['nvidia_gpu_models'] = \
431 | maybe_start_on_next_line(envinfo.nvidia_gpu_models)
432 |
433 | # If the machine doesn't have CUDA, report some fields as 'No CUDA'
434 | dynamic_cuda_fields = [
435 | 'cuda_runtime_version',
436 | 'nvidia_gpu_models',
437 | 'nvidia_driver_version',
438 | ]
439 | all_cuda_fields = dynamic_cuda_fields + ['cudnn_version']
440 | all_dynamic_cuda_fields_missing = all(
441 | mutable_dict[field] is None for field in dynamic_cuda_fields)
442 | if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:
443 | for field in all_cuda_fields:
444 | mutable_dict[field] = 'No CUDA'
445 | if envinfo.cuda_compiled_version is None:
446 | mutable_dict['cuda_compiled_version'] = 'None'
447 |
448 | # Replace True with Yes, False with No
449 | mutable_dict = replace_bools(mutable_dict)
450 |
451 | # Replace all None objects with 'Could not collect'
452 | mutable_dict = replace_nones(mutable_dict)
453 |
454 | # If either of these are '', replace with 'No relevant packages'
455 | mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages'])
456 | mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages'])
457 |
458 | # Tag conda and pip packages with a prefix
459 | # If they were previously None, they'll show up as ie '[conda] Could not collect'
460 | if mutable_dict['pip_packages']:
461 | mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'],
462 | '[{}] '.format(envinfo.pip_version))
463 | if mutable_dict['conda_packages']:
464 | mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'],
465 | '[conda] ')
466 | return env_info_fmt.format(**mutable_dict)
467 |
468 |
469 | def get_pretty_env_info():
470 | return pretty_str(get_env_info())
471 |
472 |
473 | def main():
474 | print("Collecting environment information...")
475 | output = get_pretty_env_info()
476 | print(output)
477 |
478 | if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'):
479 | minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
480 | if sys.platform == "linux" and os.path.exists(minidump_dir):
481 | dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)]
482 | latest = max(dumps, key=os.path.getctime)
483 | ctime = os.path.getctime(latest)
484 | creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S')
485 | msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \
486 | "if this is related to your bug please include it when you file a report ***"
487 | print(msg, file=sys.stderr)
488 |
489 |
490 |
491 | if __name__ == '__main__':
492 | main()
493 |
--------------------------------------------------------------------------------
/convs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GengDavid/SLCA/169d4e6b91e3ca30caa717200c1944981c4426d1/convs/__init__.py
--------------------------------------------------------------------------------
/convs/cifar_resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference:
3 | https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
4 | '''
5 | import math
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class DownsampleA(nn.Module):
13 | def __init__(self, nIn, nOut, stride):
14 | super(DownsampleA, self).__init__()
15 | assert stride == 2
16 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
17 |
18 | def forward(self, x):
19 | x = self.avg(x)
20 | return torch.cat((x, x.mul(0)), 1)
21 |
22 |
23 | class DownsampleB(nn.Module):
24 | def __init__(self, nIn, nOut, stride):
25 | super(DownsampleB, self).__init__()
26 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
27 | self.bn = nn.BatchNorm2d(nOut)
28 |
29 | def forward(self, x):
30 | x = self.conv(x)
31 | x = self.bn(x)
32 | return x
33 |
34 |
35 | class DownsampleC(nn.Module):
36 | def __init__(self, nIn, nOut, stride):
37 | super(DownsampleC, self).__init__()
38 | assert stride != 1 or nIn != nOut
39 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
40 |
41 | def forward(self, x):
42 | x = self.conv(x)
43 | return x
44 |
45 |
46 | class DownsampleD(nn.Module):
47 | def __init__(self, nIn, nOut, stride):
48 | super(DownsampleD, self).__init__()
49 | assert stride == 2
50 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
51 | self.bn = nn.BatchNorm2d(nOut)
52 |
53 | def forward(self, x):
54 | x = self.conv(x)
55 | x = self.bn(x)
56 | return x
57 |
58 |
59 | class ResNetBasicblock(nn.Module):
60 | expansion = 1
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(ResNetBasicblock, self).__init__()
64 |
65 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
66 | self.bn_a = nn.BatchNorm2d(planes)
67 |
68 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
69 | self.bn_b = nn.BatchNorm2d(planes)
70 |
71 | self.downsample = downsample
72 |
73 | def forward(self, x):
74 | residual = x
75 |
76 | basicblock = self.conv_a(x)
77 | basicblock = self.bn_a(basicblock)
78 | basicblock = F.relu(basicblock, inplace=True)
79 |
80 | basicblock = self.conv_b(basicblock)
81 | basicblock = self.bn_b(basicblock)
82 |
83 | if self.downsample is not None:
84 | residual = self.downsample(x)
85 |
86 | return F.relu(residual + basicblock, inplace=True)
87 |
88 |
89 | class CifarResNet(nn.Module):
90 | """
91 | ResNet optimized for the Cifar Dataset, as specified in
92 | https://arxiv.org/abs/1512.03385.pdf
93 | """
94 |
95 | def __init__(self, block, depth, channels=3):
96 | super(CifarResNet, self).__init__()
97 |
98 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
99 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
100 | layer_blocks = (depth - 2) // 6
101 |
102 | self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
103 | self.bn_1 = nn.BatchNorm2d(16)
104 |
105 | self.inplanes = 16
106 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
107 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
108 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
109 | self.avgpool = nn.AvgPool2d(8)
110 | self.out_dim = 64 * block.expansion
111 | self.fc = nn.Linear(64*block.expansion, 10)
112 |
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | # m.bias.data.zero_()
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 | elif isinstance(m, nn.Linear):
122 | nn.init.kaiming_normal_(m.weight)
123 | m.bias.data.zero_()
124 |
125 | def _make_layer(self, block, planes, blocks, stride=1):
126 | downsample = None
127 | if stride != 1 or self.inplanes != planes * block.expansion:
128 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
129 |
130 | layers = []
131 | layers.append(block(self.inplanes, planes, stride, downsample))
132 | self.inplanes = planes * block.expansion
133 | for i in range(1, blocks):
134 | layers.append(block(self.inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
140 | x = F.relu(self.bn_1(x), inplace=True)
141 |
142 | x_1 = self.stage_1(x) # [bs, 16, 32, 32]
143 | x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
144 | x_3 = self.stage_3(x_2) # [bs, 64, 8, 8]
145 |
146 | pooled = self.avgpool(x_3) # [bs, 64, 1, 1]
147 | features = pooled.view(pooled.size(0), -1) # [bs, 64]
148 |
149 | return {
150 | 'fmaps': [x_1, x_2, x_3],
151 | 'features': features
152 | }
153 |
154 | @property
155 | def last_conv(self):
156 | return self.stage_3[-1].conv_b
157 |
158 |
159 | def resnet20mnist():
160 | """Constructs a ResNet-20 model for MNIST."""
161 | model = CifarResNet(ResNetBasicblock, 20, 1)
162 | return model
163 |
164 |
165 | def resnet32mnist():
166 | """Constructs a ResNet-32 model for MNIST."""
167 | model = CifarResNet(ResNetBasicblock, 32, 1)
168 | return model
169 |
170 |
171 | def resnet20():
172 | """Constructs a ResNet-20 model for CIFAR-10."""
173 | model = CifarResNet(ResNetBasicblock, 20)
174 | return model
175 |
176 |
177 | def resnet32():
178 | """Constructs a ResNet-32 model for CIFAR-10."""
179 | model = CifarResNet(ResNetBasicblock, 32)
180 | return model
181 |
182 |
183 | def resnet44():
184 | """Constructs a ResNet-44 model for CIFAR-10."""
185 | model = CifarResNet(ResNetBasicblock, 44)
186 | return model
187 |
188 |
189 | def resnet56():
190 | """Constructs a ResNet-56 model for CIFAR-10."""
191 | model = CifarResNet(ResNetBasicblock, 56)
192 | return model
193 |
194 |
195 | def resnet110():
196 | """Constructs a ResNet-110 model for CIFAR-10."""
197 | model = CifarResNet(ResNetBasicblock, 110)
198 | return model
199 |
--------------------------------------------------------------------------------
/convs/linears.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference:
3 | https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
4 | '''
5 | import math
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from timm.models.layers.weight_init import trunc_normal_
10 | from timm.models.layers import Mlp
11 | from copy import deepcopy
12 |
13 | class SimpleScaler(nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 | self.weight = nn.Parameter(torch.ones(1), requires_grad=False)
17 |
18 | def forward(self, x):
19 | weight = self.weight
20 | x = x*weight.unsqueeze(1)
21 |
22 | return x
23 |
24 |
25 | class SimpleContinualLinear(nn.Module):
26 | def __init__(self, embed_dim, nb_classes, feat_expand=False, with_norm=False, scale_mu=-1):
27 | super().__init__()
28 |
29 | self.embed_dim = embed_dim
30 | self.feat_expand = feat_expand
31 | self.with_norm = with_norm
32 | self.scale_mu = scale_mu
33 |
34 | if self.scale_mu > 0:
35 | scales = []
36 | scales.append(SimpleScaler())
37 | self.scales = nn.ModuleList(scales)
38 | else:
39 | self.scales = None
40 |
41 | heads = []
42 | single_head = []
43 | if with_norm:
44 | single_head.append(nn.LayerNorm(embed_dim))
45 |
46 | single_head.append(nn.Linear(embed_dim, nb_classes, bias=True))
47 | head = nn.Sequential(*single_head)
48 |
49 | heads.append(head)
50 | self.heads = nn.ModuleList(heads)
51 | for m in self.modules():
52 | if isinstance(m, nn.Linear):
53 | trunc_normal_(m.weight, std=.02)
54 | if m.bias is not None:
55 | nn.init.constant_(m.bias, 0)
56 |
57 |
58 | def update_scale(self):
59 | if self.scales is None:
60 | return
61 | num_old_tasks = len(self.heads)-1
62 | for t_id in range(num_old_tasks): # update scale for old tasks
63 | new_scale = 1+self.scale_mu*(num_old_tasks-t_id)
64 | self.scales[t_id].weight.data = torch.tensor([new_scale]).to(self.scales[t_id].weight)
65 |
66 | def backup(self):
67 | self.old_state_dict = deepcopy(self.state_dict())
68 |
69 | def recall(self):
70 | self.load_state_dict(self.old_state_dict)
71 |
72 | def update(self, nb_classes, freeze_old=True):
73 | single_head = []
74 | if self.with_norm:
75 | single_head.append(nn.LayerNorm(self.embed_dim))
76 |
77 | if self.scale_mu>0:
78 | self.scales.append(SimpleScaler())
79 |
80 | _fc = nn.Linear(self.embed_dim, nb_classes, bias=True)
81 | trunc_normal_(_fc.weight, std=.02)
82 | nn.init.constant_(_fc.bias, 0)
83 | single_head.append(_fc)
84 | new_head = nn.Sequential(*single_head)
85 |
86 |
87 | if freeze_old:
88 | for p in self.heads.parameters():
89 | p.requires_grad=False
90 |
91 | self.heads.append(new_head)
92 |
93 | def forward(self, x):
94 | out = []
95 | for ti in range(len(self.heads)):
96 | fc_inp = x[ti] if self.feat_expand else x
97 | if self.scale_mu>0:
98 | fc_inp = self.scales[ti](fc_inp)
99 | out.append(self.heads[ti](fc_inp))
100 | out = {'logits': torch.cat(out, dim=1)}
101 | return out
102 |
--------------------------------------------------------------------------------
/convs/resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference:
3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | # from torchvision.models.utils import load_state_dict_from_url
8 |
9 |
10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
11 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
12 | 'wide_resnet50_2', 'wide_resnet101_2']
13 |
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
25 | }
26 |
27 |
28 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
29 | """3x3 convolution with padding"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
31 | padding=dilation, groups=groups, bias=False, dilation=dilation)
32 |
33 |
34 | def conv1x1(in_planes, out_planes, stride=1):
35 | """1x1 convolution"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | expansion = 1
41 | __constants__ = ['downsample']
42 |
43 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
44 | base_width=64, dilation=1, norm_layer=None, no_last_relu=False):
45 | super(BasicBlock, self).__init__()
46 | if norm_layer is None:
47 | norm_layer = nn.BatchNorm2d
48 | if groups != 1 or base_width != 64:
49 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
50 | if dilation > 1:
51 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
52 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
53 | self.conv1 = conv3x3(inplanes, planes, stride)
54 | self.bn1 = norm_layer(planes)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.conv2 = conv3x3(planes, planes)
57 | self.bn2 = norm_layer(planes)
58 | self.downsample = downsample
59 | self.stride = stride
60 | self.no_last_relu = no_last_relu
61 |
62 | def forward(self, x):
63 | identity = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 |
72 | if self.downsample is not None:
73 | identity = self.downsample(x)
74 |
75 | out += identity
76 | if not self.no_last_relu:
77 | out = self.relu(out)
78 |
79 | return out
80 |
81 |
82 | class Bottleneck(nn.Module):
83 | expansion = 4
84 | __constants__ = ['downsample']
85 |
86 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
87 | base_width=64, dilation=1, norm_layer=None, no_last_relu=False):
88 | super(Bottleneck, self).__init__()
89 | if norm_layer is None:
90 | norm_layer = nn.BatchNorm2d
91 | width = int(planes * (base_width / 64.)) * groups
92 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
93 | self.conv1 = conv1x1(inplanes, width)
94 | self.bn1 = norm_layer(width)
95 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
96 | self.bn2 = norm_layer(width)
97 | self.conv3 = conv1x1(width, planes * self.expansion)
98 | self.bn3 = norm_layer(planes * self.expansion)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.downsample = downsample
101 | self.stride = stride
102 | self.no_last_relu = no_last_relu
103 |
104 | def forward(self, x):
105 | identity = x
106 |
107 | out = self.conv1(x)
108 | out = self.bn1(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv2(out)
112 | out = self.bn2(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv3(out)
116 | out = self.bn3(out)
117 |
118 | if self.downsample is not None:
119 | identity = self.downsample(x)
120 |
121 | out += identity
122 | if not self.no_last_relu:
123 | out = self.relu(out)
124 |
125 | return out
126 |
127 |
128 |
129 |
130 | # 修改Resnet的实现。
131 | class ResNet(nn.Module):
132 |
133 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
134 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
135 | norm_layer=None, cifar=False, no_last_relu=False):
136 | super(ResNet, self).__init__()
137 | if norm_layer is None:
138 | norm_layer = nn.BatchNorm2d
139 | self._norm_layer = norm_layer
140 | self.cifar = cifar
141 |
142 | self.inplanes = 64
143 | self.dilation = 1
144 | if replace_stride_with_dilation is None:
145 | # each element in the tuple indicates if we should replace
146 | # the 2x2 stride with a dilated convolution instead
147 | replace_stride_with_dilation = [False, False, False]
148 | if len(replace_stride_with_dilation) != 3:
149 | raise ValueError("replace_stride_with_dilation should be None "
150 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
151 | self.groups = groups
152 | self.base_width = width_per_group
153 | if self.cifar:
154 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
155 | else:
156 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
157 | self.bn1 = norm_layer(self.inplanes)
158 | self.relu = nn.ReLU(inplace=True)
159 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Removed in _forward_impl for cifar
160 | self.layer1 = self._make_layer(block, 64, layers[0])
161 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
162 | dilate=replace_stride_with_dilation[0])
163 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
164 | dilate=replace_stride_with_dilation[1])
165 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
166 | dilate=replace_stride_with_dilation[2], no_last_relu=no_last_relu)
167 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
168 | self.out_dim = 512 * block.expansion
169 | # self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl
170 |
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d):
173 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
174 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
175 | nn.init.constant_(m.weight, 1)
176 | nn.init.constant_(m.bias, 0)
177 |
178 | # Zero-initialize the last BN in each residual branch,
179 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
180 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
181 | if zero_init_residual:
182 | for m in self.modules():
183 | if isinstance(m, Bottleneck):
184 | nn.init.constant_(m.bn3.weight, 0)
185 | elif isinstance(m, BasicBlock):
186 | nn.init.constant_(m.bn2.weight, 0)
187 |
188 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, no_last_relu=False):
189 | norm_layer = self._norm_layer
190 | downsample = None
191 | previous_dilation = self.dilation
192 | if dilate:
193 | self.dilation *= stride
194 | stride = 1
195 | if stride != 1 or self.inplanes != planes * block.expansion:
196 | downsample = nn.Sequential(
197 | conv1x1(self.inplanes, planes * block.expansion, stride),
198 | norm_layer(planes * block.expansion),
199 | )
200 |
201 | layers = []
202 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
203 | self.base_width, previous_dilation, norm_layer))
204 | self.inplanes = planes * block.expansion
205 | for bid in range(1, blocks):
206 | layers.append(block(self.inplanes, planes, groups=self.groups,
207 | base_width=self.base_width, dilation=self.dilation,
208 | norm_layer=norm_layer, no_last_relu=no_last_relu if bid==blocks-1 else False))
209 |
210 | return nn.Sequential(*layers)
211 |
212 | def _forward_impl(self, x):
213 | # See note [TorchScript super()]
214 | x = self.conv1(x) # [bs, 64, 32, 32]
215 | x = self.bn1(x)
216 | x = self.relu(x)
217 | if not self.cifar:
218 | x = self.maxpool(x)
219 |
220 | x_1 = self.layer1(x) # [bs, 128, 32, 32]
221 | x_2 = self.layer2(x_1) # [bs, 256, 16, 16]
222 | x_3 = self.layer3(x_2) # [bs, 512, 8, 8]
223 | x_4 = self.layer4(x_3) # [bs, 512, 4, 4]
224 |
225 | pooled = self.avgpool(x_4) # [bs, 512, 1, 1]
226 | features = torch.flatten(pooled, 1) # [bs, 512]
227 | # x = self.fc(x)
228 |
229 | return {
230 | 'fmaps': [x_1, x_2, x_3, x_4],
231 | 'features': features
232 | }
233 |
234 | def forward(self, x):
235 | return self._forward_impl(x)
236 |
237 | @property
238 | def last_conv(self):
239 | if hasattr(self.layer4[-1], 'conv3'):
240 | return self.layer4[-1].conv3
241 | else:
242 | return self.layer4[-1].conv2
243 |
244 |
245 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
246 | model = ResNet(block, layers, **kwargs)
247 | if pretrained:
248 | state_dict = load_state_dict_from_url(model_urls[arch],
249 | progress=progress)
250 | model.load_state_dict(state_dict)
251 | return model
252 |
253 |
254 | def resnet18(pretrained=False, progress=True, **kwargs):
255 | r"""ResNet-18 model from
256 | `"Deep Residual Learning for Image Recognition" `_
257 | Args:
258 | pretrained (bool): If True, returns a model pre-trained on ImageNet
259 | progress (bool): If True, displays a progress bar of the download to stderr
260 | """
261 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
262 | **kwargs)
263 |
264 |
265 | def resnet34(pretrained=False, progress=True, **kwargs):
266 | r"""ResNet-34 model from
267 | `"Deep Residual Learning for Image Recognition" `_
268 | Args:
269 | pretrained (bool): If True, returns a model pre-trained on ImageNet
270 | progress (bool): If True, displays a progress bar of the download to stderr
271 | """
272 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
273 | **kwargs)
274 |
275 |
276 | def resnet50(pretrained=False, progress=True, **kwargs):
277 | r"""ResNet-50 model from
278 | `"Deep Residual Learning for Image Recognition" `_
279 | Args:
280 | pretrained (bool): If True, returns a model pre-trained on ImageNet
281 | progress (bool): If True, displays a progress bar of the download to stderr
282 | """
283 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
284 | **kwargs)
285 |
286 |
287 | def resnet101(pretrained=False, progress=True, **kwargs):
288 | r"""ResNet-101 model from
289 | `"Deep Residual Learning for Image Recognition" `_
290 | Args:
291 | pretrained (bool): If True, returns a model pre-trained on ImageNet
292 | progress (bool): If True, displays a progress bar of the download to stderr
293 | """
294 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
295 | **kwargs)
296 |
297 |
298 | def resnet152(pretrained=False, progress=True, **kwargs):
299 | r"""ResNet-152 model from
300 | `"Deep Residual Learning for Image Recognition" `_
301 | Args:
302 | pretrained (bool): If True, returns a model pre-trained on ImageNet
303 | progress (bool): If True, displays a progress bar of the download to stderr
304 | """
305 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
306 | **kwargs)
307 |
308 |
309 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
310 | r"""ResNeXt-50 32x4d model from
311 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
312 | Args:
313 | pretrained (bool): If True, returns a model pre-trained on ImageNet
314 | progress (bool): If True, displays a progress bar of the download to stderr
315 | """
316 | kwargs['groups'] = 32
317 | kwargs['width_per_group'] = 4
318 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
319 | pretrained, progress, **kwargs)
320 |
321 |
322 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
323 | r"""ResNeXt-101 32x8d model from
324 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
325 | Args:
326 | pretrained (bool): If True, returns a model pre-trained on ImageNet
327 | progress (bool): If True, displays a progress bar of the download to stderr
328 | """
329 | kwargs['groups'] = 32
330 | kwargs['width_per_group'] = 8
331 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
332 | pretrained, progress, **kwargs)
333 |
334 |
335 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
336 | r"""Wide ResNet-50-2 model from
337 | `"Wide Residual Networks" `_
338 | The model is the same as ResNet except for the bottleneck number of channels
339 | which is twice larger in every block. The number of channels in outer 1x1
340 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
341 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
342 | Args:
343 | pretrained (bool): If True, returns a model pre-trained on ImageNet
344 | progress (bool): If True, displays a progress bar of the download to stderr
345 | """
346 | kwargs['width_per_group'] = 64 * 2
347 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
348 | pretrained, progress, **kwargs)
349 |
350 |
351 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
352 | r"""Wide ResNet-101-2 model from
353 | `"Wide Residual Networks" `_
354 | The model is the same as ResNet except for the bottleneck number of channels
355 | which is twice larger in every block. The number of channels in outer 1x1
356 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
357 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
358 | Args:
359 | pretrained (bool): If True, returns a model pre-trained on ImageNet
360 | progress (bool): If True, displays a progress bar of the download to stderr
361 | """
362 | kwargs['width_per_group'] = 64 * 2
363 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
364 | pretrained, progress, **kwargs)
365 |
--------------------------------------------------------------------------------
/convs/vits.py:
--------------------------------------------------------------------------------
1 | """ Vision Transformer (ViT) in PyTorch
2 |
3 | A PyTorch implement of Vision Transformers as described in:
4 |
5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
6 | - https://arxiv.org/abs/2010.11929
7 |
8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9 | - https://arxiv.org/abs/2106.10270
10 |
11 | The official jax code is released and available at https://github.com/google-research/vision_transformer
12 |
13 | DeiT model defs and weights from https://github.com/facebookresearch/deit,
14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
15 |
16 | Acknowledgments:
17 | * The paper authors for releasing code and weights, thanks!
18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
19 | for some einops/einsum fun
20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
22 |
23 | Hacked together by / Copyright 2020, Ross Wightman
24 | """
25 | import math
26 | import logging
27 | from functools import partial
28 | from collections import OrderedDict
29 | from copy import deepcopy
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 |
35 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
36 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv
37 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
38 | from timm.models.registry import register_model
39 |
40 | _logger = logging.getLogger(__name__)
41 |
42 |
43 | def _cfg(url='', **kwargs):
44 | return {
45 | 'url': url,
46 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
47 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
48 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
49 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
50 | **kwargs
51 | }
52 |
53 |
54 | default_cfgs = {
55 | # patch models (weights from official Google JAX impl)
56 | 'vit_tiny_patch16_224': _cfg(
57 | url='https://storage.googleapis.com/vit_models/augreg/'
58 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
59 | 'vit_tiny_patch16_384': _cfg(
60 | url='https://storage.googleapis.com/vit_models/augreg/'
61 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
62 | input_size=(3, 384, 384), crop_pct=1.0),
63 | 'vit_small_patch32_224': _cfg(
64 | url='https://storage.googleapis.com/vit_models/augreg/'
65 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
66 | 'vit_small_patch32_384': _cfg(
67 | url='https://storage.googleapis.com/vit_models/augreg/'
68 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
69 | input_size=(3, 384, 384), crop_pct=1.0),
70 | 'vit_small_patch16_224': _cfg(
71 | url='https://storage.googleapis.com/vit_models/augreg/'
72 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
73 | 'vit_small_patch16_384': _cfg(
74 | url='https://storage.googleapis.com/vit_models/augreg/'
75 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
76 | input_size=(3, 384, 384), crop_pct=1.0),
77 | 'vit_base_patch32_224': _cfg(
78 | url='https://storage.googleapis.com/vit_models/augreg/'
79 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
80 | 'vit_base_patch32_384': _cfg(
81 | url='https://storage.googleapis.com/vit_models/augreg/'
82 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
83 | input_size=(3, 384, 384), crop_pct=1.0),
84 | 'vit_base_patch16_224': _cfg(
85 | url='https://storage.googleapis.com/vit_models/augreg/'
86 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
87 | 'vit_base_patch16_384': _cfg(
88 | url='https://storage.googleapis.com/vit_models/augreg/'
89 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
90 | input_size=(3, 384, 384), crop_pct=1.0),
91 | 'vit_base_patch8_224': _cfg(
92 | url='https://storage.googleapis.com/vit_models/augreg/'
93 | 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
94 | 'vit_large_patch32_224': _cfg(
95 | url='', # no official model weights for this combo, only for in21k
96 | ),
97 | 'vit_large_patch32_384': _cfg(
98 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
99 | input_size=(3, 384, 384), crop_pct=1.0),
100 | 'vit_large_patch16_224': _cfg(
101 | url='https://storage.googleapis.com/vit_models/augreg/'
102 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
103 | 'vit_large_patch16_384': _cfg(
104 | url='https://storage.googleapis.com/vit_models/augreg/'
105 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
106 | input_size=(3, 384, 384), crop_pct=1.0),
107 |
108 | 'vit_huge_patch14_224': _cfg(url=''),
109 | 'vit_giant_patch14_224': _cfg(url=''),
110 | 'vit_gigantic_patch14_224': _cfg(url=''),
111 |
112 | # patch models, imagenet21k (weights from official Google JAX impl)
113 | 'vit_tiny_patch16_224_in21k': _cfg(
114 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
115 | num_classes=21843),
116 | 'vit_small_patch32_224_in21k': _cfg(
117 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
118 | num_classes=21843),
119 | 'vit_small_patch16_224_in21k': _cfg(
120 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
121 | num_classes=21843),
122 | 'vit_base_patch32_224_in21k': _cfg(
123 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
124 | num_classes=21843),
125 | 'vit_base_patch16_224_in21k': _cfg(
126 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
127 | #url='./B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
128 | num_classes=21843),
129 | 'vit_base_patch8_224_in21k': _cfg(
130 | url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
131 | num_classes=21843),
132 | 'vit_large_patch32_224_in21k': _cfg(
133 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
134 | num_classes=21843),
135 | 'vit_large_patch16_224_in21k': _cfg(
136 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
137 | num_classes=21843),
138 | 'vit_huge_patch14_224_in21k': _cfg(
139 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
140 | hf_hub='timm/vit_huge_patch14_224_in21k',
141 | num_classes=21843),
142 |
143 | # SAM trained models (https://arxiv.org/abs/2106.01548)
144 | 'vit_base_patch32_sam_224': _cfg(
145 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
146 | 'vit_base_patch16_sam_224': _cfg(
147 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
148 |
149 | # deit models (FB weights)
150 | 'deit_tiny_patch16_224': _cfg(
151 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
152 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
153 | 'deit_small_patch16_224': _cfg(
154 | url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
155 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
156 | 'deit_base_patch16_224': _cfg(
157 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
158 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
159 | 'deit_base_patch16_384': _cfg(
160 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
161 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
162 | 'deit_tiny_distilled_patch16_224': _cfg(
163 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
164 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
165 | 'deit_small_distilled_patch16_224': _cfg(
166 | url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
167 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
168 | 'deit_base_distilled_patch16_224': _cfg(
169 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
170 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
171 | 'deit_base_distilled_patch16_384': _cfg(
172 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
173 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
174 | classifier=('head', 'head_dist')),
175 |
176 | # ViT ImageNet-21K-P pretraining by MILL
177 | 'vit_base_patch16_224_miil_in21k': _cfg(
178 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
179 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
180 | ),
181 | 'vit_base_patch16_224_miil': _cfg(
182 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
183 | '/vit_base_patch16_224_1k_miil_84_4.pth',
184 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
185 | ),
186 | }
187 |
188 |
189 | class LoraMlp(nn.Module):
190 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
191 | lora_rank=4, lora_scale=0.8):
192 | super().__init__()
193 | out_features = out_features or in_features
194 | hidden_features = hidden_features or in_features
195 | self.hidden_features = hidden_features
196 | self.fc1 = nn.Linear(in_features, hidden_features)
197 | self.act = act_layer()
198 | self.fc2 = nn.Linear(hidden_features, out_features)
199 | self.drop = nn.Dropout(drop)
200 |
201 | self.lora_rank = lora_rank
202 | self.lora_1_B = torch.nn.Parameter(torch.zeros(hidden_features, self.lora_rank))
203 | self.lora_1_A = torch.nn.Parameter(torch.zeros(self.lora_rank, in_features))
204 | trunc_normal_(self.lora_1_A, std=0.02)
205 |
206 | self.lora_2_B = torch.nn.Parameter(torch.zeros(in_features, self.lora_rank))
207 | self.lora_2_A = torch.nn.Parameter(torch.zeros(self.lora_rank, hidden_features))
208 | trunc_normal_(self.lora_2_A, std=0.02)
209 |
210 | self.lora_scale = lora_scale
211 |
212 | def init_lora(self):
213 | U, S, Vh = torch.linalg.svd(self.fc1.weight)
214 | self.lora_1_A.data[:] = Vh[:self.lora_rank, :]
215 |
216 | U, S, Vh = torch.linalg.svd(self.fc2.weight)
217 | self.lora_2_A.data[:] = Vh[:self.lora_rank, :]
218 |
219 | def forward(self, x):
220 | lora_w = torch.matmul(self.lora_1_B, self.lora_1_A)*self.lora_scale
221 |
222 | # x = self.fc1(x)
223 | x = F.linear(x, lora_w+self.fc1.weight, bias=self.fc1.bias)
224 | x = self.act(x)
225 | x = self.drop(x)
226 |
227 | lora_w = torch.matmul(self.lora_2_B, self.lora_2_A)*self.lora_scale
228 |
229 | # x = self.fc2(x)
230 | x = F.linear(x, lora_w+self.fc2.weight, bias=self.fc2.bias)
231 | x = self.drop(x)
232 |
233 | return x
234 |
235 |
236 | class LoraAttention(nn.Module):
237 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., lora_rank=4, lora_scale=0.8):
238 | super().__init__()
239 | self.num_heads = num_heads
240 | head_dim = dim // num_heads
241 | self.scale = head_dim ** -0.5
242 |
243 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
244 | self.attn_drop = nn.Dropout(attn_drop)
245 | self.proj = nn.Linear(dim, dim)
246 | self.proj_drop = nn.Dropout(proj_drop)
247 |
248 |
249 | self.lora_rank = lora_rank
250 | self.lora_B = torch.nn.Parameter(torch.zeros(dim*3, lora_rank))
251 | self.lora_A = torch.nn.Parameter(torch.zeros(lora_rank, dim))
252 | trunc_normal_(self.lora_A, std=0.02)
253 |
254 |
255 | self.lora_scale = lora_scale
256 |
257 | def init_lora(self):
258 | U, S, Vh = torch.linalg.svd(self.qkv.weight)
259 | self.lora_A.data[:] = Vh[:self.lora_rank, :]
260 |
261 | def forward(self, x):
262 | B, N, C = x.shape
263 |
264 | lora_w = torch.matmul(self.lora_B, self.lora_A)*self.lora_scale
265 | qkv = F.linear(x, lora_w+self.qkv.weight, bias=self.qkv.bias)
266 | qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
267 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
268 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
269 |
270 | attn = (q @ k.transpose(-2, -1)) * self.scale
271 | attn = attn.softmax(dim=-1)
272 | attn = self.attn_drop(attn)
273 |
274 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
275 |
276 | x = self.proj(x)
277 | x = self.proj_drop(x)
278 | return x
279 |
280 |
281 | class Attention(nn.Module):
282 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
283 | super().__init__()
284 | self.num_heads = num_heads
285 | head_dim = dim // num_heads
286 | self.scale = head_dim ** -0.5
287 |
288 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
289 | self.attn_drop = nn.Dropout(attn_drop)
290 | self.proj = nn.Linear(dim, dim)
291 | self.proj_drop = nn.Dropout(proj_drop)
292 |
293 | def forward(self, x):
294 | B, N, C = x.shape
295 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
296 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
297 |
298 | attn = (q @ k.transpose(-2, -1)) * self.scale
299 | attn = attn.softmax(dim=-1)
300 | attn = self.attn_drop(attn)
301 |
302 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
303 | x = self.proj(x)
304 | x = self.proj_drop(x)
305 | return x
306 |
307 |
308 | class Block(nn.Module):
309 |
310 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
311 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
312 | super().__init__()
313 | self.norm1 = norm_layer(dim)
314 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
315 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
316 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
317 | self.norm2 = norm_layer(dim)
318 | mlp_hidden_dim = int(dim * mlp_ratio)
319 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
320 |
321 | def forward(self, x):
322 | x = x + self.drop_path(self.attn(self.norm1(x)))
323 | x = x + self.drop_path(self.mlp(self.norm2(x)))
324 | return x
325 |
326 |
327 | class VisionTransformer(nn.Module):
328 | """ Vision Transformer
329 |
330 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
331 | - https://arxiv.org/abs/2010.11929
332 |
333 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
334 | - https://arxiv.org/abs/2012.12877
335 | """
336 |
337 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
338 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
339 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
340 | act_layer=None, weight_init='', global_pool=False,lora_rank=-1):
341 | """
342 | Args:
343 | img_size (int, tuple): input image size
344 | patch_size (int, tuple): patch size
345 | in_chans (int): number of input channels
346 | num_classes (int): number of classes for classification head
347 | embed_dim (int): embedding dimension
348 | depth (int): depth of transformer
349 | num_heads (int): number of attention heads
350 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
351 | qkv_bias (bool): enable bias for qkv if True
352 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
353 | distilled (bool): model includes a distillation token and head as in DeiT models
354 | drop_rate (float): dropout rate
355 | attn_drop_rate (float): attention dropout rate
356 | drop_path_rate (float): stochastic depth rate
357 | embed_layer (nn.Module): patch embedding layer
358 | norm_layer: (nn.Module): normalization layer
359 | weight_init: (str): weight init scheme
360 | """
361 | super().__init__()
362 | self.num_classes = num_classes
363 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
364 | self.out_dim = embed_dim
365 | self.num_tokens = 2 if distilled else 1
366 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
367 | act_layer = act_layer or nn.GELU
368 |
369 | self.global_pool = global_pool
370 |
371 | self.patch_embed = embed_layer(
372 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
373 | num_patches = self.patch_embed.num_patches
374 |
375 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
376 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
377 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
378 | self.pos_drop = nn.Dropout(p=drop_rate)
379 |
380 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
381 | self.blocks = nn.ModuleList([
382 | Block(
383 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
384 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
385 | for i in range(depth)])
386 | self.norm = norm_layer(embed_dim)
387 |
388 | # Representation layer
389 | if representation_size and not distilled:
390 | self.num_features = representation_size
391 | self.pre_logits = nn.Sequential(OrderedDict([
392 | ('fc', nn.Linear(embed_dim, representation_size)),
393 | ('act', nn.Tanh())
394 | ]))
395 | else:
396 | self.pre_logits = nn.Identity()
397 |
398 | # Classifier head(s)
399 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
400 | self.head_dist = None
401 | if distilled:
402 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
403 |
404 | self.init_weights(weight_init)
405 |
406 |
407 | self.lora_rank = lora_rank
408 | if self.lora_rank>0:
409 | self.with_lora = True
410 | self.lora_lp = depth
411 | self.attn_lora = True
412 | self.mlp_lora = True
413 | if self.mlp_lora:
414 | for b_idx in range(self.lora_lp):
415 | self.blocks[b_idx].mlp = LoraMlp(in_features=self.embed_dim, hidden_features=int(self.embed_dim*mlp_ratio), act_layer=act_layer, drop=drop_rate, lora_rank=lora_rank)
416 | if self.attn_lora:
417 | for b_idx in range(self.lora_lp):
418 | self.blocks[b_idx].attn = LoraAttention(self.embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop_rate, proj_drop=drop_rate, lora_rank=lora_rank)
419 | else:
420 | self.with_lora = False
421 |
422 | def get_adapter(self, embed_dim):
423 | return nn.Sequential(
424 | nn.Linear(embed_dim, embed_dim*3, bias=False),
425 | nn.LayerNorm(embed_dim*3),
426 | nn.GELU(),
427 | nn.Linear(embed_dim*3, embed_dim, bias=False),
428 | nn.LayerNorm(embed_dim),
429 | nn.GELU(),
430 | nn.Linear(embed_dim, embed_dim, bias=True),
431 | nn.Sigmoid()
432 | )
433 |
434 |
435 | def init_weights(self, mode=''):
436 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
437 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
438 | trunc_normal_(self.pos_embed, std=.02)
439 | if self.dist_token is not None:
440 | trunc_normal_(self.dist_token, std=.02)
441 | if mode.startswith('jax'):
442 | # leave cls token as zeros to match jax impl
443 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
444 | else:
445 | trunc_normal_(self.cls_token, std=.02)
446 | self.apply(_init_vit_weights)
447 |
448 | def _init_weights(self, m):
449 | # this fn left here for compat with downstream users
450 | _init_vit_weights(m)
451 |
452 | @torch.jit.ignore()
453 | def load_pretrained(self, checkpoint_path, prefix=''):
454 | _load_weights(self, checkpoint_path, prefix)
455 |
456 | @torch.jit.ignore
457 | def no_weight_decay(self):
458 | return {'pos_embed', 'cls_token', 'dist_token'}
459 |
460 | def get_classifier(self):
461 | if self.dist_token is None:
462 | return self.head
463 | else:
464 | return self.head, self.head_dist
465 |
466 | def reset_classifier(self, num_classes, global_pool=''):
467 | self.num_classes = num_classes
468 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
469 | if self.num_tokens == 2:
470 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
471 |
472 | def forward_features(self, x, layer_feat=False):
473 | img = x
474 | x = self.patch_embed(x)
475 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
476 | if self.dist_token is None:
477 | x = torch.cat((cls_token, x), dim=1)
478 | else:
479 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
480 | x = self.pos_drop(x + self.pos_embed)
481 | # x = self.blocks(x)
482 | feats = []
483 | feats_l = []
484 | for b_id, block in enumerate(self.blocks):
485 | x = block(x)
486 | if layer_feat:
487 | feats_l.append(x)
488 | if b_id == len(self.blocks)-2:
489 | penultimate_feat = x.clone()
490 |
491 | if layer_feat:
492 | return feats_l
493 |
494 | if self.global_pool:
495 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
496 | return self.norm(x)
497 |
498 | x = self.norm(x)
499 | if self.dist_token is None:
500 | return self.pre_logits(x[:, 0])
501 | else:
502 | return x[:, 0] # , x[:, 1]
503 |
504 | def forward(self, x, layer_feat=False):
505 | x = self.forward_features(x, layer_feat)
506 | x = {'features': x}
507 | #if self.head_dist is not None:
508 | # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
509 | # if self.training and not torch.jit.is_scripting():
510 | # # during inference, return the average of both classifier predictions
511 | # return x, x_dist
512 | # else:
513 | # return (x + x_dist) / 2
514 | #else:
515 | # x = self.head(x)
516 | return x
517 |
518 |
519 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
520 | """ ViT weight initialization
521 | * When called without n, head_bias, jax_impl args it will behave exactly the same
522 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
523 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
524 | """
525 | if isinstance(module, nn.Linear):
526 | if name.startswith('head'):
527 | nn.init.zeros_(module.weight)
528 | nn.init.constant_(module.bias, head_bias)
529 | elif name.startswith('pre_logits'):
530 | lecun_normal_(module.weight)
531 | nn.init.zeros_(module.bias)
532 | else:
533 | if jax_impl:
534 | nn.init.xavier_uniform_(module.weight)
535 | if module.bias is not None:
536 | if 'mlp' in name:
537 | nn.init.normal_(module.bias, std=1e-6)
538 | else:
539 | nn.init.zeros_(module.bias)
540 | else:
541 | trunc_normal_(module.weight, std=.02)
542 | if module.bias is not None:
543 | nn.init.zeros_(module.bias)
544 | elif jax_impl and isinstance(module, nn.Conv2d):
545 | # NOTE conv was left to pytorch default in my original init
546 | lecun_normal_(module.weight)
547 | if module.bias is not None:
548 | nn.init.zeros_(module.bias)
549 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
550 | nn.init.zeros_(module.bias)
551 | nn.init.ones_(module.weight)
552 |
553 | @torch.no_grad()
554 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
555 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
556 | """
557 | import numpy as np
558 |
559 | def _n2p(w, t=True):
560 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
561 | w = w.flatten()
562 | if t:
563 | if w.ndim == 4:
564 | w = w.transpose([3, 2, 0, 1])
565 | elif w.ndim == 3:
566 | w = w.transpose([2, 0, 1])
567 | elif w.ndim == 2:
568 | w = w.transpose([1, 0])
569 | return torch.from_numpy(w)
570 |
571 | w = np.load(checkpoint_path)
572 | if not prefix and 'opt/target/embedding/kernel' in w:
573 | prefix = 'opt/target/'
574 |
575 | if hasattr(model.patch_embed, 'backbone'):
576 | # hybrid
577 | backbone = model.patch_embed.backbone
578 | stem_only = not hasattr(backbone, 'stem')
579 | stem = backbone if stem_only else backbone.stem
580 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
581 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
582 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
583 | if not stem_only:
584 | for i, stage in enumerate(backbone.stages):
585 | for j, block in enumerate(stage.blocks):
586 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
587 | for r in range(3):
588 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
589 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
590 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
591 | if block.downsample is not None:
592 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
593 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
594 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
595 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
596 | else:
597 | embed_conv_w = adapt_input_conv(
598 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
599 | model.patch_embed.proj.weight.copy_(embed_conv_w)
600 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
601 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
602 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
603 | if pos_embed_w.shape != model.pos_embed.shape:
604 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
605 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
606 | model.pos_embed.copy_(pos_embed_w)
607 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
608 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
609 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
610 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
611 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
612 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
613 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
614 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
615 | for i, block in enumerate(model.blocks.children()):
616 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
617 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
618 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
619 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
620 | block.attn.qkv.weight.copy_(torch.cat([
621 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
622 | block.attn.qkv.bias.copy_(torch.cat([
623 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
624 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
625 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
626 | for r in range(2):
627 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
628 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
629 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
630 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
631 |
632 |
633 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
634 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
635 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
636 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
637 | ntok_new = posemb_new.shape[1]
638 | if num_tokens:
639 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
640 | ntok_new -= num_tokens
641 | else:
642 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
643 | gs_old = int(math.sqrt(len(posemb_grid)))
644 | if not len(gs_new): # backwards compatibility
645 | gs_new = [int(math.sqrt(ntok_new))] * 2
646 | assert len(gs_new) >= 2
647 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
648 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
649 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
650 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
651 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
652 | return posemb
653 |
654 |
655 | def checkpoint_filter_fn(state_dict, model):
656 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
657 | out_dict = {}
658 | if 'model' in state_dict:
659 | # For deit models
660 | state_dict = state_dict['model']
661 | for k, v in state_dict.items():
662 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
663 | # For old models that I trained prior to conv based patchification
664 | O, I, H, W = model.patch_embed.proj.weight.shape
665 | v = v.reshape(O, -1, H, W)
666 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
667 | # To resize pos embedding when using model at different size from pretrained weights
668 | v = resize_pos_embed(
669 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
670 | out_dict[k] = v
671 | return out_dict
672 |
673 |
674 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
675 | default_cfg = default_cfg or default_cfgs[variant]
676 | if kwargs.get('features_only', None):
677 | raise RuntimeError('features_only not implemented for Vision Transformer models.')
678 |
679 | # NOTE this extra code to support handling of repr size for in21k pretrained models
680 | default_num_classes = default_cfg['num_classes']
681 | num_classes = kwargs.get('num_classes', default_num_classes)
682 | repr_size = kwargs.pop('representation_size', None)
683 | if repr_size is not None and num_classes != default_num_classes:
684 | # Remove representation layer if fine-tuning. This may not always be the desired action,
685 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
686 | _logger.warning("Removing representation layer for fine-tuning.")
687 | repr_size = None
688 |
689 | model = build_model_with_cfg(
690 | VisionTransformer, variant, pretrained,
691 | default_cfg=default_cfg,
692 | representation_size=repr_size,
693 | pretrained_filter_fn=checkpoint_filter_fn,
694 | pretrained_custom_load='npz' in default_cfg['url'],
695 | **kwargs)
696 | return model
697 |
698 |
699 |
700 | @register_model
701 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
702 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
703 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
704 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
705 | """
706 | model_kwargs = dict(
707 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
708 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
709 | del model.head
710 | del model.norm
711 | model.norm = nn.LayerNorm(768)
712 | return model
713 |
714 | @register_model
715 | def vit_base_patch16_224_mocov3(pretrained=False, **kwargs):
716 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
717 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
718 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
719 | """
720 | model_kwargs = dict(
721 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
722 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=False, **model_kwargs)
723 | del model.head
724 | ckpt = torch.load('mocov3-vit-base-300ep.pth', map_location='cpu')['model']
725 | state_dict = model.state_dict()
726 | state_dict.update(ckpt)
727 | model.load_state_dict(state_dict)
728 | del model.norm
729 | model.norm = nn.LayerNorm(768)
730 | return model
731 |
732 |
733 |
734 | @register_model
735 | def vit_base_lora_patch16_224_in21k(pretrained=False, lora_rank=4, **kwargs):
736 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
737 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
738 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
739 | """
740 | model_kwargs = dict(
741 | patch_size=16, embed_dim=768, depth=12, num_heads=12, lora_rank=lora_rank, **kwargs)
742 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
743 | del model.head
744 | del model.norm
745 | model.norm = nn.LayerNorm(768)
746 |
747 | return model
748 |
749 | @register_model
750 | def vit_base_lora_patch16_224_mocov3(pretrained=False, lora_rank=4, **kwargs):
751 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
752 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
753 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
754 | """
755 | model_kwargs = dict(
756 | patch_size=16, embed_dim=768, depth=12, num_heads=12, lora_rank=lora_rank, **kwargs)
757 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=False, **model_kwargs)
758 | del model.head
759 | ckpt = torch.load('mocov3-vit-base-300ep.pth', map_location='cpu')['model']
760 | state_dict = model.state_dict()
761 | state_dict.update(ckpt)
762 | model.load_state_dict(state_dict)
763 | del model.norm
764 | model.norm = nn.LayerNorm(768)
765 | return model
766 |
--------------------------------------------------------------------------------
/evaluator.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | import copy
4 | import torch
5 | from utils import factory
6 | from utils.data_manager import DataManager
7 | from utils.toolkit import count_parameters
8 | import os
9 | import numpy as np
10 |
11 |
12 | def test(args):
13 | seed_list = copy.deepcopy(args['seed'])
14 | device = copy.deepcopy(args['device'])
15 |
16 | for seed in seed_list:
17 | args['seed'] = seed
18 | args['device'] = device
19 | _test(args)
20 |
21 |
22 | def _test(args):
23 | logfilename = 'logs/{}/{}_test_{}_{}_{}_{}_{}_{}'.format(args['model_name'], args['prefix'], args['seed'], args['model_name'], args['convnet_type'],
24 | args['dataset'], args['init_cls'], args['increment'])
25 | logging.basicConfig(
26 | level=logging.INFO,
27 | format='%(asctime)s [%(filename)s] => %(message)s',
28 | handlers=[
29 | logging.FileHandler(filename=logfilename + '.log'),
30 | logging.StreamHandler(sys.stdout)
31 | ]
32 | )
33 |
34 | _set_random()
35 | _set_device(args)
36 | print_args(args)
37 | data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'])
38 | model = factory.get_model(args['model_name'], args)
39 |
40 | cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []}
41 | for task in range(data_manager.nb_tasks):
42 | logging.info('All params: {}'.format(count_parameters(model._network)))
43 | # logging.info('Trainable params: {}'.format(count_parameters(model._network, True)))
44 | # model.incremental_train(data_manager)
45 | model.incremental_update(data_manager)
46 | cnn_accy, nme_accy = model.eval_task()
47 | model.after_task()
48 |
49 | if nme_accy is not None:
50 | logging.info('CNN: {}'.format(cnn_accy['grouped']))
51 | logging.info('NME: {}'.format(nme_accy['grouped']))
52 |
53 | cnn_curve['top1'].append(cnn_accy['top1'])
54 | cnn_curve['top5'].append(cnn_accy['top5'])
55 |
56 | nme_curve['top1'].append(nme_accy['top1'])
57 | nme_curve['top5'].append(nme_accy['top5'])
58 |
59 | logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
60 | logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean()))
61 | if 'task_acc' in cnn_accy.keys():
62 | logging.info('Task: {}'.format(cnn_accy['task_acc']))
63 | logging.info('CNN top5 curve: {}'.format(cnn_curve['top5']))
64 | logging.info('NME top1 curve: {}'.format(nme_curve['top1']))
65 | logging.info('NME top5 curve: {}\n'.format(nme_curve['top5']))
66 | else:
67 | logging.info('No NME accuracy.')
68 | logging.info('CNN: {}'.format(cnn_accy['grouped']))
69 |
70 | cnn_curve['top1'].append(cnn_accy['top1'])
71 | cnn_curve['top5'].append(cnn_accy['top5'])
72 |
73 | logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
74 | logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5']))
75 |
76 |
77 | def _set_device(args):
78 | device_type = args['device']
79 | gpus = []
80 |
81 | for device in device_type:
82 | if device_type == -1:
83 | device = torch.device('cpu')
84 | else:
85 | device = torch.device('cuda:{}'.format(device))
86 |
87 | gpus.append(device)
88 |
89 | args['device'] = gpus
90 |
91 |
92 | def _set_random():
93 | torch.manual_seed(1)
94 | torch.cuda.manual_seed(1)
95 | torch.cuda.manual_seed_all(1)
96 | torch.backends.cudnn.deterministic = True
97 | torch.backends.cudnn.benchmark = False
98 |
99 |
100 | def print_args(args):
101 | for key, value in args.items():
102 | logging.info('{}: {}'.format(key, value))
103 |
104 |
--------------------------------------------------------------------------------
/exps/slca/slca_cars.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cars196_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slca_cars",
11 | "model_postfix": "50e",
12 | "convnet_type": "vit-b-p16",
13 | "device": ["0","1"],
14 | "seed": [1993, 1996, 1997],
15 | "epochs": 50,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.05,
18 | "milestones": [40]
19 | }
20 |
--------------------------------------------------------------------------------
/exps/slca/slca_cars_mocov3.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cars196_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slca_cars_mocov3",
11 | "model_postfix": "90e",
12 | "convnet_type": "vit-b-p16-mocov3",
13 | "device": ["0","1"],
14 | "seed": [1993, 1996, 1997],
15 | "epochs": 90,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.05,
18 | "milestones": [80]
19 | }
20 |
--------------------------------------------------------------------------------
/exps/slca/slca_cifar.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cifar100_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 10,
9 | "increment": 10,
10 | "model_name": "slca_cifar",
11 | "model_postfix": "20e",
12 | "convnet_type": "vit-b-p16",
13 | "device": ["0","1"],
14 | "seed": [1993, 1996, 1997],
15 | "epochs": 20,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.1,
18 | "milestones": [18]
19 | }
20 |
--------------------------------------------------------------------------------
/exps/slca/slca_cifar_mocov3.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cifar100_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 10,
9 | "increment": 10,
10 | "model_name": "slca_cifar_mocov3",
11 | "model_postfix": "90e",
12 | "convnet_type": "vit-b-p16-mocov3",
13 | "device": ["0","1"],
14 | "seed": [1993, 1996, 1997],
15 | "epochs": 90,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.1,
18 | "milestones": [80]
19 | }
20 |
--------------------------------------------------------------------------------
/exps/slca/slca_cub.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cub200_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slca_cub",
11 | "model_postfix": "50e",
12 | "convnet_type": "vit-b-p16",
13 | "device": ["0","1"],
14 | "seed": [1993, 1996, 1997],
15 | "epochs": 50,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.1,
18 | "milestones": [40]
19 | }
20 |
--------------------------------------------------------------------------------
/exps/slca/slca_cub_mocov3.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cub200_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slca_cub_mocov3",
11 | "model_postfix": "90e",
12 | "convnet_type": "vit-b-p16-mocov3",
13 | "device": ["0","1"],
14 | "seed": [1993, 1996, 1997],
15 | "epochs": 90,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.1,
18 | "milestones": [80]
19 | }
20 |
--------------------------------------------------------------------------------
/exps/slca/slca_imgnetr.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "imagenet-r",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slca_imgnetr",
11 | "model_postfix": "50e",
12 | "convnet_type": "vit-b-p16",
13 | "device": ["0","1"],
14 | "seed": [1993, 1996, 1997],
15 | "epochs": 50,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.1,
18 | "milestones": [40]
19 | }
20 |
--------------------------------------------------------------------------------
/exps/slca/slca_imgnetr_mocov3.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "imagenet-r",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slca_imgnetr_mocov3",
11 | "model_postfix": "90e",
12 | "convnet_type": "vit-b-p16-mocov3",
13 | "device": ["0","1"],
14 | "seed": [1993, 1996, 1997],
15 | "epochs": 90,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.1,
18 | "milestones": [80]
19 | }
20 |
--------------------------------------------------------------------------------
/exps/slcapp/slcapp_cars_lora.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cars196_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slcapp_cars_lora",
11 | "model_postfix": "50e",
12 | "convnet_type": "vit-b-p16-lora",
13 | "device": ["0"],
14 | "seed": [1993,1996,1997],
15 | "bcb_lrscale": 0.1,
16 | "weight_decay": 0.0,
17 | "lora_rank": 4,
18 | "sce_a": 0.5,
19 | "sce_b": 0.5,
20 | "fc_scale_mu": 0.02,
21 | "epochs": 50,
22 | "ca_epochs": 5,
23 | "ca_with_logit_norm": 0.05,
24 | "milestones": [40]
25 | }
26 |
--------------------------------------------------------------------------------
/exps/slcapp/slcapp_cars_lora_mocov3.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cars196_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slcapp_cars_lora_mocov3",
11 | "model_postfix": "90e",
12 | "convnet_type": "vit-b-p16-lora-mocov3",
13 | "device": ["0"],
14 | "seed": [1993,1996,1997],
15 | "bcb_lrscale": 0.1,
16 | "lora_rank": 4,
17 | "sce_a": 0.5,
18 | "sce_b": 0.5,
19 | "fc_scale_mu": 0.02,
20 | "weight_decay": 0.0,
21 | "epochs": 90,
22 | "ca_epochs": 5,
23 | "ca_with_logit_norm": 0.05,
24 | "milestones": [60, 80]
25 | }
26 |
--------------------------------------------------------------------------------
/exps/slcapp/slcapp_cifar_lora.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cifar100_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 10,
9 | "increment": 10,
10 | "model_name": "slcapp_cifar_lora",
11 | "model_postfix": "20e",
12 | "convnet_type": "vit-b-p16-lora",
13 | "device": ["0"],
14 | "seed": [1993,1996,1997],
15 | "epochs": 20,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.1,
18 | "bcb_lrscale": 0.1,
19 | "lora_rank": 4,
20 | "sce_a": 1.0,
21 | "sce_b": 0.1,
22 | "fc_scale_mu": 0.02,
23 | "weight_decay": 0.0005,
24 | "milestones": [18]
25 | }
26 |
--------------------------------------------------------------------------------
/exps/slcapp/slcapp_cifar_lora_mocov3.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cifar100_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 10,
9 | "increment": 10,
10 | "model_name": "slcapp_cifar_lora_mocov3",
11 | "model_postfix": "90e",
12 | "convnet_type": "vit-b-p16-lora-mocov3",
13 | "device": ["0"],
14 | "seed": [1993,1996,1997],
15 | "epochs": 90,
16 | "ca_epochs": 5,
17 | "ca_with_logit_norm": 0.1,
18 | "bcb_lrscale": 0.1,
19 | "lora_rank": 4,
20 | "sce_a": 1.0,
21 | "sce_b": 0.1,
22 | "fc_scale_mu": 0.02,
23 | "weight_decay": 0.0005,
24 | "milestones": [60, 80]
25 | }
26 |
--------------------------------------------------------------------------------
/exps/slcapp/slcapp_cub_lora.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cub200_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slcapp_cub_lora",
11 | "model_postfix": "50e",
12 | "convnet_type": "vit-b-p16-lora",
13 | "device": ["0"],
14 | "seed": [1993,1996,1997],
15 | "bcb_lrscale": 0.1,
16 | "lora_rank": 4,
17 | "sce_a": 1.0,
18 | "sce_b": 0.1,
19 | "fc_scale_mu": 0.02,
20 | "epochs": 50,
21 | "ca_epochs": 5,
22 | "ca_with_logit_norm": 0.1,
23 | "milestones": [40]
24 | }
25 |
--------------------------------------------------------------------------------
/exps/slcapp/slcapp_cub_lora_mocov3.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "cub200_224",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slcapp_cub_lora_mocov3",
11 | "model_postfix": "90e",
12 | "convnet_type": "vit-b-p16-lora-mocov3",
13 | "device": ["0"],
14 | "seed": [1993,1996,1997],
15 | "bcb_lrscale": 0.1,
16 | "lora_rank": 4,
17 | "sce_a": 1.0,
18 | "sce_b": 0.1,
19 | "fc_scale_mu": 0.02,
20 | "epochs": 90,
21 | "ca_epochs": 5,
22 | "ca_with_logit_norm": 0.1,
23 | "milestones": [60, 80]
24 | }
25 |
--------------------------------------------------------------------------------
/exps/slcapp/slcapp_imgnetr_lora.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "imagenet-r",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slcapp_imgnetr_lora",
11 | "model_postfix": "50e",
12 | "convnet_type": "vit-b-p16-lora",
13 | "bcb_lrscale": 0.1,
14 | "weight_decay": 0.0,
15 | "device": ["0"],
16 | "lora_rank": 4,
17 | "sce_a": 0.5,
18 | "sce_b": 0.5,
19 | "fc_scale_mu": 0.02,
20 | "seed": [1993,1996,1997],
21 | "epochs": 50,
22 | "ca_epochs": 5,
23 | "ca_with_logit_norm": 0.1,
24 | "milestones": [40]
25 | }
26 |
--------------------------------------------------------------------------------
/exps/slcapp/slcapp_imgnetr_lora_mocov3.json:
--------------------------------------------------------------------------------
1 | {
2 | "prefix": "reproduce",
3 | "dataset": "imagenet-r",
4 | "memory_size": 0,
5 | "memory_per_class": 0,
6 | "fixed_memory": false,
7 | "shuffle": true,
8 | "init_cls": 20,
9 | "increment": 20,
10 | "model_name": "slcapp_imgnetr_lora_mocov3",
11 | "model_postfix": "90e",
12 | "convnet_type": "vit-b-p16-lora-mocov3",
13 | "bcb_lrscale": 0.1,
14 | "weight_decay": 0.0,
15 | "device": ["0"],
16 | "lora_rank": 4,
17 | "sce_a": 0.5,
18 | "sce_b": 0.5,
19 | "fc_scale_mu": 0.02,
20 | "seed": [1993,1996,1997],
21 | "epochs": 90,
22 | "ca_epochs": 5,
23 | "ca_with_logit_norm": 0.1,
24 | "milestones": [60, 80]
25 | }
26 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from trainer import train
4 | from evaluator import test
5 |
6 | def main():
7 | args = setup_parser().parse_args()
8 | param = load_json(args.config)
9 | args = vars(args) # Converting argparse Namespace to a dict.
10 | args.update(param) # Add parameters from json
11 | if args['test_only']:
12 | test(args)
13 | else:
14 | train(args)
15 |
16 |
17 | def load_json(settings_path):
18 | with open(settings_path) as data_file:
19 | param = json.load(data_file)
20 |
21 | return param
22 |
23 |
24 | def setup_parser():
25 | parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
26 | parser.add_argument('--config', type=str, default='./exps/finetune.json',
27 | help='Json file of settings.')
28 | parser.add_argument('--test_only', action='store_true')
29 | return parser
30 |
31 |
32 | if __name__ == '__main__':
33 | main()
34 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GengDavid/SLCA/169d4e6b91e3ca30caa717200c1944981c4426d1/models/__init__.py
--------------------------------------------------------------------------------
/models/base.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.utils.data import DataLoader
7 | from utils.toolkit import tensor2numpy, accuracy
8 | from scipy.spatial.distance import cdist
9 | from collections import OrderedDict
10 |
11 | EPSILON = 1e-8
12 | batch_size = 64
13 |
14 |
15 | class BaseLearner(object):
16 | def __init__(self, args):
17 | self._cur_task = -1
18 | self._known_classes = 0
19 | self._total_classes = 0
20 | self._network = None
21 | self._old_network = None
22 | self._data_memory, self._targets_memory = np.array([]), np.array([])
23 | self.topk = 5
24 |
25 | self._memory_size = args['memory_size']
26 | self._memory_per_class = args['memory_per_class']
27 | self._fixed_memory = args['fixed_memory']
28 | self._device = args['device'][0]
29 | self._multiple_gpus = args['device']
30 |
31 | @property
32 | def exemplar_size(self):
33 | assert len(self._data_memory) == len(self._targets_memory), 'Exemplar size error.'
34 | return len(self._targets_memory)
35 |
36 | @property
37 | def samples_per_class(self):
38 | if self._fixed_memory:
39 | return self._memory_per_class
40 | else:
41 | assert self._total_classes != 0, 'Total classes is 0'
42 | return (self._memory_size // self._total_classes)
43 |
44 | @property
45 | def feature_dim(self):
46 | if isinstance(self._network, nn.DataParallel):
47 | return self._network.module.feature_dim
48 | else:
49 | return self._network.feature_dim
50 |
51 | def save_checkpoint(self, filename, head_only=False, learnable_only=False):
52 | if hasattr(self._network, 'module'):
53 | to_save = self._network.module
54 | else:
55 | to_save = self._network
56 |
57 | if head_only:
58 | to_save_dict = to_save.fc.state_dict()
59 | else:
60 | to_save_dict = to_save.state_dict()
61 |
62 | if learnable_only:
63 | new_dict = OrderedDict()
64 | filtered_keys = [n for n, p in to_save.named_parameters() if p.requires_grad]
65 | for k in filtered_keys:
66 | new_dict[k] = to_save_dict[k]
67 | to_save_dict = new_dict
68 |
69 | save_dict = {
70 | 'tasks': self._cur_task,
71 | 'model_state_dict': to_save_dict,
72 | }
73 | torch.save(save_dict, '{}_{}.pth'.format(filename, self._cur_task))
74 |
75 | def after_task(self):
76 | pass
77 |
78 | def _evaluate(self, y_pred, y_true):
79 | ret = {}
80 | grouped = accuracy(y_pred.T[0], y_true, self._known_classes)
81 | ret['grouped'] = grouped
82 | ret['top1'] = grouped['total']
83 | ret['top{}'.format(5)] = np.around((y_pred.T == np.tile(y_true, (self.topk, 1))).sum()*100/len(y_true),
84 | decimals=2)
85 |
86 | return ret
87 |
88 | def eval_task(self):
89 | y_pred, y_true = self._eval_cnn(self.test_loader)
90 | cnn_accy = self._evaluate(y_pred, y_true)
91 |
92 | if hasattr(self, '_class_means') and False: # TODO
93 | y_pred, y_true = self._eval_nme(self.test_loader, self._class_means)
94 | nme_accy = self._evaluate(y_pred, y_true)
95 | else:
96 | nme_accy = None
97 |
98 | return cnn_accy, nme_accy
99 |
100 | def incremental_train(self):
101 | pass
102 |
103 | def _train(self):
104 | pass
105 |
106 | def _get_memory(self):
107 | if len(self._data_memory) == 0:
108 | return None
109 | else:
110 | return (self._data_memory, self._targets_memory)
111 |
112 |
113 | def _inner_eval(self, model, loader):
114 | model.eval()
115 | y_pred, y_true = [], []
116 | for _, (_, inputs, targets) in enumerate(loader):
117 | inputs = inputs.to(self._device)
118 | with torch.no_grad():
119 | outputs = model(inputs)['logits']
120 | predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] # [bs, topk]
121 | y_pred.append(predicts.cpu().numpy())
122 | y_true.append(targets.cpu().numpy())
123 |
124 | y_pred, y_true = np.concatenate(y_pred), np.concatenate(y_true) # [N, topk]
125 |
126 | cnn_accy = self._evaluate(y_pred, y_true)
127 | return cnn_accy
128 |
129 | def _compute_accuracy(self, model, loader):
130 | model.eval()
131 | correct, total = 0, 0
132 | for i, (_, inputs, targets) in enumerate(loader):
133 | inputs = inputs.to(self._device)
134 | with torch.no_grad():
135 | outputs = model(inputs)['logits']
136 | predicts = torch.max(outputs, dim=1)[1]
137 | correct += (predicts.cpu() == targets).sum()
138 | total += len(targets)
139 |
140 | return np.around(tensor2numpy(correct)*100 / total, decimals=2)
141 |
142 | def _eval_cnn(self, loader):
143 | self._network.eval()
144 | y_pred, y_true = [], []
145 | for _, (_, inputs, targets) in enumerate(loader):
146 | inputs = inputs.to(self._device)
147 | with torch.no_grad():
148 | outputs = self._network(inputs)['logits']
149 | predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] # [bs, topk]
150 | y_pred.append(predicts.cpu().numpy())
151 | y_true.append(targets.cpu().numpy())
152 |
153 | return np.concatenate(y_pred), np.concatenate(y_true) # [N, topk]
154 |
155 | def _eval_nme(self, loader, class_means):
156 | self._network.eval()
157 | vectors, y_true = self._extract_vectors(loader)
158 | vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
159 |
160 | norm_means = class_means / np.linalg.norm(class_means)
161 | dists = cdist(norm_means, vectors, 'sqeuclidean') # [nb_classes, N]
162 | scores = dists.T # [N, nb_classes], choose the one with the smallest distance
163 |
164 | return np.argsort(scores, axis=1)[:, :self.topk], y_true # [N, topk]
165 |
166 | def _extract_vectors(self, loader):
167 | self._network.eval()
168 | vectors, targets = [], []
169 | for _, _inputs, _targets in loader:
170 | _targets = _targets.numpy()
171 | if isinstance(self._network, nn.DataParallel):
172 | _vectors = tensor2numpy(self._network.module.extract_vector(_inputs.to(self._device)))
173 | else:
174 | _vectors = tensor2numpy(self._network.extract_vector(_inputs.to(self._device)))
175 |
176 | vectors.append(_vectors)
177 | targets.append(_targets)
178 |
179 | return np.concatenate(vectors), np.concatenate(targets)
180 |
181 | def _extract_vectors_aug(self, loader, repeat=2):
182 | self._network.eval()
183 | vectors, targets = [], []
184 | for _ in range(repeat):
185 | for _, _inputs, _targets in loader:
186 | _targets = _targets.numpy()
187 | with torch.no_grad():
188 | if isinstance(self._network, nn.DataParallel):
189 | _vectors = tensor2numpy(self._network.module.extract_vector(_inputs.to(self._device)))
190 | else:
191 | _vectors = tensor2numpy(self._network.extract_vector(_inputs.to(self._device)))
192 |
193 | vectors.append(_vectors)
194 | targets.append(_targets)
195 |
196 | return np.concatenate(vectors), np.concatenate(targets)
197 |
198 | def _compute_class_mean(self, data_manager, check_diff=False, oracle=False):
199 | if hasattr(self, '_class_means') and self._class_means is not None and not check_diff:
200 | ori_classes = self._class_means.shape[0]
201 | assert ori_classes==self._known_classes
202 | new_class_means = np.zeros((self._total_classes, self.feature_dim))
203 | new_class_means[:self._known_classes] = self._class_means
204 | self._class_means = new_class_means
205 | # new_class_cov = np.zeros((self._total_classes, self.feature_dim, self.feature_dim))
206 | new_class_cov = torch.zeros((self._total_classes, self.feature_dim, self.feature_dim))
207 | new_class_cov[:self._known_classes] = self._class_covs
208 | self._class_covs = new_class_cov
209 | elif not check_diff:
210 | self._class_means = np.zeros((self._total_classes, self.feature_dim))
211 | # self._class_covs = np.zeros((self._total_classes, self.feature_dim, self.feature_dim))
212 | self._class_covs = torch.zeros((self._total_classes, self.feature_dim, self.feature_dim))
213 |
214 | # self._class_covs = []
215 |
216 | if check_diff:
217 | for class_idx in range(0, self._known_classes):
218 | data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
219 | mode='test', ret_data=True)
220 | idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
221 | # vectors, _ = self._extract_vectors_aug(idx_loader)
222 | vectors, _ = self._extract_vectors(idx_loader)
223 | class_mean = np.mean(vectors, axis=0)
224 | # class_cov = np.cov(vectors.T)
225 | class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T)
226 | if check_diff:
227 | log_info = "cls {} sim: {}".format(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0)).item())
228 | logging.info(log_info)
229 | np.save('task_{}_cls_{}_mean.npy'.format(self._cur_task, class_idx), class_mean)
230 | # print(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0)))
231 |
232 | if oracle:
233 | for class_idx in range(0, self._known_classes):
234 | data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
235 | mode='test', ret_data=True)
236 | idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
237 | vectors, _ = self._extract_vectors(idx_loader)
238 |
239 | # vectors = np.concatenate([vectors_aug, vectors])
240 |
241 | class_mean = np.mean(vectors, axis=0)
242 | # class_cov = np.cov(vectors.T)
243 | class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T)+torch.eye(class_mean.shape[-1])*1e-5
244 | self._class_means[class_idx, :] = class_mean
245 | self._class_covs[class_idx, ...] = class_cov
246 |
247 | for class_idx in range(self._known_classes, self._total_classes):
248 | # data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
249 | # mode='train', ret_data=True)
250 | # idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
251 | # vectors_aug, _ = self._extract_vectors_aug(idx_loader)
252 |
253 | data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train',
254 | mode='test', ret_data=True)
255 | idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
256 | vectors, _ = self._extract_vectors(idx_loader)
257 |
258 | # vectors = np.concatenate([vectors_aug, vectors])
259 |
260 | class_mean = np.mean(vectors, axis=0)
261 | # class_cov = np.cov(vectors.T)
262 | class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T)+torch.eye(class_mean.shape[-1])*1e-4
263 | if check_diff:
264 | log_info = "cls {} sim: {}".format(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0)).item())
265 | logging.info(log_info)
266 | np.save('task_{}_cls_{}_mean.npy'.format(self._cur_task, class_idx), class_mean)
267 | np.save('task_{}_cls_{}_mean_beforetrain.npy'.format(self._cur_task, class_idx), self._class_means[class_idx, :])
268 | # print(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0)))
269 | self._class_means[class_idx, :] = class_mean
270 | self._class_covs[class_idx, ...] = class_cov
271 | # self._class_covs.append(class_cov)
272 |
--------------------------------------------------------------------------------
/models/slca.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch import optim
6 | from torch.nn import functional as F
7 | from torch.utils.data import DataLoader
8 | from models.base import BaseLearner
9 | from utils.inc_net import FinetuneIncrementalNet
10 | from torchvision import transforms
11 | from torch.distributions.multivariate_normal import MultivariateNormal
12 | import random
13 | from utils.toolkit import tensor2numpy, accuracy
14 | import copy
15 | import os
16 |
17 | epochs = 20
18 | lrate = 0.01
19 | milestones = [60,100,140]
20 | lrate_decay = 0.1
21 | batch_size = 128
22 | split_ratio = 0.1
23 | T = 2
24 | weight_decay = 5e-4
25 | num_workers = 8
26 | ca_epochs = 5
27 |
28 |
29 | class SLCA(BaseLearner):
30 | def __init__(self, args):
31 | super().__init__(args)
32 | self._network = FinetuneIncrementalNet(args, pretrained=True)
33 | self.log_path = "logs/{}_{}".format(args['model_name'], args['model_postfix'])
34 | self.model_prefix = args['prefix']
35 | if 'epochs' in args.keys():
36 | global epochs
37 | epochs = args['epochs']
38 | if 'milestones' in args.keys():
39 | global milestones
40 | milestones = args['milestones']
41 | if 'lr' in args.keys():
42 | global lrate
43 | lrate = args['lr']
44 | print('set lr to ', lrate)
45 | if 'bcb_lrscale' in args.keys():
46 | self.bcb_lrscale = args['bcb_lrscale']
47 | else:
48 | self.bcb_lrscale = 1.0/100
49 | if self.bcb_lrscale == 0:
50 | self.fix_bcb = True
51 | else:
52 | self.fix_bcb = False
53 | print('fic_bcb', self.fix_bcb)
54 |
55 |
56 |
57 | if 'save_before_ca' in args.keys() and args['save_before_ca']:
58 | self.save_before_ca = True
59 | else:
60 | self.save_before_ca = False
61 |
62 | if 'ca_epochs' in args.keys():
63 | global ca_epochs
64 | ca_epochs = args['ca_epochs']
65 |
66 | if 'ca_with_logit_norm' in args.keys() and args['ca_with_logit_norm']>0:
67 | self.logit_norm = args['ca_with_logit_norm']
68 | else:
69 | self.logit_norm = None
70 |
71 | self.run_id = args['run_id']
72 | self.seed = args['seed']
73 | self.task_sizes = []
74 |
75 | def after_task(self):
76 | self._known_classes = self._total_classes
77 | logging.info('Exemplar size: {}'.format(self.exemplar_size))
78 | self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}'.format(self.seed), head_only=self.fix_bcb)
79 | self._network.fc.recall()
80 |
81 | def incremental_train(self, data_manager):
82 | self._cur_task += 1
83 | task_size = data_manager.get_task_size(self._cur_task)
84 | self.task_sizes.append(task_size)
85 | self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
86 | self.topk = self._total_classes if self._total_classes<5 else 5
87 | self._network.update_fc(data_manager.get_task_size(self._cur_task))
88 | logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes))
89 |
90 | self._network.to(self._device)
91 |
92 | train_dset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),
93 | source='train', mode='train',
94 | appendent=[], with_raw=False)
95 | test_dset = data_manager.get_dataset(np.arange(0, self._total_classes), source='test', mode='test')
96 | dset_name = data_manager.dataset_name.lower()
97 |
98 | self.train_loader = DataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
99 | self.test_loader = DataLoader(test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
100 |
101 | self._stage1_training(self.train_loader, self.test_loader)
102 |
103 | if len(self._multiple_gpus) > 1:
104 | self._network = self._network.module
105 |
106 | # CA
107 | self._network.fc.backup()
108 | if self.save_before_ca:
109 | self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}_before_ca'.format(self.seed), head_only=self.fix_bcb)
110 |
111 | self._compute_class_mean(data_manager, check_diff=False, oracle=False)
112 | if self._cur_task>0 and ca_epochs>0:
113 | self._stage2_compact_classifier(task_size)
114 | if len(self._multiple_gpus) > 1:
115 | self._network = self._network.module
116 |
117 |
118 | def _run(self, train_loader, test_loader, optimizer, scheduler):
119 | run_epochs = epochs
120 | for epoch in range(1, run_epochs+1):
121 | self._network.train()
122 | losses = 0.
123 | for i, (_, inputs, targets) in enumerate(train_loader):
124 | inputs, targets = inputs.to(self._device), targets.to(self._device)
125 |
126 | logits = self._network(inputs, bcb_no_grad=self.fix_bcb)['logits']
127 | cur_targets = torch.where(targets-self._known_classes>=0,targets-self._known_classes,-100)
128 | loss = F.cross_entropy(logits[:, self._known_classes:], cur_targets)
129 |
130 | optimizer.zero_grad()
131 | loss.backward()
132 | optimizer.step()
133 | losses += loss.item()
134 |
135 | scheduler.step()
136 | if epoch%5==0:
137 | train_acc = self._compute_accuracy(self._network, train_loader)
138 | test_acc = self._compute_accuracy(self._network, test_loader)
139 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}'.format(
140 | self._cur_task, epoch, epochs, losses/len(train_loader), train_acc, test_acc)
141 | else:
142 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}'.format(
143 | self._cur_task, epoch, epochs, losses/len(train_loader))
144 | logging.info(info)
145 |
146 | def _stage1_training(self, train_loader, test_loader):
147 | '''
148 | if self._cur_task == 0:
149 | loaded_dict = torch.load('./dict_0.pkl')
150 | self._network.load_state_dict(loaded_dict['model_state_dict'])
151 | self._network.to(self._device)
152 | return
153 | '''
154 | base_params = self._network.convnet.parameters()
155 | base_fc_params = [p for p in self._network.fc.parameters() if p.requires_grad==True]
156 | head_scale = 1. if 'moco' in self.log_path else 1.
157 | if not self.fix_bcb:
158 | base_params = {'params': base_params, 'lr': lrate*self.bcb_lrscale, 'weight_decay': weight_decay}
159 | base_fc_params = {'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay}
160 | network_params = [base_params, base_fc_params]
161 | else:
162 | for p in base_params:
163 | p.requires_grad = False
164 | network_params = [{'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay}]
165 | optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay)
166 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=lrate_decay)
167 |
168 | if len(self._multiple_gpus) > 1:
169 | self._network = nn.DataParallel(self._network, self._multiple_gpus)
170 |
171 | self._run(train_loader, test_loader, optimizer, scheduler)
172 |
173 |
174 | def _stage2_compact_classifier(self, task_size):
175 | for p in self._network.fc.parameters():
176 | p.requires_grad=True
177 |
178 | run_epochs = ca_epochs
179 | crct_num = self._total_classes
180 | param_list = [p for p in self._network.fc.parameters() if p.requires_grad]
181 | network_params = [{'params': param_list, 'lr': lrate,
182 | 'weight_decay': weight_decay}]
183 | optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay)
184 | # scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[4], gamma=lrate_decay)
185 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs)
186 |
187 | self._network.to(self._device)
188 | if len(self._multiple_gpus) > 1:
189 | self._network = nn.DataParallel(self._network, self._multiple_gpus)
190 |
191 | self._network.eval()
192 | for epoch in range(run_epochs):
193 | losses = 0.
194 |
195 | sampled_data = []
196 | sampled_label = []
197 | num_sampled_pcls = 256
198 |
199 | for c_id in range(crct_num):
200 | t_id = c_id//task_size
201 | decay = (t_id+1)/(self._cur_task+1)*0.1
202 | cls_mean = torch.tensor(self._class_means[c_id], dtype=torch.float64).to(self._device)*(0.9+decay) # torch.from_numpy(self._class_means[c_id]).to(self._device)
203 | cls_cov = self._class_covs[c_id].to(self._device)
204 |
205 | m = MultivariateNormal(cls_mean.float(), cls_cov.float())
206 |
207 | sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
208 | sampled_data.append(sampled_data_single)
209 | sampled_label.extend([c_id]*num_sampled_pcls)
210 |
211 | sampled_data = torch.cat(sampled_data, dim=0).float().to(self._device)
212 | sampled_label = torch.tensor(sampled_label).long().to(self._device)
213 |
214 | inputs = sampled_data
215 | targets= sampled_label
216 |
217 | sf_indexes = torch.randperm(inputs.size(0))
218 | inputs = inputs[sf_indexes]
219 | targets = targets[sf_indexes]
220 |
221 |
222 | for _iter in range(crct_num):
223 | inp = inputs[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls]
224 | tgt = targets[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls]
225 | outputs = self._network(inp, bcb_no_grad=True, fc_only=True)
226 | logits = outputs['logits']
227 |
228 | if self.logit_norm is not None:
229 | per_task_norm = []
230 | prev_t_size = 0
231 | cur_t_size = 0
232 | for _ti in range(self._cur_task+1):
233 | cur_t_size += self.task_sizes[_ti]
234 | temp_norm = torch.norm(logits[:, prev_t_size:cur_t_size], p=2, dim=-1, keepdim=True) + 1e-7
235 | per_task_norm.append(temp_norm)
236 | prev_t_size += self.task_sizes[_ti]
237 | per_task_norm = torch.cat(per_task_norm, dim=-1)
238 | norms = per_task_norm.mean(dim=-1, keepdim=True)
239 |
240 | norms_all = torch.norm(logits[:, :crct_num], p=2, dim=-1, keepdim=True) + 1e-7
241 | decoupled_logits = torch.div(logits[:, :crct_num], norms) / self.logit_norm
242 | loss = F.cross_entropy(decoupled_logits, tgt)
243 |
244 | else:
245 | loss = F.cross_entropy(logits[:, :crct_num], tgt)
246 |
247 | optimizer.zero_grad()
248 | loss.backward()
249 | optimizer.step()
250 | losses += loss.item()
251 |
252 | scheduler.step()
253 | test_acc = self._compute_accuracy(self._network, self.test_loader)
254 | info = 'CA Task {} => Loss {:.3f}, Test_accy {:.3f}'.format(
255 | self._cur_task, losses/self._total_classes, test_acc)
256 | logging.info(info)
257 |
258 |
259 |
--------------------------------------------------------------------------------
/models/slca_pp.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch import optim
6 | from torch.nn import functional as F
7 | from torch.utils.data import DataLoader
8 | from models.slca import SLCA
9 | from utils.inc_net import FinetuneIncrementalNet
10 | from torchvision import transforms
11 | from torch.distributions.multivariate_normal import MultivariateNormal
12 | import random
13 | from utils.toolkit import tensor2numpy, accuracy
14 | import copy
15 | import os
16 |
17 | epochs = 20
18 | lrate = 0.01
19 | milestones = [60,100,140]
20 | lrate_decay = 0.1
21 | batch_size = 128
22 | split_ratio = 0.1
23 | T = 2
24 | weight_decay = 5e-4
25 | num_workers = 8
26 | ca_epochs = 5
27 |
28 |
29 | class SLCApp(SLCA):
30 | def __init__(self, args):
31 | super().__init__(args)
32 | self.model_name_ = args['model_name']
33 | self.sce_a, self.sce_b = args['sce_a'], args['sce_b']
34 | assert 'fc_scale_mu' in args.keys() and args['fc_scale_mu']>0
35 | self.fc_scale_mu = args['fc_scale_mu']
36 |
37 | # hybrid-SL
38 | for n, p in self._network.convnet.named_parameters():
39 | if 'norm' in n or 'bias' in n or 'cls_token' in n or 'lora' in n:
40 | p.requires_grad=True
41 | print(n)
42 | else:
43 | p.requires_grad=False
44 |
45 | # reset global values according to args
46 | if 'weight_decay' in args.keys():
47 | global weight_decay
48 | weight_decay = args['weight_decay']
49 | if 'epochs' in args.keys():
50 | global epochs
51 | epochs = args['epochs']
52 | if 'milestones' in args.keys():
53 | global milestones
54 | milestones = args['milestones']
55 | if 'lr' in args.keys():
56 | global lrate
57 | lrate = args['lr']
58 | if 'ca_epochs' in args.keys():
59 | global ca_epochs
60 | ca_epochs = args['ca_epochs']
61 |
62 | def after_task(self):
63 | self._known_classes = self._total_classes
64 | logging.info('Exemplar size: {}'.format(self.exemplar_size))
65 | self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}'.format(self.seed), head_only=self.fix_bcb, learnable_only=True)
66 | self._network.fc.recall()
67 |
68 | def incremental_train(self, data_manager):
69 | self._cur_task += 1
70 | task_size = data_manager.get_task_size(self._cur_task)
71 | self.task_sizes.append(task_size)
72 | self._total_classes = self._known_classes + task_size
73 | self.topk = self._total_classes if self._total_classes<5 else 5
74 | self._network.update_fc(task_size)
75 | logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes))
76 |
77 | self._network.to(self._device)
78 |
79 | if self._cur_task==0 and self._network.convnet.lora_rank>0:
80 | for b_idx in range(self._network.convnet.lora_lp):
81 | self._network.convnet.blocks[b_idx].mlp.init_lora()
82 | self._network.convnet.blocks[b_idx].attn.init_lora()
83 |
84 | train_dset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),
85 | source='train', mode='train',
86 | appendent=[], with_raw=False)
87 | test_dset = data_manager.get_dataset(np.arange(0, self._total_classes), source='test', mode='test')
88 | dset_name = data_manager.dataset_name.lower()
89 |
90 | self.train_loader = DataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
91 | self.test_loader = DataLoader(test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
92 |
93 | self._stage1_training(self.train_loader, self.test_loader)
94 |
95 | if len(self._multiple_gpus) > 1:
96 | self._network = self._network.module
97 |
98 | # CA
99 | self._network.fc.backup()
100 | if self.save_before_ca:
101 | self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}_before_ca'.format(self.seed), head_only=self.fix_bcb)
102 |
103 | self._compute_class_mean(data_manager, check_diff=False, oracle=False)
104 | if self._cur_task>0 and ca_epochs>0:
105 | self._stage2_compact_classifier(task_size)
106 | if len(self._multiple_gpus) > 1:
107 | self._network = self._network.module
108 | self._network.fc.update_scale()
109 |
110 |
111 | def _run(self, train_loader, test_loader, optimizer, scheduler):
112 | run_epochs = epochs
113 | for epoch in range(1, run_epochs+1):
114 | self._network.train()
115 | losses = 0.
116 | for i, (_, inputs, targets) in enumerate(train_loader):
117 | inputs, targets = inputs.to(self._device), targets.to(self._device)
118 |
119 | logits = self._network(inputs, bcb_no_grad=self.fix_bcb)['logits']
120 | cur_targets = torch.where(targets-self._known_classes>=0,targets-self._known_classes,-100)
121 | cur_logits = logits[:, self._known_classes:]
122 |
123 | ce_loss = F.cross_entropy(cur_logits, cur_targets)
124 | pred = F.softmax(cur_logits, dim=1)
125 | pred = torch.clamp(pred, min=1e-7, max=1.0)
126 | label_one_hot = torch.nn.functional.one_hot(cur_targets, pred.size(1)).float().to(pred.device)
127 | label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
128 | rce_loss = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))
129 | loss = self.sce_a*ce_loss + self.sce_b*rce_loss.mean()
130 |
131 | optimizer.zero_grad()
132 | loss.backward()
133 | optimizer.step()
134 | losses += loss.item()
135 |
136 | scheduler.step()
137 | if epoch%5==0:
138 | train_acc = self._compute_accuracy(self._network, train_loader)
139 | test_acc = self._compute_accuracy(self._network, test_loader)
140 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}'.format(
141 | self._cur_task, epoch, epochs, losses/len(train_loader), train_acc, test_acc)
142 | else:
143 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}'.format(
144 | self._cur_task, epoch, epochs, losses/len(train_loader))
145 | logging.info(info)
146 |
147 | def _stage1_training(self, train_loader, test_loader):
148 | if self._network.convnet.lora_rank>0:
149 | base_loraA_params = [p for n, p in self._network.convnet.named_parameters() if p.requires_grad==True and 'lora_A' in n]
150 | base_loraB_params = [p for n, p in self._network.convnet.named_parameters() if p.requires_grad==True and 'lora_B' in n]
151 | base_others_params = [p for n, p in self._network.convnet.named_parameters() if p.requires_grad==True and 'lora' not in n]
152 | else:
153 | base_params = self._network.convnet.parameters()
154 |
155 | base_fc_params = [p for p in self._network.fc.parameters() if p.requires_grad==True]
156 | head_scale = 1.
157 | if not self.fix_bcb:
158 | base_fc_params = {'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay}
159 | if self._network.convnet.lora_rank>0:
160 | lora_scale = 0.5 if self._cur_task==0 else self.bcb_lrscale
161 | base_loraA_params = {'params': base_loraA_params, 'lr': lrate*lora_scale, 'weight_decay': weight_decay}
162 | base_loraB_params = {'params': base_loraB_params, 'lr': lrate*lora_scale, 'weight_decay': weight_decay}
163 | base_others_params = {'params': base_others_params, 'lr': lrate*self.bcb_lrscale, 'weight_decay': weight_decay}
164 | network_params = [base_loraA_params, base_loraB_params, base_others_params, base_fc_params]
165 | else:
166 | base_params = {'params': base_params, 'lr': lrate*self.bcb_lrscale, 'weight_decay': weight_decay}
167 | network_params = [base_params, base_fc_params]
168 | else:
169 | for p in base_params:
170 | p.requires_grad = False
171 | network_params = [{'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay}]
172 | optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay)
173 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=epochs, eta_min=lrate*self.bcb_lrscale*lrate_decay**len(milestones))
174 |
175 | if len(self._multiple_gpus) > 1:
176 | self._network = nn.DataParallel(self._network, self._multiple_gpus)
177 |
178 | self._run(train_loader, test_loader, optimizer, scheduler)
179 |
180 |
181 | def _stage2_compact_classifier(self, task_size):
182 | for n, p in self._network.fc.named_parameters():
183 | p.requires_grad=True
184 | if 'scales' in n:
185 | p.requires_grad=False
186 | print('fixed ', n)
187 |
188 | run_epochs = ca_epochs
189 | crct_num = self._total_classes
190 | param_list = [p for p in self._network.fc.parameters() if p.requires_grad]
191 | network_params = [{'params': param_list, 'lr': lrate,
192 | 'weight_decay': weight_decay}]
193 | optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay)
194 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs)
195 |
196 | self._network.to(self._device)
197 | if len(self._multiple_gpus) > 1:
198 | self._network = nn.DataParallel(self._network, self._multiple_gpus)
199 |
200 | self._network.eval()
201 | for epoch in range(run_epochs):
202 | losses = 0.
203 |
204 | sampled_data = []
205 | sampled_label = []
206 | num_sampled_pcls = 256
207 |
208 | for c_id in range(crct_num):
209 | t_id = c_id//task_size
210 | cls_mean = torch.tensor(self._class_means[c_id], dtype=torch.float64).to(self._device)
211 | cls_cov = self._class_covs[c_id].to(self._device)
212 |
213 | m = MultivariateNormal(cls_mean.float(), cls_cov.float())
214 |
215 | sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
216 |
217 | # a more robust implementation of feature scaling
218 | # fixed scaling during inference, dynamic during training
219 | rand_scaling = self.fc_scale_mu + 0.02*(torch.rand(sampled_data_single.size(0), device=self._device)-0.5)
220 | rand_scaling = 1/(1+rand_scaling*(self._cur_task-t_id))
221 | sampled_data_single = rand_scaling.unsqueeze(1)*sampled_data_single
222 |
223 | sampled_data.append(sampled_data_single)
224 | sampled_label.extend([c_id]*num_sampled_pcls)
225 |
226 | sampled_data = torch.cat(sampled_data, dim=0).float().to(self._device)
227 | sampled_label = torch.tensor(sampled_label).long().to(self._device)
228 |
229 | inputs = sampled_data
230 | targets= sampled_label
231 |
232 | sf_indexes = torch.randperm(inputs.size(0))
233 | inputs = inputs[sf_indexes]
234 | targets = targets[sf_indexes]
235 |
236 |
237 | for _iter in range(crct_num):
238 | inp = inputs[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls]
239 | tgt = targets[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls]
240 | outputs = self._network(inp, bcb_no_grad=True, fc_only=True)
241 | logits = outputs['logits']
242 |
243 | if self.logit_norm is not None:
244 | per_task_norm = []
245 | prev_t_size = 0
246 | cur_t_size = 0
247 | for _ti in range(self._cur_task+1):
248 | cur_t_size += self.task_sizes[_ti]
249 | temp_norm = torch.norm(logits[:, prev_t_size:cur_t_size], p=2, dim=-1, keepdim=True) + 1e-7
250 | per_task_norm.append(temp_norm)
251 | prev_t_size += self.task_sizes[_ti]
252 | per_task_norm = torch.cat(per_task_norm, dim=-1)
253 | norms = per_task_norm.mean(dim=-1, keepdim=True)
254 |
255 | norms_all = torch.norm(logits[:, :crct_num], p=2, dim=-1, keepdim=True) + 1e-7
256 | decoupled_logits = torch.div(logits[:, :crct_num], norms) / self.logit_norm
257 | loss = F.cross_entropy(decoupled_logits, tgt)
258 |
259 | else:
260 | loss = F.cross_entropy(logits[:, :crct_num], tgt)
261 |
262 | optimizer.zero_grad()
263 | loss.backward()
264 | optimizer.step()
265 | losses += loss.item()
266 |
267 | scheduler.step()
268 | test_acc = self._compute_accuracy(self._network, self.test_loader)
269 | info = 'CA Task {} => Loss {:.3f}, Test_accy {:.3f}'.format(
270 | self._cur_task, losses/self._total_classes, test_acc)
271 | logging.info(info)
272 |
273 |
274 |
--------------------------------------------------------------------------------
/slca_performance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GengDavid/SLCA/169d4e6b91e3ca30caa717200c1944981c4426d1/slca_performance.jpg
--------------------------------------------------------------------------------
/split_car.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import os.path as osp
4 | import shutil
5 | from tqdm import tqdm
6 |
7 | img_list = np.genfromtxt('mat2txt.txt', dtype=str) # [num, img, cls, istest]
8 | class_mappings = np.genfromtxt('label_map.txt', dtype=str)
9 | class_mappings = {a[0]: a[1] for a in class_mappings}
10 |
11 | for item in tqdm(img_list):
12 | if bool(int(item[-1])):
13 | cls_folder = osp.join('cars196', 'train', class_mappings[item[2]])
14 | if not os.path.exists(cls_folder):
15 | os.mkdir(cls_folder)
16 | shutil.copy(item[1], osp.join(cls_folder, item[1].split('/')[-1]))
17 | else:
18 | cls_folder = osp.join('cars196', 'val', class_mappings[item[2]])
19 | if not os.path.exists(cls_folder):
20 | os.mkdir(cls_folder)
21 | shutil.copy(item[1], osp.join(cls_folder, item[1].split('/')[-1]))
22 |
--------------------------------------------------------------------------------
/split_cub.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import os.path as osp
4 | import shutil
5 | from tqdm import tqdm
6 |
7 | train_val_list = np.genfromtxt('train_test_split.txt', dtype='str')
8 | img_list = np.genfromtxt('images.txt', dtype='str')
9 |
10 | img_id_mapping = {a[0]: a[1] for a in img_list}
11 | for img, is_train in tqdm(train_val_list):
12 | if bool(int(is_train)):
13 | # print(osp.join('CUB200', 'val', img_id_mapping[img]))
14 | os.remove(osp.join('CUB200', 'val', img_id_mapping[img]))
15 | else:
16 | os.remove(osp.join('CUB200', 'train', img_id_mapping[img]))
17 |
--------------------------------------------------------------------------------
/train_all.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python3 main.py --config=exps/slcapp/slcapp_cifar_lora.json &
2 | CUDA_VISIBLE_DEVICES=1 python3 main.py --config=exps/slcapp/slcapp_imgnetr_lora.json &
3 | CUDA_VISIBLE_DEVICES=2 python3 main.py --config=exps/slcapp/slcapp_cub_lora.json &
4 | CUDA_VISIBLE_DEVICES=3 python3 main.py --config=exps/slcapp/slcapp_cars_lora.json &
5 |
6 | CUDA_VISIBLE_DEVICES=0 python3 main.py --config=exps/slcapp/slcapp_cifar_lora_mocov3.json &
7 | CUDA_VISIBLE_DEVICES=1 python3 main.py --config=exps/slcapp/slcapp_imgnetr_lora_mocov3.json &
8 | CUDA_VISIBLE_DEVICES=2 python3 main.py --config=exps/slcapp/slcapp_cub_lora_mocov3.json &
9 | CUDA_VISIBLE_DEVICES=3 python3 main.py --config=exps/slcapp/slcapp_cars_lora_mocov3.json &
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | import copy
4 | import torch
5 | from utils import factory
6 | from utils.data_manager import DataManager
7 | from utils.toolkit import count_parameters
8 | import os
9 | import numpy as np
10 |
11 | def train(args):
12 | seed_list = copy.deepcopy(args['seed'])
13 | device = copy.deepcopy(args['device'])
14 |
15 | res_finals, res_avgs = [], []
16 | for run_id, seed in enumerate(seed_list):
17 | args['seed'] = seed
18 | args['run_id'] = run_id
19 | args['device'] = device
20 | res_final, res_avg = _train(args)
21 | res_finals.append(res_final)
22 | res_avgs.append(res_avg)
23 | logging.info('final accs: {}'.format(res_finals))
24 | logging.info('avg accs: {}'.format(res_avgs))
25 |
26 |
27 |
28 | def _train(args):
29 | try:
30 | os.mkdir("logs/{}_{}".format(args['model_name'], args['model_postfix']))
31 | except:
32 | pass
33 | logfilename = 'logs/{}_{}/{}_{}_{}_{}_{}_{}_{}'.format(args['model_name'], args['model_postfix'], args['prefix'], args['seed'], args['model_name'], args['convnet_type'],
34 | args['dataset'], args['init_cls'], args['increment'])
35 | logging.basicConfig(
36 | level=logging.INFO,
37 | format='%(asctime)s [%(filename)s] => %(message)s',
38 | handlers=[
39 | logging.FileHandler(filename=logfilename + '.log'),
40 | logging.StreamHandler(sys.stdout)
41 | ]
42 | )
43 |
44 | _set_random()
45 | _set_device(args)
46 | print_args(args)
47 | data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'])
48 | model = factory.get_model(args['model_name'], args)
49 |
50 | cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []}
51 | for task in range(data_manager.nb_tasks):
52 | logging.info('All params: {}'.format(count_parameters(model._network)))
53 | logging.info('Trainable params: {}'.format(count_parameters(model._network, True)))
54 | model.incremental_train(data_manager)
55 |
56 | cnn_accy, nme_accy = model.eval_task()
57 | model.after_task()
58 |
59 |
60 | if nme_accy is not None:
61 | logging.info('CNN: {}'.format(cnn_accy['grouped']))
62 | logging.info('NME: {}'.format(nme_accy['grouped']))
63 |
64 | cnn_curve['top1'].append(cnn_accy['top1'])
65 | cnn_curve['top5'].append(cnn_accy['top5'])
66 |
67 | nme_curve['top1'].append(nme_accy['top1'])
68 | nme_curve['top5'].append(nme_accy['top5'])
69 |
70 | logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
71 | logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean()))
72 | if 'task_acc' in cnn_accy.keys():
73 | logging.info('Task: {}'.format(cnn_accy['task_acc']))
74 | logging.info('CNN top5 curve: {}'.format(cnn_curve['top5']))
75 | logging.info('NME top1 curve: {}'.format(nme_curve['top1']))
76 | logging.info('NME top5 curve: {}\n'.format(nme_curve['top5']))
77 | else:
78 | logging.info('No NME accuracy.')
79 | logging.info('CNN: {}'.format(cnn_accy['grouped']))
80 |
81 | cnn_curve['top1'].append(cnn_accy['top1'])
82 | cnn_curve['top5'].append(cnn_accy['top5'])
83 |
84 | logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
85 | logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean()))
86 | logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5']))
87 |
88 | return (cnn_curve['top1'][-1], np.array(cnn_curve['top1']).mean())
89 |
90 | def _set_device(args):
91 | device_type = args['device']
92 | gpus = []
93 |
94 | for device in device_type:
95 | if device_type == -1:
96 | device = torch.device('cpu')
97 | else:
98 | device = torch.device('cuda:{}'.format(device))
99 |
100 | gpus.append(device)
101 |
102 | args['device'] = gpus
103 |
104 |
105 | def _set_random():
106 | torch.manual_seed(1)
107 | torch.cuda.manual_seed(1)
108 | torch.cuda.manual_seed_all(1)
109 | torch.backends.cudnn.deterministic = True
110 | torch.backends.cudnn.benchmark = False
111 |
112 |
113 | def print_args(args):
114 | for key, value in args.items():
115 | logging.info('{}: {}'.format(key, value))
116 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GengDavid/SLCA/169d4e6b91e3ca30caa717200c1944981c4426d1/utils/__init__.py
--------------------------------------------------------------------------------
/utils/buffer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import numpy as np
8 | from typing import Tuple
9 | from torchvision import transforms
10 | from copy import deepcopy
11 |
12 | def icarl_replay(self, dataset, val_set_split=0):
13 | """
14 | Merge the replay buffer with the current task data.
15 | Optionally split the replay buffer into a validation set.
16 |
17 | :param self: the model instance
18 | :param dataset: the dataset
19 | :param val_set_split: the fraction of the replay buffer to be used as validation set
20 | """
21 |
22 | if self.task > 0:
23 | buff_val_mask = torch.rand(len(self.buffer)) < val_set_split
24 | val_train_mask = torch.zeros(len(dataset.train_loader.dataset.data)).bool()
25 | val_train_mask[torch.randperm(len(dataset.train_loader.dataset.data))[:buff_val_mask.sum()]] = True
26 |
27 | if val_set_split > 0:
28 | self.val_loader = deepcopy(dataset.train_loader)
29 |
30 | data_concatenate = torch.cat if type(dataset.train_loader.dataset.data) == torch.Tensor else np.concatenate
31 | need_aug = hasattr(dataset.train_loader.dataset, 'not_aug_transform')
32 | if not need_aug:
33 | refold_transform = lambda x: x.cpu()
34 | else:
35 | data_shape = len(dataset.train_loader.dataset.data[0].shape)
36 | if data_shape == 3:
37 | refold_transform = lambda x: (x.cpu()*255).permute([0, 2, 3, 1]).numpy().astype(np.uint8)
38 | elif data_shape == 2:
39 | refold_transform = lambda x: (x.cpu()*255).squeeze(1).type(torch.uint8)
40 |
41 | # REDUCE AND MERGE TRAINING SET
42 | dataset.train_loader.dataset.targets = np.concatenate([
43 | dataset.train_loader.dataset.targets[~val_train_mask],
44 | self.buffer.labels.cpu().numpy()[:len(self.buffer)][~buff_val_mask]
45 | ])
46 | dataset.train_loader.dataset.data = data_concatenate([
47 | dataset.train_loader.dataset.data[~val_train_mask],
48 | refold_transform((self.buffer.examples)[:len(self.buffer)][~buff_val_mask])
49 | ])
50 |
51 | if val_set_split > 0:
52 | # REDUCE AND MERGE VALIDATION SET
53 | self.val_loader.dataset.targets = np.concatenate([
54 | self.val_loader.dataset.targets[val_train_mask],
55 | self.buffer.labels.cpu().numpy()[:len(self.buffer)][buff_val_mask]
56 | ])
57 | self.val_loader.dataset.data = data_concatenate([
58 | self.val_loader.dataset.data[val_train_mask],
59 | refold_transform((self.buffer.examples)[:len(self.buffer)][buff_val_mask])
60 | ])
61 |
62 | def reservoir(num_seen_examples: int, buffer_size: int) -> int:
63 | """
64 | Reservoir sampling algorithm.
65 | :param num_seen_examples: the number of seen examples
66 | :param buffer_size: the maximum buffer size
67 | :return: the target index if the current image is sampled, else -1
68 | """
69 | if num_seen_examples < buffer_size:
70 | return num_seen_examples
71 |
72 | rand = np.random.randint(0, num_seen_examples + 1)
73 | if rand < buffer_size:
74 | return rand
75 | else:
76 | return -1
77 |
78 |
79 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int:
80 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size
81 |
82 |
83 | class Buffer:
84 | """
85 | The memory buffer of rehearsal method.
86 | """
87 | def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir'):
88 | assert mode in ['ring', 'reservoir']
89 | self.buffer_size = buffer_size
90 | self.device = device
91 | self.num_seen_examples = 0
92 | self.functional_index = eval(mode)
93 | if mode == 'ring':
94 | assert n_tasks is not None
95 | self.task_number = n_tasks
96 | self.buffer_portion_size = buffer_size // n_tasks
97 | self.attributes = ['examples', 'labels', 'logits', 'task_labels']
98 |
99 | def to(self, device):
100 | self.device = device
101 | for attr_str in self.attributes:
102 | if hasattr(self, attr_str):
103 | setattr(self, attr_str, getattr(self, attr_str).to(device))
104 | return self
105 |
106 | def __len__(self):
107 | return min(self.num_seen_examples, self.buffer_size)
108 |
109 |
110 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor,
111 | logits: torch.Tensor, task_labels: torch.Tensor) -> None:
112 | """
113 | Initializes just the required tensors.
114 | :param examples: tensor containing the images
115 | :param labels: tensor containing the labels
116 | :param logits: tensor containing the outputs of the network
117 | :param task_labels: tensor containing the task labels
118 | """
119 | for attr_str in self.attributes:
120 | attr = eval(attr_str)
121 | if attr is not None and not hasattr(self, attr_str):
122 | typ = torch.int64 if attr_str.endswith('els') else torch.float32
123 | setattr(self, attr_str, torch.zeros((self.buffer_size,
124 | *attr.shape[1:]), dtype=typ, device=self.device))
125 |
126 | def add_data(self, examples, labels=None, logits=None, task_labels=None):
127 | """
128 | Adds the data to the memory buffer according to the reservoir strategy.
129 | :param examples: tensor containing the images
130 | :param labels: tensor containing the labels
131 | :param logits: tensor containing the outputs of the network
132 | :param task_labels: tensor containing the task labels
133 | :return:
134 | """
135 | if not hasattr(self, 'examples'):
136 | self.init_tensors(examples, labels, logits, task_labels)
137 |
138 | for i in range(examples.shape[0]):
139 | index = reservoir(self.num_seen_examples, self.buffer_size)
140 | self.num_seen_examples += 1
141 | if index >= 0:
142 | self.examples[index] = examples[i].to(self.device)
143 | if labels is not None:
144 | self.labels[index] = labels[i].to(self.device)
145 | if logits is not None:
146 | self.logits[index] = logits[i].to(self.device)
147 | if task_labels is not None:
148 | self.task_labels[index] = task_labels[i].to(self.device)
149 |
150 | def get_data(self, size: int, transform: transforms=None, return_index=False) -> Tuple:
151 | """
152 | Random samples a batch of size items.
153 | :param size: the number of requested items
154 | :param transform: the transformation to be applied (data augmentation)
155 | :return:
156 | """
157 | if size > min(self.num_seen_examples, self.examples.shape[0]):
158 | size = min(self.num_seen_examples, self.examples.shape[0])
159 |
160 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]),
161 | size=size, replace=False)
162 | if transform is None: transform = lambda x: x
163 | ret_tuple = (torch.stack([transform(ee.cpu())
164 | for ee in self.examples[choice]]).to(self.device),)
165 | for attr_str in self.attributes[1:]:
166 | if hasattr(self, attr_str):
167 | attr = getattr(self, attr_str)
168 | ret_tuple += (attr[choice],)
169 |
170 | if not return_index:
171 | return ret_tuple
172 | else:
173 | return (torch.tensor(choice).to(self.device), ) + ret_tuple
174 |
175 | return ret_tuple
176 |
177 | def get_data_by_index(self, indexes, transform: transforms=None) -> Tuple:
178 | """
179 | Returns the data by the given index.
180 | :param index: the index of the item
181 | :param transform: the transformation to be applied (data augmentation)
182 | :return:
183 | """
184 | if transform is None: transform = lambda x: x
185 | ret_tuple = (torch.stack([transform(ee.cpu())
186 | for ee in self.examples[indexes]]).to(self.device),)
187 | for attr_str in self.attributes[1:]:
188 | if hasattr(self, attr_str):
189 | attr = getattr(self, attr_str).to(self.device)
190 | ret_tuple += (attr[indexes],)
191 | return ret_tuple
192 |
193 |
194 | def is_empty(self) -> bool:
195 | """
196 | Returns true if the buffer is empty, false otherwise.
197 | """
198 | if self.num_seen_examples == 0:
199 | return True
200 | else:
201 | return False
202 |
203 | def get_all_data(self, transform: transforms=None) -> Tuple:
204 | """
205 | Return all the items in the memory buffer.
206 | :param transform: the transformation to be applied (data augmentation)
207 | :return: a tuple with all the items in the memory buffer
208 | """
209 | if transform is None: transform = lambda x: x
210 | ret_tuple = (torch.stack([transform(ee.cpu())
211 | for ee in self.examples]).to(self.device),)
212 | for attr_str in self.attributes[1:]:
213 | if hasattr(self, attr_str):
214 | attr = getattr(self, attr_str)
215 | ret_tuple += (attr,)
216 | return ret_tuple
217 |
218 | def empty(self) -> None:
219 | """
220 | Set all the tensors to None.
221 | """
222 | for attr_str in self.attributes:
223 | if hasattr(self, attr_str):
224 | delattr(self, attr_str)
225 | self.num_seen_examples = 0
226 |
--------------------------------------------------------------------------------
/utils/cutmix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | def rand_bbox(size, lam):
6 | W = size[2]
7 | H = size[3]
8 | cut_rat = np.sqrt(1. - lam)
9 | cut_w = np.int(W * cut_rat)
10 | cut_h = np.int(H * cut_rat)
11 |
12 | # uniform
13 | cx = np.random.randint(W)
14 | cy = np.random.randint(H)
15 |
16 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
17 | bby1 = np.clip(cy - cut_h // 2, 0, H)
18 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
19 | bby2 = np.clip(cy + cut_h // 2, 0, H)
20 |
21 | return bbx1, bby1, bbx2, bby2
22 |
23 | def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5):
24 | assert(alpha > 0)
25 | # generate mixed sample
26 | lam = np.random.beta(alpha, alpha)
27 |
28 | batch_size = x.size()[0]
29 | index = torch.randperm(batch_size)
30 |
31 | if torch.cuda.is_available():
32 | index = index.cuda()
33 |
34 | y_a, y_b = y, y[index]
35 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
36 | x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
37 |
38 | # adjust lambda to exactly match pixel ratio
39 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
40 | return x, y_a, y_b, lam
41 |
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import datasets, transforms
3 | from utils.toolkit import split_images_labels
4 |
5 |
6 | class iData(object):
7 | train_trsf = []
8 | test_trsf = []
9 | common_trsf = []
10 | class_order = None
11 |
12 |
13 | class iCIFAR10(iData):
14 | use_path = False
15 | train_trsf = [
16 | transforms.RandomCrop(32, padding=4),
17 | transforms.RandomHorizontalFlip(p=0.5),
18 | transforms.ColorJitter(brightness=63/255)
19 | ]
20 | test_trsf = []
21 | common_trsf = [
22 | transforms.ToTensor(),
23 | transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
24 | ]
25 |
26 | class_order = np.arange(10).tolist()
27 |
28 | def download_data(self):
29 | train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True)
30 | test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True)
31 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets)
32 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets)
33 |
34 |
35 | class iCIFAR100(iData):
36 | use_path = False
37 | train_trsf = [
38 | transforms.RandomCrop(32, padding=4),
39 | transforms.RandomHorizontalFlip(),
40 | transforms.ColorJitter(brightness=63/255)
41 | ]
42 | test_trsf = []
43 | common_trsf = [
44 | transforms.ToTensor(),
45 | transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),
46 | ]
47 |
48 | class_order = np.arange(100).tolist()
49 |
50 | def download_data(self):
51 | train_dataset = datasets.cifar.CIFAR100('./data', train=True, download=True)
52 | test_dataset = datasets.cifar.CIFAR100('./data', train=False, download=True)
53 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets)
54 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets)
55 |
56 | class iCIFAR100_224(iCIFAR100):
57 | train_trsf = [
58 | transforms.RandomResizedCrop(224, interpolation=3),
59 | transforms.RandomHorizontalFlip(),
60 | transforms.ColorJitter(brightness=63/255)
61 | ]
62 | test_trsf = [
63 | transforms.Resize(256, interpolation=3),
64 | transforms.CenterCrop(224),
65 | ]
66 |
67 | class iImageNet1000(iData):
68 | use_path = True
69 | train_trsf = [
70 | transforms.RandomResizedCrop(224),
71 | transforms.RandomHorizontalFlip(),
72 | transforms.ColorJitter(brightness=63/255)
73 | ]
74 | test_trsf = [
75 | transforms.Resize(256),
76 | transforms.CenterCrop(224),
77 | ]
78 | common_trsf = [
79 | transforms.ToTensor(),
80 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
81 | ]
82 |
83 | class_order = np.arange(1000).tolist()
84 |
85 | def download_data(self):
86 | assert 0,"You should specify the folder of your dataset"
87 | train_dir = '[DATA-PATH]/train/'
88 | test_dir = '[DATA-PATH]/val/'
89 |
90 | train_dset = datasets.ImageFolder(train_dir)
91 | test_dset = datasets.ImageFolder(test_dir)
92 |
93 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
94 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
95 |
96 | class iImageNet100(iData):
97 | use_path = True
98 | train_trsf = [
99 | transforms.RandomResizedCrop(224),
100 | transforms.RandomHorizontalFlip(),
101 | ]
102 | test_trsf = [
103 | transforms.Resize(256),
104 | transforms.CenterCrop(224),
105 | ]
106 | common_trsf = [
107 | transforms.ToTensor(),
108 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
109 | ]
110 |
111 | class_order = np.arange(1000).tolist()
112 |
113 | def download_data(self):
114 | train_dir = 'data/imagenet100/train/'
115 | test_dir = 'data/imagenet100/val/'
116 |
117 | train_dset = datasets.ImageFolder(train_dir)
118 | test_dset = datasets.ImageFolder(test_dir)
119 |
120 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
121 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
122 |
123 |
124 | class iImageNetR(iData):
125 | use_path = True
126 | train_trsf = [
127 | transforms.RandomResizedCrop(224, interpolation=3),
128 | transforms.RandomHorizontalFlip(),
129 | ]
130 | test_trsf = [
131 | transforms.Resize(256, interpolation=3),
132 | transforms.CenterCrop(224),
133 | ]
134 | common_trsf = [
135 | transforms.ToTensor(),
136 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
137 | ]
138 |
139 | class_order = np.arange(1000).tolist()
140 |
141 | def download_data(self):
142 | train_dir = 'data/imagenet-r/train/'
143 | test_dir = 'data/imagenet-r/val/'
144 |
145 | train_dset = datasets.ImageFolder(train_dir)
146 | test_dset = datasets.ImageFolder(test_dir)
147 |
148 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
149 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
150 |
151 | class iCUB200_224(iData):
152 | use_path = True
153 | train_trsf = [
154 | transforms.Resize((300, 300), interpolation=3),
155 | transforms.RandomCrop((224, 224)),
156 | transforms.RandomHorizontalFlip(),
157 | ]
158 | test_trsf = [
159 | transforms.Resize(256, interpolation=3),
160 | transforms.CenterCrop(224),
161 | ]
162 | common_trsf = [
163 | transforms.ToTensor(),
164 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
165 | ]
166 | class_order = np.arange(1000).tolist()
167 |
168 | def download_data(self):
169 | train_dir = 'data/cub_200/train/'
170 | test_dir = 'data/cub_200/val/'
171 |
172 | train_dset = datasets.ImageFolder(train_dir)
173 | test_dset = datasets.ImageFolder(test_dir)
174 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
175 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
176 |
177 | class iCARS196_224(iData):
178 | use_path = True
179 | train_trsf = [
180 | transforms.Resize((300, 300), interpolation=3),
181 | transforms.RandomCrop((224, 224)),
182 | transforms.RandomHorizontalFlip(),
183 | ]
184 | test_trsf = [
185 | transforms.Resize(256, interpolation=3),
186 | transforms.CenterCrop(224),
187 | ]
188 | common_trsf = [
189 | transforms.ToTensor(),
190 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
191 | ]
192 | class_order = np.arange(1000).tolist()
193 |
194 | def download_data(self):
195 | train_dir = 'data/cars196/train/'
196 | test_dir = 'data/cars196/val/'
197 |
198 | train_dset = datasets.ImageFolder(train_dir)
199 | test_dset = datasets.ImageFolder(test_dir)
200 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
201 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
202 |
203 |
204 | class iResisc45_224(iData):
205 | use_path = True
206 | train_trsf = [
207 | transforms.Resize((300, 300), interpolation=3),
208 | transforms.RandomCrop((224, 224)),
209 | transforms.RandomHorizontalFlip(),
210 | ]
211 | test_trsf = [
212 | transforms.Resize(256, interpolation=3),
213 | transforms.CenterCrop(224),
214 | ]
215 | common_trsf = [
216 | transforms.ToTensor(),
217 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
218 | ]
219 | class_order = np.arange(1000).tolist()
220 |
221 | def download_data(self):
222 | train_dir = 'data/resisc45/train/'
223 | test_dir = 'data/resisc45/val/'
224 |
225 | train_dset = datasets.ImageFolder(train_dir)
226 | test_dset = datasets.ImageFolder(test_dir)
227 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
228 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
229 |
230 |
231 |
232 | class iSketch345_224(iData):
233 | use_path = True
234 | train_trsf = [
235 | transforms.Resize((300, 300), interpolation=3),
236 | transforms.RandomCrop((224, 224)),
237 | transforms.RandomHorizontalFlip(),
238 | ]
239 | test_trsf = [
240 | transforms.Resize(256, interpolation=3),
241 | transforms.CenterCrop(224),
242 | ]
243 | common_trsf = [
244 | transforms.ToTensor(),
245 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
246 | ]
247 | class_order = np.arange(1000).tolist()
248 |
249 | def download_data(self):
250 | train_dir = 'data/sketch345/train/'
251 | test_dir = 'data/sketch345/val/'
252 |
253 | train_dset = datasets.ImageFolder(train_dir)
254 | test_dset = datasets.ImageFolder(test_dir)
255 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
256 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
--------------------------------------------------------------------------------
/utils/data_manager.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000, iCIFAR100_224, iImageNetR, iCUB200_224, iResisc45_224, iCARS196_224, iSketch345_224
7 | from copy import deepcopy
8 | import random
9 |
10 | class DataManager(object):
11 | def __init__(self, dataset_name, shuffle, seed, init_cls, increment):
12 | self.dataset_name = dataset_name
13 | self._setup_data(dataset_name, shuffle, seed)
14 | assert init_cls <= len(self._class_order), 'No enough classes.'
15 | self._increments = [init_cls]
16 | while sum(self._increments) + increment < len(self._class_order):
17 | self._increments.append(increment)
18 | offset = len(self._class_order) - sum(self._increments)
19 | if offset > 0:
20 | self._increments.append(offset)
21 |
22 | @property
23 | def nb_tasks(self):
24 | return len(self._increments)
25 |
26 | def get_task_size(self, task):
27 | return self._increments[task]
28 |
29 | def get_dataset(self, indices, source, mode, appendent=None, ret_data=False, with_raw=False, with_noise=False):
30 | if source == 'train':
31 | x, y = self._train_data, self._train_targets
32 | elif source == 'test':
33 | x, y = self._test_data, self._test_targets
34 | else:
35 | raise ValueError('Unknown data source {}.'.format(source))
36 |
37 | if mode == 'train':
38 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
39 | elif mode == 'flip':
40 | trsf = transforms.Compose([*self._test_trsf, transforms.RandomHorizontalFlip(p=1.), *self._common_trsf])
41 | elif mode == 'test':
42 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
43 | else:
44 | raise ValueError('Unknown mode {}.'.format(mode))
45 |
46 | data, targets = [], []
47 | for idx in indices:
48 | class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1)
49 | data.append(class_data)
50 | targets.append(class_targets)
51 |
52 | if appendent is not None and len(appendent) != 0:
53 | appendent_data, appendent_targets = appendent
54 | data.append(appendent_data)
55 | targets.append(appendent_targets)
56 |
57 | data, targets = np.concatenate(data), np.concatenate(targets)
58 |
59 | if ret_data:
60 | return data, targets, DummyDataset(data, targets, trsf, self.use_path, with_raw, with_noise)
61 | else:
62 | return DummyDataset(data, targets, trsf, self.use_path, with_raw, with_noise)
63 |
64 | def get_dataset_with_split(self, indices, source, mode, appendent=None, val_samples_per_class=0):
65 | if source == 'train':
66 | x, y = self._train_data, self._train_targets
67 | elif source == 'test':
68 | x, y = self._test_data, self._test_targets
69 | else:
70 | raise ValueError('Unknown data source {}.'.format(source))
71 |
72 | if mode == 'train':
73 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
74 | elif mode == 'test':
75 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
76 | else:
77 | raise ValueError('Unknown mode {}.'.format(mode))
78 |
79 | train_data, train_targets = [], []
80 | val_data, val_targets = [], []
81 | for idx in indices:
82 | class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1)
83 | val_indx = np.random.choice(len(class_data), val_samples_per_class, replace=False)
84 | train_indx = list(set(np.arange(len(class_data))) - set(val_indx))
85 | val_data.append(class_data[val_indx])
86 | val_targets.append(class_targets[val_indx])
87 | train_data.append(class_data[train_indx])
88 | train_targets.append(class_targets[train_indx])
89 |
90 | if appendent is not None:
91 | appendent_data, appendent_targets = appendent
92 | for idx in range(0, int(np.max(appendent_targets))+1):
93 | append_data, append_targets = self._select(appendent_data, appendent_targets,
94 | low_range=idx, high_range=idx+1)
95 | val_indx = np.random.choice(len(append_data), val_samples_per_class, replace=False)
96 | train_indx = list(set(np.arange(len(append_data))) - set(val_indx))
97 | val_data.append(append_data[val_indx])
98 | val_targets.append(append_targets[val_indx])
99 | train_data.append(append_data[train_indx])
100 | train_targets.append(append_targets[train_indx])
101 |
102 | train_data, train_targets = np.concatenate(train_data), np.concatenate(train_targets)
103 | val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets)
104 |
105 | return DummyDataset(train_data, train_targets, trsf, self.use_path), \
106 | DummyDataset(val_data, val_targets, trsf, self.use_path)
107 |
108 | def _setup_data(self, dataset_name, shuffle, seed):
109 | idata = _get_idata(dataset_name)
110 | idata.download_data()
111 |
112 | # Data
113 | self._train_data, self._train_targets = idata.train_data, idata.train_targets
114 | self._test_data, self._test_targets = idata.test_data, idata.test_targets
115 | self.use_path = idata.use_path
116 |
117 | # Transforms
118 | self._train_trsf = idata.train_trsf
119 | self._test_trsf = idata.test_trsf
120 | self._common_trsf = idata.common_trsf
121 |
122 | # Order
123 | order = [i for i in range(len(np.unique(self._train_targets)))]
124 | if shuffle:
125 | np.random.seed(seed)
126 | order = np.random.permutation(len(order)).tolist()
127 | else:
128 | order = idata.class_order
129 | self._class_order = order
130 | logging.info(self._class_order)
131 |
132 | # Map indices
133 | self._train_targets = _map_new_class_index(self._train_targets, self._class_order)
134 | self._test_targets = _map_new_class_index(self._test_targets, self._class_order)
135 |
136 | def _select(self, x, y, low_range, high_range):
137 | idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0]
138 | return x[idxes], y[idxes]
139 |
140 |
141 | class DummyDataset(Dataset):
142 | def __init__(self, images, labels, trsf, use_path=False, with_raw=False, with_noise=False):
143 | assert len(images) == len(labels), 'Data size error!'
144 | self.images = images
145 | self.labels = labels
146 | self.trsf = trsf
147 | self.use_path = use_path
148 | self.with_raw = with_raw
149 | if use_path and with_raw:
150 | self.raw_trsf = transforms.Compose([transforms.Resize((500, 500)), transforms.ToTensor()])
151 | else:
152 | self.raw_trsf = transforms.Compose([transforms.ToTensor()])
153 | if with_noise:
154 | class_list = np.unique(self.labels)
155 | self.ori_labels = deepcopy(labels)
156 | for cls in class_list:
157 | random_target = class_list.tolist()
158 | random_target.remove(cls)
159 | tindx = [i for i, x in enumerate(self.ori_labels) if x == cls]
160 | for i in tindx[:round(len(tindx)*0.2)]:
161 | self.labels[i] = random.choice(random_target)
162 |
163 |
164 | def __len__(self):
165 | return len(self.images)
166 |
167 | def __getitem__(self, idx):
168 | if self.use_path:
169 | load_image = pil_loader(self.images[idx])
170 | image = self.trsf(load_image)
171 | else:
172 | load_image = Image.fromarray(self.images[idx])
173 | image = self.trsf(load_image)
174 | label = self.labels[idx]
175 | if self.with_raw:
176 | return idx, image, label, self.raw_trsf(load_image)
177 | return idx, image, label
178 |
179 |
180 | def _map_new_class_index(y, order):
181 | return np.array(list(map(lambda x: order.index(x), y)))
182 |
183 |
184 | def _get_idata(dataset_name):
185 | name = dataset_name.lower()
186 | if name == 'cifar10':
187 | return iCIFAR10()
188 | elif name == 'cifar100':
189 | return iCIFAR100()
190 | elif name == 'cifar100_224':
191 | return iCIFAR100_224()
192 | elif name == 'imagenet1000':
193 | return iImageNet1000()
194 | elif name == "imagenet100":
195 | return iImageNet100()
196 | elif name == "imagenet-r":
197 | return iImageNetR()
198 | elif name == 'cub200_224':
199 | return iCUB200_224()
200 | elif name == 'resisc45':
201 | return iResisc45_224()
202 | elif name == 'cars196_224':
203 | return iCARS196_224()
204 | elif name == 'sketch345_224':
205 | return iSketch345_224()
206 | else:
207 | raise NotImplementedError('Unknown dataset {}.'.format(dataset_name))
208 |
209 |
210 | def pil_loader(path):
211 | '''
212 | Ref:
213 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
214 | '''
215 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
216 | with open(path, 'rb') as f:
217 | img = Image.open(f)
218 | return img.convert('RGB')
219 |
220 |
221 | def accimage_loader(path):
222 | '''
223 | Ref:
224 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
225 | accimage is an accelerated Image loader and preprocessor leveraging Intel IPP.
226 | accimage is available on conda-forge.
227 | '''
228 | import accimage
229 | try:
230 | return accimage.Image(path)
231 | except IOError:
232 | # Potentially a decoding problem, fall back to PIL.Image
233 | return pil_loader(path)
234 |
235 |
236 | def default_loader(path):
237 | '''
238 | Ref:
239 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
240 | '''
241 | from torchvision import get_image_backend
242 | if get_image_backend() == 'accimage':
243 | return accimage_loader(path)
244 | else:
245 | return pil_loader(path)
246 |
--------------------------------------------------------------------------------
/utils/factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.slca import SLCA
3 | from models.slca_pp import SLCApp
4 |
5 | def get_model(model_name, args):
6 | name = model_name.lower()
7 | if 'slcapp' in name:
8 | return SLCApp(args)
9 | elif 'slca' in name:
10 | return SLCA(args)
11 | else:
12 | assert 0
13 |
--------------------------------------------------------------------------------
/utils/inc_net.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch import nn
4 | from convs.cifar_resnet import resnet32
5 | from convs.resnet import resnet18, resnet34, resnet50
6 | from convs.linears import SimpleContinualLinear
7 | from convs.vits import vit_base_patch16_224_in21k, vit_base_patch16_224_mocov3, vit_base_lora_patch16_224_in21k, vit_base_lora_patch16_224_mocov3
8 | import torch.nn.functional as F
9 |
10 | def get_convnet(cfg, pretrained=False):
11 | name = cfg['convnet_type']
12 | name = name.lower()
13 | if name == 'resnet32':
14 | return resnet32()
15 | elif name == 'resnet18':
16 | return resnet18(pretrained=pretrained)
17 | elif name == 'resnet18_cifar':
18 | return resnet18(pretrained=pretrained, cifar=True)
19 | elif name == 'resnet18_cifar_cos':
20 | return resnet18(pretrained=pretrained, cifar=True, no_last_relu=True)
21 | elif name == 'resnet34':
22 | return resnet34(pretrained=pretrained)
23 | elif name == 'resnet50':
24 | return resnet50(pretrained=pretrained)
25 | elif name == 'vit-b-p16':
26 | return vit_base_patch16_224_in21k(pretrained=True)
27 | elif name == 'vit-b-p16-mocov3':
28 | return vit_base_patch16_224_mocov3(pretrained=True)
29 | elif name == 'vit-b-p16-lora':
30 | return vit_base_lora_patch16_224_in21k(pretrained=True, lora_rank=cfg['lora_rank'])
31 | elif name == 'vit-b-p16-lora-mocov3':
32 | return vit_base_lora_patch16_224_mocov3(pretrained=True, lora_rank=cfg['lora_rank'])
33 | else:
34 | raise NotImplementedError('Unknown type {}'.format(name))
35 |
36 |
37 | class BaseNet(nn.Module):
38 |
39 | def __init__(self, cfg, pretrained):
40 | super(BaseNet, self).__init__()
41 |
42 | self.convnet = get_convnet(cfg, pretrained)
43 | self.fc = None
44 |
45 | @property
46 | def feature_dim(self):
47 | return self.convnet.out_dim
48 |
49 | def extract_vector(self, x):
50 | return self.convnet(x)['features']
51 |
52 | def forward(self, x):
53 | x = self.convnet(x)
54 | out = self.fc(x['features'])
55 | '''
56 | {
57 | 'fmaps': [x_1, x_2, ..., x_n],
58 | 'features': features
59 | 'logits': logits
60 | }
61 | '''
62 | out.update(x)
63 |
64 | return out
65 |
66 | def update_fc(self, nb_classes):
67 | pass
68 |
69 | def generate_fc(self, in_dim, out_dim):
70 | pass
71 |
72 | def copy(self):
73 | return copy.deepcopy(self)
74 |
75 | def freeze(self):
76 | for param in self.parameters():
77 | param.requires_grad = False
78 | self.eval()
79 |
80 | return self
81 |
82 | class FinetuneIncrementalNet(BaseNet):
83 |
84 | def __init__(self, cfg, pretrained, fc_with_ln=False):
85 | super().__init__(cfg, pretrained)
86 | self.old_fc = None
87 | self.fc_with_ln = fc_with_ln
88 | if 'fc_scale_mu' in cfg.keys():
89 | self.fc_scale_mu = cfg['fc_scale_mu']
90 | else:
91 | self.fc_scale_mu = -1
92 |
93 |
94 | def extract_layerwise_vector(self, x, pool=True):
95 | with torch.no_grad():
96 | features = self.convnet(x, layer_feat=True)['features']
97 | for f_i in range(len(features)):
98 | if pool:
99 | features[f_i] = features[f_i].mean(1).cpu().numpy()
100 | else:
101 | features[f_i] = features[f_i][:, 0].cpu().numpy()
102 | return features
103 |
104 |
105 | def update_fc(self, nb_classes, freeze_old=True):
106 | if self.fc is None:
107 | self.fc = self.generate_fc(self.feature_dim, nb_classes)
108 | else:
109 | self.fc.update(nb_classes, freeze_old=freeze_old)
110 |
111 | def save_old_fc(self):
112 | if self.old_fc is None:
113 | self.old_fc = copy.deepcopy(self.fc)
114 | else:
115 | self.old_fc.heads.append(copy.deepcopy(self.fc.heads[-1]))
116 |
117 | def generate_fc(self, in_dim, out_dim):
118 | fc = SimpleContinualLinear(in_dim, out_dim, scale_mu=self.fc_scale_mu)
119 |
120 | return fc
121 |
122 | def forward(self, x, bcb_no_grad=False, fc_only=False):
123 | if fc_only:
124 | fc_out = self.fc(x)
125 | if self.old_fc is not None:
126 | old_fc_logits = self.old_fc(x)['logits']
127 | fc_out['old_logits'] = old_fc_logits
128 | return fc_out
129 | if bcb_no_grad:
130 | with torch.no_grad():
131 | x = self.convnet(x)
132 | else:
133 | x = self.convnet(x)
134 | out = self.fc(x['features'])
135 | out.update(x)
136 |
137 | return out
138 |
139 |
140 |
--------------------------------------------------------------------------------
/utils/net_linear_wapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class LinearWapper(nn.Module):
4 | def __init__(self, model):
5 | super(LinearWapper, self).__init__()
6 | self.reset_parameters()
7 |
8 | def reset_parameters(self):
9 | for m in self.modules():
10 | nn.init.kaiming_uniform_(m.weight, nonlinearity='linear')
11 | nn.init.constant_(m.bias, 0)
12 |
13 | def forward(self, input):
14 | return {'logits': F.linear(input, self.weight, self.bias)}
15 |
--------------------------------------------------------------------------------
/utils/toolkit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def count_parameters(model, trainable=False):
7 | if trainable:
8 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
9 | return sum(p.numel() for p in model.parameters())
10 |
11 |
12 | def tensor2numpy(x):
13 | return x.cpu().data.numpy() if x.is_cuda else x.data.numpy()
14 |
15 |
16 | def target2onehot(targets, n_classes):
17 | onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
18 | onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.)
19 | return onehot
20 |
21 |
22 | def makedirs(path):
23 | if not os.path.exists(path):
24 | os.makedirs(path)
25 |
26 |
27 | def accuracy(y_pred, y_true, nb_old, increment=10):
28 | assert len(y_pred) == len(y_true), 'Data length error.'
29 | all_acc = {}
30 | all_acc['total'] = np.around((y_pred == y_true).sum()*100 / len(y_true), decimals=2)
31 |
32 | # Grouped accuracy
33 | for class_id in range(0, np.max(y_true), increment):
34 | idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0]
35 | label = '{}-{}'.format(str(class_id).rjust(2, '0'), str(class_id+increment-1).rjust(2, '0'))
36 | all_acc[label] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2)
37 |
38 | # Old accuracy
39 | idxes = np.where(y_true < nb_old)[0]
40 | all_acc['old'] = 0 if len(idxes) == 0 else np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes),
41 | decimals=2)
42 |
43 | # New accuracy
44 | idxes = np.where(y_true >= nb_old)[0]
45 | all_acc['new'] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2)
46 |
47 | return all_acc
48 |
49 |
50 | def split_images_labels(imgs):
51 | # split trainset.imgs in ImageFolder
52 | images = []
53 | labels = []
54 | for item in imgs:
55 | images.append(item[0])
56 | labels.append(item[1])
57 |
58 | return np.array(images), np.array(labels)
59 |
--------------------------------------------------------------------------------