├── .DS_Store ├── .gitignore ├── DET ├── DETECTION.md ├── configs │ ├── coco.py │ ├── common.py │ └── convmae │ │ ├── convmae_base_mask_rcnn_FPN_100ep.py │ │ ├── convmae_base_mask_rcnn_FPN_25ep.py │ │ └── convmae_base_mask_rcnn_FPN_50ep.py ├── lazyconfig_train_net.py ├── models │ ├── __init__.py │ ├── convmae.py │ └── modeling │ │ ├── __init__.py │ │ ├── postprocessing.py │ │ └── rcnn.py └── utils │ ├── __init__.py │ └── pos_embed.py ├── FINETUNE.md ├── LICENSE ├── PRETRAIN.md ├── README.md ├── SEG ├── SEGMENTATION.md ├── backbone │ └── convmae.py ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ ├── ade20k.py │ │ │ └── ade20k_640x640.py │ │ ├── default_runtime.py │ │ ├── models │ │ │ └── upernet.py │ │ └── schedules │ │ │ ├── schedule_160k.py │ │ │ └── schedule_320k.py │ └── convmae │ │ └── upernet_convmae_base_512_slide_160k_ade20k.py ├── mmcv_custom │ ├── __init__.py │ ├── apex_runner │ │ ├── __init__.py │ │ ├── apex_iter_based_runner.py │ │ ├── checkpoint.py │ │ └── optimizer.py │ ├── checkpoint.py │ ├── layer_decay_optimizer_constructor.py │ ├── resize_transform.py │ └── train_api.py └── tools │ ├── dist_test.sh │ ├── dist_train.sh │ ├── flops.py │ ├── test.py │ └── train.py ├── engine_finetune.py ├── engine_pretrain.py ├── figures ├── ConvMAE.png ├── Downstream.png └── feat_map.JPG ├── main_finetune.py ├── main_linprobe.py ├── main_pretrain.py ├── models_convmae.py ├── models_convvit.py ├── submitit_finetune.py ├── submitit_linprobe.py ├── submitit_pretrain.py ├── util ├── crop.py ├── datasets.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── misc.py └── pos_embed.py └── vision_transformer.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VL/ConvMAE/4c97b4c9ec9c85724bc9594eb7302c803ae58c19/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /DET/DETECTION.md: -------------------------------------------------------------------------------- 1 | # ConvMAE: Masked Convolution Meets Masked Autoencoders 2 | 3 | This folder contains the implementation of the ConvMAE transfer learning for object detection on COCO. 4 | 5 | ## Pipeline 6 | 7 | ![tenser](../figures/Downstream.png) 8 | 9 | 10 | ## Model Zoo 11 | | Models | Pretrain | Pretrain Epochs | Finetune Epochs | #Params(M)| FLOPs(T) | box AP | mask AP | logs/weights | 12 | | :---: | :---: | :---: |:---: | :---: | :---: | :---: | :---: | :---: | 13 | | ConvMAE-B | IN1K w/o lables | 1600 | 25 | 104 | 0.9 | 53.2 | 47.1 | [log](https://drive.google.com/file/d/1vQ9ps-TxeS_8BRfSWZh-X-5Kki7mgIgR/view?usp=sharing)/[weight](https://drive.google.com/file/d/17gy2mlrRVpIlQN9ERSHh98VkHhWINn-m/view?usp=sharing) | 14 | 15 | ## Usage 16 | 17 | ### Install 18 | - Clone this repo: 19 | 20 | ```bash 21 | git clone https://github.com/Alpha-VL/ConvMAE 22 | cd ConvMAE/DET 23 | ``` 24 | 25 | - Create a conda environment and activate it: 26 | ``` 27 | conda create -n mimdet python=3.9 28 | conda activate mimdet 29 | ``` 30 | * Install `torch==1.9.0` and `torchvision==0.10.0` 31 | * Install [`Detectron2==0.6`](https://github.com/facebookresearch/detectron2), follow [d2 doc](https://detectron2.readthedocs.io/tutorials/install.html). 32 | * Install [`timm==0.4.12`](https://github.com/rwightman/pytorch-image-models), follow [timm doc](https://fastai.github.io/timmdocs/). 33 | * Install [`einops`](https://github.com/arogozhnikov/einops), follow [einops repo](https://github.com/arogozhnikov/einops#installation--). 34 | * Prepare [`COCO`](https://cocodataset.org/#home) dataset, follow [d2 doc](https://detectron2.readthedocs.io/en/latest/tutorials/builtin_datasets.html). 35 | 36 | ### Data preparation 37 | You can download the COCO-2017 [here](https://cocodataset.org) and prepare the COCO follow this format: 38 | 39 | ```tree data 40 | ├── data 41 | │ ├── coco 42 | │ │ ├── annotations 43 | │ │ ├── train2017 44 | │ │ ├── val2017 45 | │ │ ├── test2017 46 | ``` 47 | It is suggested to link the data path as: 48 | ```bash 49 | export DETECTRON2_DATASETS=/path/to/data 50 | ``` 51 | 52 | ### Evaluation 53 | Download the finetuned model [here](https://drive.google.com/file/d/17gy2mlrRVpIlQN9ERSHh98VkHhWINn-m/view?usp=sharing). 54 | 55 | ``` 56 | # inference 57 | python lazyconfig_train_net.py --config-file --num-gpus --eval-only train.init_checkpoint= 58 | ``` 59 | 60 | 61 | ### Training 62 | Download the pretrained ConvMAE model [here](https://drive.google.com/file/d/1AEPivXw0A0b_m5EwEi6fg2pOAoDr8C31/view?usp=sharing). 63 | 64 | ``` 65 | # single-machine training 66 | python lazyconfig_train_net.py --config-file --num-gpus model.backbone.bottom_up.pretrained= 67 | 68 | # multi-machine training 69 | python lazyconfig_train_net.py --config-file --num-gpus --num-machines --master_addr --master_port model.backbone.bottom_up.pretrained= 70 | ``` 71 | -------------------------------------------------------------------------------- /DET/configs/coco.py: -------------------------------------------------------------------------------- 1 | import detectron2.data.transforms as T 2 | from detectron2.config import LazyCall as L 3 | from detectron2.data import ( 4 | DatasetMapper, 5 | build_detection_test_loader, 6 | build_detection_train_loader, 7 | get_detection_dataset_dicts, 8 | ) 9 | from detectron2.evaluation import COCOEvaluator 10 | from omegaconf import OmegaConf 11 | 12 | dataloader = OmegaConf.create() 13 | 14 | dataloader.train = L(build_detection_train_loader)( 15 | dataset=L(get_detection_dataset_dicts)(names="coco_2017_train"), 16 | mapper=L(DatasetMapper)( 17 | is_train=True, 18 | augmentations=[ 19 | L(T.ResizeShortestEdge)( 20 | short_edge_length=(800,), sample_style="choice", max_size=1333, 21 | ), 22 | L(T.RandomFlip)(horizontal=True), 23 | ], 24 | image_format="RGB", 25 | use_instance_mask=False, 26 | ), 27 | total_batch_size=16, 28 | num_workers=4, 29 | ) 30 | 31 | dataloader.test = L(build_detection_test_loader)( 32 | dataset=L(get_detection_dataset_dicts)(names="coco_2017_val", filter_empty=False), 33 | mapper=L(DatasetMapper)( 34 | is_train=False, 35 | augmentations=[L(T.ResizeShortestEdge)(short_edge_length=800, max_size=1333),], 36 | image_format="${...train.mapper.image_format}", 37 | ), 38 | num_workers=4, 39 | ) 40 | 41 | dataloader.evaluator = L(COCOEvaluator)(dataset_name="${..test.dataset.names}",) 42 | -------------------------------------------------------------------------------- /DET/configs/common.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from omegaconf import DictConfig 3 | 4 | import torch 5 | from detectron2.config import LazyCall as L 6 | from detectron2.modeling.meta_arch import GeneralizedRCNN 7 | from detectron2.solver import WarmupParamScheduler 8 | from detectron2.structures import Instances 9 | from detectron2.utils.events import get_event_storage 10 | from fvcore.common.param_scheduler import MultiStepParamScheduler 11 | 12 | from models.modeling import _postprocess 13 | 14 | 15 | class GeneralizedRCNNImageListForward(GeneralizedRCNN): 16 | def __init__(self, *args, **kwargs): 17 | self.lsj_postprocess = kwargs.pop("lsj_postprocess") 18 | super().__init__(*args, **kwargs) 19 | 20 | def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): 21 | if not self.training: 22 | return self.inference(batched_inputs) 23 | 24 | images = self.preprocess_image(batched_inputs) 25 | if "instances" in batched_inputs[0]: 26 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 27 | else: 28 | gt_instances = None 29 | 30 | features = self.backbone(images) 31 | 32 | if self.proposal_generator is not None: 33 | proposals, proposal_losses = self.proposal_generator( 34 | images, features, gt_instances 35 | ) 36 | else: 37 | assert "proposals" in batched_inputs[0] 38 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 39 | proposal_losses = {} 40 | 41 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) 42 | if self.vis_period > 0: 43 | storage = get_event_storage() 44 | if storage.iter % self.vis_period == 0: 45 | self.visualize_training(batched_inputs, proposals) 46 | 47 | losses = {} 48 | losses.update(detector_losses) 49 | losses.update(proposal_losses) 50 | return losses 51 | 52 | def inference( 53 | self, 54 | batched_inputs: List[Dict[str, torch.Tensor]], 55 | detected_instances: Optional[List[Instances]] = None, 56 | do_postprocess: bool = True, 57 | ): 58 | assert not self.training 59 | 60 | images = self.preprocess_image(batched_inputs) 61 | features = self.backbone(images) 62 | 63 | if detected_instances is None: 64 | if self.proposal_generator is not None: 65 | proposals, _ = self.proposal_generator(images, features, None) 66 | else: 67 | assert "proposals" in batched_inputs[0] 68 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 69 | 70 | results, _ = self.roi_heads(images, features, proposals, None) 71 | else: 72 | detected_instances = [x.to(self.device) for x in detected_instances] 73 | results = self.roi_heads.forward_with_given_boxes( 74 | features, detected_instances 75 | ) 76 | 77 | if do_postprocess: 78 | assert ( 79 | not torch.jit.is_scripting() 80 | ), "Scripting is not supported for postprocess." 81 | if self.lsj_postprocess: 82 | return _postprocess(results, batched_inputs, images.image_sizes) 83 | return GeneralizedRCNN._postprocess( 84 | results, batched_inputs, images.image_sizes 85 | ) 86 | else: 87 | return results 88 | 89 | 90 | def get_fpn_model_parameters( 91 | model, 92 | weight_decay=1e-5, 93 | weight_decay_norm=0.0, 94 | base_lr=4e-5, 95 | skip_list=(), 96 | multiplier=1.5, 97 | ): 98 | parameter_group_vars = {} 99 | for name, param in model.named_parameters(): 100 | if not param.requires_grad: 101 | continue # frozen weights 102 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 103 | group_name = "no_decay" 104 | this_weight_decay = 0.0 105 | elif "norm" in name and weight_decay_norm is not None: 106 | group_name = "decay" 107 | this_weight_decay = weight_decay_norm 108 | else: 109 | group_name = "decay" 110 | this_weight_decay = weight_decay 111 | 112 | if name.startswith("backbone.bottom_up.encoder.patch_embed"): 113 | group_name = "backbone.bottom_up.encoder.patch_embed_%s" % (group_name) 114 | if group_name not in parameter_group_vars: 115 | parameter_group_vars[group_name] = { 116 | "weight_decay": this_weight_decay, 117 | "params": [], 118 | "lr": base_lr, 119 | } 120 | elif name.startswith("backbone.bottom_up.encoder"): 121 | group_name = "backbone.bottom_up.encoder_%s" % (group_name) 122 | if group_name not in parameter_group_vars: 123 | parameter_group_vars[group_name] = { 124 | "weight_decay": this_weight_decay, 125 | "params": [], 126 | "lr": base_lr / multiplier, 127 | } 128 | else: 129 | group_name = "others_%s" % (group_name) 130 | if group_name not in parameter_group_vars: 131 | parameter_group_vars[group_name] = { 132 | "weight_decay": this_weight_decay, 133 | "params": [], 134 | "lr": base_lr * multiplier, 135 | } 136 | 137 | parameter_group_vars[group_name]["params"].append(param) 138 | return list(parameter_group_vars.values()) 139 | 140 | 141 | train = dict( 142 | output_dir="", 143 | init_checkpoint="", 144 | max_iter=368750, 145 | amp=dict(enabled=True), # options for Automatic Mixed Precision 146 | ddp=dict( # options for DistributedDataParallel 147 | broadcast_buffers=False, find_unused_parameters=False, fp16_compression=True, 148 | ), 149 | checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer 150 | eval_period=5000, 151 | log_period=20, 152 | device="cuda" 153 | # ... 154 | ) 155 | 156 | lr_multiplier = L(WarmupParamScheduler)( 157 | scheduler=L(MultiStepParamScheduler)( 158 | values=[1.0, 0.1, 0.01], 159 | milestones=[327778, 355029], 160 | num_updates=train["max_iter"], 161 | ), 162 | warmup_length=500 / train["max_iter"], 163 | warmup_factor=0.067, 164 | ) 165 | 166 | -------------------------------------------------------------------------------- /DET/configs/convmae/convmae_base_mask_rcnn_FPN_100ep.py: -------------------------------------------------------------------------------- 1 | import detectron2.data.transforms as T 2 | import torch 3 | from detectron2.config import LazyCall as L 4 | from detectron2.layers import ShapeSpec 5 | from detectron2.layers.batch_norm import NaiveSyncBatchNorm 6 | from detectron2.modeling.anchor_generator import DefaultAnchorGenerator 7 | from detectron2.modeling.backbone import FPN 8 | from detectron2.modeling.backbone.fpn import LastLevelMaxPool 9 | from detectron2.modeling.box_regression import Box2BoxTransform 10 | from detectron2.modeling.matcher import Matcher 11 | from detectron2.modeling.poolers import ROIPooler 12 | from detectron2.modeling.proposal_generator import RPN, StandardRPNHead 13 | from detectron2.modeling.roi_heads import ( 14 | FastRCNNConvFCHead, 15 | FastRCNNOutputLayers, 16 | MaskRCNNConvUpsampleHead, 17 | StandardROIHeads, 18 | ) 19 | from detectron2.solver import WarmupParamScheduler 20 | from detectron2.solver.build import get_default_optimizer_params 21 | from fvcore.common.param_scheduler import CosineParamScheduler 22 | 23 | from models import ConvViTDet 24 | 25 | from ..coco import dataloader 26 | from ..common import GeneralizedRCNNImageListForward 27 | 28 | model = L(GeneralizedRCNNImageListForward)( 29 | lsj_postprocess=True, 30 | backbone=L(FPN)( 31 | bottom_up=L(ConvViTDet)( 32 | window_size=16, 33 | with_cp=False, 34 | pretrained="./convmae_base.pth", 35 | stop_grad_conv1=False, 36 | sincos_pos_embed=True, 37 | zero_pos_embed=False, 38 | img_size=1024, 39 | patch_size=16, 40 | embed_dim=768, 41 | depth=12, 42 | num_heads=12, 43 | drop_path_rate=0.1, 44 | init_values=None, 45 | beit_qkv_bias=False, 46 | ), 47 | in_features=["s0", "s1", "s2", "s3"], 48 | out_channels=256, 49 | norm="SyncBN", 50 | top_block=L(LastLevelMaxPool)(), 51 | ), 52 | proposal_generator=L(RPN)( 53 | in_features=["p2", "p3", "p4", "p5", "p6"], 54 | head=L(StandardRPNHead)(in_channels=256, num_anchors=3), 55 | anchor_generator=L(DefaultAnchorGenerator)( 56 | sizes=[[32], [64], [128], [256], [512]], 57 | aspect_ratios=[0.5, 1.0, 2.0], 58 | strides=[4, 8, 16, 32, 64], 59 | offset=0.0, 60 | ), 61 | anchor_matcher=L(Matcher)( 62 | thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True 63 | ), 64 | box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]), 65 | batch_size_per_image=256, 66 | positive_fraction=0.5, 67 | pre_nms_topk=(2000, 1000), 68 | post_nms_topk=(1000, 1000), 69 | nms_thresh=0.7, 70 | ), 71 | roi_heads=L(StandardROIHeads)( 72 | num_classes=80, 73 | batch_size_per_image=512, 74 | positive_fraction=0.25, 75 | proposal_matcher=L(Matcher)( 76 | thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False 77 | ), 78 | box_in_features=["p2", "p3", "p4", "p5"], 79 | box_pooler=L(ROIPooler)( 80 | output_size=7, 81 | scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), 82 | sampling_ratio=0, 83 | pooler_type="ROIAlignV2", 84 | ), 85 | box_head=L(FastRCNNConvFCHead)( 86 | input_shape=ShapeSpec(channels=256, height=7, width=7), 87 | conv_dims=[], 88 | fc_dims=[1024, 1024], 89 | ), 90 | box_predictor=L(FastRCNNOutputLayers)( 91 | input_shape=ShapeSpec(channels=1024), 92 | test_score_thresh=0.05, 93 | box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)), 94 | num_classes="${..num_classes}", 95 | ), 96 | mask_in_features=["p2", "p3", "p4", "p5"], 97 | mask_pooler=L(ROIPooler)( 98 | output_size=14, 99 | scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), 100 | sampling_ratio=0, 101 | pooler_type="ROIAlignV2", 102 | ), 103 | mask_head=L(MaskRCNNConvUpsampleHead)( 104 | input_shape=ShapeSpec(channels=256, width=14, height=14), 105 | num_classes="${..num_classes}", 106 | conv_dims=[256, 256, 256, 256, 256], 107 | ), 108 | ), 109 | pixel_mean=[123.675, 116.280, 103.530], 110 | pixel_std=[58.395, 57.12, 57.375], 111 | input_format="RGB", 112 | ) 113 | # Using NaiveSyncBatchNorm because heads may have empty input. That is not supported by 114 | # torch.nn.SyncBatchNorm. We can remove this after 115 | # https://github.com/pytorch/pytorch/issues/36530 is fixed. 116 | model.roi_heads.box_head.conv_norm = ( 117 | model.roi_heads.mask_head.conv_norm 118 | ) = lambda c: NaiveSyncBatchNorm(c, stats_mode="N") 119 | # fmt: on 120 | 121 | # 2conv in RPN: 122 | # https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/detection/modeling/architecture/heads.py#L95-L97 # noqa: E501, B950 123 | model.proposal_generator.head.conv_dims = [-1, -1] 124 | 125 | # 4conv1fc box head 126 | model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] 127 | model.roi_heads.box_head.fc_dims = [1024] 128 | 129 | optimizer = L(torch.optim.AdamW)( 130 | params=L(get_default_optimizer_params)( 131 | # params.model is meant to be set to the model object, before instantiating 132 | # the optimizer. 133 | weight_decay_norm=0.0, 134 | overrides={ 135 | "pos_embed": {"weight_decay": 0.0}, 136 | "relative_position_bias_table": {"weight_decay": 0.0}, 137 | }, 138 | ), 139 | lr=8.0e-05, 140 | betas=(0.9, 0.999), 141 | weight_decay=0.1, 142 | ) 143 | 144 | lr_multiplier = L(WarmupParamScheduler)( 145 | scheduler=L(CosineParamScheduler)(start_value=1.0, end_value=0.0), 146 | warmup_length=0.125 / 100, 147 | warmup_factor=0.001, 148 | ) 149 | 150 | train = dict( 151 | output_dir="output/convmae_base_mask_rcnn_FPN_100ep", 152 | init_checkpoint="", 153 | max_iter=368750, 154 | amp=dict(enabled=True), # options for Automatic Mixed Precision 155 | ddp=dict( # options for DistributedDataParallel 156 | broadcast_buffers=False, find_unused_parameters=False, fp16_compression=True, 157 | ), 158 | checkpointer=dict(period=3688, max_to_keep=100), # options for PeriodicCheckpointer 159 | eval_period=3688, 160 | log_period=20, 161 | device="cuda" 162 | # ... 163 | ) 164 | 165 | # resize_and_crop_image in: 166 | # https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/detection/utils/input_utils.py#L127 # noqa: E501, B950 167 | image_size = 1024 168 | dataloader.train.total_batch_size = 32 169 | dataloader.train.mapper.augmentations = [ 170 | L(T.ResizeScale)( 171 | min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size 172 | ), 173 | L(T.FixedSizeCrop)(crop_size=(image_size, image_size)), 174 | L(T.RandomFlip)(horizontal=True), 175 | ] 176 | dataloader.train.mapper.use_instance_mask = True 177 | dataloader.train.mapper.image_format = "RGB" 178 | # recompute boxes due to cropping 179 | dataloader.train.mapper.recompute_boxes = True 180 | dataloader.test.mapper.augmentations = [ 181 | L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size), 182 | L(T.FixedSizeCrop)(crop_size=(image_size, image_size)), 183 | ] 184 | dataloader.test.mapper.image_format = "RGB" 185 | dataloader.evaluator.output_dir = "${...train.output_dir}" 186 | -------------------------------------------------------------------------------- /DET/configs/convmae/convmae_base_mask_rcnn_FPN_25ep.py: -------------------------------------------------------------------------------- 1 | from .convmae_base_mask_rcnn_FPN_100ep import ( 2 | dataloader, 3 | lr_multiplier, 4 | model, 5 | optimizer, 6 | train, 7 | ) 8 | 9 | train.max_iter = train.max_iter // 4 # 100ep -> 25ep 10 | 11 | lr_multiplier.warmup_length *= 4 12 | 13 | train.output_dir = "./convmae_base_mask_rcnn_FPN_25ep" 14 | __all__ = ["dataloader", "lr_multiplier", "model", "optimizer", "train"] 15 | -------------------------------------------------------------------------------- /DET/configs/convmae/convmae_base_mask_rcnn_FPN_50ep.py: -------------------------------------------------------------------------------- 1 | from .convmae_base_mask_rcnn_FPN_100ep import ( 2 | dataloader, 3 | lr_multiplier, 4 | model, 5 | optimizer, 6 | train, 7 | ) 8 | 9 | train.max_iter = train.max_iter // 2 # 100ep -> 50ep 10 | 11 | lr_multiplier.warmup_length *= 2 12 | 13 | train.output_dir = "./convmae_base_mask_rcnn_FPN_50ep" 14 | __all__ = ["dataloader", "lr_multiplier", "model", "optimizer", "train"] 15 | -------------------------------------------------------------------------------- /DET/lazyconfig_train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Training script using the new "LazyConfig" python config files. 3 | This scripts reads a given python config file and runs the training or evaluation. 4 | It can be used to train any models or dataset as long as they can be 5 | instantiated by the recursive construction defined in the given config file. 6 | Besides lazy construction of models, dataloader, etc., this scripts expects a 7 | few common configuration parameters currently defined in "configs/common/train.py". 8 | To add more complicated training logic, you can easily add other configs 9 | in the config file and implement a new train_net.py to handle them. 10 | """ 11 | import logging 12 | 13 | from detectron2.checkpoint import DetectionCheckpointer 14 | from detectron2.config import LazyConfig, instantiate 15 | from detectron2.engine import ( 16 | AMPTrainer, 17 | SimpleTrainer, 18 | default_argument_parser, 19 | default_setup, 20 | default_writers, 21 | hooks, 22 | launch, 23 | ) 24 | from detectron2.engine.defaults import create_ddp_model 25 | from detectron2.evaluation import inference_on_dataset, print_csv_format 26 | from detectron2.utils import comm 27 | 28 | logger = logging.getLogger("detectron2") 29 | 30 | 31 | def do_test(cfg, model): 32 | if "evaluator" in cfg.dataloader: 33 | ret = inference_on_dataset( 34 | model, 35 | instantiate(cfg.dataloader.test), 36 | instantiate(cfg.dataloader.evaluator), 37 | ) 38 | print_csv_format(ret) 39 | return ret 40 | 41 | 42 | def do_train(args, cfg): 43 | """ 44 | Args: 45 | cfg: an object with the following attributes: 46 | model: instantiate to a module 47 | dataloader.{train,test}: instantiate to dataloaders 48 | dataloader.evaluator: instantiate to evaluator for test set 49 | optimizer: instantaite to an optimizer 50 | lr_multiplier: instantiate to a fvcore scheduler 51 | train: other misc config defined in `configs/common/train.py`, including: 52 | output_dir (str) 53 | init_checkpoint (str) 54 | amp.enabled (bool) 55 | max_iter (int) 56 | eval_period, log_period (int) 57 | device (str) 58 | checkpointer (dict) 59 | ddp (dict) 60 | """ 61 | model = instantiate(cfg.model) 62 | logger = logging.getLogger("detectron2") 63 | logger.info("Model:\n{}".format(model)) 64 | model.to(cfg.train.device) 65 | 66 | cfg.optimizer.params.model = model 67 | optim = instantiate(cfg.optimizer) 68 | 69 | train_loader = instantiate(cfg.dataloader.train) 70 | 71 | model = create_ddp_model(model, **cfg.train.ddp) 72 | trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)( 73 | model, train_loader, optim 74 | ) 75 | checkpointer = DetectionCheckpointer(model, cfg.train.output_dir, trainer=trainer,) 76 | trainer.register_hooks( 77 | [ 78 | hooks.IterationTimer(), 79 | hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), 80 | hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) 81 | if comm.is_main_process() 82 | else None, 83 | hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), 84 | hooks.PeriodicWriter( 85 | default_writers(cfg.train.output_dir, cfg.train.max_iter), 86 | period=cfg.train.log_period, 87 | ) 88 | if comm.is_main_process() 89 | else None, 90 | ] 91 | ) 92 | 93 | checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) 94 | if args.resume and checkpointer.has_checkpoint(): 95 | # The checkpoint stores the training iteration that just finished, thus we start 96 | # at the next iteration 97 | start_iter = trainer.iter + 1 98 | else: 99 | start_iter = 0 100 | trainer.train(start_iter, cfg.train.max_iter) 101 | 102 | 103 | def main(args): 104 | cfg = LazyConfig.load(args.config_file) 105 | cfg = LazyConfig.apply_overrides(cfg, args.opts) 106 | default_setup(cfg, args) 107 | 108 | if args.eval_only: 109 | model = instantiate(cfg.model) 110 | model.to(cfg.train.device) 111 | model = create_ddp_model(model) 112 | DetectionCheckpointer(model).load(cfg.train.init_checkpoint) 113 | print(do_test(cfg, model)) 114 | else: 115 | do_train(args, cfg) 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = default_argument_parser() 120 | parser.add_argument("--node_rank", type=int, default=0) 121 | parser.add_argument("--master_addr", default="") 122 | parser.add_argument("--master_port", default="") 123 | args = parser.parse_args() 124 | args.machine_rank = args.node_rank 125 | if args.master_addr or args.master_port: 126 | assert args.master_addr and args.master_port 127 | args.dist_url = f"tcp://{args.master_addr}:{args.master_port}" 128 | print(args) 129 | launch( 130 | main, 131 | args.num_gpus, 132 | num_machines=args.num_machines, 133 | machine_rank=args.machine_rank, 134 | dist_url=args.dist_url, 135 | args=(args,), 136 | ) 137 | -------------------------------------------------------------------------------- /DET/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .convmae import ConvViTDet 2 | from .modeling import _postprocess 3 | 4 | __all__ = [ 5 | "ConvViTDet", 6 | "_postprocess", 7 | ] 8 | -------------------------------------------------------------------------------- /DET/models/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .rcnn import _postprocess 2 | 3 | __all__ = ["_postprocess"] 4 | -------------------------------------------------------------------------------- /DET/models/modeling/postprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from detectron2.structures import Instances, ROIMasks 3 | 4 | 5 | # perhaps should rename to "resize_instance" 6 | def detector_postprocess( 7 | results: Instances, 8 | output_height: int, 9 | output_width: int, 10 | mask_threshold: float = 0.5, 11 | ): 12 | """Resize the output instances. The input images are often resized when 13 | entering an object detector. As a result, we often need the outputs of the 14 | detector in a different resolution from its inputs. 15 | 16 | This function will resize the raw outputs of an R-CNN detector 17 | to produce outputs according to the desired output resolution. 18 | 19 | Args: 20 | results (Instances): the raw outputs from the detector. 21 | `results.image_size` contains the input image resolution the detector sees. 22 | This object might be modified in-place. 23 | output_height, output_width: the desired output resolution. 24 | 25 | Returns: 26 | Instances: the resized output from the model, based on the output resolution 27 | """ 28 | if isinstance(output_width, torch.Tensor): 29 | # This shape might (but not necessarily) be tensors during tracing. 30 | # Converts integer tensors to float temporaries to ensure true 31 | # division is performed when computing scale_x and scale_y. 32 | output_width_tmp = output_width.float() 33 | output_height_tmp = output_height.float() 34 | new_size = torch.stack([output_height, output_width]) 35 | else: 36 | new_size = (output_height, output_width) 37 | output_width_tmp = output_width 38 | output_height_tmp = output_height 39 | 40 | scale_x, scale_y = ( 41 | output_width_tmp / results.image_size[1], 42 | output_height_tmp / results.image_size[0], 43 | ) 44 | scale_max = max(scale_x, scale_y) 45 | results = Instances(new_size, **results.get_fields()) 46 | 47 | if results.has("pred_boxes"): 48 | output_boxes = results.pred_boxes 49 | elif results.has("proposal_boxes"): 50 | output_boxes = results.proposal_boxes 51 | else: 52 | output_boxes = None 53 | assert output_boxes is not None, "Predictions must contain boxes!" 54 | 55 | output_boxes.scale(scale_max, scale_max) 56 | output_boxes.clip((output_height, output_width)) 57 | 58 | results = results[output_boxes.nonempty()] 59 | 60 | if results.has("pred_masks"): 61 | if isinstance(results.pred_masks, ROIMasks): 62 | roi_masks = results.pred_masks 63 | else: 64 | # pred_masks is a tensor of shape (N, 1, M, M) 65 | roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) 66 | results.pred_masks = roi_masks.to_bitmasks( 67 | results.pred_boxes, output_height, output_width, mask_threshold 68 | ).tensor # TODO return ROIMasks/BitMask object in the future 69 | 70 | if results.has("pred_keypoints"): 71 | results.pred_keypoints[:, :, 0] *= scale_x 72 | results.pred_keypoints[:, :, 1] *= scale_y 73 | 74 | return results 75 | -------------------------------------------------------------------------------- /DET/models/modeling/rcnn.py: -------------------------------------------------------------------------------- 1 | from .postprocessing import detector_postprocess 2 | 3 | 4 | def _postprocess(instances, batched_inputs, image_sizes): 5 | """Rescale the output instances to the target size.""" 6 | # note: private function; subject to changes 7 | processed_results = [] 8 | for results_per_image, input_per_image, image_size in zip( 9 | instances, batched_inputs, image_sizes 10 | ): 11 | height = input_per_image.get("height", image_size[0]) 12 | width = input_per_image.get("width", image_size[1]) 13 | r = detector_postprocess(results_per_image, height, width) 14 | processed_results.append({"instances": r}) 15 | return processed_results 16 | -------------------------------------------------------------------------------- /DET/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VL/ConvMAE/4c97b4c9ec9c85724bc9594eb7302c803ae58c19/DET/utils/__init__.py -------------------------------------------------------------------------------- /DET/utils/pos_embed.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Position embedding utils 3 | # -------------------------------------------------------- 4 | 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | # -------------------------------------------------------- 12 | # 2D sine-cosine position embedding 13 | # References: 14 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 15 | # MoCo v3: https://github.com/facebookresearch/moco-v3 16 | # -------------------------------------------------------- 17 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 18 | """ 19 | grid_size: int of the grid height and width 20 | return: 21 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 22 | """ 23 | grid_h = np.arange(grid_size, dtype=np.float32) 24 | grid_w = np.arange(grid_size, dtype=np.float32) 25 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 26 | grid = np.stack(grid, axis=0) 27 | 28 | grid = grid.reshape([2, 1, grid_size, grid_size]) 29 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 30 | if cls_token: 31 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 32 | return pos_embed 33 | 34 | 35 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 36 | assert embed_dim % 2 == 0 37 | 38 | # use half of dimensions to encode grid_h 39 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 40 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 41 | 42 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 43 | return emb 44 | 45 | 46 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 47 | """ 48 | embed_dim: output dimension for each position 49 | pos: a list of positions to be encoded: size (M,) 50 | out: (M, D) 51 | """ 52 | assert embed_dim % 2 == 0 53 | omega = np.arange(embed_dim // 2, dtype=np.float) 54 | omega /= embed_dim / 2.0 55 | omega = 1.0 / 10000 ** omega # (D/2,) 56 | 57 | pos = pos.reshape(-1) # (M,) 58 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 59 | 60 | emb_sin = np.sin(out) # (M, D/2) 61 | emb_cos = np.cos(out) # (M, D/2) 62 | 63 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 64 | return emb 65 | 66 | 67 | # -------------------------------------------------------- 68 | # Interpolate position embeddings for high-resolution 69 | # References: 70 | # DeiT: https://github.com/facebookresearch/deit 71 | # -------------------------------------------------------- 72 | def interpolate_pos_embed(model, checkpoint_model, pos_embed_key): 73 | if pos_embed_key in checkpoint_model: 74 | pos_embed_checkpoint = checkpoint_model[pos_embed_key] 75 | embedding_size = pos_embed_checkpoint.shape[-1] 76 | num_patches = model.num_patches 77 | if pos_embed_key.startswith("decoder"): 78 | num_extra_tokens = model.decoder_pos_embed.shape[-2] - num_patches 79 | else: 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print( 88 | "Position interpolate from %dx%d to %dx%d" 89 | % (orig_size, orig_size, new_size, new_size) 90 | ) 91 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 92 | # only the position tokens are interpolated 93 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 94 | pos_tokens = pos_tokens.reshape( 95 | -1, orig_size, orig_size, embedding_size 96 | ).permute(0, 3, 1, 2) 97 | pos_tokens = torch.nn.functional.interpolate( 98 | pos_tokens, 99 | size=(new_size, new_size), 100 | mode="bicubic", 101 | align_corners=False, 102 | ) 103 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 104 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 105 | checkpoint_model[pos_embed_key] = new_pos_embed 106 | 107 | 108 | def interpolate_pos_embed_online( 109 | pos_embed, orig_size: Tuple[int], new_size: Tuple[int], num_extra_tokens: int 110 | ): 111 | extra_tokens = pos_embed[:, :num_extra_tokens] 112 | pos_tokens = pos_embed[:, num_extra_tokens:] 113 | embedding_size = pos_tokens.shape[-1] 114 | pos_tokens = pos_tokens.reshape( 115 | -1, orig_size[0], orig_size[1], embedding_size 116 | ).permute(0, 3, 1, 2) 117 | pos_tokens = torch.nn.functional.interpolate( 118 | pos_tokens, size=new_size, mode="bicubic", align_corners=False, 119 | ) 120 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 121 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 122 | return new_pos_embed 123 | -------------------------------------------------------------------------------- /FINETUNE.md: -------------------------------------------------------------------------------- 1 | # ConvMAE: Masked Convolution Meets Masked Autoencoders 2 | 3 | This folder contains the implementation of the ConvMAE finetuning for image classification. 4 | 5 | ## Model Zoo 6 | 7 | | Models | #Params(M) | Supervision | Encoder Ratio | Pretrain Epochs | FT acc@1(%) | FT logs/weights | 8 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 9 | | ConvMAE-B | 88 | RGB | 25% | 1600 | 85.0 | [log](https://drive.google.com/file/d/1nzAOD5UR3b9QqwD2vMMz0Bx3170sypuy/view?usp=sharing)/[weight](https://drive.google.com/file/d/19F6vQUlITpzNLvXLKi5NRxRLOmKRxqFi/view?usp=sharing) | 10 | 11 | ## Usage 12 | 13 | ### Install 14 | - Clone this repo: 15 | 16 | ```bash 17 | git clone https://github.com/Alpha-VL/ConvMAE 18 | cd ConvMAE 19 | ``` 20 | 21 | - Create a conda environment and activate it: 22 | ```bash 23 | conda create -n convmae python=3.7 24 | conda activate convmae 25 | ``` 26 | 27 | - Install `Pytorch==1.8.0` and `torchvision==0.9.0` with `CUDA==11.1` 28 | 29 | ```bash 30 | conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge 31 | ``` 32 | 33 | - Install `timm==0.3.2` 34 | 35 | ```bash 36 | pip install timm==0.3.2 37 | ``` 38 | 39 | ### Data preparation 40 | 41 | You can download the ImageNet-1K [here](https://image-net.org) and prepare the ImageNet-1K follow this format: 42 | 43 | ```tree data 44 | imagenet 45 | ├── train 46 | │ ├── class1 47 | │ │ ├── img1.jpeg 48 | │ │ ├── img2.jpeg 49 | │ │ └── ... 50 | │ ├── class2 51 | │ │ ├── img3.jpeg 52 | │ │ └── ... 53 | │ └── ... 54 | └── val 55 | ├── class1 56 | │ ├── img4.jpeg 57 | │ ├── img5.jpeg 58 | │ └── ... 59 | ├── class2 60 | │ ├── img6.jpeg 61 | │ └── ... 62 | └── ... 63 | ``` 64 | ### Evaluation 65 | 66 | Download the finetuned model from [here](https://drive.google.com/file/d/19F6vQUlITpzNLvXLKi5NRxRLOmKRxqFi/view?usp=sharing). 67 | 68 | Evaluate ConvViT-Base by running: 69 | 70 | ```bash 71 | python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py --batch_size 128 --model convvit_base_patch16 --resume ${FINETUNE_CHKPT} --dist_eval --data_path ${IMAGENET_DIR} --eval 72 | ``` 73 | 74 | This shoud give: 75 | 76 | ```bash 77 | * Acc@1 84.982 Acc@5 97.152 loss 0.695 78 | Accuracy of the network on the 50000 test images: 85.0% 79 | ``` 80 | 81 | ### Fine-tuning 82 | Download the pretrained model from [here](https://drive.google.com/file/d/1AEPivXw0A0b_m5EwEi6fg2pOAoDr8C31/view?usp=sharing). 83 | 84 | To finetune with multi-node distributed training, run the following on 4 nodes with 8 GPUs each: 85 | ```bash 86 | python submitit_finetune.py \ 87 | --job_dir ${JOB_DIR} \ 88 | --nodes 4 \ 89 | --batch_size 32 \ 90 | --model convvit_base_patch16 \ 91 | --finetune ${PRETRAIN_CHKPT} \ 92 | --epochs 100 \ 93 | --blr 5e-4 --layer_decay 0.65 \ 94 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 95 | --dist_eval --data_path ${IMAGENET_DIR} 96 | ``` 97 | 98 | To finetune with single-node training, run the following on single node with 8 GPUs: 99 | ```bash 100 | python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ 101 | --batch_size 128 \ 102 | --model convvit_base_patch16 \ 103 | --finetune ${PRETRAIN_CHKPT} \ 104 | --epochs 100 \ 105 | --blr 5e-4 --layer_decay 0.65 \ 106 | --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \ 107 | --dist_eval --data_path ${IMAGENET_DIR} 108 | ``` 109 | #### Notes 110 | - There are chances that loss is nan during finetuning process, if so, just delete the [line](https://github.com/Alpha-VL/ConvMAE/blob/53d56ad2388665bf86e0e029aa3f424e709a6652/engine_finetune.py#L55) to use fp32 type to resume the finetuning from where it broke down. 111 | - How to resume: just add `--resume` into above scripts as: 112 | ```bash 113 | --resume ${CHKPT_RESUME} 114 | ``` 115 | - Also, we are still working to solve the possible gradient vanish caused by fp16 mixed-precision finetuning. Feeling free to contact us if you have any suggestions. 116 | 117 | ### Linear Probing 118 | Download the pretrained model from [here](https://drive.google.com/file/d/1AEPivXw0A0b_m5EwEi6fg2pOAoDr8C31/view?usp=sharing). 119 | 120 | To finetune with multi-node distributed training, run the following on 4 nodes with 8 GPUs each: 121 | ```bash 122 | python submitit_linprobe.py \ 123 | --job_dir ${JOB_DIR} \ 124 | --nodes 4 \ 125 | --batch_size 128 \ 126 | --model convvit_base_patch16 \ 127 | --global_pool \ 128 | --finetune ${PRETRAIN_CHKPT} \ 129 | --epochs 90 \ 130 | --blr 0.1 --weight_decay 0.0 \ 131 | --dist_eval --data_path ${IMAGENET_DIR} 132 | ``` 133 | 134 | To finetune with single-node training, run the following on single node with 8 GPUs: 135 | ```bash 136 | python -m torch.distributed.launch --nproc_per_node=8 main_linprobe.py \ 137 | --batch_size 512 \ 138 | --model convvit_base_patch16 \ 139 | --finetune ${PRETRAIN_CHKPT} \ 140 | --epochs 90 \ 141 | --blr 0.1 --weight_decay 0.0 \ 142 | --dist_eval --data_path ${IMAGENET_DIR} 143 | ``` 144 | 145 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alpha-VL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PRETRAIN.md: -------------------------------------------------------------------------------- 1 | ## Pretraining ConvMAE 2 | ## Usage 3 | 4 | ### Install 5 | - Clone this repo: 6 | 7 | ```bash 8 | git clone https://github.com/Alpha-VL/ConvMAE 9 | cd ConvMAE 10 | ``` 11 | 12 | - Create a conda environment and activate it: 13 | ```bash 14 | conda create -n convmae python=3.7 15 | conda activate convmae 16 | ``` 17 | 18 | - Install `Pytorch==1.8.0` and `torchvision==0.9.0` with `CUDA==11.1` 19 | 20 | ```bash 21 | conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge 22 | ``` 23 | 24 | - Install `timm==0.3.2` 25 | 26 | ```bash 27 | pip install timm==0.3.2 28 | ``` 29 | 30 | ### Data preparation 31 | 32 | You can download the ImageNet-1K [here](https://image-net.org) and prepare the ImageNet-1K follow this format: 33 | ```tree data 34 | imagenet 35 | ├── train 36 | ├── class1 37 | │ ├── img1.jpeg 38 | │ ├── img2.jpeg 39 | │ └── ... 40 | ├── class2 41 | │ ├── img3.jpeg 42 | │ └── ... 43 | └── ... 44 | ``` 45 | 46 | ### Training 47 | To pretrain ConvMAE-Base with **multi-node distributed training**, run the following on 3 nodes with 8 GPUs each: 48 | 49 | ```bash 50 | python submitit_pretrain.py \ 51 | --job_dir ${JOB_DIR} \ 52 | --nodes 3 \ 53 | --batch_size 128 \ 54 | --model convmae_convvit_base_patch16 \ 55 | --norm_pix_loss \ 56 | --mask_ratio 0.75 \ 57 | --epochs 1600 \ 58 | --warmup_epochs 40 \ 59 | --blr 1.5e-4 --weight_decay 0.05 \ 60 | --data_path ${IMAGENET_DIR} 61 | ``` 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

