├── .gitignore ├── continual_datasets ├── __pycache__ │ ├── dataset_utils.cpython-310.pyc │ ├── dataset_utils.cpython-38.pyc │ ├── continual_datasets.cpython-310.pyc │ └── continual_datasets.cpython-38.pyc ├── dataset_utils.py └── continual_datasets.py ├── README.md ├── models.py ├── requirements.txt ├── prompt.py ├── utils.py ├── datasets.py ├── LICENSE ├── main.py ├── engine.py └── vision_transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | output 3 | *.sh 4 | .vscode/ 5 | -------------------------------------------------------------------------------- /continual_datasets/__pycache__/dataset_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KU-VGI/VIL/HEAD/continual_datasets/__pycache__/dataset_utils.cpython-310.pyc -------------------------------------------------------------------------------- /continual_datasets/__pycache__/dataset_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KU-VGI/VIL/HEAD/continual_datasets/__pycache__/dataset_utils.cpython-38.pyc -------------------------------------------------------------------------------- /continual_datasets/__pycache__/continual_datasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KU-VGI/VIL/HEAD/continual_datasets/__pycache__/continual_datasets.cpython-310.pyc -------------------------------------------------------------------------------- /continual_datasets/__pycache__/continual_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KU-VGI/VIL/HEAD/continual_datasets/__pycache__/continual_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ECCV 2024] Versatile Incremental Learning: Towards Class and Domain-Agnostic Incremental Learning 2 | 3 | Official PyTorch implementation for ECCV 2024 paper: 4 | 5 | **Versatile Incremental Learning: Towards Class and Domain-Agnostic Incremental Learning** 6 | [Min-Yeong Park](https://github.com/pmy0792)\*, [Jaeho Lee](https://github.com/JH-LEE-KR)\*, and Gyeong-Moon Park† 7 | 8 | [![arXiv](https://img.shields.io/badge/arXiv-2409.10956-b31b1b.svg)](https://arxiv.org/abs/2409.10956) 9 | 10 | 11 | # Environment 12 | - Python 3.8.x 13 | - PyTorch 1.12.1 14 | - Torchvision 0.13.1 15 | - NVIDIA GeForce RTX 3090 16 | - CUDA 11.3 17 | 18 | 19 | # Getting Started 20 | ## Environment 21 | ```bash 22 | git clone git@github.com/KHU-AGI/VIL.git 23 | cd VIL 24 | conda create -n VIL python==3.8 25 | conda activate VIL 26 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Run ICON on VIL with iDigits dataset 31 | ```bash 32 | python main.py --dataset iDigits --num_tasks 5 --seed 42 --versatile_inc --batch-size 24 --IC --thre 0.0 --beta 0.01 --use_cast_loss --k 2 --d_threshold 33 | ``` 34 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # Copyright (c) 2015-present, Facebook, Inc. 3 | # All rights reserved. 4 | # ------------------------------------------ 5 | # Modification: 6 | # Added code for Joint-FT 7 | # -- Jaeho Lee, dlwogh9344@khu.ac.kr 8 | # ------------------------------------------ 9 | import torch.nn as nn 10 | from timm.models._registry import register_model 11 | from vision_transformer import _create_vision_transformer 12 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 13 | 14 | __all__ = [ 15 | 'vit_base_patch16_224', 16 | 'vit_base_patch16_clip_224.openai', 17 | 'vit_b16_in21k' 18 | ] 19 | 20 | def _cfg(url='', **kwargs): 21 | return { 22 | 'url': url, 23 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 24 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 25 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 26 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 27 | **kwargs 28 | } 29 | 30 | @register_model 31 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs): 32 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 33 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 34 | """ 35 | default_cfg = _cfg( 36 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz', 37 | custom_load=True,) 38 | kwargs.update(pretrained_cfg=default_cfg) 39 | 40 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 41 | 42 | model = _create_vision_transformer( 43 | 'vit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 44 | return model 45 | 46 | @register_model 47 | def vit_base_patch16_224(pretrained=False, **kwargs): 48 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 49 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 50 | """ 51 | default_cfg = _cfg( 52 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz', 53 | custom_load=True,) 54 | kwargs.update(pretrained_cfg=default_cfg) 55 | 56 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 57 | 58 | model = _create_vision_transformer( 59 | 'vit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 60 | return model 61 | 62 | @register_model 63 | def vit_base_patch16_clip_224_openai(pretrained=False, **kwargs): 64 | """ ViT-B/16 CLIP image tower, OpenAI original weights 65 | """ 66 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) 67 | 68 | model = _create_vision_transformer( 69 | 'vit_base_patch16_clip_224.openai', pretrained=pretrained, **dict(model_args, **kwargs)) 70 | return model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | arch==5.6.0 3 | asttokens==2.4.1 4 | astunparse==1.6.3 5 | backcall==0.2.0 6 | beautifulsoup4==4.12.3 7 | brotlipy==0.7.0 8 | cachetools==5.3.2 9 | click==8.1.7 10 | cloudpickle==3.0.0 11 | cmake==3.27.7 12 | cycler==0.11.0 13 | dask==2023.5.0 14 | decorator==5.1.1 15 | distributed==2023.5.0 16 | einops==0.7.0 17 | executing==2.0.1 18 | filelock==3.12.4 19 | finch-clust==0.2.0 20 | flatbuffers==23.5.26 21 | fonttools==4.42.1 22 | frozendict==2.3.10 23 | fsspec==2023.9.1 24 | ftfy==6.1.3 25 | gast==0.4.0 26 | gdown==5.1.0 27 | google-auth==2.25.2 28 | google-auth-oauthlib==1.0.0 29 | google-pasta==0.2.0 30 | graphviz==0.20.1 31 | grpcio==1.60.0 32 | h5py==3.10.0 33 | huggingface-hub==0.17.1 34 | hurst==0.0.5 35 | idx2numpy==1.2.3 36 | importlib-metadata==6.8.0 37 | ipdb==0.13.13 38 | ipython==8.12.3 39 | jedi==0.19.1 40 | Jinja2==3.1.2 41 | joblib==1.3.2 42 | keras==2.13.1 43 | kiwisolver==1.4.5 44 | lazy-object-proxy==1.10.0 45 | libclang==16.0.6 46 | lit==17.0.4 47 | llvmlite==0.41.1 48 | locket==1.0.0 49 | loralib==0.1.2 50 | Markdown==3.5.1 51 | MarkupSafe==2.1.3 52 | matplotlib==3.5.3 53 | matplotlib-inline==0.1.6 54 | mkl-fft==1.3.6 55 | mkl-service==2.4.0 56 | more-itertools==10.1.0 57 | msgpack==1.0.7 58 | natsort==8.4.0 59 | numba==0.58.1 60 | numpy==1.24.3 61 | oauthlib==3.2.2 62 | openai-whisper==20230918 63 | opt-einsum==3.3.0 64 | packaging==23.1 65 | pandas==2.0.3 66 | parso==0.8.3 67 | partd==1.4.1 68 | patsy==0.5.4 69 | pexpect==4.9.0 70 | pickleshare==0.7.5 71 | Pillow==9.2.0 72 | prompt-toolkit==3.0.43 73 | property-cached==1.6.4 74 | protobuf==4.25.1 75 | psutil==5.9.6 76 | ptyprocess==0.7.0 77 | pure-eval==0.2.2 78 | pyasn1==0.5.1 79 | pyasn1-modules==0.3.0 80 | Pygments==2.17.2 81 | pyparsing==3.1.1 82 | python-dateutil==2.8.2 83 | pytz==2023.3.post1 84 | PyYAML==6.0.1 85 | regex==2023.10.3 86 | requests-oauthlib==1.3.1 87 | rsa==4.9 88 | safetensors==0.3.3 89 | scikit-learn==1.3.1 90 | scipy==1.10.1 91 | seaborn==0.13.1 92 | semver==3.0.2 93 | six==1.16.0 94 | sortedcontainers==2.4.0 95 | soupsieve==2.5 96 | stack-data==0.6.3 97 | statsmodels==0.14.0 98 | stumpy==1.12.0 99 | tblib==3.0.0 100 | tensorboard==2.13.0 101 | tensorboard-data-server==0.7.2 102 | tensorflow==2.13.1 103 | tensorflow-estimator==2.13.0 104 | tensorflow-io-gcs-filesystem==0.34.0 105 | termcolor==2.4.0 106 | TFSnippet @ git+https://github.com/haowen-xu/tfsnippet.git@63adaf04d2ffff8dec299623627d55d4bacac598 107 | threadpoolctl==3.2.0 108 | tiktoken==0.3.3 109 | timm==0.9.7 110 | tokenizers==0.14.1 111 | tomli==2.0.1 112 | toolz==0.12.0 113 | # torch==1.12.1 114 | # torchvision==0.13.1 115 | torchviz==0.0.2 116 | tornado==6.4 117 | tqdm==4.66.1 118 | traitlets==5.14.1 119 | transformers==4.34.0 120 | triton==2.0.0 121 | tsfresh==0.20.1 122 | typing_extensions==4.5.0 123 | tzdata==2023.3 124 | wcwidth==0.2.13 125 | Werkzeug==3.0.1 126 | wrapt==1.16.0 127 | zhusuan @ git+https://github.com/thu-ml/zhusuan.git@4386b2a12ae4f4ed8e694e504e51d7dcdfd6f22a 128 | zict==3.0.0 129 | zipp==3.17.0 130 | -------------------------------------------------------------------------------- /prompt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Prompt(nn.Module): 5 | def __init__(self, length=5, embed_dim=768, embedding_key='mean', prompt_init='uniform', prompt_pool=False, 6 | prompt_key=False, pool_size=None, top_k=None, batchwise_prompt=False, prompt_key_init='uniform',): 7 | super().__init__() 8 | 9 | self.length = length 10 | self.embed_dim = embed_dim 11 | self.prompt_pool = prompt_pool 12 | self.embedding_key = embedding_key 13 | self.prompt_init = prompt_init 14 | self.prompt_key = prompt_key 15 | self.pool_size = pool_size 16 | self.top_k = top_k 17 | self.batchwise_prompt = batchwise_prompt 18 | 19 | if self.prompt_pool: 20 | prompt_pool_shape = (pool_size, length, embed_dim) 21 | if prompt_init == 'zero': 22 | self.prompt = nn.Parameter(torch.zeros(prompt_pool_shape)) 23 | elif prompt_init == 'uniform': 24 | self.prompt = nn.Parameter(torch.randn(prompt_pool_shape)) 25 | nn.init.uniform_(self.prompt, -1, 1) 26 | 27 | # if using learnable prompt keys 28 | if prompt_key: 29 | key_shape = (pool_size, embed_dim) 30 | if prompt_key_init == 'zero': 31 | self.prompt_key = nn.Parameter(torch.zeros(key_shape)) 32 | elif prompt_key_init == 'uniform': 33 | self.prompt_key = nn.Parameter(torch.randn(key_shape)) 34 | nn.init.uniform_(self.prompt_key, -1, 1) 35 | else: 36 | # else use mean of prompt as key 37 | # only compatible with prompt, not prefix 38 | prompt_mean = torch.mean(self.prompt, dim=1) 39 | self.prompt_key = prompt_mean 40 | 41 | def l2_normalize(self, x, dim=None, epsilon=1e-12): 42 | """Normalizes a given vector or matrix.""" 43 | square_sum = torch.sum(x ** 2, dim=dim, keepdim=True) 44 | x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device))) 45 | return x * x_inv_norm 46 | 47 | def forward(self, x_embed, prompt_mask=None, cls_features=None): 48 | out = dict() 49 | if self.prompt_pool: 50 | if self.embedding_key == 'mean': 51 | x_embed_mean = torch.mean(x_embed, dim=1) 52 | elif self.embedding_key == 'max': 53 | x_embed_mean = torch.max(x_embed, dim=1)[0] 54 | elif self.embedding_key == 'mean_max': 55 | x_embed_mean = torch.max(x_embed, dim=1)[0] + 2 * torch.mean(x_embed, dim=1) 56 | elif self.embedding_key == 'cls': 57 | if cls_features is None: 58 | x_embed_mean = torch.max(x_embed, dim=1)[0] # B, C 59 | else: 60 | x_embed_mean = cls_features 61 | else: 62 | raise NotImplementedError("Not supported way of calculating embedding keys!") 63 | 64 | prompt_norm = self.l2_normalize(self.prompt_key, dim=1) # Pool_size, C 65 | x_embed_norm = self.l2_normalize(x_embed_mean, dim=1) # B, C 66 | 67 | similarity = torch.matmul(x_embed_norm, prompt_norm.t()) # B, Pool_size 68 | 69 | if prompt_mask is None: 70 | _, idx = torch.topk(similarity, k=self.top_k, dim=1) # B, top_k 71 | if self.batchwise_prompt: 72 | prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True) 73 | # In jnp.unique, when the 'size' is specified and there are fewer than the indicated number of elements, 74 | # the remaining elements will be filled with 'fill_value', the default is the minimum value along the specified dimension. 75 | # Unless dimension is specified, this will be flattend if it is not already 1D. 76 | if prompt_id.shape[0] < self.pool_size: 77 | prompt_id = torch.cat([prompt_id, torch.full((self.pool_size - prompt_id.shape[0],), torch.min(idx.flatten()), device=prompt_id.device)]) 78 | id_counts = torch.cat([id_counts, torch.full((self.pool_size - id_counts.shape[0],), 0, device=id_counts.device)]) 79 | _, major_idx = torch.topk(id_counts, k=self.top_k) # top_k 80 | major_prompt_id = prompt_id[major_idx] # top_k 81 | # expand to batch 82 | idx = major_prompt_id.expand(x_embed.shape[0], -1) # B, top_k 83 | else: 84 | idx = prompt_mask # B, top_k 85 | 86 | batched_prompt_raw = self.prompt[idx] # B, top_k, length, C 87 | batch_size, top_k, length, c = batched_prompt_raw.shape 88 | batched_prompt = batched_prompt_raw.reshape(batch_size, top_k * length, c) # B, top_k * length, C 89 | 90 | out['prompt_idx'] = idx 91 | 92 | # Debugging, return sim as well 93 | out['prompt_norm'] = prompt_norm 94 | out['x_embed_norm'] = x_embed_norm 95 | out['similarity'] = similarity 96 | 97 | # Put pull_constraint loss calculation inside 98 | batched_key_norm = prompt_norm[idx] # B, top_k, C 99 | out['selected_key'] = batched_key_norm 100 | x_embed_norm = x_embed_norm.unsqueeze(1) # B, 1, C 101 | sim = batched_key_norm * x_embed_norm # B, top_k, C 102 | reduce_sim = torch.sum(sim) / x_embed.shape[0] # Scalar 103 | 104 | out['reduce_sim'] = reduce_sim 105 | else: 106 | if self.prompt_init == 'zero': 107 | self.prompt = nn.Parameter(torch.zeros(self.length, self.embed_dim)) 108 | elif self.prompt_init == 'uniform': 109 | self.prompt = nn.Parameter(torch.randn(self.length, self.embed_dim)) 110 | nn.init.uniform_(self.prompt) 111 | batched_prompt = self.prompt.unsqueeze(0).expand(x_embed.shape[0], -1, -1) 112 | 113 | # The input with the prompt concatenated to the front. [B, prompt+token, C] 114 | out['total_prompt_len'] = batched_prompt.shape[1] 115 | out['prompted_embedding'] = torch.cat([batched_prompt, x_embed], dim=1) 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # Copyright (c) 2015-present, Facebook, Inc. 3 | # All rights reserved. 4 | # ------------------------------------------ 5 | # Modification: 6 | # Added code for l2p implementation 7 | # -- Jaeho Lee, dlwogh9344@khu.ac.kr 8 | # ------------------------------------------ 9 | """ 10 | Misc functions, including distributed helpers. 11 | 12 | Mostly copy-paste from torchvision references. 13 | """ 14 | import io 15 | import os 16 | import time 17 | import math 18 | from collections import defaultdict, deque 19 | import datetime 20 | 21 | import torch 22 | import torch.distributed as dist 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if isinstance(v, torch.Tensor): 94 | v = v.item() 95 | assert isinstance(v, (float, int)) 96 | self.meters[k].update(v) 97 | 98 | def __getattr__(self, attr): 99 | if attr in self.meters: 100 | return self.meters[attr] 101 | if attr in self.__dict__: 102 | return self.__dict__[attr] 103 | raise AttributeError("'{}' object has no attribute '{}'".format( 104 | type(self).__name__, attr)) 105 | 106 | def __str__(self): 107 | loss_str = [] 108 | for name, meter in self.meters.items(): 109 | loss_str.append( 110 | "{}: {}".format(name, str(meter)) 111 | ) 112 | return self.delimiter.join(loss_str) 113 | 114 | def synchronize_between_processes(self): 115 | for meter in self.meters.values(): 116 | meter.synchronize_between_processes() 117 | 118 | def add_meter(self, name, meter): 119 | self.meters[name] = meter 120 | 121 | def log_every(self, iterable, print_freq, header=None): 122 | i = 0 123 | if not header: 124 | header = '' 125 | start_time = time.time() 126 | end = time.time() 127 | iter_time = SmoothedValue(fmt='{avg:.4f}') 128 | data_time = SmoothedValue(fmt='{avg:.4f}') 129 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 130 | log_msg = [ 131 | header, 132 | '[{0' + space_fmt + '}/{1}]', 133 | 'eta: {eta}', 134 | '{meters}', 135 | 'time: {time}', 136 | 'data: {data}' 137 | ] 138 | if torch.cuda.is_available(): 139 | log_msg.append('max mem: {memory:.0f}') 140 | log_msg = self.delimiter.join(log_msg) 141 | MB = 1024.0 * 1024.0 142 | for obj in iterable: 143 | data_time.update(time.time() - end) 144 | yield obj 145 | iter_time.update(time.time() - end) 146 | if i % print_freq == 0 or i == len(iterable) - 1: 147 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 149 | if torch.cuda.is_available(): 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time), 154 | memory=torch.cuda.max_memory_allocated() / MB)) 155 | else: 156 | print(log_msg.format( 157 | i, len(iterable), eta=eta_string, 158 | meters=str(self), 159 | time=str(iter_time), data=str(data_time))) 160 | i += 1 161 | end = time.time() 162 | total_time = time.time() - start_time 163 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 164 | print('{} Total time: {} ({:.4f} s / it)'.format( 165 | header, total_time_str, total_time / len(iterable))) 166 | 167 | 168 | def _load_checkpoint_for_ema(model_ema, checkpoint): 169 | """ 170 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 171 | """ 172 | mem_file = io.BytesIO() 173 | torch.save({'state_dict_ema':checkpoint}, mem_file) 174 | mem_file.seek(0) 175 | model_ema._load_checkpoint(mem_file) 176 | 177 | 178 | def setup_for_distributed(is_master): 179 | """ 180 | This function disables printing when not in master process 181 | """ 182 | import builtins as __builtin__ 183 | builtin_print = __builtin__.print 184 | 185 | def print(*args, **kwargs): 186 | force = kwargs.pop('force', False) 187 | if is_master or force: 188 | builtin_print(*args, **kwargs) 189 | 190 | __builtin__.print = print 191 | 192 | 193 | def is_dist_avail_and_initialized(): 194 | if not dist.is_available(): 195 | return False 196 | if not dist.is_initialized(): 197 | return False 198 | return True 199 | 200 | 201 | def get_world_size(): 202 | if not is_dist_avail_and_initialized(): 203 | return 1 204 | return dist.get_world_size() 205 | 206 | 207 | def get_rank(): 208 | if not is_dist_avail_and_initialized(): 209 | return 0 210 | return dist.get_rank() 211 | 212 | 213 | def is_main_process(): 214 | return get_rank() == 0 215 | 216 | 217 | def save_on_master(*args, **kwargs): 218 | if is_main_process(): 219 | torch.save(*args, **kwargs) 220 | 221 | 222 | def init_distributed_mode(args): 223 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 224 | args.rank = int(os.environ["RANK"]) 225 | args.world_size = int(os.environ['WORLD_SIZE']) 226 | args.gpu = int(os.environ['LOCAL_RANK']) 227 | elif 'SLURM_PROCID' in os.environ: 228 | args.rank = int(os.environ['SLURM_PROCID']) 229 | args.gpu = args.rank % torch.cuda.device_count() 230 | else: 231 | print('Not using distributed mode') 232 | args.distributed = False 233 | return 234 | 235 | args.distributed = True 236 | 237 | torch.cuda.set_device(args.gpu) 238 | args.dist_backend = 'nccl' 239 | print('| distributed init (rank {}): {}'.format( 240 | args.rank, args.dist_url), flush=True) 241 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 242 | world_size=args.world_size, rank=args.rank) 243 | torch.distributed.barrier() 244 | setup_for_distributed(args.rank == 0) 245 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torch.utils.data.dataset import Subset 5 | from torchvision import datasets, transforms 6 | 7 | from timm.data import create_transform 8 | 9 | from continual_datasets.continual_datasets import * 10 | 11 | import utils 12 | 13 | class Lambda(transforms.Lambda): 14 | def __init__(self, lambd, nb_classes): 15 | super().__init__(lambd) 16 | self.nb_classes = nb_classes 17 | 18 | def __call__(self, img): 19 | return self.lambd(img, self.nb_classes) 20 | 21 | def target_transform(x, nb_classes): 22 | return x + nb_classes 23 | 24 | def build_continual_dataloader(args): 25 | dataloader = list() 26 | class_mask = list() if args.task_inc or args.train_mask else None 27 | 28 | transform_train = build_transform(True, args) 29 | transform_val = build_transform(False, args) 30 | 31 | if args.task_inc: 32 | mode = 'til' 33 | elif args.domain_inc: 34 | mode = 'dil' 35 | elif args.versatile_inc: 36 | mode = 'vil' 37 | elif args.joint_train: 38 | mode = 'joint' 39 | else: 40 | mode = 'cil' 41 | 42 | if mode in ['til', 'cil']: 43 | if 'iDigits' in args.dataset: 44 | dataset_list = ['MNIST', 'SVHN', 'MNISTM', 'SynDigit'] 45 | train, val = list(), list() 46 | mask = list() 47 | for i, dataset in enumerate(dataset_list): 48 | dataset_train, dataset_val = get_dataset( 49 | dataset=dataset, 50 | transform_train=transform_train, 51 | transform_val=transform_val, 52 | mode=mode, 53 | args=args, 54 | ) 55 | 56 | splited_dataset, class_mask = split_single_dataset(dataset_train, dataset_val, args) 57 | mask.append(class_mask) 58 | 59 | for i in range(len(splited_dataset)): 60 | train.append(splited_dataset[i][0]) 61 | val.append(splited_dataset[i][1]) 62 | 63 | splited_dataset = list() 64 | for i in range(args.num_tasks): 65 | t = [train[i+args.num_tasks*j] for j in range(len(dataset_list))] 66 | v = [val[i+args.num_tasks*j] for j in range(len(dataset_list))] 67 | splited_dataset.append((torch.utils.data.ConcatDataset(t), torch.utils.data.ConcatDataset(v))) 68 | 69 | args.nb_classes = len(splited_dataset[0][1].datasets[0].dataset.classes) 70 | class_mask = np.unique(np.array(mask), axis=0).tolist()[0] 71 | 72 | else: 73 | dataset_train, dataset_val = get_dataset( 74 | dataset=args.dataset, 75 | transform_train=transform_train, 76 | transform_val=transform_val, 77 | mode=mode, 78 | args=args, 79 | ) 80 | 81 | splited_dataset, class_mask = split_single_dataset(dataset_train, dataset_val, args) 82 | args.nb_classes = len(dataset_val.classes) 83 | 84 | elif mode in ['dil', 'vil']: 85 | if 'iDigits' in args.dataset: 86 | dataset_list = ['MNIST', 'SVHN', 'MNISTM', 'SynDigit'] 87 | splited_dataset = list() 88 | 89 | for i in range(len(dataset_list)): 90 | dataset_train, dataset_val = get_dataset( 91 | dataset=dataset_list[i], 92 | transform_train=transform_train, 93 | transform_val=transform_val, 94 | mode=mode, 95 | args=args, 96 | ) 97 | splited_dataset.append((dataset_train, dataset_val)) 98 | 99 | args.nb_classes = len(dataset_val.classes) 100 | 101 | else: 102 | dataset_train, dataset_val = get_dataset( 103 | dataset=args.dataset, 104 | transform_train=transform_train, 105 | transform_val=transform_val, 106 | mode=mode, 107 | args=args, 108 | ) 109 | 110 | if args.dataset in ['CORe50']: 111 | splited_dataset = [(dataset_train[i], dataset_val) for i in range(len(dataset_train))] 112 | args.nb_classes = len(dataset_val.classes) 113 | else: 114 | splited_dataset = [(dataset_train[i], dataset_val[i]) for i in range(len(dataset_train))] 115 | args.nb_classes = len(dataset_val[0].classes) 116 | 117 | elif mode in ['joint']: 118 | if 'iDigits' in args.dataset: 119 | dataset_list = ['MNIST', 'SVHN', 'MNISTM', 'SynDigit'] 120 | train, val = list(), list() 121 | mask = list() 122 | for i, dataset in enumerate(dataset_list): 123 | dataset_train, dataset_val = get_dataset( 124 | dataset=dataset, 125 | transform_train=transform_train, 126 | transform_val=transform_val, 127 | mode=mode, 128 | args=args, 129 | ) 130 | train.append(dataset_train) 131 | val.append(dataset_val) 132 | args.nb_classes = len(dataset_val.classes) 133 | 134 | dataset_train = torch.utils.data.ConcatDataset(train) 135 | dataset_val = torch.utils.data.ConcatDataset(val) 136 | splited_dataset = [(dataset_train, dataset_val)] 137 | 138 | class_mask = None 139 | 140 | else: 141 | dataset_train, dataset_val = get_dataset( 142 | dataset=args.dataset, 143 | transform_train=transform_train, 144 | transform_val=transform_val, 145 | mode=mode, 146 | args=args, 147 | ) 148 | 149 | splited_dataset = [(dataset_train, dataset_val)] 150 | 151 | args.nb_classes = len(dataset_val.classes) 152 | class_mask = None 153 | 154 | else: 155 | raise ValueError(f'Invalid mode: {mode}') 156 | 157 | 158 | if args.versatile_inc: 159 | splited_dataset, class_mask, domain_list, args = build_vil_scenario(splited_dataset, args) 160 | for c, d in zip(class_mask, domain_list): 161 | print(c, d) 162 | for i in range(len(splited_dataset)): 163 | dataset_train, dataset_val = splited_dataset[i] 164 | 165 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 166 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 167 | 168 | data_loader_train = torch.utils.data.DataLoader( 169 | dataset_train, sampler=sampler_train, 170 | batch_size=args.batch_size, 171 | num_workers=args.num_workers, 172 | pin_memory=args.pin_mem, 173 | ) 174 | 175 | data_loader_val = torch.utils.data.DataLoader( 176 | dataset_val, sampler=sampler_val, 177 | batch_size=args.batch_size, 178 | num_workers=args.num_workers, 179 | pin_memory=args.pin_mem, 180 | ) 181 | 182 | dataloader.append({'train': data_loader_train, 'val': data_loader_val}) 183 | 184 | return dataloader, class_mask, domain_list 185 | 186 | def get_dataset(dataset, transform_train, transform_val, mode, args,): 187 | if dataset == 'MNIST': 188 | dataset_train = MNIST_RGB(args.data_path, train=True, download=True, transform=transform_train) 189 | dataset_val = MNIST_RGB(args.data_path, train=False, download=True, transform=transform_val) 190 | 191 | elif dataset == 'SVHN': 192 | dataset_train = SVHN(args.data_path, split='train', download=True, transform=transform_train) 193 | dataset_val = SVHN(args.data_path, split='test', download=True, transform=transform_val) 194 | 195 | elif dataset == 'CORe50': 196 | dataset_train = CORe50(args.data_path, train=True, download=True, transform=transform_train, mode=mode).data 197 | dataset_val = CORe50(args.data_path, train=False, download=True, transform=transform_val, mode=mode).data 198 | 199 | elif dataset == 'DomainNet': 200 | dataset_train = DomainNet(args.data_path, train=True, download=True, transform=transform_train, mode=mode).data 201 | dataset_val = DomainNet(args.data_path, train=False, download=True, transform=transform_val, mode=mode).data 202 | 203 | elif dataset == 'MNISTM': 204 | dataset_train = MNISTM(args.data_path, train=True, download=True, transform=transform_train) 205 | dataset_val = MNISTM(args.data_path, train=False, download=True, transform=transform_val) 206 | 207 | elif dataset == 'SynDigit': 208 | dataset_train = SynDigit(args.data_path, train=True, download=True, transform=transform_train) 209 | dataset_val = SynDigit(args.data_path, train=False, download=True, transform=transform_val) 210 | 211 | else: 212 | raise ValueError('Dataset {} not found.'.format(dataset)) 213 | 214 | return dataset_train, dataset_val 215 | 216 | def split_single_dataset(dataset_train, dataset_val, args): 217 | nb_classes = len(dataset_val.classes) 218 | assert nb_classes % args.num_tasks == 0 219 | classes_per_task = nb_classes // args.num_tasks 220 | 221 | labels = [i for i in range(nb_classes)] 222 | 223 | split_datasets = list() 224 | mask = list() 225 | 226 | if args.shuffle: 227 | random.shuffle(labels) 228 | 229 | for _ in range(args.num_tasks): 230 | train_split_indices = list() 231 | test_split_indices = list() 232 | 233 | scope = labels[:classes_per_task] 234 | labels = labels[classes_per_task:] 235 | 236 | mask.append(scope) 237 | 238 | for k in range(len(dataset_train.targets)): 239 | if int(dataset_train.targets[k]) in scope: 240 | train_split_indices.append(k) 241 | 242 | for h in range(len(dataset_val.targets)): 243 | if int(dataset_val.targets[h]) in scope: 244 | test_split_indices.append(h) 245 | 246 | subset_train, subset_val = Subset(dataset_train, train_split_indices), Subset(dataset_val, test_split_indices) 247 | 248 | split_datasets.append([subset_train, subset_val]) 249 | 250 | return split_datasets, mask 251 | 252 | def build_vil_scenario(splited_dataset, args): 253 | datasets = list() 254 | class_mask = list() 255 | domain_list = list() 256 | 257 | for i in range(len(splited_dataset)): 258 | dataset, mask = split_single_dataset(splited_dataset[i][0], splited_dataset[i][1], args) 259 | datasets.append(dataset) 260 | class_mask.append(mask) 261 | for _ in range(len(dataset)): 262 | domain_list.append(f'D{i}') 263 | 264 | splited_dataset = sum(datasets, []) 265 | class_mask = sum(class_mask, []) 266 | 267 | args.num_tasks = len(splited_dataset) 268 | 269 | zipped = list(zip(splited_dataset, class_mask, domain_list)) 270 | random.shuffle(zipped) 271 | splited_dataset, class_mask, domain_list = zip(*zipped) 272 | 273 | return splited_dataset, class_mask, domain_list, args 274 | 275 | def build_transform(is_train, args): 276 | if is_train: 277 | transform = transforms.Compose([ 278 | transforms.RandomResizedCrop(224), 279 | transforms.RandomHorizontalFlip(), 280 | transforms.ToTensor(), 281 | ]) 282 | else: 283 | transform = transforms.Compose([ 284 | transforms.Resize(256), 285 | transforms.CenterCrop(224), 286 | transforms.ToTensor(), 287 | ]) 288 | return transform -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import datetime 4 | import random 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | 10 | from pathlib import Path 11 | 12 | from timm.models import create_model 13 | from timm.scheduler import create_scheduler 14 | from timm.optim import create_optimizer 15 | 16 | from datasets import build_continual_dataloader 17 | from engine import Engine 18 | import models 19 | import utils 20 | import os 21 | 22 | import warnings 23 | warnings.filterwarnings('ignore', 'Argument interpolation should be of type InterpolationMode instead of int') 24 | 25 | def set_data_config(args): 26 | if args.dataset == "iDigits": 27 | args.class_num = 10 28 | args.domain_num = 4 29 | elif args.dataset == "DomainNet": 30 | args.class_num = 345 31 | args.domain_num = 6 32 | elif args.dataset == "CORe50": 33 | args.class_num = 50 34 | args.domain_num = 8 35 | return args 36 | 37 | def main(args): 38 | # utils.init_distributed_mode(args) 39 | args.distributed = False 40 | args = set_data_config(args) 41 | device = torch.device(args.device) 42 | 43 | # fix the seed for reproducibility 44 | seed = args.seed 45 | torch.manual_seed(seed) 46 | np.random.seed(seed) 47 | random.seed(seed) 48 | 49 | cudnn.benchmark = True 50 | cudnn.deterministic = True 51 | 52 | 53 | data_loader, class_mask, domain_list = build_continual_dataloader(args) 54 | 55 | 56 | model = create_model( 57 | args.model, 58 | pretrained=args.pretrained, 59 | num_classes=args.nb_classes, 60 | drop_rate=args.drop, 61 | drop_path_rate=args.drop_path, 62 | drop_block_rate=None, 63 | adapt_blocks=args.adapt_blocks, 64 | ) 65 | 66 | model.to(device) 67 | 68 | 69 | engine = Engine(model=model,device=device, class_mask=class_mask, domain_list=domain_list, args=args) 70 | 71 | for n, p in model.named_parameters(): 72 | p.requires_grad = False 73 | if 'adapter' in n: 74 | p.requires_grad = True 75 | if 'head' in n: 76 | p.requires_grad = True 77 | 78 | print(args) 79 | 80 | if args.eval: 81 | acc_matrix = np.zeros((args.num_tasks, args.num_tasks)) 82 | 83 | for task_id in range(args.num_tasks): 84 | checkpoint_path = os.path.join(args.output_dir, 'checkpoint/task{}_checkpoint.pth'.format(task_id+1)) 85 | if os.path.exists(checkpoint_path): 86 | print('Loading checkpoint from:', checkpoint_path) 87 | checkpoint = torch.load(checkpoint_path) 88 | model.load_state_dict(checkpoint['model']) 89 | else: 90 | print('No checkpoint found at:', checkpoint_path) 91 | return 92 | _ = engine.evaluate_till_now(model, data_loader, device, 93 | task_id, class_mask, acc_matrix, args,) 94 | 95 | return 96 | 97 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 98 | print('number of params:', n_parameters) 99 | 100 | optimizer = create_optimizer(args, model) 101 | 102 | if args.sched != 'constant': 103 | lr_scheduler, _ = create_scheduler(args, optimizer) 104 | elif args.sched == 'constant': 105 | lr_scheduler = None 106 | 107 | criterion = torch.nn.CrossEntropyLoss().to(device) 108 | 109 | print(f"Start training for {args.epochs} epochs") 110 | start_time = time.time() 111 | 112 | engine.train_and_evaluate(model,criterion, data_loader, optimizer, 113 | lr_scheduler, device, class_mask, args) 114 | 115 | total_time = time.time() - start_time 116 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 117 | print(f"Total training time: {total_time_str}") 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser('LAE') 121 | 122 | parser.add_argument('--batch-size', default=24, type=int, help='Batch size per device') 123 | parser.add_argument('--epochs', default=5, type=int) 124 | 125 | # Model parameters 126 | parser.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train') 127 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 128 | parser.add_argument('--pretrained', default=True, help='Load pretrained model or not') 129 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') 130 | parser.add_argument('--drop-path', type=float, default=0.0, metavar='PCT', help='Drop path rate (default: 0.)') 131 | 132 | # Optimizer parameters 133 | parser.add_argument('--opt', default='adam', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adam"') 134 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') 135 | parser.add_argument('--opt-betas', default=(0.9, 0.999), type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: (0.9, 0.999), use opt default)') 136 | parser.add_argument('--clip-grad', type=float, default=0.0, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') 137 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') 138 | parser.add_argument('--weight-decay', type=float, default=0.0, help='weight decay (default: 0.0)') 139 | parser.add_argument('--reinit_optimizer', type=bool, default=True, help='reinit optimizer (default: True)') 140 | 141 | # Learning rate schedule parameters 142 | parser.add_argument('--sched', default='constant', type=str, metavar='SCHEDULER', help='LR scheduler (default: "constant"') 143 | parser.add_argument('--lr', type=float, default=0.0028125, metavar='LR', help='learning rate (default: 0.03)') 144 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') 145 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') 146 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') 147 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') 148 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 149 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') 150 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') 151 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 152 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') 153 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') 154 | parser.add_argument('--unscale_lr', type=bool, default=True, help='scaling lr by batch size (default: True)') 155 | 156 | # Augmentation parameters 157 | parser.add_argument('--color-jitter', type=float, default=None, metavar='PCT', help='Color jitter factor (default: 0.3)') 158 | parser.add_argument('--aa', type=str, default=None, metavar='NAME', 159 | help='Use AutoAugment policy. "v0" or "original". " + \ 160 | "(default: rand-m9-mstd0.5-inc1)'), 161 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 162 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 163 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 164 | 165 | # * Random Erase params 166 | parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT', help='Random erase prob (default: 0.25)') 167 | parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') 168 | parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') 169 | 170 | # Data parameters 171 | parser.add_argument('--data-path', default='/local_datasets/', type=str, help='dataset path') 172 | parser.add_argument('--dataset', default='iDigits', type=str, help='dataset name') 173 | parser.add_argument('--shuffle', default=False, help='shuffle the data order') 174 | parser.add_argument('--output_dir', default='./output', help='path where to save, empty for no saving') 175 | parser.add_argument('--device', default='cuda', help='device to use for training / testing') 176 | parser.add_argument('--seed', default=42, type=int) 177 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 178 | parser.add_argument('--num_workers', default=4, type=int) 179 | parser.add_argument('--pin-mem', action='store_true', 180 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 181 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 182 | help='') 183 | parser.set_defaults(pin_mem=True) 184 | 185 | # distributed training parameters 186 | parser.add_argument('--world_size', default=1, type=int, 187 | help='number of distributed processes') 188 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 189 | 190 | # Continual learning parameters 191 | parser.add_argument('--num_tasks', default=10, type=int, help='number of sequential tasks') 192 | parser.add_argument('--train_mask', default=True, type=bool, help='if using the class mask at training') 193 | parser.add_argument('--task_inc', action='store_true', default=False, help='if doing task incremental') 194 | parser.add_argument('--domain_inc', action='store_true', default=False, help='if doing domain incremental') 195 | parser.add_argument('--versatile_inc', action='store_true', default=False, help='if doing versatile incremental') 196 | parser.add_argument('--joint_train', default=False, help='if doing joint training') 197 | 198 | # Prompt parameters 199 | parser.add_argument('--adapt_blocks', default=[0, 1, 2, 3, 4]) 200 | parser.add_argument('--ema_decay', default=0.9999) 201 | parser.add_argument('--num_freeze_epochs', type=int,default=3) 202 | parser.add_argument('--eval_only_emas', default=False) 203 | 204 | # Misc parameters 205 | parser.add_argument('--print_freq', type=int, default=10, help = 'The frequency of printing') 206 | parser.add_argument('--develop', action='store_true', default=False) 207 | 208 | #! IC 209 | parser.add_argument('--IC', action='store_true', default=False, help='if using incremental classifier') 210 | parser.add_argument('--d_threshold', action='store_true', default=False, help='if using dynamic thresholding in IC') 211 | parser.add_argument('--gamma',default=10.0, type=float, help='coefficient in dynamic thresholding') 212 | parser.add_argument('--thre',default=0, type=float, help='value of static threshold if not using dynamic thresholding') 213 | parser.add_argument('--alpha',default=1.0, type=float, help='coefficient of knowledge distillation in IC loss') 214 | 215 | #! CAST 216 | parser.add_argument('--beta',default=0.001, type=float, help='coefficient of cast loss') 217 | parser.add_argument('--k', default=2, type=int, help='the number of clusters in shift pool') 218 | parser.add_argument('--use_cast_loss', action='store_true', default=False, help='if using CAST loss') 219 | parser.add_argument('--norm_cast', action='store_true', default=False, help='if using normalization in cast') 220 | 221 | 222 | args = parser.parse_args() 223 | 224 | if args.output_dir: 225 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 226 | Path(args.data_path).mkdir(parents=True, exist_ok=True) 227 | main(args) 228 | 229 | sys.exit(0) -------------------------------------------------------------------------------- /continual_datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # https://github.com/pytorch/vision/blob/8635be94d1216f10fb8302da89233bd86445e449/torchvision/datasets/utils.py 8 | 9 | import os 10 | import os.path 11 | import hashlib 12 | import gzip 13 | import errno 14 | import tarfile 15 | import zipfile 16 | import numpy as np 17 | import torch 18 | import codecs 19 | 20 | from torch.utils.model_zoo import tqdm 21 | 22 | 23 | def gen_bar_updater(): 24 | pbar = tqdm(total=None) 25 | 26 | def bar_update(count, block_size, total_size): 27 | if pbar.total is None and total_size: 28 | pbar.total = total_size 29 | progress_bytes = count * block_size 30 | pbar.update(progress_bytes - pbar.n) 31 | 32 | return bar_update 33 | 34 | 35 | def calculate_md5(fpath, chunk_size=1024 * 1024): 36 | md5 = hashlib.md5() 37 | with open(fpath, 'rb') as f: 38 | for chunk in iter(lambda: f.read(chunk_size), b''): 39 | md5.update(chunk) 40 | return md5.hexdigest() 41 | 42 | 43 | def check_md5(fpath, md5, **kwargs): 44 | return md5 == calculate_md5(fpath, **kwargs) 45 | 46 | 47 | def check_integrity(fpath, md5=None): 48 | if not os.path.isfile(fpath): 49 | return False 50 | if md5 is None: 51 | return True 52 | return check_md5(fpath, md5) 53 | 54 | 55 | def makedir_exist_ok(dirpath): 56 | """ 57 | Python2 support for os.makedirs(.., exist_ok=True) 58 | """ 59 | try: 60 | os.makedirs(dirpath) 61 | except OSError as e: 62 | if e.errno == errno.EEXIST: 63 | pass 64 | else: 65 | raise 66 | 67 | 68 | def download_url(url, root, filename=None, md5=None): 69 | """Download a file from a url and place it in root. 70 | Args: 71 | url (str): URL to download file from 72 | root (str): Directory to place downloaded file in 73 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 74 | md5 (str, optional): MD5 checksum of the download. If None, do not check 75 | """ 76 | from six.moves import urllib 77 | 78 | root = os.path.expanduser(root) 79 | if not filename: 80 | filename = os.path.basename(url) 81 | fpath = os.path.join(root, filename) 82 | 83 | makedir_exist_ok(root) 84 | 85 | # downloads file 86 | if check_integrity(fpath, md5): 87 | print('Using downloaded and verified file: ' + fpath) 88 | else: 89 | try: 90 | print('Downloading ' + url + ' to ' + fpath) 91 | urllib.request.urlretrieve( 92 | url, fpath, 93 | reporthook=gen_bar_updater() 94 | ) 95 | except (urllib.error.URLError, IOError) as e: 96 | if url[:5] == 'https': 97 | url = url.replace('https:', 'http:') 98 | print('Failed download. Trying https -> http instead.' 99 | ' Downloading ' + url + ' to ' + fpath) 100 | urllib.request.urlretrieve( 101 | url, fpath, 102 | reporthook=gen_bar_updater() 103 | ) 104 | else: 105 | raise e 106 | 107 | 108 | def list_dir(root, prefix=False): 109 | """List all directories at a given root 110 | Args: 111 | root (str): Path to directory whose folders need to be listed 112 | prefix (bool, optional): If true, prepends the path to each result, otherwise 113 | only returns the name of the directories found 114 | """ 115 | root = os.path.expanduser(root) 116 | directories = list( 117 | filter( 118 | lambda p: os.path.isdir(os.path.join(root, p)), 119 | os.listdir(root) 120 | ) 121 | ) 122 | 123 | if prefix is True: 124 | directories = [os.path.join(root, d) for d in directories] 125 | 126 | return directories 127 | 128 | 129 | def list_files(root, suffix, prefix=False): 130 | """List all files ending with a suffix at a given root 131 | Args: 132 | root (str): Path to directory whose folders need to be listed 133 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 134 | It uses the Python "str.endswith" method and is passed directly 135 | prefix (bool, optional): If true, prepends the path to each result, otherwise 136 | only returns the name of the files found 137 | """ 138 | root = os.path.expanduser(root) 139 | files = list( 140 | filter( 141 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 142 | os.listdir(root) 143 | ) 144 | ) 145 | 146 | if prefix is True: 147 | files = [os.path.join(root, d) for d in files] 148 | 149 | return files 150 | 151 | 152 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 153 | """Download a Google Drive file from and place it in root. 154 | Args: 155 | file_id (str): id of file to be downloaded 156 | root (str): Directory to place downloaded file in 157 | filename (str, optional): Name to save the file under. If None, use the id of the file. 158 | md5 (str, optional): MD5 checksum of the download. If None, do not check 159 | """ 160 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 161 | import requests 162 | url = "https://docs.google.com/uc?export=download" 163 | 164 | root = os.path.expanduser(root) 165 | if not filename: 166 | filename = file_id 167 | fpath = os.path.join(root, filename) 168 | 169 | makedir_exist_ok(root) 170 | 171 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 172 | print('Using downloaded and verified file: ' + fpath) 173 | else: 174 | session = requests.Session() 175 | 176 | response = session.get(url, params={'id': file_id}, stream=True) 177 | token = _get_confirm_token(response) 178 | 179 | if token: 180 | params = {'id': file_id, 'confirm': token} 181 | response = session.get(url, params=params, stream=True) 182 | 183 | _save_response_content(response, fpath) 184 | 185 | 186 | def _get_confirm_token(response): 187 | for key, value in response.cookies.items(): 188 | if key.startswith('download_warning'): 189 | return value 190 | 191 | return None 192 | 193 | 194 | def _save_response_content(response, destination, chunk_size=32768): 195 | with open(destination, "wb") as f: 196 | pbar = tqdm(total=None) 197 | progress = 0 198 | for chunk in response.iter_content(chunk_size): 199 | if chunk: # filter out keep-alive new chunks 200 | f.write(chunk) 201 | progress += len(chunk) 202 | pbar.update(progress - pbar.n) 203 | pbar.close() 204 | 205 | 206 | def _is_tar(filename): 207 | return filename.endswith(".tar") 208 | 209 | 210 | def _is_targz(filename): 211 | return filename.endswith(".tar.gz") 212 | 213 | 214 | def _is_gzip(filename): 215 | return filename.endswith(".gz") and not filename.endswith(".tar.gz") 216 | 217 | 218 | def _is_zip(filename): 219 | return filename.endswith(".zip") 220 | 221 | 222 | def extract_archive(from_path, to_path=None, remove_finished=False): 223 | if to_path is None: 224 | to_path = os.path.dirname(from_path) 225 | 226 | if _is_tar(from_path): 227 | with tarfile.open(from_path, 'r') as tar: 228 | tar.extractall(path=to_path) 229 | elif _is_targz(from_path): 230 | with tarfile.open(from_path, 'r:gz') as tar: 231 | tar.extractall(path=to_path) 232 | elif _is_gzip(from_path): 233 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) 234 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: 235 | out_f.write(zip_f.read()) 236 | elif _is_zip(from_path): 237 | with zipfile.ZipFile(from_path, 'r') as z: 238 | z.extractall(to_path) 239 | else: 240 | raise ValueError("Extraction of {} not supported".format(from_path)) 241 | 242 | if remove_finished: 243 | os.remove(from_path) 244 | 245 | 246 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None, 247 | md5=None, remove_finished=False): 248 | download_root = os.path.expanduser(download_root) 249 | if extract_root is None: 250 | extract_root = download_root 251 | if not filename: 252 | filename = os.path.basename(url) 253 | 254 | download_url(url, download_root, filename, md5) 255 | 256 | archive = os.path.join(download_root, filename) 257 | print("Extracting {} to {}".format(archive, extract_root)) 258 | extract_archive(archive, extract_root, remove_finished) 259 | 260 | 261 | def iterable_to_str(iterable): 262 | return "'" + "', '".join([str(item) for item in iterable]) + "'" 263 | 264 | 265 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): 266 | if not isinstance(value, torch._six.string_classes): 267 | if arg is None: 268 | msg = "Expected type str, but got type {type}." 269 | else: 270 | msg = "Expected type str for argument {arg}, but got type {type}." 271 | msg = msg.format(type=type(value), arg=arg) 272 | raise ValueError(msg) 273 | 274 | if valid_values is None: 275 | return value 276 | 277 | if value not in valid_values: 278 | if custom_msg is not None: 279 | msg = custom_msg 280 | else: 281 | msg = ("Unknown value '{value}' for argument {arg}. " 282 | "Valid values are {{{valid_values}}}.") 283 | msg = msg.format(value=value, arg=arg, 284 | valid_values=iterable_to_str(valid_values)) 285 | raise ValueError(msg) 286 | 287 | return value 288 | 289 | 290 | def get_int(b): 291 | return int(codecs.encode(b, 'hex'), 16) 292 | 293 | 294 | def open_maybe_compressed_file(path): 295 | """Return a file object that possibly decompresses 'path' on the fly. 296 | Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. 297 | """ 298 | if not isinstance(path, torch._six.string_classes): 299 | return path 300 | if path.endswith('.gz'): 301 | import gzip 302 | return gzip.open(path, 'rb') 303 | if path.endswith('.xz'): 304 | import lzma 305 | return lzma.open(path, 'rb') 306 | return open(path, 'rb') 307 | 308 | 309 | def read_sn3_pascalvincent_tensor(path, strict=True): 310 | """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). 311 | Argument may be a filename, compressed filename, or file object. 312 | """ 313 | # typemap 314 | if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): 315 | read_sn3_pascalvincent_tensor.typemap = { 316 | 8: (torch.uint8, np.uint8, np.uint8), 317 | 9: (torch.int8, np.int8, np.int8), 318 | 11: (torch.int16, np.dtype('>i2'), 'i2'), 319 | 12: (torch.int32, np.dtype('>i4'), 'i4'), 320 | 13: (torch.float32, np.dtype('>f4'), 'f4'), 321 | 14: (torch.float64, np.dtype('>f8'), 'f8')} 322 | # read 323 | with open_maybe_compressed_file(path) as f: 324 | data = f.read() 325 | # parse 326 | magic = get_int(data[0:4]) 327 | nd = magic % 256 328 | ty = magic // 256 329 | assert nd >= 1 and nd <= 3 330 | assert ty >= 8 and ty <= 14 331 | m = read_sn3_pascalvincent_tensor.typemap[ty] 332 | s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] 333 | parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) 334 | assert parsed.shape[0] == np.prod(s) or not strict 335 | return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) 336 | 337 | 338 | def read_label_file(path): 339 | with open(path, 'rb') as f: 340 | x = read_sn3_pascalvincent_tensor(f, strict=False) 341 | assert(x.dtype == torch.uint8) 342 | assert(x.ndimension() == 1) 343 | return x.long() 344 | 345 | 346 | def read_image_file(path): 347 | with open(path, 'rb') as f: 348 | x = read_sn3_pascalvincent_tensor(f, strict=False) 349 | assert(x.dtype == torch.uint8) 350 | assert(x.ndimension() == 3) 351 | return x -------------------------------------------------------------------------------- /continual_datasets/continual_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import pathlib 5 | from pathlib import Path 6 | 7 | from typing import Any, Tuple 8 | 9 | import glob 10 | from shutil import move, rmtree 11 | 12 | import numpy as np 13 | 14 | import torch 15 | from torchvision import datasets 16 | from torchvision.datasets.utils import download_url, check_integrity, verify_str_arg, download_and_extract_archive 17 | 18 | import PIL 19 | from PIL import Image 20 | 21 | import tqdm 22 | import zipfile 23 | import tarfile 24 | 25 | from .dataset_utils import read_image_file, read_label_file 26 | 27 | class MNIST_RGB(datasets.MNIST): 28 | 29 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 30 | super(MNIST_RGB, self).__init__(root, transform=transform, target_transform=target_transform, download=download) 31 | self.train = train # training set or test set 32 | 33 | if self._check_legacy_exist(): 34 | self.data, self.targets = self._load_legacy_data() 35 | return 36 | 37 | if download: 38 | self.download() 39 | 40 | if not self._check_exists(): 41 | raise RuntimeError("Dataset not found. You can use download=True to download it") 42 | 43 | self.data, self.targets = self._load_data() 44 | self.classes = [i for i in range(10)] 45 | 46 | def _check_legacy_exist(self): 47 | processed_folder_exists = os.path.exists(self.processed_folder) 48 | if not processed_folder_exists: 49 | return False 50 | 51 | return all( 52 | check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) 53 | ) 54 | 55 | def _load_legacy_data(self): 56 | # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data 57 | # directly. 58 | data_file = self.training_file if self.train else self.test_file 59 | return torch.load(os.path.join(self.processed_folder, data_file)) 60 | 61 | def _load_data(self): 62 | image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" 63 | data = read_image_file(os.path.join(self.raw_folder, image_file)) 64 | 65 | label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" 66 | targets = read_label_file(os.path.join(self.raw_folder, label_file)) 67 | 68 | return data, targets 69 | 70 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 71 | """ 72 | Args: 73 | index (int): Index 74 | 75 | Returns: 76 | tuple: (image, target) where target is index of the target class. 77 | """ 78 | img, target = self.data[index], int(self.targets[index]) 79 | 80 | # doing this so that it is consistent with all other datasets 81 | # to return a PIL Image 82 | try: 83 | img = Image.fromarray(img.numpy(), mode='L').convert('RGB') 84 | except: 85 | pass 86 | 87 | if self.transform is not None: 88 | img = self.transform(img) 89 | 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | 93 | return img, target 94 | 95 | class MNISTM(torch.utils.data.Dataset): 96 | resources = [ 97 | ('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_train.pt.tar.gz', 98 | '191ed53db9933bd85cc9700558847391'), 99 | ('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_test.pt.tar.gz', 100 | 'e11cb4d7fff76d7ec588b1134907db59') 101 | ] 102 | 103 | training_file = "mnist_m_train.pt" 104 | test_file = "mnist_m_test.pt" 105 | 106 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 107 | root = os.path.join(root, 'MNIST-M') 108 | self.root = os.path.expanduser(root) 109 | self.transform = transform 110 | self.target_transform=target_transform 111 | self.train = train 112 | 113 | if download: 114 | self.download() 115 | 116 | if not self._check_exists(): 117 | raise RuntimeError("Dataset not found." + 118 | " You can use download=True to download it") 119 | 120 | if self.train: 121 | data_file = self.training_file 122 | else: 123 | data_file = self.test_file 124 | 125 | self.classes = [i for i in range(10)] 126 | 127 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 128 | 129 | def __getitem__(self, index): 130 | img, target = self.data[index], int(self.targets[index]) 131 | 132 | img = Image.fromarray(img.squeeze().numpy(), mode="RGB") 133 | 134 | if self.transform is not None: 135 | img = self.transform(img) 136 | 137 | if self.target_transform is not None: 138 | target = self.target_transform(target) 139 | 140 | return img, target 141 | 142 | def __len__(self): 143 | return len(self.data) 144 | 145 | @property 146 | def raw_folder(self): 147 | return os.path.join(self.root, 'raw') 148 | 149 | @property 150 | def processed_folder(self): 151 | return os.path.join(self.root, 'processed') 152 | 153 | def _check_exists(self): 154 | return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and 155 | os.path.exists(os.path.join(self.processed_folder, self.test_file))) 156 | 157 | def download(self): 158 | if self._check_exists(): 159 | return 160 | 161 | os.makedirs(self.raw_folder, exist_ok=True) 162 | os.makedirs(self.processed_folder, exist_ok=True) 163 | 164 | # download files 165 | for url, md5 in self.resources: 166 | filename = url.rpartition('/')[2] 167 | download_and_extract_archive(url, download_root=self.raw_folder, 168 | extract_root=self.processed_folder, 169 | filename=filename, md5=md5) 170 | 171 | class SynDigit(torch.utils.data.Dataset): 172 | resources = [ 173 | ('https://github.com/liyxi/synthetic-digits/releases/download/data/synth_train.pt.gz', 174 | 'd0e99daf379597e57448a89fc37ae5cf'), 175 | ('https://github.com/liyxi/synthetic-digits/releases/download/data/synth_test.pt.gz', 176 | '669d94c04d1c91552103e9aded0ee625') 177 | ] 178 | 179 | training_file = "synth_train.pt" 180 | test_file = "synth_test.pt" 181 | 182 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 183 | root = os.path.join(root, 'SynDigit') 184 | self.root = os.path.expanduser(root) 185 | self.transform = transform 186 | self.target_transform=target_transform 187 | self.train = train 188 | 189 | if download: 190 | self.download() 191 | 192 | if not self._check_exists(): 193 | raise RuntimeError("Dataset not found." + 194 | " You can use download=True to download it") 195 | 196 | if self.train: 197 | data_file = self.training_file 198 | else: 199 | data_file = self.test_file 200 | 201 | self.classes = [i for i in range(10)] 202 | 203 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 204 | 205 | def __getitem__(self, index): 206 | img, target = self.data[index], int(self.targets[index]) 207 | 208 | img = Image.fromarray(img.squeeze().numpy(), mode="RGB") 209 | 210 | if self.transform is not None: 211 | img = self.transform(img) 212 | 213 | if self.target_transform is not None: 214 | target = self.target_transform(target) 215 | 216 | return img, target 217 | 218 | def __len__(self): 219 | return len(self.data) 220 | 221 | @property 222 | def raw_folder(self): 223 | return os.path.join(self.root, 'raw') 224 | 225 | @property 226 | def processed_folder(self): 227 | return os.path.join(self.root, 'processed') 228 | 229 | def _check_exists(self): 230 | return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and 231 | os.path.exists(os.path.join(self.processed_folder, self.test_file))) 232 | 233 | def download(self): 234 | if self._check_exists(): 235 | return 236 | 237 | os.makedirs(self.raw_folder, exist_ok=True) 238 | os.makedirs(self.processed_folder, exist_ok=True) 239 | 240 | # download files 241 | for url, md5 in self.resources: 242 | filename = url.rpartition('/')[2] 243 | download_and_extract_archive(url, download_root=self.raw_folder, 244 | extract_root=self.processed_folder, 245 | filename=filename, md5=md5) 246 | 247 | class SVHN(datasets.SVHN): 248 | def __init__(self, root, split='train', transform=None, target_transform=None, download=False): 249 | super(SVHN, self).__init__(root, split=split, transform=transform, target_transform=target_transform, download=download) 250 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) 251 | self.url = self.split_list[split][0] 252 | self.filename = self.split_list[split][1] 253 | self.file_md5 = self.split_list[split][2] 254 | 255 | if download: 256 | self.download() 257 | 258 | if not self._check_integrity(): 259 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 260 | 261 | # import here rather than at top of file because this is 262 | # an optional dependency for torchvision 263 | import scipy.io as sio 264 | 265 | # reading(loading) mat file as array 266 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 267 | 268 | self.data = loaded_mat["X"] 269 | # loading from the .mat file gives an np array of type np.uint8 270 | # converting to np.int64, so that we have a LongTensor after 271 | # the conversion from the numpy array 272 | # the squeeze is needed to obtain a 1D tensor 273 | self.targets = loaded_mat["y"].astype(np.int64).squeeze() 274 | 275 | # the svhn dataset assigns the class label "10" to the digit 0 276 | # this makes it inconsistent with several loss functions 277 | # which expect the class labels to be in the range [0, C-1] 278 | np.place(self.targets, self.targets == 10, 0) 279 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 280 | self.classes = np.unique(self.targets) 281 | 282 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 283 | """ 284 | Args: 285 | index (int): Index 286 | 287 | Returns: 288 | tuple: (image, target) where target is index of the target class. 289 | """ 290 | img, target = self.data[index], int(self.targets[index]) 291 | 292 | # doing this so that it is consistent with all other datasets 293 | # to return a PIL Image 294 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 295 | 296 | if self.transform is not None: 297 | img = self.transform(img) 298 | 299 | if self.target_transform is not None: 300 | target = self.target_transform(target) 301 | 302 | return img, target 303 | 304 | def __len__(self) -> int: 305 | return len(self.data) 306 | 307 | def _check_integrity(self) -> bool: 308 | root = self.root 309 | md5 = self.split_list[self.split][2] 310 | fpath = os.path.join(root, self.filename) 311 | return check_integrity(fpath, md5) 312 | 313 | def download(self) -> None: 314 | md5 = self.split_list[self.split][2] 315 | download_url(self.url, self.root, self.filename, md5) 316 | 317 | def extra_repr(self) -> str: 318 | return "Split: {split}".format(**self.__dict__) 319 | 320 | class CORe50(torch.utils.data.Dataset): 321 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, mode='cil'): 322 | self.root = os.path.expanduser(root) 323 | self.transform = transform 324 | self.target_transform=target_transform 325 | self.train = train 326 | self.mode = mode 327 | 328 | self.url = 'http://bias.csr.unibo.it/maltoni/download/core50/core50_128x128.zip' 329 | self.filename = 'core50_128x128.zip' 330 | 331 | # self.fpath = os.path.join(root, 'VIL_CORe50') 332 | self.fpath = os.path.join(root, 'core50_128x128') 333 | 334 | if not os.path.isfile(self.fpath): 335 | if not download: 336 | raise RuntimeError('Dataset not found. You can use download=True to download it') 337 | else: 338 | print('Downloading from '+self.url) 339 | download_url(self.url, root, filename=self.filename) 340 | 341 | if not os.path.exists(os.path.join(root, 'core50_128x128')): 342 | with zipfile.ZipFile(os.path.join(self.root, self.filename), 'r') as zf: 343 | for member in tqdm.tqdm(zf.infolist(), desc=f'Extracting {self.filename}'): 344 | try: 345 | zf.extract(member, root) 346 | except zipfile.error as e: 347 | pass 348 | 349 | self.train_session_list = ['s1', 's2', 's4', 's5', 's6', 's8', 's9', 's11'] 350 | self.test_session_list = ['s3', 's7', 's10'] 351 | self.label = [f'o{i}' for i in range(1, 51)] 352 | 353 | if not os.path.exists(self.fpath + '/train') and not os.path.exists(self.fpath + '/test'): 354 | self.split() 355 | 356 | if self.train: 357 | fpath = self.fpath + '/train' 358 | if self.mode not in ['cil', 'joint']: 359 | self.data = [datasets.ImageFolder(f'{fpath}/{s}', transform=transform) for s in self.train_session_list] 360 | else: 361 | self.data = datasets.ImageFolder(fpath, transform=transform) 362 | else: 363 | fpath = self.fpath + '/test' 364 | self.data = datasets.ImageFolder(fpath, transform=transform) 365 | 366 | def split(self): 367 | train_folder = self.fpath + '/train' 368 | test_folder = self.fpath + '/test' 369 | 370 | if os.path.exists(train_folder): 371 | rmtree(train_folder) 372 | if os.path.exists(test_folder): 373 | rmtree(test_folder) 374 | os.mkdir(train_folder) 375 | os.mkdir(test_folder) 376 | 377 | if self.mode not in ['cil', 'joint']: 378 | for s in tqdm.tqdm(self.train_session_list, desc='Preprocessing'): 379 | src = os.path.join(self.fpath, s) 380 | if os.path.exists(os.path.join(train_folder, s)): 381 | continue 382 | move(src, train_folder) 383 | 384 | for s in tqdm.tqdm(self.test_session_list, desc='Preprocessing'): 385 | for l in self.label: 386 | dst = os.path.join(test_folder, l) 387 | if not os.path.exists(dst): 388 | os.mkdir(os.path.join(test_folder, l)) 389 | 390 | f = glob.glob(os.path.join(self.fpath, s, l, '*.png')) 391 | 392 | for src in f: 393 | move(src, dst) 394 | rmtree(os.path.join(self.fpath, s)) 395 | else: 396 | for s in tqdm.tqdm(self.train_session_list, desc='Preprocessing'): 397 | for l in self.label: 398 | dst = os.path.join(train_folder, l) 399 | if not os.path.exists(dst): 400 | os.mkdir(os.path.join(train_folder, l)) 401 | 402 | f = glob.glob(os.path.join(self.fpath, s, l, '*.png')) 403 | 404 | for src in f: 405 | move(src, dst) 406 | rmtree(os.path.join(self.fpath, s)) 407 | 408 | for s in tqdm.tqdm(self.test_session_list, desc='Preprocessing'): 409 | for l in self.label: 410 | dst = os.path.join(test_folder, l) 411 | if not os.path.exists(dst): 412 | os.mkdir(os.path.join(test_folder, l)) 413 | 414 | f = glob.glob(os.path.join(self.fpath, s, l, '*.png')) 415 | 416 | for src in f: 417 | move(src, dst) 418 | rmtree(os.path.join(self.fpath, s)) 419 | 420 | class DomainNet(torch.utils.data.Dataset): 421 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, mode='cil'): 422 | root = os.path.join(root, 'VIL_DomainNet') 423 | # root = os.path.join(root, 'DomainNet') 424 | self.root = os.path.expanduser(root) 425 | self.transform = transform 426 | self.target_transform=target_transform 427 | self.train = train 428 | self.mode = mode 429 | 430 | self.url = [ 431 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip', 432 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip', 433 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip', 434 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip', 435 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip', 436 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip' 437 | ] 438 | 439 | self.filename = [ 440 | 'clipart.zip', 441 | 'infograph.zip', 442 | 'painting.zip', 443 | 'quickdraw.zip', 444 | 'real.zip', 445 | 'sketch.zip' 446 | ] 447 | 448 | self.train_url_list = [ 449 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/txt/clipart_train.txt', 450 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/txt/infograph_train.txt', 451 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/txt/painting_train.txt', 452 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/txt/quickdraw_train.txt', 453 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/txt/real_train.txt', 454 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/txt/sketch_train.txt' 455 | ] 456 | 457 | for u in self.train_url_list: 458 | filename = u.split('/')[-1] 459 | if not os.path.isfile(os.path.join(self.root, filename)): 460 | if not download: 461 | raise RuntimeError('Dataset not found. You can use download=True to download it') 462 | else: 463 | print('Downloading from '+filename) 464 | download_url(u, root, filename=filename) 465 | 466 | self.test_url_list = [ 467 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/txt/clipart_test.txt', 468 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/txt/infograph_test.txt', 469 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/txt/painting_test.txt', 470 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/txt/quickdraw_test.txt', 471 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/txt/real_test.txt', 472 | 'http://csr.bu.edu/ftp/visda/2019/multi-source/txt/sketch_test.txt' 473 | ] 474 | 475 | for u in self.test_url_list: 476 | filename = u.split('/')[-1] 477 | if not os.path.isfile(os.path.join(self.root, filename)): 478 | if not download: 479 | raise RuntimeError('Dataset not found. You can use download=True to download it') 480 | else: 481 | print('Downloading from '+filename) 482 | download_url(u, root, filename=filename) 483 | 484 | self.fpath = [os.path.join(self.root, f) for f in self.filename] 485 | 486 | for i in range(len(self.fpath)): 487 | if not os.path.isfile(self.fpath[i]): 488 | if not download: 489 | raise RuntimeError('Dataset not found. You can use download=True to download it') 490 | else: 491 | print('Downloading from '+self.url[i]) 492 | download_url(self.url[i], root, filename=self.filename[i]) 493 | 494 | if not os.path.exists(self.root + '/train') and not os.path.exists(self.root + '/test'): 495 | for i in range(len(self.fpath)): 496 | if not os.path.exists(os.path.join(self.root, self.filename[i][:-4])): 497 | with zipfile.ZipFile(os.path.join(self.root, self.filename[i]), 'r') as zf: 498 | for member in tqdm.tqdm(zf.infolist(), desc=f'Extracting {self.filename[i]}'): 499 | try: 500 | zf.extract(member, root) 501 | except zipfile.error as e: 502 | pass 503 | 504 | self.split() 505 | 506 | if self.train: 507 | fpath = self.root + '/train' 508 | if self.mode not in ['cil', 'joint']: 509 | self.data = [datasets.ImageFolder(f'{fpath}/{d}', transform=transform) for d in ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']] 510 | else: 511 | self.data = datasets.ImageFolder(fpath, transform=transform) 512 | else: 513 | fpath = self.root + '/test' 514 | if self.mode not in ['cil', 'joint']: 515 | self.data = [datasets.ImageFolder(f'{fpath}/{d}', transform=transform) for d in ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']] 516 | else: 517 | self.data = datasets.ImageFolder(fpath, transform=transform) 518 | 519 | def split(self): 520 | train_folder = self.root + '/train' 521 | test_folder = self.root + '/test' 522 | 523 | if os.path.exists(train_folder): 524 | rmtree(train_folder) 525 | if os.path.exists(test_folder): 526 | rmtree(test_folder) 527 | os.mkdir(train_folder) 528 | os.mkdir(test_folder) 529 | 530 | if self.mode not in ['cil', 'joint']: 531 | for i in tqdm.tqdm(range(len(self.train_url_list)), desc='Preprocessing'): 532 | train_list = self.train_url_list[i].split('/')[-1] 533 | 534 | with open(os.path.join(self.root, train_list), 'r') as f: 535 | for line in f.readlines(): 536 | line = line.replace('\n', '') 537 | path, _ = line.split(' ') 538 | dst = '/'.join(path.split('/')[:2]) 539 | 540 | if not os.path.exists(os.path.join(train_folder, dst)): 541 | os.makedirs(os.path.join(train_folder, dst)) 542 | 543 | src = os.path.join(self.root, path) 544 | dst = os.path.join(train_folder, path) 545 | 546 | move(src, dst) 547 | 548 | for i in tqdm.tqdm(range(len(self.test_url_list)), desc='Preprocessing'): 549 | test_list = self.test_url_list[i].split('/')[-1] 550 | 551 | with open(os.path.join(self.root, test_list), 'r') as f: 552 | for line in f.readlines(): 553 | line = line.replace('\n', '') 554 | path, _ = line.split(' ') 555 | dst = '/'.join(path.split('/')[:2]) 556 | 557 | if not os.path.exists(os.path.join(test_folder, dst)): 558 | os.makedirs(os.path.join(test_folder, dst)) 559 | 560 | src = os.path.join(self.root, path) 561 | dst = os.path.join(test_folder, path) 562 | 563 | move(src, dst) 564 | rmtree(os.path.join(self.root, test_list.split('_')[0])) 565 | else: 566 | for i in tqdm.tqdm(range(len(self.train_url_list)), desc='Preprocessing'): 567 | train_list = self.train_url_list[i].split('/')[-1] 568 | 569 | with open(os.path.join(self.root, train_list), 'r') as f: 570 | for line in f.readlines(): 571 | line = line.replace('\n', '') 572 | path, _ = line.split(' ') 573 | dst = '/'.join(path.split('/')[1:2]) 574 | 575 | if not os.path.exists(os.path.join(train_folder, dst)): 576 | os.makedirs(os.path.join(train_folder, dst)) 577 | 578 | src = os.path.join(self.root, path) 579 | dst = '/'.join(path.split('/')[1:]) 580 | dst = os.path.join(train_folder, dst) 581 | 582 | move(src, dst) 583 | 584 | for i in tqdm.tqdm(range(len(self.test_url_list)), desc='Preprocessing'): 585 | test_list = self.test_url_list[i].split('/')[-1] 586 | 587 | with open(os.path.join(self.root, test_list), 'r') as f: 588 | for line in f.readlines(): 589 | line = line.replace('\n', '') 590 | path, _ = line.split(' ') 591 | dst = '/'.join(path.split('/')[1:2]) 592 | 593 | if not os.path.exists(os.path.join(test_folder, dst)): 594 | os.makedirs(os.path.join(test_folder, dst)) 595 | 596 | src = os.path.join(self.root, path) 597 | dst = '/'.join(path.split('/')[1:]) 598 | dst = os.path.join(test_folder, dst) 599 | 600 | move(src, dst) 601 | rmtree(os.path.join(self.root, test_list.split('_')[0])) -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import os 4 | import datetime 5 | import json 6 | from turtle import undo 7 | from typing import Iterable 8 | from pathlib import Path 9 | 10 | import torch 11 | 12 | import numpy as np 13 | 14 | from timm.utils import accuracy 15 | from timm.optim import create_optimizer 16 | from timm.utils.model_ema import ModelEmaV2 17 | import copy 18 | import utils 19 | import torch.nn.functional as F 20 | from sklearn.cluster import KMeans 21 | from sklearn.manifold import TSNE 22 | import matplotlib.pyplot as plt 23 | 24 | 25 | class Engine(): 26 | def __init__(self, model=None,device=None,class_mask=[], domain_list= [], args=None): 27 | self.current_task=0 28 | self.current_classes=[] 29 | #! distillation 30 | self.class_group_num = 5 31 | self.classifier_pool = [None for _ in range(self.class_group_num)] 32 | self.class_group_train_count = [0 for _ in range(self.class_group_num)] 33 | 34 | self.task_num = len(class_mask) 35 | self.class_group_size = len(class_mask[0]) 36 | self.distill_head= None 37 | self.model = model 38 | 39 | self.num_classes= max([item for mask in class_mask for item in mask])+1 40 | self.labels_in_head = np.arange(self.num_classes) 41 | self.added_classes_in_cur_task = set() 42 | self.head_timestamps = np.zeros_like(self.labels_in_head) 43 | self.args=args 44 | 45 | self.class_mask=class_mask 46 | self.domain_list=domain_list 47 | 48 | self.task_type="initial" 49 | self.args=args 50 | 51 | self.adapter_vec=[] 52 | self.task_type_list=[] 53 | self.class_group_list=[] 54 | self.adapter_vec_label=[] 55 | self.device=device 56 | 57 | if self.args.d_threshold: 58 | self.acc_per_label = np.zeros((self.args.class_num, self.args.domain_num)) 59 | self.label_train_count = np.zeros((self.args.class_num)) 60 | self.tanh = torch.nn.Tanh() 61 | 62 | self.cs=torch.nn.CosineSimilarity(dim=1,eps=1e-6) 63 | 64 | def kl_div(self,p,q): 65 | p=F.softmax(p,dim=1) 66 | q=F.softmax(q,dim=1) 67 | kl = torch.mean(torch.sum(p * torch.log(p / q),dim=1)) 68 | return kl 69 | 70 | def set_new_head(self, model, labels_to_be_added,task_id): 71 | len_new_nodes = len(labels_to_be_added) 72 | self.labels_in_head = np.concatenate((self.labels_in_head, labels_to_be_added)) 73 | self.added_classes_in_cur_task.update(labels_to_be_added) 74 | self.head_timestamps = np.concatenate((self.head_timestamps, [task_id]*len_new_nodes)) 75 | prev_weight, prev_bias = model.head.weight, model.head.bias 76 | prev_shape = prev_weight.shape # (class, dim) 77 | new_head = torch.nn.Linear(prev_shape[-1], prev_shape[0] + len_new_nodes) 78 | 79 | new_head.weight[:prev_weight.shape[0]].data.copy_(prev_weight) 80 | new_head.weight[prev_weight.shape[0]:].data.copy_(prev_weight[labels_to_be_added]) 81 | new_head.bias[:prev_weight.shape[0]].data.copy_(prev_bias) 82 | new_head.bias[prev_weight.shape[0]:].data.copy_(prev_bias[labels_to_be_added]) 83 | 84 | print(f"Added {len_new_nodes} nodes with label ({labels_to_be_added})") 85 | return new_head 86 | 87 | 88 | def inference_acc(self,model,data_loader,device): 89 | print("Start detecting labels to be added...") 90 | accuracy_per_label = [] 91 | correct_pred_per_label = [0 for i in range(len(self.current_classes))] 92 | num_instance_per_label = [0 for i in range(len(self.current_classes))] 93 | 94 | with torch.no_grad(): 95 | for batch_idx, (input, target) in enumerate(data_loader): 96 | if self.args.develop: 97 | if batch_idx>200: 98 | break 99 | input = input.to(device, non_blocking=True) 100 | target = target.to(device, non_blocking=True) 101 | 102 | output = model(input) 103 | 104 | if output.shape[-1] > self.num_classes: # there are already added nodes till now 105 | output,_,_ = self.get_max_label_logits(output, self.current_classes) # there are added nodes previously, but not in current task -> get maximum value and use it 106 | mask = self.current_classes 107 | not_mask = np.setdiff1d(np.arange(self.num_classes), mask) 108 | not_mask = torch.tensor(not_mask, dtype=torch.int64).to(device) 109 | logits = output.index_fill(dim=1, index=not_mask, value=float('-inf')) 110 | _, pred = torch.max(logits, 1) 111 | 112 | correct_predictions = (pred == target) 113 | for i, label in enumerate(self.current_classes): 114 | mask = (target == label) 115 | num_correct_pred = torch.sum(correct_predictions[mask]) 116 | correct_pred_per_label[i] += num_correct_pred.item() 117 | num_instance_per_label[i] += sum(mask).item() 118 | for correct, num in zip (correct_pred_per_label, num_instance_per_label): 119 | accuracy_per_label.append(round(correct/num,2)) 120 | return accuracy_per_label 121 | 122 | def detect_labels_to_be_added(self,inference_acc, thresholds=[]): 123 | labels_with_low_accuracy = [] 124 | 125 | if self.args.d_threshold: 126 | for label,acc,thre in zip(self.current_classes, inference_acc,thresholds): 127 | if acc <= thre: 128 | labels_with_low_accuracy.append(label) 129 | else: # static threshold 130 | for label,acc in zip(self.current_classes, inference_acc): 131 | if acc <= self.args.thre: 132 | labels_with_low_accuracy.append(label) 133 | 134 | print(f"Labels whose node to be increased: {labels_with_low_accuracy}") 135 | return labels_with_low_accuracy 136 | 137 | def find_same_cluster_items(self,vec): 138 | if self.kmeans.n_clusters == 1: 139 | other_cluster_vecs = self.adapter_vec_array 140 | other_cluster_vecs = torch.tensor(other_cluster_vecs,dtype=torch.float32).to(self.device) 141 | same_cluster_vecs = None 142 | else: 143 | predicted_cluster = self.kmeans.predict(vec.unsqueeze(0).detach().cpu())[0] 144 | same_cluster_vecs = self.adapter_vec_array[self.cluster_assignments == predicted_cluster] 145 | other_cluster_vecs = self.adapter_vec_array[self.cluster_assignments != predicted_cluster] 146 | same_cluster_vecs = torch.tensor(same_cluster_vecs,dtype=torch.float32).to(self.device) 147 | other_cluster_vecs = torch.tensor(other_cluster_vecs,dtype=torch.float32).to(self.device) 148 | return same_cluster_vecs, other_cluster_vecs 149 | 150 | def calculate_l2_distance(self,diff_adapter, other): 151 | weights=[] 152 | for o in other: 153 | l2_distance = torch.norm(diff_adapter - o, p=2) 154 | weights.append(l2_distance.item()) 155 | weights = torch.tensor(weights) 156 | weights = weights / torch.sum(weights) # summation-> 1 157 | return weights 158 | 159 | def train_one_epoch(self,model: torch.nn.Module, 160 | criterion, data_loader: Iterable, optimizer: torch.optim.Optimizer, 161 | device: torch.device, epoch: int, max_norm: float = 0, 162 | set_training_mode=True, task_id=-1, class_mask=None, ema_model = None, args = None,): 163 | 164 | model.train(set_training_mode) 165 | 166 | metric_logger = utils.MetricLogger(delimiter=" ") 167 | metric_logger.add_meter('Lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 168 | metric_logger.add_meter('Loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 169 | header = f'Train: Epoch[{epoch+1:{int(math.log10(args.epochs))+1}}/{args.epochs}]' 170 | 171 | for batch_idx, (input, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): 172 | if self.args.develop: 173 | if batch_idx>20: 174 | break 175 | input = input.to(device, non_blocking=True) 176 | target = target.to(device, non_blocking=True) 177 | output = model(input) # (bs, class + n) 178 | distill_loss=0 179 | if self.distill_head != None: 180 | feature = model.forward_features(input)[:,0] 181 | output_distill = self.distill_head(feature) 182 | #! exclude added nodes in current task during distillation 183 | mask = torch.isin(torch.tensor(self.labels_in_head), torch.tensor(self.current_classes)) 184 | cur_class_nodes = torch.where(mask)[0]#[:-len(self.added_classes_in_cur_task)] #! to be fixed 185 | m=torch.isin(torch.tensor(self.labels_in_head[cur_class_nodes]), torch.tensor(list(self.added_classes_in_cur_task))) 186 | distill_node_indices = self.labels_in_head[cur_class_nodes][~m] 187 | distill_loss = self.kl_div(output[:,distill_node_indices], output_distill[:,distill_node_indices]) 188 | 189 | 190 | if output.shape[-1] > self.num_classes: # there are already added nodes till now 191 | output,_,_ = self.get_max_label_logits(output, class_mask[task_id],slice=False) 192 | if len(self.added_classes_in_cur_task) > 0: # there are added nodes in current task 193 | for added_class in self.added_classes_in_cur_task: 194 | cur_node = np.where(self.labels_in_head == added_class)[0][-1] # the latest appended node 195 | output[:, added_class] = output[:,cur_node]# replace logit value of added label 196 | 197 | output = output[:, :self.num_classes] 198 | 199 | # here is the trick to mask out classes of non-current tasks 200 | if args.train_mask and class_mask is not None: 201 | mask = class_mask[task_id] 202 | not_mask = np.setdiff1d(np.arange(args.nb_classes), mask) 203 | not_mask = torch.tensor(not_mask, dtype=torch.int64).to(device) 204 | logits = output.index_fill(dim=1, index=not_mask, value=float('-inf')) 205 | 206 | loss = criterion(logits, target) # (bs, class), (bs) 207 | 208 | 209 | if self.args.use_cast_loss: 210 | if len(self.adapter_vec)> args.k: 211 | cur_adapters = model.get_adapter() 212 | self.cur_adapters = self.flatten_parameters(cur_adapters) 213 | diff_adapter = self.cur_adapters-self.prev_adapters 214 | _, other = self.find_same_cluster_items(diff_adapter) 215 | sim = 0 216 | 217 | # if self.args.ws: 218 | weights = self.calculate_l2_distance(diff_adapter,other) 219 | for o,w in zip(other,weights): 220 | if self.args.norm_cast: 221 | sim += w * torch.matmul(diff_adapter, o) / (torch.norm(diff_adapter)*torch.norm(o)) 222 | else: 223 | sim += w * torch.matmul(diff_adapter, o) 224 | # else: 225 | # for o in other: 226 | # sim += torch.matmul(diff_adapter, o) 227 | # sim /= len(other) 228 | orth_loss = args.beta * torch.abs(sim) 229 | if self.args.use_cast_loss: 230 | if orth_loss>0: 231 | loss += orth_loss 232 | 233 | if self.args.IC: 234 | if distill_loss > 0: 235 | loss += distill_loss 236 | 237 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 238 | 239 | if not math.isfinite(loss.item()): 240 | print("Loss is {}, stopping training".format(loss.item())) 241 | sys.exit(1) 242 | 243 | optimizer.zero_grad() 244 | 245 | loss.backward(retain_graph=True) 246 | optimizer.step() 247 | torch.cuda.synchronize() 248 | metric_logger.update(Loss=loss.item()) 249 | metric_logger.update(Lr=optimizer.param_groups[0]["lr"]) 250 | metric_logger.meters['Acc@1'].update(acc1.item(), n=input.shape[0]) 251 | metric_logger.meters['Acc@5'].update(acc5.item(), n=input.shape[0]) 252 | 253 | if ema_model is not None: 254 | ema_model.update(model.get_adapter()) 255 | 256 | # gather the stats from all processes 257 | metric_logger.synchronize_between_processes() 258 | print("Averaged stats:", metric_logger) 259 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 260 | 261 | def get_max_label_logits(self,output, class_mask,task_id=None, slice=True,target=None): 262 | #! Get max value for each label output 263 | correct=0 264 | total=0 265 | for label in range(self.num_classes): 266 | label_nodes = np.where(self.labels_in_head == label)[0] 267 | output[:,label],max_index = torch.max(output[:,label_nodes],dim=1) 268 | if slice: 269 | output = output[:, :self.num_classes] # discard logits of added nodes 270 | 271 | return output,correct,total 272 | 273 | @torch.no_grad() 274 | def evaluate(self, model: torch.nn.Module, data_loader, 275 | device, task_id=-1, class_mask=None, ema_model=None, args=None,): 276 | criterion = torch.nn.CrossEntropyLoss() 277 | 278 | metric_logger = utils.MetricLogger(delimiter=" ") 279 | header = 'Test: [Task {}]'.format(task_id + 1) 280 | 281 | # switch to evaluation mode 282 | model.eval() 283 | 284 | correct_sum, total_sum = 0,0 285 | label_correct, label_total = np.zeros((self.class_group_size)), np.zeros((self.class_group_size)) 286 | with torch.no_grad(): 287 | for batch_idx,(input, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): 288 | if args.develop: 289 | if batch_idx>20: 290 | break 291 | 292 | input = input.to(device, non_blocking=True) 293 | target = target.to(device, non_blocking=True) 294 | 295 | # compute output 296 | output = model(input) 297 | 298 | output, correct, total = self.get_max_label_logits(output, class_mask[task_id],task_id=task_id, target=target,slice=True) 299 | output_ema = [output.softmax(dim=1)] 300 | correct_sum+=correct 301 | total_sum+=total 302 | 303 | if ema_model is not None: 304 | tmp_adapter = model.get_adapter() 305 | model.put_adapter(ema_model.module) 306 | output = model(input) 307 | output,_,_ = self.get_max_label_logits(output, class_mask[task_id],slice=True) 308 | output_ema.append(output.softmax(dim=1)) 309 | model.put_adapter(tmp_adapter) 310 | 311 | output = torch.stack(output_ema, dim=-1).max(dim=-1)[0] 312 | loss = criterion(output, target) 313 | 314 | if self.args.d_threshold and self.current_task +1 != self.args.num_tasks and self.current_task == task_id: 315 | label_correct, label_total = self.update_acc_per_label(label_correct, label_total, output, target) 316 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 317 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 318 | 319 | metric_logger.meters['Loss'].update(loss.item()) 320 | metric_logger.meters['Acc@1'].update(acc1.item(), n=input.shape[0]) 321 | metric_logger.meters['Acc@5'].update(acc5.item(), n=input.shape[0]) 322 | if total_sum>0: 323 | print(f"Max Pooling acc: {correct_sum/total_sum}") 324 | 325 | if self.args.d_threshold and task_id == self.current_task: 326 | domain_idx = int(self.label_train_count[self.current_classes][0]) 327 | self.acc_per_label[self.current_classes, domain_idx] += np.round(label_correct / label_total, decimals=3) 328 | print(self.label_train_count) 329 | print(self.acc_per_label) 330 | # gather the stats from all processes 331 | metric_logger.synchronize_between_processes() 332 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 333 | .format(top1=metric_logger.meters['Acc@1'], top5=metric_logger.meters['Acc@5'], losses=metric_logger.meters['Loss'])) 334 | 335 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 336 | 337 | 338 | @torch.no_grad() 339 | def evaluate_till_now(self,model: torch.nn.Module, data_loader, 340 | device, task_id=-1, class_mask=None, acc_matrix=None, ema_model=None, args=None,): 341 | stat_matrix = np.zeros((3, args.num_tasks)) # 3 for Acc@1, Acc@5, Loss 342 | 343 | for i in range(task_id+1): 344 | test_stats = self.evaluate(model=model, data_loader=data_loader[i]['val'], 345 | device=device, task_id=i, class_mask=class_mask, ema_model=ema_model, args=args) 346 | 347 | stat_matrix[0, i] = test_stats['Acc@1'] 348 | stat_matrix[1, i] = test_stats['Acc@5'] 349 | stat_matrix[2, i] = test_stats['Loss'] 350 | 351 | acc_matrix[i, task_id] = test_stats['Acc@1'] 352 | 353 | avg_stat = np.divide(np.sum(stat_matrix, axis=1), task_id+1) 354 | 355 | diagonal = np.diag(acc_matrix) 356 | 357 | result_str = "[Average accuracy till task{}]\tAcc@1: {:.4f}\tAcc@5: {:.4f}\tLoss: {:.4f}".format(task_id+1, avg_stat[0], avg_stat[1], avg_stat[2]) 358 | if task_id > 0: 359 | forgetting = np.mean((np.max(acc_matrix, axis=1) - 360 | acc_matrix[:, task_id])[:task_id]) 361 | backward = np.mean((acc_matrix[:, task_id] - diagonal)[:task_id]) 362 | 363 | result_str += "\tForgetting: {:.4f}\tBackward: {:.4f}".format(forgetting, backward) 364 | print(result_str) 365 | return test_stats 366 | 367 | def flatten_parameters(self,modules): 368 | flattened_params = [] 369 | 370 | for m in modules: 371 | params = list(m.parameters()) 372 | flattened_params.extend(params) 373 | return torch.cat([param.view(-1) for param in flattened_params]) 374 | 375 | def cluster_adapters(self): 376 | k = self.args.k 377 | if len(self.adapter_vec) > k: 378 | 379 | self.adapter_vec_array = torch.stack(self.adapter_vec).detach().cpu().numpy().astype(float) 380 | self.kmeans = KMeans(n_clusters=k,n_init=10) 381 | self.kmeans.fit(self.adapter_vec_array) 382 | self.cluster_assignments = self.kmeans.labels_ 383 | print("Cluster(shifts) Assignments:", self.cluster_assignments) 384 | 385 | 386 | def pre_train_epoch(self, model: torch.nn.Module, epoch: int = 0, task_id: int = 0, args = None,): 387 | if task_id == 0 or args.num_freeze_epochs < 1: 388 | return model 389 | 390 | if epoch == 0: 391 | for n, p in model.named_parameters(): 392 | if 'adapter' in n: 393 | p.requires_grad = False 394 | print('Freezing adapter parameters for {} epochs'.format(args.num_freeze_epochs)) 395 | 396 | if epoch == args.num_freeze_epochs: 397 | for n, p in model.named_parameters(): 398 | if 'adapter' in n: 399 | p.requires_grad = True 400 | print('Unfreezing adapter parameters') 401 | return model 402 | 403 | 404 | def pre_train_task(self, model, data_loader, device, task_id, args): 405 | self.current_task += 1 406 | self.current_class_group = int(min(self.class_mask[task_id])/self.class_group_size) 407 | self.class_group_list.append(self.current_class_group) 408 | self.current_classes = self.class_mask[task_id] 409 | 410 | print(f"\n\nTASK : {task_id}") 411 | self.added_classes_in_cur_task = set() 412 | #! distillation 413 | if self.class_group_train_count[self.current_class_group]==0: 414 | self.distill_head=None 415 | else: # already seen classes 416 | if self.args.IC: 417 | self.distill_head = self.classifier_pool[self.current_class_group] 418 | inf_acc = self.inference_acc(model, data_loader, device) 419 | thresholds=[] 420 | if self.args.d_threshold: 421 | count = self.class_group_train_count[self.current_class_group] 422 | if count > 0: 423 | average_accs = np.sum(self.acc_per_label[self.current_classes, :count], axis=1) / count 424 | thresholds = self.args.gamma*(average_accs - inf_acc) / average_accs 425 | thresholds = self.tanh(torch.tensor(thresholds)).tolist() 426 | thresholds = [round(t,2) if t>self.args.thre else self.args.thre for t in thresholds] 427 | print(f"Thresholds for class {self.current_classes[0]}~{self.current_classes[-1]} : {thresholds}") 428 | labels_to_be_added = self.detect_labels_to_be_added(inf_acc, thresholds) 429 | 430 | 431 | if len(labels_to_be_added) > 0: #! Add node to the classifier if needed 432 | new_head = self.set_new_head(model, labels_to_be_added,task_id).to(device) 433 | model.head = new_head 434 | optimizer = create_optimizer(args, model) 435 | 436 | with torch.no_grad(): 437 | prev_adapters = model.get_adapter() 438 | self.prev_adapters = self.flatten_parameters(prev_adapters) 439 | self.prev_adapters.requires_grad=False 440 | 441 | if task_id==0: 442 | self.task_type_list.append("Initial") 443 | return model, optimizer 444 | 445 | prev_class = self.class_mask[task_id-1] 446 | prev_domain = self.domain_list[task_id-1] 447 | cur_class = self.class_mask[task_id] 448 | self.cur_domain = self.domain_list[task_id] 449 | 450 | if prev_class == cur_class: 451 | self.task_type = "DIL" 452 | else: 453 | self.task_type = "CIL" 454 | 455 | self.task_type_list.append(self.task_type) 456 | print(f"Current task : {self.task_type}") 457 | 458 | return model, optimizer 459 | 460 | 461 | def post_train_task(self,model: torch.nn.Module,task_id=-1): 462 | #! update classifier pool 463 | self.class_group_train_count[self.current_class_group]+=1 464 | self.classifier_pool[self.current_class_group]=copy.deepcopy(model.head) 465 | for c in self.classifier_pool: 466 | if c != None: 467 | for p in c.parameters(): 468 | p.requires_grad=False 469 | 470 | cur_adapters = model.get_adapter() 471 | self.cur_adapters = self.flatten_parameters(cur_adapters) 472 | vector=self.cur_adapters - self.prev_adapters 473 | # if task_id>0: #? 1 474 | self.adapter_vec.append(vector) 475 | self.adapter_vec_label.append(self.task_type) 476 | self.cluster_adapters() 477 | 478 | def train_and_evaluate(self, model: torch.nn.Module, criterion, data_loader: Iterable, optimizer: torch.optim.Optimizer, 479 | lr_scheduler, device: torch.device, class_mask=None, args = None,): 480 | 481 | # create matrix to save end-of-task accuracies 482 | acc_matrix = np.zeros((args.num_tasks, args.num_tasks)) 483 | 484 | ema_model = None 485 | 486 | for task_id in range(args.num_tasks): 487 | # Create new optimizer for each task to clear optimizer status 488 | if task_id > 0 and args.reinit_optimizer: 489 | optimizer = create_optimizer(args, model) 490 | 491 | if task_id == 1 and len(args.adapt_blocks) > 0: 492 | # ema_model = ModelEmaV2(model.adapter, decay=args.ema_decay).to(device) 493 | ema_model = ModelEmaV2(model.get_adapter(), decay=args.ema_decay, device=device) 494 | model, optimizer = self.pre_train_task(model, data_loader[task_id]['train'], device, task_id,args) 495 | for epoch in range(args.epochs): 496 | model = self.pre_train_epoch(model=model, epoch=epoch, task_id=task_id, args=args,) 497 | train_stats = self.train_one_epoch(model=model, criterion=criterion, 498 | data_loader=data_loader[task_id]['train'], optimizer=optimizer, 499 | device=device, epoch=epoch, max_norm=args.clip_grad, 500 | set_training_mode=True, task_id=task_id, class_mask=class_mask, ema_model=ema_model, args=args,) 501 | 502 | if lr_scheduler: 503 | lr_scheduler.step(epoch) 504 | 505 | self.post_train_task(model,task_id=task_id) 506 | if self.args.d_threshold: 507 | self.label_train_count[self.current_classes] += 1 508 | test_stats = self.evaluate_till_now(model=model, data_loader=data_loader, device=device, 509 | task_id=task_id, class_mask=class_mask, acc_matrix=acc_matrix, ema_model=ema_model, args=args) 510 | if args.output_dir and utils.is_main_process(): 511 | Path(os.path.join(args.output_dir, 'checkpoint')).mkdir(parents=True, exist_ok=True) 512 | 513 | checkpoint_path = os.path.join(args.output_dir, 'checkpoint/task{}_checkpoint.pth'.format(task_id+1)) 514 | state_dict = { 515 | 'model': model.state_dict(), 516 | 'ema_model': ema_model.state_dict() if ema_model is not None else None, 517 | 'optimizer': optimizer.state_dict(), 518 | 'epoch': epoch, 519 | 'args': args, 520 | } 521 | if args.sched is not None and args.sched != 'constant': 522 | state_dict['lr_scheduler'] = lr_scheduler.state_dict() 523 | 524 | utils.save_on_master(state_dict, checkpoint_path) 525 | 526 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 527 | **{f'test_{k}': v for k, v in test_stats.items()}, 528 | 'epoch': epoch,} 529 | 530 | if args.output_dir and utils.is_main_process(): 531 | with open(os.path.join(args.output_dir, '{}_stats.txt'.format(datetime.datetime.now().strftime('log_%Y_%m_%d_%H_%M'))), 'a') as f: 532 | f.write(json.dumps(log_stats) + '\n') -------------------------------------------------------------------------------- /vision_transformer.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | `FlexiViT: One Model for All Patch Sizes` 12 | - https://arxiv.org/abs/2212.08013 13 | 14 | The official jax code is released and available at 15 | * https://github.com/google-research/vision_transformer 16 | * https://github.com/google-research/big_vision 17 | 18 | Acknowledgments: 19 | * The paper authors for releasing code and weights, thanks! 20 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch 21 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 22 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 23 | 24 | Hacked together by / Copyright 2020, Ross Wightman 25 | """ 26 | import logging 27 | import math 28 | from collections import OrderedDict 29 | from functools import partial 30 | from typing import Callable, List, Optional, Sequence, Tuple, Union, Dict 31 | 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | import torch.utils.checkpoint 36 | from torch.jit import Final 37 | 38 | from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ 39 | resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked 40 | from timm.models._builder import build_model_with_cfg 41 | from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv 42 | 43 | __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this 44 | 45 | 46 | _logger = logging.getLogger(__name__) 47 | 48 | 49 | class Attention(nn.Module): 50 | fused_attn: Final[bool] 51 | 52 | def __init__( 53 | self, 54 | dim, 55 | num_heads=8, 56 | qkv_bias=False, 57 | qk_norm=False, 58 | attn_drop=0., 59 | proj_drop=0., 60 | norm_layer=nn.LayerNorm, 61 | ): 62 | super().__init__() 63 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 64 | self.num_heads = num_heads 65 | self.head_dim = dim // num_heads 66 | self.scale = self.head_dim ** -0.5 67 | self.fused_attn = use_fused_attn() 68 | 69 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 70 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 71 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 72 | self.attn_drop = nn.Dropout(attn_drop) 73 | self.proj = nn.Linear(dim, dim) 74 | self.proj_drop = nn.Dropout(proj_drop) 75 | 76 | 77 | def forward(self, x): 78 | B, N, C = x.shape 79 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 80 | q, k, v = qkv.unbind(0) 81 | q, k = self.q_norm(q), self.k_norm(k) 82 | 83 | if self.fused_attn: 84 | x = F.scaled_dot_product_attention( 85 | q, k, v, 86 | dropout_p=self.attn_drop.p, 87 | ) 88 | else: 89 | q = q * self.scale 90 | attn = q @ k.transpose(-2, -1) 91 | attn = attn.softmax(dim=-1) 92 | attn = self.attn_drop(attn) 93 | x = attn @ v 94 | 95 | x = x.transpose(1, 2).reshape(B, N, C) 96 | x = self.proj(x) 97 | x = self.proj_drop(x) 98 | return x 99 | 100 | 101 | class LayerScale(nn.Module): 102 | def __init__(self, dim, init_values=1e-5, inplace=False): 103 | super().__init__() 104 | self.inplace = inplace 105 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 106 | 107 | def forward(self, x): 108 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 109 | 110 | class Scaler(nn.Module): 111 | def __init__(self, scale: Optional[float] = None): 112 | super().__init__() 113 | 114 | if scale is None: 115 | self.register_parameter("scale", nn.Parameter(torch.tensor(1.0))) 116 | else: 117 | self.scale = scale 118 | 119 | def forward(self, input): 120 | return input * self.scale 121 | 122 | class Adapter(nn.Module): 123 | def __init__( 124 | self, 125 | embed_dim: int, 126 | down_sample: Union[float, int] = 5, 127 | mode: str = "parallel", # enum before, after, parallel 128 | scale: Optional[float] = None, 129 | act_layer = nn.GELU, 130 | ): 131 | super().__init__() 132 | 133 | assert mode in ["before", "after", "parallel"], f"Unknown mode {mode}" 134 | 135 | hidden_dim = down_sample 136 | if isinstance(down_sample, float): 137 | hidden_dim = int(embed_dim * down_sample) 138 | 139 | self.layer = nn.Sequential( 140 | nn.Linear(embed_dim, hidden_dim), 141 | act_layer(), 142 | nn.Linear(hidden_dim, embed_dim), 143 | Scaler(scale), 144 | ) 145 | self.mode = mode 146 | 147 | self.reset_parameters() 148 | 149 | def reset_parameters(self): 150 | nn.init.kaiming_uniform_(self.layer[0].weight, a=math.sqrt(5)) 151 | nn.init.zeros_(self.layer[0].bias) 152 | nn.init.zeros_(self.layer[2].weight) 153 | nn.init.zeros_(self.layer[2].bias) 154 | 155 | def forward(self, module, input, **kwargs): 156 | if self.mode == "before": 157 | return module(self.layer(input) + input, **kwargs) 158 | if self.mode == "after": 159 | return self.layer(module(input, **kwargs)) + input 160 | return module(input, **kwargs) + self.layer(input) 161 | 162 | class Block(nn.Module): 163 | 164 | def __init__( 165 | self, 166 | dim, 167 | num_heads, 168 | mlp_ratio=4., 169 | qkv_bias=False, 170 | qk_norm=False, 171 | proj_drop=0., 172 | attn_drop=0., 173 | init_values=None, 174 | drop_path=0., 175 | act_layer=nn.GELU, 176 | norm_layer=nn.LayerNorm, 177 | mlp_layer=Mlp, 178 | adapt_blocks=None, 179 | ): 180 | super().__init__() 181 | self.norm1 = norm_layer(dim) 182 | self.attn = Attention( 183 | dim, 184 | num_heads=num_heads, 185 | qkv_bias=qkv_bias, 186 | qk_norm=qk_norm, 187 | attn_drop=attn_drop, 188 | proj_drop=proj_drop, 189 | norm_layer=norm_layer, 190 | ) 191 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 192 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 193 | 194 | self.norm2 = norm_layer(dim) 195 | self.mlp = mlp_layer( 196 | in_features=dim, 197 | hidden_features=int(dim * mlp_ratio), 198 | act_layer=act_layer, 199 | drop=proj_drop, 200 | ) 201 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 202 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 203 | 204 | # if adapt_blocks is not None: 205 | # # self.adapter = adapt_blocks 206 | if adapt_blocks: 207 | self.adapter = Adapter(embed_dim=dim, mode="parallel") 208 | 209 | def forward(self, x): 210 | if hasattr(self, "adapter"): 211 | x = x + self.drop_path1(self.ls1(self.adapter(self.attn, self.norm1(x)))) 212 | else: 213 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 214 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 215 | return x 216 | 217 | class VisionTransformer(nn.Module): 218 | """ Vision Transformer 219 | 220 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 221 | - https://arxiv.org/abs/2010.11929 222 | """ 223 | 224 | def __init__( 225 | self, 226 | img_size: Union[int, Tuple[int, int]] = 224, 227 | patch_size: Union[int, Tuple[int, int]] = 16, 228 | in_chans: int = 3, 229 | num_classes: int = 1000, 230 | global_pool: str = 'token', 231 | embed_dim: int = 768, 232 | depth: int = 12, 233 | num_heads: int = 12, 234 | mlp_ratio: float = 4., 235 | qkv_bias: bool = True, 236 | qk_norm: bool = False, 237 | init_values: Optional[float] = None, 238 | class_token: bool = True, 239 | no_embed_class: bool = False, 240 | pre_norm: bool = False, 241 | fc_norm: Optional[bool] = None, 242 | drop_rate: float = 0., 243 | pos_drop_rate: float = 0., 244 | patch_drop_rate: float = 0., 245 | proj_drop_rate: float = 0., 246 | attn_drop_rate: float = 0., 247 | drop_path_rate: float = 0., 248 | weight_init: str = '', 249 | embed_layer: Callable = PatchEmbed, 250 | norm_layer: Optional[Callable] = None, 251 | act_layer: Optional[Callable] = None, 252 | block_fn: Callable = Block, 253 | mlp_layer: Callable = Mlp, 254 | adapt_blocks: list = [], 255 | ): 256 | """ 257 | Args: 258 | img_size: Input image size. 259 | patch_size: Patch size. 260 | in_chans: Number of image input channels. 261 | num_classes: Mumber of classes for classification head. 262 | global_pool: Type of global pooling for final sequence (default: 'token'). 263 | embed_dim: Transformer embedding dimension. 264 | depth: Depth of transformer. 265 | num_heads: Number of attention heads. 266 | mlp_ratio: Ratio of mlp hidden dim to embedding dim. 267 | qkv_bias: Enable bias for qkv projections if True. 268 | init_values: Layer-scale init values (layer-scale enabled if not None). 269 | class_token: Use class token. 270 | fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. 271 | drop_rate: Head dropout rate. 272 | pos_drop_rate: Position embedding dropout rate. 273 | attn_drop_rate: Attention dropout rate. 274 | drop_path_rate: Stochastic depth rate. 275 | weight_init: Weight initialization scheme. 276 | embed_layer: Patch embedding layer. 277 | norm_layer: Normalization layer. 278 | act_layer: MLP activation layer. 279 | block_fn: Transformer block layer. 280 | """ 281 | super().__init__() 282 | assert global_pool in ('', 'avg', 'token') 283 | assert class_token or global_pool != 'token' 284 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm 285 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 286 | act_layer = act_layer or nn.GELU 287 | 288 | self.num_classes = num_classes 289 | self.global_pool = global_pool 290 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 291 | self.num_prefix_tokens = 1 if class_token else 0 292 | self.no_embed_class = no_embed_class 293 | self.grad_checkpointing = False 294 | 295 | self.patch_embed = embed_layer( 296 | img_size=img_size, 297 | patch_size=patch_size, 298 | in_chans=in_chans, 299 | embed_dim=embed_dim, 300 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) 301 | ) 302 | num_patches = self.patch_embed.num_patches 303 | 304 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None 305 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens 306 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) 307 | self.pos_drop = nn.Dropout(p=pos_drop_rate) 308 | if patch_drop_rate > 0: 309 | self.patch_drop = PatchDropout( 310 | patch_drop_rate, 311 | num_prefix_tokens=self.num_prefix_tokens, 312 | ) 313 | else: 314 | self.patch_drop = nn.Identity() 315 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() 316 | 317 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 318 | self.adapt_blocks = adapt_blocks 319 | # Adapter 320 | # if len(adapt_blocks) > 0: 321 | # self.adapt_blocks = adapt_blocks 322 | # self.adapter = nn.ModuleList([ 323 | # Adapter(embed_dim=embed_dim, mode="parallel") for _ in adapt_blocks 324 | # ]) 325 | 326 | self.blocks = nn.Sequential(*[ 327 | block_fn( 328 | dim=embed_dim, 329 | num_heads=num_heads, 330 | mlp_ratio=mlp_ratio, 331 | qkv_bias=qkv_bias, 332 | qk_norm=qk_norm, 333 | init_values=init_values, 334 | proj_drop=proj_drop_rate, 335 | attn_drop=attn_drop_rate, 336 | drop_path=dpr[i], 337 | norm_layer=norm_layer, 338 | act_layer=act_layer, 339 | mlp_layer=mlp_layer, 340 | adapt_blocks=True if i in adapt_blocks else False, 341 | ) 342 | for i in range(depth)]) 343 | 344 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 345 | 346 | # Classifier Head 347 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 348 | self.head_drop = nn.Dropout(drop_rate) 349 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 350 | 351 | if weight_init != 'skip': 352 | self.init_weights(weight_init) 353 | 354 | def get_adapter(self): 355 | return nn.ModuleList([self.blocks[i].adapter for i in self.adapt_blocks]) 356 | 357 | def put_adapter(self, adapter): 358 | for i in self.adapt_blocks: 359 | self.blocks[i].adapter = adapter[i] 360 | 361 | def init_weights(self, mode=''): 362 | assert mode in ('jax', 'jax_nlhb', 'moco', '') 363 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 364 | trunc_normal_(self.pos_embed, std=.02) 365 | if self.cls_token is not None: 366 | nn.init.normal_(self.cls_token, std=1e-6) 367 | named_apply(get_init_weights_vit(mode, head_bias), self) 368 | 369 | def _init_weights(self, m): 370 | # this fn left here for compat with downstream users 371 | init_weights_vit_timm(m) 372 | 373 | @torch.jit.ignore() 374 | def load_pretrained(self, checkpoint_path, prefix=''): 375 | _load_weights(self, checkpoint_path, prefix) 376 | 377 | @torch.jit.ignore 378 | def no_weight_decay(self): 379 | return {'pos_embed', 'cls_token', 'dist_token'} 380 | 381 | @torch.jit.ignore 382 | def group_matcher(self, coarse=False): 383 | return dict( 384 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed 385 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] 386 | ) 387 | 388 | @torch.jit.ignore 389 | def set_grad_checkpointing(self, enable=True): 390 | self.grad_checkpointing = enable 391 | 392 | @torch.jit.ignore 393 | def get_classifier(self): 394 | return self.head 395 | 396 | def reset_classifier(self, num_classes: int, global_pool=None): 397 | self.num_classes = num_classes 398 | if global_pool is not None: 399 | assert global_pool in ('', 'avg', 'token') 400 | self.global_pool = global_pool 401 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 402 | 403 | def _pos_embed(self, x): 404 | if self.no_embed_class: 405 | # deit-3, updated JAX (big vision) 406 | # position embedding does not overlap with class token, add then concat 407 | x = x + self.pos_embed 408 | if self.cls_token is not None: 409 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 410 | else: 411 | # original timm, JAX, and deit vit impl 412 | # pos_embed has entry for class token, concat then add 413 | if self.cls_token is not None: 414 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 415 | x = x + self.pos_embed 416 | return self.pos_drop(x) 417 | 418 | def _intermediate_layers( 419 | self, 420 | x: torch.Tensor, 421 | n: Union[int, Sequence] = 1, 422 | ): 423 | outputs, num_blocks = [], len(self.blocks) 424 | take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) 425 | 426 | # forward pass 427 | x = self.patch_embed(x) 428 | x = self._pos_embed(x) 429 | x = self.patch_drop(x) 430 | x = self.norm_pre(x) 431 | for i, blk in enumerate(self.blocks): 432 | x = blk(x) 433 | if i in take_indices: 434 | outputs.append(x) 435 | 436 | return outputs 437 | 438 | def get_intermediate_layers( 439 | self, 440 | x: torch.Tensor, 441 | n: Union[int, Sequence] = 1, 442 | reshape: bool = False, 443 | return_class_token: bool = False, 444 | norm: bool = False, 445 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 446 | """ Intermediate layer accessor (NOTE: This is a WIP experiment). 447 | Inspired by DINO / DINOv2 interface 448 | """ 449 | # take last n blocks if n is an int, if in is a sequence, select by matching indices 450 | outputs = self._intermediate_layers(x, n) 451 | if norm: 452 | outputs = [self.norm(out) for out in outputs] 453 | class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] 454 | outputs = [out[:, self.num_prefix_tokens:] for out in outputs] 455 | 456 | if reshape: 457 | grid_size = self.patch_embed.grid_size 458 | outputs = [ 459 | out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() 460 | for out in outputs 461 | ] 462 | 463 | if return_class_token: 464 | return tuple(zip(outputs, class_tokens)) 465 | return tuple(outputs) 466 | 467 | def forward_features(self, x): 468 | x = self.patch_embed(x) 469 | x = self._pos_embed(x) 470 | x = self.patch_drop(x) 471 | x = self.norm_pre(x) 472 | if self.grad_checkpointing and not torch.jit.is_scripting(): 473 | x = checkpoint_seq(self.blocks, x) 474 | else: 475 | x = self.blocks(x) 476 | x = self.norm(x) 477 | return x 478 | 479 | def forward_head(self, x, pre_logits: bool = False): 480 | if self.global_pool: 481 | x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 482 | x = self.fc_norm(x) 483 | x = self.head_drop(x) 484 | return x if pre_logits else self.head(x) 485 | 486 | def forward(self, x): 487 | x = self.forward_features(x) 488 | x = self.forward_head(x) 489 | return x 490 | 491 | 492 | def init_weights_vit_timm(module: nn.Module, name: str = ''): 493 | """ ViT weight initialization, original timm impl (for reproducibility) """ 494 | if isinstance(module, nn.Linear): 495 | trunc_normal_(module.weight, std=.02) 496 | if module.bias is not None: 497 | nn.init.zeros_(module.bias) 498 | elif hasattr(module, 'init_weights'): 499 | module.init_weights() 500 | 501 | 502 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): 503 | """ ViT weight initialization, matching JAX (Flax) impl """ 504 | if isinstance(module, nn.Linear): 505 | if name.startswith('head'): 506 | nn.init.zeros_(module.weight) 507 | nn.init.constant_(module.bias, head_bias) 508 | else: 509 | nn.init.xavier_uniform_(module.weight) 510 | if module.bias is not None: 511 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) 512 | elif isinstance(module, nn.Conv2d): 513 | lecun_normal_(module.weight) 514 | if module.bias is not None: 515 | nn.init.zeros_(module.bias) 516 | elif hasattr(module, 'init_weights'): 517 | module.init_weights() 518 | 519 | 520 | def init_weights_vit_moco(module: nn.Module, name: str = ''): 521 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ 522 | if isinstance(module, nn.Linear): 523 | if 'qkv' in name: 524 | # treat the weights of Q, K, V separately 525 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) 526 | nn.init.uniform_(module.weight, -val, val) 527 | else: 528 | nn.init.xavier_uniform_(module.weight) 529 | if module.bias is not None: 530 | nn.init.zeros_(module.bias) 531 | elif hasattr(module, 'init_weights'): 532 | module.init_weights() 533 | 534 | 535 | def get_init_weights_vit(mode='jax', head_bias: float = 0.): 536 | if 'jax' in mode: 537 | return partial(init_weights_vit_jax, head_bias=head_bias) 538 | elif 'moco' in mode: 539 | return init_weights_vit_moco 540 | else: 541 | return init_weights_vit_timm 542 | 543 | 544 | def resize_pos_embed( 545 | posemb, 546 | posemb_new, 547 | num_prefix_tokens=1, 548 | gs_new=(), 549 | interpolation='bicubic', 550 | antialias=False, 551 | ): 552 | """ Rescale the grid of position embeddings when loading from state_dict. 553 | 554 | *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed 555 | 556 | Adapted from: 557 | https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 558 | """ 559 | ntok_new = posemb_new.shape[1] 560 | if num_prefix_tokens: 561 | posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] 562 | ntok_new -= num_prefix_tokens 563 | else: 564 | posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] 565 | gs_old = int(math.sqrt(len(posemb_grid))) 566 | if not len(gs_new): # backwards compatibility 567 | gs_new = [int(math.sqrt(ntok_new))] * 2 568 | assert len(gs_new) >= 2 569 | _logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).') 570 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 571 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False) 572 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 573 | posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) 574 | return posemb 575 | 576 | 577 | @torch.no_grad() 578 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 579 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 580 | """ 581 | import numpy as np 582 | 583 | def _n2p(w, t=True): 584 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 585 | w = w.flatten() 586 | if t: 587 | if w.ndim == 4: 588 | w = w.transpose([3, 2, 0, 1]) 589 | elif w.ndim == 3: 590 | w = w.transpose([2, 0, 1]) 591 | elif w.ndim == 2: 592 | w = w.transpose([1, 0]) 593 | return torch.from_numpy(w) 594 | 595 | w = np.load(checkpoint_path) 596 | interpolation = 'bilinear' 597 | antialias = False 598 | big_vision = False 599 | if not prefix: 600 | if 'opt/target/embedding/kernel' in w: 601 | prefix = 'opt/target/' 602 | elif 'params/embedding/kernel' in w: 603 | prefix = 'params/' 604 | big_vision = True 605 | 606 | if hasattr(model.patch_embed, 'backbone'): 607 | # hybrid 608 | backbone = model.patch_embed.backbone 609 | stem_only = not hasattr(backbone, 'stem') 610 | stem = backbone if stem_only else backbone.stem 611 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 612 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 613 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 614 | if not stem_only: 615 | for i, stage in enumerate(backbone.stages): 616 | for j, block in enumerate(stage.blocks): 617 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 618 | for r in range(3): 619 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 620 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 621 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 622 | if block.downsample is not None: 623 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 624 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 625 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 626 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 627 | else: 628 | embed_conv_w = adapt_input_conv( 629 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 630 | if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]: 631 | embed_conv_w = resample_patch_embed( 632 | embed_conv_w, 633 | model.patch_embed.proj.weight.shape[-2:], 634 | interpolation=interpolation, 635 | antialias=antialias, 636 | verbose=True, 637 | ) 638 | 639 | model.patch_embed.proj.weight.copy_(embed_conv_w) 640 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 641 | if model.cls_token is not None: 642 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 643 | if big_vision: 644 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) 645 | else: 646 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 647 | if pos_embed_w.shape != model.pos_embed.shape: 648 | old_shape = pos_embed_w.shape 649 | num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) 650 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights 651 | pos_embed_w, 652 | new_size=model.patch_embed.grid_size, 653 | num_prefix_tokens=num_prefix_tokens, 654 | interpolation=interpolation, 655 | antialias=antialias, 656 | verbose=True, 657 | ) 658 | model.pos_embed.copy_(pos_embed_w) 659 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 660 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 661 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 662 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 663 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 664 | # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights 665 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 666 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 667 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 668 | mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) 669 | for i, block in enumerate(model.blocks.children()): 670 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 671 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' 672 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 673 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 674 | block.attn.qkv.weight.copy_(torch.cat([ 675 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 676 | block.attn.qkv.bias.copy_(torch.cat([ 677 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 678 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 679 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 680 | for r in range(2): 681 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) 682 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) 683 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) 684 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) 685 | 686 | 687 | def _convert_openai_clip(state_dict, model): 688 | out_dict = {} 689 | swaps = [ 690 | ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'), 691 | ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'), 692 | ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'), 693 | ] 694 | for k, v in state_dict.items(): 695 | if not k.startswith('visual.'): 696 | continue 697 | for sp in swaps: 698 | k = k.replace(sp[0], sp[1]) 699 | 700 | if k == 'proj': 701 | k = 'head.weight' 702 | v = v.transpose(0, 1) 703 | out_dict['head.bias'] = torch.zeros(v.shape[0]) 704 | elif k == 'class_embedding': 705 | k = 'cls_token' 706 | v = v.unsqueeze(0).unsqueeze(1) 707 | elif k == 'pos_embed': 708 | v = v.unsqueeze(0) 709 | if v.shape[1] != model.pos_embed.shape[1]: 710 | # To resize pos embedding when using model at different size from pretrained weights 711 | v = resize_pos_embed( 712 | v, 713 | model.pos_embed, 714 | 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), 715 | model.patch_embed.grid_size 716 | ) 717 | out_dict[k] = v 718 | return out_dict 719 | 720 | 721 | def _convert_dinov2(state_dict, model): 722 | import re 723 | out_dict = {} 724 | for k, v in state_dict.items(): 725 | if k == "mask_token": 726 | continue 727 | elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): 728 | out_dict[k.replace("w12", "fc1")] = v 729 | continue 730 | elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): 731 | out_dict[k.replace("w3", "fc2")] = v 732 | continue 733 | out_dict[k] = v 734 | return out_dict 735 | 736 | 737 | def _convert_ijepa(state_dict, model): 738 | out_dict = {} 739 | for k, v in state_dict['encoder'].items(): 740 | if k.startswith('module.'): 741 | k = k[7:] 742 | if k.startswith('norm.'): 743 | k = 'fc_norm.' + k[5:] 744 | out_dict[k] = v 745 | return out_dict 746 | 747 | 748 | def checkpoint_filter_fn( 749 | state_dict, 750 | model, 751 | adapt_layer_scale=False, 752 | interpolation='bicubic', 753 | antialias=True, 754 | ): 755 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 756 | import re 757 | out_dict = {} 758 | state_dict = state_dict.get('model', state_dict) 759 | state_dict = state_dict.get('state_dict', state_dict) 760 | 761 | if 'visual.class_embedding' in state_dict: 762 | return _convert_openai_clip(state_dict, model) 763 | 764 | if "mask_token" in state_dict: 765 | state_dict = _convert_dinov2(state_dict, model) 766 | 767 | if "encoder" in state_dict: 768 | state_dict = _convert_ijepa(state_dict, model) 769 | 770 | for k, v in state_dict.items(): 771 | if 'patch_embed.proj.weight' in k: 772 | O, I, H, W = model.patch_embed.proj.weight.shape 773 | if len(v.shape) < 4: 774 | # For old models that I trained prior to conv based patchification 775 | O, I, H, W = model.patch_embed.proj.weight.shape 776 | v = v.reshape(O, -1, H, W) 777 | if v.shape[-1] != W or v.shape[-2] != H: 778 | v = resample_patch_embed( 779 | v, 780 | (H, W), 781 | interpolation=interpolation, 782 | antialias=antialias, 783 | verbose=True, 784 | ) 785 | elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: 786 | # To resize pos embedding when using model at different size from pretrained weights 787 | num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) 788 | v = resample_abs_pos_embed( 789 | v, 790 | new_size=model.patch_embed.grid_size, 791 | num_prefix_tokens=num_prefix_tokens, 792 | interpolation=interpolation, 793 | antialias=antialias, 794 | verbose=True, 795 | ) 796 | elif adapt_layer_scale and 'gamma_' in k: 797 | # remap layer-scale gamma into sub-module (deit3 models) 798 | k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) 799 | elif 'pre_logits' in k: 800 | # NOTE representation layer removed as not used in latest 21k/1k pretrained weights 801 | continue 802 | out_dict[k] = v 803 | return out_dict 804 | 805 | def _create_vision_transformer(variant, pretrained=False, **kwargs): 806 | if kwargs.get('features_only', None): 807 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 808 | 809 | if 'flexi' in variant: 810 | # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed 811 | # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. 812 | _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) 813 | else: 814 | _filter_fn = checkpoint_filter_fn 815 | 816 | return build_model_with_cfg( 817 | VisionTransformer, 818 | variant, 819 | pretrained, 820 | pretrained_filter_fn=_filter_fn, 821 | **kwargs, 822 | ) 823 | 824 | --------------------------------------------------------------------------------