├── .gitignore ├── LICENSE ├── README.md ├── collect_env.py ├── convs ├── __init__.py ├── cifar_resnet.py ├── linears.py ├── resnet.py └── vits.py ├── evaluator.py ├── exps ├── slca │ ├── slca_cars.json │ ├── slca_cars_mocov3.json │ ├── slca_cifar.json │ ├── slca_cifar_mocov3.json │ ├── slca_cub.json │ ├── slca_cub_mocov3.json │ ├── slca_imgnetr.json │ └── slca_imgnetr_mocov3.json └── slcapp │ ├── slcapp_cars_lora.json │ ├── slcapp_cars_lora_mocov3.json │ ├── slcapp_cifar_lora.json │ ├── slcapp_cifar_lora_mocov3.json │ ├── slcapp_cub_lora.json │ ├── slcapp_cub_lora_mocov3.json │ ├── slcapp_imgnetr_lora.json │ └── slcapp_imgnetr_lora_mocov3.json ├── main.py ├── models ├── __init__.py ├── base.py ├── slca.py └── slca_pp.py ├── slca_performance.jpg ├── split_car.py ├── split_cub.py ├── train_all.sh ├── trainer.py └── utils ├── __init__.py ├── buffer.py ├── cutmix.py ├── data.py ├── data_manager.py ├── factory.py ├── inc_net.py ├── net_linear_wapper.py └── toolkit.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | logs/ 3 | __pycache__/ 4 | *.pyc 5 | .DS_Store 6 | 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gengwei (David) Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 |
5 |

SLCA++: Unleash the Power of Sequential Fine-tuning for Continual Learning with Pre-training