[NeurIPS 2022] MCMAE: Masked Convolution Meets Masked Autoencoders

3 | 4 | [Peng Gao](https://scholar.google.com/citations?user=miFIAFMAAAAJ&hl=en&oi=ao)1, [Teli Ma](https://scholar.google.com/citations?user=arny77IAAAAJ&hl=en&oi=ao)1, [Hongsheng Li](https://scholar.google.com/citations?user=BN2Ze-QAAAAJ&hl=en&oi=ao)2, [Ziyi Lin](https://scholar.google.com/citations?user=-VOnnzUAAAAJ&hl=en)2, [Jifeng Dai](https://scholar.google.com/citations?user=SH_-B_AAAAAJ&hl=en&oi=ao)3, [Yu Qiao](https://scholar.google.com/citations?user=gFtI-8QAAAAJ&hl=en&oi=ao)1, 5 | 6 | 1 [Shanghai AI Laboratory](https://www.shlab.org.cn/), 2 [MMLab, CUHK](https://mmlab.ie.cuhk.edu.hk/), 3 [Sensetime Research](https://www.sensetime.com/cn). 7 | 8 |
9 | 10 | \* We change the project name from **ConvMAE** to **MCMAE**. 11 | 12 | This repo is the official implementation of [MCMAE: Masked Convolution Meets Masked Autoencoders](https://arxiv.org/abs/2205.03892). It currently concludes codes and models for the following tasks: 13 | > **ImageNet Pretrain**: See [PRETRAIN.md](PRETRAIN.md).\ 14 | > **ImageNet Finetune**: See [FINETUNE.md](FINETUNE.md).\ 15 | > **Object Detection**: See [DETECTION.md](DET/DETECTION.md).\ 16 | > **Semantic Segmentation**: See [SEGMENTATION.md](SEG/SEGMENTATION.md). \ 17 | > **Video Classification**: See [VideoConvMAE](https://github.com/Alpha-VL/VideoConvMAE). 18 | 19 | ## Updates 20 | 21 | ***14/Mar/2023*** 22 | 23 | MR-MCMAE (a.k.a. ConvMAE-v2) paper released: [Mimic before Reconstruct: Enhancing Masked Autoencoders with Feature Mimicking](https://arxiv.org/abs/2303.05475). 24 | 25 | ***15/Sep/2022*** 26 | 27 | Paper accepted at NeurIPS 2022. 28 | 29 | ***9/Sep/2022*** 30 | 31 | ConvMAE-v2 pretrained checkpoints are released. 32 | 33 | ***21/Aug/2022*** 34 | 35 | [Official-ConvMAE-Det](https://github.com/OpenGVLab/Official-ConvMAE-Det) which follows official ViTDet codebase is released. 36 | 37 | ***08/Jun/2022*** 38 | 39 | 🚀FastConvMAE🚀: significantly accelerates the pretraining hours (4000 single GPU hours => 200 single GPU hours). The code is going to be released at [FastConvMAE](https://github.com/Alpha-VL/FastConvMAE). 40 | 41 | ***27/May/2022*** 42 | 43 | 1. The supported codes for ImageNet-1K pretraining. 44 | 2. The supported codes and models for semantic segmentation are provided. 45 | 46 | ***20/May/2022*** 47 | 48 | Update results on video classification. 49 | 50 | ***16/May/2022*** 51 | 52 | The supported codes and models for COCO object detection and instance segmentation are available. 53 | 54 | ***11/May/2022*** 55 | 56 | 1. Pretrained models on ImageNet-1K for ConvMAE. 57 | 2. The supported codes and models for ImageNet-1K finetuning and linear probing are provided. 58 | 59 | ***08/May/2022*** 60 | 61 | The preprint version is public at [arxiv](https://arxiv.org/abs/2205.03892). 62 | 63 | ## Introduction 64 | ConvMAE framework demonstrates that multi-scale hybrid convolution-transformer can learn more discriminative representations via the mask auto-encoding scheme. 65 | * We present the strong and efficient self-supervised framework ConvMAE, which is easy to implement but show outstanding performances on downstream tasks. 66 | * ConvMAE naturally generates hierarchical representations and exhibit promising performances on object detection and segmentation. 67 | * ConvMAE-Base improves the ImageNet finetuning accuracy by 1.4% compared with MAE-Base. 68 | On object detection with Mask-RCNN, ConvMAE-Base achieves 53.2 box AP and 47.1 mask AP with a 25-epoch training schedule while MAE-Base attains 50.3 box AP and 44.9 mask AP with 100 training epochs. On ADE20K with UperNet, ConvMAE-Base surpasses MAE-Base by 3.6 mIoU (48.1 vs. 51.7). 69 | 70 | 71 | ![tenser](figures/ConvMAE.png) 72 | 73 | ## Pretrain on ImageNet-1K 74 | The following table provides pretrained checkpoints and logs used in the paper. 75 | | | ConvMAE-Base| 76 | | :---: | :---: | 77 | | pretrained checkpoints| [download](https://drive.google.com/file/d/1AEPivXw0A0b_m5EwEi6fg2pOAoDr8C31/view?usp=sharing) | 78 | | logs | [download](https://drive.google.com/file/d/1Je9ClIGCQP43xC3YURVFPnaMRC0-ax1h/view?usp=sharing) | 79 | 80 | The following results are for ConvMAE-v2 (pretrained for 200 epochs on ImageNet-1k). 81 | | model | pretrained checkpoints | ft. acc. on ImageNet-1k | 82 | | :---: | :---: | :---: | 83 | | ConvMAE-v2-Small | [download](https://drive.google.com/file/d/1LqU-0tajhxYMSTN6WVFwiIveFjETVvKb/view?usp=sharing) | 83.6 | 84 | | ConvMAE-v2-Base | [download](https://drive.google.com/file/d/1gykVKNDlRn8eiuXk5bUj1PbSnHXFzLnI/view?usp=sharing) | 85.7 | 85 | | ConvMAE-v2-Large | [download](https://drive.google.com/file/d/1RN3ZseDseWGwuUwrVTkel17_iYFvZL6m/view?usp=sharing) | 86.8 | 86 | | ConvMAE-v2-Huge | [download](https://drive.google.com/file/d/1k1OBhNTLzRI9c6ReSgK7_7vqGZr-2Cpd/view?usp=sharing) | 88.0 | 87 | 88 | ## Main Results on ImageNet-1K 89 | | Models | #Params(M) | Supervision | Encoder Ratio | Pretrain Epochs | FT acc@1(%) | LIN acc@1(%) | FT logs/weights | LIN logs/weights | 90 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 91 | | BEiT | 88 | DALLE | 100% | 300 | 83.0 | 37.6 | - | - | 92 | | MAE | 88 | RGB | 25% | 1600 | 83.6 | 67.8 | - | - | 93 | | SimMIM | 88 | RGB | 100% | 800 | 84.0 | 56.7 | - | - | 94 | | MaskFeat | 88 | HOG | 100% | 300 | 83.6 | N/A | - | - | 95 | | data2vec | 88 | RGB | 100% | 800 | 84.2 | N/A | - | - | 96 | | ConvMAE-B | 88 | RGB | 25% | 1600 | 85.0 | 70.9 | [log](https://drive.google.com/file/d/1nzAOD5UR3b9QqwD2vMMz0Bx3170sypuy/view?usp=sharing)/[weight](https://drive.google.com/file/d/19F6vQUlITpzNLvXLKi5NRxRLOmKRxqFi/view?usp=sharing) | 97 | 98 | 99 | 100 | ## Main Results on COCO 101 | ### Mask R-CNN 102 | | Models | Pretrain | Pretrain Epochs | Finetune Epochs | #Params(M)| FLOPs(T) | box AP | mask AP | logs/weights | 103 | | :---: | :---: | :---: |:---: | :---: | :---: | :---: | :---: | :---: | 104 | | Swin-B | IN21K w/ labels | 90 | 36 | 109 | 0.7 | 51.4 | 45.4 | - | 105 | | Swin-L | IN21K w/ labels | 90 | 36 | 218 | 1.1 | 52.4 | 46.2 | - | 106 | | MViTv2-B | IN21K w/ labels | 90 | 36 | 73 | 0.6 | 53.1 | 47.4 | - | 107 | | MViTv2-L | IN21K w/ labels | 90 | 36 | 239 | 1.3 | 53.6 | 47.5 | - | 108 | | Benchmarking-ViT-B | IN1K w/o labels | 1600 | 100 | 118 | 0.9 | 50.4 | 44.9 | - | 109 | | Benchmarking-ViT-L | IN1K w/o labels | 1600 | 100 | 340 | 1.9 | 53.3 | 47.2 | - | 110 | | ViTDet | IN1K w/o labels | 1600 | 100 | 111 | 0.8 | 51.2 | 45.5 | - | 111 | | MIMDet-ViT-B | IN1K w/o labels | 1600 | 36 | 127 | 1.1 | 51.5 | 46.0 | - | 112 | | MIMDet-ViT-L | IN1K w/o labels | 1600 | 36 | 345 | 2.6 | 53.3 | 47.5 | - | 113 | | ConvMAE-B | IN1K w/o lables | 1600 | 25 | 104 | 0.9 | 53.2 | 47.1 | [log](https://drive.google.com/file/d/1vQ9ps-TxeS_8BRfSWZh-X-5Kki7mgIgR/view?usp=sharing)/[weight](https://drive.google.com/file/d/17gy2mlrRVpIlQN9ERSHh98VkHhWINn-m/view?usp=sharing) | 114 | 115 | 116 | 117 | ## Main Results on ADE20K 118 | ### UperNet 119 | | Models | Pretrain | Pretrain Epochs| Finetune Iters | #Params(M)| FLOPs(T) | mIoU | logs/weights | 120 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 121 | | DeiT-B | IN1K w/ labels | 300 | 16K | 163 | 0.6 | 45.6 | - | 122 | | Swin-B | IN1K w/ labels | 300 | 16K | 121 | 0.3 | 48.1 | - | 123 | | MoCo V3 | IN1K | 300 | 16K | 163 | 0.6 | 47.3 | - | 124 | | DINO | IN1K | 400 | 16K | 163 | 0.6 | 47.2 | - | 125 | | BEiT | IN1K+DALLE | 1600 | 16K | 163 | 0.6 | 47.1 | - | 126 | | PeCo | IN1K | 300 | 16K | 163 | 0.6 | 46.7 | - | 127 | | CAE | IN1K+DALLE | 800 | 16K | 163 | 0.6 | 48.8 | - | 128 | | MAE | IN1K | 1600 | 16K | 163 | 0.6 | 48.1 | - | 129 | | ConvMAE-B | IN1K | 1600 | 16K | 153 | 0.6 | 51.7 | [log](https://drive.google.com/file/d/1N3LEhEd2FLx8777Kn5tVn5gxYiBTz00A/view?usp=sharing)/[weight](https://drive.google.com/file/d/1aQR_CmZBzN2eHWYgzPUDm4ulme-g9cIR/view?usp=sharing) | 130 | 131 | ## Main Results on Kinetics-400 132 | 133 | | Models | Pretrain Epochs | Finetune Epochs | #Params(M) | Top1 | Top5 | logs/weights | 134 | | :---------------------: | :-------------: | :-------------------: | :--------: | :--: | :--: | :----------: | 135 | | VideoMAE-B | 200 | 100 | 87 | 77.8 | | | 136 | | VideoMAE-B | 800 | 100 | 87 | 79.4 | | | 137 | | VideoMAE-B | 1600 | 100 | 87 | 79.8 | | | 138 | | VideoMAE-B | 1600 | 100 (w/ Repeated Aug) | 87 | 80.7 | 94.7 | | 139 | | SpatioTemporalLearner-B | 800 | 150 (w/ Repeated Aug) | 87 | 81.3 | 94.9 | | 140 | | VideoConvMAE-B | 200 | 100 | 86 | 80.1 | 94.3 | Soon | 141 | | VideoConvMAE-B | 800 | 100 | 86 | 81.7 | 95.1 | Soon | 142 | | VideoConvMAE-B-MSD | 800 | 100 | 86 | 82.7 | 95.5 | Soon | 143 | 144 | ## Main Results on Something-Something V2 145 | 146 | | Models | Pretrain Epochs | Finetune Epochs | #Params(M) | Top1 | Top5 | logs/weights | 147 | | :----------------: | :-------------: | :-------------: | :--------: | :--: | :--: | :----------: | 148 | | VideoMAE-B | 200 | 40 | 87 | 66.1 | | | 149 | | VideoMAE-B | 800 | 40 | 87 | 69.3 | | | 150 | | VideoMAE-B | 2400 | 40 | 87 | 70.3 | | | 151 | | VideoConvMAE-B | 200 | 40 | 86 | 67.7 | 91.2 | Soon | 152 | | VideoConvMAE-B | 800 | 40 | 86 | 69.9 | 92.4 | Soon | 153 | | VideoConvMAE-B-MSD | 800 | 40 | 86 | 70.7 | 93.0 | Soon | 154 | 155 | 156 | ## Getting Started 157 | ### Prerequisites 158 | * Linux 159 | * Python 3.7+ 160 | * CUDA 10.2+ 161 | * GCC 5+ 162 | 163 | ### Training and evaluation 164 | * See [PRETRAIN.md](PRETRAIN.md) for pretraining. 165 | * See [FINETUNE.md](FINETUNE.md) for pretrained model finetuning and linear probing. 166 | * See [DETECTION.md](DET/DETECTION.md) for using pretrained backbone on [Mask RCNN](https://openaccess.thecvf.com/content_iccv_2017/html/He_Mask_R-CNN_ICCV_2017_paper.html). 167 | * See [SEGMENTATION.md](SEG/SEGMENTATION.md) for using pretrained backbone on [UperNet](https://openaccess.thecvf.com/content_ECCV_2018/html/Tete_Xiao_Unified_Perceptual_Parsing_ECCV_2018_paper.html). 168 | * See [VideoConvMAE](https://github.com/Alpha-VL/VideoConvMAE) for video classification. 169 | 170 | ## Visualization 171 | ![tenser](figures/feat_map.JPG) 172 | 173 | ## Acknowledgement 174 | The pretraining and finetuning of our project are based on [DeiT](https://github.com/facebookresearch/deit) and [MAE](https://github.com/facebookresearch/mae). The object detection and semantic segmentation parts are based on [MIMDet](https://github.com/hustvl/MIMDet) and [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) respectively. Thanks for their wonderful work. 175 | 176 | ## License 177 | ConvMAE is released under the [MIT License](https://github.com/Alpha-VL/ConvMAE/blob/main/LICENSE). 178 | 179 | ## Citation 180 | 181 | ```bash 182 | @article{gao2022convmae, 183 | title={ConvMAE: Masked Convolution Meets Masked Autoencoders}, 184 | author={Gao, Peng and Ma, Teli and Li, Hongsheng and Dai, Jifeng and Qiao, Yu}, 185 | journal={arXiv preprint arXiv:2205.03892}, 186 | year={2022} 187 | } 188 | ``` 189 | 190 | 191 | -------------------------------------------------------------------------------- /SEG/SEGMENTATION.md: -------------------------------------------------------------------------------- 1 | # ConvMAE: Masked Convolution Meets Masked Autoencoders 2 | 3 | This folder contains the implementation of the ConvMAE transfer learning for semantic segmentation on ADE-20K. It is based on [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). 4 | 5 | ## Pipeline 6 | 7 | ![tenser](../figures/Downstream.png) 8 | 9 | | Models | Pretrain | Pretrain Epochs| Finetune Iters | #Params(M)| FLOPs(T) | mIoU | logs/weights | 10 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 11 | | ConvMAE-B | IN1K | 1600 | 16K | 153 | 0.6 | 51.7 | [log](https://drive.google.com/file/d/1N3LEhEd2FLx8777Kn5tVn5gxYiBTz00A/view?usp=sharing)/[weight](https://drive.google.com/file/d/1aQR_CmZBzN2eHWYgzPUDm4ulme-g9cIR/view?usp=sharing) | 12 | 13 | ## Usage 14 | 15 | ### Install 16 | - Clone this repo: 17 | 18 | ```bash 19 | git clone https://github.com/Alpha-VL/ConvMAE 20 | cd ConvMAE/SEG 21 | ``` 22 | 23 | - Create a conda environment and activate it: 24 | ```bash 25 | conda create -n upernet python=3.7 26 | conda activate upernet 27 | ``` 28 | 29 | - Install `Pytorch==1.8.0` and `torchvision==0.9.0` with `CUDA==11.1` 30 | 31 | ```bash 32 | conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge 33 | ``` 34 | 35 | - Install the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library and some required packages. 36 | 37 | ```bash 38 | pip install mmcv-full==1.3.0 mmsegmentation==0.11.0 39 | pip install scipy timm==0.3.2 40 | ``` 41 | 42 | - Install [apex](https://github.com/NVIDIA/apex) for mixed-precision training 43 | 44 | ```bash 45 | git clone https://github.com/NVIDIA/apex 46 | cd apex 47 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 48 | ``` 49 | 50 | ### Data preparation 51 | Follow the guide in [mmseg](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#ade20k) to prepare the ADE20k dataset. 52 | 53 | ### Training 54 | Download the pretrained model [here](https://drive.google.com/file/d/1AEPivXw0A0b_m5EwEi6fg2pOAoDr8C31/view?usp=sharing). 55 | 56 | ```bash 57 | ./tools/dist_train.sh --work-dir --options model.pretrained= 58 | ``` 59 | 60 | For example: 61 | ```bash 62 | ./tools/dist_train.sh \ 63 | configs/convmae/upernet_convmae_base_512_slide_160k_ade20k.py 8 \ 64 | --work-dir /path/to/save \ 65 | --options model.pretrained=/path/to/pretrained/weights 66 | ``` 67 | 68 | ### Evaluation 69 | 70 | Download the fine-tuned checkpoint [here](https://drive.google.com/file/d/1aQR_CmZBzN2eHWYgzPUDm4ulme-g9cIR/view?usp=sharing). 71 | ``` 72 | ./tools/dist_test.sh --eval mIoU 73 | ``` 74 | 75 | Run 76 | ``` 77 | ./tools/dist_test.sh configs/convmae/upernet_convmae_base_512_slide_160k_ade20k.py /path/to/finetuned/weights 8 --eval mIoU 78 | ``` 79 | 80 | This should give 81 | ``` 82 | +--------+-------+-------+-------+ 83 | | Scope | mIoU | mAcc | aAcc | 84 | +--------+-------+-------+-------+ 85 | | global | 51.66 | 63.88 | 84.45 | 86 | +--------+-------+-------+-------+ 87 | ``` -------------------------------------------------------------------------------- /SEG/configs/_base_/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'ADE20KDataset' 3 | data_root = '/mnt/cache/mateli/ADEChallengeData2016/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (512, 512) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', reduce_zero_label=True), 10 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2048, 512), 24 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ]) 33 | ] 34 | data = dict( 35 | samples_per_gpu=4, 36 | workers_per_gpu=4, 37 | train=dict( 38 | type=dataset_type, 39 | data_root=data_root, 40 | img_dir='images/training', 41 | ann_dir='annotations/training', 42 | pipeline=train_pipeline), 43 | val=dict( 44 | type=dataset_type, 45 | data_root=data_root, 46 | img_dir='images/validation', 47 | ann_dir='annotations/validation', 48 | pipeline=test_pipeline), 49 | test=dict( 50 | type=dataset_type, 51 | data_root=data_root, 52 | img_dir='images/validation', 53 | ann_dir='annotations/validation', 54 | pipeline=test_pipeline)) 55 | -------------------------------------------------------------------------------- /SEG/configs/_base_/datasets/ade20k_640x640.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'ADE20KDataset' 3 | data_root = 'data/ade/ADEChallengeData2016' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (640, 640) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', reduce_zero_label=True), 10 | dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2560, 640), 24 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ]) 33 | ] 34 | data = dict( 35 | samples_per_gpu=4, 36 | workers_per_gpu=4, 37 | train=dict( 38 | type=dataset_type, 39 | data_root=data_root, 40 | img_dir='images/training', 41 | ann_dir='annotations/training', 42 | pipeline=train_pipeline), 43 | val=dict( 44 | type=dataset_type, 45 | data_root=data_root, 46 | img_dir='images/validation', 47 | ann_dir='annotations/validation', 48 | pipeline=test_pipeline), 49 | test=dict( 50 | type=dataset_type, 51 | data_root=data_root, 52 | img_dir='images/validation', 53 | ann_dir='annotations/validation', 54 | pipeline=test_pipeline)) 55 | -------------------------------------------------------------------------------- /SEG/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | -------------------------------------------------------------------------------- /SEG/configs/_base_/models/upernet.py: -------------------------------------------------------------------------------- 1 | norm_cfg = dict(type='SyncBN', requires_grad=True) 2 | model = dict( 3 | type='EncoderDecoder', 4 | pretrained=None, 5 | backbone=dict( 6 | type='XCiT', 7 | patch_size=16, 8 | embed_dim=384, 9 | depth=12, 10 | num_heads=8, 11 | mlp_ratio=4, 12 | qkv_bias=True, 13 | use_abs_pos_emb=True, 14 | use_rel_pos_bias=False, 15 | ), 16 | decode_head=dict( 17 | type='UPerHead', 18 | in_channels=[384, 384, 384, 384], 19 | in_index=[0, 1, 2, 3], 20 | pool_scales=(1, 2, 3, 6), 21 | channels=512, 22 | dropout_ratio=0.1, 23 | num_classes=19, 24 | norm_cfg=norm_cfg, 25 | align_corners=False, 26 | loss_decode=dict( 27 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 28 | auxiliary_head=dict( 29 | type='FCNHead', 30 | in_channels=384, 31 | in_index=2, 32 | channels=256, 33 | num_convs=1, 34 | concat_input=False, 35 | dropout_ratio=0.1, 36 | num_classes=19, 37 | norm_cfg=norm_cfg, 38 | align_corners=False, 39 | loss_decode=dict( 40 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 41 | # model training and testing settings 42 | train_cfg=dict(), 43 | test_cfg=dict(mode='whole')) 44 | -------------------------------------------------------------------------------- /SEG/configs/_base_/schedules/schedule_160k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=16000) 9 | evaluation = dict(interval=16000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /SEG/configs/_base_/schedules/schedule_320k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=320000) 8 | checkpoint_config = dict(by_epoch=False, interval=32000) 9 | evaluation = dict(interval=32000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /SEG/configs/convmae/upernet_convmae_base_512_slide_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/upernet.py', '../_base_/datasets/ade20k.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 4 | ] 5 | crop_size = (512, 512) 6 | 7 | model = dict( 8 | pretrained='', 9 | backbone=dict( 10 | type='ConvMAE', 11 | img_size=[512, 128, 64], 12 | patch_size=[4, 2, 2], 13 | embed_dim=[256, 384, 768], 14 | depth=[2, 2, 11], 15 | num_heads=12, 16 | mlp_ratio=[4, 4, 4], 17 | qkv_bias=True, 18 | use_abs_pos_emb=True, 19 | use_rel_pos_bias=True, 20 | init_values=1., 21 | drop_path_rate=0.2, 22 | out_indices=[3, 5, 7, 11] 23 | ), 24 | decode_head=dict( 25 | in_channels=[256, 384, 768, 768], 26 | num_classes=150, 27 | channels=768, 28 | ), 29 | auxiliary_head=dict( 30 | in_channels=768, 31 | num_classes=150 32 | ), 33 | test_cfg = dict(mode='slide', crop_size=crop_size, stride=(341, 341)) 34 | ) 35 | 36 | optimizer = dict(_delete_=True, type='AdamW', lr=1e-4, betas=(0.9, 0.999), weight_decay=0.05, 37 | constructor='LayerDecayOptimizerConstructor', 38 | paramwise_cfg=dict(num_layers=11, layer_decay_rate=0.75)) 39 | 40 | lr_config = dict(_delete_=True, policy='poly', 41 | warmup='linear', 42 | warmup_iters=1500, 43 | warmup_ratio=1e-6, 44 | power=1.0, min_lr=0.0, by_epoch=False) 45 | 46 | # By default, models are trained on 8 GPUs with 2 images per GPU 47 | data=dict(samples_per_gpu=2) 48 | 49 | runner = dict(type='IterBasedRunnerAmp') 50 | 51 | # do not use mmdet version fp16 52 | fp16 = None 53 | optimizer_config = dict( 54 | type="DistOptimizerHook", 55 | update_interval=1, 56 | grad_clip=None, 57 | coalesce=True, 58 | bucket_size_mb=-1, 59 | use_fp16=True, 60 | ) 61 | -------------------------------------------------------------------------------- /SEG/mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import load_checkpoint 4 | from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor 5 | from .resize_transform import SETR_Resize 6 | from .apex_runner.optimizer import DistOptimizerHook 7 | from .train_api import train_segmentor 8 | 9 | __all__ = ['load_checkpoint', 'LayerDecayOptimizerConstructor', 'SETR_Resize', 'DistOptimizerHook', 'train_segmentor'] 10 | -------------------------------------------------------------------------------- /SEG/mmcv_custom/apex_runner/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | from .checkpoint import save_checkpoint 3 | from .apex_iter_based_runner import IterBasedRunnerAmp 4 | 5 | 6 | __all__ = [ 7 | 'save_checkpoint', 'IterBasedRunnerAmp', 8 | ] 9 | -------------------------------------------------------------------------------- /SEG/mmcv_custom/apex_runner/apex_iter_based_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import platform 4 | import shutil 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.runner import RUNNERS, IterBasedRunner 11 | from .checkpoint import save_checkpoint 12 | 13 | try: 14 | import apex 15 | except: 16 | print('apex is not installed') 17 | 18 | 19 | @RUNNERS.register_module() 20 | class IterBasedRunnerAmp(IterBasedRunner): 21 | """Iteration-based Runner with AMP support. 22 | 23 | This runner train models iteration by iteration. 24 | """ 25 | 26 | def save_checkpoint(self, 27 | out_dir, 28 | filename_tmpl='iter_{}.pth', 29 | meta=None, 30 | save_optimizer=True, 31 | create_symlink=False): 32 | """Save checkpoint to file. 33 | 34 | Args: 35 | out_dir (str): Directory to save checkpoint files. 36 | filename_tmpl (str, optional): Checkpoint file template. 37 | Defaults to 'iter_{}.pth'. 38 | meta (dict, optional): Metadata to be saved in checkpoint. 39 | Defaults to None. 40 | save_optimizer (bool, optional): Whether save optimizer. 41 | Defaults to True. 42 | create_symlink (bool, optional): Whether create symlink to the 43 | latest checkpoint file. Defaults to True. 44 | """ 45 | if meta is None: 46 | meta = dict(iter=self.iter + 1, epoch=self.epoch + 1) 47 | elif isinstance(meta, dict): 48 | meta.update(iter=self.iter + 1, epoch=self.epoch + 1) 49 | else: 50 | raise TypeError( 51 | f'meta should be a dict or None, but got {type(meta)}') 52 | if self.meta is not None: 53 | meta.update(self.meta) 54 | 55 | filename = filename_tmpl.format(self.iter + 1) 56 | filepath = osp.join(out_dir, filename) 57 | optimizer = self.optimizer if save_optimizer else None 58 | save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) 59 | # in some environments, `os.symlink` is not supported, you may need to 60 | # set `create_symlink` to False 61 | # if create_symlink: 62 | # dst_file = osp.join(out_dir, 'latest.pth') 63 | # if platform.system() != 'Windows': 64 | # mmcv.symlink(filename, dst_file) 65 | # else: 66 | # shutil.copy(filepath, dst_file) 67 | 68 | def resume(self, 69 | checkpoint, 70 | resume_optimizer=True, 71 | map_location='default'): 72 | if map_location == 'default': 73 | if torch.cuda.is_available(): 74 | device_id = torch.cuda.current_device() 75 | checkpoint = self.load_checkpoint( 76 | checkpoint, 77 | map_location=lambda storage, loc: storage.cuda(device_id)) 78 | else: 79 | checkpoint = self.load_checkpoint(checkpoint) 80 | else: 81 | checkpoint = self.load_checkpoint( 82 | checkpoint, map_location=map_location) 83 | 84 | self._epoch = checkpoint['meta']['epoch'] 85 | self._iter = checkpoint['meta']['iter'] 86 | self._inner_iter = checkpoint['meta']['iter'] 87 | if 'optimizer' in checkpoint and resume_optimizer: 88 | if isinstance(self.optimizer, Optimizer): 89 | self.optimizer.load_state_dict(checkpoint['optimizer']) 90 | elif isinstance(self.optimizer, dict): 91 | for k in self.optimizer.keys(): 92 | self.optimizer[k].load_state_dict( 93 | checkpoint['optimizer'][k]) 94 | else: 95 | raise TypeError( 96 | 'Optimizer should be dict or torch.optim.Optimizer ' 97 | f'but got {type(self.optimizer)}') 98 | 99 | if 'amp' in checkpoint: 100 | apex.amp.load_state_dict(checkpoint['amp']) 101 | self.logger.info('load amp state dict') 102 | 103 | self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') 104 | -------------------------------------------------------------------------------- /SEG/mmcv_custom/apex_runner/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import time 4 | from tempfile import TemporaryDirectory 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.parallel import is_module_wrapper 11 | from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict 12 | 13 | try: 14 | import apex 15 | except: 16 | print('apex is not installed') 17 | 18 | 19 | def save_checkpoint(model, filename, optimizer=None, meta=None): 20 | """Save checkpoint to file. 21 | 22 | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and 23 | ``optimizer``, ``amp``. By default ``meta`` will contain version 24 | and time info. 25 | 26 | Args: 27 | model (Module): Module whose params are to be saved. 28 | filename (str): Checkpoint filename. 29 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 30 | meta (dict, optional): Metadata to be saved in checkpoint. 31 | """ 32 | if meta is None: 33 | meta = {} 34 | elif not isinstance(meta, dict): 35 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') 36 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 37 | 38 | if is_module_wrapper(model): 39 | model = model.module 40 | 41 | if hasattr(model, 'CLASSES') and model.CLASSES is not None: 42 | # save class name to the meta 43 | meta.update(CLASSES=model.CLASSES) 44 | 45 | checkpoint = { 46 | 'meta': meta, 47 | 'state_dict': weights_to_cpu(get_state_dict(model)) 48 | } 49 | # save optimizer state dict in the checkpoint 50 | if isinstance(optimizer, Optimizer): 51 | checkpoint['optimizer'] = optimizer.state_dict() 52 | elif isinstance(optimizer, dict): 53 | checkpoint['optimizer'] = {} 54 | for name, optim in optimizer.items(): 55 | checkpoint['optimizer'][name] = optim.state_dict() 56 | 57 | # save amp state dict in the checkpoint 58 | checkpoint['amp'] = apex.amp.state_dict() 59 | 60 | if filename.startswith('pavi://'): 61 | try: 62 | from pavi import modelcloud 63 | from pavi.exception import NodeNotFoundError 64 | except ImportError: 65 | raise ImportError( 66 | 'Please install pavi to load checkpoint from modelcloud.') 67 | model_path = filename[7:] 68 | root = modelcloud.Folder() 69 | model_dir, model_name = osp.split(model_path) 70 | try: 71 | model = modelcloud.get(model_dir) 72 | except NodeNotFoundError: 73 | model = root.create_training_model(model_dir) 74 | with TemporaryDirectory() as tmp_dir: 75 | checkpoint_file = osp.join(tmp_dir, model_name) 76 | with open(checkpoint_file, 'wb') as f: 77 | torch.save(checkpoint, f) 78 | f.flush() 79 | model.create_file(checkpoint_file, name=model_name) 80 | else: 81 | mmcv.mkdir_or_exist(osp.dirname(filename)) 82 | # immediately flush buffer 83 | with open(filename, 'wb') as f: 84 | torch.save(checkpoint, f) 85 | f.flush() 86 | -------------------------------------------------------------------------------- /SEG/mmcv_custom/apex_runner/optimizer.py: -------------------------------------------------------------------------------- 1 | from mmcv.runner import OptimizerHook, HOOKS 2 | try: 3 | import apex 4 | except: 5 | print('apex is not installed') 6 | 7 | 8 | @HOOKS.register_module() 9 | class DistOptimizerHook(OptimizerHook): 10 | """Optimizer hook for distributed training.""" 11 | 12 | def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): 13 | self.grad_clip = grad_clip 14 | self.coalesce = coalesce 15 | self.bucket_size_mb = bucket_size_mb 16 | self.update_interval = update_interval 17 | self.use_fp16 = use_fp16 18 | 19 | def before_run(self, runner): 20 | runner.optimizer.zero_grad() 21 | 22 | def after_train_iter(self, runner): 23 | runner.outputs['loss'] /= self.update_interval 24 | if self.use_fp16: 25 | with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss: 26 | scaled_loss.backward() 27 | else: 28 | runner.outputs['loss'].backward() 29 | if self.every_n_iters(runner, self.update_interval): 30 | if self.grad_clip is not None: 31 | self.clip_grads(runner.model.parameters()) 32 | runner.optimizer.step() 33 | runner.optimizer.zero_grad() 34 | -------------------------------------------------------------------------------- /SEG/mmcv_custom/layer_decay_optimizer_constructor.py: -------------------------------------------------------------------------------- 1 | import json 2 | from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor 3 | from mmcv.runner import get_dist_info 4 | 5 | 6 | def get_num_layer_for_vit(var_name, num_max_layer): 7 | if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): 8 | return 0 9 | elif var_name.startswith("backbone.patch_embed"): 10 | return 0 11 | elif var_name.startswith('backbone.blocks1'): 12 | return 0 13 | elif var_name.startswith('backbone.blocks2'): 14 | return 0 15 | elif var_name.startswith("backbone.blocks3"): 16 | layer_id = int(var_name.split('.')[2]) 17 | return layer_id + 1 18 | else: 19 | return num_max_layer - 1 20 | 21 | 22 | @OPTIMIZER_BUILDERS.register_module() 23 | class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): 24 | def add_params(self, params, module, prefix='', is_dcn_module=None): 25 | """Add all parameters of module to the params list. 26 | The parameters of the given module will be added to the list of param 27 | groups, with specific rules defined by paramwise_cfg. 28 | Args: 29 | params (list[dict]): A list of param groups, it will be modified 30 | in place. 31 | module (nn.Module): The module to be added. 32 | prefix (str): The prefix of the module 33 | is_dcn_module (int|float|None): If the current module is a 34 | submodule of DCN, `is_dcn_module` will be passed to 35 | control conv_offset layer's learning rate. Defaults to None. 36 | """ 37 | parameter_groups = {} 38 | print(self.paramwise_cfg) 39 | num_layers = self.paramwise_cfg.get('num_layers') + 2 40 | layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate') 41 | print("Build LayerDecayOptimizerConstructor %f - %d" % (layer_decay_rate, num_layers)) 42 | weight_decay = self.base_wd 43 | 44 | for name, param in module.named_parameters(): 45 | if not param.requires_grad: 46 | continue # frozen weights 47 | if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'): 48 | group_name = "no_decay" 49 | this_weight_decay = 0. 50 | else: 51 | group_name = "decay" 52 | this_weight_decay = weight_decay 53 | 54 | layer_id = get_num_layer_for_vit(name, num_layers) 55 | group_name = "layer_%d_%s" % (layer_id, group_name) 56 | 57 | if group_name not in parameter_groups: 58 | scale = layer_decay_rate ** (num_layers - layer_id - 1) 59 | 60 | parameter_groups[group_name] = { 61 | "weight_decay": this_weight_decay, 62 | "params": [], 63 | "param_names": [], 64 | "lr_scale": scale, 65 | "group_name": group_name, 66 | "lr": scale * self.base_lr, 67 | } 68 | 69 | parameter_groups[group_name]["params"].append(param) 70 | parameter_groups[group_name]["param_names"].append(name) 71 | rank, _ = get_dist_info() 72 | if rank == 0: 73 | to_display = {} 74 | for key in parameter_groups: 75 | to_display[key] = { 76 | "param_names": parameter_groups[key]["param_names"], 77 | "lr_scale": parameter_groups[key]["lr_scale"], 78 | "lr": parameter_groups[key]["lr"], 79 | "weight_decay": parameter_groups[key]["weight_decay"], 80 | } 81 | print("Param groups = %s" % json.dumps(to_display, indent=2)) 82 | 83 | # state_dict = module.state_dict() 84 | # for group_name in parameter_groups: 85 | # group = parameter_groups[group_name] 86 | # for name in group["param_names"]: 87 | # group["params"].append(state_dict[name]) 88 | params.extend(parameter_groups.values()) 89 | -------------------------------------------------------------------------------- /SEG/mmcv_custom/resize_transform.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | 4 | from mmseg.datasets.builder import PIPELINES 5 | 6 | 7 | @PIPELINES.register_module() 8 | class SETR_Resize(object): 9 | """Resize images & seg. 10 | 11 | This transform resizes the input image to some scale. If the input dict 12 | contains the key "scale", then the scale in the input dict is used, 13 | otherwise the specified scale in the init method is used. 14 | 15 | ``img_scale`` can either be a tuple (single-scale) or a list of tuple 16 | (multi-scale). There are 3 multiscale modes: 17 | 18 | - ``ratio_range is not None``: randomly sample a ratio from the ratio range 19 | and multiply it with the image scale. 20 | 21 | - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a 22 | scale from the a range. 23 | 24 | - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a 25 | scale from multiple scales. 26 | 27 | Args: 28 | img_scale (tuple or list[tuple]): Images scales for resizing. 29 | multiscale_mode (str): Either "range" or "value". 30 | ratio_range (tuple[float]): (min_ratio, max_ratio) 31 | keep_ratio (bool): Whether to keep the aspect ratio when resizing the 32 | image. 33 | """ 34 | 35 | def __init__(self, 36 | img_scale=None, 37 | multiscale_mode='range', 38 | ratio_range=None, 39 | keep_ratio=True, 40 | crop_size=None, 41 | setr_multi_scale=False): 42 | 43 | if img_scale is None: 44 | self.img_scale = None 45 | else: 46 | if isinstance(img_scale, list): 47 | self.img_scale = img_scale 48 | else: 49 | self.img_scale = [img_scale] 50 | # assert mmcv.is_list_of(self.img_scale, tuple) 51 | 52 | if ratio_range is not None: 53 | # mode 1: given a scale and a range of image ratio 54 | assert len(self.img_scale) == 1 55 | else: 56 | # mode 2: given multiple scales or a range of scales 57 | assert multiscale_mode in ['value', 'range'] 58 | 59 | self.multiscale_mode = multiscale_mode 60 | self.ratio_range = ratio_range 61 | self.keep_ratio = keep_ratio 62 | self.crop_size = crop_size 63 | self.setr_multi_scale = setr_multi_scale 64 | 65 | @staticmethod 66 | def random_select(img_scales): 67 | """Randomly select an img_scale from given candidates. 68 | 69 | Args: 70 | img_scales (list[tuple]): Images scales for selection. 71 | 72 | Returns: 73 | (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, 74 | where ``img_scale`` is the selected image scale and 75 | ``scale_idx`` is the selected index in the given candidates. 76 | """ 77 | 78 | assert mmcv.is_list_of(img_scales, tuple) 79 | scale_idx = np.random.randint(len(img_scales)) 80 | img_scale = img_scales[scale_idx] 81 | return img_scale, scale_idx 82 | 83 | @staticmethod 84 | def random_sample(img_scales): 85 | """Randomly sample an img_scale when ``multiscale_mode=='range'``. 86 | 87 | Args: 88 | img_scales (list[tuple]): Images scale range for sampling. 89 | There must be two tuples in img_scales, which specify the lower 90 | and uper bound of image scales. 91 | 92 | Returns: 93 | (tuple, None): Returns a tuple ``(img_scale, None)``, where 94 | ``img_scale`` is sampled scale and None is just a placeholder 95 | to be consistent with :func:`random_select`. 96 | """ 97 | 98 | assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 99 | img_scale_long = [max(s) for s in img_scales] 100 | img_scale_short = [min(s) for s in img_scales] 101 | long_edge = np.random.randint( 102 | min(img_scale_long), 103 | max(img_scale_long) + 1) 104 | short_edge = np.random.randint( 105 | min(img_scale_short), 106 | max(img_scale_short) + 1) 107 | img_scale = (long_edge, short_edge) 108 | return img_scale, None 109 | 110 | @staticmethod 111 | def random_sample_ratio(img_scale, ratio_range): 112 | """Randomly sample an img_scale when ``ratio_range`` is specified. 113 | 114 | A ratio will be randomly sampled from the range specified by 115 | ``ratio_range``. Then it would be multiplied with ``img_scale`` to 116 | generate sampled scale. 117 | 118 | Args: 119 | img_scale (tuple): Images scale base to multiply with ratio. 120 | ratio_range (tuple[float]): The minimum and maximum ratio to scale 121 | the ``img_scale``. 122 | 123 | Returns: 124 | (tuple, None): Returns a tuple ``(scale, None)``, where 125 | ``scale`` is sampled ratio multiplied with ``img_scale`` and 126 | None is just a placeholder to be consistent with 127 | :func:`random_select`. 128 | """ 129 | 130 | assert isinstance(img_scale, tuple) and len(img_scale) == 2 131 | min_ratio, max_ratio = ratio_range 132 | assert min_ratio <= max_ratio 133 | ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio 134 | scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) 135 | return scale, None 136 | 137 | def _random_scale(self, results): 138 | """Randomly sample an img_scale according to ``ratio_range`` and 139 | ``multiscale_mode``. 140 | 141 | If ``ratio_range`` is specified, a ratio will be sampled and be 142 | multiplied with ``img_scale``. 143 | If multiple scales are specified by ``img_scale``, a scale will be 144 | sampled according to ``multiscale_mode``. 145 | Otherwise, single scale will be used. 146 | 147 | Args: 148 | results (dict): Result dict from :obj:`dataset`. 149 | 150 | Returns: 151 | dict: Two new keys 'scale` and 'scale_idx` are added into 152 | ``results``, which would be used by subsequent pipelines. 153 | """ 154 | 155 | if self.ratio_range is not None: 156 | scale, scale_idx = self.random_sample_ratio( 157 | self.img_scale[0], self.ratio_range) 158 | elif len(self.img_scale) == 1: 159 | scale, scale_idx = self.img_scale[0], 0 160 | elif self.multiscale_mode == 'range': 161 | scale, scale_idx = self.random_sample(self.img_scale) 162 | elif self.multiscale_mode == 'value': 163 | scale, scale_idx = self.random_select(self.img_scale) 164 | else: 165 | raise NotImplementedError 166 | 167 | results['scale'] = scale 168 | results['scale_idx'] = scale_idx 169 | 170 | def _resize_img(self, results): 171 | """Resize images with ``results['scale']``.""" 172 | 173 | if self.keep_ratio: 174 | if self.setr_multi_scale: 175 | if min(results['scale']) < self.crop_size[0]: 176 | new_short = self.crop_size[0] 177 | else: 178 | new_short = min(results['scale']) 179 | 180 | h, w = results['img'].shape[:2] 181 | if h > w: 182 | new_h, new_w = new_short * h / w, new_short 183 | else: 184 | new_h, new_w = new_short, new_short * w / h 185 | results['scale'] = (new_h, new_w) 186 | 187 | img, scale_factor = mmcv.imrescale( 188 | results['img'], results['scale'], return_scale=True) 189 | # the w_scale and h_scale has minor difference 190 | # a real fix should be done in the mmcv.imrescale in the future 191 | new_h, new_w = img.shape[:2] 192 | h, w = results['img'].shape[:2] 193 | w_scale = new_w / w 194 | h_scale = new_h / h 195 | else: 196 | img, w_scale, h_scale = mmcv.imresize( 197 | results['img'], results['scale'], return_scale=True) 198 | scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], 199 | dtype=np.float32) 200 | results['img'] = img 201 | results['img_shape'] = img.shape 202 | results['pad_shape'] = img.shape # in case that there is no padding 203 | results['scale_factor'] = scale_factor 204 | results['keep_ratio'] = self.keep_ratio 205 | 206 | def _resize_seg(self, results): 207 | """Resize semantic segmentation map with ``results['scale']``.""" 208 | for key in results.get('seg_fields', []): 209 | if self.keep_ratio: 210 | gt_seg = mmcv.imrescale( 211 | results[key], results['scale'], interpolation='nearest') 212 | else: 213 | gt_seg = mmcv.imresize( 214 | results[key], results['scale'], interpolation='nearest') 215 | results['gt_semantic_seg'] = gt_seg 216 | 217 | def __call__(self, results): 218 | """Call function to resize images, bounding boxes, masks, semantic 219 | segmentation map. 220 | 221 | Args: 222 | results (dict): Result dict from loading pipeline. 223 | 224 | Returns: 225 | dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', 226 | 'keep_ratio' keys are added into result dict. 227 | """ 228 | 229 | if 'scale' not in results: 230 | self._random_scale(results) 231 | self._resize_img(results) 232 | self._resize_seg(results) 233 | return results 234 | 235 | def __repr__(self): 236 | repr_str = self.__class__.__name__ 237 | repr_str += (f'(img_scale={self.img_scale}, ' 238 | f'multiscale_mode={self.multiscale_mode}, ' 239 | f'ratio_range={self.ratio_range}, ' 240 | f'keep_ratio={self.keep_ratio})') 241 | return repr_str 242 | -------------------------------------------------------------------------------- /SEG/mmcv_custom/train_api.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 7 | from mmcv.runner import build_optimizer, build_runner 8 | 9 | from mmseg.core import DistEvalHook, EvalHook 10 | from mmseg.datasets import build_dataloader, build_dataset 11 | from mmseg.utils import get_root_logger 12 | try: 13 | import apex 14 | except: 15 | print('apex is not installed') 16 | 17 | 18 | def set_random_seed(seed, deterministic=False): 19 | """Set random seed. 20 | 21 | Args: 22 | seed (int): Seed to be used. 23 | deterministic (bool): Whether to set the deterministic option for 24 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 25 | to True and `torch.backends.cudnn.benchmark` to False. 26 | Default: False. 27 | """ 28 | random.seed(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | if deterministic: 33 | torch.backends.cudnn.deterministic = True 34 | torch.backends.cudnn.benchmark = False 35 | 36 | 37 | def train_segmentor(model, 38 | dataset, 39 | cfg, 40 | distributed=False, 41 | validate=False, 42 | timestamp=None, 43 | meta=None): 44 | """Launch segmentor training.""" 45 | logger = get_root_logger(cfg.log_level) 46 | 47 | # prepare data loaders 48 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 49 | data_loaders = [ 50 | build_dataloader( 51 | ds, 52 | cfg.data.samples_per_gpu, 53 | cfg.data.workers_per_gpu, 54 | # cfg.gpus will be ignored if distributed 55 | len(cfg.gpu_ids), 56 | dist=distributed, 57 | seed=cfg.seed, 58 | drop_last=True) for ds in dataset 59 | ] 60 | 61 | # build optimizer 62 | optimizer = build_optimizer(model, cfg.optimizer) 63 | 64 | # use apex fp16 optimizer 65 | if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook": 66 | if cfg.optimizer_config.get("use_fp16", False): 67 | model, optimizer = apex.amp.initialize( 68 | model.cuda(), optimizer, opt_level="O1") 69 | for m in model.modules(): 70 | if hasattr(m, "fp16_enabled"): 71 | m.fp16_enabled = True 72 | 73 | # put model on gpus 74 | if distributed: 75 | find_unused_parameters = cfg.get('find_unused_parameters', False) 76 | # Sets the `find_unused_parameters` parameter in 77 | # torch.nn.parallel.DistributedDataParallel 78 | model = MMDistributedDataParallel( 79 | model.cuda(), 80 | device_ids=[torch.cuda.current_device()], 81 | broadcast_buffers=False, 82 | find_unused_parameters=find_unused_parameters) 83 | else: 84 | model = MMDataParallel( 85 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 86 | 87 | if cfg.get('runner') is None: 88 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 89 | warnings.warn( 90 | 'config is now expected to have a `runner` section, ' 91 | 'please set `runner` in your config.', UserWarning) 92 | 93 | runner = build_runner( 94 | cfg.runner, 95 | default_args=dict( 96 | model=model, 97 | batch_processor=None, 98 | optimizer=optimizer, 99 | work_dir=cfg.work_dir, 100 | logger=logger, 101 | meta=meta)) 102 | 103 | # register hooks 104 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 105 | cfg.checkpoint_config, cfg.log_config, 106 | cfg.get('momentum_config', None)) 107 | 108 | # an ugly walkaround to make the .log and .log.json filenames the same 109 | runner.timestamp = timestamp 110 | 111 | # register eval hooks 112 | if validate: 113 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 114 | val_dataloader = build_dataloader( 115 | val_dataset, 116 | samples_per_gpu=1, 117 | workers_per_gpu=cfg.data.workers_per_gpu, 118 | dist=distributed, 119 | shuffle=False) 120 | eval_cfg = cfg.get('evaluation', {}) 121 | eval_cfg['by_epoch'] = 'IterBasedRunner' not in cfg.runner['type'] 122 | eval_hook = DistEvalHook if distributed else EvalHook 123 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 124 | 125 | if cfg.resume_from: 126 | runner.resume(cfg.resume_from) 127 | elif cfg.load_from: 128 | runner.load_checkpoint(cfg.load_from) 129 | runner.run(data_loaders, cfg.workflow) 130 | -------------------------------------------------------------------------------- /SEG/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 10 | -------------------------------------------------------------------------------- /SEG/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-27501} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /SEG/tools/flops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from mmcv import Config 4 | from mmcv.cnn import get_model_complexity_info 5 | 6 | from mmseg.models import build_segmentor 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='Get the FLOPs of a segmentor') 12 | parser.add_argument('config', help='train config file path') 13 | parser.add_argument( 14 | '--shape', 15 | type=int, 16 | nargs='+', 17 | default=[512, 512], 18 | help='input image size') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | 25 | args = parse_args() 26 | 27 | if len(args.shape) == 1: 28 | input_shape = (3, args.shape[0], args.shape[0]) 29 | elif len(args.shape) == 2: 30 | input_shape = (3, ) + tuple(args.shape) 31 | else: 32 | raise ValueError('invalid input shape') 33 | 34 | cfg = Config.fromfile(args.config) 35 | cfg.model.pretrained = None 36 | model = build_segmentor( 37 | cfg.model, 38 | train_cfg=cfg.get('train_cfg'), 39 | test_cfg=cfg.get('test_cfg')).cuda() 40 | model.eval() 41 | 42 | if hasattr(model, 'forward_dummy'): 43 | model.forward = model.forward_dummy 44 | else: 45 | raise NotImplementedError( 46 | 'FLOPs counter is currently not currently supported with {}'. 47 | format(model.__class__.__name__)) 48 | 49 | flops, params = get_model_complexity_info(model, input_shape) 50 | split_line = '=' * 30 51 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( 52 | split_line, input_shape, flops, params)) 53 | print('!!!Please be cautious if you use the results in papers. ' 54 | 'You may need to check if all ops are supported and verify that the ' 55 | 'flops computation is correct.') 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /SEG/tools/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import mmcv 5 | import torch 6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 7 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint 8 | from mmcv.utils import DictAction 9 | 10 | from mmseg.apis import multi_gpu_test, single_gpu_test 11 | from mmseg.datasets import build_dataloader, build_dataset 12 | from mmseg.models import build_segmentor 13 | 14 | from backbone import convmae 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser( 18 | description='mmseg test (and eval) a model') 19 | parser.add_argument('config', help='test config file path') 20 | parser.add_argument('checkpoint', help='checkpoint file') 21 | parser.add_argument( 22 | '--aug-test', action='store_true', help='Use Flip and Multi scale aug') 23 | parser.add_argument('--out', help='output result file in pickle format') 24 | parser.add_argument( 25 | '--format-only', 26 | action='store_true', 27 | help='Format the output results without perform evaluation. It is' 28 | 'useful when you want to format the result to a specific format and ' 29 | 'submit it to the test server') 30 | parser.add_argument( 31 | '--eval', 32 | type=str, 33 | nargs='+', 34 | help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' 35 | ' for generic datasets, and "cityscapes" for Cityscapes') 36 | parser.add_argument('--show', action='store_true', help='show results') 37 | parser.add_argument( 38 | '--show-dir', help='directory where painted images will be saved') 39 | parser.add_argument( 40 | '--gpu-collect', 41 | action='store_true', 42 | help='whether to use gpu to collect results.') 43 | parser.add_argument( 44 | '--tmpdir', 45 | help='tmp directory used for collecting results from multiple ' 46 | 'workers, available when gpu_collect is not specified') 47 | parser.add_argument( 48 | '--options', nargs='+', action=DictAction, help='custom options') 49 | parser.add_argument( 50 | '--eval-options', 51 | nargs='+', 52 | action=DictAction, 53 | help='custom options for evaluation') 54 | parser.add_argument( 55 | '--launcher', 56 | choices=['none', 'pytorch', 'slurm', 'mpi'], 57 | default='none', 58 | help='job launcher') 59 | parser.add_argument('--local_rank', type=int, default=0) 60 | args = parser.parse_args() 61 | if 'LOCAL_RANK' not in os.environ: 62 | os.environ['LOCAL_RANK'] = str(args.local_rank) 63 | return args 64 | 65 | 66 | def main(): 67 | args = parse_args() 68 | 69 | assert args.out or args.eval or args.format_only or args.show \ 70 | or args.show_dir, \ 71 | ('Please specify at least one operation (save/eval/format/show the ' 72 | 'results / save the results) with the argument "--out", "--eval"' 73 | ', "--format-only", "--show" or "--show-dir"') 74 | 75 | if args.eval and args.format_only: 76 | raise ValueError('--eval and --format_only cannot be both specified') 77 | 78 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 79 | raise ValueError('The output file must be a pkl file.') 80 | 81 | cfg = mmcv.Config.fromfile(args.config) 82 | if args.options is not None: 83 | cfg.merge_from_dict(args.options) 84 | # set cudnn_benchmark 85 | if cfg.get('cudnn_benchmark', False): 86 | torch.backends.cudnn.benchmark = True 87 | if args.aug_test: 88 | # hard code index 89 | cfg.data.test.pipeline[1].img_ratios = [ 90 | 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 91 | ] 92 | cfg.data.test.pipeline[1].flip = True 93 | cfg.model.pretrained = None 94 | cfg.data.test.test_mode = True 95 | 96 | # init distributed env first, since logger depends on the dist info. 97 | if args.launcher == 'none': 98 | distributed = False 99 | else: 100 | distributed = True 101 | init_dist(args.launcher, **cfg.dist_params) 102 | 103 | # build the dataloader 104 | # TODO: support multiple images per gpu (only minor changes are needed) 105 | dataset = build_dataset(cfg.data.test) 106 | data_loader = build_dataloader( 107 | dataset, 108 | samples_per_gpu=1, 109 | workers_per_gpu=cfg.data.workers_per_gpu, 110 | dist=distributed, 111 | shuffle=False) 112 | 113 | # build the model and load checkpoint 114 | cfg.model.train_cfg = None 115 | model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) 116 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 117 | model.CLASSES = checkpoint['meta']['CLASSES'] 118 | model.PALETTE = checkpoint['meta']['PALETTE'] 119 | 120 | efficient_test = False 121 | if args.eval_options is not None: 122 | efficient_test = args.eval_options.get('efficient_test', False) 123 | 124 | if not distributed: 125 | model = MMDataParallel(model, device_ids=[0]) 126 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, 127 | efficient_test) 128 | else: 129 | model = MMDistributedDataParallel( 130 | model.cuda(), 131 | device_ids=[torch.cuda.current_device()], 132 | broadcast_buffers=False) 133 | outputs = multi_gpu_test(model, data_loader, args.tmpdir, 134 | args.gpu_collect, efficient_test) 135 | 136 | rank, _ = get_dist_info() 137 | if rank == 0: 138 | if args.out: 139 | print(f'\nwriting results to {args.out}') 140 | mmcv.dump(outputs, args.out) 141 | kwargs = {} if args.eval_options is None else args.eval_options 142 | if args.format_only: 143 | dataset.format_results(outputs, **kwargs) 144 | if args.eval: 145 | dataset.evaluate(outputs, args.eval, **kwargs) 146 | 147 | 148 | if __name__ == '__main__': 149 | main() 150 | -------------------------------------------------------------------------------- /SEG/tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | 7 | import mmcv 8 | import mmcv_custom 9 | import torch 10 | from mmcv.runner import init_dist 11 | from mmcv.utils import Config, DictAction, get_git_hash 12 | 13 | from mmseg import __version__ 14 | from mmseg.apis import set_random_seed 15 | from mmcv_custom import train_segmentor 16 | from mmseg.datasets import build_dataset 17 | from mmseg.models import build_segmentor 18 | from mmseg.utils import collect_env, get_root_logger 19 | 20 | from backbone import convmae 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Train a segmentor') 24 | parser.add_argument('config', help='train config file path') 25 | parser.add_argument('--work-dir', help='the dir to save logs and models') 26 | parser.add_argument( 27 | '--load-from', help='the checkpoint file to load weights from') 28 | parser.add_argument( 29 | '--resume-from', help='the checkpoint file to resume from') 30 | parser.add_argument( 31 | '--no-validate', 32 | action='store_true', 33 | help='whether not to evaluate the checkpoint during training') 34 | group_gpus = parser.add_mutually_exclusive_group() 35 | group_gpus.add_argument( 36 | '--gpus', 37 | type=int, 38 | help='number of gpus to use ' 39 | '(only applicable to non-distributed training)') 40 | group_gpus.add_argument( 41 | '--gpu-ids', 42 | type=int, 43 | nargs='+', 44 | help='ids of gpus to use ' 45 | '(only applicable to non-distributed training)') 46 | parser.add_argument('--seed', type=int, default=None, help='random seed') 47 | parser.add_argument( 48 | '--deterministic', 49 | action='store_true', 50 | help='whether to set deterministic options for CUDNN backend.') 51 | parser.add_argument( 52 | '--options', nargs='+', action=DictAction, help='custom options') 53 | parser.add_argument( 54 | '--launcher', 55 | choices=['none', 'pytorch', 'slurm', 'mpi'], 56 | default='none', 57 | help='job launcher') 58 | parser.add_argument('--local_rank', type=int, default=0) 59 | args = parser.parse_args() 60 | if 'LOCAL_RANK' not in os.environ: 61 | os.environ['LOCAL_RANK'] = str(args.local_rank) 62 | 63 | return args 64 | 65 | 66 | def main(): 67 | args = parse_args() 68 | 69 | cfg = Config.fromfile(args.config) 70 | if args.options is not None: 71 | cfg.merge_from_dict(args.options) 72 | # set cudnn_benchmark 73 | if cfg.get('cudnn_benchmark', False): 74 | torch.backends.cudnn.benchmark = True 75 | 76 | # work_dir is determined in this priority: CLI > segment in file > filename 77 | if args.work_dir is not None: 78 | # update configs according to CLI args if args.work_dir is not None 79 | cfg.work_dir = args.work_dir 80 | elif cfg.get('work_dir', None) is None: 81 | # use config filename as default work_dir if cfg.work_dir is None 82 | cfg.work_dir = osp.join('./work_dirs', 83 | osp.splitext(osp.basename(args.config))[0]) 84 | if args.load_from is not None: 85 | cfg.load_from = args.load_from 86 | if args.resume_from is not None: 87 | cfg.resume_from = args.resume_from 88 | if args.gpu_ids is not None: 89 | cfg.gpu_ids = args.gpu_ids 90 | else: 91 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 92 | 93 | # init distributed env first, since logger depends on the dist info. 94 | if args.launcher == 'none': 95 | distributed = False 96 | else: 97 | distributed = True 98 | init_dist(args.launcher, **cfg.dist_params) 99 | 100 | # create work_dir 101 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 102 | # dump config 103 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 104 | # init the logger before other steps 105 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 106 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 107 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 108 | 109 | # init the meta dict to record some important information such as 110 | # environment info and seed, which will be logged 111 | meta = dict() 112 | # log env info 113 | env_info_dict = collect_env() 114 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 115 | dash_line = '-' * 60 + '\n' 116 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 117 | dash_line) 118 | meta['env_info'] = env_info 119 | 120 | # log some basic info 121 | logger.info(f'Distributed training: {distributed}') 122 | logger.info(f'Config:\n{cfg.pretty_text}') 123 | 124 | # set random seeds 125 | if args.seed is not None: 126 | logger.info(f'Set random seed to {args.seed}, deterministic: ' 127 | f'{args.deterministic}') 128 | set_random_seed(args.seed, deterministic=args.deterministic) 129 | cfg.seed = args.seed 130 | meta['seed'] = args.seed 131 | meta['exp_name'] = osp.basename(args.config) 132 | 133 | model = build_segmentor( 134 | cfg.model, 135 | train_cfg=cfg.get('train_cfg'), 136 | test_cfg=cfg.get('test_cfg')) 137 | 138 | logger.info(model) 139 | 140 | datasets = [build_dataset(cfg.data.train)] 141 | if len(cfg.workflow) == 2: 142 | val_dataset = copy.deepcopy(cfg.data.val) 143 | val_dataset.pipeline = cfg.data.train.pipeline 144 | datasets.append(build_dataset(val_dataset)) 145 | if cfg.checkpoint_config is not None: 146 | # save mmseg version, config file content and class names in 147 | # checkpoints as meta data 148 | cfg.checkpoint_config.meta = dict( 149 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}', 150 | config=cfg.pretty_text, 151 | CLASSES=datasets[0].CLASSES, 152 | PALETTE=datasets[0].PALETTE) 153 | # add an attribute for visualization convenience 154 | model.CLASSES = datasets[0].CLASSES 155 | train_segmentor( 156 | model, 157 | datasets, 158 | cfg, 159 | distributed=distributed, 160 | validate=(not args.no_validate), 161 | timestamp=timestamp, 162 | meta=meta) 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 28 | mixup_fn: Optional[Mixup] = None, log_writer=None, 29 | args=None): 30 | model.train(True) 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 20 35 | 36 | accum_iter = args.accum_iter 37 | 38 | optimizer.zero_grad() 39 | 40 | if log_writer is not None: 41 | print('log_dir: {}'.format(log_writer.log_dir)) 42 | 43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | # we use a per iteration (instead of per epoch) lr scheduler 46 | if data_iter_step % accum_iter == 0: 47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 48 | 49 | samples = samples.to(device, non_blocking=True) 50 | targets = targets.to(device, non_blocking=True) 51 | 52 | if mixup_fn is not None: 53 | samples, targets = mixup_fn(samples, targets) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(samples) 57 | loss = criterion(outputs, targets) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | loss /= accum_iter 66 | loss_scaler(loss, optimizer, clip_grad=max_norm, 67 | parameters=model.parameters(), create_graph=False, 68 | update_grad=(data_iter_step + 1) % accum_iter == 0) 69 | if (data_iter_step + 1) % accum_iter == 0: 70 | optimizer.zero_grad() 71 | 72 | torch.cuda.synchronize() 73 | 74 | metric_logger.update(loss=loss_value) 75 | min_lr = 10. 76 | max_lr = 0. 77 | for group in optimizer.param_groups: 78 | min_lr = min(min_lr, group["lr"]) 79 | max_lr = max(max_lr, group["lr"]) 80 | 81 | metric_logger.update(lr=max_lr) 82 | 83 | loss_value_reduce = misc.all_reduce_mean(loss_value) 84 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 85 | """ We use epoch_1000x as the x-axis in tensorboard. 86 | This calibrates different curves when batch size changes. 87 | """ 88 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 89 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 90 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 91 | 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | print("Averaged stats:", metric_logger) 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 96 | 97 | 98 | @torch.no_grad() 99 | def evaluate(data_loader, model, device): 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | metric_logger = misc.MetricLogger(delimiter=" ") 103 | header = 'Test:' 104 | 105 | # switch to evaluation mode 106 | model.eval() 107 | 108 | for batch in metric_logger.log_every(data_loader, 10, header): 109 | images = batch[0] 110 | target = batch[-1] 111 | images = images.to(device, non_blocking=True) 112 | target = target.to(device, non_blocking=True) 113 | 114 | # compute output 115 | with torch.cuda.amp.autocast(): 116 | output = model(images) 117 | loss = criterion(output, target) 118 | 119 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 120 | 121 | batch_size = images.shape[0] 122 | metric_logger.update(loss=loss.item()) 123 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 124 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 125 | # gather the stats from all processes 126 | metric_logger.synchronize_between_processes() 127 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 128 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 129 | 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Alpha-VL 2 | # -------------------------------------------------------- 3 | # References: 4 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # DeiT: https://github.com/facebookresearch/deit 6 | # MAE: https://github.com/facebookresearch/mae 7 | # -------------------------------------------------------- 8 | import math 9 | import sys 10 | from typing import Iterable 11 | 12 | import torch 13 | 14 | import util.misc as misc 15 | import util.lr_sched as lr_sched 16 | 17 | 18 | def train_one_epoch(model: torch.nn.Module, 19 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 20 | device: torch.device, epoch: int, loss_scaler, 21 | log_writer=None, 22 | args=None): 23 | model.train(True) 24 | metric_logger = misc.MetricLogger(delimiter=" ") 25 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | header = 'Epoch: [{}]'.format(epoch) 27 | print_freq = 20 28 | 29 | accum_iter = args.accum_iter 30 | 31 | optimizer.zero_grad() 32 | 33 | if log_writer is not None: 34 | print('log_dir: {}'.format(log_writer.log_dir)) 35 | 36 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 37 | 38 | # we use a per iteration (instead of per epoch) lr scheduler 39 | if data_iter_step % accum_iter == 0: 40 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 41 | 42 | samples = samples.to(device, non_blocking=True) 43 | 44 | with torch.cuda.amp.autocast(): 45 | loss, _, _ = model(samples, mask_ratio=args.mask_ratio) 46 | 47 | loss_value = loss.item() 48 | 49 | if not math.isfinite(loss_value): 50 | print("Loss is {}, stopping training".format(loss_value)) 51 | sys.exit(1) 52 | 53 | loss /= accum_iter 54 | loss_scaler(loss, optimizer, parameters=model.parameters(), 55 | update_grad=(data_iter_step + 1) % accum_iter == 0) 56 | if (data_iter_step + 1) % accum_iter == 0: 57 | optimizer.zero_grad() 58 | 59 | torch.cuda.synchronize() 60 | 61 | metric_logger.update(loss=loss_value) 62 | 63 | lr = optimizer.param_groups[0]["lr"] 64 | metric_logger.update(lr=lr) 65 | 66 | loss_value_reduce = misc.all_reduce_mean(loss_value) 67 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 68 | """ We use epoch_1000x as the x-axis in tensorboard. 69 | This calibrates different curves when batch size changes. 70 | """ 71 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 72 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 73 | log_writer.add_scalar('lr', lr, epoch_1000x) 74 | 75 | 76 | # gather the stats from all processes 77 | metric_logger.synchronize_between_processes() 78 | print("Averaged stats:", metric_logger) 79 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /figures/ConvMAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VL/ConvMAE/4c97b4c9ec9c85724bc9594eb7302c803ae58c19/figures/ConvMAE.png -------------------------------------------------------------------------------- /figures/Downstream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VL/ConvMAE/4c97b4c9ec9c85724bc9594eb7302c803ae58c19/figures/Downstream.png -------------------------------------------------------------------------------- /figures/feat_map.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VL/ConvMAE/4c97b4c9ec9c85724bc9594eb7302c803ae58c19/figures/feat_map.JPG -------------------------------------------------------------------------------- /main_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Alpha-VL 2 | # -------------------------------------------------------- 3 | # References: 4 | # DeiT: https://github.com/facebookresearch/deit 5 | # MoCo v3: https://github.com/facebookresearch/moco-v3 6 | # MAE: https://github.com/facebookresearch/mae 7 | # -------------------------------------------------------- 8 | 9 | import argparse 10 | import datetime 11 | import json 12 | import numpy as np 13 | import os 14 | import time 15 | from pathlib import Path 16 | 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | from torch.utils.tensorboard import SummaryWriter 20 | import torchvision.transforms as transforms 21 | import torchvision.datasets as datasets 22 | 23 | import timm 24 | 25 | assert timm.__version__ == "0.3.2" # version check 26 | from timm.models.layers import trunc_normal_ 27 | 28 | import util.misc as misc 29 | from util.pos_embed import interpolate_pos_embed 30 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 31 | from util.lars import LARS 32 | from util.crop import RandomResizedCrop 33 | 34 | import models_convvit 35 | 36 | from engine_finetune import train_one_epoch, evaluate 37 | 38 | 39 | def get_args_parser(): 40 | parser = argparse.ArgumentParser('ConvMAE linear probing for image classification', add_help=False) 41 | parser.add_argument('--batch_size', default=512, type=int, 42 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 43 | parser.add_argument('--epochs', default=90, type=int) 44 | parser.add_argument('--accum_iter', default=1, type=int, 45 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 46 | 47 | # Model parameters 48 | parser.add_argument('--model', default='convvit_base_patch16', type=str, metavar='MODEL', 49 | help='Name of model to train') 50 | 51 | # Optimizer parameters 52 | parser.add_argument('--weight_decay', type=float, default=0, 53 | help='weight decay (default: 0 for linear probe following MoCo v1)') 54 | 55 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 56 | help='learning rate (absolute lr)') 57 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 58 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 59 | 60 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 61 | help='lower lr bound for cyclic schedulers that hit 0') 62 | 63 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 64 | help='epochs to warmup LR') 65 | 66 | # * Finetuning params 67 | parser.add_argument('--finetune', default='', 68 | help='finetune from checkpoint') 69 | parser.add_argument('--global_pool', action='store_true') 70 | 71 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 72 | help='Use class token instead of global pool for classification') 73 | 74 | # Dataset parameters 75 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 76 | help='dataset path') 77 | parser.add_argument('--nb_classes', default=1000, type=int, 78 | help='number of the classification types') 79 | 80 | parser.add_argument('--output_dir', default='./output_dir', 81 | help='path where to save, empty for no saving') 82 | parser.add_argument('--log_dir', default='./output_dir', 83 | help='path where to tensorboard log') 84 | parser.add_argument('--device', default='cuda', 85 | help='device to use for training / testing') 86 | parser.add_argument('--seed', default=0, type=int) 87 | parser.add_argument('--resume', default='', 88 | help='resume from checkpoint') 89 | 90 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 91 | help='start epoch') 92 | parser.add_argument('--eval', action='store_true', 93 | help='Perform evaluation only') 94 | parser.add_argument('--dist_eval', action='store_true', default=False, 95 | help='Enabling distributed evaluation (recommended during training for faster monitor') 96 | parser.add_argument('--num_workers', default=10, type=int) 97 | parser.add_argument('--pin_mem', action='store_true', 98 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 99 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 100 | parser.set_defaults(pin_mem=True) 101 | 102 | # distributed training parameters 103 | parser.add_argument('--world_size', default=1, type=int, 104 | help='number of distributed processes') 105 | parser.add_argument('--local_rank', default=-1, type=int) 106 | parser.add_argument('--dist_on_itp', action='store_true') 107 | parser.add_argument('--dist_url', default='env://', 108 | help='url used to set up distributed training') 109 | 110 | return parser 111 | 112 | 113 | def main(args): 114 | misc.init_distributed_mode(args) 115 | 116 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 117 | print("{}".format(args).replace(', ', ',\n')) 118 | 119 | device = torch.device(args.device) 120 | 121 | # fix the seed for reproducibility 122 | seed = args.seed + misc.get_rank() 123 | torch.manual_seed(seed) 124 | np.random.seed(seed) 125 | 126 | cudnn.benchmark = True 127 | 128 | # linear probe: weak augmentation 129 | transform_train = transforms.Compose([ 130 | RandomResizedCrop(224, interpolation=3), 131 | transforms.RandomHorizontalFlip(), 132 | transforms.ToTensor(), 133 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 134 | transform_val = transforms.Compose([ 135 | transforms.Resize(256, interpolation=3), 136 | transforms.CenterCrop(224), 137 | transforms.ToTensor(), 138 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 139 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 140 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 141 | print(dataset_train) 142 | print(dataset_val) 143 | 144 | if True: # args.distributed: 145 | num_tasks = misc.get_world_size() 146 | global_rank = misc.get_rank() 147 | sampler_train = torch.utils.data.DistributedSampler( 148 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 149 | ) 150 | print("Sampler_train = %s" % str(sampler_train)) 151 | if args.dist_eval: 152 | if len(dataset_val) % num_tasks != 0: 153 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 154 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 155 | 'equal num of samples per-process.') 156 | sampler_val = torch.utils.data.DistributedSampler( 157 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 158 | else: 159 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 160 | else: 161 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 162 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 163 | 164 | if global_rank == 0 and args.log_dir is not None and not args.eval: 165 | os.makedirs(args.log_dir, exist_ok=True) 166 | log_writer = SummaryWriter(log_dir=args.log_dir) 167 | else: 168 | log_writer = None 169 | 170 | data_loader_train = torch.utils.data.DataLoader( 171 | dataset_train, sampler=sampler_train, 172 | batch_size=args.batch_size, 173 | num_workers=args.num_workers, 174 | pin_memory=args.pin_mem, 175 | drop_last=True, 176 | ) 177 | 178 | data_loader_val = torch.utils.data.DataLoader( 179 | dataset_val, sampler=sampler_val, 180 | batch_size=args.batch_size, 181 | num_workers=args.num_workers, 182 | pin_memory=args.pin_mem, 183 | drop_last=False 184 | ) 185 | 186 | model = models_convvit.__dict__[args.model]( 187 | num_classes=args.nb_classes, 188 | global_pool=args.global_pool, 189 | ) 190 | 191 | if args.finetune and not args.eval: 192 | checkpoint = torch.load(args.finetune, map_location='cpu') 193 | 194 | print("Load pre-trained checkpoint from: %s" % args.finetune) 195 | checkpoint_model = checkpoint['model'] 196 | state_dict = model.state_dict() 197 | for k in ['head.weight', 'head.bias']: 198 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 199 | print(f"Removing key {k} from pretrained checkpoint") 200 | del checkpoint_model[k] 201 | 202 | # interpolate position embedding 203 | # interpolate_pos_embed(model, checkpoint_model) 204 | 205 | # load pre-trained model 206 | msg = model.load_state_dict(checkpoint_model, strict=False) 207 | print(msg) 208 | 209 | # manually initialize fc layer: following MoCo v3 210 | trunc_normal_(model.head.weight, std=0.01) 211 | 212 | # for linear prob only 213 | # hack: revise model's head with BN 214 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 215 | # freeze all but the head 216 | for _, p in model.named_parameters(): 217 | p.requires_grad = False 218 | for _, p in model.head.named_parameters(): 219 | p.requires_grad = True 220 | 221 | model.to(device) 222 | 223 | model_without_ddp = model 224 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 225 | 226 | print("Model = %s" % str(model_without_ddp)) 227 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 228 | 229 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 230 | 231 | if args.lr is None: # only base_lr is specified 232 | args.lr = args.blr * eff_batch_size / 256 233 | 234 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 235 | print("actual lr: %.2e" % args.lr) 236 | 237 | print("accumulate grad iterations: %d" % args.accum_iter) 238 | print("effective batch size: %d" % eff_batch_size) 239 | 240 | if args.distributed: 241 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 242 | model_without_ddp = model.module 243 | 244 | optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) 245 | print(optimizer) 246 | loss_scaler = NativeScaler() 247 | 248 | criterion = torch.nn.CrossEntropyLoss() 249 | 250 | print("criterion = %s" % str(criterion)) 251 | 252 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 253 | 254 | if args.eval: 255 | test_stats = evaluate(data_loader_val, model, device) 256 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 257 | exit(0) 258 | 259 | print(f"Start training for {args.epochs} epochs") 260 | start_time = time.time() 261 | max_accuracy = 0.0 262 | for epoch in range(args.start_epoch, args.epochs): 263 | if args.distributed: 264 | data_loader_train.sampler.set_epoch(epoch) 265 | train_stats = train_one_epoch( 266 | model, criterion, data_loader_train, 267 | optimizer, device, epoch, loss_scaler, 268 | max_norm=None, 269 | log_writer=log_writer, 270 | args=args 271 | ) 272 | if args.output_dir: 273 | misc.save_model( 274 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 275 | loss_scaler=loss_scaler, epoch=epoch) 276 | 277 | test_stats = evaluate(data_loader_val, model, device) 278 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 279 | #max_accuracy = max(max_accuracy, test_stats["acc1"]) 280 | if max_accuracy < test_stats["acc1"]: 281 | max_accuracy = test_stats["acc1"] 282 | misc.save_best_model( 283 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 284 | loss_scaler=loss_scaler, epoch=epoch) 285 | print(f'Max accuracy: {max_accuracy:.2f}%') 286 | 287 | if log_writer is not None: 288 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 289 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 290 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 291 | 292 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 293 | **{f'test_{k}': v for k, v in test_stats.items()}, 294 | 'epoch': epoch, 295 | 'n_parameters': n_parameters} 296 | 297 | if args.output_dir and misc.is_main_process(): 298 | if log_writer is not None: 299 | log_writer.flush() 300 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 301 | f.write(json.dumps(log_stats) + "\n") 302 | 303 | total_time = time.time() - start_time 304 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 305 | print('Training time {}'.format(total_time_str)) 306 | 307 | 308 | if __name__ == '__main__': 309 | args = get_args_parser() 310 | args = args.parse_args() 311 | if args.output_dir: 312 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 313 | main(args) 314 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Alpha-VL 2 | # -------------------------------------------------------- 3 | # References: 4 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # DeiT: https://github.com/facebookresearch/deit 6 | # MAE: https://github.com/facebookresearch/mae 7 | # -------------------------------------------------------- 8 | import argparse 9 | import datetime 10 | import json 11 | import numpy as np 12 | import os 13 | import time 14 | from pathlib import Path 15 | 16 | import torch 17 | import torch.backends.cudnn as cudnn 18 | from torch.utils.tensorboard import SummaryWriter 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | 22 | import timm 23 | 24 | assert timm.__version__ == "0.3.2" # version check 25 | import timm.optim.optim_factory as optim_factory 26 | 27 | import util.misc as misc 28 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 29 | 30 | import models_convmae 31 | 32 | from engine_pretrain import train_one_epoch 33 | 34 | 35 | def get_args_parser(): 36 | parser = argparse.ArgumentParser('ConvMAE pre-training', add_help=False) 37 | parser.add_argument('--batch_size', default=64, type=int, 38 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 39 | parser.add_argument('--epochs', default=400, type=int) 40 | parser.add_argument('--accum_iter', default=1, type=int, 41 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 42 | 43 | # Model parameters 44 | parser.add_argument('--model', default='convmae_convvit_base_patch16', type=str, metavar='MODEL', 45 | help='Name of model to train') 46 | 47 | parser.add_argument('--input_size', default=224, type=int, 48 | help='images input size') 49 | 50 | parser.add_argument('--mask_ratio', default=0.75, type=float, 51 | help='Masking ratio (percentage of removed patches).') 52 | 53 | parser.add_argument('--norm_pix_loss', action='store_true', 54 | help='Use (per-patch) normalized pixels as targets for computing loss') 55 | parser.set_defaults(norm_pix_loss=False) 56 | 57 | # Optimizer parameters 58 | parser.add_argument('--weight_decay', type=float, default=0.05, 59 | help='weight decay (default: 0.05)') 60 | 61 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 62 | help='learning rate (absolute lr)') 63 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 64 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 65 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 66 | help='lower lr bound for cyclic schedulers that hit 0') 67 | 68 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 69 | help='epochs to warmup LR') 70 | 71 | # Dataset parameters 72 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 73 | help='dataset path') 74 | 75 | parser.add_argument('--output_dir', default='./output_dir', 76 | help='path where to save, empty for no saving') 77 | parser.add_argument('--log_dir', default='./output_dir', 78 | help='path where to tensorboard log') 79 | parser.add_argument('--device', default='cuda', 80 | help='device to use for training / testing') 81 | parser.add_argument('--seed', default=0, type=int) 82 | parser.add_argument('--resume', default='', 83 | help='resume from checkpoint') 84 | 85 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 86 | help='start epoch') 87 | parser.add_argument('--num_workers', default=10, type=int) 88 | parser.add_argument('--pin_mem', action='store_true', 89 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 90 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 91 | parser.set_defaults(pin_mem=True) 92 | 93 | # distributed training parameters 94 | parser.add_argument('--world_size', default=1, type=int, 95 | help='number of distributed processes') 96 | parser.add_argument('--local_rank', default=-1, type=int) 97 | parser.add_argument('--dist_on_itp', action='store_true') 98 | parser.add_argument('--dist_url', default='env://', 99 | help='url used to set up distributed training') 100 | 101 | return parser 102 | 103 | 104 | def main(args): 105 | misc.init_distributed_mode(args) 106 | 107 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 108 | print("{}".format(args).replace(', ', ',\n')) 109 | 110 | device = torch.device(args.device) 111 | 112 | # fix the seed for reproducibility 113 | seed = args.seed + misc.get_rank() 114 | torch.manual_seed(seed) 115 | np.random.seed(seed) 116 | 117 | cudnn.benchmark = True 118 | 119 | # simple augmentation 120 | transform_train = transforms.Compose([ 121 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 122 | transforms.RandomHorizontalFlip(), 123 | transforms.ToTensor(), 124 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 125 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 126 | print(dataset_train) 127 | 128 | if True: # args.distributed: 129 | num_tasks = misc.get_world_size() 130 | global_rank = misc.get_rank() 131 | sampler_train = torch.utils.data.DistributedSampler( 132 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 133 | ) 134 | print("Sampler_train = %s" % str(sampler_train)) 135 | else: 136 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 137 | 138 | if global_rank == 0 and args.log_dir is not None: 139 | os.makedirs(args.log_dir, exist_ok=True) 140 | log_writer = SummaryWriter(log_dir=args.log_dir) 141 | else: 142 | log_writer = None 143 | 144 | data_loader_train = torch.utils.data.DataLoader( 145 | dataset_train, sampler=sampler_train, 146 | batch_size=args.batch_size, 147 | num_workers=args.num_workers, 148 | pin_memory=args.pin_mem, 149 | drop_last=True, 150 | ) 151 | 152 | # define the model 153 | model = models_convmae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) 154 | 155 | model.to(device) 156 | 157 | model_without_ddp = model 158 | print("Model = %s" % str(model_without_ddp)) 159 | 160 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 161 | 162 | if args.lr is None: # only base_lr is specified 163 | args.lr = args.blr * eff_batch_size / 256 164 | 165 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 166 | print("actual lr: %.2e" % args.lr) 167 | 168 | print("accumulate grad iterations: %d" % args.accum_iter) 169 | print("effective batch size: %d" % eff_batch_size) 170 | 171 | if args.distributed: 172 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 173 | model_without_ddp = model.module 174 | 175 | # following timm: set wd as 0 for bias and norm layers 176 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 177 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 178 | print(optimizer) 179 | loss_scaler = NativeScaler() 180 | 181 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 182 | 183 | print(f"Start training for {args.epochs} epochs") 184 | start_time = time.time() 185 | for epoch in range(args.start_epoch, args.epochs): 186 | if args.distributed: 187 | data_loader_train.sampler.set_epoch(epoch) 188 | train_stats = train_one_epoch( 189 | model, data_loader_train, 190 | optimizer, device, epoch, loss_scaler, 191 | log_writer=log_writer, 192 | args=args 193 | ) 194 | if args.output_dir and (epoch % 40 == 0 or epoch + 1 == args.epochs): 195 | misc.save_model( 196 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 197 | loss_scaler=loss_scaler, epoch=epoch) 198 | 199 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 200 | 'epoch': epoch,} 201 | 202 | if args.output_dir and misc.is_main_process(): 203 | if log_writer is not None: 204 | log_writer.flush() 205 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 206 | f.write(json.dumps(log_stats) + "\n") 207 | 208 | total_time = time.time() - start_time 209 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 210 | print('Training time {}'.format(total_time_str)) 211 | 212 | 213 | if __name__ == '__main__': 214 | args = get_args_parser() 215 | args = args.parse_args() 216 | if args.output_dir: 217 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 218 | main(args) 219 | -------------------------------------------------------------------------------- /models_convmae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Alpha-VL 2 | # -------------------------------------------------------- 3 | # References: 4 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # DeiT: https://github.com/facebookresearch/deit 6 | # MAE: https://github.com/facebookresearch/mae 7 | # -------------------------------------------------------- 8 | 9 | from functools import partial 10 | import pdb 11 | import torch 12 | import torch.nn as nn 13 | 14 | from vision_transformer import PatchEmbed, Block, CBlock 15 | 16 | from util.pos_embed import get_2d_sincos_pos_embed 17 | 18 | 19 | class MaskedAutoencoderConvViT(nn.Module): 20 | """ Masked Autoencoder with VisionTransformer backbone 21 | """ 22 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 23 | embed_dim=1024, depth=24, num_heads=16, 24 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 25 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 26 | super().__init__() 27 | # -------------------------------------------------------------------------- 28 | # ConvMAE encoder specifics 29 | self.patch_embed1 = PatchEmbed( 30 | img_size=img_size[0], patch_size=patch_size[0], in_chans=in_chans, embed_dim=embed_dim[0]) 31 | self.patch_embed2 = PatchEmbed( 32 | img_size=img_size[1], patch_size=patch_size[1], in_chans=embed_dim[0], embed_dim=embed_dim[1]) 33 | self.patch_embed3 = PatchEmbed( 34 | img_size=img_size[2], patch_size=patch_size[2], in_chans=embed_dim[1], embed_dim=embed_dim[2]) 35 | 36 | self.patch_embed4 = nn.Linear(embed_dim[2], embed_dim[2]) 37 | self.stage1_output_decode = nn.Conv2d(embed_dim[0], embed_dim[2], 4, stride=4) 38 | self.stage2_output_decode = nn.Conv2d(embed_dim[1], embed_dim[2], 2, stride=2) 39 | 40 | num_patches = self.patch_embed3.num_patches 41 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[2]), requires_grad=False) 42 | self.blocks1 = nn.ModuleList([ 43 | CBlock( 44 | dim=embed_dim[0], num_heads=num_heads, mlp_ratio=mlp_ratio[0], qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 45 | for i in range(depth[0])]) 46 | self.blocks2 = nn.ModuleList([ 47 | CBlock( 48 | dim=embed_dim[1], num_heads=num_heads, mlp_ratio=mlp_ratio[1], qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 49 | for i in range(depth[1])]) 50 | self.blocks3 = nn.ModuleList([ 51 | Block( 52 | dim=embed_dim[2], num_heads=num_heads, mlp_ratio=mlp_ratio[2], qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 53 | for i in range(depth[2])]) 54 | self.norm = norm_layer(embed_dim[-1]) 55 | 56 | # -------------------------------------------------------------------------- 57 | # ConvMAE decoder specifics 58 | self.decoder_embed = nn.Linear(embed_dim[-1], decoder_embed_dim, bias=True) 59 | 60 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 61 | 62 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 63 | self.decoder_blocks = nn.ModuleList([ 64 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio[0], qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 65 | for i in range(decoder_depth)]) 66 | 67 | self.decoder_norm = norm_layer(decoder_embed_dim) 68 | self.decoder_pred = nn.Linear(decoder_embed_dim, (patch_size[0] * patch_size[1] * patch_size[2])**2 * in_chans, bias=True) # decoder to patch 69 | # -------------------------------------------------------------------------- 70 | 71 | self.norm_pix_loss = norm_pix_loss 72 | 73 | self.initialize_weights() 74 | 75 | def initialize_weights(self): 76 | # initialization 77 | # initialize (and freeze) pos_embed by sin-cos embedding 78 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed3.num_patches**.5), cls_token=False) 79 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 80 | 81 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed3.num_patches**.5), cls_token=False) 82 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 83 | 84 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 85 | w = self.patch_embed3.proj.weight.data 86 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 87 | 88 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 89 | # torch.nn.init.normal_(self.cls_token, std=.02) 90 | torch.nn.init.normal_(self.mask_token, std=.02) 91 | 92 | # initialize nn.Linear and nn.LayerNorm 93 | self.apply(self._init_weights) 94 | 95 | def _init_weights(self, m): 96 | if isinstance(m, nn.Linear): 97 | # we use xavier_uniform following official JAX ViT: 98 | torch.nn.init.xavier_uniform_(m.weight) 99 | if isinstance(m, nn.Linear) and m.bias is not None: 100 | nn.init.constant_(m.bias, 0) 101 | elif isinstance(m, nn.LayerNorm): 102 | nn.init.constant_(m.bias, 0) 103 | nn.init.constant_(m.weight, 1.0) 104 | 105 | def patchify(self, imgs): 106 | """ 107 | imgs: (N, 3, H, W) 108 | x: (N, L, patch_size**2 *3) 109 | """ 110 | p = 16 111 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 112 | 113 | h = w = imgs.shape[2] // p 114 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 115 | x = torch.einsum('nchpwq->nhwpqc', x) 116 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 117 | return x 118 | 119 | def unpatchify(self, x): 120 | """ 121 | x: (N, L, patch_size**2 *3) 122 | imgs: (N, 3, H, W) 123 | """ 124 | p = self.patch_embed.patch_size[0] 125 | h = w = int(x.shape[1]**.5) 126 | assert h * w == x.shape[1] 127 | 128 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 129 | x = torch.einsum('nhwpqc->nchpwq', x) 130 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 131 | return imgs 132 | 133 | def random_masking(self, x, mask_ratio): 134 | """ 135 | Perform per-sample random masking by per-sample shuffling. 136 | Per-sample shuffling is done by argsort random noise. 137 | x: [N, L, D], sequence 138 | """ 139 | N = x.shape[0] 140 | L = self.patch_embed3.num_patches 141 | # N, L, D = x.shape # batch, length, dim 142 | len_keep = int(L * (1 - mask_ratio)) 143 | 144 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 145 | 146 | # sort noise for each sample 147 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 148 | ids_restore = torch.argsort(ids_shuffle, dim=1) 149 | 150 | # keep the first subset 151 | ids_keep = ids_shuffle[:, :len_keep] 152 | # x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 153 | 154 | # generate the binary mask: 0 is keep, 1 is remove 155 | mask = torch.ones([N, L], device=x.device) 156 | mask[:, :len_keep] = 0 157 | # unshuffle to get the binary mask 158 | mask = torch.gather(mask, dim=1, index=ids_restore) 159 | 160 | return ids_keep, mask, ids_restore 161 | 162 | def forward_encoder(self, x, mask_ratio): 163 | # embed patches 164 | ids_keep, mask, ids_restore = self.random_masking(x, mask_ratio) 165 | mask_for_patch1 = mask.reshape(-1, 14, 14).unsqueeze(-1).repeat(1, 1, 1, 16).reshape(-1, 14, 14, 4, 4).permute(0, 1, 3, 2, 4).reshape(x.shape[0], 56, 56).unsqueeze(1) 166 | mask_for_patch2 = mask.reshape(-1, 14, 14).unsqueeze(-1).repeat(1, 1, 1, 4).reshape(-1, 14, 14, 2, 2).permute(0, 1, 3, 2, 4).reshape(x.shape[0], 28, 28).unsqueeze(1) 167 | x = self.patch_embed1(x) 168 | for blk in self.blocks1: 169 | x = blk(x, 1 - mask_for_patch1) 170 | stage1_embed = self.stage1_output_decode(x).flatten(2).permute(0, 2, 1) 171 | 172 | x = self.patch_embed2(x) 173 | for blk in self.blocks2: 174 | x = blk(x, 1 - mask_for_patch2) 175 | stage2_embed = self.stage2_output_decode(x).flatten(2).permute(0, 2, 1) 176 | x = self.patch_embed3(x) 177 | x = x.flatten(2).permute(0, 2, 1) 178 | x = self.patch_embed4(x) 179 | # add pos embed w/o cls token 180 | x = x + self.pos_embed 181 | x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1])) 182 | stage1_embed = torch.gather(stage1_embed, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) 183 | stage2_embed = torch.gather(stage2_embed, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) 184 | 185 | 186 | # apply Transformer blocks 187 | for blk in self.blocks3: 188 | x = blk(x) 189 | x = x + stage1_embed + stage2_embed 190 | x = self.norm(x) 191 | 192 | return x, mask, ids_restore 193 | 194 | def forward_decoder(self, x, ids_restore): 195 | # embed tokens 196 | x = self.decoder_embed(x) 197 | 198 | # append mask tokens to sequence 199 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) 200 | x_ = torch.cat([x, mask_tokens], dim=1) # no cls token 201 | x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 202 | 203 | # add pos embed 204 | x = x + self.decoder_pos_embed 205 | 206 | # apply Transformer blocks 207 | for blk in self.decoder_blocks: 208 | x = blk(x) 209 | x = self.decoder_norm(x) 210 | 211 | # predictor projection 212 | x = self.decoder_pred(x) 213 | 214 | return x 215 | 216 | def forward_loss(self, imgs, pred, mask): 217 | """ 218 | imgs: [N, 3, H, W] 219 | pred: [N, L, p*p*3] 220 | mask: [N, L], 0 is keep, 1 is remove, 221 | """ 222 | target = self.patchify(imgs) 223 | if self.norm_pix_loss: 224 | mean = target.mean(dim=-1, keepdim=True) 225 | var = target.var(dim=-1, keepdim=True) 226 | target = (target - mean) / (var + 1.e-6)**.5 227 | 228 | loss = (pred - target) ** 2 229 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 230 | 231 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 232 | return loss 233 | 234 | def forward(self, imgs, mask_ratio=0.75): 235 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 236 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 237 | loss = self.forward_loss(imgs, pred, mask) 238 | return loss, pred, mask 239 | 240 | 241 | def convmae_convvit_base_patch16_dec512d8b(**kwargs): 242 | model = MaskedAutoencoderConvViT( 243 | img_size=[224, 56, 28], patch_size=[4, 2, 2], embed_dim=[256, 384, 768], depth=[2, 2, 11], num_heads=12, 244 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 245 | mlp_ratio=[4, 4, 4], norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 246 | return model 247 | 248 | 249 | # set recommended archs 250 | convmae_convvit_base_patch16 = convmae_convvit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 251 | -------------------------------------------------------------------------------- /models_convvit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Alpha-VL 2 | # References: 3 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 4 | # DeiT: https://github.com/facebookresearch/deit 5 | # MAE: https://github.com/facebookresearch/mae 6 | # -------------------------------------------------------- 7 | 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn as nn 12 | import pdb 13 | import vision_transformer 14 | 15 | 16 | class ConvViT(vision_transformer.ConvViT): 17 | """ Vision Transformer with support for global average pooling 18 | """ 19 | def __init__(self, global_pool=False, **kwargs): 20 | super(ConvViT, self).__init__(**kwargs) 21 | self.global_pool = global_pool 22 | if self.global_pool: 23 | norm_layer = kwargs['norm_layer'] 24 | embed_dim = kwargs['embed_dim'] 25 | self.fc_norm = norm_layer(embed_dim[-1]) 26 | 27 | del self.norm # remove the original norm 28 | 29 | def forward_features(self, x): 30 | B = x.shape[0] 31 | x = self.patch_embed1(x) 32 | x = self.pos_drop(x) 33 | for blk in self.blocks1: 34 | x = blk(x) 35 | x = self.patch_embed2(x) 36 | for blk in self.blocks2: 37 | x = blk(x) 38 | x = self.patch_embed3(x) 39 | x = x.flatten(2).permute(0, 2, 1) 40 | x = self.patch_embed4(x) 41 | x = x + self.pos_embed 42 | for blk in self.blocks3: 43 | x = blk(x) 44 | if self.global_pool: 45 | x = x[:, :, :].mean(dim=1) # global pool without cls token 46 | outcome = self.fc_norm(x) 47 | else: 48 | x = self.norm(x) 49 | outcome = x[:, 0] 50 | return outcome 51 | 52 | 53 | def convvit_base_patch16(**kwargs): 54 | model = ConvViT( 55 | img_size=[224, 56, 28], patch_size=[4, 2, 2], embed_dim=[256, 384, 768], depth=[2, 2, 11], num_heads=12, mlp_ratio=[4, 4, 4], qkv_bias=True, 56 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 57 | return model 58 | 59 | -------------------------------------------------------------------------------- /submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_finetune as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_finetune as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /submitit_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_linprobe as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_linprobe as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_pretrain as trainer 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | trainer_parser = trainer.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=20160, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="alpha_vl", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_pretrain as trainer 57 | 58 | self._setup_gpu_args() 59 | trainer.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, # max is 60 * 72 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 23 | dataset = datasets.ImageFolder(root, transform=transform) 24 | 25 | print(dataset) 26 | 27 | return dataset 28 | 29 | 30 | def build_transform(is_train, args): 31 | mean = IMAGENET_DEFAULT_MEAN 32 | std = IMAGENET_DEFAULT_STD 33 | # train transform 34 | if is_train: 35 | # this should always dispatch to transforms_imagenet_train 36 | transform = create_transform( 37 | input_size=args.input_size, 38 | is_training=True, 39 | color_jitter=args.color_jitter, 40 | auto_augment=args.aa, 41 | interpolation='bicubic', 42 | re_prob=args.reprob, 43 | re_mode=args.remode, 44 | re_count=args.recount, 45 | mean=mean, 46 | std=std, 47 | ) 48 | return transform 49 | 50 | # eval transform 51 | t = [] 52 | if args.input_size <= 224: 53 | crop_pct = 224 / 256 54 | else: 55 | crop_pct = 1.0 56 | size = int(args.input_size / crop_pct) 57 | t.append( 58 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 59 | ) 60 | t.append(transforms.CenterCrop(args.input_size)) 61 | 62 | t.append(transforms.ToTensor()) 63 | t.append(transforms.Normalize(mean, std)) 64 | return transforms.Compose(t) 65 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | import pdb 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | num_layers = len(model.blocks3) + 1 23 | 24 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 25 | 26 | for n, p in model.named_parameters(): 27 | if not p.requires_grad: 28 | continue 29 | 30 | # no decay: all 1D parameters and model specific ones 31 | if p.ndim == 1 or n in no_weight_decay_list: 32 | g_decay = "no_decay" 33 | this_decay = 0. 34 | else: 35 | g_decay = "decay" 36 | this_decay = weight_decay 37 | 38 | layer_id = get_layer_id_for_vit(n, num_layers) 39 | group_name = "layer_%d_%s" % (layer_id, g_decay) 40 | 41 | if group_name not in param_group_names: 42 | this_scale = layer_scales[layer_id] 43 | 44 | param_group_names[group_name] = { 45 | "lr_scale": this_scale, 46 | "weight_decay": this_decay, 47 | "params": [], 48 | } 49 | param_groups[group_name] = { 50 | "lr_scale": this_scale, 51 | "weight_decay": this_decay, 52 | "params": [], 53 | } 54 | 55 | param_group_names[group_name]["params"].append(n) 56 | param_groups[group_name]["params"].append(p) 57 | 58 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 59 | 60 | return list(param_groups.values()) 61 | 62 | 63 | def get_layer_id_for_vit(name, num_layers): 64 | """ 65 | Assign a parameter with its layer id 66 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 67 | """ 68 | if name in ['cls_token', 'pos_embed']: 69 | return 0 70 | elif name.startswith('patch_embed'): 71 | return 0 72 | elif name.startswith('blocks1'): 73 | return 0 74 | elif name.startswith('blocks2'): 75 | return 0 76 | elif name.startswith('blocks3'): 77 | return int(name.split('.')[1]) + 1 78 | else: 79 | return num_layers 80 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def save_best_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 316 | output_dir = Path(args.output_dir) 317 | epoch_name = str(epoch) 318 | if loss_scaler is not None: 319 | checkpoint_paths = [output_dir / 'best_checkpoint.pth'] 320 | for checkpoint_path in checkpoint_paths: 321 | to_save = { 322 | 'model': model_without_ddp.state_dict(), 323 | 'optimizer': optimizer.state_dict(), 324 | 'epoch': epoch, 325 | 'scaler': loss_scaler.state_dict(), 326 | 'args': args, 327 | } 328 | 329 | save_on_master(to_save, checkpoint_path) 330 | else: 331 | client_state = {'epoch': epoch} 332 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 333 | 334 | 335 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 336 | if args.resume: 337 | if args.resume.startswith('https'): 338 | checkpoint = torch.hub.load_state_dict_from_url( 339 | args.resume, map_location='cpu', check_hash=True) 340 | else: 341 | checkpoint = torch.load(args.resume, map_location='cpu') 342 | model_without_ddp.load_state_dict(checkpoint['model']) 343 | print("Resume checkpoint %s" % args.resume) 344 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 345 | optimizer.load_state_dict(checkpoint['optimizer']) 346 | args.start_epoch = checkpoint['epoch'] + 1 347 | if 'scaler' in checkpoint: 348 | loss_scaler.load_state_dict(checkpoint['scaler']) 349 | print("With optim & sched!") 350 | 351 | 352 | def all_reduce_mean(x): 353 | world_size = get_world_size() 354 | if world_size > 1: 355 | x_reduce = torch.tensor(x).cuda() 356 | dist.all_reduce(x_reduce) 357 | x_reduce /= world_size 358 | return x_reduce.item() 359 | else: 360 | return x 361 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Alpha-VL 2 | # -------------------------------------------------------- 3 | # References: 4 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # DeiT: https://github.com/facebookresearch/deit 6 | # -------------------------------------------------------- 7 | import torch 8 | import torch.nn as nn 9 | from functools import partial 10 | 11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.models.helpers import load_pretrained 13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 14 | from timm.models.resnet import resnet26d, resnet50d 15 | from timm.models.registry import register_model 16 | 17 | import pdb 18 | 19 | def _cfg(url='', **kwargs): 20 | return { 21 | 'url': url, 22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 23 | 'crop_pct': .9, 'interpolation': 'bicubic', 24 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 25 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 26 | **kwargs 27 | } 28 | 29 | 30 | default_cfgs = { 31 | # patch models 32 | 'vit_small_patch16_224': _cfg( 33 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', 34 | ), 35 | 'vit_base_patch16_224': _cfg( 36 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 37 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 38 | ), 39 | 'vit_base_patch16_384': _cfg( 40 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', 41 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 42 | 'vit_base_patch32_384': _cfg( 43 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', 44 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 45 | 'vit_large_patch16_224': _cfg( 46 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 47 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 48 | 'vit_large_patch16_384': _cfg( 49 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', 50 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 51 | 'vit_large_patch32_384': _cfg( 52 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 53 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 54 | 'vit_huge_patch16_224': _cfg(), 55 | 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), 56 | # hybrid models 57 | 'vit_small_resnet26d_224': _cfg(), 58 | 'vit_small_resnet50d_s3_224': _cfg(), 59 | 'vit_base_resnet26d_224': _cfg(), 60 | 'vit_base_resnet50d_224': _cfg(), 61 | } 62 | 63 | 64 | class CMlp(nn.Module): 65 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 66 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 67 | super().__init__() 68 | out_features = out_features or in_features 69 | hidden_features = hidden_features or in_features 70 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 71 | self.act = act_layer() 72 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 73 | self.drop = nn.Dropout(drop) 74 | 75 | def forward(self, x): 76 | x = self.fc1(x) 77 | x = self.act(x) 78 | x = self.drop(x) 79 | x = self.fc2(x) 80 | x = self.drop(x) 81 | return x 82 | 83 | class Mlp(nn.Module): 84 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 85 | super().__init__() 86 | out_features = out_features or in_features 87 | hidden_features = hidden_features or in_features 88 | self.fc1 = nn.Linear(in_features, hidden_features) 89 | self.act = act_layer() 90 | self.fc2 = nn.Linear(hidden_features, out_features) 91 | self.drop = nn.Dropout(drop) 92 | 93 | def forward(self, x): 94 | x = self.fc1(x) 95 | x = self.act(x) 96 | x = self.drop(x) 97 | x = self.fc2(x) 98 | x = self.drop(x) 99 | return x 100 | 101 | class CBlock(nn.Module): 102 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 103 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 104 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 105 | super().__init__() 106 | self.norm1 = nn.LayerNorm(dim) 107 | self.conv1 = nn.Conv2d(dim, dim, 1) 108 | self.conv2 = nn.Conv2d(dim, dim, 1) 109 | self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 110 | # self.attn = nn.Conv2d(dim, dim, 13, padding=6, groups=dim) 111 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 112 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 113 | self.norm2 = nn.LayerNorm(dim) 114 | mlp_hidden_dim = int(dim * mlp_ratio) 115 | self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 116 | 117 | def forward(self, x, mask=None): 118 | if mask is not None: 119 | x = x + self.drop_path(self.conv2(self.attn(mask * self.conv1(self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))))) 120 | else: 121 | x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))))) 122 | x = x + self.drop_path(self.mlp(self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))) 123 | return x 124 | 125 | 126 | class Attention(nn.Module): 127 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 128 | super().__init__() 129 | self.num_heads = num_heads 130 | head_dim = dim // num_heads 131 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 132 | self.scale = qk_scale or head_dim ** -0.5 133 | 134 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 135 | self.attn_drop = nn.Dropout(attn_drop) 136 | self.proj = nn.Linear(dim, dim) 137 | self.proj_drop = nn.Dropout(proj_drop) 138 | 139 | def forward(self, x): 140 | B, N, C = x.shape 141 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 142 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 143 | 144 | attn = (q @ k.transpose(-2, -1)) * self.scale 145 | attn = attn.softmax(dim=-1) 146 | attn = self.attn_drop(attn) 147 | 148 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 149 | x = self.proj(x) 150 | x = self.proj_drop(x) 151 | return x 152 | 153 | 154 | class Block(nn.Module): 155 | 156 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 157 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 158 | super().__init__() 159 | self.norm1 = norm_layer(dim) 160 | self.attn = Attention( 161 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 162 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 163 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 164 | self.norm2 = norm_layer(dim) 165 | mlp_hidden_dim = int(dim * mlp_ratio) 166 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 167 | 168 | def forward(self, x): 169 | x = x + self.drop_path(self.attn(self.norm1(x))) 170 | x = x + self.drop_path(self.mlp(self.norm2(x))) 171 | return x 172 | 173 | 174 | class PatchEmbed(nn.Module): 175 | """ Image to Patch Embedding 176 | """ 177 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 178 | super().__init__() 179 | img_size = to_2tuple(img_size) 180 | patch_size = to_2tuple(patch_size) 181 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 182 | self.img_size = img_size 183 | self.patch_size = patch_size 184 | self.num_patches = num_patches 185 | 186 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 187 | self.norm = nn.LayerNorm(embed_dim) 188 | self.act = nn.GELU() 189 | def forward(self, x): 190 | B, C, H, W = x.shape 191 | # FIXME look at relaxing size constraints 192 | assert H == self.img_size[0] and W == self.img_size[1], \ 193 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 194 | x = self.proj(x) 195 | x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 196 | return self.act(x) 197 | 198 | 199 | class HybridEmbed(nn.Module): 200 | """ CNN Feature Map Embedding 201 | Extract feature map from CNN, flatten, project to embedding dim. 202 | """ 203 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 204 | super().__init__() 205 | assert isinstance(backbone, nn.Module) 206 | img_size = to_2tuple(img_size) 207 | self.img_size = img_size 208 | self.backbone = backbone 209 | if feature_size is None: 210 | with torch.no_grad(): 211 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 212 | # map for all networks, the feature metadata has reliable channel and stride info, but using 213 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 214 | training = backbone.training 215 | if training: 216 | backbone.eval() 217 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 218 | feature_size = o.shape[-2:] 219 | feature_dim = o.shape[1] 220 | backbone.train(training) 221 | else: 222 | feature_size = to_2tuple(feature_size) 223 | feature_dim = self.backbone.feature_info.channels()[-1] 224 | self.num_patches = feature_size[0] * feature_size[1] 225 | self.proj = nn.Linear(feature_dim, embed_dim) 226 | 227 | def forward(self, x): 228 | x = self.backbone(x)[-1] 229 | x = x.flatten(2).transpose(1, 2) 230 | x = self.proj(x) 231 | return x 232 | 233 | 234 | class ConvViT(nn.Module): 235 | """ Vision Transformer with support for patch or hybrid CNN input stage 236 | """ 237 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 238 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 239 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): 240 | super().__init__() 241 | self.num_classes = num_classes 242 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 243 | 244 | if hybrid_backbone is not None: 245 | self.patch_embed = HybridEmbed( 246 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 247 | else: 248 | self.patch_embed1 = PatchEmbed( 249 | img_size=img_size[0], patch_size=patch_size[0], in_chans=in_chans, embed_dim=embed_dim[0]) 250 | self.patch_embed2 = PatchEmbed( 251 | img_size=img_size[1], patch_size=patch_size[1], in_chans=embed_dim[0], embed_dim=embed_dim[1]) 252 | self.patch_embed3 = PatchEmbed( 253 | img_size=img_size[2], patch_size=patch_size[2], in_chans=embed_dim[1], embed_dim=embed_dim[2]) 254 | num_patches = self.patch_embed3.num_patches 255 | self.patch_embed4 = nn.Linear(embed_dim[2], embed_dim[2]) 256 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[2])) 257 | self.pos_drop = nn.Dropout(p=drop_rate) 258 | 259 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule 260 | self.blocks1 = nn.ModuleList([ 261 | CBlock( 262 | dim=embed_dim[0], num_heads=num_heads, mlp_ratio=mlp_ratio[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 263 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 264 | for i in range(depth[0])]) 265 | self.blocks2 = nn.ModuleList([ 266 | CBlock( 267 | dim=embed_dim[1], num_heads=num_heads, mlp_ratio=mlp_ratio[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 268 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[depth[0] + i], norm_layer=norm_layer) 269 | for i in range(depth[1])]) 270 | self.blocks3 = nn.ModuleList([ 271 | Block( 272 | dim=embed_dim[2], num_heads=num_heads, mlp_ratio=mlp_ratio[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 273 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[depth[0] + depth[1] + i], norm_layer=norm_layer) 274 | for i in range(depth[2])]) 275 | 276 | 277 | self.norm = norm_layer(embed_dim[-1]) 278 | 279 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 280 | #self.repr = nn.Linear(embed_dim, representation_size) 281 | #self.repr_act = nn.Tanh() 282 | 283 | # Classifier head 284 | self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() 285 | 286 | trunc_normal_(self.pos_embed, std=.02) 287 | self.apply(self._init_weights) 288 | 289 | def _init_weights(self, m): 290 | if isinstance(m, nn.Linear): 291 | trunc_normal_(m.weight, std=.02) 292 | if isinstance(m, nn.Linear) and m.bias is not None: 293 | nn.init.constant_(m.bias, 0) 294 | elif isinstance(m, nn.LayerNorm): 295 | nn.init.constant_(m.bias, 0) 296 | nn.init.constant_(m.weight, 1.0) 297 | 298 | @torch.jit.ignore 299 | def no_weight_decay(self): 300 | return {'pos_embed', 'cls_token'} 301 | 302 | def get_classifier(self): 303 | return self.head 304 | 305 | def reset_classifier(self, num_classes, global_pool=''): 306 | self.num_classes = num_classes 307 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 308 | 309 | def forward_features(self, x): 310 | B = x.shape[0] 311 | x = self.patch_embed1(x) 312 | x = self.pos_drop(x) 313 | for blk in self.blocks1: 314 | x = blk(x) 315 | x = self.patch_embed2(x) 316 | for blk in self.blocks2: 317 | x = blk(x) 318 | x = self.patch_embed3(x) 319 | x = x.flatten(2).permute(0, 2, 1) 320 | x = x + self.pos_embed 321 | for blk in self.blocks3: 322 | x = blk(x) 323 | x = self.norm(x) 324 | return x.mean(1) 325 | 326 | def forward(self, x): 327 | x = self.forward_features(x) 328 | x = self.head(x) 329 | return x 330 | 331 | 332 | 333 | --------------------------------------------------------------------------------