├── .gitignore ├── LICENSE ├── README.md ├── args.py ├── dataset ├── __init__.py ├── augmentation.py ├── base.py ├── benchmark_dataset.py ├── coco_api_wrapper.py ├── custom_colorjitter.py ├── meta_info │ ├── ap10k │ │ ├── test_class_dict.pth │ │ ├── test_class_dict_skip_crowd.pth │ │ ├── train_class_dict.pth │ │ ├── train_class_dict_skip_crowd.pth │ │ ├── val_class_dict.pth │ │ └── val_class_dict_skip_crowd.pth │ ├── cellpose │ │ └── train_idxs_perm.pth │ ├── coco │ │ ├── coco_train_class_dict.pth │ │ ├── coco_val_class_dict.pth │ │ ├── cocostuff_train_class_dict.pth │ │ ├── cocostuff_val_class_dict.pth │ │ ├── edge_params.pth │ │ ├── idxs_perm_categorical.pth │ │ ├── idxs_perm_categorical_kp_train.pth │ │ ├── idxs_perm_categorical_kp_train_cropped.pth │ │ ├── idxs_perm_categorical_kp_val.pth │ │ ├── idxs_perm_categorical_kp_val_cropped.pth │ │ ├── idxs_perm_categorical_train.pth │ │ ├── idxs_perm_categorical_val.pth │ │ ├── idxs_perm_train.pth │ │ └── idxs_perm_val.pth │ ├── davis2017 │ │ └── davis2017_n_objects.pth │ ├── deepfashion │ │ ├── edge_params.pth │ │ ├── idxs_perm_train.pth │ │ └── idxs_perm_val.pth │ ├── freihand │ │ ├── edge_params.pth │ │ ├── idxs_perm_train.pth │ │ └── idxs_perm_val.pth │ ├── kitti │ │ └── kitti_eigen_depth_range.pth │ ├── midair │ │ ├── edge_params.pth │ │ ├── idxs_perm_all.pth │ │ ├── idxs_perm_classes.pth │ │ ├── img_files.pth │ │ ├── midair_class_dict.pth │ │ ├── midair_log_depth_range.pth │ │ └── midair_log_disparity_range.pth │ ├── mpii │ │ ├── edge_params.pth │ │ ├── idxs_perm.pth │ │ └── kp_idxs_perm.pth │ ├── nyud │ │ ├── depth_range.pth │ │ ├── log_depth_range.pth │ │ └── train_idxs.pth │ ├── openimages │ │ ├── edge_params.pth │ │ ├── train_files.pth │ │ └── validation_files.pth │ ├── sintel │ │ └── train_perm.pth │ └── taskonomy │ │ ├── class_dict_all.pth │ │ ├── depth_quantiles.pth │ │ ├── edge_params.pth │ │ ├── edge_thresholds.pth │ │ ├── idxs_perm_all.pth │ │ ├── idxs_perm_classes.pth │ │ └── img_files.pth ├── resize_buildings.py ├── unlabeled.py └── utils.py ├── davis2016-evaluation ├── .gitignore ├── README.md ├── apply_crf.py ├── crf.py ├── davis2017 │ ├── __init__.py │ ├── davis.py │ ├── evaluation.py │ ├── metrics.py │ └── utils.py ├── evaluation_codalab.py ├── evaluation_method.py ├── pytest │ └── test_evaluation.py ├── setup.cfg ├── setup.py └── test.py ├── downstream ├── __init__.py ├── ap10k │ ├── configs │ │ ├── test_config.yaml │ │ └── train_config.yaml │ ├── dataset.py │ ├── evaluator.py │ ├── learner.py │ └── utils.py ├── base_learner.py ├── cellpose │ ├── configs │ │ ├── test_config.yaml │ │ └── train_config.yaml │ ├── dataset.py │ ├── learner.py │ └── utils.py ├── davis2017 │ ├── __init__.py │ ├── configs │ │ ├── test_config.yaml │ │ └── train_config.yaml │ ├── dataset.py │ ├── learner.py │ └── utils.py ├── fsc147 │ ├── configs │ │ ├── test_config.yaml │ │ └── train_config.yaml │ ├── dataset.py │ ├── learner.py │ └── utils.py ├── isic2018 │ ├── configs │ │ ├── test_config.yaml │ │ └── train_config.yaml │ ├── dataset.py │ └── learner.py ├── learner_factory.py └── linemod │ ├── configs │ ├── test_config.yaml │ └── train_config.yaml │ ├── dataset.py │ ├── learner.py │ └── utils.py ├── get_beitv2.py ├── main.py ├── main_figure.png ├── meta_train ├── __init__.py ├── datasets │ ├── __init__.py │ ├── coco.py │ ├── coco_stereo.py │ ├── deepfashion.py │ ├── freihand.py │ ├── midair.py │ ├── midair_stereo.py │ ├── mpii.py │ └── taskonomy.py ├── evaluator.py ├── learner.py ├── train_config.yaml ├── unified.py └── utils.py ├── model ├── __init__.py ├── chameleon.py ├── decoder.py ├── encoder.py ├── matching.py ├── model_factory.py └── transformers │ ├── __init__.py │ ├── beit.py │ ├── custom_layers.py │ ├── factory.py │ ├── helpers.py │ ├── registry.py │ └── vision_transformer.py ├── preprocess_checkpoints.py ├── requirements.txt ├── scripts ├── ap10k │ ├── finetune.sh │ ├── finetune_all.sh │ ├── run.sh │ ├── run_all.sh │ ├── test.sh │ └── test_all.sh ├── cellpose │ ├── finetune.sh │ ├── run.sh │ └── test.sh ├── davis2017 │ ├── davis_evaluation.sh │ ├── finetune.sh │ ├── finetune_all.sh │ ├── run.sh │ ├── run_all.sh │ ├── run_all_in_one.sh │ ├── test.sh │ └── test_all.sh ├── fsc147 │ ├── finetune.sh │ ├── run.sh │ └── test.sh ├── isic2018 │ ├── finetune.sh │ ├── finetune_all.sh │ ├── run.sh │ └── test.sh ├── linemod │ ├── finetune_all.sh │ ├── finetune_pose.sh │ ├── finetune_segment.sh │ ├── finetune_segment_all.sh │ ├── run.sh │ ├── run_all_in_one.sh │ ├── run_pose.sh │ ├── run_pose_all.sh │ ├── run_segment.sh │ ├── run_segment_all.sh │ ├── test_all.sh │ ├── test_pose.sh │ ├── test_segment.sh │ └── test_segment_all.sh └── unified │ └── train.sh └── train ├── __init__.py ├── loss.py ├── miou_fss.py ├── optim.py ├── train_utils.py ├── trainer.py ├── visualize.py └── zero_to_fp32.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb_checkpoints* 2 | *__pycache__* 3 | *.ipynb 4 | *.code-workspace 5 | experiments* 6 | support_data*.pth 7 | .idea* 8 | .vscode* 9 | data_paths.yaml 10 | notebooks* 11 | lightning_logs* 12 | results* 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Donggyun Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chameleon: A Data-Efficient Generalist for Dense Visual Prediction in the Wild 2 | 3 | (Update 03/13, 2025) We have uploaded pretrained checkpoints in [this link](https://drive.google.com/drive/folders/1m71RO6-8lN6q4yv8odCjG01loFUHq3-b?usp=drive_link). 4 | 5 | This repository contains official code for [Chameleon: A Data-Efficient Generalist for Dense Visual Prediction in the Wild](https://arxiv.org/abs/2404.18459) (ECCV 2024 oral). 6 | 7 | The documentation for downloading and preprocessing datasets will be uploaded soon. 8 | 9 | Please checkout gaudi-v2 branch for gaudi-v2 implementation of Chameleon. 10 | 11 | ![image-Chameleon](https://github.com/GitGyun/chameleon/blob/main/main_figure.png) 12 | 13 | ## Citation 14 | If you find this work useful, please consider citing: 15 | ```bib 16 | @article{kim2024chameleon, 17 | title={Chameleon: A Data-Efficient Generalist for Dense Visual Prediction in the Wild}, 18 | author={Kim, Donggyun and Cho, Seongwoong and Kim, Semin and Luo, Chong and Hong, Seunghoon}, 19 | journal={arXiv preprint arXiv:2404.18459}, 20 | year={2024} 21 | } 22 | ``` 23 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from .utils import crop_arrays, SobelEdgeDetector 7 | from .augmentation import RandomHorizontalFlip, FILTERING_AUGMENTATIONS, RandomCompose, Mixup, CustomTrivialAugmentWide 8 | from .unified_constants import TASKS_GROUP_DICT 9 | 10 | 11 | class BaseDataset(Dataset): 12 | def __init__(self, root_dir, domains, tasks, component, base_size=(256, 256), 13 | img_size=(224, 224), seed=None, precision='fp32', meta_dir='meta_info'): 14 | super().__init__() 15 | 16 | if seed is not None: 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | 22 | self.data_root = os.path.join(root_dir, f'{component}_{base_size[0]}_merged') 23 | self.domains = sorted(domains) 24 | 25 | self.subtasks = tasks 26 | 27 | self.base_size = base_size 28 | self.img_size = img_size 29 | self.precision = precision 30 | 31 | self.meta_info_path = os.path.join('dataset', meta_dir) 32 | self.edge_params = torch.load(os.path.join(self.meta_info_path, 'edge_params.pth')) 33 | self.sobel_detectors = [SobelEdgeDetector(kernel_size=k, sigma=s) for k, s in self.edge_params['params']] 34 | 35 | def load_image(self, img_path): 36 | raise NotImplementedError 37 | 38 | def load_task(self, task, img_path): 39 | raise NotImplementedError 40 | 41 | def preprocess_batch(self, task, imgs, labels, masks, channels=None, drop_background=True): 42 | raise NotImplementedError 43 | 44 | 45 | class TrainDataset(BaseDataset): 46 | def __init__(self, root_dir, domains, tasks, shot, tasks_per_batch, domains_per_batch, 47 | image_augmentation, unary_augmentation, binary_augmentation, 48 | dset_size=-1, **kwargs): 49 | super().__init__(root_dir, domains, tasks, **kwargs) 50 | 51 | assert shot > 0 52 | self.shot = shot 53 | self.tasks_per_batch = tasks_per_batch 54 | self.domains_per_batch = min(len(domains)//2, domains_per_batch) 55 | self.dset_size = dset_size 56 | 57 | if image_augmentation: 58 | self.image_augmentation = RandomHorizontalFlip() 59 | else: 60 | self.image_augmentation = None 61 | 62 | if unary_augmentation: 63 | self.unary_augmentation = RandomCompose( 64 | [augmentation(**kwargs) for augmentation, kwargs in FILTERING_AUGMENTATIONS.values()], 65 | p=0.8, 66 | ) 67 | else: 68 | self.unary_augmentation = None 69 | 70 | if binary_augmentation is not None: 71 | self.binary_augmentation = Mixup(order=True) 72 | else: 73 | self.binary_augmentation = None 74 | 75 | def __len__(self): 76 | if self.dset_size > 0: 77 | return self.dset_size 78 | else: 79 | return len(self.img_paths) // self.shot 80 | 81 | def __getitem__(self, idx): 82 | return self.sample(idx, self.tasks_per_batch) 83 | 84 | def sample(self, idx, n_channels): 85 | raise NotImplementedError 86 | 87 | 88 | class ContinuousDataset(BaseDataset): 89 | def __init__(self, root_dir, domains, task, channel_idx=-1, dset_size=-1, **kwargs): 90 | super().__init__(root_dir, domains, [task], **kwargs) 91 | 92 | self.task = task 93 | self.channel_idx = channel_idx 94 | self.dset_size = dset_size 95 | self.n_channels = len(TASKS_GROUP_DICT[task]) 96 | 97 | def __len__(self): 98 | if self.dset_size > 0: 99 | return self.dset_size 100 | else: 101 | return len(self.img_paths) 102 | 103 | def __getitem__(self, idx): 104 | img_path = self.img_paths[idx % len(self.img_paths)] 105 | 106 | # load image, label, and mask 107 | img, success = self.load_img(img_path) 108 | label, mask = self.load_task(self.task, img_path) 109 | if not success: 110 | mask = np.zeros_like(label) 111 | 112 | # preprocess labels 113 | imgs, labels, masks = self.preprocess_batch(self.task, 114 | img[None], 115 | None if label is None else label[None], 116 | None if mask is None else mask[None], 117 | channels=([self.channel_idx] if self.channel_idx >= 0 else None), 118 | drop_background=False) 119 | 120 | 121 | X, Y, M = imgs[0], labels[0], masks[0] 122 | if self.image_augmentation is not None: 123 | X, Y, M = self.image_augmentation(X, Y, M) 124 | 125 | # crop arrays 126 | X, Y, M = crop_arrays(X, Y, M, 127 | base_size=self.base_size, 128 | crop_size=self.img_size, 129 | random=True) 130 | 131 | return X, Y, M 132 | 133 | 134 | class SegmentationDataset(BaseDataset): 135 | def __init__(self, root_dir, domains, semseg_class=-1, dset_size=-1, **kwargs): 136 | super().__init__(root_dir, domains, ['segment_semantic'], **kwargs) 137 | 138 | self.semseg_class = semseg_class 139 | self.img_paths = torch.load(os.path.join(self.meta_info_path, 'img_files.pth')) # use global path dictionary 140 | self.n_channels = 1 141 | self.dset_size = dset_size 142 | 143 | def generate_class_idxs(self): 144 | raise NotImplementedError 145 | 146 | def __len__(self): 147 | if self.dset_size > 0: 148 | return self.dset_size 149 | else: 150 | return len(self.class_idxs) 151 | 152 | def __getitem__(self, idx): 153 | path_idx = self.class_idxs[idx % len(self.class_idxs)] 154 | img_path = self.img_paths[path_idx] 155 | 156 | # load image, label, and mask 157 | img, success = self.load_img(img_path) 158 | label, mask = self.load_task('segment_semantic', img_path) 159 | if not success: 160 | mask = np.zeros_like(mask) 161 | 162 | # preprocess labels 163 | imgs, labels, masks = self.preprocess_batch('segment_semantic', 164 | img[None], 165 | None if label is None else label[None], 166 | None if mask is None else mask[None], 167 | [self.semseg_class]) 168 | 169 | X, Y, M = imgs[0], labels[0], masks[0] 170 | 171 | # crop arrays 172 | X, Y, M = crop_arrays(X, Y, M, 173 | base_size=self.base_size, 174 | crop_size=self.img_size, 175 | random=True) 176 | 177 | return X, Y, M 178 | -------------------------------------------------------------------------------- /dataset/benchmark_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class BenchmarkDataset(Dataset): 6 | def __init__(self, shot, dset_size, n_channels): 7 | self.shot = shot 8 | self.dset_size = dset_size 9 | self.n_channels = n_channels 10 | self.task_group_names = [f'proxy_{i}' for i in range(n_channels)] 11 | self.TASK_GROUP_NAMES = self.task_group_names 12 | 13 | def __len__(self): 14 | return self.dset_size 15 | 16 | def __getitem__(self, idx): 17 | X = torch.rand(self.n_channels, self.shot*2, 3, 224, 224) 18 | Y = torch.rand(self.n_channels, self.shot*2, 1, 224, 224) 19 | M = torch.ones_like(Y) 20 | t_idx = torch.arange(self.n_channels) 21 | g_idx = torch.LongTensor([0])[0] 22 | channel_mask = torch.ones(self.n_channels, self.n_channels, dtype=torch.bool) 23 | 24 | return X, Y, M, t_idx, g_idx, channel_mask 25 | -------------------------------------------------------------------------------- /dataset/custom_colorjitter.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | 3 | import torch 4 | import torchvision.transforms.functional as F 5 | 6 | 7 | class CustomColorJitter(torch.nn.Module): 8 | """Randomly change the brightness, contrast and saturation of an image. 9 | 10 | Args: 11 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 12 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 13 | or the given [min, max]. Should be non negative numbers. 14 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 15 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 16 | or the given [min, max]. Should be non negative numbers. 17 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 18 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 19 | or the given [min, max]. Should be non negative numbers. 20 | hue (float or tuple of float (min, max)): How much to jitter hue. 21 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 22 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 23 | """ 24 | 25 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 26 | super().__init__() 27 | self.brightness = self._check_input(brightness, 'brightness') 28 | self.contrast = self._check_input(contrast, 'contrast') 29 | self.saturation = self._check_input(saturation, 'saturation') 30 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 31 | clip_first_on_zero=False) 32 | 33 | @torch.jit.unused 34 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 35 | if isinstance(value, numbers.Number): 36 | if value < 0: 37 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 38 | value = [center - float(value), center + float(value)] 39 | if clip_first_on_zero: 40 | value[0] = max(value[0], 0.0) 41 | elif isinstance(value, (tuple, list)) and len(value) == 2: 42 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 43 | raise ValueError("{} values should be between {}".format(name, bound)) 44 | else: 45 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 46 | 47 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 48 | # or (0., 0.) for hue, do nothing 49 | if value[0] == value[1] == center: 50 | value = None 51 | return value 52 | 53 | def forward(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor): 54 | """ 55 | Args: 56 | img (PIL Image or Tensor): Input image. 57 | 58 | Returns: 59 | PIL Image or Tensor: Color jittered image. 60 | """ 61 | # fn_idx = torch.randperm(4) 62 | for fn_id in fn_idx: 63 | if fn_id == 0 and self.brightness is not None: 64 | # brightness = self.brightness 65 | # brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() 66 | img = F.adjust_brightness(img, brightness_factor) 67 | 68 | if fn_id == 1 and self.contrast is not None: 69 | # contrast = self.contrast 70 | # contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() 71 | img = F.adjust_contrast(img, contrast_factor) 72 | 73 | if fn_id == 2 and self.saturation is not None: 74 | # saturation = self.saturation 75 | # saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() 76 | img = F.adjust_saturation(img, saturation_factor) 77 | 78 | if fn_id == 3 and self.hue is not None: 79 | # hue = self.hue 80 | # hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() 81 | img = F.adjust_hue(img, hue_factor) 82 | 83 | return img 84 | 85 | def __repr__(self): 86 | format_string = self.__class__.__name__ + '(' 87 | format_string += 'brightness={0}'.format(self.brightness) 88 | format_string += ', contrast={0}'.format(self.contrast) 89 | format_string += ', saturation={0}'.format(self.saturation) 90 | format_string += ', hue={0})'.format(self.hue) 91 | return format_string 92 | -------------------------------------------------------------------------------- /dataset/meta_info/ap10k/test_class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/ap10k/test_class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/ap10k/test_class_dict_skip_crowd.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/ap10k/test_class_dict_skip_crowd.pth -------------------------------------------------------------------------------- /dataset/meta_info/ap10k/train_class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/ap10k/train_class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/ap10k/train_class_dict_skip_crowd.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/ap10k/train_class_dict_skip_crowd.pth -------------------------------------------------------------------------------- /dataset/meta_info/ap10k/val_class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/ap10k/val_class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/ap10k/val_class_dict_skip_crowd.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/ap10k/val_class_dict_skip_crowd.pth -------------------------------------------------------------------------------- /dataset/meta_info/cellpose/train_idxs_perm.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/cellpose/train_idxs_perm.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/coco_train_class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/coco_train_class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/coco_val_class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/coco_val_class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/cocostuff_train_class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/cocostuff_train_class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/cocostuff_val_class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/cocostuff_val_class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/edge_params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/edge_params.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_categorical.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_categorical.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_categorical_kp_train.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_categorical_kp_train.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_categorical_kp_train_cropped.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_categorical_kp_train_cropped.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_categorical_kp_val.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_categorical_kp_val.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_categorical_kp_val_cropped.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_categorical_kp_val_cropped.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_categorical_train.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_categorical_train.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_categorical_val.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_categorical_val.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_train.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_train.pth -------------------------------------------------------------------------------- /dataset/meta_info/coco/idxs_perm_val.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/coco/idxs_perm_val.pth -------------------------------------------------------------------------------- /dataset/meta_info/davis2017/davis2017_n_objects.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/davis2017/davis2017_n_objects.pth -------------------------------------------------------------------------------- /dataset/meta_info/deepfashion/edge_params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/deepfashion/edge_params.pth -------------------------------------------------------------------------------- /dataset/meta_info/deepfashion/idxs_perm_train.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/deepfashion/idxs_perm_train.pth -------------------------------------------------------------------------------- /dataset/meta_info/deepfashion/idxs_perm_val.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/deepfashion/idxs_perm_val.pth -------------------------------------------------------------------------------- /dataset/meta_info/freihand/edge_params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/freihand/edge_params.pth -------------------------------------------------------------------------------- /dataset/meta_info/freihand/idxs_perm_train.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/freihand/idxs_perm_train.pth -------------------------------------------------------------------------------- /dataset/meta_info/freihand/idxs_perm_val.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/freihand/idxs_perm_val.pth -------------------------------------------------------------------------------- /dataset/meta_info/kitti/kitti_eigen_depth_range.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/kitti/kitti_eigen_depth_range.pth -------------------------------------------------------------------------------- /dataset/meta_info/midair/edge_params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/midair/edge_params.pth -------------------------------------------------------------------------------- /dataset/meta_info/midair/idxs_perm_all.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/midair/idxs_perm_all.pth -------------------------------------------------------------------------------- /dataset/meta_info/midair/idxs_perm_classes.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/midair/idxs_perm_classes.pth -------------------------------------------------------------------------------- /dataset/meta_info/midair/img_files.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/midair/img_files.pth -------------------------------------------------------------------------------- /dataset/meta_info/midair/midair_class_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/midair/midair_class_dict.pth -------------------------------------------------------------------------------- /dataset/meta_info/midair/midair_log_depth_range.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/midair/midair_log_depth_range.pth -------------------------------------------------------------------------------- /dataset/meta_info/midair/midair_log_disparity_range.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/midair/midair_log_disparity_range.pth -------------------------------------------------------------------------------- /dataset/meta_info/mpii/edge_params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/mpii/edge_params.pth -------------------------------------------------------------------------------- /dataset/meta_info/mpii/idxs_perm.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/mpii/idxs_perm.pth -------------------------------------------------------------------------------- /dataset/meta_info/mpii/kp_idxs_perm.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/mpii/kp_idxs_perm.pth -------------------------------------------------------------------------------- /dataset/meta_info/nyud/depth_range.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/nyud/depth_range.pth -------------------------------------------------------------------------------- /dataset/meta_info/nyud/log_depth_range.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/nyud/log_depth_range.pth -------------------------------------------------------------------------------- /dataset/meta_info/nyud/train_idxs.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/nyud/train_idxs.pth -------------------------------------------------------------------------------- /dataset/meta_info/openimages/edge_params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/openimages/edge_params.pth -------------------------------------------------------------------------------- /dataset/meta_info/openimages/train_files.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/openimages/train_files.pth -------------------------------------------------------------------------------- /dataset/meta_info/openimages/validation_files.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/openimages/validation_files.pth -------------------------------------------------------------------------------- /dataset/meta_info/sintel/train_perm.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/sintel/train_perm.pth -------------------------------------------------------------------------------- /dataset/meta_info/taskonomy/class_dict_all.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/taskonomy/class_dict_all.pth -------------------------------------------------------------------------------- /dataset/meta_info/taskonomy/depth_quantiles.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/taskonomy/depth_quantiles.pth -------------------------------------------------------------------------------- /dataset/meta_info/taskonomy/edge_params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/taskonomy/edge_params.pth -------------------------------------------------------------------------------- /dataset/meta_info/taskonomy/edge_thresholds.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/taskonomy/edge_thresholds.pth -------------------------------------------------------------------------------- /dataset/meta_info/taskonomy/idxs_perm_all.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/taskonomy/idxs_perm_all.pth -------------------------------------------------------------------------------- /dataset/meta_info/taskonomy/idxs_perm_classes.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/taskonomy/idxs_perm_classes.pth -------------------------------------------------------------------------------- /dataset/meta_info/taskonomy/img_files.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/dataset/meta_info/taskonomy/img_files.pth -------------------------------------------------------------------------------- /dataset/resize_buildings.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import os 3 | import torch 4 | import tqdm 5 | import yaml 6 | from PIL import Image 7 | 8 | building_list = [ 9 | 'allensville', 10 | 'beechwood', 11 | 'benevolence', 12 | 'collierville', 13 | 'coffeen', 14 | 'corozal', 15 | 'cosmos', 16 | 'darden', 17 | 'forkland', 18 | 'hanson', 19 | 'hiteman', 20 | 'ihlen', 21 | 'klickitat', 22 | 'lakeville', 23 | 'leonardo', 24 | 'lindenwood', 25 | 'markleeville', 26 | 'marstons', 27 | 'mcdade', 28 | 'merom', 29 | 'mifflinburg', 30 | 'muleshoe', 31 | 'newfields', 32 | 'noxapater', 33 | 'onaga', 34 | 'pinesdale', 35 | 'pomaria', 36 | 'ranchester', 37 | 'shelbyville', 38 | 'stockman', 39 | 'tolstoy', 40 | 'uvalda', 41 | 'wainscott', 42 | 'wiconisco', 43 | 'woodbine', 44 | ] 45 | 46 | task_list = [ 47 | 'rgb', 48 | 'normal', 49 | 'depth_euclidean', 50 | 'depth_zbuffer', 51 | 'edge_occlusion', 52 | 'keypoints2d', 53 | 'keypoints3d', 54 | 'reshading', 55 | 'principal_curvature', 56 | 'segment_semantic' 57 | ] 58 | 59 | 60 | def resize(args): 61 | load_path, save_path, mode = args 62 | try: 63 | img = Image.open(load_path) 64 | img = img.resize(size, mode) 65 | img.save(save_path) 66 | return None 67 | except Exception as e: 68 | print(e) 69 | return load_path 70 | 71 | 72 | if __name__ == "__main__": 73 | verbose = True 74 | size = (256, 256) 75 | split = "tiny" 76 | n_threads = 20 77 | 78 | with open('data_paths.yaml', 'r') as f: 79 | path_dict = yaml.safe_load(f) 80 | load_root = save_root = path_dict['taskonomy'] 81 | 82 | load_dir = os.path.join(load_root, split) 83 | assert os.path.isdir(load_dir) 84 | ''' 85 | load_dir 86 | |--building 87 | |--task 88 | |--file 89 | ''' 90 | save_dir = os.path.join(save_root, f"{split}_{size[0]}_merged") 91 | os.makedirs(save_dir, exist_ok=True) 92 | ''' 93 | save_dir 94 | |--task 95 | |--file 96 | ''' 97 | 98 | args = [] 99 | print("creating args...") 100 | for b_idx, building in enumerate(building_list): 101 | assert os.path.isdir(os.path.join(load_dir, building)) 102 | for task in task_list: 103 | mode = Image.NEAREST if task == "segment_semantic" else Image.BILINEAR 104 | if b_idx == 0: 105 | os.makedirs(os.path.join(save_dir, task), exist_ok=True) 106 | 107 | load_names = os.listdir(os.path.join(load_dir, building, task)) 108 | load_paths = [os.path.join(load_dir, building, task, load_name) for load_name in load_names] 109 | save_paths = [os.path.join(save_dir, task, f'{building}_{load_name}') for load_name in load_names] 110 | modes = [mode]*len(load_names) 111 | args += list(zip(load_paths, save_paths, modes)) 112 | 113 | fail_list = [] 114 | pool = Pool(n_threads) 115 | total = len(args) 116 | pbar = tqdm.tqdm(total=total, bar_format="{desc:<5}{percentage:3.0f}%|{bar:10}{r_bar}") 117 | for fail_path in pool.imap(resize, args): 118 | if fail_path is not None: 119 | fail_list += [fail_path] 120 | pbar.update() 121 | pbar.close() 122 | 123 | torch.save(fail_list, "fail_list.pth") 124 | 125 | pool.close() 126 | pool.join() 127 | -------------------------------------------------------------------------------- /davis2016-evaluation/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | docs/site/ 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # pytest 105 | .pytest_cache 106 | 107 | # Pylint 108 | .pylintrc 109 | 110 | # PyCharm 111 | .idea/ 112 | .DS_Store 113 | 114 | # Generated C code 115 | _mask.c 116 | -------------------------------------------------------------------------------- /davis2016-evaluation/README.md: -------------------------------------------------------------------------------- 1 | # DAVIS 2016 evaluation 2 | 3 | This is **not** an official script. 4 | 5 | Using the [precomputed results](https://davischallenge.org/davis2016/soa_compare.html), the numbers are the same as those on the leaderboard so I think this script is correct. Note that it accepts results in the 0~255 (thresholded at 128) format, not the 0/1 pixel format. 6 | 7 | Example: 8 | 9 | ```bash 10 | python evaluation_method.py --task semi-supervised --davis_path [path to davis 2017 trainval] --year 2016 --results_path ../mhpvos 11 | ``` 12 | 13 | 14 | See also: 15 | 16 | https://github.com/davisvideochallenge/davis2017-evaluation 17 | 18 | https://github.com/davisvideochallenge/davis2017-evaluation/issues/4 -------------------------------------------------------------------------------- /davis2016-evaluation/apply_crf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from crf import dense_crf 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | from torchvision.transforms import Resize 8 | from torchvision.utils import save_image 9 | from tqdm import tqdm 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--image_dir', type=str, required=True) 14 | parser.add_argument('--result_dir', type=str, required=True) 15 | parser.add_argument('--save_dir', type=str, required=True) 16 | args = parser.parse_args() 17 | 18 | upsample = Resize((480, 854)) 19 | 20 | obj_dirs = [] 21 | save_dirs = [] 22 | for obj in os.listdir(args.result_dir): 23 | obj_dir = os.path.join(args.result_dir, obj) 24 | if os.path.isdir(obj_dir): 25 | obj_dirs.append(obj_dir) 26 | save_dir = os.path.join(args.save_dir, obj) 27 | os.makedirs(save_dir, exist_ok=True) 28 | save_dirs.append(save_dir) 29 | 30 | pbar = tqdm(total=len(obj_dirs), bar_format="{desc:<5}{percentage:3.0f}%|{bar:10}{r_bar}") 31 | for obj_dir, save_dir in zip(obj_dirs, save_dirs): 32 | obj = obj_dir.split('/')[-1] 33 | for file in os.listdir(obj_dir): 34 | if file.split('.')[-1] == 'pth': 35 | score_path = os.path.join(obj_dir, file) 36 | score = torch.load(score_path).float() 37 | 38 | img_path = os.path.join(args.image_dir, obj, file.replace('.pth', '.jpg')) 39 | img = np.array(Image.open(img_path)) 40 | 41 | save_path = os.path.join(save_dir, file.replace('.pth', '.png')) 42 | 43 | prob = upsample(torch.cat((torch.zeros_like(score), 44 | score)).softmax(dim=0).cpu()).numpy() 45 | pred = torch.argmax(torch.from_numpy(dense_crf(img, prob)), dim=0).float() 46 | save_image(pred, save_path) 47 | pbar.update() 48 | -------------------------------------------------------------------------------- /davis2016-evaluation/crf.py: -------------------------------------------------------------------------------- 1 | import pydensecrf.densecrf as dcrf 2 | import pydensecrf.utils as utils 3 | import numpy as np 4 | 5 | 6 | def dense_crf(img, output_probs): 7 | """ Conditional Random Field for better segmentation 8 | Refer to https://github.com/lucasb-eyer/pydensecrf for details. 9 | """ 10 | 11 | c = output_probs.shape[0] 12 | h = output_probs.shape[1] 13 | w = output_probs.shape[2] 14 | 15 | U = utils.unary_from_softmax(output_probs) 16 | U = np.ascontiguousarray(U) 17 | 18 | img = np.ascontiguousarray(img) 19 | 20 | d = dcrf.DenseCRF2D(w, h, c) 21 | d.setUnaryEnergy(U) 22 | d.addPairwiseGaussian(sxy=1, compat=15) 23 | d.addPairwiseBilateral(sxy=67, srgb=3, rgbim=img, compat=4) 24 | 25 | Q = d.inference(10) 26 | Q = np.array(Q).reshape((c, h, w)) 27 | return Q 28 | -------------------------------------------------------------------------------- /davis2016-evaluation/davis2017/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | __version__ = '0.1.0' 4 | -------------------------------------------------------------------------------- /davis2016-evaluation/davis2017/davis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from collections import defaultdict 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | class DAVIS(object): 9 | SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge'] 10 | TASKS = ['semi-supervised', 'unsupervised'] 11 | YEARS = ['2016', '2017', '2019'] 12 | DATASET_WEB = 'https://davischallenge.org/davis2017/code.html' 13 | VOID_LABEL = 255 14 | 15 | def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False, year='2017'): 16 | """ 17 | Class to read the DAVIS dataset 18 | :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. 19 | :param task: Task to load the annotations, choose between semi-supervised or unsupervised. 20 | :param subset: Set to load the annotations 21 | :param sequences: Sequences to consider, 'all' to use all the sequences in a set. 22 | :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' 23 | """ 24 | if subset not in self.SUBSET_OPTIONS: 25 | raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') 26 | if task not in self.TASKS: 27 | raise ValueError(f'The only tasks that are supported are {self.TASKS}') 28 | if year not in self.YEARS: 29 | raise ValueError(f'Year should be one of the following {self.YEARS}') 30 | 31 | self.task = task 32 | self.subset = subset 33 | self.root = root 34 | if year == '2017': 35 | self.root = os.path.join(self.root, 'trainval') 36 | self.img_path = os.path.join(self.root, 'JPEGImages', resolution) 37 | annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised' 38 | self.mask_path = os.path.join(self.root, annotations_folder, resolution) 39 | # year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017' 40 | # self.imagesets_path = os.path.join(self.root, 'ImageSets', year) 41 | 42 | self.year = year 43 | if self.year == '2019' and not (task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge')): 44 | raise ValueError("Set 'task' to 'unsupervised' and subset to 'test-dev' or 'test-challenge'") 45 | self.imagesets_path = os.path.join(self.root, 'ImageSets', self.year) 46 | if self.year == '2016': 47 | self.imagesets_path = os.path.join(self.root.replace('2016', '2017'), 'trainval', 'ImageSets', self.year) 48 | 49 | self._check_directories() 50 | 51 | if sequences == 'all': 52 | with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: 53 | tmp = f.readlines() 54 | sequences_names = [x.strip() for x in tmp] 55 | else: 56 | sequences_names = sequences if isinstance(sequences, list) else [sequences] 57 | self.sequences = defaultdict(dict) 58 | 59 | for seq in sequences_names: 60 | images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() 61 | if len(images) == 0 and not codalab: 62 | raise FileNotFoundError(f'Images for sequence {seq} not found.') 63 | self.sequences[seq]['images'] = images 64 | masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() 65 | masks.extend([-1] * (len(images) - len(masks))) 66 | self.sequences[seq]['masks'] = masks 67 | 68 | def _check_directories(self): 69 | if not os.path.exists(self.root): 70 | raise FileNotFoundError(f'DAVIS not found in the specified directory in {self.root}, download it from {self.DATASET_WEB}') 71 | if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): 72 | raise FileNotFoundError(f'Subset sequences list for {self.subset} not found in {self.imagests_path}, download the missing subset ' 73 | f'for the {self.task} task from {self.DATASET_WEB}') 74 | if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): 75 | raise FileNotFoundError(f'Annotations folder for the {self.task} task not found in {self.mask_path}, download it from {self.DATASET_WEB}') 76 | 77 | def get_frames(self, sequence): 78 | for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): 79 | image = np.array(Image.open(img)) 80 | mask = None if msk is None else np.array(Image.open(msk)) 81 | yield image, mask 82 | 83 | def _get_all_elements(self, sequence, obj_type): 84 | obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) 85 | all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) 86 | obj_id = [] 87 | for i, obj in enumerate(self.sequences[sequence][obj_type]): 88 | all_objs[i, ...] = np.array(Image.open(obj)) 89 | obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) 90 | if self.year == '2016' and obj_type != 'scores': 91 | all_objs /= 255. 92 | return all_objs, obj_id 93 | 94 | def get_all_images(self, sequence): 95 | return self._get_all_elements(sequence, 'images') 96 | 97 | def get_all_masks(self, sequence, separate_objects_masks=False): 98 | masks, masks_id = self._get_all_elements(sequence, 'masks') 99 | masks_void = np.zeros_like(masks) 100 | 101 | # Separate void and object masks 102 | for i in range(masks.shape[0]): 103 | masks_void[i, ...] = masks[i, ...] == 255 104 | masks[i, masks[i, ...] == 255] = 0 105 | 106 | if separate_objects_masks: 107 | num_objects = int(np.max(masks[0, ...])) 108 | tmp = np.ones((num_objects, *masks.shape)) 109 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 110 | masks = (tmp == masks[None, ...]) 111 | masks = masks > 0 112 | else: 113 | # for single object evaluation (e.g. DAVIS2016) 114 | masks = np.expand_dims(masks, axis=0) 115 | masks = masks > 0 116 | return masks, masks_void, masks_id 117 | 118 | def get_sequences(self): 119 | for seq in self.sequences: 120 | yield seq 121 | 122 | 123 | if __name__ == '__main__': 124 | from matplotlib import pyplot as plt 125 | 126 | only_first_frame = True 127 | subsets = ['train', 'val'] 128 | 129 | for s in subsets: 130 | dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s) 131 | for seq in dataset.get_sequences(): 132 | g = dataset.get_frames(seq) 133 | img, mask = next(g) 134 | plt.subplot(2, 1, 1) 135 | plt.title(seq) 136 | plt.imshow(img) 137 | plt.subplot(2, 1, 2) 138 | plt.imshow(mask) 139 | plt.show(block=True) 140 | 141 | -------------------------------------------------------------------------------- /davis2016-evaluation/davis2017/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from PIL import Image 5 | import warnings 6 | from davis2017.davis import DAVIS 7 | 8 | 9 | def _pascal_color_map(N=256, normalized=False): 10 | """ 11 | Python implementation of the color map function for the PASCAL VOC data set. 12 | Official Matlab version can be found in the PASCAL VOC devkit 13 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 14 | """ 15 | 16 | def bitget(byteval, idx): 17 | return (byteval & (1 << idx)) != 0 18 | 19 | dtype = 'float32' if normalized else 'uint8' 20 | cmap = np.zeros((N, 3), dtype=dtype) 21 | for i in range(N): 22 | r = g = b = 0 23 | c = i 24 | for j in range(8): 25 | r = r | (bitget(c, 0) << 7 - j) 26 | g = g | (bitget(c, 1) << 7 - j) 27 | b = b | (bitget(c, 2) << 7 - j) 28 | c = c >> 3 29 | 30 | cmap[i] = np.array([r, g, b]) 31 | 32 | cmap = cmap / 255 if normalized else cmap 33 | return cmap 34 | 35 | 36 | def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): 37 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) 38 | if im.shape[:-1] != ann.shape: 39 | raise ValueError('First two dimensions of `im` and `ann` must match') 40 | if im.shape[-1] != 3: 41 | raise ValueError('im must have three channels at the 3 dimension') 42 | 43 | colors = colors or _pascal_color_map() 44 | colors = np.asarray(colors, dtype=np.uint8) 45 | 46 | mask = colors[ann] 47 | fg = im * alpha + (1 - alpha) * mask 48 | 49 | img = im.copy() 50 | img[ann > 0] = fg[ann > 0] 51 | 52 | if contour_thickness: # pragma: no cover 53 | import cv2 54 | for obj_id in np.unique(ann[ann > 0]): 55 | contours = cv2.findContours((ann == obj_id).astype( 56 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 57 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), 58 | contour_thickness) 59 | return img 60 | 61 | 62 | def generate_obj_proposals(davis_root, subset, num_proposals, save_path): 63 | dataset = DAVIS(davis_root, subset=subset, codalab=True) 64 | for seq in dataset.get_sequences(): 65 | save_dir = os.path.join(save_path, seq) 66 | if os.path.exists(save_dir): 67 | continue 68 | all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 69 | img_size = all_gt_masks.shape[2:] 70 | num_rows = int(np.ceil(np.sqrt(num_proposals))) 71 | proposals = np.zeros((num_proposals, len(all_masks_id), *img_size)) 72 | height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist() 73 | width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist() 74 | ii = 0 75 | prev_h, prev_w = 0, 0 76 | for h in height_slices[1:]: 77 | for w in width_slices[1:]: 78 | proposals[ii, :, prev_h:h, prev_w:w] = 1 79 | prev_w = w 80 | ii += 1 81 | if ii == num_proposals: 82 | break 83 | prev_h, prev_w = h, 0 84 | if ii == num_proposals: 85 | break 86 | 87 | os.makedirs(save_dir, exist_ok=True) 88 | for i, mask_id in enumerate(all_masks_id): 89 | mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0) 90 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 91 | 92 | 93 | def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path): 94 | dataset = DAVIS(davis_root, subset=subset, codalab=True) 95 | for seq in dataset.get_sequences(): 96 | gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 97 | obj_swap = np.random.permutation(np.arange(gt_masks.shape[0])) 98 | gt_masks = gt_masks[obj_swap, ...] 99 | save_dir = os.path.join(save_path, seq) 100 | os.makedirs(save_dir, exist_ok=True) 101 | for i, mask_id in enumerate(all_masks_id): 102 | mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0) 103 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 104 | 105 | 106 | def color_map(N=256, normalized=False): 107 | def bitget(byteval, idx): 108 | return ((byteval & (1 << idx)) != 0) 109 | 110 | dtype = 'float32' if normalized else 'uint8' 111 | cmap = np.zeros((N, 3), dtype=dtype) 112 | for i in range(N): 113 | r = g = b = 0 114 | c = i 115 | for j in range(8): 116 | r = r | (bitget(c, 0) << 7-j) 117 | g = g | (bitget(c, 1) << 7-j) 118 | b = b | (bitget(c, 2) << 7-j) 119 | c = c >> 3 120 | 121 | cmap[i] = np.array([r, g, b]) 122 | 123 | cmap = cmap/255 if normalized else cmap 124 | return cmap 125 | 126 | 127 | def save_mask(mask, img_path): 128 | if np.max(mask) > 255: 129 | raise ValueError('Maximum id pixel value is 255') 130 | mask_img = Image.fromarray(mask.astype(np.uint8)) 131 | mask_img.putpalette(color_map().flatten().tolist()) 132 | mask_img.save(img_path) 133 | 134 | 135 | def db_statistics(per_frame_values): 136 | """ Compute mean,recall and decay from per-frame evaluation. 137 | Arguments: 138 | per_frame_values (ndarray): per-frame evaluation 139 | 140 | Returns: 141 | M,O,D (float,float,float): 142 | return evaluation statistics: mean,recall,decay. 143 | """ 144 | 145 | # strip off nan values 146 | with warnings.catch_warnings(): 147 | warnings.simplefilter("ignore", category=RuntimeWarning) 148 | M = np.nanmean(per_frame_values) 149 | O = np.nanmean(per_frame_values > 0.5) 150 | 151 | N_bins = 4 152 | ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1 153 | ids = ids.astype(np.uint8) 154 | 155 | D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)] 156 | 157 | with warnings.catch_warnings(): 158 | warnings.simplefilter("ignore", category=RuntimeWarning) 159 | D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3]) 160 | 161 | return M, O, D 162 | 163 | 164 | def list_files(dir, extension=".png"): 165 | return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)] 166 | 167 | 168 | def force_symlink(file1, file2): 169 | try: 170 | os.symlink(file1, file2) 171 | except OSError as e: 172 | if e.errno == errno.EEXIST: 173 | os.remove(file2) 174 | os.symlink(file1, file2) 175 | -------------------------------------------------------------------------------- /davis2016-evaluation/evaluation_codalab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import os.path 4 | from time import time 5 | 6 | import numpy as np 7 | import pandas 8 | from davis2017.evaluation import DAVISEvaluation 9 | 10 | task = 'semi-supervised' 11 | gt_set = 'test-dev' 12 | 13 | time_start = time() 14 | # as per the metadata file, input and output directories are the arguments 15 | if len(sys.argv) < 3: 16 | input_dir = "input_dir" 17 | output_dir = "output_dir" 18 | debug = True 19 | else: 20 | [_, input_dir, output_dir] = sys.argv 21 | debug = False 22 | 23 | # unzipped submission data is always in the 'res' subdirectory 24 | # https://github.com/codalab/codalab-competitions/wiki/User_Building-a-Scoring-Program-for-a-Competition#directory-structure-for-submissions 25 | submission_path = os.path.join(input_dir, 'res') 26 | if not os.path.exists(submission_path): 27 | sys.exit('Could not find submission file {0}'.format(submission_path)) 28 | 29 | # unzipped reference data is always in the 'ref' subdirectory 30 | # https://github.com/codalab/codalab-competitions/wiki/User_Building-a-Scoring-Program-for-a-Competition#directory-structure-for-submissions 31 | gt_path = os.path.join(input_dir, 'ref') 32 | if not os.path.exists(gt_path): 33 | sys.exit('Could not find GT file {0}'.format(gt_path)) 34 | 35 | 36 | # Create dataset 37 | dataset_eval = DAVISEvaluation(davis_root=gt_path, gt_set=gt_set, task=task, codalab=True) 38 | 39 | # Check directory structure 40 | res_subfolders = os.listdir(submission_path) 41 | if len(res_subfolders) == 1: 42 | sys.stdout.write( 43 | "Incorrect folder structure, the folders of the sequences have to be placed directly inside the " 44 | "zip.\nInside every folder of the sequences there must be an indexed PNG file for every frame.\n" 45 | "The indexes have to match with the initial frame.\n") 46 | sys.exit() 47 | 48 | # Check that all sequences are there 49 | missing = False 50 | for seq in dataset_eval.dataset.get_sequences(): 51 | if seq not in res_subfolders: 52 | sys.stdout.write(seq + " sequence is missing.\n") 53 | missing = True 54 | if missing: 55 | sys.stdout.write( 56 | "Verify also the folder structure, the folders of the sequences have to be placed directly inside " 57 | "the zip.\nInside every folder of the sequences there must be an indexed PNG file for every frame.\n" 58 | "The indexes have to match with the initial frame.\n") 59 | sys.exit() 60 | 61 | metrics_res = dataset_eval.evaluate(submission_path, debug=debug) 62 | J, F = metrics_res['J'], metrics_res['F'] 63 | 64 | # Generate output to the stdout 65 | seq_names = list(J['M_per_object'].keys()) 66 | if gt_set == "val" or gt_set == "train" or gt_set == "test-dev": 67 | sys.stdout.write("----------------Global results in CSV---------------\n") 68 | g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay'] 69 | final_mean = (np.mean(J["M"]) + np.mean(F["M"])) / 2. 70 | g_res = np.array([final_mean, np.mean(J["M"]), np.mean(J["R"]), np.mean(J["D"]), np.mean(F["M"]), np.mean(F["R"]), 71 | np.mean(F["D"])]) 72 | table_g = pandas.DataFrame(data=np.reshape(g_res, [1, len(g_res)]), columns=g_measures) 73 | table_g.to_csv(sys.stdout, index=False, float_format="%0.3f") 74 | 75 | sys.stdout.write("\n\n------------Per sequence results in CSV-------------\n") 76 | seq_measures = ['Sequence', 'J-Mean', 'F-Mean'] 77 | J_per_object = [J['M_per_object'][x] for x in seq_names] 78 | F_per_object = [F['M_per_object'][x] for x in seq_names] 79 | table_seq = pandas.DataFrame(data=list(zip(seq_names, J_per_object, F_per_object)), columns=seq_measures) 80 | table_seq.to_csv(sys.stdout, index=False, float_format="%0.3f") 81 | 82 | # Write scores to a file named "scores.txt" 83 | with open(os.path.join(output_dir, 'scores.txt'), 'w') as output_file: 84 | final_mean = (np.mean(J["M"]) + np.mean(F["M"])) / 2. 85 | output_file.write("GlobalMean: %f\n" % final_mean) 86 | output_file.write("JMean: %f\n" % np.mean(J["M"])) 87 | output_file.write("JRecall: %f\n" % np.mean(J["R"])) 88 | output_file.write("JDecay: %f\n" % np.mean(J["D"])) 89 | output_file.write("FMean: %f\n" % np.mean(F["M"])) 90 | output_file.write("FRecall: %f\n" % np.mean(F["R"])) 91 | output_file.write("FDecay: %f\n" % np.mean(F["D"])) 92 | total_time = time() - time_start 93 | sys.stdout.write('\nTotal time:' + str(total_time)) 94 | -------------------------------------------------------------------------------- /davis2016-evaluation/evaluation_method.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | from time import time 5 | import argparse 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from davis2017.evaluation import DAVISEvaluation 10 | import yaml 11 | 12 | 13 | with open('../data_paths.yaml') as f: 14 | path_dict = yaml.safe_load(f) 15 | default_davis_path = path_dict['davis2017'].replace('2017', '') 16 | 17 | 18 | time_start = time() 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--davis_path', type=str, help='Path to the DAVIS folder containing the JPEGImages, Annotations, ' 21 | 'ImageSets, Annotations_unsupervised folders', 22 | required=False, default=default_davis_path) 23 | parser.add_argument('--set', type=str, help='Subset to evaluate the results', default='val') 24 | parser.add_argument('--task', type=str, help='Task to evaluate the results', default='semi-supervised', 25 | choices=['semi-supervised', 'unsupervised']) 26 | parser.add_argument('--results_path', type=str, help='Path to the folder containing the sequences folders', 27 | required=True) 28 | parser.add_argument("--year", '-y', type=str, help="Davis dataset year (default: 2016)", default='2016', 29 | choices=['2016', '2017', '2019']) 30 | parser.add_argument('--reset_mode', '-reset', default=False, action='store_true') 31 | args, _ = parser.parse_known_args() 32 | csv_name_global = f'global_results-{args.set}.csv' 33 | csv_name_per_sequence = f'per-sequence_results-{args.set}.csv' 34 | args.davis_path = f'{default_davis_path}/{args.year}' 35 | 36 | # Check if the method has been evaluated before, if so read the results, otherwise compute the results 37 | csv_name_global_path = os.path.join(args.results_path, csv_name_global) 38 | csv_name_per_sequence_path = os.path.join(args.results_path, csv_name_per_sequence) 39 | if os.path.exists(csv_name_global_path) and os.path.exists(csv_name_per_sequence_path) and not args.reset_mode: 40 | print('Using precomputed results...') 41 | table_g = pd.read_csv(csv_name_global_path) 42 | table_seq = pd.read_csv(csv_name_per_sequence_path) 43 | else: 44 | print(f'Evaluating sequences for the {args.task} task...') 45 | # Create dataset and evaluate 46 | # dataset_eval = DAVISEvaluation(davis_root=args.davis_path, task=args.task, gt_set=args.set) 47 | dataset_eval = DAVISEvaluation(davis_root=args.davis_path, task=args.task, gt_set=args.set, year=args.year) 48 | metrics_res = dataset_eval.evaluate(args.results_path) 49 | J, F = metrics_res['J'], metrics_res['F'] 50 | 51 | # Generate dataframe for the general results 52 | g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay'] 53 | final_mean = (np.mean(J["M"]) + np.mean(F["M"])) / 2. 54 | g_res = np.array([final_mean, np.mean(J["M"]), np.mean(J["R"]), np.mean(J["D"]), np.mean(F["M"]), np.mean(F["R"]), 55 | np.mean(F["D"])]) 56 | g_res = np.reshape(g_res, [1, len(g_res)]) 57 | table_g = pd.DataFrame(data=g_res, columns=g_measures) 58 | with open(csv_name_global_path, 'w') as f: 59 | # table_g.to_csv(f, index=False, float_format="%.3f") 60 | table_g.to_csv(f, index=False) 61 | print(f'Global results saved in {csv_name_global_path}') 62 | 63 | # Generate a dataframe for the per sequence results 64 | seq_names = list(J['M_per_object'].keys()) 65 | seq_measures = ['Sequence', 'J-Mean', 'F-Mean'] 66 | J_per_object = [J['M_per_object'][x] for x in seq_names] 67 | F_per_object = [F['M_per_object'][x] for x in seq_names] 68 | table_seq = pd.DataFrame(data=list(zip(seq_names, J_per_object, F_per_object)), columns=seq_measures) 69 | with open(csv_name_per_sequence_path, 'w') as f: 70 | # table_seq.to_csv(f, index=False, float_format="%.3f") 71 | table_seq.to_csv(f, index=False) 72 | print(f'Per-sequence results saved in {csv_name_per_sequence_path}') 73 | 74 | # Print the results 75 | sys.stdout.write(f"--------------------------- Global results for {args.set} ---------------------------\n") 76 | print(table_g.to_string(index=False)) 77 | sys.stdout.write(f"\n---------- Per sequence results for {args.set} ----------\n") 78 | print(table_seq.to_string(index=False)) 79 | total_time = time() - time_start 80 | sys.stdout.write('\nTotal time:' + str(total_time)) 81 | -------------------------------------------------------------------------------- /davis2016-evaluation/pytest/test_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import pandas 5 | from time import time 6 | from collections import defaultdict 7 | 8 | from davis2017.evaluation import DAVISEvaluation 9 | from davis2017 import utils 10 | from davis2017.metrics import db_eval_boundary, db_eval_iou 11 | 12 | 13 | davis_root = 'input_dir/ref' 14 | methods_root = 'examples' 15 | 16 | 17 | def test_task(task, gt_set, res_path, J_target=None, F_target=None, metric=('J', 'F')): 18 | dataset_eval = DAVISEvaluation(davis_root=davis_root, gt_set=gt_set, task=task, codalab=True) 19 | metrics_res = dataset_eval.evaluate(res_path, debug=False, metric=metric) 20 | 21 | num_seq = len(list(dataset_eval.dataset.get_sequences())) 22 | J = metrics_res['J'] if 'J' in metric else {'M': np.zeros(num_seq), 'R': np.zeros(num_seq), 'D': np.zeros(num_seq)} 23 | F = metrics_res['F'] if 'F' in metric else {'M': np.zeros(num_seq), 'R': np.zeros(num_seq), 'D': np.zeros(num_seq)} 24 | 25 | if gt_set == "val" or gt_set == "train" or gt_set == "test-dev": 26 | sys.stdout.write("----------------Global results in CSV---------------\n") 27 | g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay'] 28 | final_mean = (np.mean(J["M"]) + np.mean(F["M"])) / 2. if 'J' in metric and 'F' in metric else 0 29 | g_res = np.array([final_mean, np.mean(J["M"]), np.mean(J["R"]), np.mean(J["D"]), np.mean(F["M"]), np.mean(F["R"]), np.mean(F["D"])]) 30 | table_g = pandas.DataFrame(data=np.reshape(g_res, [1, len(g_res)]), columns=g_measures) 31 | table_g.to_csv(sys.stdout, index=False, float_format="%0.3f") 32 | if J_target is not None: 33 | assert check_results_similarity(J, J_target), f'J {print_error(J, J_target)}' 34 | if F_target is not None: 35 | assert check_results_similarity(F, F_target), f'F {print_error(F, F_target)}' 36 | return J, F 37 | 38 | 39 | def check_results_similarity(target, result): 40 | return np.isclose(np.mean(target['M']) - result[0], 0, atol=0.001) & \ 41 | np.isclose(np.mean(target['R']) - result[1], 0, atol=0.001) & \ 42 | np.isclose(np.mean(target['D']) - result[2], 0, atol=0.001) 43 | 44 | 45 | def print_error(target, result): 46 | return f'M:{np.mean(target["M"])} = {result[0]}\t' + \ 47 | f'R:{np.mean(target["R"])} = {result[1]}\t' + \ 48 | f'D:{np.mean(target["D"])} = {result[2]}' 49 | 50 | 51 | def test_semisupervised_premvos(): 52 | method_path = os.path.join(methods_root, 'premvos') 53 | print('Evaluating PREMVOS val') 54 | J_val = [0.739, 0.831, 0.162] 55 | F_val = [0.818, 0.889, 0.195] 56 | test_task('semi-supervised', 'val', method_path, J_val, F_val) 57 | print('Evaluating PREMVOS test-dev') 58 | J_test_dev = [0.675, 0.768, 0.217] 59 | F_test_dev = [0.758, 0.843, 0.206] 60 | test_task('semi-supervised', 'test-dev', method_path, J_test_dev, F_test_dev) 61 | print('\n') 62 | 63 | 64 | def test_semisupervised_onavos(): 65 | method_path = os.path.join(methods_root, 'onavos') 66 | print('Evaluating OnAVOS val') 67 | J_val = [0.616, 0.674, 0.279] 68 | F_val = [0.691, 0.754, 0.266] 69 | test_task('semi-supervised', 'val', method_path, J_val, F_val) 70 | print('Evaluating OnAVOS test-dev') 71 | J_test_dev = [0.499, 0.543, 0.230] 72 | F_test_dev = [0.557, 0.603, 0.234] 73 | test_task('semi-supervised', 'test-dev', method_path, J_test_dev, F_test_dev) 74 | print('\n') 75 | 76 | 77 | def test_semisupervised_osvos(): 78 | method_path = os.path.join(methods_root, 'osvos') 79 | print('Evaluating OSVOS val') 80 | J_val = [0.566, 0.638, 0.261] 81 | F_val = [0.639, 0.738, 0.270] 82 | test_task('semi-supervised', 'val', method_path, J_val, F_val) 83 | print('Evaluating OSVOS test-dev') 84 | J_test_dev = [0.470, 0.521, 0.192] 85 | F_test_dev = [0.548, 0.597, 0.198] 86 | test_task('semi-supervised', 'test-dev', method_path, J_test_dev, F_test_dev) 87 | print('\n') 88 | 89 | 90 | def test_unsupervised_flip_gt(): 91 | print('Evaluating Unsupervised Permute GT') 92 | method_path = os.path.join(methods_root, 'swap_gt') 93 | if not os.path.isdir(method_path): 94 | utils.generate_random_permutation_gt_obj_proposals(davis_root, 'val', method_path) 95 | # utils.generate_random_permutation_gt_obj_proposals('test-dev', method_path) 96 | J_val = [1, 1, 0] 97 | F_val= [1, 1, 0] 98 | test_task('unsupervised', 'val', method_path, J_val, F_val) 99 | # test_task('unsupervised', 'test-dev', method_path, J_val, F_val) 100 | 101 | 102 | def test_unsupervised_rvos(): 103 | print('Evaluating RVOS') 104 | method_path = os.path.join(methods_root, 'rvos') 105 | test_task('unsupervised', 'val', method_path) 106 | # test_task('unsupervised', 'test-dev', method_path) 107 | 108 | 109 | def test_unsupervsied_multiple_proposals(num_proposals=20, metric=('J', 'F')): 110 | print('Evaluating Multiple Proposals') 111 | method_path = os.path.join(methods_root, f'generated_proposals_{num_proposals}') 112 | utils.generate_obj_proposals(davis_root, 'val', num_proposals, method_path) 113 | # utils.generate_obj_proposals('test-dev', num_proposals, method_path) 114 | test_task('unsupervised', 'val', method_path, metric=metric) 115 | # test_task('unsupervised', 'test-dev', method_path, metric=metric) 116 | 117 | 118 | def test_void_masks(): 119 | gt = np.zeros((2, 200, 200)) 120 | mask = np.zeros((2, 200, 200)) 121 | void = np.zeros((2, 200, 200)) 122 | 123 | gt[:, 100:150, 100:150] = 1 124 | void[:, 50:100, 100:150] = 1 125 | mask[:, 50:150, 100:150] = 1 126 | 127 | assert np.mean(db_eval_iou(gt, mask, void)) == 1 128 | assert np.mean(db_eval_boundary(gt, mask, void)) == 1 129 | 130 | 131 | def benchmark_number_proposals(): 132 | number_proposals = [10, 15, 20, 30] 133 | timing_results = defaultdict(dict) 134 | for n in number_proposals: 135 | time_start = time() 136 | test_unsupervsied_multiple_proposals(n, 'J') 137 | timing_results['J'][n] = time() - time_start 138 | 139 | for n in number_proposals: 140 | time_start = time() 141 | test_unsupervsied_multiple_proposals(n) 142 | timing_results['J_F'][n] = time() - time_start 143 | 144 | print(f'Using J {timing_results["J"]}') 145 | print(f'Using J&F {timing_results["J_F"]}') 146 | 147 | # Using J {10: 156.45335865020752, 15: 217.91797709465027, 20: 282.0747673511505, 30: 427.6770250797272} 148 | # Using J & F {10: 574.3529748916626, 15: 849.7542386054993, 20: 1123.4619634151459, 30: 1663.6704666614532} 149 | # Codalab 150 | # Using J & F {10: 971.196366071701, 15: 1473.9757001399994, 20: 1918.787559747696, 30: 3007.116141319275} 151 | 152 | 153 | if __name__ == '__main__': 154 | # Test void masks 155 | test_void_masks() 156 | 157 | # Test semi-supervised methods 158 | test_semisupervised_premvos() 159 | test_semisupervised_onavos() 160 | test_semisupervised_osvos() 161 | 162 | # Test unsupervised methods 163 | test_unsupervised_flip_gt() 164 | # test_unsupervised_rvos() 165 | test_unsupervsied_multiple_proposals() 166 | -------------------------------------------------------------------------------- /davis2016-evaluation/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = davis2017 3 | version = attr: davis2017.__version__ 4 | description = Evaluation Framework for DAVIS 2017 Semi-supervised and Unsupervised used in the DAVIS Challenges 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | keywords = segmentation 8 | license = GPL v3 9 | author = Sergi Caelles 10 | author-email = scaelles@vision.ee.ethz.ch 11 | home-page = https://github.com/davisvideochallenge/davis2017-evaluation 12 | classifiers = 13 | Development Status :: 4 - Beta 14 | Intended Audience :: Developers 15 | Intended Audience :: Education 16 | Intended Audience :: Science/Research 17 | License :: OSI Approved :: GNU General Public License v3 (GPLv3) 18 | Programming Language :: Python :: 3.6 19 | Programming Language :: Python :: 3.7 20 | Topic :: Scientific/Engineering :: Human Machine Interfaces 21 | Topic :: Software Development :: Libraries 22 | Topic :: Software Development :: Libraries :: Python Modules 23 | -------------------------------------------------------------------------------- /davis2016-evaluation/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import sys 3 | 4 | if sys.version_info < (3, 6): 5 | sys.exit('Sorry, only Python >= 3.6 is supported') 6 | 7 | setup( 8 | python_requires='>=3.6, <4', 9 | install_requires=[ 10 | 'Pillow>=4.1.1', 11 | 'networkx>=2.0', 12 | 'numpy>=1.12.1', 13 | 'opencv-python>=4.0.0.21', 14 | 'pandas>=0.21.1', 15 | 'pathlib2;python_version<"3.5"', 16 | 'scikit-image>=0.13.1', 17 | 'scikit-learn>=0.18', 18 | 'scipy>=1.0.0', 19 | 'tqdm>=4.28.1' 20 | ]) 21 | -------------------------------------------------------------------------------- /davis2016-evaluation/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | import glob 5 | import yaml 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--test_dir', '-td', type=str, default='TEST') 10 | parser.add_argument('--dataset', type=str, default='davis2017') 11 | parser.add_argument('--exp_name', type=str, default=None) 12 | parser.add_argument('--exp_subname', '-sname', type=str, default=None) 13 | parser.add_argument('--result_dir', '-rdir', type=str, default=None) 14 | parser.add_argument('--shot', type=int, default=None) 15 | parser.add_argument('--test_anyway', '-ta', default=False, action='store_true') 16 | parser.add_argument('--reset_mode', '-reset', default=False, action='store_true') 17 | args = parser.parse_args() 18 | 19 | 20 | with open('data_paths.yaml') as f: 21 | path_dict = yaml.safe_load(f) 22 | 23 | year = args.dataset.strip('davis') 24 | args.gt_dir = path_dict[f'davis{year}'].replace(year, '') 25 | if year == '2016': 26 | n_classes = 20 27 | subdir = '' 28 | elif year == '2017': 29 | n_classes = 30 30 | subdir = 'trainval' 31 | 32 | 33 | def check_dir(path): 34 | if args.test_anyway or args.reset_mode: 35 | return True 36 | 37 | if len(os.listdir(path)) != n_classes: 38 | return False 39 | 40 | for class_name in os.listdir(path): 41 | n_src = len(os.listdir(os.path.join(path, class_name))) 42 | n_tgt = len(os.listdir(os.path.join(args.gt_dir, year, subdir, 'JPEGImages', '480p', class_name))) 43 | if n_src != n_tgt: 44 | return False 45 | 46 | return True 47 | 48 | 49 | def check_result_dir(result_dir): 50 | if args.shot is not None: 51 | target_name = f'{args.dataset}_vos_results_shot:{args.shot}' 52 | if result_dir == target_name: 53 | return True 54 | else: 55 | return False 56 | else: 57 | target_name = f'{args.dataset}_vos_results_shot:' 58 | if result_dir.startswith(target_name): 59 | return True 60 | else: 61 | return False 62 | 63 | 64 | if args.result_dir is None: 65 | if args.exp_name is None: 66 | exp_names = os.listdir(os.path.join('experiments', args.test_dir)) 67 | exp_names = [exp_name for exp_name in exp_names 68 | if len(glob.glob(os.path.join('experiments', args.test_dir, exp_name, '*', 'logs', f'{args.dataset}*'))) > 0] 69 | else: 70 | exp_names = [args.exp_name] 71 | 72 | for exp_name in exp_names: 73 | exp_dir = os.path.join('experiments', args.test_dir, exp_name) 74 | if args.exp_subname is not None: 75 | exp_subnames = [args.exp_subname] 76 | else: 77 | exp_subnames_all = os.listdir(exp_dir) 78 | exp_subnames = [] 79 | for exp_subname in exp_subnames_all: 80 | for result_dir in os.listdir(os.path.join(exp_dir, exp_subname, 'logs')): 81 | if check_result_dir(result_dir): 82 | if check_dir(os.path.join(exp_dir, exp_subname, 'logs', result_dir)): 83 | exp_subnames.append((exp_subname, result_dir)) 84 | 85 | os.chdir('davis2016-evaluation') 86 | for exp_subname, result_dir in exp_subnames: 87 | print(f'Processing {exp_subname}') 88 | command = f"python evaluation_method.py --results_path {os.path.join('..', exp_dir, exp_subname, 'logs', result_dir)} -y {year}" 89 | if args.reset_mode: 90 | command += ' -reset' 91 | print(command) 92 | subprocess.call(command.split()) 93 | os.chdir('..') 94 | else: 95 | os.chdir('davis2016-evaluation') 96 | command = f"python evaluation_method.py --results_path {os.path.join('..', args.result_dir)} -y {year}" 97 | if args.reset_mode: 98 | command += ' -reset' 99 | print(command) 100 | subprocess.call(command.split()) 101 | os.chdir('..') 102 | -------------------------------------------------------------------------------- /downstream/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/downstream/__init__.py -------------------------------------------------------------------------------- /downstream/ap10k/configs/test_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: ap10k 9 | task: animalkp 10 | class_wise: True 11 | class_name: none 12 | shot: 20 13 | img_size: 224 14 | base_size: 256 15 | vis_size: 256 16 | support_idx: 0 17 | coord_path: none 18 | autocrop: False 19 | autocrop_minoverlap: 0.5 20 | 21 | # dataloader arguments 22 | num_workers: 1 23 | eval_shot: 20 24 | eval_batch_size: 8 25 | eval_size: -1 26 | support_idx: 0 27 | chunk_size: 1 28 | channel_chunk_size: 1 29 | 30 | # logging arguments 31 | log_dir: TEST_AP10K 32 | save_dir: FINETUNE_AP10K 33 | load_dir: TRAIN 34 | load_step: 400000 35 | load_path: none 36 | result_dir: none -------------------------------------------------------------------------------- /downstream/ap10k/configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: ap10k 9 | task: animalkp 10 | class_wise: True 11 | class_name: none 12 | shot: 20 13 | img_size: 224 14 | base_size: 256 15 | vis_size: 256 16 | support_idx: 0 17 | randomflip: True 18 | randomjitter: True 19 | randomrotate: True 20 | randomblur: False 21 | coord_path: none 22 | autocrop: False 23 | autocrop_minoverlap: 0.5 24 | 25 | # dataloader arguments 26 | num_workers: 1 27 | global_batch_size: 8 28 | eval_shot: 20 29 | eval_batch_size: 10 30 | eval_size: 100 31 | chunk_size: 1 32 | channel_chunk_size: 1 33 | channel_sampling: 12 34 | 35 | # model arguments 36 | attn_dropout: 0.5 37 | n_input_images: 1 38 | separate_alpha: False 39 | 40 | # training arguments 41 | n_steps: 10000 42 | n_schedule_steps: -1 43 | optimizer: adam 44 | loss_type: ssl 45 | mask_value: -1. 46 | early_stopping_patience: 5 47 | monitor: AP_inverted 48 | 49 | # learning rate arguments 50 | lr: 0.001 51 | lr_pretrained: 0.0002 52 | lr_schedule: constant 53 | lr_warmup: 0 54 | lr_warmup_scale: 0. 55 | schedule_from: 0 56 | weight_decay: 0. 57 | lr_decay_degree: 0.9 58 | 59 | # parameter arguments 60 | from_scratch: False 61 | head_tuning: False 62 | label_decoder_tuning: False 63 | input_embed_tuning: False 64 | output_embed_tuning: False 65 | relpos_tuning: False 66 | 67 | # logging arguments 68 | log_dir: FINETUNE_AP10K 69 | save_dir: FINETUNE_AP10K 70 | load_dir: TRAIN 71 | val_iter: 100 72 | load_step: 400000 73 | -------------------------------------------------------------------------------- /downstream/ap10k/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | from dataset.coco_api_wrapper import SilentXTCOCO as COCO 5 | from dataset.coco_api_wrapper import SilentXTCOCOeval as COCOeval 6 | from dataset.utils import oks_nms 7 | 8 | 9 | class AP10KEvaluator: 10 | def __init__(self, kp_json_path, ann_path, base_size=(256, 256), local_rank=0, n_devices=1): 11 | self.sigmas = [ 12 | 0.025, 0.025, 0.026, 0.035, 0.035, 0.079, 0.072, 0.062, 0.079, 0.072, 13 | 0.062, 0.107, 0.087, 0.089, 0.107, 0.087, 0.089 14 | ] 15 | self.base_size = base_size 16 | self.kp_json_path = kp_json_path 17 | self.ann_path = ann_path 18 | self.all_ret = {} 19 | self.local_rank = local_rank 20 | self.n_devices = n_devices 21 | 22 | def reset(self): 23 | self.all_ret = {} 24 | 25 | def ret_to_coco_cropped(self, gt_coco): 26 | ''' 27 | self.all_ret : dict {imgId: {annId: (17 x 3 ndarray, score)}} 28 | ''' 29 | pred_dict = {imgId: {} for imgId in self.all_ret} 30 | for imgId in self.all_ret: 31 | for annId in self.all_ret[imgId]: 32 | if annId not in pred_dict[imgId]: 33 | pred_dict[imgId][annId] = [] 34 | 35 | detection, scores = self.all_ret[imgId][annId] 36 | cat = gt_coco.anns[annId]['category_id'] 37 | 38 | for d in range(detection.shape[1]): 39 | one_det = detection[:, d, :] # 17 x 3 40 | 41 | W, H = gt_coco.imgs[imgId]['width'], gt_coco.imgs[imgId]['height'] 42 | x1, y1, w1, h1 = gt_coco.anns[annId]['bbox'] 43 | x2 = x1 + w1 44 | y2 = y1 + h1 45 | 46 | x = round(max(0, (x1 + x2)/2 - w1/2*1.2)) 47 | y = round(max(0, (y1 + y2)/2 - h1/2*1.2)) 48 | w = round(min(W - max(0, (x1 + x2) / 2 - w1/2*1.2), w1*1.2)) 49 | h = round(min(H - max(0, (y1 + y2) / 2 - h1/2*1.2), h1*1.2)) 50 | 51 | # resize and shift 52 | one_det[:, 0] = np.round(one_det[:, 0] * w / self.base_size[1]) + x 53 | one_det[:, 1] = np.round(one_det[:, 1] * h / self.base_size[0]) + y 54 | 55 | one_det = one_det.astype(int) 56 | res = { 57 | 'id': annId, 58 | 'image_id': imgId, 59 | 'category_id': cat, 60 | 'keypoints': one_det.reshape(-1).tolist(), 61 | 'score': scores[d].item() 62 | } 63 | pred_dict[imgId][annId].append(res) 64 | 65 | oks_thr = 0.9 66 | sigmas = np.array([ 67 | 0.025, 0.025, 0.026, 0.035, 0.035, 0.079, 0.072, 0.062, 0.079, 0.072, 68 | 0.062, 0.107, 0.087, 0.089, 0.107, 0.087, 0.089 69 | ]) 70 | valid_kpts = [] 71 | for image_id in pred_dict.keys(): 72 | for ann_id in pred_dict[image_id]: 73 | img_kpts = pred_dict[image_id][ann_id] 74 | for n_p in img_kpts: 75 | box_score = n_p['score'] 76 | n_p['keypoints'] = np.array(n_p['keypoints']).reshape(-1, 3) 77 | kpt_score = 0 78 | valid_num = 0 79 | x_min = np.min(n_p['keypoints'][:, 0]) 80 | x_max = np.max(n_p['keypoints'][:, 0]) 81 | y_min = np.min(n_p['keypoints'][:, 1]) 82 | y_max = np.max(n_p['keypoints'][:, 1]) 83 | area = (x_max - x_min) * (y_max - y_min) 84 | n_p['area'] = int(area) 85 | valid_num = len(n_p['keypoints']) # assume all visible 86 | kpt_score = kpt_score / valid_num 87 | # rescoring 88 | n_p['score'] = float(kpt_score * box_score) 89 | keep = oks_nms(list(img_kpts), thr=oks_thr, sigmas=sigmas) 90 | valid_kpts.append([img_kpts[_keep] for _keep in keep]) 91 | 92 | ret = [] 93 | for each in valid_kpts: 94 | for det in each: 95 | det['keypoints'] = det['keypoints'].reshape(-1).astype(int).tolist() 96 | ret.append(det) 97 | 98 | return ret 99 | 100 | def evaluate_keypoints(self): 101 | gt_coco = COCO(self.ann_path) 102 | ret_coco = self.ret_to_coco_cropped(gt_coco) 103 | 104 | # save coco result 105 | with open(self.kp_json_path.replace('.json', f'_{self.local_rank}.json'), 'w') as f: 106 | json.dump(ret_coco, f) 107 | 108 | # synchronize at this point 109 | if self.n_devices > 1: 110 | torch.distributed.barrier() 111 | 112 | if self.local_rank == 0: 113 | for local_rank in range(1, self.n_devices): 114 | with open(self.kp_json_path.replace('.json', f'_{local_rank}.json'), 'r') as f: 115 | ret_coco += json.load(f) 116 | with open(self.kp_json_path, 'w') as f: 117 | json.dump(ret_coco, f) 118 | 119 | imgIds = [ret['image_id'] for ret in ret_coco] 120 | gt_coco.imgs = {k: v for k, v in gt_coco.imgs.items() if k in imgIds} 121 | annIds = [ret['id'] for ret in ret_coco] 122 | gt_coco.anns = {k: v for k, v in gt_coco.anns.items() if k in annIds} 123 | 124 | # evaluate 125 | coco_det = gt_coco.loadRes(self.kp_json_path) 126 | sigmas = np.array(self.sigmas) 127 | coco_eval = COCOeval(gt_coco, coco_det, 'keypoints', sigmas=sigmas) 128 | coco_eval.params.useSegm = None 129 | coco_eval.evaluate() 130 | coco_eval.accumulate() 131 | coco_eval.summarize() 132 | else: 133 | coco_eval = None 134 | 135 | return coco_eval -------------------------------------------------------------------------------- /downstream/ap10k/learner.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange, reduce 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import os 6 | 7 | from ..base_learner import BaseLearner 8 | from .dataset import AP10KDataset 9 | from .utils import dense_to_sparse, vis_animal_keypoints, get_modes, modes_to_array 10 | from .evaluator import AP10KEvaluator 11 | from train.loss import spatial_softmax_loss 12 | 13 | 14 | class AP10KLearner(BaseLearner): 15 | BaseDataset = AP10KDataset 16 | 17 | def register_evaluator(self): 18 | self.kp_classes = AP10KDataset.CLASS_NAMES 19 | self.kp_json_path = {} 20 | self.ann_path = {} 21 | 22 | if self.config.stage == 1: 23 | keys = ['mtest_train', 'mtest_valid'] 24 | else: 25 | keys = ['mtest_test'] 26 | self.evaluator = {key: None for key in keys} 27 | 28 | for key in keys: 29 | split = key.split('_')[1] 30 | kp_json_path = os.path.join(self.result_dir, f'{key}_temp.json') 31 | ann_path = os.path.join(self.config.path_dict[self.config.dataset], 'annotations', 32 | f'ap10k-{split.replace("valid", "val")}-split1.json') 33 | self.evaluator[key] = AP10KEvaluator(kp_json_path, ann_path, self.config.base_size, self.local_rank, self.n_devices) 34 | 35 | def reset_evaluator(self): 36 | for key in self.evaluator.keys(): 37 | self.evaluator[key].reset() 38 | 39 | def compute_loss(self, Y_pred, Y, M): 40 | ''' 41 | loss function that returns loss and a dictionary of its components 42 | ''' 43 | loss = spatial_softmax_loss(Y_pred, Y, M, reduction='mean', scaled=self.config.scale_ssl) 44 | loss_values = {'loss': loss.detach(), 'loss_ssl': loss.detach()} 45 | 46 | return loss, loss_values 47 | 48 | def postprocess_logits(self, Y_pred_out): 49 | ''' 50 | post-processing function for logits 51 | ''' 52 | # spatial softmax 53 | H, W = Y_pred_out.shape[-2:] 54 | Y_pred_out = rearrange(Y_pred_out, '1 T N C H W -> 1 T N C (H W)') 55 | Y_pred_out = F.softmax(Y_pred_out, dim=-1) 56 | Y_pred_out = rearrange(Y_pred_out, '1 T N C (H W) -> 1 T N C H W', H=H, W=W) 57 | Y_pred_out = Y_pred_out / (1e-18 + reduce(Y_pred_out, '1 T N C H W -> 1 T N C 1 1', 'max')) 58 | return Y_pred_out 59 | 60 | def postprocess_vis(self, label, img=None, aux=None): 61 | ''' 62 | post-processing function for visualization 63 | ''' 64 | if label.ndim == 3: 65 | sparse = True 66 | else: 67 | sparse = False 68 | sparse_gt, _ = aux 69 | 70 | label_vis = [] 71 | for i in range(len(label)): 72 | img_ = np.ascontiguousarray((img[i]*128).byte().permute(1, 2, 0).numpy()) 73 | if sparse: 74 | kps = label[i].transpose(0, 1).numpy() 75 | else: 76 | kps = dense_to_sparse(label[i]).transpose(0, 1).numpy() 77 | if sparse_gt is not None: 78 | kps[2] = sparse_gt[i, :, 2].float().cpu().numpy() 79 | vis = vis_animal_keypoints(img_, kps) 80 | vis = torch.from_numpy(vis).permute(2, 0, 1) / 255 81 | label_vis.append(vis) 82 | label = torch.stack(label_vis) 83 | 84 | return label 85 | 86 | def compute_metric(self, Y, Y_pred, M, aux, evaluator_key): 87 | ''' 88 | compute evaluation metric 89 | ''' 90 | metric = 0 91 | assert aux is not None 92 | assert evaluator_key is not None 93 | evaluator = self.evaluator[evaluator_key] 94 | 95 | _, (imgIds, annIds) = aux 96 | for i in range(len(Y_pred)): 97 | modes, scores = get_modes(Y_pred[i].cpu(), return_scores=True) 98 | arr, score = modes_to_array(modes, scores, max_detection=1) 99 | if imgIds[i].item() not in evaluator.all_ret: 100 | evaluator.all_ret[imgIds[i].item()] = {} 101 | evaluator.all_ret[imgIds[i].item()][annIds[i].item()] = (arr, score) 102 | return metric 103 | 104 | def log_metrics(self, loss_pred, log_dict, valid_tag): 105 | ''' 106 | log evaluation metrics 107 | ''' 108 | coco_eval = self.evaluator[valid_tag].evaluate_keypoints() 109 | ap = coco_eval.stats[0] if coco_eval is not None else 0 110 | if self.n_devices > 1: 111 | ap = self.trainer.all_gather(torch.tensor(ap, device=self.trainer.device))[0] 112 | log_dict[f'{valid_tag}/{self.vis_tag}_AP_inverted'] = 1 - ap 113 | 114 | def get_test_metrics(self, metrics_total): 115 | ''' 116 | save test metrics 117 | ''' 118 | coco_eval = self.evaluator['mtest_test'].evaluate_keypoints() 119 | stats_names = [ 120 | 'AP', 'AP .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 121 | 'AR .75', 'AR (M)', 'AR (L)' 122 | ] 123 | metrics = list(zip(stats_names, coco_eval.stats)) 124 | return metrics 125 | 126 | -------------------------------------------------------------------------------- /downstream/ap10k/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from dataset.utils import Keypoints, _create_flip_indices, dense_to_sparse, get_modes, modes_to_array 6 | 7 | 8 | class AnimalKeypoints(Keypoints): 9 | NAMES = [ 10 | 'left_eye', 11 | 'right_eye', 12 | 'nose', 13 | 'neck', 14 | 'root_of_tail', 15 | 'left_shoulder', 16 | 'left_elbow', 17 | 'left_front_paw', 18 | 'right_shoulder', 19 | 'right_elbow', 20 | 'right_front_paw', 21 | 'left_hip', 22 | 'left_knee', 23 | 'left_back_paw', 24 | 'right_hip', 25 | 'right_knee', 26 | 'right_back_paw', 27 | ] 28 | FLIP_MAP = { 29 | 'left_eye': 'right_eye', 30 | 'left_front_paw': 'right_front_paw', 31 | 'left_shoulder': 'right_shoulder', 32 | 'left_elbow': 'right_elbow', 33 | 'left_hip': 'right_hip', 34 | 'left_knee': 'right_knee', 35 | 'left_back_paw': 'right_back_paw', 36 | } 37 | 38 | AnimalKeypoints.FLIP_INDS = _create_flip_indices(AnimalKeypoints.NAMES, AnimalKeypoints.FLIP_MAP) 39 | 40 | 41 | def kp_connections_animal(keypoints): 42 | kp_lines = [ 43 | [keypoints.index('left_eye'), keypoints.index('right_eye')], 44 | [keypoints.index('left_eye'), keypoints.index('nose')], 45 | [keypoints.index('right_eye'), keypoints.index('nose')], 46 | [keypoints.index('nose'), keypoints.index('neck')], 47 | [keypoints.index('neck'), keypoints.index('left_shoulder')], 48 | [keypoints.index('left_shoulder'), keypoints.index('left_elbow')], 49 | [keypoints.index('left_elbow'), keypoints.index('left_front_paw')], 50 | [keypoints.index('neck'), keypoints.index('right_shoulder')], 51 | [keypoints.index('right_shoulder'), keypoints.index('right_elbow')], 52 | [keypoints.index('right_elbow'), keypoints.index('right_front_paw')], 53 | [keypoints.index('neck'), keypoints.index('root_of_tail')], 54 | [keypoints.index('root_of_tail'), keypoints.index('left_hip')], 55 | [keypoints.index('left_hip'), keypoints.index('left_knee')], 56 | [keypoints.index('left_knee'), keypoints.index('left_back_paw')], 57 | [keypoints.index('root_of_tail'), keypoints.index('right_hip')], 58 | [keypoints.index('right_hip'), keypoints.index('right_knee')], 59 | [keypoints.index('right_knee'), keypoints.index('right_back_paw')], 60 | ] 61 | return kp_lines 62 | 63 | 64 | AnimalKeypoints.CONNECTIONS = kp_connections_animal(AnimalKeypoints.NAMES) 65 | 66 | 67 | def vis_animal_keypoints(img, kps, kp_thresh=0.5, alpha=0.7, lth=1, crad=2): 68 | """Visualizes keypoints (adapted from vis_one_image). 69 | kps has shape (3, #keypoints) where 4 rows are (x, y, prob). 70 | """ 71 | dataset_keypoints = AnimalKeypoints.NAMES 72 | kp_lines = AnimalKeypoints.CONNECTIONS 73 | 74 | # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. 75 | cmap = plt.get_cmap('rainbow') 76 | colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)] 77 | colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors] 78 | 79 | # Perform the drawing on a copy of the image, to allow for blending. 80 | kp_mask = np.copy(img) 81 | kps = kps.astype(np.int64) 82 | 83 | # Draw the keypoints. 84 | for l in range(len(kp_lines)): 85 | i1 = kp_lines[l][0] 86 | i2 = kp_lines[l][1] 87 | p1 = kps[0, i1], kps[1, i1] 88 | p2 = kps[0, i2], kps[1, i2] 89 | if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: 90 | cv2.line( 91 | kp_mask, p1, p2, 92 | color=colors[l], thickness=lth, lineType=cv2.LINE_AA) 93 | if kps[2, i1] > kp_thresh: 94 | cv2.circle( 95 | kp_mask, p1, 96 | radius=crad, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) 97 | if kps[2, i2] > kp_thresh: 98 | cv2.circle( 99 | kp_mask, p2, 100 | radius=crad, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) 101 | 102 | # Blend the keypoints. 103 | vis = (1 - alpha) * img + kp_mask * alpha 104 | return vis -------------------------------------------------------------------------------- /downstream/cellpose/configs/test_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: cellpose 9 | task: cellpose 10 | class_wise: False 11 | class_name: none 12 | shot: 50 13 | img_size: 224 14 | base_size: 256 15 | vis_size: 256 16 | support_idx: 0 17 | coord_path: none 18 | autocrop: True 19 | autocrop_minoverlap: 0.5 20 | autocrop_rescale: 1 21 | 22 | # dataloader arguments 23 | num_workers: 1 24 | eval_shot: 50 25 | eval_batch_size: 1 26 | eval_size: -1 27 | support_idx: 0 28 | chunk_size: 1 29 | channel_chunk_size: 1 30 | 31 | # logging arguments 32 | log_dir: TEST_CELLPOSE 33 | save_dir: FINETUNE_CELLPOSE 34 | load_dir: TRAIN 35 | load_step: 400000 36 | load_path: none 37 | result_dir: none -------------------------------------------------------------------------------- /downstream/cellpose/configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: cellpose 9 | task: cellpose 10 | class_wise: False 11 | class_name: none 12 | shot: 50 13 | img_size: 224 14 | base_size: 256 15 | vis_size: 256 16 | support_idx: 0 17 | randomflip: True 18 | randomjitter: False 19 | randomrotate: False 20 | randomblur: False 21 | coord_path: none 22 | autocrop: True 23 | autocrop_minoverlap: 0.5 24 | autocrop_rescale: 1 25 | 26 | # dataloader arguments 27 | num_workers: 1 28 | global_batch_size: 8 29 | eval_shot: 50 30 | eval_batch_size: 10 31 | eval_size: -1 32 | chunk_size: -1 33 | channel_chunk_size: 1 34 | channel_sampling: -1 35 | 36 | # model arguments 37 | attn_dropout: 0.5 38 | n_input_images: 2 39 | separate_alpha: False 40 | monitor: AP50_inverted 41 | 42 | # training arguments 43 | n_steps: 8000 44 | n_schedule_steps: -1 45 | optimizer: adam 46 | loss_type: cellpose 47 | mask_value: -1. 48 | early_stopping_patience: -1 49 | 50 | # learning rate arguments 51 | lr: 0.003 52 | lr_pretrained: 0.0002 53 | lr_schedule: constant 54 | lr_warmup: 0 55 | lr_warmup_scale: 0. 56 | schedule_from: 0 57 | weight_decay: 0. 58 | lr_decay_degree: 0.9 59 | 60 | # parameter arguments 61 | from_scratch: False 62 | head_tuning: False 63 | input_embed_tuning: False 64 | output_embed_tuning: False 65 | relpos_tuning: True 66 | label_decoder_tuning: True 67 | 68 | # logging arguments 69 | log_dir: FINETUNE_CELLPOSE 70 | save_dir: FINETUNE_CELLPOSE 71 | load_dir: TRAIN 72 | val_iter: 1000 73 | load_step: 400000 -------------------------------------------------------------------------------- /downstream/cellpose/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | from einops import repeat 5 | import random 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import ToTensor, Resize, ColorJitter, RandomRotation 11 | 12 | from dataset.utils import crop_arrays 13 | 14 | 15 | class CELLPOSEDataset(Dataset): 16 | ''' 17 | cellpose dataset 18 | ''' 19 | def __init__(self, config, split, base_size, crop_size, eval_mode=False, resize=False, dset_size=-1): 20 | super().__init__() 21 | self.base_size = base_size 22 | data_root = config.path_dict[config.dataset] 23 | self.eval_mode = eval_mode 24 | if split == 'valid': 25 | split = 'test' 26 | self.data_dir = os.path.join(data_root, split) 27 | 28 | self.img_size = crop_size 29 | self.eval_mode = eval_mode 30 | self.resize = resize 31 | self.shot = config.shot 32 | self.support_idx = config.support_idx 33 | self.precision = config.precision 34 | 35 | if split == 'train': 36 | idxs_path = os.path.join('dataset', 'meta_info', 'cellpose', 'train_idxs_perm.pth') 37 | if not os.path.exists(idxs_path): 38 | idxs = torch.randperm(540) 39 | torch.save(idxs, idxs_path) 40 | else: 41 | idxs = torch.load(idxs_path) 42 | self.data_idxs = idxs[self.support_idx*self.shot:(self.support_idx + 1)*self.shot] 43 | else: 44 | self.data_idxs = [i for i in range(0, 68)] 45 | 46 | self.toten = ToTensor() 47 | self.base_resizer = Resize(base_size) 48 | self.resizer = Resize(self.img_size) 49 | self.resize = resize 50 | self.eval_mode = eval_mode 51 | 52 | self.randomflip = config.randomflip 53 | self.randomjitter = config.randomjitter 54 | self.randomrotate = config.randomrotate 55 | self.jitter = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2) 56 | self.rotate = RandomRotation(30) 57 | self.max_h = 576 58 | self.max_w = 720 59 | 60 | if dset_size < 0: 61 | self.dset_size = len(self.data_idxs) 62 | elif not eval_mode: 63 | self.dset_size = dset_size 64 | else: 65 | self.dset_size = min(dset_size, len(self.data_idxs)) 66 | 67 | def __len__(self): 68 | return self.dset_size 69 | 70 | def load_data(self, img_path, mask_path): 71 | image = Image.open(img_path).convert('RGB') 72 | mask = Image.open(mask_path) 73 | image = self.toten(image) 74 | image = torch.stack([image[1], image[0]]) # make order to (cytoplasm, nuclei) 75 | 76 | mask = self.toten(mask) 77 | flow = torch.from_numpy(np.load(mask_path.replace("masks.png", "flows.npy"))) 78 | return image, flow, mask 79 | 80 | def postprocess_data(self, image, flow, mask): 81 | """ 82 | image: 2 H W 83 | flow: 2 H W 84 | mask: 1 H W 85 | """ 86 | 87 | # for evaluation 88 | aux = {} 89 | aux["full_mask"] = F.pad(mask, (0, self.max_w - mask.shape[-1], 0 ,self.max_h - mask.shape[-2])) 90 | aux["full_flow"] = F.pad(flow, (0, self.max_w - mask.shape[-1], 0 ,self.max_h - mask.shape[-2])) 91 | aux["full_semmask"] = F.pad((mask > 0).float(), (0, self.max_w - mask.shape[-1], 0 ,self.max_h - mask.shape[-2])) 92 | aux["full_res"] = mask.shape[-2:] 93 | 94 | # normalize flow to [0, 1] 95 | flow = (flow + 1) / 2 96 | 97 | # repeat single-channel images to rgb 98 | X = repeat(image, "N H W -> (N C) H W", C=3) # 2 H W -> 6 H W 99 | 100 | # foreground mask 101 | Y_mask = (mask > 0).float() 102 | 103 | # append label channels 104 | Y = torch.cat([flow, Y_mask], dim=0) # 3 H W 105 | M = torch.ones_like(Y) 106 | 107 | # resize images to base_size or adaptive size 108 | if (X.shape[-1] < self.img_size[-1] or X.shape[-2] < self.img_size[-2]): 109 | # autocropping 110 | # adaptive_img_size = (int(self.img_size[-1] * self.img_size[-1] / 224), int(self.img_size[-2] * self.img_size[-2] / 224)) 111 | adaptive_img_size = self.img_size 112 | 113 | min_ratio = max(adaptive_img_size[-1] / X.shape[-1] , adaptive_img_size[-2] / X.shape[-2]) 114 | adaptive_resizer = Resize((1, 1)) 115 | adaptive_resizer.size = (max(adaptive_img_size[-2], int(min_ratio * X.shape[-2])), 116 | max(adaptive_img_size[-1], int(min_ratio * X.shape[-1])) ) 117 | X = adaptive_resizer(X) 118 | Y = adaptive_resizer(Y) 119 | M = adaptive_resizer(M) 120 | 121 | # image augmentation 122 | if not self.eval_mode: 123 | if self.randomflip: 124 | if random.random() > 0.5: 125 | X = torch.flip(X, dims=[-1]) 126 | Y = torch.flip(Y, dims=[-1]) 127 | Y[1] = 1-Y[1] 128 | M = torch.flip(M, dims=[-1]) 129 | if random.random() > 0.5: 130 | X = torch.flip(X, dims=[-2]) 131 | Y = torch.flip(Y, dims=[-2]) 132 | Y[0] = 1-Y[0] 133 | M = torch.flip(M, dims=[-2]) 134 | 135 | if self.randomjitter and random.random() > 0.5: 136 | X1, X2 = X.split(3, dim=0) 137 | X12 = torch.cat([X1, X2], dim=-1) 138 | X12 = self.jitter(X12) 139 | X1, X2 = X12.split(X1.size()[-1], dim=-1) 140 | X1 = repeat(X1[:1], '1 H W -> 3 H W') 141 | X2 = repeat(X2[:1], '1 H W -> 3 H W') 142 | X = torch.cat([X1, X2], dim=0) 143 | 144 | # resize or crop image to img_size 145 | if self.resize: 146 | X = self.resizer(X) 147 | Y = self.resizer(Y) 148 | M = self.resizer(M) 149 | elif not self.eval_mode: 150 | X, Y, M = crop_arrays(X, Y, M, 151 | base_size=X.size()[-2:], 152 | crop_size=self.img_size, 153 | random=(not self.eval_mode)) 154 | 155 | if self.precision == 'bf16': 156 | X = X.to(torch.bfloat16) 157 | Y = Y.to(torch.bfloat16) 158 | M = M.to(torch.bfloat16) 159 | 160 | if self.eval_mode: 161 | return X, Y, M, aux 162 | else: 163 | return X, Y, M 164 | 165 | def __getitem__(self, idx): 166 | cur_idx = self.data_idxs[idx % len(self.data_idxs)] 167 | cur_idx = str(cur_idx).zfill(3) 168 | img_path = os.path.join(self.data_dir, f'{cur_idx}_img.png') 169 | mask_path = os.path.join(self.data_dir, f'{cur_idx}_masks.png') 170 | 171 | image, flow, mask = self.load_data(img_path, mask_path) 172 | return self.postprocess_data(image, flow, mask) 173 | -------------------------------------------------------------------------------- /downstream/cellpose/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | import flow_vis 5 | 6 | from ..base_learner import BaseLearner 7 | from .dataset import CELLPOSEDataset 8 | from .utils import compute_masks, average_precision 9 | 10 | 11 | class CELLPOSELearner(BaseLearner): 12 | BaseDataset = CELLPOSEDataset 13 | 14 | def compute_loss(self, Y_pred, Y, M): 15 | ''' 16 | loss function that returns loss and a dictionary of its components 17 | ''' 18 | loss = (M * F.binary_cross_entropy_with_logits(Y_pred, Y, reduction='none')).mean() 19 | loss_values = {'loss': loss.detach(), 'loss_bce': loss.detach()} 20 | return loss, loss_values 21 | 22 | def postprocess_logits(self, Y_pred_out): 23 | ''' 24 | post-processing function for logits 25 | ''' 26 | Y_pred_out = Y_pred_out.sigmoid() 27 | return Y_pred_out 28 | 29 | def postprocess_vis(self, label, img=None, aux=None): 30 | ''' 31 | post-processing function for visualization 32 | ''' 33 | vis = [] 34 | label_vis = label.clone() 35 | label_vis[:, :2] = label_vis[:, :2] * 2 - 1 36 | label_vis[:, :2] = label_vis[:, :2].clip(-1, 1) 37 | for i in range(len(label)): 38 | vis_ = flow_vis.flow_to_color(rearrange(label_vis[i, :2] * (label_vis[i, 2:3] > 0.5).float(), 'C H W -> H W C').numpy()) 39 | vis_ = torch.from_numpy(vis_/255) 40 | vis += [rearrange(vis_, 'H W C -> C H W')] 41 | 42 | label = torch.stack(vis) 43 | 44 | return label 45 | 46 | def compute_metric(self, Y, Y_pred, M, aux, evaluator_key=None): 47 | ''' 48 | compute evaluation metric 49 | ''' 50 | Y = Y.clone() 51 | Y_pred = Y_pred.clone() 52 | Y[:, :2] = Y[:, :2] * 2 - 1 53 | Y_pred[:, :2] = Y_pred[:, :2] * 2 - 1 54 | 55 | full_res = torch.stack(aux["full_res"]).permute(1, 0).long().cpu().numpy() 56 | full_mask = [mask[..., :h, :w] for mask, (h, w) in zip(aux["full_mask"], full_res)] 57 | 58 | Y_full = [full_mask_[0].data.cpu().numpy() for full_mask_ in full_mask] 59 | flow = Y_pred[:, :2].cpu().numpy() 60 | prob_mask = Y_pred[:, -1].cpu().numpy() 61 | mask_preds = [] 62 | for i in range(len(Y_full)): 63 | mask_pred, _ = compute_masks(5 * flow[i], prob_mask[i], cellprob_threshold=0.5, 64 | flow_threshold=0., resize=Y_full[i].shape[-2:], use_gpu=True) 65 | mask_preds += [mask_pred.astype('int32')] 66 | 67 | ap, tp, fp, fn = average_precision(Y_full, mask_preds, threshold=0.5) 68 | metric = torch.tensor(ap.mean()).to(Y.device) 69 | 70 | return metric 71 | 72 | def log_metrics(self, loss_pred, log_dict, valid_tag): 73 | ''' 74 | log evaluation metrics 75 | ''' 76 | log_dict[f'{valid_tag}/{self.vis_tag}_AP50_inverted'] = 1 - loss_pred 77 | -------------------------------------------------------------------------------- /downstream/davis2017/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/downstream/davis2017/__init__.py -------------------------------------------------------------------------------- /downstream/davis2017/configs/test_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: davis2017 9 | task: vos 10 | class_wise: True 11 | class_name: none 12 | shot: 1 13 | img_size: 384 14 | base_size: 448 15 | vis_size: 448 16 | support_idx: 0 17 | coord_path: none 18 | autocrop: False 19 | autocrop_minoverlap: 0.5 20 | randomscale: False 21 | 22 | # dataloader arguments 23 | num_workers: 1 24 | eval_shot: 1 25 | eval_batch_size: 8 26 | eval_size: -1 27 | support_idx: 0 28 | chunk_size: -1 29 | channel_chunk_size: -1 30 | 31 | # logging arguments 32 | log_dir: TEST_DAVIS2017 33 | save_dir: FINETUNE_DAVIS2017 34 | load_dir: TRAIN 35 | load_step: 400000 36 | load_path: none 37 | result_dir: results_davis2017 -------------------------------------------------------------------------------- /downstream/davis2017/configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: davis2017 9 | task: vos 10 | class_wise: True 11 | class_name: none 12 | shot: 1 13 | img_size: 384 14 | base_size: 448 15 | vis_size: 448 16 | support_idx: 0 17 | randomflip: True 18 | randomjitter: False 19 | randomrotate: False 20 | randomblur: False 21 | randomscale: False 22 | coord_path: none 23 | autocrop: False 24 | autocrop_minoverlap: 0.5 25 | 26 | # dataloader arguments 27 | num_workers: 1 28 | global_batch_size: 2 29 | eval_shot: 1 30 | eval_batch_size: 10 31 | eval_size: 100 32 | chunk_size: -1 33 | channel_chunk_size: 2 34 | channel_sampling: 2 35 | 36 | # model arguments 37 | attn_dropout: 0.5 38 | n_input_images: 1 39 | separate_alpha: False 40 | 41 | # training arguments 42 | n_steps: 5000 43 | n_schedule_steps: -1 44 | optimizer: adam 45 | loss_type: ce 46 | mask_value: -1. 47 | early_stopping_patience: -1 48 | monitor: J&F_inverted 49 | 50 | # learning rate arguments 51 | lr: 0.0005 52 | lr_pretrained: 0.0002 53 | lr_schedule: constant 54 | lr_warmup: 0 55 | lr_warmup_scale: 0. 56 | schedule_from: 0 57 | weight_decay: 0. 58 | lr_decay_degree: 0.9 59 | 60 | # parameter arguments 61 | from_scratch: False 62 | head_tuning: False 63 | label_decoder_tuning: False 64 | input_embed_tuning: False 65 | output_embed_tuning: False 66 | relpos_tuning: False 67 | 68 | # logging arguments 69 | log_dir: FINETUNE_DAVIS2017 70 | save_dir: FINETUNE_DAVIS2017 71 | load_dir: TRAIN 72 | val_iter: 100 73 | load_step: 400000 74 | -------------------------------------------------------------------------------- /downstream/davis2017/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import random 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import Dataset 9 | from torchvision.transforms import ToTensor, Resize 10 | 11 | from dataset.utils import crop_arrays 12 | 13 | 14 | class DAVIS2017(Dataset): 15 | ''' 16 | base class for DAVIS2017 dataset 17 | ''' 18 | CLASS_NAMES = ['bike-packing', 'blackswan', 'bmx-trees', 'breakdance', 'camel', 'car-roundabout', 'car-shadow', 'cows', 19 | 'dance-twirl', 'dog', 'dogs-jump', 'drift-chicane', 'drift-straight', 'goat', 'gold-fish', 'horsejump-high', 20 | 'india', 'judo', 'kite-surf', 'lab-coat', 'libby', 'loading', 'mbike-trick', 'motocross-jump', 'paragliding-launch', 21 | 'parkour', 'pigs', 'scooter-black', 'shooting', 'soapbox'] 22 | NUM_INSTANCES = [2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 5, 2, 3, 2, 3, 5, 1, 3, 2, 2, 3, 1, 3, 2, 3, 3] 23 | 24 | 25 | class DAVIS2017Dataset(DAVIS2017): 26 | ''' 27 | DAVIS2017 dataset 28 | ''' 29 | def __init__(self, config, split, base_size, crop_size, eval_mode=False, resize=False, dset_size=-1): 30 | super().__init__() 31 | assert config.class_name in self.CLASS_NAMES 32 | 33 | # configure paths 34 | self.base_size = base_size 35 | data_root = config.path_dict[config.dataset] 36 | data_dir = f'resized_{config.base_size[1]}' 37 | self.image_dir = os.path.join(data_root, data_dir, config.class_name, 'images') 38 | self.label_dir = os.path.join(data_root, data_dir, config.class_name, 'labels') 39 | 40 | self.img_size = crop_size 41 | self.eval_mode = eval_mode 42 | self.resize = resize 43 | self.shot = config.shot 44 | self.support_idx = config.support_idx 45 | self.n_channels = self.NUM_INSTANCES[self.CLASS_NAMES.index(config.class_name)] + 1 46 | self.precision = config.precision 47 | self.randomscale = config.randomscale 48 | 49 | self.toten = ToTensor() 50 | self.resizer = Resize(self.img_size) 51 | 52 | self.data_idxs = np.array(sorted(os.listdir(self.image_dir))) 53 | 54 | assert split in ['train', 'valid', 'test'] 55 | if split == 'train': 56 | train_idxs = [0] 57 | for i in range(1, self.shot): 58 | train_idxs.append((len(self.data_idxs) // self.shot) * i) 59 | assert len(train_idxs) == self.shot, f'{len(train_idxs)} != {self.shot}' 60 | self.data_idxs = self.data_idxs[train_idxs] 61 | elif split == 'valid': 62 | n_vis = 10 63 | valid_idxs_reordered = [] 64 | valid_idxs = torch.arange(len(self.data_idxs)) 65 | vis_idxs = torch.linspace(min(valid_idxs), max(valid_idxs), n_vis).round().long().tolist() 66 | vis_idxs = [min(valid_idxs, key=lambda x:abs(x-vis_idx)) for vis_idx in vis_idxs] 67 | for i in vis_idxs: 68 | valid_idxs_reordered.append(i) 69 | for i in valid_idxs: 70 | if i not in vis_idxs: 71 | valid_idxs_reordered.append(i) 72 | self.data_idxs = self.data_idxs[valid_idxs_reordered] 73 | 74 | if dset_size < 0: 75 | self.dset_size = len(self.data_idxs) 76 | elif not eval_mode: 77 | self.dset_size = dset_size 78 | else: 79 | self.dset_size = min(dset_size, len(self.data_idxs)) 80 | 81 | def __len__(self): 82 | return self.dset_size 83 | 84 | def __getitem__(self, idx): 85 | img_path = self.data_idxs[idx % len(self.data_idxs)] 86 | image, label = self.load_data(img_path) 87 | 88 | return self.postprocess_data(image, label) 89 | 90 | def load_data(self, img_path): 91 | image = Image.open(os.path.join(self.image_dir, img_path)).convert('RGB') 92 | 93 | lbl_path = img_path.replace('jpg', 'png') 94 | label = Image.open(os.path.join(self.label_dir, lbl_path)) 95 | 96 | return image, label 97 | 98 | def postprocess_data(self, image, label): 99 | X = self.toten(image) 100 | Y = self.toten(label).squeeze(0) 101 | 102 | Y = (Y*255).long() 103 | Y = F.one_hot(Y, self.n_channels).permute(2, 0, 1).float() 104 | Y = Y[1:] # remove background channel 105 | 106 | if (not self.eval_mode): 107 | if self.randomscale and random.random() > 0.5: 108 | max_scale = 1.5 109 | scale_h = random.uniform(self.img_size[0] / self.base_size[0], max_scale) 110 | if random.random() > 0.5: 111 | scale_w = scale_h 112 | else: 113 | scale_w = random.uniform(self.img_size[1] / self.base_size[1], max_scale) 114 | target_size = (max(self.img_size[0], int(self.base_size[0] * scale_h)), 115 | max(self.img_size[1], int(self.base_size[1] * scale_w))) 116 | X = F.interpolate(X.unsqueeze(0), target_size, mode='bilinear', align_corners=False).squeeze(0) 117 | Y = F.interpolate(Y.unsqueeze(0), target_size, mode='nearest').squeeze(0) 118 | 119 | if self.resize: 120 | X = self.resizer(X) 121 | Y = self.resizer(Y) 122 | else: 123 | X, Y = crop_arrays(X, Y, 124 | base_size=X.size()[-2:], 125 | crop_size=self.img_size, 126 | random=(not self.eval_mode)) 127 | 128 | if self.precision == 'bf16': 129 | X = X.to(torch.bfloat16) 130 | Y = Y.to(torch.bfloat16) 131 | 132 | M = torch.ones_like(Y) 133 | 134 | return X, Y, M -------------------------------------------------------------------------------- /downstream/davis2017/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from skimage import color 5 | import os 6 | 7 | from ..base_learner import BaseLearner 8 | from .dataset import DAVIS2017Dataset 9 | from .utils import db_eval_iou, db_eval_boundary 10 | from train.visualize import postprocess_semseg 11 | 12 | 13 | class DAVIS2017Learner(BaseLearner): 14 | BaseDataset = DAVIS2017Dataset 15 | 16 | def compute_loss(self, Y_pred, Y, M): 17 | ''' 18 | cross-entropy loss with implicit background class 19 | ''' 20 | Y_bg = torch.zeros_like(Y[:, :1]) 21 | Y = torch.argmax(torch.cat((Y_bg, Y), dim=1), dim=1) 22 | 23 | Y_pred_bg = torch.zeros_like(Y_pred[:, :1]) 24 | Y_pred = torch.cat((Y_pred_bg, Y_pred), dim=1) 25 | 26 | loss = (M * F.cross_entropy(Y_pred, Y, reduction='none')).mean() 27 | loss_values = {'loss': loss.detach(), 'loss_ce': loss.detach()} 28 | 29 | return loss, loss_values 30 | 31 | def postprocess_final(self, Y_pred): 32 | ''' 33 | post-processing function for final prediction 34 | ''' 35 | if Y_pred.shape[1] == 1: 36 | Y_pred = (Y_pred.sigmoid() > 0.5).squeeze(1).to(Y_pred.dtype) 37 | else: 38 | Y_logits_bg = torch.zeros_like(Y_pred[:, :1]) 39 | Y_logits = torch.cat((Y_logits_bg, Y_pred), dim=1) 40 | Y_pred = torch.argmax(Y_logits, dim=1).to(Y_pred.dtype) 41 | 42 | return Y_pred 43 | 44 | def postprocess_vis(self, label, img=None, aux=None): 45 | ''' 46 | post-processing function for visualization 47 | ''' 48 | label = postprocess_semseg(label, img, aux) 49 | return label 50 | 51 | def compute_metric(self, Y, Y_pred, M, aux, evaluator_key=None): 52 | ''' 53 | J&F metric 54 | ''' 55 | if Y.shape[1] > 1: 56 | Y = torch.argmax(torch.cat((0.5*torch.ones_like(Y[:, :1]), Y), dim=1), dim=1) 57 | else: 58 | Y = Y.squeeze(1).long() 59 | j_metric = db_eval_iou(Y.cpu().float().numpy(), Y_pred.cpu().float().numpy()) 60 | f_metric = db_eval_boundary(Y.cpu().float().numpy(), Y_pred.cpu().float().numpy()) 61 | metric = torch.tensor((j_metric + f_metric).mean() / 2).to(Y.device) 62 | 63 | return metric 64 | 65 | def log_metrics(self, loss_pred, log_dict, valid_tag): 66 | ''' 67 | log inverted J&F metric 68 | ''' 69 | log_dict[f'{valid_tag}/{self.vis_tag}_J&F_inverted'] = 1 - loss_pred 70 | 71 | def save_test_outputs(self, Y_pred, batch_idx): 72 | if self.config.class_name == 'bike-packing': 73 | target_size = (480, 910) 74 | elif self.config.class_name == 'shooting': 75 | target_size = (480, 1152) 76 | else: 77 | target_size = (480, 854) 78 | Y_pred = F.interpolate(Y_pred[:, None], target_size, mode='nearest')[:, 0] 79 | 80 | for i in range(len(Y_pred)): 81 | img = self.topil(Y_pred[i].cpu().byte()) 82 | img.save(os.path.join(self.result_dir, f'{batch_idx*self.config.eval_batch_size+i:05d}.png')) 83 | -------------------------------------------------------------------------------- /downstream/davis2017/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def db_eval_iou(annotation, segmentation, void_pixels=None): 7 | """ Compute region similarity as the Jaccard Index. 8 | Arguments: 9 | annotation (ndarray): binary annotation map. 10 | segmentation (ndarray): binary segmentation map. 11 | void_pixels (ndarray): optional mask with void pixels 12 | 13 | Return: 14 | jaccard (float): region similarity 15 | """ 16 | assert annotation.shape == segmentation.shape, \ 17 | f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.' 18 | annotation = annotation.astype(bool) 19 | segmentation = segmentation.astype(bool) 20 | 21 | if void_pixels is not None: 22 | assert annotation.shape == void_pixels.shape, \ 23 | f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.' 24 | void_pixels = void_pixels.astype(bool) 25 | else: 26 | void_pixels = np.zeros_like(segmentation) 27 | 28 | # Intersection between all sets 29 | inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 30 | union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 31 | 32 | j = inters / union 33 | if j.ndim == 0: 34 | j = 1 if np.isclose(union, 0) else j 35 | else: 36 | j[np.isclose(union, 0)] = 1 37 | return j 38 | 39 | 40 | def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008): 41 | assert annotation.shape == segmentation.shape 42 | if void_pixels is not None: 43 | assert annotation.shape == void_pixels.shape 44 | if annotation.ndim == 3: 45 | n_frames = annotation.shape[0] 46 | f_res = np.zeros(n_frames) 47 | for frame_id in range(n_frames): 48 | void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ] 49 | f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th) 50 | elif annotation.ndim == 2: 51 | f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th) 52 | else: 53 | raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions') 54 | return f_res 55 | 56 | 57 | def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008): 58 | """ 59 | Compute mean,recall and decay from per-frame evaluation. 60 | Calculates precision/recall for boundaries between foreground_mask and 61 | gt_mask using morphological operators to speed it up. 62 | 63 | Arguments: 64 | foreground_mask (ndarray): binary segmentation image. 65 | gt_mask (ndarray): binary annotated image. 66 | void_pixels (ndarray): optional mask with void pixels 67 | 68 | Returns: 69 | F (float): boundaries F-measure 70 | """ 71 | assert np.atleast_3d(foreground_mask).shape[2] == 1 72 | if void_pixels is not None: 73 | void_pixels = void_pixels.astype(bool) 74 | else: 75 | void_pixels = np.zeros_like(foreground_mask).astype(bool) 76 | 77 | bound_pix = bound_th if bound_th >= 1 else \ 78 | np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 79 | 80 | # Get the pixel boundaries of both masks 81 | fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels)) 82 | gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels)) 83 | 84 | from skimage.morphology import disk 85 | 86 | # fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) 87 | fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 88 | # gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) 89 | gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 90 | 91 | # Get the intersection 92 | gt_match = gt_boundary * fg_dil 93 | fg_match = fg_boundary * gt_dil 94 | 95 | # Area of the intersection 96 | n_fg = np.sum(fg_boundary) 97 | n_gt = np.sum(gt_boundary) 98 | 99 | # % Compute precision and recall 100 | if n_fg == 0 and n_gt > 0: 101 | precision = 1 102 | recall = 0 103 | elif n_fg > 0 and n_gt == 0: 104 | precision = 0 105 | recall = 1 106 | elif n_fg == 0 and n_gt == 0: 107 | precision = 1 108 | recall = 1 109 | else: 110 | precision = np.sum(fg_match) / float(n_fg) 111 | recall = np.sum(gt_match) / float(n_gt) 112 | 113 | # Compute F measure 114 | if precision + recall == 0: 115 | F = 0 116 | else: 117 | F = 2 * precision * recall / (precision + recall) 118 | 119 | return F 120 | 121 | 122 | def _seg2bmap(seg, width=None, height=None): 123 | """ 124 | From a segmentation, compute a binary boundary map with 1 pixel wide 125 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 126 | origin from the actual segment boundary. 127 | Arguments: 128 | seg : Segments labeled from 1..k. 129 | width : Width of desired bmap <= seg.shape[1] 130 | height : Height of desired bmap <= seg.shape[0] 131 | Returns: 132 | bmap (ndarray): Binary boundary map. 133 | David Martin 134 | January 2003 135 | """ 136 | 137 | seg = seg.astype(bool) 138 | seg[seg > 0] = 1 139 | 140 | assert np.atleast_3d(seg).shape[2] == 1 141 | 142 | width = seg.shape[1] if width is None else width 143 | height = seg.shape[0] if height is None else height 144 | 145 | h, w = seg.shape[:2] 146 | 147 | ar1 = float(width) / float(height) 148 | ar2 = float(w) / float(h) 149 | 150 | assert not ( 151 | width > w | height > h | abs(ar1 - ar2) > 0.01 152 | ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) 153 | 154 | e = np.zeros_like(seg) 155 | s = np.zeros_like(seg) 156 | se = np.zeros_like(seg) 157 | 158 | e[:, :-1] = seg[:, 1:] 159 | s[:-1, :] = seg[1:, :] 160 | se[:-1, :-1] = seg[1:, 1:] 161 | 162 | b = seg ^ e | seg ^ s | seg ^ se 163 | b[-1, :] = seg[-1, :] ^ e[-1, :] 164 | b[:, -1] = seg[:, -1] ^ s[:, -1] 165 | b[-1, -1] = 0 166 | 167 | if w == width and h == height: 168 | bmap = b 169 | else: 170 | bmap = np.zeros((height, width)) 171 | for x in range(w): 172 | for y in range(h): 173 | if b[y, x]: 174 | j = 1 + math.floor((y - 1) + height / h) 175 | i = 1 + math.floor((x - 1) + width / h) 176 | bmap[j, i] = 1 177 | 178 | return bmap 179 | -------------------------------------------------------------------------------- /downstream/fsc147/configs/test_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: fsc147 9 | task: object_counting 10 | class_wise: False 11 | class_name: none 12 | shot: 50 13 | img_size: 512 14 | base_size: 592 15 | vis_size: 592 16 | support_idx: 0 17 | coord_path: none 18 | autocrop: False 19 | autocrop_minoverlap: 0. 20 | autocrop_rescale: 3 21 | 22 | # dataloader arguments 23 | num_workers: 1 24 | eval_shot: 50 25 | eval_batch_size: 1 26 | eval_size: -1 27 | support_idx: 0 28 | chunk_size: 1 29 | channel_chunk_size: 1 30 | 31 | # logging arguments 32 | log_dir: TEST_FSC147 33 | save_dir: FINETUNE_FSC147 34 | load_dir: TRAIN 35 | load_step: 400000 36 | load_path: none 37 | result_dir: none -------------------------------------------------------------------------------- /downstream/fsc147/configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: fsc147 9 | task: object_counting 10 | class_wise: False 11 | class_name: none 12 | shot: 50 13 | img_size: 512 14 | base_size: 592 15 | vis_size: 592 16 | support_idx: 0 17 | randomflip: True 18 | randomjitter: True 19 | randomrotate: False 20 | randomblur: False 21 | coord_path: none 22 | autocrop: False 23 | autocrop_minoverlap: 0. 24 | autocrop_rescale: 3 25 | 26 | # dataloader arguments 27 | num_workers: 1 28 | global_batch_size: 8 29 | eval_shot: 20 30 | eval_batch_size: 4 31 | eval_size: -1 32 | chunk_size: -1 33 | channel_chunk_size: -1 34 | channel_sampling: -1 35 | 36 | # model arguments 37 | attn_dropout: 0.5 38 | n_input_images: 2 39 | separate_alpha: False 40 | monitor: MAE 41 | 42 | # training arguments 43 | n_steps: 30000 44 | n_schedule_steps: -1 45 | optimizer: adam 46 | loss_type: l2 47 | mask_value: -1. 48 | early_stopping_patience: -1 49 | 50 | # learning rate arguments 51 | lr: 0.001 52 | lr_pretrained: 0.0002 53 | lr_schedule: constant 54 | lr_warmup: 0 55 | lr_warmup_scale: 0. 56 | schedule_from: 0 57 | weight_decay: 0. 58 | lr_decay_degree: 0.9 59 | 60 | # parameter arguments 61 | from_scratch: False 62 | head_tuning: False 63 | input_embed_tuning: False 64 | output_embed_tuning: False 65 | relpos_tuning: True 66 | label_decoder_tuning: False 67 | 68 | # logging arguments 69 | log_dir: FINETUNE_FSC147 70 | save_dir: FINETUNE_FSC147 71 | load_dir: TRAIN 72 | val_iter: 2000 73 | load_step: 400000 74 | -------------------------------------------------------------------------------- /downstream/fsc147/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | from ..base_learner import BaseLearner 6 | from .dataset import FSC147Dataset 7 | from .utils import preprocess_kpmap, make_density_tensor, viridis 8 | 9 | 10 | class FSC147Learner(BaseLearner): 11 | BaseDataset = FSC147Dataset 12 | 13 | def register_evaluator(self): 14 | if self.config.stage == 1: 15 | keys = ['mtest_train', 'mtest_valid'] 16 | else: 17 | keys = ['mtest_test'] 18 | self.evaluator = {key: None for key in keys} 19 | 20 | for key in keys: 21 | self.evaluator[key] = [] 22 | 23 | def reset_evaluator(self): 24 | for key in self.evaluator.keys(): 25 | self.evaluator[key] = [] 26 | 27 | def compute_loss(self, Y_pred, Y, M): 28 | ''' 29 | loss function that returns loss and a dictionary of its components 30 | ''' 31 | loss = (M * F.mse_loss(Y_pred.sigmoid(), Y, reduction='none')).mean() 32 | loss_values = {'loss': loss.detach(), 'loss_con': loss.detach()} 33 | return loss, loss_values 34 | 35 | def postprocess_logits(self, Y_pred_out): 36 | Y_pred_out = Y_pred_out.sigmoid() 37 | return Y_pred_out 38 | 39 | def postprocess_vis(self, label, img=None, aux=None): 40 | ''' 41 | post-processing function for visualization 42 | ''' 43 | # process label by finding modes 44 | width = label.size(-1) // 64 45 | if width %2 == 1: 46 | width += 1 47 | 48 | label = label.clip(0, 1) 49 | label_vis = torch.stack([make_density_tensor(preprocess_kpmap(p[0], threshold=0.2), img_size=label.shape[-2:], width=width) for p in label]).float() 50 | label_vis = viridis(label_vis) 51 | if img is not None: 52 | label_vis = label_vis * 0.5 + img[:len(label_vis), :3] * 0.5 53 | 54 | label = label_vis 55 | 56 | return label 57 | 58 | def compute_metric(self, Y, Y_pred, M, aux, evaluator_key=None): 59 | ''' 60 | compute evaluation metric 61 | ''' 62 | weights = aux[0] 63 | evaluator = self.evaluator[evaluator_key] 64 | 65 | assert len(Y) == 1 66 | if Y_pred.size(-1) == Y.size(-1) and Y_pred.size(-2) == Y.size(-2): 67 | ninecrop_sample = torch.tensor([False]).to(Y.device) 68 | denom = 1 69 | else: 70 | ninecrop_sample = torch.tensor([True]).to(Y.device) 71 | denom = Y_pred.size(-1) * Y_pred.size(-2) / (Y.size(-1) * Y.size(-2)) 72 | # denom = 9 73 | 74 | Y = (Y.float() * rearrange(weights.float(), 'B -> B 1 1 1')).sum((1,2,3)).round().float() 75 | Y_pred_sum = ((Y_pred.float() * rearrange(weights.float(), 'B -> B 1 1 1')).sum((1,2,3)) / denom).round().float() 76 | Y_pred_mode = torch.tensor([len(preprocess_kpmap(p[0].data.cpu(), threshold=0.20)) for p in Y_pred]).float().to(Y.device) 77 | mask = (ninecrop_sample and Y_pred_sum > 3000).float() # apply mask to ninecrop samples whose sum is > 3000 78 | Y_pred = Y_pred_sum * mask + Y_pred_mode * (1 - mask) 79 | 80 | evaluator.append((Y, Y_pred)) 81 | metric = (Y - Y_pred).abs().mean() 82 | 83 | return metric 84 | 85 | def log_metrics(self, loss_pred, log_dict, valid_tag): 86 | ''' 87 | log evaluation metrics 88 | ''' 89 | log_dict[f'{valid_tag}/{self.vis_tag}_MAE'] = loss_pred 90 | 91 | def get_test_metrics(self, metrics_total): 92 | Ys = torch.cat([y for y, _ in self.evaluator['mtest_test']]) 93 | Y_preds = torch.cat([y_pred for _, y_pred in self.evaluator['mtest_test']]) 94 | mae = (Ys - Y_preds).abs().mean() 95 | rmse = ((Ys - Y_preds)**2).mean().sqrt() 96 | metrics = [mae, rmse] 97 | return metrics -------------------------------------------------------------------------------- /downstream/fsc147/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from skimage.morphology import extrema 3 | 4 | 5 | def preprocess_kpmap(heatmap, threshold=0.1): 6 | assert heatmap.ndim == 2 7 | maxima = extrema.h_maxima((heatmap).numpy(), threshold) 8 | grid = torch.stack(torch.meshgrid(torch.arange(heatmap.shape[0]), torch.arange(heatmap.shape[1]), indexing='ij'), dim=2) 9 | modes = grid[maxima.astype(bool)] 10 | return modes 11 | 12 | 13 | def make_density_tensor(kpmaps, img_size=(256,256), width=2): 14 | # N, 2 15 | assert width % 2 == 0 16 | density = torch.zeros(1, *img_size) 17 | for x, y in kpmaps: 18 | density[0, x - width // 2 : x + width // 2, y - width // 2 : y + width // 2] = 1 19 | 20 | return density 21 | 22 | 23 | def viridis(x): 24 | c1 = torch.tensor([68., 1., 84.]) / 255 25 | c2 = torch.tensor([33., 145., 140.]) / 255 26 | c3 = torch.tensor([253., 231., 37.]) / 255 27 | 28 | for i in range(x.ndim - 1): 29 | c1 = c1[:, None] 30 | c2 = c2[:, None] 31 | c3 = c3[:, None] 32 | if x.ndim == 4: 33 | c1 = c1.transpose(0, 1) 34 | c2 = c2.transpose(0, 1) 35 | c3 = c3.transpose(0, 1) 36 | 37 | x = torch.where( 38 | x < 0.5, 39 | c1 * (0.5 - x) * 2 + c2 * x * 2, 40 | c2 * (1 - x) * 2 + c3 * (x - 0.5) * 2 41 | ) 42 | return x -------------------------------------------------------------------------------- /downstream/isic2018/configs/test_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: isic2018 9 | task: segment_medical 10 | class_wise: False 11 | class_name: none 12 | shot: 20 13 | img_size: 384 14 | base_size: 448 15 | vis_size: 448 16 | support_idx: 0 17 | coord_path: none 18 | autocrop: False 19 | autocrop_minoverlap: 0.5 20 | 21 | # dataloader arguments 22 | num_workers: 1 23 | eval_shot: 50 24 | eval_batch_size: 8 25 | eval_size: -1 26 | support_idx: 0 27 | chunk_size: -1 28 | channel_chunk_size: -1 29 | 30 | # logging arguments 31 | log_dir: TEST_ISIC2018 32 | save_dir: FINETUNE_ISIC2018 33 | load_dir: TRAIN 34 | load_step: 400000 35 | load_path: none 36 | result_dir: none -------------------------------------------------------------------------------- /downstream/isic2018/configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: isic2018 9 | task: segment_medical 10 | class_wise: False 11 | class_name: none 12 | shot: 20 13 | img_size: 384 14 | base_size: 448 15 | vis_size: 448 16 | support_idx: 0 17 | randomflip: False 18 | randomjitter: True 19 | randomrotate: True 20 | randomblur: True 21 | coord_path: none 22 | autocrop: False 23 | autocrop_minoverlap: 0.5 24 | 25 | # dataloader arguments 26 | num_workers: 1 27 | global_batch_size: 5 28 | eval_shot: 50 29 | eval_batch_size: 10 30 | eval_size: 100 31 | chunk_size: -1 32 | channel_chunk_size: -1 33 | channel_sampling: -1 34 | 35 | # model arguments 36 | attn_dropout: 0.5 37 | n_input_images: 1 38 | separate_alpha: False 39 | 40 | # training arguments 41 | n_steps: 10000 42 | n_schedule_steps: -1 43 | optimizer: adam 44 | loss_type: bce 45 | mask_value: -1. 46 | early_stopping_patience: 5 47 | monitor: F1_inverted 48 | 49 | # learning rate arguments 50 | lr: 0.001 51 | lr_pretrained: 0.0002 52 | lr_schedule: constant 53 | lr_warmup: 0 54 | lr_warmup_scale: 0. 55 | schedule_from: 0 56 | weight_decay: 0. 57 | lr_decay_degree: 0.9 58 | 59 | # parameter arguments 60 | from_scratch: False 61 | head_tuning: False 62 | label_decoder_tuning: False 63 | input_embed_tuning: False 64 | output_embed_tuning: False 65 | relpos_tuning: False 66 | 67 | # logging arguments 68 | log_dir: FINETUNE_ISIC2018 69 | save_dir: FINETUNE_ISIC2018 70 | load_dir: TRAIN 71 | val_iter: 100 72 | load_step: 400000 73 | -------------------------------------------------------------------------------- /downstream/isic2018/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torchvision.transforms import ToTensor, Resize 6 | 7 | from dataset.utils import crop_arrays 8 | 9 | 10 | class ISIC2018Dataset(Dataset): 11 | ''' 12 | ISIC2018 dataset 13 | ''' 14 | def __init__(self, config, split, base_size, crop_size, eval_mode=False, resize=False, dset_size=-1): 15 | super().__init__() 16 | self.base_size = base_size 17 | data_root = config.path_dict[config.dataset] 18 | self.eval_mode = eval_mode 19 | if self.eval_mode: 20 | self.image_dir = os.path.join(data_root, 'original', 'images') 21 | self.label_dir = os.path.join(data_root, 'original', 'labels') 22 | else: 23 | self.image_dir = os.path.join(data_root, f'resized_{base_size[0]}', 'images') 24 | self.label_dir = os.path.join(data_root, f'resized_{base_size[0]}', 'labels') 25 | 26 | self.img_size = crop_size 27 | self.resize = resize 28 | self.shot = config.shot 29 | self.support_idx = config.support_idx 30 | self.precision = config.precision 31 | 32 | assert split in ['train', 'valid', 'test'] 33 | self.dset_size = dset_size 34 | self.base_size = base_size 35 | self.toten = ToTensor() 36 | self.resizer = Resize(self.img_size) 37 | self.resize = resize 38 | 39 | ptf = '' if self.support_idx == 0 else f'_{self.support_idx}' 40 | if split == 'test': 41 | file_path = os.path.join(data_root, 'meta', f'test_files{ptf}.pth') 42 | else: 43 | file_path = os.path.join(data_root, 'meta', f'train_files{ptf}.pth') 44 | file_dict = torch.load(file_path) 45 | 46 | self.data_idxs = [] 47 | if split == 'test': 48 | for class_name in file_dict.keys(): 49 | self.data_idxs += file_dict[class_name] 50 | else: 51 | # get class names sorted by number of files 52 | class_names = list(file_dict.keys()) 53 | class_names.sort(key=lambda x: len(file_dict[x])) 54 | class_names = list(reversed(class_names)) 55 | 56 | # choose number of files per class 57 | shot_per_class = [self.shot // len(class_names) for _ in range(len(class_names))] 58 | for i in range(self.shot % len(class_names)): 59 | shot_per_class[i] += 1 60 | 61 | # choose files for training 62 | if split == 'train': 63 | for class_name, n_files in zip(class_names, shot_per_class): 64 | self.data_idxs += file_dict[class_name][self.support_idx*n_files:(self.support_idx+1)*n_files] 65 | if len(file_dict[class_name]) - self.support_idx*n_files < n_files: 66 | self.data_idxs += file_dict[class_name][:(n_files - len(file_dict[class_name]) + self.support_idx*n_files)] 67 | 68 | # choose files for validation 69 | else: 70 | files_per_class = [dset_size // len(class_names) for _ in range(len(class_names))] 71 | for i in range(dset_size % len(class_names)): 72 | files_per_class[i] += 1 73 | 74 | for class_name, n_files_train, n_files_val in zip(class_names, shot_per_class, files_per_class): 75 | valid_files = file_dict[class_name][:self.support_idx*n_files_train] + file_dict[class_name][(self.support_idx+1)*n_files_train:] 76 | self.data_idxs += valid_files[:n_files_val] 77 | 78 | if dset_size < 0: 79 | self.dset_size = len(self.data_idxs) 80 | elif not eval_mode: 81 | self.dset_size = dset_size 82 | else: 83 | self.dset_size = min(dset_size, len(self.data_idxs)) 84 | 85 | def __len__(self): 86 | return self.dset_size 87 | 88 | def __getitem__(self, idx): 89 | img_path = self.data_idxs[idx % len(self.data_idxs)] + '.jpg' 90 | image, label = self.load_data(img_path) 91 | 92 | return self.postprocess_data(image, label) 93 | 94 | def load_data(self, img_path): 95 | image = Image.open(os.path.join(self.image_dir, img_path)).convert('RGB') 96 | lbl_path = img_path.replace('.jpg', '.png') 97 | label = Image.open(os.path.join(self.label_dir, lbl_path)) 98 | 99 | return image, label 100 | 101 | def postprocess_data(self, image, label): 102 | if self.eval_mode: 103 | Y_full = self.toten(label.resize((512, 512))) 104 | image = image.resize((self.base_size[1], self.base_size[0])) 105 | label = label.resize((self.base_size[1], self.base_size[0])) 106 | 107 | X = self.toten(image) 108 | Y = self.toten(label) 109 | 110 | if self.precision == 'bf16': 111 | X = X.to(torch.bfloat16) 112 | Y = Y.to(torch.bfloat16) 113 | 114 | if self.resize: 115 | X = self.resizer(X) 116 | Y = self.resizer(Y) 117 | else: 118 | X, Y = crop_arrays(X, Y, 119 | base_size=X.size()[-2:], 120 | crop_size=self.img_size, 121 | random=(not self.eval_mode)) 122 | M = torch.ones_like(Y) 123 | 124 | if self.eval_mode: 125 | return X, Y, M, Y_full 126 | else: 127 | return X, Y, M 128 | -------------------------------------------------------------------------------- /downstream/isic2018/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | from skimage import color 5 | from einops import reduce 6 | import numpy as np 7 | 8 | from ..base_learner import BaseLearner 9 | from .dataset import ISIC2018Dataset 10 | from train.miou_fss import AverageMeter, Evaluator 11 | from train.visualize import postprocess_semseg 12 | 13 | 14 | class ISIC2018Learner(BaseLearner): 15 | BaseDataset = ISIC2018Dataset 16 | 17 | def register_evaluator(self): 18 | if self.config.stage == 1: 19 | keys = ['mtest_train', 'mtest_valid'] 20 | else: 21 | keys = ['mtest_test'] 22 | 23 | self.evaluator = {key: None for key in keys} 24 | for key in keys: 25 | self.evaluator[key] = AverageMeter(class_ids_interest=[0], semseg_classes=[0], device=torch.device(f'cuda:{self.local_rank}')) 26 | self.result_path = self.result_path.replace('.pth', f'_sid:{self.config.support_idx}.pth') 27 | 28 | def reset_evaluator(self): 29 | for key in self.evaluator: 30 | self.evaluator[key].reset() 31 | 32 | def compute_loss(self, Y_pred, Y, M): 33 | ''' 34 | loss function that returns loss and a dictionary of its components 35 | ''' 36 | loss = (M * F.binary_cross_entropy_with_logits(Y_pred, Y, reduction='none')).mean() 37 | loss_values = {'loss': loss.detach(), 'loss_bce': loss.detach()} 38 | return loss, loss_values 39 | 40 | def postprocess_final(self, Y_pred): 41 | Y_pred = (Y_pred.sigmoid() > 0.5).squeeze(1).to(Y_pred.dtype) 42 | return Y_pred 43 | 44 | def postprocess_vis(self, label, img=None, aux=None): 45 | ''' 46 | post-processing function for visualization 47 | ''' 48 | label = postprocess_semseg(label, img, aux) 49 | return label 50 | 51 | def compute_metric(self, Y, Y_pred, M, aux, evaluator_key=None): 52 | ''' 53 | compute evaluation metric 54 | ''' 55 | Y_full = aux 56 | evaluator = self.evaluator[evaluator_key] 57 | 58 | if Y_pred.ndim == 3: 59 | Y_pred = Y_pred.unsqueeze(1) 60 | assert Y_full.ndim == 4 61 | if Y_pred.shape[-2:] != (512, 512): 62 | Y_pred = F.interpolate(Y_pred.float(), (512, 512), mode='nearest') 63 | area_inter, area_union = Evaluator.classify_prediction(Y_pred.clone().float(), Y_full.float().round()) 64 | class_id = torch.tensor([0]*len(Y_pred), device=Y_full.device) # use 0 for all classes 65 | area_inter = area_inter.to(Y_full.device) 66 | area_union = area_union.to(Y_full.device) 67 | evaluator.update(area_inter, area_union, class_id) 68 | metric = 0 69 | 70 | return metric 71 | 72 | def log_metrics(self, loss_pred, log_dict, valid_tag): 73 | ''' 74 | log evaluation metrics 75 | ''' 76 | evaluator = self.evaluator[valid_tag] 77 | if self.n_devices > 1: 78 | evaluator.intersection_buf = reduce(self.trainer.all_gather(evaluator.intersection_buf), 'G ... -> ...', 'sum') 79 | evaluator.union_buf = reduce(self.trainer.all_gather(evaluator.union_buf), 'G ... -> ...', 'sum') 80 | intersection = evaluator.intersection_buf.float() 81 | union = evaluator.union_buf.float() 82 | f1 = 2*intersection / torch.max(torch.stack([union + intersection, evaluator.ones]), dim=0)[0] 83 | log_dict[f'{valid_tag}/{self.vis_tag}_F1_inverted'] = 1 - f1[1, 0] 84 | 85 | def get_test_metrics(self, metrics_total): 86 | ''' 87 | save test metrics 88 | ''' 89 | evaluator = self.evaluator['mtest_test'] 90 | if self.n_devices > 1: 91 | evaluator.intersection_buf = reduce(self.all_gather(evaluator.intersection_buf), 'G ... -> ...', 'sum') 92 | evaluator.union_buf = reduce(self.all_gather(evaluator.union_buf), 'G ... -> ...', 'sum') 93 | intersection = evaluator.intersection_buf.float() 94 | union = evaluator.union_buf.float() 95 | dsc = 2*intersection / torch.max(torch.stack([union + intersection, evaluator.ones]), dim=0)[0] 96 | dsc = dsc[1, 0].item() 97 | iou = evaluator.compute_iou()[0].cpu().item() 98 | metric = [dsc, iou] 99 | 100 | return metric -------------------------------------------------------------------------------- /downstream/learner_factory.py: -------------------------------------------------------------------------------- 1 | from .davis2017.learner import DAVIS2017Learner 2 | from .ap10k.learner import AP10KLearner 3 | from .linemod.learner import LINEMODLearner, LINEMODMaskLearner 4 | from .isic2018.learner import ISIC2018Learner 5 | from .cellpose.learner import CELLPOSELearner 6 | from .fsc147.learner import FSC147Learner 7 | 8 | 9 | def get_downstream_learner(config, trainer): 10 | ''' 11 | add custom learner here 12 | ''' 13 | if config.dataset == 'davis2017': 14 | return DAVIS2017Learner(config, trainer) 15 | elif config.dataset == 'ap10k': 16 | return AP10KLearner(config, trainer) 17 | elif config.dataset == 'linemod': 18 | if config.task == 'pose_6d': 19 | return LINEMODLearner(config, trainer) 20 | elif config.task == 'segment_semantic': 21 | return LINEMODMaskLearner(config, trainer) 22 | else: 23 | raise NotImplementedError 24 | elif config.dataset == 'isic2018': 25 | return ISIC2018Learner(config, trainer) 26 | elif config.dataset == 'cellpose': 27 | return CELLPOSELearner(config, trainer) 28 | elif config.dataset == 'fsc147': 29 | return FSC147Learner(config, trainer) 30 | else: 31 | raise NotImplementedError -------------------------------------------------------------------------------- /downstream/linemod/configs/test_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: linemod 9 | task: pose_6d 10 | class_wise: True 11 | class_name: none 12 | shot: 50 13 | img_size: 224 14 | base_size: 256 15 | vis_size: 256 16 | support_idx: 0 17 | coord_path: none 18 | autocrop: False 19 | autocrop_minoverlap: 0.5 20 | 21 | # dataloader arguments 22 | num_workers: 1 23 | eval_shot: 50 24 | eval_batch_size: 8 25 | eval_size: -1 26 | support_idx: 0 27 | chunk_size: -1 28 | channel_chunk_size: -1 29 | 30 | # logging arguments 31 | log_dir: TEST_LINEMOD 32 | save_dir: FINETUNE_LINEMOD 33 | load_dir: TRAIN 34 | load_step: 400000 35 | load_path: none 36 | result_dir: none -------------------------------------------------------------------------------- /downstream/linemod/configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: ddp 5 | num_nodes: 1 6 | 7 | # data arguments 8 | dataset: linemod 9 | task: pose_6d 10 | class_wise: True 11 | class_name: none 12 | shot: 50 13 | img_size: 224 14 | base_size: 256 15 | vis_size: 256 16 | support_idx: 0 17 | randomflip: False 18 | randomjitter: True 19 | randomrotate: True 20 | randomblur: True 21 | coord_path: none 22 | autocrop: False 23 | autocrop_minoverlap: 0.5 24 | 25 | # dataloader arguments 26 | num_workers: 1 27 | global_batch_size: 20 28 | eval_shot: 50 29 | eval_batch_size: 10 30 | eval_size: 100 31 | chunk_size: -1 32 | channel_chunk_size: -1 33 | channel_sampling: -1 34 | 35 | # model arguments 36 | attn_dropout: 0.5 37 | n_input_images: 1 38 | separate_alpha: True 39 | 40 | # training arguments 41 | n_steps: 20000 42 | n_schedule_steps: -1 43 | optimizer: adam 44 | loss_type: suvw 45 | mask_value: -1. 46 | early_stopping_patience: -1 47 | monitor: ADD0.1s_inverted 48 | 49 | # learning rate arguments 50 | lr: 0.005 51 | lr_pretrained: 0.0002 52 | lr_schedule: poly 53 | lr_warmup: 0 54 | lr_warmup_scale: 0. 55 | schedule_from: 0 56 | weight_decay: 0. 57 | lr_decay_degree: 0.9 58 | 59 | # parameter arguments 60 | from_scratch: False 61 | head_tuning: True 62 | label_decoder_tuning: False 63 | input_embed_tuning: False 64 | output_embed_tuning: True 65 | relpos_tuning: False 66 | 67 | # logging arguments 68 | log_dir: FINETUNE_LINEMOD 69 | save_dir: FINETUNE_LINEMOD 70 | load_dir: TRAIN 71 | val_iter: 100 72 | load_step: 400000 73 | -------------------------------------------------------------------------------- /get_beitv2.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import os 3 | import glob 4 | import shutil 5 | from pathlib import Path 6 | import os; import ipdb; ipdb.set_trace(context=15) if os.environ.get("LOCAL_RANK", '0') == '0' else None 7 | 8 | beitv2_large = timm.create_model("hf_hub:timm/beitv2_large_patch16_224.in1k_ft_in22k", pretrained=True) 9 | # import os; import ipdb; ipdb.set_trace(context=15) if os.environ.get("LOCAL_RANK", '0') == '0' else None 10 | file_path = glob.glob('/root/.cache/huggingface/hub/models--timm--beitv2_large_patch16_224.in1k_ft_in22k/snapshots/**/pytorch_model.bin')[0] 11 | 12 | Path('/root/.cache/torch/hub/checkpoints/').mkdir(parents=True, exist_ok=True) 13 | os.system(f'ln -s {file_path} /root/.cache/torch/hub/checkpoints/beitv2_large_patch16_224_pt1k_ft21kto1k.pth') 14 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | import torch 4 | import warnings 5 | 6 | from args import parse_args 7 | from train.train_utils import configure_experiment, load_model, print_configs 8 | from lightning_fabric.utilities.seed import seed_everything 9 | 10 | 11 | if __name__ == "__main__": 12 | torch.multiprocessing.freeze_support() 13 | torch.set_num_threads(1) 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | warnings.filterwarnings("ignore", category=pl.utilities.warnings.PossibleUserWarning) 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | 18 | # parse args 19 | config = parse_args() 20 | seed_everything(config.seed, workers=True) 21 | 22 | if config.slurm: 23 | IS_RANK_ZERO = int(os.environ.get('SLURM_LOCALID', 0)) == 0 24 | else: 25 | IS_RANK_ZERO = int(os.environ.get('LOCAL_RANK', 0)) == 0 26 | 27 | if not config.check_mode: 28 | # load model 29 | model, config, ckpt_path, mt_config, ft_config, ts_config = load_model(config, verbose=IS_RANK_ZERO, reduced=(config.stage > 0)) 30 | 31 | # environmental settings 32 | logger, log_dir, save_dir, callbacks, profiler, precision, strategy, plugins = configure_experiment(config, model, is_rank_zero=IS_RANK_ZERO) 33 | model.config.ckpt_dir = save_dir 34 | model.config.result_dir = log_dir 35 | 36 | # print configs 37 | if IS_RANK_ZERO: 38 | print_configs(config, model, mt_config, ft_config, ts_config) 39 | 40 | # set max epochs 41 | if (not config.no_eval) and config.stage <= 1: 42 | max_epochs = config.n_steps // config.val_iter 43 | else: 44 | max_epochs = 1 45 | 46 | # create pytorch lightning trainer. 47 | trainer = pl.Trainer( 48 | logger=logger, 49 | default_root_dir=save_dir, 50 | accelerator='gpu', 51 | max_epochs=max_epochs, 52 | log_every_n_steps=-1, 53 | num_sanity_val_steps=(2 if config.sanity_check else 0), 54 | callbacks=callbacks, 55 | benchmark=True, 56 | devices=(1 if config.single_gpu else torch.cuda.device_count()), 57 | strategy=strategy, 58 | precision=precision, 59 | profiler=profiler, 60 | plugins=plugins, 61 | gradient_clip_val=config.gradient_clip_val, 62 | num_nodes=config.num_nodes, 63 | ) 64 | 65 | # validation at start 66 | if config.stage == 1 or (config.stage == 0 and config.no_train): 67 | trainer.validate(model, verbose=False) 68 | 69 | # start evaluation 70 | if config.stage == 2: 71 | trainer.test(model) 72 | # start training or fine-tuning 73 | elif not config.no_train: 74 | trainer.fit(model, ckpt_path=ckpt_path) 75 | -------------------------------------------------------------------------------- /main_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/main_figure.png -------------------------------------------------------------------------------- /meta_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/meta_train/__init__.py -------------------------------------------------------------------------------- /meta_train/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/meta_train/datasets/__init__.py -------------------------------------------------------------------------------- /meta_train/train_config.yaml: -------------------------------------------------------------------------------- 1 | # environment settings 2 | seed: 0 3 | precision: bf16 4 | strategy: deepspeed 5 | gradient_clip_val: 2. 6 | slurm: False 7 | num_nodes: 1 8 | 9 | # data arguments 10 | dataset: unified 11 | taskonomy: True 12 | coco: True 13 | midair: True 14 | mpii: True 15 | deepfashion: True 16 | freihand: True 17 | midair_stereo: True 18 | coco_stereo: True 19 | task_sampling_weight: [1., 3., 3.] 20 | task: unified 21 | base_task: True 22 | cont_task: True 23 | cat_task: True 24 | task_group: None 25 | coco_cropped: True 26 | use_stereo_datasets: True 27 | no_coco_kp: False 28 | 29 | num_workers: 8 30 | global_batch_size: 8 31 | shot: 4 32 | support_idx: 0 33 | eval_shot: 16 34 | max_channels: 4 35 | domains_per_batch: 2 36 | eval_batch_size: 8 37 | eval_size: 80 38 | img_size: 224 39 | base_size: -1 40 | vis_size: -1 41 | image_augmentation: True 42 | label_augmentation: True 43 | autocrop: False 44 | chunk_size: 1 45 | channel_chunk_size: 17 46 | channel_sampling: -1 47 | 48 | # model arguments 49 | image_encoder: beitv2_large_patch16_224 50 | label_encoder: vit_large_patch16_224 51 | n_attn_heads: 16 52 | decoder_features: 256 53 | image_encoder_drop_path_rate: 0.1 54 | label_encoder_drop_path_rate: 0. 55 | n_input_images: 1 56 | separate_alpha: False 57 | matching_alpha_init: 0. 58 | matching_alpha_temp: 0.05 59 | 60 | # training arguments 61 | n_steps: 400000 62 | n_schedule_steps: -1 63 | optimizer: adam 64 | lr: 0.0001 65 | lr_pretrained: 0.00001 66 | lr_schedule: poly 67 | lr_warmup: 5000 68 | lr_warmup_scale: 0. 69 | schedule_from: 0 70 | weight_decay: 0. 71 | lr_decay_degree: 0.9 72 | mask_value: -1. 73 | early_stopping_patience: -1 74 | loss_type: hybrid 75 | con_coef: 1.0 76 | bce_coef: 1.0 77 | ssl_coef: 1.0 78 | from_scratch: False 79 | scale_ssl: True 80 | 81 | # logging arguments 82 | log_dir: TRAIN 83 | save_dir: TRAIN 84 | load_dir: TRAIN 85 | val_iter: 10000 86 | monitor: summary/mtrain_valid_pred 87 | load_step: -1 88 | -------------------------------------------------------------------------------- /meta_train/utils.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange, repeat 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def disassemble_batch(batch, resize=None): 7 | X, Y, M, *_ = batch 8 | if resize is not None: 9 | X = F.interpolate(X, size=resize, mode='bilinear', align_corners=False) 10 | Y = F.interpolate(Y, size=resize, mode='nearest') 11 | M = F.interpolate(M, size=resize, mode='nearest') 12 | 13 | T = Y.size(1) 14 | X = repeat(X, 'N C H W -> 1 T N C H W', T=T) 15 | Y = rearrange(Y, 'N T H W -> 1 T N 1 H W') 16 | M = rearrange(M, 'N T H W -> 1 T N 1 H W') 17 | 18 | return X, Y, M 19 | 20 | 21 | def generate_task_mask(t_idx, task_idxs): 22 | ''' 23 | Generate binary mask whether the task is semantic segmentation (1) or not (0). 24 | ''' 25 | task_mask = torch.zeros_like(t_idx, dtype=bool) 26 | for task_idx in task_idxs: 27 | task_mask = torch.logical_or(task_mask, t_idx == task_idx) 28 | 29 | return task_mask 30 | 31 | 32 | def normalize_tensor(input_tensor, dim): 33 | ''' 34 | Normalize Euclidean vector. 35 | ''' 36 | norm = torch.norm(input_tensor, p='fro', dim=dim, keepdim=True) 37 | zero_mask = (norm == 0) 38 | norm[zero_mask] = 1 39 | out = input_tensor.div(norm) 40 | out[zero_mask.expand_as(out)] = 0 41 | return out -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/model/__init__.py -------------------------------------------------------------------------------- /model/chameleon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .encoder import ViTEncoder 5 | from .decoder import DPTDecoder 6 | from .matching import MatchingModule 7 | 8 | 9 | class Chameleon(nn.Module): 10 | def __init__(self, config, n_tasks, n_task_groups): 11 | super().__init__() 12 | self.n_tasks = n_tasks 13 | self.n_task_groups = n_task_groups 14 | self.separate_alpha = getattr(config, 'separate_alpha', False) 15 | 16 | self.image_encoder = ViTEncoder(config, config.image_encoder, pretrained=(config.stage == 0 and not config.continue_mode), 17 | in_chans=3, drop_path_rate=config.image_encoder_drop_path_rate, 18 | n_bias_sets=self.n_tasks, n_input_images=config.n_input_images,) 19 | self.label_encoder = ViTEncoder(config, config.label_encoder, pretrained=False, 20 | in_chans=1, drop_path_rate=config.label_encoder_drop_path_rate, 21 | n_bias_sets=0, n_input_images=1,) 22 | self.matching_module = MatchingModule(self.image_encoder.backbone.embed_dim, self.label_encoder.backbone.embed_dim, config.n_attn_heads, 23 | alpha_init=config.matching_alpha_init, alpha_temp=config.matching_alpha_temp, 24 | n_alphas=(self.n_tasks if self.separate_alpha else self.n_task_groups)) 25 | self.label_decoder = DPTDecoder(self.label_encoder.grid_size, self.label_encoder.backbone.embed_dim, 26 | hidden_features=[min(config.decoder_features*(2**i), 1024) for i in range(4)], 27 | out_chans=1, img_size=config.img_size) 28 | 29 | self.reset_support() 30 | 31 | def reset_support(self): 32 | # for support encoding 33 | self.has_encoded_support = False 34 | self.W_Ss = self.Z_Ss = None 35 | 36 | def bias_parameters(self): 37 | # bias parameters for similarity adaptation 38 | for p in self.image_encoder.bias_parameters(): 39 | yield p 40 | 41 | def bias_parameter_names(self): 42 | names = [f'image_encoder.{name}' for name in self.image_encoder.bias_parameter_names()] 43 | 44 | return names 45 | 46 | def pretrained_parameters(self): 47 | for p in self.image_encoder.parameters(): 48 | yield p 49 | 50 | def scratch_parameters(self): 51 | modules = [self.label_encoder, self.matching_module, self.label_decoder] 52 | for module in modules: 53 | for p in module.parameters(): 54 | yield p 55 | 56 | def forward(self, X_S, Y_S, X_Q, t_idx=None, g_idx=None): 57 | # encode query input, support input and output 58 | W_Qs = self.image_encoder(X_Q, t_idx=t_idx) 59 | W_Ss = self.image_encoder(X_S, t_idx=t_idx) 60 | Z_Ss = self.label_encoder(Y_S) 61 | 62 | # mix support output by matching 63 | a_idx = t_idx if self.separate_alpha else g_idx 64 | Z_Q_preds = self.matching_module(W_Qs, W_Ss, Z_Ss, a_idx=a_idx) 65 | 66 | # decode support output 67 | Y_Q_pred = self.label_decoder(Z_Q_preds) 68 | 69 | return Y_Q_pred 70 | 71 | @torch.no_grad() 72 | def encode_support(self, X_S, Y_S, t_idx=None, g_idx=None): 73 | self.t_idx = t_idx 74 | self.g_idx = g_idx 75 | 76 | # encode query input, support input and output 77 | W_Ss = self.image_encoder(X_S, t_idx=t_idx) 78 | Z_Ss = self.label_encoder(Y_S) 79 | self.has_encoded_support = True 80 | 81 | # append suppot data 82 | if self.W_Ss is None: 83 | self.W_Ss = W_Ss 84 | self.Z_Ss = Z_Ss 85 | else: 86 | self.W_Ss = {level: torch.cat([self.W_Ss[level], W_Ss[level]], dim=2) for level in range(len(W_Ss))} 87 | self.Z_Ss = {level: torch.cat([self.Z_Ss[level], Z_Ss[level]], dim=2) for level in range(len(Z_Ss))} 88 | 89 | def predict_query(self, X_Q, channel_idxs=None, get_attn_map=False): 90 | assert self.has_encoded_support 91 | if channel_idxs is not None: 92 | W_Ss = {level: self.W_Ss[level][:, channel_idxs] for level in range(len(self.W_Ss))} 93 | Z_Ss = {level: self.Z_Ss[level][:, channel_idxs] for level in range(len(self.Z_Ss))} 94 | t_idx = self.t_idx[:, channel_idxs] 95 | else: 96 | W_Ss = self.W_Ss 97 | Z_Ss = self.Z_Ss 98 | t_idx = self.t_idx 99 | 100 | W_Qs = self.image_encoder(X_Q, t_idx=t_idx) 101 | 102 | a_idx = t_idx if self.separate_alpha else self.g_idx 103 | 104 | if get_attn_map: 105 | Z_Q_preds, As = self.matching_module(W_Qs, W_Ss, Z_Ss, a_idx=a_idx, get_attn_map=True) 106 | return torch.stack(As, dim=-2) 107 | else: 108 | Z_Q_preds = self.matching_module(W_Qs, W_Ss, Z_Ss, a_idx=a_idx, get_attn_map=False) 109 | 110 | Y_Q_pred = self.label_decoder(Z_Q_preds) 111 | 112 | return Y_Q_pred 113 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from einops import rearrange, repeat 3 | 4 | from .transformers.factory import create_model 5 | from .transformers.custom_layers import Identity 6 | 7 | 8 | class ViTEncoder(nn.Module): 9 | ''' 10 | Vision Transformer encoder wrapper 11 | ''' 12 | def __init__(self, config, backbone, pretrained, in_chans, **kwargs): 13 | super().__init__() 14 | self.backbone = create_model( 15 | backbone, 16 | config=config, 17 | pretrained=pretrained, 18 | in_chans=in_chans, 19 | global_pool='', 20 | num_classes=0, 21 | **kwargs 22 | ) 23 | self.grid_size = self.backbone.patch_embed.grid_size 24 | self.backbone.norm = Identity() 25 | self.feature_idxs = [level * (len(self.backbone.blocks) // 4) - 1 26 | for level in range(1, 5)] 27 | 28 | def bias_parameters(self): 29 | for name, p in self.backbone.named_parameters(): 30 | if name.endswith('bias') and p.ndim == 2: 31 | yield p 32 | 33 | def bias_parameter_names(self): 34 | names = [] 35 | for name, p in self.backbone.named_parameters(): 36 | if name.endswith('bias') and p.ndim == 2: 37 | names.append(f'backbone.{name}') 38 | return names 39 | 40 | def relpos_parameters(self): 41 | for name, p in self.backbone.named_parameters(): 42 | if name.endswith('relative_position_bias_table'): 43 | yield p 44 | 45 | def relpos_parameter_names(self): 46 | names = [] 47 | for name, p in self.backbone.named_parameters(): 48 | if name.endswith('relative_position_bias_table'): 49 | names.append(f'backbone.{name}') 50 | return names 51 | 52 | def forward(self, x, t_idx=None): 53 | ''' 54 | [input] 55 | x: (B, T, N, C, H, W) 56 | t_idx: None or (B, T) 57 | [output] 58 | features: dict of (B, T, N, hw+1, d) 59 | ''' 60 | B, T, N = x.shape[:3] 61 | 62 | # flatten tensors 63 | x = rearrange(x, 'B T N C H W -> (B T N) C H W').contiguous() 64 | 65 | # repeat task index for shots 66 | if t_idx is not None: 67 | t_idx = repeat(t_idx, 'B T -> (B T N)', N=N) 68 | 69 | features = self.backbone.forward_features(x, feature_idxs=self.feature_idxs, b_idx=t_idx) 70 | 71 | features = [rearrange(feat, '(B T N) n d -> B T N n d', B=B, T=T, N=N) for feat in features] 72 | 73 | return features 74 | -------------------------------------------------------------------------------- /model/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from einops import rearrange 6 | 7 | 8 | class CrossAttention(nn.Module): 9 | ''' 10 | Multi-Head Cross-Attention layer for Matching 11 | ''' 12 | def __init__(self, dim_q, dim_v, dim_o, num_heads=16, temperature=-1, dr=0.1): 13 | super().__init__() 14 | 15 | self.dim_q = dim_q 16 | 17 | # heads and temperature 18 | self.num_heads = num_heads 19 | self.dim_split_q = dim_q // num_heads 20 | self.dim_split_v = dim_o // num_heads 21 | if temperature > 0: 22 | self.temperature = temperature 23 | else: 24 | self.temperature = math.sqrt(dim_o) 25 | 26 | # projection layers 27 | self.fc_q = nn.Linear(dim_q, dim_q, bias=False) 28 | self.fc_k = nn.Linear(dim_q, dim_q, bias=False) 29 | self.fc_v = nn.Linear(dim_v, dim_o, bias=False) 30 | self.fc_o = nn.Linear(dim_o, dim_o, bias=False) 31 | 32 | # nonlinear activation and dropout 33 | self.activation = nn.GELU() 34 | self.attn_dropout = nn.Dropout(dr) 35 | 36 | # layernorm layers 37 | self.pre_ln_q = self.pre_ln_k = nn.LayerNorm(dim_q) 38 | self.ln = nn.LayerNorm(dim_o) 39 | 40 | def forward(self, Q, K, V, mask=None, get_attn_map=False): 41 | # pre-layer normalization 42 | Q = self.pre_ln_q(Q) 43 | K = self.pre_ln_k(K) 44 | 45 | # lienar projection 46 | Q = self.fc_q(Q) 47 | K = self.fc_k(K) 48 | V = self.fc_v(V) 49 | 50 | # split into multiple heads 51 | Q_ = torch.cat(Q.split(self.dim_split_q, 2), 0) 52 | K_ = torch.cat(K.split(self.dim_split_q, 2), 0) 53 | V_ = torch.cat(V.split(self.dim_split_v, 2), 0) 54 | 55 | # scaled dot-product attention with mask and dropout 56 | L = Q_.bmm(K_.transpose(1, 2)) / self.temperature 57 | L = L.clip(-1e4, 1e4) 58 | 59 | # mask 60 | if mask is not None: 61 | L = L.masked_fill(~mask, -float('inf')) 62 | 63 | A = L.softmax(dim=2) 64 | if mask is not None: 65 | A.masked_fill(~mask, 0) 66 | A = self.attn_dropout(A) 67 | 68 | # apply attention to values 69 | O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) 70 | 71 | # layer normalization 72 | O = self.ln(O) 73 | 74 | # residual connection with non-linearity 75 | O = O + self.activation(self.fc_o(O)) 76 | 77 | if get_attn_map: 78 | return O, A 79 | else: 80 | return O 81 | 82 | 83 | 84 | class MatchingModule(nn.Module): 85 | def __init__(self, dim_w, dim_z, n_heads=16, alpha_init=0, alpha_temp=0.05, n_alphas=1): 86 | super().__init__() 87 | self.matching = nn.ModuleList([CrossAttention(dim_w, dim_z, dim_z, num_heads=n_heads) 88 | for _ in range(4)]) 89 | 90 | self.alpha = nn.ParameterList([nn.Parameter(alpha_init*F.one_hot(torch.tensor([level]*n_alphas), 4).float()) for level in range(4)]) 91 | self.alpha_temp = alpha_temp 92 | self.layernorm = nn.LayerNorm(dim_w) 93 | 94 | def forward(self, W_Qs, W_Ss, Z_Ss, a_idx, get_attn_map=False): 95 | B, T, N = W_Qs[0].shape[:3] 96 | 97 | W_Qs = torch.stack([self.layernorm(W_Qs[level]) for level in range(4)]) 98 | W_Ss = torch.stack([self.layernorm(W_Ss[level]) for level in range(4)]) 99 | W_Qs_mix = [] 100 | W_Ss_mix = [] 101 | for level in range(4): 102 | alpha = (self.alpha[level][a_idx] / self.alpha_temp).softmax(dim=-1) 103 | if a_idx.ndim == 2: 104 | alpha = rearrange(alpha, 'B T L-> L B T 1 1 1') 105 | else: 106 | alpha = rearrange(alpha, 'B L-> L B 1 1 1 1') 107 | W_Qs_mix.append((alpha * W_Qs).sum(dim=0)) 108 | W_Ss_mix.append((alpha * W_Ss).sum(dim=0)) 109 | 110 | Z_Qs = [] 111 | if get_attn_map: 112 | As = [] 113 | for level in range(4): 114 | # drop the cls token 115 | Q = rearrange(W_Qs_mix[level], 'B T N n d -> (B T) (N n) d') 116 | K = rearrange(W_Ss_mix[level], 'B T N n d -> (B T) (N n) d') 117 | V = rearrange(Z_Ss[level], 'B T N n d -> (B T) (N n) d') 118 | 119 | O = self.matching[level](Q, K, V, get_attn_map=get_attn_map) 120 | if get_attn_map: 121 | O, A = O 122 | A = rearrange(A, '(nh B T) ... -> B T ... nh', B=B, T=T) 123 | As.append(A) 124 | 125 | Z_Q = rearrange(O, '(B T) (N n) d -> B T N n d', B=B, T=T, N=N) 126 | Z_Qs.append(Z_Q) 127 | 128 | if get_attn_map: 129 | return Z_Qs, As 130 | else: 131 | return Z_Qs 132 | -------------------------------------------------------------------------------- /model/model_factory.py: -------------------------------------------------------------------------------- 1 | from .chameleon import Chameleon 2 | from meta_train.unified import Unified 3 | from downstream.davis2017.dataset import DAVIS2017 4 | 5 | 6 | def get_model(config, verbose=False): 7 | # set number of tasks for bitfit 8 | if config.stage == 0: 9 | n_tasks = len(Unified.TASKS) 10 | n_task_groups = len(Unified.TASK_GROUP_NAMES) 11 | else: 12 | if config.dataset == 'ap10k': 13 | n_tasks = 17 # number of joints 14 | elif config.dataset == 'davis2017': 15 | n_tasks = DAVIS2017.NUM_INSTANCES[DAVIS2017.CLASS_NAMES.index(config.class_name)] # number of instances 16 | elif config.dataset == 'linemod': 17 | if config.task == 'segment_semantic': 18 | n_tasks = 1 # segmentation 19 | else: 20 | n_tasks = 4 # segmentation, u, v, w 21 | elif config.dataset == 'fsc147': 22 | n_tasks = 1 # density map 23 | elif config.dataset == 'cellpose': 24 | n_tasks = 3 # u, v, segmentation 25 | else: 26 | n_tasks = 1 27 | n_task_groups = 1 28 | 29 | model = Chameleon(config, n_tasks, n_task_groups) 30 | 31 | if verbose: 32 | print(f'Registered Chameleon with {n_tasks} task-specific and {n_task_groups} group-specific parameters.') 33 | 34 | return model 35 | -------------------------------------------------------------------------------- /model/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from .vision_transformer import * 2 | from .beit import * 3 | 4 | from .factory import create_model, parse_model_name, safe_model_name 5 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 6 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ 7 | is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value -------------------------------------------------------------------------------- /model/transformers/custom_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers.helpers import to_2tuple 5 | from einops import repeat, rearrange 6 | from timm.models.layers.trace_utils import _assert 7 | 8 | 9 | class Identity(nn.Module): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | 13 | def forward(self, *args, **kwargs): 14 | return args[0] 15 | 16 | 17 | class Linear(nn.Linear): 18 | """ 19 | Bias-Switching Linear layer 20 | """ 21 | def __init__(self, n_bias_sets=0, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | if self.bias is None: 24 | n_bias_sets = 0 25 | 26 | self.n_bias_sets = n_bias_sets 27 | if self.n_bias_sets > 0: 28 | assert self.bias is not None 29 | self.bias = nn.Parameter(repeat(self.bias.data, '... -> T ...', T=n_bias_sets).contiguous()) 30 | 31 | def forward(self, input, b_idx=None): 32 | if self.n_bias_sets > 0: 33 | assert b_idx is not None 34 | output = F.linear(input, self.weight, None) 35 | bias = self.bias[b_idx][:, None] 36 | return output + bias 37 | else: 38 | return F.linear(input, self.weight, self.bias) 39 | 40 | 41 | class LayerNorm(nn.LayerNorm): 42 | """ 43 | Bias-Switching LayerNorm 44 | """ 45 | def __init__(self, n_bias_sets=0, *args, **kwargs): 46 | super().__init__(*args, **kwargs) 47 | if self.bias is None: 48 | n_bias_sets = 0 49 | 50 | self.n_bias_sets = n_bias_sets 51 | if self.n_bias_sets > 0: 52 | assert self.elementwise_affine 53 | self.bias = nn.Parameter(repeat(self.bias.data, '... -> T ...', T=n_bias_sets).contiguous()) 54 | 55 | def forward(self, input, b_idx=None): 56 | if self.n_bias_sets > 0: 57 | assert b_idx is not None 58 | output = F.layer_norm(input, self.normalized_shape, self.weight, None, self.eps) 59 | if b_idx.ndim == 1: 60 | bias = self.bias[b_idx] 61 | for _ in range(output.ndim - 2): 62 | bias = bias[:, None] 63 | else: 64 | assert False 65 | bias_mh = torch.stack(self.bias.split(self.bias.shape[1] // b_idx.shape[1], dim=1), 0) 66 | bias = torch.einsum('bhn,hnd->bhd', b_idx, bias_mh) 67 | bias = rearrange(bias, 'B h d -> B 1 (h d)') 68 | return output + bias 69 | else: 70 | return F.layer_norm( 71 | input, self.normalized_shape, self.weight, self.bias, self.eps) 72 | 73 | 74 | class Conv2d(nn.Conv2d): 75 | """ 76 | Bias-Switching Conv2d layer 77 | """ 78 | def __init__(self, n_bias_sets=0, *args, **kwargs): 79 | super().__init__(*args, **kwargs) 80 | if self.bias is None: 81 | n_bias_sets = 0 82 | 83 | self.n_bias_sets = n_bias_sets 84 | if self.n_bias_sets > 0: 85 | assert self.bias is not None 86 | self.bias = nn.Parameter(repeat(self.bias.data, '... -> T ...', T=n_bias_sets).contiguous()) 87 | 88 | def forward(self, input, b_idx=None): 89 | if self.n_bias_sets > 0: 90 | assert b_idx is not None 91 | output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups) 92 | if b_idx.ndim == 1: 93 | bias = self.bias[b_idx][:, :, None, None] 94 | else: 95 | raise NotImplementedError 96 | 97 | return output + bias 98 | else: 99 | return F.conv2d(input, self.weight, self.bias, 100 | self.stride, self.padding, self.dilation, self.groups) 101 | 102 | 103 | class Mlp(nn.Module): 104 | """ 105 | Bias-Switching MLP as used in Vision Transformer, MLP-Mixer and related networks 106 | """ 107 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, 108 | drop=0., n_bias_sets=0): 109 | super().__init__() 110 | out_features = out_features or in_features 111 | hidden_features = hidden_features or in_features 112 | bias = to_2tuple(bias) 113 | drop_probs = to_2tuple(drop) 114 | 115 | self.fc1 = Linear(n_bias_sets, in_features, hidden_features, bias=bias[0]) 116 | self.act = act_layer() 117 | self.drop1 = nn.Dropout(drop_probs[0]) 118 | self.fc2 = Linear(n_bias_sets, hidden_features, out_features, bias=bias[1]) 119 | self.drop2 = nn.Dropout(drop_probs[1]) 120 | 121 | def forward(self, x, b_idx=None): 122 | x = self.fc1(x, b_idx=b_idx) 123 | x = self.act(x) 124 | x = self.drop1(x) 125 | x = self.fc2(x, b_idx=b_idx) 126 | x = self.drop2(x) 127 | return x 128 | 129 | 130 | class Sequential(nn.Sequential): 131 | def forward(self, *inputs): 132 | for module in self: 133 | input = module(*inputs) 134 | return input 135 | 136 | 137 | class PatchEmbed(nn.Module): 138 | """ 2D Image to Patch Embedding 139 | """ 140 | def __init__( 141 | self, 142 | img_size=224, 143 | patch_size=16, 144 | in_chans=3, 145 | embed_dim=768, 146 | norm_layer=None, 147 | flatten=True, 148 | bias=True, 149 | ): 150 | super().__init__() 151 | img_size = to_2tuple(img_size) 152 | patch_size = to_2tuple(patch_size) 153 | self.img_size = img_size 154 | self.patch_size = patch_size 155 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 156 | self.num_patches = self.grid_size[0] * self.grid_size[1] 157 | self.flatten = flatten 158 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 159 | self.proj_switching = False 160 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 161 | 162 | def forward(self, x): 163 | B, C, H, W = x.shape 164 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 165 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 166 | if C != self.proj.in_channels: 167 | assert C % self.proj.in_channels == 0 168 | x = rearrange(x, 'B (N C) H W -> (B N) C H W', C=self.proj.in_channels) 169 | x = self.proj(x) 170 | if C != self.proj.in_channels: 171 | x = rearrange(x, '(B N) C H W -> B C (N H) W', N=(C // self.proj.in_channels)) 172 | if self.flatten: 173 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 174 | x = self.norm(x) 175 | return x 176 | -------------------------------------------------------------------------------- /model/transformers/factory.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlsplit, urlunsplit 2 | import os 3 | 4 | from .registry import is_model, is_model_in_modules, model_entrypoint 5 | from .helpers import load_checkpoint 6 | from timm.models.layers import set_layer_config 7 | from timm.models.hub import load_model_config_from_hf 8 | 9 | 10 | def parse_model_name(model_name): 11 | model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use 12 | parsed = urlsplit(model_name) 13 | assert parsed.scheme in ('', 'timm', 'hf-hub') 14 | if parsed.scheme == 'hf-hub': 15 | # FIXME may use fragment as revision, currently `@` in URI path 16 | return parsed.scheme, parsed.path 17 | else: 18 | model_name = os.path.split(parsed.path)[-1] 19 | return 'timm', model_name 20 | 21 | 22 | def safe_model_name(model_name, remove_source=True): 23 | def make_safe(name): 24 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 25 | if remove_source: 26 | model_name = parse_model_name(model_name)[-1] 27 | return make_safe(model_name) 28 | 29 | 30 | def create_model( 31 | model_name, 32 | pretrained=False, 33 | pretrained_cfg=None, 34 | checkpoint_path='', 35 | scriptable=None, 36 | exportable=None, 37 | no_jit=None, 38 | **kwargs): 39 | """Create a model 40 | Args: 41 | model_name (str): name of model to instantiate 42 | pretrained (bool): load pretrained ImageNet-1k weights if true 43 | checkpoint_path (str): path of checkpoint to load after model is initialized 44 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 45 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 46 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 47 | Keyword Args: 48 | drop_rate (float): dropout rate for training (default: 0.0) 49 | global_pool (str): global pool type (default: 'avg') 50 | **: other kwargs are model specific 51 | """ 52 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 53 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 54 | # non-supporting models don't break and default args remain in effect. 55 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 56 | 57 | model_source, model_name = parse_model_name(model_name) 58 | if model_source == 'hf-hub': 59 | # FIXME hf-hub source overrides any passed in pretrained_cfg, warn? 60 | # For model names specified in the form `hf-hub:path/architecture_name@revision`, 61 | # load model weights + pretrained_cfg from Hugging Face hub. 62 | pretrained_cfg, model_name = load_model_config_from_hf(model_name) 63 | 64 | if not is_model(model_name): 65 | raise RuntimeError('Unknown model (%s)' % model_name) 66 | 67 | create_fn = model_entrypoint(model_name) 68 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 69 | model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) 70 | 71 | if checkpoint_path: 72 | load_checkpoint(model, checkpoint_path) 73 | 74 | return model -------------------------------------------------------------------------------- /model/transformers/registry.py: -------------------------------------------------------------------------------- 1 | """ Model Registry 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | 5 | import sys 6 | import re 7 | import fnmatch 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | 11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 12 | 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] 13 | 14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 15 | _model_to_module = {} # mapping of model names to module names 16 | _model_entrypoints = {} # mapping of model names to entrypoint fns 17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 18 | _model_pretrained_cfgs = dict() # central repo for model default_cfgs 19 | 20 | 21 | def register_model(fn): 22 | # lookup containing module 23 | mod = sys.modules[fn.__module__] 24 | module_name_split = fn.__module__.split('.') 25 | module_name = module_name_split[-1] if len(module_name_split) else '' 26 | 27 | # add model to __all__ in module 28 | model_name = fn.__name__ 29 | if hasattr(mod, '__all__'): 30 | mod.__all__.append(model_name) 31 | else: 32 | mod.__all__ = [model_name] 33 | 34 | # add entries to registry dict/sets 35 | _model_entrypoints[model_name] = fn 36 | _model_to_module[model_name] = module_name 37 | _module_to_models[module_name].add(model_name) 38 | has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this 39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 41 | # entrypoints or non-matching combos 42 | cfg = mod.default_cfgs[model_name] 43 | has_valid_pretrained = ( 44 | ('url' in cfg and 'http' in cfg['url']) or 45 | ('file' in cfg and cfg['file']) or 46 | ('hf_hub_id' in cfg and cfg['hf_hub_id']) 47 | ) 48 | _model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name] 49 | if has_valid_pretrained: 50 | _model_has_pretrained.add(model_name) 51 | return fn 52 | 53 | 54 | def _natural_key(string_): 55 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 56 | 57 | 58 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): 59 | """ Return list of available model names, sorted alphabetically 60 | Args: 61 | filter (str) - Wildcard filter string that works with fnmatch 62 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 63 | pretrained (bool) - Include only models with pretrained weights if True 64 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 65 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) 66 | Example: 67 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 68 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 69 | """ 70 | if module: 71 | all_models = list(_module_to_models[module]) 72 | else: 73 | all_models = _model_entrypoints.keys() 74 | if filter: 75 | models = [] 76 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter] 77 | for f in include_filters: 78 | include_models = fnmatch.filter(all_models, f) # include these models 79 | if len(include_models): 80 | models = set(models).union(include_models) 81 | else: 82 | models = all_models 83 | if exclude_filters: 84 | if not isinstance(exclude_filters, (tuple, list)): 85 | exclude_filters = [exclude_filters] 86 | for xf in exclude_filters: 87 | exclude_models = fnmatch.filter(models, xf) # exclude these models 88 | if len(exclude_models): 89 | models = set(models).difference(exclude_models) 90 | if pretrained: 91 | models = _model_has_pretrained.intersection(models) 92 | if name_matches_cfg: 93 | models = set(_model_pretrained_cfgs).intersection(models) 94 | return list(sorted(models, key=_natural_key)) 95 | 96 | 97 | def is_model(model_name): 98 | """ Check if a model name exists 99 | """ 100 | return model_name in _model_entrypoints 101 | 102 | 103 | def model_entrypoint(model_name): 104 | """Fetch a model entrypoint for specified model name 105 | """ 106 | return _model_entrypoints[model_name] 107 | 108 | 109 | def list_modules(): 110 | """ Return list of module names that contain models / model entrypoints 111 | """ 112 | modules = _module_to_models.keys() 113 | return list(sorted(modules)) 114 | 115 | 116 | def is_model_in_modules(model_name, module_names): 117 | """Check if a model exists within a subset of modules 118 | Args: 119 | model_name (str) - name of model to check 120 | module_names (tuple, list, set) - names of modules to search in 121 | """ 122 | assert isinstance(module_names, (tuple, list, set)) 123 | return any(model_name in _module_to_models[n] for n in module_names) 124 | 125 | 126 | def is_model_pretrained(model_name): 127 | return model_name in _model_has_pretrained 128 | 129 | 130 | def get_pretrained_cfg(model_name): 131 | if model_name in _model_pretrained_cfgs: 132 | return deepcopy(_model_pretrained_cfgs[model_name]) 133 | return {} 134 | 135 | 136 | def has_pretrained_cfg_key(model_name, cfg_key): 137 | """ Query model default_cfgs for existence of a specific key. 138 | """ 139 | if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]: 140 | return True 141 | return False 142 | 143 | 144 | def is_pretrained_cfg_key(model_name, cfg_key): 145 | """ Return truthy value for specified model default_cfg key, False if does not exist. 146 | """ 147 | if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False): 148 | return True 149 | return False 150 | 151 | 152 | def get_pretrained_cfg_value(model_name, cfg_key): 153 | """ Get a specific model default_cfg value by key. None if it doesn't exist. 154 | """ 155 | if model_name in _model_pretrained_cfgs: 156 | return _model_pretrained_cfgs[model_name].get(cfg_key, None) 157 | return None -------------------------------------------------------------------------------- /preprocess_checkpoints.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from train.zero_to_fp32 import _get_fp32_state_dict_from_zero_checkpoint 5 | 6 | 7 | def reduce_checkpoint(ckpt_path, reduced_path, verbose=True): 8 | # load state dict and config 9 | if os.path.isdir(ckpt_path): 10 | state_dict = _get_fp32_state_dict_from_zero_checkpoint(os.path.join(ckpt_path, 'checkpoint')) 11 | state_dict = {k.replace('_forward_module.', ''): v for k, v in state_dict.items()} 12 | for k in list(state_dict.keys()): 13 | if len(k.split('.')) > 4 and k.split('.')[1] == 'matching_module' and k.split('.')[4] == 'pre_ln_q': 14 | state_dict[k.replace('pre_ln_q', 'pre_ln_k')] = state_dict[k] 15 | ckpt = torch.load(os.path.join(ckpt_path, 'checkpoint', 'mp_rank_00_model_states.pt'), map_location='cpu') 16 | 17 | # add ema parameters 18 | for key in ckpt['module']: 19 | if 'ema' in key: 20 | state_dict[key.replace('_forward_module.', '')] = ckpt['module'][key] 21 | else: 22 | ckpt = torch.load(ckpt_path, map_location='cpu') 23 | state_dict = ckpt['state_dict'] 24 | 25 | # reduce memory 26 | ckpt_reduced = {} 27 | ckpt_reduced['state_dict'] = state_dict 28 | ckpt_reduced['config'] = ckpt['hyper_parameters']['config'] 29 | ckpt_reduced['global_step'] = ckpt['global_step'] 30 | torch.save(ckpt_reduced, reduced_path) 31 | if verbose: 32 | print(f'checkpoint converted to memory-reduced checkpoint: {ckpt_path}') 33 | 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--root_dir', type=str, default=None) 37 | parser.add_argument('--exp_name', type=str, default=None) 38 | parser.add_argument('--ckpt_name', '-cname', type=str, default=None) 39 | parser.add_argument('--load_dir', '-ld', type=str, default='TRAIN') 40 | parser.add_argument('--verbose', '-v', default=False, action='store_true') 41 | parser.add_argument('--reset_mode', '-reset', default=False, action='store_true') 42 | args = parser.parse_args() 43 | 44 | 45 | if args.root_dir is None: 46 | root_dir = 'experiments' 47 | else: 48 | root_dir = args.root_dir 49 | 50 | if args.exp_name is None: 51 | exp_names = sorted(os.listdir(os.path.join(root_dir, args.load_dir))) 52 | else: 53 | exp_names = [args.exp_name] 54 | 55 | for exp_name in exp_names: 56 | if args.ckpt_name is None: 57 | ckpt_names = sorted(os.listdir(os.path.join(root_dir, args.load_dir, exp_name, 'checkpoints'))) 58 | ckpt_names = [name for name in ckpt_names if 'best' not in name and 'last' not in name] 59 | else: 60 | ckpt_names = [args.ckpt_name] 61 | 62 | for ckpt_name in ckpt_names: 63 | ckpt_path = os.path.join(root_dir, args.load_dir, exp_name, 'checkpoints', f'{ckpt_name}') 64 | reduced_path = ckpt_path.replace('.ckpt', '.pth') 65 | if not os.path.exists(ckpt_path): 66 | if args.verbose: 67 | print(f'checkpoint not found: {ckpt_path}') 68 | continue 69 | if os.path.exists(reduced_path) and not args.reset_mode: 70 | if args.verbose: 71 | print(f'checkpoint already exists: {reduced_path}') 72 | continue 73 | 74 | reduce_checkpoint(ckpt_path, reduced_path, args.verbose) 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.1 2 | pytorch_lightning==2.4.0 3 | lightning_habana==1.6.0 4 | deepspeed==0.15.4 5 | timm==0.6.12 6 | numpy 7 | scikit-image 8 | easydict 9 | tensorboard 10 | tqdm 11 | einops 12 | pycocotools 13 | xtcocotools 14 | fastremap 15 | numba 16 | opencv-python 17 | flow_vis 18 | plyfile 19 | -------------------------------------------------------------------------------- /scripts/ap10k/finetune.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | 4 | cmd="python main.py --stage 1 --dataset ap10k --exp_name ${exp_name} --class_name ${class_name} ${@:3:$#}" 5 | 6 | echo $cmd 7 | eval $cmd 8 | -------------------------------------------------------------------------------- /scripts/ap10k/finetune_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "Antelope" "Cat" "Elephant" "Giraffe" "Hippo" "Horse" "Mouse" "Pig" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/ap10k/finetune.sh ${exp_name} ${class_name} ${@:2:$#} 6 | done 7 | -------------------------------------------------------------------------------- /scripts/ap10k/run.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | 4 | bash scripts/ap10k/finetune.sh $exp_name $class_name ${@:3:$#} 5 | bash scripts/ap10k/test.sh $exp_name $class_name ${@:3:$#} 6 | -------------------------------------------------------------------------------- /scripts/ap10k/run_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "Antelope" "Cat" "Elephant" "Giraffe" "Hippo" "Horse" "Mouse" "Pig" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/ap10k/finetune.sh ${exp_name} ${class_name} ${@:2:$#} 6 | bash scripts/ap10k/test.sh ${exp_name} ${class_name} ${@:2:$#} 7 | done 8 | bash scripts/davis2017/davis_evaluation.sh 9 | -------------------------------------------------------------------------------- /scripts/ap10k/test.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | exp_subname="task:animalkp_class:${class_name}_shot:20_is:224_lr:0.001_sid:0" 4 | 5 | cmd="python main.py --stage 2 --dataset ap10k --exp_name ${exp_name} --class_name ${class_name} --exp_subname ${exp_subname} ${@:3:$#}" 6 | 7 | echo $cmd 8 | eval $cmd 9 | -------------------------------------------------------------------------------- /scripts/ap10k/test_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "Antelope" "Cat" "Elephant" "Giraffe" "Hippo" "Horse" "Mouse" "Pig" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/ap10k/test.sh ${exp_name} ${class_name} ${@:2:$#} 6 | done 7 | -------------------------------------------------------------------------------- /scripts/cellpose/finetune.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | 3 | cmd="python main.py --stage 1 --dataset cellpose --exp_name ${exp_name} ${@:2:$#}" 4 | 5 | echo $cmd 6 | eval $cmd 7 | -------------------------------------------------------------------------------- /scripts/cellpose/run.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | 3 | bash scripts/cellpose/finetune.sh $exp_name ${@:2:$#} 4 | bash scripts/cellpose/test.sh $exp_name ${@:2:$#} 5 | -------------------------------------------------------------------------------- /scripts/cellpose/test.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | exp_subname="task:cellpose_shot:50_is:224_lr:0.003_sid:0" 3 | 4 | cmd="python main.py --stage 2 --dataset cellpose --exp_name ${exp_name} --exp_subname ${exp_subname} ${@:2:$#}" 5 | 6 | echo $cmd 7 | eval $cmd 8 | -------------------------------------------------------------------------------- /scripts/davis2017/davis_evaluation.sh: -------------------------------------------------------------------------------- 1 | cmd="python davis2016-evaluation/test.py --dataset davis2017" 2 | echo $cmd 3 | eval $cmd 4 | -------------------------------------------------------------------------------- /scripts/davis2017/finetune.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | lr=$3 4 | 5 | cmd="python main.py --stage 1 --dataset davis2017 --exp_name ${exp_name} --class_name ${class_name} --lr ${lr} ${@:4:$#}" 6 | 7 | echo $cmd 8 | eval $cmd 9 | -------------------------------------------------------------------------------- /scripts/davis2017/finetune_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "bike-packing" "blackswan" "bmx-trees" "breakdance" "camel" "car-roundabout" "car-shadow" "cows" "dance-twirl" "dog" "dogs-jump" "drift-chicane" "drift-straight" "goat" "gold-fish" "horsejump-high" "india" "judo" "kite-surf" "lab-coat" "libby" "loading" "mbike-trick" "motocross-jump" "paragliding-launch" "parkour" "pigs" "scooter-black" "shooting" "soapbox" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | # bash scripts/davis2017/finetune.sh ${exp_name} ${class_name} ${@:2:$#} 6 | echo scripts/davis2017/finetune.sh ${exp_name} ${class_name} ${@:2:$#} 7 | done 8 | -------------------------------------------------------------------------------- /scripts/davis2017/run.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | 4 | bash scripts/davis2017/finetune.sh $exp_name $class_name ${@:3:$#} 5 | bash scripts/davis2017/test.sh $exp_name $class_name ${@:3:$#} 6 | -------------------------------------------------------------------------------- /scripts/davis2017/run_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "bike-packing" "blackswan" "bmx-trees" "breakdance" "camel" "car-roundabout" "car-shadow" "cows" "dance-twirl" "dog" "dogs-jump" "drift-chicane" "drift-straight" "goat" "gold-fish" "horsejump-high" "india" "judo" "kite-surf" "lab-coat" "libby" "loading" "mbike-trick" "motocross-jump" "paragliding-launch" "parkour" "pigs" "scooter-black" "shooting" "soapbox" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/davis2017/finetune.sh ${exp_name} ${class_name} ${@:2:$#} 6 | bash scripts/davis2017/test.sh ${exp_name} ${class_name} ${@:2:$#} 7 | done 8 | -------------------------------------------------------------------------------- /scripts/davis2017/run_all_in_one.sh: -------------------------------------------------------------------------------- 1 | bash scripts/davis2017/run_all.sh 2 | bash scripts/davis2017/davis_evaluation.sh 3 | -------------------------------------------------------------------------------- /scripts/davis2017/test.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | lr=$3 4 | snptf=$4 5 | exp_subname="task:vos_class:${class_name}_shot:1_is:384_lr:${lr}_sid:0${snptf}" 6 | 7 | cmd="python main.py --stage 2 --dataset davis2017 --exp_name ${exp_name} --class_name ${class_name} --exp_subname ${exp_subname} --result_dir results_davis2017_lr:${lr}${snptf} ${@:5:$#}" 8 | 9 | echo $cmd 10 | eval $cmd 11 | -------------------------------------------------------------------------------- /scripts/davis2017/test_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "bike-packing" "blackswan" "bmx-trees" "breakdance" "camel" "car-roundabout" "car-shadow" "cows" "dance-twirl" "dog" "dogs-jump" "drift-chicane" "drift-straight" "goat" "gold-fish" "horsejump-high" "india" "judo" "kite-surf" "lab-coat" "libby" "loading" "mbike-trick" "motocross-jump" "paragliding-launch" "parkour" "pigs" "scooter-black" "shooting" "soapbox" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/davis2017/test.sh ${exp_name} ${class_name} ${@:2:$#} 6 | done 7 | bash scripts/davis2017/davis_evaluation.sh 8 | -------------------------------------------------------------------------------- /scripts/fsc147/finetune.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | 3 | cmd="python main.py --stage 1 --dataset fsc147 --exp_name ${exp_name} ${@:2:$#}" 4 | 5 | echo $cmd 6 | eval $cmd 7 | -------------------------------------------------------------------------------- /scripts/fsc147/run.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | 3 | bash scripts/fsc147/finetune.sh $exp_name ${@:2:$#} 4 | bash scripts/fsc147/test.sh $exp_name ${@:2:$#} 5 | -------------------------------------------------------------------------------- /scripts/fsc147/test.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | exp_subname="task:object_counting_shot:50_is:512_lr:0.001_sid:0" 3 | 4 | cmd="python main.py --stage 2 --dataset fsc147 --exp_name ${exp_name} --exp_subname ${exp_subname} ${@:2:$#}" 5 | 6 | echo $cmd 7 | eval $cmd 8 | -------------------------------------------------------------------------------- /scripts/isic2018/finetune.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | support_idx=$2 3 | 4 | cmd="python main.py --stage 1 --dataset isic2018 --exp_name ${exp_name} --support_idx ${support_idx} ${@:3:$#}" 5 | 6 | echo $cmd 7 | eval $cmd 8 | -------------------------------------------------------------------------------- /scripts/isic2018/finetune_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | support_idxs=( "0" "1" "2" "3" "4" ) 3 | for support_idx in ${support_idxs[@]}; 4 | do 5 | bash scripts/ap10k/finetune.sh ${exp_name} ${support_idx} ${@:2:$#} 6 | done 7 | -------------------------------------------------------------------------------- /scripts/isic2018/run.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | support_idx=$2 3 | 4 | bash scripts/isic2018/finetune.sh $exp_name $support_idx ${@:3:$#} 5 | bash scripts/isic2018/test.sh $exp_name $support_idx ${@:3:$#} 6 | -------------------------------------------------------------------------------- /scripts/isic2018/test.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | support_idx=$2 3 | exp_subname="task:segment_medical_shot:20_is:384_lr:0.001_sid:${support_idx}" 4 | 5 | cmd="python main.py --stage 2 --dataset isic2018 --exp_name ${exp_name} --support_idx ${support_idx} --exp_subname ${exp_subname} ${@:3:$#}" 6 | 7 | echo $cmd 8 | eval $cmd 9 | -------------------------------------------------------------------------------- /scripts/linemod/finetune_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "ape" "benchviseblue" "cam" "can" "cat" "driller" "duck" "eggbox" "glue" "holepuncher" "iron" "lamp" "phone" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/linemod/finetune.sh ${exp_name} ${class_name} ${@:2:$#} 6 | done 7 | -------------------------------------------------------------------------------- /scripts/linemod/finetune_pose.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | 4 | cmd="python main.py --stage 1 --dataset linemod --exp_name ${exp_name} --class_name ${class_name} ${@:3:$#}" 5 | 6 | echo $cmd 7 | eval $cmd 8 | -------------------------------------------------------------------------------- /scripts/linemod/finetune_segment.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | 4 | cmd="python main.py --stage 1 --dataset linemod --exp_name ${exp_name} --class_name ${class_name} --task segment_semantic --loss_type bce --monitor IoU_inverted ${@:3:$#}" 5 | 6 | echo $cmd 7 | eval $cmd 8 | -------------------------------------------------------------------------------- /scripts/linemod/finetune_segment_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "ape" "benchviseblue" "cam" "can" "cat" "driller" "duck" "eggbox" "glue" "holepuncher" "iron" "lamp" "phone" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/linemod/finetune_segment.sh ${exp_name} ${class_name} ${@:2:$#} 6 | echo $class_name 7 | done 8 | -------------------------------------------------------------------------------- /scripts/linemod/run.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | 4 | bash scripts/linemod/run_segment.sh $exp_name $class_name ${@:3:$#} 5 | bash scripts/linemod/run_pose.sh $exp_name $class_name ${@:3:$#} 6 | -------------------------------------------------------------------------------- /scripts/linemod/run_all_in_one.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "ape" "benchviseblue" "cam" "can" "cat" "driller" "duck" "eggbox" "glue" "holepuncher" "iron" "lamp" "phone" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/linemod/run.sh ${exp_name} ${class_name} ${@:2:$#} 6 | done 7 | -------------------------------------------------------------------------------- /scripts/linemod/run_pose.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | 4 | bash scripts/linemod/finetune_pose.sh $exp_name $class_name ${@:3:$#} 5 | bash scripts/linemod/test_pose.sh $exp_name $class_name ${@:3:$#} 6 | -------------------------------------------------------------------------------- /scripts/linemod/run_pose_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "ape" "benchviseblue" "cam" "can" "cat" "driller" "duck" "eggbox" "glue" "holepuncher" "iron" "lamp" "phone" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/linemod/finetune.sh ${exp_name} ${class_name} ${@:2:$#} 6 | bash scripts/linemod/test.sh ${exp_name} ${class_name} ${@:2:$#} 7 | done 8 | -------------------------------------------------------------------------------- /scripts/linemod/run_segment.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | 4 | bash scripts/linemod/finetune_segment.sh $exp_name $class_name ${@:3:$#} 5 | bash scripts/linemod/test_segment.sh $exp_name $class_name ${@:3:$#} 6 | -------------------------------------------------------------------------------- /scripts/linemod/run_segment_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "ape" "benchviseblue" "cam" "can" "cat" "driller" "duck" "eggbox" "glue" "holepuncher" "iron" "lamp" "phone" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/linemod/finetune_segment.sh ${exp_name} ${class_name} ${@:2:$#} 6 | bash scripts/linemod/test_segment.sh ${exp_name} ${class_name} ${@:2:$#} 7 | done 8 | -------------------------------------------------------------------------------- /scripts/linemod/test_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "ape" "benchviseblue" "cam" "can" "cat" "driller" "duck" "eggbox" "glue" "holepuncher" "iron" "lamp" "phone" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/linemod/test.sh ${exp_name} ${class_name} ${@:2:$#} 6 | done 7 | -------------------------------------------------------------------------------- /scripts/linemod/test_pose.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | exp_subname="task:pose_6d_class:${class_name}_shot:50_is:224_lr:0.005_sid:0" 4 | exp_subname_coord="task:segment_semantic_class:${class_name}_shot:50_is:224_lr:0.005_sid:0" 5 | coord_path="${exp_name}/${exp_subname_coord}/logs/bbox_${class_name}.npy" 6 | 7 | cmd="python main.py --stage 2 --dataset linemod --exp_name ${exp_name} --class_name ${class_name} --exp_subname ${exp_subname} --coord_path ${coord_path} ${@:3:$#}" 8 | 9 | echo $cmd 10 | eval $cmd 11 | -------------------------------------------------------------------------------- /scripts/linemod/test_segment.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_name=$2 3 | exp_subname="task:segment_semantic_class:${class_name}_shot:50_is:224_lr:0.005_sid:0" 4 | 5 | cmd="python main.py --stage 2 --dataset linemod --task segment_semantic --exp_name ${exp_name} --class_name ${class_name} --exp_subname ${exp_subname} ${@:3:$#}" 6 | 7 | echo $cmd 8 | eval $cmd 9 | -------------------------------------------------------------------------------- /scripts/linemod/test_segment_all.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | class_names=( "ape" "benchviseblue" "cam" "can" "cat" "driller" "duck" "eggbox" "glue" "holepuncher" "iron" "lamp" "phone" ) 3 | for class_name in ${class_names[@]}; 4 | do 5 | bash scripts/linemod/test_segment.sh ${exp_name} ${class_name} ${@:2:$#} 6 | done 7 | -------------------------------------------------------------------------------- /scripts/unified/train.sh: -------------------------------------------------------------------------------- 1 | cmd="python main.py --stage 0" 2 | 3 | echo $cmd 4 | eval $cmd 5 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitGyun/chameleon/dbba367d57301ed22e8a693ec2335ebec728e4a4/train/__init__.py -------------------------------------------------------------------------------- /train/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | 6 | def spatial_softmax_loss(Y_pred, Y, M, reduction='mean', scaled=False): 7 | ''' 8 | Compute spatial softmax loss for AnimalKP. 9 | ''' 10 | if Y_pred.ndim == 6: 11 | M = rearrange(M, 'B T N C H W -> B (H W) T N C') 12 | Y_pred = rearrange(Y_pred, 'B T N C H W -> B (H W) T N C') 13 | Y = rearrange(Y, 'B T N C H W -> B (H W) T N C') 14 | else: 15 | M = rearrange(M, 'B N C H W -> B (H W) N C') 16 | Y_pred = rearrange(Y_pred, 'B N C H W -> B (H W) N C') 17 | Y = rearrange(Y, 'B N C H W -> B (H W) N C') 18 | loss = F.cross_entropy(Y_pred*M, Y*M, reduction='none') 19 | 20 | if reduction == 'mean': 21 | loss = loss.mean() 22 | if scaled: 23 | loss = loss / max(1, Y.sum()) 24 | 25 | return loss 26 | 27 | 28 | def spatio_channel_softmax_loss(Y_pred, Y, M, reduction='mean'): 29 | ''' 30 | Compute spatial softmax loss for AnimalKP. 31 | ''' 32 | assert Y_pred.ndim == 6 33 | 34 | # normalize over channels 35 | Y_pred = Y_pred - torch.logsumexp(Y_pred, dim=1, keepdim=True) 36 | 37 | # normalize over spatial dimensions 38 | Y_pred = rearrange(Y_pred, 'B T N C H W -> B (H W) T N C') 39 | M = rearrange(M, 'B T N C H W -> B (H W) T N C') 40 | Y = rearrange(Y, 'B T N C H W -> B (H W) T N C') 41 | loss = F.cross_entropy(Y_pred*M, Y*M, reduction='none') 42 | 43 | if reduction == 'mean': 44 | loss = loss.mean() 45 | return loss 46 | -------------------------------------------------------------------------------- /train/miou_fss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AverageMeter: 5 | r""" Stores loss, evaluation results """ 6 | def __init__(self, class_ids_interest, semseg_classes=None, device=None): 7 | if device is None: 8 | device = torch.device('cpu') 9 | if isinstance(class_ids_interest, int): 10 | class_ids_interest = [class_ids_interest] 11 | if semseg_classes is None: 12 | semseg_classes = class_ids_interest 13 | 14 | self.device = device 15 | self.class_ids_interest = torch.tensor(class_ids_interest, device=self.device, requires_grad=False) 16 | self.semseg_classes = semseg_classes 17 | self.nclass = len(self.class_ids_interest) 18 | self.reset() 19 | 20 | def reset(self): 21 | self.intersection_buf = torch.zeros([2, self.nclass], device=self.device, requires_grad=False).float() 22 | self.union_buf = torch.zeros([2, self.nclass], device=self.device, requires_grad=False).float() 23 | self.ones = torch.ones_like(self.union_buf, requires_grad=False) 24 | self.loss_buf = [] 25 | 26 | @torch.inference_mode() 27 | def update(self, inter_b, union_b, class_id): 28 | self.intersection_buf.index_add_(1, class_id, inter_b.float()) 29 | self.union_buf.index_add_(1, class_id, union_b.float()) 30 | 31 | def class_iou(self, class_id): 32 | iou = self.intersection_buf.float() / \ 33 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] 34 | iou = iou.index_select(1, torch.tensor([class_id], device=iou.device)) 35 | miou = iou[1].mean() 36 | 37 | return miou 38 | 39 | def compute_iou(self): 40 | iou = self.intersection_buf.float() / \ 41 | torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] 42 | iou = iou.index_select(1, self.class_ids_interest) 43 | miou = iou[1].mean() 44 | 45 | fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) / 46 | self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() 47 | 48 | return miou, fb_iou 49 | 50 | 51 | class Evaluator: 52 | r""" Computes intersection and union between prediction and ground-truth """ 53 | @classmethod 54 | def initialize(cls): 55 | pass 56 | 57 | @classmethod 58 | def classify_prediction(cls, pred_mask, gt_mask): 59 | # compute intersection and union of each episode in a batch 60 | area_inter, area_pred, area_gt = [], [], [] 61 | for _pred_mask, _gt_mask in zip(pred_mask, gt_mask): 62 | _inter = _pred_mask[_pred_mask == _gt_mask] 63 | if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1) 64 | _area_inter = torch.tensor([0, 0], device=_pred_mask.device) 65 | else: 66 | _area_inter = torch.histc(_inter, bins=2, min=0, max=1) 67 | area_inter.append(_area_inter) 68 | area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1)) 69 | area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1)) 70 | area_inter = torch.stack(area_inter).t() 71 | area_pred = torch.stack(area_pred).t() 72 | area_gt = torch.stack(area_gt).t() 73 | area_union = area_pred + area_gt - area_inter 74 | 75 | return area_inter, area_union 76 | -------------------------------------------------------------------------------- /train/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from deepspeed.ops.adam import DeepSpeedCPUAdam 4 | 5 | 6 | optim_dict = { 7 | 'sgd': torch.optim.SGD, 8 | 'adam': torch.optim.Adam, 9 | 'adamw': torch.optim.AdamW, 10 | 'cpuadam': DeepSpeedCPUAdam, 11 | } 12 | 13 | 14 | def get_optimizer(config, model): 15 | learnable_params = [] 16 | 17 | # train all parameters for episodic training 18 | if config.stage == 0: 19 | learnable_params.append({'params': model.pretrained_parameters(), 'lr': config.lr_pretrained}) 20 | learnable_params.append({'params': model.scratch_parameters(), 'lr': config.lr}) 21 | 22 | # train only task-specific parameters for fine-tuning 23 | elif config.stage == 1: 24 | if config.from_scratch: 25 | learnable_params.append({'params': model.pretrained_parameters(), 'lr': config.lr_pretrained}) 26 | learnable_params.append({'params': model.scratch_parameters(), 'lr': config.lr}) 27 | else: 28 | learnable_params.append({'params': model.bias_parameters(), 'lr': config.lr}) 29 | learnable_params.append({'params': model.matching_module.alpha.parameters(), 'lr': config.lr}) 30 | learnable_params.append({'params': model.matching_module.layernorm.parameters(), 'lr': config.lr}) 31 | if config.head_tuning: 32 | learnable_params.append({'params': model.label_decoder.head.parameters(), 'lr': config.lr_pretrained}) 33 | if config.label_decoder_tuning: 34 | learnable_params.append({'params': model.label_decoder.parameters(), 'lr': config.lr_pretrained}) 35 | if config.relpos_tuning: 36 | learnable_params.append({'params': model.image_encoder.relpos_parameters(), 'lr': config.lr_pretrained}) 37 | if getattr(model.image_encoder.backbone, "pos_embed", None) is not None: 38 | learnable_params.append({'params': model.image_encoder.backbone.pos_embed, 'lr': config.lr_pretrained}) 39 | if config.input_embed_tuning: 40 | learnable_params.append({'params': model.image_encoder.backbone.patch_embed.parameters(), 'lr': config.lr_pretrained}) 41 | if config.output_embed_tuning: 42 | learnable_params.append({'params': model.label_encoder.backbone.patch_embed.parameters(), 'lr': config.lr_pretrained}) 43 | 44 | kwargs = {} 45 | if config.optimizer == 'sgd': 46 | kwargs['momentum'] = 0.9 47 | optimizer = optim_dict[config.optimizer](learnable_params, weight_decay=config.weight_decay, **kwargs) 48 | if config.lr_warmup >= 0: 49 | lr_warmup = config.lr_warmup 50 | else: 51 | assert config.lr_warmup_scale >= 0. and config.lr_warmup_scale <= 1. 52 | lr_warmup = int(config.lr_warmup_scale * config.n_schedule_steps) 53 | lr_scheduler = CustomLRScheduler(optimizer, config.lr_schedule, config.lr, config.n_schedule_steps, lr_warmup, 54 | from_iter=config.schedule_from, decay_degree=config.lr_decay_degree) 55 | 56 | return optimizer, lr_scheduler 57 | 58 | 59 | class CustomLRScheduler(object): 60 | ''' 61 | Custom learning rate scheduler for pytorch optimizer. 62 | Assumes 1 <= self.iter <= 1 + num_iters. 63 | ''' 64 | 65 | def __init__(self, optimizer, mode, base_lr, num_iters, warmup_iters=1000, 66 | from_iter=0, decay_degree=0.9, decay_steps=5000): 67 | self.optimizer = optimizer 68 | self.mode = mode 69 | self.base_lr = base_lr 70 | self.lr = base_lr 71 | self.iter = from_iter 72 | self.N = num_iters + 1 73 | self.warmup_iters = warmup_iters 74 | self.decay_degree = decay_degree 75 | self.decay_steps = decay_steps 76 | 77 | self.lr_coefs = [] 78 | for param_group in optimizer.param_groups: 79 | self.lr_coefs.append(param_group['lr'] / base_lr) 80 | 81 | if self.iter > 0: 82 | self.step(self.iter) 83 | 84 | def step(self, step=-1): 85 | # updatae current step 86 | if step >= 0: 87 | self.iter = step 88 | else: 89 | self.iter += 1 90 | 91 | # schedule lr 92 | if self.mode == 'cos': 93 | self.lr = 0.5 * self.base_lr * (1 + math.cos(1.0 * self.iter / self.N * math.pi)) 94 | elif self.mode == 'poly': 95 | if self.iter < self.N and self.iter >= self.warmup_iters: 96 | self.lr = self.base_lr * pow((1 - 1.0 * (self.iter - self.warmup_iters) / (self.N - self.warmup_iters)), self.decay_degree) 97 | elif self.mode == 'step': 98 | self.lr = self.base_lr * (0.1**(self.decay_steps // self.iter)) 99 | elif self.mode == 'constant': 100 | self.lr = self.base_lr 101 | elif self.mode == 'sqroot': 102 | self.lr = self.base_lr * self.warmup_iters**0.5 * min(self.iter * self.warmup_iters**-1.5, self.iter**-0.5) 103 | else: 104 | raise NotImplementedError 105 | 106 | # warm up lr schedule 107 | if self.warmup_iters > 0 and self.iter < self.warmup_iters and self.mode != 'sqroot': 108 | self.lr = self.base_lr * 1.0 * self.iter / self.warmup_iters 109 | assert self.lr >= 0 110 | 111 | # adjust lr 112 | self._adjust_learning_rate(self.optimizer, self.lr) 113 | 114 | def _adjust_learning_rate(self, optimizer, lr): 115 | for i in range(len(optimizer.param_groups)): 116 | optimizer.param_groups[i]['lr'] = lr * self.lr_coefs[i] 117 | 118 | def reset(self): 119 | self.lr = self.base_lr 120 | self.iter = 0 121 | self._adjust_learning_rate(self.optimizer, self.lr) 122 | --------------------------------------------------------------------------------