├── .gitignore ├── LICENSE ├── README.md ├── assets ├── network.png └── sbd_samples_weights.pkl ├── configs ├── adaptiveclick_plainvit_base448.yaml ├── adaptiveclick_plainvit_huge448.yaml └── base_configuration.yaml ├── demo.py ├── interactive_demo ├── app.py ├── canvas.py ├── controller.py └── wrappers.py ├── isegm ├── __init__.py ├── data │ ├── base.py │ ├── compose.py │ ├── datasets │ │ ├── __init__.py │ │ ├── ade20k.py │ │ ├── berkeley.py │ │ ├── brats.py │ │ ├── coco.py │ │ ├── coco_lvis.py │ │ ├── davis.py │ │ ├── grabcut.py │ │ ├── hard.py │ │ ├── images_dir.py │ │ ├── lvis.py │ │ ├── lvis_v1.py │ │ ├── oai.py │ │ ├── oai_zib.py │ │ ├── openimages.py │ │ ├── pascalvoc.py │ │ ├── sbd.py │ │ └── ssTEM.py │ ├── points_sampler.py │ ├── sample.py │ └── transforms.py ├── engine │ ├── adaptiveclick_trainer.py │ ├── optimizer.py │ └── trainer.py ├── inference │ ├── clicker.py │ ├── evaluation.py │ ├── predictors │ │ ├── __init__.py │ │ ├── base.py │ │ ├── brs.py │ │ ├── brs_functors.py │ │ └── brs_losses.py │ ├── transforms │ │ ├── __init__.py │ │ ├── base.py │ │ ├── crops.py │ │ ├── flip.py │ │ ├── limit_longest_side.py │ │ └── zoom_in.py │ └── utils.py ├── model │ ├── __init__.py │ ├── criterion.py │ ├── initializer.py │ ├── is_adaptiveclick_model.py │ ├── is_deeplab_model.py │ ├── is_hrformer_model.py │ ├── is_hrnet_model.py │ ├── is_model.py │ ├── is_plainvit_model.py │ ├── is_segformer_model.py │ ├── is_swinformer_model.py │ ├── losses.py │ ├── matcher.py │ ├── metrics.py │ ├── modeling │ │ ├── __init__.py │ │ ├── basic_blocks.py │ │ ├── deeplab_v3.py │ │ ├── hrformer.py │ │ ├── hrformer_helper │ │ │ ├── __init__.py │ │ │ ├── backbone_selector.py │ │ │ └── hrt │ │ │ │ ├── __init__.py │ │ │ │ ├── hrt_backbone.py │ │ │ │ ├── hrt_config.py │ │ │ │ ├── logger.py │ │ │ │ ├── module_helper.py │ │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── bottleneck_block.py │ │ │ │ ├── ffn_block.py │ │ │ │ ├── multihead_attention.py │ │ │ │ ├── multihead_isa_attention.py │ │ │ │ ├── multihead_isa_pool_attention.py │ │ │ │ ├── spatial_ocr_block.py │ │ │ │ └── transformer_block.py │ │ ├── hrnet_ocr.py │ │ ├── mask2former_helper │ │ │ ├── __init__.py │ │ │ ├── mask2former_transformer_decoder.py │ │ │ ├── msdeformattn.py │ │ │ ├── ops │ │ │ │ ├── __init__.py │ │ │ │ ├── functions │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ ├── make.sh │ │ │ │ ├── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ms_deform_attn.py │ │ │ │ ├── setup.py │ │ │ │ ├── src │ │ │ │ │ ├── cpu │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ │ ├── cuda │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ └── vision.cpp │ │ │ │ └── test.py │ │ │ ├── position_encoding.py │ │ │ └── transformer.py │ │ ├── models_vit.py │ │ ├── ocr.py │ │ ├── pos_embed.py │ │ ├── resnet.py │ │ ├── resnetv1b.py │ │ ├── segformer.py │ │ ├── swin_transformer.py │ │ ├── swin_transformer_helper │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── checkpoint.py │ │ │ └── logger.py │ │ ├── swin_unet.py │ │ └── transformer_helper │ │ │ ├── __init__.py │ │ │ ├── accuracy.py │ │ │ ├── base_pixel_sampler.py │ │ │ ├── builder.py │ │ │ ├── cross_entropy_loss.py │ │ │ ├── decode_head.py │ │ │ ├── embed.py │ │ │ ├── logger.py │ │ │ ├── shape_convert.py │ │ │ ├── utils.py │ │ │ └── wrappers.py │ ├── modifiers.py │ └── ops.py └── utils │ ├── cython │ ├── __init__.py │ ├── _get_dist_maps.pyx │ ├── _get_dist_maps.pyxbld │ └── dist_maps.py │ ├── distributed.py │ ├── exp.py │ ├── exp_imports │ └── default.py │ ├── log.py │ ├── lr_decay.py │ ├── misc.py │ ├── serialization.py │ └── vis.py ├── models └── iter_mask │ ├── adaptiveclick_base448_cocolvis_itermask.py │ ├── adaptiveclick_base448_sbd_itermask.py │ ├── adaptiveclick_huge448_cocolvis_itermask.py │ ├── adaptiveclick_huge448_sbd_itermask.py │ ├── simpleclick_base448_cocolvis_itermask.py │ ├── simpleclick_base448_sbd_itermask.py │ ├── simpleclick_huge448_cocolvis_itermask.py │ └── simpleclick_huge448_sbd_itermask.py ├── requirements.txt ├── scripts ├── analyze_image_size.py ├── annotations_conversion │ ├── ade20k.py │ ├── coco_lvis.py │ ├── common.py │ └── openimages.py ├── convert_annotations.py ├── draw_radar.py ├── draw_radar_natural.py ├── evaluate_model.py └── plot_ious_analysis.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | output 3 | instant_test_output 4 | inference_test_output 5 | 6 | 7 | *.png 8 | *.json 9 | *.diff 10 | *.jpg 11 | *.gif 12 | !/assets/*.jpg 13 | !/assets/*.png 14 | 15 | # compilation and distribution 16 | __pycache__ 17 | _ext 18 | *.pyc 19 | *.pyd 20 | *.so 21 | *.dll 22 | *.egg-info/ 23 | build/ 24 | dist/ 25 | wheels/ 26 | 27 | # pytorch/python/numpy formats 28 | *.pth 29 | *.pkl 30 | !assets/sbd_samples_weights.pkl 31 | *.npy 32 | *.ts 33 | model_ts*.txt 34 | 35 | # ipython/jupyter notebooks 36 | *.ipynb 37 | **/.ipynb_checkpoints/ 38 | 39 | # Editor temporaries 40 | *.swn 41 | *.swo 42 | *.swp 43 | *~ 44 | 45 | # editor settings 46 | .idea 47 | .vscode 48 | _darcs 49 | 50 | # project dirs 51 | /datasets/* 52 | !/datasets/*.* 53 | /projects/*/datasets 54 | /snippet 55 | /logs 56 | weights/ 57 | experiments/ 58 | checkpoints/ 59 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2021 Samsung Electronics Co., Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /assets/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/assets/network.png -------------------------------------------------------------------------------- /assets/sbd_samples_weights.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/assets/sbd_samples_weights.pkl -------------------------------------------------------------------------------- /configs/adaptiveclick_plainvit_base448.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: 'base_configuration.yaml' 2 | MODEL: 3 | IMAGENET_PRETRAINED_MODELS: "weights/mae_pretrain_vit_base.pth" 4 | BACKBONE: 5 | PATCH_SIZE: [ 16, 16 ] 6 | IN_CHANS: 3 7 | EMBED_DIM: 768 8 | DEPTH: 12 9 | NUM_HEADS: 12 10 | MLP_RATIO: 4 11 | QKV_BIAS: TRUE 12 | SIMPLE_FEATURE_PYRAMID: 13 | IN_DIM: 768 14 | OUT_DIMS: [ 256, 512, 1024, 2048 ] 15 | SEM_SEG_HEAD: 16 | IGNORE_VALUE: 255 17 | NUM_CLASSES: 1 18 | LOSS_WEIGHT: 1.0 19 | CONVS_DIM: 256 20 | MASK_DIM: 256 21 | NORM: "GN" 22 | # pixel decoder 23 | IN_FEATURES: [ "res2", "res3", "res4", "res5" ] 24 | DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] 25 | COMMON_STRIDE: 4 26 | TRANSFORMER_ENC_LAYERS: 4 27 | MASK_FORMER: 28 | DEEP_SUPERVISION: True 29 | NO_OBJECT_WEIGHT: 0.1 30 | CLASS_WEIGHT: 2.0 31 | MASK_WEIGHT: 5.0 32 | DICE_WEIGHT: 5.0 33 | HIDDEN_DIM: 256 34 | NUM_OBJECT_QUERIES: 10 35 | NHEADS: 8 36 | DROPOUT: 0.0 37 | DIM_FEEDFORWARD: 2048 38 | ENC_LAYERS: 0 39 | PRE_NORM: False 40 | ENFORCE_INPUT_PROJ: False 41 | SIZE_DIVISIBILITY: 32 42 | DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query 43 | TRAIN_NUM_POINTS: 12544 44 | OVERSAMPLE_RATIO: 3.0 45 | IMPORTANCE_SAMPLE_RATIO: 0.75 46 | TEST: 47 | SEMANTIC_ON: True 48 | INSTANCE_ON: True 49 | PANOPTIC_ON: True 50 | OVERLAP_THRESHOLD: 0.8 51 | OBJECT_MASK_THRESHOLD: 0.8 52 | 53 | -------------------------------------------------------------------------------- /configs/adaptiveclick_plainvit_huge448.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: 'base_configuration.yaml' 2 | MODEL: 3 | IMAGENET_PRETRAINED_MODELS: "weights/mae_pretrain_vit_huge.pth" 4 | BACKBONE: 5 | PATCH_SIZE: [ 14, 14 ] 6 | IN_CHANS: 3 7 | EMBED_DIM: 1280 8 | DEPTH: 32 9 | NUM_HEADS: 16 10 | MLP_RATIO: 4 11 | QKV_BIAS: TRUE 12 | SIMPLE_FEATURE_PYRAMID: 13 | IN_DIM: 1280 14 | OUT_DIMS: [ 512, 1024, 2048, 4096 ] 15 | SEM_SEG_HEAD: 16 | IGNORE_VALUE: 255 17 | NUM_CLASSES: 1 18 | LOSS_WEIGHT: 1.0 19 | CONVS_DIM: 256 20 | MASK_DIM: 256 21 | NORM: "GN" 22 | # pixel decoder 23 | IN_FEATURES: [ "res2", "res3", "res4", "res5" ] 24 | DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ] 25 | COMMON_STRIDE: 4 26 | TRANSFORMER_ENC_LAYERS: 4 27 | MASK_FORMER: 28 | DEEP_SUPERVISION: True 29 | NO_OBJECT_WEIGHT: 0.1 30 | CLASS_WEIGHT: 2.0 31 | MASK_WEIGHT: 5.0 32 | DICE_WEIGHT: 5.0 33 | HIDDEN_DIM: 256 34 | NUM_OBJECT_QUERIES: 10 35 | NHEADS: 8 36 | DROPOUT: 0.0 37 | DIM_FEEDFORWARD: 2048 38 | ENC_LAYERS: 0 39 | PRE_NORM: False 40 | ENFORCE_INPUT_PROJ: False 41 | SIZE_DIVISIBILITY: 32 42 | DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query 43 | TRAIN_NUM_POINTS: 12544 44 | OVERSAMPLE_RATIO: 3.0 45 | IMPORTANCE_SAMPLE_RATIO: 0.75 46 | TEST: 47 | SEMANTIC_ON: True 48 | INSTANCE_ON: True 49 | PANOPTIC_ON: True 50 | OVERLAP_THRESHOLD: 0.8 51 | OBJECT_MASK_THRESHOLD: 0.8 52 | 53 | -------------------------------------------------------------------------------- /configs/base_configuration.yaml: -------------------------------------------------------------------------------- 1 | INTERACTIVE_MODELS_PATH: "./weights" 2 | EXPS_PATH: "./experiments" 3 | 4 | # Evaluation datasets 5 | GRABCUT_PATH: "./datasets/GrabCut" 6 | BERKELEY_PATH: "./datasets/Berkeley" 7 | DAVIS_PATH: "./datasets/DAVIS" 8 | 9 | COCO_MVAL_PATH: "./datasets/COCO_MVal" 10 | 11 | # Train datasets 12 | SBD_PATH: "./datasets/SBD" 13 | COCO_PATH: "./datasets/LVIS" 14 | LVIS_PATH: "./datasets/LVIS" 15 | LVIS_v1_PATH: "./datasets/LVIS" 16 | 17 | OPENIMAGES_PATH: "./datasets/OpenImages" 18 | PASCALVOC_PATH: "./datasets/VOC2012" 19 | ADE20K_PATH: "./datasets/ADE20K" 20 | 21 | BraTS_PATH: "./datasets/BraTS20" 22 | ssTEM_PATH: "./datasets/ssTEM/stack1" 23 | OAIZIB_PATH: "./datasets/OAI-ZIB" 24 | 25 | # For SimpleClick 26 | IMAGENET_PRETRAINED_MODELS: 27 | MAE_BASE: "./weights/mae_pretrain_vit_base.pth" 28 | MAE_LARGE: "./weights/mae_pretrain_vit_large.pth" 29 | MAE_HUGE: "./weights/mae_pretrain_vit_huge.pth" -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import argparse 5 | import tkinter as tk 6 | 7 | import torch 8 | 9 | from isegm.utils import exp 10 | from isegm.inference import utils 11 | from interactive_demo.app import InteractiveDemoApp 12 | 13 | def main(): 14 | args, cfg = parse_args() 15 | 16 | torch.backends.cudnn.deterministic = True 17 | checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint) 18 | model = utils.load_is_model(checkpoint_path, args.device, args.eval_ritm, cpu_dist_maps=True) 19 | 20 | root = tk.Tk() 21 | root.minsize(960, 480) 22 | app = InteractiveDemoApp(root, args, model) 23 | root.deiconify() 24 | app.mainloop() 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | 30 | parser.add_argument('--checkpoint', type=str, required=True, 31 | help='The path to the checkpoint. ' 32 | 'This can be a relative path (relative to cfg.INTERACTIVE_MODELS_PATH) ' 33 | 'or an absolute path. The file extension can be omitted.') 34 | 35 | parser.add_argument('--gpu', type=int, default=0, 36 | help='Id of GPU to use.') 37 | 38 | parser.add_argument('--cpu', action='store_true', default=False, 39 | help='Use only CPU for inference.') 40 | 41 | parser.add_argument('--limit-longest-size', type=int, default=800, 42 | help='If the largest side of an image exceeds this value, ' 43 | 'it is resized so that its largest side is equal to this value.') 44 | 45 | parser.add_argument('--cfg', type=str, default="config.yml", 46 | help='The path to the config file.') 47 | 48 | parser.add_argument('--eval-ritm', action='store_true', default=False) 49 | 50 | args = parser.parse_args() 51 | if args.cpu: 52 | args.device =torch.device('cpu') 53 | else: 54 | args.device = torch.device(f'cuda:{args.gpu}') 55 | cfg = exp.load_config_file(args.cfg, return_edict=True) 56 | 57 | return args, cfg 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /interactive_demo/wrappers.py: -------------------------------------------------------------------------------- 1 | import tkinter as tk 2 | from tkinter import messagebox, ttk 3 | 4 | 5 | class BoundedNumericalEntry(tk.Entry): 6 | def __init__(self, master=None, min_value=None, max_value=None, variable=None, 7 | vartype=float, width=7, allow_inf=False, **kwargs): 8 | if variable is None: 9 | if vartype == float: 10 | self.var = tk.DoubleVar() 11 | elif vartype == int: 12 | self.var = tk.IntVar() 13 | else: 14 | self.var = tk.StringVar() 15 | else: 16 | self.var = variable 17 | 18 | self.fake_var = tk.StringVar(value=self.var.get()) 19 | self.vartype = vartype 20 | self.old_value = self.var.get() 21 | self.allow_inf = allow_inf 22 | 23 | self.min_value, self.max_value = min_value, max_value 24 | self.get, self.set = self.fake_var.get, self.fake_var.set 25 | 26 | self.validate_command = master.register(self._check_bounds) 27 | tk.Entry.__init__(self, master, textvariable=self.fake_var, validate="focus", width=width, 28 | vcmd=(self.validate_command, '%P', '%d'), **kwargs) 29 | 30 | def _check_bounds(self, instr, action_type): 31 | if self.allow_inf and instr == 'INF': 32 | self.fake_var.set('INF') 33 | return True 34 | 35 | if action_type == '-1': 36 | try: 37 | new_value = self.vartype(instr) 38 | except ValueError: 39 | pass 40 | else: 41 | if (self.min_value is None or new_value >= self.min_value) and \ 42 | (self.max_value is None or new_value <= self.max_value): 43 | if new_value != self.old_value: 44 | self.old_value = self.vartype(self.fake_var.get()) 45 | self.delete(0, tk.END) 46 | self.insert(0, str(self.old_value)) 47 | self.var.set(self.old_value) 48 | return True 49 | self.delete(0, tk.END) 50 | self.insert(0, str(self.old_value)) 51 | mn = '-inf' if self.min_value is None else str(self.min_value) 52 | mx = '+inf' if self.max_value is None else str(self.max_value) 53 | messagebox.showwarning("Incorrect value in input field", f"Value for {self._name} should be in " 54 | f"[{mn}; {mx}] and of type {self.vartype.__name__}") 55 | 56 | return False 57 | 58 | 59 | class FocusHorizontalScale(tk.Scale): 60 | def __init__(self, *args, highlightthickness=0, sliderrelief=tk.GROOVE, resolution=0.01, 61 | sliderlength=20, length=200, **kwargs): 62 | tk.Scale.__init__(self, *args, orient=tk.HORIZONTAL, highlightthickness=highlightthickness, 63 | sliderrelief=sliderrelief, resolution=resolution, 64 | sliderlength=sliderlength, length=length, **kwargs) 65 | self.bind("<1>", lambda event: self.focus_set()) 66 | 67 | 68 | class FocusCheckButton(tk.Checkbutton): 69 | def __init__(self, *args, highlightthickness=0, **kwargs): 70 | tk.Checkbutton.__init__(self, *args, highlightthickness=highlightthickness, **kwargs) 71 | self.bind("<1>", lambda event: self.focus_set()) 72 | 73 | 74 | class FocusButton(tk.Button): 75 | def __init__(self, *args, highlightthickness=0, **kwargs): 76 | tk.Button.__init__(self, *args, highlightthickness=highlightthickness, **kwargs) 77 | self.bind("<1>", lambda event: self.focus_set()) 78 | 79 | 80 | class FocusLabelFrame(ttk.LabelFrame): 81 | def __init__(self, *args, highlightthickness=0, relief=tk.RIDGE, borderwidth=2, **kwargs): 82 | tk.LabelFrame.__init__(self, *args, highlightthickness=highlightthickness, relief=relief, 83 | borderwidth=borderwidth, **kwargs) 84 | self.bind("<1>", lambda event: self.focus_set()) 85 | 86 | def set_frame_state(self, state): 87 | def set_widget_state(widget, state): 88 | if widget.winfo_children is not None: 89 | for w in widget.winfo_children(): 90 | w.configure(state=state) 91 | 92 | set_widget_state(self, state) 93 | -------------------------------------------------------------------------------- /isegm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/isegm/__init__.py -------------------------------------------------------------------------------- /isegm/data/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | import numpy as np 4 | import torch 5 | from torchvision import transforms 6 | from .points_sampler import MultiPointSampler 7 | from .sample import DSample 8 | 9 | 10 | class ISDataset(torch.utils.data.dataset.Dataset): 11 | def __init__(self, 12 | augmentator=None, 13 | points_sampler=MultiPointSampler(max_num_points=12), 14 | min_object_area=0, 15 | keep_background_prob=0.0, 16 | with_image_info=False, 17 | samples_scores_path=None, 18 | samples_scores_gamma=1.0, 19 | epoch_len=-1): 20 | super(ISDataset, self).__init__() 21 | self.epoch_len = epoch_len 22 | self.augmentator = augmentator 23 | self.min_object_area = min_object_area 24 | self.keep_background_prob = keep_background_prob 25 | self.points_sampler = points_sampler 26 | self.with_image_info = with_image_info 27 | self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma) 28 | self.to_tensor = transforms.ToTensor() 29 | 30 | self.dataset_samples = None 31 | 32 | def __getitem__(self, index): 33 | if self.samples_precomputed_scores is not None: 34 | index = np.random.choice(self.samples_precomputed_scores['indices'], 35 | p=self.samples_precomputed_scores['probs']) 36 | else: 37 | if self.epoch_len > 0: 38 | index = random.randrange(0, len(self.dataset_samples)) 39 | 40 | sample = self.get_sample(index) 41 | sample = self.augment_sample(sample) 42 | sample.remove_small_objects(self.min_object_area) 43 | 44 | self.points_sampler.sample_object(sample) 45 | points = np.array(self.points_sampler.sample_points()) 46 | mask = self.points_sampler.selected_mask 47 | 48 | output = { 49 | 'images': self.to_tensor(sample.image), 50 | 'points': points.astype(np.float32), 51 | 'instances': mask 52 | } 53 | 54 | if self.with_image_info: 55 | output['image_info'] = sample.sample_id 56 | 57 | return output 58 | 59 | def augment_sample(self, sample) -> DSample: 60 | if self.augmentator is None: 61 | return sample 62 | 63 | valid_augmentation = False 64 | while not valid_augmentation: 65 | sample.augment(self.augmentator) 66 | keep_sample = (self.keep_background_prob < 0.0 or 67 | random.random() < self.keep_background_prob) 68 | valid_augmentation = len(sample) > 0 or keep_sample 69 | 70 | return sample 71 | 72 | def get_sample(self, index) -> DSample: 73 | raise NotImplementedError 74 | 75 | def __len__(self): 76 | if self.epoch_len > 0: 77 | return self.epoch_len 78 | else: 79 | return self.get_samples_number() 80 | 81 | def get_samples_number(self): 82 | return len(self.dataset_samples) 83 | 84 | @staticmethod 85 | def _load_samples_scores(samples_scores_path, samples_scores_gamma): 86 | if samples_scores_path is None: 87 | return None 88 | 89 | with open(samples_scores_path, 'rb') as f: 90 | images_scores = pickle.load(f) 91 | 92 | probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores]) 93 | probs /= probs.sum() 94 | samples_scores = { 95 | 'indices': [x[0] for x in images_scores], 96 | 'probs': probs 97 | } 98 | print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}') 99 | return samples_scores 100 | -------------------------------------------------------------------------------- /isegm/data/compose.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import isclose 3 | from .base import ISDataset 4 | 5 | 6 | class ComposeDataset(ISDataset): 7 | def __init__(self, datasets, **kwargs): 8 | super(ComposeDataset, self).__init__(**kwargs) 9 | 10 | self._datasets = datasets 11 | self.dataset_samples = [] 12 | for dataset_indx, dataset in enumerate(self._datasets): 13 | self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) 14 | 15 | def get_sample(self, index): 16 | dataset_indx, sample_indx = self.dataset_samples[index] 17 | return self._datasets[dataset_indx].get_sample(sample_indx) 18 | 19 | 20 | class ProportionalComposeDataset(ISDataset): 21 | def __init__(self, datasets, ratios, **kwargs): 22 | super().__init__(**kwargs) 23 | 24 | assert len(ratios) == len(datasets),\ 25 | "The number of datasets must match the number of ratios" 26 | assert isclose(sum(ratios), 1.0),\ 27 | "The sum of ratios must be equal to 1" 28 | 29 | self._ratios = ratios 30 | self._datasets = datasets 31 | self.dataset_samples = [] 32 | for dataset_indx, dataset in enumerate(self._datasets): 33 | self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) 34 | 35 | def get_sample(self, index): 36 | dataset_indx = np.random.choice(len(self._datasets), p=self._ratios) 37 | sample_indx = np.random.choice(len(self._datasets[dataset_indx])) 38 | 39 | return self._datasets[dataset_indx].get_sample(sample_indx) 40 | -------------------------------------------------------------------------------- /isegm/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from isegm.data.compose import ComposeDataset, ProportionalComposeDataset 2 | from .berkeley import BerkeleyDataset 3 | from .coco import CocoDataset 4 | from .davis import DavisDataset 5 | from .grabcut import GrabCutDataset 6 | from .coco_lvis import CocoLvisDataset 7 | from .lvis import LvisDataset 8 | from .lvis_v1 import Lvis_v1_Dataset 9 | from .openimages import OpenImagesDataset 10 | from .sbd import SBDDataset, SBDEvaluationDataset 11 | from .images_dir import ImagesDirDataset 12 | from .ade20k import ADE20kDataset 13 | from .pascalvoc import PascalVocDataset 14 | from .brats import BraTSDataset 15 | from .ssTEM import ssTEMDataset 16 | from .oai_zib import OAIZIBDataset 17 | from .oai import OAIDataset 18 | from .hard import HARDDataset -------------------------------------------------------------------------------- /isegm/data/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle as pkl 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from isegm.data.base import ISDataset 10 | from isegm.data.sample import DSample 11 | from isegm.utils.misc import get_labels_with_sizes 12 | 13 | 14 | class ADE20kDataset(ISDataset): 15 | def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs): 16 | super().__init__(**kwargs) 17 | assert split in {'train', 'val'} 18 | 19 | self.dataset_path = Path(dataset_path) 20 | self.dataset_split = split 21 | self.dataset_split_folder = 'training' if split == 'train' else 'validation' 22 | self.stuff_prob = stuff_prob 23 | 24 | anno_path = self.dataset_path / f'{split}-annotations-object-segmentation.pkl' 25 | if os.path.exists(anno_path): 26 | with anno_path.open('rb') as f: 27 | annotations = pkl.load(f) 28 | else: 29 | raise RuntimeError(f"Can't find annotations at {anno_path}") 30 | self.annotations = annotations 31 | self.dataset_samples = list(annotations.keys()) 32 | 33 | def get_sample(self, index) -> DSample: 34 | image_id = self.dataset_samples[index] 35 | sample_annos = self.annotations[image_id] 36 | 37 | image_path = str(self.dataset_path / sample_annos['folder'] / f'{image_id}.jpg') 38 | image = cv2.imread(image_path) 39 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 40 | 41 | # select random mask for an image 42 | layer = random.choice(sample_annos['layers']) 43 | mask_path = str(self.dataset_path / sample_annos['folder'] / layer['mask_name']) 44 | instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[:, :, 0] # the B channel holds instances 45 | instances_mask = instances_mask.astype(np.int32) 46 | object_ids, _ = get_labels_with_sizes(instances_mask) 47 | 48 | if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob): 49 | # remove stuff objects 50 | for i, object_id in enumerate(object_ids): 51 | if i in layer['stuff_instances']: 52 | instances_mask[instances_mask == object_id] = 0 53 | object_ids, _ = get_labels_with_sizes(instances_mask) 54 | 55 | return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index) 56 | -------------------------------------------------------------------------------- /isegm/data/datasets/berkeley.py: -------------------------------------------------------------------------------- 1 | from .grabcut import GrabCutDataset 2 | 3 | 4 | class BerkeleyDataset(GrabCutDataset): 5 | def __init__(self, dataset_path, **kwargs): 6 | super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs) 7 | -------------------------------------------------------------------------------- /isegm/data/datasets/brats.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class BraTSDataset(ISDataset): 11 | def __init__(self, dataset_path, 12 | images_dir_name='image', masks_dir_name='annotation', 13 | **kwargs): 14 | super(BraTSDataset, self).__init__(**kwargs) 15 | 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / images_dir_name 18 | self._insts_path = self.dataset_path / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.png'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.png')} 22 | 23 | def get_sample(self, index) -> DSample: 24 | image_name = self.dataset_samples[index] 25 | image_path = str(self._images_path / image_name) 26 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 27 | 28 | image = cv2.imread(image_path) 29 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.int32) 31 | instances_mask[instances_mask > 0] = 1 32 | 33 | return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index) 34 | -------------------------------------------------------------------------------- /isegm/data/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import random 4 | import numpy as np 5 | from pathlib import Path 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class CocoDataset(ISDataset): 11 | def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs): 12 | super(CocoDataset, self).__init__(**kwargs) 13 | self.split = split 14 | self.dataset_path = Path(dataset_path) 15 | self.stuff_prob = stuff_prob 16 | 17 | self.load_samples() 18 | 19 | def load_samples(self): 20 | annotation_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}.json' 21 | self.labels_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}' 22 | self.images_path = self.dataset_path / self.split 23 | 24 | with open(annotation_path, 'r') as f: 25 | annotation = json.load(f) 26 | 27 | self.dataset_samples = annotation['annotations'] 28 | 29 | self._categories = annotation['categories'] 30 | self._stuff_labels = [x['id'] for x in self._categories if x['isthing'] == 0] 31 | self._things_labels = [x['id'] for x in self._categories if x['isthing'] == 1] 32 | self._things_labels_set = set(self._things_labels) 33 | self._stuff_labels_set = set(self._stuff_labels) 34 | 35 | def get_sample(self, index) -> DSample: 36 | dataset_sample = self.dataset_samples[index] 37 | 38 | image_path = self.images_path / self.get_image_name(dataset_sample['file_name']) 39 | label_path = self.labels_path / dataset_sample['file_name'] 40 | 41 | image = cv2.imread(str(image_path)) 42 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 43 | label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32) 44 | label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2] 45 | 46 | instance_map = np.full_like(label, 0) 47 | things_ids = [] 48 | stuff_ids = [] 49 | 50 | for segment in dataset_sample['segments_info']: 51 | class_id = segment['category_id'] 52 | obj_id = segment['id'] 53 | if class_id in self._things_labels_set: 54 | if segment['iscrowd'] == 1: 55 | continue 56 | things_ids.append(obj_id) 57 | else: 58 | stuff_ids.append(obj_id) 59 | 60 | instance_map[label == obj_id] = obj_id 61 | 62 | if self.stuff_prob > 0 and random.random() < self.stuff_prob: 63 | instances_ids = things_ids + stuff_ids 64 | else: 65 | instances_ids = things_ids 66 | 67 | for stuff_id in stuff_ids: 68 | instance_map[instance_map == stuff_id] = 0 69 | 70 | return DSample(image, instance_map, objects_ids=instances_ids) 71 | 72 | @classmethod 73 | def get_image_name(cls, panoptic_name): 74 | return panoptic_name.replace('.png', '.jpg') 75 | -------------------------------------------------------------------------------- /isegm/data/datasets/coco_lvis.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pickle 3 | import random 4 | import numpy as np 5 | import json 6 | import cv2 7 | from copy import deepcopy 8 | from isegm.data.base import ISDataset 9 | from isegm.data.sample import DSample 10 | 11 | 12 | class CocoLvisDataset(ISDataset): 13 | def __init__(self, dataset_path, split='train', stuff_prob=0.0, 14 | allow_list_name=None, anno_file='hannotation.pickle', **kwargs): 15 | super(CocoLvisDataset, self).__init__(**kwargs) 16 | dataset_path = Path(dataset_path) 17 | self._split_path = dataset_path / split 18 | self.split = split 19 | self._images_path = self._split_path / 'images' 20 | self._masks_path = self._split_path / 'masks' 21 | self.stuff_prob = stuff_prob 22 | 23 | with open(self._split_path / anno_file, 'rb') as f: 24 | self.dataset_samples = sorted(pickle.load(f).items()) 25 | 26 | if allow_list_name is not None: 27 | allow_list_path = self._split_path / allow_list_name 28 | with open(allow_list_path, 'r') as f: 29 | allow_images_ids = json.load(f) 30 | allow_images_ids = set(allow_images_ids) 31 | 32 | self.dataset_samples = [sample for sample in self.dataset_samples 33 | if sample[0] in allow_images_ids] 34 | 35 | def get_sample(self, index) -> DSample: 36 | image_id, sample = self.dataset_samples[index] 37 | image_path = self._images_path / f'{image_id}.jpg' 38 | 39 | image = cv2.imread(str(image_path)) 40 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 41 | 42 | packed_masks_path = self._masks_path / f'{image_id}.pickle' 43 | with open(packed_masks_path, 'rb') as f: 44 | encoded_layers, objs_mapping = pickle.load(f) 45 | layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers] 46 | layers = np.stack(layers, axis=2) 47 | 48 | instances_info = deepcopy(sample['hierarchy']) 49 | for inst_id, inst_info in list(instances_info.items()): 50 | if inst_info is None: 51 | inst_info = {'children': [], 'parent': None, 'node_level': 0} 52 | instances_info[inst_id] = inst_info 53 | inst_info['mapping'] = objs_mapping[inst_id] 54 | 55 | if self.stuff_prob > 0 and random.random() < self.stuff_prob: 56 | for inst_id in range(sample['num_instance_masks'], len(objs_mapping)): 57 | instances_info[inst_id] = { 58 | 'mapping': objs_mapping[inst_id], 59 | 'parent': None, 60 | 'children': [] 61 | } 62 | else: 63 | for inst_id in range(sample['num_instance_masks'], len(objs_mapping)): 64 | layer_indx, mask_id = objs_mapping[inst_id] 65 | layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0 66 | 67 | return DSample(image, layers, objects=instances_info) 68 | -------------------------------------------------------------------------------- /isegm/data/datasets/davis.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class DavisDataset(ISDataset): 11 | def __init__(self, dataset_path, 12 | images_dir_name='img', masks_dir_name='gt', 13 | **kwargs): 14 | super(DavisDataset, self).__init__(**kwargs) 15 | 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / images_dir_name 18 | self._insts_path = self.dataset_path / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} 22 | 23 | def get_sample(self, index) -> DSample: 24 | image_name = self.dataset_samples[index] 25 | image_path = str(self._images_path / image_name) 26 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 27 | 28 | image = cv2.imread(image_path) 29 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2) 31 | instances_mask[instances_mask > 0] = 1 32 | 33 | return DSample(image, instances_mask, objects_ids=[1], sample_id=index) 34 | -------------------------------------------------------------------------------- /isegm/data/datasets/grabcut.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class GrabCutDataset(ISDataset): 11 | def __init__(self, dataset_path, 12 | images_dir_name='data_GT', masks_dir_name='boundary_GT', 13 | **kwargs): 14 | super(GrabCutDataset, self).__init__(**kwargs) 15 | 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / images_dir_name 18 | self._insts_path = self.dataset_path / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} 22 | 23 | def get_sample(self, index) -> DSample: 24 | image_name = self.dataset_samples[index] 25 | image_path = str(self._images_path / image_name) 26 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 27 | 28 | image = cv2.imread(image_path) 29 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.int32) 31 | instances_mask[instances_mask == 128] = -1 32 | instances_mask[instances_mask > 128] = 1 33 | 34 | return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index) 35 | -------------------------------------------------------------------------------- /isegm/data/datasets/hard.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class HARDDataset(ISDataset): 11 | def __init__(self, dataset_path, 12 | images_dir_name='img', masks_dir_name='gt', 13 | **kwargs): 14 | super(HARDDataset, self).__init__(**kwargs) 15 | 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / images_dir_name 18 | self._insts_path = self.dataset_path / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} 22 | 23 | def get_sample(self, index) -> DSample: 24 | image_name = self.dataset_samples[index] 25 | image_path = str(self._images_path / image_name) 26 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 27 | 28 | image = cv2.imread(image_path) 29 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2) 31 | instances_mask[instances_mask > 0] = 1 32 | 33 | return DSample(image, instances_mask, objects_ids=[1], sample_id=index) 34 | -------------------------------------------------------------------------------- /isegm/data/datasets/images_dir.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from pathlib import Path 4 | 5 | from isegm.data.base import ISDataset 6 | from isegm.data.sample import DSample 7 | 8 | 9 | class ImagesDirDataset(ISDataset): 10 | def __init__(self, dataset_path, 11 | images_dir_name='images', masks_dir_name='masks', 12 | **kwargs): 13 | super(ImagesDirDataset, self).__init__(**kwargs) 14 | 15 | self.dataset_path = Path(dataset_path) 16 | self._images_path = self.dataset_path / images_dir_name 17 | self._insts_path = self.dataset_path / masks_dir_name 18 | 19 | images_list = [x for x in sorted(self._images_path.glob('*.*'))] 20 | 21 | samples = {x.stem: {'image': x, 'masks': []} for x in images_list} 22 | for mask_path in self._insts_path.glob('*.*'): 23 | mask_name = mask_path.stem 24 | if mask_name in samples: 25 | samples[mask_name]['masks'].append(mask_path) 26 | continue 27 | 28 | mask_name_split = mask_name.split('_') 29 | if mask_name_split[-1].isdigit(): 30 | mask_name = '_'.join(mask_name_split[:-1]) 31 | assert mask_name in samples 32 | samples[mask_name]['masks'].append(mask_path) 33 | 34 | for x in samples.values(): 35 | assert len(x['masks']) > 0, x['image'] 36 | 37 | self.dataset_samples = [v for k, v in sorted(samples.items())] 38 | 39 | def get_sample(self, index) -> DSample: 40 | sample = self.dataset_samples[index] 41 | image_path = str(sample['image']) 42 | 43 | objects = [] 44 | ignored_regions = [] 45 | masks = [] 46 | for indx, mask_path in enumerate(sample['masks']): 47 | gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32) 48 | instances_mask = np.zeros_like(gt_mask) 49 | instances_mask[gt_mask == 128] = 2 50 | instances_mask[gt_mask > 128] = 1 51 | masks.append(instances_mask) 52 | objects.append((indx, 1)) 53 | ignored_regions.append((indx, 2)) 54 | 55 | image = cv2.imread(image_path) 56 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 57 | 58 | return DSample(image, np.stack(masks, axis=2), 59 | objects_ids=objects, ignore_ids=ignored_regions, sample_id=index) 60 | -------------------------------------------------------------------------------- /isegm/data/datasets/lvis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from isegm.data.base import ISDataset 10 | from isegm.data.sample import DSample 11 | 12 | 13 | class LvisDataset(ISDataset): 14 | def __init__(self, dataset_path, split='train', 15 | max_overlap_ratio=0.5, 16 | **kwargs): 17 | super(LvisDataset, self).__init__(**kwargs) 18 | dataset_path = Path(dataset_path) 19 | train_categories_path = dataset_path / 'train_categories.json' 20 | self._train_path = dataset_path / 'train' 21 | self._val_path = dataset_path / 'val' 22 | 23 | self.split = split 24 | self.max_overlap_ratio = max_overlap_ratio 25 | 26 | with open( dataset_path / split / f'lvis_{self.split}.json', 'r') as f: 27 | json_annotation = json.loads(f.read()) 28 | 29 | self.annotations = defaultdict(list) 30 | for x in json_annotation['annotations']: 31 | self.annotations[x['image_id']].append(x) 32 | 33 | if not train_categories_path.exists(): 34 | self.generate_train_categories(dataset_path, train_categories_path) 35 | self.dataset_samples = [x for x in json_annotation['images'] 36 | if len(self.annotations[x['id']]) > 0] 37 | 38 | def get_sample(self, index) -> DSample: 39 | image_info = self.dataset_samples[index] 40 | image_id, image_url = image_info['id'], image_info['coco_url'] 41 | image_filename = image_url.split('/')[-1] 42 | image_annotations = self.annotations[image_id] 43 | random.shuffle(image_annotations) 44 | 45 | # LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017) 46 | if 'train2017' in image_url: 47 | image_path = self._train_path / 'images' / image_filename 48 | else: 49 | image_path = self._val_path / 'images' / image_filename 50 | image = cv2.imread(str(image_path)) 51 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 52 | 53 | instances_mask = None 54 | instances_area = defaultdict(int) 55 | objects_ids = [] 56 | for indx, obj_annotation in enumerate(image_annotations): 57 | mask = self.get_mask_from_polygon(obj_annotation, image) 58 | object_mask = mask > 0 59 | object_area = object_mask.sum() 60 | 61 | if instances_mask is None: 62 | instances_mask = np.zeros_like(object_mask, dtype=np.int32) 63 | 64 | overlap_ids = np.bincount(instances_mask[object_mask].flatten()) 65 | overlap_areas = [overlap_area / instances_area[inst_id] for inst_id, overlap_area in enumerate(overlap_ids) 66 | if overlap_area > 0 and inst_id > 0] 67 | overlap_ratio = np.logical_and(object_mask, instances_mask > 0).sum() / object_area 68 | if overlap_areas: 69 | overlap_ratio = max(overlap_ratio, max(overlap_areas)) 70 | if overlap_ratio > self.max_overlap_ratio: 71 | continue 72 | 73 | instance_id = indx + 1 74 | instances_mask[object_mask] = instance_id 75 | instances_area[instance_id] = object_area 76 | objects_ids.append(instance_id) 77 | 78 | return DSample(image, instances_mask, objects_ids=objects_ids) 79 | 80 | 81 | @staticmethod 82 | def get_mask_from_polygon(annotation, image): 83 | mask = np.zeros(image.shape[:2], dtype=np.int32) 84 | for contour_points in annotation['segmentation']: 85 | contour_points = np.array(contour_points).reshape((-1, 2)) 86 | contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :] 87 | cv2.fillPoly(mask, contour_points, 1) 88 | 89 | return mask 90 | 91 | @staticmethod 92 | def generate_train_categories(dataset_path, train_categories_path): 93 | with open(dataset_path / 'train/lvis_train.json', 'r') as f: 94 | annotation = json.load(f) 95 | 96 | with open(train_categories_path, 'w') as f: 97 | json.dump(annotation['categories'], f, indent=1) 98 | -------------------------------------------------------------------------------- /isegm/data/datasets/lvis_v1.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from isegm.data.base import ISDataset 10 | from isegm.data.sample import DSample 11 | 12 | 13 | class Lvis_v1_Dataset(ISDataset): 14 | def __init__(self, dataset_path, split='train', 15 | max_overlap_ratio=0.5, 16 | **kwargs): 17 | super(Lvis_v1_Dataset, self).__init__(**kwargs) 18 | dataset_path = Path(dataset_path) 19 | train_categories_path = dataset_path / 'train_categories.json' 20 | self._train_path = dataset_path / 'train' 21 | self._val_path = dataset_path / 'val' 22 | 23 | self.split = split 24 | self.max_overlap_ratio = max_overlap_ratio 25 | 26 | with open( dataset_path / split / f'lvis_v1_{self.split}.json', 'r') as f: 27 | json_annotation = json.loads(f.read()) 28 | 29 | self.annotations = defaultdict(list) 30 | for x in json_annotation['annotations']: 31 | self.annotations[x['image_id']].append(x) 32 | 33 | if not train_categories_path.exists(): 34 | self.generate_train_categories(dataset_path, train_categories_path) 35 | self.dataset_samples = [x for x in json_annotation['images'] 36 | if len(self.annotations[x['id']]) > 0] 37 | 38 | def get_sample(self, index) -> DSample: 39 | image_info = self.dataset_samples[index] 40 | image_id, image_url = image_info['id'], image_info['coco_url'] 41 | image_filename = image_url.split('/')[-1] 42 | image_annotations = self.annotations[image_id] 43 | random.shuffle(image_annotations) 44 | 45 | # LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017) 46 | if 'train2017' in image_url: 47 | image_path = self._train_path / 'images' / image_filename 48 | else: 49 | image_path = self._val_path / 'images' / image_filename 50 | image = cv2.imread(str(image_path)) 51 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 52 | 53 | instances_mask = None 54 | instances_area = defaultdict(int) 55 | objects_ids = [] 56 | for indx, obj_annotation in enumerate(image_annotations): 57 | mask = self.get_mask_from_polygon(obj_annotation, image) 58 | object_mask = mask > 0 59 | object_area = object_mask.sum() 60 | 61 | if instances_mask is None: 62 | instances_mask = np.zeros_like(object_mask, dtype=np.int32) 63 | 64 | overlap_ids = np.bincount(instances_mask[object_mask].flatten()) 65 | overlap_areas = [overlap_area / instances_area[inst_id] for inst_id, overlap_area in enumerate(overlap_ids) 66 | if overlap_area > 0 and inst_id > 0] 67 | overlap_ratio = np.logical_and(object_mask, instances_mask > 0).sum() / object_area 68 | if overlap_areas: 69 | overlap_ratio = max(overlap_ratio, max(overlap_areas)) 70 | if overlap_ratio > self.max_overlap_ratio: 71 | continue 72 | 73 | instance_id = indx + 1 74 | instances_mask[object_mask] = instance_id 75 | instances_area[instance_id] = object_area 76 | objects_ids.append(instance_id) 77 | 78 | return DSample(image, instances_mask, objects_ids=objects_ids) 79 | 80 | 81 | @staticmethod 82 | def get_mask_from_polygon(annotation, image): 83 | mask = np.zeros(image.shape[:2], dtype=np.int32) 84 | for contour_points in annotation['segmentation']: 85 | contour_points = np.array(contour_points).reshape((-1, 2)) 86 | contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :] 87 | cv2.fillPoly(mask, contour_points, 1) 88 | 89 | return mask 90 | 91 | @staticmethod 92 | def generate_train_categories(dataset_path, train_categories_path): 93 | with open(dataset_path / 'train/lvis_v1_train.json', 'r') as f: 94 | annotation = json.load(f) 95 | 96 | with open(train_categories_path, 'w') as f: 97 | json.dump(annotation['categories'], f, indent=1) 98 | -------------------------------------------------------------------------------- /isegm/data/datasets/oai.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class OAIDataset(ISDataset): 11 | def __init__(self, dataset_path, split='train', 12 | images_dir_name='image', masks_dir_name='annotations', 13 | **kwargs): 14 | super(OAIDataset, self).__init__(**kwargs) 15 | 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / split / images_dir_name 18 | self._insts_path = self.dataset_path / split / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.png'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.png')} 22 | 23 | assert len(self.dataset_samples) > 0 24 | 25 | def get_sample(self, index) -> DSample: 26 | image_name = self.dataset_samples[index] 27 | image_path = str(self._images_path / image_name) 28 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 29 | 30 | image = cv2.imread(image_path) 31 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 32 | instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.uint8) 33 | 34 | objects_ids = np.unique(instances_mask) 35 | objects_ids = [x for x in objects_ids] 36 | 37 | return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[-1], sample_id=index) 38 | -------------------------------------------------------------------------------- /isegm/data/datasets/oai_zib.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class OAIZIBDataset(ISDataset): 11 | def __init__(self, dataset_path, split='test', 12 | images_dir_name='image', masks_dir_name='annotations', 13 | **kwargs): 14 | super(OAIZIBDataset, self).__init__(**kwargs) 15 | 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / split / images_dir_name 18 | self._insts_path = self.dataset_path / split / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.png'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.png')} 22 | 23 | assert len(self.dataset_samples) > 0 24 | 25 | def get_sample(self, index) -> DSample: 26 | image_name = self.dataset_samples[index] 27 | image_path = str(self._images_path / image_name) 28 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 29 | 30 | image = cv2.imread(image_path) 31 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 32 | instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.uint8) 33 | 34 | # FC and TC 35 | objects_ids = np.unique(instances_mask) 36 | objects_ids = [x for x in objects_ids if x != 0] 37 | 38 | return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[-1], sample_id=index) 39 | -------------------------------------------------------------------------------- /isegm/data/datasets/openimages.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle as pkl 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from isegm.data.base import ISDataset 10 | from isegm.data.sample import DSample 11 | 12 | 13 | class OpenImagesDataset(ISDataset): 14 | def __init__(self, dataset_path, split='train', **kwargs): 15 | super().__init__(**kwargs) 16 | assert split in {'train', 'val', 'test'} 17 | 18 | self.dataset_path = Path(dataset_path) 19 | self._split_path = self.dataset_path / split 20 | self._images_path = self._split_path / 'images' 21 | self._masks_path = self._split_path / 'masks' 22 | self.dataset_split = split 23 | 24 | clean_anno_path = self._split_path / f'{split}-annotations-object-segmentation_clean.pkl' 25 | if os.path.exists(clean_anno_path): 26 | with clean_anno_path.open('rb') as f: 27 | annotations = pkl.load(f) 28 | else: 29 | raise RuntimeError(f"Can't find annotations at {clean_anno_path}") 30 | self.image_id_to_masks = annotations['image_id_to_masks'] 31 | self.dataset_samples = annotations['dataset_samples'] 32 | 33 | def get_sample(self, index) -> DSample: 34 | image_id = self.dataset_samples[index] 35 | 36 | image_path = str(self._images_path / f'{image_id}.jpg') 37 | image = cv2.imread(image_path) 38 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 39 | 40 | mask_paths = self.image_id_to_masks[image_id] 41 | # select random mask for an image 42 | mask_path = str(self._masks_path / random.choice(mask_paths)) 43 | instances_mask = cv2.imread(mask_path) 44 | instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY) 45 | instances_mask[instances_mask > 0] = 1 46 | instances_mask = instances_mask.astype(np.int32) 47 | 48 | min_width = min(image.shape[1], instances_mask.shape[1]) 49 | min_height = min(image.shape[0], instances_mask.shape[0]) 50 | 51 | if image.shape[0] != min_height or image.shape[1] != min_width: 52 | image = cv2.resize(image, (min_width, min_height), interpolation=cv2.INTER_LINEAR) 53 | if instances_mask.shape[0] != min_height or instances_mask.shape[1] != min_width: 54 | instances_mask = cv2.resize(instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST) 55 | 56 | object_ids = [1] if instances_mask.sum() > 0 else [] 57 | 58 | return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index) 59 | -------------------------------------------------------------------------------- /isegm/data/datasets/pascalvoc.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from isegm.data.base import ISDataset 8 | from isegm.data.sample import DSample 9 | 10 | 11 | class PascalVocDataset(ISDataset): 12 | def __init__(self, dataset_path, split='train', **kwargs): 13 | super().__init__(**kwargs) 14 | assert split in {'train', 'val', 'trainval', 'test'} 15 | 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / "JPEGImages" 18 | self._insts_path = self.dataset_path / "SegmentationObject" 19 | self.dataset_split = split 20 | 21 | if split == 'test': 22 | with open(self.dataset_path / f'ImageSets/Segmentation/test.pickle', 'rb') as f: 23 | self.dataset_samples, self.instance_ids = pkl.load(f) 24 | else: 25 | with open(self.dataset_path / f'ImageSets/Segmentation/{split}.txt', 'r') as f: 26 | self.dataset_samples = [name.strip() for name in f.readlines()] 27 | 28 | def get_sample(self, index) -> DSample: 29 | sample_id = self.dataset_samples[index] 30 | image_path = str(self._images_path / f'{sample_id}.jpg') 31 | mask_path = str(self._insts_path / f'{sample_id}.png') 32 | 33 | image = cv2.imread(image_path) 34 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 35 | instances_mask = cv2.imread(mask_path) 36 | instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) 37 | if self.dataset_split == 'test': 38 | instance_id = self.instance_ids[index] 39 | mask = np.zeros_like(instances_mask) 40 | mask[instances_mask == 220] = 220 # ignored area 41 | mask[instances_mask == instance_id] = 1 42 | objects_ids = [1] 43 | instances_mask = mask 44 | else: 45 | objects_ids = np.unique(instances_mask) 46 | objects_ids = [x for x in objects_ids if x != 0 and x != 220] 47 | 48 | return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) 49 | -------------------------------------------------------------------------------- /isegm/data/datasets/ssTEM.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class ssTEMDataset(ISDataset): 11 | def __init__(self, dataset_path, 12 | images_dir_name='raw', masks_dir_name='mitochondria', 13 | **kwargs): 14 | super(ssTEMDataset, self).__init__(**kwargs) 15 | 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / images_dir_name 18 | self._insts_path = self.dataset_path / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.tif'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.png')} 22 | 23 | def get_sample(self, index) -> DSample: 24 | image_name = self.dataset_samples[index] 25 | image_path = str(self._images_path / image_name) 26 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 27 | 28 | image = cv2.imread(image_path) 29 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.uint8) 31 | 32 | connectivity = 4 33 | output = cv2.connectedComponentsWithStats(instances_mask, connectivity) 34 | label_mask = output[1] 35 | objects_ids = np.unique(label_mask) 36 | objects_ids = [x for x in objects_ids if x != 0] 37 | 38 | return DSample(image, label_mask, objects_ids=objects_ids, ignore_ids=[-1], sample_id=index) 39 | -------------------------------------------------------------------------------- /isegm/engine/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from isegm.utils.log import logger 4 | import isegm.utils.lr_decay as lrd 5 | 6 | def get_optimizer(model, opt_name, opt_kwargs): 7 | params = [] 8 | base_lr = opt_kwargs['lr'] 9 | for name, param in model.named_parameters(): 10 | param_group = {'params': [param]} 11 | if not param.requires_grad: 12 | params.append(param_group) 13 | continue 14 | 15 | if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): 16 | logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') 17 | param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult 18 | 19 | params.append(param_group) 20 | 21 | optimizer = { 22 | 'sgd': torch.optim.SGD, 23 | 'adam': torch.optim.Adam, 24 | 'adamw': torch.optim.AdamW 25 | }[opt_name.lower()](params, **opt_kwargs) 26 | 27 | return optimizer 28 | 29 | def get_optimizer_with_layerwise_decay(model, opt_name, opt_kwargs): 30 | # build optimizer with layer-wise lr decay (lrd) 31 | lr = opt_kwargs['lr'] 32 | param_groups = lrd.param_groups_lrd(model, lr, weight_decay=0.02, 33 | no_weight_decay_list=model.backbone.no_weight_decay(), 34 | layer_decay=0.75 35 | ) 36 | optimizer = { 37 | 'sgd': torch.optim.SGD, 38 | 'adam': torch.optim.Adam, 39 | 'adamw': torch.optim.AdamW 40 | }[opt_name.lower()](param_groups, **opt_kwargs) 41 | 42 | return optimizer -------------------------------------------------------------------------------- /isegm/inference/clicker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import cv2 4 | 5 | 6 | class Clicker(object): 7 | def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0): 8 | self.click_indx_offset = click_indx_offset 9 | if gt_mask is not None: 10 | self.gt_mask = gt_mask == 1 11 | self.not_ignore_mask = gt_mask != ignore_label 12 | else: 13 | self.gt_mask = None 14 | 15 | self.reset_clicks() 16 | 17 | if init_clicks is not None: 18 | for click in init_clicks: 19 | self.add_click(click) 20 | 21 | def make_next_click(self, pred_mask): 22 | assert self.gt_mask is not None 23 | click = self._get_next_click(pred_mask) 24 | self.add_click(click) 25 | 26 | def get_clicks(self, clicks_limit=None): 27 | return self.clicks_list[:clicks_limit] 28 | 29 | def _get_next_click(self, pred_mask, padding=True): 30 | fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) 31 | fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) 32 | 33 | if padding: 34 | fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') 35 | fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') 36 | 37 | fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) 38 | fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) 39 | 40 | if padding: 41 | fn_mask_dt = fn_mask_dt[1:-1, 1:-1] 42 | fp_mask_dt = fp_mask_dt[1:-1, 1:-1] 43 | 44 | fn_mask_dt = fn_mask_dt * self.not_clicked_map 45 | fp_mask_dt = fp_mask_dt * self.not_clicked_map 46 | 47 | fn_max_dist = np.max(fn_mask_dt) 48 | fp_max_dist = np.max(fp_mask_dt) 49 | 50 | is_positive = fn_max_dist > fp_max_dist 51 | if is_positive: 52 | coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] 53 | else: 54 | coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] 55 | 56 | return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) 57 | 58 | def add_click(self, click): 59 | coords = click.coords 60 | 61 | click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks 62 | if click.is_positive: 63 | self.num_pos_clicks += 1 64 | else: 65 | self.num_neg_clicks += 1 66 | 67 | self.clicks_list.append(click) 68 | if self.gt_mask is not None: 69 | self.not_clicked_map[coords[0], coords[1]] = False 70 | 71 | def _remove_last_click(self): 72 | click = self.clicks_list.pop() 73 | coords = click.coords 74 | 75 | if click.is_positive: 76 | self.num_pos_clicks -= 1 77 | else: 78 | self.num_neg_clicks -= 1 79 | 80 | if self.gt_mask is not None: 81 | self.not_clicked_map[coords[0], coords[1]] = True 82 | 83 | def reset_clicks(self): 84 | if self.gt_mask is not None: 85 | self.not_clicked_map = np.ones_like(self.gt_mask, dtype=bool) 86 | 87 | self.num_pos_clicks = 0 88 | self.num_neg_clicks = 0 89 | 90 | self.clicks_list = [] 91 | 92 | def get_state(self): 93 | return deepcopy(self.clicks_list) 94 | 95 | def set_state(self, state): 96 | self.reset_clicks() 97 | for click in state: 98 | self.add_click(click) 99 | 100 | def __len__(self): 101 | return len(self.clicks_list) 102 | 103 | 104 | class Click: 105 | def __init__(self, is_positive, coords, indx=None): 106 | self.is_positive = is_positive 107 | self.coords = coords 108 | self.indx = indx 109 | 110 | @property 111 | def coords_and_indx(self): 112 | return (*self.coords, self.indx) 113 | 114 | def copy(self, **kwargs): 115 | self_copy = deepcopy(self) 116 | for k, v in kwargs.items(): 117 | setattr(self_copy, k, v) 118 | return self_copy 119 | -------------------------------------------------------------------------------- /isegm/inference/evaluation.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from isegm.inference import utils 7 | from isegm.inference.clicker import Clicker 8 | 9 | try: 10 | get_ipython() 11 | from tqdm import tqdm_notebook as tqdm 12 | except NameError: 13 | from tqdm import tqdm 14 | 15 | 16 | def evaluate_dataset(dataset, predictor, **kwargs): 17 | all_ious = [] 18 | 19 | start_time = time() 20 | for index in tqdm(range(len(dataset)), leave=False): 21 | sample = dataset.get_sample(index) 22 | 23 | for object_id in sample.objects_ids: 24 | _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask(object_id), predictor, 25 | sample_id=index, **kwargs) 26 | all_ious.append(sample_ious) 27 | end_time = time() 28 | elapsed_time = end_time - start_time 29 | 30 | return all_ious, elapsed_time 31 | 32 | 33 | def evaluate_sample(image, gt_mask, predictor, max_iou_thr, 34 | pred_thr=0.49, min_clicks=1, max_clicks=20, 35 | sample_id=None, callback=None): 36 | clicker = Clicker(gt_mask=gt_mask) 37 | pred_mask = np.zeros_like(gt_mask) 38 | ious_list = [] 39 | 40 | with torch.no_grad(): 41 | predictor.set_input_image(image) 42 | 43 | for click_indx in range(max_clicks): 44 | clicker.make_next_click(pred_mask) 45 | pred_probs = predictor.get_prediction(clicker) 46 | pred_mask = pred_probs > pred_thr 47 | 48 | if callback is not None: 49 | callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) 50 | 51 | iou = utils.get_iou(gt_mask, pred_mask) 52 | ious_list.append(iou) 53 | 54 | if iou >= max_iou_thr and click_indx + 1 >= min_clicks: 55 | break 56 | 57 | return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs 58 | -------------------------------------------------------------------------------- /isegm/inference/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BasePredictor 2 | from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor 3 | from .brs_functors import InputOptimizer, ScaleBiasOptimizer 4 | from isegm.inference.transforms import ZoomIn 5 | from isegm.model.is_hrnet_model import HRNetModel 6 | 7 | 8 | def get_predictor(net, brs_mode, device, 9 | prob_thresh=0.49, 10 | with_flip=True, 11 | zoom_in_params=dict(), 12 | predictor_params=None, 13 | brs_opt_func_params=None, 14 | lbfgs_params=None): 15 | lbfgs_params_ = { 16 | 'm': 20, 17 | 'factr': 0, 18 | 'pgtol': 1e-8, 19 | 'maxfun': 20, 20 | } 21 | 22 | predictor_params_ = { 23 | 'optimize_after_n_clicks': 1 24 | } 25 | 26 | if zoom_in_params is not None: 27 | zoom_in = ZoomIn(**zoom_in_params) 28 | else: 29 | zoom_in = None 30 | 31 | if lbfgs_params is not None: 32 | lbfgs_params_.update(lbfgs_params) 33 | lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] 34 | 35 | if brs_opt_func_params is None: 36 | brs_opt_func_params = dict() 37 | 38 | if isinstance(net, (list, tuple)): 39 | assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode." 40 | 41 | if brs_mode == 'NoBRS': 42 | if predictor_params is not None: 43 | predictor_params_.update(predictor_params) 44 | predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) 45 | elif brs_mode.startswith('f-BRS'): 46 | predictor_params_.update({ 47 | 'net_clicks_limit': 8, 48 | }) 49 | if predictor_params is not None: 50 | predictor_params_.update(predictor_params) 51 | 52 | insertion_mode = { 53 | 'f-BRS-A': 'after_c4', 54 | 'f-BRS-B': 'after_aspp', 55 | 'f-BRS-C': 'after_deeplab' 56 | }[brs_mode] 57 | 58 | opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, 59 | with_flip=with_flip, 60 | optimizer_params=lbfgs_params_, 61 | **brs_opt_func_params) 62 | 63 | if isinstance(net, HRNetModel): 64 | FeaturePredictor = HRNetFeatureBRSPredictor 65 | insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] 66 | else: 67 | FeaturePredictor = FeatureBRSPredictor 68 | 69 | predictor = FeaturePredictor(net, device, 70 | opt_functor=opt_functor, 71 | with_flip=with_flip, 72 | insertion_mode=insertion_mode, 73 | zoom_in=zoom_in, 74 | **predictor_params_) 75 | elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': 76 | use_dmaps = brs_mode == 'DistMap-BRS' 77 | 78 | predictor_params_.update({ 79 | 'net_clicks_limit': 5, 80 | }) 81 | if predictor_params is not None: 82 | predictor_params_.update(predictor_params) 83 | 84 | opt_functor = InputOptimizer(prob_thresh=prob_thresh, 85 | with_flip=with_flip, 86 | optimizer_params=lbfgs_params_, 87 | **brs_opt_func_params) 88 | 89 | predictor = InputBRSPredictor(net, device, 90 | optimize_target='dmaps' if use_dmaps else 'rgb', 91 | opt_functor=opt_functor, 92 | with_flip=with_flip, 93 | zoom_in=zoom_in, 94 | **predictor_params_) 95 | else: 96 | raise NotImplementedError 97 | 98 | return predictor 99 | -------------------------------------------------------------------------------- /isegm/inference/predictors/brs_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from isegm.model.losses import SigmoidBinaryCrossEntropyLoss 4 | 5 | 6 | class BRSMaskLoss(torch.nn.Module): 7 | def __init__(self, eps=1e-5): 8 | super().__init__() 9 | self._eps = eps 10 | 11 | def forward(self, result, pos_mask, neg_mask): 12 | pos_diff = (1 - result) * pos_mask 13 | pos_target = torch.sum(pos_diff ** 2) 14 | pos_target = pos_target / (torch.sum(pos_mask) + self._eps) 15 | 16 | neg_diff = result * neg_mask 17 | neg_target = torch.sum(neg_diff ** 2) 18 | neg_target = neg_target / (torch.sum(neg_mask) + self._eps) 19 | 20 | loss = pos_target + neg_target 21 | 22 | with torch.no_grad(): 23 | f_max_pos = torch.max(torch.abs(pos_diff)).item() 24 | f_max_neg = torch.max(torch.abs(neg_diff)).item() 25 | 26 | return loss, f_max_pos, f_max_neg 27 | 28 | 29 | class OracleMaskLoss(torch.nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self.gt_mask = None 33 | self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) 34 | self.predictor = None 35 | self.history = [] 36 | 37 | def set_gt_mask(self, gt_mask): 38 | self.gt_mask = gt_mask 39 | self.history = [] 40 | 41 | def forward(self, result, pos_mask, neg_mask): 42 | gt_mask = self.gt_mask.to(result.device) 43 | if self.predictor.object_roi is not None: 44 | r1, r2, c1, c2 = self.predictor.object_roi[:4] 45 | gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] 46 | gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) 47 | 48 | if result.shape[0] == 2: 49 | gt_mask_flipped = torch.flip(gt_mask, dims=[3]) 50 | gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) 51 | 52 | loss = self.loss(result, gt_mask) 53 | self.history.append(loss.detach().cpu().numpy()[0]) 54 | 55 | if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: 56 | return 0, 0, 0 57 | 58 | return loss, 1.0, 1.0 59 | -------------------------------------------------------------------------------- /isegm/inference/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SigmoidForPred 2 | from .flip import AddHorizontalFlip 3 | from .zoom_in import ZoomIn 4 | from .limit_longest_side import LimitLongestSide 5 | from .crops import Crops 6 | -------------------------------------------------------------------------------- /isegm/inference/transforms/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseTransform(object): 5 | def __init__(self): 6 | self.image_changed = False 7 | 8 | def transform(self, image_nd, clicks_lists): 9 | raise NotImplementedError 10 | 11 | def inv_transform(self, prob_map): 12 | raise NotImplementedError 13 | 14 | def reset(self): 15 | raise NotImplementedError 16 | 17 | def get_state(self): 18 | raise NotImplementedError 19 | 20 | def set_state(self, state): 21 | raise NotImplementedError 22 | 23 | 24 | class SigmoidForPred(BaseTransform): 25 | def transform(self, image_nd, clicks_lists): 26 | return image_nd, clicks_lists 27 | 28 | def inv_transform(self, prob_map): 29 | return torch.sigmoid(prob_map) 30 | 31 | def reset(self): 32 | pass 33 | 34 | def get_state(self): 35 | return None 36 | 37 | def set_state(self, state): 38 | pass 39 | -------------------------------------------------------------------------------- /isegm/inference/transforms/crops.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | from typing import List 6 | 7 | from isegm.inference.clicker import Click 8 | from .base import BaseTransform 9 | 10 | 11 | class Crops(BaseTransform): 12 | def __init__(self, crop_size=(320, 480), min_overlap=0.2): 13 | super().__init__() 14 | self.crop_height, self.crop_width = crop_size 15 | self.min_overlap = min_overlap 16 | 17 | self.x_offsets = None 18 | self.y_offsets = None 19 | self._counts = None 20 | 21 | def transform(self, image_nd, clicks_lists: List[List[Click]]): 22 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 23 | image_height, image_width = image_nd.shape[2:4] 24 | self._counts = None 25 | 26 | if image_height < self.crop_height or image_width < self.crop_width: 27 | return image_nd, clicks_lists 28 | 29 | self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) 30 | self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) 31 | self._counts = np.zeros((image_height, image_width)) 32 | 33 | image_crops = [] 34 | for dy in self.y_offsets: 35 | for dx in self.x_offsets: 36 | self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 37 | image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] 38 | image_crops.append(image_crop) 39 | image_crops = torch.cat(image_crops, dim=0) 40 | self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) 41 | 42 | clicks_list = clicks_lists[0] 43 | clicks_lists = [] 44 | for dy in self.y_offsets: 45 | for dx in self.x_offsets: 46 | crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list] 47 | clicks_lists.append(crop_clicks) 48 | 49 | return image_crops, clicks_lists 50 | 51 | def inv_transform(self, prob_map): 52 | if self._counts is None: 53 | return prob_map 54 | 55 | new_prob_map = torch.zeros((1, 1, *self._counts.shape), 56 | dtype=prob_map.dtype, device=prob_map.device) 57 | 58 | crop_indx = 0 59 | for dy in self.y_offsets: 60 | for dx in self.x_offsets: 61 | new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] 62 | crop_indx += 1 63 | new_prob_map = torch.div(new_prob_map, self._counts) 64 | 65 | return new_prob_map 66 | 67 | def get_state(self): 68 | return self.x_offsets, self.y_offsets, self._counts 69 | 70 | def set_state(self, state): 71 | self.x_offsets, self.y_offsets, self._counts = state 72 | 73 | def reset(self): 74 | self.x_offsets = None 75 | self.y_offsets = None 76 | self._counts = None 77 | 78 | 79 | def get_offsets(length, crop_size, min_overlap_ratio=0.2): 80 | if length == crop_size: 81 | return [0] 82 | 83 | N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) 84 | N = math.ceil(N) 85 | 86 | overlap_ratio = (N - length / crop_size) / (N - 1) 87 | overlap_width = int(crop_size * overlap_ratio) 88 | 89 | offsets = [0] 90 | for i in range(1, N): 91 | new_offset = offsets[-1] + crop_size - overlap_width 92 | if new_offset + crop_size > length: 93 | new_offset = length - crop_size 94 | 95 | offsets.append(new_offset) 96 | 97 | return offsets 98 | -------------------------------------------------------------------------------- /isegm/inference/transforms/flip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import List 4 | from isegm.inference.clicker import Click 5 | from .base import BaseTransform 6 | 7 | 8 | class AddHorizontalFlip(BaseTransform): 9 | def transform(self, image_nd, clicks_lists: List[List[Click]]): 10 | assert len(image_nd.shape) == 4 11 | image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) 12 | 13 | image_width = image_nd.shape[3] 14 | clicks_lists_flipped = [] 15 | for clicks_list in clicks_lists: 16 | clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1)) 17 | for click in clicks_list] 18 | clicks_lists_flipped.append(clicks_list_flipped) 19 | clicks_lists = clicks_lists + clicks_lists_flipped 20 | 21 | return image_nd, clicks_lists 22 | 23 | def inv_transform(self, prob_map): 24 | assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 25 | num_maps = prob_map.shape[0] // 2 26 | prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] 27 | 28 | return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) 29 | 30 | def get_state(self): 31 | return None 32 | 33 | def set_state(self, state): 34 | pass 35 | 36 | def reset(self): 37 | pass 38 | -------------------------------------------------------------------------------- /isegm/inference/transforms/limit_longest_side.py: -------------------------------------------------------------------------------- 1 | from .zoom_in import ZoomIn, get_roi_image_nd 2 | 3 | 4 | class LimitLongestSide(ZoomIn): 5 | def __init__(self, max_size=800): 6 | super().__init__(target_size=max_size, skip_clicks=0) 7 | 8 | def transform(self, image_nd, clicks_lists): 9 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 10 | image_max_size = max(image_nd.shape[2:4]) 11 | self.image_changed = False 12 | 13 | if image_max_size <= self.target_size: 14 | return image_nd, clicks_lists 15 | self._input_image = image_nd 16 | 17 | self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) 18 | self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) 19 | self.image_changed = True 20 | 21 | tclicks_lists = [self._transform_clicks(clicks_lists[0])] 22 | return self._roi_image, tclicks_lists 23 | -------------------------------------------------------------------------------- /isegm/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/isegm/model/__init__.py -------------------------------------------------------------------------------- /isegm/model/initializer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Initializer(object): 7 | def __init__(self, local_init=True, gamma=None): 8 | self.local_init = local_init 9 | self.gamma = gamma 10 | 11 | def __call__(self, m): 12 | if getattr(m, '__initialized', False): 13 | return 14 | 15 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, 16 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, 17 | nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: 18 | if m.weight is not None: 19 | self._init_gamma(m.weight.data) 20 | if m.bias is not None: 21 | self._init_beta(m.bias.data) 22 | else: 23 | if getattr(m, 'weight', None) is not None: 24 | self._init_weight(m.weight.data) 25 | if getattr(m, 'bias', None) is not None: 26 | self._init_bias(m.bias.data) 27 | 28 | if self.local_init: 29 | object.__setattr__(m, '__initialized', True) 30 | 31 | def _init_weight(self, data): 32 | nn.init.uniform_(data, -0.07, 0.07) 33 | 34 | def _init_bias(self, data): 35 | nn.init.constant_(data, 0) 36 | 37 | def _init_gamma(self, data): 38 | if self.gamma is None: 39 | nn.init.constant_(data, 1.0) 40 | else: 41 | nn.init.normal_(data, 1.0, self.gamma) 42 | 43 | def _init_beta(self, data): 44 | nn.init.constant_(data, 0) 45 | 46 | 47 | class Bilinear(Initializer): 48 | def __init__(self, scale, groups, in_channels, **kwargs): 49 | super().__init__(**kwargs) 50 | self.scale = scale 51 | self.groups = groups 52 | self.in_channels = in_channels 53 | 54 | def _init_weight(self, data): 55 | """Reset the weight and bias.""" 56 | bilinear_kernel = self.get_bilinear_kernel(self.scale) 57 | weight = torch.zeros_like(data) 58 | for i in range(self.in_channels): 59 | if self.groups == 1: 60 | j = i 61 | else: 62 | j = 0 63 | weight[i, j] = bilinear_kernel 64 | data[:] = weight 65 | 66 | @staticmethod 67 | def get_bilinear_kernel(scale): 68 | """Generate a bilinear upsampling kernel.""" 69 | kernel_size = 2 * scale - scale % 2 70 | scale = (kernel_size + 1) // 2 71 | center = scale - 0.5 * (1 + kernel_size % 2) 72 | 73 | og = np.ogrid[:kernel_size, :kernel_size] 74 | kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) 75 | 76 | return torch.tensor(kernel, dtype=torch.float32) 77 | 78 | 79 | class XavierGluon(Initializer): 80 | def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): 81 | super().__init__(**kwargs) 82 | 83 | self.rnd_type = rnd_type 84 | self.factor_type = factor_type 85 | self.magnitude = float(magnitude) 86 | 87 | def _init_weight(self, arr): 88 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) 89 | 90 | if self.factor_type == 'avg': 91 | factor = (fan_in + fan_out) / 2.0 92 | elif self.factor_type == 'in': 93 | factor = fan_in 94 | elif self.factor_type == 'out': 95 | factor = fan_out 96 | else: 97 | raise ValueError('Incorrect factor type') 98 | scale = np.sqrt(self.magnitude / factor) 99 | 100 | if self.rnd_type == 'uniform': 101 | nn.init.uniform_(arr, -scale, scale) 102 | elif self.rnd_type == 'gaussian': 103 | nn.init.normal_(arr, 0, scale) 104 | else: 105 | raise ValueError('Unknown random type') 106 | -------------------------------------------------------------------------------- /isegm/model/is_deeplab_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from isegm.utils.serialization import serialize 4 | from .is_model import ISModel 5 | from .modeling.deeplab_v3 import DeepLabV3Plus 6 | from .modeling.basic_blocks import SepConvHead 7 | from isegm.model.modifiers import LRMult 8 | 9 | 10 | class DeeplabModel(ISModel): 11 | @serialize 12 | def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, 13 | backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs): 14 | super().__init__(norm_layer=norm_layer, **kwargs) 15 | 16 | self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout, 17 | norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer) 18 | self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult)) 19 | self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, 20 | num_layers=2, norm_layer=norm_layer) 21 | 22 | def backbone_forward(self, image, coord_features=None): 23 | backbone_features = self.feature_extractor(image, coord_features) 24 | 25 | return {'instances': self.head(backbone_features[0])} 26 | -------------------------------------------------------------------------------- /isegm/model/is_hrformer_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from collections import OrderedDict 5 | 6 | from isegm.utils.serialization import serialize 7 | from .is_model import ISModel 8 | from isegm.model.modifiers import LRMult 9 | from .modeling.hrformer import HRT_B_OCR_V3 10 | 11 | class HRFormerModel(ISModel): 12 | @serialize 13 | def __init__( 14 | self, 15 | num_classes=1, 16 | in_ch=6, 17 | backbone_lr_mult=0.1, 18 | **kwargs 19 | ): 20 | 21 | super().__init__(**kwargs) 22 | 23 | self.feature_extractor = HRT_B_OCR_V3(num_classes, in_ch) 24 | self.feature_extractor.apply(LRMult(backbone_lr_mult)) 25 | 26 | def backbone_forward(self, image, coord_features=None): 27 | backbone_features = self.feature_extractor(image) 28 | return {'instances': backbone_features[0], 'instances_aux': backbone_features[1]} 29 | 30 | def init_weight(self, pretrained=None): 31 | if pretrained is not None: 32 | state_dict = torch.load(pretrained)['model'] 33 | state_dict_rename = OrderedDict() 34 | for k, v in state_dict.items(): 35 | state_dict_rename['backbone.' + k] = v 36 | 37 | ori_proj_weight = state_dict_rename['backbone.conv1.weight'] 38 | state_dict_rename['backbone.conv1.weight'] = torch.cat([ori_proj_weight, ori_proj_weight], dim=1) 39 | 40 | self.feature_extractor.load_state_dict(state_dict_rename, False) 41 | print('Successfully loaded pretrained model.') 42 | -------------------------------------------------------------------------------- /isegm/model/is_hrnet_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from isegm.utils.serialization import serialize 4 | from .is_model import ISModel 5 | from .modeling.hrnet_ocr import HighResolutionNet 6 | from isegm.model.modifiers import LRMult 7 | 8 | 9 | class HRNetModel(ISModel): 10 | @serialize 11 | def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1, 12 | norm_layer=nn.BatchNorm2d, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small, 16 | num_classes=1, norm_layer=norm_layer) 17 | self.feature_extractor.apply(LRMult(backbone_lr_mult)) 18 | if ocr_width > 0: 19 | self.feature_extractor.ocr_distri_head.apply(LRMult(1.0)) 20 | self.feature_extractor.ocr_gather_head.apply(LRMult(1.0)) 21 | self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0)) 22 | 23 | def backbone_forward(self, image, coord_features=None): 24 | net_outputs = self.feature_extractor(image, coord_features) 25 | 26 | return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]} 27 | -------------------------------------------------------------------------------- /isegm/model/is_plainvit_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from isegm.utils.serialization import serialize 4 | from .is_model import ISModel 5 | from .modeling.models_vit import VisionTransformer, PatchEmbed 6 | from .modeling.swin_transformer import SwinTransfomerSegHead 7 | 8 | 9 | class SimpleFPN(nn.Module): 10 | def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]): 11 | super().__init__() 12 | self.down_4_chan = max(out_dims[0]*2, in_dim // 2) 13 | self.down_4 = nn.Sequential( 14 | nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2), 15 | nn.GroupNorm(1, self.down_4_chan), 16 | nn.GELU(), 17 | nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2), 18 | nn.GroupNorm(1, self.down_4_chan // 2), 19 | nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1), 20 | nn.GroupNorm(1, out_dims[0]), 21 | nn.GELU() 22 | ) 23 | self.down_8_chan = max(out_dims[1], in_dim // 2) 24 | self.down_8 = nn.Sequential( 25 | nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2), 26 | nn.GroupNorm(1, self.down_8_chan), 27 | nn.Conv2d(self.down_8_chan, out_dims[1], 1), 28 | nn.GroupNorm(1, out_dims[1]), 29 | nn.GELU() 30 | ) 31 | self.down_16 = nn.Sequential( 32 | nn.Conv2d(in_dim, out_dims[2], 1), 33 | nn.GroupNorm(1, out_dims[2]), 34 | nn.GELU() 35 | ) 36 | self.down_32_chan = max(out_dims[3], in_dim * 2) 37 | self.down_32 = nn.Sequential( 38 | nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2), 39 | nn.GroupNorm(1, self.down_32_chan), 40 | nn.Conv2d(self.down_32_chan, out_dims[3], 1), 41 | nn.GroupNorm(1, out_dims[3]), 42 | nn.GELU() 43 | ) 44 | 45 | self.init_weights() 46 | 47 | def init_weights(self): 48 | pass 49 | 50 | def forward(self, x): 51 | x_down_4 = self.down_4(x) 52 | x_down_8 = self.down_8(x) 53 | x_down_16 = self.down_16(x) 54 | x_down_32 = self.down_32(x) 55 | 56 | return [x_down_4, x_down_8, x_down_16, x_down_32] 57 | 58 | 59 | class PlainVitModel(ISModel): 60 | @serialize 61 | def __init__( 62 | self, 63 | backbone_params={}, 64 | neck_params={}, 65 | head_params={}, 66 | random_split=False, 67 | **kwargs 68 | ): 69 | 70 | super().__init__(**kwargs) 71 | self.random_split = random_split 72 | 73 | self.patch_embed_coords = PatchEmbed( 74 | img_size= backbone_params['img_size'], 75 | patch_size=backbone_params['patch_size'], 76 | in_chans=3 if self.with_prev_mask else 2, 77 | embed_dim=backbone_params['embed_dim'], 78 | ) 79 | 80 | self.backbone = VisionTransformer(**backbone_params) 81 | self.neck = SimpleFPN(**neck_params) 82 | self.head = SwinTransfomerSegHead(**head_params) 83 | 84 | def backbone_forward(self, image, coord_features=None): 85 | coord_features = self.patch_embed_coords(coord_features) 86 | backbone_features = self.backbone.forward_backbone(image, coord_features, self.random_split) 87 | 88 | # Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 89 | B, N, C = backbone_features.shape 90 | grid_size = self.backbone.patch_embed.grid_size 91 | 92 | backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) 93 | multi_scale_features = self.neck(backbone_features) 94 | 95 | return {'instances': self.head(multi_scale_features), 'instances_aux': None} 96 | -------------------------------------------------------------------------------- /isegm/model/is_segformer_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from isegm.utils.serialization import serialize 4 | from .is_model import ISModel 5 | from isegm.model.modifiers import LRMult 6 | from .modeling.segformer import MixVisionTransformer, SegformerHead 7 | 8 | 9 | class SegformerModel(ISModel): 10 | @serialize 11 | def __init__( 12 | self, 13 | backbone_params=None, 14 | decode_head_params=None, 15 | backbone_lr_mult=0.1, 16 | **kwargs 17 | ): 18 | 19 | super().__init__(**kwargs) 20 | 21 | self.feature_extractor = MixVisionTransformer(**backbone_params) 22 | self.feature_extractor.apply(LRMult(backbone_lr_mult)) 23 | 24 | self.head = SegformerHead(**decode_head_params) 25 | 26 | def backbone_forward(self, image, coord_features=None): 27 | backbone_features = self.feature_extractor(image, coord_features) 28 | return {'instances': self.head(backbone_features), 'instances_aux': None} 29 | -------------------------------------------------------------------------------- /isegm/model/is_swinformer_model.py: -------------------------------------------------------------------------------- 1 | from isegm.utils.serialization import serialize 2 | from .is_model import ISModel 3 | from .modeling.swin_transformer import SwinTransformer, SwinTransfomerSegHead 4 | 5 | class SwinformerModel(ISModel): 6 | @serialize 7 | def __init__( 8 | self, 9 | backbone_params={}, 10 | head_params={}, 11 | **kwargs 12 | ): 13 | 14 | super().__init__(**kwargs) 15 | 16 | self.backbone = SwinTransformer(**backbone_params) 17 | self.head = SwinTransfomerSegHead(**head_params) 18 | 19 | def backbone_forward(self, image, coord_features=None): 20 | backbone_features = self.backbone(image, coord_features) 21 | return {'instances': self.head(backbone_features), 'instances_aux': None} -------------------------------------------------------------------------------- /isegm/model/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from isegm.utils import misc 5 | 6 | 7 | class TrainMetric(object): 8 | def __init__(self, pred_outputs, gt_outputs): 9 | self.pred_outputs = pred_outputs 10 | self.gt_outputs = gt_outputs 11 | 12 | def update(self, *args, **kwargs): 13 | raise NotImplementedError 14 | 15 | def get_epoch_value(self): 16 | raise NotImplementedError 17 | 18 | def reset_epoch_stats(self): 19 | raise NotImplementedError 20 | 21 | def log_states(self, sw, tag_prefix, global_step): 22 | pass 23 | 24 | @property 25 | def name(self): 26 | return type(self).__name__ 27 | 28 | 29 | class AdaptiveIoU(TrainMetric): 30 | def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, 31 | ignore_label=-1, from_logits=True, 32 | pred_output='instances', gt_output='instances'): 33 | super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) 34 | self._ignore_label = ignore_label 35 | self._from_logits = from_logits 36 | self._iou_thresh = init_thresh 37 | self._thresh_step = thresh_step 38 | self._thresh_beta = thresh_beta 39 | self._iou_beta = iou_beta 40 | self._ema_iou = 0.0 41 | self._epoch_iou_sum = 0.0 42 | self._epoch_batch_count = 0 43 | 44 | def update(self, pred, gt): 45 | gt_mask = gt > 0.5 46 | if self._from_logits: 47 | pred = torch.sigmoid(pred) 48 | 49 | gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() 50 | if np.all(gt_mask_area == 0): 51 | return 52 | 53 | ignore_mask = gt == self._ignore_label 54 | max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() 55 | best_thresh = self._iou_thresh 56 | for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: 57 | temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() 58 | if temp_iou > max_iou: 59 | max_iou = temp_iou 60 | best_thresh = t 61 | 62 | self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh 63 | self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou 64 | self._epoch_iou_sum += max_iou 65 | self._epoch_batch_count += 1 66 | 67 | def get_epoch_value(self): 68 | if self._epoch_batch_count > 0: 69 | return self._epoch_iou_sum / self._epoch_batch_count 70 | else: 71 | return 0.0 72 | 73 | def reset_epoch_stats(self): 74 | self._epoch_iou_sum = 0.0 75 | self._epoch_batch_count = 0 76 | 77 | def log_states(self, sw, tag_prefix, global_step): 78 | sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) 79 | sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) 80 | 81 | @property 82 | def iou_thresh(self): 83 | return self._iou_thresh 84 | 85 | 86 | def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): 87 | if ignore_mask is not None: 88 | pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) 89 | 90 | reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) 91 | union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() 92 | intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() 93 | nonzero = union > 0 94 | 95 | iou = intersection[nonzero] / union[nonzero] 96 | if not keep_ignore: 97 | return iou 98 | else: 99 | result = np.full_like(intersection, -1) 100 | result[nonzero] = iou 101 | return result 102 | -------------------------------------------------------------------------------- /isegm/model/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/isegm/model/modeling/__init__.py -------------------------------------------------------------------------------- /isegm/model/modeling/basic_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from isegm.model import ops 4 | 5 | 6 | class ConvHead(nn.Module): 7 | def __init__(self, out_channels, in_channels=32, num_layers=1, 8 | kernel_size=3, padding=1, 9 | norm_layer=nn.BatchNorm2d): 10 | super(ConvHead, self).__init__() 11 | convhead = [] 12 | 13 | for i in range(num_layers): 14 | convhead.extend([ 15 | nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), 16 | nn.ReLU(), 17 | norm_layer(in_channels) if norm_layer is not None else nn.Identity() 18 | ]) 19 | convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) 20 | 21 | self.convhead = nn.Sequential(*convhead) 22 | 23 | def forward(self, *inputs): 24 | return self.convhead(inputs[0]) 25 | 26 | 27 | class SepConvHead(nn.Module): 28 | def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, 29 | kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, 30 | norm_layer=nn.BatchNorm2d): 31 | super(SepConvHead, self).__init__() 32 | 33 | sepconvhead = [] 34 | 35 | for i in range(num_layers): 36 | sepconvhead.append( 37 | SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, 38 | out_channels=mid_channels, 39 | dw_kernel=kernel_size, dw_padding=padding, 40 | norm_layer=norm_layer, activation='relu') 41 | ) 42 | if dropout_ratio > 0 and dropout_indx == i: 43 | sepconvhead.append(nn.Dropout(dropout_ratio)) 44 | 45 | sepconvhead.append( 46 | nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) 47 | ) 48 | 49 | self.layers = nn.Sequential(*sepconvhead) 50 | 51 | def forward(self, *inputs): 52 | x = inputs[0] 53 | 54 | return self.layers(x) 55 | 56 | 57 | class SeparableConv2d(nn.Module): 58 | def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, 59 | activation=None, use_bias=False, norm_layer=None): 60 | super(SeparableConv2d, self).__init__() 61 | _activation = ops.select_activation_function(activation) 62 | self.body = nn.Sequential( 63 | nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, 64 | padding=dw_padding, bias=use_bias, groups=in_channels), 65 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), 66 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(), 67 | _activation() 68 | ) 69 | 70 | def forward(self, x): 71 | return self.body(x) 72 | -------------------------------------------------------------------------------- /isegm/model/modeling/hrformer_helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/isegm/model/modeling/hrformer_helper/__init__.py -------------------------------------------------------------------------------- /isegm/model/modeling/hrformer_helper/backbone_selector.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Donny You, RainbowSecret 3 | ## Microsoft Research 4 | ## yuyua@microsoft.com 5 | ## Copyright (c) 2019 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | # from lib.models.backbones.resnet.resnet_backbone import ResNetBackbone 17 | # from lib.models.backbones.hrnet.hrnet_backbone import HRNetBackbone 18 | from .hrt.hrt_backbone import HRTBackbone 19 | # from lib.models.backbones.swin.swin_backbone import SwinTransformerBackbone 20 | from .hrt.logger import Logger as Log 21 | 22 | 23 | class BackboneSelector(object): 24 | def __init__(self, configer): 25 | self.configer = configer 26 | 27 | def get_backbone(self, **params): 28 | backbone = self.configer.get("network", "backbone") 29 | 30 | model = None 31 | # if ( 32 | # "resnet" in backbone or "resnext" in backbone or "resnest" in backbone 33 | # ) and "senet" not in backbone: 34 | # model = ResNetBackbone(self.configer)(**params) 35 | 36 | if "hrt" in backbone: 37 | # model = HRTBackbone(self.configer)(**params) 38 | pass 39 | 40 | # elif "hrnet" in backbone: 41 | # model = HRNetBackbone(self.configer)(**params) 42 | 43 | # elif "swin" in backbone: 44 | # model = SwinTransformerBackbone(self.configer)(**params) 45 | 46 | else: 47 | Log.error("Backbone {} is invalid.".format(backbone)) 48 | exit(1) 49 | 50 | return model 51 | 52 | class Test(): 53 | def __init__(): 54 | pass -------------------------------------------------------------------------------- /isegm/model/modeling/hrformer_helper/hrt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/isegm/model/modeling/hrformer_helper/hrt/__init__.py -------------------------------------------------------------------------------- /isegm/model/modeling/hrformer_helper/hrt/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/isegm/model/modeling/hrformer_helper/hrt/modules/__init__.py -------------------------------------------------------------------------------- /isegm/model/modeling/hrformer_helper/hrt/modules/bottleneck_block.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import logging 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | # from torchvision.models.utils import load_state_dict_from_url 7 | # from timm.models.registry import register_model 8 | from functools import partial 9 | 10 | BN_MOMENTUM = 0.1 11 | 12 | 13 | class Bottleneck(nn.Module): 14 | expansion = 4 15 | 16 | def __init__( 17 | self, 18 | inplanes, 19 | planes, 20 | stride=1, 21 | downsample=None, 22 | mhsa_flag=False, 23 | num_heads=1, 24 | num_halo_block=1, 25 | num_mlp_ratio=4, 26 | num_sr_ratio=1, 27 | num_resolution=None, 28 | with_rpe=False, 29 | with_ffn=True, 30 | ): 31 | super(Bottleneck, self).__init__() 32 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 33 | self.bn1 = nn.SyncBatchNorm(planes) 34 | self.conv2 = nn.Conv2d( 35 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 36 | ) 37 | self.bn2 = nn.SyncBatchNorm(planes) 38 | self.conv3 = nn.Conv2d( 39 | planes, planes * self.expansion, kernel_size=1, bias=False 40 | ) 41 | self.bn3 = nn.SyncBatchNorm(planes * self.expansion) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv3(out) 58 | out = self.bn3(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | class BottleneckDWP(nn.Module): 70 | expansion = 4 71 | 72 | def __init__( 73 | self, 74 | inplanes, 75 | planes, 76 | stride=1, 77 | downsample=None, 78 | mhsa_flag=False, 79 | num_heads=1, 80 | num_halo_block=1, 81 | num_mlp_ratio=4, 82 | num_sr_ratio=1, 83 | num_resolution=None, 84 | with_rpe=False, 85 | with_ffn=True, 86 | ): 87 | super(BottleneckDWP, self).__init__() 88 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 89 | self.bn1 = nn.SyncBatchNorm(planes, momentum=BN_MOMENTUM) 90 | self.conv2 = nn.Conv2d( 91 | planes, 92 | planes, 93 | kernel_size=3, 94 | stride=stride, 95 | padding=1, 96 | bias=False, 97 | groups=planes, 98 | ) 99 | self.bn2 = nn.SyncBatchNorm(planes, momentum=BN_MOMENTUM) 100 | self.conv3 = nn.Conv2d( 101 | planes, planes * self.expansion, kernel_size=1, bias=False 102 | ) 103 | self.bn3 = nn.SyncBatchNorm(planes * self.expansion, momentum=BN_MOMENTUM) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x): 109 | residual = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | residual = self.downsample(x) 124 | 125 | out += residual 126 | out = self.relu(out) 127 | 128 | return out -------------------------------------------------------------------------------- /isegm/model/modeling/hrformer_helper/hrt/modules/multihead_isa_pool_attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .multihead_isa_attention import MHA_, PadBlock, LocalPermuteModule 8 | 9 | class InterlacedPoolAttention(nn.Module): 10 | r""" interlaced sparse multi-head self attention (ISA) module with relative position bias. 11 | Args: 12 | dim (int): Number of input channels. 13 | window_size (tuple[int]): Window size. 14 | num_heads (int): Number of attention heads. 15 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 16 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 17 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 18 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 19 | """ 20 | def __init__(self, embed_dim, num_heads, window_size=7, 21 | rpe=True, **kwargs): 22 | super(InterlacedPoolAttention, self).__init__() 23 | 24 | self.dim = embed_dim 25 | self.num_heads = num_heads 26 | self.window_size = window_size 27 | self.with_rpe = rpe 28 | self.attn = MHA_(embed_dim, num_heads, rpe=rpe, window_size=window_size, **kwargs) 29 | self.pad_helper = PadBlock(window_size) 30 | self.permute_helper = LocalPermuteModule(window_size) 31 | 32 | def forward(self, x, H, W, **kwargs): 33 | B, N, C = x.shape 34 | x = x.view(B, H, W, C) 35 | # attention 36 | # pad 37 | x_pad = self.pad_helper.pad_if_needed(x, x.size()) 38 | # permute 39 | x_permute = self.permute_helper.permute(x_pad, x_pad.size()) 40 | # attention 41 | out, _, _ = self.attn(x_permute, x_permute, x_permute, rpe=self.with_rpe, **kwargs) 42 | # reverse permutation 43 | out = self.permute_helper.rev_permute(out, x_pad.size()) 44 | out = self.pad_helper.depad_if_needed(out, x.size()) 45 | return out.reshape(B, N, C) -------------------------------------------------------------------------------- /isegm/model/modeling/hrformer_helper/hrt/modules/transformer_block.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import math 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | from functools import partial 8 | 9 | from .multihead_isa_pool_attention import InterlacedPoolAttention 10 | from .ffn_block import MlpDWBN 11 | 12 | 13 | BN_MOMENTUM = 0.1 14 | 15 | 16 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 17 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 18 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 19 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 20 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 21 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 22 | 'survival rate' as the argument. 23 | """ 24 | if drop_prob == 0.0 or not training: 25 | return x 26 | keep_prob = 1 - drop_prob 27 | shape = (x.shape[0],) + (1,) * ( 28 | x.ndim - 1 29 | ) # work with diff dim tensors, not just 2D ConvNets 30 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 31 | random_tensor.floor_() # binarize 32 | output = x.div(keep_prob) * random_tensor 33 | return output 34 | 35 | 36 | class DropPath(nn.Module): 37 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 38 | 39 | def __init__(self, drop_prob=None): 40 | super(DropPath, self).__init__() 41 | self.drop_prob = drop_prob 42 | 43 | def forward(self, x): 44 | return drop_path(x, self.drop_prob, self.training) 45 | 46 | def extra_repr(self): 47 | # (Optional)Set the extra information about this module. You can test 48 | # it by printing an object of this class. 49 | return "drop_prob={}".format(self.drop_prob) 50 | 51 | 52 | class GeneralTransformerBlock(nn.Module): 53 | expansion = 1 54 | 55 | def __init__( 56 | self, 57 | inplanes, 58 | planes, 59 | num_heads, 60 | window_size=7, 61 | mlp_ratio=4.0, 62 | qkv_bias=True, 63 | qk_scale=None, 64 | drop=0.0, 65 | attn_drop=0.0, 66 | drop_path=0.0, 67 | act_layer=nn.GELU, 68 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 69 | ): 70 | super(GeneralTransformerBlock, self).__init__() 71 | self.dim = inplanes 72 | self.out_dim = planes 73 | self.num_heads = num_heads 74 | self.window_size = window_size 75 | self.mlp_ratio = mlp_ratio 76 | self.attn = InterlacedPoolAttention( 77 | self.dim, 78 | num_heads=num_heads, 79 | window_size=window_size, 80 | rpe=True, 81 | dropout=attn_drop, 82 | ) 83 | 84 | self.norm1 = norm_layer(self.dim) 85 | self.norm2 = norm_layer(self.out_dim) 86 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 87 | mlp_hidden_dim = int(self.dim * mlp_ratio) 88 | 89 | self.mlp = MlpDWBN( 90 | in_features=self.dim, 91 | hidden_features=mlp_hidden_dim, 92 | out_features=self.out_dim, 93 | act_layer=act_layer, 94 | dw_act_layer=act_layer, 95 | drop=drop, 96 | ) 97 | 98 | def forward(self, x, mask=None): 99 | B, C, H, W = x.size() 100 | # reshape 101 | x = x.view(B, C, -1).permute(0, 2, 1) 102 | # Attention 103 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 104 | # FFN 105 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 106 | # reshape 107 | x = x.permute(0, 2, 1).view(B, C, H, W) 108 | return x 109 | 110 | def extra_repr(self): 111 | # (Optional)Set the extra information about this module. You can test 112 | # it by printing an object of this class. 113 | return "num_heads={}, window_size={}, mlp_ratio={}".format( 114 | self.num_heads, self.window_size, self.mlp_ratio 115 | ) -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/isegm/model/modeling/mask2former_helper/ops/__init__.py -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | # from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install 14 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | from torch.autograd import gradcheck 18 | 19 | from mask2former.model.modeling.pixel_decoder.ops.functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 20 | 21 | 22 | N, M, D = 1, 2, 2 23 | Lq, L, P = 2, 2, 2 24 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 25 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 26 | S = sum([(H*W).item() for H, W in shapes]) 27 | 28 | 29 | torch.manual_seed(3) 30 | 31 | 32 | @torch.no_grad() 33 | def check_forward_equal_with_pytorch_double(): 34 | value = torch.rand(N, S, M, D).cuda() * 0.01 35 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 36 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 37 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 38 | im2col_step = 2 39 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 40 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 41 | fwdok = torch.allclose(output_cuda, output_pytorch) 42 | max_abs_err = (output_cuda - output_pytorch).abs().max() 43 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 44 | 45 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 46 | 47 | 48 | @torch.no_grad() 49 | def check_forward_equal_with_pytorch_float(): 50 | value = torch.rand(N, S, M, D).cuda() * 0.01 51 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 52 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 53 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 54 | im2col_step = 2 55 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 56 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 57 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 58 | max_abs_err = (output_cuda - output_pytorch).abs().max() 59 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 60 | 61 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 62 | 63 | 64 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 65 | 66 | value = torch.rand(N, S, M, channels).cuda() * 0.01 67 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 68 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 69 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 70 | im2col_step = 2 71 | func = MSDeformAttnFunction.apply 72 | 73 | value.requires_grad = grad_value 74 | sampling_locations.requires_grad = grad_sampling_loc 75 | attention_weights.requires_grad = grad_attn_weight 76 | 77 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 78 | 79 | print(f'* {gradok} check_gradient_numerical(D={channels})') 80 | 81 | 82 | if __name__ == '__main__': 83 | check_forward_equal_with_pytorch_double() 84 | check_forward_equal_with_pytorch_float() 85 | 86 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 87 | check_gradient_numerical(channels, True, True, True) 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /isegm/model/modeling/mask2former_helper/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | """ 4 | Various positional encodings for the transformer. 5 | """ 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, x, mask=None): 30 | if mask is None: 31 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack( 46 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 47 | ).flatten(3) 48 | pos_y = torch.stack( 49 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 50 | ).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | 54 | def __repr__(self, _repr_indent=4): 55 | head = "Positional encoding " + self.__class__.__name__ 56 | body = [ 57 | "num_pos_feats: {}".format(self.num_pos_feats), 58 | "temperature: {}".format(self.temperature), 59 | "normalize: {}".format(self.normalize), 60 | "scale: {}".format(self.scale), 61 | ] 62 | # _repr_indent = 4 63 | lines = [head] + [" " * _repr_indent + line for line in body] 64 | return "\n".join(lines) 65 | -------------------------------------------------------------------------------- /isegm/model/modeling/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s 3 | 4 | 5 | class ResNetBackbone(torch.nn.Module): 6 | def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): 7 | super(ResNetBackbone, self).__init__() 8 | 9 | if backbone == 'resnet34': 10 | pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) 11 | elif backbone == 'resnet50': 12 | pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 13 | elif backbone == 'resnet101': 14 | pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 15 | elif backbone == 'resnet152': 16 | pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 17 | else: 18 | raise RuntimeError(f'unknown backbone: {backbone}') 19 | 20 | self.conv1 = pretrained.conv1 21 | self.bn1 = pretrained.bn1 22 | self.relu = pretrained.relu 23 | self.maxpool = pretrained.maxpool 24 | self.layer1 = pretrained.layer1 25 | self.layer2 = pretrained.layer2 26 | self.layer3 = pretrained.layer3 27 | self.layer4 = pretrained.layer4 28 | 29 | def forward(self, x, additional_features=None): 30 | x = self.conv1(x) 31 | x = self.bn1(x) 32 | x = self.relu(x) 33 | if additional_features is not None: 34 | x = x + torch.nn.functional.pad(additional_features, 35 | [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)], 36 | mode='constant', value=0) 37 | x = self.maxpool(x) 38 | c1 = self.layer1(x) 39 | c2 = self.layer2(c1) 40 | c3 = self.layer3(c2) 41 | c4 = self.layer4(c3) 42 | 43 | return c1, c2, c3, c4 44 | -------------------------------------------------------------------------------- /isegm/model/modeling/swin_transformer_helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab206/AdaptiveClick/e5e7344e70c6992f095cbd78db3ddf2c7969252c/isegm/model/modeling/swin_transformer_helper/__init__.py -------------------------------------------------------------------------------- /isegm/model/modeling/swin_transformer_helper/builder.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from mmcv.utils import Registry, build_from_cfg 4 | from torch import nn 5 | 6 | BACKBONES = Registry('backbone') 7 | NECKS = Registry('neck') 8 | HEADS = Registry('head') 9 | LOSSES = Registry('loss') 10 | SEGMENTORS = Registry('segmentor') 11 | 12 | 13 | def build(cfg, registry, default_args=None): 14 | """Build a module. 15 | 16 | Args: 17 | cfg (dict, list[dict]): The config of modules, is is either a dict 18 | or a list of configs. 19 | registry (:obj:`Registry`): A registry the module belongs to. 20 | default_args (dict, optional): Default arguments to build the module. 21 | Defaults to None. 22 | 23 | Returns: 24 | nn.Module: A built nn module. 25 | """ 26 | 27 | if isinstance(cfg, list): 28 | modules = [ 29 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 30 | ] 31 | return nn.Sequential(*modules) 32 | else: 33 | return build_from_cfg(cfg, registry, default_args) 34 | 35 | 36 | def build_backbone(cfg): 37 | """Build backbone.""" 38 | return build(cfg, BACKBONES) 39 | 40 | 41 | def build_neck(cfg): 42 | """Build neck.""" 43 | return build(cfg, NECKS) 44 | 45 | 46 | def build_head(cfg): 47 | """Build head.""" 48 | return build(cfg, HEADS) 49 | 50 | 51 | def build_loss(cfg): 52 | """Build loss.""" 53 | return build(cfg, LOSSES) 54 | 55 | 56 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 57 | """Build segmentor.""" 58 | if train_cfg is not None or test_cfg is not None: 59 | warnings.warn( 60 | 'train_cfg and test_cfg is deprecated, ' 61 | 'please specify them in model', UserWarning) 62 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 63 | 'train_cfg specified in both outer field and model field ' 64 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 65 | 'test_cfg specified in both outer field and model field ' 66 | return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) -------------------------------------------------------------------------------- /isegm/model/modeling/swin_transformer_helper/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mmcv.utils import get_logger 4 | 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO): 7 | """Get the root logger. 8 | 9 | The logger will be initialized if it has not been initialized. By default a 10 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 11 | also be added. The name of the root logger is the top-level package name, 12 | e.g., "mmseg". 13 | 14 | Args: 15 | log_file (str | None): The log filename. If specified, a FileHandler 16 | will be added to the root logger. 17 | log_level (int): The root logger level. Note that only the process of 18 | rank 0 is affected, while other processes will set the level to 19 | "Error" and be silent most of the time. 20 | 21 | Returns: 22 | logging.Logger: The root logger. 23 | """ 24 | 25 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 26 | 27 | return logger -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/__init__.py: -------------------------------------------------------------------------------- 1 | from .embed import PatchEmbed 2 | from .shape_convert import nchw_to_nlc, nlc_to_nchw 3 | from .wrappers import resize, Upsample 4 | from .logger import get_root_logger 5 | from .decode_head import BaseDecodeHead 6 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, 7 | build_head, build_loss, build_segmentor) 8 | 9 | __all__ = [ 10 | 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'resize', 'Upsample', 11 | 'get_root_logger', 'BaseDecodeHead', 'BACKBONES', 'HEADS', 'LOSSES', 12 | 'SEGMENTORS', 'build_backbone', 'build_head', 'build_loss', 'build_segmentor' 13 | ] 14 | -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | 4 | 5 | def accuracy(pred, target, topk=1, thresh=None): 6 | """Calculate accuracy according to the prediction and target. 7 | 8 | Args: 9 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 10 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 11 | topk (int | tuple[int], optional): If the predictions in ``topk`` 12 | matches the target, the predictions will be regarded as 13 | correct ones. Defaults to 1. 14 | thresh (float, optional): If not None, predictions with scores under 15 | this threshold are considered incorrect. Default to None. 16 | 17 | Returns: 18 | float | tuple[float]: If the input ``topk`` is a single integer, 19 | the function will return a single float as accuracy. If 20 | ``topk`` is a tuple containing multiple integers, the 21 | function will return a tuple containing accuracies of 22 | each ``topk`` number. 23 | """ 24 | assert isinstance(topk, (int, tuple)) 25 | if isinstance(topk, int): 26 | topk = (topk, ) 27 | return_single = True 28 | else: 29 | return_single = False 30 | 31 | maxk = max(topk) 32 | if pred.size(0) == 0: 33 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 34 | return accu[0] if return_single else accu 35 | assert pred.ndim == target.ndim + 1 36 | assert pred.size(0) == target.size(0) 37 | assert maxk <= pred.size(1), \ 38 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 39 | pred_value, pred_label = pred.topk(maxk, dim=1) 40 | # transpose to shape (maxk, N, ...) 41 | pred_label = pred_label.transpose(0, 1) 42 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 43 | if thresh is not None: 44 | # Only prediction values larger than thresh are counted as correct 45 | correct = correct & (pred_value > thresh).t() 46 | res = [] 47 | for k in topk: 48 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 49 | res.append(correct_k.mul_(100.0 / target.numel())) 50 | return res[0] if return_single else res 51 | 52 | 53 | class Accuracy(nn.Module): 54 | """Accuracy calculation module.""" 55 | 56 | def __init__(self, topk=(1, ), thresh=None): 57 | """Module to calculate the accuracy. 58 | 59 | Args: 60 | topk (tuple, optional): The criterion used to calculate the 61 | accuracy. Defaults to (1,). 62 | thresh (float, optional): If not None, predictions with scores 63 | under this threshold are considered incorrect. Default to None. 64 | """ 65 | super().__init__() 66 | self.topk = topk 67 | self.thresh = thresh 68 | 69 | def forward(self, pred, target): 70 | """Forward function to calculate accuracy. 71 | 72 | Args: 73 | pred (torch.Tensor): Prediction of models. 74 | target (torch.Tensor): Target for each prediction. 75 | 76 | Returns: 77 | tuple[float]: The accuracies under different topk criterions. 78 | """ 79 | return accuracy(pred, target, self.topk, self.thresh) 80 | -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BasePixelSampler(metaclass=ABCMeta): 6 | """Base class of pixel sampler.""" 7 | 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | @abstractmethod 12 | def sample(self, seg_logit, seg_label): 13 | """Placeholder for sample function.""" 14 | -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmcv.cnn import MODELS as MMCV_MODELS 5 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 6 | from mmcv.utils import Registry, build_from_cfg 7 | 8 | 9 | PIXEL_SAMPLERS = Registry('pixel sampler') 10 | MODELS = Registry('models', parent=MMCV_MODELS) 11 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 12 | 13 | BACKBONES = MODELS 14 | NECKS = MODELS 15 | HEADS = MODELS 16 | LOSSES = MODELS 17 | SEGMENTORS = MODELS 18 | 19 | def build_pixel_sampler(cfg, **default_args): 20 | """Build pixel sampler for segmentation map.""" 21 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 22 | 23 | 24 | def build_backbone(cfg): 25 | """Build backbone.""" 26 | return BACKBONES.build(cfg) 27 | 28 | 29 | def build_neck(cfg): 30 | """Build neck.""" 31 | return NECKS.build(cfg) 32 | 33 | 34 | def build_head(cfg): 35 | """Build head.""" 36 | return HEADS.build(cfg) 37 | 38 | 39 | def build_loss(cfg): 40 | """Build loss.""" 41 | return LOSSES.build(cfg) 42 | 43 | 44 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 45 | """Build segmentor.""" 46 | if train_cfg is not None or test_cfg is not None: 47 | warnings.warn( 48 | 'train_cfg and test_cfg is deprecated, ' 49 | 'please specify them in model', UserWarning) 50 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 51 | 'train_cfg specified in both outer field and model field ' 52 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 53 | 'test_cfg specified in both outer field and model field ' 54 | return SEGMENTORS.build( 55 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 56 | -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn.functional as F 3 | from mmcv.cnn import build_conv_layer, build_norm_layer 4 | from mmcv.runner.base_module import BaseModule 5 | from torch.nn.modules.utils import _pair as to_2tuple 6 | 7 | 8 | # Modified from Pytorch-Image-Models 9 | class PatchEmbed(BaseModule): 10 | """Image to Patch Embedding V2. 11 | 12 | We use a conv layer to implement PatchEmbed. 13 | Args: 14 | in_channels (int): The num of input channels. Default: 3 15 | embed_dims (int): The dimensions of embedding. Default: 768 16 | conv_type (dict, optional): The config dict for conv layers type 17 | selection. Default: None. 18 | kernel_size (int): The kernel_size of embedding conv. Default: 16. 19 | stride (int): The slide stride of embedding conv. 20 | Default: None (Default to be equal with kernel_size). 21 | padding (int): The padding length of embedding conv. Default: 0. 22 | dilation (int): The dilation rate of embedding conv. Default: 1. 23 | pad_to_patch_size (bool, optional): Whether to pad feature map shape 24 | to multiple patch size. Default: True. 25 | norm_cfg (dict, optional): Config dict for normalization layer. 26 | init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. 27 | Default: None. 28 | """ 29 | 30 | def __init__(self, 31 | in_channels=3, 32 | embed_dims=768, 33 | conv_type=None, 34 | kernel_size=16, 35 | stride=16, 36 | padding=0, 37 | dilation=1, 38 | pad_to_patch_size=True, 39 | norm_cfg=None, 40 | init_cfg=None): 41 | super(PatchEmbed, self).__init__() 42 | 43 | self.embed_dims = embed_dims 44 | self.init_cfg = init_cfg 45 | 46 | if stride is None: 47 | stride = kernel_size 48 | 49 | self.pad_to_patch_size = pad_to_patch_size 50 | 51 | # The default setting of patch size is equal to kernel size. 52 | patch_size = kernel_size 53 | if isinstance(patch_size, int): 54 | patch_size = to_2tuple(patch_size) 55 | elif isinstance(patch_size, tuple): 56 | if len(patch_size) == 1: 57 | patch_size = to_2tuple(patch_size[0]) 58 | assert len(patch_size) == 2, \ 59 | f'The size of patch should have length 1 or 2, ' \ 60 | f'but got {len(patch_size)}' 61 | 62 | self.patch_size = patch_size 63 | 64 | # Use conv layer to embed 65 | conv_type = conv_type or 'Conv2d' 66 | self.projection = build_conv_layer( 67 | dict(type=conv_type), 68 | in_channels=in_channels, 69 | out_channels=embed_dims, 70 | kernel_size=kernel_size, 71 | stride=stride, 72 | padding=padding, 73 | dilation=dilation) 74 | 75 | if norm_cfg is not None: 76 | self.norm = build_norm_layer(norm_cfg, embed_dims)[1] 77 | else: 78 | self.norm = None 79 | 80 | def forward(self, x): 81 | H, W = x.shape[2], x.shape[3] 82 | 83 | # TODO: Process overlapping op 84 | if self.pad_to_patch_size: 85 | # Modify H, W to multiple of patch size. 86 | if H % self.patch_size[0] != 0: 87 | x = F.pad( 88 | x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 89 | if W % self.patch_size[1] != 0: 90 | x = F.pad( 91 | x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0)) 92 | 93 | x = self.projection(x) 94 | self.DH, self.DW = x.shape[2], x.shape[3] 95 | x = x.flatten(2).transpose(1, 2) 96 | 97 | if self.norm is not None: 98 | x = self.norm(x) 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mmcv.utils import get_logger 4 | 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO): 7 | """Get the root logger. 8 | 9 | The logger will be initialized if it has not been initialized. By default a 10 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 11 | also be added. The name of the root logger is the top-level package name, 12 | e.g., "mmseg". 13 | 14 | Args: 15 | log_file (str | None): The log filename. If specified, a FileHandler 16 | will be added to the root logger. 17 | log_level (int): The root logger level. Note that only the process of 18 | rank 0 is affected, while other processes will set the level to 19 | "Error" and be silent most of the time. 20 | 21 | Returns: 22 | logging.Logger: The root logger. 23 | """ 24 | 25 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 26 | 27 | return logger 28 | -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/shape_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def nlc_to_nchw(x, hw_shape): 3 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 4 | 5 | Args: 6 | x (Tensor): The input tensor of shape [N, L, C] before convertion. 7 | hw_shape (Sequence[int]): The height and width of output feature map. 8 | 9 | Returns: 10 | Tensor: The output tensor of shape [N, C, H, W] after convertion. 11 | """ 12 | H, W = hw_shape 13 | assert len(x.shape) == 3 14 | B, L, C = x.shape 15 | assert L == H * W, 'The seq_len doesn\'t match H, W' 16 | return x.transpose(1, 2).reshape(B, C, H, W) 17 | 18 | 19 | def nchw_to_nlc(x): 20 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 21 | 22 | Args: 23 | x (Tensor): The input tensor of shape [N, C, H, W] before convertion. 24 | 25 | Returns: 26 | Tensor: The output tensor of shape [N, L, C] after convertion. 27 | """ 28 | assert len(x.shape) == 4 29 | return x.flatten(2).transpose(1, 2).contiguous() 30 | -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import functools 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | def get_class_weight(class_weight): 10 | """Get class weight for loss function. 11 | 12 | Args: 13 | class_weight (list[float] | str | None): If class_weight is a str, 14 | take it as a file name and read from it. 15 | """ 16 | if isinstance(class_weight, str): 17 | # take it as a file path 18 | if class_weight.endswith('.npy'): 19 | class_weight = np.load(class_weight) 20 | else: 21 | # pkl, json or yaml 22 | class_weight = mmcv.load(class_weight) 23 | 24 | return class_weight 25 | 26 | 27 | def reduce_loss(loss, reduction): 28 | """Reduce loss as specified. 29 | 30 | Args: 31 | loss (Tensor): Elementwise loss tensor. 32 | reduction (str): Options are "none", "mean" and "sum". 33 | 34 | Return: 35 | Tensor: Reduced loss tensor. 36 | """ 37 | reduction_enum = F._Reduction.get_enum(reduction) 38 | # none: 0, elementwise_mean:1, sum: 2 39 | if reduction_enum == 0: 40 | return loss 41 | elif reduction_enum == 1: 42 | return loss.mean() 43 | elif reduction_enum == 2: 44 | return loss.sum() 45 | 46 | 47 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 48 | """Apply element-wise weight and reduce loss. 49 | 50 | Args: 51 | loss (Tensor): Element-wise loss. 52 | weight (Tensor): Element-wise weights. 53 | reduction (str): Same as built-in losses of PyTorch. 54 | avg_factor (float): Avarage factor when computing the mean of losses. 55 | 56 | Returns: 57 | Tensor: Processed loss values. 58 | """ 59 | # if weight is specified, apply element-wise weight 60 | if weight is not None: 61 | assert weight.dim() == loss.dim() 62 | if weight.dim() > 1: 63 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 64 | loss = loss * weight 65 | 66 | # if avg_factor is not specified, just reduce the loss 67 | if avg_factor is None: 68 | loss = reduce_loss(loss, reduction) 69 | else: 70 | # if reduction is mean, then average the loss by avg_factor 71 | if reduction == 'mean': 72 | loss = loss.sum() / avg_factor 73 | # if reduction is 'none', then do nothing, otherwise raise an error 74 | elif reduction != 'none': 75 | raise ValueError('avg_factor can not be used with reduction="sum"') 76 | return loss 77 | 78 | 79 | def weighted_loss(loss_func): 80 | """Create a weighted version of a given loss function. 81 | 82 | To use this decorator, the loss function must have the signature like 83 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 84 | element-wise loss without any reduction. This decorator will add weight 85 | and reduction arguments to the function. The decorated function will have 86 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 87 | avg_factor=None, **kwargs)`. 88 | 89 | :Example: 90 | 91 | >>> import torch 92 | >>> @weighted_loss 93 | >>> def l1_loss(pred, target): 94 | >>> return (pred - target).abs() 95 | 96 | >>> pred = torch.Tensor([0, 2, 3]) 97 | >>> target = torch.Tensor([1, 1, 1]) 98 | >>> weight = torch.Tensor([1, 0, 1]) 99 | 100 | >>> l1_loss(pred, target) 101 | tensor(1.3333) 102 | >>> l1_loss(pred, target, weight) 103 | tensor(1.) 104 | >>> l1_loss(pred, target, reduction='none') 105 | tensor([1., 1., 2.]) 106 | >>> l1_loss(pred, target, weight, avg_factor=2) 107 | tensor(1.5000) 108 | """ 109 | 110 | @functools.wraps(loss_func) 111 | def wrapper(pred, 112 | target, 113 | weight=None, 114 | reduction='mean', 115 | avg_factor=None, 116 | **kwargs): 117 | # get element-wise loss 118 | loss = loss_func(pred, target, **kwargs) 119 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 120 | return loss 121 | 122 | return wrapper 123 | -------------------------------------------------------------------------------- /isegm/model/modeling/transformer_helper/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > output_h: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | 29 | 30 | class Upsample(nn.Module): 31 | 32 | def __init__(self, 33 | size=None, 34 | scale_factor=None, 35 | mode='nearest', 36 | align_corners=None): 37 | super(Upsample, self).__init__() 38 | self.size = size 39 | if isinstance(scale_factor, tuple): 40 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 41 | else: 42 | self.scale_factor = float(scale_factor) if scale_factor else None 43 | self.mode = mode 44 | self.align_corners = align_corners 45 | 46 | def forward(self, x): 47 | if not self.size: 48 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 49 | else: 50 | size = self.size 51 | return resize(x, size, None, self.mode, self.align_corners) 52 | -------------------------------------------------------------------------------- /isegm/model/modifiers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class LRMult(object): 4 | def __init__(self, lr_mult=1.): 5 | self.lr_mult = lr_mult 6 | 7 | def __call__(self, m): 8 | if getattr(m, 'weight', None) is not None: 9 | m.weight.lr_mult = self.lr_mult 10 | if getattr(m, 'bias', None) is not None: 11 | m.bias.lr_mult = self.lr_mult 12 | -------------------------------------------------------------------------------- /isegm/utils/cython/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .dist_maps import get_dist_maps -------------------------------------------------------------------------------- /isegm/utils/cython/_get_dist_maps.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport cython 3 | cimport numpy as np 4 | from libc.stdlib cimport malloc, free 5 | 6 | ctypedef struct qnode: 7 | int row 8 | int col 9 | int layer 10 | int orig_row 11 | int orig_col 12 | 13 | @cython.infer_types(True) 14 | @cython.boundscheck(False) 15 | @cython.wraparound(False) 16 | @cython.nonecheck(False) 17 | def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points, 18 | int height, int width, float norm_delimeter): 19 | cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \ 20 | np.full((2, height, width), 1e6, dtype=np.float32, order="C") 21 | 22 | cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0] 23 | cdef int i, j, x, y, dx, dy 24 | cdef qnode v 25 | cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode)) 26 | cdef int qhead = 0, qtail = -1 27 | cdef float ndist 28 | 29 | for i in range(points.shape[0]): 30 | x, y = round(points[i, 0]), round(points[i, 1]) 31 | if x >= 0: 32 | qtail += 1 33 | q[qtail].row = x 34 | q[qtail].col = y 35 | q[qtail].orig_row = x 36 | q[qtail].orig_col = y 37 | if i >= points.shape[0] / 2: 38 | q[qtail].layer = 1 39 | else: 40 | q[qtail].layer = 0 41 | dist_maps[q[qtail].layer, x, y] = 0 42 | 43 | while qtail - qhead + 1 > 0: 44 | v = q[qhead] 45 | qhead += 1 46 | 47 | for k in range(4): 48 | x = v.row + dxy[2 * k] 49 | y = v.col + dxy[2 * k + 1] 50 | 51 | ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2 52 | if (x >= 0 and y >= 0 and x < height and y < width and 53 | dist_maps[v.layer, x, y] > ndist): 54 | qtail += 1 55 | q[qtail].orig_col = v.orig_col 56 | q[qtail].orig_row = v.orig_row 57 | q[qtail].layer = v.layer 58 | q[qtail].row = x 59 | q[qtail].col = y 60 | dist_maps[v.layer, x, y] = ndist 61 | 62 | free(q) 63 | return dist_maps 64 | -------------------------------------------------------------------------------- /isegm/utils/cython/_get_dist_maps.pyxbld: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | def make_ext(modname, pyxfilename): 4 | from distutils.extension import Extension 5 | return Extension(modname, [pyxfilename], 6 | include_dirs=[numpy.get_include()], 7 | extra_compile_args=['-O3'], language='c++') 8 | -------------------------------------------------------------------------------- /isegm/utils/cython/dist_maps.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install(pyximport=True, language_level=3) 2 | # noinspection PyUnresolvedReferences 3 | from ._get_dist_maps import get_dist_maps -------------------------------------------------------------------------------- /isegm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed as dist 3 | from torch.utils import data 4 | 5 | 6 | def get_rank(): 7 | if not dist.is_available() or not dist.is_initialized(): 8 | return 0 9 | return dist.get_rank() 10 | 11 | 12 | def synchronize(): 13 | if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1: 14 | return 15 | dist.barrier() 16 | 17 | 18 | def get_world_size(): 19 | if not dist.is_available() or not dist.is_initialized(): 20 | return 1 21 | 22 | return dist.get_world_size() 23 | 24 | 25 | def reduce_loss_dict(loss_dict): 26 | world_size = get_world_size() 27 | 28 | if world_size < 2: 29 | return loss_dict 30 | 31 | with torch.no_grad(): 32 | keys = [] 33 | losses = [] 34 | 35 | for k in loss_dict.keys(): 36 | keys.append(k) 37 | losses.append(loss_dict[k]) 38 | 39 | losses = torch.stack(losses, 0) 40 | dist.reduce(losses, dst=0) 41 | 42 | if dist.get_rank() == 0: 43 | losses /= world_size 44 | 45 | reduced_losses = {k: v for k, v in zip(keys, losses)} 46 | 47 | return reduced_losses 48 | 49 | 50 | def get_sampler(dataset, shuffle, distributed): 51 | if distributed: 52 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 53 | 54 | if shuffle: 55 | return data.RandomSampler(dataset) 56 | else: 57 | return data.SequentialSampler(dataset) 58 | 59 | 60 | def get_dp_wrapper(distributed): 61 | class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel): 62 | def __getattr__(self, name): 63 | try: 64 | return super().__getattr__(name) 65 | except AttributeError: 66 | return getattr(self.module, name) 67 | return DPWrapper 68 | -------------------------------------------------------------------------------- /isegm/utils/exp_imports/default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from easydict import EasyDict as edict 4 | from albumentations import * 5 | 6 | from isegm.data.datasets import * 7 | from isegm.model.losses import * 8 | from isegm.data.transforms import * 9 | from isegm.engine.trainer import ISTrainer 10 | from isegm.model.metrics import AdaptiveIoU 11 | from isegm.data.points_sampler import MultiPointSampler 12 | from isegm.utils.log import logger 13 | from isegm.model import initializer 14 | 15 | from isegm.model.is_hrnet_model import HRNetModel 16 | from isegm.model.is_deeplab_model import DeeplabModel 17 | from isegm.model.is_segformer_model import SegformerModel 18 | from isegm.model.is_hrformer_model import HRFormerModel 19 | from isegm.model.is_swinformer_model import SwinformerModel 20 | from isegm.model.is_plainvit_model import PlainVitModel -------------------------------------------------------------------------------- /isegm/utils/log.py: -------------------------------------------------------------------------------- 1 | import io 2 | import time 3 | import logging 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | LOGGER_NAME = 'root' 10 | LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S' 11 | 12 | handler = logging.StreamHandler() 13 | 14 | logger = logging.getLogger(LOGGER_NAME) 15 | logger.setLevel(logging.INFO) 16 | logger.addHandler(handler) 17 | 18 | 19 | def add_logging(logs_path, prefix): 20 | log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log' 21 | stdout_log_path = logs_path / log_name 22 | 23 | fh = logging.FileHandler(str(stdout_log_path)) 24 | formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s', 25 | datefmt=LOGGER_DATEFMT) 26 | fh.setFormatter(formatter) 27 | logger.addHandler(fh) 28 | 29 | 30 | class TqdmToLogger(io.StringIO): 31 | logger = None 32 | level = None 33 | buf = '' 34 | 35 | def __init__(self, logger, level=None, mininterval=5): 36 | super(TqdmToLogger, self).__init__() 37 | self.logger = logger 38 | self.level = level or logging.INFO 39 | self.mininterval = mininterval 40 | self.last_time = 0 41 | 42 | def write(self, buf): 43 | self.buf = buf.strip('\r\n\t ') 44 | 45 | def flush(self): 46 | if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval: 47 | self.logger.log(self.level, self.buf) 48 | self.last_time = time.time() 49 | 50 | 51 | class SummaryWriterAvg(SummaryWriter): 52 | def __init__(self, *args, dump_period=20, **kwargs): 53 | super().__init__(*args, **kwargs) 54 | self._dump_period = dump_period 55 | self._avg_scalars = dict() 56 | 57 | def add_scalar(self, tag, value, global_step=None, disable_avg=False): 58 | if disable_avg or isinstance(value, (tuple, list, dict)): 59 | super().add_scalar(tag, np.array(value), global_step=global_step) 60 | else: 61 | if tag not in self._avg_scalars: 62 | self._avg_scalars[tag] = ScalarAccumulator(self._dump_period) 63 | avg_scalar = self._avg_scalars[tag] 64 | avg_scalar.add(value) 65 | 66 | if avg_scalar.is_full(): 67 | super().add_scalar(tag, avg_scalar.value, 68 | global_step=global_step) 69 | avg_scalar.reset() 70 | 71 | 72 | class ScalarAccumulator(object): 73 | def __init__(self, period): 74 | self.sum = 0 75 | self.cnt = 0 76 | self.period = period 77 | 78 | def add(self, value): 79 | self.sum += value 80 | self.cnt += 1 81 | 82 | @property 83 | def value(self): 84 | if self.cnt > 0: 85 | return self.sum / self.cnt 86 | else: 87 | return 0 88 | 89 | def reset(self): 90 | self.cnt = 0 91 | self.sum = 0 92 | 93 | def is_full(self): 94 | return self.cnt >= self.period 95 | 96 | def __len__(self): 97 | return self.cnt 98 | -------------------------------------------------------------------------------- /isegm/utils/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, lr, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | num_layers = len(model.backbone.blocks) + 1 23 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 24 | for n, p in model.backbone.named_parameters(): 25 | if not p.requires_grad: 26 | continue 27 | 28 | # no decay: all 1D parameters and model specific ones 29 | if p.ndim == 1 or n in no_weight_decay_list: 30 | g_decay = "no_decay" 31 | this_decay = 0. 32 | else: 33 | g_decay = "decay" 34 | this_decay = weight_decay 35 | 36 | layer_id = get_layer_id_for_vit(n, num_layers) 37 | group_name = "layer_%d_%s" % (layer_id, g_decay) 38 | 39 | if group_name not in param_group_names: 40 | this_scale = layer_scales[layer_id] 41 | 42 | param_group_names[group_name] = { 43 | "lr_scale": this_scale, 44 | "lr": lr * this_scale, 45 | "weight_decay": this_decay, 46 | "params": [], 47 | } 48 | param_groups[group_name] = { 49 | "lr_scale": this_scale, 50 | "lr": lr * this_scale, 51 | "weight_decay": this_decay, 52 | "params": [], 53 | } 54 | 55 | param_group_names[group_name]["params"].append(n) 56 | param_groups[group_name]["params"].append(p) 57 | 58 | params = list(param_groups.values()) 59 | 60 | for n, p in model.neck.named_parameters(): 61 | if not p.requires_grad: 62 | continue 63 | params.append({"params": p, "weight_decay": weight_decay}) 64 | 65 | for n, p in model.head.named_parameters(): 66 | if not p.requires_grad: 67 | continue 68 | params.append({"params": p, "weight_decay": weight_decay}) 69 | 70 | return params 71 | 72 | 73 | def get_layer_id_for_vit(name, num_layers): 74 | """ 75 | Assign a parameter with its layer id 76 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 77 | """ 78 | if name in ['cls_token', 'pos_embed']: 79 | return 0 80 | elif name.startswith('patch_embed'): 81 | return 0 82 | elif name.startswith('blocks'): 83 | return int(name.split('.')[1]) + 1 84 | else: 85 | return num_layers -------------------------------------------------------------------------------- /isegm/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from copy import deepcopy 3 | import inspect 4 | import torch.nn as nn 5 | 6 | 7 | def serialize(init): 8 | parameters = list(inspect.signature(init).parameters) 9 | 10 | @wraps(init) 11 | def new_init(self, *args, **kwargs): 12 | params = deepcopy(kwargs) 13 | for pname, value in zip(parameters[1:], args): 14 | params[pname] = value 15 | 16 | config = { 17 | 'class': get_classname(self.__class__), 18 | 'params': dict() 19 | } 20 | specified_params = set(params.keys()) 21 | 22 | for pname, param in get_default_params(self.__class__).items(): 23 | if pname not in params: 24 | params[pname] = param.default 25 | 26 | for name, value in list(params.items()): 27 | param_type = 'builtin' 28 | if inspect.isclass(value): 29 | param_type = 'class' 30 | value = get_classname(value) 31 | 32 | config['params'][name] = { 33 | 'type': param_type, 34 | 'value': value, 35 | 'specified': name in specified_params 36 | } 37 | 38 | setattr(self, '_config', config) 39 | init(self, *args, **kwargs) 40 | 41 | return new_init 42 | 43 | 44 | def load_model(config, eval_ritm, **kwargs): 45 | model_class = get_class_from_str(config['class']) 46 | model_default_params = get_default_params(model_class) 47 | 48 | model_args = dict() 49 | for pname, param in config['params'].items(): 50 | value = param['value'] 51 | if param['type'] == 'class': 52 | value = get_class_from_str(value) 53 | 54 | if pname not in model_default_params and not param['specified']: 55 | continue 56 | 57 | # assert pname in model_default_params 58 | if not param['specified'] and model_default_params[pname].default == value: 59 | continue 60 | model_args[pname] = value 61 | model_args.update(kwargs) 62 | 63 | # This ugly hardcode is only to support evalution for RITM models 64 | # Ignore it if you are evaluting SimpleClick models. 65 | if eval_ritm: 66 | model_args['use_rgb_conv'] = True 67 | 68 | return model_class(**model_args) 69 | 70 | 71 | def get_config_repr(config): 72 | config_str = f'Model: {config["class"]}\n' 73 | for pname, param in config['params'].items(): 74 | value = param["value"] 75 | if param['type'] == 'class': 76 | value = value.split('.')[-1] 77 | param_str = f'{pname:<22} = {str(value):<12}' 78 | if not param['specified']: 79 | param_str += ' (default)' 80 | config_str += param_str + '\n' 81 | return config_str 82 | 83 | 84 | def get_default_params(some_class): 85 | params = dict() 86 | for mclass in some_class.mro(): 87 | if mclass is nn.Module or mclass is object: 88 | continue 89 | 90 | mclass_params = inspect.signature(mclass.__init__).parameters 91 | for pname, param in mclass_params.items(): 92 | if param.default != param.empty and pname not in params: 93 | params[pname] = param 94 | 95 | return params 96 | 97 | 98 | def get_classname(cls): 99 | module = cls.__module__ 100 | name = cls.__qualname__ 101 | if module is not None and module != "__builtin__": 102 | name = module + "." + name 103 | return name 104 | 105 | 106 | def get_class_from_str(class_str): 107 | components = class_str.split('.') 108 | mod = __import__('.'.join(components[:-1])) 109 | for comp in components[1:]: 110 | mod = getattr(mod, comp) 111 | return mod 112 | -------------------------------------------------------------------------------- /models/iter_mask/adaptiveclick_base448_cocolvis_itermask.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.config import CfgNode 2 | 3 | from isegm.utils.exp_imports.default import * 4 | from isegm.engine.adaptiveclick_trainer import AdaptiveClickTrainer 5 | from isegm.model.is_adaptiveclick_model import AdaptiveClickModel 6 | 7 | MODEL_NAME = 'adaptiveclick_base448_cocolvis' 8 | MODEL_CONFIG = 'configs/adaptiveclick_plainvit_base448.yaml' 9 | 10 | 11 | def main(cfg): 12 | model, model_cfg = init_model(cfg) 13 | train(model, cfg, model_cfg) 14 | 15 | 16 | def init_model(cfg): 17 | # fetch config for model 18 | model_cfg = CfgNode.load_yaml_with_base(MODEL_CONFIG, allow_unsafe=True) 19 | # merge args to model_config 20 | for k, v in cfg.__dict__.items(): 21 | model_cfg[k] = v 22 | 23 | model_cfg = edict(d=model_cfg) 24 | model_cfg.crop_size = (448, 448) 25 | model_cfg.num_max_points = 24 26 | 27 | model = AdaptiveClickModel(model_cfg, with_prev_mask=True, use_disks=True, norm_radius=5) 28 | model.backbone.init_weights_from_pretrained(model_cfg.MODEL.IMAGENET_PRETRAINED_MODELS) 29 | model.to(cfg.device) 30 | return model, model_cfg 31 | 32 | 33 | def train(model, cfg, model_cfg): 34 | cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size 35 | cfg.val_batch_size = cfg.batch_size 36 | crop_size = model_cfg.crop_size 37 | 38 | loss_cfg = edict() 39 | loss_cfg.instance_loss_weight = 1.0 40 | 41 | train_augmentator = Compose([ 42 | UniformRandomResize(scale_range=(0.75, 1.40)), 43 | HorizontalFlip(), 44 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 45 | RandomCrop(*crop_size), 46 | RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), 47 | RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) 48 | ], p=1.0) 49 | 50 | val_augmentator = Compose([ 51 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 52 | RandomCrop(*crop_size) 53 | ], p=1.0) 54 | 55 | points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, 56 | merge_objects_prob=0.15, 57 | max_num_merged_objects=2) 58 | 59 | trainset = CocoLvisDataset( 60 | cfg.LVIS_v1_PATH, 61 | split='train', 62 | augmentator=train_augmentator, 63 | min_object_area=1000, 64 | keep_background_prob=0.05, 65 | points_sampler=points_sampler, 66 | epoch_len=30000, 67 | stuff_prob=0.30 68 | ) 69 | 70 | valset = CocoLvisDataset( 71 | cfg.LVIS_v1_PATH, 72 | split='val', 73 | augmentator=val_augmentator, 74 | min_object_area=1000, 75 | points_sampler=points_sampler, 76 | epoch_len=2000 77 | ) 78 | 79 | optimizer_params = { 80 | 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 81 | } 82 | 83 | lr_scheduler = partial( 84 | torch.optim.lr_scheduler.MultiStepLR, 85 | milestones=[40, 60], gamma=0.1 86 | ) 87 | 88 | trainer = AdaptiveClickTrainer( 89 | model=model, cfg=cfg, model_cfg=model_cfg, loss_cfg=loss_cfg, 90 | trainset=trainset, valset=valset, 91 | optimizer='adam', 92 | optimizer_params=optimizer_params, 93 | lr_scheduler=lr_scheduler, 94 | checkpoint_interval=[(0, 30), (50, 1)], 95 | image_dump_interval=100, 96 | metrics=[AdaptiveIoU(from_logits=False)], 97 | max_interactive_points=model_cfg.num_max_points, 98 | max_num_next_clicks=3 99 | ) 100 | trainer.run(num_epochs=60, validation=False) 101 | -------------------------------------------------------------------------------- /models/iter_mask/adaptiveclick_base448_sbd_itermask.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.config import CfgNode 2 | 3 | from isegm.utils.exp_imports.default import * 4 | from isegm.engine.adaptiveclick_trainer import AdaptiveClickTrainer 5 | from isegm.model.is_adaptiveclick_model import AdaptiveClickModel 6 | 7 | MODEL_NAME = 'adaptiveclick_base448_sbd' 8 | MODEL_CONFIG = 'configs/adaptiveclick_plainvit_base448.yaml' 9 | 10 | 11 | def main(cfg): 12 | model, model_cfg = init_model(cfg) 13 | train(model, cfg, model_cfg) 14 | 15 | 16 | def init_model(cfg): 17 | # fetch config for model 18 | model_cfg = CfgNode.load_yaml_with_base(MODEL_CONFIG, allow_unsafe=True) 19 | # merge args to model_config 20 | for k, v in cfg.__dict__.items(): 21 | model_cfg[k] = v 22 | 23 | model_cfg = edict(d=model_cfg) 24 | model_cfg.crop_size = (448, 448) 25 | model_cfg.num_max_points = 24 26 | 27 | model = AdaptiveClickModel(model_cfg, with_prev_mask=True, use_disks=True, norm_radius=5) 28 | model.backbone.init_weights_from_pretrained(model_cfg.MODEL.IMAGENET_PRETRAINED_MODELS) 29 | model.to(cfg.device) 30 | return model, model_cfg 31 | 32 | 33 | def train(model, cfg, model_cfg): 34 | cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size 35 | cfg.val_batch_size = cfg.batch_size 36 | crop_size = model_cfg.crop_size 37 | 38 | loss_cfg = edict() 39 | loss_cfg.instance_loss_weight = 1.0 40 | 41 | train_augmentator = Compose([ 42 | UniformRandomResize(scale_range=(0.75, 1.25)), 43 | Flip(), 44 | RandomRotate90(), 45 | ShiftScaleRotate(shift_limit=0.03, scale_limit=0, 46 | rotate_limit=(-3, 3), border_mode=0, p=0.75), 47 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 48 | RandomCrop(*crop_size), 49 | RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), 50 | RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) 51 | ], p=1.0) 52 | 53 | val_augmentator = Compose([ 54 | UniformRandomResize(scale_range=(0.75, 1.25)), 55 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 56 | RandomCrop(*crop_size) 57 | ], p=1.0) 58 | 59 | points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, 60 | merge_objects_prob=0.15, 61 | max_num_merged_objects=2) 62 | 63 | trainset = SBDDataset( 64 | cfg.SBD_PATH, 65 | split='train', 66 | augmentator=train_augmentator, 67 | min_object_area=20, 68 | keep_background_prob=0.01, 69 | points_sampler=points_sampler, 70 | samples_scores_path='./assets/sbd_samples_weights.pkl', 71 | samples_scores_gamma=1.25, 72 | ) 73 | 74 | valset = SBDDataset( 75 | cfg.SBD_PATH, 76 | split='val', 77 | augmentator=val_augmentator, 78 | min_object_area=20, 79 | points_sampler=points_sampler, 80 | epoch_len=500 81 | ) 82 | 83 | optimizer_params = { 84 | 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 85 | } 86 | 87 | lr_scheduler = partial( 88 | torch.optim.lr_scheduler.MultiStepLR, 89 | milestones=[40, 60], gamma=0.1 90 | ) 91 | 92 | trainer = AdaptiveClickTrainer( 93 | model=model, cfg=cfg, model_cfg=model_cfg, loss_cfg=loss_cfg, 94 | trainset=trainset, valset=valset, 95 | optimizer='adam', 96 | optimizer_params=optimizer_params, 97 | lr_scheduler=lr_scheduler, 98 | checkpoint_interval=[(0, 30), (50, 1)], 99 | image_dump_interval=100, 100 | metrics=[AdaptiveIoU(from_logits=False)], 101 | max_interactive_points=model_cfg.num_max_points, 102 | max_num_next_clicks=3 103 | ) 104 | trainer.run(num_epochs=60, validation=False) 105 | -------------------------------------------------------------------------------- /models/iter_mask/adaptiveclick_huge448_cocolvis_itermask.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.config import CfgNode 2 | 3 | from isegm.utils.exp_imports.default import * 4 | from isegm.engine.adaptiveclick_trainer import AdaptiveClickTrainer 5 | from isegm.model.is_adaptiveclick_model import AdaptiveClickModel 6 | 7 | MODEL_NAME = 'adaptiveclick_huge448_cocolvis' 8 | MODEL_CONFIG = 'configs/adaptiveclick_plainvit_huge448.yaml' 9 | 10 | 11 | def main(cfg): 12 | model, model_cfg = init_model(cfg) 13 | train(model, cfg, model_cfg) 14 | 15 | 16 | def init_model(cfg): 17 | # fetch config for model 18 | model_cfg = CfgNode.load_yaml_with_base(MODEL_CONFIG, allow_unsafe=True) 19 | # merge args to model_config 20 | for k, v in cfg.__dict__.items(): 21 | model_cfg[k] = v 22 | 23 | model_cfg = edict(d=model_cfg) 24 | model_cfg.crop_size = (448, 448) 25 | model_cfg.num_max_points = 24 26 | 27 | model = AdaptiveClickModel(model_cfg, with_prev_mask=True, use_disks=True, norm_radius=5) 28 | model.backbone.init_weights_from_pretrained(model_cfg.MODEL.IMAGENET_PRETRAINED_MODELS) 29 | model.to(cfg.device) 30 | return model, model_cfg 31 | 32 | 33 | def train(model, cfg, model_cfg): 34 | cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size 35 | cfg.val_batch_size = cfg.batch_size 36 | crop_size = model_cfg.crop_size 37 | 38 | 39 | loss_cfg = edict() 40 | loss_cfg.instance_loss_weight = 1.0 41 | 42 | train_augmentator = Compose([ 43 | UniformRandomResize(scale_range=(0.75, 1.40)), 44 | HorizontalFlip(), 45 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 46 | RandomCrop(*crop_size), 47 | RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), 48 | RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) 49 | ], p=1.0) 50 | 51 | val_augmentator = Compose([ 52 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 53 | RandomCrop(*crop_size) 54 | ], p=1.0) 55 | 56 | points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, 57 | merge_objects_prob=0.15, 58 | max_num_merged_objects=2) 59 | 60 | trainset = CocoLvisDataset( 61 | cfg.LVIS_v1_PATH, 62 | split='train', 63 | augmentator=train_augmentator, 64 | min_object_area=1000, 65 | keep_background_prob=0.05, 66 | points_sampler=points_sampler, 67 | epoch_len=30000, 68 | stuff_prob=0.30 69 | ) 70 | 71 | valset = CocoLvisDataset( 72 | cfg.LVIS_v1_PATH, 73 | split='val', 74 | augmentator=val_augmentator, 75 | min_object_area=1000, 76 | points_sampler=points_sampler, 77 | epoch_len=2000 78 | ) 79 | 80 | optimizer_params = { 81 | 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 82 | } 83 | 84 | lr_scheduler = partial( 85 | torch.optim.lr_scheduler.MultiStepLR, 86 | milestones=[40, 60], gamma=0.1 87 | ) 88 | 89 | trainer = AdaptiveClickTrainer( 90 | model=model, cfg=cfg, model_cfg=model_cfg, loss_cfg=loss_cfg, 91 | trainset=trainset, valset=valset, 92 | optimizer='adam', 93 | optimizer_params=optimizer_params, 94 | lr_scheduler=lr_scheduler, 95 | checkpoint_interval=[(0, 30), (50, 1)], 96 | image_dump_interval=100, 97 | metrics=[AdaptiveIoU(from_logits=False)], 98 | max_interactive_points=model_cfg.num_max_points, 99 | max_num_next_clicks=3 100 | ) 101 | trainer.run(num_epochs=60, validation=False) 102 | 103 | -------------------------------------------------------------------------------- /models/iter_mask/adaptiveclick_huge448_sbd_itermask.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.config import CfgNode 2 | 3 | from isegm.utils.exp_imports.default import * 4 | from isegm.engine.adaptiveclick_trainer import AdaptiveClickTrainer 5 | from isegm.model.is_adaptiveclick_model import AdaptiveClickModel 6 | 7 | MODEL_NAME = 'adaptiveclick_huge448_sbd' 8 | MODEL_CONFIG = 'configs/adaptiveclick_plainvit_huge448.yaml' 9 | 10 | 11 | def main(cfg): 12 | model, model_cfg = init_model(cfg) 13 | train(model, cfg, model_cfg) 14 | 15 | 16 | def init_model(cfg): 17 | # fetch config for model 18 | model_cfg = CfgNode.load_yaml_with_base(MODEL_CONFIG, allow_unsafe=True) 19 | # merge args to model_config 20 | for k, v in cfg.__dict__.items(): 21 | model_cfg[k] = v 22 | 23 | model_cfg = edict(d=model_cfg) 24 | model_cfg.crop_size = (448, 448) 25 | model_cfg.num_max_points = 24 26 | 27 | model = AdaptiveClickModel(model_cfg, with_prev_mask=True, use_disks=True, norm_radius=5) 28 | model.backbone.init_weights_from_pretrained(model_cfg.MODEL.IMAGENET_PRETRAINED_MODELS) 29 | model.to(cfg.device) 30 | return model, model_cfg 31 | 32 | 33 | def train(model, cfg, model_cfg): 34 | cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size 35 | cfg.val_batch_size = cfg.batch_size 36 | crop_size = model_cfg.crop_size 37 | 38 | loss_cfg = edict() 39 | loss_cfg.instance_loss_weight = 1.0 40 | 41 | train_augmentator = Compose([ 42 | UniformRandomResize(scale_range=(0.75, 1.25)), 43 | Flip(), 44 | RandomRotate90(), 45 | ShiftScaleRotate(shift_limit=0.03, scale_limit=0, 46 | rotate_limit=(-3, 3), border_mode=0, p=0.75), 47 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 48 | RandomCrop(*crop_size), 49 | RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), 50 | RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) 51 | ], p=1.0) 52 | 53 | val_augmentator = Compose([ 54 | UniformRandomResize(scale_range=(0.75, 1.25)), 55 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 56 | RandomCrop(*crop_size) 57 | ], p=1.0) 58 | 59 | points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, 60 | merge_objects_prob=0.15, 61 | max_num_merged_objects=2) 62 | 63 | trainset = SBDDataset( 64 | cfg.SBD_PATH, 65 | split='train', 66 | augmentator=train_augmentator, 67 | min_object_area=20, 68 | keep_background_prob=0.01, 69 | points_sampler=points_sampler, 70 | samples_scores_path='./assets/sbd_samples_weights.pkl', 71 | samples_scores_gamma=1.25, 72 | ) 73 | 74 | valset = SBDDataset( 75 | cfg.SBD_PATH, 76 | split='val', 77 | augmentator=val_augmentator, 78 | min_object_area=20, 79 | points_sampler=points_sampler, 80 | epoch_len=500 81 | ) 82 | 83 | optimizer_params = { 84 | 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 85 | } 86 | 87 | lr_scheduler = partial( 88 | torch.optim.lr_scheduler.MultiStepLR, 89 | milestones=[40, 60], gamma=0.1 90 | ) 91 | 92 | trainer = AdaptiveClickTrainer( 93 | model=model, cfg=cfg, model_cfg=model_cfg, loss_cfg=loss_cfg, 94 | trainset=trainset, valset=valset, 95 | optimizer='adam', 96 | optimizer_params=optimizer_params, 97 | lr_scheduler=lr_scheduler, 98 | checkpoint_interval=[(0, 30), (50, 1)], 99 | image_dump_interval=100, 100 | metrics=[AdaptiveIoU(from_logits=False)], 101 | max_interactive_points=model_cfg.num_max_points, 102 | max_num_next_clicks=3 103 | ) 104 | trainer.run(num_epochs=60, validation=False) 105 | 106 | -------------------------------------------------------------------------------- /models/iter_mask/simpleclick_base448_cocolvis_itermask.py: -------------------------------------------------------------------------------- 1 | from isegm.utils.exp_imports.default import * 2 | from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss 3 | 4 | MODEL_NAME = 'cocolvis_plainvit_base448' 5 | 6 | 7 | def main(cfg): 8 | model, model_cfg = init_model(cfg) 9 | train(model, cfg, model_cfg) 10 | 11 | 12 | def init_model(cfg): 13 | model_cfg = edict() 14 | model_cfg.crop_size = (448, 448) 15 | model_cfg.num_max_points = 24 16 | 17 | backbone_params = dict( 18 | img_size=model_cfg.crop_size, 19 | patch_size=(16, 16), 20 | in_chans=3, 21 | embed_dim=768, 22 | depth=12, 23 | num_heads=12, 24 | mlp_ratio=4, 25 | qkv_bias=True, 26 | ) 27 | 28 | neck_params = dict( 29 | in_dim=768, 30 | out_dims=[128, 256, 512, 1024], 31 | ) 32 | 33 | head_params = dict( 34 | in_channels=[128, 256, 512, 1024], 35 | in_index=[0, 1, 2, 3], 36 | dropout_ratio=0.1, 37 | num_classes=1, 38 | loss_decode=CrossEntropyLoss(), 39 | align_corners=False, 40 | upsample=cfg.upsample, 41 | channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], 42 | ) 43 | 44 | model = PlainVitModel( 45 | use_disks=True, 46 | norm_radius=5, 47 | with_prev_mask=True, 48 | backbone_params=backbone_params, 49 | neck_params=neck_params, 50 | head_params=head_params, 51 | random_split=cfg.random_split, 52 | ) 53 | 54 | model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_BASE) 55 | model.to(cfg.device) 56 | 57 | return model, model_cfg 58 | 59 | 60 | def train(model, cfg, model_cfg): 61 | cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size 62 | cfg.val_batch_size = cfg.batch_size 63 | crop_size = model_cfg.crop_size 64 | 65 | loss_cfg = edict() 66 | loss_cfg.instance_loss = AdaptiveFocalLossSigmoid(gamma=2, delta=0.4) 67 | loss_cfg.instance_loss_weight = 1.0 68 | 69 | train_augmentator = Compose([ 70 | UniformRandomResize(scale_range=(0.75, 1.40)), 71 | HorizontalFlip(), 72 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 73 | RandomCrop(*crop_size), 74 | RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), 75 | RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) 76 | ], p=1.0) 77 | 78 | val_augmentator = Compose([ 79 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 80 | RandomCrop(*crop_size) 81 | ], p=1.0) 82 | 83 | points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, 84 | merge_objects_prob=0.15, 85 | max_num_merged_objects=2) 86 | 87 | trainset = CocoLvisDataset( 88 | cfg.LVIS_v1_PATH, 89 | split='train', 90 | augmentator=train_augmentator, 91 | min_object_area=1000, 92 | keep_background_prob=0.05, 93 | points_sampler=points_sampler, 94 | epoch_len=30000, 95 | stuff_prob=0.30 96 | ) 97 | 98 | valset = CocoLvisDataset( 99 | cfg.LVIS_v1_PATH, 100 | split='val', 101 | augmentator=val_augmentator, 102 | min_object_area=1000, 103 | points_sampler=points_sampler, 104 | epoch_len=2000 105 | ) 106 | 107 | optimizer_params = { 108 | 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8, 109 | } 110 | 111 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, 112 | milestones=[50, 55], gamma=0.1) 113 | trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, 114 | trainset, valset, 115 | optimizer='adam', 116 | optimizer_params=optimizer_params, 117 | layerwise_decay=cfg.layerwise_decay, 118 | lr_scheduler=lr_scheduler, 119 | checkpoint_interval=[(0, 20), (50, 1)], 120 | image_dump_interval=300, 121 | metrics=[AdaptiveIoU()], 122 | max_interactive_points=model_cfg.num_max_points, 123 | max_num_next_clicks=3) 124 | trainer.run(num_epochs=55, validation=False) 125 | -------------------------------------------------------------------------------- /models/iter_mask/simpleclick_huge448_cocolvis_itermask.py: -------------------------------------------------------------------------------- 1 | from isegm.utils.exp_imports.default import * 2 | from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss 3 | 4 | MODEL_NAME = 'cocolvis_vit_huge448' 5 | 6 | 7 | def main(cfg): 8 | model, model_cfg = init_model(cfg) 9 | train(model, cfg, model_cfg) 10 | 11 | 12 | def init_model(cfg): 13 | model_cfg = edict() 14 | model_cfg.crop_size = (448, 448) 15 | model_cfg.num_max_points = 24 16 | 17 | backbone_params = dict( 18 | img_size=model_cfg.crop_size, 19 | patch_size=(14,14), 20 | in_chans=3, 21 | embed_dim=1280, 22 | depth=32, 23 | num_heads=16, 24 | mlp_ratio=4, 25 | qkv_bias=True, 26 | ) 27 | 28 | neck_params = dict( 29 | in_dim = 1280, 30 | out_dims = [240, 480, 960, 1920], 31 | ) 32 | 33 | head_params = dict( 34 | in_channels=[240, 480, 960, 1920], 35 | in_index=[0, 1, 2, 3], 36 | dropout_ratio=0.1, 37 | num_classes=1, 38 | loss_decode=CrossEntropyLoss(), 39 | align_corners=False, 40 | upsample=cfg.upsample, 41 | channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], 42 | ) 43 | 44 | model = PlainVitModel( 45 | use_disks=True, 46 | norm_radius=5, 47 | with_prev_mask=True, 48 | backbone_params=backbone_params, 49 | neck_params=neck_params, 50 | head_params=head_params, 51 | random_split=cfg.random_split, 52 | ) 53 | 54 | model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_HUGE) 55 | model.to(cfg.device) 56 | 57 | return model, model_cfg 58 | 59 | 60 | def train(model, cfg, model_cfg): 61 | cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size 62 | cfg.val_batch_size = cfg.batch_size 63 | crop_size = model_cfg.crop_size 64 | 65 | loss_cfg = edict() 66 | loss_cfg.instance_loss = AdaptiveFocalLossSigmoid(gamma=2, delta=0.4) 67 | loss_cfg.instance_loss_weight = 1.0 68 | 69 | train_augmentator = Compose([ 70 | UniformRandomResize(scale_range=(0.75, 1.40)), 71 | HorizontalFlip(), 72 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 73 | RandomCrop(*crop_size), 74 | RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), 75 | RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) 76 | ], p=1.0) 77 | 78 | val_augmentator = Compose([ 79 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 80 | RandomCrop(*crop_size) 81 | ], p=1.0) 82 | 83 | points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, 84 | merge_objects_prob=0.15, 85 | max_num_merged_objects=2) 86 | 87 | trainset = CocoLvisDataset( 88 | cfg.LVIS_v1_PATH, 89 | split='train', 90 | augmentator=train_augmentator, 91 | min_object_area=1000, 92 | keep_background_prob=0.05, 93 | points_sampler=points_sampler, 94 | epoch_len=30000, 95 | stuff_prob=0.30 96 | ) 97 | 98 | valset = CocoLvisDataset( 99 | cfg.LVIS_v1_PATH, 100 | split='val', 101 | augmentator=val_augmentator, 102 | min_object_area=1000, 103 | points_sampler=points_sampler, 104 | epoch_len=2000 105 | ) 106 | 107 | optimizer_params = { 108 | 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 109 | } 110 | 111 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, 112 | milestones=[50, 55], gamma=0.1) 113 | trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, 114 | trainset, valset, 115 | optimizer='adam', 116 | optimizer_params=optimizer_params, 117 | layerwise_decay=cfg.layerwise_decay, 118 | lr_scheduler=lr_scheduler, 119 | checkpoint_interval=[(0, 20), (50, 1)], 120 | image_dump_interval=300, 121 | metrics=[AdaptiveIoU()], 122 | max_interactive_points=model_cfg.num_max_points, 123 | max_num_next_clicks=3) 124 | trainer.run(num_epochs=55, validation=False) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.4 2 | torch==1.12.1 3 | torchvision==0.13.1 4 | setuptools==52.0.0 5 | albumentations==0.5.2 6 | PyYAML==6.0 7 | easydict==1.9 8 | tensorboard==2.8.0 9 | opencv-python-headless==4.5.3.56 10 | albumentations==0.5.2 11 | mmcv==1.6.2 12 | timm==0.6.11 13 | Cython==0.29.32 14 | fvcore==0.1.5.post20221221 -------------------------------------------------------------------------------- /scripts/analyze_image_size.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.insert(0, '/playpen-raid2/qinliu/projects/iSegFormer') 4 | 5 | from isegm.data.datasets.grabcut import GrabCutDataset 6 | from isegm.data.compose import ComposeDataset, ProportionalComposeDataset 7 | from isegm.data.datasets.berkeley import BerkeleyDataset 8 | from isegm.data.datasets.coco import CocoDataset 9 | from isegm.data.datasets.davis import DavisDataset 10 | from isegm.data.datasets.grabcut import GrabCutDataset 11 | from isegm.data.datasets.coco_lvis import CocoLvisDataset 12 | from isegm.data.datasets.lvis import LvisDataset 13 | from isegm.data.datasets.lvis_v1 import Lvis_v1_Dataset 14 | from isegm.data.datasets.openimages import OpenImagesDataset 15 | from isegm.data.datasets.sbd import SBDDataset, SBDEvaluationDataset 16 | from isegm.data.datasets.images_dir import ImagesDirDataset 17 | from isegm.data.datasets.ade20k import ADE20kDataset 18 | from isegm.data.datasets.pascalvoc import PascalVocDataset 19 | from isegm.data.datasets.brats import BraTSDataset 20 | from isegm.data.datasets.ssTEM import ssTEMDataset 21 | from isegm.data.datasets.oai_zib import OAIZIBDataset 22 | from isegm.data.datasets.oai import OAIDataset 23 | 24 | # Evaluation datasets 25 | GRABCUT_PATH="/playpen-raid2/qinliu/data/GrabCut" 26 | BERKELEY_PATH="/playpen-raid/qinliu/data/Berkeley" 27 | DAVIS_PATH="/playpen-raid/qinliu/data/DAVIS" 28 | COCO_MVAL_PATH="/playpen-raid/qinliu/data/COCO_MVal" 29 | 30 | BraTS_PATH="/playpen-raid/qinliu/data/BraTS20" 31 | ssTEM_PATH="/playpen-raid/qinliu/data/ssTEM" 32 | 33 | OAIZIB_PATH="/playpen-raid2/qinliu/data/OAI-ZIB" 34 | OAI_PATH="/playpen-raid2/qinliu/data/OAI" 35 | SBD_PATH="/playpen-raid/qinliu/data/SBD/dataset" 36 | PASCALVOC_PATH="/playpen-raid/qinliu/data/PascalVOC" 37 | 38 | 39 | def get_dataset(dataset_name): 40 | if dataset_name == 'GrabCut': 41 | dataset = GrabCutDataset(GRABCUT_PATH) 42 | elif dataset_name == 'Berkeley': 43 | dataset = BerkeleyDataset(BERKELEY_PATH) 44 | elif dataset_name == 'DAVIS': 45 | dataset = DavisDataset(DAVIS_PATH) 46 | elif dataset_name == 'SBD': 47 | dataset = SBDEvaluationDataset(SBD_PATH) 48 | elif dataset_name == 'SBD_Train': 49 | dataset = SBDEvaluationDataset(SBD_PATH, split='train') 50 | elif dataset_name == 'PascalVOC': 51 | dataset = PascalVocDataset(PASCALVOC_PATH, split='val') 52 | elif dataset_name == 'COCO_MVal': 53 | dataset = DavisDataset(COCO_MVAL_PATH) 54 | elif dataset_name == 'BraTS': 55 | dataset = BraTSDataset(BraTS_PATH) 56 | elif dataset_name == 'ssTEM': 57 | dataset = ssTEMDataset(ssTEM_PATH) 58 | elif dataset_name == 'OAIZIB': 59 | dataset = OAIZIBDataset(OAIZIB_PATH) 60 | else: 61 | dataset = None 62 | 63 | return dataset 64 | 65 | 66 | GrabCut = get_dataset('GrabCut') 67 | Berkeley = get_dataset('Berkeley') 68 | DAVIS = get_dataset('DAVIS') 69 | SBD = get_dataset('SBD') 70 | PascalVOC = get_dataset('PascalVOC') 71 | COCO_MVal = get_dataset('COCO_MVal') 72 | BraTS = get_dataset('BraTS') 73 | ssTEM = get_dataset('ssTEM') 74 | OAIZIB = get_dataset('OAIZIB') 75 | 76 | print('Length of each evaluation dataset.') 77 | # print('GrabCut: ', len(GrabCut)) 78 | # print('Berkeley: ', len(Berkeley)) 79 | # print('DAVIS: ', len(DAVIS)) 80 | # print('SBD: ', len(SBD)) 81 | # print('PascalVOC: ', len(PascalVOC)) 82 | # print('COCO_MVal: ', len(COCO_MVal)) 83 | # print('BraTS: ', len(BraTS)) 84 | # print('ssTEM: ', len(ssTEM)) 85 | # print('OAIZIB: ', len(OAIZIB)) 86 | 87 | dataset_names = ['GrabCut'] 88 | # dataset_names = ['GrabCut', 'Berkeley', 'DAVIS', 'SBD', 'PascalVOC', 'COCO_MVal', 'BraTS', 'ssTEM', 'OAIZIB'] 89 | xs, ys, labels = [], [], [] 90 | for dataset_name in dataset_names: 91 | dataset = get_dataset(dataset_name) 92 | print(dataset_name, len(dataset)) 93 | for i in range(len(dataset)): 94 | sample = dataset.get_sample(i) 95 | print(sample.image.shape) 96 | x, y, _ = sample.image.shape 97 | xs.append(x) 98 | ys.append(y) 99 | labels.append(dataset_name) 100 | 101 | import matplotlib.pyplot as plt 102 | import pandas as pd 103 | 104 | df = pd.DataFrame(dict(x=xs, y=ys, label=labels)) 105 | groups = df.groupby('label') 106 | 107 | fig, ax = plt.subplots() 108 | # ax.margins(0.5) 109 | for name, group in groups: 110 | ax.plot(group.x, group.y, marker='o', linestyle='', label=name) 111 | ax.legend() 112 | ax.grid() 113 | 114 | plt.show() -------------------------------------------------------------------------------- /scripts/annotations_conversion/ade20k.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | from pathlib import Path 3 | from scipy.io import loadmat 4 | 5 | from scripts.annotations_conversion.common import parallel_map 6 | 7 | 8 | ADE20K_STUFF_CLASSES = ['water', 'wall', 'snow', 'sky', 'sea', 'sand', 'road', 'route', 'river', 'path', 'mountain', 9 | 'mount', 'land', 'ground', 'soil', 'hill', 'grass', 'floor', 'flooring', 'field', 'earth', 10 | 'ground', 'fence', 'ceiling', 'wave', 'crosswalk', 'hay bale', 'bridge', 'span', 'building', 11 | 'edifice', 'cabinet', 'cushion', 'curtain', 'drape', 'drapery', 'mantle', 'pall', 'door', 12 | 'fencing', 'house', 'pole', 'seat', 'windowpane', 'window', 'tree', 'towel', 'table', 13 | 'stairs', 'steps', 'streetlight', 'street lamp', 'sofa', 'couch', 'lounge', 'skyscraper', 14 | 'signboard', 'sign', 'sidewalk', 'pavement', 'shrub', 'bush', 'rug', 'carpet'] 15 | 16 | 17 | def worker_annotations_loader(anno_pair, dataset_path): 18 | image_id, folder = anno_pair 19 | n_masks = len(list((dataset_path / folder).glob(f'{image_id}_*.png'))) 20 | 21 | # each image has several layers with instances, 22 | # each layer has mask name and instance_to_class mapping 23 | layers = [{ 24 | 'mask_name': f'{image_id}_{suffix}.png', 25 | 'instance_to_class': {}, 26 | 'object_instances': [], 27 | 'stuff_instances': [] 28 | } for suffix in ['seg'] + [f'parts_{i}' for i in range(1, n_masks)]] 29 | 30 | # parse txt with instance to class mappings 31 | with (dataset_path / folder / (image_id + "_atr.txt")).open('r') as f: 32 | for line in f: 33 | # instance_id layer_n is_occluded class_names class_name_raw attributes 34 | line = line.strip().split('#') 35 | inst_id, layer_n, class_names = int(line[0]), int(line[1]), line[3] 36 | 37 | # there may be more than one class name for each instance 38 | class_names = [name.strip() for name in class_names.split(',')] 39 | 40 | # check if any of classes is stuff 41 | if set(class_names) & set(ADE20K_STUFF_CLASSES): 42 | layers[layer_n]['stuff_instances'].append(inst_id) 43 | else: 44 | layers[layer_n]['object_instances'].append(inst_id) 45 | layers[layer_n]['instance_to_class'][inst_id] = class_names 46 | 47 | return layers 48 | 49 | 50 | def load_and_parse_annotations(dataset_path, dataset_split, n_jobs=1): 51 | dataset_split_folder = 'training' if dataset_split == 'train' else 'validation' 52 | 53 | orig_annotations = loadmat(dataset_path / 'index_ade20k.mat', squeeze_me=True, struct_as_record=True) 54 | image_ids = [image_id.split('.')[0] for image_id in orig_annotations['index'].item()[0] 55 | if dataset_split in image_id] 56 | folders = [Path(folder).relative_to('ADE20K_2016_07_26') for folder in orig_annotations['index'].item()[1] 57 | if dataset_split_folder in folder] 58 | 59 | # list of dictionaries with filename and instance to class mapping 60 | all_layers = parallel_map(list(zip(image_ids, folders)), worker_annotations_loader, n_jobs=n_jobs, 61 | use_kwargs=False, const_args={ 62 | 'dataset_path': dataset_path 63 | }) 64 | 65 | return image_ids, folders, all_layers 66 | 67 | 68 | def create_annotations(dataset_path, dataset_split='train', n_jobs=1): 69 | anno_path = dataset_path / f'{dataset_split}-annotations-object-segmentation.pkl' 70 | image_ids, folders, all_layers = load_and_parse_annotations(dataset_path, dataset_split, n_jobs=n_jobs) 71 | 72 | # create dictionary with annotations 73 | annotations = {} 74 | for index, image_id in enumerate(image_ids): 75 | annotations[image_id] = { 76 | 'folder': folders[index], 77 | 'layers': all_layers[index] 78 | } 79 | 80 | with anno_path.open('wb') as f: 81 | pkl.dump(annotations, f) 82 | 83 | return annotations 84 | -------------------------------------------------------------------------------- /scripts/annotations_conversion/openimages.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import pickle as pkl 3 | from pathlib import Path 4 | from collections import defaultdict 5 | 6 | 7 | def create_annotations(dataset_path, dataset_split='train'): 8 | dataset_path = Path(dataset_path) 9 | _split_path = dataset_path / dataset_split 10 | _images_path = _split_path / 'images' 11 | _masks_path = _split_path / 'masks' 12 | clean_anno_path = _split_path / f'{dataset_split}-annotations-object-segmentation_clean.pkl' 13 | 14 | annotations = { 15 | 'image_id_to_masks': defaultdict(list), # mapping from image_id to a list of masks 16 | 'dataset_samples': [] # list of unique image ids 17 | } 18 | 19 | with open(_split_path / f'{dataset_split}-annotations-object-segmentation.csv', 'r') as f: 20 | reader = csv.DictReader(f, delimiter=',') 21 | for row in reader: 22 | image_id = row['ImageID'] 23 | mask_path = row['MaskPath'] 24 | 25 | if (_images_path / f'{image_id}.jpg').is_file() \ 26 | and (_masks_path / mask_path).is_file(): 27 | annotations['image_id_to_masks'][image_id].append(mask_path) 28 | annotations['dataset_samples'] = list(annotations['image_id_to_masks'].keys()) 29 | 30 | with clean_anno_path.open('wb') as f: 31 | pkl.dump(annotations, f) 32 | 33 | return annotations 34 | -------------------------------------------------------------------------------- /scripts/convert_annotations.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import multiprocessing as mp 4 | from pathlib import Path 5 | 6 | sys.path.insert(0, '.') 7 | from isegm.utils.exp import load_config_file 8 | from scripts.annotations_conversion import openimages, ade20k, coco_lvis 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('dataset', choices=['openimages', 'ade20k', 'coco_lvis'], help='') 15 | parser.add_argument('--split', nargs='+', choices=['train', 'val', 'test'], type=str, default=['train', 'val'], 16 | help='') 17 | parser.add_argument('--n-jobs', type=int, default=10) 18 | parser.add_argument('--config-path', type=str, default='./config.yml', 19 | help='The path to the config file.') 20 | 21 | args = parser.parse_args() 22 | cfg = load_config_file(args.config_path, return_edict=True) 23 | return args, cfg 24 | 25 | 26 | def main(): 27 | mp.set_start_method('spawn') 28 | args, cfg = parse_args() 29 | 30 | for split in args.split: 31 | if args.dataset == 'openimages': 32 | openimages.create_annotations(Path(cfg.OPENIMAGES_PATH), dataset_split=split) 33 | elif args.dataset == 'ade20k' and split != 'test': 34 | ade20k.create_annotations(Path(cfg.ADE20K_PATH), dataset_split=split, n_jobs=args.n_jobs) 35 | elif args.dataset == 'coco_lvis': 36 | coco_lvis.create_annotations(Path(cfg.LVIS_PATH), Path(cfg.COCO_PATH), dataset_split=split) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /scripts/draw_radar.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | 6 | def example_data(): 7 | data = [ 8 | # ('NoC@80%', [ 9 | # [1.36, 1.30, 2.82, 1.75, 2.62, 5.52, 2.15, 13.50, 1.77], 10 | # [0 for _ in range(9)], 11 | # [1.56, 1.56, 3.19, 2.41, 3.41, 6.10, 3.35, 13.31, 2.48], 12 | # [0 for _ in range(9)], 13 | # [0 for _ in range(9)]]), 14 | ('NoC@85%', [ 15 | [1.36, 1.40, 3.79, 1.94, 3.47, 7.18, 2.80, 16.38, 2.05], 16 | [0 for _ in range(9)], 17 | [1.70, 1.83, 4.41, 2.71, 4.46, 7.76, 4.05, 16.71, 2.90], 18 | [0 for _ in range(9)], 19 | [0 for _ in range(9)]]), 20 | ('NoC@90%', [ 21 | [1.50, 2.08, 5.11, 2.25, 5.54, 10.82, 5.55, 19.17, 2.79], 22 | [0 for _ in range(9)], 23 | [1.92, 2.79, 6.08, 3.17, 6.79, 11.24, 6.20, 19.09, 3.80], 24 | [0 for _ in range(9)], 25 | [0 for _ in range(9)]]) 26 | ] 27 | return data 28 | 29 | 30 | if __name__ == '__main__': 31 | N = 9 32 | models = ('SOTA', 'ViT-B-224', 'Vit-B-448', 'ViT-L-224', 'ViT-L-448') 33 | num_models = len(models) 34 | 35 | colors = ('#1aaf6c', '#429bf4', '#d42cea', '#feea11', '#00ffff') 36 | 37 | datasets = ['GrabCut', 'Berkeley', 'DAVIS', 'Pascal', 'SBD', 'BraTS', 'ssTEM', 'OAI-ZIB', 'COCO_MVal'] 38 | num_datasets = len(datasets) 39 | 40 | angles = np.linspace(0, 2 * np.pi, num_datasets, endpoint=False).tolist() 41 | angles += angles[:1] 42 | 43 | fig, axs = plt.subplots(figsize=(12, 6), nrows=1, ncols=2, subplot_kw=dict(polar=True)) 44 | fig.subplots_adjust(wspace=0.5, hspace=0.05, top=0.85, bottom=0.05) 45 | 46 | def add_to_radar(label, values, color, alpha=0.25): 47 | values += values[:1] 48 | ax.plot(angles, values, color=color, linewidth=1, label=label) 49 | ax.fill(angles, values, color=color, alpha=alpha) 50 | 51 | metric_data = example_data() 52 | for ax, (metric, data) in zip(axs.flat, metric_data): 53 | 54 | # add each model to the chart 55 | for i, model in enumerate(models): 56 | values = data[i] 57 | add_to_radar(models[i], values, color=colors[i]) 58 | 59 | ax.set_title(metric, weight='bold', size='medium', position=(0.5, 1.1), 60 | horizontalalignment='center', verticalalignment='center') 61 | 62 | ax.set_theta_offset(np.pi / 2) 63 | ax.set_theta_direction(-1) 64 | ax.set_thetagrids(np.degrees(angles[:-1]), datasets) 65 | 66 | # Go through labels and adjust alignment based on where 67 | # it is in the circle. 68 | for label, angle in zip(ax.get_xticklabels(), angles): 69 | if angle in (0, np.pi): 70 | label.set_horizontalalignment('center') 71 | elif 0 < angle < np.pi: 72 | label.set_horizontalalignment('left') 73 | else: 74 | label.set_horizontalalignment('right') 75 | 76 | ax.set_ylim(0, 20) 77 | 78 | # Add some custom styling. 79 | # Change the color of the tick labels. 80 | ax.tick_params(colors='#222222') 81 | # Make the y-axis (0-100) labels smaller. 82 | ax.tick_params(axis='y', labelsize=8) 83 | # Change the color of the circular gridlines. 84 | ax.grid(color='#AAAAAA') 85 | # Change the color of the outermost gridline (the spine). 86 | ax.spines['polar'].set_color('#222222') 87 | # Change the background color inside the circle itself. 88 | ax.set_facecolor('#FAFAFA') 89 | 90 | axs[0].legend(loc=(1.1, 0.95)) 91 | 92 | # axs[0].legend(loc=(1.2, 0.95), labelspacing=0.1, fontsize='small') 93 | # fig.text(0.5, 0.965, 'Comparison Results', 94 | # horizontalalignment='center', color='black', weight='bold', 95 | # size='large') 96 | 97 | plt.show() -------------------------------------------------------------------------------- /scripts/draw_radar_natural.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | 6 | def example_data(): 7 | data = [ 8 | # ('NoC@80%', [ 9 | # [1.36, 1.30, 2.82, 1.75, 2.62, 5.52, 2.15, 13.50, 1.77], 10 | # [0 for _ in range(9)], 11 | # [1.56, 1.56, 3.19, 2.41, 3.41, 6.10, 3.35, 13.31, 2.48], 12 | # [0 for _ in range(9)], 13 | # [0 for _ in range(9)]]), 14 | ('NoC@85%', [ 15 | [1.36, 1.40, 3.79, 1.94, 3.47, 2.05], 16 | [0 for _ in range(6)], 17 | [1.70, 1.83, 4.41, 2.71, 4.46, 2.90], 18 | [0 for _ in range(6)], 19 | [0 for _ in range(6)]]), 20 | ('NoC@90%', [ 21 | [1.50, 2.08, 5.11, 2.25, 5.54, 2.79], 22 | [0 for _ in range(6)], 23 | [1.92, 2.79, 6.08, 3.17, 6.79, 3.80], 24 | [0 for _ in range(6)], 25 | [0 for _ in range(6)]]) 26 | ] 27 | return data 28 | 29 | 30 | if __name__ == '__main__': 31 | N = 9 32 | models = ('SOTA', 'ViT-B-224', 'Vit-B-448', 'ViT-L-224', 'ViT-L-448') 33 | num_models = len(models) 34 | 35 | colors = ('#1aaf6c', '#429bf4', '#d42cea', '#feea11', '#00ffff') 36 | 37 | datasets = ['GrabCut', 'Berkeley', 'DAVIS', 'Pascal', 'SBD', 'COCO_MVal'] 38 | num_datasets = len(datasets) 39 | 40 | angles = np.linspace(0, 2 * np.pi, num_datasets, endpoint=False).tolist() 41 | angles += angles[:1] 42 | 43 | fig, axs = plt.subplots(figsize=(12, 6), nrows=1, ncols=2, subplot_kw=dict(polar=True)) 44 | fig.subplots_adjust(wspace=0.5, hspace=0.05, top=0.85, bottom=0.05) 45 | 46 | def add_to_radar(label, values, color, alpha=0.25): 47 | values += values[:1] 48 | ax.plot(angles, values, color=color, linewidth=1, label=label) 49 | ax.fill(angles, values, color=color, alpha=alpha) 50 | 51 | metric_data = example_data() 52 | for ax, (metric, data) in zip(axs.flat, metric_data): 53 | 54 | # add each model to the chart 55 | for i, model in enumerate(models): 56 | values = data[i] 57 | add_to_radar(models[i], values, color=colors[i]) 58 | 59 | ax.set_title(metric, weight='bold', size='medium', position=(0.5, 1.1), 60 | horizontalalignment='center', verticalalignment='center') 61 | 62 | ax.set_theta_offset(np.pi / 2) 63 | ax.set_theta_direction(-1) 64 | ax.set_thetagrids(np.degrees(angles[:-1]), datasets) 65 | 66 | # Go through labels and adjust alignment based on where 67 | # it is in the circle. 68 | for label, angle in zip(ax.get_xticklabels(), angles): 69 | if angle in (0, np.pi): 70 | label.set_horizontalalignment('center') 71 | elif 0 < angle < np.pi: 72 | label.set_horizontalalignment('left') 73 | else: 74 | label.set_horizontalalignment('right') 75 | 76 | ax.set_ylim(0, 7) 77 | 78 | # Add some custom styling. 79 | # Change the color of the tick labels. 80 | ax.tick_params(colors='#222222') 81 | # Make the y-axis (0-100) labels smaller. 82 | ax.tick_params(axis='y', labelsize=8) 83 | # Change the color of the circular gridlines. 84 | ax.grid(color='#AAAAAA') 85 | # Change the color of the outermost gridline (the spine). 86 | ax.spines['polar'].set_color('#222222') 87 | # Change the background color inside the circle itself. 88 | ax.set_facecolor('#FAFAFA') 89 | 90 | axs[0].legend(loc=(1.1, 0.95)) 91 | 92 | # axs[0].legend(loc=(1.2, 0.95), labelspacing=0.1, fontsize='small') 93 | # fig.text(0.5, 0.965, 'Comparison Results', 94 | # horizontalalignment='center', color='black', weight='bold', 95 | # size='large') 96 | 97 | plt.show() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import importlib.util 4 | 5 | import torch 6 | from isegm.utils.exp import init_experiment 7 | 8 | 9 | def main(): 10 | args = parse_args() 11 | if args.temp_model_path: 12 | model_script = load_module(args.temp_model_path) 13 | else: 14 | model_script = load_module(args.model_path) 15 | 16 | model_base_name = getattr(model_script, 'MODEL_NAME', None) 17 | 18 | args.distributed = 'WORLD_SIZE' in os.environ 19 | cfg = init_experiment(args, model_base_name) 20 | 21 | torch.backends.cudnn.benchmark = True 22 | torch.multiprocessing.set_sharing_strategy('file_system') 23 | 24 | model_script.main(cfg) 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | 30 | parser.add_argument('model_path', type=str, 31 | help='Path to the model script.') 32 | 33 | parser.add_argument('--exp-name', type=str, default='', 34 | help='Here you can specify the name of the experiment. ' 35 | 'It will be added as a suffix to the experiment folder.') 36 | 37 | parser.add_argument('--workers', type=int, default=4, 38 | metavar='N', help='Dataloader threads.') 39 | 40 | parser.add_argument('--batch-size', type=int, default=-1, 41 | help='You can override model batch size by specify positive number.') 42 | 43 | parser.add_argument('--ngpus', type=int, default=1, 44 | help='Number of GPUs. ' 45 | 'If you only specify "--gpus" argument, the ngpus value will be calculated automatically. ' 46 | 'You should use either this argument or "--gpus".') 47 | 48 | parser.add_argument('--gpus', type=str, default='', required=False, 49 | help='Ids of used GPUs. You should use either this argument or "--ngpus".') 50 | 51 | parser.add_argument('--resume-exp', type=str, default=None, 52 | help='The prefix of the name of the experiment to be continued. ' 53 | 'If you use this field, you must specify the "--resume-prefix" argument.') 54 | 55 | parser.add_argument('--resume-prefix', type=str, default='latest', 56 | help='The prefix of the name of the checkpoint to be loaded.') 57 | 58 | parser.add_argument('--start-epoch', type=int, default=0, 59 | help='The number of the starting epoch from which training will continue. ' 60 | '(it is important for correct logging and learning rate)') 61 | 62 | parser.add_argument('--weights', type=str, default=None, 63 | help='Model weights will be loaded from the specified path if you use this argument.') 64 | 65 | parser.add_argument('--temp-model-path', type=str, default='', 66 | help='Do not use this argument (for internal purposes).') 67 | 68 | parser.add_argument("--local_rank", type=int, default=0) 69 | 70 | # parameters for experimenting 71 | parser.add_argument('--layerwise-decay', action='store_true', 72 | help='layer wise decay for transformer blocks.') 73 | 74 | parser.add_argument('--upsample', type=str, default='x1', 75 | help='upsample the output.') 76 | 77 | parser.add_argument('--random-split', action='store_true', 78 | help='random split the patch instead of window split.') 79 | 80 | return parser.parse_args() 81 | 82 | 83 | def load_module(script_path): 84 | spec = importlib.util.spec_from_file_location("model_script", script_path) 85 | model_script = importlib.util.module_from_spec(spec) 86 | spec.loader.exec_module(model_script) 87 | 88 | return model_script 89 | 90 | 91 | if __name__ == '__main__': 92 | main() --------------------------------------------------------------------------------