├── .gitignore ├── INSTALL.md ├── LICENSE ├── README.md ├── downstream_d2 ├── README.md ├── configs │ ├── Base-RCNN-FPN.yaml │ └── coco_R_50_FPN_CONV_1x_moco_adam.yaml ├── convert-timm-to-d2.py ├── lr_decay.py └── train_net.py ├── downstream_imagenet ├── README.md ├── arg.py ├── data.py ├── lr_decay.py ├── main.py ├── mixup.py ├── models │ ├── __init__.py │ └── convnext_official.py ├── requirements.txt └── util.py ├── downstream_mmdet ├── README.md ├── configs │ ├── _base_ │ │ ├── default_runtime.py │ │ └── models │ │ │ ├── cascade_mask_rcnn_convnext_fpn.py │ │ │ └── mask_rcnn_convnext_fpn.py │ └── convnext_spark │ │ └── mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py ├── mmcv_custom │ ├── __init__.py │ ├── customized_text.py │ ├── layer_decay_optimizer_constructor.py │ └── runner │ │ └── checkpoint.py └── mmdet │ └── models │ └── backbones │ ├── __init__.py │ └── convnext.py └── pretrain ├── README.md ├── decoder.py ├── dist.py ├── encoder.py ├── main.py ├── models ├── __init__.py ├── convnext.py ├── custom.py └── resnet.py ├── requirements.txt ├── sampler.py ├── spark.py ├── utils ├── arg_util.py ├── imagenet.py ├── lamb.py ├── lr_control.py └── misc.py ├── viz_imgs ├── recon.png ├── spconv1.png ├── spconv2.png └── spconv3.png ├── viz_reconstruction.ipynb └── viz_spconv.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | .idea/* 4 | ckpt/ 5 | *.pth 6 | *.log 7 | *.txt 8 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Preparation for pre-training & ImageNet fine-tuning 2 | 3 | ## Pip dependencies 4 | 5 | 1. Prepare a python environment, e.g.: 6 | ```shell script 7 | $ conda create -n spark python=3.8 -y 8 | $ conda activate spark 9 | ``` 10 | 11 | 2. Install `PyTorch` and `timm` (better to use `torch~=1.10`, `torchvision~=0.11`, and `timm==0.5.4`) then other python packages: 12 | ```shell script 13 | $ pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html 14 | $ pip install timm==0.5.4 15 | $ pip install -r requirements.txt 16 | ``` 17 | 18 | It is highly recommended to install these versions to ensure a consistent environment for re-implementation. 19 | 20 | 21 | ## ImageNet preparation 22 | 23 | Prepare the [ImageNet-1k](http://image-net.org/) dataset 24 | - assume the dataset is in `/path/to/imagenet` 25 | - it should look like this: 26 | ``` 27 | /path/to/imagenet/: 28 | train/: 29 | class1: 30 | a_lot_images.jpeg 31 | class2: 32 | a_lot_images.jpeg 33 | val/: 34 | class1: 35 | a_lot_images.jpeg 36 | class2: 37 | a_lot_images.jpeg 38 | ``` 39 | - that argument of `--data_path=/path/to/imagenet` should be passed to the training script introduced later 40 | 41 | 42 | > `PS:` In our implementation, we use pytorch built-in operators to simulate the submanifold sparse convolution in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py) for generality, 43 | due to the fact that many convolution operators (e.g., grouped conv and dilated conv) do not yet have efficient sparse implementations on today's hardware. 44 | If you want to try those sparse convolution, you may refer to [this](https://github.com/facebookresearch/SparseConvNet) sparse convolution library or [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine). 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Keyu Tian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /downstream_d2/README.md: -------------------------------------------------------------------------------- 1 | ## About code isolation 2 | 3 | This `downstream_d2` is isolated from pre-training codes. One can treat this `downstream_d2` as an independent codebase 🛠️. 4 | 5 | 6 | ## Fine-tuned ResNet-50 weights, log files, and performance 7 | 8 |
9 | 10 | [[`weights (pre-trained by SparK)`](https://drive.google.com/file/d/1H8605HbxGvrsu4x4rIoNr-Wkd7JkxFPQ/view?usp=share_link)] 11 | [[`weights (fine-tuned on COCO)`](https://drive.google.com/file/d/1Ue7SiQ1E_AwgtYo56Fm-iUlQPZ8vIwYj/view?usp=share_link)] 12 | [[`metrics.json`](https://drive.google.com/file/d/1wfbUWh4svV8sPWya_0PAhsLHVayDQRCi/view?usp=share_link)] 13 | [[`log.txt`](https://drive.google.com/file/d/11zVo_87pe9DMAmfNQK9FUfyjQWHTRKxV/view?usp=share_link)] 14 | [[`tensorboard file`](https://drive.google.com/file/d/1aM1qj8c3-Uka1dZuYmKhgp1lNJpeMDMl/view?usp=share_link)] 15 |
16 | 17 |

18 | 19 |

20 | 21 | 22 | ## Installation [Detectron2 v0.6](https://github.com/facebookresearch/detectron2/releases/tag/v0.6) before fine-tuning ResNet on COCO 23 | 24 | 25 | 1. Let you in some python environment, e.g.: 26 | ```shell script 27 | $ conda create -n spark python=3.8 -y 28 | $ conda activate spark 29 | ``` 30 | 31 | 2. Install `detectron2==0.6` (e.g., with `torch==1.10.0` and `cuda11.3`): 32 | ```shell script 33 | $ pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html 34 | ``` 35 | 36 | You can also find instructions for different pytorch/cuda versions on [this page](https://github.com/facebookresearch/detectron2/releases/tag/v0.6). 37 | 38 | 39 | 3. Put the COCO dataset folder at `downstream_d2/datasets/coco`. 40 | The folder should follow the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) requried by `Detectron2`, which should look like this: 41 | ``` 42 | downstream_d2/datasets/coco: 43 | annotations/: 44 | captions_train2017.json captions_val2017.json 45 | instances_train2017.json instances_val2017.json 46 | person_keypoints_train2017.json person_keypoints_val2017.json 47 | train2017/: 48 | a_lot_images.jpg 49 | val2017/: 50 | a_lot_images.jpg 51 | ``` 52 | 53 | 54 | ## Training from pre-trained checkpoint 55 | 56 | The script file for COCO fine-tuning (object detection and instance segmentation) is [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py), 57 | which is a modification of [Detectron2's tools/train_net.py](https://github.com/facebookresearch/detectron2/blob/v0.6/tools/train_net.py). 58 | 59 | 60 | Before fine-tuning a ResNet50 pre-trained by SparK, you should first convert our checkpoint file to Detectron2-style `.pkl` file: 61 | 62 | ```shell script 63 | $ cd /path/to/SparK/downstream_d2 64 | $ python3 convert-timm-to-d2.py /some/path/to/resnet50_1kpretrained_timm_style.pth d2-style.pkl 65 | ``` 66 | 67 | For a ResNet50, you should see a log reporting `len(state)==318`: 68 | ```text 69 | [convert] .pkl is generated! (from `/some/path/to/resnet50_1kpretrained_timm_style.pth`, to `d2-style.pkl`, len(state)==318) 70 | ``` 71 | 72 | Then run fine-tuning on single machine with 8 gpus: 73 | 74 | ```shell script 75 | $ cd /path/to/SparK/downstream_d2 76 | $ python3 ./train_net.py --resume --num-gpus 8 --config-file ./configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml \ 77 | MODEL.WEIGHTS d2-style.pkl \ 78 | OUTPUT_DIR 79 | ``` 80 | 81 | For multiple machines, plus these args: 82 | ```shell script 83 | --num-machines --machine-rank --dist-url 84 | ``` 85 | 86 | In `` you'll see the log files generated by `Detectron2`. 87 | 88 | 89 | ## Details: how we modify the official Detectron2's [tools/train_net.py](https://github.com/facebookresearch/detectron2/blob/v0.6/tools/train_net.py) to get our [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py) 90 | 91 | 1. We add two new hyperparameters: 92 | - str `SOLVER.OPTIMIZER`: use 'ADAM' (the same as 'ADAMW') or 'SGD' optimizer 93 | - float `SOLVER.LR_DECAY`: the decay ratio (from 0. to 1.) of layer-wise learning rate decay trick 94 | 95 | 2. We implement layer-wise lr decay in [downstream_d2/lr_decay.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/lr_decay.py). 96 | 97 | 3. We write a script to convert our timm-style pre-trained ResNet weights to Detectron2-style in [downstream_d2/convert-timm-to-d2.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/convert-timm-to-d2.py). 98 | 99 | 4. We also add a hook for logging results to `cfg.OUTPUT_DIR/d2_coco_log.txt`. 100 | 101 | All of our modifications to the original are commented with `# [modification] ...` in [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py) or other files. 102 | -------------------------------------------------------------------------------- /downstream_d2/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /downstream_d2/configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "" 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | 7 | MASK_ON: True 8 | BACKBONE: 9 | FREEZE_AT: 0 10 | RESNETS: 11 | DEPTH: 50 12 | NORM: "SyncBN" 13 | STRIDE_IN_1X1: False 14 | FPN: 15 | NORM: "SyncBN" 16 | ROI_BOX_HEAD: 17 | NAME: "FastRCNNConvFCHead" 18 | NUM_FC: 1 19 | NUM_CONV: 4 20 | POOLER_RESOLUTION: 7 21 | NORM: "SyncBN" 22 | ROI_MASK_HEAD: 23 | NAME: "MaskRCNNConvUpsampleHead" 24 | NUM_CONV: 4 25 | POOLER_RESOLUTION: 14 26 | NORM: "SyncBN" 27 | 28 | INPUT: 29 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896) 30 | CROP: 31 | ENABLED: False 32 | TYPE: "absolute_range" 33 | SIZE: (384, 600) 34 | FORMAT: "RGB" 35 | TEST: 36 | EVAL_PERIOD: 5000 37 | PRECISE_BN: 38 | ENABLED: True 39 | 40 | SOLVER: 41 | STEPS: (60000, 80000) 42 | MAX_ITER: 90000 43 | GAMMA: 0.25 44 | BASE_LR: 0.00025 45 | WARMUP_FACTOR: 0.01 46 | WARMUP_ITERS: 1000 47 | WEIGHT_DECAY: 0.0001 48 | CHECKPOINT_PERIOD: 5000 49 | CLIP_GRADIENTS: 50 | ENABLED: False 51 | CLIP_TYPE: "value" 52 | CLIP_VALUE: 1.0 53 | NORM_TYPE: 2.0 54 | 55 | # compared to standard detectron2, we add these two new configurations: 56 | OPTIMIZER: "ADAMW" 57 | LR_DECAY: 0.6 58 | -------------------------------------------------------------------------------- /downstream_d2/convert-timm-to-d2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | # Copyright (c) ByteDance, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import pickle as pkl 10 | 11 | import torch 12 | 13 | 14 | # we use `timm.models.ResNet` in pre-training, so keys are timm-style 15 | def timm_resnet_to_detectron2_resnet(source_file, target_file): 16 | pretrained: dict = torch.load(source_file, map_location='cpu') 17 | for mod_k in {'state_dict', 'state', 'module', 'model'}: 18 | if mod_k in pretrained: 19 | pretrained = pretrained[mod_k] 20 | if any(k.startswith('module.encoder_q.') for k in pretrained.keys()): 21 | pretrained = {k.replace('module.encoder_q.', ''): v for k, v in pretrained.items() if k.startswith('module.encoder_q.')} 22 | 23 | pkl_state = {} 24 | for k, v in pretrained.items(): # convert resnet's keys from timm-style to d2-style 25 | if 'layer' not in k: 26 | k = 'stem.' + k 27 | for t in [1, 2, 3, 4]: 28 | k = k.replace(f'layer{t}', f'res{t+1}') 29 | for t in [1, 2, 3]: 30 | k = k.replace(f'bn{t}', f'conv{t}.norm') 31 | k = k.replace('downsample.0', 'shortcut') 32 | k = k.replace('downsample.1', 'shortcut.norm') 33 | 34 | pkl_state[k] = v.detach().numpy() 35 | 36 | with open(target_file, 'wb') as fp: 37 | print(f'[convert] .pkl is generated! (from `{source_file}`, to `{target_file}`, len(state)=={len(pkl_state)})') 38 | pkl.dump({'model': pkl_state, '__author__': 'https://github.com/keyu-tian/SparK', 'matching_heuristics': True}, fp) 39 | 40 | 41 | if __name__ == '__main__': 42 | import sys 43 | timm_resnet_to_detectron2_resnet(sys.argv[1], sys.argv[2]) 44 | -------------------------------------------------------------------------------- /downstream_d2/lr_decay.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Set, Optional, Callable, Any 2 | import torch 3 | import copy 4 | 5 | from detectron2.solver.build import reduce_param_groups 6 | 7 | 8 | def lr_factor_func(para_name: str, is_resnet50, dec: float, debug=False) -> float: 9 | if dec == 0: 10 | dec = 1. 11 | 12 | N = 5 if is_resnet50 else 11 13 | if '.stem.' in para_name: 14 | layer_id = 0 15 | elif '.res' in para_name: 16 | ls = para_name.split('.res')[1].split('.') 17 | if ls[0].isnumeric() and ls[1].isnumeric(): 18 | stage_id, block_id = int(ls[0]), int(ls[1]) 19 | if stage_id == 2: # res2 20 | layer_id = 1 21 | elif stage_id == 3: # res3 22 | layer_id = 2 23 | elif stage_id == 4: # res4 24 | layer_id = 3 + block_id // 3 # 3, 4 or 4, 5 25 | else: # res5 26 | layer_id = N 27 | else: 28 | assert para_name.startswith('roi_heads.res5.norm.') 29 | layer_id = N + 1 # roi_heads.res5.norm.weight and roi_heads.res5.norm.bias of C4 30 | else: 31 | layer_id = N + 1 32 | 33 | exp = N + 1 - layer_id 34 | return f'{dec:g} ** {exp}' if debug else dec ** exp 35 | 36 | 37 | # [modification] see: https://github.com/facebookresearch/detectron2/blob/v0.6/detectron2/solver/build.py#L134 38 | # add the `lr_factor_func` to implement lr decay 39 | def get_default_optimizer_params( 40 | model: torch.nn.Module, 41 | base_lr: Optional[float] = None, 42 | weight_decay: Optional[float] = None, 43 | weight_decay_norm: Optional[float] = None, 44 | bias_lr_factor: Optional[float] = 1.0, 45 | weight_decay_bias: Optional[float] = None, 46 | lr_factor_func: Optional[Callable] = None, 47 | overrides: Optional[Dict[str, Dict[str, float]]] = None, 48 | ) -> List[Dict[str, Any]]: 49 | """ 50 | Get default param list for optimizer, with support for a few types of 51 | overrides. If no overrides needed, this is equivalent to `model.parameters()`. 52 | 53 | Args: 54 | base_lr: lr for every group by default. Can be omitted to use the one in optimizer. 55 | weight_decay: weight decay for every group by default. Can be omitted to use the one 56 | in optimizer. 57 | weight_decay_norm: override weight decay for params in normalization layers 58 | bias_lr_factor: multiplier of lr for bias parameters. 59 | weight_decay_bias: override weight decay for bias parameters. 60 | lr_factor_func: function to calculate lr decay rate by mapping the parameter names to 61 | corresponding lr decay rate. Note that setting this option requires 62 | also setting ``base_lr``. 63 | overrides: if not `None`, provides values for optimizer hyperparameters 64 | (LR, weight decay) for module parameters with a given name; e.g. 65 | ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and 66 | weight decay values for all module parameters named `embedding`. 67 | 68 | For common detection models, ``weight_decay_norm`` is the only option 69 | needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings 70 | from Detectron1 that are not found useful. 71 | 72 | Example: 73 | :: 74 | torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), 75 | lr=0.01, weight_decay=1e-4, momentum=0.9) 76 | """ 77 | if overrides is None: 78 | overrides = {} 79 | defaults = {} 80 | if base_lr is not None: 81 | defaults["lr"] = base_lr 82 | if weight_decay is not None: 83 | defaults["weight_decay"] = weight_decay 84 | bias_overrides = {} 85 | if bias_lr_factor is not None and bias_lr_factor != 1.0: 86 | # NOTE: unlike Detectron v1, we now by default make bias hyperparameters 87 | # exactly the same as regular weights. 88 | if base_lr is None: 89 | raise ValueError("bias_lr_factor requires base_lr") 90 | bias_overrides["lr"] = base_lr * bias_lr_factor 91 | if weight_decay_bias is not None: 92 | bias_overrides["weight_decay"] = weight_decay_bias 93 | if len(bias_overrides): 94 | if "bias" in overrides: 95 | raise ValueError("Conflicting overrides for 'bias'") 96 | overrides["bias"] = bias_overrides 97 | if lr_factor_func is not None: 98 | if base_lr is None: 99 | raise ValueError("lr_factor_func requires base_lr") 100 | norm_module_types = ( 101 | torch.nn.BatchNorm1d, 102 | torch.nn.BatchNorm2d, 103 | torch.nn.BatchNorm3d, 104 | torch.nn.SyncBatchNorm, 105 | # NaiveSyncBatchNorm inherits from BatchNorm2d 106 | torch.nn.GroupNorm, 107 | torch.nn.InstanceNorm1d, 108 | torch.nn.InstanceNorm2d, 109 | torch.nn.InstanceNorm3d, 110 | torch.nn.LayerNorm, 111 | torch.nn.LocalResponseNorm, 112 | ) 113 | params: List[Dict[str, Any]] = [] 114 | memo: Set[torch.nn.parameter.Parameter] = set() 115 | for module_name, module in model.named_modules(): 116 | for module_param_name, value in module.named_parameters(recurse=False): 117 | if not value.requires_grad: 118 | continue 119 | # Avoid duplicating parameters 120 | if value in memo: 121 | continue 122 | memo.add(value) 123 | 124 | hyperparams = copy.copy(defaults) 125 | if isinstance(module, norm_module_types) and weight_decay_norm is not None: 126 | hyperparams["weight_decay"] = weight_decay_norm 127 | if lr_factor_func is not None: 128 | hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}") 129 | 130 | hyperparams.update(overrides.get(module_param_name, {})) 131 | params.append({"params": [value], **hyperparams}) 132 | return reduce_param_groups(params) 133 | -------------------------------------------------------------------------------- /downstream_imagenet/README.md: -------------------------------------------------------------------------------- 1 | ## About code isolation 2 | 3 | This `downstream_imagenet` is isolated from pre-training codes. One can treat this `downstream_imagenet` as an independent codebase 🛠️. 4 | 5 | 6 | ## Preparation for ImageNet-1k fine-tuning 7 | 8 | See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset. 9 | 10 | **Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).** 11 | 12 | 13 | ## Fine-tuning on ImageNet-1k from pre-trained weights 14 | 15 | Run [/downstream_imagenet/main.py](/downstream_imagenet/main.py) via `torchrun`. 16 | **It is required to specify** the ImageNet data folder (`--data_path`), your experiment name & log dir (`--exp_name` and `--exp_dir`, automatically created if not exists), the model name (`--model`, valid choices see the keys of 'HP_DEFAULT_VALUES' in [/downstream_imagenet/arg.py line14](/downstream_imagenet/arg.py#L14)), and the pretrained weight file `--resume_from` to run fine-tuning. 17 | 18 | All the other configurations have their default values, listed in [/downstream_imagenet/arg.py#L13](/downstream_imagenet/arg.py#L13). 19 | You can overwrite any defaults by `--bs=1024` or something like that. 20 | 21 | 22 | Here is an example to pretrain a ConvNeXt-Small on an 8-GPU single machine: 23 | ```shell script 24 | $ cd /path/to/SparK/downstream_imagenet 25 | $ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port= main.py \ 26 | --data_path=/path/to/imagenet --exp_name= --exp_dir=/path/to/logdir \ 27 | --model=convnext_small --resume_from=/some/path/to/convnextS_1kpretrained_official_style.pth 28 | ``` 29 | 30 | For multiple machines, change the `--nnodes` and `--master_addr` to your configurations. E.g.: 31 | ```shell script 32 | $ torchrun --nproc_per_node=8 --nnodes= --node_rank= --master_address= --master_port= main.py \ 33 | ... 34 | ``` 35 | 36 | 37 | ## Logging 38 | 39 | See files under `--exp_dir` to track your experiment: 40 | 41 | - `_1kfinetuned_last.pth`: the latest model weights 42 | - `_1kfinetuned_best.pth`: model weights with the highest acc 43 | - `_1kfinetuned_best_ema.pth`: EMA weights with the highest acc 44 | - `finetune_log.txt`: records some important information such as: 45 | - `git_commit_id`: git version 46 | - `cmd`: all arguments passed to the script 47 | 48 | It also reports training loss/acc, best evaluation acc, and remaining time at each epoch. 49 | 50 | - `tensorboard_log/`: saves a lot of tensorboard logs, you can visualize accuracies, loss values, learning rates, gradient norms and more things via `tensorboard --logdir /path/to/this/tensorboard_log/ --port 23333`. 51 | 52 | ## Resuming 53 | 54 | Use `--resume_from` again, like `--resume_from=path/to/_1kfinetuned_last.pth`. 55 | -------------------------------------------------------------------------------- /downstream_imagenet/arg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | import sys 10 | 11 | from tap import Tap 12 | 13 | HP_DEFAULT_NAMES = ['bs', 'ep', 'wp_ep', 'opt', 'base_lr', 'lr_scale', 'wd', 'mixup', 'rep_aug', 'drop_path', 'ema'] 14 | HP_DEFAULT_VALUES = { 15 | 'convnext_small': (4096, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999), 16 | 'convnext_base': (4096, 400, 20, 'adam', 0.0001, 0.7, 0.01, 0.8, 3, 0.4, 0.9999), 17 | 'convnext_large': (4096, 200, 10, 'adam', 0.0001, 0.7, 0.02, 0.8, 3, 0.5, 0.9999), 18 | 'convnext_large_384': (1024, 200, 20, 'adam', 0.00006, 0.7, 0.01, 0.8, 3, 0.5, 0.99995), 19 | 20 | 'resnet50': (2048, 300, 5, 'lamb', 0.002, 0.7, 0.02, 0.1, 0, 0.05, 0.9999), 21 | 'resnet101': (2048, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999), 22 | 'resnet152': (2048, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999), 23 | 'resnet200': (2048, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999), 24 | } 25 | 26 | 27 | class FineTuneArgs(Tap): 28 | # environment 29 | exp_name: str 30 | exp_dir: str 31 | data_path: str 32 | model: str 33 | resume_from: str = '' # resume from some checkpoint.pth 34 | 35 | img_size: int = 224 36 | dataloader_workers: int = 8 37 | 38 | # ImageNet classification fine-tuning hyperparameters; see `HP_DEFAULT_VALUES` above for detailed default values 39 | # - batch size, epoch 40 | bs: int = 0 # global batch size (== batch_size_per_gpu * num_gpus) 41 | ep: int = 0 # number of epochs 42 | wp_ep: int = 0 # epochs for warmup 43 | 44 | # - optimization 45 | opt: str = '' # optimizer; 'adam' or 'lamb' 46 | base_lr: float = 0. # lr == base_lr * (bs) 47 | lr_scale: float = 0. # see file `lr_decay.py` for more details 48 | clip: int = -1 # use gradient clipping if clip > 0 49 | 50 | # - regularization tricks 51 | wd: float = 0. # weight decay 52 | mixup: float = 0. # use mixup if mixup > 0 53 | rep_aug: int = 0 # use repeated augmentation if rep_aug > 0 54 | drop_path: float = 0. # drop_path ratio 55 | 56 | # - other tricks 57 | ema: float = 0. # use EMA if ema > 0 58 | sbn: bool = True # use SyncBatchNorm 59 | 60 | # NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically 61 | lr: float = None 62 | batch_size_per_gpu: int = 0 63 | glb_batch_size: int = 0 64 | device: str = 'cpu' 65 | world_size: int = 1 66 | global_rank: int = 0 67 | local_rank: int = 0 # we DO USE this arg 68 | is_master: bool = False 69 | is_local_master: bool = False 70 | cmd: str = ' '.join(sys.argv[1:]) 71 | commit_id: str = os.popen(f'git rev-parse HEAD').read().strip() 72 | commit_msg: str = os.popen(f'git log -1').read().strip().splitlines()[-1].strip() 73 | log_txt_name: str = '{args.exp_dir}/pretrain_log.txt' 74 | tb_lg_dir: str = '' # tensorboard log directory 75 | 76 | train_loss: float = 0. 77 | train_acc: float = 0. 78 | best_val_acc: float = 0. 79 | cur_ep: str = '' 80 | remain_time: str = '' 81 | finish_time: str = '' 82 | first_logging: bool = True 83 | 84 | def log_epoch(self): 85 | if not self.is_local_master: 86 | return 87 | 88 | if self.first_logging: 89 | self.first_logging = False 90 | with open(self.log_txt_name, 'w') as fp: 91 | json.dump({ 92 | 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg, 93 | 'model': self.model, 94 | }, fp) 95 | fp.write('\n\n') 96 | 97 | with open(self.log_txt_name, 'a') as fp: 98 | json.dump({ 99 | 'cur_ep': self.cur_ep, 100 | 'train_L': self.train_loss, 'train_acc': self.train_acc, 101 | 'best_val_acc': self.best_val_acc, 102 | 'rema': self.remain_time, 'fini': self.finish_time, 103 | }, fp) 104 | fp.write('\n') 105 | 106 | 107 | def get_args(world_size, global_rank, local_rank, device) -> FineTuneArgs: 108 | # parse args and prepare directories 109 | args = FineTuneArgs(explicit_bool=True).parse_args() 110 | d_name, b_name = os.path.dirname(os.path.abspath(args.exp_dir)), os.path.basename(os.path.abspath(args.exp_dir)) 111 | b_name = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in b_name) 112 | args.exp_dir = os.path.join(d_name, b_name) 113 | os.makedirs(args.exp_dir, exist_ok=True) 114 | args.log_txt_name = os.path.join(args.exp_dir, 'finetune_log.txt') 115 | 116 | args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log') 117 | try: os.makedirs(args.tb_lg_dir, exist_ok=True) 118 | except: pass 119 | 120 | # fill in args.bs, args.ep, etc. with their default values (if their values are not explicitly specified, i.e., if bool(they) == False) 121 | if args.model == 'convnext_large' and args.img_size == 384: 122 | default_values = HP_DEFAULT_VALUES['convnext_large_384'] 123 | else: 124 | default_values = HP_DEFAULT_VALUES[args.model] 125 | for k, v in zip(HP_DEFAULT_NAMES, default_values): 126 | if bool(getattr(args, k)) == False: 127 | setattr(args, k, v) 128 | 129 | # update other runtime args 130 | args.world_size, args.global_rank, args.local_rank, args.device = world_size, global_rank, local_rank, device 131 | args.is_master = global_rank == 0 132 | args.is_local_master = local_rank == 0 133 | args.batch_size_per_gpu = args.bs // world_size 134 | args.glb_batch_size = args.batch_size_per_gpu * world_size 135 | args.lr = args.base_lr * args.glb_batch_size / 256 136 | 137 | return args 138 | -------------------------------------------------------------------------------- /downstream_imagenet/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import random 9 | import time 10 | 11 | import PIL.Image as PImage 12 | import numpy as np 13 | import torch 14 | import torchvision 15 | from timm.data import AutoAugment as TimmAutoAugment 16 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform 17 | from timm.data.distributed_sampler import RepeatAugSampler 18 | from timm.data.transforms_factory import transforms_imagenet_eval 19 | from torch.utils.data import DataLoader 20 | from torch.utils.data.sampler import Sampler 21 | from torchvision.transforms import AutoAugment as TorchAutoAugment 22 | from torchvision.transforms import transforms, TrivialAugmentWide 23 | 24 | try: 25 | from torchvision.transforms import InterpolationMode 26 | interpolation = InterpolationMode.BICUBIC 27 | except: 28 | import PIL 29 | interpolation = PIL.Image.BICUBIC 30 | 31 | 32 | def create_classification_dataset(data_path, img_size, rep_aug, workers, batch_size_per_gpu, world_size, global_rank): 33 | import warnings 34 | warnings.filterwarnings('ignore', category=UserWarning) 35 | 36 | mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 37 | trans_train = create_transform( 38 | is_training=True, input_size=img_size, 39 | auto_augment='v0', interpolation='bicubic', re_prob=0.25, re_mode='pixel', re_count=1, 40 | mean=mean, std=std, 41 | ) 42 | if img_size < 384: 43 | for i, t in enumerate(trans_train.transforms): 44 | if isinstance(t, (TorchAutoAugment, TimmAutoAugment)): 45 | trans_train.transforms[i] = TrivialAugmentWide(interpolation=interpolation) 46 | break 47 | trans_val = transforms_imagenet_eval(img_size=img_size, interpolation='bicubic', crop_pct=0.95, mean=mean, std=std) 48 | else: 49 | trans_val = transforms.Compose([ 50 | transforms.Resize((img_size, img_size), interpolation=interpolation), 51 | transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), 52 | ]) 53 | print_transform(trans_train, '[train]') 54 | print_transform(trans_val, '[val]') 55 | 56 | imagenet_folder = os.path.abspath(data_path) 57 | for postfix in ('train', 'val'): 58 | if imagenet_folder.endswith(postfix): 59 | imagenet_folder = imagenet_folder[:-len(postfix)] 60 | dataset_train = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'train'), trans_train) 61 | dataset_val = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'val'), trans_val) 62 | 63 | if rep_aug: 64 | print(f'[dataset] using repeated augmentation: count={rep_aug}') 65 | train_sp = RepeatAugSampler(dataset_train, shuffle=True, num_repeats=rep_aug) 66 | else: 67 | train_sp = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True, drop_last=True) 68 | 69 | loader_train = DataLoader( 70 | dataset=dataset_train, num_workers=workers, pin_memory=True, 71 | batch_size=batch_size_per_gpu, sampler=train_sp, persistent_workers=workers > 0, 72 | worker_init_fn=worker_init_fn, 73 | ) 74 | iters_train = len(loader_train) 75 | print(f'[dataset: train] bs={world_size}x{batch_size_per_gpu}={world_size * batch_size_per_gpu}, num_iters={iters_train}') 76 | 77 | val_ratio = 2 78 | loader_val = DataLoader( 79 | dataset=dataset_val, num_workers=workers, pin_memory=True, 80 | batch_sampler=DistInfiniteBatchSampler(world_size, global_rank, len(dataset_val), glb_batch_size=val_ratio * batch_size_per_gpu, filling=False, shuffle=False), 81 | worker_init_fn=worker_init_fn, 82 | ) 83 | iters_val = len(loader_val) 84 | print(f'[dataset: val] bs={world_size}x{val_ratio * batch_size_per_gpu}={val_ratio * world_size * batch_size_per_gpu}, num_iters={iters_val}') 85 | 86 | time.sleep(3) 87 | warnings.resetwarnings() 88 | return loader_train, iters_train, iter(loader_val), iters_val 89 | 90 | 91 | def worker_init_fn(worker_id): 92 | # see: https://pytorch.org/docs/stable/notes/randomness.html#dataloader 93 | worker_seed = torch.initial_seed() % 2 ** 32 94 | np.random.seed(worker_seed) 95 | random.seed(worker_seed) 96 | 97 | 98 | def print_transform(transform, s): 99 | print(f'Transform {s} = ') 100 | for t in transform.transforms: 101 | print(t) 102 | print('---------------------------\n') 103 | 104 | 105 | class DistInfiniteBatchSampler(Sampler): 106 | def __init__(self, world_size, global_rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True): 107 | assert glb_batch_size % world_size == 0 108 | self.world_size, self.rank = world_size, global_rank 109 | self.dataset_len = dataset_len 110 | self.glb_batch_size = glb_batch_size 111 | self.batch_size = glb_batch_size // world_size 112 | 113 | self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size 114 | self.filling = filling 115 | self.shuffle = shuffle 116 | self.epoch = 0 117 | self.seed = seed 118 | self.indices = self.gener_indices() 119 | 120 | def gener_indices(self): 121 | global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 122 | if self.shuffle: 123 | g = torch.Generator() 124 | g.manual_seed(self.epoch + self.seed) 125 | global_indices = torch.randperm(self.dataset_len, generator=g) 126 | else: 127 | global_indices = torch.arange(self.dataset_len) 128 | filling = global_max_p - global_indices.shape[0] 129 | if filling > 0 and self.filling: 130 | global_indices = torch.cat((global_indices, global_indices[:filling])) 131 | global_indices = tuple(global_indices.numpy().tolist()) 132 | 133 | seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int) 134 | local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]] 135 | self.max_p = len(local_indices) 136 | return local_indices 137 | 138 | def __iter__(self): 139 | self.epoch = 0 140 | while True: 141 | self.epoch += 1 142 | p, q = 0, 0 143 | while p < self.max_p: 144 | q = p + self.batch_size 145 | yield self.indices[p:q] 146 | p = q 147 | if self.shuffle: 148 | self.indices = self.gener_indices() 149 | 150 | def __len__(self): 151 | return self.iters_per_ep 152 | -------------------------------------------------------------------------------- /downstream_imagenet/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from pprint import pformat 9 | 10 | 11 | def lr_wd_annealing(optimizer, peak_lr, wd, cur_it, wp_it, max_it): 12 | wp_it = round(wp_it) 13 | if cur_it < wp_it: 14 | cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it 15 | else: 16 | ratio = (cur_it - wp_it) / (max_it - 1 - wp_it) 17 | cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio)) 18 | 19 | min_lr, max_lr = cur_lr, cur_lr 20 | min_wd, max_wd = wd, wd 21 | for param_group in optimizer.param_groups: 22 | scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned 23 | min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr) 24 | scaled_wd = param_group['weight_decay'] = wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned 25 | min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd) 26 | return min_lr, max_lr, min_wd, max_wd 27 | 28 | 29 | def get_param_groups(model, nowd_keys=(), lr_scale=0.0): 30 | using_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0.0 < lr_scale < 1.0 31 | print(f'[get_ft_param_groups][lr decay] using_lr_scale={using_lr_scale}, ft_lr_scale={lr_scale}') 32 | para_groups, para_groups_dbg = {}, {} 33 | 34 | for name, para in model.named_parameters(): 35 | if not para.requires_grad: 36 | continue # frozen weights 37 | if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys): 38 | wd_scale, group_name = 0., 'no_decay' 39 | else: 40 | wd_scale, group_name = 1., 'decay' 41 | 42 | if using_lr_scale: 43 | layer_id, scale_exp = model.get_layer_id_and_scale_exp(name) 44 | group_name = f'layer{layer_id}_' + group_name 45 | this_lr_scale = lr_scale ** scale_exp 46 | dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]' 47 | else: 48 | this_lr_scale = 1 49 | dbg = f'[no scale]' 50 | 51 | if group_name not in para_groups: 52 | para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': this_lr_scale} 53 | para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': dbg} 54 | para_groups[group_name]['params'].append(para) 55 | para_groups_dbg[group_name]['params'].append(name) 56 | 57 | for g in para_groups_dbg.values(): 58 | g['params'] = pformat(', '.join(g['params']), width=200) 59 | 60 | print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n') 61 | return list(para_groups.values()) 62 | -------------------------------------------------------------------------------- /downstream_imagenet/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import datetime 8 | import time 9 | 10 | import torch 11 | import torch.distributed as tdist 12 | from timm.utils import ModelEmaV2 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from arg import get_args, FineTuneArgs 16 | from models import ConvNeXt, ResNet 17 | __for_timm_registration = ConvNeXt, ResNet 18 | from lr_decay import lr_wd_annealing 19 | from util import init_distributed_environ, create_model_opt, load_checkpoint, save_checkpoint 20 | from data import create_classification_dataset 21 | 22 | 23 | def main_ft(): 24 | world_size, global_rank, local_rank, device = init_distributed_environ() 25 | args: FineTuneArgs = get_args(world_size, global_rank, local_rank, device) 26 | print(f'initial args:\n{str(args)}') 27 | args.log_epoch() 28 | 29 | criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer = create_model_opt(args) 30 | ep_start, performance_desc = load_checkpoint(args.resume_from, model_without_ddp, model_ema, optimizer) 31 | 32 | if ep_start >= args.ep: # load from a complete checkpoint file 33 | print(f' [*] [FT already done] Max/Last Acc: {performance_desc}') 34 | else: 35 | tb_lg = SummaryWriter(args.tb_lg_dir) if args.is_master else None 36 | loader_train, iters_train, iterator_val, iters_val = create_classification_dataset( 37 | args.data_path, args.img_size, args.rep_aug, 38 | args.dataloader_workers, args.batch_size_per_gpu, args.world_size, args.global_rank 39 | ) 40 | 41 | # train & eval 42 | tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model) 43 | max_acc = last_acc 44 | max_acc_e = last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)[-1] 45 | print(f'[fine-tune] initial acc={last_acc:.2f}, ema={last_acc_e:.2f}') 46 | 47 | ep_eval = set(range(0, args.ep//3, 5)) | set(range(args.ep//3, args.ep)) 48 | print(f'[FT start] ep_eval={sorted(ep_eval)} ') 49 | print(f'[FT start] from ep{ep_start}') 50 | 51 | params_req_grad = [p for p in model.parameters() if p.requires_grad] 52 | ft_start_time = time.time() 53 | for ep in range(ep_start, args.ep): 54 | ep_start_time = time.time() 55 | if hasattr(loader_train, 'sampler') and hasattr(loader_train.sampler, 'set_epoch'): 56 | loader_train.sampler.set_epoch(ep) 57 | if 0 <= ep <= 3: 58 | print(f'[loader_train.sampler.set_epoch({ep})]') 59 | 60 | train_loss, train_acc = fine_tune_one_epoch(ep, args, tb_lg, loader_train, iters_train, criterion, mixup_fn, model, model_ema, optimizer, params_req_grad) 61 | if ep in ep_eval: 62 | eval_start_time = time.time() 63 | tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model) 64 | tot_pred_e, last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module) 65 | eval_cost = round(time.time() - eval_start_time, 2) 66 | performance_desc = f'Max (Last) Acc: {max(max_acc, last_acc):.2f} ({last_acc:.2f} o {tot_pred}) EMA: {max(max_acc_e, last_acc_e):.2f} ({last_acc_e:.2f} o {tot_pred_e})' 67 | states = model_without_ddp.state_dict(), model_ema.module.state_dict(), optimizer.state_dict() 68 | if last_acc > max_acc: 69 | max_acc = last_acc 70 | save_checkpoint(f'{args.model}_1kfinetuned_best.pth', args, ep, performance_desc, *states) 71 | if last_acc_e > max_acc_e: 72 | max_acc_e = last_acc_e 73 | save_checkpoint(f'{args.model}_1kfinetuned_best_ema.pth', args, ep, performance_desc, *states) 74 | save_checkpoint(f'{args.model}_1kfinetuned_last.pth', args, ep, performance_desc, *states) 75 | else: 76 | eval_cost = '-' 77 | 78 | ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost 79 | remain_secs = (args.ep-1 - ep) * ep_cost 80 | remain_time = datetime.timedelta(seconds=round(remain_secs)) 81 | finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)) 82 | print(f'[ep{ep}/{args.ep}] {performance_desc} Ep cost: {ep_cost}s, Ev cost: {eval_cost}, Remain: {remain_time}, Finish @ {finish_time}') 83 | args.cur_ep = f'{ep + 1}/{args.ep}' 84 | args.remain_time, args.finish_time = str(remain_time), str(finish_time) 85 | args.train_loss, args.train_acc, args.best_val_acc = train_loss, train_acc, max(max_acc, max_acc_e) 86 | args.log_epoch() 87 | 88 | if args.is_master: 89 | tb_lg.add_scalar(f'ft_train/ep_loss', train_loss, ep) 90 | tb_lg.add_scalar(f'ft_eval/max_acc', max_acc, ep) 91 | tb_lg.add_scalar(f'ft_eval/last_acc', last_acc, ep) 92 | tb_lg.add_scalar(f'ft_eval/max_acc_ema', max_acc_e, ep) 93 | tb_lg.add_scalar(f'ft_eval/last_acc_ema', last_acc_e, ep) 94 | tb_lg.add_scalar(f'ft_z_burnout/rest_hours', round(remain_secs/60/60, 2), ep) 95 | tb_lg.flush() 96 | 97 | # finish fine-tuning 98 | result_acc = max(max_acc, max_acc_e) 99 | if args.is_master: 100 | tb_lg.add_scalar('ft_result/result_acc', result_acc, ep_start) 101 | tb_lg.add_scalar('ft_result/result_acc', result_acc, args.ep) 102 | tb_lg.flush() 103 | tb_lg.close() 104 | print(f'final args:\n{str(args)}') 105 | print('\n\n') 106 | print(f' [*] [FT finished] {performance_desc} Total Cost: {(time.time() - ft_start_time) / 60 / 60:.1f}h\n') 107 | print(f' [*] [FT finished] max(max_acc, max_acc_e)={result_acc} EMA better={max_acc_e>max_acc}') 108 | print('\n\n') 109 | time.sleep(10) 110 | 111 | args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time())) 112 | args.log_epoch() 113 | 114 | 115 | def fine_tune_one_epoch(ep, args: FineTuneArgs, tb_lg: SummaryWriter, loader_train, iters_train, criterion, mixup_fn, model, model_ema: ModelEmaV2, optimizer, params_req_grad): 116 | model.train() 117 | tot_loss = tot_acc = 0.0 118 | log_freq = max(1, round(iters_train * 0.7)) 119 | ep_start_time = time.time() 120 | for it, (inp, tar) in enumerate(loader_train): 121 | # adjust lr and wd 122 | cur_it = it + ep * iters_train 123 | min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, cur_it, args.wp_ep * iters_train, args.ep * iters_train) 124 | 125 | # forward 126 | inp = inp.to(args.device, non_blocking=True) 127 | raw_tar = tar = tar.to(args.device, non_blocking=True) 128 | if mixup_fn is not None: 129 | inp, tar, raw_tar = mixup_fn(inp, tar) 130 | oup = model(inp) 131 | pred = oup.data.argmax(dim=1) 132 | if mixup_fn is None: 133 | acc = pred.eq(tar).float().mean().item() * 100 134 | tot_acc += acc 135 | else: 136 | acc = (pred.eq(raw_tar) | pred.eq(raw_tar.flip(0))).float().mean().item() * 100 137 | tot_acc += acc 138 | 139 | # backward 140 | optimizer.zero_grad() 141 | loss = criterion(oup, tar) 142 | loss.backward() 143 | loss = loss.item() 144 | tot_loss += loss 145 | if args.clip > 0: 146 | orig_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item() 147 | else: 148 | orig_norm = None 149 | optimizer.step() 150 | model_ema.update(model) 151 | torch.cuda.synchronize() 152 | 153 | # log 154 | if args.is_master and cur_it % log_freq == 0: 155 | tb_lg.add_scalar(f'ft_train/it_loss', loss, cur_it) 156 | tb_lg.add_scalar(f'ft_train/it_acc', acc, cur_it) 157 | tb_lg.add_scalar(f'ft_hp/min_lr', min_lr, cur_it), tb_lg.add_scalar(f'ft_hp/max_lr', max_lr, cur_it) 158 | tb_lg.add_scalar(f'ft_hp/min_wd', min_wd, cur_it), tb_lg.add_scalar(f'ft_hp/max_wd', max_wd, cur_it) 159 | if orig_norm is not None: 160 | tb_lg.add_scalar(f'ft_hp/orig_norm', orig_norm, cur_it) 161 | 162 | if it in [3, iters_train//2, iters_train-1]: 163 | remain_secs = (iters_train-1 - it) * (time.time() - ep_start_time) / (it + 1) 164 | remain_time = datetime.timedelta(seconds=round(remain_secs)) 165 | print(f'[ep{ep} it{it:3d}/{iters_train}] L: {loss:.4f} Acc: {acc:.2f} lr: {min_lr:.1e}~{max_lr:.1e} Remain: {remain_time}') 166 | 167 | return tot_loss / iters_train, tot_acc / iters_train 168 | 169 | 170 | @torch.no_grad() 171 | def evaluate(dev, iterator_val, iters_val, model): 172 | training = model.training 173 | model.train(False) 174 | tot_pred, tot_correct = 0., 0. 175 | for _ in range(iters_val): 176 | inp, tar = next(iterator_val) 177 | tot_pred += tar.shape[0] 178 | inp = inp.to(dev, non_blocking=True) 179 | tar = tar.to(dev, non_blocking=True) 180 | oup = model(inp) 181 | tot_correct += oup.argmax(dim=1).eq(tar).sum().item() 182 | model.train(training) 183 | t = torch.tensor([tot_pred, tot_correct]).to(dev) 184 | tdist.all_reduce(t) 185 | return t[0].item(), (t[1] / t[0]).item() * 100. 186 | 187 | 188 | if __name__ == '__main__': 189 | main_ft() 190 | -------------------------------------------------------------------------------- /downstream_imagenet/mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This file is a modified version of timm.data.Mixup 8 | # Fixed error of "Batch size should be even when using this" 9 | 10 | """ Mixup and Cutmix 11 | 12 | Papers: 13 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 14 | 15 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 16 | 17 | Code Reference: 18 | CutMix: https://github.com/clovaai/CutMix-PyTorch 19 | 20 | Hacked together by / Copyright 2019, Ross Wightman 21 | """ 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 27 | x = x.long().view(-1, 1) 28 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 29 | 30 | 31 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 32 | off_value = smoothing / num_classes 33 | on_value = 1. - smoothing + off_value 34 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 35 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 36 | return y1 * lam + y2 * (1. - lam) 37 | 38 | 39 | def rand_bbox(img_shape, lam, margin=0., count=None): 40 | """ Standard CutMix bounding-box 41 | Generates a random square bbox based on lambda value. This impl includes 42 | support for enforcing a border margin as percent of bbox dimensions. 43 | 44 | Args: 45 | img_shape (tuple): Image shape as tuple 46 | lam (float): Cutmix lambda value 47 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 48 | count (int): Number of bbox to generate 49 | """ 50 | ratio = np.sqrt(1 - lam) 51 | img_h, img_w = img_shape[-2:] 52 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 53 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 54 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 55 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 56 | yl = np.clip(cy - cut_h // 2, 0, img_h) 57 | yh = np.clip(cy + cut_h // 2, 0, img_h) 58 | xl = np.clip(cx - cut_w // 2, 0, img_w) 59 | xh = np.clip(cx + cut_w // 2, 0, img_w) 60 | return yl, yh, xl, xh 61 | 62 | 63 | def rand_bbox_minmax(img_shape, minmax, count=None): 64 | """ Min-Max CutMix bounding-box 65 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 66 | based on min/max percent values applied to each dimension of the input image. 67 | 68 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 69 | 70 | Args: 71 | img_shape (tuple): Image shape as tuple 72 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 73 | count (int): Number of bbox to generate 74 | """ 75 | assert len(minmax) == 2 76 | img_h, img_w = img_shape[-2:] 77 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 78 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 79 | yl = np.random.randint(0, img_h - cut_h, size=count) 80 | xl = np.random.randint(0, img_w - cut_w, size=count) 81 | yu = yl + cut_h 82 | xu = xl + cut_w 83 | return yl, yu, xl, xu 84 | 85 | 86 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 87 | """ Generate bbox and apply lambda correction. 88 | """ 89 | if ratio_minmax is not None: 90 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 91 | else: 92 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 93 | if correct_lam or ratio_minmax is not None: 94 | bbox_area = (yu - yl) * (xu - xl) 95 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 96 | return (yl, yu, xl, xu), lam 97 | 98 | 99 | class BatchMixup: 100 | """ Mixup/Cutmix that applies different params to each element or whole batch 101 | 102 | Args: 103 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 104 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 105 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 106 | prob (float): probability of applying mixup or cutmix per batch or element 107 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 108 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 109 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 110 | label_smoothing (float): apply label smoothing to the mixed target tensor 111 | num_classes (int): number of classes for target 112 | """ 113 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 114 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): 115 | assert mode == 'batch' 116 | self.mixup_alpha = mixup_alpha 117 | self.cutmix_alpha = cutmix_alpha 118 | self.cutmix_minmax = cutmix_minmax 119 | if self.cutmix_minmax is not None: 120 | assert len(self.cutmix_minmax) == 2 121 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 122 | self.cutmix_alpha = 1.0 123 | self.mix_prob = prob 124 | self.switch_prob = switch_prob 125 | self.label_smoothing = label_smoothing 126 | self.num_classes = num_classes 127 | self.mode = mode 128 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 129 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 130 | 131 | def _params_per_batch(self): 132 | lam = 1. 133 | use_cutmix = False 134 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 135 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 136 | use_cutmix = np.random.rand() < self.switch_prob 137 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 138 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 139 | elif self.mixup_alpha > 0.: 140 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 141 | elif self.cutmix_alpha > 0.: 142 | use_cutmix = True 143 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 144 | else: 145 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 146 | lam = float(lam_mix) 147 | return lam, use_cutmix 148 | 149 | def _mix_batch(self, x): 150 | lam, use_cutmix = self._params_per_batch() 151 | if lam == 1.: 152 | return 1. 153 | if use_cutmix: 154 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 155 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 156 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] 157 | else: 158 | x_flipped = x.flip(0).mul_(1. - lam) 159 | x.mul_(lam).add_(x_flipped) 160 | return lam 161 | 162 | def __call__(self, x, raw_target): 163 | if x.shape[0] % 2 == 1: 164 | x, raw_target = torch.cat((x[:1], x), dim=0), torch.cat((raw_target[:1], raw_target), dim=0) 165 | # assert len(x) % 2 == 0, 'Batch size should be even when using this' 166 | lam = self._mix_batch(x) 167 | target = mixup_target(raw_target, self.num_classes, lam, self.label_smoothing, x.device) 168 | return x, target, raw_target 169 | -------------------------------------------------------------------------------- /downstream_imagenet/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | from timm.data import Mixup 11 | from timm.loss import BinaryCrossEntropy, SoftTargetCrossEntropy 12 | from timm.models.layers import drop 13 | from timm.models.resnet import ResNet 14 | 15 | from .convnext_official import ConvNeXt 16 | 17 | 18 | def convnext_get_layer_id_and_scale_exp(self: ConvNeXt, para_name: str): 19 | N = 12 if len(self.stages[-2]) > 9 else 6 20 | if para_name.startswith("downsample_layers"): 21 | stage_id = int(para_name.split('.')[1]) 22 | if stage_id == 0: 23 | layer_id = 0 24 | elif stage_id == 1 or stage_id == 2: 25 | layer_id = stage_id + 1 26 | else: # stage_id == 3: 27 | layer_id = N 28 | elif para_name.startswith("stages"): 29 | stage_id = int(para_name.split('.')[1]) 30 | block_id = int(para_name.split('.')[2]) 31 | if stage_id == 0 or stage_id == 1: 32 | layer_id = stage_id + 1 33 | elif stage_id == 2: 34 | layer_id = 3 + block_id // 3 35 | else: # stage_id == 3: 36 | layer_id = N 37 | else: 38 | layer_id = N + 1 # after backbone 39 | 40 | return layer_id, N + 1 - layer_id 41 | 42 | 43 | def resnets_get_layer_id_and_scale_exp(self: ResNet, para_name: str): 44 | # stages: 45 | # 50 : [3, 4, 6, 3] 46 | # 101 : [3, 4, 23, 3] 47 | # 152 : [3, 8, 36, 3] 48 | # 200 : [3, 24, 36, 3] 49 | # eca269d: [3, 30, 48, 8] 50 | 51 | L2, L3 = len(self.layer2), len(self.layer3) 52 | if L2 == 4 and L3 == 6: 53 | blk2, blk3 = 2, 3 54 | elif L2 == 4 and L3 == 23: 55 | blk2, blk3 = 2, 3 56 | elif L2 == 8 and L3 == 36: 57 | blk2, blk3 = 4, 4 58 | elif L2 == 24 and L3 == 36: 59 | blk2, blk3 = 4, 4 60 | elif L2 == 30 and L3 == 48: 61 | blk2, blk3 = 5, 6 62 | else: 63 | raise NotImplementedError 64 | 65 | N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5) 66 | N = 2 + N2 + N3 67 | if para_name.startswith('layer'): # 1, 2, 3, 4, 5 68 | stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1]) 69 | if stage_id == 1: 70 | layer_id = 1 71 | elif stage_id == 2: 72 | layer_id = 2 + block_id // blk2 # 2, 3 73 | elif stage_id == 3: 74 | layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11 75 | else: # == 4 76 | layer_id = N # r50: 6 r101: 12 77 | elif para_name.startswith('fc.'): 78 | layer_id = N + 1 # r50: 7 r101: 13 79 | else: 80 | layer_id = 0 81 | 82 | return layer_id, N + 1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0 83 | 84 | 85 | def _ex_repr(self): 86 | return ', '.join( 87 | f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) 88 | for k, v in vars(self).items() 89 | if not k.startswith('_') and k != 'training' 90 | and not isinstance(v, (torch.nn.Module, torch.Tensor)) 91 | ) 92 | 93 | 94 | # IMPORTANT: update some member functions 95 | __UPDATED = False 96 | if not __UPDATED: 97 | for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, BinaryCrossEntropy, Mixup, drop.DropPath): 98 | if hasattr(clz, 'extra_repr'): 99 | clz.extra_repr = _ex_repr 100 | else: 101 | clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' 102 | ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp 103 | ConvNeXt.get_layer_id_and_scale_exp = convnext_get_layer_id_and_scale_exp 104 | __UPDATED = True 105 | -------------------------------------------------------------------------------- /downstream_imagenet/models/convnext_official.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # This file is exactly the same as: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_, DropPath 13 | from timm.models.registry import register_model 14 | 15 | class Block(nn.Module): 16 | r""" ConvNeXt Block. There are two equivalent implementations: 17 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 18 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 19 | We use (2) as we find it slightly faster in PyTorch 20 | 21 | Args: 22 | dim (int): Number of input channels. 23 | drop_path (float): Stochastic depth rate. Default: 0.0 24 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 25 | """ 26 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 27 | super().__init__() 28 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 29 | self.norm = LayerNorm(dim, eps=1e-6) 30 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 31 | self.act = nn.GELU() 32 | self.pwconv2 = nn.Linear(4 * dim, dim) 33 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 34 | requires_grad=True) if layer_scale_init_value > 0 else None 35 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 36 | 37 | def forward(self, x): 38 | input = x 39 | x = self.dwconv(x) 40 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 41 | x = self.norm(x) 42 | x = self.pwconv1(x) 43 | x = self.act(x) 44 | x = self.pwconv2(x) 45 | if self.gamma is not None: 46 | x = self.gamma * x 47 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 48 | 49 | x = input + self.drop_path(x) 50 | return x 51 | 52 | class ConvNeXt(nn.Module): 53 | r""" ConvNeXt 54 | A PyTorch impl of : `A ConvNet for the 2020s` - 55 | https://arxiv.org/pdf/2201.03545.pdf 56 | Args: 57 | in_chans (int): Number of input image channels. Default: 3 58 | num_classes (int): Number of classes for classification head. Default: 1000 59 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 60 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 61 | drop_path_rate (float): Stochastic depth rate. Default: 0. 62 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 63 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 64 | """ 65 | def __init__(self, in_chans=3, num_classes=1000, 66 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 67 | layer_scale_init_value=1e-6, head_init_scale=1., 68 | ): 69 | super().__init__() 70 | 71 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 72 | stem = nn.Sequential( 73 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 74 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 75 | ) 76 | self.downsample_layers.append(stem) 77 | for i in range(3): 78 | downsample_layer = nn.Sequential( 79 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 80 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 81 | ) 82 | self.downsample_layers.append(downsample_layer) 83 | 84 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 85 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 86 | cur = 0 87 | for i in range(4): 88 | stage = nn.Sequential( 89 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 90 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 91 | ) 92 | self.stages.append(stage) 93 | cur += depths[i] 94 | 95 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 96 | self.head = nn.Linear(dims[-1], num_classes) 97 | 98 | self.apply(self._init_weights) 99 | self.head.weight.data.mul_(head_init_scale) 100 | self.head.bias.data.mul_(head_init_scale) 101 | 102 | def _init_weights(self, m): 103 | if isinstance(m, (nn.Conv2d, nn.Linear)): 104 | trunc_normal_(m.weight, std=.02) 105 | nn.init.constant_(m.bias, 0) 106 | 107 | def forward_features(self, x): 108 | for i in range(4): 109 | x = self.downsample_layers[i](x) 110 | x = self.stages[i](x) 111 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 112 | 113 | def forward(self, x): 114 | x = self.forward_features(x) 115 | x = self.head(x) 116 | return x 117 | 118 | class LayerNorm(nn.Module): 119 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 120 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 121 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 122 | with shape (batch_size, channels, height, width). 123 | """ 124 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 125 | super().__init__() 126 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 127 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 128 | self.eps = eps 129 | self.data_format = data_format 130 | if self.data_format not in ["channels_last", "channels_first"]: 131 | raise NotImplementedError 132 | self.normalized_shape = (normalized_shape, ) 133 | 134 | def forward(self, x): 135 | if self.data_format == "channels_last": 136 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 137 | elif self.data_format == "channels_first": 138 | u = x.mean(1, keepdim=True) 139 | s = (x - u).pow(2).mean(1, keepdim=True) 140 | x = (x - u) / torch.sqrt(s + self.eps) 141 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 142 | return x 143 | 144 | 145 | model_urls = { 146 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 147 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 148 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 149 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 150 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 151 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 152 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 153 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 154 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 155 | } 156 | 157 | @register_model 158 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 159 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 160 | if pretrained: 161 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 162 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 163 | model.load_state_dict(checkpoint["model"]) 164 | return model 165 | 166 | @register_model 167 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 168 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 169 | if pretrained: 170 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 171 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 172 | model.load_state_dict(checkpoint["model"]) 173 | return model 174 | 175 | @register_model 176 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 177 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 178 | if pretrained: 179 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 180 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 181 | model.load_state_dict(checkpoint["model"]) 182 | return model 183 | 184 | @register_model 185 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 186 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 187 | if pretrained: 188 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 189 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 190 | model.load_state_dict(checkpoint["model"]) 191 | return model 192 | 193 | @register_model 194 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 195 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 196 | if pretrained: 197 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 198 | url = model_urls['convnext_xlarge_22k'] 199 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 200 | model.load_state_dict(checkpoint["model"]) 201 | return model 202 | -------------------------------------------------------------------------------- /downstream_imagenet/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | Pillow 4 | typed-argument-parser 5 | timm==0.5.4 6 | tensorboardx 7 | -------------------------------------------------------------------------------- /downstream_imagenet/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import datetime 8 | import os 9 | import sys 10 | from functools import partial 11 | from typing import List, Tuple, Callable 12 | 13 | import pytz 14 | import torch 15 | import torch.distributed as tdist 16 | import torch.multiprocessing as tmp 17 | from timm import create_model 18 | from timm.loss import SoftTargetCrossEntropy, BinaryCrossEntropy 19 | from timm.optim import AdamW, Lamb 20 | from timm.utils import ModelEmaV2 21 | from torch.nn.parallel import DistributedDataParallel 22 | from torch.optim.optimizer import Optimizer 23 | 24 | from arg import FineTuneArgs 25 | from downstream_imagenet.mixup import BatchMixup 26 | from lr_decay import get_param_groups 27 | 28 | 29 | def time_str(for_dirname=False): 30 | return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]') 31 | 32 | 33 | def init_distributed_environ(): 34 | # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 35 | if tmp.get_start_method(allow_none=True) is None: 36 | tmp.set_start_method('spawn') 37 | global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count() 38 | local_rank = global_rank % num_gpus 39 | torch.cuda.set_device(local_rank) 40 | 41 | tdist.init_process_group(backend='nccl') 42 | assert tdist.is_initialized(), 'torch.distributed is not initialized!' 43 | torch.backends.cudnn.benchmark = True 44 | torch.backends.cudnn.deterministic = False 45 | 46 | # print only when local_rank == 0 or print(..., force=True) 47 | import builtins as __builtin__ 48 | builtin_print = __builtin__.print 49 | 50 | def prt(msg, *args, **kwargs): 51 | force = kwargs.pop('force', False) 52 | if local_rank == 0 or force: 53 | f_back = sys._getframe().f_back 54 | file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] 55 | builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}', *args, **kwargs) 56 | 57 | __builtin__.print = prt 58 | tdist.barrier() 59 | return tdist.get_world_size(), global_rank, local_rank, torch.empty(1).cuda().device 60 | 61 | 62 | def create_model_opt(args: FineTuneArgs) -> Tuple[torch.nn.Module, Callable, torch.nn.Module, DistributedDataParallel, ModelEmaV2, Optimizer]: 63 | num_classes = 1000 64 | model_without_ddp: torch.nn.Module = create_model(args.model, num_classes=num_classes, drop_path_rate=args.drop_path).to(args.device) 65 | model_para = f'{sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) / 1e6:.1f}M' 66 | # create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 67 | model_ema = ModelEmaV2(model_without_ddp, decay=args.ema, device=args.device) 68 | if args.sbn: 69 | model_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_without_ddp) 70 | print(f'[model={args.model}] [#para={model_para}, drop_path={args.drop_path}, ema={args.ema}] {model_without_ddp}\n') 71 | model = DistributedDataParallel(model_without_ddp, device_ids=[args.local_rank], find_unused_parameters=False, broadcast_buffers=False) 72 | model.train() 73 | opt_cls = { 74 | 'adam': AdamW, 'adamw': AdamW, 75 | 'lamb': partial(Lamb, max_grad_norm=1e7, always_adapt=True, bias_correction=False), 76 | } 77 | param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}, lr_scale=args.lr_scale) 78 | # param_groups[0] is like this: {'params': List[nn.Parameters], 'lr': float, 'lr_scale': float, 'weight_decay': float, 'weight_decay_scale': float} 79 | optimizer = opt_cls[args.opt](param_groups, lr=args.lr, weight_decay=0) 80 | print(f'[optimizer={type(optimizer)}]') 81 | mixup_fn = BatchMixup( 82 | mixup_alpha=args.mixup, cutmix_alpha=1.0, cutmix_minmax=None, 83 | prob=1.0, switch_prob=0.5, mode='batch', 84 | label_smoothing=0.1, num_classes=num_classes 85 | ) 86 | mixup_fn.mixup_enabled = args.mixup > 0.0 87 | if 'lamb' in args.opt: 88 | # label smoothing is solved in AdaptiveMixup with `label_smoothing`, so here smoothing=0 89 | criterion = BinaryCrossEntropy(smoothing=0, target_threshold=None) 90 | else: 91 | criterion = SoftTargetCrossEntropy() 92 | print(f'[loss_fn] {criterion}') 93 | print(f'[mixup_fn] {mixup_fn}') 94 | return criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer 95 | 96 | 97 | def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer): 98 | if len(resume_from) == 0 or not os.path.exists(resume_from): 99 | raise AttributeError(f'ckpt `{resume_from}` not found!') 100 | # return 0, '[no performance_desc]' 101 | print(f'[try to resume from file `{resume_from}`]') 102 | checkpoint = torch.load(resume_from, map_location='cpu') 103 | assert checkpoint.get('is_pretrain', False) == False, 'Please do not use `*_withdecoder_1kpretrained_spark_style.pth`, which is ONLY for resuming the pretraining. Use `*_1kpretrained_timm_style.pth` or `*_1kfinetuned*.pth` instead.' 104 | 105 | ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]') 106 | missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False) 107 | print(f'[load_checkpoint] missing_keys={missing}') 108 | print(f'[load_checkpoint] unexpected_keys={unexpected}') 109 | print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}') 110 | 111 | if 'optimizer' in checkpoint: 112 | optimizer.load_state_dict(checkpoint['optimizer']) 113 | if 'ema' in checkpoint: 114 | ema_module.load_state_dict(checkpoint['ema']) 115 | return ep_start, performance_desc 116 | 117 | 118 | def save_checkpoint(save_to, args, epoch, performance_desc, model_without_ddp_state, ema_state, optimizer_state): 119 | checkpoint_path = os.path.join(args.exp_dir, save_to) 120 | if args.is_local_master: 121 | to_save = { 122 | 'args': str(args), 123 | 'arch': args.model, 124 | 'epoch': epoch, 125 | 'performance_desc': performance_desc, 126 | 'module': model_without_ddp_state, 127 | 'ema': ema_state, 128 | 'optimizer': optimizer_state, 129 | 'is_pretrain': False, 130 | } 131 | torch.save(to_save, checkpoint_path) 132 | -------------------------------------------------------------------------------- /downstream_mmdet/README.md: -------------------------------------------------------------------------------- 1 | ## About code isolation 2 | 3 | This `downstream_mmdet` is isolated from pre-training codes. One can treat this `downstream_mmdet` as an independent codebase 🛠️. 4 | 5 | ## Fine-tuned ConvNeXt-B weights, log files, and performance 6 | 7 | 8 |

9 | 10 | [[`weights (pre-trained by SparK)`](https://drive.google.com/file/d/1ZjWbqI1qoBcqeQijI5xX9E-YNkxpJcYV/view?usp=share_link)] 11 | [[`weights (fine-tuned on COCO)`](https://drive.google.com/file/d/1t10dmzg5KOO27o2yIglK-gQepB5gR4zR/view?usp=share_link)] 12 | [[`log.json`](https://drive.google.com/file/d/1TuNboXl1qwjf1tggZ3QOssI67uU7Jtig/view?usp=share_link)] 13 | [[`log`](https://drive.google.com/file/d/1JY5CkL_MX08zJ8P1FBIeC60OJsuIiyZc/view?usp=sharing)] 14 |
15 | 16 | 17 |

18 | 19 |

20 | 21 | 22 | ## Installation [MMDetection with commit 6a979e2](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d) before fine-tuning ConvNeXt on COCO 23 | 24 | We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d). 25 | Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions. 26 | 27 | Note the COCO dataset folder should be at `downstream_mmdet/data/coco`. 28 | The folder should follow the directory structure requried by `MMDetection`, which should look like this: 29 | ``` 30 | downstream_mmdet/data/coco: 31 | annotations/: 32 | captions_train2017.json captions_val2017.json 33 | instances_train2017.json instances_val2017.json 34 | person_keypoints_train2017.json person_keypoints_val2017.json 35 | train2017/: 36 | a_lot_images.jpg 37 | val2017/: 38 | a_lot_images.jpg 39 | ``` 40 | 41 | 42 | ### Training 43 | 44 | To train a detector with pre-trained models, run: 45 | ``` 46 | # single-gpu training 47 | python tools/train.py --cfg-options model.pretrained= [other optional arguments] 48 | 49 | # multi-gpu training 50 | tools/dist_train.sh --cfg-options model.pretrained= [other optional arguments] 51 | ``` 52 | For example, to train a Mask R-CNN model with a SparK pretrained `ConvNeXt-B` backbone and 4 gpus, run: 53 | ``` 54 | tools/dist_train.sh configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py 4 \ 55 | --cfg-options model.pretrained=/some/path/to/official_convnext_base_1kpretrained.pth 56 | ``` 57 | 58 | The Mask R-CNN 3x fine-tuning config file can be found at [`configs/convnext_spark`](configs/convnext_spark). This config is basically a copy of [https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py](https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py). 59 | 60 | ### Inference 61 | ``` 62 | # single-gpu testing 63 | python tools/test.py --eval bbox segm 64 | 65 | # multi-gpu testing 66 | tools/dist_test.sh --eval bbox segm 67 | ``` 68 | 69 | ## Acknowledgment 70 | 71 | We appreciate these useful codebases: 72 | 73 | - [MMDetection](https://github.com/open-mmlab/mmdetection) 74 | - [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) 75 | - [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection) 76 | 77 | -------------------------------------------------------------------------------- /downstream_mmdet/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type='CustomizedTextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook') 8 | ]) 9 | # yapf:enable 10 | custom_hooks = [dict(type='NumClassCheckHook')] 11 | 12 | dist_params = dict(backend='nccl') 13 | log_level = 'INFO' 14 | load_from = None 15 | resume_from = None 16 | workflow = [('train', 1)] 17 | -------------------------------------------------------------------------------- /downstream_mmdet/configs/_base_/models/cascade_mask_rcnn_convnext_fpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | # model settings 10 | model = dict( 11 | type='CascadeRCNN', 12 | pretrained=None, 13 | backbone=dict( 14 | type='ConvNeXt', 15 | in_chans=3, 16 | depths=[3, 3, 9, 3], 17 | dims=[96, 192, 384, 768], 18 | drop_path_rate=0.2, 19 | layer_scale_init_value=1e-6, 20 | out_indices=[0, 1, 2, 3], 21 | ), 22 | neck=dict( 23 | type='FPN', 24 | in_channels=[128, 256, 512, 1024], 25 | out_channels=256, 26 | num_outs=5), 27 | rpn_head=dict( 28 | type='RPNHead', 29 | in_channels=256, 30 | feat_channels=256, 31 | anchor_generator=dict( 32 | type='AnchorGenerator', 33 | scales=[8], 34 | ratios=[0.5, 1.0, 2.0], 35 | strides=[4, 8, 16, 32, 64]), 36 | bbox_coder=dict( 37 | type='DeltaXYWHBBoxCoder', 38 | target_means=[.0, .0, .0, .0], 39 | target_stds=[1.0, 1.0, 1.0, 1.0]), 40 | loss_cls=dict( 41 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 42 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), 43 | roi_head=dict( 44 | type='CascadeRoIHead', 45 | num_stages=3, 46 | stage_loss_weights=[1, 0.5, 0.25], 47 | bbox_roi_extractor=dict( 48 | type='SingleRoIExtractor', 49 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 50 | out_channels=256, 51 | featmap_strides=[4, 8, 16, 32]), 52 | bbox_head=[ 53 | dict( 54 | type='Shared2FCBBoxHead', 55 | in_channels=256, 56 | fc_out_channels=1024, 57 | roi_feat_size=7, 58 | num_classes=80, 59 | bbox_coder=dict( 60 | type='DeltaXYWHBBoxCoder', 61 | target_means=[0., 0., 0., 0.], 62 | target_stds=[0.1, 0.1, 0.2, 0.2]), 63 | reg_class_agnostic=True, 64 | loss_cls=dict( 65 | type='CrossEntropyLoss', 66 | use_sigmoid=False, 67 | loss_weight=1.0), 68 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, 69 | loss_weight=1.0)), 70 | dict( 71 | type='Shared2FCBBoxHead', 72 | in_channels=256, 73 | fc_out_channels=1024, 74 | roi_feat_size=7, 75 | num_classes=80, 76 | bbox_coder=dict( 77 | type='DeltaXYWHBBoxCoder', 78 | target_means=[0., 0., 0., 0.], 79 | target_stds=[0.05, 0.05, 0.1, 0.1]), 80 | reg_class_agnostic=True, 81 | loss_cls=dict( 82 | type='CrossEntropyLoss', 83 | use_sigmoid=False, 84 | loss_weight=1.0), 85 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, 86 | loss_weight=1.0)), 87 | dict( 88 | type='Shared2FCBBoxHead', 89 | in_channels=256, 90 | fc_out_channels=1024, 91 | roi_feat_size=7, 92 | num_classes=80, 93 | bbox_coder=dict( 94 | type='DeltaXYWHBBoxCoder', 95 | target_means=[0., 0., 0., 0.], 96 | target_stds=[0.033, 0.033, 0.067, 0.067]), 97 | reg_class_agnostic=True, 98 | loss_cls=dict( 99 | type='CrossEntropyLoss', 100 | use_sigmoid=False, 101 | loss_weight=1.0), 102 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) 103 | ], 104 | mask_roi_extractor=dict( 105 | type='SingleRoIExtractor', 106 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 107 | out_channels=256, 108 | featmap_strides=[4, 8, 16, 32]), 109 | mask_head=dict( 110 | type='FCNMaskHead', 111 | num_convs=4, 112 | in_channels=256, 113 | conv_out_channels=256, 114 | num_classes=80, 115 | loss_mask=dict( 116 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 117 | # model training and testing settings 118 | train_cfg = dict( 119 | rpn=dict( 120 | assigner=dict( 121 | type='MaxIoUAssigner', 122 | pos_iou_thr=0.7, 123 | neg_iou_thr=0.3, 124 | min_pos_iou=0.3, 125 | match_low_quality=True, 126 | ignore_iof_thr=-1), 127 | sampler=dict( 128 | type='RandomSampler', 129 | num=256, 130 | pos_fraction=0.5, 131 | neg_pos_ub=-1, 132 | add_gt_as_proposals=False), 133 | allowed_border=0, 134 | pos_weight=-1, 135 | debug=False), 136 | rpn_proposal=dict( 137 | nms_across_levels=False, 138 | nms_pre=2000, 139 | nms_post=2000, 140 | max_per_img=2000, 141 | nms=dict(type='nms', iou_threshold=0.7), 142 | min_bbox_size=0), 143 | rcnn=[ 144 | dict( 145 | assigner=dict( 146 | type='MaxIoUAssigner', 147 | pos_iou_thr=0.5, 148 | neg_iou_thr=0.5, 149 | min_pos_iou=0.5, 150 | match_low_quality=False, 151 | ignore_iof_thr=-1), 152 | sampler=dict( 153 | type='RandomSampler', 154 | num=512, 155 | pos_fraction=0.25, 156 | neg_pos_ub=-1, 157 | add_gt_as_proposals=True), 158 | mask_size=28, 159 | pos_weight=-1, 160 | debug=False), 161 | dict( 162 | assigner=dict( 163 | type='MaxIoUAssigner', 164 | pos_iou_thr=0.6, 165 | neg_iou_thr=0.6, 166 | min_pos_iou=0.6, 167 | match_low_quality=False, 168 | ignore_iof_thr=-1), 169 | sampler=dict( 170 | type='RandomSampler', 171 | num=512, 172 | pos_fraction=0.25, 173 | neg_pos_ub=-1, 174 | add_gt_as_proposals=True), 175 | mask_size=28, 176 | pos_weight=-1, 177 | debug=False), 178 | dict( 179 | assigner=dict( 180 | type='MaxIoUAssigner', 181 | pos_iou_thr=0.7, 182 | neg_iou_thr=0.7, 183 | min_pos_iou=0.7, 184 | match_low_quality=False, 185 | ignore_iof_thr=-1), 186 | sampler=dict( 187 | type='RandomSampler', 188 | num=512, 189 | pos_fraction=0.25, 190 | neg_pos_ub=-1, 191 | add_gt_as_proposals=True), 192 | mask_size=28, 193 | pos_weight=-1, 194 | debug=False) 195 | ]), 196 | test_cfg = dict( 197 | rpn=dict( 198 | nms_across_levels=False, 199 | nms_pre=1000, 200 | nms_post=1000, 201 | max_per_img=1000, 202 | nms=dict(type='nms', iou_threshold=0.7), 203 | min_bbox_size=0), 204 | rcnn=dict( 205 | score_thr=0.05, 206 | nms=dict(type='nms', iou_threshold=0.5), 207 | max_per_img=100, 208 | mask_thr_binary=0.5))) 209 | -------------------------------------------------------------------------------- /downstream_mmdet/configs/_base_/models/mask_rcnn_convnext_fpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | # model settings 10 | model = dict( 11 | type='MaskRCNN', 12 | pretrained=None, 13 | backbone=dict( 14 | type='ConvNeXt', 15 | in_chans=3, 16 | depths=[3, 3, 9, 3], 17 | dims=[96, 192, 384, 768], 18 | drop_path_rate=0.2, 19 | layer_scale_init_value=1e-6, 20 | out_indices=[0, 1, 2, 3], 21 | ), 22 | neck=dict( 23 | type='FPN', 24 | in_channels=[128, 256, 512, 1024], 25 | out_channels=256, 26 | num_outs=5), 27 | rpn_head=dict( 28 | type='RPNHead', 29 | in_channels=256, 30 | feat_channels=256, 31 | anchor_generator=dict( 32 | type='AnchorGenerator', 33 | scales=[8], 34 | ratios=[0.5, 1.0, 2.0], 35 | strides=[4, 8, 16, 32, 64]), 36 | bbox_coder=dict( 37 | type='DeltaXYWHBBoxCoder', 38 | target_means=[.0, .0, .0, .0], 39 | target_stds=[1.0, 1.0, 1.0, 1.0]), 40 | loss_cls=dict( 41 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 42 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 43 | roi_head=dict( 44 | type='StandardRoIHead', 45 | bbox_roi_extractor=dict( 46 | type='SingleRoIExtractor', 47 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 48 | out_channels=256, 49 | featmap_strides=[4, 8, 16, 32]), 50 | bbox_head=dict( 51 | type='Shared2FCBBoxHead', 52 | in_channels=256, 53 | fc_out_channels=1024, 54 | roi_feat_size=7, 55 | num_classes=80, 56 | bbox_coder=dict( 57 | type='DeltaXYWHBBoxCoder', 58 | target_means=[0., 0., 0., 0.], 59 | target_stds=[0.1, 0.1, 0.2, 0.2]), 60 | reg_class_agnostic=False, 61 | loss_cls=dict( 62 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 63 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 64 | mask_roi_extractor=dict( 65 | type='SingleRoIExtractor', 66 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 67 | out_channels=256, 68 | featmap_strides=[4, 8, 16, 32]), 69 | mask_head=dict( 70 | type='FCNMaskHead', 71 | num_convs=4, 72 | in_channels=256, 73 | conv_out_channels=256, 74 | num_classes=80, 75 | loss_mask=dict( 76 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 77 | # model training and testing settings 78 | train_cfg=dict( 79 | rpn=dict( 80 | assigner=dict( 81 | type='MaxIoUAssigner', 82 | pos_iou_thr=0.7, 83 | neg_iou_thr=0.3, 84 | min_pos_iou=0.3, 85 | match_low_quality=True, 86 | ignore_iof_thr=-1), 87 | sampler=dict( 88 | type='RandomSampler', 89 | num=256, 90 | pos_fraction=0.5, 91 | neg_pos_ub=-1, 92 | add_gt_as_proposals=False), 93 | allowed_border=-1, 94 | pos_weight=-1, 95 | debug=False), 96 | rpn_proposal=dict( 97 | nms_pre=2000, 98 | max_per_img=1000, 99 | nms=dict(type='nms', iou_threshold=0.7), 100 | min_bbox_size=0), 101 | rcnn=dict( 102 | assigner=dict( 103 | type='MaxIoUAssigner', 104 | pos_iou_thr=0.5, 105 | neg_iou_thr=0.5, 106 | min_pos_iou=0.5, 107 | match_low_quality=True, 108 | ignore_iof_thr=-1), 109 | sampler=dict( 110 | type='RandomSampler', 111 | num=512, 112 | pos_fraction=0.25, 113 | neg_pos_ub=-1, 114 | add_gt_as_proposals=True), 115 | mask_size=28, 116 | pos_weight=-1, 117 | debug=False)), 118 | test_cfg=dict( 119 | rpn=dict( 120 | nms_pre=1000, 121 | max_per_img=1000, 122 | nms=dict(type='nms', iou_threshold=0.7), 123 | min_bbox_size=0), 124 | rcnn=dict( 125 | score_thr=0.05, 126 | nms=dict(type='nms', iou_threshold=0.5), 127 | max_per_img=100, 128 | mask_thr_binary=0.5))) 129 | -------------------------------------------------------------------------------- /downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py: -------------------------------------------------------------------------------- 1 | """ 2 | We directly take the ConvNeXt-T+MaskRCNN 3x recipe from https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py 3 | And we modify this ConvNeXt-T+MaskRCNN 3x recipe to our ConvNeXt-B+MaskRCNN 3x recipe. 4 | The modifications (commented as [modified] below) are according to: 5 | - 1. tiny-to-base: (some configs of ConvNext-T are updated to those of ConvNext-B, referring to https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/cascade_mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco_in22k.py) 6 | - model.backbone.{depths, dims, drop_path_rate} 7 | - models.neck 8 | - optimizer.paramwise_cfg.num_layers 9 | 10 | - 2. our paper (https://openreview.net/forum?id=NRxydtWup1S, or https://arxiv.org/abs/2301.03580): 11 | - LR layer decay (optimizer.paramwise_cfg.decay_rate): 0.65 12 | - LR scheduled ratio (lr_config.gamma): 0.2 13 | - Learning rate (optimizer.lr): 0.0002 14 | - optimizer_config.use_fp16: False (we just use fp32 by default; actually we didn't test the performance of using fp16) 15 | """ 16 | 17 | _base_ = [ 18 | '../_base_/models/mask_rcnn_convnext_fpn.py', 19 | '../_base_/datasets/coco_instance.py', 20 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 21 | ] 22 | 23 | model = dict( 24 | backbone=dict( 25 | in_chans=3, 26 | depths=[3, 3, 27, 3], # [modified] according to tiny-to-base 27 | dims=[128, 256, 512, 1024], # [modified] according to tiny-to-base 28 | drop_path_rate=0.5, # [modified] according to tiny-to-base 29 | layer_scale_init_value=1.0, 30 | out_indices=[0, 1, 2, 3], 31 | ), 32 | neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny-to-base 33 | 34 | img_norm_cfg = dict( 35 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 36 | 37 | # augmentation strategy originates from DETR / Sparse RCNN 38 | train_pipeline = [ 39 | dict(type='LoadImageFromFile'), 40 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 41 | dict(type='RandomFlip', flip_ratio=0.5), 42 | dict(type='AutoAugment', 43 | policies=[ 44 | [ 45 | dict(type='Resize', 46 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), 47 | (608, 1333), (640, 1333), (672, 1333), (704, 1333), 48 | (736, 1333), (768, 1333), (800, 1333)], 49 | multiscale_mode='value', 50 | keep_ratio=True) 51 | ], 52 | [ 53 | dict(type='Resize', 54 | img_scale=[(400, 1333), (500, 1333), (600, 1333)], 55 | multiscale_mode='value', 56 | keep_ratio=True), 57 | dict(type='RandomCrop', 58 | crop_type='absolute_range', 59 | crop_size=(384, 600), 60 | allow_negative_crop=True), 61 | dict(type='Resize', 62 | img_scale=[(480, 1333), (512, 1333), (544, 1333), 63 | (576, 1333), (608, 1333), (640, 1333), 64 | (672, 1333), (704, 1333), (736, 1333), 65 | (768, 1333), (800, 1333)], 66 | multiscale_mode='value', 67 | override=True, 68 | keep_ratio=True) 69 | ] 70 | ]), 71 | dict(type='Normalize', **img_norm_cfg), 72 | dict(type='Pad', size_divisor=32), 73 | dict(type='DefaultFormatBundle'), 74 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 75 | ] 76 | data = dict(train=dict(pipeline=train_pipeline)) 77 | 78 | optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW', 79 | lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05, # [modified] according to our paper 80 | paramwise_cfg={'decay_rate': 0.65, # [modified] according to our paper 81 | 'decay_type': 'layer_wise', 82 | 'num_layers': 12}) # [modified] according to tiny-to-base 83 | lr_config = dict(step=[27, 33], gamma=0.2) # [modified] according to our paper 84 | runner = dict(type='EpochBasedRunnerAmp', max_epochs=36) 85 | 86 | # do not use mmdet version fp16 87 | fp16 = None 88 | optimizer_config = dict( 89 | type="DistOptimizerHook", 90 | update_interval=1, 91 | grad_clip=None, 92 | coalesce=True, 93 | bucket_size_mb=-1, 94 | use_fp16=False, # [modified] True => False 95 | ) -------------------------------------------------------------------------------- /downstream_mmdet/mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | # -*- coding: utf-8 -*- 10 | 11 | from .checkpoint import load_checkpoint 12 | from .layer_decay_optimizer_constructor import LearningRateDecayOptimizerConstructor 13 | from .customized_text import CustomizedTextLoggerHook 14 | 15 | __all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'CustomizedTextLoggerHook'] 16 | -------------------------------------------------------------------------------- /downstream_mmdet/mmcv_custom/customized_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import datetime 10 | from collections import OrderedDict 11 | 12 | import torch 13 | 14 | import mmcv 15 | from mmcv.runner import HOOKS 16 | from mmcv.runner import TextLoggerHook 17 | 18 | 19 | @HOOKS.register_module() 20 | class CustomizedTextLoggerHook(TextLoggerHook): 21 | """Customized Text Logger hook. 22 | 23 | This logger prints out both lr and layer_0_lr. 24 | 25 | """ 26 | 27 | def _log_info(self, log_dict, runner): 28 | # print exp name for users to distinguish experiments 29 | # at every ``interval_exp_name`` iterations and the end of each epoch 30 | if runner.meta is not None and 'exp_name' in runner.meta: 31 | if (self.every_n_iters(runner, self.interval_exp_name)) or ( 32 | self.by_epoch and self.end_of_epoch(runner)): 33 | exp_info = f'Exp name: {runner.meta["exp_name"]}' 34 | runner.logger.info(exp_info) 35 | 36 | if log_dict['mode'] == 'train': 37 | lr_str = {} 38 | for lr_type in ['lr', 'layer_0_lr']: 39 | if isinstance(log_dict[lr_type], dict): 40 | lr_str[lr_type] = [] 41 | for k, val in log_dict[lr_type].items(): 42 | lr_str.append(f'{lr_type}_{k}: {val:.3e}') 43 | lr_str[lr_type] = ' '.join(lr_str) 44 | else: 45 | lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}' 46 | 47 | # by epoch: Epoch [4][100/1000] 48 | # by iter: Iter [100/100000] 49 | if self.by_epoch: 50 | log_str = f'Epoch [{log_dict["epoch"]}]' \ 51 | f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' 52 | else: 53 | log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t' 54 | log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, ' 55 | 56 | if 'time' in log_dict.keys(): 57 | self.time_sec_tot += (log_dict['time'] * self.interval) 58 | time_sec_avg = self.time_sec_tot / ( 59 | runner.iter - self.start_iter + 1) 60 | eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) 61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 62 | log_str += f'eta: {eta_str}, ' 63 | log_str += f'time: {log_dict["time"]:.3f}, ' \ 64 | f'data_time: {log_dict["data_time"]:.3f}, ' 65 | # statistic memory 66 | if torch.cuda.is_available(): 67 | log_str += f'memory: {log_dict["memory"]}, ' 68 | else: 69 | # val/test time 70 | # here 1000 is the length of the val dataloader 71 | # by epoch: Epoch[val] [4][1000] 72 | # by iter: Iter[val] [1000] 73 | if self.by_epoch: 74 | log_str = f'Epoch({log_dict["mode"]}) ' \ 75 | f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t' 76 | else: 77 | log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t' 78 | 79 | log_items = [] 80 | for name, val in log_dict.items(): 81 | # TODO: resolve this hack 82 | # these items have been in log_str 83 | if name in [ 84 | 'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time', 85 | 'memory', 'epoch' 86 | ]: 87 | continue 88 | if isinstance(val, float): 89 | val = f'{val:.4f}' 90 | log_items.append(f'{name}: {val}') 91 | log_str += ', '.join(log_items) 92 | 93 | runner.logger.info(log_str) 94 | 95 | 96 | def log(self, runner): 97 | if 'eval_iter_num' in runner.log_buffer.output: 98 | # this doesn't modify runner.iter and is regardless of by_epoch 99 | cur_iter = runner.log_buffer.output.pop('eval_iter_num') 100 | else: 101 | cur_iter = self.get_iter(runner, inner_iter=True) 102 | 103 | log_dict = OrderedDict( 104 | mode=self.get_mode(runner), 105 | epoch=self.get_epoch(runner), 106 | iter=cur_iter) 107 | 108 | # record lr and layer_0_lr 109 | cur_lr = runner.current_lr() 110 | if isinstance(cur_lr, list): 111 | log_dict['layer_0_lr'] = min(cur_lr) 112 | log_dict['lr'] = max(cur_lr) 113 | else: 114 | assert isinstance(cur_lr, dict) 115 | log_dict['lr'], log_dict['layer_0_lr'] = {}, {} 116 | for k, lr_ in cur_lr.items(): 117 | assert isinstance(lr_, list) 118 | log_dict['layer_0_lr'].update({k: min(lr_)}) 119 | log_dict['lr'].update({k: max(lr_)}) 120 | 121 | if 'time' in runner.log_buffer.output: 122 | # statistic memory 123 | if torch.cuda.is_available(): 124 | log_dict['memory'] = self._get_max_memory(runner) 125 | 126 | log_dict = dict(log_dict, **runner.log_buffer.output) 127 | 128 | self._log_info(log_dict, runner) 129 | self._dump_log(log_dict, runner) 130 | return log_dict 131 | -------------------------------------------------------------------------------- /downstream_mmdet/mmcv_custom/layer_decay_optimizer_constructor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import json 10 | from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor 11 | from mmcv.runner import get_dist_info 12 | 13 | 14 | def get_num_layer_layer_wise(var_name, num_max_layer=12): 15 | 16 | if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): 17 | return 0 18 | elif var_name.startswith("backbone.downsample_layers"): 19 | stage_id = int(var_name.split('.')[2]) 20 | if stage_id == 0: 21 | layer_id = 0 22 | elif stage_id == 1: 23 | layer_id = 2 24 | elif stage_id == 2: 25 | layer_id = 3 26 | elif stage_id == 3: 27 | layer_id = num_max_layer 28 | return layer_id 29 | elif var_name.startswith("backbone.stages"): 30 | stage_id = int(var_name.split('.')[2]) 31 | block_id = int(var_name.split('.')[3]) 32 | if stage_id == 0: 33 | layer_id = 1 34 | elif stage_id == 1: 35 | layer_id = 2 36 | elif stage_id == 2: 37 | layer_id = 3 + block_id // 3 38 | elif stage_id == 3: 39 | layer_id = num_max_layer 40 | return layer_id 41 | else: 42 | return num_max_layer + 1 43 | 44 | 45 | def get_num_layer_stage_wise(var_name, num_max_layer): 46 | if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): 47 | return 0 48 | elif var_name.startswith("backbone.downsample_layers"): 49 | return 0 50 | elif var_name.startswith("backbone.stages"): 51 | stage_id = int(var_name.split('.')[2]) 52 | return stage_id + 1 53 | else: 54 | return num_max_layer - 1 55 | 56 | 57 | @OPTIMIZER_BUILDERS.register_module() 58 | class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): 59 | def add_params(self, params, module, prefix='', is_dcn_module=None): 60 | """Add all parameters of module to the params list. 61 | The parameters of the given module will be added to the list of param 62 | groups, with specific rules defined by paramwise_cfg. 63 | Args: 64 | params (list[dict]): A list of param groups, it will be modified 65 | in place. 66 | module (nn.Module): The module to be added. 67 | prefix (str): The prefix of the module 68 | is_dcn_module (int|float|None): If the current module is a 69 | submodule of DCN, `is_dcn_module` will be passed to 70 | control conv_offset layer's learning rate. Defaults to None. 71 | """ 72 | parameter_groups = {} 73 | print(self.paramwise_cfg) 74 | num_layers = self.paramwise_cfg.get('num_layers') + 2 75 | decay_rate = self.paramwise_cfg.get('decay_rate') 76 | decay_type = self.paramwise_cfg.get('decay_type', "layer_wise") 77 | print("Build LearningRateDecayOptimizerConstructor %s %f - %d" % (decay_type, decay_rate, num_layers)) 78 | weight_decay = self.base_wd 79 | 80 | for name, param in module.named_parameters(): 81 | if not param.requires_grad: 82 | continue # frozen weights 83 | if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'): 84 | group_name = "no_decay" 85 | this_weight_decay = 0. 86 | else: 87 | group_name = "decay" 88 | this_weight_decay = weight_decay 89 | 90 | if decay_type == "layer_wise": 91 | layer_id = get_num_layer_layer_wise(name, self.paramwise_cfg.get('num_layers')) 92 | elif decay_type == "stage_wise": 93 | layer_id = get_num_layer_stage_wise(name, num_layers) 94 | 95 | group_name = "layer_%d_%s" % (layer_id, group_name) 96 | 97 | if group_name not in parameter_groups: 98 | scale = decay_rate ** (num_layers - layer_id - 1) 99 | 100 | parameter_groups[group_name] = { 101 | "weight_decay": this_weight_decay, 102 | "params": [], 103 | "param_names": [], 104 | "lr_scale": scale, 105 | "group_name": group_name, 106 | "lr": scale * self.base_lr, 107 | } 108 | 109 | parameter_groups[group_name]["params"].append(param) 110 | parameter_groups[group_name]["param_names"].append(name) 111 | rank, _ = get_dist_info() 112 | if rank == 0: 113 | to_display = {} 114 | for key in parameter_groups: 115 | to_display[key] = { 116 | "param_names": parameter_groups[key]["param_names"], 117 | "lr_scale": parameter_groups[key]["lr_scale"], 118 | "lr": parameter_groups[key]["lr"], 119 | "weight_decay": parameter_groups[key]["weight_decay"], 120 | } 121 | print("Param groups = %s" % json.dumps(to_display, indent=2)) 122 | 123 | params.extend(parameter_groups.values()) 124 | -------------------------------------------------------------------------------- /downstream_mmdet/mmcv_custom/runner/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import time 4 | from tempfile import TemporaryDirectory 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.parallel import is_module_wrapper 11 | from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict 12 | 13 | try: 14 | import apex 15 | except: 16 | print('apex is not installed') 17 | 18 | 19 | def save_checkpoint(model, filename, optimizer=None, meta=None): 20 | """Save checkpoint to file. 21 | 22 | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and 23 | ``optimizer``, ``amp``. By default ``meta`` will contain version 24 | and time info. 25 | 26 | Args: 27 | model (Module): Module whose params are to be saved. 28 | filename (str): Checkpoint filename. 29 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 30 | meta (dict, optional): Metadata to be saved in checkpoint. 31 | """ 32 | if meta is None: 33 | meta = {} 34 | elif not isinstance(meta, dict): 35 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') 36 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 37 | 38 | if is_module_wrapper(model): 39 | model = model.module 40 | 41 | if hasattr(model, 'CLASSES') and model.CLASSES is not None: 42 | # save class name to the meta 43 | meta.update(CLASSES=model.CLASSES) 44 | 45 | checkpoint = { 46 | 'meta': meta, 47 | 'state_dict': weights_to_cpu(get_state_dict(model)) 48 | } 49 | # save optimizer state dict in the checkpoint 50 | if isinstance(optimizer, Optimizer): 51 | checkpoint['optimizer'] = optimizer.state_dict() 52 | elif isinstance(optimizer, dict): 53 | checkpoint['optimizer'] = {} 54 | for name, optim in optimizer.items(): 55 | checkpoint['optimizer'][name] = optim.state_dict() 56 | 57 | # save amp state dict in the checkpoint 58 | # checkpoint['amp'] = apex.amp.state_dict() 59 | 60 | if filename.startswith('pavi://'): 61 | try: 62 | from pavi import modelcloud 63 | from pavi.exception import NodeNotFoundError 64 | except ImportError: 65 | raise ImportError( 66 | 'Please install pavi to load checkpoint from modelcloud.') 67 | model_path = filename[7:] 68 | root = modelcloud.Folder() 69 | model_dir, model_name = osp.split(model_path) 70 | try: 71 | model = modelcloud.get(model_dir) 72 | except NodeNotFoundError: 73 | model = root.create_training_model(model_dir) 74 | with TemporaryDirectory() as tmp_dir: 75 | checkpoint_file = osp.join(tmp_dir, model_name) 76 | with open(checkpoint_file, 'wb') as f: 77 | torch.save(checkpoint, f) 78 | f.flush() 79 | model.create_file(checkpoint_file, name=model_name) 80 | else: 81 | mmcv.mkdir_or_exist(osp.dirname(filename)) 82 | # immediately flush buffer 83 | with open(filename, 'wb') as f: 84 | torch.save(checkpoint, f) 85 | f.flush() 86 | -------------------------------------------------------------------------------- /downstream_mmdet/mmdet/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .darknet import Darknet 2 | from .detectors_resnet import DetectoRS_ResNet 3 | from .detectors_resnext import DetectoRS_ResNeXt 4 | from .hourglass import HourglassNet 5 | from .hrnet import HRNet 6 | from .regnet import RegNet 7 | from .res2net import Res2Net 8 | from .resnest import ResNeSt 9 | from .resnet import ResNet, ResNetV1d 10 | from .resnext import ResNeXt 11 | from .ssd_vgg import SSDVGG 12 | from .trident_resnet import TridentResNet 13 | from .swin_transformer import SwinTransformer 14 | from .convnext import ConvNeXt 15 | 16 | __all__ = [ 17 | 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', 18 | 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet', 19 | 'ResNeSt', 'TridentResNet', 'SwinTransformer', 'ConvNeXt' 20 | ] 21 | -------------------------------------------------------------------------------- /downstream_mmdet/mmdet/models/backbones/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | from functools import partial 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from timm.models.layers import trunc_normal_, DropPath 14 | 15 | from mmcv_custom import load_checkpoint 16 | from mmdet.utils import get_root_logger 17 | from ..builder import BACKBONES 18 | 19 | class Block(nn.Module): 20 | r""" ConvNeXt Block. There are two equivalent implementations: 21 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 22 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 23 | We use (2) as we find it slightly faster in PyTorch 24 | 25 | Args: 26 | dim (int): Number of input channels. 27 | drop_path (float): Stochastic depth rate. Default: 0.0 28 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 29 | """ 30 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 31 | super().__init__() 32 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 33 | self.norm = LayerNorm(dim, eps=1e-6) 34 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 35 | self.act = nn.GELU() 36 | self.pwconv2 = nn.Linear(4 * dim, dim) 37 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 38 | requires_grad=True) if layer_scale_init_value > 0 else None 39 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 40 | 41 | def forward(self, x): 42 | input = x 43 | x = self.dwconv(x) 44 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 45 | x = self.norm(x) 46 | x = self.pwconv1(x) 47 | x = self.act(x) 48 | x = self.pwconv2(x) 49 | if self.gamma is not None: 50 | x = self.gamma * x 51 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 52 | 53 | x = input + self.drop_path(x) 54 | return x 55 | 56 | @BACKBONES.register_module() 57 | class ConvNeXt(nn.Module): 58 | r""" ConvNeXt 59 | A PyTorch impl of : `A ConvNet for the 2020s` - 60 | https://arxiv.org/pdf/2201.03545.pdf 61 | 62 | Args: 63 | in_chans (int): Number of input image channels. Default: 3 64 | num_classes (int): Number of classes for classification head. Default: 1000 65 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 66 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 67 | drop_path_rate (float): Stochastic depth rate. Default: 0. 68 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 69 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 70 | """ 71 | def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 72 | drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3], 73 | ): 74 | super().__init__() 75 | 76 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 77 | stem = nn.Sequential( 78 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 79 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 80 | ) 81 | self.downsample_layers.append(stem) 82 | for i in range(3): 83 | downsample_layer = nn.Sequential( 84 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 85 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 86 | ) 87 | self.downsample_layers.append(downsample_layer) 88 | 89 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 90 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 91 | cur = 0 92 | for i in range(4): 93 | stage = nn.Sequential( 94 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 95 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 96 | ) 97 | self.stages.append(stage) 98 | cur += depths[i] 99 | 100 | self.out_indices = out_indices 101 | 102 | norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first") 103 | for i_layer in range(4): 104 | layer = norm_layer(dims[i_layer]) 105 | layer_name = f'norm{i_layer}' 106 | self.add_module(layer_name, layer) 107 | 108 | self.apply(self._init_weights) 109 | 110 | def _init_weights(self, m): 111 | if isinstance(m, (nn.Conv2d, nn.Linear)): 112 | trunc_normal_(m.weight, std=.02) 113 | nn.init.constant_(m.bias, 0) 114 | 115 | def init_weights(self, pretrained=None): 116 | """Initialize the weights in backbone. 117 | Args: 118 | pretrained (str, optional): Path to pre-trained weights. 119 | Defaults to None. 120 | """ 121 | 122 | def _init_weights(m): 123 | if isinstance(m, nn.Linear): 124 | trunc_normal_(m.weight, std=.02) 125 | if isinstance(m, nn.Linear) and m.bias is not None: 126 | nn.init.constant_(m.bias, 0) 127 | elif isinstance(m, nn.LayerNorm): 128 | nn.init.constant_(m.bias, 0) 129 | nn.init.constant_(m.weight, 1.0) 130 | 131 | if isinstance(pretrained, str): 132 | self.apply(_init_weights) 133 | logger = get_root_logger() 134 | load_checkpoint(self, pretrained, strict=False, logger=logger) 135 | elif pretrained is None: 136 | self.apply(_init_weights) 137 | else: 138 | raise TypeError('pretrained must be a str or None') 139 | 140 | def forward_features(self, x): 141 | outs = [] 142 | for i in range(4): 143 | x = self.downsample_layers[i](x) 144 | x = self.stages[i](x) 145 | if i in self.out_indices: 146 | norm_layer = getattr(self, f'norm{i}') 147 | x_out = norm_layer(x) 148 | outs.append(x_out) 149 | 150 | return tuple(outs) 151 | 152 | def forward(self, x): 153 | x = self.forward_features(x) 154 | return x 155 | 156 | class LayerNorm(nn.Module): 157 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 158 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 159 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 160 | with shape (batch_size, channels, height, width). 161 | """ 162 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 163 | super().__init__() 164 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 165 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 166 | self.eps = eps 167 | self.data_format = data_format 168 | if self.data_format not in ["channels_last", "channels_first"]: 169 | raise NotImplementedError 170 | self.normalized_shape = (normalized_shape, ) 171 | 172 | def forward(self, x): 173 | if self.data_format == "channels_last": 174 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 175 | elif self.data_format == "channels_first": 176 | u = x.mean(1, keepdim=True) 177 | s = (x - u).pow(2).mean(1, keepdim=True) 178 | x = (x - u) / torch.sqrt(s + self.eps) 179 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 180 | return x 181 | -------------------------------------------------------------------------------- /pretrain/README.md: -------------------------------------------------------------------------------- 1 | ## Preparation for ImageNet-1k pretraining 2 | 3 | See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset. 4 | 5 | **Note: for neural network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).** 6 | 7 | 8 | ## Tutorial for pretraining your own CNN model 9 | 10 | See [/pretrain/models/custom.py](/pretrain/models/custom.py). Your todo list is: 11 | 12 | - implement `get_downsample_ratio` in [/pretrain/models/custom.py line20](/pretrain/models/custom.py#L20). 13 | - implement `get_feature_map_channels` in [/pretrain/models/custom.py line29](/pretrain/models/custom.py#L29). 14 | - implement `forward` in [/pretrain/models/custom.py line38](/pretrain/models/custom.py#L38). 15 | - define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py line54](/pretrain/models/custom.py#L53-L54). 16 | - add default kwargs of `your_convnet(...)` in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34). 17 | - **Note: see [#54](/../../issues/54) if your CNN contains SE module or global average pooling layer, and see [#56](/../../issues/56) if it contains GroupNorm**. 18 | 19 | Then run the experiment with `--model=your_convnet`. 20 | 21 | 22 | ## Tutorial for pretraining on your own dataset 23 | 24 | See the comment of `build_dataset_to_pretrain` in [line55 of /pretrain/utils/imagenet.py](/pretrain/utils/imagenet.py#L55). Your todo list: 25 | 26 | - Define a subclass of `torch.utils.data.Dataset` for your own unlabeled dataset, to replace our `ImageNetDataset`. 27 | - Use `args.data_path` and `args.input_size` to help build your dataset, with `--data_path=... --input_size=...` to specify them. 28 | - Note the batch size `--bs` is the total batch size of all GPU, which may need to be adjusted based on your dataset size. FYI: we use `--bs=4096` for ImageNet, which contains 1.28 million images. 29 | 30 | **If your dataset is relatively small**, you can try `--init_weight=/path/to/res50_withdecoder_1kpretrained_spark_style.pth` to do your pretraining *from our pretrained weights*, rather than *form scratch*. 31 | 32 | ## Debug on 1 GPU (without DistributedDataParallel) 33 | 34 | Use a small batch size `--bs=32` for avoiding OOM. 35 | 36 | ```shell script 37 | python3 main.py --exp_name=debug --data_path=/path/to/imagenet --model=resnet50 --bs=32 38 | ``` 39 | 40 | 41 | ## Pretraining Any Model on ImageNet-1k (224x224) 42 | 43 | For pretraining, run [/pretrain/main.py](/pretrain/main.py) with `torchrun`. 44 | **It is required to specify** the ImageNet data folder (`--data_path`), your experiment name & log dir (`--exp_name` and `--exp_dir`, automatically created if not exists), and the model name (`--model`, valid choices see the keys of 'pretrain_default_model_kwargs' in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34)). 45 | 46 | We use the **same** pretraining configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts) in 224 pretraining. 47 | Their **names** and **default values** are in [/pretrain/utils/arg_util.py line23-44](/pretrain/utils/arg_util.py#L23-L44). 48 | All these default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`. 49 | 50 | **Note: the batch size `--bs` is the total batch size of all GPU, and the learning rate `--base_lr` is the base lr. The actual lr would be `lr = base_lr * bs / 256`, as in [/pretrain/utils/arg_util.py line131](/pretrain/utils/arg_util.py#L131). So do not use `--lr` to specify a lr (that will be ignored)** 51 | 52 | Here is an example to pretrain a ResNet50 on an 8-GPU single machine (we use DistributedDataParallel), overwriting the default batch size to 512: 53 | ```shell script 54 | $ cd /path/to/SparK/pretrain 55 | $ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port= main.py \ 56 | --data_path=/path/to/imagenet --exp_name= --exp_dir=/path/to/logdir \ 57 | --model=resnet50 --bs=512 58 | ``` 59 | 60 | For multiple machines, change the `--nnodes`, `--node_rank`, `--master_address` and `--master_port` to your configurations. E.g.: 61 | ```shell script 62 | $ torchrun --nproc_per_node=8 --nnodes= --node_rank= --master_address= --master_port= main.py \ 63 | ... 64 | ``` 65 | 66 | ## Pretraining ConvNeXt-Large on ImageNet-1k (384x384) 67 | 68 | For 384 pretraining we use a larger mask ratio (0.75), a half batch size (2048), and a double base learning rate (4e-4): 69 | 70 | ```shell script 71 | $ cd /path/to/SparK/pretrain 72 | $ torchrun --nproc_per_node=8 --nnodes= --node_rank= --master_address= --master_port= main.py \ 73 | --data_path=/path/to/imagenet --exp_name= --exp_dir=/path/to/logdir \ 74 | --model=convnext_large --input_size=384 --mask=0.75 --bs=2048 --base_lr=4e-4 75 | ``` 76 | 77 | ## Logging 78 | 79 | See files in your `--exp_dir` to track your experiment: 80 | 81 | - `_withdecoder_1kpretrained_spark_style.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc.; can be used to resume pretraining; can also be used for visualization in [/pretrain/viz_reconstruction.ipynb](/pretrain/viz_reconstruction.ipynb) 82 | - `_1kpretrained_timm_style.pth`: can be used for downstream finetuning 83 | - `pretrain_log.txt`: records some important information such as: 84 | - `git_commit_id`: git version 85 | - `cmd`: the command of this experiment 86 | 87 | It also reports the loss and remaining pretraining time. 88 | 89 | - `tensorboard_log/`: saves a lot of tensorboard logs including loss values, learning rates, gradient norms and more things. Use `tensorboard --logdir /path/to/this/tensorboard_log/ --port 23333` for viz. 90 | - `stdout_backup.txt` and `stderr_backup.txt`: backups stdout/stderr. 91 | 92 | ## Resuming 93 | 94 | Specify `--resume_from=path/to/_withdecoder_1kpretrained_spark_style.pth` to resume pretraining. Note this is different from `--init_weight`: 95 | 96 | - `--resume_from` will load three things: model weights, optimizer states, and current epoch, so it is used to resume some interrupted experiment (will start from that 'current epoch'). 97 | - `--init_weight` ONLY loads the model weights, so it's just like a model initialization (will start from epoch 0). 98 | 99 | 100 | ## Regarding sparse convolution 101 | 102 | We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardware. 103 | As can be found in [/pretrain/encoder.py](/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution. 104 | We also define some sparse pooling or normalization layers in [/pretrain/encoder.py](/pretrain/encoder.py). 105 | All these "sparse" layers are implemented through pytorch built-in operators. 106 | 107 | 108 | ## Some details: how we mask images and how to set the patch size 109 | 110 | In SparK, the mask patch size **equals to** the downsample ratio of the CNN model (so there is no configuration like `--patch_size=32`). 111 | 112 | Here is the reason: when we do mask, we: 113 | 114 | 1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py#L86-L87), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_h, fmap_w]`, and would be used to mask the smallest feature map. 115 | 3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., dim=2)` and `repeat_interleave(..., dim=3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py#L21)) with larger resolutions . 116 | 117 | So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8. 118 | See [Tutorial for pretraining your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model). 119 | -------------------------------------------------------------------------------- /pretrain/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import List 9 | 10 | import torch 11 | import torch.nn as nn 12 | from timm.models.layers import trunc_normal_ 13 | 14 | from utils.misc import is_pow2n 15 | 16 | 17 | class UNetBlock(nn.Module): 18 | def __init__(self, cin, cout, bn2d): 19 | """ 20 | a UNet block with 2x up sampling 21 | """ 22 | super().__init__() 23 | self.up_sample = nn.ConvTranspose2d(cin, cin, kernel_size=4, stride=2, padding=1, bias=True) 24 | self.conv = nn.Sequential( 25 | nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cin), nn.ReLU6(inplace=True), 26 | nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cout), 27 | ) 28 | 29 | def forward(self, x): 30 | x = self.up_sample(x) 31 | return self.conv(x) 32 | 33 | 34 | class LightDecoder(nn.Module): 35 | def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule 36 | super().__init__() 37 | self.width = width 38 | assert is_pow2n(up_sample_ratio) 39 | n = round(math.log2(up_sample_ratio)) 40 | channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule 41 | bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d 42 | self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])]) 43 | self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True) 44 | 45 | self.initialize() 46 | 47 | def forward(self, to_dec: List[torch.Tensor]): 48 | x = 0 49 | for i, d in enumerate(self.dec): 50 | if i < len(to_dec) and to_dec[i] is not None: 51 | x = x + to_dec[i] 52 | x = self.dec[i](x) 53 | return self.proj(x) 54 | 55 | def extra_repr(self) -> str: 56 | return f'width={self.width}' 57 | 58 | def initialize(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Linear): 61 | trunc_normal_(m.weight, std=.02) 62 | if m.bias is not None: 63 | nn.init.constant_(m.bias, 0) 64 | elif isinstance(m, nn.Conv2d): 65 | trunc_normal_(m.weight, std=.02) 66 | if m.bias is not None: 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 69 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 70 | if m.bias is not None: 71 | nn.init.constant_(m.bias, 0.) 72 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)): 73 | nn.init.constant_(m.bias, 0) 74 | nn.init.constant_(m.weight, 1.0) 75 | -------------------------------------------------------------------------------- /pretrain/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import List 9 | from typing import Union 10 | 11 | import sys 12 | import torch 13 | import torch.distributed as tdist 14 | import torch.multiprocessing as mp 15 | 16 | __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' 17 | __initialized = False 18 | 19 | 20 | def initialized(): 21 | return __initialized 22 | 23 | 24 | def initialize(backend='nccl'): 25 | global __device 26 | if not torch.cuda.is_available(): 27 | print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) 28 | return 29 | elif 'RANK' not in os.environ: 30 | __device = torch.empty(1).cuda().device 31 | print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr) 32 | return 33 | 34 | # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 35 | if mp.get_start_method(allow_none=True) is None: 36 | mp.set_start_method('spawn') 37 | global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() 38 | local_rank = global_rank % num_gpus 39 | torch.cuda.set_device(local_rank) 40 | tdist.init_process_group(backend=backend) 41 | 42 | global __rank, __local_rank, __world_size, __initialized 43 | __local_rank = local_rank 44 | __rank, __world_size = tdist.get_rank(), tdist.get_world_size() 45 | __device = torch.empty(1).cuda().device 46 | __initialized = True 47 | 48 | assert tdist.is_initialized(), 'torch.distributed is not initialized!' 49 | 50 | 51 | def get_rank(): 52 | return __rank 53 | 54 | 55 | def get_local_rank(): 56 | return __local_rank 57 | 58 | 59 | def get_world_size(): 60 | return __world_size 61 | 62 | 63 | def get_device(): 64 | return __device 65 | 66 | 67 | def is_master(): 68 | return __rank == 0 69 | 70 | 71 | def is_local_master(): 72 | return __local_rank == 0 73 | 74 | 75 | def barrier(): 76 | if __initialized: 77 | tdist.barrier() 78 | 79 | 80 | def parallelize(net, syncbn=False): 81 | if syncbn: 82 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) 83 | net = net.cuda() 84 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) 85 | return net 86 | 87 | 88 | def allreduce(t: torch.Tensor) -> None: 89 | if __initialized: 90 | if not t.is_cuda: 91 | cu = t.detach().cuda() 92 | tdist.all_reduce(cu) 93 | t.copy_(cu.cpu()) 94 | else: 95 | tdist.all_reduce(t) 96 | 97 | 98 | def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: 99 | if __initialized: 100 | if not t.is_cuda: 101 | t = t.cuda() 102 | ls = [torch.empty_like(t) for _ in range(__world_size)] 103 | tdist.all_gather(ls, t) 104 | else: 105 | ls = [t] 106 | if cat: 107 | ls = torch.cat(ls, dim=0) 108 | return ls 109 | 110 | 111 | def broadcast(t: torch.Tensor, src_rank) -> None: 112 | if __initialized: 113 | if not t.is_cuda: 114 | cu = t.detach().cuda() 115 | tdist.broadcast(cu, src=src_rank) 116 | t.copy_(cu.cpu()) 117 | else: 118 | tdist.broadcast(t, src=src_rank) 119 | -------------------------------------------------------------------------------- /pretrain/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from timm.models.layers import DropPath 10 | 11 | 12 | _cur_active: torch.Tensor = None # B1ff 13 | # todo: try to use `gather` for speed? 14 | def _get_active_ex_or_ii(H, W, returning_active_ex=True): 15 | h_repeat, w_repeat = H // _cur_active.shape[-2], W // _cur_active.shape[-1] 16 | active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3) 17 | return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi 18 | 19 | 20 | def sp_conv_forward(self, x: torch.Tensor): 21 | x = super(type(self), self).forward(x) 22 | x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv 23 | return x 24 | 25 | 26 | def sp_bn_forward(self, x: torch.Tensor): 27 | ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False) 28 | 29 | bhwc = x.permute(0, 2, 3, 1) 30 | nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc` 31 | nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc` 32 | 33 | bchw = torch.zeros_like(bhwc) 34 | bchw[ii] = nc 35 | bchw = bchw.permute(0, 3, 1, 2) 36 | return bchw 37 | 38 | 39 | class SparseConv2d(nn.Conv2d): 40 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details 41 | 42 | 43 | class SparseMaxPooling(nn.MaxPool2d): 44 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details 45 | 46 | 47 | class SparseAvgPooling(nn.AvgPool2d): 48 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details 49 | 50 | 51 | class SparseBatchNorm2d(nn.BatchNorm1d): 52 | forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details 53 | 54 | 55 | class SparseSyncBatchNorm2d(nn.SyncBatchNorm): 56 | forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details 57 | 58 | 59 | class SparseConvNeXtLayerNorm(nn.LayerNorm): 60 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 61 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 62 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 63 | with shape (batch_size, channels, height, width). 64 | """ 65 | 66 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True): 67 | if data_format not in ["channels_last", "channels_first"]: 68 | raise NotImplementedError 69 | super().__init__(normalized_shape, eps, elementwise_affine=True) 70 | self.data_format = data_format 71 | self.sparse = sparse 72 | 73 | def forward(self, x): 74 | if x.ndim == 4: # BHWC or BCHW 75 | if self.data_format == "channels_last": # BHWC 76 | if self.sparse: 77 | ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], returning_active_ex=False) 78 | nc = x[ii] 79 | nc = super(SparseConvNeXtLayerNorm, self).forward(nc) 80 | 81 | x = torch.zeros_like(x) 82 | x[ii] = nc 83 | return x 84 | else: 85 | return super(SparseConvNeXtLayerNorm, self).forward(x) 86 | else: # channels_first, BCHW 87 | if self.sparse: 88 | ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False) 89 | bhwc = x.permute(0, 2, 3, 1) 90 | nc = bhwc[ii] 91 | nc = super(SparseConvNeXtLayerNorm, self).forward(nc) 92 | 93 | x = torch.zeros_like(bhwc) 94 | x[ii] = nc 95 | return x.permute(0, 3, 1, 2) 96 | else: 97 | u = x.mean(1, keepdim=True) 98 | s = (x - u).pow(2).mean(1, keepdim=True) 99 | x = (x - u) / torch.sqrt(s + self.eps) 100 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 101 | return x 102 | else: # BLC or BC 103 | if self.sparse: 104 | raise NotImplementedError 105 | else: 106 | return super(SparseConvNeXtLayerNorm, self).forward(x) 107 | 108 | def __repr__(self): 109 | return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})' 110 | 111 | 112 | class SparseConvNeXtBlock(nn.Module): 113 | r""" ConvNeXt Block. There are two equivalent implementations: 114 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 115 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 116 | We use (2) as we find it slightly faster in PyTorch 117 | 118 | Args: 119 | dim (int): Number of input channels. 120 | drop_path (float): Stochastic depth rate. Default: 0.0 121 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 122 | """ 123 | 124 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7): 125 | super().__init__() 126 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv 127 | self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse) 128 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 129 | self.act = nn.GELU() 130 | self.pwconv2 = nn.Linear(4 * dim, dim) 131 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 132 | requires_grad=True) if layer_scale_init_value > 0 else None 133 | self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity() 134 | self.sparse = sparse 135 | 136 | def forward(self, x): 137 | input = x 138 | x = self.dwconv(x) 139 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 140 | x = self.norm(x) 141 | x = self.pwconv1(x) 142 | x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`) 143 | x = self.pwconv2(x) 144 | if self.gamma is not None: 145 | x = self.gamma * x 146 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 147 | 148 | if self.sparse: 149 | x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) 150 | 151 | x = input + self.drop_path(x) 152 | return x 153 | 154 | def __repr__(self): 155 | return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})' 156 | 157 | 158 | class SparseEncoder(nn.Module): 159 | def __init__(self, cnn, input_size, sbn=False, verbose=False): 160 | super(SparseEncoder, self).__init__() 161 | self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn) 162 | self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels() 163 | 164 | @staticmethod 165 | def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False): 166 | oup = m 167 | if isinstance(m, nn.Conv2d): 168 | m: nn.Conv2d 169 | bias = m.bias is not None 170 | oup = SparseConv2d( 171 | m.in_channels, m.out_channels, 172 | kernel_size=m.kernel_size, stride=m.stride, padding=m.padding, 173 | dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode, 174 | ) 175 | oup.weight.data.copy_(m.weight.data) 176 | if bias: 177 | oup.bias.data.copy_(m.bias.data) 178 | elif isinstance(m, nn.MaxPool2d): 179 | m: nn.MaxPool2d 180 | oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode) 181 | elif isinstance(m, nn.AvgPool2d): 182 | m: nn.AvgPool2d 183 | oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override) 184 | elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): 185 | m: nn.BatchNorm2d 186 | oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats) 187 | oup.weight.data.copy_(m.weight.data) 188 | oup.bias.data.copy_(m.bias.data) 189 | oup.running_mean.data.copy_(m.running_mean.data) 190 | oup.running_var.data.copy_(m.running_var.data) 191 | oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data) 192 | if hasattr(m, "qconfig"): 193 | oup.qconfig = m.qconfig 194 | elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm): 195 | m: nn.LayerNorm 196 | oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps) 197 | oup.weight.data.copy_(m.weight.data) 198 | oup.bias.data.copy_(m.bias.data) 199 | elif isinstance(m, (nn.Conv1d,)): 200 | raise NotImplementedError 201 | 202 | for name, child in m.named_children(): 203 | oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn)) 204 | del m 205 | return oup 206 | 207 | def forward(self, x): 208 | return self.sp_cnn(x, hierarchical=True) 209 | -------------------------------------------------------------------------------- /pretrain/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import datetime 8 | import math 9 | import sys 10 | import time 11 | from functools import partial 12 | from typing import List 13 | 14 | import torch 15 | from torch.nn.parallel import DistributedDataParallel 16 | from torch.utils.data import DataLoader 17 | 18 | import dist 19 | import encoder 20 | from decoder import LightDecoder 21 | from models import build_sparse_encoder 22 | from sampler import DistInfiniteBatchSampler, worker_init_fn 23 | from spark import SparK 24 | from utils import arg_util, misc, lamb 25 | from utils.imagenet import build_dataset_to_pretrain 26 | from utils.lr_control import lr_wd_annealing, get_param_groups 27 | 28 | 29 | class LocalDDP(torch.nn.Module): 30 | def __init__(self, module): 31 | super(LocalDDP, self).__init__() 32 | self.module = module 33 | 34 | def forward(self, *args, **kwargs): 35 | return self.module(*args, **kwargs) 36 | 37 | 38 | def main_pt(): 39 | args: arg_util.Args = arg_util.init_dist_and_get_args() 40 | print(f'initial args:\n{str(args)}') 41 | args.log_epoch() 42 | 43 | # build data 44 | print(f'[build data for pre-training] ...\n') 45 | dataset_train = build_dataset_to_pretrain(args.data_path, args.input_size) 46 | data_loader_train = DataLoader( 47 | dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True, 48 | batch_sampler=DistInfiniteBatchSampler( 49 | dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, 50 | shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(), 51 | ), worker_init_fn=worker_init_fn 52 | ) 53 | itrt_train, iters_train = iter(data_loader_train), len(data_loader_train) 54 | print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}') 55 | 56 | # build encoder and decoder 57 | enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False) 58 | dec = LightDecoder(enc.downsample_raito, sbn=args.sbn) 59 | model_without_ddp = SparK( 60 | sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, 61 | densify_norm=args.densify_norm, sbn=args.sbn, 62 | ).to(args.device) 63 | print(f'[PT model] model = {model_without_ddp}\n') 64 | 65 | # the model has been randomly initialized in their construction time 66 | # now try to load some checkpoint as model weight initialization; this ONLY loads the model weights 67 | misc.initialize_weight(args.init_weight, model_without_ddp) 68 | 69 | if dist.initialized(): 70 | model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) 71 | else: 72 | model = LocalDDP(model_without_ddp) 73 | 74 | # build optimizer and lr_scheduler 75 | param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}) 76 | opt_clz = { 77 | 'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True), 78 | 'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)), 79 | 'lamb': partial(lamb.TheSameAsTimmLAMB, betas=(0.9, args.ada), max_grad_norm=5.0), 80 | }[args.opt] 81 | optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0) 82 | print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n') 83 | 84 | # try to resume the experiment from some checkpoint.pth; this will load model weights, optimizer states, and last epoch (ep_start) 85 | # if loaded, ep_start will be greater than 0 86 | ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer) 87 | if ep_start >= args.ep: # load from a complete checkpoint file 88 | print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}') 89 | else: # perform pre-training 90 | tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt') 91 | min_loss = 1e9 92 | print(f'[PT start] from ep{ep_start}') 93 | 94 | pt_start_time = time.time() 95 | for ep in range(ep_start, args.ep): 96 | ep_start_time = time.time() 97 | tb_lg.set_step(ep * iters_train) 98 | if hasattr(itrt_train, 'set_epoch'): 99 | itrt_train.set_epoch(ep) 100 | 101 | stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer) 102 | last_loss = stats['last_loss'] 103 | min_loss = min(min_loss, last_loss) 104 | performance_desc = f'{min_loss:.4f} {last_loss:.4f}' 105 | misc.save_checkpoint_with_meta_info_and_opt_state(f'{args.model}_withdecoder_1kpretrained_spark_style.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict()) 106 | misc.save_checkpoint_model_weights_only(f'{args.model}_1kpretrained_timm_style.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict()) 107 | 108 | ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost 109 | remain_secs = (args.ep-1 - ep) * ep_cost 110 | remain_time = datetime.timedelta(seconds=round(remain_secs)) 111 | finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)) 112 | print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}') 113 | 114 | args.cur_ep = f'{ep + 1}/{args.ep}' 115 | args.remain_time, args.finish_time = str(remain_time), str(finish_time) 116 | args.last_loss = last_loss 117 | args.log_epoch() 118 | 119 | tb_lg.update(min_loss=min_loss, head='train', step=ep) 120 | tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep) 121 | tb_lg.flush() 122 | 123 | # finish pre-training 124 | tb_lg.update(min_loss=min_loss, head='result', step=ep_start) 125 | tb_lg.update(min_loss=min_loss, head='result', step=args.ep) 126 | tb_lg.flush() 127 | print(f'final args:\n{str(args)}') 128 | print('\n\n') 129 | print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n') 130 | print('\n\n') 131 | tb_lg.close() 132 | time.sleep(10) 133 | 134 | args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time())) 135 | args.log_epoch() 136 | 137 | 138 | def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer): 139 | model.train() 140 | me = misc.MetricLogger(delimiter=' ') 141 | me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}')) 142 | header = f'[PT] Epoch {ep}:' 143 | 144 | optimizer.zero_grad() 145 | early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm') 146 | late_clipping = hasattr(optimizer, 'global_grad_norm') 147 | if early_clipping: 148 | params_req_grad = [p for p in model.parameters() if p.requires_grad] 149 | 150 | for it, inp in enumerate(me.log_every(iters_train, itrt_train, 3, header)): 151 | # adjust lr and wd 152 | min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train) 153 | 154 | # forward and backward 155 | inp = inp.to(args.device, non_blocking=True) 156 | SparK.forward 157 | loss = model(inp, active_b1ff=None, vis=False) 158 | optimizer.zero_grad() 159 | loss.backward() 160 | loss = loss.item() 161 | if not math.isfinite(loss): 162 | print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True) 163 | sys.exit(-1) 164 | 165 | # optimize 166 | grad_norm = None 167 | if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item() 168 | optimizer.step() 169 | if late_clipping: grad_norm = optimizer.global_grad_norm 170 | torch.cuda.synchronize() 171 | 172 | # log 173 | me.update(last_loss=loss) 174 | me.update(max_lr=max_lr) 175 | tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss') 176 | tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max') 177 | tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min') 178 | tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max') 179 | tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min') 180 | 181 | if grad_norm is not None: 182 | me.update(orig_norm=grad_norm) 183 | tb_lg.update(orig_norm=grad_norm, head='train_hp') 184 | tb_lg.set_step() 185 | 186 | me.synchronize_between_processes() 187 | return {k: meter.global_avg for k, meter in me.meters.items()} 188 | 189 | 190 | if __name__ == '__main__': 191 | main_pt() 192 | -------------------------------------------------------------------------------- /pretrain/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from timm import create_model 9 | from timm.loss import SoftTargetCrossEntropy 10 | from timm.models.layers import drop 11 | 12 | 13 | from models.convnext import ConvNeXt 14 | from models.resnet import ResNet 15 | from models.custom import YourConvNet 16 | _import_resnets_for_timm_registration = (ResNet,) 17 | 18 | 19 | # log more 20 | def _ex_repr(self): 21 | return ', '.join( 22 | f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) 23 | for k, v in vars(self).items() 24 | if not k.startswith('_') and k != 'training' 25 | and not isinstance(v, (torch.nn.Module, torch.Tensor)) 26 | ) 27 | for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath): 28 | if hasattr(clz, 'extra_repr'): 29 | clz.extra_repr = _ex_repr 30 | else: 31 | clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' 32 | 33 | 34 | pretrain_default_model_kwargs = { 35 | 'your_convnet': dict(), 36 | 'resnet50': dict(drop_path_rate=0.05), 37 | 'resnet101': dict(drop_path_rate=0.08), 38 | 'resnet152': dict(drop_path_rate=0.10), 39 | 'resnet200': dict(drop_path_rate=0.15), 40 | 'convnext_small': dict(sparse=True, drop_path_rate=0.2), 41 | 'convnext_base': dict(sparse=True, drop_path_rate=0.3), 42 | 'convnext_large': dict(sparse=True, drop_path_rate=0.4), 43 | } 44 | for kw in pretrain_default_model_kwargs.values(): 45 | kw['pretrained'] = False 46 | kw['num_classes'] = 0 47 | kw['global_pool'] = '' 48 | 49 | 50 | def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): 51 | from encoder import SparseEncoder 52 | 53 | kwargs = pretrain_default_model_kwargs[name] 54 | if drop_path_rate != 0: 55 | kwargs['drop_path_rate'] = drop_path_rate 56 | print(f'[build_sparse_encoder] model kwargs={kwargs}') 57 | cnn = create_model(name, **kwargs) 58 | 59 | return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose) 60 | 61 | -------------------------------------------------------------------------------- /pretrain/models/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py 8 | from typing import List 9 | 10 | import torch 11 | import torch.nn as nn 12 | from timm.models.layers import trunc_normal_ 13 | from timm.models.registry import register_model 14 | 15 | from encoder import SparseConvNeXtBlock, SparseConvNeXtLayerNorm 16 | 17 | 18 | class ConvNeXt(nn.Module): 19 | r""" ConvNeXt 20 | A PyTorch impl of : `A ConvNet for the 2020s` - 21 | https://arxiv.org/pdf/2201.03545.pdf 22 | Args: 23 | in_chans (int): Number of input image channels. Default: 3 24 | num_classes (int): Number of classes for classification head. Default: 1000 25 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 26 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 27 | drop_path_rate (float): Stochastic depth rate. Default: 0. 28 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 29 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 30 | """ 31 | 32 | def __init__(self, in_chans=3, num_classes=1000, 33 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 34 | layer_scale_init_value=1e-6, head_init_scale=1., global_pool='avg', 35 | sparse=True, 36 | ): 37 | super().__init__() 38 | self.dims: List[int] = dims 39 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 40 | stem = nn.Sequential( 41 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 42 | SparseConvNeXtLayerNorm(dims[0], eps=1e-6, data_format="channels_first", sparse=sparse) 43 | ) 44 | self.downsample_layers.append(stem) 45 | for i in range(3): 46 | downsample_layer = nn.Sequential( 47 | SparseConvNeXtLayerNorm(dims[i], eps=1e-6, data_format="channels_first", sparse=sparse), 48 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), 49 | ) 50 | self.downsample_layers.append(downsample_layer) 51 | 52 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 53 | self.drop_path_rate = drop_path_rate 54 | self.layer_scale_init_value = layer_scale_init_value 55 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 56 | cur = 0 57 | for i in range(4): 58 | stage = nn.Sequential( 59 | *[SparseConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j], 60 | layer_scale_init_value=layer_scale_init_value, sparse=sparse) for j in range(depths[i])] 61 | ) 62 | self.stages.append(stage) 63 | cur += depths[i] 64 | self.depths = depths 65 | 66 | self.apply(self._init_weights) 67 | if num_classes > 0: 68 | self.norm = SparseConvNeXtLayerNorm(dims[-1], eps=1e-6, sparse=False) # final norm layer for LE/FT; should not be sparse 69 | self.fc = nn.Linear(dims[-1], num_classes) 70 | else: 71 | self.norm = nn.Identity() 72 | self.fc = nn.Identity() 73 | 74 | def _init_weights(self, m): 75 | if isinstance(m, (nn.Conv2d, nn.Linear)): 76 | trunc_normal_(m.weight, std=.02) 77 | nn.init.constant_(m.bias, 0) 78 | 79 | def get_downsample_ratio(self) -> int: 80 | return 32 81 | 82 | def get_feature_map_channels(self) -> List[int]: 83 | return self.dims 84 | 85 | def forward(self, x, hierarchical=False): 86 | if hierarchical: 87 | ls = [] 88 | for i in range(4): 89 | x = self.downsample_layers[i](x) 90 | x = self.stages[i](x) 91 | ls.append(x) 92 | return ls 93 | else: 94 | return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls) 95 | 96 | def get_classifier(self): 97 | return self.fc 98 | 99 | def extra_repr(self): 100 | return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}' 101 | 102 | 103 | @register_model 104 | def convnext_tiny(pretrained=False, in_22k=False, **kwargs): 105 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 106 | return model 107 | 108 | 109 | @register_model 110 | def convnext_small(pretrained=False, in_22k=False, **kwargs): 111 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 112 | return model 113 | 114 | 115 | @register_model 116 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 117 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 118 | return model 119 | 120 | 121 | @register_model 122 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 123 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 124 | return model 125 | 126 | -------------------------------------------------------------------------------- /pretrain/models/custom.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from typing import List 10 | from timm.models.registry import register_model 11 | 12 | 13 | class YourConvNet(nn.Module): 14 | """ 15 | This is a template for your custom ConvNet. 16 | It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`. 17 | You can refer to the implementations in `pretrain\models\resnet.py` for an example. 18 | """ 19 | 20 | def get_downsample_ratio(self) -> int: 21 | """ 22 | This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). 23 | 24 | :return: the TOTAL downsample ratio of the ConvNet. 25 | E.g., for a ResNet-50, this should return 32. 26 | """ 27 | raise NotImplementedError 28 | 29 | def get_feature_map_channels(self) -> List[int]: 30 | """ 31 | This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). 32 | 33 | :return: a list of the number of channels of each feature map. 34 | E.g., for a ResNet-50, this should return [256, 512, 1024, 2048]. 35 | """ 36 | raise NotImplementedError 37 | 38 | def forward(self, inp_bchw: torch.Tensor, hierarchical=False): 39 | """ 40 | The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`). 41 | 42 | :param inp_bchw: input image tensor, shape: (batch_size, channels, height, width). 43 | :param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical). 44 | :return: 45 | - hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes). 46 | - hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`. 47 | E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map]. 48 | for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)] 49 | """ 50 | raise NotImplementedError 51 | 52 | 53 | @register_model 54 | def your_convnet_small(pretrained=False, **kwargs): 55 | raise NotImplementedError 56 | return YourConvNet(**kwargs) 57 | 58 | 59 | @torch.no_grad() 60 | def convnet_test(): 61 | from timm.models import create_model 62 | cnn = create_model('your_convnet_small') 63 | print('get_downsample_ratio:', cnn.get_downsample_ratio()) 64 | print('get_feature_map_channels:', cnn.get_feature_map_channels()) 65 | 66 | downsample_ratio = cnn.get_downsample_ratio() 67 | feature_map_channels = cnn.get_feature_map_channels() 68 | 69 | # check the forward function 70 | B, C, H, W = 4, 3, 224, 224 71 | inp = torch.rand(B, C, H, W) 72 | feats = cnn(inp, hierarchical=True) 73 | assert isinstance(feats, list) 74 | assert len(feats) == len(feature_map_channels) 75 | print([tuple(t.shape) for t in feats]) 76 | 77 | # check the downsample ratio 78 | feats = cnn(inp, hierarchical=True) 79 | assert feats[-1].shape[-2] == H // downsample_ratio 80 | assert feats[-1].shape[-1] == W // downsample_ratio 81 | 82 | # check the channel number 83 | for feat, ch in zip(feats, feature_map_channels): 84 | assert feat.ndim == 4 85 | assert feat.shape[1] == ch 86 | 87 | 88 | if __name__ == '__main__': 89 | convnet_test() 90 | -------------------------------------------------------------------------------- /pretrain/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import List 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from timm.models.resnet import ResNet 11 | 12 | 13 | # hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet` 14 | def get_downsample_ratio(self: ResNet) -> int: 15 | return 32 16 | 17 | 18 | # hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet` 19 | def get_feature_map_channels(self: ResNet) -> List[int]: 20 | # `self.feature_info` is maintained by `timm` 21 | return [info['num_chs'] for info in self.feature_info[1:]] 22 | 23 | 24 | # hack: override the forward function of `timm.models.resnet.ResNet` 25 | def forward(self, x, hierarchical=False): 26 | """ this forward function is a modified version of `timm.models.resnet.ResNet.forward` 27 | >>> ResNet.forward 28 | """ 29 | x = self.conv1(x) 30 | x = self.bn1(x) 31 | x = self.act1(x) 32 | x = self.maxpool(x) 33 | 34 | if hierarchical: 35 | ls = [] 36 | x = self.layer1(x); ls.append(x) 37 | x = self.layer2(x); ls.append(x) 38 | x = self.layer3(x); ls.append(x) 39 | x = self.layer4(x); ls.append(x) 40 | return ls 41 | else: 42 | x = self.global_pool(x) 43 | if self.drop_rate: 44 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 45 | x = self.fc(x) 46 | return x 47 | 48 | 49 | ResNet.get_downsample_ratio = get_downsample_ratio 50 | ResNet.get_feature_map_channels = get_feature_map_channels 51 | ResNet.forward = forward 52 | 53 | 54 | @torch.no_grad() 55 | def convnet_test(): 56 | from timm.models import create_model 57 | cnn = create_model('resnet50') 58 | print('get_downsample_ratio:', cnn.get_downsample_ratio()) 59 | print('get_feature_map_channels:', cnn.get_feature_map_channels()) 60 | 61 | downsample_ratio = cnn.get_downsample_ratio() 62 | feature_map_channels = cnn.get_feature_map_channels() 63 | 64 | # check the forward function 65 | B, C, H, W = 4, 3, 224, 224 66 | inp = torch.rand(B, C, H, W) 67 | feats = cnn(inp, hierarchical=True) 68 | assert isinstance(feats, list) 69 | assert len(feats) == len(feature_map_channels) 70 | print([tuple(t.shape) for t in feats]) 71 | 72 | # check the downsample ratio 73 | feats = cnn(inp, hierarchical=True) 74 | assert feats[-1].shape[-2] == H // downsample_ratio 75 | assert feats[-1].shape[-1] == W // downsample_ratio 76 | 77 | # check the channel number 78 | for feat, ch in zip(feats, feature_map_channels): 79 | assert feat.ndim == 4 80 | assert feat.shape[1] == ch 81 | 82 | 83 | if __name__ == '__main__': 84 | convnet_test() 85 | -------------------------------------------------------------------------------- /pretrain/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | Pillow 4 | typed-argument-parser 5 | timm==0.5.4 6 | tensorboardx 7 | -------------------------------------------------------------------------------- /pretrain/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | from torch.utils.data.sampler import Sampler 12 | 13 | 14 | def worker_init_fn(worker_id): 15 | # https://pytorch.org/docs/stable/notes/randomness.html#dataloader 16 | worker_seed = torch.initial_seed() % 2 ** 32 17 | np.random.seed(worker_seed) 18 | random.seed(worker_seed) 19 | 20 | 21 | class DistInfiniteBatchSampler(Sampler): 22 | def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True): 23 | assert glb_batch_size % world_size == 0 24 | self.world_size, self.rank = world_size, rank 25 | self.dataset_len = dataset_len 26 | self.glb_batch_size = glb_batch_size 27 | self.batch_size = glb_batch_size // world_size 28 | 29 | self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size 30 | self.filling = filling 31 | self.shuffle = shuffle 32 | self.epoch = 0 33 | self.seed = seed 34 | self.indices = self.gener_indices() 35 | 36 | def gener_indices(self): 37 | global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 38 | if self.shuffle: 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch + self.seed) 41 | global_indices = torch.randperm(self.dataset_len, generator=g) 42 | else: 43 | global_indices = torch.arange(self.dataset_len) 44 | filling = global_max_p - global_indices.shape[0] 45 | if filling > 0 and self.filling: 46 | global_indices = torch.cat((global_indices, global_indices[:filling])) 47 | global_indices = tuple(global_indices.numpy().tolist()) 48 | 49 | seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int) 50 | local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]] 51 | self.max_p = len(local_indices) 52 | return local_indices 53 | 54 | def __iter__(self): 55 | self.epoch = 0 56 | while True: 57 | self.epoch += 1 58 | p, q = 0, 0 59 | while p < self.max_p: 60 | q = p + self.batch_size 61 | yield self.indices[p:q] 62 | p = q 63 | if self.shuffle: 64 | self.indices = self.gener_indices() 65 | 66 | def __len__(self): 67 | return self.iters_per_ep 68 | 69 | 70 | if __name__ == '__main__': 71 | W = 16 72 | for rk in range(W): 73 | ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices() 74 | print(rk, len(ind)) 75 | -------------------------------------------------------------------------------- /pretrain/spark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pprint import pformat 8 | from typing import List 9 | 10 | import sys 11 | import torch 12 | import torch.nn as nn 13 | from timm.models.layers import trunc_normal_ 14 | 15 | import encoder 16 | from decoder import LightDecoder 17 | 18 | 19 | class SparK(nn.Module): 20 | def __init__( 21 | self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder, 22 | mask_ratio=0.6, densify_norm='bn', sbn=False, 23 | ): 24 | super().__init__() 25 | input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito 26 | self.downsample_raito = downsample_raito 27 | self.fmap_h, self.fmap_w = input_size // downsample_raito, input_size // downsample_raito 28 | self.mask_ratio = mask_ratio 29 | self.len_keep = round(self.fmap_h * self.fmap_w * (1 - mask_ratio)) 30 | 31 | self.sparse_encoder = sparse_encoder 32 | self.dense_decoder = dense_decoder 33 | 34 | self.sbn = sbn 35 | self.hierarchy = len(sparse_encoder.enc_feat_map_chs) 36 | self.densify_norm_str = densify_norm.lower() 37 | self.densify_norms = nn.ModuleList() 38 | self.densify_projs = nn.ModuleList() 39 | self.mask_tokens = nn.ParameterList() 40 | 41 | # build the `densify` layers 42 | e_widths, d_width = self.sparse_encoder.enc_feat_map_chs, self.dense_decoder.width 43 | e_widths: List[int] 44 | for i in range(self.hierarchy): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ... 45 | e_width = e_widths.pop() 46 | # create mask token 47 | p = nn.Parameter(torch.zeros(1, e_width, 1, 1)) 48 | trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02) 49 | self.mask_tokens.append(p) 50 | 51 | # create densify norm 52 | if self.densify_norm_str == 'bn': 53 | densify_norm = (encoder.SparseSyncBatchNorm2d if self.sbn else encoder.SparseBatchNorm2d)(e_width) 54 | elif self.densify_norm_str == 'ln': 55 | densify_norm = encoder.SparseConvNeXtLayerNorm(e_width, data_format='channels_first', sparse=True) 56 | else: 57 | densify_norm = nn.Identity() 58 | self.densify_norms.append(densify_norm) 59 | 60 | # create densify proj 61 | if i == 0 and e_width == d_width: 62 | densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768 63 | print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj') 64 | else: 65 | kernel_size = 1 if i <= 0 else 3 66 | densify_proj = nn.Conv2d(e_width, d_width, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, bias=True) 67 | print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)') 68 | self.densify_projs.append(densify_proj) 69 | 70 | # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule 71 | d_width //= 2 72 | 73 | print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}') 74 | 75 | # these are deprecated and would never be used; can be removed. 76 | self.register_buffer('imn_m', torch.empty(1, 3, 1, 1)) 77 | self.register_buffer('imn_s', torch.empty(1, 3, 1, 1)) 78 | self.register_buffer('norm_black', torch.zeros(1, 3, input_size, input_size)) 79 | self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ... 80 | 81 | def mask(self, B: int, device, generator=None): 82 | h, w = self.fmap_h, self.fmap_w 83 | idx = torch.rand(B, h * w, generator=generator).argsort(dim=1) 84 | idx = idx[:, :self.len_keep].to(device) # (B, len_keep) 85 | return torch.zeros(B, h * w, dtype=torch.bool, device=device).scatter_(dim=1, index=idx, value=True).view(B, 1, h, w) 86 | 87 | def forward(self, inp_bchw: torch.Tensor, active_b1ff=None, vis=False): 88 | # step1. Mask 89 | if active_b1ff is None: # rand mask 90 | active_b1ff: torch.BoolTensor = self.mask(inp_bchw.shape[0], inp_bchw.device) # (B, 1, f, f) 91 | encoder._cur_active = active_b1ff # (B, 1, f, f) 92 | active_b1hw = active_b1ff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W) 93 | masked_bchw = inp_bchw * active_b1hw 94 | 95 | # step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales) 96 | fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw) 97 | fea_bcffs.reverse() # after reversion: from the smallest feature map to the largest 98 | 99 | # step3. Densify: get hierarchical dense features for decoding 100 | cur_active = active_b1ff # (B, 1, f, f) 101 | to_dec = [] 102 | for i, bcff in enumerate(fea_bcffs): # from the smallest feature map to the largest 103 | if bcff is not None: 104 | bcff = self.densify_norms[i](bcff) 105 | mask_tokens = self.mask_tokens[i].expand_as(bcff) 106 | bcff = torch.where(cur_active.expand_as(bcff), bcff, mask_tokens) # fill in empty (non-active) positions with [mask] tokens 107 | bcff: torch.Tensor = self.densify_projs[i](bcff) 108 | to_dec.append(bcff) 109 | cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) # dilate the mask map, from (B, 1, f, f) to (B, 1, H, W) 110 | 111 | # step4. Decode and reconstruct 112 | rec_bchw = self.dense_decoder(to_dec) 113 | inp, rec = self.patchify(inp_bchw), self.patchify(rec_bchw) # inp and rec: (B, L = f*f, N = C*downsample_raito**2) 114 | mean = inp.mean(dim=-1, keepdim=True) 115 | var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5 116 | inp = (inp - mean) / var 117 | l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) # (B, L, C) ==mean==> (B, L) 118 | 119 | non_active = active_b1ff.logical_not().int().view(active_b1ff.shape[0], -1) # (B, 1, f, f) => (B, L) 120 | recon_loss = l2_loss.mul_(non_active).sum() / (non_active.sum() + 1e-8) # loss only on masked (non-active) patches 121 | 122 | if vis: 123 | masked_bchw = inp_bchw * active_b1hw 124 | rec_bchw = self.unpatchify(rec * var + mean) 125 | rec_or_inp = torch.where(active_b1hw, inp_bchw, rec_bchw) 126 | return inp_bchw, masked_bchw, rec_or_inp 127 | else: 128 | return recon_loss 129 | 130 | def patchify(self, bchw): 131 | p = self.downsample_raito 132 | h, w = self.fmap_h, self.fmap_w 133 | B, C = bchw.shape[:2] 134 | bchw = bchw.reshape(shape=(B, C, h, p, w, p)) 135 | bchw = torch.einsum('bchpwq->bhwpqc', bchw) 136 | bln = bchw.reshape(shape=(B, h * w, C * p ** 2)) # (B, f*f, 3*downsample_raito**2) 137 | return bln 138 | 139 | def unpatchify(self, bln): 140 | p = self.downsample_raito 141 | h, w = self.fmap_h, self.fmap_w 142 | B, C = bln.shape[0], bln.shape[-1] // p ** 2 143 | bln = bln.reshape(shape=(B, h, w, p, p, C)) 144 | bln = torch.einsum('bhwpqc->bchpwq', bln) 145 | bchw = bln.reshape(shape=(B, C, h * p, w * p)) 146 | return bchw 147 | 148 | def __repr__(self): 149 | return ( 150 | f'\n' 151 | f'[SparK.config]: {pformat(self.get_config(), indent=2, width=250)}\n' 152 | f'[SparK.structure]: {super(SparK, self).__repr__().replace(SparK.__name__, "")}' 153 | ) 154 | 155 | def get_config(self): 156 | return { 157 | # self 158 | 'mask_ratio': self.mask_ratio, 159 | 'densify_norm_str': self.densify_norm_str, 160 | 'sbn': self.sbn, 'hierarchy': self.hierarchy, 161 | 162 | # enc 163 | 'sparse_encoder.input_size': self.sparse_encoder.input_size, 164 | # dec 165 | 'dense_decoder.width': self.dense_decoder.width, 166 | } 167 | 168 | def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False): 169 | state = super(SparK, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) 170 | if with_config: 171 | state['config'] = self.get_config() 172 | return state 173 | 174 | def load_state_dict(self, state_dict, strict=True): 175 | config: dict = state_dict.pop('config', None) 176 | incompatible_keys = super(SparK, self).load_state_dict(state_dict, strict=strict) 177 | if config is not None: 178 | for k, v in self.get_config().items(): 179 | ckpt_v = config.get(k, None) 180 | if ckpt_v != v: 181 | err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={ckpt_v})' 182 | if strict: 183 | raise AttributeError(err) 184 | else: 185 | print(err, file=sys.stderr) 186 | return incompatible_keys 187 | -------------------------------------------------------------------------------- /pretrain/utils/arg_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | import sys 10 | 11 | from tap import Tap 12 | 13 | import dist 14 | 15 | 16 | class Args(Tap): 17 | # environment 18 | exp_name: str = 'your_exp_name' 19 | exp_dir: str = 'your_exp_dir' # will be created if not exists 20 | data_path: str = 'imagenet_data_path' 21 | init_weight: str = '' # use some checkpoint as model weight initialization; ONLY load model weights 22 | resume_from: str = '' # resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch 23 | 24 | # SparK hyperparameters 25 | mask: float = 0.6 # mask ratio, should be in (0, 1) 26 | 27 | # encoder hyperparameters 28 | model: str = 'resnet50' 29 | input_size: int = 224 30 | sbn: bool = True 31 | 32 | # data hyperparameters 33 | bs: int = 4096 34 | dataloader_workers: int = 8 35 | 36 | # pre-training hyperparameters 37 | dp: float = 0.0 38 | base_lr: float = 2e-4 39 | wd: float = 0.04 40 | wde: float = 0.2 41 | ep: int = 1600 42 | wp_ep: int = 40 43 | clip: int = 5. 44 | opt: str = 'lamb' 45 | ada: float = 0. 46 | 47 | # NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically 48 | lr: float = None 49 | batch_size_per_gpu: int = 0 50 | glb_batch_size: int = 0 51 | densify_norm: str = '' 52 | device: str = 'cpu' 53 | local_rank: int = 0 54 | cmd: str = ' '.join(sys.argv[1:]) 55 | commit_id: str = os.popen(f'git rev-parse HEAD').read().strip() or '[unknown]' 56 | commit_msg: str = (os.popen(f'git log -1').read().strip().splitlines() or ['[unknown]'])[-1].strip() 57 | last_loss: float = 0. 58 | cur_ep: str = '' 59 | remain_time: str = '' 60 | finish_time: str = '' 61 | first_logging: bool = True 62 | log_txt_name: str = '{args.exp_dir}/pretrain_log.txt' 63 | tb_lg_dir: str = '' # tensorboard log directory 64 | 65 | @property 66 | def is_convnext(self): 67 | return 'convnext' in self.model or 'cnx' in self.model 68 | 69 | @property 70 | def is_resnet(self): 71 | return 'resnet' in self.model 72 | 73 | def log_epoch(self): 74 | if not dist.is_local_master(): 75 | return 76 | 77 | if self.first_logging: 78 | self.first_logging = False 79 | with open(self.log_txt_name, 'w') as fp: 80 | json.dump({ 81 | 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg, 82 | 'model': self.model, 83 | }, fp) 84 | fp.write('\n\n') 85 | 86 | with open(self.log_txt_name, 'a') as fp: 87 | json.dump({ 88 | 'cur_ep': self.cur_ep, 89 | 'last_L': self.last_loss, 90 | 'rema': self.remain_time, 'fini': self.finish_time, 91 | }, fp) 92 | fp.write('\n') 93 | 94 | 95 | def init_dist_and_get_args(): 96 | from utils import misc 97 | 98 | # initialize 99 | args = Args(explicit_bool=True).parse_args() 100 | e = os.path.abspath(args.exp_dir) 101 | d, e = os.path.dirname(e), os.path.basename(e) 102 | e = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in e) 103 | args.exp_dir = os.path.join(d, e) 104 | 105 | os.makedirs(args.exp_dir, exist_ok=True) 106 | args.log_txt_name = os.path.join(args.exp_dir, 'pretrain_log.txt') 107 | args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log') 108 | try: 109 | os.makedirs(args.tb_lg_dir, exist_ok=True) 110 | except: 111 | pass 112 | 113 | misc.init_distributed_environ(exp_dir=args.exp_dir) 114 | 115 | # update args 116 | if not dist.initialized(): 117 | args.sbn = False 118 | args.first_logging = True 119 | args.device = dist.get_device() 120 | args.batch_size_per_gpu = args.bs // dist.get_world_size() 121 | args.glb_batch_size = args.batch_size_per_gpu * dist.get_world_size() 122 | 123 | if args.is_resnet: 124 | args.ada = args.ada or 0.95 125 | args.densify_norm = 'bn' 126 | 127 | if args.is_convnext: 128 | args.ada = args.ada or 0.999 129 | args.densify_norm = 'ln' 130 | 131 | args.opt = args.opt.lower() 132 | args.lr = args.base_lr * args.glb_batch_size / 256 133 | args.wde = args.wde or args.wd 134 | 135 | return args 136 | -------------------------------------------------------------------------------- /pretrain/utils/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Any, Callable, Optional, Tuple 9 | 10 | import PIL.Image as PImage 11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS 13 | from torchvision.transforms import transforms 14 | from torch.utils.data import Dataset 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | interpolation = InterpolationMode.BICUBIC 19 | except: 20 | import PIL 21 | interpolation = PIL.Image.BICUBIC 22 | 23 | 24 | def pil_loader(path): 25 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 26 | with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') 27 | return img 28 | 29 | 30 | class ImageNetDataset(DatasetFolder): 31 | def __init__( 32 | self, 33 | imagenet_folder: str, 34 | train: bool, 35 | transform: Callable, 36 | is_valid_file: Optional[Callable[[str], bool]] = None, 37 | ): 38 | imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val') 39 | super(ImageNetDataset, self).__init__( 40 | imagenet_folder, 41 | loader=pil_loader, 42 | extensions=IMG_EXTENSIONS if is_valid_file is None else None, 43 | transform=transform, 44 | target_transform=None, is_valid_file=is_valid_file 45 | ) 46 | 47 | self.samples = tuple(img for (img, label) in self.samples) 48 | self.targets = None # this is self-supervised learning so we don't need labels 49 | 50 | def __getitem__(self, index: int) -> Any: 51 | img_file_path = self.samples[index] 52 | return self.transform(self.loader(img_file_path)) 53 | 54 | 55 | def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset: 56 | """ 57 | You may need to modify this function to return your own dataset. 58 | Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset. 59 | Use dataset_path to build your image file path list. 60 | Use input_size to create the transformation function for your images, can refer to the `trans_train` blow. 61 | 62 | :param dataset_path: the folder of dataset 63 | :param input_size: the input size (image resolution) 64 | :return: the dataset used for pretraining 65 | """ 66 | trans_train = transforms.Compose([ 67 | transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), 68 | transforms.RandomHorizontalFlip(), 69 | transforms.ToTensor(), 70 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 71 | ]) 72 | 73 | dataset_path = os.path.abspath(dataset_path) 74 | for postfix in ('train', 'val'): 75 | if dataset_path.endswith(postfix): 76 | dataset_path = dataset_path[:-len(postfix)] 77 | 78 | dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True) 79 | print_transform(trans_train, '[pre-train]') 80 | return dataset_train 81 | 82 | 83 | def print_transform(transform, s): 84 | print(f'Transform {s} = ') 85 | for t in transform.transforms: 86 | print(t) 87 | print('---------------------------\n') 88 | -------------------------------------------------------------------------------- /pretrain/utils/lamb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # This file is basically a copy to: https://github.com/rwightman/pytorch-image-models/blob/v0.5.4/timm/optim/lamb.py 8 | # **The only modification** is adding the `global_grad_norm` member for debugging 9 | 10 | 11 | """ PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb 12 | This optimizer code was adapted from the following (starting with latest) 13 | * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py 14 | * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py 15 | * https://github.com/cybertronai/pytorch-lamb 16 | Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is 17 | similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. 18 | In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. 19 | Original copyrights for above sources are below. 20 | Modifications Copyright 2021 Ross Wightman 21 | """ 22 | import math 23 | 24 | import torch 25 | from torch.optim.optimizer import Optimizer 26 | 27 | 28 | class TheSameAsTimmLAMB(Optimizer): 29 | """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB 30 | reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py 31 | 32 | LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 33 | 34 | Arguments: 35 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups. 36 | lr (float, optional): learning rate. (default: 1e-3) 37 | betas (Tuple[float, float], optional): coefficients used for computing 38 | running averages of gradient and its norm. (default: (0.9, 0.999)) 39 | eps (float, optional): term added to the denominator to improve 40 | numerical stability. (default: 1e-8) 41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 42 | grad_averaging (bool, optional): whether apply (1-beta2) to grad when 43 | calculating running averages of gradient. (default: True) 44 | max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) 45 | trust_clip (bool): enable LAMBC trust ratio clipping (default: False) 46 | always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 47 | weight decay parameter (default: False) 48 | 49 | .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: 50 | https://arxiv.org/abs/1904.00962 51 | .. _On the Convergence of Adam and Beyond: 52 | https://openreview.net/forum?id=ryQu7f-RZ 53 | """ 54 | 55 | def __init__( 56 | self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, 57 | weight_decay=0.01, grad_averaging=True, max_grad_norm=2.0, trust_clip=False, always_adapt=False): 58 | defaults = dict( 59 | lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, 60 | grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, 61 | trust_clip=trust_clip, always_adapt=always_adapt) 62 | super().__init__(params, defaults) 63 | print(f'[lamb1] max_grad_norm={max_grad_norm}') 64 | self.global_grad_norm = 0 65 | 66 | @torch.no_grad() 67 | def step(self, closure=None): 68 | """Performs a single optimization step. 69 | Arguments: 70 | closure (callable, optional): A closure that reevaluates the model 71 | and returns the loss. 72 | """ 73 | loss = None 74 | if closure is not None: 75 | with torch.enable_grad(): 76 | loss = closure() 77 | 78 | device = self.param_groups[0]['params'][0].device 79 | one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly 80 | global_grad_norm = torch.zeros(1, device=device) 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad 86 | if grad.is_sparse: 87 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 88 | global_grad_norm.add_(grad.pow(2).sum()) 89 | 90 | global_grad_norm = torch.sqrt(global_grad_norm) 91 | self.global_grad_norm = global_grad_norm.item() 92 | max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) 93 | clip_global_grad_norm = 1 / torch.where( 94 | global_grad_norm > max_grad_norm, 95 | global_grad_norm / max_grad_norm, 96 | one_tensor) 97 | 98 | for group in self.param_groups: 99 | bias_correction = 1 if group['bias_correction'] else 0 100 | beta1, beta2 = group['betas'] 101 | grad_averaging = 1 if group['grad_averaging'] else 0 102 | beta3 = 1 - beta1 if grad_averaging else 1.0 103 | 104 | # assume same step across group now to simplify things 105 | # per parameter step can be easily support by making it tensor, or pass list into kernel 106 | if 'step' in group: 107 | group['step'] += 1 108 | else: 109 | group['step'] = 1 110 | 111 | if bias_correction: 112 | bias_correction1 = 1 - beta1 ** group['step'] 113 | bias_correction2 = 1 - beta2 ** group['step'] 114 | else: 115 | bias_correction1, bias_correction2 = 1.0, 1.0 116 | 117 | for p in group['params']: 118 | if p.grad is None: 119 | continue 120 | grad = p.grad.mul_(clip_global_grad_norm) 121 | state = self.state[p] 122 | 123 | # State initialization 124 | if len(state) == 0: 125 | # Exponential moving average of gradient valuesa 126 | state['exp_avg'] = torch.zeros_like(p) 127 | # Exponential moving average of squared gradient values 128 | state['exp_avg_sq'] = torch.zeros_like(p) 129 | 130 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 131 | 132 | # Decay the first and second moment running average coefficient 133 | exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t 134 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t 135 | 136 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 137 | update = (exp_avg / bias_correction1).div_(denom) 138 | 139 | weight_decay = group['weight_decay'] 140 | if weight_decay != 0: 141 | update.add_(p, alpha=weight_decay) 142 | 143 | if weight_decay != 0 or group['always_adapt']: 144 | # Layer-wise LR adaptation. By default, skip adaptation on parameters that are 145 | # excluded from weight decay, unless always_adapt == True, then always enabled. 146 | w_norm = p.norm(2.0) 147 | g_norm = update.norm(2.0) 148 | # FIXME nested where required since logical and/or not working in PT XLA 149 | trust_ratio = torch.where( 150 | w_norm > 0, 151 | torch.where(g_norm > 0, w_norm / g_norm, one_tensor), 152 | one_tensor, 153 | ) 154 | if group['trust_clip']: 155 | # LAMBC trust clipping, upper bound fixed at one 156 | trust_ratio = torch.minimum(trust_ratio, one_tensor) 157 | update.mul_(trust_ratio) 158 | 159 | p.add_(update, alpha=-group['lr']) 160 | 161 | return loss 162 | -------------------------------------------------------------------------------- /pretrain/utils/lr_control.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from pprint import pformat 9 | 10 | 11 | def lr_wd_annealing(optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it): 12 | wp_it = round(wp_it) 13 | if cur_it < wp_it: 14 | cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it 15 | else: 16 | ratio = (cur_it - wp_it) / (max_it - 1 - wp_it) 17 | cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio)) 18 | 19 | ratio = cur_it / (max_it - 1) 20 | cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * ratio)) 21 | 22 | min_lr, max_lr = cur_lr, cur_lr 23 | min_wd, max_wd = cur_wd, cur_wd 24 | for param_group in optimizer.param_groups: 25 | scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned 26 | min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr) 27 | scaled_wd = param_group['weight_decay'] = cur_wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned 28 | min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd) 29 | return min_lr, max_lr, min_wd, max_wd 30 | 31 | 32 | def get_param_groups(model, nowd_keys=()): 33 | para_groups, para_groups_dbg = {}, {} 34 | 35 | for name, para in model.named_parameters(): 36 | if not para.requires_grad: 37 | continue # frozen weights 38 | if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys): 39 | wd_scale, group_name = 0., 'no_decay' 40 | else: 41 | wd_scale, group_name = 1., 'decay' 42 | 43 | if group_name not in para_groups: 44 | para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.} 45 | para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.} 46 | para_groups[group_name]['params'].append(para) 47 | para_groups_dbg[group_name]['params'].append(name) 48 | 49 | for g in para_groups_dbg.values(): 50 | g['params'] = pformat(', '.join(g['params']), width=200) 51 | 52 | print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n') 53 | return list(para_groups.values()) 54 | -------------------------------------------------------------------------------- /pretrain/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import datetime 8 | import functools 9 | import os 10 | import subprocess 11 | import sys 12 | import time 13 | from collections import defaultdict, deque 14 | from typing import Iterator 15 | 16 | import numpy as np 17 | import pytz 18 | import torch 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | import dist 22 | 23 | os_system = functools.partial(subprocess.call, shell=True) 24 | os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') 25 | def os_system_get_stdout_stderr(cmd): 26 | sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 27 | return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') 28 | 29 | 30 | def is_pow2n(x): 31 | return x > 0 and (x & (x - 1) == 0) 32 | 33 | 34 | def time_str(for_dirname=False): 35 | return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]') 36 | 37 | 38 | def init_distributed_environ(exp_dir): 39 | dist.initialize() 40 | dist.barrier() 41 | 42 | import torch.backends.cudnn as cudnn 43 | cudnn.benchmark = True 44 | cudnn.deterministic = False 45 | 46 | _set_print_only_on_master_proc(is_master=dist.is_local_master()) 47 | if dist.is_local_master() and len(exp_dir): 48 | sys.stdout, sys.stderr = _SyncPrintToFile(exp_dir, stdout=True), _SyncPrintToFile(exp_dir, stdout=False) 49 | 50 | 51 | def _set_print_only_on_master_proc(is_master): 52 | import builtins as __builtin__ 53 | 54 | builtin_print = __builtin__.print 55 | 56 | def prt(msg, *args, **kwargs): 57 | force = kwargs.pop('force', False) 58 | clean = kwargs.pop('clean', False) 59 | deeper = kwargs.pop('deeper', False) 60 | if is_master or force: 61 | if not clean: 62 | f_back = sys._getframe().f_back 63 | if deeper and f_back.f_back is not None: 64 | f_back = f_back.f_back 65 | file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] 66 | msg = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}' 67 | builtin_print(msg, *args, **kwargs) 68 | 69 | __builtin__.print = prt 70 | 71 | 72 | class _SyncPrintToFile(object): 73 | def __init__(self, exp_dir, stdout=True): 74 | self.terminal = sys.stdout if stdout else sys.stderr 75 | fname = os.path.join(exp_dir, 'stdout_backup.txt' if stdout else 'stderr_backup.txt') 76 | self.log = open(fname, 'w') 77 | self.log.flush() 78 | 79 | def write(self, message): 80 | self.terminal.write(message) 81 | self.log.write(message) 82 | self.log.flush() 83 | 84 | def flush(self): 85 | self.terminal.flush() 86 | self.log.flush() 87 | 88 | 89 | class TensorboardLogger(object): 90 | def __init__(self, log_dir, is_master, prefix='pt'): 91 | self.is_master = is_master 92 | self.writer = SummaryWriter(log_dir=log_dir) if self.is_master else None 93 | self.step = 0 94 | self.prefix = prefix 95 | self.log_freq = 300 96 | 97 | def set_step(self, step=None): 98 | if step is not None: 99 | self.step = step 100 | else: 101 | self.step += 1 102 | 103 | def get_loggable(self, step=None): 104 | if step is None: # iter wise 105 | step = self.step 106 | loggable = step % self.log_freq == 0 107 | else: # epoch wise 108 | loggable = True 109 | return step, (loggable and self.is_master) 110 | 111 | def update(self, head='scalar', step=None, **kwargs): 112 | step, loggable = self.get_loggable(step) 113 | if loggable: 114 | head = f'{self.prefix}_{head}' 115 | for k, v in kwargs.items(): 116 | if v is None: 117 | continue 118 | if isinstance(v, torch.Tensor): 119 | v = v.item() 120 | assert isinstance(v, (float, int)) 121 | self.writer.add_scalar(head + "/" + k, v, step) 122 | 123 | def log_distribution(self, tag, values, step=None): 124 | step, loggable = self.get_loggable(step) 125 | if loggable: 126 | if not isinstance(values, torch.Tensor): 127 | values = torch.tensor(values) 128 | self.writer.add_histogram(tag=tag, values=values, global_step=step) 129 | 130 | def log_image(self, tag, img, step=None, dataformats='NCHW'): 131 | step, loggable = self.get_loggable(step) 132 | if loggable: 133 | # img = img.cpu().numpy() 134 | self.writer.add_image(tag, img, step, dataformats=dataformats) 135 | 136 | def flush(self): 137 | if self.is_master: self.writer.flush() 138 | 139 | def close(self): 140 | if self.is_master: self.writer.close() 141 | 142 | 143 | def save_checkpoint_with_meta_info_and_opt_state(save_to, args, epoch, performance_desc, model_without_ddp_state, optimizer_state): 144 | checkpoint_path = os.path.join(args.exp_dir, save_to) 145 | if dist.is_local_master(): 146 | to_save = { 147 | 'args': str(args), 148 | 'input_size': args.input_size, 149 | 'arch': args.model, 150 | 'epoch': epoch, 151 | 'performance_desc': performance_desc, 152 | 'module': model_without_ddp_state, 153 | 'optimizer': optimizer_state, 154 | 'is_pretrain': True, 155 | } 156 | torch.save(to_save, checkpoint_path) 157 | 158 | 159 | def save_checkpoint_model_weights_only(save_to, args, sp_cnn_state): 160 | checkpoint_path = os.path.join(args.exp_dir, save_to) 161 | if dist.is_local_master(): 162 | torch.save(sp_cnn_state, checkpoint_path) 163 | 164 | 165 | def initialize_weight(init_weight: str, model_without_ddp): 166 | # use some checkpoint as model weight initialization; ONLY load model weights 167 | if len(init_weight): 168 | checkpoint = torch.load(init_weight, 'cpu') 169 | missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False) 170 | print(f'[initialize_weight] missing_keys={missing}') 171 | print(f'[initialize_weight] unexpected_keys={unexpected}') 172 | 173 | 174 | def load_checkpoint(resume_from: str, model_without_ddp, optimizer): 175 | # resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch 176 | if len(resume_from) == 0: 177 | return 0, '[no performance_desc]' 178 | print(f'[try to resume from file `{resume_from}`]') 179 | checkpoint = torch.load(resume_from, map_location='cpu') 180 | 181 | ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]') 182 | missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False) 183 | print(f'[load_checkpoint] missing_keys={missing}') 184 | print(f'[load_checkpoint] unexpected_keys={unexpected}') 185 | print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}') 186 | 187 | if 'optimizer' in checkpoint: 188 | optimizer.load_state_dict(checkpoint['optimizer']) 189 | return ep_start, performance_desc 190 | 191 | 192 | class SmoothedValue(object): 193 | """Track a series of values and provide access to smoothed values over a 194 | window or the global series average. 195 | """ 196 | 197 | def __init__(self, window_size=20, fmt=None): 198 | if fmt is None: 199 | fmt = "{median:.4f} ({global_avg:.4f})" 200 | self.deque = deque(maxlen=window_size) 201 | self.total = 0.0 202 | self.count = 0 203 | self.fmt = fmt 204 | 205 | def update(self, value, n=1): 206 | self.deque.append(value) 207 | self.count += n 208 | self.total += value * n 209 | 210 | def synchronize_between_processes(self): 211 | """ 212 | Warning: does not synchronize the deque! 213 | """ 214 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 215 | dist.barrier() 216 | dist.allreduce(t) 217 | t = t.tolist() 218 | self.count = int(t[0]) 219 | self.total = t[1] 220 | 221 | @property 222 | def median(self): 223 | d = torch.tensor(list(self.deque)) 224 | return d.median().item() 225 | 226 | @property 227 | def avg(self): 228 | d = torch.tensor(list(self.deque), dtype=torch.float32) 229 | return d.mean().item() 230 | 231 | @property 232 | def global_avg(self): 233 | return self.total / self.count 234 | 235 | @property 236 | def max(self): 237 | return max(self.deque) 238 | 239 | @property 240 | def value(self): 241 | return self.deque[-1] 242 | 243 | def __str__(self): 244 | return self.fmt.format( 245 | median=self.median, 246 | avg=self.avg, 247 | global_avg=self.global_avg, 248 | max=self.max, 249 | value=self.value) 250 | 251 | 252 | class MetricLogger(object): 253 | def __init__(self, delimiter="\t"): 254 | self.meters = defaultdict(SmoothedValue) 255 | self.delimiter = delimiter 256 | 257 | def update(self, **kwargs): 258 | for k, v in kwargs.items(): 259 | if v is None: 260 | continue 261 | if isinstance(v, torch.Tensor): 262 | v = v.item() 263 | assert isinstance(v, (float, int)) 264 | self.meters[k].update(v) 265 | 266 | def __getattr__(self, attr): 267 | if attr in self.meters: 268 | return self.meters[attr] 269 | if attr in self.__dict__: 270 | return self.__dict__[attr] 271 | raise AttributeError("'{}' object has no attribute '{}'".format( 272 | type(self).__name__, attr)) 273 | 274 | def __str__(self): 275 | loss_str = [] 276 | for name, meter in self.meters.items(): 277 | loss_str.append( 278 | "{}: {}".format(name, str(meter)) 279 | ) 280 | return self.delimiter.join(loss_str) 281 | 282 | def synchronize_between_processes(self): 283 | for meter in self.meters.values(): 284 | meter.synchronize_between_processes() 285 | 286 | def add_meter(self, name, meter): 287 | self.meters[name] = meter 288 | 289 | def log_every(self, max_iters, itrt, print_freq, header=None): 290 | print_iters = set(np.linspace(0, max_iters - 1, print_freq, dtype=int).tolist()) 291 | if not header: 292 | header = '' 293 | start_time = time.time() 294 | end = time.time() 295 | self.iter_time = SmoothedValue(fmt='{avg:.4f}') 296 | self.data_time = SmoothedValue(fmt='{avg:.4f}') 297 | space_fmt = ':' + str(len(str(max_iters))) + 'd' 298 | log_msg = [ 299 | header, 300 | '[{0' + space_fmt + '}/{1}]', 301 | 'eta: {eta}', 302 | '{meters}', 303 | 'iter: {time}s', 304 | 'data: {data}s' 305 | ] 306 | log_msg = self.delimiter.join(log_msg) 307 | 308 | if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'): 309 | for i in range(max_iters): 310 | obj = next(itrt) 311 | self.data_time.update(time.time() - end) 312 | yield obj 313 | self.iter_time.update(time.time() - end) 314 | if i in print_iters: 315 | eta_seconds = self.iter_time.global_avg * (max_iters - i) 316 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 317 | print(log_msg.format( 318 | i, max_iters, eta=eta_string, 319 | meters=str(self), 320 | time=str(self.iter_time), data=str(self.data_time))) 321 | end = time.time() 322 | else: 323 | for i, obj in enumerate(itrt): 324 | self.data_time.update(time.time() - end) 325 | yield obj 326 | self.iter_time.update(time.time() - end) 327 | if i in print_iters: 328 | eta_seconds = self.iter_time.global_avg * (max_iters - i) 329 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 330 | print(log_msg.format( 331 | i, max_iters, eta=eta_string, 332 | meters=str(self), 333 | time=str(self.iter_time), data=str(self.data_time))) 334 | end = time.time() 335 | 336 | total_time = time.time() - start_time 337 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 338 | print('{} Total time: {} ({:.3f} s / it)'.format( 339 | header, total_time_str, total_time / max_iters)) 340 | -------------------------------------------------------------------------------- /pretrain/viz_imgs/recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keyu-tian/SparK/a63e386f8e5186bc07ad7fce86e06b08f48a61ea/pretrain/viz_imgs/recon.png -------------------------------------------------------------------------------- /pretrain/viz_imgs/spconv1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keyu-tian/SparK/a63e386f8e5186bc07ad7fce86e06b08f48a61ea/pretrain/viz_imgs/spconv1.png -------------------------------------------------------------------------------- /pretrain/viz_imgs/spconv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keyu-tian/SparK/a63e386f8e5186bc07ad7fce86e06b08f48a61ea/pretrain/viz_imgs/spconv2.png -------------------------------------------------------------------------------- /pretrain/viz_imgs/spconv3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keyu-tian/SparK/a63e386f8e5186bc07ad7fce86e06b08f48a61ea/pretrain/viz_imgs/spconv3.png --------------------------------------------------------------------------------