├── model ├── __init__.py ├── beit │ ├── __init__.py │ ├── beit_factory.py │ ├── beit_registry.py │ ├── beit.py │ └── beit_custom.py ├── reshape.py ├── model_factory.py ├── dpt │ ├── dpt.py │ ├── vit.py │ └── dpt_blocks.py ├── vtm.py └── attention.py ├── train ├── __init__.py ├── miou_fss.py ├── optim.py ├── loss.py ├── visualize.py ├── train_utils.py └── trainer.py ├── dataset ├── __init__.py ├── meta_info │ ├── class_dict.pth │ ├── edge_params.pth │ ├── idxs_perm_all.pth │ ├── depth_quantiles.pth │ ├── edge_thresholds.pth │ ├── idxs_perm_finetune.pth │ ├── class_perm_finetune_10.pth │ ├── class_perm_finetune_12.pth │ ├── class_perm_finetune_13.pth │ ├── class_perm_finetune_15.pth │ ├── class_perm_finetune_16.pth │ ├── class_perm_finetune_2.pth │ ├── class_perm_finetune_3.pth │ ├── class_perm_finetune_4.pth │ ├── class_perm_finetune_5.pth │ ├── class_perm_finetune_6.pth │ ├── class_perm_finetune_8.pth │ └── class_perm_finetune_9.pth ├── taskonomy_constants.py ├── resize_buildings.py ├── augmentation.py ├── utils.py ├── dataloader_factory.py └── taskonomy.py ├── VTM Overview.png ├── .gitignore ├── requirements.txt ├── configs ├── test_config.yaml ├── finetune_config.yaml └── train_config.yaml ├── LICENSE ├── main.py ├── print_results.py ├── README.md └── args.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/beit/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit_custom import * -------------------------------------------------------------------------------- /VTM Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/VTM Overview.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | experiments* 3 | model/pretrained_checkpoints* 4 | support_data.pth 5 | data_paths.yaml 6 | *ipynb* 7 | -------------------------------------------------------------------------------- /dataset/meta_info/class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/edge_params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/edge_params.pth -------------------------------------------------------------------------------- /dataset/meta_info/idxs_perm_all.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/idxs_perm_all.pth -------------------------------------------------------------------------------- /dataset/meta_info/depth_quantiles.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/depth_quantiles.pth -------------------------------------------------------------------------------- /dataset/meta_info/edge_thresholds.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/edge_thresholds.pth -------------------------------------------------------------------------------- /dataset/meta_info/idxs_perm_finetune.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/idxs_perm_finetune.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_10.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_12.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_12.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_13.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_13.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_15.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_15.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_16.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_16.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_2.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_3.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_4.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_5.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_6.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_6.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_8.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_8.pth -------------------------------------------------------------------------------- /dataset/meta_info/class_perm_finetune_9.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/visual_token_matching/HEAD/dataset/meta_info/class_perm_finetune_9.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | torchvision==0.13.1 3 | pytorch_lightning 4 | timm==0.5.4 5 | numpy 6 | scikit-image 7 | tensorboard 8 | tqdm 9 | einops 10 | -------------------------------------------------------------------------------- /configs/test_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | 6 | # data arguments 7 | dataset: taskonomy 8 | test_split: muleshoe 9 | num_workers: 4 10 | shot: 10 11 | eval_batch_size: 8 12 | n_eval_batches: -1 13 | img_size: 224 14 | support_idx: 0 15 | channel_idx: -1 16 | 17 | # model arguments 18 | model: VTM 19 | semseg_threshold: 0.2 20 | 21 | # logging arguments 22 | log_dir: TEST 23 | save_dir: FINETUNE 24 | load_dir: TRAIN 25 | load_step: 0 -------------------------------------------------------------------------------- /configs/finetune_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | 6 | # data arguments 7 | dataset: taskonomy 8 | num_workers: 4 9 | global_batch_size: 1 10 | shot: 10 11 | eval_batch_size: 5 12 | n_eval_batches: 2 13 | img_size: 224 14 | support_idx: 0 15 | channel_idx: -1 16 | 17 | # model arguments 18 | model: VTM 19 | semseg_threshold: 0.2 20 | attn_dropout: 0.5 21 | 22 | # training arguments 23 | n_steps: 20000 24 | n_schedule_steps: 20000 25 | optimizer: adam 26 | lr: 0.005 27 | lr_schedule: constant 28 | lr_warmup: 0 29 | lr_warmup_scale: 0. 30 | schedule_from: 0 31 | weight_decay: 0. 32 | lr_decay_degree: 0.9 33 | mask_value: -1. 34 | early_stopping_patience: 5 35 | 36 | # logging arguments 37 | log_dir: FINETUNE 38 | save_dir: FINETUNE 39 | load_dir: TRAIN 40 | log_iter: 100 41 | val_iter: 100 42 | save_iter: 100 43 | load_step: 0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Donggyun Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | 6 | # data arguments 7 | dataset: taskonomy 8 | task_fold: 0 9 | num_workers: 4 10 | global_batch_size: 8 11 | max_channels: 5 12 | shot: 4 13 | n_buildings: -1 14 | domains_per_batch: 2 15 | eval_batch_size: 8 16 | n_eval_batches: 10 17 | img_size: 224 18 | image_augmentation: True 19 | unary_augmentation: True 20 | binary_augmentation: True 21 | mixed_augmentation: True 22 | channel_idx: -1 23 | 24 | # model arguments 25 | model: VTM 26 | image_backbone: beit_base_patch16_224_in22k 27 | label_backbone: vit_base_patch16_224 28 | image_encoder_weights: imagenet 29 | drop_rate: 0. 30 | drop_path_rate: 0.1 31 | attn_drop_rate: 0. 32 | n_attn_heads: 4 33 | semseg_threshold: 0.2 34 | channel_idx: -1 35 | n_levels: 4 36 | bitfit: True 37 | 38 | # training arguments 39 | n_steps: 300000 40 | optimizer: adam 41 | lr: 0.0001 42 | lr_pretrained: 0.00001 43 | lr_schedule: poly 44 | lr_warmup: 5000 45 | lr_warmup_scale: 0. 46 | schedule_from: 0 47 | weight_decay: 0. 48 | lr_decay_degree: 0.9 49 | mask_value: -1. 50 | early_stopping_patience: -1 51 | 52 | # logging arguments 53 | log_dir: TRAIN 54 | save_dir: TRAIN 55 | load_dir: TRAIN 56 | log_iter: 100 57 | val_iter: 20000 58 | save_iter: 20000 59 | load_step: -1 -------------------------------------------------------------------------------- /model/reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | 4 | 5 | def get_reshaper(pattern): 6 | def reshaper(x, contiguous=False, **kwargs): 7 | if isinstance(x, torch.Tensor): 8 | x = rearrange(x, pattern, **kwargs) 9 | if contiguous: 10 | x = x.contiguous() 11 | return x 12 | elif isinstance(x, dict): 13 | return {key: reshaper(x[key], contiguous=contiguous, **kwargs) for key in x} 14 | elif isinstance(x, tuple): 15 | return tuple(reshaper(x_, contiguous=contiguous, **kwargs) for x_ in x) 16 | elif isinstance(x, list): 17 | return [reshaper(x_, contiguous=contiguous, **kwargs) for x_ in x] 18 | else: 19 | return x 20 | 21 | return reshaper 22 | 23 | 24 | from_6d_to_4d = get_reshaper('B T N C H W -> (B T N) C H W') 25 | from_4d_to_6d = get_reshaper('(B T N) C H W -> B T N C H W') 26 | 27 | from_6d_to_3d = get_reshaper('B T N C H W -> (B T) (N H W) C') 28 | from_3d_to_6d = get_reshaper('(B T) (N H W) C -> B T N C H W') 29 | 30 | 31 | def parse_BTN(x): 32 | if isinstance(x, torch.Tensor): 33 | B, T, N = x.size()[:3] 34 | elif isinstance(x, (tuple, list)): 35 | B, T, N = x[0].size()[:3] 36 | elif isinstance(x, dict): 37 | B, T, N = x[list(x.keys())[0]].size()[:3] 38 | else: 39 | raise ValueError(f'unsupported type: {type(x)}') 40 | 41 | return B, T, N 42 | 43 | 44 | def forward_6d_as_4d(func, x, t_idx=None, **kwargs): 45 | B, T, N = parse_BTN(x) 46 | 47 | x = from_6d_to_4d(x, contiguous=True) 48 | 49 | if t_idx is not None: 50 | t_idx = repeat(t_idx, 'B T -> (B T N)', N=N) 51 | x = func(x, t_idx=t_idx, **kwargs) 52 | else: 53 | x = func(x, **kwargs) 54 | 55 | x = from_4d_to_6d(x, B=B, T=T) 56 | 57 | return x 58 | -------------------------------------------------------------------------------- /model/model_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dataset.taskonomy_constants import TASKS, TASKS_GROUP_DICT 4 | 5 | from .beit.beit import BEiTEncoder, load_beit_ckpt 6 | from .dpt.dpt import DPT 7 | from .vtm import VTM, VTMImageBackbone, VTMLabelBackbone, VTMMatchingModule 8 | 9 | 10 | def get_model(config, device=None, verbose=True, load_pretrained=True): 11 | image_backbone = create_image_backbone(config, verbose=verbose, load_pretrained=load_pretrained) 12 | label_backbone = create_label_backbone(config) 13 | 14 | dim_w = image_backbone.dim_hidden 15 | dim_z = label_backbone.dim_hidden 16 | 17 | image_backbone = VTMImageBackbone(image_backbone) 18 | label_backbone = VTMLabelBackbone(label_backbone) 19 | matching_module = VTMMatchingModule(dim_w, dim_z, config) 20 | 21 | model = VTM(image_backbone, label_backbone, matching_module) 22 | 23 | if device is not None: 24 | model = model.to(device) 25 | 26 | return model 27 | 28 | 29 | def create_image_backbone(config, verbose=True, load_pretrained=True): 30 | if config.stage == 0: 31 | n_tasks = len(TASKS) 32 | else: 33 | if config.task == 'segment_semantic': 34 | n_tasks = 1 35 | else: 36 | n_tasks = len(TASKS_GROUP_DICT[config.task]) 37 | 38 | backbone = BEiTEncoder( 39 | config.image_backbone, 40 | drop_rate=config.drop_rate, 41 | drop_path_rate=config.drop_path_rate, 42 | attn_drop_rate=config.attn_drop_rate, 43 | n_tasks=n_tasks, 44 | n_levels=config.n_levels, 45 | bitfit=config.bitfit, 46 | ) 47 | backbone.dim_hidden = backbone.embed_dim 48 | 49 | if load_pretrained and config.image_encoder_weights == 'imagenet': 50 | ckpt_path = os.path.join('model/pretrained_checkpoints', 51 | f'{config.image_backbone.replace("in22k", "pt22k")}.pth') 52 | 53 | if getattr(config, 'bitfit', False): 54 | n_bitfit_tasks = n_tasks 55 | else: 56 | n_bitfit_tasks = 0 57 | 58 | load_beit_ckpt(backbone.beit, ckpt_path, n_bitfit_tasks=n_bitfit_tasks, verbose=verbose) 59 | 60 | return backbone 61 | 62 | 63 | def create_label_backbone(config): 64 | backbone = DPT(config.label_backbone, pretrained=False) 65 | backbone.dim_hidden = backbone.embed_dim 66 | 67 | return backbone 68 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | import torch 4 | import warnings 5 | 6 | from train.train_utils import configure_experiment, load_model 7 | from dataset.taskonomy_constants import SEMSEG_CLASSES 8 | 9 | 10 | def run(config): 11 | # set monitor name and postfix 12 | if config.stage == 0: 13 | config.monitor = 'summary/mtrain_valid_pred' 14 | else: 15 | if config.task == 'segment_semantic': 16 | config.monitor = f'mtest_support/segment_semantic_{config.channel_idx}_pred' 17 | if config.save_postfix == '': 18 | config.save_postfix = f'_task:segment_semantic_{config.channel_idx}' 19 | else: 20 | config.monitor = f'mtest_support/{config.task}_pred' 21 | if config.save_postfix == '': 22 | config.save_postfix = f'_task:{config.task}{config.save_postfix}' 23 | 24 | # load model 25 | model, ckpt_path = load_model(config, verbose=IS_RANK_ZERO) 26 | 27 | # environmental settings 28 | logger, log_dir, callbacks, precision, strategy, plugins = configure_experiment(config, model) 29 | if config.stage == 2: 30 | model.config.result_dir = log_dir 31 | 32 | # create pytorch lightning trainer. 33 | trainer = pl.Trainer( 34 | logger=logger, 35 | default_root_dir=log_dir, 36 | accelerator='gpu', 37 | max_epochs=((config.n_steps // config.val_iter) if (not config.no_eval) and config.stage <= 1 else 1), 38 | log_every_n_steps=-1, 39 | num_sanity_val_steps=0, 40 | callbacks=callbacks, 41 | benchmark=True, 42 | devices=-1, 43 | strategy=strategy, 44 | precision=precision, 45 | plugins=plugins, 46 | ) 47 | 48 | # validation at start 49 | if config.stage == 1: 50 | trainer.validate(model, verbose=False) 51 | # start training or fine-tuning 52 | if config.stage <= 1: 53 | trainer.fit(model, ckpt_path=ckpt_path) 54 | # start evaluation 55 | else: 56 | trainer.test(model) 57 | 58 | 59 | if __name__ == "__main__": 60 | torch.multiprocessing.freeze_support() 61 | torch.set_num_threads(1) 62 | warnings.filterwarnings("ignore", category=UserWarning) 63 | warnings.filterwarnings("ignore", category=pl.utilities.warnings.PossibleUserWarning) 64 | IS_RANK_ZERO = int(os.environ.get('LOCAL_RANK', 0)) == 0 65 | 66 | from args import config # parse arguments 67 | 68 | if config.stage >= 1 and config.task == 'segment_semantic' and config.channel_idx < 0: 69 | save_postfix = config.save_postfix 70 | for channel_idx in SEMSEG_CLASSES: 71 | config.save_postfix = save_postfix 72 | config.channel_idx = channel_idx 73 | run(config) 74 | else: 75 | run(config) 76 | 77 | -------------------------------------------------------------------------------- /dataset/taskonomy_constants.py: -------------------------------------------------------------------------------- 1 | # Building Splits 2 | BUILDINGS_TRAIN = ['allensville', 'beechwood', 'benevolence', 'coffeen', 'cosmos', 3 | 'forkland', 'hanson', 'hiteman', 'klickitat', 'lakeville', 4 | 'leonardo', 'lindenwood', 'marstons', 'merom', 'mifflinburg', 5 | 'newfields', 'onaga', 'pinesdale', 'pomaria', 'ranchester', 6 | 'shelbyville', 'stockman', 'tolstoy', 'wainscott', 'woodbine'] 7 | BUILDINGS_VALID = ['collierville', 'corozal', 'darden', 'markleeville', 'wiconisco'] 8 | BUILDINGS_TEST = ['ihlen', 'mcdade', 'muleshoe', 'noxapater', 'uvalda'] 9 | BUILDINGS = BUILDINGS_TRAIN + BUILDINGS_VALID + BUILDINGS_TEST 10 | 11 | # Class Splits for Semantic Segmentation 12 | SEMSEG_CLASSES = [2, 3, 4, 5, 6, 8, 9, 10, 12, 13, 15, 16] 13 | SEMSEG_CLASS_RANGE = range(1, 17) 14 | 15 | # Task Type Grouping 16 | TASKS_SEMSEG = [f'segment_semantic_{c}' for c in SEMSEG_CLASSES] 17 | TASKS_DEPTHE = [f'depth_euclidean_{c}' for c in range(5)] 18 | TASKS_DEPTHZ = ['depth_zbuffer_0'] 19 | TASKS_EDGE2D = [f'edge_texture_{c}' for c in range(3)] 20 | TASKS_EDGE3D = [f'edge_occlusion_{c}' for c in range(5)] 21 | TASKS_KEYPOINTS2D = ['keypoints2d_0'] 22 | TASKS_KEYPOINTS3D = ['keypoints3d_0'] 23 | TASKS_NORMAL = [f'normal_{c}' for c in range(3)] 24 | TASKS_RESHADING = ['reshading_0'] 25 | TASKS_CURVATURE = [f'principal_curvature_{c}' for c in range(2)] 26 | 27 | # All Tasks 28 | TASKS = TASKS_SEMSEG + TASKS_DEPTHE + TASKS_DEPTHZ + \ 29 | TASKS_EDGE2D + TASKS_EDGE3D + TASKS_KEYPOINTS2D + TASKS_KEYPOINTS3D + \ 30 | TASKS_NORMAL + TASKS_RESHADING + TASKS_CURVATURE 31 | 32 | # Train and Test Tasks - can be splitted in other ways 33 | N_SPLITS = 5 34 | TASKS_GROUP_NAMES = ["segment_semantic", "normal", "depth_euclidean", "depth_zbuffer", "edge_texture", "edge_occlusion", "keypoints2d", "keypoints3d", "reshading", "principal_curvature"] 35 | TASKS_GROUP_LIST = [TASKS_SEMSEG, TASKS_NORMAL, TASKS_DEPTHE, TASKS_DEPTHZ, TASKS_EDGE2D, TASKS_EDGE3D, TASKS_KEYPOINTS2D, TASKS_KEYPOINTS3D, TASKS_RESHADING, TASKS_CURVATURE] 36 | TASKS_GROUP_DICT = {name: group for name, group in zip(TASKS_GROUP_NAMES, TASKS_GROUP_LIST)} 37 | 38 | N_TASK_GROUPS = len(TASKS_GROUP_NAMES) 39 | GROUP_UNIT = N_TASK_GROUPS // N_SPLITS 40 | 41 | TASKS_GROUP_TRAIN = {} 42 | TASKS_GROUP_TEST = {} 43 | for split_idx in range(N_SPLITS): 44 | TASKS_GROUP_TRAIN[split_idx] = TASKS_GROUP_NAMES[:-GROUP_UNIT*(split_idx+1)] + (TASKS_GROUP_NAMES[-GROUP_UNIT*split_idx:] if split_idx > 0 else []) 45 | TASKS_GROUP_TEST[split_idx] = TASKS_GROUP_NAMES[-GROUP_UNIT*(split_idx+1):-GROUP_UNIT*split_idx] if split_idx > 0 else TASKS_GROUP_NAMES[-GROUP_UNIT*(split_idx+1):] 46 | 47 | N_TASKS = len(TASKS) 48 | 49 | 50 | CLASS_NAMES = ['bottle', 'chair', 'couch', 'plant', 51 | 'bed', 'd.table', 'toilet', 'tv', 'microw', 52 | 'oven', 'toaster', 'sink', 'fridge', 'book', 53 | 'clock', 'vase'] -------------------------------------------------------------------------------- /train/miou_fss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset.taskonomy_constants import SEMSEG_CLASSES 3 | 4 | 5 | class AverageMeter: 6 | r""" Stores loss, evaluation results """ 7 | def __init__(self, class_ids_interest, device=None): 8 | if device is None: 9 | device = torch.device('cpu') 10 | if isinstance(class_ids_interest, int): 11 | self.class_ids_interest = torch.tensor([class_ids_interest], device=device) 12 | else: 13 | self.class_ids_interest = torch.tensor(class_ids_interest, device=device) 14 | 15 | self.nclass = len(SEMSEG_CLASSES) 16 | 17 | self.intersection_buf = torch.zeros([2, self.nclass], device=device).float() 18 | self.union_buf = torch.zeros([2, self.nclass], device=device).float() 19 | self.ones = torch.ones_like(self.union_buf) 20 | self.loss_buf = [] 21 | 22 | def update(self, inter_b, union_b, class_id): 23 | self.intersection_buf.index_add_(1, class_id, inter_b.float()) 24 | self.union_buf.index_add_(1, class_id, union_b.float()) 25 | 26 | def class_iou(self, class_id): 27 | iou = self.intersection_buf.float() / \ 28 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] 29 | iou = iou.index_select(1, torch.tensor([class_id], device=iou.device)) 30 | miou = iou[1].mean() 31 | 32 | return miou 33 | 34 | def compute_iou(self): 35 | iou = self.intersection_buf.float() / \ 36 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] 37 | iou = iou.index_select(1, self.class_ids_interest) 38 | miou = iou[1].mean() 39 | 40 | fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) / 41 | self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() 42 | 43 | return miou, fb_iou 44 | 45 | 46 | class Evaluator: 47 | r""" Computes intersection and union between prediction and ground-truth """ 48 | @classmethod 49 | def initialize(cls): 50 | pass 51 | 52 | @classmethod 53 | def classify_prediction(cls, pred_mask, gt_mask): 54 | # compute intersection and union of each episode in a batch 55 | area_inter, area_pred, area_gt = [], [], [] 56 | for _pred_mask, _gt_mask in zip(pred_mask, gt_mask): 57 | _inter = _pred_mask[_pred_mask == _gt_mask] 58 | if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1) 59 | _area_inter = torch.tensor([0, 0], device=_pred_mask.device) 60 | else: 61 | _area_inter = torch.histc(_inter, bins=2, min=0, max=1) 62 | area_inter.append(_area_inter) 63 | area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1)) 64 | area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1)) 65 | area_inter = torch.stack(area_inter).t() 66 | area_pred = torch.stack(area_pred).t() 67 | area_gt = torch.stack(area_gt).t() 68 | area_union = area_pred + area_gt - area_inter 69 | 70 | return area_inter, area_union 71 | -------------------------------------------------------------------------------- /dataset/resize_buildings.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import os 3 | import torch 4 | import tqdm 5 | import yaml 6 | from PIL import Image 7 | 8 | building_list = [ 9 | 'allensville', 10 | 'beechwood', 11 | 'benevolence', 12 | 'collierville', 13 | 'coffeen', 14 | 'corozal', 15 | 'cosmos', 16 | 'darden', 17 | 'forkland', 18 | 'hanson', 19 | 'hiteman', 20 | 'ihlen', 21 | 'klickitat', 22 | 'lakeville', 23 | 'leonardo', 24 | 'lindenwood', 25 | 'markleeville', 26 | 'marstons', 27 | 'mcdade', 28 | 'merom', 29 | 'mifflinburg', 30 | 'muleshoe', 31 | 'newfields', 32 | 'noxapater', 33 | 'onaga', 34 | 'pinesdale', 35 | 'pomaria', 36 | 'ranchester', 37 | 'shelbyville', 38 | 'stockman', 39 | 'tolstoy', 40 | 'uvalda', 41 | 'wainscott', 42 | 'wiconisco', 43 | 'woodbine', 44 | ] 45 | 46 | task_list = [ 47 | 'rgb', 48 | 'normal', 49 | 'depth_euclidean', 50 | 'depth_zbuffer', 51 | 'edge_occlusion', 52 | 'keypoints2d', 53 | 'keypoints3d', 54 | 'reshading', 55 | 'principal_curvature', 56 | 'segment_semantic' 57 | ] 58 | 59 | 60 | def resize(args): 61 | load_path, save_path, mode = args 62 | try: 63 | img = Image.open(load_path) 64 | img = img.resize(size, mode) 65 | img.save(save_path) 66 | return None 67 | except Exception as e: 68 | print(e) 69 | return load_path 70 | 71 | 72 | if __name__ == "__main__": 73 | verbose = True 74 | size = (256, 256) 75 | split = "tiny" 76 | n_threads = 20 77 | 78 | with open('data_paths.yaml', 'r') as f: 79 | path_dict = yaml.safe_load(f) 80 | load_root = save_root = path_dict['taskonomy'] 81 | 82 | load_dir = os.path.join(load_root, split) 83 | assert os.path.isdir(load_dir) 84 | ''' 85 | load_dir 86 | |--building 87 | |--task 88 | |--file 89 | ''' 90 | save_dir = os.path.join(save_root, f"{split}_{size[0]}_merged") 91 | os.makedirs(save_dir, exist_ok=True) 92 | ''' 93 | save_dir 94 | |--task 95 | |--file 96 | ''' 97 | 98 | args = [] 99 | print("creating args...") 100 | for b_idx, building in enumerate(building_list): 101 | assert os.path.isdir(os.path.join(load_dir, building)) 102 | for task in task_list: 103 | mode = Image.NEAREST if task == "segment_semantic" else Image.BILINEAR 104 | if b_idx == 0: 105 | os.makedirs(os.path.join(save_dir, task), exist_ok=True) 106 | 107 | load_names = os.listdir(os.path.join(load_dir, building, task)) 108 | load_paths = [os.path.join(load_dir, building, task, load_name) for load_name in load_names] 109 | save_paths = [os.path.join(save_dir, task, f'{building}_{load_name}') for load_name in load_names] 110 | modes = [mode]*len(load_names) 111 | args += list(zip(load_paths, save_paths, modes)) 112 | 113 | fail_list = [] 114 | pool = Pool(n_threads) 115 | total = len(args) 116 | pbar = tqdm.tqdm(total=total, bar_format="{desc:<5}{percentage:3.0f}%|{bar:10}{r_bar}") 117 | for fail_path in pool.imap(resize, args): 118 | if fail_path is not None: 119 | fail_list += [fail_path] 120 | pbar.update() 121 | pbar.close() 122 | 123 | torch.save(fail_list, "fail_list.pth") 124 | 125 | pool.close() 126 | pool.join() 127 | -------------------------------------------------------------------------------- /model/beit/beit_factory.py: -------------------------------------------------------------------------------- 1 | from .beit_registry import is_model, model_entrypoint 2 | from timm.models.helpers import load_checkpoint 3 | from timm.models.layers import set_layer_config 4 | from timm.models.hub import load_model_config_from_hf 5 | 6 | 7 | def split_model_name(model_name): 8 | model_split = model_name.split(':', 1) 9 | if len(model_split) == 1: 10 | return '', model_split[0] 11 | else: 12 | source_name, model_name = model_split 13 | assert source_name in ('timm', 'hf_hub') 14 | return source_name, model_name 15 | 16 | 17 | def safe_model_name(model_name, remove_source=True): 18 | def make_safe(name): 19 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 20 | if remove_source: 21 | model_name = split_model_name(model_name)[-1] 22 | return make_safe(model_name) 23 | 24 | 25 | def create_model( 26 | model_name, 27 | pretrained=False, 28 | checkpoint_path='', 29 | scriptable=None, 30 | exportable=None, 31 | no_jit=None, 32 | **kwargs): 33 | """Create a model 34 | Args: 35 | model_name (str): name of model to instantiate 36 | pretrained (bool): load pretrained ImageNet-1k weights if true 37 | checkpoint_path (str): path of checkpoint to load after model is initialized 38 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 39 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 40 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 41 | Keyword Args: 42 | drop_rate (float): dropout rate for training (default: 0.0) 43 | global_pool (str): global pool type (default: 'avg') 44 | **: other kwargs are model specific 45 | """ 46 | source_name, model_name = split_model_name(model_name) 47 | 48 | # handle backwards compat with drop_connect -> drop_path change 49 | drop_connect_rate = kwargs.pop('drop_connect_rate', None) 50 | if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: 51 | print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." 52 | " Setting drop_path to %f." % drop_connect_rate) 53 | kwargs['drop_path_rate'] = drop_connect_rate 54 | 55 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 56 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 57 | # non-supporting models don't break and default args remain in effect. 58 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 59 | 60 | if source_name == 'hf_hub': 61 | # For model names specified in the form `hf_hub:path/architecture_name#revision`, 62 | # load model weights + default_cfg from Hugging Face hub. 63 | hf_default_cfg, model_name = load_model_config_from_hf(model_name) 64 | kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday 65 | 66 | if is_model(model_name): 67 | create_fn = model_entrypoint(model_name) 68 | else: 69 | raise RuntimeError('Unknown model (%s)' % model_name) 70 | 71 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 72 | model = create_fn(pretrained=pretrained, **kwargs) 73 | 74 | if checkpoint_path: 75 | load_checkpoint(model, checkpoint_path) 76 | 77 | return model -------------------------------------------------------------------------------- /print_results.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import math 4 | import os 5 | from dataset.taskonomy_constants import TASKS_GROUP_TEST, SEMSEG_CLASSES 6 | from train.miou_fss import AverageMeter 7 | import argparse 8 | 9 | 10 | def create_table(model, tasks, ptf, print_failure=False): 11 | result_root = os.path.join('experiments', args.result_dir) 12 | 13 | df = pd.DataFrame(index=[model], columns=[task_tags[task] for task in tasks]) 14 | for task in tasks: 15 | task_tag = task_tags[task] 16 | exp_name = f'{model}_fold:{fold_dict[task]}{ptf}' 17 | exp_dir = os.path.join(result_root, exp_name) 18 | if not os.path.exists(exp_dir): 19 | continue 20 | if task == 'segment_semantic': 21 | success = True 22 | average_meter = AverageMeter(range(len(SEMSEG_CLASSES))) 23 | for i, c in enumerate(SEMSEG_CLASSES): 24 | result_name = f'result_task:{task}_{c}_split:{args.test_split}.pth' 25 | result_path = os.path.join(result_root, exp_name, 'logs', result_name) 26 | try: 27 | average_meter_c = torch.load(result_path, map_location='cpu') 28 | assert isinstance(average_meter_c, AverageMeter) 29 | except: 30 | success = False 31 | break 32 | 33 | average_meter.intersection_buf[:, i] += average_meter_c.intersection_buf[:, 0].cpu() 34 | average_meter.union_buf[:, i] += average_meter_c.union_buf[:, 0].cpu() 35 | 36 | if success: 37 | df.loc[model][task_tag] = average_meter.compute_iou()[0].cpu().item() 38 | elif print_failure: 39 | print(result_path) 40 | else: 41 | result_name = f'result_task:{task}_split:{args.test_split}.pth' 42 | result_path = os.path.join(result_root, exp_name, 'logs', result_name) 43 | if os.path.exists(result_path): 44 | result = torch.load(result_path) 45 | df.loc[model][task_tag] = result 46 | elif print_failure: 47 | print(result_path) 48 | 49 | return df 50 | 51 | 52 | if __name__ == '__main__': 53 | from dataset.taskonomy_constants import TASKS_GROUP_NAMES 54 | 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--result_dir', type=str, default='TEST') 57 | parser.add_argument('--test_split', type=str, default='muleshoe') 58 | parser.add_argument('--model', type=str, default='VTM') 59 | parser.add_argument('--name_postfix', '-ptf', type=str, default='') 60 | parser.add_argument('--task', type=str, default='all', choices=['all'] + TASKS_GROUP_NAMES) 61 | args = parser.parse_args() 62 | 63 | task_tags = { 64 | 'segment_semantic': 'Semantic Segmentation (mIoU ↑)', 65 | 'normal': 'Surface Normal (mErr ↓)', 66 | 'depth_euclidean': 'Euclidean Distance (RMSE ↓)', 67 | 'depth_zbuffer': 'Zbuffer Depth (RMSE ↓)', 68 | 'edge_texture': 'Texture Edge (RMSE ↓)', 69 | 'edge_occlusion': 'Occlusion Edge (RMSE ↓)', 70 | 'keypoints2d': '2D Keypoints (RMSE ↓)', 71 | 'keypoints3d': '3D Keypoints (RMSE ↓)', 72 | 'reshading': 'Reshading (RMSE ↓)', 73 | 'principal_curvature': 'Principal Curvature (RMSE ↓)', 74 | } 75 | fold_dict = {} 76 | for fold in TASKS_GROUP_TEST: 77 | for task in TASKS_GROUP_TEST[fold]: 78 | fold_dict[task] = fold 79 | 80 | if args.task == 'all': 81 | tasks = TASKS_GROUP_NAMES 82 | else: 83 | tasks = [args.task] 84 | 85 | pd.set_option('max_columns', None) 86 | df = create_table(args.model, tasks, args.name_postfix, print_failure=False) 87 | print(df) 88 | -------------------------------------------------------------------------------- /train/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | optim_dict = { 6 | 'sgd': torch.optim.SGD, 7 | 'adam': torch.optim.Adam, 8 | 'adamw': torch.optim.AdamW, 9 | } 10 | 11 | 12 | def get_optimizer(config, model): 13 | learnable_params = [] 14 | 15 | # train all parameters for episodic training 16 | if config.stage == 0: 17 | learnable_params.append({'params': model.pretrained_parameters(), 'lr': config.lr_pretrained}) 18 | learnable_params.append({'params': model.scratch_parameters(), 'lr': config.lr}) 19 | 20 | # train only task-specific parameters for fine-tuning 21 | elif config.stage == 1: 22 | learnable_params.append({'params': model.bias_parameters(), 'lr': config.lr}) 23 | 24 | kwargs = {} 25 | if config.optimizer == 'sgd': 26 | kwargs['momentum'] = 0.9 27 | optimizer = optim_dict[config.optimizer](learnable_params, weight_decay=config.weight_decay, **kwargs) 28 | if config.lr_warmup >= 0: 29 | lr_warmup = config.lr_warmup 30 | else: 31 | assert config.lr_warmup_scale >= 0. and config.lr_warmup_scale <= 1. 32 | lr_warmup = int(config.lr_warmup_scale * config.n_steps) 33 | lr_scheduler = CustomLRScheduler(optimizer, config.lr_schedule, config.lr, config.n_steps, lr_warmup, 34 | decay_degree=config.lr_decay_degree) 35 | 36 | return optimizer, lr_scheduler 37 | 38 | 39 | class CustomLRScheduler(object): 40 | ''' 41 | Custom learning rate scheduler for pytorch optimizer. 42 | Assumes 1 <= self.iter <= 1 + num_iters. 43 | ''' 44 | 45 | def __init__(self, optimizer, mode, base_lr, num_iters, warmup_iters=1000, 46 | from_iter=0, decay_degree=0.9, decay_steps=5000): 47 | self.optimizer = optimizer 48 | self.mode = mode 49 | self.base_lr = base_lr 50 | self.lr = base_lr 51 | self.iter = from_iter 52 | self.N = num_iters + 1 53 | self.warmup_iters = warmup_iters 54 | self.decay_degree = decay_degree 55 | self.decay_steps = decay_steps 56 | 57 | self.lr_coefs = [] 58 | for param_group in optimizer.param_groups: 59 | self.lr_coefs.append(param_group['lr'] / base_lr) 60 | 61 | def step(self, step=-1): 62 | # updatae current step 63 | if step >= 0: 64 | self.iter = step 65 | else: 66 | self.iter += 1 67 | 68 | # schedule lr 69 | if self.mode == 'cos': 70 | self.lr = 0.5 * self.base_lr * (1 + math.cos(1.0 * self.iter / self.N * math.pi)) 71 | elif self.mode == 'poly': 72 | if self.iter < self.N: 73 | self.lr = self.base_lr * pow((1 - 1.0 * self.iter / self.N), self.decay_degree) 74 | elif self.mode == 'step': 75 | self.lr = self.base_lr * (0.1**(self.decay_steps // self.iter)) 76 | elif self.mode == 'constant': 77 | self.lr = self.base_lr 78 | elif self.mode == 'sqroot': 79 | self.lr = self.base_lr * self.warmup_iters**0.5 * min(self.iter * self.warmup_iters**-1.5, self.iter**-0.5) 80 | else: 81 | raise NotImplementedError 82 | 83 | # warm up lr schedule 84 | if self.warmup_iters > 0 and self.iter < self.warmup_iters and self.mode != 'sqroot': 85 | self.lr = self.base_lr * 1.0 * self.iter / self.warmup_iters 86 | assert self.lr >= 0 87 | 88 | # adjust lr 89 | self._adjust_learning_rate(self.optimizer, self.lr) 90 | 91 | def _adjust_learning_rate(self, optimizer, lr): 92 | for i in range(len(optimizer.param_groups)): 93 | optimizer.param_groups[i]['lr'] = lr * self.lr_coefs[i] 94 | 95 | def reset(self): 96 | self.lr = self.base_lr 97 | self.iter = 0 98 | self._adjust_learning_rate(self.optimizer, self.lr) 99 | -------------------------------------------------------------------------------- /model/dpt/dpt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from einops import rearrange 3 | from .dpt_blocks import _make_fusion_block, Interpolate, _make_encoder 4 | 5 | 6 | backbone_dict = { 7 | 'vit_base_patch16_224': 'vitb16_224', 8 | 'vit_base_patch16_384': 'vitb16_384', 9 | 'vit_large_patch16_384': 'vitl16_384', 10 | } 11 | 12 | 13 | class DPT(nn.Module): 14 | def __init__(self, 15 | model_name='vit_base_patch16_224', 16 | features=256, 17 | use_bn=False, 18 | pretrained=True, 19 | in_chans=1, 20 | out_chans=1 21 | ): 22 | super().__init__() 23 | # Instantiate backbone and reassemble blocks 24 | self.pretrained, self.scratch = _make_encoder( 25 | backbone_dict[model_name], 26 | features, 27 | pretrained, # Set to true of you want to train from scratch, uses ImageNet weights 28 | groups=1, 29 | expand=False, 30 | in_chans=in_chans, 31 | ) 32 | 33 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 34 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 35 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 36 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 37 | 38 | self.scratch.output_conv = nn.Sequential( 39 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 40 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 41 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 42 | nn.ReLU(True), 43 | nn.Conv2d(32, out_chans, kernel_size=1, stride=1, padding=0), 44 | ) 45 | 46 | size = [int(model_name.split('_')[-1])]*2 47 | self.embed_dim = self.pretrained.model.embed_dim 48 | self.patch_size = (size[0] // 16, size[1] // 16) 49 | self.n_levels = 4 50 | self.feature_blocks = [level * (len(self.pretrained.model.blocks) // self.n_levels) - 1 for level in range(1, self.n_levels+1)] 51 | 52 | def pretrained_parameters(self): 53 | return self.pretrained.parameters() 54 | 55 | def scratch_parameters(self): 56 | return self.scratch.parameters() 57 | 58 | def encode(self, x, t_idx=None): 59 | glob = self.pretrained.model.forward_flex(x, t_idx=t_idx) 60 | 61 | layer_1 = rearrange(self.pretrained.activations["1"][:, 1:], 'B (H W) C -> B C H W', H=self.patch_size[0]) 62 | layer_2 = rearrange(self.pretrained.activations["2"][:, 1:], 'B (H W) C -> B C H W', H=self.patch_size[0]) 63 | layer_3 = rearrange(self.pretrained.activations["3"][:, 1:], 'B (H W) C -> B C H W', H=self.patch_size[0]) 64 | layer_4 = rearrange(self.pretrained.activations["4"][:, 1:], 'B (H W) C -> B C H W', H=self.patch_size[0]) 65 | 66 | return layer_1, layer_2, layer_3, layer_4 67 | 68 | def decode(self, features, t_idx=None): 69 | layer_1, layer_2, layer_3, layer_4 = features 70 | 71 | layer_1 = self.pretrained.act_postprocess1(layer_1) 72 | layer_2 = self.pretrained.act_postprocess2(layer_2) 73 | layer_3 = self.pretrained.act_postprocess3(layer_3) 74 | layer_4 = self.pretrained.act_postprocess4(layer_4) 75 | 76 | layer_1_rn = self.scratch.layer1_rn(layer_1) 77 | layer_2_rn = self.scratch.layer2_rn(layer_2) 78 | layer_3_rn = self.scratch.layer3_rn(layer_3) 79 | layer_4_rn = self.scratch.layer4_rn(layer_4) 80 | 81 | path_4 = self.scratch.refinenet4(layer_4_rn) 82 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 83 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 84 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 85 | 86 | out = self.scratch.output_conv(path_1) 87 | 88 | return out 89 | 90 | def forward(self, x, t_idx=None): 91 | x = self.encode(x, t_idx=t_idx) 92 | x = self.decode(x, t_idx=t_idx) 93 | 94 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Token Matching 2 | 3 | (Update 08/23, 2023) 4 | We have uploaded pretrained checkpoints in [this link](https://drive.google.com/drive/folders/1FY_zM_x_isBP_I80gJWI9i5Sl6-IR1CC?usp=sharing). 5 | Meta-trained checkpoints for each fold are included in `TRAIN` directory, and fine-tuned checkpoints for each task (and each channel) are included in `FINETUNE` directory. 6 | 7 | **(News) Our paper received the [Outstanding Paper Award](https://blog.iclr.cc/2023/03/21/announcing-the-iclr-2023-outstanding-paper-award-recipients/) in ICLR 2023!** 8 | 9 | This repository contains official code for [Universal Few-shot Learning of Dense Prediction Tasks with Visual Token Matching](https://openreview.net/forum?id=88nT0j5jAn) (ICLR 2023 oral). 10 | 11 | ![image-VTM](https://github.com/GitGyun/visual_token_matching/blob/5c1ddd730dac9e82601e5032c973a9ee0c5bdf4b/VTM%20Overview.png) 12 | 13 | ## Setup 14 | 1. Download Taskonomy Dataset (tiny split) from the official github page https://github.com/StanfordVL/taskonomy/tree/master/data. 15 | * You may download data of `depth_euclidean`, `depth_zbuffer`, `edge_occlusion`, `keypoints2d`, `keypoints3d`, `normal`, `principal_curvature`, `reshading`, `segment_semantic`, and `rgb`. 16 | * (Optional) Resize the images and labels into (256, 256) resolution. 17 | * To reduce the I/O bottleneck of dataloader, we stored data from all buildings in a single directory. The directory structure looks like: 18 | ``` 19 | 20 | |-- 21 | | |--_ 22 | | | ... 23 | | |--_ 24 | | |... 25 | | 26 | |-- 27 | | |--_ 28 | | | ... 29 | | |--_ 30 | | |... 31 | | 32 | |... 33 | ``` 34 | 35 | 2. Create `data_paths.yaml` file and write the root directory path (`` in the above structure) by `taskonomy: PATH_TO_YOUR_TASKONOMY`. 36 | 37 | 3. Install pre-requirements by `pip install -r requirements.txt`. 38 | 39 | 4. Create `model/pretrained_checkpoints` directory and download [BEiT pre-trained checkpoints](https://github.com/microsoft/unilm/tree/master/beit) to the directory. 40 | * We used `beit_base_patch16_224_pt22k` checkpoint for our experiment. 41 | 42 | ## Usage 43 | 44 | ### Training 45 | ``` 46 | python main.py --stage 0 --task_fold [0/1/2/3/4] 47 | ``` 48 | 49 | ### Fine-tuning 50 | 51 | ``` 52 | python main.py --stage 1 --task [segment_semantic/normal/depth_euclidean/depth_zbuffer/edge_texture/edge_occlusion/keypoints2d/keypoints3d/reshading/principal_curvature] 53 | ``` 54 | 55 | ### Evaluation 56 | 57 | ``` 58 | python main.py --stage 2 --task [segment_semantic/normal/depth_euclidean/depth_zbuffer/edge_texture/edge_occlusion/keypoints2d/keypoints3d/reshading/principal_curvature] 59 | ``` 60 | After the evaluation, you can print the test results by running `python print_results.py` 61 | 62 | ## References 63 | Our code refers the following repositores: 64 | * [Taskonomy](https://github.com/StanfordVL/taskonomy) 65 | * [timm](https://github.com/huggingface/pytorch-image-models/tree/0.5.x) 66 | * [BEiT: BERT Pre-Training of Image Transformers](https://github.com/microsoft/unilm/tree/master/beit) 67 | * [Vision Transformers for Dense Prediction](https://github.com/isl-org/DPT) 68 | * [Inverted Pyramid Multi-task Transformer for Dense Scene Understanding](https://github.com/prismformore/Multi-Task-Transformer/tree/main/InvPT) 69 | * [Hypercorrelation Squeeze for Few-Shot Segmentation](https://github.com/juhongm999/hsnet) 70 | * [Cost Aggregation with 4D Convolutional Swin Transformer for Few-Shot Segmentation](https://github.com/Seokju-Cho/Volumetric-Aggregation-Transformer) 71 | 72 | ## Citation 73 | If you find this work useful, please consider citing: 74 | ```bib 75 | @inproceedings{kim2023universal, 76 | title={Universal Few-shot Learning of Dense Prediction Tasks with Visual Token Matching}, 77 | author={Donggyun Kim and Jinwoo Kim and Seongwoong Cho and Chong Luo and Seunghoon Hong}, 78 | booktitle={International Conference on Learning Representations}, 79 | year={2023}, 80 | url={https://openreview.net/forum?id=88nT0j5jAn} 81 | } 82 | ``` 83 | 84 | ## Acknowledgements 85 | The development of this open-sourced code was supported in part by the National Research Foundation of Korea (NRF) (No. 2021R1A4A3032834). 86 | -------------------------------------------------------------------------------- /model/vtm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .reshape import forward_6d_as_4d, from_6d_to_3d, from_3d_to_6d 3 | from .attention import CrossAttention 4 | 5 | 6 | class VTMImageBackbone(nn.Module): 7 | def __init__(self, backbone): 8 | super().__init__() 9 | self.backbone = backbone 10 | 11 | def forward(self, x, t_idx=None, **kwargs): 12 | return forward_6d_as_4d(self.backbone, x, t_idx=t_idx, get_features=True, **kwargs) 13 | 14 | def bias_parameters(self): 15 | # bias parameters for similarity adaptation 16 | for p in self.backbone.bias_parameters(): 17 | yield p 18 | 19 | def bias_parameter_names(self): 20 | return [f'backbone.{name}' for name in self.backbone.bias_parameter_names()] 21 | 22 | 23 | class VTMLabelBackbone(nn.Module): 24 | def __init__(self, backbone): 25 | super().__init__() 26 | self.backbone = backbone 27 | 28 | def encode(self, x, t_idx=None, **kwargs): 29 | return forward_6d_as_4d(self.backbone.encode, x, t_idx=t_idx, **kwargs) 30 | 31 | def decode(self, x, t_idx=None, **kwargs): 32 | return forward_6d_as_4d(self.backbone.decode, x, t_idx=t_idx, **kwargs) 33 | 34 | def forward(self, x, t_idx=None, encode_only=False, decode_only=False, **kwargs): 35 | assert not (encode_only and decode_only) 36 | if not decode_only: 37 | x = self.encode(x, t_idx=t_idx, **kwargs) 38 | if not encode_only: 39 | x = self.decode(x, t_idx=t_idx, **kwargs) 40 | 41 | return x 42 | 43 | 44 | class VTMMatchingModule(nn.Module): 45 | def __init__(self, dim_w, dim_z, config): 46 | super().__init__() 47 | self.matching = nn.ModuleList([CrossAttention(dim_w, dim_z, dim_z, num_heads=config.n_attn_heads) 48 | for i in range(config.n_levels)]) 49 | self.n_levels = config.n_levels 50 | 51 | def forward(self, W_Qs, W_Ss, Z_Ss, attn_mask=None): 52 | B, T, N, _, H, W = W_Ss[-1].size() 53 | 54 | assert len(W_Qs) == self.n_levels 55 | 56 | if attn_mask is not None: 57 | attn_mask = from_6d_to_3d(attn_mask) 58 | 59 | Z_Qs = [] 60 | for level in range(self.n_levels): 61 | Q = from_6d_to_3d(W_Qs[level]) 62 | K = from_6d_to_3d(W_Ss[level]) 63 | V = from_6d_to_3d(Z_Ss[level]) 64 | 65 | O = self.matching[level](Q, K, V, N=N, H=H, mask=attn_mask) 66 | Z_Q = from_3d_to_6d(O, B=B, T=T, H=H, W=W) 67 | Z_Qs.append(Z_Q) 68 | 69 | return Z_Qs 70 | 71 | 72 | class VTM(nn.Module): 73 | def __init__(self, image_backbone, label_backbone, matching_module): 74 | super().__init__() 75 | self.image_backbone = image_backbone 76 | self.label_backbone = label_backbone 77 | self.matching_module = matching_module 78 | 79 | self.n_levels = self.image_backbone.backbone.n_levels 80 | 81 | def bias_parameters(self): 82 | # bias parameters for similarity adaptation 83 | for p in self.image_backbone.bias_parameters(): 84 | yield p 85 | 86 | def bias_parameter_names(self): 87 | return [f'image_backbone.{name}' for name in self.image_backbone.bias_parameter_names()] 88 | 89 | def pretrained_parameters(self): 90 | return self.image_backbone.parameters() 91 | 92 | def scratch_parameters(self): 93 | modules = [self.label_backbone, self.matching_module] 94 | for module in modules: 95 | for p in module.parameters(): 96 | yield p 97 | 98 | def forward(self, X_S, Y_S, X_Q, t_idx=None, sigmoid=True): 99 | # encode support input, query input, and support output 100 | W_Ss = self.image_backbone(X_S, t_idx) 101 | W_Qs = self.image_backbone(X_Q, t_idx) 102 | Z_Ss = self.label_backbone.encode(Y_S, t_idx) 103 | 104 | # mix support output by matching 105 | Z_Q_preds = self.matching_module(W_Qs, W_Ss, Z_Ss) 106 | 107 | # decode support output 108 | Y_Q_pred = self.label_backbone.decode(Z_Q_preds, t_idx) 109 | 110 | if sigmoid: 111 | Y_Q_pred = Y_Q_pred.sigmoid() 112 | 113 | return Y_Q_pred 114 | -------------------------------------------------------------------------------- /train/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | import math 5 | 6 | from dataset.taskonomy_constants import TASKS, TASKS_SEMSEG, SEMSEG_CLASSES 7 | from .miou_fss import Evaluator 8 | 9 | 10 | SEMSEG_IDXS = [TASKS.index(task) for task in TASKS_SEMSEG] 11 | 12 | 13 | def generate_semseg_mask(t_idx): 14 | ''' 15 | Generate binary mask whether the task is semantic segmentation (1) or not (0). 16 | ''' 17 | semseg_mask = torch.zeros_like(t_idx, dtype=bool) 18 | for semseg_idx in SEMSEG_IDXS: 19 | semseg_mask = torch.logical_or(semseg_mask, t_idx == semseg_idx) 20 | 21 | return semseg_mask 22 | 23 | 24 | def hybrid_loss(Y_src, Y_tgt, M, t_idx): 25 | ''' 26 | Compute l1 loss for continuous tasks and bce loss for semantic segmentation. 27 | [loss_args] 28 | Y_src: unnormalized prediction of shape (B, T, N, 1, H, W) 29 | Y_tgt: normalized GT of shape (B, T, N, 1, H, W) 30 | M : mask for loss computation of shape (B, T, N, 1, H, W) 31 | t_idx: task index of shape (B, T) 32 | ''' 33 | # prediction loss 34 | loss_seg = F.binary_cross_entropy_with_logits(Y_src, Y_tgt, reduction='none') 35 | loss_con = F.l1_loss(Y_src.sigmoid(), Y_tgt, reduction='none') 36 | 37 | # loss masking 38 | loss_seg = rearrange((M * loss_seg), 'B T ... -> (B T) ...') 39 | loss_con = rearrange((M * loss_con), 'B T ... -> (B T) ...') 40 | t_idx = rearrange(t_idx, 'B T -> (B T)') 41 | 42 | # loss switching 43 | semseg_mask = generate_semseg_mask(t_idx) 44 | semseg_mask = rearrange(semseg_mask, 'B -> B 1 1 1 1').float() 45 | loss = (semseg_mask * loss_seg + (1 - semseg_mask) * loss_con).mean() 46 | 47 | return loss 48 | 49 | 50 | def compute_loss(model, train_data, config): 51 | ''' 52 | Compute episodic training loss for VTM. 53 | [train_data] 54 | X : input image of shape (B, T, N, 3, H, W) 55 | Y : output label of shape (B, T, N, 1, H, W) 56 | M : output mask of shape (B, T, N, 1, H, W) 57 | t_idx: task index of shape (B, T) 58 | ''' 59 | X, Y, M, t_idx = train_data 60 | 61 | # split the batches into support and query 62 | X_S, X_Q = X.split(math.ceil(X.size(2) / 2), dim=2) 63 | Y_S, Y_Q = Y.split(math.ceil(Y.size(2) / 2), dim=2) 64 | M_S, M_Q = M.split(math.ceil(M.size(2) / 2), dim=2) 65 | 66 | # ignore masked region in support label 67 | Y_S_in = torch.where(M_S.bool(), Y_S, torch.ones_like(Y_S) * config.mask_value) 68 | 69 | # compute loss for query images 70 | Y_Q_pred = model(X_S, Y_S_in, X_Q, t_idx=t_idx, sigmoid=False) 71 | loss = hybrid_loss(Y_Q_pred, Y_Q, M_Q, t_idx) 72 | 73 | return loss 74 | 75 | 76 | def normalize_tensor(input_tensor, dim): 77 | ''' 78 | Normalize Euclidean vector. 79 | ''' 80 | norm = torch.norm(input_tensor, p='fro', dim=dim, keepdim=True) 81 | zero_mask = (norm == 0) 82 | norm[zero_mask] = 1 83 | out = input_tensor.div(norm) 84 | out[zero_mask.expand_as(out)] = 0 85 | return out 86 | 87 | 88 | def compute_metric(Y, Y_pred, M, task, miou_evaluator=None, stage=0): 89 | ''' 90 | Compute evaluation metric for each task. 91 | ''' 92 | # Mean Angle Error 93 | if task == 'normal': 94 | pred = normalize_tensor(Y_pred, dim=1) 95 | gt = normalize_tensor(Y, dim=1) 96 | deg_diff = torch.rad2deg(2 * torch.atan2(torch.norm(pred - gt, dim=1), torch.norm(pred + gt, dim=1))) 97 | metric = (M[:, 0] * deg_diff).mean() 98 | 99 | # Mean IoU 100 | elif 'segment_semantic' in task: 101 | assert miou_evaluator is not None 102 | 103 | area_inter, area_union = Evaluator.classify_prediction(Y_pred.clone().float(), Y.float()) 104 | if stage == 0: 105 | assert 'segment_semantic' in task 106 | semseg_class = int(task.split('_')[-1]) 107 | class_id = torch.tensor([SEMSEG_CLASSES.index(semseg_class)]*len(Y_pred), device=Y.device) 108 | else: 109 | class_id = torch.tensor([0]*len(Y_pred), device=Y.device) 110 | 111 | area_inter = area_inter.to(Y.device) 112 | area_union = area_union.to(Y.device) 113 | miou_evaluator.update(area_inter, area_union, class_id) 114 | 115 | metric = 0 116 | 117 | # Mean Squared Error 118 | else: 119 | metric = (M * F.mse_loss(Y, Y_pred, reduction='none').pow(0.5)).mean() 120 | 121 | return metric -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class CrossAttention(nn.Module): 7 | def __init__(self, dim_q, dim_v, dim_o, num_heads=4, act_fn=nn.GELU, 8 | dr=0.1, pre_ln=True, ln=True, residual=True, dim_k=None): 9 | super().__init__() 10 | 11 | if dim_k is None: 12 | dim_k = dim_q 13 | 14 | # heads and temperature 15 | self.num_heads = num_heads 16 | self.dim_split_q = dim_q // num_heads 17 | self.dim_split_v = dim_o // num_heads 18 | self.temperature = math.sqrt(dim_o) 19 | self.residual = residual 20 | 21 | # projection layers 22 | self.fc_q = nn.Linear(dim_q, dim_q, bias=False) 23 | self.fc_k = nn.Linear(dim_k, dim_q, bias=False) 24 | self.fc_v = nn.Linear(dim_v, dim_o, bias=False) 25 | self.fc_o = nn.Linear(dim_o, dim_o, bias=False) 26 | 27 | # nonlinear activation and dropout 28 | self.activation = act_fn() 29 | self.attn_dropout = nn.Dropout(dr) 30 | 31 | # layernorm layers 32 | if pre_ln: 33 | if dim_q == dim_k: 34 | self.pre_ln_q = self.pre_ln_k = nn.LayerNorm(dim_q) 35 | else: 36 | self.pre_ln_q = nn.LayerNorm(dim_q) 37 | self.pre_ln_k = nn.LayerNorm(dim_k) 38 | else: 39 | self.pre_ln_q = self.pre_ln_k = nn.Identity() 40 | self.ln = nn.LayerNorm(dim_o) if ln else nn.Identity() 41 | 42 | def compute_attention_scores(self, Q, K, mask=None, **kwargs): 43 | # pre-layer normalization 44 | Q = self.pre_ln_q(Q) 45 | K = self.pre_ln_k(K) 46 | 47 | # lienar projection 48 | Q = self.fc_q(Q) 49 | K = self.fc_k(K) 50 | 51 | # split into multiple heads 52 | Q_ = torch.cat(Q.split(self.dim_split_q, 2), 0) 53 | K_ = torch.cat(K.split(self.dim_split_q, 2), 0) 54 | 55 | # scaled dot-product attention with mask and dropout 56 | A = Q_.bmm(K_.transpose(1, 2)) / self.temperature 57 | A = A.clip(-1e4, 1e4) 58 | if mask is not None: 59 | A.masked_fill(mask, -1e38) 60 | A = A.softmax(dim=2) 61 | if mask is not None: 62 | A.masked_fill(mask, 0) 63 | A = self.attn_dropout(A) 64 | 65 | return A 66 | 67 | def project_values(self, V): 68 | # linear projection 69 | O = self.fc_v(V) 70 | 71 | # residual connection with non-linearity 72 | if self.residual: 73 | O = O + self.activation(self.fc_o(O)) 74 | else: 75 | O = self.fc_o(O) 76 | 77 | return O 78 | 79 | def forward(self, Q, K, V, mask=None, get_attn_map=False, disconnect_self_image=False, H=None, W=None, **kwargs): 80 | # pre-layer normalization 81 | Q = self.pre_ln_q(Q) 82 | K = self.pre_ln_k(K) 83 | 84 | # lienar projection 85 | Q = self.fc_q(Q) 86 | K = self.fc_k(K) 87 | V = self.fc_v(V) 88 | 89 | # split into multiple heads 90 | Q_ = torch.cat(Q.split(self.dim_split_q, 2), 0) 91 | K_ = torch.cat(K.split(self.dim_split_q, 2), 0) 92 | V_ = torch.cat(V.split(self.dim_split_v, 2), 0) 93 | 94 | # scaled dot-product attention with mask and dropout 95 | L = Q_.bmm(K_.transpose(1, 2)) / self.temperature 96 | L = L.clip(-1e4, 1e4) 97 | 98 | # mask 99 | if mask is not None: 100 | mask = mask.transpose(1, 2).expand_as(L) 101 | elif disconnect_self_image: 102 | assert Q_.size(1) == K_.size(1) 103 | assert H is not None and W is not None 104 | N = Q_.size(1) // (H*W) 105 | mask = torch.block_diag(*[torch.ones(H*W, H*W, device=Q.device) for _ in range(N)]).bool() 106 | 107 | if mask is not None: 108 | L.masked_fill(mask, -1e38) 109 | 110 | A = L.softmax(dim=2) 111 | if mask is not None: 112 | A.masked_fill(mask, 0) 113 | A = self.attn_dropout(A) 114 | 115 | # apply attention to values 116 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) 117 | 118 | # layer normalization 119 | O = self.ln(O) 120 | 121 | # residual connection with non-linearity 122 | if self.residual: 123 | O = O + self.activation(self.fc_o(O)) 124 | else: 125 | O = self.fc_o(O) 126 | 127 | if get_attn_map: 128 | return O, A 129 | else: 130 | return O -------------------------------------------------------------------------------- /train/visualize.py: -------------------------------------------------------------------------------- 1 | from skimage import color 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torchvision.utils import make_grid 7 | 8 | from dataset.taskonomy_constants import * 9 | 10 | 11 | def visualize_batch(X=None, Y=None, M=None, Y_preds=None, channels=None, size=None, postprocess_fn=None, **kwargs): 12 | ''' 13 | Visualize a global batch consists of N-shot images and labels for T channels. 14 | It is assumed that images are shared by all channels, thus convert channels into RGB and visualize at once. 15 | ''' 16 | 17 | vis = [] 18 | 19 | # shape check 20 | assert X is not None or Y is not None or Y_preds is not None 21 | 22 | # visualize image 23 | if X is not None: 24 | img = X.cpu().float() 25 | vis.append(img) 26 | else: 27 | img = None 28 | 29 | # flatten labels and masks 30 | Ys = [] 31 | Ms = [] 32 | if Y is not None: 33 | Ys.append((Y, None)) 34 | Ms.append(M) 35 | if Y_preds is not None: 36 | if isinstance(Y_preds, torch.Tensor): 37 | Ys.append((Y_preds, Y)) 38 | Ms.append(None) 39 | elif isinstance(Y_preds, (tuple, list)): 40 | if Y is not None: 41 | for Y_pred in Y_preds: 42 | Ys.append((Y_pred, Y)) 43 | Ms.append(None) 44 | else: 45 | for Y_pred in Y_preds: 46 | Ys.append((Y_pred, None)) 47 | Ms.append(None) 48 | else: 49 | ValueError(f'unsupported predictions type: {type(Y_preds)}') 50 | 51 | # visualize labels 52 | if len(Ys) > 0: 53 | for Y, Y_gt in Ys: 54 | Y = Y.cpu().float() 55 | if Y_gt is not None: 56 | Y_gt = Y_gt.cpu().float() 57 | 58 | if channels is None: 59 | channels = list(range(Y.size(1))) 60 | 61 | label = Y[:, channels].clip(0, 1) 62 | if Y_gt is not None: 63 | label_gt = Y_gt[:, channels].clip(0, 1) 64 | else: 65 | label_gt = None 66 | 67 | # fill masked region with random noise 68 | if M is not None: 69 | assert Y.shape == M.shape 70 | M = M.cpu().float() 71 | label = torch.where(M[:, channels].bool(), 72 | label, 73 | torch.rand_like(label)) 74 | if Y_gt is not None: 75 | label_gt = Y_gt[:, channels].clip(0, 1) 76 | label_gt = torch.where(M[:, channels].bool(), 77 | label_gt, 78 | torch.rand_like(label_gt)) 79 | 80 | if postprocess_fn is not None: 81 | label = postprocess_fn(label, img, label_gt=label_gt) 82 | 83 | label = visualize_label_as_rgb(label) 84 | vis.append(label) 85 | 86 | nrow = len(vis[0]) 87 | vis = torch.cat(vis) 88 | if size is not None: 89 | vis = F.interpolate(vis, size) 90 | vis = make_grid(vis, nrow=nrow, **kwargs) 91 | vis = vis.float() 92 | 93 | return vis 94 | 95 | 96 | def postprocess_depth(label, img=None, **kwargs): 97 | label = 0.6*label + 0.4 98 | label = torch.exp(label * np.log(2.0**16.0)) - 1.0 99 | label = torch.log(label) / 11.09 100 | label = (label - 0.64) / 0.18 101 | label = (label + 1.) / 2 102 | label = (label*255).byte().float() / 255. 103 | return label 104 | 105 | 106 | def postprocess_semseg(label, img=None, **kwargs): 107 | COLORS = ('red', 'blue', 'yellow', 'magenta', 108 | 'green', 'indigo', 'darkorange', 'cyan', 'pink', 109 | 'yellowgreen', 'black', 'darkgreen', 'brown', 'gray', 110 | 'purple', 'darkviolet') 111 | 112 | if label.ndim == 4: 113 | label = label.squeeze(1) 114 | 115 | label_vis = [] 116 | if img is not None: 117 | for img_, label_ in zip(img, label): 118 | for c in range(len(COLORS)+1): 119 | label_[0, c] = c 120 | 121 | label_vis.append(torch.from_numpy(color.label2rgb(label_.numpy(), 122 | image=img_.permute(1, 2, 0).numpy(), 123 | colors=COLORS, 124 | kind='overlay')).permute(2, 0, 1)) 125 | else: 126 | for label_ in label: 127 | for c in range(len(COLORS)+1): 128 | label_[0, c] = c 129 | 130 | label_vis.append(torch.from_numpy(color.label2rgb(label_.numpy(), 131 | colors=COLORS, 132 | kind='overlay')).permute(2, 0, 1)) 133 | 134 | label = torch.stack(label_vis) 135 | 136 | return label 137 | 138 | 139 | def visualize_label_as_rgb(label): 140 | if label.size(1) == 1: 141 | label = label.repeat(1, 3, 1, 1) 142 | elif label.size(1) == 2: 143 | label = torch.cat((label, torch.zeros_like(label[:, :1])), 1) 144 | elif label.size(1) == 5: 145 | label = torch.stack(( 146 | label[:, :2].mean(1), 147 | label[:, 2:4].mean(1), 148 | label[:, 4] 149 | ), 1) 150 | elif label.size(1) != 3: 151 | assert NotImplementedError 152 | 153 | return label -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from easydict import EasyDict 4 | 5 | from dataset.taskonomy_constants import TASKS_GROUP_NAMES, TASKS_GROUP_TEST 6 | 7 | 8 | def str2bool(v): 9 | if v == 'True' or v == 'true': 10 | return True 11 | elif v == 'False' or v == 'false': 12 | return False 13 | else: 14 | raise argparse.ArgumentTypeError('Boolean value expected.') 15 | 16 | 17 | # argument parser 18 | parser = argparse.ArgumentParser() 19 | 20 | # necessary arguments 21 | parser.add_argument('--debug_mode', '-debug', default=False, action='store_true') 22 | parser.add_argument('--continue_mode', '-cont', default=False, action='store_true') 23 | parser.add_argument('--skip_mode', '-skip', default=False, action='store_true') 24 | parser.add_argument('--no_eval', '-ne', default=False, action='store_true') 25 | parser.add_argument('--no_save', '-ns', default=False, action='store_true') 26 | parser.add_argument('--reset_mode', '-reset', default=False, action='store_true') 27 | 28 | parser.add_argument('--stage', type=int, default=0, choices=[0, 1, 2]) 29 | parser.add_argument('--task', type=str, default='', choices=['', 'all'] + TASKS_GROUP_NAMES) 30 | parser.add_argument('--task_fold', '-fold', type=int, default=None, choices=[0, 1, 2, 3, 4]) 31 | parser.add_argument('--exp_name', type=str, default='') 32 | parser.add_argument('--exp_subname', type=str, default='') 33 | parser.add_argument('--name_postfix', '-ptf', type=str, default='') 34 | parser.add_argument('--save_postfix', '-sptf', type=str, default='') 35 | parser.add_argument('--result_postfix', '-rptf', type=str, default='') 36 | parser.add_argument('--load_step', '-ls', type=int, default=-1) 37 | 38 | # optional arguments 39 | parser.add_argument('--model', type=str, default=None, choices=['VTM']) 40 | parser.add_argument('--seed', type=int, default=None) 41 | parser.add_argument('--strategy', '-str', type=str, default=None) 42 | parser.add_argument('--num_workers', '-nw', type=int, default=None) 43 | parser.add_argument('--global_batch_size', '-gbs', type=int, default=None) 44 | parser.add_argument('--eval_batch_size', '-ebs', type=int, default=None) 45 | parser.add_argument('--n_eval_batches', '-neb', type=int, default=None) 46 | parser.add_argument('--shot', type=int, default=None) 47 | parser.add_argument('--max_channels', '-mc', type=int, default=None) 48 | parser.add_argument('--support_idx', '-sid', type=int, default=None) 49 | parser.add_argument('--channel_idx', '-cid', type=int, default=None) 50 | parser.add_argument('--test_split', '-split', type=str, default=None) 51 | parser.add_argument('--semseg_threshold', '-sth', type=float, default=None) 52 | 53 | parser.add_argument('--image_augmentation', '-ia', type=str2bool, default=None) 54 | parser.add_argument('--unary_augmentation', '-ua', type=str2bool, default=None) 55 | parser.add_argument('--binary_augmentation', '-ba', type=str2bool, default=None) 56 | parser.add_argument('--mixed_augmentation', '-ma', type=str2bool, default=None) 57 | parser.add_argument('--image_backbone', '-ib', type=str, default=None) 58 | parser.add_argument('--label_backbone', '-lb', type=str, default=None) 59 | parser.add_argument('--n_attn_heads', '-nah', type=int, default=None) 60 | 61 | parser.add_argument('--n_steps', '-nst', type=int, default=None) 62 | parser.add_argument('--optimizer', '-opt', type=str, default=None, choices=['sgd', 'adam', 'adamw']) 63 | parser.add_argument('--lr', type=float, default=None) 64 | parser.add_argument('--lr_pretrained', '-lrp', type=float, default=None) 65 | parser.add_argument('--lr_schedule', '-lrs', type=str, default=None, choices=['constant', 'sqroot', 'cos', 'poly']) 66 | parser.add_argument('--early_stopping_patience', '-esp', type=int, default=None) 67 | 68 | parser.add_argument('--log_dir', type=str, default=None) 69 | parser.add_argument('--save_dir', type=str, default=None) 70 | parser.add_argument('--load_dir', type=str, default=None) 71 | parser.add_argument('--val_iter', '-viter', type=int, default=None) 72 | parser.add_argument('--save_iter', '-siter', type=int, default=None) 73 | 74 | args = parser.parse_args() 75 | 76 | 77 | # load config file 78 | if args.stage == 0: 79 | config_path = 'configs/train_config.yaml' 80 | elif args.stage == 1: 81 | config_path = 'configs/finetune_config.yaml' 82 | elif args.stage == 2: 83 | config_path = 'configs/test_config.yaml' 84 | 85 | with open(config_path, 'r') as f: 86 | config = yaml.safe_load(f) 87 | config = EasyDict(config) 88 | 89 | # copy parsed arguments 90 | for key in args.__dir__(): 91 | if key[:2] != '__' and getattr(args, key) is not None: 92 | setattr(config, key, getattr(args, key)) 93 | 94 | # retrieve data root 95 | with open('data_paths.yaml', 'r') as f: 96 | path_dict = yaml.safe_load(f) 97 | config.root_dir = path_dict[config.dataset] 98 | 99 | # for debugging 100 | if config.debug_mode: 101 | config.n_steps = 10 102 | config.log_iter = 1 103 | config.val_iter = 5 104 | config.save_iter = 5 105 | if config.stage == 2: 106 | config.n_eval_batches = 2 107 | config.log_dir += '_debugging' 108 | if config.stage == 0: 109 | config.load_dir += '_debugging' 110 | if config.stage <= 1: 111 | config.save_dir += '_debugging' 112 | 113 | # create experiment name 114 | if config.exp_name == '': 115 | if config.stage == 0: 116 | if config.task == '': 117 | config.exp_name = f'{config.model}_fold:{config.task_fold}{config.name_postfix}' 118 | else: 119 | config.exp_name = f'{config.model}_task:{config.task}{config.name_postfix}' 120 | else: 121 | fold_dict = {} 122 | for fold in TASKS_GROUP_TEST: 123 | for task in TASKS_GROUP_TEST[fold]: 124 | fold_dict[task] = fold 125 | task_fold = fold_dict[config.task] 126 | config.exp_name = f'{config.model}_fold:{task_fold}{config.name_postfix}' -------------------------------------------------------------------------------- /model/beit/beit_registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | import fnmatch 4 | from collections import defaultdict 5 | from copy import deepcopy 6 | 7 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 8 | 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] 9 | 10 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 11 | _model_to_module = {} # mapping of model names to module names 12 | _model_entrypoints = {} # mapping of model names to entrypoint fns 13 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 14 | _model_default_cfgs = dict() # central repo for model default_cfgs 15 | 16 | 17 | def register_model(fn): 18 | # lookup containing module 19 | mod = sys.modules[fn.__module__] 20 | module_name_split = fn.__module__.split('.') 21 | module_name = module_name_split[-1] if len(module_name_split) else '' 22 | 23 | # add model to __all__ in module 24 | model_name = fn.__name__ 25 | if hasattr(mod, '__all__'): 26 | mod.__all__.append(model_name) 27 | else: 28 | mod.__all__ = [model_name] 29 | 30 | # add entries to registry dict/sets 31 | _model_entrypoints[model_name] = fn 32 | _model_to_module[model_name] = module_name 33 | _module_to_models[module_name].add(model_name) 34 | has_pretrained = False # check if model has a pretrained url to allow filtering on this 35 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 36 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 37 | # entrypoints or non-matching combos 38 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] 39 | _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) 40 | if has_pretrained: 41 | _model_has_pretrained.add(model_name) 42 | return fn 43 | 44 | 45 | def _natural_key(string_): 46 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 47 | 48 | 49 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): 50 | """ Return list of available model names, sorted alphabetically 51 | Args: 52 | filter (str) - Wildcard filter string that works with fnmatch 53 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 54 | pretrained (bool) - Include only models with pretrained weights if True 55 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 56 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) 57 | Example: 58 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 59 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 60 | """ 61 | if module: 62 | all_models = list(_module_to_models[module]) 63 | else: 64 | all_models = _model_entrypoints.keys() 65 | if filter: 66 | models = [] 67 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter] 68 | for f in include_filters: 69 | include_models = fnmatch.filter(all_models, f) # include these models 70 | if len(include_models): 71 | models = set(models).union(include_models) 72 | else: 73 | models = all_models 74 | if exclude_filters: 75 | if not isinstance(exclude_filters, (tuple, list)): 76 | exclude_filters = [exclude_filters] 77 | for xf in exclude_filters: 78 | exclude_models = fnmatch.filter(models, xf) # exclude these models 79 | if len(exclude_models): 80 | models = set(models).difference(exclude_models) 81 | if pretrained: 82 | models = _model_has_pretrained.intersection(models) 83 | if name_matches_cfg: 84 | models = set(_model_default_cfgs).intersection(models) 85 | return list(sorted(models, key=_natural_key)) 86 | 87 | 88 | def is_model(model_name): 89 | """ Check if a model name exists 90 | """ 91 | return model_name in _model_entrypoints 92 | 93 | 94 | def model_entrypoint(model_name): 95 | """Fetch a model entrypoint for specified model name 96 | """ 97 | return _model_entrypoints[model_name] 98 | 99 | 100 | def list_modules(): 101 | """ Return list of module names that contain models / model entrypoints 102 | """ 103 | modules = _module_to_models.keys() 104 | return list(sorted(modules)) 105 | 106 | 107 | def is_model_in_modules(model_name, module_names): 108 | """Check if a model exists within a subset of modules 109 | Args: 110 | model_name (str) - name of model to check 111 | module_names (tuple, list, set) - names of modules to search in 112 | """ 113 | assert isinstance(module_names, (tuple, list, set)) 114 | return any(model_name in _module_to_models[n] for n in module_names) 115 | 116 | 117 | def has_model_default_key(model_name, cfg_key): 118 | """ Query model default_cfgs for existence of a specific key. 119 | """ 120 | if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: 121 | return True 122 | return False 123 | 124 | 125 | def is_model_default_key(model_name, cfg_key): 126 | """ Return truthy value for specified model default_cfg key, False if does not exist. 127 | """ 128 | if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): 129 | return True 130 | return False 131 | 132 | 133 | def get_model_default_value(model_name, cfg_key): 134 | """ Get a specific model default_cfg value by key. None if it doesn't exist. 135 | """ 136 | if model_name in _model_default_cfgs: 137 | return _model_default_cfgs[model_name].get(cfg_key, None) 138 | else: 139 | return None 140 | 141 | 142 | def is_model_pretrained(model_name): 143 | return model_name in _model_has_pretrained -------------------------------------------------------------------------------- /model/dpt/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | activations = {} 10 | 11 | 12 | def get_activation(name): 13 | def hook(model, input, output): 14 | activations[name] = output 15 | 16 | return hook 17 | 18 | 19 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 20 | posemb_tok, posemb_grid = ( 21 | posemb[:, : self.start_index], 22 | posemb[0, self.start_index :], 23 | ) 24 | 25 | gs_old = int(math.sqrt(len(posemb_grid))) 26 | 27 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 28 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 29 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 30 | 31 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 32 | 33 | return posemb 34 | 35 | 36 | def forward_flex(self, x, t_idx=None): 37 | b, c, h, w = x.shape 38 | 39 | if self.pos_embed is not None: 40 | pos_embed = self._resize_pos_embed( 41 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 42 | ) 43 | else: 44 | pos_embed = None 45 | 46 | B = x.shape[0] 47 | 48 | if hasattr(self.patch_embed, "backbone"): 49 | x = self.patch_embed.backbone(x) 50 | if isinstance(x, (list, tuple)): 51 | x = x[-1] # last feature if backbone outputs list/tuple of features 52 | 53 | if isinstance(self.patch_embed, nn.ModuleList): 54 | assert t_idx is not None 55 | x = torch.cat([self.patch_embed[t_idx_].proj(x_[None]) for t_idx_, x_ in zip(t_idx, x)]) 56 | else: 57 | x = self.patch_embed.proj(x) 58 | x = x.flatten(2).transpose(1, 2) 59 | 60 | if getattr(self, "dist_token", None) is not None: 61 | cls_tokens = self.cls_token.expand( 62 | B, -1, -1 63 | ) # stole cls_tokens impl from Phil Wang, thanks 64 | dist_token = self.dist_token.expand(B, -1, -1) 65 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 66 | else: 67 | cls_tokens = self.cls_token.expand( 68 | B, -1, -1 69 | ) # stole cls_tokens impl from Phil Wang, thanks 70 | x = torch.cat((cls_tokens, x), dim=1) 71 | 72 | if pos_embed is not None: 73 | x = x + pos_embed 74 | x = self.pos_drop(x) 75 | 76 | for blk in self.blocks: 77 | x = blk(x) 78 | 79 | x = self.norm(x) 80 | 81 | return x 82 | 83 | 84 | def _make_vit_backbone( 85 | model, 86 | features=[96, 192, 384, 768], 87 | hooks=[2, 5, 8, 11], 88 | vit_features=768, 89 | start_index=1, 90 | ): 91 | pretrained = nn.Module() 92 | 93 | pretrained.model = model 94 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 95 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 96 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 97 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 98 | 99 | pretrained.activations = activations 100 | 101 | # 32, 48, 136, 384 102 | pretrained.act_postprocess1 = nn.Sequential( 103 | nn.Conv2d( 104 | in_channels=vit_features, 105 | out_channels=features[0], 106 | kernel_size=1, 107 | stride=1, 108 | padding=0, 109 | ), 110 | nn.ConvTranspose2d( 111 | in_channels=features[0], 112 | out_channels=features[0], 113 | kernel_size=4, 114 | stride=4, 115 | padding=0, 116 | bias=True, 117 | dilation=1, 118 | groups=1, 119 | ), 120 | ) 121 | 122 | pretrained.act_postprocess2 = nn.Sequential( 123 | nn.Conv2d( 124 | in_channels=vit_features, 125 | out_channels=features[1], 126 | kernel_size=1, 127 | stride=1, 128 | padding=0, 129 | ), 130 | nn.ConvTranspose2d( 131 | in_channels=features[1], 132 | out_channels=features[1], 133 | kernel_size=2, 134 | stride=2, 135 | padding=0, 136 | bias=True, 137 | dilation=1, 138 | groups=1, 139 | ), 140 | ) 141 | 142 | pretrained.act_postprocess3 = nn.Sequential( 143 | nn.Conv2d( 144 | in_channels=vit_features, 145 | out_channels=features[2], 146 | kernel_size=1, 147 | stride=1, 148 | padding=0, 149 | ), 150 | ) 151 | 152 | pretrained.act_postprocess4 = nn.Sequential( 153 | nn.Conv2d( 154 | in_channels=vit_features, 155 | out_channels=features[3], 156 | kernel_size=1, 157 | stride=1, 158 | padding=0, 159 | ), 160 | nn.Conv2d( 161 | in_channels=features[3], 162 | out_channels=features[3], 163 | kernel_size=3, 164 | stride=2, 165 | padding=1, 166 | ), 167 | ) 168 | 169 | pretrained.model.start_index = start_index 170 | pretrained.model.patch_size = [16, 16] 171 | 172 | # We inject this function into the VisionTransformer instances so that 173 | # we can use it with interpolated position embeddings without modifying the library source. 174 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 175 | pretrained.model._resize_pos_embed = types.MethodType( 176 | _resize_pos_embed, pretrained.model 177 | ) 178 | 179 | return pretrained 180 | 181 | 182 | def _make_pretrained_vitl16_384(pretrained, in_chans=3): 183 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained, in_chans=in_chans) 184 | 185 | hooks = [5, 11, 17, 23] 186 | return _make_vit_backbone( 187 | model, 188 | features=[256, 512, 1024, 1024], 189 | hooks=hooks, 190 | vit_features=1024, 191 | ) 192 | 193 | 194 | def _make_pretrained_vitb16_384(pretrained, in_chans=3): 195 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained, in_chans=in_chans) 196 | 197 | hooks = [2, 5, 8, 11] 198 | return _make_vit_backbone( 199 | model, 200 | features=[96, 192, 384, 768], 201 | hooks=hooks, 202 | ) 203 | 204 | 205 | def _make_pretrained_vitb16_224(pretrained, in_chans=3): 206 | model = timm.create_model("vit_base_patch16_224", pretrained=pretrained, in_chans=in_chans) 207 | 208 | hooks = [2, 5, 8, 11] 209 | return _make_vit_backbone( 210 | model, 211 | features=[96, 192, 384, 768], 212 | hooks=hooks, 213 | ) -------------------------------------------------------------------------------- /model/dpt/dpt_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitl16_384, 6 | _make_pretrained_vitb16_384, 7 | _make_pretrained_vitb16_224, 8 | ) 9 | 10 | 11 | def _make_encoder( 12 | backbone, 13 | features, 14 | use_pretrained, 15 | groups=1, 16 | expand=False, 17 | model=None, 18 | in_chans=3, 19 | ): 20 | if backbone == "vitl16_384": 21 | pretrained = _make_pretrained_vitl16_384( 22 | use_pretrained, 23 | in_chans=in_chans, 24 | ) 25 | scratch = _make_scratch( 26 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 27 | ) # ViT-L/16 - 85.0% Top1 (backbone) 28 | elif backbone == "vitb16_384": 29 | pretrained = _make_pretrained_vitb16_384( 30 | use_pretrained, 31 | in_chans=in_chans, 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "vitb16_224": 37 | pretrained = _make_pretrained_vitb16_224( 38 | use_pretrained, 39 | in_chans=in_chans, 40 | ) 41 | scratch = _make_scratch( 42 | [96, 192, 384, 768], features, groups=groups, expand=expand 43 | ) 44 | else: 45 | print(f"Backbone '{backbone}' not implemented") 46 | assert False 47 | 48 | return pretrained, scratch 49 | 50 | 51 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 52 | scratch = nn.Module() 53 | 54 | out_shape1 = out_shape 55 | out_shape2 = out_shape 56 | out_shape3 = out_shape 57 | out_shape4 = out_shape 58 | if expand == True: 59 | out_shape1 = out_shape 60 | out_shape2 = out_shape * 2 61 | out_shape3 = out_shape * 4 62 | out_shape4 = out_shape * 8 63 | 64 | scratch.layer1_rn = nn.Conv2d( 65 | in_shape[0], 66 | out_shape1, 67 | kernel_size=3, 68 | stride=1, 69 | padding=1, 70 | bias=False, 71 | groups=groups, 72 | ) 73 | scratch.layer2_rn = nn.Conv2d( 74 | in_shape[1], 75 | out_shape2, 76 | kernel_size=3, 77 | stride=1, 78 | padding=1, 79 | bias=False, 80 | groups=groups, 81 | ) 82 | scratch.layer3_rn = nn.Conv2d( 83 | in_shape[2], 84 | out_shape3, 85 | kernel_size=3, 86 | stride=1, 87 | padding=1, 88 | bias=False, 89 | groups=groups, 90 | ) 91 | scratch.layer4_rn = nn.Conv2d( 92 | in_shape[3], 93 | out_shape4, 94 | kernel_size=3, 95 | stride=1, 96 | padding=1, 97 | bias=False, 98 | groups=groups, 99 | ) 100 | 101 | return scratch 102 | 103 | 104 | def _make_fusion_block(features, use_bn): 105 | return FeatureFusionBlock( 106 | features, 107 | nn.ReLU(False), 108 | deconv=False, 109 | bn=use_bn, 110 | expand=False, 111 | align_corners=True, 112 | ) 113 | 114 | 115 | class Interpolate(nn.Module): 116 | """Interpolation module.""" 117 | 118 | def __init__(self, scale_factor, mode, align_corners=False): 119 | """Init. 120 | 121 | Args: 122 | scale_factor (float): scaling 123 | mode (str): interpolation mode 124 | """ 125 | super(Interpolate, self).__init__() 126 | 127 | self.interp = nn.functional.interpolate 128 | self.scale_factor = scale_factor 129 | self.mode = mode 130 | self.align_corners = align_corners 131 | 132 | def forward(self, x): 133 | """Forward pass. 134 | 135 | Args: 136 | x (tensor): input 137 | 138 | Returns: 139 | tensor: interpolated data 140 | """ 141 | 142 | cast = False 143 | if x.dtype != torch.float32: 144 | cast = True 145 | dtype = x.dtype 146 | x = x.float() 147 | 148 | x = self.interp( 149 | x, 150 | scale_factor=self.scale_factor, 151 | mode=self.mode, 152 | align_corners=self.align_corners, 153 | ) 154 | 155 | if cast: 156 | x = x.to(dtype) 157 | 158 | return x 159 | 160 | 161 | class ResidualConvUnit(nn.Module): 162 | """Residual convolution module.""" 163 | 164 | def __init__(self, features, activation, bn): 165 | """Init. 166 | 167 | Args: 168 | features (int): number of features 169 | """ 170 | super().__init__() 171 | 172 | self.bn = bn 173 | 174 | self.groups = 1 175 | 176 | self.conv1 = nn.Conv2d( 177 | features, 178 | features, 179 | kernel_size=3, 180 | stride=1, 181 | padding=1, 182 | bias=not self.bn, 183 | groups=self.groups, 184 | ) 185 | 186 | self.conv2 = nn.Conv2d( 187 | features, 188 | features, 189 | kernel_size=3, 190 | stride=1, 191 | padding=1, 192 | bias=not self.bn, 193 | groups=self.groups, 194 | ) 195 | 196 | if self.bn == True: 197 | self.bn1 = nn.BatchNorm2d(features) 198 | self.bn2 = nn.BatchNorm2d(features) 199 | 200 | self.activation = activation 201 | 202 | self.skip_add = nn.quantized.FloatFunctional() 203 | 204 | def forward(self, x): 205 | """Forward pass. 206 | 207 | Args: 208 | x (tensor): input 209 | 210 | Returns: 211 | tensor: output 212 | """ 213 | 214 | out = self.activation(x) 215 | out = self.conv1(out) 216 | if self.bn == True: 217 | out = self.bn1(out) 218 | 219 | out = self.activation(out) 220 | out = self.conv2(out) 221 | if self.bn == True: 222 | out = self.bn2(out) 223 | 224 | if self.groups > 1: 225 | out = self.conv_merge(out) 226 | 227 | return self.skip_add.add(out, x) 228 | 229 | 230 | class FeatureFusionBlock(nn.Module): 231 | """Feature fusion block.""" 232 | 233 | def __init__( 234 | self, 235 | features, 236 | activation, 237 | deconv=False, 238 | bn=False, 239 | expand=False, 240 | align_corners=True, 241 | ): 242 | """Init. 243 | 244 | Args: 245 | features (int): number of features 246 | """ 247 | super(FeatureFusionBlock, self).__init__() 248 | 249 | self.deconv = deconv 250 | self.align_corners = align_corners 251 | 252 | self.groups = 1 253 | 254 | self.expand = expand 255 | out_features = features 256 | if self.expand == True: 257 | out_features = features // 2 258 | 259 | self.out_conv = nn.Conv2d( 260 | features, 261 | out_features, 262 | kernel_size=1, 263 | stride=1, 264 | padding=0, 265 | bias=True, 266 | groups=1, 267 | ) 268 | 269 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 270 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 271 | 272 | self.skip_add = nn.quantized.FloatFunctional() 273 | 274 | def forward(self, *xs): 275 | """Forward pass. 276 | 277 | Returns: 278 | tensor: output 279 | """ 280 | output = xs[0] 281 | 282 | if len(xs) == 2: 283 | res = self.resConfUnit1(xs[1]) 284 | output = self.skip_add.add(output, res) 285 | # output += res 286 | 287 | output = self.resConfUnit2(output) 288 | 289 | cast = False 290 | if output.dtype != torch.float32: 291 | cast = True 292 | dtype = output.dtype 293 | output = output.float() 294 | output = nn.functional.interpolate( 295 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 296 | ) 297 | if cast: 298 | output = output.to(dtype) 299 | 300 | 301 | output = self.out_conv(output) 302 | 303 | return output 304 | -------------------------------------------------------------------------------- /dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | import torch 5 | from torchvision.transforms.functional import gaussian_blur 6 | 7 | 8 | def normalize(x): 9 | if x.max() == x.min(): 10 | return x - x.min() 11 | else: 12 | return (x - x.min()) / (x.max() - x.min()) 13 | 14 | 15 | def linear_sample(p_range): 16 | if isinstance(p_range, float): 17 | return p_range 18 | else: 19 | return p_range[0] + random.random()*(p_range[1] - p_range[0]) 20 | 21 | 22 | def log_sample(p_range): 23 | if isinstance(p_range, float): 24 | return p_range 25 | else: 26 | return math.exp(math.log(p_range[0]) + random.random()*(math.log(p_range[1]) - math.log(p_range[0]))) 27 | 28 | 29 | def categorical_sample(p_range): 30 | if isinstance(p_range, (float, int)): 31 | return p_range 32 | else: 33 | return p_range[np.random.randint(len(p_range))] 34 | 35 | 36 | def rand_bbox(size, lam): 37 | H, W = size 38 | cut_rat = np.sqrt(1. - lam) 39 | cut_w = int(W * cut_rat) 40 | cut_h = int(H * cut_rat) 41 | 42 | # uniform 43 | cx = np.random.randint(W) 44 | cy = np.random.randint(H) 45 | 46 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 47 | bby1 = np.clip(cy - cut_h // 2, 0, H) 48 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 49 | bby2 = np.clip(cy + cut_h // 2, 0, H) 50 | 51 | return bbx1, bby1, bbx2, bby2 52 | 53 | 54 | class Augmentation: 55 | pass 56 | 57 | 58 | class RandomHorizontalFlip(Augmentation): 59 | def __init__(self): 60 | self.augmentation = lambda x: torch.flip(x, dims=[-1]) 61 | 62 | def __str__(self): 63 | return 'RandomHorizontalFlip Augmentation' 64 | 65 | def __call__(self, *arrays, get_augs=False): 66 | if random.random() < 0.5: 67 | if len(arrays) == 1: 68 | if get_augs: 69 | return self.augmentation(arrays[0]), self.augmentation 70 | else: 71 | return self.augmentation(arrays[0]) 72 | else: 73 | arrays_flipped = [] 74 | for array in arrays: 75 | arrays_flipped.append(self.augmentation(array)) 76 | if get_augs: 77 | return arrays_flipped, self.augmentation 78 | else: 79 | return arrays_flipped 80 | else: 81 | if len(arrays) == 1: 82 | if get_augs: 83 | return arrays[0], lambda x: x 84 | else: 85 | return arrays[0] 86 | else: 87 | if get_augs: 88 | return arrays, lambda x: x 89 | else: 90 | return arrays 91 | 92 | 93 | class RandomCompose(Augmentation): 94 | def __init__(self, augmentations, n_aug=2, p=0.5, verbose=False): 95 | assert len(augmentations) >= n_aug 96 | self.augmentations = augmentations 97 | self.n_aug = n_aug 98 | self.p = p 99 | self.verbose = verbose # for debugging 100 | 101 | def __call__(self, label, mask, get_augs=False): 102 | augmentations = [ 103 | self.augmentations[i] 104 | for i in np.random.choice(len(self.augmentations), size=self.n_aug, replace=False) 105 | ] 106 | 107 | for augmentation in augmentations: 108 | if random.random() < self.p: 109 | label, mask = augmentation(label, mask) 110 | if self.verbose: 111 | print(augmentation) 112 | elif self.verbose: 113 | print('skipped') 114 | 115 | if get_augs: 116 | return label, mask, augmentations 117 | else: 118 | return label, mask 119 | 120 | 121 | class RandomJitter(Augmentation): 122 | def __init__(self, brightness, contrast): 123 | self.brightness = brightness 124 | self.contrast = contrast 125 | 126 | def __str__(self): 127 | return f'RandomJitter Augmentation (brightness = {self.brightness}, contrast = {self.contrast})' 128 | 129 | def __call__(self, label, mask): 130 | brightness = linear_sample(self.brightness) 131 | contrast = linear_sample(self.contrast) 132 | 133 | alpha = 1 + contrast 134 | beta = brightness 135 | 136 | label = alpha * label + beta 137 | label = torch.clip(label, 0, 1) 138 | label = normalize(label) 139 | 140 | return label, mask 141 | 142 | 143 | class RandomPolynomialTransform(Augmentation): 144 | def __init__(self, degree): 145 | self.degree = degree 146 | 147 | def __str__(self): 148 | return f'RandomPolynomialTransform Augmentation (degree = {self.degree})' 149 | 150 | def __call__(self, label, mask): 151 | degree = log_sample(self.degree) 152 | 153 | label = label.pow(degree) 154 | label = normalize(label) 155 | return label, mask 156 | 157 | 158 | class RandomSigmoidTransform(Augmentation): 159 | def __init__(self, temperature): 160 | self.temperature = temperature 161 | 162 | def __str__(self): 163 | return f'RandomSigmoidTransform Augmentation (temperature = {self.temperature})' 164 | 165 | def __call__(self, label, mask): 166 | cast = False 167 | if label.dtype != torch.float32: 168 | dtype = label.dtype 169 | cast = True 170 | label = label.float() 171 | 172 | temperature = categorical_sample(self.temperature) 173 | 174 | label = torch.sigmoid(label / temperature) 175 | label = normalize(label) 176 | 177 | if cast: 178 | label = label.to(dtype) 179 | 180 | return label, mask 181 | 182 | 183 | class RandomGaussianBlur(Augmentation): 184 | def __init__(self, kernel_size, sigma): 185 | self.kernel_size = kernel_size 186 | self.sigma = sigma 187 | 188 | def __str__(self): 189 | return f'RandomGaussianBlur Augmentation (kernel_size = {self.kernel_size}, sigma = {self.sigma})' 190 | 191 | def __call__(self, label, mask): 192 | cast = False 193 | if label.dtype != torch.float32: 194 | dtype = label.dtype 195 | cast = True 196 | label = label.float() 197 | 198 | kernel_size = [categorical_sample(self.kernel_size)]*2 199 | sigma = categorical_sample(self.sigma) 200 | 201 | label = gaussian_blur(label, kernel_size, sigma) 202 | label = normalize(label) 203 | 204 | if cast: 205 | label = label.to(dtype) 206 | 207 | return label, mask 208 | 209 | 210 | class BinaryAugmentation(Augmentation): 211 | pass 212 | 213 | 214 | class Mixup(BinaryAugmentation): 215 | def __init__(self, alpha=1.0): 216 | self.alpha = alpha 217 | 218 | def __call__(self, label_1, label_2, mask_1, mask_2): 219 | lam = np.random.beta(self.alpha, self.alpha) 220 | label_mix = lam*label_1 + (1 - lam)*label_2 221 | mask_mix = torch.logical_and(mask_1, mask_2) 222 | 223 | return label_mix, mask_mix 224 | 225 | 226 | class Cutmix(BinaryAugmentation): 227 | def __init__(self, alpha=1.0): 228 | self.alpha = alpha 229 | 230 | def __call__(self, label_1, label_2, mask_1, mask_2): 231 | assert label_1.size() == label_2.size() 232 | 233 | lam = np.random.beta(self.alpha, self.alpha) 234 | bbx1, bby1, bbx2, bby2 = rand_bbox(label_1.size()[-2:], lam) 235 | 236 | label_mix = label_1.clone() 237 | label_mix[:, :, bbx1:bbx2, bby1:bby2] = label_2[:, :, bbx1:bbx2, bby1:bby2] 238 | mask_mix = mask_1.clone() 239 | mask_mix[:, :, bbx1:bbx2, bby1:bby2] = mask_2[:, :, bbx1:bbx2, bby1:bby2] 240 | 241 | return label_mix, mask_mix 242 | 243 | 244 | FILTERING_AUGMENTATIONS = { 245 | 'jitter': (RandomJitter, {"brightness": (-0.5, 0.5), 246 | "contrast": (-0.5, 0.5)}), 247 | 'polynomial': (RandomPolynomialTransform, {"degree": (1.0/3, 3.0)}), 248 | 'sigmoid': (RandomSigmoidTransform, {"temperature": [0.1, 0.2, 0.5, 2e5, 5e5, 1e6, 2e6]}), 249 | 'gaussianblur': (RandomGaussianBlur, {"kernel_size": [9, 17, 33], 250 | "sigma": [0.5, 1.0, 2.0, 5.0, 10.0]}), 251 | } -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision.transforms as T 4 | import numpy as np 5 | 6 | 7 | def crop_arrays(*arrays, base_size=(256, 256), img_size=(224, 224), random=True, get_offsets=False, offset_cuts=None): 8 | ''' 9 | Crop arrays from base_size to img_size. 10 | Apply center crop if not random. 11 | ''' 12 | if base_size[0] == img_size[0] and base_size[1] == img_size[1]: 13 | if get_offsets: 14 | return arrays, (0, 0) 15 | else: 16 | return arrays 17 | 18 | if random: 19 | if offset_cuts is not None: 20 | off_H = np.random.randint(min(base_size[0] - img_size[0], offset_cuts[0])) 21 | off_W = np.random.randint(min(base_size[1] - img_size[1], offset_cuts[1])) 22 | else: 23 | off_H = np.random.randint(base_size[0] - img_size[0]) 24 | off_W = np.random.randint(base_size[1] - img_size[1]) 25 | else: 26 | if offset_cuts is not None: 27 | off_H = min( 28 | max(0, offset_cuts[0] - (base_size[0] - img_size[0]) // 2), 29 | (base_size[0] - img_size[0]) // 2 30 | ) 31 | off_W = min( 32 | max(0, offset_cuts[1] - (base_size[1] - img_size[1]) // 2), 33 | (base_size[1] - img_size[1]) // 2 34 | ) 35 | else: 36 | off_H = (base_size[0] - img_size[0]) // 2 37 | off_W = (base_size[1] - img_size[1]) // 2 38 | 39 | slice_H = slice(off_H, off_H + img_size[0]) 40 | slice_W = slice(off_W, off_W + img_size[1]) 41 | 42 | arrays_cropped = [] 43 | for array in arrays: 44 | if array is not None: 45 | assert array.ndim >= 2 46 | array_cropped = array[..., slice_H, slice_W] 47 | arrays_cropped.append(array_cropped) 48 | else: 49 | arrays_cropped.append(array) 50 | 51 | if get_offsets: 52 | return arrays_cropped, (off_H, off_W) 53 | else: 54 | return arrays_cropped 55 | 56 | 57 | def mix_fivecrop(x_crop, base_size=256, crop_size=224): 58 | margin = base_size - crop_size 59 | submargin = margin // 2 60 | 61 | ### Five-pad each crops 62 | pads = [ 63 | T.Pad((0, 0, margin, margin)), 64 | T.Pad((margin, 0, 0, margin)), 65 | T.Pad((0, margin, margin, 0)), 66 | T.Pad((margin, margin, 0, 0)), 67 | T.Pad((submargin, submargin, submargin, submargin)), 68 | ] 69 | x_pad = [] 70 | for x_, pad in zip(x_crop, pads): 71 | x_pad.append(pad(x_)) 72 | x_pad = torch.stack(x_pad) 73 | 74 | x_avg = torch.zeros_like(x_pad[0]) 75 | 76 | ### Mix padded crops 77 | # top-left corner 78 | x_avg[:, :, :submargin, :margin] = x_pad[0][:, :, :submargin, :margin] 79 | x_avg[:, :, submargin:margin, :submargin] = x_pad[0][:, :, submargin:margin, :submargin] 80 | x_avg[:, :, submargin:margin, submargin:margin] = (x_pad[0][:, :, submargin:margin, submargin:margin] + \ 81 | x_pad[4][:, :, submargin:margin, submargin:margin]) / 2 82 | 83 | # top-right corner 84 | x_avg[:, :, :submargin, -margin:] = x_pad[1][:, :, :submargin, -margin:] 85 | x_avg[:, :, submargin:margin, -submargin:] = x_pad[1][:, :, submargin:margin, -submargin:] 86 | x_avg[:, :, submargin:margin, -margin:-submargin] = (x_pad[1][:, :, submargin:margin, -margin:-submargin] + \ 87 | x_pad[4][:, :, submargin:margin, -margin:-submargin]) / 2 88 | 89 | # bottom-left corner 90 | x_avg[:, :, -submargin:, :margin] = x_pad[2][:, :, -submargin:, :margin] 91 | x_avg[:, :, -margin:-submargin:, :submargin] = x_pad[2][:, :, -margin:-submargin, :submargin] 92 | x_avg[:, :, -margin:-submargin, submargin:margin] = (x_pad[2][:, :, -margin:-submargin, submargin:margin] + \ 93 | x_pad[4][:, :, -margin:-submargin, submargin:margin]) / 2 94 | 95 | # bottom-left corner 96 | x_avg[:, :, -submargin:, -margin:] = x_pad[3][:, :, -submargin:, -margin:] 97 | x_avg[:, :, -margin:-submargin, -submargin:] = x_pad[3][:, :, -margin:-submargin, -submargin:] 98 | x_avg[:, :, -margin:-submargin, -margin:-submargin] = (x_pad[3][:, :, -margin:-submargin, -margin:-submargin] + \ 99 | x_pad[4][:, :, -margin:-submargin, -margin:-submargin]) / 2 100 | 101 | # top side 102 | x_avg[:, :, :submargin, margin:-margin] = (x_pad[0][:, :, :submargin, margin:-margin] + \ 103 | x_pad[1][:, :, :submargin, margin:-margin]) / 2 104 | x_avg[:, :, submargin:margin, margin:-margin] = (x_pad[0][:, :, submargin:margin, margin:-margin] + \ 105 | x_pad[1][:, :, submargin:margin, margin:-margin] + \ 106 | x_pad[4][:, :, submargin:margin, margin:-margin]) / 3 107 | 108 | # right side 109 | x_avg[:, :, margin:-margin, -submargin:] = (x_pad[1][:, :, margin:-margin, -submargin:] + \ 110 | x_pad[3][:, :, margin:-margin, -submargin:]) / 2 111 | x_avg[:, :, margin:-margin, -margin:-submargin] = (x_pad[1][:, :, margin:-margin, -margin:-submargin] + \ 112 | x_pad[3][:, :, margin:-margin, -margin:-submargin] + \ 113 | x_pad[4][:, :, margin:-margin, -margin:-submargin]) / 3 114 | 115 | # bottom side 116 | x_avg[:, :, -submargin:, margin:-margin] = (x_pad[2][:, :, -submargin:, margin:-margin] + \ 117 | x_pad[3][:, :, -submargin:, margin:-margin]) / 2 118 | x_avg[:, :, -margin:-submargin:, margin:-margin] = (x_pad[2][:, :, -margin:-submargin, margin:-margin] + \ 119 | x_pad[3][:, :, -margin:-submargin, margin:-margin] + \ 120 | x_pad[4][:, :, -margin:-submargin, margin:-margin]) / 3 121 | 122 | # left side 123 | x_avg[:, :, margin:-margin, :submargin] = (x_pad[0][:, :, margin:-margin, :submargin] + \ 124 | x_pad[2][:, :, margin:-margin, :submargin]) / 2 125 | x_avg[:, :, margin:-margin, submargin:margin] = (x_pad[0][:, :, margin:-margin, submargin:margin] + \ 126 | x_pad[2][:, :, margin:-margin, submargin:margin] + \ 127 | x_pad[4][:, :, margin:-margin, submargin:margin]) / 3 128 | 129 | # center 130 | x_avg[:, :, margin:-margin, margin:-margin] = (x_pad[0][:, :, margin:-margin, margin:-margin] + \ 131 | x_pad[1][:, :, margin:-margin, margin:-margin] + \ 132 | x_pad[2][:, :, margin:-margin, margin:-margin] + \ 133 | x_pad[3][:, :, margin:-margin, margin:-margin] + \ 134 | x_pad[4][:, :, margin:-margin, margin:-margin]) / 5 135 | 136 | return x_avg 137 | 138 | 139 | def to_device(data, device=None, dtype=None): 140 | ''' 141 | Load data with arbitrary structure on device. 142 | ''' 143 | def to_device_wrapper(data): 144 | if isinstance(data, torch.Tensor): 145 | return data.to(device=device, dtype=dtype) 146 | elif isinstance(data, tuple): 147 | return tuple(map(to_device_wrapper, data)) 148 | elif isinstance(data, list): 149 | return list(map(to_device_wrapper, data)) 150 | elif isinstance(data, dict): 151 | return {key: to_device_wrapper(data[key]) for key in data} 152 | else: 153 | return data 154 | 155 | return to_device_wrapper(data) 156 | 157 | 158 | def pad_by_reflect(x, padding=1): 159 | x = torch.cat((x[..., :padding], x, x[..., -padding:]), dim=-1) 160 | x = torch.cat((x[..., :padding, :], x, x[..., -padding:, :]), dim=-2) 161 | return x 162 | 163 | 164 | class SobelEdgeDetector: 165 | def __init__(self, kernel_size=5, sigma=1): 166 | self.kernel_size = kernel_size 167 | self.sigma = sigma 168 | 169 | # compute gaussian kernel 170 | size = kernel_size // 2 171 | x, y = np.mgrid[-size:size+1, -size:size+1] 172 | normal = 1 / (2.0 * np.pi * sigma**2) 173 | g = np.exp(-((x**2 + y**2) / (2.0*sigma**2))) * normal 174 | 175 | self.gaussian_kernel = torch.from_numpy(g)[None, None, :, :].float() 176 | self.Kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float)[None, None, :, :] 177 | self.Ky = -torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float)[None, None, :, :] 178 | 179 | def detect(self, img, normalize=True): 180 | squeeze = False 181 | if len(img.shape) == 3: 182 | img = img[None, ...] 183 | squeeze = True 184 | 185 | img = pad_by_reflect(img, padding=self.kernel_size//2) 186 | img = F.conv2d(img, self.gaussian_kernel.repeat(1, img.size(1), 1, 1)) 187 | 188 | img = pad_by_reflect(img, padding=1) 189 | Gx = F.conv2d(img, self.Kx) 190 | Gy = F.conv2d(img, self.Ky) 191 | 192 | G = (Gx.pow(2) + Gy.pow(2)).pow(0.5) 193 | if normalize: 194 | G = G / G.max() 195 | if squeeze: 196 | G = G[0] 197 | 198 | return G 199 | -------------------------------------------------------------------------------- /dataset/dataloader_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import os 4 | from einops import rearrange, repeat 5 | 6 | from .taskonomy import TaskonomyHybridDataset, TaskonomyContinuousDataset, TaskonomySegmentationDataset, TaskonomyFinetuneDataset 7 | from .taskonomy_constants import TASKS, TASKS_GROUP_DICT, TASKS_GROUP_TRAIN, TASKS_GROUP_NAMES, \ 8 | BUILDINGS, BUILDINGS_TRAIN, BUILDINGS_VALID, BUILDINGS_TEST, SEMSEG_CLASSES 9 | from .utils import crop_arrays 10 | 11 | 12 | base_sizes = { 13 | 224: (256, 256) 14 | } 15 | 16 | 17 | def get_train_loader(config, pin_memory=True, verbose=True, get_support_data=False): 18 | ''' 19 | Load training dataloader. 20 | ''' 21 | # set dataset size 22 | if get_support_data: 23 | dset_size = config.shot 24 | elif config.no_eval: 25 | dset_size = config.n_steps*config.global_batch_size 26 | else: 27 | dset_size = config.val_iter*config.global_batch_size 28 | 29 | # compute common arguments 30 | common_kwargs = { 31 | 'base_size': base_sizes[config.img_size], 32 | 'img_size': (config.img_size, config.img_size), 33 | 'dset_size': dset_size, 34 | 'seed': config.seed + int(os.environ.get('LOCAL_RANK', 0)), 35 | 'precision': config.precision, 36 | 'root_dir': config.root_dir, 37 | 'buildings': BUILDINGS_TRAIN, 38 | } 39 | 40 | # create dataset for episodic training 41 | if config.stage == 0: 42 | tasks = TASKS if config.task == 'all' else TASKS_GROUP_TRAIN[config.task_fold] 43 | if verbose: 44 | print(f'Loading tasks {", ".join(tasks)} in train split.') 45 | 46 | # create training dataset. 47 | train_data = TaskonomyHybridDataset( 48 | tasks=tasks, 49 | shot=config.shot, 50 | tasks_per_batch=config.max_channels, 51 | domains_per_batch=config.domains_per_batch, 52 | image_augmentation=config.image_augmentation, 53 | unary_augmentation=config.unary_augmentation, 54 | binary_augmentation=config.binary_augmentation, 55 | mixed_augmentation=config.mixed_augmentation, 56 | **common_kwargs, 57 | ) 58 | # create dataset for fine-tuning or testing 59 | else: 60 | if config.task in ['', 'all']: 61 | raise ValueError("task should be specified for fine-tuning") 62 | 63 | train_data = TaskonomyFinetuneDataset( 64 | task=config.task, 65 | shot=config.shot, 66 | support_idx=config.support_idx, 67 | channel_idx=config.channel_idx, 68 | image_augmentation=(config.image_augmentation and config.task != 'normal'), 69 | **common_kwargs 70 | ) 71 | if get_support_data: 72 | train_data.fix_seed = True 73 | if config.stage == 1: 74 | train_data.img_size = base_sizes[config.img_size] 75 | support_loader = DataLoader(train_data, batch_size=1, shuffle=False, drop_last=False) 76 | for support_data in support_loader: 77 | break 78 | 79 | return support_data 80 | 81 | 82 | # create training loader. 83 | train_loader = DataLoader(train_data, batch_size=(config.global_batch_size // torch.cuda.device_count()), 84 | shuffle=False, pin_memory=pin_memory, 85 | drop_last=True, num_workers=config.num_workers) 86 | 87 | return train_loader 88 | 89 | 90 | def get_eval_loader(config, task, split='valid', channel_idx=-1, pin_memory=True, verbose=True): 91 | ''' 92 | Load evaluation dataloader. 93 | ''' 94 | # no crop for evaluation. 95 | img_size = base_size = base_sizes[config.img_size] 96 | 97 | # choose appropriate split. 98 | if split == 'train': 99 | buildings = BUILDINGS_TRAIN 100 | elif split == 'valid': 101 | buildings = BUILDINGS_VALID 102 | elif split == 'test': 103 | buildings = BUILDINGS_TEST 104 | elif split in BUILDINGS: 105 | buildings = [split] 106 | 107 | # evaluate some subset or the whole data. 108 | if config.n_eval_batches > 0: 109 | dset_size = config.n_eval_batches * config.eval_batch_size 110 | else: 111 | dset_size = -1 112 | 113 | # common arguments for both continuous and segmentation datasets. 114 | common_kwargs = { 115 | 'root_dir': config.root_dir, 116 | 'buildings': buildings, 117 | 'dset_size': dset_size, 118 | 'base_size': base_size, 119 | 'img_size': img_size, 120 | 'seed': int(os.environ.get('LOCAL_RANK', 0)), 121 | 'precision': config.precision, 122 | } 123 | if verbose: 124 | if channel_idx < 0: 125 | print(f'Loading task {task} in {split} split.') 126 | else: 127 | print(f'Loading task {task}_{channel_idx} in {split} split.') 128 | 129 | # create appropriate dataset. 130 | if task == 'segment_semantic': 131 | assert channel_idx in SEMSEG_CLASSES 132 | eval_data = TaskonomySegmentationDataset( 133 | semseg_class=channel_idx, 134 | **common_kwargs 135 | ) 136 | else: 137 | eval_data = TaskonomyContinuousDataset( 138 | task=task, 139 | channel_idx=channel_idx, 140 | **common_kwargs 141 | ) 142 | 143 | # create dataloader. 144 | eval_loader = DataLoader(eval_data, batch_size=(config.eval_batch_size // torch.cuda.device_count()), 145 | shuffle=False, pin_memory=pin_memory, 146 | drop_last=False, num_workers=1) 147 | 148 | return eval_loader 149 | 150 | 151 | def get_validation_loaders(config, verbose=True): 152 | ''' 153 | Load validation loaders (of unseen images) for training tasks. 154 | ''' 155 | if config.stage == 0: 156 | if config.task == 'all': 157 | train_tasks = TASKS_GROUP_DICT 158 | else: 159 | train_tasks = TASKS_GROUP_TRAIN[config.task_fold] 160 | loader_tag = 'mtrain_valid' 161 | else: 162 | if config.task in ['', 'all']: 163 | raise ValueError("task should be specified for fine-tuning") 164 | train_tasks = [config.task] 165 | loader_tag = 'mtest_valid' 166 | 167 | valid_loaders = {} 168 | for task in train_tasks: 169 | if task == 'segment_semantic': 170 | if config.channel_idx < 0: 171 | channels = SEMSEG_CLASSES 172 | else: 173 | channels = [config.channel_idx] 174 | for c in channels: 175 | valid_loaders[f'segment_semantic_{c}'] = get_eval_loader(config, task, 'valid', c, verbose=verbose) 176 | else: 177 | valid_loaders[task] = get_eval_loader(config, task, 'valid', verbose=verbose) 178 | 179 | return valid_loaders, loader_tag 180 | 181 | 182 | def generate_support_data(config, data_path, split='train', support_idx=0, verbose=True): 183 | ''' 184 | Generate support data for all tasks. 185 | ''' 186 | if os.path.exists(data_path): 187 | support_data = torch.load(data_path) 188 | else: 189 | support_data = {} 190 | 191 | modified = False 192 | base_size = img_size = base_sizes[config.img_size] 193 | 194 | if split == 'train': 195 | buildings = BUILDINGS_TRAIN 196 | elif split == 'valid': 197 | buildings = BUILDINGS_VALID 198 | elif split == 'test': 199 | buildings = BUILDINGS_TEST 200 | else: 201 | raise ValueError(split) 202 | 203 | common_kwargs = { 204 | 'root_dir': config.root_dir, 205 | 'buildings': buildings, 206 | 'base_size': base_size, 207 | 'img_size': img_size, 208 | 'seed': int(os.environ.get('LOCAL_RANK', 0)), 209 | 'precision': 'fp32', 210 | } 211 | 212 | for task in TASKS_GROUP_NAMES: 213 | if task == 'segment_semantic': 214 | for c in SEMSEG_CLASSES: 215 | if f'segment_semantic_{c}' in support_data: 216 | continue 217 | 218 | dset = TaskonomySegmentationDataset( 219 | semseg_class=c, 220 | **common_kwargs 221 | ) 222 | dloader = DataLoader(dset, batch_size=config.shot, shuffle=False, num_workers=0) 223 | for idx, batch in enumerate(dloader): 224 | if idx == support_idx: 225 | break 226 | 227 | X, Y, M = batch 228 | 229 | T = Y.size(1) 230 | X = repeat(X, 'N C H W -> 1 T N C H W', T=T) 231 | Y = rearrange(Y, 'N T H W -> 1 T N 1 H W') 232 | M = rearrange(M, 'N T H W -> 1 T N 1 H W') 233 | X, Y, M = crop_arrays(X, Y, M, base_size=base_size, img_size=(config.img_size, config.img_size), 234 | random=False) 235 | 236 | t_idx = torch.tensor([[TASKS.index(f'segment_semantic_{c}')]]) 237 | 238 | support_data[f'segment_semantic_{c}'] = (X, Y, M, t_idx) 239 | if verbose: 240 | print(f'generated support data for task segment_semantic_{c}') 241 | modified = True 242 | else: 243 | if task in support_data: 244 | continue 245 | 246 | dset = TaskonomyContinuousDataset( 247 | task=task, 248 | **common_kwargs 249 | ) 250 | 251 | dloader = DataLoader(dset, batch_size=config.shot, shuffle=False, num_workers=0) 252 | for idx, batch in enumerate(dloader): 253 | if idx == support_idx: 254 | break 255 | 256 | X, Y, M = batch 257 | T = Y.size(1) 258 | X = repeat(X, 'N C H W -> 1 T N C H W', T=T) 259 | Y = rearrange(Y, 'N T H W -> 1 T N 1 H W') 260 | M = rearrange(M, 'N T H W -> 1 T N 1 H W') 261 | X, Y, M = crop_arrays(X, Y, M, base_size=base_size, img_size=(config.img_size, config.img_size), 262 | random=False) 263 | 264 | t_idx = torch.tensor([[TASKS.index(f'{task}_{c}') for c in range(len(TASKS_GROUP_DICT[task]))]]) 265 | 266 | support_data[task] = (X, Y, M, t_idx) 267 | if verbose: 268 | print(f'generated support data for task {task}') 269 | modified = True 270 | 271 | if modified: 272 | torch.save(support_data, data_path) 273 | 274 | return support_data -------------------------------------------------------------------------------- /model/beit/beit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | import numpy as np 5 | from scipy import interpolate 6 | 7 | from .beit_factory import create_model as create_custom_model 8 | 9 | 10 | class BEiTEncoder(nn.Module): 11 | def __init__(self, model_name='beit_base_patch16_224_in22k', 12 | drop_rate=0.0, drop_path_rate=0.1, attn_drop_rate=0.0, 13 | n_tasks=0, bitfit=True, n_levels=1): 14 | super().__init__() 15 | self.beit = create_custom_model( 16 | model_name, 17 | pretrained=False, 18 | num_classes=0, 19 | drop_rate=drop_rate, 20 | drop_path_rate=drop_path_rate, 21 | attn_drop_rate=attn_drop_rate, 22 | init_scale=0.001, 23 | n_tasks=(n_tasks if bitfit else 0), 24 | ) 25 | 26 | self.model_name = model_name 27 | self.img_size = self.beit.patch_embed.img_size 28 | self.grid_size = self.beit.patch_embed.grid_size 29 | self.patch_size = self.beit.patch_embed.patch_size 30 | self.embed_dim = self.beit.embed_dim 31 | self.n_tasks = n_tasks 32 | self.n_levels = n_levels 33 | self.feature_blocks = [level * (len(self.beit.blocks) // self.n_levels) - 1 for level in range(1, self.n_levels+1)] 34 | 35 | def bias_parameters(self): 36 | for key, param in self.beit.named_parameters(): 37 | if key.split('.')[0] == 'blocks' and key.split('.')[-1] == 'bias' and key.split('.')[-3] != 'patch_embed': 38 | yield param 39 | 40 | def bias_parameter_names(self): 41 | names = [] 42 | for key, _ in self.beit.named_parameters(): 43 | if key.split('.')[0] == 'blocks' and key.split('.')[-1] == 'bias' and key.split('.')[-3] != 'patch_embed': 44 | names.append(f'beit.{key}') 45 | return names 46 | 47 | def tokenize(self, x): 48 | # project image patches to tokens 49 | x = self.beit.patch_embed(x) 50 | 51 | # add CLS token 52 | x = torch.cat((self.beit.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 53 | 54 | # positional embedding 55 | if self.beit.pos_embed is not None: 56 | x = x + self.beit.pos_embed 57 | x = self.beit.pos_drop(x) 58 | 59 | return x 60 | 61 | def forward(self, x, t_idx=None, get_features=False): 62 | x = self.tokenize(x) 63 | rel_pos_bias = self.beit.rel_pos_bias() if self.beit.rel_pos_bias is not None else None 64 | 65 | if get_features: 66 | features = [] 67 | 68 | # transformer blocks 69 | for blk_idx in range(len(self.beit.blocks)): 70 | x = self.beit.blocks[blk_idx](x, rel_pos_bias=rel_pos_bias, t_idx=t_idx) 71 | 72 | if get_features and blk_idx in self.feature_blocks: 73 | feature = x[:, 1:] 74 | feature = rearrange(feature, 'B (H W) C -> B C H W', H=self.grid_size[0]).contiguous() 75 | features.append(feature) 76 | 77 | if get_features: 78 | return features 79 | else: 80 | # cut off CLS token, then rearrange into spatial maps 81 | x = x[:, 1:] 82 | x = rearrange(x, 'B (H W) C -> B C H W', H=self.grid_size[0], W=self.grid_size[1]).contiguous() 83 | 84 | return x 85 | 86 | 87 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 88 | missing_keys = [] 89 | unexpected_keys = [] 90 | error_msgs = [] 91 | # copy state_dict so _load_from_state_dict can modify it 92 | metadata = getattr(state_dict, '_metadata', None) 93 | state_dict = state_dict.copy() 94 | if metadata is not None: 95 | state_dict._metadata = metadata 96 | 97 | def load(module, prefix=''): 98 | local_metadata = {} if metadata is None else metadata.get( 99 | prefix[:-1], {}) 100 | module._load_from_state_dict( 101 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 102 | for name, child in module._modules.items(): 103 | if child is not None: 104 | load(child, prefix + name + '.') 105 | 106 | load(model, prefix=prefix) 107 | 108 | warn_missing_keys = [] 109 | ignore_missing_keys = [] 110 | for key in missing_keys: 111 | keep_flag = True 112 | for ignore_key in ignore_missing.split('|'): 113 | if ignore_key in key: 114 | keep_flag = False 115 | break 116 | if keep_flag: 117 | warn_missing_keys.append(key) 118 | else: 119 | ignore_missing_keys.append(key) 120 | 121 | missing_keys = warn_missing_keys 122 | 123 | if len(error_msgs) > 0: 124 | print('\n'.join(error_msgs)) 125 | 126 | 127 | def load_beit_ckpt(model, ckpt_path, n_bitfit_tasks=0, verbose=True): 128 | model_key = 'model|module' 129 | checkpoint = torch.load(ckpt_path, map_location='cpu') 130 | 131 | if verbose: 132 | print("Load ckpt from %s" % ckpt_path) 133 | checkpoint_model = None 134 | for model_key in model_key.split('|'): 135 | if model_key in checkpoint: 136 | checkpoint_model = checkpoint[model_key] 137 | if verbose: 138 | print("Load state_dict by model_key = %s" % model_key) 139 | break 140 | if checkpoint_model is None: 141 | checkpoint_model = checkpoint 142 | state_dict = model.state_dict() 143 | for k in ['head.weight', 'head.bias']: 144 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 145 | if verbose: 146 | print(f"Removing key {k} from pretrained checkpoint") 147 | del checkpoint_model[k] 148 | 149 | if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model: 150 | if verbose: 151 | print("Expand the shared relative position embedding to each transformer block. ") 152 | num_layers = model.get_num_layers() 153 | rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"] 154 | for i in range(num_layers): 155 | checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() 156 | 157 | checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") 158 | 159 | all_keys = list(checkpoint_model.keys()) 160 | for key in all_keys: 161 | if "relative_position_index" in key: 162 | checkpoint_model.pop(key) 163 | 164 | if "relative_position_bias_table" in key: 165 | rel_pos_bias = checkpoint_model[key] 166 | src_num_pos, num_attn_heads = rel_pos_bias.size() 167 | dst_num_pos, _ = model.state_dict()[key].size() 168 | try: 169 | dst_patch_shape = model.patch_embed.patch_size 170 | except AttributeError: 171 | dst_patch_shape = model.patch_embed[0].patch_size 172 | if dst_patch_shape[0] != dst_patch_shape[1]: 173 | raise NotImplementedError() 174 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) 175 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 176 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 177 | if src_size != dst_size: 178 | if verbose: 179 | print("Position interpolate for %s from %dx%d to %dx%d" % ( 180 | key, src_size, src_size, dst_size, dst_size)) 181 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 182 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 183 | 184 | def geometric_progression(a, r, n): 185 | return a * (1.0 - r ** n) / (1.0 - r) 186 | 187 | left, right = 1.01, 1.5 188 | while right - left > 1e-6: 189 | q = (left + right) / 2.0 190 | gp = geometric_progression(1, q, src_size // 2) 191 | if gp > dst_size // 2: 192 | right = q 193 | else: 194 | left = q 195 | 196 | # if q > 1.090307: 197 | # q = 1.090307 198 | 199 | dis = [] 200 | cur = 1 201 | for i in range(src_size // 2): 202 | dis.append(cur) 203 | cur += q ** (i + 1) 204 | 205 | r_ids = [-_ for _ in reversed(dis)] 206 | 207 | x = r_ids + [0] + dis 208 | y = r_ids + [0] + dis 209 | 210 | t = dst_size // 2.0 211 | dx = np.arange(-t, t + 0.1, 1.0) 212 | dy = np.arange(-t, t + 0.1, 1.0) 213 | 214 | if verbose: 215 | print("Original positions = %s" % str(x)) 216 | print("Target positions = %s" % str(dx)) 217 | 218 | all_rel_pos_bias = [] 219 | 220 | for i in range(num_attn_heads): 221 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 222 | f = interpolate.interp2d(x, y, z, kind='cubic') 223 | all_rel_pos_bias.append( 224 | torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) 225 | 226 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 227 | 228 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 229 | checkpoint_model[key] = new_rel_pos_bias 230 | 231 | # interpolate position embedding 232 | if 'pos_embed' in checkpoint_model: 233 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 234 | embedding_size = pos_embed_checkpoint.shape[-1] 235 | num_patches = model.patch_embed.num_patches 236 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 237 | # height (== width) for the checkpoint position embedding 238 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 239 | # height (== width) for the new position embedding 240 | new_size = int(num_patches ** 0.5) 241 | # class_token and dist_token are kept unchanged 242 | if orig_size != new_size: 243 | if verbose: 244 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 245 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 246 | # only the position tokens are interpolated 247 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 248 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 249 | pos_tokens = torch.nn.functional.interpolate( 250 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 251 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 252 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 253 | checkpoint_model['pos_embed'] = new_pos_embed 254 | 255 | 256 | if n_bitfit_tasks > 0: 257 | for key in checkpoint_model: 258 | if key.split('.')[0] == 'blocks' and key.split('.')[-1] == 'bias' and key.split('.')[-3] != 'patch_embed': 259 | checkpoint_model[key] = torch.stack([checkpoint_model[key] for _ in range(n_bitfit_tasks)]).contiguous() 260 | 261 | load_state_dict(model, checkpoint_model) 262 | -------------------------------------------------------------------------------- /train/train_utils.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar 3 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 4 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 5 | from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO 6 | 7 | import os 8 | import sys 9 | import shutil 10 | import random 11 | import tqdm 12 | 13 | import numpy as np 14 | import torch 15 | 16 | from .trainer import LightningTrainWrapper 17 | from dataset.taskonomy_constants import TASKS, TASKS_GROUP_DICT 18 | 19 | 20 | def configure_experiment(config, model): 21 | # set seeds 22 | set_seeds(config.seed) 23 | 24 | # set directories 25 | log_dir, save_dir = set_directories(config, 26 | exp_name=config.exp_name, 27 | exp_subname=(config.exp_subname if config.stage >= 1 else ''), 28 | create_save_dir=(config.stage <= 1)) 29 | 30 | # create lightning callbacks, logger, and checkpoint plugin 31 | if config.stage <= 1: 32 | callbacks = set_callbacks(config, save_dir, config.monitor, ptf=config.save_postfix) 33 | logger = CustomTBLogger(log_dir, name='', version='', default_hp_metric=False) 34 | else: 35 | callbacks = set_callbacks(config, save_dir) 36 | logger = None 37 | 38 | # parse precision 39 | precision = int(config.precision.strip('fp')) if config.precision in ['fp16', 'fp32'] else config.precision 40 | 41 | # choose accelerator 42 | strategy = set_strategy(config.strategy) 43 | 44 | # choose plugins 45 | if config.stage == 1: 46 | plugins = [CustomCheckpointIO([f'model.{name}' for name in model.model.bias_parameter_names()])] 47 | else: 48 | plugins = None 49 | 50 | return logger, log_dir, callbacks, precision, strategy, plugins 51 | 52 | 53 | def set_seeds(seed): 54 | torch.manual_seed(seed) 55 | torch.cuda.manual_seed(seed) 56 | random.seed(seed) 57 | np.random.seed(seed) 58 | 59 | 60 | def set_directories(config, root_dir='experiments', exp_name='', log_dir='logs', save_dir='checkpoints', 61 | create_log_dir=True, create_save_dir=True, dir_postfix='', exp_subname=''): 62 | # make an experiment name 63 | if exp_name == '': 64 | if config.task == '': 65 | exp_name = config.exp_name = f'{config.model}_fold:{config.task_fold}{config.name_postfix}' 66 | else: 67 | exp_name = config.exp_name = f'{config.model}_task:{config.task}{config.name_postfix}' 68 | 69 | # create the root directory 70 | os.makedirs(root_dir, exist_ok=True) 71 | 72 | # set logging directory 73 | if create_log_dir: 74 | os.makedirs(os.path.join(root_dir, config.log_dir), exist_ok=True) 75 | log_root = os.path.join(root_dir, config.log_dir, exp_name + dir_postfix) 76 | os.makedirs(log_root, exist_ok=True) 77 | if exp_subname != '': 78 | log_root = os.path.join(log_root, exp_subname) 79 | os.makedirs(log_root, exist_ok=True) 80 | log_dir = os.path.join(log_root, log_dir) 81 | 82 | # reset the logging directory if exists 83 | if config.stage == 0 and os.path.exists(log_dir) and not (config.continue_mode or config.skip_mode): 84 | shutil.rmtree(log_dir) 85 | os.makedirs(log_dir, exist_ok=True) 86 | else: 87 | log_dir = None 88 | 89 | # set saving directory 90 | if create_save_dir: 91 | save_root = os.path.join(root_dir, config.save_dir, exp_name + dir_postfix) 92 | if exp_subname != '': 93 | save_root = os.path.join(save_root, exp_subname) 94 | save_dir = os.path.join(save_root, save_dir) 95 | 96 | # create the saving directory if checkpoint doesn't exist or in skipping mode, 97 | # otherwise ask user to reset it 98 | if config.stage == 0 and os.path.exists(save_dir) and int(os.environ.get('LOCAL_RANK', 0)) == 0: 99 | if config.continue_mode: 100 | print(f'resume from checkpoint ({exp_name})') 101 | elif config.skip_mode: 102 | print(f'skip the existing checkpoint ({exp_name})') 103 | sys.exit() 104 | elif config.debug_mode or config.reset_mode: 105 | print(f'remove existing checkpoint ({exp_name})') 106 | shutil.rmtree(save_dir) 107 | else: 108 | while True: 109 | print(f'redundant experiment name! ({exp_name}) remove existing checkpoints? (y/n)') 110 | inp = input() 111 | if inp == 'y': 112 | shutil.rmtree(save_dir) 113 | break 114 | elif inp == 'n': 115 | print('quit') 116 | sys.exit() 117 | else: 118 | print('invalid input') 119 | os.makedirs(save_dir, exist_ok=True) 120 | else: 121 | save_dir = None 122 | 123 | return log_dir, save_dir 124 | 125 | 126 | def set_strategy(strategy): 127 | if strategy == 'ddp': 128 | strategy = pl.strategies.DDPStrategy() 129 | else: 130 | strategy = None 131 | 132 | return strategy 133 | 134 | 135 | def set_callbacks(config, save_dir, monitor=None, ptf=''): 136 | callbacks = [ 137 | CustomProgressBar(), 138 | ] 139 | if ((not config.no_eval) and 140 | monitor is not None and 141 | config.early_stopping_patience > 0): 142 | callbacks.append(CustomEarlyStopping(monitor=monitor, mode="min", patience=config.early_stopping_patience)) 143 | 144 | if not config.no_save and save_dir is not None: 145 | # last checkpointing 146 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 147 | dirpath=save_dir, 148 | filename=f'last{ptf}', 149 | auto_insert_metric_name=False, 150 | every_n_epochs=1, 151 | save_top_k=1, 152 | save_last=False, 153 | monitor='epoch', 154 | mode='max', 155 | ) 156 | checkpoint_callback.CHECKPOINT_JOIN_CHAR = "_" 157 | callbacks.append(checkpoint_callback) 158 | 159 | # best checkpointing 160 | if not (config.no_eval or monitor is None): 161 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 162 | dirpath=save_dir, 163 | filename=f'best{ptf}', 164 | auto_insert_metric_name=False, 165 | every_n_epochs=1, 166 | save_top_k=1, 167 | save_last=False, 168 | monitor=monitor, 169 | mode='min', 170 | ) 171 | checkpoint_callback.CHECKPOINT_JOIN_CHAR = "_" 172 | callbacks.append(checkpoint_callback) 173 | 174 | return callbacks 175 | 176 | 177 | def get_ckpt_path(load_dir, exp_name, load_step, exp_subname='', save_postfix=''): 178 | if load_step == 0: 179 | ckpt_name = f'best{save_postfix}.ckpt' 180 | elif load_step < 0: 181 | ckpt_name = f'last{save_postfix}.ckpt' 182 | else: 183 | ckpt_name = f'step_{load_step:06d}.ckpt' 184 | 185 | load_path = os.path.join('experiments', load_dir, exp_name, exp_subname, 'checkpoints', ckpt_name) 186 | if not os.path.exists(load_path): 187 | raise FileNotFoundError(f"checkpoint ({load_path}) does not exists!") 188 | 189 | return load_path 190 | 191 | 192 | def copy_values(config_new, config_old): 193 | for key in config_new.__dir__(): 194 | if key[:2] != '__': 195 | setattr(config_old, key, getattr(config_new, key)) 196 | 197 | 198 | def load_ckpt(ckpt_path, config_new=None): 199 | ckpt = torch.load(ckpt_path) 200 | state_dict = ckpt['state_dict'] 201 | config = ckpt['hyper_parameters']['config'] 202 | 203 | # merge config 204 | if config_new is not None: 205 | copy_values(config_new, config) 206 | 207 | return state_dict, config 208 | 209 | 210 | def select_task_specific_parameters(config, model, state_dict): 211 | if config.channel_idx < 0: 212 | t_idx = torch.tensor([TASKS.index(task) for task in TASKS_GROUP_DICT[config.task]]) 213 | else: 214 | t_idx = torch.tensor([TASKS.index(f'{config.task}_{config.channel_idx}')]) 215 | 216 | # for fine-tuning 217 | bias_parameters = [f'model.{name}' for name in model.model.bias_parameter_names()] 218 | for key in state_dict.keys(): 219 | if key in bias_parameters: 220 | state_dict[key] = state_dict[key][t_idx] 221 | 222 | 223 | def load_model(config, verbose=True): 224 | load_path = None 225 | 226 | # create trainer for episodic training 227 | if config.stage == 0: 228 | model = LightningTrainWrapper(config, verbose=verbose) 229 | if config.continue_mode: 230 | load_path = get_ckpt_path(config.load_dir, config.exp_name, -1, save_postfix=config.save_postfix) 231 | 232 | # create trainer for fine-tuning or evaluation 233 | else: 234 | # load meta-trained checkpoint 235 | ckpt_path = get_ckpt_path(config.load_dir, config.exp_name, 0) 236 | state_dict, config = load_ckpt(ckpt_path, config) 237 | 238 | model = LightningTrainWrapper(config=config, verbose=verbose) 239 | # select task-specific parameters for test task 240 | if config.stage == 1: 241 | select_task_specific_parameters(config, model, state_dict) 242 | # load fine-tuned checkpoint 243 | else: 244 | ft_ckpt_path = get_ckpt_path(config.save_dir, config.exp_name, 0, config.exp_subname, config.save_postfix) 245 | ft_state_dict, _ = load_ckpt(ft_ckpt_path) 246 | for key in ft_state_dict: 247 | state_dict[key] = ft_state_dict[key] 248 | 249 | print(model.load_state_dict(state_dict)) 250 | if verbose: 251 | print(f'meta-trained checkpoint loaded from {ckpt_path}') 252 | if config.stage == 2: 253 | print(f'fine-tuned checkpoint loaded from {ft_ckpt_path}') 254 | 255 | return model, load_path 256 | 257 | 258 | class CustomProgressBar(TQDMProgressBar): 259 | def __init__(self, rescale_validation_batches=1): 260 | super().__init__() 261 | self.rescale_validation_batches = rescale_validation_batches 262 | 263 | def init_train_tqdm(self): 264 | """Override this to customize the tqdm bar for training.""" 265 | bar = tqdm.tqdm( 266 | desc="Training", 267 | bar_format="{desc:<5}{percentage:3.0f}%|{bar:10}{r_bar}", 268 | initial=self.train_batch_idx, 269 | position=(2 * self.process_position), 270 | disable=self.is_disabled, 271 | leave=True, 272 | dynamic_ncols=True, 273 | file=sys.stdout, 274 | smoothing=0, 275 | ) 276 | return bar 277 | 278 | def init_validation_tqdm(self): 279 | """Override this to customize the tqdm bar for validation.""" 280 | # The main progress bar doesn't exist in `trainer.validate()` 281 | has_main_bar = self.trainer.state.fn != "validate" 282 | bar = tqdm.tqdm( 283 | desc="Validation", 284 | bar_format="{desc:<5}{percentage:3.0f}%|{bar:10}{r_bar}", 285 | position=(2 * self.process_position + has_main_bar), 286 | disable=self.is_disabled, 287 | leave=not has_main_bar, 288 | dynamic_ncols=True, 289 | file=sys.stdout, 290 | ) 291 | return bar 292 | 293 | def init_test_tqdm(self): 294 | """Override this to customize the tqdm bar for testing.""" 295 | bar = tqdm.tqdm( 296 | desc="Testing", 297 | bar_format="{desc:<5}{percentage:3.0f}%|{bar:10}{r_bar}", 298 | position=(2 * self.process_position), 299 | disable=self.is_disabled, 300 | leave=True, 301 | dynamic_ncols=True, 302 | file=sys.stdout, 303 | ) 304 | return bar 305 | 306 | 307 | class CustomTBLogger(TensorBoardLogger): 308 | @pl.utilities.rank_zero_only 309 | def log_metrics(self, metrics, step): 310 | metrics.pop('epoch', None) 311 | return super().log_metrics(metrics, step) 312 | 313 | 314 | class CustomEarlyStopping(EarlyStopping): 315 | def _run_early_stopping_check(self, trainer): 316 | """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" 317 | logs = trainer.callback_metrics 318 | if self.monitor not in logs: 319 | should_stop = False 320 | reason = None 321 | else: 322 | current = logs[self.monitor].squeeze() 323 | should_stop, reason = self._evaluate_stopping_criteria(current) 324 | 325 | # stop every ddp process if any world process decides to stop 326 | should_stop = trainer.strategy.reduce_boolean_decision(should_stop) 327 | trainer.should_stop = trainer.should_stop or should_stop 328 | if should_stop: 329 | self.stopped_epoch = trainer.current_epoch 330 | if reason and self.verbose: 331 | self._log_info(trainer, reason, self.log_rank_zero_only) 332 | 333 | 334 | class CustomCheckpointIO(TorchCheckpointIO): 335 | def __init__(self, save_parameter_names): 336 | self.save_parameter_names = save_parameter_names 337 | 338 | def save_checkpoint(self, checkpoint, path, storage_options=None): 339 | # store only task-specific parameters 340 | state_dict = checkpoint['state_dict'] 341 | state_dict = {key: value for key, value in state_dict.items() if key in self.save_parameter_names} 342 | checkpoint['state_dict'] = state_dict 343 | 344 | super().save_checkpoint(checkpoint, path, storage_options) 345 | def _run_early_stopping_check(self, trainer): 346 | """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" 347 | logs = trainer.callback_metrics 348 | if self.monitor not in logs: 349 | should_stop = False 350 | reason = None 351 | else: 352 | current = logs[self.monitor].squeeze() 353 | should_stop, reason = self._evaluate_stopping_criteria(current) 354 | 355 | # stop every ddp process if any world process decides to stop 356 | should_stop = trainer.strategy.reduce_boolean_decision(should_stop) 357 | trainer.should_stop = trainer.should_stop or should_stop 358 | if should_stop: 359 | self.stopped_epoch = trainer.current_epoch 360 | if reason and self.verbose: 361 | self._log_info(trainer, reason, self.log_rank_zero_only) 362 | -------------------------------------------------------------------------------- /train/trainer.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torchvision.transforms as T 4 | from einops import rearrange, repeat, reduce 5 | import os 6 | 7 | from model.model_factory import get_model 8 | 9 | from dataset.dataloader_factory import get_train_loader, get_validation_loaders, generate_support_data, get_eval_loader, base_sizes 10 | from dataset.taskonomy_constants import SEMSEG_CLASSES, TASKS_SEMSEG 11 | from dataset.utils import to_device, mix_fivecrop, crop_arrays 12 | 13 | from .optim import get_optimizer 14 | from .loss import compute_loss, compute_metric 15 | from .visualize import visualize_batch, postprocess_depth, postprocess_semseg 16 | from .miou_fss import AverageMeter 17 | 18 | 19 | class LightningTrainWrapper(pl.LightningModule): 20 | def __init__(self, config, verbose=True, load_pretrained=True): 21 | ''' 22 | Pytorch lightning wrapper for Visual Token Matching. 23 | ''' 24 | super().__init__() 25 | 26 | # load model. 27 | self.model = get_model(config, verbose=verbose, load_pretrained=load_pretrained) 28 | self.config = config 29 | self.verbose = verbose 30 | 31 | # tools for validation. 32 | self.miou_evaluator = AverageMeter(range(len(SEMSEG_CLASSES))) 33 | self.crop = T.Compose([ 34 | T.FiveCrop(config.img_size), 35 | T.Lambda(lambda crops: torch.stack([crop for crop in crops])) 36 | ]) 37 | self.support_data = self.load_support_data() 38 | 39 | if self.config.stage == 1: 40 | for attn in self.model.matching_module.matching: 41 | attn.attn_dropout.p = self.config.attn_dropout 42 | 43 | # save hyper=parameters 44 | self.save_hyperparameters() 45 | 46 | def load_support_data(self, data_path='support_data.pth'): 47 | ''' 48 | Load support data for validation. 49 | ''' 50 | if self.config.stage == 0: 51 | # generate support data if not exists. 52 | support_data = generate_support_data(self.config, data_path=data_path, verbose=self.verbose) 53 | else: 54 | task = f'{self.config.task}_{self.config.channel_idx}' if self.config.task == 'segment_semantic' else self.config.task 55 | support_data = {task: get_train_loader(self.config, verbose=False, get_support_data=True)} 56 | 57 | if self.verbose: 58 | print('loaded support data') 59 | 60 | # convert to proper precision 61 | if self.config.precision == 'fp16': 62 | support_data = to_device(support_data, dtype=torch.half) 63 | elif self.config.precision == 'bf16': 64 | support_data = to_device(support_data, dtype=torch.bfloat16) 65 | 66 | return support_data 67 | 68 | def configure_optimizers(self): 69 | ''' 70 | Prepare optimizer and lr scheduler. 71 | ''' 72 | optimizer, self.lr_scheduler = get_optimizer(self.config, self.model) 73 | return optimizer 74 | 75 | def train_dataloader(self, verbose=True): 76 | ''' 77 | Prepare training loader. 78 | ''' 79 | return get_train_loader(self.config, verbose=(self.verbose and verbose)) 80 | 81 | def val_dataloader(self, verbose=True): 82 | ''' 83 | Prepare validation loaders. 84 | ''' 85 | if not self.config.no_eval: 86 | # use external data from validation split 87 | if self.config.stage == 0: 88 | val_loaders, loader_tag = get_validation_loaders(self.config, verbose=(self.verbose and verbose)) 89 | self.valid_tasks = list(val_loaders.keys()) 90 | self.valid_tag = loader_tag 91 | 92 | return list(val_loaders.values()) 93 | 94 | # use second half of support data as validation query 95 | else: 96 | assert self.config.shot > 1 97 | class SubQueryDataset: 98 | def __init__(self, data): 99 | self.data = data 100 | self.n_query = self.data[0].shape[2] // 2 101 | 102 | def __len__(self): 103 | return self.n_query 104 | 105 | def __getitem__(self, idx): 106 | return (self.data[0][0, 0, self.n_query+idx], 107 | self.data[1][0, :, self.n_query+idx, 0], 108 | self.data[2][0, :, self.n_query+idx, 0]) 109 | 110 | valid_task = list(self.support_data.keys())[0] 111 | dset = SubQueryDataset(self.support_data[valid_task][:3]) 112 | self.valid_tasks = [valid_task] 113 | self.valid_tag = 'mtest_support' 114 | 115 | return torch.utils.data.DataLoader(dset, shuffle=False, batch_size=len(dset)) 116 | 117 | def test_dataloader(self, verbose=True): 118 | ''' 119 | Prepare test loaders. 120 | ''' 121 | test_loader = get_eval_loader(self.config, self.config.task, split=self.config.test_split, 122 | channel_idx=self.config.channel_idx, verbose=(self.verbose and verbose)) 123 | 124 | return test_loader 125 | 126 | def forward(self, *args, **kwargs): 127 | ''' 128 | Forward data to model. 129 | ''' 130 | return self.model(*args, **kwargs) 131 | 132 | def training_step(self, batch, batch_idx): 133 | ''' 134 | A single training iteration. 135 | ''' 136 | # forward model and compute loss. 137 | loss = compute_loss(self.model, batch, self.config) 138 | 139 | # schedule learning rate. 140 | self.lr_scheduler.step(self.global_step) 141 | 142 | if self.config.stage == 0: 143 | tag = '' 144 | elif self.config.stage == 1: 145 | if self.config.task == 'segment_semantic': 146 | tag = f'_segment_semantic_{self.config.channel_idx}' 147 | else: 148 | tag = f'_{self.config.task}' 149 | 150 | # log losses and learning rate. 151 | log_dict = { 152 | f'training/loss{tag}': loss.detach(), 153 | f'training/lr{tag}': self.lr_scheduler.lr, 154 | 'step': self.global_step, 155 | } 156 | self.log_dict( 157 | log_dict, 158 | logger=True, 159 | on_step=True, 160 | sync_dist=True, 161 | ) 162 | 163 | return loss 164 | 165 | @torch.autocast(device_type='cuda', dtype=torch.float32) 166 | def inference(self, X, task): 167 | # support data 168 | X_S, Y_S, M_S, t_idx = to_device(self.support_data[task], X.device) 169 | 170 | # use first half of support data as validation support 171 | if self.config.stage == 1: 172 | n_support = X_S.shape[2] // 2 173 | base_size = base_sizes[self.config.img_size] 174 | img_size = (self.config.img_size, self.config.img_size) 175 | X_S, Y_S, M_S = crop_arrays(X_S[:, :, :n_support], 176 | Y_S[:, :, :n_support], 177 | M_S[:, :, :n_support], 178 | base_size=base_size, 179 | img_size=img_size, 180 | random=False) 181 | 182 | t_idx = t_idx.long() 183 | T = Y_S.size(1) 184 | 185 | # five-crop query images to 224 x 224 and reshape for matching 186 | X_crop = repeat(self.crop(X), 'F B C H W -> 1 T (F B) C H W', T=T) 187 | 188 | # predict labels on each crop 189 | Y_S_in = torch.where(M_S.bool(), Y_S, torch.ones_like(Y_S) * self.config.mask_value) 190 | Y_pred_crop = self.model(X_S, Y_S_in, X_crop, t_idx=t_idx, sigmoid=('segment_semantic' not in task)) 191 | 192 | # remix the cropped predictions into a whole prediction 193 | Y_pred_crop = rearrange(Y_pred_crop, '1 T (F B) 1 H W -> F B T H W', F=5) 194 | Y_pred = mix_fivecrop(Y_pred_crop, base_size=X.size(-1), crop_size=X_crop.size(-1)) 195 | 196 | return Y_pred 197 | 198 | def on_validation_start(self) -> None: 199 | if self.config.stage == 0: 200 | self.miou_evaluator = AverageMeter(range(len(SEMSEG_CLASSES)), device=self.device) 201 | else: 202 | self.miou_evaluator = AverageMeter(0, device=self.device) 203 | return super().on_validation_start() 204 | 205 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 206 | ''' 207 | Evaluate few-shot performance on validation dataset. 208 | ''' 209 | task = self.valid_tasks[dataloader_idx] 210 | 211 | # get query data 212 | X, Y, M = batch 213 | 214 | # few-shot inference based on support data 215 | Y_pred = self.inference(X, task) 216 | 217 | # discretization for semantic segmentation 218 | if 'segment_semantic' in task: 219 | Y_pred = (Y_pred.sigmoid() > self.config.semseg_threshold).float() 220 | 221 | # compute evaluation metric 222 | metric = compute_metric(Y, Y_pred, M, task, self.miou_evaluator, self.config.stage) 223 | metric *= len(X) 224 | 225 | # visualize first batch 226 | if batch_idx == 0: 227 | X_vis = rearrange(self.all_gather(X), 'G B ... -> (B G) ...') 228 | Y_vis = rearrange(self.all_gather(Y), 'G B ... -> (B G) ...') 229 | M_vis = rearrange(self.all_gather(M), 'G B ... -> (B G) ...') 230 | Y_pred_vis = rearrange(self.all_gather(Y_pred), 'G B ... -> (B G) ...') 231 | vis_batch = (X_vis, Y_vis, M_vis, Y_pred_vis) 232 | self.vis_images(vis_batch, task) 233 | 234 | return metric, torch.tensor(len(X), device=self.device) 235 | 236 | def validation_epoch_end(self, validation_step_outputs): 237 | ''' 238 | Aggregate losses of all validation datasets and log them into tensorboard. 239 | ''' 240 | if len(self.valid_tasks) == 1: 241 | validation_step_outputs = (validation_step_outputs,) 242 | avg_loss = [] 243 | log_dict = {'step': self.global_step} 244 | 245 | for task, losses_batch in zip(self.valid_tasks, validation_step_outputs): 246 | N_total = sum([losses[1] for losses in losses_batch]) 247 | loss_pred = sum([losses[0] for losses in losses_batch]) 248 | N_total = self.all_gather(N_total).sum() 249 | loss_pred = self.all_gather(loss_pred).sum() 250 | 251 | loss_pred = loss_pred / N_total 252 | 253 | # log task-specific errors 254 | if 'segment_semantic' in task: 255 | if self.config.stage > 0 or TASKS_SEMSEG.index(task) == 0: 256 | self.miou_evaluator.intersection_buf = reduce(self.all_gather(self.miou_evaluator.intersection_buf), 257 | 'G ... -> ...', 'sum') 258 | self.miou_evaluator.union_buf = reduce(self.all_gather(self.miou_evaluator.union_buf), 259 | 'G ... -> ...', 'sum') 260 | 261 | loss_pred = 1 - self.miou_evaluator.compute_iou()[0] 262 | 263 | if self.config.stage == 0: 264 | tag = f'{self.valid_tag}/segment_semantic_pred' 265 | else: 266 | tag = f'{self.valid_tag}/segment_semantic_{self.config.channel_idx}_pred' 267 | 268 | log_dict[tag] = loss_pred 269 | avg_loss.append(loss_pred) 270 | else: 271 | log_dict[f'{self.valid_tag}/{task}_pred'] = loss_pred 272 | avg_loss.append(loss_pred) 273 | 274 | # log task-averaged error 275 | if self.config.stage == 0: 276 | avg_loss = sum(avg_loss) / len(avg_loss) 277 | log_dict[f'summary/{self.valid_tag}_pred'] = avg_loss 278 | 279 | self.log_dict( 280 | log_dict, 281 | logger=True, 282 | rank_zero_only=True 283 | ) 284 | 285 | def on_test_start(self) -> None: 286 | if self.config.stage == 0: 287 | self.miou_evaluator = AverageMeter(range(len(SEMSEG_CLASSES)), device=self.device) 288 | else: 289 | self.miou_evaluator = AverageMeter(0, device=self.device) 290 | return super().on_test_start() 291 | 292 | def test_step(self, batch, batch_idx): 293 | ''' 294 | Evaluate few-shot performance on test dataset. 295 | ''' 296 | if self.config.task == 'segment_semantic': 297 | task = f'segment_semantic_{self.config.channel_idx}' 298 | else: 299 | task = self.config.task 300 | 301 | # query data 302 | X, Y, M = batch 303 | 304 | # support data 305 | Y_pred = self.inference(X, task) 306 | 307 | # discretization for semantic segmentation 308 | if 'segment_semantic' in task: 309 | Y_pred = (Y_pred.sigmoid() > self.config.semseg_threshold).float() 310 | 311 | # compute evaluation metric 312 | metric = compute_metric(Y, Y_pred, M, task, self.miou_evaluator, self.config.stage) 313 | metric *= len(X) 314 | 315 | return metric, torch.tensor(len(X), device=self.device) 316 | 317 | def test_epoch_end(self, test_step_outputs): 318 | # append test split to save_postfix 319 | log_name = f'result{self.config.save_postfix}_split:{self.config.test_split}.pth' 320 | log_path = os.path.join(self.config.result_dir, log_name) 321 | 322 | if self.config.task == 'segment_semantic': 323 | torch.save(self.miou_evaluator, log_path) 324 | else: 325 | N_total = sum([losses[1] for losses in test_step_outputs]) 326 | metric = sum([losses[0] for losses in test_step_outputs]) / N_total 327 | metric = metric.cpu().item() 328 | torch.save(metric, log_path) 329 | 330 | @pl.utilities.rank_zero_only 331 | def vis_images(self, batch, task, vis_shot=-1, **kwargs): 332 | ''' 333 | Visualize query prediction into tensorboard. 334 | ''' 335 | X, Y, M, Y_pred = batch 336 | 337 | # choose proper subset. 338 | if vis_shot > 0: 339 | X = X[:vis_shot] 340 | Y = Y[:vis_shot] 341 | M = M[:vis_shot] 342 | Y_pred = Y_pred[:vis_shot] 343 | 344 | # set task-specific post-processing function for visualization 345 | if task == 'depth_zbuffer': 346 | postprocess_fn = postprocess_depth 347 | elif 'segment_semantic' in task: 348 | postprocess_fn = postprocess_semseg 349 | else: 350 | postprocess_fn = None 351 | 352 | # visualize batch 353 | vis = visualize_batch(X, Y, M, Y_pred, postprocess_fn=postprocess_fn, **kwargs) 354 | self.logger.experiment.add_image(f'{self.valid_tag}/{task}', vis, self.global_step) -------------------------------------------------------------------------------- /model/beit/beit_custom.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from timm.models.helpers import build_model_with_cfg 10 | from timm.models.layers import PatchEmbed, DropPath, trunc_normal_ 11 | from .beit_registry import register_model 12 | from timm.models.vision_transformer import checkpoint_filter_fn 13 | 14 | from einops import repeat 15 | 16 | 17 | def _cfg(url='', **kwargs): 18 | return { 19 | 'url': url, 20 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 21 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 22 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 23 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 24 | **kwargs 25 | } 26 | 27 | 28 | default_cfgs = { 29 | 'beit_base_patch16_224': _cfg( 30 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), 31 | 'beit_base_patch16_384': _cfg( 32 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', 33 | input_size=(3, 384, 384), crop_pct=1.0, 34 | ), 35 | 'beit_base_patch16_224_in22k': _cfg( 36 | url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth', 37 | num_classes=21841, 38 | ), 39 | 'beit_large_patch16_224': _cfg( 40 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), 41 | 'beit_large_patch16_384': _cfg( 42 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', 43 | input_size=(3, 384, 384), crop_pct=1.0, 44 | ), 45 | 'beit_large_patch16_512': _cfg( 46 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', 47 | input_size=(3, 512, 512), crop_pct=1.0, 48 | ), 49 | 'beit_large_patch16_224_in22k': _cfg( 50 | url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth', 51 | num_classes=21841, 52 | ), 53 | } 54 | 55 | 56 | class CustomMlp(nn.Module): 57 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 58 | """ 59 | def __init__(self, n_tasks, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): 60 | super().__init__() 61 | out_features = out_features or in_features 62 | hidden_features = hidden_features or in_features 63 | bias = (bias, bias) 64 | drop_probs = (drop, drop) 65 | 66 | self.fc1 = CustomLinear(n_tasks, in_features, hidden_features, bias=bias[0]) 67 | self.act = act_layer() 68 | self.drop1 = nn.Dropout(drop_probs[0]) 69 | self.fc2 = CustomLinear(n_tasks, hidden_features, out_features, bias=bias[1]) 70 | self.drop2 = nn.Dropout(drop_probs[1]) 71 | 72 | def forward(self, x, t_idx=None): 73 | x = self.fc1(x, t_idx) 74 | x = self.act(x) 75 | x = self.drop1(x) 76 | x = self.fc2(x, t_idx) 77 | x = self.drop2(x) 78 | return x 79 | 80 | 81 | class CustomLinear(nn.Linear): 82 | def __init__(self, n_tasks=0, *args, **kwargs): 83 | super().__init__(*args, **kwargs) 84 | 85 | self.n_tasks = n_tasks 86 | if self.n_tasks > 0: 87 | assert self.bias is not None 88 | self.bias = nn.Parameter(repeat(self.bias.data, '... -> T ...', T=n_tasks).contiguous()) 89 | 90 | def forward(self, input, t_idx=None): 91 | if self.n_tasks > 0: 92 | assert t_idx is not None 93 | output = F.linear(input, self.weight, None) 94 | return output + self.bias[t_idx][:, None] 95 | else: 96 | return F.linear(input, self.weight, self.bias) 97 | 98 | 99 | class CustomLayerNorm(nn.LayerNorm): 100 | def __init__(self, n_tasks=0, *args, **kwargs): 101 | super().__init__(*args, **kwargs) 102 | 103 | self.n_tasks = n_tasks 104 | if self.n_tasks > 0: 105 | assert self.elementwise_affine 106 | self.bias = nn.Parameter(repeat(self.bias.data, '... -> T ...', T=n_tasks).contiguous()) 107 | 108 | def forward(self, input, t_idx=None): 109 | if self.n_tasks > 0: 110 | assert t_idx is not None 111 | output = F.layer_norm(input, self.normalized_shape, self.weight, None, self.eps) 112 | return output + self.bias[t_idx][:, None] 113 | else: 114 | return F.layer_norm( 115 | input, self.normalized_shape, self.weight, self.bias, self.eps) 116 | 117 | 118 | class CustomAttention(nn.Module): 119 | def __init__( 120 | self, dim, num_heads=8, qkv_bias=False, attn_drop=0., 121 | proj_drop=0., window_size=None, attn_head_dim=None, n_prompts=0, n_tasks=0): 122 | super().__init__() 123 | self.num_heads = num_heads 124 | head_dim = dim // num_heads 125 | if attn_head_dim is not None: 126 | head_dim = attn_head_dim 127 | all_head_dim = head_dim * self.num_heads 128 | self.scale = head_dim ** -0.5 129 | self.n_prompts = n_prompts 130 | 131 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 132 | if qkv_bias: 133 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 134 | self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) 135 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 136 | else: 137 | self.q_bias = None 138 | self.k_bias = None 139 | self.v_bias = None 140 | 141 | if window_size: 142 | self.window_size = window_size 143 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 144 | self.relative_position_bias_table = nn.Parameter( 145 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 146 | # cls to token & token 2 cls & cls to cls 147 | 148 | # get pair-wise relative position index for each token inside the window 149 | coords_h = torch.arange(window_size[0]) 150 | coords_w = torch.arange(window_size[1]) 151 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 152 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 153 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 154 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 155 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 156 | relative_coords[:, :, 1] += window_size[1] - 1 157 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 158 | relative_position_index = \ 159 | torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) 160 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 161 | relative_position_index[0, 0:] = self.num_relative_distance - 3 162 | relative_position_index[0:, 0] = self.num_relative_distance - 2 163 | relative_position_index[0, 0] = self.num_relative_distance - 1 164 | 165 | self.register_buffer("relative_position_index", relative_position_index) 166 | else: 167 | self.window_size = None 168 | self.relative_position_bias_table = None 169 | self.relative_position_index = None 170 | 171 | self.attn_drop = nn.Dropout(attn_drop) 172 | self.proj = CustomLinear(n_tasks, all_head_dim, dim) 173 | self.proj_drop = nn.Dropout(proj_drop) 174 | 175 | def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None, t_idx=None): 176 | B, N, C = x.shape 177 | qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None 178 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 179 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 180 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 181 | 182 | q = q * self.scale 183 | attn = (q @ k.transpose(-2, -1)) 184 | 185 | if self.relative_position_bias_table is not None: 186 | relative_position_bias = \ 187 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 188 | self.window_size[0] * self.window_size[1] + 1, 189 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 190 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 191 | 192 | _C, _H, _W = relative_position_bias.shape 193 | 194 | if self.n_prompts > 0: 195 | relative_position_bias = torch.cat(( 196 | torch.zeros(_C, self.n_prompts, _W, dtype=attn.dtype, device=attn.device), 197 | relative_position_bias 198 | ), dim=-2) 199 | 200 | relative_position_bias = torch.cat(( 201 | torch.zeros(_C, _H + self.n_prompts, self.n_prompts, dtype=attn.dtype, device=attn.device), 202 | relative_position_bias 203 | ), dim=-1) 204 | 205 | attn = attn + relative_position_bias.unsqueeze(0) 206 | 207 | if rel_pos_bias is not None: 208 | attn = attn + rel_pos_bias 209 | 210 | attn = attn.softmax(dim=-1) 211 | attn = self.attn_drop(attn) 212 | 213 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 214 | x = self.proj(x, t_idx) 215 | x = self.proj_drop(x) 216 | return x 217 | 218 | 219 | class CustomBlock(nn.Module): 220 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 221 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=CustomLayerNorm, 222 | window_size=None, attn_head_dim=None, n_prompts=0, n_tasks=0): 223 | super().__init__() 224 | self.norm1 = norm_layer(n_tasks, dim) 225 | self.attn = CustomAttention( 226 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 227 | window_size=window_size, attn_head_dim=attn_head_dim, n_prompts=n_prompts, n_tasks=n_tasks) 228 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 229 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 230 | self.norm2 = norm_layer(n_tasks, dim) 231 | mlp_hidden_dim = int(dim * mlp_ratio) 232 | self.mlp = CustomMlp(n_tasks=n_tasks, in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 233 | 234 | if init_values: 235 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 236 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 237 | else: 238 | self.gamma_1, self.gamma_2 = None, None 239 | 240 | def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None, t_idx=None): 241 | if self.gamma_1 is None: 242 | x = x + self.drop_path(self.attn(self.norm1(x, t_idx), rel_pos_bias=rel_pos_bias, t_idx=t_idx)) 243 | x = x + self.drop_path(self.mlp(self.norm2(x, t_idx), t_idx)) 244 | else: 245 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x, t_idx), rel_pos_bias=rel_pos_bias, t_idx=t_idx)) 246 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x, t_idx), t_idx)) 247 | return x 248 | 249 | 250 | class RelativePositionBias(nn.Module): 251 | 252 | def __init__(self, window_size, num_heads): 253 | super().__init__() 254 | self.window_size = window_size 255 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 256 | self.relative_position_bias_table = nn.Parameter( 257 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 258 | # cls to token & token 2 cls & cls to cls 259 | 260 | # get pair-wise relative position index for each token inside the window 261 | coords_h = torch.arange(window_size[0]) 262 | coords_w = torch.arange(window_size[1]) 263 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 264 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 265 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 266 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 267 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 268 | relative_coords[:, :, 1] += window_size[1] - 1 269 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 270 | relative_position_index = \ 271 | torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) 272 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 273 | relative_position_index[0, 0:] = self.num_relative_distance - 3 274 | relative_position_index[0:, 0] = self.num_relative_distance - 2 275 | relative_position_index[0, 0] = self.num_relative_distance - 1 276 | 277 | self.register_buffer("relative_position_index", relative_position_index) 278 | 279 | # trunc_normal_(self.relative_position_bias_table, std=.02) 280 | 281 | def forward(self): 282 | relative_position_bias = \ 283 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 284 | self.window_size[0] * self.window_size[1] + 1, 285 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 286 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 287 | 288 | 289 | class CustomBeit(nn.Module): 290 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 291 | num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., 292 | drop_path_rate=0., norm_layer=partial(CustomLayerNorm, eps=1e-6), init_values=None, 293 | use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, 294 | use_mean_pooling=True, init_scale=0.001, n_prompts=0, n_tasks=0): 295 | super().__init__() 296 | self.num_classes = num_classes 297 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 298 | 299 | self.patch_embed = PatchEmbed( 300 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 301 | num_patches = self.patch_embed.num_patches 302 | 303 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 304 | # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 305 | if use_abs_pos_emb: 306 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 307 | else: 308 | self.pos_embed = None 309 | self.pos_drop = nn.Dropout(p=drop_rate) 310 | 311 | if use_shared_rel_pos_bias: 312 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads) 313 | else: 314 | self.rel_pos_bias = None 315 | 316 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 317 | self.use_rel_pos_bias = use_rel_pos_bias 318 | self.blocks = nn.ModuleList([ 319 | CustomBlock( 320 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 321 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 322 | init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None, 323 | n_prompts=n_prompts, n_tasks=n_tasks, 324 | ) 325 | for i in range(depth)]) 326 | 327 | self.apply(self._init_weights) 328 | if self.pos_embed is not None: 329 | trunc_normal_(self.pos_embed, std=.02) 330 | trunc_normal_(self.cls_token, std=.02) 331 | # trunc_normal_(self.mask_token, std=.02) 332 | self.fix_init_weight() 333 | 334 | def fix_init_weight(self): 335 | def rescale(param, layer_id): 336 | param.div_(math.sqrt(2.0 * layer_id)) 337 | 338 | for layer_id, layer in enumerate(self.blocks): 339 | rescale(layer.attn.proj.weight.data, layer_id + 1) 340 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 341 | 342 | def _init_weights(self, m): 343 | if isinstance(m, nn.Linear): 344 | trunc_normal_(m.weight, std=.02) 345 | if isinstance(m, nn.Linear) and m.bias is not None: 346 | nn.init.constant_(m.bias, 0) 347 | elif isinstance(m, CustomLayerNorm): 348 | nn.init.constant_(m.bias, 0) 349 | nn.init.constant_(m.weight, 1.0) 350 | 351 | def get_num_layers(self): 352 | return len(self.blocks) 353 | 354 | @torch.jit.ignore 355 | def no_weight_decay(self): 356 | return {'pos_embed', 'cls_token'} 357 | 358 | 359 | def _create_beit(variant, pretrained=False, default_cfg=None, **kwargs): 360 | default_cfg = default_cfg or default_cfgs[variant] 361 | if kwargs.get('features_only', None): 362 | raise RuntimeError('features_only not implemented for Beit models.') 363 | 364 | model = build_model_with_cfg( 365 | CustomBeit, variant, pretrained, 366 | default_cfg=default_cfg, 367 | # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes 368 | pretrained_filter_fn=checkpoint_filter_fn, 369 | **kwargs) 370 | return model 371 | 372 | 373 | @register_model 374 | def beit_base_patch16_224(pretrained=False, **kwargs): 375 | model_kwargs = dict( 376 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 377 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) 378 | model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **model_kwargs) 379 | return model 380 | 381 | 382 | @register_model 383 | def beit_base_patch16_384(pretrained=False, **kwargs): 384 | model_kwargs = dict( 385 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 386 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) 387 | model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs) 388 | return model 389 | 390 | 391 | @register_model 392 | def beit_base_patch16_224_in22k(pretrained=False, **kwargs): 393 | model_kwargs = dict( 394 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 395 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) 396 | model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) 397 | return model 398 | 399 | 400 | @register_model 401 | def beit_large_patch16_224(pretrained=False, **kwargs): 402 | model_kwargs = dict( 403 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 404 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 405 | model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs) 406 | return model 407 | 408 | 409 | @register_model 410 | def beit_large_patch16_384(pretrained=False, **kwargs): 411 | model_kwargs = dict( 412 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 413 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 414 | model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs) 415 | return model 416 | 417 | 418 | @register_model 419 | def beit_large_patch16_512(pretrained=False, **kwargs): 420 | model_kwargs = dict( 421 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 422 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 423 | model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs) 424 | return model 425 | 426 | 427 | @register_model 428 | def beit_large_patch16_224_in22k(pretrained=False, **kwargs): 429 | model_kwargs = dict( 430 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 431 | use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) 432 | model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) 433 | return model 434 | 435 | 436 | def convert_beit_bitfit(beit, n_tasks): 437 | for blk in beit.blocks: 438 | for name, module in blk.named_children(): 439 | if isinstance(module, CustomLayerNorm): 440 | norm_bitfit = CustomLayerNorm( 441 | n_tasks, 442 | module.normalized_shape, 443 | module.eps, 444 | module.elementwise_affine 445 | ).to(module.bias.device) 446 | norm_bitfit.bias.data = repeat(module.bias.data, '... -> T ...', T=n_tasks) 447 | setattr(blk, name, norm_bitfit) 448 | 449 | elif isinstance(module, CustomMlp): 450 | for subname, submodule in module.named_children(): 451 | if isinstance(submodule, CustomLinear): 452 | fc_bitfit = CustomLinear( 453 | n_tasks, 454 | submodule.in_features, 455 | submodule.out_features 456 | ).to(submodule.bias.device) 457 | fc_bitfit.bias.data = repeat(submodule.bias.data, '... -> T ...', T=n_tasks) 458 | setattr(getattr(blk, name), subname, fc_bitfit) -------------------------------------------------------------------------------- /dataset/taskonomy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import PIL 5 | from PIL import Image 6 | from einops import rearrange, repeat 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset 11 | 12 | from .taskonomy_constants import SEMSEG_CLASSES, SEMSEG_CLASS_RANGE, TASKS_GROUP_DICT, TASKS, BUILDINGS 13 | from .augmentation import RandomHorizontalFlip, FILTERING_AUGMENTATIONS, RandomCompose, Mixup 14 | from .utils import crop_arrays, SobelEdgeDetector 15 | 16 | 17 | class TaskonomyBaseDataset(Dataset): 18 | def __init__(self, root_dir, buildings, tasks, base_size=(256, 256), img_size=(224, 224), seed=None, precision='fp32'): 19 | super().__init__() 20 | 21 | if seed is not None: 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | 27 | self.data_root = root_dir 28 | self.buildings = sorted(buildings) 29 | 30 | self.tasks = tasks 31 | self.subtasks = [] 32 | for task in tasks: 33 | if task in TASKS_GROUP_DICT: 34 | self.subtasks += TASKS_GROUP_DICT[task] 35 | else: 36 | self.subtasks += [task] 37 | 38 | self.support_classes = SEMSEG_CLASSES 39 | 40 | self.base_size = base_size 41 | self.img_size = img_size 42 | self.precision = precision 43 | 44 | self.img_paths = [img_path for img_path in sorted(os.listdir(os.path.join(self.data_root, 'rgb'))) 45 | if img_path.split('_')[0] in self.buildings] 46 | self.path_dict = {building: [i for i, img_path in enumerate(self.img_paths) 47 | if img_path.split('_')[0] == building] 48 | for building in self.buildings} 49 | 50 | # register euclidean depth and occlusion edge statistics, sobel edge detectors, and class dictionary 51 | self.meta_info_path = 'dataset/meta_info' 52 | self.depth_quantiles = torch.load(os.path.join(self.meta_info_path, 'depth_quantiles.pth')) 53 | self.edge_params = torch.load(os.path.join(self.meta_info_path, 'edge_params.pth')) 54 | self.sobel_detectors = [SobelEdgeDetector(kernel_size=k, sigma=s) for k, s in self.edge_params['params']] 55 | self.edge_thresholds = torch.load(os.path.join(self.meta_info_path, 'edge_thresholds.pth')) 56 | 57 | def load_img(self, img_path): 58 | img_path = os.path.join(self.data_root, 'rgb', img_path) 59 | try: 60 | # open image file 61 | img = Image.open(img_path) 62 | img = np.asarray(img) 63 | 64 | # type conversion 65 | img = img.astype('float32') / 255 66 | 67 | # shape conversion 68 | img = np.transpose(img, (2, 0, 1)) 69 | 70 | success = True 71 | 72 | except PIL.UnidentifiedImageError: 73 | print(f'PIL Error on {img_path}') 74 | img = -np.ones(3, 80, 80).astype('float32') 75 | success = False 76 | 77 | return img, success 78 | 79 | def load_label(self, task, img_path): 80 | task_root = os.path.join(self.data_root, task) 81 | if task == 'segment_semantic': 82 | label_name = img_path.replace('rgb', 'segmentsemantic') 83 | else: 84 | label_name = img_path.replace('rgb', task) 85 | label_path = os.path.join(task_root, label_name) 86 | 87 | # open label file 88 | label = Image.open(label_path) 89 | label = np.asarray(label) 90 | 91 | # type conversion 92 | if label.dtype == 'uint8': 93 | label = label.astype('float32') / 255 94 | else: 95 | label = label.astype('float32') 96 | 97 | # shape conversion 98 | if label.ndim == 2: 99 | label = label[np.newaxis, ...] 100 | elif label.ndim == 3: 101 | label = np.transpose(label, (2, 0, 1)) 102 | 103 | 104 | return label 105 | 106 | def load_task(self, task, img_path): 107 | if task == 'segment_semantic': 108 | label = self.load_label(task, img_path) 109 | label = (255*label).astype("long") 110 | label[label == 0] = 1 111 | label = label - 1 112 | mask = np.ones_like(label) 113 | 114 | elif task == 'normal': 115 | label = self.load_label(task, img_path) 116 | label = np.clip(label, 0, 1) 117 | 118 | mask = np.ones_like(label) 119 | 120 | elif task in ['depth_euclidean', 'depth_zbuffer']: 121 | label = self.load_label(task, img_path) 122 | label = np.log((1 + label)) / np.log(2 ** 16) 123 | 124 | depth_label = self.load_label('depth_euclidean', img_path) 125 | mask = (depth_label < 64500) 126 | 127 | elif task == 'edge_texture': 128 | label = mask = None 129 | 130 | elif task == 'edge_occlusion': 131 | label = self.load_label(task, img_path) 132 | label = label / (2 ** 16) 133 | 134 | depth_label = self.load_label('depth_euclidean', img_path) 135 | mask = (depth_label < 64500) 136 | 137 | elif task == 'keypoints2d': 138 | label = self.load_label(task, img_path) 139 | label = label / (2 ** 16) 140 | label = np.clip(label, 0, 0.005) / 0.005 141 | 142 | mask = np.ones_like(label) 143 | 144 | elif task == 'keypoints3d': 145 | label = self.load_label(task, img_path) 146 | label = label / (2 ** 16) 147 | 148 | depth_label = self.load_label('depth_euclidean', img_path) 149 | mask = (depth_label < 64500) 150 | 151 | elif task == 'reshading': 152 | label = self.load_label(task, img_path) 153 | label = label[:1] 154 | label = np.clip(label, 0, 1) 155 | 156 | mask = np.ones_like(label) 157 | 158 | elif task == 'principal_curvature': 159 | label = self.load_label(task, img_path) 160 | label = label[:2] 161 | label = np.clip(label, 0, 1) 162 | 163 | depth_label = self.load_label('depth_euclidean', img_path) 164 | mask = (depth_label < 64500) 165 | 166 | else: 167 | raise ValueError(task) 168 | 169 | return label, mask 170 | 171 | def preprocess_segment_semantic(self, labels, channels, drop_background=True): 172 | # regard non-support classes as background 173 | for c in SEMSEG_CLASS_RANGE: 174 | if c not in channels: 175 | labels = np.where(labels == c, 176 | np.zeros_like(labels), 177 | labels) 178 | 179 | # re-label support classes 180 | for i, c in enumerate(sorted(channels)): 181 | labels = np.where(labels == c, 182 | (i + 1)*np.ones_like(labels), 183 | labels) 184 | 185 | # one-hot encoding 186 | labels = torch.from_numpy(labels).long().squeeze(1) 187 | labels = F.one_hot(labels, len(channels) + 1).permute(0, 3, 1, 2).float() 188 | if drop_background: 189 | labels = labels[:, 1:] 190 | masks = torch.ones_like(labels) 191 | 192 | return labels, masks 193 | 194 | def preprocess_depth(self, labels, masks, channels, task): 195 | labels = torch.from_numpy(labels).float() 196 | masks = torch.from_numpy(masks).float() 197 | 198 | labels_th = [] 199 | for c in channels: 200 | assert c < len(self.depth_quantiles[task]) - 1 201 | 202 | # get boundary values for the depth segment 203 | t_min = self.depth_quantiles[task][c] 204 | if task == 'depth_euclidean': 205 | t_max = self.depth_quantiles[task][c+1] 206 | else: 207 | t_max = self.depth_quantiles[task][5] 208 | 209 | # thresholding and re-normalizing 210 | labels_ = torch.where(masks.bool(), labels, t_min*torch.ones_like(labels)) 211 | labels_ = torch.clip(labels_, t_min, t_max) 212 | labels_ = (labels_ - t_min) / (t_max - t_min) 213 | labels_th.append(labels_) 214 | 215 | labels = torch.cat(labels_th, 1) 216 | masks = masks.expand_as(labels) 217 | 218 | return labels, masks 219 | 220 | def preprocess_edge_texture(self, imgs, channels): 221 | labels = [] 222 | # detect sobel edge with different set of pre-defined parameters 223 | for c in channels: 224 | labels_ = self.sobel_detectors[c].detect(imgs) 225 | labels.append(labels_) 226 | labels = torch.cat(labels, 1) 227 | 228 | # thresholding and re-normalizing 229 | labels = torch.clip(labels, 0, self.edge_params['threshold']) 230 | labels = labels / self.edge_params['threshold'] 231 | 232 | masks = torch.ones_like(labels) 233 | 234 | return labels, masks 235 | 236 | def preprocess_edge_occlusion(self, labels, masks, channels): 237 | labels = torch.from_numpy(labels).float() 238 | masks = torch.from_numpy(masks).float() 239 | 240 | labels_th = [] 241 | labels = torch.where(masks.bool(), labels, torch.zeros_like(labels)) 242 | for c in channels: 243 | assert c < len(self.edge_thresholds) 244 | t_max = self.edge_thresholds[c] 245 | 246 | # thresholding and re-normalizing 247 | labels_ = torch.clip(labels, 0, t_max) 248 | labels_ = labels_ / t_max 249 | labels_th.append(labels_) 250 | 251 | labels = torch.cat(labels_th, 1) 252 | masks = masks.expand_as(labels) 253 | 254 | return labels, masks 255 | 256 | def preprocess_default(self, labels, masks, channels): 257 | labels = torch.from_numpy(labels).float() 258 | 259 | if masks is not None: 260 | masks = torch.from_numpy(masks).float().expand_as(labels) 261 | else: 262 | masks = torch.ones_like(labels) 263 | 264 | labels = labels[:, channels] 265 | masks = masks[:, channels] 266 | 267 | return labels, masks 268 | 269 | def preprocess_batch(self, task, imgs, labels, masks, channels=None, drop_background=True): 270 | imgs = torch.from_numpy(imgs).float() 271 | 272 | # process all channels if not given 273 | if channels is None: 274 | if task == 'segment_semantic': 275 | channels = SEMSEG_CLASSES 276 | elif task in TASKS_GROUP_DICT: 277 | channels = range(len(TASKS_GROUP_DICT[task])) 278 | else: 279 | raise ValueError(task) 280 | 281 | # task-specific preprocessing 282 | if task == 'segment_semantic': 283 | labels, masks = self.preprocess_segment_semantic(labels, channels, drop_background) 284 | 285 | elif task in ['depth_euclidean', 'depth_zbuffer']: 286 | labels, masks = self.preprocess_depth(labels, masks, channels, task) 287 | 288 | elif task == 'edge_texture': 289 | labels, masks = self.preprocess_edge_texture(imgs, channels) 290 | 291 | elif task == 'edge_occlusion': 292 | labels, masks = self.preprocess_edge_occlusion(labels, masks, channels) 293 | 294 | else: 295 | labels, masks = self.preprocess_default(labels, masks, channels) 296 | 297 | # ensure label values to be in [0, 1] 298 | labels = labels.clip(0, 1) 299 | 300 | # precision conversion 301 | if self.precision == 'fp16': 302 | imgs = imgs.half() 303 | labels = labels.half() 304 | masks = masks.half() 305 | elif self.precision == 'bf16': 306 | imgs = imgs.bfloat16() 307 | labels = labels.bfloat16() 308 | masks = masks.bfloat16() 309 | 310 | return imgs, labels, masks 311 | 312 | 313 | class TaskonomyHybridDataset(TaskonomyBaseDataset): 314 | def __init__(self, root_dir, buildings, tasks, shot, tasks_per_batch, domains_per_batch, 315 | image_augmentation, unary_augmentation, binary_augmentation, mixed_augmentation, dset_size=-1, **kwargs): 316 | super().__init__(root_dir, buildings, tasks, **kwargs) 317 | 318 | assert shot > 0 319 | self.shot = shot 320 | self.tasks_per_batch = tasks_per_batch 321 | self.domains_per_batch = domains_per_batch 322 | self.dset_size = dset_size 323 | 324 | if image_augmentation: 325 | self.image_augmentation = RandomHorizontalFlip() 326 | else: 327 | self.image_augmentation = None 328 | 329 | if unary_augmentation: 330 | self.unary_augmentation = RandomCompose( 331 | [augmentation(**kwargs) for augmentation, kwargs in FILTERING_AUGMENTATIONS.values()], 332 | p=0.8, 333 | ) 334 | else: 335 | self.unary_augmentation = None 336 | 337 | if binary_augmentation is not None: 338 | self.binary_augmentation = Mixup() 339 | else: 340 | self.binary_augmentation = None 341 | 342 | self.mixed_augmentation = mixed_augmentation 343 | 344 | def __len__(self): 345 | if self.dset_size > 0: 346 | return self.dset_size 347 | else: 348 | return len(self.img_paths) // self.shot 349 | 350 | def sample_batch(self, task, channel, path_idxs=None): 351 | # sample data paths 352 | if path_idxs is None: 353 | # sample buildings for support and query 354 | buildings = np.random.choice(self.buildings, 2*self.domains_per_batch, replace=False) 355 | 356 | # sample image path indices in each building 357 | path_idxs = np.array([], dtype=np.int64) 358 | for building in buildings: 359 | path_idxs = np.concatenate((path_idxs, 360 | np.random.choice(self.path_dict[building], 361 | self.shot // self.domains_per_batch, replace=False))) 362 | 363 | # load images and labels 364 | imgs = [] 365 | labels = [] 366 | masks = [] 367 | for path_idx in path_idxs: 368 | # index image path 369 | img_path = self.img_paths[path_idx] 370 | 371 | # load image, label, and mask 372 | img, success = self.load_img(img_path) 373 | label, mask = self.load_task(task, img_path) 374 | if not success: 375 | mask = np.zeros_like(label) 376 | 377 | imgs.append(img) 378 | labels.append(label) 379 | masks.append(mask) 380 | 381 | # form a batch 382 | imgs = np.stack(imgs) 383 | labels = np.stack(labels) if labels[0] is not None else None 384 | masks = np.stack(masks) if masks[0] is not None else None 385 | 386 | # preprocess and make numpy arrays to torch tensors 387 | imgs, labels, masks = self.preprocess_batch(task, imgs, labels, masks, [channel]) 388 | 389 | return imgs, labels, masks, path_idxs 390 | 391 | def sample_tasks(self): 392 | # sample subtasks 393 | replace = len(self.subtasks) < self.tasks_per_batch 394 | subtasks = np.random.choice(self.subtasks, self.tasks_per_batch, replace=replace) 395 | 396 | # parse task and channel from the subtasks 397 | tasks = [] 398 | channels = [] 399 | for subtask in subtasks: 400 | # subtask format: "{task}_{channel}" 401 | task = '_'.join(subtask.split('_')[:-1]) 402 | channel = int(subtask.split('_')[-1]) 403 | 404 | tasks.append(task) 405 | channels.append(channel) 406 | 407 | return tasks, channels 408 | 409 | def __getitem__(self, idx): 410 | # sample tasks 411 | tasks, channels = self.sample_tasks() 412 | if self.binary_augmentation is not None: 413 | tasks_aux, channels_aux = self.sample_tasks() 414 | 415 | X = [] 416 | Y = [] 417 | M = [] 418 | t_idx = [] 419 | 420 | path_idxs = None # generated at the first task and then shared for the remaining tasks 421 | for i in range(self.tasks_per_batch): 422 | # sample a batch of images, labels, and masks for each task 423 | X_, Y_, M_, path_idxs = self.sample_batch(tasks[i], channels[i], path_idxs) 424 | 425 | # apply image augmentation 426 | if self.image_augmentation is not None: 427 | (X_, Y_, M_), image_aug = self.image_augmentation(X_, Y_, M_, get_augs=True) 428 | else: 429 | image_aug = lambda x: x 430 | 431 | # apply unary task augmentation 432 | if self.unary_augmentation is not None: 433 | Y_, M_ = self.unary_augmentation(Y_, M_) 434 | 435 | # apply binary task augmentation 436 | if self.binary_augmentation is not None: 437 | _, Y_aux_, M_aux_, _, = self.sample_batch(tasks_aux[i], channels_aux[i], path_idxs) 438 | if self.mixed_augmentation and self.image_augmentation is not None: 439 | (Y_aux_, M_aux_) = self.image_augmentation(Y_, M_) 440 | else: 441 | Y_aux_ = image_aug(Y_aux_) 442 | M_aux_ = image_aug(M_aux_) 443 | Y_, M_ = self.binary_augmentation(Y_, Y_aux_, M_, M_aux_) 444 | 445 | X.append(X_) 446 | Y.append(Y_) 447 | M.append(M_) 448 | 449 | t_idx.append(TASKS.index(f'{tasks[i]}_{channels[i]}')) 450 | 451 | # form a global batch 452 | X = torch.stack(X) 453 | Y = torch.stack(Y) 454 | M = torch.stack(M) 455 | 456 | # random-crop arrays 457 | X, Y, M = crop_arrays(X, Y, M, 458 | base_size=self.base_size, 459 | img_size=self.img_size, 460 | random=True) 461 | 462 | # task and task-group index 463 | t_idx = torch.tensor(t_idx) 464 | 465 | return X, Y, M, t_idx 466 | 467 | 468 | class TaskonomyContinuousDataset(TaskonomyBaseDataset): 469 | def __init__(self, root_dir, buildings, task, channel_idx=-1, dset_size=-1, image_augmentation=False, **kwargs): 470 | super().__init__(root_dir, buildings, [task], **kwargs) 471 | 472 | self.task = task 473 | self.channel_idx = channel_idx 474 | self.dset_size = dset_size 475 | self.n_channels = len(TASKS_GROUP_DICT[task]) 476 | 477 | if image_augmentation: 478 | self.image_augmentation = RandomHorizontalFlip() 479 | else: 480 | self.image_augmentation = None 481 | 482 | def __len__(self): 483 | if self.dset_size > 0: 484 | return self.dset_size 485 | else: 486 | return len(self.img_paths) 487 | 488 | def __getitem__(self, idx): 489 | img_path = self.img_paths[idx % len(self.img_paths)] 490 | 491 | # load image, label, and mask 492 | img, success = self.load_img(img_path) 493 | label, mask = self.load_task(self.task, img_path) 494 | if not success: 495 | mask = np.zeros_like(label) 496 | 497 | # preprocess labels 498 | imgs, labels, masks = self.preprocess_batch(self.task, 499 | img[None], 500 | None if label is None else label[None], 501 | None if mask is None else mask[None], 502 | channels=([self.channel_idx] if self.channel_idx >= 0 else None), 503 | drop_background=False) 504 | 505 | 506 | X, Y, M = imgs[0], labels[0], masks[0] 507 | if self.image_augmentation is not None: 508 | X, Y, M = self.image_augmentation(X, Y, M) 509 | 510 | # crop arrays 511 | X, Y, M = crop_arrays(X, Y, M, 512 | base_size=self.base_size, 513 | img_size=self.img_size, 514 | random=True) 515 | 516 | return X, Y, M 517 | 518 | 519 | class TaskonomySegmentationDataset(TaskonomyBaseDataset): 520 | def __init__(self, root_dir, buildings, semseg_class, dset_size=-1, **kwargs): 521 | super().__init__(root_dir, buildings, ['segment_semantic'], **kwargs) 522 | 523 | self.semseg_class = semseg_class 524 | self.img_paths = sorted(os.listdir(os.path.join(self.data_root, 'rgb'))) # use global path dictionary 525 | self.class_dict = torch.load(os.path.join(self.meta_info_path, 'class_dict.pth')) 526 | 527 | self.n_channels = 1 528 | 529 | perm_path = os.path.join(self.meta_info_path, 'idxs_perm_all.pth') 530 | if os.path.exists(perm_path): 531 | idxs_perm = torch.load(perm_path) 532 | else: 533 | idxs_perm = {} 534 | for c in SEMSEG_CLASSES: 535 | n_imgs = 0 536 | for building in BUILDINGS: 537 | n_imgs += len(self.class_dict[building][c]) 538 | idxs_perm[c] = torch.randperm(n_imgs) 539 | torch.save(idxs_perm, perm_path) 540 | 541 | # collect all images of class c 542 | class_idxs = [] 543 | for building in BUILDINGS: 544 | class_idxs += self.class_dict[building][self.semseg_class] 545 | 546 | # permute the image indices 547 | class_idxs = (torch.tensor(class_idxs)[idxs_perm[self.semseg_class]]).numpy() 548 | 549 | # filter images in given buildings 550 | self.class_idxs = [class_idx for class_idx in class_idxs if self.img_paths[class_idx].split('_')[0] in buildings] 551 | 552 | self.dset_size = dset_size 553 | 554 | def __len__(self): 555 | if self.dset_size > 0: 556 | return self.dset_size 557 | else: 558 | return len(self.class_idxs) 559 | 560 | def __getitem__(self, idx): 561 | path_idx = self.class_idxs[idx % len(self.class_idxs)] 562 | img_path = self.img_paths[path_idx] 563 | 564 | # load image, label, and mask 565 | img, success = self.load_img(img_path) 566 | label, mask = self.load_task('segment_semantic', img_path) 567 | if not success: 568 | mask = np.zeros_like(mask) 569 | 570 | # preprocess labels 571 | imgs, labels, masks = self.preprocess_batch('segment_semantic', 572 | img[None], 573 | None if label is None else label[None], 574 | None if mask is None else mask[None], 575 | channels=[self.semseg_class]) 576 | 577 | X, Y, M = imgs[0], labels[0], masks[0] 578 | 579 | # crop arrays 580 | X, Y, M = crop_arrays(X, Y, M, 581 | base_size=self.base_size, 582 | img_size=self.img_size, 583 | random=True) 584 | 585 | return X, Y, M 586 | 587 | 588 | class TaskonomyFinetuneDataset(TaskonomyBaseDataset): 589 | def __init__(self, root_dir, buildings, task, support_idx, channel_idx, shot, 590 | dset_size=1, image_augmentation=False, fix_seed=False, shuffle_idxs=True, **kwargs): 591 | super().__init__(root_dir, buildings, [task], **kwargs) 592 | 593 | self.task = task 594 | self.support_idx = support_idx 595 | self.shot = shot 596 | self.dset_size = dset_size 597 | self.channel_idx = channel_idx 598 | self.fix_seed = fix_seed 599 | self.shuffle_idxs = shuffle_idxs 600 | self.offset = support_idx*shot 601 | 602 | if image_augmentation: 603 | self.image_augmentation = RandomHorizontalFlip() 604 | else: 605 | self.image_augmentation = None 606 | 607 | if task == 'segment_semantic': 608 | self.img_paths = sorted(os.listdir(os.path.join(self.data_root, 'rgb'))) 609 | self.class_dict = torch.load(os.path.join(self.meta_info_path, 'class_dict.pth')) 610 | 611 | class_idxs = [] 612 | for building in buildings: 613 | class_idxs += self.class_dict[building][self.channel_idx] 614 | self.class_idxs = class_idxs 615 | 616 | perm_path = os.path.join(self.meta_info_path, f'class_perm_finetune_{self.channel_idx}.pth') 617 | if not os.path.exists(perm_path): 618 | class_perm = torch.randperm(len(self.class_idxs)) 619 | torch.save(class_perm, perm_path) 620 | else: 621 | class_perm = torch.load(perm_path) 622 | self.class_perm = class_perm 623 | 624 | else: 625 | perm_path = os.path.join(self.meta_info_path, 'idxs_perm_finetune.pth') 626 | if not os.path.exists(perm_path): 627 | idxs_perm = torch.randperm(len(self.img_paths)) 628 | torch.save(idxs_perm, perm_path) 629 | else: 630 | idxs_perm = torch.load(perm_path) 631 | self.idxs_perm = idxs_perm 632 | 633 | def __len__(self): 634 | return self.dset_size 635 | 636 | def __getitem__(self, idx): 637 | idxs = [(idx % self.shot) + self.offset + i for i in range(self.shot)] 638 | random.shuffle(idxs) 639 | imgs = [] 640 | labels = [] 641 | masks = [] 642 | for idx_ in idxs: 643 | if self.task == 'segment_semantic': 644 | if self.shuffle_idxs: 645 | idx_ = self.class_perm[idx_ % len(self.class_idxs)] 646 | path_idx = self.class_idxs[idx_] 647 | img_path = self.img_paths[path_idx] 648 | else: 649 | if self.shuffle_idxs: 650 | idx_ = self.idxs_perm[idx_ % len(self.img_paths)] 651 | img_path = self.img_paths[idx_] 652 | 653 | # load image, label, and mask 654 | img, success = self.load_img(img_path) 655 | label, mask = self.load_task(self.task, img_path) 656 | if not success: 657 | mask = np.zeros_like(label) 658 | 659 | imgs.append(img) 660 | labels.append(label) 661 | masks.append(mask) 662 | 663 | imgs = np.stack(imgs) 664 | labels = np.stack(labels) if labels[0] is not None else None 665 | masks = np.stack(masks) if masks[0] is not None else None 666 | 667 | # preprocess labels 668 | imgs, labels, masks = self.preprocess_batch(self.task, imgs, labels, masks, 669 | channels=([self.channel_idx] if self.channel_idx >= 0 else None), 670 | drop_background=True) 671 | 672 | X = repeat(imgs, 'N C H W -> T N C H W', T=labels.size(1)) 673 | Y = rearrange(labels, 'N T H W -> T N 1 H W') 674 | M = rearrange(masks, 'N T H W -> T N 1 H W') 675 | t_idx = torch.arange(len(Y)) 676 | 677 | if self.image_augmentation is not None and not self.fix_seed: 678 | X, Y, M = self.image_augmentation(X, Y, M) 679 | 680 | # crop arrays 681 | X, Y, M = crop_arrays(X, Y, M, 682 | base_size=self.base_size, 683 | img_size=self.img_size, 684 | random=(not self.fix_seed)) 685 | 686 | return X, Y, M, t_idx 687 | --------------------------------------------------------------------------------