6 |
7 | 8 |
9 | Gengwei Zhang*  Liyuan Wang*  Guoliang Kang  Ling Chen  Yunchao Wei 10 |
11 |
12 | 13 |
14 | 15 | 16 | PyTorch code for paper "[SLCA++: Unleash the Power of Sequential Fine-tuning for Continual Learning with Pre-training](https://arxiv.org/abs/2408.08295)", together with the code for our ICCV 2023 paper "[SLCA: Slow Learner with Classifier Alignment for Continual Learning on a Pre-trained Model](https://arxiv.org/abs/2303.05118)". 17 | 18 | ## What's new? 19 | [2024.08] We release SLCA++, a parameter-efficient version of SLCA with even better continual performance on fine-grained benchmarks! 20 | 21 | ## Introduction 22 | In our paper, we present an in-depth analysis of the progressive overfitting problem from the lens of Seq FT. Considering that the overly fast representation learning and the biased classification layer constitute this particular problem, we introduce the advanced Slow Learner with Classifier Alignment (SLCA++) framework to unleash the power of Seq FT, serving as a strong baseline approach for Continual Learning with Pre-training (CLPT). Our approach involves a Slow Learner (SL) to selectively reduce the learning rate of backbone parameters, and a Classifier Alignment (CA) to align the disjoint classification layers in a post-hoc fashion. We further enhance the efficacy of SL with a symmetric cross-entropy loss (SCE), as well as employ a parameter-efficient strategy to implement Seq FT with SLCA++. Across a variety of continual learning scenarios, including class-incremental learning on general datasets like CIFAR-100 and ImageNet-R, fine-grained datasets like CUB-200 and Cars-196, and domain-incremental learning on DomainNet, our approach provides substantial improvements and outperforms state-of-the-art methods by a large margin. 23 | 24 | 25 | 26 | ## Requirement 27 | 1. torch==1.12.0 28 | 2. torchvision==0.13.0 29 | 3. timm==0.5.4 30 | 4. tqdm 31 | 5. numpy 32 | 6. scipy 33 | 7. quadprog 34 | 8. POT 35 | 36 | ## Pre-trained Models 37 | Please download pre-trained ViT-Base models from [MoCo v3](https://drive.google.com/file/d/1bshDu4jEKztZZvwpTVXSAuCsDoXwCkfy/view?usp=share_link) and [ImaegNet-21K](https://drive.google.com/file/d/1PcAOf0tJYs1FVDpj-7lrkSuwXTJXVmuk/view?usp=share_link) and then put or link the pre-trained models to ```SLCA/pretrained``` 38 | 39 | ## Acknowledgement 40 | This repo is heavily based on [PyCIL](https://github.com/G-U-N/PyCIL), many thanks. 41 | 42 | ## Citation 43 | 44 | If you find our codes or paper useful, please consider giving us a star or cite with: 45 | 46 | ``` 47 | @misc{zhang2024slcaunleashpowersequential, 48 | title={SLCA++: Unleash the Power of Sequential Fine-tuning for Continual Learning with Pre-training}, 49 | author={Zhang, Gengwei and Wang, Liyuan and Kang, Guoliang and Chen, Ling and Wei, Yunchao}, 50 | year={2024}, 51 | eprint={2408.08295}, 52 | archivePrefix={arXiv}, 53 | url={https://arxiv.org/abs/2408.08295}, 54 | } 55 | ``` 56 | 57 | ``` 58 | @inproceedings{zhang2023slca, 59 | title={SLCA: Slow Learner with Classifier Alignment for Continual Learning on a Pre-trained Model}, 60 | author={Zhang, Gengwei and Wang, Liyuan and Kang, Guoliang and Chen, Ling and Wei, Yunchao}, 61 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 62 | year={2023} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /collect_env.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | # Unlike the rest of the PyTorch this file must be python2 compliant. 4 | # This script outputs relevant system environment info 5 | # Run it with `python collect_env.py`. 6 | import datetime 7 | import locale 8 | import re 9 | import subprocess 10 | import sys 11 | import os 12 | from collections import namedtuple 13 | 14 | 15 | try: 16 | import torch 17 | TORCH_AVAILABLE = True 18 | except (ImportError, NameError, AttributeError, OSError): 19 | TORCH_AVAILABLE = False 20 | 21 | # System Environment Information 22 | SystemEnv = namedtuple('SystemEnv', [ 23 | 'torch_version', 24 | 'is_debug_build', 25 | 'cuda_compiled_version', 26 | 'gcc_version', 27 | 'clang_version', 28 | 'cmake_version', 29 | 'os', 30 | 'libc_version', 31 | 'python_version', 32 | 'python_platform', 33 | 'is_cuda_available', 34 | 'cuda_runtime_version', 35 | 'nvidia_driver_version', 36 | 'nvidia_gpu_models', 37 | 'cudnn_version', 38 | 'pip_version', # 'pip' or 'pip3' 39 | 'pip_packages', 40 | 'conda_packages', 41 | 'hip_compiled_version', 42 | 'hip_runtime_version', 43 | 'miopen_runtime_version', 44 | 'caching_allocator_config', 45 | 'is_xnnpack_available', 46 | ]) 47 | 48 | 49 | def run(command): 50 | """Returns (return-code, stdout, stderr)""" 51 | p = subprocess.Popen(command, stdout=subprocess.PIPE, 52 | stderr=subprocess.PIPE, shell=True) 53 | raw_output, raw_err = p.communicate() 54 | rc = p.returncode 55 | if get_platform() == 'win32': 56 | enc = 'oem' 57 | else: 58 | enc = locale.getpreferredencoding() 59 | output = raw_output.decode(enc) 60 | err = raw_err.decode(enc) 61 | return rc, output.strip(), err.strip() 62 | 63 | 64 | def run_and_read_all(run_lambda, command): 65 | """Runs command using run_lambda; reads and returns entire output if rc is 0""" 66 | rc, out, _ = run_lambda(command) 67 | if rc != 0: 68 | return None 69 | return out 70 | 71 | 72 | def run_and_parse_first_match(run_lambda, command, regex): 73 | """Runs command using run_lambda, returns the first regex match if it exists""" 74 | rc, out, _ = run_lambda(command) 75 | if rc != 0: 76 | return None 77 | match = re.search(regex, out) 78 | if match is None: 79 | return None 80 | return match.group(1) 81 | 82 | def run_and_return_first_line(run_lambda, command): 83 | """Runs command using run_lambda and returns first line if output is not empty""" 84 | rc, out, _ = run_lambda(command) 85 | if rc != 0: 86 | return None 87 | return out.split('\n')[0] 88 | 89 | 90 | def get_conda_packages(run_lambda): 91 | conda = os.environ.get('CONDA_EXE', 'conda') 92 | out = run_and_read_all(run_lambda, "{} list".format(conda)) 93 | if out is None: 94 | return out 95 | 96 | return "\n".join( 97 | line 98 | for line in out.splitlines() 99 | if not line.startswith("#") 100 | and any( 101 | name in line 102 | for name in { 103 | "torch", 104 | "numpy", 105 | "cudatoolkit", 106 | "soumith", 107 | "mkl", 108 | "magma", 109 | "mkl", 110 | } 111 | ) 112 | ) 113 | 114 | def get_gcc_version(run_lambda): 115 | return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') 116 | 117 | def get_clang_version(run_lambda): 118 | return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') 119 | 120 | 121 | def get_cmake_version(run_lambda): 122 | return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') 123 | 124 | 125 | def get_nvidia_driver_version(run_lambda): 126 | if get_platform() == 'darwin': 127 | cmd = 'kextstat | grep -i cuda' 128 | return run_and_parse_first_match(run_lambda, cmd, 129 | r'com[.]nvidia[.]CUDA [(](.*?)[)]') 130 | smi = get_nvidia_smi() 131 | return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') 132 | 133 | 134 | def get_gpu_info(run_lambda): 135 | if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): 136 | if TORCH_AVAILABLE and torch.cuda.is_available(): 137 | return torch.cuda.get_device_name(None) 138 | return None 139 | smi = get_nvidia_smi() 140 | uuid_regex = re.compile(r' \(UUID: .+?\)') 141 | rc, out, _ = run_lambda(smi + ' -L') 142 | if rc != 0: 143 | return None 144 | # Anonymize GPUs by removing their UUID 145 | return re.sub(uuid_regex, '', out) 146 | 147 | 148 | def get_running_cuda_version(run_lambda): 149 | return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') 150 | 151 | 152 | def get_cudnn_version(run_lambda): 153 | """This will return a list of libcudnn.so; it's hard to tell which one is being used""" 154 | if get_platform() == 'win32': 155 | system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') 156 | cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") 157 | where_cmd = os.path.join(system_root, 'System32', 'where') 158 | cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) 159 | elif get_platform() == 'darwin': 160 | # CUDA libraries and drivers can be found in /usr/local/cuda/. See 161 | # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install 162 | # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac 163 | # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. 164 | cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' 165 | else: 166 | cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' 167 | rc, out, _ = run_lambda(cudnn_cmd) 168 | # find will return 1 if there are permission errors or if not found 169 | if len(out) == 0 or (rc != 1 and rc != 0): 170 | l = os.environ.get('CUDNN_LIBRARY') 171 | if l is not None and os.path.isfile(l): 172 | return os.path.realpath(l) 173 | return None 174 | files_set = set() 175 | for fn in out.split('\n'): 176 | fn = os.path.realpath(fn) # eliminate symbolic links 177 | if os.path.isfile(fn): 178 | files_set.add(fn) 179 | if not files_set: 180 | return None 181 | # Alphabetize the result because the order is non-deterministic otherwise 182 | files = list(sorted(files_set)) 183 | if len(files) == 1: 184 | return files[0] 185 | result = '\n'.join(files) 186 | return 'Probably one of the following:\n{}'.format(result) 187 | 188 | 189 | def get_nvidia_smi(): 190 | # Note: nvidia-smi is currently available only on Windows and Linux 191 | smi = 'nvidia-smi' 192 | if get_platform() == 'win32': 193 | system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') 194 | program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') 195 | legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) 196 | new_path = os.path.join(system_root, 'System32', smi) 197 | smis = [new_path, legacy_path] 198 | for candidate_smi in smis: 199 | if os.path.exists(candidate_smi): 200 | smi = '"{}"'.format(candidate_smi) 201 | break 202 | return smi 203 | 204 | 205 | def get_platform(): 206 | if sys.platform.startswith('linux'): 207 | return 'linux' 208 | elif sys.platform.startswith('win32'): 209 | return 'win32' 210 | elif sys.platform.startswith('cygwin'): 211 | return 'cygwin' 212 | elif sys.platform.startswith('darwin'): 213 | return 'darwin' 214 | else: 215 | return sys.platform 216 | 217 | 218 | def get_mac_version(run_lambda): 219 | return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') 220 | 221 | 222 | def get_windows_version(run_lambda): 223 | system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') 224 | wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') 225 | findstr_cmd = os.path.join(system_root, 'System32', 'findstr') 226 | return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) 227 | 228 | 229 | def get_lsb_version(run_lambda): 230 | return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') 231 | 232 | 233 | def check_release_file(run_lambda): 234 | return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', 235 | r'PRETTY_NAME="(.*)"') 236 | 237 | 238 | def get_os(run_lambda): 239 | from platform import machine 240 | platform = get_platform() 241 | 242 | if platform == 'win32' or platform == 'cygwin': 243 | return get_windows_version(run_lambda) 244 | 245 | if platform == 'darwin': 246 | version = get_mac_version(run_lambda) 247 | if version is None: 248 | return None 249 | return 'macOS {} ({})'.format(version, machine()) 250 | 251 | if platform == 'linux': 252 | # Ubuntu/Debian based 253 | desc = get_lsb_version(run_lambda) 254 | if desc is not None: 255 | return '{} ({})'.format(desc, machine()) 256 | 257 | # Try reading /etc/*-release 258 | desc = check_release_file(run_lambda) 259 | if desc is not None: 260 | return '{} ({})'.format(desc, machine()) 261 | 262 | return '{} ({})'.format(platform, machine()) 263 | 264 | # Unknown platform 265 | return platform 266 | 267 | 268 | def get_python_platform(): 269 | import platform 270 | return platform.platform() 271 | 272 | 273 | def get_libc_version(): 274 | import platform 275 | if get_platform() != 'linux': 276 | return 'N/A' 277 | return '-'.join(platform.libc_ver()) 278 | 279 | 280 | def get_pip_packages(run_lambda): 281 | """Returns `pip list` output. Note: will also find conda-installed pytorch 282 | and numpy packages.""" 283 | # People generally have `pip` as `pip` or `pip3` 284 | # But here it is incoved as `python -mpip` 285 | def run_with_pip(pip): 286 | out = run_and_read_all(run_lambda, "{} list --format=freeze".format(pip)) 287 | return "\n".join( 288 | line 289 | for line in out.splitlines() 290 | if any( 291 | name in line 292 | for name in { 293 | "torch", 294 | "numpy", 295 | "mypy", 296 | } 297 | ) 298 | ) 299 | 300 | pip_version = 'pip3' if sys.version[0] == '3' else 'pip' 301 | out = run_with_pip(sys.executable + ' -mpip') 302 | 303 | return pip_version, out 304 | 305 | 306 | def get_cachingallocator_config(): 307 | ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') 308 | return ca_config 309 | 310 | def is_xnnpack_available(): 311 | if TORCH_AVAILABLE: 312 | import torch.backends.xnnpack 313 | return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] 314 | else: 315 | return "N/A" 316 | 317 | def get_env_info(): 318 | run_lambda = run 319 | pip_version, pip_list_output = get_pip_packages(run_lambda) 320 | 321 | if TORCH_AVAILABLE: 322 | version_str = torch.__version__ 323 | debug_mode_str = str(torch.version.debug) 324 | cuda_available_str = str(torch.cuda.is_available()) 325 | cuda_version_str = torch.version.cuda 326 | if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version 327 | hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' 328 | else: # HIP version 329 | cfg = torch._C._show_config().split('\n') 330 | hip_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'HIP Runtime' in s][0] 331 | miopen_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'MIOpen' in s][0] 332 | cuda_version_str = 'N/A' 333 | hip_compiled_version = torch.version.hip 334 | else: 335 | version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' 336 | hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' 337 | 338 | sys_version = sys.version.replace("\n", " ") 339 | 340 | return SystemEnv( 341 | torch_version=version_str, 342 | is_debug_build=debug_mode_str, 343 | python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), 344 | python_platform=get_python_platform(), 345 | is_cuda_available=cuda_available_str, 346 | cuda_compiled_version=cuda_version_str, 347 | cuda_runtime_version=get_running_cuda_version(run_lambda), 348 | nvidia_gpu_models=get_gpu_info(run_lambda), 349 | nvidia_driver_version=get_nvidia_driver_version(run_lambda), 350 | cudnn_version=get_cudnn_version(run_lambda), 351 | hip_compiled_version=hip_compiled_version, 352 | hip_runtime_version=hip_runtime_version, 353 | miopen_runtime_version=miopen_runtime_version, 354 | pip_version=pip_version, 355 | pip_packages=pip_list_output, 356 | conda_packages=get_conda_packages(run_lambda), 357 | os=get_os(run_lambda), 358 | libc_version=get_libc_version(), 359 | gcc_version=get_gcc_version(run_lambda), 360 | clang_version=get_clang_version(run_lambda), 361 | cmake_version=get_cmake_version(run_lambda), 362 | caching_allocator_config=get_cachingallocator_config(), 363 | is_xnnpack_available=is_xnnpack_available(), 364 | ) 365 | 366 | env_info_fmt = """ 367 | PyTorch version: {torch_version} 368 | Is debug build: {is_debug_build} 369 | CUDA used to build PyTorch: {cuda_compiled_version} 370 | ROCM used to build PyTorch: {hip_compiled_version} 371 | 372 | OS: {os} 373 | GCC version: {gcc_version} 374 | Clang version: {clang_version} 375 | CMake version: {cmake_version} 376 | Libc version: {libc_version} 377 | 378 | Python version: {python_version} 379 | Python platform: {python_platform} 380 | Is CUDA available: {is_cuda_available} 381 | CUDA runtime version: {cuda_runtime_version} 382 | GPU models and configuration: {nvidia_gpu_models} 383 | Nvidia driver version: {nvidia_driver_version} 384 | cuDNN version: {cudnn_version} 385 | HIP runtime version: {hip_runtime_version} 386 | MIOpen runtime version: {miopen_runtime_version} 387 | Is XNNPACK available: {is_xnnpack_available} 388 | 389 | Versions of relevant libraries: 390 | {pip_packages} 391 | {conda_packages} 392 | """.strip() 393 | 394 | 395 | def pretty_str(envinfo): 396 | def replace_nones(dct, replacement='Could not collect'): 397 | for key in dct.keys(): 398 | if dct[key] is not None: 399 | continue 400 | dct[key] = replacement 401 | return dct 402 | 403 | def replace_bools(dct, true='Yes', false='No'): 404 | for key in dct.keys(): 405 | if dct[key] is True: 406 | dct[key] = true 407 | elif dct[key] is False: 408 | dct[key] = false 409 | return dct 410 | 411 | def prepend(text, tag='[prepend]'): 412 | lines = text.split('\n') 413 | updated_lines = [tag + line for line in lines] 414 | return '\n'.join(updated_lines) 415 | 416 | def replace_if_empty(text, replacement='No relevant packages'): 417 | if text is not None and len(text) == 0: 418 | return replacement 419 | return text 420 | 421 | def maybe_start_on_next_line(string): 422 | # If `string` is multiline, prepend a \n to it. 423 | if string is not None and len(string.split('\n')) > 1: 424 | return '\n{}\n'.format(string) 425 | return string 426 | 427 | mutable_dict = envinfo._asdict() 428 | 429 | # If nvidia_gpu_models is multiline, start on the next line 430 | mutable_dict['nvidia_gpu_models'] = \ 431 | maybe_start_on_next_line(envinfo.nvidia_gpu_models) 432 | 433 | # If the machine doesn't have CUDA, report some fields as 'No CUDA' 434 | dynamic_cuda_fields = [ 435 | 'cuda_runtime_version', 436 | 'nvidia_gpu_models', 437 | 'nvidia_driver_version', 438 | ] 439 | all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] 440 | all_dynamic_cuda_fields_missing = all( 441 | mutable_dict[field] is None for field in dynamic_cuda_fields) 442 | if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: 443 | for field in all_cuda_fields: 444 | mutable_dict[field] = 'No CUDA' 445 | if envinfo.cuda_compiled_version is None: 446 | mutable_dict['cuda_compiled_version'] = 'None' 447 | 448 | # Replace True with Yes, False with No 449 | mutable_dict = replace_bools(mutable_dict) 450 | 451 | # Replace all None objects with 'Could not collect' 452 | mutable_dict = replace_nones(mutable_dict) 453 | 454 | # If either of these are '', replace with 'No relevant packages' 455 | mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) 456 | mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) 457 | 458 | # Tag conda and pip packages with a prefix 459 | # If they were previously None, they'll show up as ie '[conda] Could not collect' 460 | if mutable_dict['pip_packages']: 461 | mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], 462 | '[{}] '.format(envinfo.pip_version)) 463 | if mutable_dict['conda_packages']: 464 | mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], 465 | '[conda] ') 466 | return env_info_fmt.format(**mutable_dict) 467 | 468 | 469 | def get_pretty_env_info(): 470 | return pretty_str(get_env_info()) 471 | 472 | 473 | def main(): 474 | print("Collecting environment information...") 475 | output = get_pretty_env_info() 476 | print(output) 477 | 478 | if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): 479 | minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR 480 | if sys.platform == "linux" and os.path.exists(minidump_dir): 481 | dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] 482 | latest = max(dumps, key=os.path.getctime) 483 | ctime = os.path.getctime(latest) 484 | creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') 485 | msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ 486 | "if this is related to your bug please include it when you file a report ***" 487 | print(msg, file=sys.stderr) 488 | 489 | 490 | 491 | if __name__ == '__main__': 492 | main() 493 | -------------------------------------------------------------------------------- /convs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GengDavid/SLCA/169d4e6b91e3ca30caa717200c1944981c4426d1/convs/__init__.py -------------------------------------------------------------------------------- /convs/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: 3 | https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py 4 | ''' 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class DownsampleA(nn.Module): 13 | def __init__(self, nIn, nOut, stride): 14 | super(DownsampleA, self).__init__() 15 | assert stride == 2 16 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) 17 | 18 | def forward(self, x): 19 | x = self.avg(x) 20 | return torch.cat((x, x.mul(0)), 1) 21 | 22 | 23 | class DownsampleB(nn.Module): 24 | def __init__(self, nIn, nOut, stride): 25 | super(DownsampleB, self).__init__() 26 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 27 | self.bn = nn.BatchNorm2d(nOut) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | x = self.bn(x) 32 | return x 33 | 34 | 35 | class DownsampleC(nn.Module): 36 | def __init__(self, nIn, nOut, stride): 37 | super(DownsampleC, self).__init__() 38 | assert stride != 1 or nIn != nOut 39 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 40 | 41 | def forward(self, x): 42 | x = self.conv(x) 43 | return x 44 | 45 | 46 | class DownsampleD(nn.Module): 47 | def __init__(self, nIn, nOut, stride): 48 | super(DownsampleD, self).__init__() 49 | assert stride == 2 50 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) 51 | self.bn = nn.BatchNorm2d(nOut) 52 | 53 | def forward(self, x): 54 | x = self.conv(x) 55 | x = self.bn(x) 56 | return x 57 | 58 | 59 | class ResNetBasicblock(nn.Module): 60 | expansion = 1 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(ResNetBasicblock, self).__init__() 64 | 65 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 66 | self.bn_a = nn.BatchNorm2d(planes) 67 | 68 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 69 | self.bn_b = nn.BatchNorm2d(planes) 70 | 71 | self.downsample = downsample 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | basicblock = self.conv_a(x) 77 | basicblock = self.bn_a(basicblock) 78 | basicblock = F.relu(basicblock, inplace=True) 79 | 80 | basicblock = self.conv_b(basicblock) 81 | basicblock = self.bn_b(basicblock) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | return F.relu(residual + basicblock, inplace=True) 87 | 88 | 89 | class CifarResNet(nn.Module): 90 | """ 91 | ResNet optimized for the Cifar Dataset, as specified in 92 | https://arxiv.org/abs/1512.03385.pdf 93 | """ 94 | 95 | def __init__(self, block, depth, channels=3): 96 | super(CifarResNet, self).__init__() 97 | 98 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 99 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 100 | layer_blocks = (depth - 2) // 6 101 | 102 | self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 103 | self.bn_1 = nn.BatchNorm2d(16) 104 | 105 | self.inplanes = 16 106 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 107 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 108 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 109 | self.avgpool = nn.AvgPool2d(8) 110 | self.out_dim = 64 * block.expansion 111 | self.fc = nn.Linear(64*block.expansion, 10) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | # m.bias.data.zero_() 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | elif isinstance(m, nn.Linear): 122 | nn.init.kaiming_normal_(m.weight) 123 | m.bias.data.zero_() 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv_1_3x3(x) # [bs, 16, 32, 32] 140 | x = F.relu(self.bn_1(x), inplace=True) 141 | 142 | x_1 = self.stage_1(x) # [bs, 16, 32, 32] 143 | x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] 144 | x_3 = self.stage_3(x_2) # [bs, 64, 8, 8] 145 | 146 | pooled = self.avgpool(x_3) # [bs, 64, 1, 1] 147 | features = pooled.view(pooled.size(0), -1) # [bs, 64] 148 | 149 | return { 150 | 'fmaps': [x_1, x_2, x_3], 151 | 'features': features 152 | } 153 | 154 | @property 155 | def last_conv(self): 156 | return self.stage_3[-1].conv_b 157 | 158 | 159 | def resnet20mnist(): 160 | """Constructs a ResNet-20 model for MNIST.""" 161 | model = CifarResNet(ResNetBasicblock, 20, 1) 162 | return model 163 | 164 | 165 | def resnet32mnist(): 166 | """Constructs a ResNet-32 model for MNIST.""" 167 | model = CifarResNet(ResNetBasicblock, 32, 1) 168 | return model 169 | 170 | 171 | def resnet20(): 172 | """Constructs a ResNet-20 model for CIFAR-10.""" 173 | model = CifarResNet(ResNetBasicblock, 20) 174 | return model 175 | 176 | 177 | def resnet32(): 178 | """Constructs a ResNet-32 model for CIFAR-10.""" 179 | model = CifarResNet(ResNetBasicblock, 32) 180 | return model 181 | 182 | 183 | def resnet44(): 184 | """Constructs a ResNet-44 model for CIFAR-10.""" 185 | model = CifarResNet(ResNetBasicblock, 44) 186 | return model 187 | 188 | 189 | def resnet56(): 190 | """Constructs a ResNet-56 model for CIFAR-10.""" 191 | model = CifarResNet(ResNetBasicblock, 56) 192 | return model 193 | 194 | 195 | def resnet110(): 196 | """Constructs a ResNet-110 model for CIFAR-10.""" 197 | model = CifarResNet(ResNetBasicblock, 110) 198 | return model 199 | -------------------------------------------------------------------------------- /convs/linears.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: 3 | https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py 4 | ''' 5 | import math 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from timm.models.layers.weight_init import trunc_normal_ 10 | from timm.models.layers import Mlp 11 | from copy import deepcopy 12 | 13 | class SimpleScaler(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.weight = nn.Parameter(torch.ones(1), requires_grad=False) 17 | 18 | def forward(self, x): 19 | weight = self.weight 20 | x = x*weight.unsqueeze(1) 21 | 22 | return x 23 | 24 | 25 | class SimpleContinualLinear(nn.Module): 26 | def __init__(self, embed_dim, nb_classes, feat_expand=False, with_norm=False, scale_mu=-1): 27 | super().__init__() 28 | 29 | self.embed_dim = embed_dim 30 | self.feat_expand = feat_expand 31 | self.with_norm = with_norm 32 | self.scale_mu = scale_mu 33 | 34 | if self.scale_mu > 0: 35 | scales = [] 36 | scales.append(SimpleScaler()) 37 | self.scales = nn.ModuleList(scales) 38 | else: 39 | self.scales = None 40 | 41 | heads = [] 42 | single_head = [] 43 | if with_norm: 44 | single_head.append(nn.LayerNorm(embed_dim)) 45 | 46 | single_head.append(nn.Linear(embed_dim, nb_classes, bias=True)) 47 | head = nn.Sequential(*single_head) 48 | 49 | heads.append(head) 50 | self.heads = nn.ModuleList(heads) 51 | for m in self.modules(): 52 | if isinstance(m, nn.Linear): 53 | trunc_normal_(m.weight, std=.02) 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | 57 | 58 | def update_scale(self): 59 | if self.scales is None: 60 | return 61 | num_old_tasks = len(self.heads)-1 62 | for t_id in range(num_old_tasks): # update scale for old tasks 63 | new_scale = 1+self.scale_mu*(num_old_tasks-t_id) 64 | self.scales[t_id].weight.data = torch.tensor([new_scale]).to(self.scales[t_id].weight) 65 | 66 | def backup(self): 67 | self.old_state_dict = deepcopy(self.state_dict()) 68 | 69 | def recall(self): 70 | self.load_state_dict(self.old_state_dict) 71 | 72 | def update(self, nb_classes, freeze_old=True): 73 | single_head = [] 74 | if self.with_norm: 75 | single_head.append(nn.LayerNorm(self.embed_dim)) 76 | 77 | if self.scale_mu>0: 78 | self.scales.append(SimpleScaler()) 79 | 80 | _fc = nn.Linear(self.embed_dim, nb_classes, bias=True) 81 | trunc_normal_(_fc.weight, std=.02) 82 | nn.init.constant_(_fc.bias, 0) 83 | single_head.append(_fc) 84 | new_head = nn.Sequential(*single_head) 85 | 86 | 87 | if freeze_old: 88 | for p in self.heads.parameters(): 89 | p.requires_grad=False 90 | 91 | self.heads.append(new_head) 92 | 93 | def forward(self, x): 94 | out = [] 95 | for ti in range(len(self.heads)): 96 | fc_inp = x[ti] if self.feat_expand else x 97 | if self.scale_mu>0: 98 | fc_inp = self.scales[ti](fc_inp) 99 | out.append(self.heads[ti](fc_inp)) 100 | out = {'logits': torch.cat(out, dim=1)} 101 | return out 102 | -------------------------------------------------------------------------------- /convs/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | # from torchvision.models.utils import load_state_dict_from_url 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 12 | 'wide_resnet50_2', 'wide_resnet101_2'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 25 | } 26 | 27 | 28 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 29 | """3x3 convolution with padding""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=dilation, groups=groups, bias=False, dilation=dilation) 32 | 33 | 34 | def conv1x1(in_planes, out_planes, stride=1): 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | __constants__ = ['downsample'] 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 44 | base_width=64, dilation=1, norm_layer=None, no_last_relu=False): 45 | super(BasicBlock, self).__init__() 46 | if norm_layer is None: 47 | norm_layer = nn.BatchNorm2d 48 | if groups != 1 or base_width != 64: 49 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 50 | if dilation > 1: 51 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 52 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 53 | self.conv1 = conv3x3(inplanes, planes, stride) 54 | self.bn1 = norm_layer(planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.conv2 = conv3x3(planes, planes) 57 | self.bn2 = norm_layer(planes) 58 | self.downsample = downsample 59 | self.stride = stride 60 | self.no_last_relu = no_last_relu 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | if not self.no_last_relu: 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | __constants__ = ['downsample'] 85 | 86 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 87 | base_width=64, dilation=1, norm_layer=None, no_last_relu=False): 88 | super(Bottleneck, self).__init__() 89 | if norm_layer is None: 90 | norm_layer = nn.BatchNorm2d 91 | width = int(planes * (base_width / 64.)) * groups 92 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 93 | self.conv1 = conv1x1(inplanes, width) 94 | self.bn1 = norm_layer(width) 95 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 96 | self.bn2 = norm_layer(width) 97 | self.conv3 = conv1x1(width, planes * self.expansion) 98 | self.bn3 = norm_layer(planes * self.expansion) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.downsample = downsample 101 | self.stride = stride 102 | self.no_last_relu = no_last_relu 103 | 104 | def forward(self, x): 105 | identity = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | identity = self.downsample(x) 120 | 121 | out += identity 122 | if not self.no_last_relu: 123 | out = self.relu(out) 124 | 125 | return out 126 | 127 | 128 | 129 | 130 | # 修改Resnet的实现。 131 | class ResNet(nn.Module): 132 | 133 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 134 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 135 | norm_layer=None, cifar=False, no_last_relu=False): 136 | super(ResNet, self).__init__() 137 | if norm_layer is None: 138 | norm_layer = nn.BatchNorm2d 139 | self._norm_layer = norm_layer 140 | self.cifar = cifar 141 | 142 | self.inplanes = 64 143 | self.dilation = 1 144 | if replace_stride_with_dilation is None: 145 | # each element in the tuple indicates if we should replace 146 | # the 2x2 stride with a dilated convolution instead 147 | replace_stride_with_dilation = [False, False, False] 148 | if len(replace_stride_with_dilation) != 3: 149 | raise ValueError("replace_stride_with_dilation should be None " 150 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 151 | self.groups = groups 152 | self.base_width = width_per_group 153 | if self.cifar: 154 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 155 | else: 156 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 157 | self.bn1 = norm_layer(self.inplanes) 158 | self.relu = nn.ReLU(inplace=True) 159 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Removed in _forward_impl for cifar 160 | self.layer1 = self._make_layer(block, 64, layers[0]) 161 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 162 | dilate=replace_stride_with_dilation[0]) 163 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 164 | dilate=replace_stride_with_dilation[1]) 165 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 166 | dilate=replace_stride_with_dilation[2], no_last_relu=no_last_relu) 167 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 168 | self.out_dim = 512 * block.expansion 169 | # self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl 170 | 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 174 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 175 | nn.init.constant_(m.weight, 1) 176 | nn.init.constant_(m.bias, 0) 177 | 178 | # Zero-initialize the last BN in each residual branch, 179 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 180 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 181 | if zero_init_residual: 182 | for m in self.modules(): 183 | if isinstance(m, Bottleneck): 184 | nn.init.constant_(m.bn3.weight, 0) 185 | elif isinstance(m, BasicBlock): 186 | nn.init.constant_(m.bn2.weight, 0) 187 | 188 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, no_last_relu=False): 189 | norm_layer = self._norm_layer 190 | downsample = None 191 | previous_dilation = self.dilation 192 | if dilate: 193 | self.dilation *= stride 194 | stride = 1 195 | if stride != 1 or self.inplanes != planes * block.expansion: 196 | downsample = nn.Sequential( 197 | conv1x1(self.inplanes, planes * block.expansion, stride), 198 | norm_layer(planes * block.expansion), 199 | ) 200 | 201 | layers = [] 202 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 203 | self.base_width, previous_dilation, norm_layer)) 204 | self.inplanes = planes * block.expansion 205 | for bid in range(1, blocks): 206 | layers.append(block(self.inplanes, planes, groups=self.groups, 207 | base_width=self.base_width, dilation=self.dilation, 208 | norm_layer=norm_layer, no_last_relu=no_last_relu if bid==blocks-1 else False)) 209 | 210 | return nn.Sequential(*layers) 211 | 212 | def _forward_impl(self, x): 213 | # See note [TorchScript super()] 214 | x = self.conv1(x) # [bs, 64, 32, 32] 215 | x = self.bn1(x) 216 | x = self.relu(x) 217 | if not self.cifar: 218 | x = self.maxpool(x) 219 | 220 | x_1 = self.layer1(x) # [bs, 128, 32, 32] 221 | x_2 = self.layer2(x_1) # [bs, 256, 16, 16] 222 | x_3 = self.layer3(x_2) # [bs, 512, 8, 8] 223 | x_4 = self.layer4(x_3) # [bs, 512, 4, 4] 224 | 225 | pooled = self.avgpool(x_4) # [bs, 512, 1, 1] 226 | features = torch.flatten(pooled, 1) # [bs, 512] 227 | # x = self.fc(x) 228 | 229 | return { 230 | 'fmaps': [x_1, x_2, x_3, x_4], 231 | 'features': features 232 | } 233 | 234 | def forward(self, x): 235 | return self._forward_impl(x) 236 | 237 | @property 238 | def last_conv(self): 239 | if hasattr(self.layer4[-1], 'conv3'): 240 | return self.layer4[-1].conv3 241 | else: 242 | return self.layer4[-1].conv2 243 | 244 | 245 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 246 | model = ResNet(block, layers, **kwargs) 247 | if pretrained: 248 | state_dict = load_state_dict_from_url(model_urls[arch], 249 | progress=progress) 250 | model.load_state_dict(state_dict) 251 | return model 252 | 253 | 254 | def resnet18(pretrained=False, progress=True, **kwargs): 255 | r"""ResNet-18 model from 256 | `"Deep Residual Learning for Image Recognition" `_ 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | progress (bool): If True, displays a progress bar of the download to stderr 260 | """ 261 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 262 | **kwargs) 263 | 264 | 265 | def resnet34(pretrained=False, progress=True, **kwargs): 266 | r"""ResNet-34 model from 267 | `"Deep Residual Learning for Image Recognition" `_ 268 | Args: 269 | pretrained (bool): If True, returns a model pre-trained on ImageNet 270 | progress (bool): If True, displays a progress bar of the download to stderr 271 | """ 272 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 273 | **kwargs) 274 | 275 | 276 | def resnet50(pretrained=False, progress=True, **kwargs): 277 | r"""ResNet-50 model from 278 | `"Deep Residual Learning for Image Recognition" `_ 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | progress (bool): If True, displays a progress bar of the download to stderr 282 | """ 283 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 284 | **kwargs) 285 | 286 | 287 | def resnet101(pretrained=False, progress=True, **kwargs): 288 | r"""ResNet-101 model from 289 | `"Deep Residual Learning for Image Recognition" `_ 290 | Args: 291 | pretrained (bool): If True, returns a model pre-trained on ImageNet 292 | progress (bool): If True, displays a progress bar of the download to stderr 293 | """ 294 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 295 | **kwargs) 296 | 297 | 298 | def resnet152(pretrained=False, progress=True, **kwargs): 299 | r"""ResNet-152 model from 300 | `"Deep Residual Learning for Image Recognition" `_ 301 | Args: 302 | pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | progress (bool): If True, displays a progress bar of the download to stderr 304 | """ 305 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 306 | **kwargs) 307 | 308 | 309 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 310 | r"""ResNeXt-50 32x4d model from 311 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 312 | Args: 313 | pretrained (bool): If True, returns a model pre-trained on ImageNet 314 | progress (bool): If True, displays a progress bar of the download to stderr 315 | """ 316 | kwargs['groups'] = 32 317 | kwargs['width_per_group'] = 4 318 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 319 | pretrained, progress, **kwargs) 320 | 321 | 322 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 323 | r"""ResNeXt-101 32x8d model from 324 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 325 | Args: 326 | pretrained (bool): If True, returns a model pre-trained on ImageNet 327 | progress (bool): If True, displays a progress bar of the download to stderr 328 | """ 329 | kwargs['groups'] = 32 330 | kwargs['width_per_group'] = 8 331 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 332 | pretrained, progress, **kwargs) 333 | 334 | 335 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 336 | r"""Wide ResNet-50-2 model from 337 | `"Wide Residual Networks" `_ 338 | The model is the same as ResNet except for the bottleneck number of channels 339 | which is twice larger in every block. The number of channels in outer 1x1 340 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 341 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | progress (bool): If True, displays a progress bar of the download to stderr 345 | """ 346 | kwargs['width_per_group'] = 64 * 2 347 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 348 | pretrained, progress, **kwargs) 349 | 350 | 351 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 352 | r"""Wide ResNet-101-2 model from 353 | `"Wide Residual Networks" `_ 354 | The model is the same as ResNet except for the bottleneck number of channels 355 | which is twice larger in every block. The number of channels in outer 1x1 356 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 357 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 358 | Args: 359 | pretrained (bool): If True, returns a model pre-trained on ImageNet 360 | progress (bool): If True, displays a progress bar of the download to stderr 361 | """ 362 | kwargs['width_per_group'] = 64 * 2 363 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 364 | pretrained, progress, **kwargs) 365 | -------------------------------------------------------------------------------- /convs/vits.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | The official jax code is released and available at https://github.com/google-research/vision_transformer 12 | 13 | DeiT model defs and weights from https://github.com/facebookresearch/deit, 14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 15 | 16 | Acknowledgments: 17 | * The paper authors for releasing code and weights, thanks! 18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 19 | for some einops/einsum fun 20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 22 | 23 | Hacked together by / Copyright 2020, Ross Wightman 24 | """ 25 | import math 26 | import logging 27 | from functools import partial 28 | from collections import OrderedDict 29 | from copy import deepcopy 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 36 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 37 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 38 | from timm.models.registry import register_model 39 | 40 | _logger = logging.getLogger(__name__) 41 | 42 | 43 | def _cfg(url='', **kwargs): 44 | return { 45 | 'url': url, 46 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 47 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 48 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 49 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 50 | **kwargs 51 | } 52 | 53 | 54 | default_cfgs = { 55 | # patch models (weights from official Google JAX impl) 56 | 'vit_tiny_patch16_224': _cfg( 57 | url='https://storage.googleapis.com/vit_models/augreg/' 58 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 59 | 'vit_tiny_patch16_384': _cfg( 60 | url='https://storage.googleapis.com/vit_models/augreg/' 61 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 62 | input_size=(3, 384, 384), crop_pct=1.0), 63 | 'vit_small_patch32_224': _cfg( 64 | url='https://storage.googleapis.com/vit_models/augreg/' 65 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 66 | 'vit_small_patch32_384': _cfg( 67 | url='https://storage.googleapis.com/vit_models/augreg/' 68 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 69 | input_size=(3, 384, 384), crop_pct=1.0), 70 | 'vit_small_patch16_224': _cfg( 71 | url='https://storage.googleapis.com/vit_models/augreg/' 72 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 73 | 'vit_small_patch16_384': _cfg( 74 | url='https://storage.googleapis.com/vit_models/augreg/' 75 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 76 | input_size=(3, 384, 384), crop_pct=1.0), 77 | 'vit_base_patch32_224': _cfg( 78 | url='https://storage.googleapis.com/vit_models/augreg/' 79 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 80 | 'vit_base_patch32_384': _cfg( 81 | url='https://storage.googleapis.com/vit_models/augreg/' 82 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 83 | input_size=(3, 384, 384), crop_pct=1.0), 84 | 'vit_base_patch16_224': _cfg( 85 | url='https://storage.googleapis.com/vit_models/augreg/' 86 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 87 | 'vit_base_patch16_384': _cfg( 88 | url='https://storage.googleapis.com/vit_models/augreg/' 89 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 90 | input_size=(3, 384, 384), crop_pct=1.0), 91 | 'vit_base_patch8_224': _cfg( 92 | url='https://storage.googleapis.com/vit_models/augreg/' 93 | 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 94 | 'vit_large_patch32_224': _cfg( 95 | url='', # no official model weights for this combo, only for in21k 96 | ), 97 | 'vit_large_patch32_384': _cfg( 98 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 99 | input_size=(3, 384, 384), crop_pct=1.0), 100 | 'vit_large_patch16_224': _cfg( 101 | url='https://storage.googleapis.com/vit_models/augreg/' 102 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 103 | 'vit_large_patch16_384': _cfg( 104 | url='https://storage.googleapis.com/vit_models/augreg/' 105 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 106 | input_size=(3, 384, 384), crop_pct=1.0), 107 | 108 | 'vit_huge_patch14_224': _cfg(url=''), 109 | 'vit_giant_patch14_224': _cfg(url=''), 110 | 'vit_gigantic_patch14_224': _cfg(url=''), 111 | 112 | # patch models, imagenet21k (weights from official Google JAX impl) 113 | 'vit_tiny_patch16_224_in21k': _cfg( 114 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 115 | num_classes=21843), 116 | 'vit_small_patch32_224_in21k': _cfg( 117 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 118 | num_classes=21843), 119 | 'vit_small_patch16_224_in21k': _cfg( 120 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 121 | num_classes=21843), 122 | 'vit_base_patch32_224_in21k': _cfg( 123 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', 124 | num_classes=21843), 125 | 'vit_base_patch16_224_in21k': _cfg( 126 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 127 | #url='./B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 128 | num_classes=21843), 129 | 'vit_base_patch8_224_in21k': _cfg( 130 | url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 131 | num_classes=21843), 132 | 'vit_large_patch32_224_in21k': _cfg( 133 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', 134 | num_classes=21843), 135 | 'vit_large_patch16_224_in21k': _cfg( 136 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', 137 | num_classes=21843), 138 | 'vit_huge_patch14_224_in21k': _cfg( 139 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', 140 | hf_hub='timm/vit_huge_patch14_224_in21k', 141 | num_classes=21843), 142 | 143 | # SAM trained models (https://arxiv.org/abs/2106.01548) 144 | 'vit_base_patch32_sam_224': _cfg( 145 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), 146 | 'vit_base_patch16_sam_224': _cfg( 147 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), 148 | 149 | # deit models (FB weights) 150 | 'deit_tiny_patch16_224': _cfg( 151 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', 152 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 153 | 'deit_small_patch16_224': _cfg( 154 | url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', 155 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 156 | 'deit_base_patch16_224': _cfg( 157 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', 158 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 159 | 'deit_base_patch16_384': _cfg( 160 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', 161 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), 162 | 'deit_tiny_distilled_patch16_224': _cfg( 163 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', 164 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 165 | 'deit_small_distilled_patch16_224': _cfg( 166 | url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', 167 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 168 | 'deit_base_distilled_patch16_224': _cfg( 169 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', 170 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 171 | 'deit_base_distilled_patch16_384': _cfg( 172 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', 173 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, 174 | classifier=('head', 'head_dist')), 175 | 176 | # ViT ImageNet-21K-P pretraining by MILL 177 | 'vit_base_patch16_224_miil_in21k': _cfg( 178 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', 179 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, 180 | ), 181 | 'vit_base_patch16_224_miil': _cfg( 182 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' 183 | '/vit_base_patch16_224_1k_miil_84_4.pth', 184 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', 185 | ), 186 | } 187 | 188 | 189 | class LoraMlp(nn.Module): 190 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., 191 | lora_rank=4, lora_scale=0.8): 192 | super().__init__() 193 | out_features = out_features or in_features 194 | hidden_features = hidden_features or in_features 195 | self.hidden_features = hidden_features 196 | self.fc1 = nn.Linear(in_features, hidden_features) 197 | self.act = act_layer() 198 | self.fc2 = nn.Linear(hidden_features, out_features) 199 | self.drop = nn.Dropout(drop) 200 | 201 | self.lora_rank = lora_rank 202 | self.lora_1_B = torch.nn.Parameter(torch.zeros(hidden_features, self.lora_rank)) 203 | self.lora_1_A = torch.nn.Parameter(torch.zeros(self.lora_rank, in_features)) 204 | trunc_normal_(self.lora_1_A, std=0.02) 205 | 206 | self.lora_2_B = torch.nn.Parameter(torch.zeros(in_features, self.lora_rank)) 207 | self.lora_2_A = torch.nn.Parameter(torch.zeros(self.lora_rank, hidden_features)) 208 | trunc_normal_(self.lora_2_A, std=0.02) 209 | 210 | self.lora_scale = lora_scale 211 | 212 | def init_lora(self): 213 | U, S, Vh = torch.linalg.svd(self.fc1.weight) 214 | self.lora_1_A.data[:] = Vh[:self.lora_rank, :] 215 | 216 | U, S, Vh = torch.linalg.svd(self.fc2.weight) 217 | self.lora_2_A.data[:] = Vh[:self.lora_rank, :] 218 | 219 | def forward(self, x): 220 | lora_w = torch.matmul(self.lora_1_B, self.lora_1_A)*self.lora_scale 221 | 222 | # x = self.fc1(x) 223 | x = F.linear(x, lora_w+self.fc1.weight, bias=self.fc1.bias) 224 | x = self.act(x) 225 | x = self.drop(x) 226 | 227 | lora_w = torch.matmul(self.lora_2_B, self.lora_2_A)*self.lora_scale 228 | 229 | # x = self.fc2(x) 230 | x = F.linear(x, lora_w+self.fc2.weight, bias=self.fc2.bias) 231 | x = self.drop(x) 232 | 233 | return x 234 | 235 | 236 | class LoraAttention(nn.Module): 237 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., lora_rank=4, lora_scale=0.8): 238 | super().__init__() 239 | self.num_heads = num_heads 240 | head_dim = dim // num_heads 241 | self.scale = head_dim ** -0.5 242 | 243 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 244 | self.attn_drop = nn.Dropout(attn_drop) 245 | self.proj = nn.Linear(dim, dim) 246 | self.proj_drop = nn.Dropout(proj_drop) 247 | 248 | 249 | self.lora_rank = lora_rank 250 | self.lora_B = torch.nn.Parameter(torch.zeros(dim*3, lora_rank)) 251 | self.lora_A = torch.nn.Parameter(torch.zeros(lora_rank, dim)) 252 | trunc_normal_(self.lora_A, std=0.02) 253 | 254 | 255 | self.lora_scale = lora_scale 256 | 257 | def init_lora(self): 258 | U, S, Vh = torch.linalg.svd(self.qkv.weight) 259 | self.lora_A.data[:] = Vh[:self.lora_rank, :] 260 | 261 | def forward(self, x): 262 | B, N, C = x.shape 263 | 264 | lora_w = torch.matmul(self.lora_B, self.lora_A)*self.lora_scale 265 | qkv = F.linear(x, lora_w+self.qkv.weight, bias=self.qkv.bias) 266 | qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 267 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 268 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 269 | 270 | attn = (q @ k.transpose(-2, -1)) * self.scale 271 | attn = attn.softmax(dim=-1) 272 | attn = self.attn_drop(attn) 273 | 274 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 275 | 276 | x = self.proj(x) 277 | x = self.proj_drop(x) 278 | return x 279 | 280 | 281 | class Attention(nn.Module): 282 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 283 | super().__init__() 284 | self.num_heads = num_heads 285 | head_dim = dim // num_heads 286 | self.scale = head_dim ** -0.5 287 | 288 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 289 | self.attn_drop = nn.Dropout(attn_drop) 290 | self.proj = nn.Linear(dim, dim) 291 | self.proj_drop = nn.Dropout(proj_drop) 292 | 293 | def forward(self, x): 294 | B, N, C = x.shape 295 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 296 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 297 | 298 | attn = (q @ k.transpose(-2, -1)) * self.scale 299 | attn = attn.softmax(dim=-1) 300 | attn = self.attn_drop(attn) 301 | 302 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 303 | x = self.proj(x) 304 | x = self.proj_drop(x) 305 | return x 306 | 307 | 308 | class Block(nn.Module): 309 | 310 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 311 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 312 | super().__init__() 313 | self.norm1 = norm_layer(dim) 314 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 315 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 316 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 317 | self.norm2 = norm_layer(dim) 318 | mlp_hidden_dim = int(dim * mlp_ratio) 319 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 320 | 321 | def forward(self, x): 322 | x = x + self.drop_path(self.attn(self.norm1(x))) 323 | x = x + self.drop_path(self.mlp(self.norm2(x))) 324 | return x 325 | 326 | 327 | class VisionTransformer(nn.Module): 328 | """ Vision Transformer 329 | 330 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 331 | - https://arxiv.org/abs/2010.11929 332 | 333 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 334 | - https://arxiv.org/abs/2012.12877 335 | """ 336 | 337 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 338 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 339 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 340 | act_layer=None, weight_init='', global_pool=False,lora_rank=-1): 341 | """ 342 | Args: 343 | img_size (int, tuple): input image size 344 | patch_size (int, tuple): patch size 345 | in_chans (int): number of input channels 346 | num_classes (int): number of classes for classification head 347 | embed_dim (int): embedding dimension 348 | depth (int): depth of transformer 349 | num_heads (int): number of attention heads 350 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 351 | qkv_bias (bool): enable bias for qkv if True 352 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 353 | distilled (bool): model includes a distillation token and head as in DeiT models 354 | drop_rate (float): dropout rate 355 | attn_drop_rate (float): attention dropout rate 356 | drop_path_rate (float): stochastic depth rate 357 | embed_layer (nn.Module): patch embedding layer 358 | norm_layer: (nn.Module): normalization layer 359 | weight_init: (str): weight init scheme 360 | """ 361 | super().__init__() 362 | self.num_classes = num_classes 363 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 364 | self.out_dim = embed_dim 365 | self.num_tokens = 2 if distilled else 1 366 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 367 | act_layer = act_layer or nn.GELU 368 | 369 | self.global_pool = global_pool 370 | 371 | self.patch_embed = embed_layer( 372 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 373 | num_patches = self.patch_embed.num_patches 374 | 375 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 376 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 377 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 378 | self.pos_drop = nn.Dropout(p=drop_rate) 379 | 380 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 381 | self.blocks = nn.ModuleList([ 382 | Block( 383 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 384 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 385 | for i in range(depth)]) 386 | self.norm = norm_layer(embed_dim) 387 | 388 | # Representation layer 389 | if representation_size and not distilled: 390 | self.num_features = representation_size 391 | self.pre_logits = nn.Sequential(OrderedDict([ 392 | ('fc', nn.Linear(embed_dim, representation_size)), 393 | ('act', nn.Tanh()) 394 | ])) 395 | else: 396 | self.pre_logits = nn.Identity() 397 | 398 | # Classifier head(s) 399 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 400 | self.head_dist = None 401 | if distilled: 402 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 403 | 404 | self.init_weights(weight_init) 405 | 406 | 407 | self.lora_rank = lora_rank 408 | if self.lora_rank>0: 409 | self.with_lora = True 410 | self.lora_lp = depth 411 | self.attn_lora = True 412 | self.mlp_lora = True 413 | if self.mlp_lora: 414 | for b_idx in range(self.lora_lp): 415 | self.blocks[b_idx].mlp = LoraMlp(in_features=self.embed_dim, hidden_features=int(self.embed_dim*mlp_ratio), act_layer=act_layer, drop=drop_rate, lora_rank=lora_rank) 416 | if self.attn_lora: 417 | for b_idx in range(self.lora_lp): 418 | self.blocks[b_idx].attn = LoraAttention(self.embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop_rate, proj_drop=drop_rate, lora_rank=lora_rank) 419 | else: 420 | self.with_lora = False 421 | 422 | def get_adapter(self, embed_dim): 423 | return nn.Sequential( 424 | nn.Linear(embed_dim, embed_dim*3, bias=False), 425 | nn.LayerNorm(embed_dim*3), 426 | nn.GELU(), 427 | nn.Linear(embed_dim*3, embed_dim, bias=False), 428 | nn.LayerNorm(embed_dim), 429 | nn.GELU(), 430 | nn.Linear(embed_dim, embed_dim, bias=True), 431 | nn.Sigmoid() 432 | ) 433 | 434 | 435 | def init_weights(self, mode=''): 436 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 437 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 438 | trunc_normal_(self.pos_embed, std=.02) 439 | if self.dist_token is not None: 440 | trunc_normal_(self.dist_token, std=.02) 441 | if mode.startswith('jax'): 442 | # leave cls token as zeros to match jax impl 443 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 444 | else: 445 | trunc_normal_(self.cls_token, std=.02) 446 | self.apply(_init_vit_weights) 447 | 448 | def _init_weights(self, m): 449 | # this fn left here for compat with downstream users 450 | _init_vit_weights(m) 451 | 452 | @torch.jit.ignore() 453 | def load_pretrained(self, checkpoint_path, prefix=''): 454 | _load_weights(self, checkpoint_path, prefix) 455 | 456 | @torch.jit.ignore 457 | def no_weight_decay(self): 458 | return {'pos_embed', 'cls_token', 'dist_token'} 459 | 460 | def get_classifier(self): 461 | if self.dist_token is None: 462 | return self.head 463 | else: 464 | return self.head, self.head_dist 465 | 466 | def reset_classifier(self, num_classes, global_pool=''): 467 | self.num_classes = num_classes 468 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 469 | if self.num_tokens == 2: 470 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 471 | 472 | def forward_features(self, x, layer_feat=False): 473 | img = x 474 | x = self.patch_embed(x) 475 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 476 | if self.dist_token is None: 477 | x = torch.cat((cls_token, x), dim=1) 478 | else: 479 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 480 | x = self.pos_drop(x + self.pos_embed) 481 | # x = self.blocks(x) 482 | feats = [] 483 | feats_l = [] 484 | for b_id, block in enumerate(self.blocks): 485 | x = block(x) 486 | if layer_feat: 487 | feats_l.append(x) 488 | if b_id == len(self.blocks)-2: 489 | penultimate_feat = x.clone() 490 | 491 | if layer_feat: 492 | return feats_l 493 | 494 | if self.global_pool: 495 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 496 | return self.norm(x) 497 | 498 | x = self.norm(x) 499 | if self.dist_token is None: 500 | return self.pre_logits(x[:, 0]) 501 | else: 502 | return x[:, 0] # , x[:, 1] 503 | 504 | def forward(self, x, layer_feat=False): 505 | x = self.forward_features(x, layer_feat) 506 | x = {'features': x} 507 | #if self.head_dist is not None: 508 | # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 509 | # if self.training and not torch.jit.is_scripting(): 510 | # # during inference, return the average of both classifier predictions 511 | # return x, x_dist 512 | # else: 513 | # return (x + x_dist) / 2 514 | #else: 515 | # x = self.head(x) 516 | return x 517 | 518 | 519 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 520 | """ ViT weight initialization 521 | * When called without n, head_bias, jax_impl args it will behave exactly the same 522 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 523 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 524 | """ 525 | if isinstance(module, nn.Linear): 526 | if name.startswith('head'): 527 | nn.init.zeros_(module.weight) 528 | nn.init.constant_(module.bias, head_bias) 529 | elif name.startswith('pre_logits'): 530 | lecun_normal_(module.weight) 531 | nn.init.zeros_(module.bias) 532 | else: 533 | if jax_impl: 534 | nn.init.xavier_uniform_(module.weight) 535 | if module.bias is not None: 536 | if 'mlp' in name: 537 | nn.init.normal_(module.bias, std=1e-6) 538 | else: 539 | nn.init.zeros_(module.bias) 540 | else: 541 | trunc_normal_(module.weight, std=.02) 542 | if module.bias is not None: 543 | nn.init.zeros_(module.bias) 544 | elif jax_impl and isinstance(module, nn.Conv2d): 545 | # NOTE conv was left to pytorch default in my original init 546 | lecun_normal_(module.weight) 547 | if module.bias is not None: 548 | nn.init.zeros_(module.bias) 549 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 550 | nn.init.zeros_(module.bias) 551 | nn.init.ones_(module.weight) 552 | 553 | @torch.no_grad() 554 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 555 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 556 | """ 557 | import numpy as np 558 | 559 | def _n2p(w, t=True): 560 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 561 | w = w.flatten() 562 | if t: 563 | if w.ndim == 4: 564 | w = w.transpose([3, 2, 0, 1]) 565 | elif w.ndim == 3: 566 | w = w.transpose([2, 0, 1]) 567 | elif w.ndim == 2: 568 | w = w.transpose([1, 0]) 569 | return torch.from_numpy(w) 570 | 571 | w = np.load(checkpoint_path) 572 | if not prefix and 'opt/target/embedding/kernel' in w: 573 | prefix = 'opt/target/' 574 | 575 | if hasattr(model.patch_embed, 'backbone'): 576 | # hybrid 577 | backbone = model.patch_embed.backbone 578 | stem_only = not hasattr(backbone, 'stem') 579 | stem = backbone if stem_only else backbone.stem 580 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 581 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 582 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 583 | if not stem_only: 584 | for i, stage in enumerate(backbone.stages): 585 | for j, block in enumerate(stage.blocks): 586 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 587 | for r in range(3): 588 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 589 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 590 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 591 | if block.downsample is not None: 592 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 593 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 594 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 595 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 596 | else: 597 | embed_conv_w = adapt_input_conv( 598 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 599 | model.patch_embed.proj.weight.copy_(embed_conv_w) 600 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 601 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 602 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 603 | if pos_embed_w.shape != model.pos_embed.shape: 604 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 605 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 606 | model.pos_embed.copy_(pos_embed_w) 607 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 608 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 609 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 610 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 611 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 612 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 613 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 614 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 615 | for i, block in enumerate(model.blocks.children()): 616 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 617 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 618 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 619 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 620 | block.attn.qkv.weight.copy_(torch.cat([ 621 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 622 | block.attn.qkv.bias.copy_(torch.cat([ 623 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 624 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 625 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 626 | for r in range(2): 627 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 628 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 629 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 630 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 631 | 632 | 633 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 634 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 635 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 636 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 637 | ntok_new = posemb_new.shape[1] 638 | if num_tokens: 639 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 640 | ntok_new -= num_tokens 641 | else: 642 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 643 | gs_old = int(math.sqrt(len(posemb_grid))) 644 | if not len(gs_new): # backwards compatibility 645 | gs_new = [int(math.sqrt(ntok_new))] * 2 646 | assert len(gs_new) >= 2 647 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 648 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 649 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 650 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 651 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 652 | return posemb 653 | 654 | 655 | def checkpoint_filter_fn(state_dict, model): 656 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 657 | out_dict = {} 658 | if 'model' in state_dict: 659 | # For deit models 660 | state_dict = state_dict['model'] 661 | for k, v in state_dict.items(): 662 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 663 | # For old models that I trained prior to conv based patchification 664 | O, I, H, W = model.patch_embed.proj.weight.shape 665 | v = v.reshape(O, -1, H, W) 666 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 667 | # To resize pos embedding when using model at different size from pretrained weights 668 | v = resize_pos_embed( 669 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 670 | out_dict[k] = v 671 | return out_dict 672 | 673 | 674 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): 675 | default_cfg = default_cfg or default_cfgs[variant] 676 | if kwargs.get('features_only', None): 677 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 678 | 679 | # NOTE this extra code to support handling of repr size for in21k pretrained models 680 | default_num_classes = default_cfg['num_classes'] 681 | num_classes = kwargs.get('num_classes', default_num_classes) 682 | repr_size = kwargs.pop('representation_size', None) 683 | if repr_size is not None and num_classes != default_num_classes: 684 | # Remove representation layer if fine-tuning. This may not always be the desired action, 685 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 686 | _logger.warning("Removing representation layer for fine-tuning.") 687 | repr_size = None 688 | 689 | model = build_model_with_cfg( 690 | VisionTransformer, variant, pretrained, 691 | default_cfg=default_cfg, 692 | representation_size=repr_size, 693 | pretrained_filter_fn=checkpoint_filter_fn, 694 | pretrained_custom_load='npz' in default_cfg['url'], 695 | **kwargs) 696 | return model 697 | 698 | 699 | 700 | @register_model 701 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs): 702 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 703 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 704 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 705 | """ 706 | model_kwargs = dict( 707 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 708 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 709 | del model.head 710 | del model.norm 711 | model.norm = nn.LayerNorm(768) 712 | return model 713 | 714 | @register_model 715 | def vit_base_patch16_224_mocov3(pretrained=False, **kwargs): 716 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 717 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 718 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 719 | """ 720 | model_kwargs = dict( 721 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 722 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=False, **model_kwargs) 723 | del model.head 724 | ckpt = torch.load('mocov3-vit-base-300ep.pth', map_location='cpu')['model'] 725 | state_dict = model.state_dict() 726 | state_dict.update(ckpt) 727 | model.load_state_dict(state_dict) 728 | del model.norm 729 | model.norm = nn.LayerNorm(768) 730 | return model 731 | 732 | 733 | 734 | @register_model 735 | def vit_base_lora_patch16_224_in21k(pretrained=False, lora_rank=4, **kwargs): 736 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 737 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 738 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 739 | """ 740 | model_kwargs = dict( 741 | patch_size=16, embed_dim=768, depth=12, num_heads=12, lora_rank=lora_rank, **kwargs) 742 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 743 | del model.head 744 | del model.norm 745 | model.norm = nn.LayerNorm(768) 746 | 747 | return model 748 | 749 | @register_model 750 | def vit_base_lora_patch16_224_mocov3(pretrained=False, lora_rank=4, **kwargs): 751 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 752 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 753 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 754 | """ 755 | model_kwargs = dict( 756 | patch_size=16, embed_dim=768, depth=12, num_heads=12, lora_rank=lora_rank, **kwargs) 757 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=False, **model_kwargs) 758 | del model.head 759 | ckpt = torch.load('mocov3-vit-base-300ep.pth', map_location='cpu')['model'] 760 | state_dict = model.state_dict() 761 | state_dict.update(ckpt) 762 | model.load_state_dict(state_dict) 763 | del model.norm 764 | model.norm = nn.LayerNorm(768) 765 | return model 766 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import copy 4 | import torch 5 | from utils import factory 6 | from utils.data_manager import DataManager 7 | from utils.toolkit import count_parameters 8 | import os 9 | import numpy as np 10 | 11 | 12 | def test(args): 13 | seed_list = copy.deepcopy(args['seed']) 14 | device = copy.deepcopy(args['device']) 15 | 16 | for seed in seed_list: 17 | args['seed'] = seed 18 | args['device'] = device 19 | _test(args) 20 | 21 | 22 | def _test(args): 23 | logfilename = 'logs/{}/{}_test_{}_{}_{}_{}_{}_{}'.format(args['model_name'], args['prefix'], args['seed'], args['model_name'], args['convnet_type'], 24 | args['dataset'], args['init_cls'], args['increment']) 25 | logging.basicConfig( 26 | level=logging.INFO, 27 | format='%(asctime)s [%(filename)s] => %(message)s', 28 | handlers=[ 29 | logging.FileHandler(filename=logfilename + '.log'), 30 | logging.StreamHandler(sys.stdout) 31 | ] 32 | ) 33 | 34 | _set_random() 35 | _set_device(args) 36 | print_args(args) 37 | data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment']) 38 | model = factory.get_model(args['model_name'], args) 39 | 40 | cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []} 41 | for task in range(data_manager.nb_tasks): 42 | logging.info('All params: {}'.format(count_parameters(model._network))) 43 | # logging.info('Trainable params: {}'.format(count_parameters(model._network, True))) 44 | # model.incremental_train(data_manager) 45 | model.incremental_update(data_manager) 46 | cnn_accy, nme_accy = model.eval_task() 47 | model.after_task() 48 | 49 | if nme_accy is not None: 50 | logging.info('CNN: {}'.format(cnn_accy['grouped'])) 51 | logging.info('NME: {}'.format(nme_accy['grouped'])) 52 | 53 | cnn_curve['top1'].append(cnn_accy['top1']) 54 | cnn_curve['top5'].append(cnn_accy['top5']) 55 | 56 | nme_curve['top1'].append(nme_accy['top1']) 57 | nme_curve['top5'].append(nme_accy['top5']) 58 | 59 | logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) 60 | logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean())) 61 | if 'task_acc' in cnn_accy.keys(): 62 | logging.info('Task: {}'.format(cnn_accy['task_acc'])) 63 | logging.info('CNN top5 curve: {}'.format(cnn_curve['top5'])) 64 | logging.info('NME top1 curve: {}'.format(nme_curve['top1'])) 65 | logging.info('NME top5 curve: {}\n'.format(nme_curve['top5'])) 66 | else: 67 | logging.info('No NME accuracy.') 68 | logging.info('CNN: {}'.format(cnn_accy['grouped'])) 69 | 70 | cnn_curve['top1'].append(cnn_accy['top1']) 71 | cnn_curve['top5'].append(cnn_accy['top5']) 72 | 73 | logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) 74 | logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5'])) 75 | 76 | 77 | def _set_device(args): 78 | device_type = args['device'] 79 | gpus = [] 80 | 81 | for device in device_type: 82 | if device_type == -1: 83 | device = torch.device('cpu') 84 | else: 85 | device = torch.device('cuda:{}'.format(device)) 86 | 87 | gpus.append(device) 88 | 89 | args['device'] = gpus 90 | 91 | 92 | def _set_random(): 93 | torch.manual_seed(1) 94 | torch.cuda.manual_seed(1) 95 | torch.cuda.manual_seed_all(1) 96 | torch.backends.cudnn.deterministic = True 97 | torch.backends.cudnn.benchmark = False 98 | 99 | 100 | def print_args(args): 101 | for key, value in args.items(): 102 | logging.info('{}: {}'.format(key, value)) 103 | 104 | -------------------------------------------------------------------------------- /exps/slca/slca_cars.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cars196_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slca_cars", 11 | "model_postfix": "50e", 12 | "convnet_type": "vit-b-p16", 13 | "device": ["0","1"], 14 | "seed": [1993, 1996, 1997], 15 | "epochs": 50, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.05, 18 | "milestones": [40] 19 | } 20 | -------------------------------------------------------------------------------- /exps/slca/slca_cars_mocov3.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cars196_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slca_cars_mocov3", 11 | "model_postfix": "90e", 12 | "convnet_type": "vit-b-p16-mocov3", 13 | "device": ["0","1"], 14 | "seed": [1993, 1996, 1997], 15 | "epochs": 90, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.05, 18 | "milestones": [80] 19 | } 20 | -------------------------------------------------------------------------------- /exps/slca/slca_cifar.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "slca_cifar", 11 | "model_postfix": "20e", 12 | "convnet_type": "vit-b-p16", 13 | "device": ["0","1"], 14 | "seed": [1993, 1996, 1997], 15 | "epochs": 20, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.1, 18 | "milestones": [18] 19 | } 20 | -------------------------------------------------------------------------------- /exps/slca/slca_cifar_mocov3.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "slca_cifar_mocov3", 11 | "model_postfix": "90e", 12 | "convnet_type": "vit-b-p16-mocov3", 13 | "device": ["0","1"], 14 | "seed": [1993, 1996, 1997], 15 | "epochs": 90, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.1, 18 | "milestones": [80] 19 | } 20 | -------------------------------------------------------------------------------- /exps/slca/slca_cub.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cub200_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slca_cub", 11 | "model_postfix": "50e", 12 | "convnet_type": "vit-b-p16", 13 | "device": ["0","1"], 14 | "seed": [1993, 1996, 1997], 15 | "epochs": 50, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.1, 18 | "milestones": [40] 19 | } 20 | -------------------------------------------------------------------------------- /exps/slca/slca_cub_mocov3.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cub200_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slca_cub_mocov3", 11 | "model_postfix": "90e", 12 | "convnet_type": "vit-b-p16-mocov3", 13 | "device": ["0","1"], 14 | "seed": [1993, 1996, 1997], 15 | "epochs": 90, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.1, 18 | "milestones": [80] 19 | } 20 | -------------------------------------------------------------------------------- /exps/slca/slca_imgnetr.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "imagenet-r", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slca_imgnetr", 11 | "model_postfix": "50e", 12 | "convnet_type": "vit-b-p16", 13 | "device": ["0","1"], 14 | "seed": [1993, 1996, 1997], 15 | "epochs": 50, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.1, 18 | "milestones": [40] 19 | } 20 | -------------------------------------------------------------------------------- /exps/slca/slca_imgnetr_mocov3.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "imagenet-r", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slca_imgnetr_mocov3", 11 | "model_postfix": "90e", 12 | "convnet_type": "vit-b-p16-mocov3", 13 | "device": ["0","1"], 14 | "seed": [1993, 1996, 1997], 15 | "epochs": 90, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.1, 18 | "milestones": [80] 19 | } 20 | -------------------------------------------------------------------------------- /exps/slcapp/slcapp_cars_lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cars196_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slcapp_cars_lora", 11 | "model_postfix": "50e", 12 | "convnet_type": "vit-b-p16-lora", 13 | "device": ["0"], 14 | "seed": [1993,1996,1997], 15 | "bcb_lrscale": 0.1, 16 | "weight_decay": 0.0, 17 | "lora_rank": 4, 18 | "sce_a": 0.5, 19 | "sce_b": 0.5, 20 | "fc_scale_mu": 0.02, 21 | "epochs": 50, 22 | "ca_epochs": 5, 23 | "ca_with_logit_norm": 0.05, 24 | "milestones": [40] 25 | } 26 | -------------------------------------------------------------------------------- /exps/slcapp/slcapp_cars_lora_mocov3.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cars196_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slcapp_cars_lora_mocov3", 11 | "model_postfix": "90e", 12 | "convnet_type": "vit-b-p16-lora-mocov3", 13 | "device": ["0"], 14 | "seed": [1993,1996,1997], 15 | "bcb_lrscale": 0.1, 16 | "lora_rank": 4, 17 | "sce_a": 0.5, 18 | "sce_b": 0.5, 19 | "fc_scale_mu": 0.02, 20 | "weight_decay": 0.0, 21 | "epochs": 90, 22 | "ca_epochs": 5, 23 | "ca_with_logit_norm": 0.05, 24 | "milestones": [60, 80] 25 | } 26 | -------------------------------------------------------------------------------- /exps/slcapp/slcapp_cifar_lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "slcapp_cifar_lora", 11 | "model_postfix": "20e", 12 | "convnet_type": "vit-b-p16-lora", 13 | "device": ["0"], 14 | "seed": [1993,1996,1997], 15 | "epochs": 20, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.1, 18 | "bcb_lrscale": 0.1, 19 | "lora_rank": 4, 20 | "sce_a": 1.0, 21 | "sce_b": 0.1, 22 | "fc_scale_mu": 0.02, 23 | "weight_decay": 0.0005, 24 | "milestones": [18] 25 | } 26 | -------------------------------------------------------------------------------- /exps/slcapp/slcapp_cifar_lora_mocov3.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cifar100_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 10, 9 | "increment": 10, 10 | "model_name": "slcapp_cifar_lora_mocov3", 11 | "model_postfix": "90e", 12 | "convnet_type": "vit-b-p16-lora-mocov3", 13 | "device": ["0"], 14 | "seed": [1993,1996,1997], 15 | "epochs": 90, 16 | "ca_epochs": 5, 17 | "ca_with_logit_norm": 0.1, 18 | "bcb_lrscale": 0.1, 19 | "lora_rank": 4, 20 | "sce_a": 1.0, 21 | "sce_b": 0.1, 22 | "fc_scale_mu": 0.02, 23 | "weight_decay": 0.0005, 24 | "milestones": [60, 80] 25 | } 26 | -------------------------------------------------------------------------------- /exps/slcapp/slcapp_cub_lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cub200_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slcapp_cub_lora", 11 | "model_postfix": "50e", 12 | "convnet_type": "vit-b-p16-lora", 13 | "device": ["0"], 14 | "seed": [1993,1996,1997], 15 | "bcb_lrscale": 0.1, 16 | "lora_rank": 4, 17 | "sce_a": 1.0, 18 | "sce_b": 0.1, 19 | "fc_scale_mu": 0.02, 20 | "epochs": 50, 21 | "ca_epochs": 5, 22 | "ca_with_logit_norm": 0.1, 23 | "milestones": [40] 24 | } 25 | -------------------------------------------------------------------------------- /exps/slcapp/slcapp_cub_lora_mocov3.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "cub200_224", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slcapp_cub_lora_mocov3", 11 | "model_postfix": "90e", 12 | "convnet_type": "vit-b-p16-lora-mocov3", 13 | "device": ["0"], 14 | "seed": [1993,1996,1997], 15 | "bcb_lrscale": 0.1, 16 | "lora_rank": 4, 17 | "sce_a": 1.0, 18 | "sce_b": 0.1, 19 | "fc_scale_mu": 0.02, 20 | "epochs": 90, 21 | "ca_epochs": 5, 22 | "ca_with_logit_norm": 0.1, 23 | "milestones": [60, 80] 24 | } 25 | -------------------------------------------------------------------------------- /exps/slcapp/slcapp_imgnetr_lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "imagenet-r", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slcapp_imgnetr_lora", 11 | "model_postfix": "50e", 12 | "convnet_type": "vit-b-p16-lora", 13 | "bcb_lrscale": 0.1, 14 | "weight_decay": 0.0, 15 | "device": ["0"], 16 | "lora_rank": 4, 17 | "sce_a": 0.5, 18 | "sce_b": 0.5, 19 | "fc_scale_mu": 0.02, 20 | "seed": [1993,1996,1997], 21 | "epochs": 50, 22 | "ca_epochs": 5, 23 | "ca_with_logit_norm": 0.1, 24 | "milestones": [40] 25 | } 26 | -------------------------------------------------------------------------------- /exps/slcapp/slcapp_imgnetr_lora_mocov3.json: -------------------------------------------------------------------------------- 1 | { 2 | "prefix": "reproduce", 3 | "dataset": "imagenet-r", 4 | "memory_size": 0, 5 | "memory_per_class": 0, 6 | "fixed_memory": false, 7 | "shuffle": true, 8 | "init_cls": 20, 9 | "increment": 20, 10 | "model_name": "slcapp_imgnetr_lora_mocov3", 11 | "model_postfix": "90e", 12 | "convnet_type": "vit-b-p16-lora-mocov3", 13 | "bcb_lrscale": 0.1, 14 | "weight_decay": 0.0, 15 | "device": ["0"], 16 | "lora_rank": 4, 17 | "sce_a": 0.5, 18 | "sce_b": 0.5, 19 | "fc_scale_mu": 0.02, 20 | "seed": [1993,1996,1997], 21 | "epochs": 90, 22 | "ca_epochs": 5, 23 | "ca_with_logit_norm": 0.1, 24 | "milestones": [60, 80] 25 | } 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from trainer import train 4 | from evaluator import test 5 | 6 | def main(): 7 | args = setup_parser().parse_args() 8 | param = load_json(args.config) 9 | args = vars(args) # Converting argparse Namespace to a dict. 10 | args.update(param) # Add parameters from json 11 | if args['test_only']: 12 | test(args) 13 | else: 14 | train(args) 15 | 16 | 17 | def load_json(settings_path): 18 | with open(settings_path) as data_file: 19 | param = json.load(data_file) 20 | 21 | return param 22 | 23 | 24 | def setup_parser(): 25 | parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') 26 | parser.add_argument('--config', type=str, default='./exps/finetune.json', 27 | help='Json file of settings.') 28 | parser.add_argument('--test_only', action='store_true') 29 | return parser 30 | 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GengDavid/SLCA/169d4e6b91e3ca30caa717200c1944981c4426d1/models/__init__.py -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from utils.toolkit import tensor2numpy, accuracy 8 | from scipy.spatial.distance import cdist 9 | from collections import OrderedDict 10 | 11 | EPSILON = 1e-8 12 | batch_size = 64 13 | 14 | 15 | class BaseLearner(object): 16 | def __init__(self, args): 17 | self._cur_task = -1 18 | self._known_classes = 0 19 | self._total_classes = 0 20 | self._network = None 21 | self._old_network = None 22 | self._data_memory, self._targets_memory = np.array([]), np.array([]) 23 | self.topk = 5 24 | 25 | self._memory_size = args['memory_size'] 26 | self._memory_per_class = args['memory_per_class'] 27 | self._fixed_memory = args['fixed_memory'] 28 | self._device = args['device'][0] 29 | self._multiple_gpus = args['device'] 30 | 31 | @property 32 | def exemplar_size(self): 33 | assert len(self._data_memory) == len(self._targets_memory), 'Exemplar size error.' 34 | return len(self._targets_memory) 35 | 36 | @property 37 | def samples_per_class(self): 38 | if self._fixed_memory: 39 | return self._memory_per_class 40 | else: 41 | assert self._total_classes != 0, 'Total classes is 0' 42 | return (self._memory_size // self._total_classes) 43 | 44 | @property 45 | def feature_dim(self): 46 | if isinstance(self._network, nn.DataParallel): 47 | return self._network.module.feature_dim 48 | else: 49 | return self._network.feature_dim 50 | 51 | def save_checkpoint(self, filename, head_only=False, learnable_only=False): 52 | if hasattr(self._network, 'module'): 53 | to_save = self._network.module 54 | else: 55 | to_save = self._network 56 | 57 | if head_only: 58 | to_save_dict = to_save.fc.state_dict() 59 | else: 60 | to_save_dict = to_save.state_dict() 61 | 62 | if learnable_only: 63 | new_dict = OrderedDict() 64 | filtered_keys = [n for n, p in to_save.named_parameters() if p.requires_grad] 65 | for k in filtered_keys: 66 | new_dict[k] = to_save_dict[k] 67 | to_save_dict = new_dict 68 | 69 | save_dict = { 70 | 'tasks': self._cur_task, 71 | 'model_state_dict': to_save_dict, 72 | } 73 | torch.save(save_dict, '{}_{}.pth'.format(filename, self._cur_task)) 74 | 75 | def after_task(self): 76 | pass 77 | 78 | def _evaluate(self, y_pred, y_true): 79 | ret = {} 80 | grouped = accuracy(y_pred.T[0], y_true, self._known_classes) 81 | ret['grouped'] = grouped 82 | ret['top1'] = grouped['total'] 83 | ret['top{}'.format(5)] = np.around((y_pred.T == np.tile(y_true, (self.topk, 1))).sum()*100/len(y_true), 84 | decimals=2) 85 | 86 | return ret 87 | 88 | def eval_task(self): 89 | y_pred, y_true = self._eval_cnn(self.test_loader) 90 | cnn_accy = self._evaluate(y_pred, y_true) 91 | 92 | if hasattr(self, '_class_means') and False: # TODO 93 | y_pred, y_true = self._eval_nme(self.test_loader, self._class_means) 94 | nme_accy = self._evaluate(y_pred, y_true) 95 | else: 96 | nme_accy = None 97 | 98 | return cnn_accy, nme_accy 99 | 100 | def incremental_train(self): 101 | pass 102 | 103 | def _train(self): 104 | pass 105 | 106 | def _get_memory(self): 107 | if len(self._data_memory) == 0: 108 | return None 109 | else: 110 | return (self._data_memory, self._targets_memory) 111 | 112 | 113 | def _inner_eval(self, model, loader): 114 | model.eval() 115 | y_pred, y_true = [], [] 116 | for _, (_, inputs, targets) in enumerate(loader): 117 | inputs = inputs.to(self._device) 118 | with torch.no_grad(): 119 | outputs = model(inputs)['logits'] 120 | predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] # [bs, topk] 121 | y_pred.append(predicts.cpu().numpy()) 122 | y_true.append(targets.cpu().numpy()) 123 | 124 | y_pred, y_true = np.concatenate(y_pred), np.concatenate(y_true) # [N, topk] 125 | 126 | cnn_accy = self._evaluate(y_pred, y_true) 127 | return cnn_accy 128 | 129 | def _compute_accuracy(self, model, loader): 130 | model.eval() 131 | correct, total = 0, 0 132 | for i, (_, inputs, targets) in enumerate(loader): 133 | inputs = inputs.to(self._device) 134 | with torch.no_grad(): 135 | outputs = model(inputs)['logits'] 136 | predicts = torch.max(outputs, dim=1)[1] 137 | correct += (predicts.cpu() == targets).sum() 138 | total += len(targets) 139 | 140 | return np.around(tensor2numpy(correct)*100 / total, decimals=2) 141 | 142 | def _eval_cnn(self, loader): 143 | self._network.eval() 144 | y_pred, y_true = [], [] 145 | for _, (_, inputs, targets) in enumerate(loader): 146 | inputs = inputs.to(self._device) 147 | with torch.no_grad(): 148 | outputs = self._network(inputs)['logits'] 149 | predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] # [bs, topk] 150 | y_pred.append(predicts.cpu().numpy()) 151 | y_true.append(targets.cpu().numpy()) 152 | 153 | return np.concatenate(y_pred), np.concatenate(y_true) # [N, topk] 154 | 155 | def _eval_nme(self, loader, class_means): 156 | self._network.eval() 157 | vectors, y_true = self._extract_vectors(loader) 158 | vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T 159 | 160 | norm_means = class_means / np.linalg.norm(class_means) 161 | dists = cdist(norm_means, vectors, 'sqeuclidean') # [nb_classes, N] 162 | scores = dists.T # [N, nb_classes], choose the one with the smallest distance 163 | 164 | return np.argsort(scores, axis=1)[:, :self.topk], y_true # [N, topk] 165 | 166 | def _extract_vectors(self, loader): 167 | self._network.eval() 168 | vectors, targets = [], [] 169 | for _, _inputs, _targets in loader: 170 | _targets = _targets.numpy() 171 | if isinstance(self._network, nn.DataParallel): 172 | _vectors = tensor2numpy(self._network.module.extract_vector(_inputs.to(self._device))) 173 | else: 174 | _vectors = tensor2numpy(self._network.extract_vector(_inputs.to(self._device))) 175 | 176 | vectors.append(_vectors) 177 | targets.append(_targets) 178 | 179 | return np.concatenate(vectors), np.concatenate(targets) 180 | 181 | def _extract_vectors_aug(self, loader, repeat=2): 182 | self._network.eval() 183 | vectors, targets = [], [] 184 | for _ in range(repeat): 185 | for _, _inputs, _targets in loader: 186 | _targets = _targets.numpy() 187 | with torch.no_grad(): 188 | if isinstance(self._network, nn.DataParallel): 189 | _vectors = tensor2numpy(self._network.module.extract_vector(_inputs.to(self._device))) 190 | else: 191 | _vectors = tensor2numpy(self._network.extract_vector(_inputs.to(self._device))) 192 | 193 | vectors.append(_vectors) 194 | targets.append(_targets) 195 | 196 | return np.concatenate(vectors), np.concatenate(targets) 197 | 198 | def _compute_class_mean(self, data_manager, check_diff=False, oracle=False): 199 | if hasattr(self, '_class_means') and self._class_means is not None and not check_diff: 200 | ori_classes = self._class_means.shape[0] 201 | assert ori_classes==self._known_classes 202 | new_class_means = np.zeros((self._total_classes, self.feature_dim)) 203 | new_class_means[:self._known_classes] = self._class_means 204 | self._class_means = new_class_means 205 | # new_class_cov = np.zeros((self._total_classes, self.feature_dim, self.feature_dim)) 206 | new_class_cov = torch.zeros((self._total_classes, self.feature_dim, self.feature_dim)) 207 | new_class_cov[:self._known_classes] = self._class_covs 208 | self._class_covs = new_class_cov 209 | elif not check_diff: 210 | self._class_means = np.zeros((self._total_classes, self.feature_dim)) 211 | # self._class_covs = np.zeros((self._total_classes, self.feature_dim, self.feature_dim)) 212 | self._class_covs = torch.zeros((self._total_classes, self.feature_dim, self.feature_dim)) 213 | 214 | # self._class_covs = [] 215 | 216 | if check_diff: 217 | for class_idx in range(0, self._known_classes): 218 | data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', 219 | mode='test', ret_data=True) 220 | idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 221 | # vectors, _ = self._extract_vectors_aug(idx_loader) 222 | vectors, _ = self._extract_vectors(idx_loader) 223 | class_mean = np.mean(vectors, axis=0) 224 | # class_cov = np.cov(vectors.T) 225 | class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T) 226 | if check_diff: 227 | log_info = "cls {} sim: {}".format(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0)).item()) 228 | logging.info(log_info) 229 | np.save('task_{}_cls_{}_mean.npy'.format(self._cur_task, class_idx), class_mean) 230 | # print(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0))) 231 | 232 | if oracle: 233 | for class_idx in range(0, self._known_classes): 234 | data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', 235 | mode='test', ret_data=True) 236 | idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 237 | vectors, _ = self._extract_vectors(idx_loader) 238 | 239 | # vectors = np.concatenate([vectors_aug, vectors]) 240 | 241 | class_mean = np.mean(vectors, axis=0) 242 | # class_cov = np.cov(vectors.T) 243 | class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T)+torch.eye(class_mean.shape[-1])*1e-5 244 | self._class_means[class_idx, :] = class_mean 245 | self._class_covs[class_idx, ...] = class_cov 246 | 247 | for class_idx in range(self._known_classes, self._total_classes): 248 | # data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', 249 | # mode='train', ret_data=True) 250 | # idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 251 | # vectors_aug, _ = self._extract_vectors_aug(idx_loader) 252 | 253 | data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', 254 | mode='test', ret_data=True) 255 | idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 256 | vectors, _ = self._extract_vectors(idx_loader) 257 | 258 | # vectors = np.concatenate([vectors_aug, vectors]) 259 | 260 | class_mean = np.mean(vectors, axis=0) 261 | # class_cov = np.cov(vectors.T) 262 | class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T)+torch.eye(class_mean.shape[-1])*1e-4 263 | if check_diff: 264 | log_info = "cls {} sim: {}".format(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0)).item()) 265 | logging.info(log_info) 266 | np.save('task_{}_cls_{}_mean.npy'.format(self._cur_task, class_idx), class_mean) 267 | np.save('task_{}_cls_{}_mean_beforetrain.npy'.format(self._cur_task, class_idx), self._class_means[class_idx, :]) 268 | # print(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0))) 269 | self._class_means[class_idx, :] = class_mean 270 | self._class_covs[class_idx, ...] = class_cov 271 | # self._class_covs.append(class_cov) 272 | -------------------------------------------------------------------------------- /models/slca.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from models.base import BaseLearner 9 | from utils.inc_net import FinetuneIncrementalNet 10 | from torchvision import transforms 11 | from torch.distributions.multivariate_normal import MultivariateNormal 12 | import random 13 | from utils.toolkit import tensor2numpy, accuracy 14 | import copy 15 | import os 16 | 17 | epochs = 20 18 | lrate = 0.01 19 | milestones = [60,100,140] 20 | lrate_decay = 0.1 21 | batch_size = 128 22 | split_ratio = 0.1 23 | T = 2 24 | weight_decay = 5e-4 25 | num_workers = 8 26 | ca_epochs = 5 27 | 28 | 29 | class SLCA(BaseLearner): 30 | def __init__(self, args): 31 | super().__init__(args) 32 | self._network = FinetuneIncrementalNet(args, pretrained=True) 33 | self.log_path = "logs/{}_{}".format(args['model_name'], args['model_postfix']) 34 | self.model_prefix = args['prefix'] 35 | if 'epochs' in args.keys(): 36 | global epochs 37 | epochs = args['epochs'] 38 | if 'milestones' in args.keys(): 39 | global milestones 40 | milestones = args['milestones'] 41 | if 'lr' in args.keys(): 42 | global lrate 43 | lrate = args['lr'] 44 | print('set lr to ', lrate) 45 | if 'bcb_lrscale' in args.keys(): 46 | self.bcb_lrscale = args['bcb_lrscale'] 47 | else: 48 | self.bcb_lrscale = 1.0/100 49 | if self.bcb_lrscale == 0: 50 | self.fix_bcb = True 51 | else: 52 | self.fix_bcb = False 53 | print('fic_bcb', self.fix_bcb) 54 | 55 | 56 | 57 | if 'save_before_ca' in args.keys() and args['save_before_ca']: 58 | self.save_before_ca = True 59 | else: 60 | self.save_before_ca = False 61 | 62 | if 'ca_epochs' in args.keys(): 63 | global ca_epochs 64 | ca_epochs = args['ca_epochs'] 65 | 66 | if 'ca_with_logit_norm' in args.keys() and args['ca_with_logit_norm']>0: 67 | self.logit_norm = args['ca_with_logit_norm'] 68 | else: 69 | self.logit_norm = None 70 | 71 | self.run_id = args['run_id'] 72 | self.seed = args['seed'] 73 | self.task_sizes = [] 74 | 75 | def after_task(self): 76 | self._known_classes = self._total_classes 77 | logging.info('Exemplar size: {}'.format(self.exemplar_size)) 78 | self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}'.format(self.seed), head_only=self.fix_bcb) 79 | self._network.fc.recall() 80 | 81 | def incremental_train(self, data_manager): 82 | self._cur_task += 1 83 | task_size = data_manager.get_task_size(self._cur_task) 84 | self.task_sizes.append(task_size) 85 | self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) 86 | self.topk = self._total_classes if self._total_classes<5 else 5 87 | self._network.update_fc(data_manager.get_task_size(self._cur_task)) 88 | logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes)) 89 | 90 | self._network.to(self._device) 91 | 92 | train_dset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), 93 | source='train', mode='train', 94 | appendent=[], with_raw=False) 95 | test_dset = data_manager.get_dataset(np.arange(0, self._total_classes), source='test', mode='test') 96 | dset_name = data_manager.dataset_name.lower() 97 | 98 | self.train_loader = DataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 99 | self.test_loader = DataLoader(test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 100 | 101 | self._stage1_training(self.train_loader, self.test_loader) 102 | 103 | if len(self._multiple_gpus) > 1: 104 | self._network = self._network.module 105 | 106 | # CA 107 | self._network.fc.backup() 108 | if self.save_before_ca: 109 | self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}_before_ca'.format(self.seed), head_only=self.fix_bcb) 110 | 111 | self._compute_class_mean(data_manager, check_diff=False, oracle=False) 112 | if self._cur_task>0 and ca_epochs>0: 113 | self._stage2_compact_classifier(task_size) 114 | if len(self._multiple_gpus) > 1: 115 | self._network = self._network.module 116 | 117 | 118 | def _run(self, train_loader, test_loader, optimizer, scheduler): 119 | run_epochs = epochs 120 | for epoch in range(1, run_epochs+1): 121 | self._network.train() 122 | losses = 0. 123 | for i, (_, inputs, targets) in enumerate(train_loader): 124 | inputs, targets = inputs.to(self._device), targets.to(self._device) 125 | 126 | logits = self._network(inputs, bcb_no_grad=self.fix_bcb)['logits'] 127 | cur_targets = torch.where(targets-self._known_classes>=0,targets-self._known_classes,-100) 128 | loss = F.cross_entropy(logits[:, self._known_classes:], cur_targets) 129 | 130 | optimizer.zero_grad() 131 | loss.backward() 132 | optimizer.step() 133 | losses += loss.item() 134 | 135 | scheduler.step() 136 | if epoch%5==0: 137 | train_acc = self._compute_accuracy(self._network, train_loader) 138 | test_acc = self._compute_accuracy(self._network, test_loader) 139 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}'.format( 140 | self._cur_task, epoch, epochs, losses/len(train_loader), train_acc, test_acc) 141 | else: 142 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}'.format( 143 | self._cur_task, epoch, epochs, losses/len(train_loader)) 144 | logging.info(info) 145 | 146 | def _stage1_training(self, train_loader, test_loader): 147 | ''' 148 | if self._cur_task == 0: 149 | loaded_dict = torch.load('./dict_0.pkl') 150 | self._network.load_state_dict(loaded_dict['model_state_dict']) 151 | self._network.to(self._device) 152 | return 153 | ''' 154 | base_params = self._network.convnet.parameters() 155 | base_fc_params = [p for p in self._network.fc.parameters() if p.requires_grad==True] 156 | head_scale = 1. if 'moco' in self.log_path else 1. 157 | if not self.fix_bcb: 158 | base_params = {'params': base_params, 'lr': lrate*self.bcb_lrscale, 'weight_decay': weight_decay} 159 | base_fc_params = {'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay} 160 | network_params = [base_params, base_fc_params] 161 | else: 162 | for p in base_params: 163 | p.requires_grad = False 164 | network_params = [{'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay}] 165 | optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay) 166 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=lrate_decay) 167 | 168 | if len(self._multiple_gpus) > 1: 169 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 170 | 171 | self._run(train_loader, test_loader, optimizer, scheduler) 172 | 173 | 174 | def _stage2_compact_classifier(self, task_size): 175 | for p in self._network.fc.parameters(): 176 | p.requires_grad=True 177 | 178 | run_epochs = ca_epochs 179 | crct_num = self._total_classes 180 | param_list = [p for p in self._network.fc.parameters() if p.requires_grad] 181 | network_params = [{'params': param_list, 'lr': lrate, 182 | 'weight_decay': weight_decay}] 183 | optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay) 184 | # scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[4], gamma=lrate_decay) 185 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs) 186 | 187 | self._network.to(self._device) 188 | if len(self._multiple_gpus) > 1: 189 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 190 | 191 | self._network.eval() 192 | for epoch in range(run_epochs): 193 | losses = 0. 194 | 195 | sampled_data = [] 196 | sampled_label = [] 197 | num_sampled_pcls = 256 198 | 199 | for c_id in range(crct_num): 200 | t_id = c_id//task_size 201 | decay = (t_id+1)/(self._cur_task+1)*0.1 202 | cls_mean = torch.tensor(self._class_means[c_id], dtype=torch.float64).to(self._device)*(0.9+decay) # torch.from_numpy(self._class_means[c_id]).to(self._device) 203 | cls_cov = self._class_covs[c_id].to(self._device) 204 | 205 | m = MultivariateNormal(cls_mean.float(), cls_cov.float()) 206 | 207 | sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,)) 208 | sampled_data.append(sampled_data_single) 209 | sampled_label.extend([c_id]*num_sampled_pcls) 210 | 211 | sampled_data = torch.cat(sampled_data, dim=0).float().to(self._device) 212 | sampled_label = torch.tensor(sampled_label).long().to(self._device) 213 | 214 | inputs = sampled_data 215 | targets= sampled_label 216 | 217 | sf_indexes = torch.randperm(inputs.size(0)) 218 | inputs = inputs[sf_indexes] 219 | targets = targets[sf_indexes] 220 | 221 | 222 | for _iter in range(crct_num): 223 | inp = inputs[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls] 224 | tgt = targets[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls] 225 | outputs = self._network(inp, bcb_no_grad=True, fc_only=True) 226 | logits = outputs['logits'] 227 | 228 | if self.logit_norm is not None: 229 | per_task_norm = [] 230 | prev_t_size = 0 231 | cur_t_size = 0 232 | for _ti in range(self._cur_task+1): 233 | cur_t_size += self.task_sizes[_ti] 234 | temp_norm = torch.norm(logits[:, prev_t_size:cur_t_size], p=2, dim=-1, keepdim=True) + 1e-7 235 | per_task_norm.append(temp_norm) 236 | prev_t_size += self.task_sizes[_ti] 237 | per_task_norm = torch.cat(per_task_norm, dim=-1) 238 | norms = per_task_norm.mean(dim=-1, keepdim=True) 239 | 240 | norms_all = torch.norm(logits[:, :crct_num], p=2, dim=-1, keepdim=True) + 1e-7 241 | decoupled_logits = torch.div(logits[:, :crct_num], norms) / self.logit_norm 242 | loss = F.cross_entropy(decoupled_logits, tgt) 243 | 244 | else: 245 | loss = F.cross_entropy(logits[:, :crct_num], tgt) 246 | 247 | optimizer.zero_grad() 248 | loss.backward() 249 | optimizer.step() 250 | losses += loss.item() 251 | 252 | scheduler.step() 253 | test_acc = self._compute_accuracy(self._network, self.test_loader) 254 | info = 'CA Task {} => Loss {:.3f}, Test_accy {:.3f}'.format( 255 | self._cur_task, losses/self._total_classes, test_acc) 256 | logging.info(info) 257 | 258 | 259 | -------------------------------------------------------------------------------- /models/slca_pp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from models.slca import SLCA 9 | from utils.inc_net import FinetuneIncrementalNet 10 | from torchvision import transforms 11 | from torch.distributions.multivariate_normal import MultivariateNormal 12 | import random 13 | from utils.toolkit import tensor2numpy, accuracy 14 | import copy 15 | import os 16 | 17 | epochs = 20 18 | lrate = 0.01 19 | milestones = [60,100,140] 20 | lrate_decay = 0.1 21 | batch_size = 128 22 | split_ratio = 0.1 23 | T = 2 24 | weight_decay = 5e-4 25 | num_workers = 8 26 | ca_epochs = 5 27 | 28 | 29 | class SLCApp(SLCA): 30 | def __init__(self, args): 31 | super().__init__(args) 32 | self.model_name_ = args['model_name'] 33 | self.sce_a, self.sce_b = args['sce_a'], args['sce_b'] 34 | assert 'fc_scale_mu' in args.keys() and args['fc_scale_mu']>0 35 | self.fc_scale_mu = args['fc_scale_mu'] 36 | 37 | # hybrid-SL 38 | for n, p in self._network.convnet.named_parameters(): 39 | if 'norm' in n or 'bias' in n or 'cls_token' in n or 'lora' in n: 40 | p.requires_grad=True 41 | print(n) 42 | else: 43 | p.requires_grad=False 44 | 45 | # reset global values according to args 46 | if 'weight_decay' in args.keys(): 47 | global weight_decay 48 | weight_decay = args['weight_decay'] 49 | if 'epochs' in args.keys(): 50 | global epochs 51 | epochs = args['epochs'] 52 | if 'milestones' in args.keys(): 53 | global milestones 54 | milestones = args['milestones'] 55 | if 'lr' in args.keys(): 56 | global lrate 57 | lrate = args['lr'] 58 | if 'ca_epochs' in args.keys(): 59 | global ca_epochs 60 | ca_epochs = args['ca_epochs'] 61 | 62 | def after_task(self): 63 | self._known_classes = self._total_classes 64 | logging.info('Exemplar size: {}'.format(self.exemplar_size)) 65 | self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}'.format(self.seed), head_only=self.fix_bcb, learnable_only=True) 66 | self._network.fc.recall() 67 | 68 | def incremental_train(self, data_manager): 69 | self._cur_task += 1 70 | task_size = data_manager.get_task_size(self._cur_task) 71 | self.task_sizes.append(task_size) 72 | self._total_classes = self._known_classes + task_size 73 | self.topk = self._total_classes if self._total_classes<5 else 5 74 | self._network.update_fc(task_size) 75 | logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes)) 76 | 77 | self._network.to(self._device) 78 | 79 | if self._cur_task==0 and self._network.convnet.lora_rank>0: 80 | for b_idx in range(self._network.convnet.lora_lp): 81 | self._network.convnet.blocks[b_idx].mlp.init_lora() 82 | self._network.convnet.blocks[b_idx].attn.init_lora() 83 | 84 | train_dset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), 85 | source='train', mode='train', 86 | appendent=[], with_raw=False) 87 | test_dset = data_manager.get_dataset(np.arange(0, self._total_classes), source='test', mode='test') 88 | dset_name = data_manager.dataset_name.lower() 89 | 90 | self.train_loader = DataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 91 | self.test_loader = DataLoader(test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 92 | 93 | self._stage1_training(self.train_loader, self.test_loader) 94 | 95 | if len(self._multiple_gpus) > 1: 96 | self._network = self._network.module 97 | 98 | # CA 99 | self._network.fc.backup() 100 | if self.save_before_ca: 101 | self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}_before_ca'.format(self.seed), head_only=self.fix_bcb) 102 | 103 | self._compute_class_mean(data_manager, check_diff=False, oracle=False) 104 | if self._cur_task>0 and ca_epochs>0: 105 | self._stage2_compact_classifier(task_size) 106 | if len(self._multiple_gpus) > 1: 107 | self._network = self._network.module 108 | self._network.fc.update_scale() 109 | 110 | 111 | def _run(self, train_loader, test_loader, optimizer, scheduler): 112 | run_epochs = epochs 113 | for epoch in range(1, run_epochs+1): 114 | self._network.train() 115 | losses = 0. 116 | for i, (_, inputs, targets) in enumerate(train_loader): 117 | inputs, targets = inputs.to(self._device), targets.to(self._device) 118 | 119 | logits = self._network(inputs, bcb_no_grad=self.fix_bcb)['logits'] 120 | cur_targets = torch.where(targets-self._known_classes>=0,targets-self._known_classes,-100) 121 | cur_logits = logits[:, self._known_classes:] 122 | 123 | ce_loss = F.cross_entropy(cur_logits, cur_targets) 124 | pred = F.softmax(cur_logits, dim=1) 125 | pred = torch.clamp(pred, min=1e-7, max=1.0) 126 | label_one_hot = torch.nn.functional.one_hot(cur_targets, pred.size(1)).float().to(pred.device) 127 | label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0) 128 | rce_loss = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1)) 129 | loss = self.sce_a*ce_loss + self.sce_b*rce_loss.mean() 130 | 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | losses += loss.item() 135 | 136 | scheduler.step() 137 | if epoch%5==0: 138 | train_acc = self._compute_accuracy(self._network, train_loader) 139 | test_acc = self._compute_accuracy(self._network, test_loader) 140 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}'.format( 141 | self._cur_task, epoch, epochs, losses/len(train_loader), train_acc, test_acc) 142 | else: 143 | info = 'Task {}, Epoch {}/{} => Loss {:.3f}'.format( 144 | self._cur_task, epoch, epochs, losses/len(train_loader)) 145 | logging.info(info) 146 | 147 | def _stage1_training(self, train_loader, test_loader): 148 | if self._network.convnet.lora_rank>0: 149 | base_loraA_params = [p for n, p in self._network.convnet.named_parameters() if p.requires_grad==True and 'lora_A' in n] 150 | base_loraB_params = [p for n, p in self._network.convnet.named_parameters() if p.requires_grad==True and 'lora_B' in n] 151 | base_others_params = [p for n, p in self._network.convnet.named_parameters() if p.requires_grad==True and 'lora' not in n] 152 | else: 153 | base_params = self._network.convnet.parameters() 154 | 155 | base_fc_params = [p for p in self._network.fc.parameters() if p.requires_grad==True] 156 | head_scale = 1. 157 | if not self.fix_bcb: 158 | base_fc_params = {'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay} 159 | if self._network.convnet.lora_rank>0: 160 | lora_scale = 0.5 if self._cur_task==0 else self.bcb_lrscale 161 | base_loraA_params = {'params': base_loraA_params, 'lr': lrate*lora_scale, 'weight_decay': weight_decay} 162 | base_loraB_params = {'params': base_loraB_params, 'lr': lrate*lora_scale, 'weight_decay': weight_decay} 163 | base_others_params = {'params': base_others_params, 'lr': lrate*self.bcb_lrscale, 'weight_decay': weight_decay} 164 | network_params = [base_loraA_params, base_loraB_params, base_others_params, base_fc_params] 165 | else: 166 | base_params = {'params': base_params, 'lr': lrate*self.bcb_lrscale, 'weight_decay': weight_decay} 167 | network_params = [base_params, base_fc_params] 168 | else: 169 | for p in base_params: 170 | p.requires_grad = False 171 | network_params = [{'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay}] 172 | optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay) 173 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=epochs, eta_min=lrate*self.bcb_lrscale*lrate_decay**len(milestones)) 174 | 175 | if len(self._multiple_gpus) > 1: 176 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 177 | 178 | self._run(train_loader, test_loader, optimizer, scheduler) 179 | 180 | 181 | def _stage2_compact_classifier(self, task_size): 182 | for n, p in self._network.fc.named_parameters(): 183 | p.requires_grad=True 184 | if 'scales' in n: 185 | p.requires_grad=False 186 | print('fixed ', n) 187 | 188 | run_epochs = ca_epochs 189 | crct_num = self._total_classes 190 | param_list = [p for p in self._network.fc.parameters() if p.requires_grad] 191 | network_params = [{'params': param_list, 'lr': lrate, 192 | 'weight_decay': weight_decay}] 193 | optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay) 194 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs) 195 | 196 | self._network.to(self._device) 197 | if len(self._multiple_gpus) > 1: 198 | self._network = nn.DataParallel(self._network, self._multiple_gpus) 199 | 200 | self._network.eval() 201 | for epoch in range(run_epochs): 202 | losses = 0. 203 | 204 | sampled_data = [] 205 | sampled_label = [] 206 | num_sampled_pcls = 256 207 | 208 | for c_id in range(crct_num): 209 | t_id = c_id//task_size 210 | cls_mean = torch.tensor(self._class_means[c_id], dtype=torch.float64).to(self._device) 211 | cls_cov = self._class_covs[c_id].to(self._device) 212 | 213 | m = MultivariateNormal(cls_mean.float(), cls_cov.float()) 214 | 215 | sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,)) 216 | 217 | # a more robust implementation of feature scaling 218 | # fixed scaling during inference, dynamic during training 219 | rand_scaling = self.fc_scale_mu + 0.02*(torch.rand(sampled_data_single.size(0), device=self._device)-0.5) 220 | rand_scaling = 1/(1+rand_scaling*(self._cur_task-t_id)) 221 | sampled_data_single = rand_scaling.unsqueeze(1)*sampled_data_single 222 | 223 | sampled_data.append(sampled_data_single) 224 | sampled_label.extend([c_id]*num_sampled_pcls) 225 | 226 | sampled_data = torch.cat(sampled_data, dim=0).float().to(self._device) 227 | sampled_label = torch.tensor(sampled_label).long().to(self._device) 228 | 229 | inputs = sampled_data 230 | targets= sampled_label 231 | 232 | sf_indexes = torch.randperm(inputs.size(0)) 233 | inputs = inputs[sf_indexes] 234 | targets = targets[sf_indexes] 235 | 236 | 237 | for _iter in range(crct_num): 238 | inp = inputs[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls] 239 | tgt = targets[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls] 240 | outputs = self._network(inp, bcb_no_grad=True, fc_only=True) 241 | logits = outputs['logits'] 242 | 243 | if self.logit_norm is not None: 244 | per_task_norm = [] 245 | prev_t_size = 0 246 | cur_t_size = 0 247 | for _ti in range(self._cur_task+1): 248 | cur_t_size += self.task_sizes[_ti] 249 | temp_norm = torch.norm(logits[:, prev_t_size:cur_t_size], p=2, dim=-1, keepdim=True) + 1e-7 250 | per_task_norm.append(temp_norm) 251 | prev_t_size += self.task_sizes[_ti] 252 | per_task_norm = torch.cat(per_task_norm, dim=-1) 253 | norms = per_task_norm.mean(dim=-1, keepdim=True) 254 | 255 | norms_all = torch.norm(logits[:, :crct_num], p=2, dim=-1, keepdim=True) + 1e-7 256 | decoupled_logits = torch.div(logits[:, :crct_num], norms) / self.logit_norm 257 | loss = F.cross_entropy(decoupled_logits, tgt) 258 | 259 | else: 260 | loss = F.cross_entropy(logits[:, :crct_num], tgt) 261 | 262 | optimizer.zero_grad() 263 | loss.backward() 264 | optimizer.step() 265 | losses += loss.item() 266 | 267 | scheduler.step() 268 | test_acc = self._compute_accuracy(self._network, self.test_loader) 269 | info = 'CA Task {} => Loss {:.3f}, Test_accy {:.3f}'.format( 270 | self._cur_task, losses/self._total_classes, test_acc) 271 | logging.info(info) 272 | 273 | 274 | -------------------------------------------------------------------------------- /slca_performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GengDavid/SLCA/169d4e6b91e3ca30caa717200c1944981c4426d1/slca_performance.jpg -------------------------------------------------------------------------------- /split_car.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import os.path as osp 4 | import shutil 5 | from tqdm import tqdm 6 | 7 | img_list = np.genfromtxt('mat2txt.txt', dtype=str) # [num, img, cls, istest] 8 | class_mappings = np.genfromtxt('label_map.txt', dtype=str) 9 | class_mappings = {a[0]: a[1] for a in class_mappings} 10 | 11 | for item in tqdm(img_list): 12 | if bool(int(item[-1])): 13 | cls_folder = osp.join('cars196', 'train', class_mappings[item[2]]) 14 | if not os.path.exists(cls_folder): 15 | os.mkdir(cls_folder) 16 | shutil.copy(item[1], osp.join(cls_folder, item[1].split('/')[-1])) 17 | else: 18 | cls_folder = osp.join('cars196', 'val', class_mappings[item[2]]) 19 | if not os.path.exists(cls_folder): 20 | os.mkdir(cls_folder) 21 | shutil.copy(item[1], osp.join(cls_folder, item[1].split('/')[-1])) 22 | -------------------------------------------------------------------------------- /split_cub.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import os.path as osp 4 | import shutil 5 | from tqdm import tqdm 6 | 7 | train_val_list = np.genfromtxt('train_test_split.txt', dtype='str') 8 | img_list = np.genfromtxt('images.txt', dtype='str') 9 | 10 | img_id_mapping = {a[0]: a[1] for a in img_list} 11 | for img, is_train in tqdm(train_val_list): 12 | if bool(int(is_train)): 13 | # print(osp.join('CUB200', 'val', img_id_mapping[img])) 14 | os.remove(osp.join('CUB200', 'val', img_id_mapping[img])) 15 | else: 16 | os.remove(osp.join('CUB200', 'train', img_id_mapping[img])) 17 | -------------------------------------------------------------------------------- /train_all.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 main.py --config=exps/slcapp/slcapp_cifar_lora.json & 2 | CUDA_VISIBLE_DEVICES=1 python3 main.py --config=exps/slcapp/slcapp_imgnetr_lora.json & 3 | CUDA_VISIBLE_DEVICES=2 python3 main.py --config=exps/slcapp/slcapp_cub_lora.json & 4 | CUDA_VISIBLE_DEVICES=3 python3 main.py --config=exps/slcapp/slcapp_cars_lora.json & 5 | 6 | CUDA_VISIBLE_DEVICES=0 python3 main.py --config=exps/slcapp/slcapp_cifar_lora_mocov3.json & 7 | CUDA_VISIBLE_DEVICES=1 python3 main.py --config=exps/slcapp/slcapp_imgnetr_lora_mocov3.json & 8 | CUDA_VISIBLE_DEVICES=2 python3 main.py --config=exps/slcapp/slcapp_cub_lora_mocov3.json & 9 | CUDA_VISIBLE_DEVICES=3 python3 main.py --config=exps/slcapp/slcapp_cars_lora_mocov3.json & -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import copy 4 | import torch 5 | from utils import factory 6 | from utils.data_manager import DataManager 7 | from utils.toolkit import count_parameters 8 | import os 9 | import numpy as np 10 | 11 | def train(args): 12 | seed_list = copy.deepcopy(args['seed']) 13 | device = copy.deepcopy(args['device']) 14 | 15 | res_finals, res_avgs = [], [] 16 | for run_id, seed in enumerate(seed_list): 17 | args['seed'] = seed 18 | args['run_id'] = run_id 19 | args['device'] = device 20 | res_final, res_avg = _train(args) 21 | res_finals.append(res_final) 22 | res_avgs.append(res_avg) 23 | logging.info('final accs: {}'.format(res_finals)) 24 | logging.info('avg accs: {}'.format(res_avgs)) 25 | 26 | 27 | 28 | def _train(args): 29 | try: 30 | os.mkdir("logs/{}_{}".format(args['model_name'], args['model_postfix'])) 31 | except: 32 | pass 33 | logfilename = 'logs/{}_{}/{}_{}_{}_{}_{}_{}_{}'.format(args['model_name'], args['model_postfix'], args['prefix'], args['seed'], args['model_name'], args['convnet_type'], 34 | args['dataset'], args['init_cls'], args['increment']) 35 | logging.basicConfig( 36 | level=logging.INFO, 37 | format='%(asctime)s [%(filename)s] => %(message)s', 38 | handlers=[ 39 | logging.FileHandler(filename=logfilename + '.log'), 40 | logging.StreamHandler(sys.stdout) 41 | ] 42 | ) 43 | 44 | _set_random() 45 | _set_device(args) 46 | print_args(args) 47 | data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment']) 48 | model = factory.get_model(args['model_name'], args) 49 | 50 | cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []} 51 | for task in range(data_manager.nb_tasks): 52 | logging.info('All params: {}'.format(count_parameters(model._network))) 53 | logging.info('Trainable params: {}'.format(count_parameters(model._network, True))) 54 | model.incremental_train(data_manager) 55 | 56 | cnn_accy, nme_accy = model.eval_task() 57 | model.after_task() 58 | 59 | 60 | if nme_accy is not None: 61 | logging.info('CNN: {}'.format(cnn_accy['grouped'])) 62 | logging.info('NME: {}'.format(nme_accy['grouped'])) 63 | 64 | cnn_curve['top1'].append(cnn_accy['top1']) 65 | cnn_curve['top5'].append(cnn_accy['top5']) 66 | 67 | nme_curve['top1'].append(nme_accy['top1']) 68 | nme_curve['top5'].append(nme_accy['top5']) 69 | 70 | logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) 71 | logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean())) 72 | if 'task_acc' in cnn_accy.keys(): 73 | logging.info('Task: {}'.format(cnn_accy['task_acc'])) 74 | logging.info('CNN top5 curve: {}'.format(cnn_curve['top5'])) 75 | logging.info('NME top1 curve: {}'.format(nme_curve['top1'])) 76 | logging.info('NME top5 curve: {}\n'.format(nme_curve['top5'])) 77 | else: 78 | logging.info('No NME accuracy.') 79 | logging.info('CNN: {}'.format(cnn_accy['grouped'])) 80 | 81 | cnn_curve['top1'].append(cnn_accy['top1']) 82 | cnn_curve['top5'].append(cnn_accy['top5']) 83 | 84 | logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) 85 | logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean())) 86 | logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5'])) 87 | 88 | return (cnn_curve['top1'][-1], np.array(cnn_curve['top1']).mean()) 89 | 90 | def _set_device(args): 91 | device_type = args['device'] 92 | gpus = [] 93 | 94 | for device in device_type: 95 | if device_type == -1: 96 | device = torch.device('cpu') 97 | else: 98 | device = torch.device('cuda:{}'.format(device)) 99 | 100 | gpus.append(device) 101 | 102 | args['device'] = gpus 103 | 104 | 105 | def _set_random(): 106 | torch.manual_seed(1) 107 | torch.cuda.manual_seed(1) 108 | torch.cuda.manual_seed_all(1) 109 | torch.backends.cudnn.deterministic = True 110 | torch.backends.cudnn.benchmark = False 111 | 112 | 113 | def print_args(args): 114 | for key, value in args.items(): 115 | logging.info('{}: {}'.format(key, value)) 116 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GengDavid/SLCA/169d4e6b91e3ca30caa717200c1944981c4426d1/utils/__init__.py -------------------------------------------------------------------------------- /utils/buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from typing import Tuple 9 | from torchvision import transforms 10 | from copy import deepcopy 11 | 12 | def icarl_replay(self, dataset, val_set_split=0): 13 | """ 14 | Merge the replay buffer with the current task data. 15 | Optionally split the replay buffer into a validation set. 16 | 17 | :param self: the model instance 18 | :param dataset: the dataset 19 | :param val_set_split: the fraction of the replay buffer to be used as validation set 20 | """ 21 | 22 | if self.task > 0: 23 | buff_val_mask = torch.rand(len(self.buffer)) < val_set_split 24 | val_train_mask = torch.zeros(len(dataset.train_loader.dataset.data)).bool() 25 | val_train_mask[torch.randperm(len(dataset.train_loader.dataset.data))[:buff_val_mask.sum()]] = True 26 | 27 | if val_set_split > 0: 28 | self.val_loader = deepcopy(dataset.train_loader) 29 | 30 | data_concatenate = torch.cat if type(dataset.train_loader.dataset.data) == torch.Tensor else np.concatenate 31 | need_aug = hasattr(dataset.train_loader.dataset, 'not_aug_transform') 32 | if not need_aug: 33 | refold_transform = lambda x: x.cpu() 34 | else: 35 | data_shape = len(dataset.train_loader.dataset.data[0].shape) 36 | if data_shape == 3: 37 | refold_transform = lambda x: (x.cpu()*255).permute([0, 2, 3, 1]).numpy().astype(np.uint8) 38 | elif data_shape == 2: 39 | refold_transform = lambda x: (x.cpu()*255).squeeze(1).type(torch.uint8) 40 | 41 | # REDUCE AND MERGE TRAINING SET 42 | dataset.train_loader.dataset.targets = np.concatenate([ 43 | dataset.train_loader.dataset.targets[~val_train_mask], 44 | self.buffer.labels.cpu().numpy()[:len(self.buffer)][~buff_val_mask] 45 | ]) 46 | dataset.train_loader.dataset.data = data_concatenate([ 47 | dataset.train_loader.dataset.data[~val_train_mask], 48 | refold_transform((self.buffer.examples)[:len(self.buffer)][~buff_val_mask]) 49 | ]) 50 | 51 | if val_set_split > 0: 52 | # REDUCE AND MERGE VALIDATION SET 53 | self.val_loader.dataset.targets = np.concatenate([ 54 | self.val_loader.dataset.targets[val_train_mask], 55 | self.buffer.labels.cpu().numpy()[:len(self.buffer)][buff_val_mask] 56 | ]) 57 | self.val_loader.dataset.data = data_concatenate([ 58 | self.val_loader.dataset.data[val_train_mask], 59 | refold_transform((self.buffer.examples)[:len(self.buffer)][buff_val_mask]) 60 | ]) 61 | 62 | def reservoir(num_seen_examples: int, buffer_size: int) -> int: 63 | """ 64 | Reservoir sampling algorithm. 65 | :param num_seen_examples: the number of seen examples 66 | :param buffer_size: the maximum buffer size 67 | :return: the target index if the current image is sampled, else -1 68 | """ 69 | if num_seen_examples < buffer_size: 70 | return num_seen_examples 71 | 72 | rand = np.random.randint(0, num_seen_examples + 1) 73 | if rand < buffer_size: 74 | return rand 75 | else: 76 | return -1 77 | 78 | 79 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: 80 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size 81 | 82 | 83 | class Buffer: 84 | """ 85 | The memory buffer of rehearsal method. 86 | """ 87 | def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir'): 88 | assert mode in ['ring', 'reservoir'] 89 | self.buffer_size = buffer_size 90 | self.device = device 91 | self.num_seen_examples = 0 92 | self.functional_index = eval(mode) 93 | if mode == 'ring': 94 | assert n_tasks is not None 95 | self.task_number = n_tasks 96 | self.buffer_portion_size = buffer_size // n_tasks 97 | self.attributes = ['examples', 'labels', 'logits', 'task_labels'] 98 | 99 | def to(self, device): 100 | self.device = device 101 | for attr_str in self.attributes: 102 | if hasattr(self, attr_str): 103 | setattr(self, attr_str, getattr(self, attr_str).to(device)) 104 | return self 105 | 106 | def __len__(self): 107 | return min(self.num_seen_examples, self.buffer_size) 108 | 109 | 110 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, 111 | logits: torch.Tensor, task_labels: torch.Tensor) -> None: 112 | """ 113 | Initializes just the required tensors. 114 | :param examples: tensor containing the images 115 | :param labels: tensor containing the labels 116 | :param logits: tensor containing the outputs of the network 117 | :param task_labels: tensor containing the task labels 118 | """ 119 | for attr_str in self.attributes: 120 | attr = eval(attr_str) 121 | if attr is not None and not hasattr(self, attr_str): 122 | typ = torch.int64 if attr_str.endswith('els') else torch.float32 123 | setattr(self, attr_str, torch.zeros((self.buffer_size, 124 | *attr.shape[1:]), dtype=typ, device=self.device)) 125 | 126 | def add_data(self, examples, labels=None, logits=None, task_labels=None): 127 | """ 128 | Adds the data to the memory buffer according to the reservoir strategy. 129 | :param examples: tensor containing the images 130 | :param labels: tensor containing the labels 131 | :param logits: tensor containing the outputs of the network 132 | :param task_labels: tensor containing the task labels 133 | :return: 134 | """ 135 | if not hasattr(self, 'examples'): 136 | self.init_tensors(examples, labels, logits, task_labels) 137 | 138 | for i in range(examples.shape[0]): 139 | index = reservoir(self.num_seen_examples, self.buffer_size) 140 | self.num_seen_examples += 1 141 | if index >= 0: 142 | self.examples[index] = examples[i].to(self.device) 143 | if labels is not None: 144 | self.labels[index] = labels[i].to(self.device) 145 | if logits is not None: 146 | self.logits[index] = logits[i].to(self.device) 147 | if task_labels is not None: 148 | self.task_labels[index] = task_labels[i].to(self.device) 149 | 150 | def get_data(self, size: int, transform: transforms=None, return_index=False) -> Tuple: 151 | """ 152 | Random samples a batch of size items. 153 | :param size: the number of requested items 154 | :param transform: the transformation to be applied (data augmentation) 155 | :return: 156 | """ 157 | if size > min(self.num_seen_examples, self.examples.shape[0]): 158 | size = min(self.num_seen_examples, self.examples.shape[0]) 159 | 160 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]), 161 | size=size, replace=False) 162 | if transform is None: transform = lambda x: x 163 | ret_tuple = (torch.stack([transform(ee.cpu()) 164 | for ee in self.examples[choice]]).to(self.device),) 165 | for attr_str in self.attributes[1:]: 166 | if hasattr(self, attr_str): 167 | attr = getattr(self, attr_str) 168 | ret_tuple += (attr[choice],) 169 | 170 | if not return_index: 171 | return ret_tuple 172 | else: 173 | return (torch.tensor(choice).to(self.device), ) + ret_tuple 174 | 175 | return ret_tuple 176 | 177 | def get_data_by_index(self, indexes, transform: transforms=None) -> Tuple: 178 | """ 179 | Returns the data by the given index. 180 | :param index: the index of the item 181 | :param transform: the transformation to be applied (data augmentation) 182 | :return: 183 | """ 184 | if transform is None: transform = lambda x: x 185 | ret_tuple = (torch.stack([transform(ee.cpu()) 186 | for ee in self.examples[indexes]]).to(self.device),) 187 | for attr_str in self.attributes[1:]: 188 | if hasattr(self, attr_str): 189 | attr = getattr(self, attr_str).to(self.device) 190 | ret_tuple += (attr[indexes],) 191 | return ret_tuple 192 | 193 | 194 | def is_empty(self) -> bool: 195 | """ 196 | Returns true if the buffer is empty, false otherwise. 197 | """ 198 | if self.num_seen_examples == 0: 199 | return True 200 | else: 201 | return False 202 | 203 | def get_all_data(self, transform: transforms=None) -> Tuple: 204 | """ 205 | Return all the items in the memory buffer. 206 | :param transform: the transformation to be applied (data augmentation) 207 | :return: a tuple with all the items in the memory buffer 208 | """ 209 | if transform is None: transform = lambda x: x 210 | ret_tuple = (torch.stack([transform(ee.cpu()) 211 | for ee in self.examples]).to(self.device),) 212 | for attr_str in self.attributes[1:]: 213 | if hasattr(self, attr_str): 214 | attr = getattr(self, attr_str) 215 | ret_tuple += (attr,) 216 | return ret_tuple 217 | 218 | def empty(self) -> None: 219 | """ 220 | Set all the tensors to None. 221 | """ 222 | for attr_str in self.attributes: 223 | if hasattr(self, attr_str): 224 | delattr(self, attr_str) 225 | self.num_seen_examples = 0 226 | -------------------------------------------------------------------------------- /utils/cutmix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | def rand_bbox(size, lam): 6 | W = size[2] 7 | H = size[3] 8 | cut_rat = np.sqrt(1. - lam) 9 | cut_w = np.int(W * cut_rat) 10 | cut_h = np.int(H * cut_rat) 11 | 12 | # uniform 13 | cx = np.random.randint(W) 14 | cy = np.random.randint(H) 15 | 16 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 17 | bby1 = np.clip(cy - cut_h // 2, 0, H) 18 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 19 | bby2 = np.clip(cy + cut_h // 2, 0, H) 20 | 21 | return bbx1, bby1, bbx2, bby2 22 | 23 | def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5): 24 | assert(alpha > 0) 25 | # generate mixed sample 26 | lam = np.random.beta(alpha, alpha) 27 | 28 | batch_size = x.size()[0] 29 | index = torch.randperm(batch_size) 30 | 31 | if torch.cuda.is_available(): 32 | index = index.cuda() 33 | 34 | y_a, y_b = y, y[index] 35 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam) 36 | x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] 37 | 38 | # adjust lambda to exactly match pixel ratio 39 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 40 | return x, y_a, y_b, lam 41 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from utils.toolkit import split_images_labels 4 | 5 | 6 | class iData(object): 7 | train_trsf = [] 8 | test_trsf = [] 9 | common_trsf = [] 10 | class_order = None 11 | 12 | 13 | class iCIFAR10(iData): 14 | use_path = False 15 | train_trsf = [ 16 | transforms.RandomCrop(32, padding=4), 17 | transforms.RandomHorizontalFlip(p=0.5), 18 | transforms.ColorJitter(brightness=63/255) 19 | ] 20 | test_trsf = [] 21 | common_trsf = [ 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)), 24 | ] 25 | 26 | class_order = np.arange(10).tolist() 27 | 28 | def download_data(self): 29 | train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True) 30 | test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True) 31 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets) 32 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets) 33 | 34 | 35 | class iCIFAR100(iData): 36 | use_path = False 37 | train_trsf = [ 38 | transforms.RandomCrop(32, padding=4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ColorJitter(brightness=63/255) 41 | ] 42 | test_trsf = [] 43 | common_trsf = [ 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)), 46 | ] 47 | 48 | class_order = np.arange(100).tolist() 49 | 50 | def download_data(self): 51 | train_dataset = datasets.cifar.CIFAR100('./data', train=True, download=True) 52 | test_dataset = datasets.cifar.CIFAR100('./data', train=False, download=True) 53 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets) 54 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets) 55 | 56 | class iCIFAR100_224(iCIFAR100): 57 | train_trsf = [ 58 | transforms.RandomResizedCrop(224, interpolation=3), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ColorJitter(brightness=63/255) 61 | ] 62 | test_trsf = [ 63 | transforms.Resize(256, interpolation=3), 64 | transforms.CenterCrop(224), 65 | ] 66 | 67 | class iImageNet1000(iData): 68 | use_path = True 69 | train_trsf = [ 70 | transforms.RandomResizedCrop(224), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ColorJitter(brightness=63/255) 73 | ] 74 | test_trsf = [ 75 | transforms.Resize(256), 76 | transforms.CenterCrop(224), 77 | ] 78 | common_trsf = [ 79 | transforms.ToTensor(), 80 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 81 | ] 82 | 83 | class_order = np.arange(1000).tolist() 84 | 85 | def download_data(self): 86 | assert 0,"You should specify the folder of your dataset" 87 | train_dir = '[DATA-PATH]/train/' 88 | test_dir = '[DATA-PATH]/val/' 89 | 90 | train_dset = datasets.ImageFolder(train_dir) 91 | test_dset = datasets.ImageFolder(test_dir) 92 | 93 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 94 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 95 | 96 | class iImageNet100(iData): 97 | use_path = True 98 | train_trsf = [ 99 | transforms.RandomResizedCrop(224), 100 | transforms.RandomHorizontalFlip(), 101 | ] 102 | test_trsf = [ 103 | transforms.Resize(256), 104 | transforms.CenterCrop(224), 105 | ] 106 | common_trsf = [ 107 | transforms.ToTensor(), 108 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 109 | ] 110 | 111 | class_order = np.arange(1000).tolist() 112 | 113 | def download_data(self): 114 | train_dir = 'data/imagenet100/train/' 115 | test_dir = 'data/imagenet100/val/' 116 | 117 | train_dset = datasets.ImageFolder(train_dir) 118 | test_dset = datasets.ImageFolder(test_dir) 119 | 120 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 121 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 122 | 123 | 124 | class iImageNetR(iData): 125 | use_path = True 126 | train_trsf = [ 127 | transforms.RandomResizedCrop(224, interpolation=3), 128 | transforms.RandomHorizontalFlip(), 129 | ] 130 | test_trsf = [ 131 | transforms.Resize(256, interpolation=3), 132 | transforms.CenterCrop(224), 133 | ] 134 | common_trsf = [ 135 | transforms.ToTensor(), 136 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 137 | ] 138 | 139 | class_order = np.arange(1000).tolist() 140 | 141 | def download_data(self): 142 | train_dir = 'data/imagenet-r/train/' 143 | test_dir = 'data/imagenet-r/val/' 144 | 145 | train_dset = datasets.ImageFolder(train_dir) 146 | test_dset = datasets.ImageFolder(test_dir) 147 | 148 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 149 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 150 | 151 | class iCUB200_224(iData): 152 | use_path = True 153 | train_trsf = [ 154 | transforms.Resize((300, 300), interpolation=3), 155 | transforms.RandomCrop((224, 224)), 156 | transforms.RandomHorizontalFlip(), 157 | ] 158 | test_trsf = [ 159 | transforms.Resize(256, interpolation=3), 160 | transforms.CenterCrop(224), 161 | ] 162 | common_trsf = [ 163 | transforms.ToTensor(), 164 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 165 | ] 166 | class_order = np.arange(1000).tolist() 167 | 168 | def download_data(self): 169 | train_dir = 'data/cub_200/train/' 170 | test_dir = 'data/cub_200/val/' 171 | 172 | train_dset = datasets.ImageFolder(train_dir) 173 | test_dset = datasets.ImageFolder(test_dir) 174 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 175 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 176 | 177 | class iCARS196_224(iData): 178 | use_path = True 179 | train_trsf = [ 180 | transforms.Resize((300, 300), interpolation=3), 181 | transforms.RandomCrop((224, 224)), 182 | transforms.RandomHorizontalFlip(), 183 | ] 184 | test_trsf = [ 185 | transforms.Resize(256, interpolation=3), 186 | transforms.CenterCrop(224), 187 | ] 188 | common_trsf = [ 189 | transforms.ToTensor(), 190 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 191 | ] 192 | class_order = np.arange(1000).tolist() 193 | 194 | def download_data(self): 195 | train_dir = 'data/cars196/train/' 196 | test_dir = 'data/cars196/val/' 197 | 198 | train_dset = datasets.ImageFolder(train_dir) 199 | test_dset = datasets.ImageFolder(test_dir) 200 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 201 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 202 | 203 | 204 | class iResisc45_224(iData): 205 | use_path = True 206 | train_trsf = [ 207 | transforms.Resize((300, 300), interpolation=3), 208 | transforms.RandomCrop((224, 224)), 209 | transforms.RandomHorizontalFlip(), 210 | ] 211 | test_trsf = [ 212 | transforms.Resize(256, interpolation=3), 213 | transforms.CenterCrop(224), 214 | ] 215 | common_trsf = [ 216 | transforms.ToTensor(), 217 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 218 | ] 219 | class_order = np.arange(1000).tolist() 220 | 221 | def download_data(self): 222 | train_dir = 'data/resisc45/train/' 223 | test_dir = 'data/resisc45/val/' 224 | 225 | train_dset = datasets.ImageFolder(train_dir) 226 | test_dset = datasets.ImageFolder(test_dir) 227 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 228 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 229 | 230 | 231 | 232 | class iSketch345_224(iData): 233 | use_path = True 234 | train_trsf = [ 235 | transforms.Resize((300, 300), interpolation=3), 236 | transforms.RandomCrop((224, 224)), 237 | transforms.RandomHorizontalFlip(), 238 | ] 239 | test_trsf = [ 240 | transforms.Resize(256, interpolation=3), 241 | transforms.CenterCrop(224), 242 | ] 243 | common_trsf = [ 244 | transforms.ToTensor(), 245 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 246 | ] 247 | class_order = np.arange(1000).tolist() 248 | 249 | def download_data(self): 250 | train_dir = 'data/sketch345/train/' 251 | test_dir = 'data/sketch345/val/' 252 | 253 | train_dset = datasets.ImageFolder(train_dir) 254 | test_dset = datasets.ImageFolder(test_dir) 255 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 256 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) -------------------------------------------------------------------------------- /utils/data_manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000, iCIFAR100_224, iImageNetR, iCUB200_224, iResisc45_224, iCARS196_224, iSketch345_224 7 | from copy import deepcopy 8 | import random 9 | 10 | class DataManager(object): 11 | def __init__(self, dataset_name, shuffle, seed, init_cls, increment): 12 | self.dataset_name = dataset_name 13 | self._setup_data(dataset_name, shuffle, seed) 14 | assert init_cls <= len(self._class_order), 'No enough classes.' 15 | self._increments = [init_cls] 16 | while sum(self._increments) + increment < len(self._class_order): 17 | self._increments.append(increment) 18 | offset = len(self._class_order) - sum(self._increments) 19 | if offset > 0: 20 | self._increments.append(offset) 21 | 22 | @property 23 | def nb_tasks(self): 24 | return len(self._increments) 25 | 26 | def get_task_size(self, task): 27 | return self._increments[task] 28 | 29 | def get_dataset(self, indices, source, mode, appendent=None, ret_data=False, with_raw=False, with_noise=False): 30 | if source == 'train': 31 | x, y = self._train_data, self._train_targets 32 | elif source == 'test': 33 | x, y = self._test_data, self._test_targets 34 | else: 35 | raise ValueError('Unknown data source {}.'.format(source)) 36 | 37 | if mode == 'train': 38 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) 39 | elif mode == 'flip': 40 | trsf = transforms.Compose([*self._test_trsf, transforms.RandomHorizontalFlip(p=1.), *self._common_trsf]) 41 | elif mode == 'test': 42 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) 43 | else: 44 | raise ValueError('Unknown mode {}.'.format(mode)) 45 | 46 | data, targets = [], [] 47 | for idx in indices: 48 | class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1) 49 | data.append(class_data) 50 | targets.append(class_targets) 51 | 52 | if appendent is not None and len(appendent) != 0: 53 | appendent_data, appendent_targets = appendent 54 | data.append(appendent_data) 55 | targets.append(appendent_targets) 56 | 57 | data, targets = np.concatenate(data), np.concatenate(targets) 58 | 59 | if ret_data: 60 | return data, targets, DummyDataset(data, targets, trsf, self.use_path, with_raw, with_noise) 61 | else: 62 | return DummyDataset(data, targets, trsf, self.use_path, with_raw, with_noise) 63 | 64 | def get_dataset_with_split(self, indices, source, mode, appendent=None, val_samples_per_class=0): 65 | if source == 'train': 66 | x, y = self._train_data, self._train_targets 67 | elif source == 'test': 68 | x, y = self._test_data, self._test_targets 69 | else: 70 | raise ValueError('Unknown data source {}.'.format(source)) 71 | 72 | if mode == 'train': 73 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) 74 | elif mode == 'test': 75 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) 76 | else: 77 | raise ValueError('Unknown mode {}.'.format(mode)) 78 | 79 | train_data, train_targets = [], [] 80 | val_data, val_targets = [], [] 81 | for idx in indices: 82 | class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1) 83 | val_indx = np.random.choice(len(class_data), val_samples_per_class, replace=False) 84 | train_indx = list(set(np.arange(len(class_data))) - set(val_indx)) 85 | val_data.append(class_data[val_indx]) 86 | val_targets.append(class_targets[val_indx]) 87 | train_data.append(class_data[train_indx]) 88 | train_targets.append(class_targets[train_indx]) 89 | 90 | if appendent is not None: 91 | appendent_data, appendent_targets = appendent 92 | for idx in range(0, int(np.max(appendent_targets))+1): 93 | append_data, append_targets = self._select(appendent_data, appendent_targets, 94 | low_range=idx, high_range=idx+1) 95 | val_indx = np.random.choice(len(append_data), val_samples_per_class, replace=False) 96 | train_indx = list(set(np.arange(len(append_data))) - set(val_indx)) 97 | val_data.append(append_data[val_indx]) 98 | val_targets.append(append_targets[val_indx]) 99 | train_data.append(append_data[train_indx]) 100 | train_targets.append(append_targets[train_indx]) 101 | 102 | train_data, train_targets = np.concatenate(train_data), np.concatenate(train_targets) 103 | val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets) 104 | 105 | return DummyDataset(train_data, train_targets, trsf, self.use_path), \ 106 | DummyDataset(val_data, val_targets, trsf, self.use_path) 107 | 108 | def _setup_data(self, dataset_name, shuffle, seed): 109 | idata = _get_idata(dataset_name) 110 | idata.download_data() 111 | 112 | # Data 113 | self._train_data, self._train_targets = idata.train_data, idata.train_targets 114 | self._test_data, self._test_targets = idata.test_data, idata.test_targets 115 | self.use_path = idata.use_path 116 | 117 | # Transforms 118 | self._train_trsf = idata.train_trsf 119 | self._test_trsf = idata.test_trsf 120 | self._common_trsf = idata.common_trsf 121 | 122 | # Order 123 | order = [i for i in range(len(np.unique(self._train_targets)))] 124 | if shuffle: 125 | np.random.seed(seed) 126 | order = np.random.permutation(len(order)).tolist() 127 | else: 128 | order = idata.class_order 129 | self._class_order = order 130 | logging.info(self._class_order) 131 | 132 | # Map indices 133 | self._train_targets = _map_new_class_index(self._train_targets, self._class_order) 134 | self._test_targets = _map_new_class_index(self._test_targets, self._class_order) 135 | 136 | def _select(self, x, y, low_range, high_range): 137 | idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] 138 | return x[idxes], y[idxes] 139 | 140 | 141 | class DummyDataset(Dataset): 142 | def __init__(self, images, labels, trsf, use_path=False, with_raw=False, with_noise=False): 143 | assert len(images) == len(labels), 'Data size error!' 144 | self.images = images 145 | self.labels = labels 146 | self.trsf = trsf 147 | self.use_path = use_path 148 | self.with_raw = with_raw 149 | if use_path and with_raw: 150 | self.raw_trsf = transforms.Compose([transforms.Resize((500, 500)), transforms.ToTensor()]) 151 | else: 152 | self.raw_trsf = transforms.Compose([transforms.ToTensor()]) 153 | if with_noise: 154 | class_list = np.unique(self.labels) 155 | self.ori_labels = deepcopy(labels) 156 | for cls in class_list: 157 | random_target = class_list.tolist() 158 | random_target.remove(cls) 159 | tindx = [i for i, x in enumerate(self.ori_labels) if x == cls] 160 | for i in tindx[:round(len(tindx)*0.2)]: 161 | self.labels[i] = random.choice(random_target) 162 | 163 | 164 | def __len__(self): 165 | return len(self.images) 166 | 167 | def __getitem__(self, idx): 168 | if self.use_path: 169 | load_image = pil_loader(self.images[idx]) 170 | image = self.trsf(load_image) 171 | else: 172 | load_image = Image.fromarray(self.images[idx]) 173 | image = self.trsf(load_image) 174 | label = self.labels[idx] 175 | if self.with_raw: 176 | return idx, image, label, self.raw_trsf(load_image) 177 | return idx, image, label 178 | 179 | 180 | def _map_new_class_index(y, order): 181 | return np.array(list(map(lambda x: order.index(x), y))) 182 | 183 | 184 | def _get_idata(dataset_name): 185 | name = dataset_name.lower() 186 | if name == 'cifar10': 187 | return iCIFAR10() 188 | elif name == 'cifar100': 189 | return iCIFAR100() 190 | elif name == 'cifar100_224': 191 | return iCIFAR100_224() 192 | elif name == 'imagenet1000': 193 | return iImageNet1000() 194 | elif name == "imagenet100": 195 | return iImageNet100() 196 | elif name == "imagenet-r": 197 | return iImageNetR() 198 | elif name == 'cub200_224': 199 | return iCUB200_224() 200 | elif name == 'resisc45': 201 | return iResisc45_224() 202 | elif name == 'cars196_224': 203 | return iCARS196_224() 204 | elif name == 'sketch345_224': 205 | return iSketch345_224() 206 | else: 207 | raise NotImplementedError('Unknown dataset {}.'.format(dataset_name)) 208 | 209 | 210 | def pil_loader(path): 211 | ''' 212 | Ref: 213 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder 214 | ''' 215 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 216 | with open(path, 'rb') as f: 217 | img = Image.open(f) 218 | return img.convert('RGB') 219 | 220 | 221 | def accimage_loader(path): 222 | ''' 223 | Ref: 224 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder 225 | accimage is an accelerated Image loader and preprocessor leveraging Intel IPP. 226 | accimage is available on conda-forge. 227 | ''' 228 | import accimage 229 | try: 230 | return accimage.Image(path) 231 | except IOError: 232 | # Potentially a decoding problem, fall back to PIL.Image 233 | return pil_loader(path) 234 | 235 | 236 | def default_loader(path): 237 | ''' 238 | Ref: 239 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder 240 | ''' 241 | from torchvision import get_image_backend 242 | if get_image_backend() == 'accimage': 243 | return accimage_loader(path) 244 | else: 245 | return pil_loader(path) 246 | -------------------------------------------------------------------------------- /utils/factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.slca import SLCA 3 | from models.slca_pp import SLCApp 4 | 5 | def get_model(model_name, args): 6 | name = model_name.lower() 7 | if 'slcapp' in name: 8 | return SLCApp(args) 9 | elif 'slca' in name: 10 | return SLCA(args) 11 | else: 12 | assert 0 13 | -------------------------------------------------------------------------------- /utils/inc_net.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch import nn 4 | from convs.cifar_resnet import resnet32 5 | from convs.resnet import resnet18, resnet34, resnet50 6 | from convs.linears import SimpleContinualLinear 7 | from convs.vits import vit_base_patch16_224_in21k, vit_base_patch16_224_mocov3, vit_base_lora_patch16_224_in21k, vit_base_lora_patch16_224_mocov3 8 | import torch.nn.functional as F 9 | 10 | def get_convnet(cfg, pretrained=False): 11 | name = cfg['convnet_type'] 12 | name = name.lower() 13 | if name == 'resnet32': 14 | return resnet32() 15 | elif name == 'resnet18': 16 | return resnet18(pretrained=pretrained) 17 | elif name == 'resnet18_cifar': 18 | return resnet18(pretrained=pretrained, cifar=True) 19 | elif name == 'resnet18_cifar_cos': 20 | return resnet18(pretrained=pretrained, cifar=True, no_last_relu=True) 21 | elif name == 'resnet34': 22 | return resnet34(pretrained=pretrained) 23 | elif name == 'resnet50': 24 | return resnet50(pretrained=pretrained) 25 | elif name == 'vit-b-p16': 26 | return vit_base_patch16_224_in21k(pretrained=True) 27 | elif name == 'vit-b-p16-mocov3': 28 | return vit_base_patch16_224_mocov3(pretrained=True) 29 | elif name == 'vit-b-p16-lora': 30 | return vit_base_lora_patch16_224_in21k(pretrained=True, lora_rank=cfg['lora_rank']) 31 | elif name == 'vit-b-p16-lora-mocov3': 32 | return vit_base_lora_patch16_224_mocov3(pretrained=True, lora_rank=cfg['lora_rank']) 33 | else: 34 | raise NotImplementedError('Unknown type {}'.format(name)) 35 | 36 | 37 | class BaseNet(nn.Module): 38 | 39 | def __init__(self, cfg, pretrained): 40 | super(BaseNet, self).__init__() 41 | 42 | self.convnet = get_convnet(cfg, pretrained) 43 | self.fc = None 44 | 45 | @property 46 | def feature_dim(self): 47 | return self.convnet.out_dim 48 | 49 | def extract_vector(self, x): 50 | return self.convnet(x)['features'] 51 | 52 | def forward(self, x): 53 | x = self.convnet(x) 54 | out = self.fc(x['features']) 55 | ''' 56 | { 57 | 'fmaps': [x_1, x_2, ..., x_n], 58 | 'features': features 59 | 'logits': logits 60 | } 61 | ''' 62 | out.update(x) 63 | 64 | return out 65 | 66 | def update_fc(self, nb_classes): 67 | pass 68 | 69 | def generate_fc(self, in_dim, out_dim): 70 | pass 71 | 72 | def copy(self): 73 | return copy.deepcopy(self) 74 | 75 | def freeze(self): 76 | for param in self.parameters(): 77 | param.requires_grad = False 78 | self.eval() 79 | 80 | return self 81 | 82 | class FinetuneIncrementalNet(BaseNet): 83 | 84 | def __init__(self, cfg, pretrained, fc_with_ln=False): 85 | super().__init__(cfg, pretrained) 86 | self.old_fc = None 87 | self.fc_with_ln = fc_with_ln 88 | if 'fc_scale_mu' in cfg.keys(): 89 | self.fc_scale_mu = cfg['fc_scale_mu'] 90 | else: 91 | self.fc_scale_mu = -1 92 | 93 | 94 | def extract_layerwise_vector(self, x, pool=True): 95 | with torch.no_grad(): 96 | features = self.convnet(x, layer_feat=True)['features'] 97 | for f_i in range(len(features)): 98 | if pool: 99 | features[f_i] = features[f_i].mean(1).cpu().numpy() 100 | else: 101 | features[f_i] = features[f_i][:, 0].cpu().numpy() 102 | return features 103 | 104 | 105 | def update_fc(self, nb_classes, freeze_old=True): 106 | if self.fc is None: 107 | self.fc = self.generate_fc(self.feature_dim, nb_classes) 108 | else: 109 | self.fc.update(nb_classes, freeze_old=freeze_old) 110 | 111 | def save_old_fc(self): 112 | if self.old_fc is None: 113 | self.old_fc = copy.deepcopy(self.fc) 114 | else: 115 | self.old_fc.heads.append(copy.deepcopy(self.fc.heads[-1])) 116 | 117 | def generate_fc(self, in_dim, out_dim): 118 | fc = SimpleContinualLinear(in_dim, out_dim, scale_mu=self.fc_scale_mu) 119 | 120 | return fc 121 | 122 | def forward(self, x, bcb_no_grad=False, fc_only=False): 123 | if fc_only: 124 | fc_out = self.fc(x) 125 | if self.old_fc is not None: 126 | old_fc_logits = self.old_fc(x)['logits'] 127 | fc_out['old_logits'] = old_fc_logits 128 | return fc_out 129 | if bcb_no_grad: 130 | with torch.no_grad(): 131 | x = self.convnet(x) 132 | else: 133 | x = self.convnet(x) 134 | out = self.fc(x['features']) 135 | out.update(x) 136 | 137 | return out 138 | 139 | 140 | -------------------------------------------------------------------------------- /utils/net_linear_wapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LinearWapper(nn.Module): 4 | def __init__(self, model): 5 | super(LinearWapper, self).__init__() 6 | self.reset_parameters() 7 | 8 | def reset_parameters(self): 9 | for m in self.modules(): 10 | nn.init.kaiming_uniform_(m.weight, nonlinearity='linear') 11 | nn.init.constant_(m.bias, 0) 12 | 13 | def forward(self, input): 14 | return {'logits': F.linear(input, self.weight, self.bias)} 15 | -------------------------------------------------------------------------------- /utils/toolkit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def count_parameters(model, trainable=False): 7 | if trainable: 8 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | return sum(p.numel() for p in model.parameters()) 10 | 11 | 12 | def tensor2numpy(x): 13 | return x.cpu().data.numpy() if x.is_cuda else x.data.numpy() 14 | 15 | 16 | def target2onehot(targets, n_classes): 17 | onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) 18 | onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.) 19 | return onehot 20 | 21 | 22 | def makedirs(path): 23 | if not os.path.exists(path): 24 | os.makedirs(path) 25 | 26 | 27 | def accuracy(y_pred, y_true, nb_old, increment=10): 28 | assert len(y_pred) == len(y_true), 'Data length error.' 29 | all_acc = {} 30 | all_acc['total'] = np.around((y_pred == y_true).sum()*100 / len(y_true), decimals=2) 31 | 32 | # Grouped accuracy 33 | for class_id in range(0, np.max(y_true), increment): 34 | idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0] 35 | label = '{}-{}'.format(str(class_id).rjust(2, '0'), str(class_id+increment-1).rjust(2, '0')) 36 | all_acc[label] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2) 37 | 38 | # Old accuracy 39 | idxes = np.where(y_true < nb_old)[0] 40 | all_acc['old'] = 0 if len(idxes) == 0 else np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), 41 | decimals=2) 42 | 43 | # New accuracy 44 | idxes = np.where(y_true >= nb_old)[0] 45 | all_acc['new'] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2) 46 | 47 | return all_acc 48 | 49 | 50 | def split_images_labels(imgs): 51 | # split trainset.imgs in ImageFolder 52 | images = [] 53 | labels = [] 54 | for item in imgs: 55 | images.append(item[0]) 56 | labels.append(item[1]) 57 | 58 | return np.array(images), np.array(labels) 59 | --------------------------------------------------------------------------------