├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── downstream_d2
├── README.md
├── configs
│ ├── Base-RCNN-FPN.yaml
│ └── coco_R_50_FPN_CONV_1x_moco_adam.yaml
├── convert-timm-to-d2.py
├── lr_decay.py
└── train_net.py
├── downstream_imagenet
├── README.md
├── arg.py
├── data.py
├── lr_decay.py
├── main.py
├── mixup.py
├── models
│ ├── __init__.py
│ └── convnext_official.py
├── requirements.txt
└── util.py
├── downstream_mmdet
├── README.md
├── configs
│ ├── _base_
│ │ ├── default_runtime.py
│ │ └── models
│ │ │ ├── cascade_mask_rcnn_convnext_fpn.py
│ │ │ └── mask_rcnn_convnext_fpn.py
│ └── convnext_spark
│ │ └── mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
├── mmcv_custom
│ ├── __init__.py
│ ├── customized_text.py
│ ├── layer_decay_optimizer_constructor.py
│ └── runner
│ │ └── checkpoint.py
└── mmdet
│ └── models
│ └── backbones
│ ├── __init__.py
│ └── convnext.py
└── pretrain
├── README.md
├── decoder.py
├── dist.py
├── encoder.py
├── main.py
├── models
├── __init__.py
├── convnext.py
├── custom.py
└── resnet.py
├── requirements.txt
├── sampler.py
├── spark.py
├── utils
├── arg_util.py
├── imagenet.py
├── lamb.py
├── lr_control.py
└── misc.py
├── viz_imgs
├── recon.png
├── spconv1.png
├── spconv2.png
└── spconv3.png
├── viz_reconstruction.ipynb
└── viz_spconv.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | **/__pycache__/**
3 | .idea/*
4 | ckpt/
5 | *.pth
6 | *.log
7 | *.txt
8 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Preparation for pre-training & ImageNet fine-tuning
2 |
3 | ## Pip dependencies
4 |
5 | 1. Prepare a python environment, e.g.:
6 | ```shell script
7 | $ conda create -n spark python=3.8 -y
8 | $ conda activate spark
9 | ```
10 |
11 | 2. Install `PyTorch` and `timm` (better to use `torch~=1.10`, `torchvision~=0.11`, and `timm==0.5.4`) then other python packages:
12 | ```shell script
13 | $ pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
14 | $ pip install timm==0.5.4
15 | $ pip install -r requirements.txt
16 | ```
17 |
18 | It is highly recommended to install these versions to ensure a consistent environment for re-implementation.
19 |
20 |
21 | ## ImageNet preparation
22 |
23 | Prepare the [ImageNet-1k](http://image-net.org/) dataset
24 | - assume the dataset is in `/path/to/imagenet`
25 | - it should look like this:
26 | ```
27 | /path/to/imagenet/:
28 | train/:
29 | class1:
30 | a_lot_images.jpeg
31 | class2:
32 | a_lot_images.jpeg
33 | val/:
34 | class1:
35 | a_lot_images.jpeg
36 | class2:
37 | a_lot_images.jpeg
38 | ```
39 | - that argument of `--data_path=/path/to/imagenet` should be passed to the training script introduced later
40 |
41 |
42 | > `PS:` In our implementation, we use pytorch built-in operators to simulate the submanifold sparse convolution in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py) for generality,
43 | due to the fact that many convolution operators (e.g., grouped conv and dilated conv) do not yet have efficient sparse implementations on today's hardware.
44 | If you want to try those sparse convolution, you may refer to [this](https://github.com/facebookresearch/SparseConvNet) sparse convolution library or [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine).
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Keyu Tian
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/downstream_d2/README.md:
--------------------------------------------------------------------------------
1 | ## About code isolation
2 |
3 | This `downstream_d2` is isolated from pre-training codes. One can treat this `downstream_d2` as an independent codebase 🛠️.
4 |
5 |
6 | ## Fine-tuned ResNet-50 weights, log files, and performance
7 |
8 |
9 |
10 | [[`weights (pre-trained by SparK)`](https://drive.google.com/file/d/1H8605HbxGvrsu4x4rIoNr-Wkd7JkxFPQ/view?usp=share_link)]
11 | [[`weights (fine-tuned on COCO)`](https://drive.google.com/file/d/1Ue7SiQ1E_AwgtYo56Fm-iUlQPZ8vIwYj/view?usp=share_link)]
12 | [[`metrics.json`](https://drive.google.com/file/d/1wfbUWh4svV8sPWya_0PAhsLHVayDQRCi/view?usp=share_link)]
13 | [[`log.txt`](https://drive.google.com/file/d/11zVo_87pe9DMAmfNQK9FUfyjQWHTRKxV/view?usp=share_link)]
14 | [[`tensorboard file`](https://drive.google.com/file/d/1aM1qj8c3-Uka1dZuYmKhgp1lNJpeMDMl/view?usp=share_link)]
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | ## Installation [Detectron2 v0.6](https://github.com/facebookresearch/detectron2/releases/tag/v0.6) before fine-tuning ResNet on COCO
23 |
24 |
25 | 1. Let you in some python environment, e.g.:
26 | ```shell script
27 | $ conda create -n spark python=3.8 -y
28 | $ conda activate spark
29 | ```
30 |
31 | 2. Install `detectron2==0.6` (e.g., with `torch==1.10.0` and `cuda11.3`):
32 | ```shell script
33 | $ pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
34 | ```
35 |
36 | You can also find instructions for different pytorch/cuda versions on [this page](https://github.com/facebookresearch/detectron2/releases/tag/v0.6).
37 |
38 |
39 | 3. Put the COCO dataset folder at `downstream_d2/datasets/coco`.
40 | The folder should follow the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) requried by `Detectron2`, which should look like this:
41 | ```
42 | downstream_d2/datasets/coco:
43 | annotations/:
44 | captions_train2017.json captions_val2017.json
45 | instances_train2017.json instances_val2017.json
46 | person_keypoints_train2017.json person_keypoints_val2017.json
47 | train2017/:
48 | a_lot_images.jpg
49 | val2017/:
50 | a_lot_images.jpg
51 | ```
52 |
53 |
54 | ## Training from pre-trained checkpoint
55 |
56 | The script file for COCO fine-tuning (object detection and instance segmentation) is [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py),
57 | which is a modification of [Detectron2's tools/train_net.py](https://github.com/facebookresearch/detectron2/blob/v0.6/tools/train_net.py).
58 |
59 |
60 | Before fine-tuning a ResNet50 pre-trained by SparK, you should first convert our checkpoint file to Detectron2-style `.pkl` file:
61 |
62 | ```shell script
63 | $ cd /path/to/SparK/downstream_d2
64 | $ python3 convert-timm-to-d2.py /some/path/to/resnet50_1kpretrained_timm_style.pth d2-style.pkl
65 | ```
66 |
67 | For a ResNet50, you should see a log reporting `len(state)==318`:
68 | ```text
69 | [convert] .pkl is generated! (from `/some/path/to/resnet50_1kpretrained_timm_style.pth`, to `d2-style.pkl`, len(state)==318)
70 | ```
71 |
72 | Then run fine-tuning on single machine with 8 gpus:
73 |
74 | ```shell script
75 | $ cd /path/to/SparK/downstream_d2
76 | $ python3 ./train_net.py --resume --num-gpus 8 --config-file ./configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml \
77 | MODEL.WEIGHTS d2-style.pkl \
78 | OUTPUT_DIR
79 | ```
80 |
81 | For multiple machines, plus these args:
82 | ```shell script
83 | --num-machines --machine-rank --dist-url
84 | ```
85 |
86 | In `` you'll see the log files generated by `Detectron2`.
87 |
88 |
89 | ## Details: how we modify the official Detectron2's [tools/train_net.py](https://github.com/facebookresearch/detectron2/blob/v0.6/tools/train_net.py) to get our [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py)
90 |
91 | 1. We add two new hyperparameters:
92 | - str `SOLVER.OPTIMIZER`: use 'ADAM' (the same as 'ADAMW') or 'SGD' optimizer
93 | - float `SOLVER.LR_DECAY`: the decay ratio (from 0. to 1.) of layer-wise learning rate decay trick
94 |
95 | 2. We implement layer-wise lr decay in [downstream_d2/lr_decay.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/lr_decay.py).
96 |
97 | 3. We write a script to convert our timm-style pre-trained ResNet weights to Detectron2-style in [downstream_d2/convert-timm-to-d2.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/convert-timm-to-d2.py).
98 |
99 | 4. We also add a hook for logging results to `cfg.OUTPUT_DIR/d2_coco_log.txt`.
100 |
101 | All of our modifications to the original are commented with `# [modification] ...` in [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py) or other files.
102 |
--------------------------------------------------------------------------------
/downstream_d2/configs/Base-RCNN-FPN.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN"
3 | BACKBONE:
4 | NAME: "build_resnet_fpn_backbone"
5 | RESNETS:
6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"]
7 | FPN:
8 | IN_FEATURES: ["res2", "res3", "res4", "res5"]
9 | ANCHOR_GENERATOR:
10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
12 | RPN:
13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level
16 | # Detectron1 uses 2000 proposals per-batch,
17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
19 | POST_NMS_TOPK_TRAIN: 1000
20 | POST_NMS_TOPK_TEST: 1000
21 | ROI_HEADS:
22 | NAME: "StandardROIHeads"
23 | IN_FEATURES: ["p2", "p3", "p4", "p5"]
24 | ROI_BOX_HEAD:
25 | NAME: "FastRCNNConvFCHead"
26 | NUM_FC: 2
27 | POOLER_RESOLUTION: 7
28 | ROI_MASK_HEAD:
29 | NAME: "MaskRCNNConvUpsampleHead"
30 | NUM_CONV: 4
31 | POOLER_RESOLUTION: 14
32 | DATASETS:
33 | TRAIN: ("coco_2017_train",)
34 | TEST: ("coco_2017_val",)
35 | SOLVER:
36 | IMS_PER_BATCH: 16
37 | BASE_LR: 0.02
38 | STEPS: (60000, 80000)
39 | MAX_ITER: 90000
40 | INPUT:
41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
42 | VERSION: 2
43 |
--------------------------------------------------------------------------------
/downstream_d2/configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: ""
4 | PIXEL_MEAN: [123.675, 116.280, 103.530]
5 | PIXEL_STD: [58.395, 57.120, 57.375]
6 |
7 | MASK_ON: True
8 | BACKBONE:
9 | FREEZE_AT: 0
10 | RESNETS:
11 | DEPTH: 50
12 | NORM: "SyncBN"
13 | STRIDE_IN_1X1: False
14 | FPN:
15 | NORM: "SyncBN"
16 | ROI_BOX_HEAD:
17 | NAME: "FastRCNNConvFCHead"
18 | NUM_FC: 1
19 | NUM_CONV: 4
20 | POOLER_RESOLUTION: 7
21 | NORM: "SyncBN"
22 | ROI_MASK_HEAD:
23 | NAME: "MaskRCNNConvUpsampleHead"
24 | NUM_CONV: 4
25 | POOLER_RESOLUTION: 14
26 | NORM: "SyncBN"
27 |
28 | INPUT:
29 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896)
30 | CROP:
31 | ENABLED: False
32 | TYPE: "absolute_range"
33 | SIZE: (384, 600)
34 | FORMAT: "RGB"
35 | TEST:
36 | EVAL_PERIOD: 5000
37 | PRECISE_BN:
38 | ENABLED: True
39 |
40 | SOLVER:
41 | STEPS: (60000, 80000)
42 | MAX_ITER: 90000
43 | GAMMA: 0.25
44 | BASE_LR: 0.00025
45 | WARMUP_FACTOR: 0.01
46 | WARMUP_ITERS: 1000
47 | WEIGHT_DECAY: 0.0001
48 | CHECKPOINT_PERIOD: 5000
49 | CLIP_GRADIENTS:
50 | ENABLED: False
51 | CLIP_TYPE: "value"
52 | CLIP_VALUE: 1.0
53 | NORM_TYPE: 2.0
54 |
55 | # compared to standard detectron2, we add these two new configurations:
56 | OPTIMIZER: "ADAMW"
57 | LR_DECAY: 0.6
58 |
--------------------------------------------------------------------------------
/downstream_d2/convert-timm-to-d2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | # Copyright (c) ByteDance, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import pickle as pkl
10 |
11 | import torch
12 |
13 |
14 | # we use `timm.models.ResNet` in pre-training, so keys are timm-style
15 | def timm_resnet_to_detectron2_resnet(source_file, target_file):
16 | pretrained: dict = torch.load(source_file, map_location='cpu')
17 | for mod_k in {'state_dict', 'state', 'module', 'model'}:
18 | if mod_k in pretrained:
19 | pretrained = pretrained[mod_k]
20 | if any(k.startswith('module.encoder_q.') for k in pretrained.keys()):
21 | pretrained = {k.replace('module.encoder_q.', ''): v for k, v in pretrained.items() if k.startswith('module.encoder_q.')}
22 |
23 | pkl_state = {}
24 | for k, v in pretrained.items(): # convert resnet's keys from timm-style to d2-style
25 | if 'layer' not in k:
26 | k = 'stem.' + k
27 | for t in [1, 2, 3, 4]:
28 | k = k.replace(f'layer{t}', f'res{t+1}')
29 | for t in [1, 2, 3]:
30 | k = k.replace(f'bn{t}', f'conv{t}.norm')
31 | k = k.replace('downsample.0', 'shortcut')
32 | k = k.replace('downsample.1', 'shortcut.norm')
33 |
34 | pkl_state[k] = v.detach().numpy()
35 |
36 | with open(target_file, 'wb') as fp:
37 | print(f'[convert] .pkl is generated! (from `{source_file}`, to `{target_file}`, len(state)=={len(pkl_state)})')
38 | pkl.dump({'model': pkl_state, '__author__': 'https://github.com/keyu-tian/SparK', 'matching_heuristics': True}, fp)
39 |
40 |
41 | if __name__ == '__main__':
42 | import sys
43 | timm_resnet_to_detectron2_resnet(sys.argv[1], sys.argv[2])
44 |
--------------------------------------------------------------------------------
/downstream_d2/lr_decay.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Set, Optional, Callable, Any
2 | import torch
3 | import copy
4 |
5 | from detectron2.solver.build import reduce_param_groups
6 |
7 |
8 | def lr_factor_func(para_name: str, is_resnet50, dec: float, debug=False) -> float:
9 | if dec == 0:
10 | dec = 1.
11 |
12 | N = 5 if is_resnet50 else 11
13 | if '.stem.' in para_name:
14 | layer_id = 0
15 | elif '.res' in para_name:
16 | ls = para_name.split('.res')[1].split('.')
17 | if ls[0].isnumeric() and ls[1].isnumeric():
18 | stage_id, block_id = int(ls[0]), int(ls[1])
19 | if stage_id == 2: # res2
20 | layer_id = 1
21 | elif stage_id == 3: # res3
22 | layer_id = 2
23 | elif stage_id == 4: # res4
24 | layer_id = 3 + block_id // 3 # 3, 4 or 4, 5
25 | else: # res5
26 | layer_id = N
27 | else:
28 | assert para_name.startswith('roi_heads.res5.norm.')
29 | layer_id = N + 1 # roi_heads.res5.norm.weight and roi_heads.res5.norm.bias of C4
30 | else:
31 | layer_id = N + 1
32 |
33 | exp = N + 1 - layer_id
34 | return f'{dec:g} ** {exp}' if debug else dec ** exp
35 |
36 |
37 | # [modification] see: https://github.com/facebookresearch/detectron2/blob/v0.6/detectron2/solver/build.py#L134
38 | # add the `lr_factor_func` to implement lr decay
39 | def get_default_optimizer_params(
40 | model: torch.nn.Module,
41 | base_lr: Optional[float] = None,
42 | weight_decay: Optional[float] = None,
43 | weight_decay_norm: Optional[float] = None,
44 | bias_lr_factor: Optional[float] = 1.0,
45 | weight_decay_bias: Optional[float] = None,
46 | lr_factor_func: Optional[Callable] = None,
47 | overrides: Optional[Dict[str, Dict[str, float]]] = None,
48 | ) -> List[Dict[str, Any]]:
49 | """
50 | Get default param list for optimizer, with support for a few types of
51 | overrides. If no overrides needed, this is equivalent to `model.parameters()`.
52 |
53 | Args:
54 | base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
55 | weight_decay: weight decay for every group by default. Can be omitted to use the one
56 | in optimizer.
57 | weight_decay_norm: override weight decay for params in normalization layers
58 | bias_lr_factor: multiplier of lr for bias parameters.
59 | weight_decay_bias: override weight decay for bias parameters.
60 | lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
61 | corresponding lr decay rate. Note that setting this option requires
62 | also setting ``base_lr``.
63 | overrides: if not `None`, provides values for optimizer hyperparameters
64 | (LR, weight decay) for module parameters with a given name; e.g.
65 | ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
66 | weight decay values for all module parameters named `embedding`.
67 |
68 | For common detection models, ``weight_decay_norm`` is the only option
69 | needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
70 | from Detectron1 that are not found useful.
71 |
72 | Example:
73 | ::
74 | torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
75 | lr=0.01, weight_decay=1e-4, momentum=0.9)
76 | """
77 | if overrides is None:
78 | overrides = {}
79 | defaults = {}
80 | if base_lr is not None:
81 | defaults["lr"] = base_lr
82 | if weight_decay is not None:
83 | defaults["weight_decay"] = weight_decay
84 | bias_overrides = {}
85 | if bias_lr_factor is not None and bias_lr_factor != 1.0:
86 | # NOTE: unlike Detectron v1, we now by default make bias hyperparameters
87 | # exactly the same as regular weights.
88 | if base_lr is None:
89 | raise ValueError("bias_lr_factor requires base_lr")
90 | bias_overrides["lr"] = base_lr * bias_lr_factor
91 | if weight_decay_bias is not None:
92 | bias_overrides["weight_decay"] = weight_decay_bias
93 | if len(bias_overrides):
94 | if "bias" in overrides:
95 | raise ValueError("Conflicting overrides for 'bias'")
96 | overrides["bias"] = bias_overrides
97 | if lr_factor_func is not None:
98 | if base_lr is None:
99 | raise ValueError("lr_factor_func requires base_lr")
100 | norm_module_types = (
101 | torch.nn.BatchNorm1d,
102 | torch.nn.BatchNorm2d,
103 | torch.nn.BatchNorm3d,
104 | torch.nn.SyncBatchNorm,
105 | # NaiveSyncBatchNorm inherits from BatchNorm2d
106 | torch.nn.GroupNorm,
107 | torch.nn.InstanceNorm1d,
108 | torch.nn.InstanceNorm2d,
109 | torch.nn.InstanceNorm3d,
110 | torch.nn.LayerNorm,
111 | torch.nn.LocalResponseNorm,
112 | )
113 | params: List[Dict[str, Any]] = []
114 | memo: Set[torch.nn.parameter.Parameter] = set()
115 | for module_name, module in model.named_modules():
116 | for module_param_name, value in module.named_parameters(recurse=False):
117 | if not value.requires_grad:
118 | continue
119 | # Avoid duplicating parameters
120 | if value in memo:
121 | continue
122 | memo.add(value)
123 |
124 | hyperparams = copy.copy(defaults)
125 | if isinstance(module, norm_module_types) and weight_decay_norm is not None:
126 | hyperparams["weight_decay"] = weight_decay_norm
127 | if lr_factor_func is not None:
128 | hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
129 |
130 | hyperparams.update(overrides.get(module_param_name, {}))
131 | params.append({"params": [value], **hyperparams})
132 | return reduce_param_groups(params)
133 |
--------------------------------------------------------------------------------
/downstream_imagenet/README.md:
--------------------------------------------------------------------------------
1 | ## About code isolation
2 |
3 | This `downstream_imagenet` is isolated from pre-training codes. One can treat this `downstream_imagenet` as an independent codebase 🛠️.
4 |
5 |
6 | ## Preparation for ImageNet-1k fine-tuning
7 |
8 | See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
9 |
10 | **Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
11 |
12 |
13 | ## Fine-tuning on ImageNet-1k from pre-trained weights
14 |
15 | Run [/downstream_imagenet/main.py](/downstream_imagenet/main.py) via `torchrun`.
16 | **It is required to specify** the ImageNet data folder (`--data_path`), your experiment name & log dir (`--exp_name` and `--exp_dir`, automatically created if not exists), the model name (`--model`, valid choices see the keys of 'HP_DEFAULT_VALUES' in [/downstream_imagenet/arg.py line14](/downstream_imagenet/arg.py#L14)), and the pretrained weight file `--resume_from` to run fine-tuning.
17 |
18 | All the other configurations have their default values, listed in [/downstream_imagenet/arg.py#L13](/downstream_imagenet/arg.py#L13).
19 | You can overwrite any defaults by `--bs=1024` or something like that.
20 |
21 |
22 | Here is an example to pretrain a ConvNeXt-Small on an 8-GPU single machine:
23 | ```shell script
24 | $ cd /path/to/SparK/downstream_imagenet
25 | $ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port= main.py \
26 | --data_path=/path/to/imagenet --exp_name= --exp_dir=/path/to/logdir \
27 | --model=convnext_small --resume_from=/some/path/to/convnextS_1kpretrained_official_style.pth
28 | ```
29 |
30 | For multiple machines, change the `--nnodes` and `--master_addr` to your configurations. E.g.:
31 | ```shell script
32 | $ torchrun --nproc_per_node=8 --nnodes= --node_rank= --master_address= --master_port= main.py \
33 | ...
34 | ```
35 |
36 |
37 | ## Logging
38 |
39 | See files under `--exp_dir` to track your experiment:
40 |
41 | - `_1kfinetuned_last.pth`: the latest model weights
42 | - `_1kfinetuned_best.pth`: model weights with the highest acc
43 | - `_1kfinetuned_best_ema.pth`: EMA weights with the highest acc
44 | - `finetune_log.txt`: records some important information such as:
45 | - `git_commit_id`: git version
46 | - `cmd`: all arguments passed to the script
47 |
48 | It also reports training loss/acc, best evaluation acc, and remaining time at each epoch.
49 |
50 | - `tensorboard_log/`: saves a lot of tensorboard logs, you can visualize accuracies, loss values, learning rates, gradient norms and more things via `tensorboard --logdir /path/to/this/tensorboard_log/ --port 23333`.
51 |
52 | ## Resuming
53 |
54 | Use `--resume_from` again, like `--resume_from=path/to/_1kfinetuned_last.pth`.
55 |
--------------------------------------------------------------------------------
/downstream_imagenet/arg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import json
8 | import os
9 | import sys
10 |
11 | from tap import Tap
12 |
13 | HP_DEFAULT_NAMES = ['bs', 'ep', 'wp_ep', 'opt', 'base_lr', 'lr_scale', 'wd', 'mixup', 'rep_aug', 'drop_path', 'ema']
14 | HP_DEFAULT_VALUES = {
15 | 'convnext_small': (4096, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999),
16 | 'convnext_base': (4096, 400, 20, 'adam', 0.0001, 0.7, 0.01, 0.8, 3, 0.4, 0.9999),
17 | 'convnext_large': (4096, 200, 10, 'adam', 0.0001, 0.7, 0.02, 0.8, 3, 0.5, 0.9999),
18 | 'convnext_large_384': (1024, 200, 20, 'adam', 0.00006, 0.7, 0.01, 0.8, 3, 0.5, 0.99995),
19 |
20 | 'resnet50': (2048, 300, 5, 'lamb', 0.002, 0.7, 0.02, 0.1, 0, 0.05, 0.9999),
21 | 'resnet101': (2048, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
22 | 'resnet152': (2048, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
23 | 'resnet200': (2048, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
24 | }
25 |
26 |
27 | class FineTuneArgs(Tap):
28 | # environment
29 | exp_name: str
30 | exp_dir: str
31 | data_path: str
32 | model: str
33 | resume_from: str = '' # resume from some checkpoint.pth
34 |
35 | img_size: int = 224
36 | dataloader_workers: int = 8
37 |
38 | # ImageNet classification fine-tuning hyperparameters; see `HP_DEFAULT_VALUES` above for detailed default values
39 | # - batch size, epoch
40 | bs: int = 0 # global batch size (== batch_size_per_gpu * num_gpus)
41 | ep: int = 0 # number of epochs
42 | wp_ep: int = 0 # epochs for warmup
43 |
44 | # - optimization
45 | opt: str = '' # optimizer; 'adam' or 'lamb'
46 | base_lr: float = 0. # lr == base_lr * (bs)
47 | lr_scale: float = 0. # see file `lr_decay.py` for more details
48 | clip: int = -1 # use gradient clipping if clip > 0
49 |
50 | # - regularization tricks
51 | wd: float = 0. # weight decay
52 | mixup: float = 0. # use mixup if mixup > 0
53 | rep_aug: int = 0 # use repeated augmentation if rep_aug > 0
54 | drop_path: float = 0. # drop_path ratio
55 |
56 | # - other tricks
57 | ema: float = 0. # use EMA if ema > 0
58 | sbn: bool = True # use SyncBatchNorm
59 |
60 | # NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically
61 | lr: float = None
62 | batch_size_per_gpu: int = 0
63 | glb_batch_size: int = 0
64 | device: str = 'cpu'
65 | world_size: int = 1
66 | global_rank: int = 0
67 | local_rank: int = 0 # we DO USE this arg
68 | is_master: bool = False
69 | is_local_master: bool = False
70 | cmd: str = ' '.join(sys.argv[1:])
71 | commit_id: str = os.popen(f'git rev-parse HEAD').read().strip()
72 | commit_msg: str = os.popen(f'git log -1').read().strip().splitlines()[-1].strip()
73 | log_txt_name: str = '{args.exp_dir}/pretrain_log.txt'
74 | tb_lg_dir: str = '' # tensorboard log directory
75 |
76 | train_loss: float = 0.
77 | train_acc: float = 0.
78 | best_val_acc: float = 0.
79 | cur_ep: str = ''
80 | remain_time: str = ''
81 | finish_time: str = ''
82 | first_logging: bool = True
83 |
84 | def log_epoch(self):
85 | if not self.is_local_master:
86 | return
87 |
88 | if self.first_logging:
89 | self.first_logging = False
90 | with open(self.log_txt_name, 'w') as fp:
91 | json.dump({
92 | 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg,
93 | 'model': self.model,
94 | }, fp)
95 | fp.write('\n\n')
96 |
97 | with open(self.log_txt_name, 'a') as fp:
98 | json.dump({
99 | 'cur_ep': self.cur_ep,
100 | 'train_L': self.train_loss, 'train_acc': self.train_acc,
101 | 'best_val_acc': self.best_val_acc,
102 | 'rema': self.remain_time, 'fini': self.finish_time,
103 | }, fp)
104 | fp.write('\n')
105 |
106 |
107 | def get_args(world_size, global_rank, local_rank, device) -> FineTuneArgs:
108 | # parse args and prepare directories
109 | args = FineTuneArgs(explicit_bool=True).parse_args()
110 | d_name, b_name = os.path.dirname(os.path.abspath(args.exp_dir)), os.path.basename(os.path.abspath(args.exp_dir))
111 | b_name = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in b_name)
112 | args.exp_dir = os.path.join(d_name, b_name)
113 | os.makedirs(args.exp_dir, exist_ok=True)
114 | args.log_txt_name = os.path.join(args.exp_dir, 'finetune_log.txt')
115 |
116 | args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log')
117 | try: os.makedirs(args.tb_lg_dir, exist_ok=True)
118 | except: pass
119 |
120 | # fill in args.bs, args.ep, etc. with their default values (if their values are not explicitly specified, i.e., if bool(they) == False)
121 | if args.model == 'convnext_large' and args.img_size == 384:
122 | default_values = HP_DEFAULT_VALUES['convnext_large_384']
123 | else:
124 | default_values = HP_DEFAULT_VALUES[args.model]
125 | for k, v in zip(HP_DEFAULT_NAMES, default_values):
126 | if bool(getattr(args, k)) == False:
127 | setattr(args, k, v)
128 |
129 | # update other runtime args
130 | args.world_size, args.global_rank, args.local_rank, args.device = world_size, global_rank, local_rank, device
131 | args.is_master = global_rank == 0
132 | args.is_local_master = local_rank == 0
133 | args.batch_size_per_gpu = args.bs // world_size
134 | args.glb_batch_size = args.batch_size_per_gpu * world_size
135 | args.lr = args.base_lr * args.glb_batch_size / 256
136 |
137 | return args
138 |
--------------------------------------------------------------------------------
/downstream_imagenet/data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import random
9 | import time
10 |
11 | import PIL.Image as PImage
12 | import numpy as np
13 | import torch
14 | import torchvision
15 | from timm.data import AutoAugment as TimmAutoAugment
16 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
17 | from timm.data.distributed_sampler import RepeatAugSampler
18 | from timm.data.transforms_factory import transforms_imagenet_eval
19 | from torch.utils.data import DataLoader
20 | from torch.utils.data.sampler import Sampler
21 | from torchvision.transforms import AutoAugment as TorchAutoAugment
22 | from torchvision.transforms import transforms, TrivialAugmentWide
23 |
24 | try:
25 | from torchvision.transforms import InterpolationMode
26 | interpolation = InterpolationMode.BICUBIC
27 | except:
28 | import PIL
29 | interpolation = PIL.Image.BICUBIC
30 |
31 |
32 | def create_classification_dataset(data_path, img_size, rep_aug, workers, batch_size_per_gpu, world_size, global_rank):
33 | import warnings
34 | warnings.filterwarnings('ignore', category=UserWarning)
35 |
36 | mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
37 | trans_train = create_transform(
38 | is_training=True, input_size=img_size,
39 | auto_augment='v0', interpolation='bicubic', re_prob=0.25, re_mode='pixel', re_count=1,
40 | mean=mean, std=std,
41 | )
42 | if img_size < 384:
43 | for i, t in enumerate(trans_train.transforms):
44 | if isinstance(t, (TorchAutoAugment, TimmAutoAugment)):
45 | trans_train.transforms[i] = TrivialAugmentWide(interpolation=interpolation)
46 | break
47 | trans_val = transforms_imagenet_eval(img_size=img_size, interpolation='bicubic', crop_pct=0.95, mean=mean, std=std)
48 | else:
49 | trans_val = transforms.Compose([
50 | transforms.Resize((img_size, img_size), interpolation=interpolation),
51 | transforms.ToTensor(), transforms.Normalize(mean=mean, std=std),
52 | ])
53 | print_transform(trans_train, '[train]')
54 | print_transform(trans_val, '[val]')
55 |
56 | imagenet_folder = os.path.abspath(data_path)
57 | for postfix in ('train', 'val'):
58 | if imagenet_folder.endswith(postfix):
59 | imagenet_folder = imagenet_folder[:-len(postfix)]
60 | dataset_train = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'train'), trans_train)
61 | dataset_val = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'val'), trans_val)
62 |
63 | if rep_aug:
64 | print(f'[dataset] using repeated augmentation: count={rep_aug}')
65 | train_sp = RepeatAugSampler(dataset_train, shuffle=True, num_repeats=rep_aug)
66 | else:
67 | train_sp = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True, drop_last=True)
68 |
69 | loader_train = DataLoader(
70 | dataset=dataset_train, num_workers=workers, pin_memory=True,
71 | batch_size=batch_size_per_gpu, sampler=train_sp, persistent_workers=workers > 0,
72 | worker_init_fn=worker_init_fn,
73 | )
74 | iters_train = len(loader_train)
75 | print(f'[dataset: train] bs={world_size}x{batch_size_per_gpu}={world_size * batch_size_per_gpu}, num_iters={iters_train}')
76 |
77 | val_ratio = 2
78 | loader_val = DataLoader(
79 | dataset=dataset_val, num_workers=workers, pin_memory=True,
80 | batch_sampler=DistInfiniteBatchSampler(world_size, global_rank, len(dataset_val), glb_batch_size=val_ratio * batch_size_per_gpu, filling=False, shuffle=False),
81 | worker_init_fn=worker_init_fn,
82 | )
83 | iters_val = len(loader_val)
84 | print(f'[dataset: val] bs={world_size}x{val_ratio * batch_size_per_gpu}={val_ratio * world_size * batch_size_per_gpu}, num_iters={iters_val}')
85 |
86 | time.sleep(3)
87 | warnings.resetwarnings()
88 | return loader_train, iters_train, iter(loader_val), iters_val
89 |
90 |
91 | def worker_init_fn(worker_id):
92 | # see: https://pytorch.org/docs/stable/notes/randomness.html#dataloader
93 | worker_seed = torch.initial_seed() % 2 ** 32
94 | np.random.seed(worker_seed)
95 | random.seed(worker_seed)
96 |
97 |
98 | def print_transform(transform, s):
99 | print(f'Transform {s} = ')
100 | for t in transform.transforms:
101 | print(t)
102 | print('---------------------------\n')
103 |
104 |
105 | class DistInfiniteBatchSampler(Sampler):
106 | def __init__(self, world_size, global_rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True):
107 | assert glb_batch_size % world_size == 0
108 | self.world_size, self.rank = world_size, global_rank
109 | self.dataset_len = dataset_len
110 | self.glb_batch_size = glb_batch_size
111 | self.batch_size = glb_batch_size // world_size
112 |
113 | self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
114 | self.filling = filling
115 | self.shuffle = shuffle
116 | self.epoch = 0
117 | self.seed = seed
118 | self.indices = self.gener_indices()
119 |
120 | def gener_indices(self):
121 | global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
122 | if self.shuffle:
123 | g = torch.Generator()
124 | g.manual_seed(self.epoch + self.seed)
125 | global_indices = torch.randperm(self.dataset_len, generator=g)
126 | else:
127 | global_indices = torch.arange(self.dataset_len)
128 | filling = global_max_p - global_indices.shape[0]
129 | if filling > 0 and self.filling:
130 | global_indices = torch.cat((global_indices, global_indices[:filling]))
131 | global_indices = tuple(global_indices.numpy().tolist())
132 |
133 | seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
134 | local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
135 | self.max_p = len(local_indices)
136 | return local_indices
137 |
138 | def __iter__(self):
139 | self.epoch = 0
140 | while True:
141 | self.epoch += 1
142 | p, q = 0, 0
143 | while p < self.max_p:
144 | q = p + self.batch_size
145 | yield self.indices[p:q]
146 | p = q
147 | if self.shuffle:
148 | self.indices = self.gener_indices()
149 |
150 | def __len__(self):
151 | return self.iters_per_ep
152 |
--------------------------------------------------------------------------------
/downstream_imagenet/lr_decay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from pprint import pformat
9 |
10 |
11 | def lr_wd_annealing(optimizer, peak_lr, wd, cur_it, wp_it, max_it):
12 | wp_it = round(wp_it)
13 | if cur_it < wp_it:
14 | cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it
15 | else:
16 | ratio = (cur_it - wp_it) / (max_it - 1 - wp_it)
17 | cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio))
18 |
19 | min_lr, max_lr = cur_lr, cur_lr
20 | min_wd, max_wd = wd, wd
21 | for param_group in optimizer.param_groups:
22 | scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned
23 | min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr)
24 | scaled_wd = param_group['weight_decay'] = wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned
25 | min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd)
26 | return min_lr, max_lr, min_wd, max_wd
27 |
28 |
29 | def get_param_groups(model, nowd_keys=(), lr_scale=0.0):
30 | using_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0.0 < lr_scale < 1.0
31 | print(f'[get_ft_param_groups][lr decay] using_lr_scale={using_lr_scale}, ft_lr_scale={lr_scale}')
32 | para_groups, para_groups_dbg = {}, {}
33 |
34 | for name, para in model.named_parameters():
35 | if not para.requires_grad:
36 | continue # frozen weights
37 | if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys):
38 | wd_scale, group_name = 0., 'no_decay'
39 | else:
40 | wd_scale, group_name = 1., 'decay'
41 |
42 | if using_lr_scale:
43 | layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
44 | group_name = f'layer{layer_id}_' + group_name
45 | this_lr_scale = lr_scale ** scale_exp
46 | dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
47 | else:
48 | this_lr_scale = 1
49 | dbg = f'[no scale]'
50 |
51 | if group_name not in para_groups:
52 | para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': this_lr_scale}
53 | para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': dbg}
54 | para_groups[group_name]['params'].append(para)
55 | para_groups_dbg[group_name]['params'].append(name)
56 |
57 | for g in para_groups_dbg.values():
58 | g['params'] = pformat(', '.join(g['params']), width=200)
59 |
60 | print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n')
61 | return list(para_groups.values())
62 |
--------------------------------------------------------------------------------
/downstream_imagenet/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import datetime
8 | import time
9 |
10 | import torch
11 | import torch.distributed as tdist
12 | from timm.utils import ModelEmaV2
13 | from torch.utils.tensorboard import SummaryWriter
14 |
15 | from arg import get_args, FineTuneArgs
16 | from models import ConvNeXt, ResNet
17 | __for_timm_registration = ConvNeXt, ResNet
18 | from lr_decay import lr_wd_annealing
19 | from util import init_distributed_environ, create_model_opt, load_checkpoint, save_checkpoint
20 | from data import create_classification_dataset
21 |
22 |
23 | def main_ft():
24 | world_size, global_rank, local_rank, device = init_distributed_environ()
25 | args: FineTuneArgs = get_args(world_size, global_rank, local_rank, device)
26 | print(f'initial args:\n{str(args)}')
27 | args.log_epoch()
28 |
29 | criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer = create_model_opt(args)
30 | ep_start, performance_desc = load_checkpoint(args.resume_from, model_without_ddp, model_ema, optimizer)
31 |
32 | if ep_start >= args.ep: # load from a complete checkpoint file
33 | print(f' [*] [FT already done] Max/Last Acc: {performance_desc}')
34 | else:
35 | tb_lg = SummaryWriter(args.tb_lg_dir) if args.is_master else None
36 | loader_train, iters_train, iterator_val, iters_val = create_classification_dataset(
37 | args.data_path, args.img_size, args.rep_aug,
38 | args.dataloader_workers, args.batch_size_per_gpu, args.world_size, args.global_rank
39 | )
40 |
41 | # train & eval
42 | tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model)
43 | max_acc = last_acc
44 | max_acc_e = last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)[-1]
45 | print(f'[fine-tune] initial acc={last_acc:.2f}, ema={last_acc_e:.2f}')
46 |
47 | ep_eval = set(range(0, args.ep//3, 5)) | set(range(args.ep//3, args.ep))
48 | print(f'[FT start] ep_eval={sorted(ep_eval)} ')
49 | print(f'[FT start] from ep{ep_start}')
50 |
51 | params_req_grad = [p for p in model.parameters() if p.requires_grad]
52 | ft_start_time = time.time()
53 | for ep in range(ep_start, args.ep):
54 | ep_start_time = time.time()
55 | if hasattr(loader_train, 'sampler') and hasattr(loader_train.sampler, 'set_epoch'):
56 | loader_train.sampler.set_epoch(ep)
57 | if 0 <= ep <= 3:
58 | print(f'[loader_train.sampler.set_epoch({ep})]')
59 |
60 | train_loss, train_acc = fine_tune_one_epoch(ep, args, tb_lg, loader_train, iters_train, criterion, mixup_fn, model, model_ema, optimizer, params_req_grad)
61 | if ep in ep_eval:
62 | eval_start_time = time.time()
63 | tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model)
64 | tot_pred_e, last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)
65 | eval_cost = round(time.time() - eval_start_time, 2)
66 | performance_desc = f'Max (Last) Acc: {max(max_acc, last_acc):.2f} ({last_acc:.2f} o {tot_pred}) EMA: {max(max_acc_e, last_acc_e):.2f} ({last_acc_e:.2f} o {tot_pred_e})'
67 | states = model_without_ddp.state_dict(), model_ema.module.state_dict(), optimizer.state_dict()
68 | if last_acc > max_acc:
69 | max_acc = last_acc
70 | save_checkpoint(f'{args.model}_1kfinetuned_best.pth', args, ep, performance_desc, *states)
71 | if last_acc_e > max_acc_e:
72 | max_acc_e = last_acc_e
73 | save_checkpoint(f'{args.model}_1kfinetuned_best_ema.pth', args, ep, performance_desc, *states)
74 | save_checkpoint(f'{args.model}_1kfinetuned_last.pth', args, ep, performance_desc, *states)
75 | else:
76 | eval_cost = '-'
77 |
78 | ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
79 | remain_secs = (args.ep-1 - ep) * ep_cost
80 | remain_time = datetime.timedelta(seconds=round(remain_secs))
81 | finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
82 | print(f'[ep{ep}/{args.ep}] {performance_desc} Ep cost: {ep_cost}s, Ev cost: {eval_cost}, Remain: {remain_time}, Finish @ {finish_time}')
83 | args.cur_ep = f'{ep + 1}/{args.ep}'
84 | args.remain_time, args.finish_time = str(remain_time), str(finish_time)
85 | args.train_loss, args.train_acc, args.best_val_acc = train_loss, train_acc, max(max_acc, max_acc_e)
86 | args.log_epoch()
87 |
88 | if args.is_master:
89 | tb_lg.add_scalar(f'ft_train/ep_loss', train_loss, ep)
90 | tb_lg.add_scalar(f'ft_eval/max_acc', max_acc, ep)
91 | tb_lg.add_scalar(f'ft_eval/last_acc', last_acc, ep)
92 | tb_lg.add_scalar(f'ft_eval/max_acc_ema', max_acc_e, ep)
93 | tb_lg.add_scalar(f'ft_eval/last_acc_ema', last_acc_e, ep)
94 | tb_lg.add_scalar(f'ft_z_burnout/rest_hours', round(remain_secs/60/60, 2), ep)
95 | tb_lg.flush()
96 |
97 | # finish fine-tuning
98 | result_acc = max(max_acc, max_acc_e)
99 | if args.is_master:
100 | tb_lg.add_scalar('ft_result/result_acc', result_acc, ep_start)
101 | tb_lg.add_scalar('ft_result/result_acc', result_acc, args.ep)
102 | tb_lg.flush()
103 | tb_lg.close()
104 | print(f'final args:\n{str(args)}')
105 | print('\n\n')
106 | print(f' [*] [FT finished] {performance_desc} Total Cost: {(time.time() - ft_start_time) / 60 / 60:.1f}h\n')
107 | print(f' [*] [FT finished] max(max_acc, max_acc_e)={result_acc} EMA better={max_acc_e>max_acc}')
108 | print('\n\n')
109 | time.sleep(10)
110 |
111 | args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
112 | args.log_epoch()
113 |
114 |
115 | def fine_tune_one_epoch(ep, args: FineTuneArgs, tb_lg: SummaryWriter, loader_train, iters_train, criterion, mixup_fn, model, model_ema: ModelEmaV2, optimizer, params_req_grad):
116 | model.train()
117 | tot_loss = tot_acc = 0.0
118 | log_freq = max(1, round(iters_train * 0.7))
119 | ep_start_time = time.time()
120 | for it, (inp, tar) in enumerate(loader_train):
121 | # adjust lr and wd
122 | cur_it = it + ep * iters_train
123 | min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, cur_it, args.wp_ep * iters_train, args.ep * iters_train)
124 |
125 | # forward
126 | inp = inp.to(args.device, non_blocking=True)
127 | raw_tar = tar = tar.to(args.device, non_blocking=True)
128 | if mixup_fn is not None:
129 | inp, tar, raw_tar = mixup_fn(inp, tar)
130 | oup = model(inp)
131 | pred = oup.data.argmax(dim=1)
132 | if mixup_fn is None:
133 | acc = pred.eq(tar).float().mean().item() * 100
134 | tot_acc += acc
135 | else:
136 | acc = (pred.eq(raw_tar) | pred.eq(raw_tar.flip(0))).float().mean().item() * 100
137 | tot_acc += acc
138 |
139 | # backward
140 | optimizer.zero_grad()
141 | loss = criterion(oup, tar)
142 | loss.backward()
143 | loss = loss.item()
144 | tot_loss += loss
145 | if args.clip > 0:
146 | orig_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
147 | else:
148 | orig_norm = None
149 | optimizer.step()
150 | model_ema.update(model)
151 | torch.cuda.synchronize()
152 |
153 | # log
154 | if args.is_master and cur_it % log_freq == 0:
155 | tb_lg.add_scalar(f'ft_train/it_loss', loss, cur_it)
156 | tb_lg.add_scalar(f'ft_train/it_acc', acc, cur_it)
157 | tb_lg.add_scalar(f'ft_hp/min_lr', min_lr, cur_it), tb_lg.add_scalar(f'ft_hp/max_lr', max_lr, cur_it)
158 | tb_lg.add_scalar(f'ft_hp/min_wd', min_wd, cur_it), tb_lg.add_scalar(f'ft_hp/max_wd', max_wd, cur_it)
159 | if orig_norm is not None:
160 | tb_lg.add_scalar(f'ft_hp/orig_norm', orig_norm, cur_it)
161 |
162 | if it in [3, iters_train//2, iters_train-1]:
163 | remain_secs = (iters_train-1 - it) * (time.time() - ep_start_time) / (it + 1)
164 | remain_time = datetime.timedelta(seconds=round(remain_secs))
165 | print(f'[ep{ep} it{it:3d}/{iters_train}] L: {loss:.4f} Acc: {acc:.2f} lr: {min_lr:.1e}~{max_lr:.1e} Remain: {remain_time}')
166 |
167 | return tot_loss / iters_train, tot_acc / iters_train
168 |
169 |
170 | @torch.no_grad()
171 | def evaluate(dev, iterator_val, iters_val, model):
172 | training = model.training
173 | model.train(False)
174 | tot_pred, tot_correct = 0., 0.
175 | for _ in range(iters_val):
176 | inp, tar = next(iterator_val)
177 | tot_pred += tar.shape[0]
178 | inp = inp.to(dev, non_blocking=True)
179 | tar = tar.to(dev, non_blocking=True)
180 | oup = model(inp)
181 | tot_correct += oup.argmax(dim=1).eq(tar).sum().item()
182 | model.train(training)
183 | t = torch.tensor([tot_pred, tot_correct]).to(dev)
184 | tdist.all_reduce(t)
185 | return t[0].item(), (t[1] / t[0]).item() * 100.
186 |
187 |
188 | if __name__ == '__main__':
189 | main_ft()
190 |
--------------------------------------------------------------------------------
/downstream_imagenet/mixup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # This file is a modified version of timm.data.Mixup
8 | # Fixed error of "Batch size should be even when using this"
9 |
10 | """ Mixup and Cutmix
11 |
12 | Papers:
13 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
14 |
15 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
16 |
17 | Code Reference:
18 | CutMix: https://github.com/clovaai/CutMix-PyTorch
19 |
20 | Hacked together by / Copyright 2019, Ross Wightman
21 | """
22 | import numpy as np
23 | import torch
24 |
25 |
26 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
27 | x = x.long().view(-1, 1)
28 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
29 |
30 |
31 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
32 | off_value = smoothing / num_classes
33 | on_value = 1. - smoothing + off_value
34 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
35 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
36 | return y1 * lam + y2 * (1. - lam)
37 |
38 |
39 | def rand_bbox(img_shape, lam, margin=0., count=None):
40 | """ Standard CutMix bounding-box
41 | Generates a random square bbox based on lambda value. This impl includes
42 | support for enforcing a border margin as percent of bbox dimensions.
43 |
44 | Args:
45 | img_shape (tuple): Image shape as tuple
46 | lam (float): Cutmix lambda value
47 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
48 | count (int): Number of bbox to generate
49 | """
50 | ratio = np.sqrt(1 - lam)
51 | img_h, img_w = img_shape[-2:]
52 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
53 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
54 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
55 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
56 | yl = np.clip(cy - cut_h // 2, 0, img_h)
57 | yh = np.clip(cy + cut_h // 2, 0, img_h)
58 | xl = np.clip(cx - cut_w // 2, 0, img_w)
59 | xh = np.clip(cx + cut_w // 2, 0, img_w)
60 | return yl, yh, xl, xh
61 |
62 |
63 | def rand_bbox_minmax(img_shape, minmax, count=None):
64 | """ Min-Max CutMix bounding-box
65 | Inspired by Darknet cutmix impl, generates a random rectangular bbox
66 | based on min/max percent values applied to each dimension of the input image.
67 |
68 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
69 |
70 | Args:
71 | img_shape (tuple): Image shape as tuple
72 | minmax (tuple or list): Min and max bbox ratios (as percent of image size)
73 | count (int): Number of bbox to generate
74 | """
75 | assert len(minmax) == 2
76 | img_h, img_w = img_shape[-2:]
77 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
78 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
79 | yl = np.random.randint(0, img_h - cut_h, size=count)
80 | xl = np.random.randint(0, img_w - cut_w, size=count)
81 | yu = yl + cut_h
82 | xu = xl + cut_w
83 | return yl, yu, xl, xu
84 |
85 |
86 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
87 | """ Generate bbox and apply lambda correction.
88 | """
89 | if ratio_minmax is not None:
90 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
91 | else:
92 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
93 | if correct_lam or ratio_minmax is not None:
94 | bbox_area = (yu - yl) * (xu - xl)
95 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
96 | return (yl, yu, xl, xu), lam
97 |
98 |
99 | class BatchMixup:
100 | """ Mixup/Cutmix that applies different params to each element or whole batch
101 |
102 | Args:
103 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
104 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
105 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
106 | prob (float): probability of applying mixup or cutmix per batch or element
107 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
108 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
109 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
110 | label_smoothing (float): apply label smoothing to the mixed target tensor
111 | num_classes (int): number of classes for target
112 | """
113 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
114 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
115 | assert mode == 'batch'
116 | self.mixup_alpha = mixup_alpha
117 | self.cutmix_alpha = cutmix_alpha
118 | self.cutmix_minmax = cutmix_minmax
119 | if self.cutmix_minmax is not None:
120 | assert len(self.cutmix_minmax) == 2
121 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
122 | self.cutmix_alpha = 1.0
123 | self.mix_prob = prob
124 | self.switch_prob = switch_prob
125 | self.label_smoothing = label_smoothing
126 | self.num_classes = num_classes
127 | self.mode = mode
128 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
129 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
130 |
131 | def _params_per_batch(self):
132 | lam = 1.
133 | use_cutmix = False
134 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
135 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
136 | use_cutmix = np.random.rand() < self.switch_prob
137 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
138 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
139 | elif self.mixup_alpha > 0.:
140 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
141 | elif self.cutmix_alpha > 0.:
142 | use_cutmix = True
143 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
144 | else:
145 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
146 | lam = float(lam_mix)
147 | return lam, use_cutmix
148 |
149 | def _mix_batch(self, x):
150 | lam, use_cutmix = self._params_per_batch()
151 | if lam == 1.:
152 | return 1.
153 | if use_cutmix:
154 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
155 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
156 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
157 | else:
158 | x_flipped = x.flip(0).mul_(1. - lam)
159 | x.mul_(lam).add_(x_flipped)
160 | return lam
161 |
162 | def __call__(self, x, raw_target):
163 | if x.shape[0] % 2 == 1:
164 | x, raw_target = torch.cat((x[:1], x), dim=0), torch.cat((raw_target[:1], raw_target), dim=0)
165 | # assert len(x) % 2 == 0, 'Batch size should be even when using this'
166 | lam = self._mix_batch(x)
167 | target = mixup_target(raw_target, self.num_classes, lam, self.label_smoothing, x.device)
168 | return x, target, raw_target
169 |
--------------------------------------------------------------------------------
/downstream_imagenet/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | import torch
10 | from timm.data import Mixup
11 | from timm.loss import BinaryCrossEntropy, SoftTargetCrossEntropy
12 | from timm.models.layers import drop
13 | from timm.models.resnet import ResNet
14 |
15 | from .convnext_official import ConvNeXt
16 |
17 |
18 | def convnext_get_layer_id_and_scale_exp(self: ConvNeXt, para_name: str):
19 | N = 12 if len(self.stages[-2]) > 9 else 6
20 | if para_name.startswith("downsample_layers"):
21 | stage_id = int(para_name.split('.')[1])
22 | if stage_id == 0:
23 | layer_id = 0
24 | elif stage_id == 1 or stage_id == 2:
25 | layer_id = stage_id + 1
26 | else: # stage_id == 3:
27 | layer_id = N
28 | elif para_name.startswith("stages"):
29 | stage_id = int(para_name.split('.')[1])
30 | block_id = int(para_name.split('.')[2])
31 | if stage_id == 0 or stage_id == 1:
32 | layer_id = stage_id + 1
33 | elif stage_id == 2:
34 | layer_id = 3 + block_id // 3
35 | else: # stage_id == 3:
36 | layer_id = N
37 | else:
38 | layer_id = N + 1 # after backbone
39 |
40 | return layer_id, N + 1 - layer_id
41 |
42 |
43 | def resnets_get_layer_id_and_scale_exp(self: ResNet, para_name: str):
44 | # stages:
45 | # 50 : [3, 4, 6, 3]
46 | # 101 : [3, 4, 23, 3]
47 | # 152 : [3, 8, 36, 3]
48 | # 200 : [3, 24, 36, 3]
49 | # eca269d: [3, 30, 48, 8]
50 |
51 | L2, L3 = len(self.layer2), len(self.layer3)
52 | if L2 == 4 and L3 == 6:
53 | blk2, blk3 = 2, 3
54 | elif L2 == 4 and L3 == 23:
55 | blk2, blk3 = 2, 3
56 | elif L2 == 8 and L3 == 36:
57 | blk2, blk3 = 4, 4
58 | elif L2 == 24 and L3 == 36:
59 | blk2, blk3 = 4, 4
60 | elif L2 == 30 and L3 == 48:
61 | blk2, blk3 = 5, 6
62 | else:
63 | raise NotImplementedError
64 |
65 | N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5)
66 | N = 2 + N2 + N3
67 | if para_name.startswith('layer'): # 1, 2, 3, 4, 5
68 | stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1])
69 | if stage_id == 1:
70 | layer_id = 1
71 | elif stage_id == 2:
72 | layer_id = 2 + block_id // blk2 # 2, 3
73 | elif stage_id == 3:
74 | layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11
75 | else: # == 4
76 | layer_id = N # r50: 6 r101: 12
77 | elif para_name.startswith('fc.'):
78 | layer_id = N + 1 # r50: 7 r101: 13
79 | else:
80 | layer_id = 0
81 |
82 | return layer_id, N + 1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0
83 |
84 |
85 | def _ex_repr(self):
86 | return ', '.join(
87 | f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
88 | for k, v in vars(self).items()
89 | if not k.startswith('_') and k != 'training'
90 | and not isinstance(v, (torch.nn.Module, torch.Tensor))
91 | )
92 |
93 |
94 | # IMPORTANT: update some member functions
95 | __UPDATED = False
96 | if not __UPDATED:
97 | for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, BinaryCrossEntropy, Mixup, drop.DropPath):
98 | if hasattr(clz, 'extra_repr'):
99 | clz.extra_repr = _ex_repr
100 | else:
101 | clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
102 | ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp
103 | ConvNeXt.get_layer_id_and_scale_exp = convnext_get_layer_id_and_scale_exp
104 | __UPDATED = True
105 |
--------------------------------------------------------------------------------
/downstream_imagenet/models/convnext_official.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # This file is exactly the same as: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from timm.models.layers import trunc_normal_, DropPath
13 | from timm.models.registry import register_model
14 |
15 | class Block(nn.Module):
16 | r""" ConvNeXt Block. There are two equivalent implementations:
17 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
18 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
19 | We use (2) as we find it slightly faster in PyTorch
20 |
21 | Args:
22 | dim (int): Number of input channels.
23 | drop_path (float): Stochastic depth rate. Default: 0.0
24 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
25 | """
26 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
27 | super().__init__()
28 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29 | self.norm = LayerNorm(dim, eps=1e-6)
30 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
31 | self.act = nn.GELU()
32 | self.pwconv2 = nn.Linear(4 * dim, dim)
33 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
34 | requires_grad=True) if layer_scale_init_value > 0 else None
35 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
36 |
37 | def forward(self, x):
38 | input = x
39 | x = self.dwconv(x)
40 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
41 | x = self.norm(x)
42 | x = self.pwconv1(x)
43 | x = self.act(x)
44 | x = self.pwconv2(x)
45 | if self.gamma is not None:
46 | x = self.gamma * x
47 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
48 |
49 | x = input + self.drop_path(x)
50 | return x
51 |
52 | class ConvNeXt(nn.Module):
53 | r""" ConvNeXt
54 | A PyTorch impl of : `A ConvNet for the 2020s` -
55 | https://arxiv.org/pdf/2201.03545.pdf
56 | Args:
57 | in_chans (int): Number of input image channels. Default: 3
58 | num_classes (int): Number of classes for classification head. Default: 1000
59 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
60 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
61 | drop_path_rate (float): Stochastic depth rate. Default: 0.
62 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
63 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
64 | """
65 | def __init__(self, in_chans=3, num_classes=1000,
66 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
67 | layer_scale_init_value=1e-6, head_init_scale=1.,
68 | ):
69 | super().__init__()
70 |
71 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
72 | stem = nn.Sequential(
73 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
74 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
75 | )
76 | self.downsample_layers.append(stem)
77 | for i in range(3):
78 | downsample_layer = nn.Sequential(
79 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
80 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
81 | )
82 | self.downsample_layers.append(downsample_layer)
83 |
84 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
85 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
86 | cur = 0
87 | for i in range(4):
88 | stage = nn.Sequential(
89 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
90 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
91 | )
92 | self.stages.append(stage)
93 | cur += depths[i]
94 |
95 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
96 | self.head = nn.Linear(dims[-1], num_classes)
97 |
98 | self.apply(self._init_weights)
99 | self.head.weight.data.mul_(head_init_scale)
100 | self.head.bias.data.mul_(head_init_scale)
101 |
102 | def _init_weights(self, m):
103 | if isinstance(m, (nn.Conv2d, nn.Linear)):
104 | trunc_normal_(m.weight, std=.02)
105 | nn.init.constant_(m.bias, 0)
106 |
107 | def forward_features(self, x):
108 | for i in range(4):
109 | x = self.downsample_layers[i](x)
110 | x = self.stages[i](x)
111 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
112 |
113 | def forward(self, x):
114 | x = self.forward_features(x)
115 | x = self.head(x)
116 | return x
117 |
118 | class LayerNorm(nn.Module):
119 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
120 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
121 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
122 | with shape (batch_size, channels, height, width).
123 | """
124 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
125 | super().__init__()
126 | self.weight = nn.Parameter(torch.ones(normalized_shape))
127 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
128 | self.eps = eps
129 | self.data_format = data_format
130 | if self.data_format not in ["channels_last", "channels_first"]:
131 | raise NotImplementedError
132 | self.normalized_shape = (normalized_shape, )
133 |
134 | def forward(self, x):
135 | if self.data_format == "channels_last":
136 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
137 | elif self.data_format == "channels_first":
138 | u = x.mean(1, keepdim=True)
139 | s = (x - u).pow(2).mean(1, keepdim=True)
140 | x = (x - u) / torch.sqrt(s + self.eps)
141 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
142 | return x
143 |
144 |
145 | model_urls = {
146 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
147 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
148 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
149 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
150 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
151 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
152 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
153 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
154 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
155 | }
156 |
157 | @register_model
158 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
159 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
160 | if pretrained:
161 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
162 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
163 | model.load_state_dict(checkpoint["model"])
164 | return model
165 |
166 | @register_model
167 | def convnext_small(pretrained=False,in_22k=False, **kwargs):
168 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
169 | if pretrained:
170 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
171 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
172 | model.load_state_dict(checkpoint["model"])
173 | return model
174 |
175 | @register_model
176 | def convnext_base(pretrained=False, in_22k=False, **kwargs):
177 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
178 | if pretrained:
179 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
180 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
181 | model.load_state_dict(checkpoint["model"])
182 | return model
183 |
184 | @register_model
185 | def convnext_large(pretrained=False, in_22k=False, **kwargs):
186 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
187 | if pretrained:
188 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
189 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
190 | model.load_state_dict(checkpoint["model"])
191 | return model
192 |
193 | @register_model
194 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
195 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
196 | if pretrained:
197 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
198 | url = model_urls['convnext_xlarge_22k']
199 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
200 | model.load_state_dict(checkpoint["model"])
201 | return model
202 |
--------------------------------------------------------------------------------
/downstream_imagenet/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | numpy
3 | Pillow
4 | typed-argument-parser
5 | timm==0.5.4
6 | tensorboardx
7 |
--------------------------------------------------------------------------------
/downstream_imagenet/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import datetime
8 | import os
9 | import sys
10 | from functools import partial
11 | from typing import List, Tuple, Callable
12 |
13 | import pytz
14 | import torch
15 | import torch.distributed as tdist
16 | import torch.multiprocessing as tmp
17 | from timm import create_model
18 | from timm.loss import SoftTargetCrossEntropy, BinaryCrossEntropy
19 | from timm.optim import AdamW, Lamb
20 | from timm.utils import ModelEmaV2
21 | from torch.nn.parallel import DistributedDataParallel
22 | from torch.optim.optimizer import Optimizer
23 |
24 | from arg import FineTuneArgs
25 | from downstream_imagenet.mixup import BatchMixup
26 | from lr_decay import get_param_groups
27 |
28 |
29 | def time_str(for_dirname=False):
30 | return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]')
31 |
32 |
33 | def init_distributed_environ():
34 | # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
35 | if tmp.get_start_method(allow_none=True) is None:
36 | tmp.set_start_method('spawn')
37 | global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count()
38 | local_rank = global_rank % num_gpus
39 | torch.cuda.set_device(local_rank)
40 |
41 | tdist.init_process_group(backend='nccl')
42 | assert tdist.is_initialized(), 'torch.distributed is not initialized!'
43 | torch.backends.cudnn.benchmark = True
44 | torch.backends.cudnn.deterministic = False
45 |
46 | # print only when local_rank == 0 or print(..., force=True)
47 | import builtins as __builtin__
48 | builtin_print = __builtin__.print
49 |
50 | def prt(msg, *args, **kwargs):
51 | force = kwargs.pop('force', False)
52 | if local_rank == 0 or force:
53 | f_back = sys._getframe().f_back
54 | file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
55 | builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}', *args, **kwargs)
56 |
57 | __builtin__.print = prt
58 | tdist.barrier()
59 | return tdist.get_world_size(), global_rank, local_rank, torch.empty(1).cuda().device
60 |
61 |
62 | def create_model_opt(args: FineTuneArgs) -> Tuple[torch.nn.Module, Callable, torch.nn.Module, DistributedDataParallel, ModelEmaV2, Optimizer]:
63 | num_classes = 1000
64 | model_without_ddp: torch.nn.Module = create_model(args.model, num_classes=num_classes, drop_path_rate=args.drop_path).to(args.device)
65 | model_para = f'{sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) / 1e6:.1f}M'
66 | # create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
67 | model_ema = ModelEmaV2(model_without_ddp, decay=args.ema, device=args.device)
68 | if args.sbn:
69 | model_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_without_ddp)
70 | print(f'[model={args.model}] [#para={model_para}, drop_path={args.drop_path}, ema={args.ema}] {model_without_ddp}\n')
71 | model = DistributedDataParallel(model_without_ddp, device_ids=[args.local_rank], find_unused_parameters=False, broadcast_buffers=False)
72 | model.train()
73 | opt_cls = {
74 | 'adam': AdamW, 'adamw': AdamW,
75 | 'lamb': partial(Lamb, max_grad_norm=1e7, always_adapt=True, bias_correction=False),
76 | }
77 | param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}, lr_scale=args.lr_scale)
78 | # param_groups[0] is like this: {'params': List[nn.Parameters], 'lr': float, 'lr_scale': float, 'weight_decay': float, 'weight_decay_scale': float}
79 | optimizer = opt_cls[args.opt](param_groups, lr=args.lr, weight_decay=0)
80 | print(f'[optimizer={type(optimizer)}]')
81 | mixup_fn = BatchMixup(
82 | mixup_alpha=args.mixup, cutmix_alpha=1.0, cutmix_minmax=None,
83 | prob=1.0, switch_prob=0.5, mode='batch',
84 | label_smoothing=0.1, num_classes=num_classes
85 | )
86 | mixup_fn.mixup_enabled = args.mixup > 0.0
87 | if 'lamb' in args.opt:
88 | # label smoothing is solved in AdaptiveMixup with `label_smoothing`, so here smoothing=0
89 | criterion = BinaryCrossEntropy(smoothing=0, target_threshold=None)
90 | else:
91 | criterion = SoftTargetCrossEntropy()
92 | print(f'[loss_fn] {criterion}')
93 | print(f'[mixup_fn] {mixup_fn}')
94 | return criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer
95 |
96 |
97 | def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer):
98 | if len(resume_from) == 0 or not os.path.exists(resume_from):
99 | raise AttributeError(f'ckpt `{resume_from}` not found!')
100 | # return 0, '[no performance_desc]'
101 | print(f'[try to resume from file `{resume_from}`]')
102 | checkpoint = torch.load(resume_from, map_location='cpu')
103 | assert checkpoint.get('is_pretrain', False) == False, 'Please do not use `*_withdecoder_1kpretrained_spark_style.pth`, which is ONLY for resuming the pretraining. Use `*_1kpretrained_timm_style.pth` or `*_1kfinetuned*.pth` instead.'
104 |
105 | ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]')
106 | missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
107 | print(f'[load_checkpoint] missing_keys={missing}')
108 | print(f'[load_checkpoint] unexpected_keys={unexpected}')
109 | print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}')
110 |
111 | if 'optimizer' in checkpoint:
112 | optimizer.load_state_dict(checkpoint['optimizer'])
113 | if 'ema' in checkpoint:
114 | ema_module.load_state_dict(checkpoint['ema'])
115 | return ep_start, performance_desc
116 |
117 |
118 | def save_checkpoint(save_to, args, epoch, performance_desc, model_without_ddp_state, ema_state, optimizer_state):
119 | checkpoint_path = os.path.join(args.exp_dir, save_to)
120 | if args.is_local_master:
121 | to_save = {
122 | 'args': str(args),
123 | 'arch': args.model,
124 | 'epoch': epoch,
125 | 'performance_desc': performance_desc,
126 | 'module': model_without_ddp_state,
127 | 'ema': ema_state,
128 | 'optimizer': optimizer_state,
129 | 'is_pretrain': False,
130 | }
131 | torch.save(to_save, checkpoint_path)
132 |
--------------------------------------------------------------------------------
/downstream_mmdet/README.md:
--------------------------------------------------------------------------------
1 | ## About code isolation
2 |
3 | This `downstream_mmdet` is isolated from pre-training codes. One can treat this `downstream_mmdet` as an independent codebase 🛠️.
4 |
5 | ## Fine-tuned ConvNeXt-B weights, log files, and performance
6 |
7 |
8 |
9 |
10 | [[`weights (pre-trained by SparK)`](https://drive.google.com/file/d/1ZjWbqI1qoBcqeQijI5xX9E-YNkxpJcYV/view?usp=share_link)]
11 | [[`weights (fine-tuned on COCO)`](https://drive.google.com/file/d/1t10dmzg5KOO27o2yIglK-gQepB5gR4zR/view?usp=share_link)]
12 | [[`log.json`](https://drive.google.com/file/d/1TuNboXl1qwjf1tggZ3QOssI67uU7Jtig/view?usp=share_link)]
13 | [[`log`](https://drive.google.com/file/d/1JY5CkL_MX08zJ8P1FBIeC60OJsuIiyZc/view?usp=sharing)]
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | ## Installation [MMDetection with commit 6a979e2](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d) before fine-tuning ConvNeXt on COCO
23 |
24 | We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d).
25 | Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions.
26 |
27 | Note the COCO dataset folder should be at `downstream_mmdet/data/coco`.
28 | The folder should follow the directory structure requried by `MMDetection`, which should look like this:
29 | ```
30 | downstream_mmdet/data/coco:
31 | annotations/:
32 | captions_train2017.json captions_val2017.json
33 | instances_train2017.json instances_val2017.json
34 | person_keypoints_train2017.json person_keypoints_val2017.json
35 | train2017/:
36 | a_lot_images.jpg
37 | val2017/:
38 | a_lot_images.jpg
39 | ```
40 |
41 |
42 | ### Training
43 |
44 | To train a detector with pre-trained models, run:
45 | ```
46 | # single-gpu training
47 | python tools/train.py --cfg-options model.pretrained= [other optional arguments]
48 |
49 | # multi-gpu training
50 | tools/dist_train.sh --cfg-options model.pretrained= [other optional arguments]
51 | ```
52 | For example, to train a Mask R-CNN model with a SparK pretrained `ConvNeXt-B` backbone and 4 gpus, run:
53 | ```
54 | tools/dist_train.sh configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py 4 \
55 | --cfg-options model.pretrained=/some/path/to/official_convnext_base_1kpretrained.pth
56 | ```
57 |
58 | The Mask R-CNN 3x fine-tuning config file can be found at [`configs/convnext_spark`](configs/convnext_spark). This config is basically a copy of [https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py](https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py).
59 |
60 | ### Inference
61 | ```
62 | # single-gpu testing
63 | python tools/test.py --eval bbox segm
64 |
65 | # multi-gpu testing
66 | tools/dist_test.sh --eval bbox segm
67 | ```
68 |
69 | ## Acknowledgment
70 |
71 | We appreciate these useful codebases:
72 |
73 | - [MMDetection](https://github.com/open-mmlab/mmdetection)
74 | - [ConvNeXt](https://github.com/facebookresearch/ConvNeXt)
75 | - [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection)
76 |
77 |
--------------------------------------------------------------------------------
/downstream_mmdet/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | checkpoint_config = dict(interval=1)
2 | # yapf:disable
3 | log_config = dict(
4 | interval=50,
5 | hooks=[
6 | dict(type='CustomizedTextLoggerHook'),
7 | # dict(type='TensorboardLoggerHook')
8 | ])
9 | # yapf:enable
10 | custom_hooks = [dict(type='NumClassCheckHook')]
11 |
12 | dist_params = dict(backend='nccl')
13 | log_level = 'INFO'
14 | load_from = None
15 | resume_from = None
16 | workflow = [('train', 1)]
17 |
--------------------------------------------------------------------------------
/downstream_mmdet/configs/_base_/models/cascade_mask_rcnn_convnext_fpn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | # model settings
10 | model = dict(
11 | type='CascadeRCNN',
12 | pretrained=None,
13 | backbone=dict(
14 | type='ConvNeXt',
15 | in_chans=3,
16 | depths=[3, 3, 9, 3],
17 | dims=[96, 192, 384, 768],
18 | drop_path_rate=0.2,
19 | layer_scale_init_value=1e-6,
20 | out_indices=[0, 1, 2, 3],
21 | ),
22 | neck=dict(
23 | type='FPN',
24 | in_channels=[128, 256, 512, 1024],
25 | out_channels=256,
26 | num_outs=5),
27 | rpn_head=dict(
28 | type='RPNHead',
29 | in_channels=256,
30 | feat_channels=256,
31 | anchor_generator=dict(
32 | type='AnchorGenerator',
33 | scales=[8],
34 | ratios=[0.5, 1.0, 2.0],
35 | strides=[4, 8, 16, 32, 64]),
36 | bbox_coder=dict(
37 | type='DeltaXYWHBBoxCoder',
38 | target_means=[.0, .0, .0, .0],
39 | target_stds=[1.0, 1.0, 1.0, 1.0]),
40 | loss_cls=dict(
41 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
42 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
43 | roi_head=dict(
44 | type='CascadeRoIHead',
45 | num_stages=3,
46 | stage_loss_weights=[1, 0.5, 0.25],
47 | bbox_roi_extractor=dict(
48 | type='SingleRoIExtractor',
49 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
50 | out_channels=256,
51 | featmap_strides=[4, 8, 16, 32]),
52 | bbox_head=[
53 | dict(
54 | type='Shared2FCBBoxHead',
55 | in_channels=256,
56 | fc_out_channels=1024,
57 | roi_feat_size=7,
58 | num_classes=80,
59 | bbox_coder=dict(
60 | type='DeltaXYWHBBoxCoder',
61 | target_means=[0., 0., 0., 0.],
62 | target_stds=[0.1, 0.1, 0.2, 0.2]),
63 | reg_class_agnostic=True,
64 | loss_cls=dict(
65 | type='CrossEntropyLoss',
66 | use_sigmoid=False,
67 | loss_weight=1.0),
68 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
69 | loss_weight=1.0)),
70 | dict(
71 | type='Shared2FCBBoxHead',
72 | in_channels=256,
73 | fc_out_channels=1024,
74 | roi_feat_size=7,
75 | num_classes=80,
76 | bbox_coder=dict(
77 | type='DeltaXYWHBBoxCoder',
78 | target_means=[0., 0., 0., 0.],
79 | target_stds=[0.05, 0.05, 0.1, 0.1]),
80 | reg_class_agnostic=True,
81 | loss_cls=dict(
82 | type='CrossEntropyLoss',
83 | use_sigmoid=False,
84 | loss_weight=1.0),
85 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
86 | loss_weight=1.0)),
87 | dict(
88 | type='Shared2FCBBoxHead',
89 | in_channels=256,
90 | fc_out_channels=1024,
91 | roi_feat_size=7,
92 | num_classes=80,
93 | bbox_coder=dict(
94 | type='DeltaXYWHBBoxCoder',
95 | target_means=[0., 0., 0., 0.],
96 | target_stds=[0.033, 0.033, 0.067, 0.067]),
97 | reg_class_agnostic=True,
98 | loss_cls=dict(
99 | type='CrossEntropyLoss',
100 | use_sigmoid=False,
101 | loss_weight=1.0),
102 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
103 | ],
104 | mask_roi_extractor=dict(
105 | type='SingleRoIExtractor',
106 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
107 | out_channels=256,
108 | featmap_strides=[4, 8, 16, 32]),
109 | mask_head=dict(
110 | type='FCNMaskHead',
111 | num_convs=4,
112 | in_channels=256,
113 | conv_out_channels=256,
114 | num_classes=80,
115 | loss_mask=dict(
116 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
117 | # model training and testing settings
118 | train_cfg = dict(
119 | rpn=dict(
120 | assigner=dict(
121 | type='MaxIoUAssigner',
122 | pos_iou_thr=0.7,
123 | neg_iou_thr=0.3,
124 | min_pos_iou=0.3,
125 | match_low_quality=True,
126 | ignore_iof_thr=-1),
127 | sampler=dict(
128 | type='RandomSampler',
129 | num=256,
130 | pos_fraction=0.5,
131 | neg_pos_ub=-1,
132 | add_gt_as_proposals=False),
133 | allowed_border=0,
134 | pos_weight=-1,
135 | debug=False),
136 | rpn_proposal=dict(
137 | nms_across_levels=False,
138 | nms_pre=2000,
139 | nms_post=2000,
140 | max_per_img=2000,
141 | nms=dict(type='nms', iou_threshold=0.7),
142 | min_bbox_size=0),
143 | rcnn=[
144 | dict(
145 | assigner=dict(
146 | type='MaxIoUAssigner',
147 | pos_iou_thr=0.5,
148 | neg_iou_thr=0.5,
149 | min_pos_iou=0.5,
150 | match_low_quality=False,
151 | ignore_iof_thr=-1),
152 | sampler=dict(
153 | type='RandomSampler',
154 | num=512,
155 | pos_fraction=0.25,
156 | neg_pos_ub=-1,
157 | add_gt_as_proposals=True),
158 | mask_size=28,
159 | pos_weight=-1,
160 | debug=False),
161 | dict(
162 | assigner=dict(
163 | type='MaxIoUAssigner',
164 | pos_iou_thr=0.6,
165 | neg_iou_thr=0.6,
166 | min_pos_iou=0.6,
167 | match_low_quality=False,
168 | ignore_iof_thr=-1),
169 | sampler=dict(
170 | type='RandomSampler',
171 | num=512,
172 | pos_fraction=0.25,
173 | neg_pos_ub=-1,
174 | add_gt_as_proposals=True),
175 | mask_size=28,
176 | pos_weight=-1,
177 | debug=False),
178 | dict(
179 | assigner=dict(
180 | type='MaxIoUAssigner',
181 | pos_iou_thr=0.7,
182 | neg_iou_thr=0.7,
183 | min_pos_iou=0.7,
184 | match_low_quality=False,
185 | ignore_iof_thr=-1),
186 | sampler=dict(
187 | type='RandomSampler',
188 | num=512,
189 | pos_fraction=0.25,
190 | neg_pos_ub=-1,
191 | add_gt_as_proposals=True),
192 | mask_size=28,
193 | pos_weight=-1,
194 | debug=False)
195 | ]),
196 | test_cfg = dict(
197 | rpn=dict(
198 | nms_across_levels=False,
199 | nms_pre=1000,
200 | nms_post=1000,
201 | max_per_img=1000,
202 | nms=dict(type='nms', iou_threshold=0.7),
203 | min_bbox_size=0),
204 | rcnn=dict(
205 | score_thr=0.05,
206 | nms=dict(type='nms', iou_threshold=0.5),
207 | max_per_img=100,
208 | mask_thr_binary=0.5)))
209 |
--------------------------------------------------------------------------------
/downstream_mmdet/configs/_base_/models/mask_rcnn_convnext_fpn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | # model settings
10 | model = dict(
11 | type='MaskRCNN',
12 | pretrained=None,
13 | backbone=dict(
14 | type='ConvNeXt',
15 | in_chans=3,
16 | depths=[3, 3, 9, 3],
17 | dims=[96, 192, 384, 768],
18 | drop_path_rate=0.2,
19 | layer_scale_init_value=1e-6,
20 | out_indices=[0, 1, 2, 3],
21 | ),
22 | neck=dict(
23 | type='FPN',
24 | in_channels=[128, 256, 512, 1024],
25 | out_channels=256,
26 | num_outs=5),
27 | rpn_head=dict(
28 | type='RPNHead',
29 | in_channels=256,
30 | feat_channels=256,
31 | anchor_generator=dict(
32 | type='AnchorGenerator',
33 | scales=[8],
34 | ratios=[0.5, 1.0, 2.0],
35 | strides=[4, 8, 16, 32, 64]),
36 | bbox_coder=dict(
37 | type='DeltaXYWHBBoxCoder',
38 | target_means=[.0, .0, .0, .0],
39 | target_stds=[1.0, 1.0, 1.0, 1.0]),
40 | loss_cls=dict(
41 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
42 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
43 | roi_head=dict(
44 | type='StandardRoIHead',
45 | bbox_roi_extractor=dict(
46 | type='SingleRoIExtractor',
47 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
48 | out_channels=256,
49 | featmap_strides=[4, 8, 16, 32]),
50 | bbox_head=dict(
51 | type='Shared2FCBBoxHead',
52 | in_channels=256,
53 | fc_out_channels=1024,
54 | roi_feat_size=7,
55 | num_classes=80,
56 | bbox_coder=dict(
57 | type='DeltaXYWHBBoxCoder',
58 | target_means=[0., 0., 0., 0.],
59 | target_stds=[0.1, 0.1, 0.2, 0.2]),
60 | reg_class_agnostic=False,
61 | loss_cls=dict(
62 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
63 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
64 | mask_roi_extractor=dict(
65 | type='SingleRoIExtractor',
66 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
67 | out_channels=256,
68 | featmap_strides=[4, 8, 16, 32]),
69 | mask_head=dict(
70 | type='FCNMaskHead',
71 | num_convs=4,
72 | in_channels=256,
73 | conv_out_channels=256,
74 | num_classes=80,
75 | loss_mask=dict(
76 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
77 | # model training and testing settings
78 | train_cfg=dict(
79 | rpn=dict(
80 | assigner=dict(
81 | type='MaxIoUAssigner',
82 | pos_iou_thr=0.7,
83 | neg_iou_thr=0.3,
84 | min_pos_iou=0.3,
85 | match_low_quality=True,
86 | ignore_iof_thr=-1),
87 | sampler=dict(
88 | type='RandomSampler',
89 | num=256,
90 | pos_fraction=0.5,
91 | neg_pos_ub=-1,
92 | add_gt_as_proposals=False),
93 | allowed_border=-1,
94 | pos_weight=-1,
95 | debug=False),
96 | rpn_proposal=dict(
97 | nms_pre=2000,
98 | max_per_img=1000,
99 | nms=dict(type='nms', iou_threshold=0.7),
100 | min_bbox_size=0),
101 | rcnn=dict(
102 | assigner=dict(
103 | type='MaxIoUAssigner',
104 | pos_iou_thr=0.5,
105 | neg_iou_thr=0.5,
106 | min_pos_iou=0.5,
107 | match_low_quality=True,
108 | ignore_iof_thr=-1),
109 | sampler=dict(
110 | type='RandomSampler',
111 | num=512,
112 | pos_fraction=0.25,
113 | neg_pos_ub=-1,
114 | add_gt_as_proposals=True),
115 | mask_size=28,
116 | pos_weight=-1,
117 | debug=False)),
118 | test_cfg=dict(
119 | rpn=dict(
120 | nms_pre=1000,
121 | max_per_img=1000,
122 | nms=dict(type='nms', iou_threshold=0.7),
123 | min_bbox_size=0),
124 | rcnn=dict(
125 | score_thr=0.05,
126 | nms=dict(type='nms', iou_threshold=0.5),
127 | max_per_img=100,
128 | mask_thr_binary=0.5)))
129 |
--------------------------------------------------------------------------------
/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py:
--------------------------------------------------------------------------------
1 | """
2 | We directly take the ConvNeXt-T+MaskRCNN 3x recipe from https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
3 | And we modify this ConvNeXt-T+MaskRCNN 3x recipe to our ConvNeXt-B+MaskRCNN 3x recipe.
4 | The modifications (commented as [modified] below) are according to:
5 | - 1. tiny-to-base: (some configs of ConvNext-T are updated to those of ConvNext-B, referring to https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/cascade_mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco_in22k.py)
6 | - model.backbone.{depths, dims, drop_path_rate}
7 | - models.neck
8 | - optimizer.paramwise_cfg.num_layers
9 |
10 | - 2. our paper (https://openreview.net/forum?id=NRxydtWup1S, or https://arxiv.org/abs/2301.03580):
11 | - LR layer decay (optimizer.paramwise_cfg.decay_rate): 0.65
12 | - LR scheduled ratio (lr_config.gamma): 0.2
13 | - Learning rate (optimizer.lr): 0.0002
14 | - optimizer_config.use_fp16: False (we just use fp32 by default; actually we didn't test the performance of using fp16)
15 | """
16 |
17 | _base_ = [
18 | '../_base_/models/mask_rcnn_convnext_fpn.py',
19 | '../_base_/datasets/coco_instance.py',
20 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
21 | ]
22 |
23 | model = dict(
24 | backbone=dict(
25 | in_chans=3,
26 | depths=[3, 3, 27, 3], # [modified] according to tiny-to-base
27 | dims=[128, 256, 512, 1024], # [modified] according to tiny-to-base
28 | drop_path_rate=0.5, # [modified] according to tiny-to-base
29 | layer_scale_init_value=1.0,
30 | out_indices=[0, 1, 2, 3],
31 | ),
32 | neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny-to-base
33 |
34 | img_norm_cfg = dict(
35 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
36 |
37 | # augmentation strategy originates from DETR / Sparse RCNN
38 | train_pipeline = [
39 | dict(type='LoadImageFromFile'),
40 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
41 | dict(type='RandomFlip', flip_ratio=0.5),
42 | dict(type='AutoAugment',
43 | policies=[
44 | [
45 | dict(type='Resize',
46 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
47 | (608, 1333), (640, 1333), (672, 1333), (704, 1333),
48 | (736, 1333), (768, 1333), (800, 1333)],
49 | multiscale_mode='value',
50 | keep_ratio=True)
51 | ],
52 | [
53 | dict(type='Resize',
54 | img_scale=[(400, 1333), (500, 1333), (600, 1333)],
55 | multiscale_mode='value',
56 | keep_ratio=True),
57 | dict(type='RandomCrop',
58 | crop_type='absolute_range',
59 | crop_size=(384, 600),
60 | allow_negative_crop=True),
61 | dict(type='Resize',
62 | img_scale=[(480, 1333), (512, 1333), (544, 1333),
63 | (576, 1333), (608, 1333), (640, 1333),
64 | (672, 1333), (704, 1333), (736, 1333),
65 | (768, 1333), (800, 1333)],
66 | multiscale_mode='value',
67 | override=True,
68 | keep_ratio=True)
69 | ]
70 | ]),
71 | dict(type='Normalize', **img_norm_cfg),
72 | dict(type='Pad', size_divisor=32),
73 | dict(type='DefaultFormatBundle'),
74 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
75 | ]
76 | data = dict(train=dict(pipeline=train_pipeline))
77 |
78 | optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
79 | lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05, # [modified] according to our paper
80 | paramwise_cfg={'decay_rate': 0.65, # [modified] according to our paper
81 | 'decay_type': 'layer_wise',
82 | 'num_layers': 12}) # [modified] according to tiny-to-base
83 | lr_config = dict(step=[27, 33], gamma=0.2) # [modified] according to our paper
84 | runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
85 |
86 | # do not use mmdet version fp16
87 | fp16 = None
88 | optimizer_config = dict(
89 | type="DistOptimizerHook",
90 | update_interval=1,
91 | grad_clip=None,
92 | coalesce=True,
93 | bucket_size_mb=-1,
94 | use_fp16=False, # [modified] True => False
95 | )
--------------------------------------------------------------------------------
/downstream_mmdet/mmcv_custom/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | # -*- coding: utf-8 -*-
10 |
11 | from .checkpoint import load_checkpoint
12 | from .layer_decay_optimizer_constructor import LearningRateDecayOptimizerConstructor
13 | from .customized_text import CustomizedTextLoggerHook
14 |
15 | __all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'CustomizedTextLoggerHook']
16 |
--------------------------------------------------------------------------------
/downstream_mmdet/mmcv_custom/customized_text.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import datetime
10 | from collections import OrderedDict
11 |
12 | import torch
13 |
14 | import mmcv
15 | from mmcv.runner import HOOKS
16 | from mmcv.runner import TextLoggerHook
17 |
18 |
19 | @HOOKS.register_module()
20 | class CustomizedTextLoggerHook(TextLoggerHook):
21 | """Customized Text Logger hook.
22 |
23 | This logger prints out both lr and layer_0_lr.
24 |
25 | """
26 |
27 | def _log_info(self, log_dict, runner):
28 | # print exp name for users to distinguish experiments
29 | # at every ``interval_exp_name`` iterations and the end of each epoch
30 | if runner.meta is not None and 'exp_name' in runner.meta:
31 | if (self.every_n_iters(runner, self.interval_exp_name)) or (
32 | self.by_epoch and self.end_of_epoch(runner)):
33 | exp_info = f'Exp name: {runner.meta["exp_name"]}'
34 | runner.logger.info(exp_info)
35 |
36 | if log_dict['mode'] == 'train':
37 | lr_str = {}
38 | for lr_type in ['lr', 'layer_0_lr']:
39 | if isinstance(log_dict[lr_type], dict):
40 | lr_str[lr_type] = []
41 | for k, val in log_dict[lr_type].items():
42 | lr_str.append(f'{lr_type}_{k}: {val:.3e}')
43 | lr_str[lr_type] = ' '.join(lr_str)
44 | else:
45 | lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}'
46 |
47 | # by epoch: Epoch [4][100/1000]
48 | # by iter: Iter [100/100000]
49 | if self.by_epoch:
50 | log_str = f'Epoch [{log_dict["epoch"]}]' \
51 | f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
52 | else:
53 | log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
54 | log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, '
55 |
56 | if 'time' in log_dict.keys():
57 | self.time_sec_tot += (log_dict['time'] * self.interval)
58 | time_sec_avg = self.time_sec_tot / (
59 | runner.iter - self.start_iter + 1)
60 | eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
62 | log_str += f'eta: {eta_str}, '
63 | log_str += f'time: {log_dict["time"]:.3f}, ' \
64 | f'data_time: {log_dict["data_time"]:.3f}, '
65 | # statistic memory
66 | if torch.cuda.is_available():
67 | log_str += f'memory: {log_dict["memory"]}, '
68 | else:
69 | # val/test time
70 | # here 1000 is the length of the val dataloader
71 | # by epoch: Epoch[val] [4][1000]
72 | # by iter: Iter[val] [1000]
73 | if self.by_epoch:
74 | log_str = f'Epoch({log_dict["mode"]}) ' \
75 | f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
76 | else:
77 | log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
78 |
79 | log_items = []
80 | for name, val in log_dict.items():
81 | # TODO: resolve this hack
82 | # these items have been in log_str
83 | if name in [
84 | 'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time',
85 | 'memory', 'epoch'
86 | ]:
87 | continue
88 | if isinstance(val, float):
89 | val = f'{val:.4f}'
90 | log_items.append(f'{name}: {val}')
91 | log_str += ', '.join(log_items)
92 |
93 | runner.logger.info(log_str)
94 |
95 |
96 | def log(self, runner):
97 | if 'eval_iter_num' in runner.log_buffer.output:
98 | # this doesn't modify runner.iter and is regardless of by_epoch
99 | cur_iter = runner.log_buffer.output.pop('eval_iter_num')
100 | else:
101 | cur_iter = self.get_iter(runner, inner_iter=True)
102 |
103 | log_dict = OrderedDict(
104 | mode=self.get_mode(runner),
105 | epoch=self.get_epoch(runner),
106 | iter=cur_iter)
107 |
108 | # record lr and layer_0_lr
109 | cur_lr = runner.current_lr()
110 | if isinstance(cur_lr, list):
111 | log_dict['layer_0_lr'] = min(cur_lr)
112 | log_dict['lr'] = max(cur_lr)
113 | else:
114 | assert isinstance(cur_lr, dict)
115 | log_dict['lr'], log_dict['layer_0_lr'] = {}, {}
116 | for k, lr_ in cur_lr.items():
117 | assert isinstance(lr_, list)
118 | log_dict['layer_0_lr'].update({k: min(lr_)})
119 | log_dict['lr'].update({k: max(lr_)})
120 |
121 | if 'time' in runner.log_buffer.output:
122 | # statistic memory
123 | if torch.cuda.is_available():
124 | log_dict['memory'] = self._get_max_memory(runner)
125 |
126 | log_dict = dict(log_dict, **runner.log_buffer.output)
127 |
128 | self._log_info(log_dict, runner)
129 | self._dump_log(log_dict, runner)
130 | return log_dict
131 |
--------------------------------------------------------------------------------
/downstream_mmdet/mmcv_custom/layer_decay_optimizer_constructor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import json
10 | from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
11 | from mmcv.runner import get_dist_info
12 |
13 |
14 | def get_num_layer_layer_wise(var_name, num_max_layer=12):
15 |
16 | if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"):
17 | return 0
18 | elif var_name.startswith("backbone.downsample_layers"):
19 | stage_id = int(var_name.split('.')[2])
20 | if stage_id == 0:
21 | layer_id = 0
22 | elif stage_id == 1:
23 | layer_id = 2
24 | elif stage_id == 2:
25 | layer_id = 3
26 | elif stage_id == 3:
27 | layer_id = num_max_layer
28 | return layer_id
29 | elif var_name.startswith("backbone.stages"):
30 | stage_id = int(var_name.split('.')[2])
31 | block_id = int(var_name.split('.')[3])
32 | if stage_id == 0:
33 | layer_id = 1
34 | elif stage_id == 1:
35 | layer_id = 2
36 | elif stage_id == 2:
37 | layer_id = 3 + block_id // 3
38 | elif stage_id == 3:
39 | layer_id = num_max_layer
40 | return layer_id
41 | else:
42 | return num_max_layer + 1
43 |
44 |
45 | def get_num_layer_stage_wise(var_name, num_max_layer):
46 | if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"):
47 | return 0
48 | elif var_name.startswith("backbone.downsample_layers"):
49 | return 0
50 | elif var_name.startswith("backbone.stages"):
51 | stage_id = int(var_name.split('.')[2])
52 | return stage_id + 1
53 | else:
54 | return num_max_layer - 1
55 |
56 |
57 | @OPTIMIZER_BUILDERS.register_module()
58 | class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
59 | def add_params(self, params, module, prefix='', is_dcn_module=None):
60 | """Add all parameters of module to the params list.
61 | The parameters of the given module will be added to the list of param
62 | groups, with specific rules defined by paramwise_cfg.
63 | Args:
64 | params (list[dict]): A list of param groups, it will be modified
65 | in place.
66 | module (nn.Module): The module to be added.
67 | prefix (str): The prefix of the module
68 | is_dcn_module (int|float|None): If the current module is a
69 | submodule of DCN, `is_dcn_module` will be passed to
70 | control conv_offset layer's learning rate. Defaults to None.
71 | """
72 | parameter_groups = {}
73 | print(self.paramwise_cfg)
74 | num_layers = self.paramwise_cfg.get('num_layers') + 2
75 | decay_rate = self.paramwise_cfg.get('decay_rate')
76 | decay_type = self.paramwise_cfg.get('decay_type', "layer_wise")
77 | print("Build LearningRateDecayOptimizerConstructor %s %f - %d" % (decay_type, decay_rate, num_layers))
78 | weight_decay = self.base_wd
79 |
80 | for name, param in module.named_parameters():
81 | if not param.requires_grad:
82 | continue # frozen weights
83 | if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'):
84 | group_name = "no_decay"
85 | this_weight_decay = 0.
86 | else:
87 | group_name = "decay"
88 | this_weight_decay = weight_decay
89 |
90 | if decay_type == "layer_wise":
91 | layer_id = get_num_layer_layer_wise(name, self.paramwise_cfg.get('num_layers'))
92 | elif decay_type == "stage_wise":
93 | layer_id = get_num_layer_stage_wise(name, num_layers)
94 |
95 | group_name = "layer_%d_%s" % (layer_id, group_name)
96 |
97 | if group_name not in parameter_groups:
98 | scale = decay_rate ** (num_layers - layer_id - 1)
99 |
100 | parameter_groups[group_name] = {
101 | "weight_decay": this_weight_decay,
102 | "params": [],
103 | "param_names": [],
104 | "lr_scale": scale,
105 | "group_name": group_name,
106 | "lr": scale * self.base_lr,
107 | }
108 |
109 | parameter_groups[group_name]["params"].append(param)
110 | parameter_groups[group_name]["param_names"].append(name)
111 | rank, _ = get_dist_info()
112 | if rank == 0:
113 | to_display = {}
114 | for key in parameter_groups:
115 | to_display[key] = {
116 | "param_names": parameter_groups[key]["param_names"],
117 | "lr_scale": parameter_groups[key]["lr_scale"],
118 | "lr": parameter_groups[key]["lr"],
119 | "weight_decay": parameter_groups[key]["weight_decay"],
120 | }
121 | print("Param groups = %s" % json.dumps(to_display, indent=2))
122 |
123 | params.extend(parameter_groups.values())
124 |
--------------------------------------------------------------------------------
/downstream_mmdet/mmcv_custom/runner/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 | import os.path as osp
3 | import time
4 | from tempfile import TemporaryDirectory
5 |
6 | import torch
7 | from torch.optim import Optimizer
8 |
9 | import mmcv
10 | from mmcv.parallel import is_module_wrapper
11 | from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
12 |
13 | try:
14 | import apex
15 | except:
16 | print('apex is not installed')
17 |
18 |
19 | def save_checkpoint(model, filename, optimizer=None, meta=None):
20 | """Save checkpoint to file.
21 |
22 | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
23 | ``optimizer``, ``amp``. By default ``meta`` will contain version
24 | and time info.
25 |
26 | Args:
27 | model (Module): Module whose params are to be saved.
28 | filename (str): Checkpoint filename.
29 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
30 | meta (dict, optional): Metadata to be saved in checkpoint.
31 | """
32 | if meta is None:
33 | meta = {}
34 | elif not isinstance(meta, dict):
35 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
36 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
37 |
38 | if is_module_wrapper(model):
39 | model = model.module
40 |
41 | if hasattr(model, 'CLASSES') and model.CLASSES is not None:
42 | # save class name to the meta
43 | meta.update(CLASSES=model.CLASSES)
44 |
45 | checkpoint = {
46 | 'meta': meta,
47 | 'state_dict': weights_to_cpu(get_state_dict(model))
48 | }
49 | # save optimizer state dict in the checkpoint
50 | if isinstance(optimizer, Optimizer):
51 | checkpoint['optimizer'] = optimizer.state_dict()
52 | elif isinstance(optimizer, dict):
53 | checkpoint['optimizer'] = {}
54 | for name, optim in optimizer.items():
55 | checkpoint['optimizer'][name] = optim.state_dict()
56 |
57 | # save amp state dict in the checkpoint
58 | # checkpoint['amp'] = apex.amp.state_dict()
59 |
60 | if filename.startswith('pavi://'):
61 | try:
62 | from pavi import modelcloud
63 | from pavi.exception import NodeNotFoundError
64 | except ImportError:
65 | raise ImportError(
66 | 'Please install pavi to load checkpoint from modelcloud.')
67 | model_path = filename[7:]
68 | root = modelcloud.Folder()
69 | model_dir, model_name = osp.split(model_path)
70 | try:
71 | model = modelcloud.get(model_dir)
72 | except NodeNotFoundError:
73 | model = root.create_training_model(model_dir)
74 | with TemporaryDirectory() as tmp_dir:
75 | checkpoint_file = osp.join(tmp_dir, model_name)
76 | with open(checkpoint_file, 'wb') as f:
77 | torch.save(checkpoint, f)
78 | f.flush()
79 | model.create_file(checkpoint_file, name=model_name)
80 | else:
81 | mmcv.mkdir_or_exist(osp.dirname(filename))
82 | # immediately flush buffer
83 | with open(filename, 'wb') as f:
84 | torch.save(checkpoint, f)
85 | f.flush()
86 |
--------------------------------------------------------------------------------
/downstream_mmdet/mmdet/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .darknet import Darknet
2 | from .detectors_resnet import DetectoRS_ResNet
3 | from .detectors_resnext import DetectoRS_ResNeXt
4 | from .hourglass import HourglassNet
5 | from .hrnet import HRNet
6 | from .regnet import RegNet
7 | from .res2net import Res2Net
8 | from .resnest import ResNeSt
9 | from .resnet import ResNet, ResNetV1d
10 | from .resnext import ResNeXt
11 | from .ssd_vgg import SSDVGG
12 | from .trident_resnet import TridentResNet
13 | from .swin_transformer import SwinTransformer
14 | from .convnext import ConvNeXt
15 |
16 | __all__ = [
17 | 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
18 | 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
19 | 'ResNeSt', 'TridentResNet', 'SwinTransformer', 'ConvNeXt'
20 | ]
21 |
--------------------------------------------------------------------------------
/downstream_mmdet/mmdet/models/backbones/convnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | from functools import partial
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from timm.models.layers import trunc_normal_, DropPath
14 |
15 | from mmcv_custom import load_checkpoint
16 | from mmdet.utils import get_root_logger
17 | from ..builder import BACKBONES
18 |
19 | class Block(nn.Module):
20 | r""" ConvNeXt Block. There are two equivalent implementations:
21 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
22 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
23 | We use (2) as we find it slightly faster in PyTorch
24 |
25 | Args:
26 | dim (int): Number of input channels.
27 | drop_path (float): Stochastic depth rate. Default: 0.0
28 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
29 | """
30 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
31 | super().__init__()
32 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
33 | self.norm = LayerNorm(dim, eps=1e-6)
34 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
35 | self.act = nn.GELU()
36 | self.pwconv2 = nn.Linear(4 * dim, dim)
37 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
38 | requires_grad=True) if layer_scale_init_value > 0 else None
39 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
40 |
41 | def forward(self, x):
42 | input = x
43 | x = self.dwconv(x)
44 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
45 | x = self.norm(x)
46 | x = self.pwconv1(x)
47 | x = self.act(x)
48 | x = self.pwconv2(x)
49 | if self.gamma is not None:
50 | x = self.gamma * x
51 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
52 |
53 | x = input + self.drop_path(x)
54 | return x
55 |
56 | @BACKBONES.register_module()
57 | class ConvNeXt(nn.Module):
58 | r""" ConvNeXt
59 | A PyTorch impl of : `A ConvNet for the 2020s` -
60 | https://arxiv.org/pdf/2201.03545.pdf
61 |
62 | Args:
63 | in_chans (int): Number of input image channels. Default: 3
64 | num_classes (int): Number of classes for classification head. Default: 1000
65 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
66 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
67 | drop_path_rate (float): Stochastic depth rate. Default: 0.
68 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
69 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
70 | """
71 | def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
72 | drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3],
73 | ):
74 | super().__init__()
75 |
76 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
77 | stem = nn.Sequential(
78 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
79 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
80 | )
81 | self.downsample_layers.append(stem)
82 | for i in range(3):
83 | downsample_layer = nn.Sequential(
84 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
85 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
86 | )
87 | self.downsample_layers.append(downsample_layer)
88 |
89 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
90 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
91 | cur = 0
92 | for i in range(4):
93 | stage = nn.Sequential(
94 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
95 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
96 | )
97 | self.stages.append(stage)
98 | cur += depths[i]
99 |
100 | self.out_indices = out_indices
101 |
102 | norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
103 | for i_layer in range(4):
104 | layer = norm_layer(dims[i_layer])
105 | layer_name = f'norm{i_layer}'
106 | self.add_module(layer_name, layer)
107 |
108 | self.apply(self._init_weights)
109 |
110 | def _init_weights(self, m):
111 | if isinstance(m, (nn.Conv2d, nn.Linear)):
112 | trunc_normal_(m.weight, std=.02)
113 | nn.init.constant_(m.bias, 0)
114 |
115 | def init_weights(self, pretrained=None):
116 | """Initialize the weights in backbone.
117 | Args:
118 | pretrained (str, optional): Path to pre-trained weights.
119 | Defaults to None.
120 | """
121 |
122 | def _init_weights(m):
123 | if isinstance(m, nn.Linear):
124 | trunc_normal_(m.weight, std=.02)
125 | if isinstance(m, nn.Linear) and m.bias is not None:
126 | nn.init.constant_(m.bias, 0)
127 | elif isinstance(m, nn.LayerNorm):
128 | nn.init.constant_(m.bias, 0)
129 | nn.init.constant_(m.weight, 1.0)
130 |
131 | if isinstance(pretrained, str):
132 | self.apply(_init_weights)
133 | logger = get_root_logger()
134 | load_checkpoint(self, pretrained, strict=False, logger=logger)
135 | elif pretrained is None:
136 | self.apply(_init_weights)
137 | else:
138 | raise TypeError('pretrained must be a str or None')
139 |
140 | def forward_features(self, x):
141 | outs = []
142 | for i in range(4):
143 | x = self.downsample_layers[i](x)
144 | x = self.stages[i](x)
145 | if i in self.out_indices:
146 | norm_layer = getattr(self, f'norm{i}')
147 | x_out = norm_layer(x)
148 | outs.append(x_out)
149 |
150 | return tuple(outs)
151 |
152 | def forward(self, x):
153 | x = self.forward_features(x)
154 | return x
155 |
156 | class LayerNorm(nn.Module):
157 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
158 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
159 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
160 | with shape (batch_size, channels, height, width).
161 | """
162 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
163 | super().__init__()
164 | self.weight = nn.Parameter(torch.ones(normalized_shape))
165 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
166 | self.eps = eps
167 | self.data_format = data_format
168 | if self.data_format not in ["channels_last", "channels_first"]:
169 | raise NotImplementedError
170 | self.normalized_shape = (normalized_shape, )
171 |
172 | def forward(self, x):
173 | if self.data_format == "channels_last":
174 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
175 | elif self.data_format == "channels_first":
176 | u = x.mean(1, keepdim=True)
177 | s = (x - u).pow(2).mean(1, keepdim=True)
178 | x = (x - u) / torch.sqrt(s + self.eps)
179 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
180 | return x
181 |
--------------------------------------------------------------------------------
/pretrain/README.md:
--------------------------------------------------------------------------------
1 | ## Preparation for ImageNet-1k pretraining
2 |
3 | See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
4 |
5 | **Note: for neural network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
6 |
7 |
8 | ## Tutorial for pretraining your own CNN model
9 |
10 | See [/pretrain/models/custom.py](/pretrain/models/custom.py). Your todo list is:
11 |
12 | - implement `get_downsample_ratio` in [/pretrain/models/custom.py line20](/pretrain/models/custom.py#L20).
13 | - implement `get_feature_map_channels` in [/pretrain/models/custom.py line29](/pretrain/models/custom.py#L29).
14 | - implement `forward` in [/pretrain/models/custom.py line38](/pretrain/models/custom.py#L38).
15 | - define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py line54](/pretrain/models/custom.py#L53-L54).
16 | - add default kwargs of `your_convnet(...)` in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34).
17 | - **Note: see [#54](/../../issues/54) if your CNN contains SE module or global average pooling layer, and see [#56](/../../issues/56) if it contains GroupNorm**.
18 |
19 | Then run the experiment with `--model=your_convnet`.
20 |
21 |
22 | ## Tutorial for pretraining on your own dataset
23 |
24 | See the comment of `build_dataset_to_pretrain` in [line55 of /pretrain/utils/imagenet.py](/pretrain/utils/imagenet.py#L55). Your todo list:
25 |
26 | - Define a subclass of `torch.utils.data.Dataset` for your own unlabeled dataset, to replace our `ImageNetDataset`.
27 | - Use `args.data_path` and `args.input_size` to help build your dataset, with `--data_path=... --input_size=...` to specify them.
28 | - Note the batch size `--bs` is the total batch size of all GPU, which may need to be adjusted based on your dataset size. FYI: we use `--bs=4096` for ImageNet, which contains 1.28 million images.
29 |
30 | **If your dataset is relatively small**, you can try `--init_weight=/path/to/res50_withdecoder_1kpretrained_spark_style.pth` to do your pretraining *from our pretrained weights*, rather than *form scratch*.
31 |
32 | ## Debug on 1 GPU (without DistributedDataParallel)
33 |
34 | Use a small batch size `--bs=32` for avoiding OOM.
35 |
36 | ```shell script
37 | python3 main.py --exp_name=debug --data_path=/path/to/imagenet --model=resnet50 --bs=32
38 | ```
39 |
40 |
41 | ## Pretraining Any Model on ImageNet-1k (224x224)
42 |
43 | For pretraining, run [/pretrain/main.py](/pretrain/main.py) with `torchrun`.
44 | **It is required to specify** the ImageNet data folder (`--data_path`), your experiment name & log dir (`--exp_name` and `--exp_dir`, automatically created if not exists), and the model name (`--model`, valid choices see the keys of 'pretrain_default_model_kwargs' in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34)).
45 |
46 | We use the **same** pretraining configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts) in 224 pretraining.
47 | Their **names** and **default values** are in [/pretrain/utils/arg_util.py line23-44](/pretrain/utils/arg_util.py#L23-L44).
48 | All these default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`.
49 |
50 | **Note: the batch size `--bs` is the total batch size of all GPU, and the learning rate `--base_lr` is the base lr. The actual lr would be `lr = base_lr * bs / 256`, as in [/pretrain/utils/arg_util.py line131](/pretrain/utils/arg_util.py#L131). So do not use `--lr` to specify a lr (that will be ignored)**
51 |
52 | Here is an example to pretrain a ResNet50 on an 8-GPU single machine (we use DistributedDataParallel), overwriting the default batch size to 512:
53 | ```shell script
54 | $ cd /path/to/SparK/pretrain
55 | $ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port= main.py \
56 | --data_path=/path/to/imagenet --exp_name= --exp_dir=/path/to/logdir \
57 | --model=resnet50 --bs=512
58 | ```
59 |
60 | For multiple machines, change the `--nnodes`, `--node_rank`, `--master_address` and `--master_port` to your configurations. E.g.:
61 | ```shell script
62 | $ torchrun --nproc_per_node=8 --nnodes= --node_rank= --master_address= --master_port= main.py \
63 | ...
64 | ```
65 |
66 | ## Pretraining ConvNeXt-Large on ImageNet-1k (384x384)
67 |
68 | For 384 pretraining we use a larger mask ratio (0.75), a half batch size (2048), and a double base learning rate (4e-4):
69 |
70 | ```shell script
71 | $ cd /path/to/SparK/pretrain
72 | $ torchrun --nproc_per_node=8 --nnodes= --node_rank= --master_address= --master_port= main.py \
73 | --data_path=/path/to/imagenet --exp_name= --exp_dir=/path/to/logdir \
74 | --model=convnext_large --input_size=384 --mask=0.75 --bs=2048 --base_lr=4e-4
75 | ```
76 |
77 | ## Logging
78 |
79 | See files in your `--exp_dir` to track your experiment:
80 |
81 | - `_withdecoder_1kpretrained_spark_style.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc.; can be used to resume pretraining; can also be used for visualization in [/pretrain/viz_reconstruction.ipynb](/pretrain/viz_reconstruction.ipynb)
82 | - `_1kpretrained_timm_style.pth`: can be used for downstream finetuning
83 | - `pretrain_log.txt`: records some important information such as:
84 | - `git_commit_id`: git version
85 | - `cmd`: the command of this experiment
86 |
87 | It also reports the loss and remaining pretraining time.
88 |
89 | - `tensorboard_log/`: saves a lot of tensorboard logs including loss values, learning rates, gradient norms and more things. Use `tensorboard --logdir /path/to/this/tensorboard_log/ --port 23333` for viz.
90 | - `stdout_backup.txt` and `stderr_backup.txt`: backups stdout/stderr.
91 |
92 | ## Resuming
93 |
94 | Specify `--resume_from=path/to/_withdecoder_1kpretrained_spark_style.pth` to resume pretraining. Note this is different from `--init_weight`:
95 |
96 | - `--resume_from` will load three things: model weights, optimizer states, and current epoch, so it is used to resume some interrupted experiment (will start from that 'current epoch').
97 | - `--init_weight` ONLY loads the model weights, so it's just like a model initialization (will start from epoch 0).
98 |
99 |
100 | ## Regarding sparse convolution
101 |
102 | We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardware.
103 | As can be found in [/pretrain/encoder.py](/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution.
104 | We also define some sparse pooling or normalization layers in [/pretrain/encoder.py](/pretrain/encoder.py).
105 | All these "sparse" layers are implemented through pytorch built-in operators.
106 |
107 |
108 | ## Some details: how we mask images and how to set the patch size
109 |
110 | In SparK, the mask patch size **equals to** the downsample ratio of the CNN model (so there is no configuration like `--patch_size=32`).
111 |
112 | Here is the reason: when we do mask, we:
113 |
114 | 1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py#L86-L87), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_h, fmap_w]`, and would be used to mask the smallest feature map.
115 | 3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., dim=2)` and `repeat_interleave(..., dim=3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py#L21)) with larger resolutions .
116 |
117 | So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8.
118 | See [Tutorial for pretraining your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model).
119 |
--------------------------------------------------------------------------------
/pretrain/decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import List
9 |
10 | import torch
11 | import torch.nn as nn
12 | from timm.models.layers import trunc_normal_
13 |
14 | from utils.misc import is_pow2n
15 |
16 |
17 | class UNetBlock(nn.Module):
18 | def __init__(self, cin, cout, bn2d):
19 | """
20 | a UNet block with 2x up sampling
21 | """
22 | super().__init__()
23 | self.up_sample = nn.ConvTranspose2d(cin, cin, kernel_size=4, stride=2, padding=1, bias=True)
24 | self.conv = nn.Sequential(
25 | nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cin), nn.ReLU6(inplace=True),
26 | nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cout),
27 | )
28 |
29 | def forward(self, x):
30 | x = self.up_sample(x)
31 | return self.conv(x)
32 |
33 |
34 | class LightDecoder(nn.Module):
35 | def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
36 | super().__init__()
37 | self.width = width
38 | assert is_pow2n(up_sample_ratio)
39 | n = round(math.log2(up_sample_ratio))
40 | channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
41 | bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d
42 | self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])])
43 | self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True)
44 |
45 | self.initialize()
46 |
47 | def forward(self, to_dec: List[torch.Tensor]):
48 | x = 0
49 | for i, d in enumerate(self.dec):
50 | if i < len(to_dec) and to_dec[i] is not None:
51 | x = x + to_dec[i]
52 | x = self.dec[i](x)
53 | return self.proj(x)
54 |
55 | def extra_repr(self) -> str:
56 | return f'width={self.width}'
57 |
58 | def initialize(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.Linear):
61 | trunc_normal_(m.weight, std=.02)
62 | if m.bias is not None:
63 | nn.init.constant_(m.bias, 0)
64 | elif isinstance(m, nn.Conv2d):
65 | trunc_normal_(m.weight, std=.02)
66 | if m.bias is not None:
67 | nn.init.constant_(m.bias, 0)
68 | elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
69 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
70 | if m.bias is not None:
71 | nn.init.constant_(m.bias, 0.)
72 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)):
73 | nn.init.constant_(m.bias, 0)
74 | nn.init.constant_(m.weight, 1.0)
75 |
--------------------------------------------------------------------------------
/pretrain/dist.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | from typing import List
9 | from typing import Union
10 |
11 | import sys
12 | import torch
13 | import torch.distributed as tdist
14 | import torch.multiprocessing as mp
15 |
16 | __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
17 | __initialized = False
18 |
19 |
20 | def initialized():
21 | return __initialized
22 |
23 |
24 | def initialize(backend='nccl'):
25 | global __device
26 | if not torch.cuda.is_available():
27 | print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
28 | return
29 | elif 'RANK' not in os.environ:
30 | __device = torch.empty(1).cuda().device
31 | print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr)
32 | return
33 |
34 | # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
35 | if mp.get_start_method(allow_none=True) is None:
36 | mp.set_start_method('spawn')
37 | global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
38 | local_rank = global_rank % num_gpus
39 | torch.cuda.set_device(local_rank)
40 | tdist.init_process_group(backend=backend)
41 |
42 | global __rank, __local_rank, __world_size, __initialized
43 | __local_rank = local_rank
44 | __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
45 | __device = torch.empty(1).cuda().device
46 | __initialized = True
47 |
48 | assert tdist.is_initialized(), 'torch.distributed is not initialized!'
49 |
50 |
51 | def get_rank():
52 | return __rank
53 |
54 |
55 | def get_local_rank():
56 | return __local_rank
57 |
58 |
59 | def get_world_size():
60 | return __world_size
61 |
62 |
63 | def get_device():
64 | return __device
65 |
66 |
67 | def is_master():
68 | return __rank == 0
69 |
70 |
71 | def is_local_master():
72 | return __local_rank == 0
73 |
74 |
75 | def barrier():
76 | if __initialized:
77 | tdist.barrier()
78 |
79 |
80 | def parallelize(net, syncbn=False):
81 | if syncbn:
82 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
83 | net = net.cuda()
84 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
85 | return net
86 |
87 |
88 | def allreduce(t: torch.Tensor) -> None:
89 | if __initialized:
90 | if not t.is_cuda:
91 | cu = t.detach().cuda()
92 | tdist.all_reduce(cu)
93 | t.copy_(cu.cpu())
94 | else:
95 | tdist.all_reduce(t)
96 |
97 |
98 | def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
99 | if __initialized:
100 | if not t.is_cuda:
101 | t = t.cuda()
102 | ls = [torch.empty_like(t) for _ in range(__world_size)]
103 | tdist.all_gather(ls, t)
104 | else:
105 | ls = [t]
106 | if cat:
107 | ls = torch.cat(ls, dim=0)
108 | return ls
109 |
110 |
111 | def broadcast(t: torch.Tensor, src_rank) -> None:
112 | if __initialized:
113 | if not t.is_cuda:
114 | cu = t.detach().cuda()
115 | tdist.broadcast(cu, src=src_rank)
116 | t.copy_(cu.cpu())
117 | else:
118 | tdist.broadcast(t, src=src_rank)
119 |
--------------------------------------------------------------------------------
/pretrain/encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from timm.models.layers import DropPath
10 |
11 |
12 | _cur_active: torch.Tensor = None # B1ff
13 | # todo: try to use `gather` for speed?
14 | def _get_active_ex_or_ii(H, W, returning_active_ex=True):
15 | h_repeat, w_repeat = H // _cur_active.shape[-2], W // _cur_active.shape[-1]
16 | active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3)
17 | return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
18 |
19 |
20 | def sp_conv_forward(self, x: torch.Tensor):
21 | x = super(type(self), self).forward(x)
22 | x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv
23 | return x
24 |
25 |
26 | def sp_bn_forward(self, x: torch.Tensor):
27 | ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
28 |
29 | bhwc = x.permute(0, 2, 3, 1)
30 | nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc`
31 | nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc`
32 |
33 | bchw = torch.zeros_like(bhwc)
34 | bchw[ii] = nc
35 | bchw = bchw.permute(0, 3, 1, 2)
36 | return bchw
37 |
38 |
39 | class SparseConv2d(nn.Conv2d):
40 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
41 |
42 |
43 | class SparseMaxPooling(nn.MaxPool2d):
44 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
45 |
46 |
47 | class SparseAvgPooling(nn.AvgPool2d):
48 | forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
49 |
50 |
51 | class SparseBatchNorm2d(nn.BatchNorm1d):
52 | forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
53 |
54 |
55 | class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
56 | forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
57 |
58 |
59 | class SparseConvNeXtLayerNorm(nn.LayerNorm):
60 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
61 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
62 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
63 | with shape (batch_size, channels, height, width).
64 | """
65 |
66 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
67 | if data_format not in ["channels_last", "channels_first"]:
68 | raise NotImplementedError
69 | super().__init__(normalized_shape, eps, elementwise_affine=True)
70 | self.data_format = data_format
71 | self.sparse = sparse
72 |
73 | def forward(self, x):
74 | if x.ndim == 4: # BHWC or BCHW
75 | if self.data_format == "channels_last": # BHWC
76 | if self.sparse:
77 | ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], returning_active_ex=False)
78 | nc = x[ii]
79 | nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
80 |
81 | x = torch.zeros_like(x)
82 | x[ii] = nc
83 | return x
84 | else:
85 | return super(SparseConvNeXtLayerNorm, self).forward(x)
86 | else: # channels_first, BCHW
87 | if self.sparse:
88 | ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
89 | bhwc = x.permute(0, 2, 3, 1)
90 | nc = bhwc[ii]
91 | nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
92 |
93 | x = torch.zeros_like(bhwc)
94 | x[ii] = nc
95 | return x.permute(0, 3, 1, 2)
96 | else:
97 | u = x.mean(1, keepdim=True)
98 | s = (x - u).pow(2).mean(1, keepdim=True)
99 | x = (x - u) / torch.sqrt(s + self.eps)
100 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
101 | return x
102 | else: # BLC or BC
103 | if self.sparse:
104 | raise NotImplementedError
105 | else:
106 | return super(SparseConvNeXtLayerNorm, self).forward(x)
107 |
108 | def __repr__(self):
109 | return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
110 |
111 |
112 | class SparseConvNeXtBlock(nn.Module):
113 | r""" ConvNeXt Block. There are two equivalent implementations:
114 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
115 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
116 | We use (2) as we find it slightly faster in PyTorch
117 |
118 | Args:
119 | dim (int): Number of input channels.
120 | drop_path (float): Stochastic depth rate. Default: 0.0
121 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
122 | """
123 |
124 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7):
125 | super().__init__()
126 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv
127 | self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse)
128 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
129 | self.act = nn.GELU()
130 | self.pwconv2 = nn.Linear(4 * dim, dim)
131 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
132 | requires_grad=True) if layer_scale_init_value > 0 else None
133 | self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
134 | self.sparse = sparse
135 |
136 | def forward(self, x):
137 | input = x
138 | x = self.dwconv(x)
139 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
140 | x = self.norm(x)
141 | x = self.pwconv1(x)
142 | x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
143 | x = self.pwconv2(x)
144 | if self.gamma is not None:
145 | x = self.gamma * x
146 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
147 |
148 | if self.sparse:
149 | x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True)
150 |
151 | x = input + self.drop_path(x)
152 | return x
153 |
154 | def __repr__(self):
155 | return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
156 |
157 |
158 | class SparseEncoder(nn.Module):
159 | def __init__(self, cnn, input_size, sbn=False, verbose=False):
160 | super(SparseEncoder, self).__init__()
161 | self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn)
162 | self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels()
163 |
164 | @staticmethod
165 | def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
166 | oup = m
167 | if isinstance(m, nn.Conv2d):
168 | m: nn.Conv2d
169 | bias = m.bias is not None
170 | oup = SparseConv2d(
171 | m.in_channels, m.out_channels,
172 | kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
173 | dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode,
174 | )
175 | oup.weight.data.copy_(m.weight.data)
176 | if bias:
177 | oup.bias.data.copy_(m.bias.data)
178 | elif isinstance(m, nn.MaxPool2d):
179 | m: nn.MaxPool2d
180 | oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode)
181 | elif isinstance(m, nn.AvgPool2d):
182 | m: nn.AvgPool2d
183 | oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override)
184 | elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
185 | m: nn.BatchNorm2d
186 | oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats)
187 | oup.weight.data.copy_(m.weight.data)
188 | oup.bias.data.copy_(m.bias.data)
189 | oup.running_mean.data.copy_(m.running_mean.data)
190 | oup.running_var.data.copy_(m.running_var.data)
191 | oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
192 | if hasattr(m, "qconfig"):
193 | oup.qconfig = m.qconfig
194 | elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm):
195 | m: nn.LayerNorm
196 | oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps)
197 | oup.weight.data.copy_(m.weight.data)
198 | oup.bias.data.copy_(m.bias.data)
199 | elif isinstance(m, (nn.Conv1d,)):
200 | raise NotImplementedError
201 |
202 | for name, child in m.named_children():
203 | oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn))
204 | del m
205 | return oup
206 |
207 | def forward(self, x):
208 | return self.sp_cnn(x, hierarchical=True)
209 |
--------------------------------------------------------------------------------
/pretrain/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import datetime
8 | import math
9 | import sys
10 | import time
11 | from functools import partial
12 | from typing import List
13 |
14 | import torch
15 | from torch.nn.parallel import DistributedDataParallel
16 | from torch.utils.data import DataLoader
17 |
18 | import dist
19 | import encoder
20 | from decoder import LightDecoder
21 | from models import build_sparse_encoder
22 | from sampler import DistInfiniteBatchSampler, worker_init_fn
23 | from spark import SparK
24 | from utils import arg_util, misc, lamb
25 | from utils.imagenet import build_dataset_to_pretrain
26 | from utils.lr_control import lr_wd_annealing, get_param_groups
27 |
28 |
29 | class LocalDDP(torch.nn.Module):
30 | def __init__(self, module):
31 | super(LocalDDP, self).__init__()
32 | self.module = module
33 |
34 | def forward(self, *args, **kwargs):
35 | return self.module(*args, **kwargs)
36 |
37 |
38 | def main_pt():
39 | args: arg_util.Args = arg_util.init_dist_and_get_args()
40 | print(f'initial args:\n{str(args)}')
41 | args.log_epoch()
42 |
43 | # build data
44 | print(f'[build data for pre-training] ...\n')
45 | dataset_train = build_dataset_to_pretrain(args.data_path, args.input_size)
46 | data_loader_train = DataLoader(
47 | dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
48 | batch_sampler=DistInfiniteBatchSampler(
49 | dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
50 | shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
51 | ), worker_init_fn=worker_init_fn
52 | )
53 | itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
54 | print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
55 |
56 | # build encoder and decoder
57 | enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
58 | dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
59 | model_without_ddp = SparK(
60 | sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
61 | densify_norm=args.densify_norm, sbn=args.sbn,
62 | ).to(args.device)
63 | print(f'[PT model] model = {model_without_ddp}\n')
64 |
65 | # the model has been randomly initialized in their construction time
66 | # now try to load some checkpoint as model weight initialization; this ONLY loads the model weights
67 | misc.initialize_weight(args.init_weight, model_without_ddp)
68 |
69 | if dist.initialized():
70 | model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
71 | else:
72 | model = LocalDDP(model_without_ddp)
73 |
74 | # build optimizer and lr_scheduler
75 | param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'})
76 | opt_clz = {
77 | 'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
78 | 'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
79 | 'lamb': partial(lamb.TheSameAsTimmLAMB, betas=(0.9, args.ada), max_grad_norm=5.0),
80 | }[args.opt]
81 | optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
82 | print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
83 |
84 | # try to resume the experiment from some checkpoint.pth; this will load model weights, optimizer states, and last epoch (ep_start)
85 | # if loaded, ep_start will be greater than 0
86 | ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
87 | if ep_start >= args.ep: # load from a complete checkpoint file
88 | print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
89 | else: # perform pre-training
90 | tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
91 | min_loss = 1e9
92 | print(f'[PT start] from ep{ep_start}')
93 |
94 | pt_start_time = time.time()
95 | for ep in range(ep_start, args.ep):
96 | ep_start_time = time.time()
97 | tb_lg.set_step(ep * iters_train)
98 | if hasattr(itrt_train, 'set_epoch'):
99 | itrt_train.set_epoch(ep)
100 |
101 | stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
102 | last_loss = stats['last_loss']
103 | min_loss = min(min_loss, last_loss)
104 | performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
105 | misc.save_checkpoint_with_meta_info_and_opt_state(f'{args.model}_withdecoder_1kpretrained_spark_style.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
106 | misc.save_checkpoint_model_weights_only(f'{args.model}_1kpretrained_timm_style.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict())
107 |
108 | ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
109 | remain_secs = (args.ep-1 - ep) * ep_cost
110 | remain_time = datetime.timedelta(seconds=round(remain_secs))
111 | finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
112 | print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
113 |
114 | args.cur_ep = f'{ep + 1}/{args.ep}'
115 | args.remain_time, args.finish_time = str(remain_time), str(finish_time)
116 | args.last_loss = last_loss
117 | args.log_epoch()
118 |
119 | tb_lg.update(min_loss=min_loss, head='train', step=ep)
120 | tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
121 | tb_lg.flush()
122 |
123 | # finish pre-training
124 | tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
125 | tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
126 | tb_lg.flush()
127 | print(f'final args:\n{str(args)}')
128 | print('\n\n')
129 | print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
130 | print('\n\n')
131 | tb_lg.close()
132 | time.sleep(10)
133 |
134 | args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
135 | args.log_epoch()
136 |
137 |
138 | def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
139 | model.train()
140 | me = misc.MetricLogger(delimiter=' ')
141 | me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
142 | header = f'[PT] Epoch {ep}:'
143 |
144 | optimizer.zero_grad()
145 | early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
146 | late_clipping = hasattr(optimizer, 'global_grad_norm')
147 | if early_clipping:
148 | params_req_grad = [p for p in model.parameters() if p.requires_grad]
149 |
150 | for it, inp in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
151 | # adjust lr and wd
152 | min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
153 |
154 | # forward and backward
155 | inp = inp.to(args.device, non_blocking=True)
156 | SparK.forward
157 | loss = model(inp, active_b1ff=None, vis=False)
158 | optimizer.zero_grad()
159 | loss.backward()
160 | loss = loss.item()
161 | if not math.isfinite(loss):
162 | print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
163 | sys.exit(-1)
164 |
165 | # optimize
166 | grad_norm = None
167 | if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
168 | optimizer.step()
169 | if late_clipping: grad_norm = optimizer.global_grad_norm
170 | torch.cuda.synchronize()
171 |
172 | # log
173 | me.update(last_loss=loss)
174 | me.update(max_lr=max_lr)
175 | tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
176 | tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
177 | tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
178 | tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
179 | tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
180 |
181 | if grad_norm is not None:
182 | me.update(orig_norm=grad_norm)
183 | tb_lg.update(orig_norm=grad_norm, head='train_hp')
184 | tb_lg.set_step()
185 |
186 | me.synchronize_between_processes()
187 | return {k: meter.global_avg for k, meter in me.meters.items()}
188 |
189 |
190 | if __name__ == '__main__':
191 | main_pt()
192 |
--------------------------------------------------------------------------------
/pretrain/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from timm import create_model
9 | from timm.loss import SoftTargetCrossEntropy
10 | from timm.models.layers import drop
11 |
12 |
13 | from models.convnext import ConvNeXt
14 | from models.resnet import ResNet
15 | from models.custom import YourConvNet
16 | _import_resnets_for_timm_registration = (ResNet,)
17 |
18 |
19 | # log more
20 | def _ex_repr(self):
21 | return ', '.join(
22 | f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
23 | for k, v in vars(self).items()
24 | if not k.startswith('_') and k != 'training'
25 | and not isinstance(v, (torch.nn.Module, torch.Tensor))
26 | )
27 | for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
28 | if hasattr(clz, 'extra_repr'):
29 | clz.extra_repr = _ex_repr
30 | else:
31 | clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
32 |
33 |
34 | pretrain_default_model_kwargs = {
35 | 'your_convnet': dict(),
36 | 'resnet50': dict(drop_path_rate=0.05),
37 | 'resnet101': dict(drop_path_rate=0.08),
38 | 'resnet152': dict(drop_path_rate=0.10),
39 | 'resnet200': dict(drop_path_rate=0.15),
40 | 'convnext_small': dict(sparse=True, drop_path_rate=0.2),
41 | 'convnext_base': dict(sparse=True, drop_path_rate=0.3),
42 | 'convnext_large': dict(sparse=True, drop_path_rate=0.4),
43 | }
44 | for kw in pretrain_default_model_kwargs.values():
45 | kw['pretrained'] = False
46 | kw['num_classes'] = 0
47 | kw['global_pool'] = ''
48 |
49 |
50 | def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
51 | from encoder import SparseEncoder
52 |
53 | kwargs = pretrain_default_model_kwargs[name]
54 | if drop_path_rate != 0:
55 | kwargs['drop_path_rate'] = drop_path_rate
56 | print(f'[build_sparse_encoder] model kwargs={kwargs}')
57 | cnn = create_model(name, **kwargs)
58 |
59 | return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose)
60 |
61 |
--------------------------------------------------------------------------------
/pretrain/models/convnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
8 | from typing import List
9 |
10 | import torch
11 | import torch.nn as nn
12 | from timm.models.layers import trunc_normal_
13 | from timm.models.registry import register_model
14 |
15 | from encoder import SparseConvNeXtBlock, SparseConvNeXtLayerNorm
16 |
17 |
18 | class ConvNeXt(nn.Module):
19 | r""" ConvNeXt
20 | A PyTorch impl of : `A ConvNet for the 2020s` -
21 | https://arxiv.org/pdf/2201.03545.pdf
22 | Args:
23 | in_chans (int): Number of input image channels. Default: 3
24 | num_classes (int): Number of classes for classification head. Default: 1000
25 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
26 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
27 | drop_path_rate (float): Stochastic depth rate. Default: 0.
28 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
29 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
30 | """
31 |
32 | def __init__(self, in_chans=3, num_classes=1000,
33 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
34 | layer_scale_init_value=1e-6, head_init_scale=1., global_pool='avg',
35 | sparse=True,
36 | ):
37 | super().__init__()
38 | self.dims: List[int] = dims
39 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
40 | stem = nn.Sequential(
41 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
42 | SparseConvNeXtLayerNorm(dims[0], eps=1e-6, data_format="channels_first", sparse=sparse)
43 | )
44 | self.downsample_layers.append(stem)
45 | for i in range(3):
46 | downsample_layer = nn.Sequential(
47 | SparseConvNeXtLayerNorm(dims[i], eps=1e-6, data_format="channels_first", sparse=sparse),
48 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
49 | )
50 | self.downsample_layers.append(downsample_layer)
51 |
52 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
53 | self.drop_path_rate = drop_path_rate
54 | self.layer_scale_init_value = layer_scale_init_value
55 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
56 | cur = 0
57 | for i in range(4):
58 | stage = nn.Sequential(
59 | *[SparseConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j],
60 | layer_scale_init_value=layer_scale_init_value, sparse=sparse) for j in range(depths[i])]
61 | )
62 | self.stages.append(stage)
63 | cur += depths[i]
64 | self.depths = depths
65 |
66 | self.apply(self._init_weights)
67 | if num_classes > 0:
68 | self.norm = SparseConvNeXtLayerNorm(dims[-1], eps=1e-6, sparse=False) # final norm layer for LE/FT; should not be sparse
69 | self.fc = nn.Linear(dims[-1], num_classes)
70 | else:
71 | self.norm = nn.Identity()
72 | self.fc = nn.Identity()
73 |
74 | def _init_weights(self, m):
75 | if isinstance(m, (nn.Conv2d, nn.Linear)):
76 | trunc_normal_(m.weight, std=.02)
77 | nn.init.constant_(m.bias, 0)
78 |
79 | def get_downsample_ratio(self) -> int:
80 | return 32
81 |
82 | def get_feature_map_channels(self) -> List[int]:
83 | return self.dims
84 |
85 | def forward(self, x, hierarchical=False):
86 | if hierarchical:
87 | ls = []
88 | for i in range(4):
89 | x = self.downsample_layers[i](x)
90 | x = self.stages[i](x)
91 | ls.append(x)
92 | return ls
93 | else:
94 | return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls)
95 |
96 | def get_classifier(self):
97 | return self.fc
98 |
99 | def extra_repr(self):
100 | return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}'
101 |
102 |
103 | @register_model
104 | def convnext_tiny(pretrained=False, in_22k=False, **kwargs):
105 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
106 | return model
107 |
108 |
109 | @register_model
110 | def convnext_small(pretrained=False, in_22k=False, **kwargs):
111 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
112 | return model
113 |
114 |
115 | @register_model
116 | def convnext_base(pretrained=False, in_22k=False, **kwargs):
117 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
118 | return model
119 |
120 |
121 | @register_model
122 | def convnext_large(pretrained=False, in_22k=False, **kwargs):
123 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
124 | return model
125 |
126 |
--------------------------------------------------------------------------------
/pretrain/models/custom.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from typing import List
10 | from timm.models.registry import register_model
11 |
12 |
13 | class YourConvNet(nn.Module):
14 | """
15 | This is a template for your custom ConvNet.
16 | It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`.
17 | You can refer to the implementations in `pretrain\models\resnet.py` for an example.
18 | """
19 |
20 | def get_downsample_ratio(self) -> int:
21 | """
22 | This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
23 |
24 | :return: the TOTAL downsample ratio of the ConvNet.
25 | E.g., for a ResNet-50, this should return 32.
26 | """
27 | raise NotImplementedError
28 |
29 | def get_feature_map_channels(self) -> List[int]:
30 | """
31 | This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
32 |
33 | :return: a list of the number of channels of each feature map.
34 | E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
35 | """
36 | raise NotImplementedError
37 |
38 | def forward(self, inp_bchw: torch.Tensor, hierarchical=False):
39 | """
40 | The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`).
41 |
42 | :param inp_bchw: input image tensor, shape: (batch_size, channels, height, width).
43 | :param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical).
44 | :return:
45 | - hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes).
46 | - hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`.
47 | E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map].
48 | for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)]
49 | """
50 | raise NotImplementedError
51 |
52 |
53 | @register_model
54 | def your_convnet_small(pretrained=False, **kwargs):
55 | raise NotImplementedError
56 | return YourConvNet(**kwargs)
57 |
58 |
59 | @torch.no_grad()
60 | def convnet_test():
61 | from timm.models import create_model
62 | cnn = create_model('your_convnet_small')
63 | print('get_downsample_ratio:', cnn.get_downsample_ratio())
64 | print('get_feature_map_channels:', cnn.get_feature_map_channels())
65 |
66 | downsample_ratio = cnn.get_downsample_ratio()
67 | feature_map_channels = cnn.get_feature_map_channels()
68 |
69 | # check the forward function
70 | B, C, H, W = 4, 3, 224, 224
71 | inp = torch.rand(B, C, H, W)
72 | feats = cnn(inp, hierarchical=True)
73 | assert isinstance(feats, list)
74 | assert len(feats) == len(feature_map_channels)
75 | print([tuple(t.shape) for t in feats])
76 |
77 | # check the downsample ratio
78 | feats = cnn(inp, hierarchical=True)
79 | assert feats[-1].shape[-2] == H // downsample_ratio
80 | assert feats[-1].shape[-1] == W // downsample_ratio
81 |
82 | # check the channel number
83 | for feat, ch in zip(feats, feature_map_channels):
84 | assert feat.ndim == 4
85 | assert feat.shape[1] == ch
86 |
87 |
88 | if __name__ == '__main__':
89 | convnet_test()
90 |
--------------------------------------------------------------------------------
/pretrain/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from typing import List
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from timm.models.resnet import ResNet
11 |
12 |
13 | # hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet`
14 | def get_downsample_ratio(self: ResNet) -> int:
15 | return 32
16 |
17 |
18 | # hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet`
19 | def get_feature_map_channels(self: ResNet) -> List[int]:
20 | # `self.feature_info` is maintained by `timm`
21 | return [info['num_chs'] for info in self.feature_info[1:]]
22 |
23 |
24 | # hack: override the forward function of `timm.models.resnet.ResNet`
25 | def forward(self, x, hierarchical=False):
26 | """ this forward function is a modified version of `timm.models.resnet.ResNet.forward`
27 | >>> ResNet.forward
28 | """
29 | x = self.conv1(x)
30 | x = self.bn1(x)
31 | x = self.act1(x)
32 | x = self.maxpool(x)
33 |
34 | if hierarchical:
35 | ls = []
36 | x = self.layer1(x); ls.append(x)
37 | x = self.layer2(x); ls.append(x)
38 | x = self.layer3(x); ls.append(x)
39 | x = self.layer4(x); ls.append(x)
40 | return ls
41 | else:
42 | x = self.global_pool(x)
43 | if self.drop_rate:
44 | x = F.dropout(x, p=float(self.drop_rate), training=self.training)
45 | x = self.fc(x)
46 | return x
47 |
48 |
49 | ResNet.get_downsample_ratio = get_downsample_ratio
50 | ResNet.get_feature_map_channels = get_feature_map_channels
51 | ResNet.forward = forward
52 |
53 |
54 | @torch.no_grad()
55 | def convnet_test():
56 | from timm.models import create_model
57 | cnn = create_model('resnet50')
58 | print('get_downsample_ratio:', cnn.get_downsample_ratio())
59 | print('get_feature_map_channels:', cnn.get_feature_map_channels())
60 |
61 | downsample_ratio = cnn.get_downsample_ratio()
62 | feature_map_channels = cnn.get_feature_map_channels()
63 |
64 | # check the forward function
65 | B, C, H, W = 4, 3, 224, 224
66 | inp = torch.rand(B, C, H, W)
67 | feats = cnn(inp, hierarchical=True)
68 | assert isinstance(feats, list)
69 | assert len(feats) == len(feature_map_channels)
70 | print([tuple(t.shape) for t in feats])
71 |
72 | # check the downsample ratio
73 | feats = cnn(inp, hierarchical=True)
74 | assert feats[-1].shape[-2] == H // downsample_ratio
75 | assert feats[-1].shape[-1] == W // downsample_ratio
76 |
77 | # check the channel number
78 | for feat, ch in zip(feats, feature_map_channels):
79 | assert feat.ndim == 4
80 | assert feat.shape[1] == ch
81 |
82 |
83 | if __name__ == '__main__':
84 | convnet_test()
85 |
--------------------------------------------------------------------------------
/pretrain/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | numpy
3 | Pillow
4 | typed-argument-parser
5 | timm==0.5.4
6 | tensorboardx
7 |
--------------------------------------------------------------------------------
/pretrain/sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 |
9 | import numpy as np
10 | import torch
11 | from torch.utils.data.sampler import Sampler
12 |
13 |
14 | def worker_init_fn(worker_id):
15 | # https://pytorch.org/docs/stable/notes/randomness.html#dataloader
16 | worker_seed = torch.initial_seed() % 2 ** 32
17 | np.random.seed(worker_seed)
18 | random.seed(worker_seed)
19 |
20 |
21 | class DistInfiniteBatchSampler(Sampler):
22 | def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True):
23 | assert glb_batch_size % world_size == 0
24 | self.world_size, self.rank = world_size, rank
25 | self.dataset_len = dataset_len
26 | self.glb_batch_size = glb_batch_size
27 | self.batch_size = glb_batch_size // world_size
28 |
29 | self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
30 | self.filling = filling
31 | self.shuffle = shuffle
32 | self.epoch = 0
33 | self.seed = seed
34 | self.indices = self.gener_indices()
35 |
36 | def gener_indices(self):
37 | global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
38 | if self.shuffle:
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch + self.seed)
41 | global_indices = torch.randperm(self.dataset_len, generator=g)
42 | else:
43 | global_indices = torch.arange(self.dataset_len)
44 | filling = global_max_p - global_indices.shape[0]
45 | if filling > 0 and self.filling:
46 | global_indices = torch.cat((global_indices, global_indices[:filling]))
47 | global_indices = tuple(global_indices.numpy().tolist())
48 |
49 | seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
50 | local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
51 | self.max_p = len(local_indices)
52 | return local_indices
53 |
54 | def __iter__(self):
55 | self.epoch = 0
56 | while True:
57 | self.epoch += 1
58 | p, q = 0, 0
59 | while p < self.max_p:
60 | q = p + self.batch_size
61 | yield self.indices[p:q]
62 | p = q
63 | if self.shuffle:
64 | self.indices = self.gener_indices()
65 |
66 | def __len__(self):
67 | return self.iters_per_ep
68 |
69 |
70 | if __name__ == '__main__':
71 | W = 16
72 | for rk in range(W):
73 | ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices()
74 | print(rk, len(ind))
75 |
--------------------------------------------------------------------------------
/pretrain/spark.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from pprint import pformat
8 | from typing import List
9 |
10 | import sys
11 | import torch
12 | import torch.nn as nn
13 | from timm.models.layers import trunc_normal_
14 |
15 | import encoder
16 | from decoder import LightDecoder
17 |
18 |
19 | class SparK(nn.Module):
20 | def __init__(
21 | self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder,
22 | mask_ratio=0.6, densify_norm='bn', sbn=False,
23 | ):
24 | super().__init__()
25 | input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito
26 | self.downsample_raito = downsample_raito
27 | self.fmap_h, self.fmap_w = input_size // downsample_raito, input_size // downsample_raito
28 | self.mask_ratio = mask_ratio
29 | self.len_keep = round(self.fmap_h * self.fmap_w * (1 - mask_ratio))
30 |
31 | self.sparse_encoder = sparse_encoder
32 | self.dense_decoder = dense_decoder
33 |
34 | self.sbn = sbn
35 | self.hierarchy = len(sparse_encoder.enc_feat_map_chs)
36 | self.densify_norm_str = densify_norm.lower()
37 | self.densify_norms = nn.ModuleList()
38 | self.densify_projs = nn.ModuleList()
39 | self.mask_tokens = nn.ParameterList()
40 |
41 | # build the `densify` layers
42 | e_widths, d_width = self.sparse_encoder.enc_feat_map_chs, self.dense_decoder.width
43 | e_widths: List[int]
44 | for i in range(self.hierarchy): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ...
45 | e_width = e_widths.pop()
46 | # create mask token
47 | p = nn.Parameter(torch.zeros(1, e_width, 1, 1))
48 | trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
49 | self.mask_tokens.append(p)
50 |
51 | # create densify norm
52 | if self.densify_norm_str == 'bn':
53 | densify_norm = (encoder.SparseSyncBatchNorm2d if self.sbn else encoder.SparseBatchNorm2d)(e_width)
54 | elif self.densify_norm_str == 'ln':
55 | densify_norm = encoder.SparseConvNeXtLayerNorm(e_width, data_format='channels_first', sparse=True)
56 | else:
57 | densify_norm = nn.Identity()
58 | self.densify_norms.append(densify_norm)
59 |
60 | # create densify proj
61 | if i == 0 and e_width == d_width:
62 | densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768
63 | print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj')
64 | else:
65 | kernel_size = 1 if i <= 0 else 3
66 | densify_proj = nn.Conv2d(e_width, d_width, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, bias=True)
67 | print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)')
68 | self.densify_projs.append(densify_proj)
69 |
70 | # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
71 | d_width //= 2
72 |
73 | print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}')
74 |
75 | # these are deprecated and would never be used; can be removed.
76 | self.register_buffer('imn_m', torch.empty(1, 3, 1, 1))
77 | self.register_buffer('imn_s', torch.empty(1, 3, 1, 1))
78 | self.register_buffer('norm_black', torch.zeros(1, 3, input_size, input_size))
79 | self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ...
80 |
81 | def mask(self, B: int, device, generator=None):
82 | h, w = self.fmap_h, self.fmap_w
83 | idx = torch.rand(B, h * w, generator=generator).argsort(dim=1)
84 | idx = idx[:, :self.len_keep].to(device) # (B, len_keep)
85 | return torch.zeros(B, h * w, dtype=torch.bool, device=device).scatter_(dim=1, index=idx, value=True).view(B, 1, h, w)
86 |
87 | def forward(self, inp_bchw: torch.Tensor, active_b1ff=None, vis=False):
88 | # step1. Mask
89 | if active_b1ff is None: # rand mask
90 | active_b1ff: torch.BoolTensor = self.mask(inp_bchw.shape[0], inp_bchw.device) # (B, 1, f, f)
91 | encoder._cur_active = active_b1ff # (B, 1, f, f)
92 | active_b1hw = active_b1ff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W)
93 | masked_bchw = inp_bchw * active_b1hw
94 |
95 | # step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
96 | fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw)
97 | fea_bcffs.reverse() # after reversion: from the smallest feature map to the largest
98 |
99 | # step3. Densify: get hierarchical dense features for decoding
100 | cur_active = active_b1ff # (B, 1, f, f)
101 | to_dec = []
102 | for i, bcff in enumerate(fea_bcffs): # from the smallest feature map to the largest
103 | if bcff is not None:
104 | bcff = self.densify_norms[i](bcff)
105 | mask_tokens = self.mask_tokens[i].expand_as(bcff)
106 | bcff = torch.where(cur_active.expand_as(bcff), bcff, mask_tokens) # fill in empty (non-active) positions with [mask] tokens
107 | bcff: torch.Tensor = self.densify_projs[i](bcff)
108 | to_dec.append(bcff)
109 | cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) # dilate the mask map, from (B, 1, f, f) to (B, 1, H, W)
110 |
111 | # step4. Decode and reconstruct
112 | rec_bchw = self.dense_decoder(to_dec)
113 | inp, rec = self.patchify(inp_bchw), self.patchify(rec_bchw) # inp and rec: (B, L = f*f, N = C*downsample_raito**2)
114 | mean = inp.mean(dim=-1, keepdim=True)
115 | var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5
116 | inp = (inp - mean) / var
117 | l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) # (B, L, C) ==mean==> (B, L)
118 |
119 | non_active = active_b1ff.logical_not().int().view(active_b1ff.shape[0], -1) # (B, 1, f, f) => (B, L)
120 | recon_loss = l2_loss.mul_(non_active).sum() / (non_active.sum() + 1e-8) # loss only on masked (non-active) patches
121 |
122 | if vis:
123 | masked_bchw = inp_bchw * active_b1hw
124 | rec_bchw = self.unpatchify(rec * var + mean)
125 | rec_or_inp = torch.where(active_b1hw, inp_bchw, rec_bchw)
126 | return inp_bchw, masked_bchw, rec_or_inp
127 | else:
128 | return recon_loss
129 |
130 | def patchify(self, bchw):
131 | p = self.downsample_raito
132 | h, w = self.fmap_h, self.fmap_w
133 | B, C = bchw.shape[:2]
134 | bchw = bchw.reshape(shape=(B, C, h, p, w, p))
135 | bchw = torch.einsum('bchpwq->bhwpqc', bchw)
136 | bln = bchw.reshape(shape=(B, h * w, C * p ** 2)) # (B, f*f, 3*downsample_raito**2)
137 | return bln
138 |
139 | def unpatchify(self, bln):
140 | p = self.downsample_raito
141 | h, w = self.fmap_h, self.fmap_w
142 | B, C = bln.shape[0], bln.shape[-1] // p ** 2
143 | bln = bln.reshape(shape=(B, h, w, p, p, C))
144 | bln = torch.einsum('bhwpqc->bchpwq', bln)
145 | bchw = bln.reshape(shape=(B, C, h * p, w * p))
146 | return bchw
147 |
148 | def __repr__(self):
149 | return (
150 | f'\n'
151 | f'[SparK.config]: {pformat(self.get_config(), indent=2, width=250)}\n'
152 | f'[SparK.structure]: {super(SparK, self).__repr__().replace(SparK.__name__, "")}'
153 | )
154 |
155 | def get_config(self):
156 | return {
157 | # self
158 | 'mask_ratio': self.mask_ratio,
159 | 'densify_norm_str': self.densify_norm_str,
160 | 'sbn': self.sbn, 'hierarchy': self.hierarchy,
161 |
162 | # enc
163 | 'sparse_encoder.input_size': self.sparse_encoder.input_size,
164 | # dec
165 | 'dense_decoder.width': self.dense_decoder.width,
166 | }
167 |
168 | def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False):
169 | state = super(SparK, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
170 | if with_config:
171 | state['config'] = self.get_config()
172 | return state
173 |
174 | def load_state_dict(self, state_dict, strict=True):
175 | config: dict = state_dict.pop('config', None)
176 | incompatible_keys = super(SparK, self).load_state_dict(state_dict, strict=strict)
177 | if config is not None:
178 | for k, v in self.get_config().items():
179 | ckpt_v = config.get(k, None)
180 | if ckpt_v != v:
181 | err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={ckpt_v})'
182 | if strict:
183 | raise AttributeError(err)
184 | else:
185 | print(err, file=sys.stderr)
186 | return incompatible_keys
187 |
--------------------------------------------------------------------------------
/pretrain/utils/arg_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import json
8 | import os
9 | import sys
10 |
11 | from tap import Tap
12 |
13 | import dist
14 |
15 |
16 | class Args(Tap):
17 | # environment
18 | exp_name: str = 'your_exp_name'
19 | exp_dir: str = 'your_exp_dir' # will be created if not exists
20 | data_path: str = 'imagenet_data_path'
21 | init_weight: str = '' # use some checkpoint as model weight initialization; ONLY load model weights
22 | resume_from: str = '' # resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch
23 |
24 | # SparK hyperparameters
25 | mask: float = 0.6 # mask ratio, should be in (0, 1)
26 |
27 | # encoder hyperparameters
28 | model: str = 'resnet50'
29 | input_size: int = 224
30 | sbn: bool = True
31 |
32 | # data hyperparameters
33 | bs: int = 4096
34 | dataloader_workers: int = 8
35 |
36 | # pre-training hyperparameters
37 | dp: float = 0.0
38 | base_lr: float = 2e-4
39 | wd: float = 0.04
40 | wde: float = 0.2
41 | ep: int = 1600
42 | wp_ep: int = 40
43 | clip: int = 5.
44 | opt: str = 'lamb'
45 | ada: float = 0.
46 |
47 | # NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically
48 | lr: float = None
49 | batch_size_per_gpu: int = 0
50 | glb_batch_size: int = 0
51 | densify_norm: str = ''
52 | device: str = 'cpu'
53 | local_rank: int = 0
54 | cmd: str = ' '.join(sys.argv[1:])
55 | commit_id: str = os.popen(f'git rev-parse HEAD').read().strip() or '[unknown]'
56 | commit_msg: str = (os.popen(f'git log -1').read().strip().splitlines() or ['[unknown]'])[-1].strip()
57 | last_loss: float = 0.
58 | cur_ep: str = ''
59 | remain_time: str = ''
60 | finish_time: str = ''
61 | first_logging: bool = True
62 | log_txt_name: str = '{args.exp_dir}/pretrain_log.txt'
63 | tb_lg_dir: str = '' # tensorboard log directory
64 |
65 | @property
66 | def is_convnext(self):
67 | return 'convnext' in self.model or 'cnx' in self.model
68 |
69 | @property
70 | def is_resnet(self):
71 | return 'resnet' in self.model
72 |
73 | def log_epoch(self):
74 | if not dist.is_local_master():
75 | return
76 |
77 | if self.first_logging:
78 | self.first_logging = False
79 | with open(self.log_txt_name, 'w') as fp:
80 | json.dump({
81 | 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg,
82 | 'model': self.model,
83 | }, fp)
84 | fp.write('\n\n')
85 |
86 | with open(self.log_txt_name, 'a') as fp:
87 | json.dump({
88 | 'cur_ep': self.cur_ep,
89 | 'last_L': self.last_loss,
90 | 'rema': self.remain_time, 'fini': self.finish_time,
91 | }, fp)
92 | fp.write('\n')
93 |
94 |
95 | def init_dist_and_get_args():
96 | from utils import misc
97 |
98 | # initialize
99 | args = Args(explicit_bool=True).parse_args()
100 | e = os.path.abspath(args.exp_dir)
101 | d, e = os.path.dirname(e), os.path.basename(e)
102 | e = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in e)
103 | args.exp_dir = os.path.join(d, e)
104 |
105 | os.makedirs(args.exp_dir, exist_ok=True)
106 | args.log_txt_name = os.path.join(args.exp_dir, 'pretrain_log.txt')
107 | args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log')
108 | try:
109 | os.makedirs(args.tb_lg_dir, exist_ok=True)
110 | except:
111 | pass
112 |
113 | misc.init_distributed_environ(exp_dir=args.exp_dir)
114 |
115 | # update args
116 | if not dist.initialized():
117 | args.sbn = False
118 | args.first_logging = True
119 | args.device = dist.get_device()
120 | args.batch_size_per_gpu = args.bs // dist.get_world_size()
121 | args.glb_batch_size = args.batch_size_per_gpu * dist.get_world_size()
122 |
123 | if args.is_resnet:
124 | args.ada = args.ada or 0.95
125 | args.densify_norm = 'bn'
126 |
127 | if args.is_convnext:
128 | args.ada = args.ada or 0.999
129 | args.densify_norm = 'ln'
130 |
131 | args.opt = args.opt.lower()
132 | args.lr = args.base_lr * args.glb_batch_size / 256
133 | args.wde = args.wde or args.wd
134 |
135 | return args
136 |
--------------------------------------------------------------------------------
/pretrain/utils/imagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | from typing import Any, Callable, Optional, Tuple
9 |
10 | import PIL.Image as PImage
11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12 | from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
13 | from torchvision.transforms import transforms
14 | from torch.utils.data import Dataset
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | interpolation = InterpolationMode.BICUBIC
19 | except:
20 | import PIL
21 | interpolation = PIL.Image.BICUBIC
22 |
23 |
24 | def pil_loader(path):
25 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
26 | with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB')
27 | return img
28 |
29 |
30 | class ImageNetDataset(DatasetFolder):
31 | def __init__(
32 | self,
33 | imagenet_folder: str,
34 | train: bool,
35 | transform: Callable,
36 | is_valid_file: Optional[Callable[[str], bool]] = None,
37 | ):
38 | imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val')
39 | super(ImageNetDataset, self).__init__(
40 | imagenet_folder,
41 | loader=pil_loader,
42 | extensions=IMG_EXTENSIONS if is_valid_file is None else None,
43 | transform=transform,
44 | target_transform=None, is_valid_file=is_valid_file
45 | )
46 |
47 | self.samples = tuple(img for (img, label) in self.samples)
48 | self.targets = None # this is self-supervised learning so we don't need labels
49 |
50 | def __getitem__(self, index: int) -> Any:
51 | img_file_path = self.samples[index]
52 | return self.transform(self.loader(img_file_path))
53 |
54 |
55 | def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset:
56 | """
57 | You may need to modify this function to return your own dataset.
58 | Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset.
59 | Use dataset_path to build your image file path list.
60 | Use input_size to create the transformation function for your images, can refer to the `trans_train` blow.
61 |
62 | :param dataset_path: the folder of dataset
63 | :param input_size: the input size (image resolution)
64 | :return: the dataset used for pretraining
65 | """
66 | trans_train = transforms.Compose([
67 | transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
68 | transforms.RandomHorizontalFlip(),
69 | transforms.ToTensor(),
70 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
71 | ])
72 |
73 | dataset_path = os.path.abspath(dataset_path)
74 | for postfix in ('train', 'val'):
75 | if dataset_path.endswith(postfix):
76 | dataset_path = dataset_path[:-len(postfix)]
77 |
78 | dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True)
79 | print_transform(trans_train, '[pre-train]')
80 | return dataset_train
81 |
82 |
83 | def print_transform(transform, s):
84 | print(f'Transform {s} = ')
85 | for t in transform.transforms:
86 | print(t)
87 | print('---------------------------\n')
88 |
--------------------------------------------------------------------------------
/pretrain/utils/lamb.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # This file is basically a copy to: https://github.com/rwightman/pytorch-image-models/blob/v0.5.4/timm/optim/lamb.py
8 | # **The only modification** is adding the `global_grad_norm` member for debugging
9 |
10 |
11 | """ PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
12 | This optimizer code was adapted from the following (starting with latest)
13 | * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
14 | * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
15 | * https://github.com/cybertronai/pytorch-lamb
16 | Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
17 | similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
18 | In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
19 | Original copyrights for above sources are below.
20 | Modifications Copyright 2021 Ross Wightman
21 | """
22 | import math
23 |
24 | import torch
25 | from torch.optim.optimizer import Optimizer
26 |
27 |
28 | class TheSameAsTimmLAMB(Optimizer):
29 | """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
30 | reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
31 |
32 | LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
33 |
34 | Arguments:
35 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
36 | lr (float, optional): learning rate. (default: 1e-3)
37 | betas (Tuple[float, float], optional): coefficients used for computing
38 | running averages of gradient and its norm. (default: (0.9, 0.999))
39 | eps (float, optional): term added to the denominator to improve
40 | numerical stability. (default: 1e-8)
41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42 | grad_averaging (bool, optional): whether apply (1-beta2) to grad when
43 | calculating running averages of gradient. (default: True)
44 | max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
45 | trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
46 | always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
47 | weight decay parameter (default: False)
48 |
49 | .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
50 | https://arxiv.org/abs/1904.00962
51 | .. _On the Convergence of Adam and Beyond:
52 | https://openreview.net/forum?id=ryQu7f-RZ
53 | """
54 |
55 | def __init__(
56 | self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
57 | weight_decay=0.01, grad_averaging=True, max_grad_norm=2.0, trust_clip=False, always_adapt=False):
58 | defaults = dict(
59 | lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
60 | grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
61 | trust_clip=trust_clip, always_adapt=always_adapt)
62 | super().__init__(params, defaults)
63 | print(f'[lamb1] max_grad_norm={max_grad_norm}')
64 | self.global_grad_norm = 0
65 |
66 | @torch.no_grad()
67 | def step(self, closure=None):
68 | """Performs a single optimization step.
69 | Arguments:
70 | closure (callable, optional): A closure that reevaluates the model
71 | and returns the loss.
72 | """
73 | loss = None
74 | if closure is not None:
75 | with torch.enable_grad():
76 | loss = closure()
77 |
78 | device = self.param_groups[0]['params'][0].device
79 | one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
80 | global_grad_norm = torch.zeros(1, device=device)
81 | for group in self.param_groups:
82 | for p in group['params']:
83 | if p.grad is None:
84 | continue
85 | grad = p.grad
86 | if grad.is_sparse:
87 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
88 | global_grad_norm.add_(grad.pow(2).sum())
89 |
90 | global_grad_norm = torch.sqrt(global_grad_norm)
91 | self.global_grad_norm = global_grad_norm.item()
92 | max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
93 | clip_global_grad_norm = 1 / torch.where(
94 | global_grad_norm > max_grad_norm,
95 | global_grad_norm / max_grad_norm,
96 | one_tensor)
97 |
98 | for group in self.param_groups:
99 | bias_correction = 1 if group['bias_correction'] else 0
100 | beta1, beta2 = group['betas']
101 | grad_averaging = 1 if group['grad_averaging'] else 0
102 | beta3 = 1 - beta1 if grad_averaging else 1.0
103 |
104 | # assume same step across group now to simplify things
105 | # per parameter step can be easily support by making it tensor, or pass list into kernel
106 | if 'step' in group:
107 | group['step'] += 1
108 | else:
109 | group['step'] = 1
110 |
111 | if bias_correction:
112 | bias_correction1 = 1 - beta1 ** group['step']
113 | bias_correction2 = 1 - beta2 ** group['step']
114 | else:
115 | bias_correction1, bias_correction2 = 1.0, 1.0
116 |
117 | for p in group['params']:
118 | if p.grad is None:
119 | continue
120 | grad = p.grad.mul_(clip_global_grad_norm)
121 | state = self.state[p]
122 |
123 | # State initialization
124 | if len(state) == 0:
125 | # Exponential moving average of gradient valuesa
126 | state['exp_avg'] = torch.zeros_like(p)
127 | # Exponential moving average of squared gradient values
128 | state['exp_avg_sq'] = torch.zeros_like(p)
129 |
130 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
131 |
132 | # Decay the first and second moment running average coefficient
133 | exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
134 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
135 |
136 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
137 | update = (exp_avg / bias_correction1).div_(denom)
138 |
139 | weight_decay = group['weight_decay']
140 | if weight_decay != 0:
141 | update.add_(p, alpha=weight_decay)
142 |
143 | if weight_decay != 0 or group['always_adapt']:
144 | # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
145 | # excluded from weight decay, unless always_adapt == True, then always enabled.
146 | w_norm = p.norm(2.0)
147 | g_norm = update.norm(2.0)
148 | # FIXME nested where required since logical and/or not working in PT XLA
149 | trust_ratio = torch.where(
150 | w_norm > 0,
151 | torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
152 | one_tensor,
153 | )
154 | if group['trust_clip']:
155 | # LAMBC trust clipping, upper bound fixed at one
156 | trust_ratio = torch.minimum(trust_ratio, one_tensor)
157 | update.mul_(trust_ratio)
158 |
159 | p.add_(update, alpha=-group['lr'])
160 |
161 | return loss
162 |
--------------------------------------------------------------------------------
/pretrain/utils/lr_control.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from pprint import pformat
9 |
10 |
11 | def lr_wd_annealing(optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it):
12 | wp_it = round(wp_it)
13 | if cur_it < wp_it:
14 | cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it
15 | else:
16 | ratio = (cur_it - wp_it) / (max_it - 1 - wp_it)
17 | cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio))
18 |
19 | ratio = cur_it / (max_it - 1)
20 | cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * ratio))
21 |
22 | min_lr, max_lr = cur_lr, cur_lr
23 | min_wd, max_wd = cur_wd, cur_wd
24 | for param_group in optimizer.param_groups:
25 | scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned
26 | min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr)
27 | scaled_wd = param_group['weight_decay'] = cur_wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned
28 | min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd)
29 | return min_lr, max_lr, min_wd, max_wd
30 |
31 |
32 | def get_param_groups(model, nowd_keys=()):
33 | para_groups, para_groups_dbg = {}, {}
34 |
35 | for name, para in model.named_parameters():
36 | if not para.requires_grad:
37 | continue # frozen weights
38 | if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys):
39 | wd_scale, group_name = 0., 'no_decay'
40 | else:
41 | wd_scale, group_name = 1., 'decay'
42 |
43 | if group_name not in para_groups:
44 | para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.}
45 | para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.}
46 | para_groups[group_name]['params'].append(para)
47 | para_groups_dbg[group_name]['params'].append(name)
48 |
49 | for g in para_groups_dbg.values():
50 | g['params'] = pformat(', '.join(g['params']), width=200)
51 |
52 | print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n')
53 | return list(para_groups.values())
54 |
--------------------------------------------------------------------------------
/pretrain/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import datetime
8 | import functools
9 | import os
10 | import subprocess
11 | import sys
12 | import time
13 | from collections import defaultdict, deque
14 | from typing import Iterator
15 |
16 | import numpy as np
17 | import pytz
18 | import torch
19 | from torch.utils.tensorboard import SummaryWriter
20 |
21 | import dist
22 |
23 | os_system = functools.partial(subprocess.call, shell=True)
24 | os_system_get_stdout = lambda cmd: subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
25 | def os_system_get_stdout_stderr(cmd):
26 | sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
27 | return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
28 |
29 |
30 | def is_pow2n(x):
31 | return x > 0 and (x & (x - 1) == 0)
32 |
33 |
34 | def time_str(for_dirname=False):
35 | return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]')
36 |
37 |
38 | def init_distributed_environ(exp_dir):
39 | dist.initialize()
40 | dist.barrier()
41 |
42 | import torch.backends.cudnn as cudnn
43 | cudnn.benchmark = True
44 | cudnn.deterministic = False
45 |
46 | _set_print_only_on_master_proc(is_master=dist.is_local_master())
47 | if dist.is_local_master() and len(exp_dir):
48 | sys.stdout, sys.stderr = _SyncPrintToFile(exp_dir, stdout=True), _SyncPrintToFile(exp_dir, stdout=False)
49 |
50 |
51 | def _set_print_only_on_master_proc(is_master):
52 | import builtins as __builtin__
53 |
54 | builtin_print = __builtin__.print
55 |
56 | def prt(msg, *args, **kwargs):
57 | force = kwargs.pop('force', False)
58 | clean = kwargs.pop('clean', False)
59 | deeper = kwargs.pop('deeper', False)
60 | if is_master or force:
61 | if not clean:
62 | f_back = sys._getframe().f_back
63 | if deeper and f_back.f_back is not None:
64 | f_back = f_back.f_back
65 | file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
66 | msg = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}'
67 | builtin_print(msg, *args, **kwargs)
68 |
69 | __builtin__.print = prt
70 |
71 |
72 | class _SyncPrintToFile(object):
73 | def __init__(self, exp_dir, stdout=True):
74 | self.terminal = sys.stdout if stdout else sys.stderr
75 | fname = os.path.join(exp_dir, 'stdout_backup.txt' if stdout else 'stderr_backup.txt')
76 | self.log = open(fname, 'w')
77 | self.log.flush()
78 |
79 | def write(self, message):
80 | self.terminal.write(message)
81 | self.log.write(message)
82 | self.log.flush()
83 |
84 | def flush(self):
85 | self.terminal.flush()
86 | self.log.flush()
87 |
88 |
89 | class TensorboardLogger(object):
90 | def __init__(self, log_dir, is_master, prefix='pt'):
91 | self.is_master = is_master
92 | self.writer = SummaryWriter(log_dir=log_dir) if self.is_master else None
93 | self.step = 0
94 | self.prefix = prefix
95 | self.log_freq = 300
96 |
97 | def set_step(self, step=None):
98 | if step is not None:
99 | self.step = step
100 | else:
101 | self.step += 1
102 |
103 | def get_loggable(self, step=None):
104 | if step is None: # iter wise
105 | step = self.step
106 | loggable = step % self.log_freq == 0
107 | else: # epoch wise
108 | loggable = True
109 | return step, (loggable and self.is_master)
110 |
111 | def update(self, head='scalar', step=None, **kwargs):
112 | step, loggable = self.get_loggable(step)
113 | if loggable:
114 | head = f'{self.prefix}_{head}'
115 | for k, v in kwargs.items():
116 | if v is None:
117 | continue
118 | if isinstance(v, torch.Tensor):
119 | v = v.item()
120 | assert isinstance(v, (float, int))
121 | self.writer.add_scalar(head + "/" + k, v, step)
122 |
123 | def log_distribution(self, tag, values, step=None):
124 | step, loggable = self.get_loggable(step)
125 | if loggable:
126 | if not isinstance(values, torch.Tensor):
127 | values = torch.tensor(values)
128 | self.writer.add_histogram(tag=tag, values=values, global_step=step)
129 |
130 | def log_image(self, tag, img, step=None, dataformats='NCHW'):
131 | step, loggable = self.get_loggable(step)
132 | if loggable:
133 | # img = img.cpu().numpy()
134 | self.writer.add_image(tag, img, step, dataformats=dataformats)
135 |
136 | def flush(self):
137 | if self.is_master: self.writer.flush()
138 |
139 | def close(self):
140 | if self.is_master: self.writer.close()
141 |
142 |
143 | def save_checkpoint_with_meta_info_and_opt_state(save_to, args, epoch, performance_desc, model_without_ddp_state, optimizer_state):
144 | checkpoint_path = os.path.join(args.exp_dir, save_to)
145 | if dist.is_local_master():
146 | to_save = {
147 | 'args': str(args),
148 | 'input_size': args.input_size,
149 | 'arch': args.model,
150 | 'epoch': epoch,
151 | 'performance_desc': performance_desc,
152 | 'module': model_without_ddp_state,
153 | 'optimizer': optimizer_state,
154 | 'is_pretrain': True,
155 | }
156 | torch.save(to_save, checkpoint_path)
157 |
158 |
159 | def save_checkpoint_model_weights_only(save_to, args, sp_cnn_state):
160 | checkpoint_path = os.path.join(args.exp_dir, save_to)
161 | if dist.is_local_master():
162 | torch.save(sp_cnn_state, checkpoint_path)
163 |
164 |
165 | def initialize_weight(init_weight: str, model_without_ddp):
166 | # use some checkpoint as model weight initialization; ONLY load model weights
167 | if len(init_weight):
168 | checkpoint = torch.load(init_weight, 'cpu')
169 | missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
170 | print(f'[initialize_weight] missing_keys={missing}')
171 | print(f'[initialize_weight] unexpected_keys={unexpected}')
172 |
173 |
174 | def load_checkpoint(resume_from: str, model_without_ddp, optimizer):
175 | # resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch
176 | if len(resume_from) == 0:
177 | return 0, '[no performance_desc]'
178 | print(f'[try to resume from file `{resume_from}`]')
179 | checkpoint = torch.load(resume_from, map_location='cpu')
180 |
181 | ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]')
182 | missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
183 | print(f'[load_checkpoint] missing_keys={missing}')
184 | print(f'[load_checkpoint] unexpected_keys={unexpected}')
185 | print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}')
186 |
187 | if 'optimizer' in checkpoint:
188 | optimizer.load_state_dict(checkpoint['optimizer'])
189 | return ep_start, performance_desc
190 |
191 |
192 | class SmoothedValue(object):
193 | """Track a series of values and provide access to smoothed values over a
194 | window or the global series average.
195 | """
196 |
197 | def __init__(self, window_size=20, fmt=None):
198 | if fmt is None:
199 | fmt = "{median:.4f} ({global_avg:.4f})"
200 | self.deque = deque(maxlen=window_size)
201 | self.total = 0.0
202 | self.count = 0
203 | self.fmt = fmt
204 |
205 | def update(self, value, n=1):
206 | self.deque.append(value)
207 | self.count += n
208 | self.total += value * n
209 |
210 | def synchronize_between_processes(self):
211 | """
212 | Warning: does not synchronize the deque!
213 | """
214 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
215 | dist.barrier()
216 | dist.allreduce(t)
217 | t = t.tolist()
218 | self.count = int(t[0])
219 | self.total = t[1]
220 |
221 | @property
222 | def median(self):
223 | d = torch.tensor(list(self.deque))
224 | return d.median().item()
225 |
226 | @property
227 | def avg(self):
228 | d = torch.tensor(list(self.deque), dtype=torch.float32)
229 | return d.mean().item()
230 |
231 | @property
232 | def global_avg(self):
233 | return self.total / self.count
234 |
235 | @property
236 | def max(self):
237 | return max(self.deque)
238 |
239 | @property
240 | def value(self):
241 | return self.deque[-1]
242 |
243 | def __str__(self):
244 | return self.fmt.format(
245 | median=self.median,
246 | avg=self.avg,
247 | global_avg=self.global_avg,
248 | max=self.max,
249 | value=self.value)
250 |
251 |
252 | class MetricLogger(object):
253 | def __init__(self, delimiter="\t"):
254 | self.meters = defaultdict(SmoothedValue)
255 | self.delimiter = delimiter
256 |
257 | def update(self, **kwargs):
258 | for k, v in kwargs.items():
259 | if v is None:
260 | continue
261 | if isinstance(v, torch.Tensor):
262 | v = v.item()
263 | assert isinstance(v, (float, int))
264 | self.meters[k].update(v)
265 |
266 | def __getattr__(self, attr):
267 | if attr in self.meters:
268 | return self.meters[attr]
269 | if attr in self.__dict__:
270 | return self.__dict__[attr]
271 | raise AttributeError("'{}' object has no attribute '{}'".format(
272 | type(self).__name__, attr))
273 |
274 | def __str__(self):
275 | loss_str = []
276 | for name, meter in self.meters.items():
277 | loss_str.append(
278 | "{}: {}".format(name, str(meter))
279 | )
280 | return self.delimiter.join(loss_str)
281 |
282 | def synchronize_between_processes(self):
283 | for meter in self.meters.values():
284 | meter.synchronize_between_processes()
285 |
286 | def add_meter(self, name, meter):
287 | self.meters[name] = meter
288 |
289 | def log_every(self, max_iters, itrt, print_freq, header=None):
290 | print_iters = set(np.linspace(0, max_iters - 1, print_freq, dtype=int).tolist())
291 | if not header:
292 | header = ''
293 | start_time = time.time()
294 | end = time.time()
295 | self.iter_time = SmoothedValue(fmt='{avg:.4f}')
296 | self.data_time = SmoothedValue(fmt='{avg:.4f}')
297 | space_fmt = ':' + str(len(str(max_iters))) + 'd'
298 | log_msg = [
299 | header,
300 | '[{0' + space_fmt + '}/{1}]',
301 | 'eta: {eta}',
302 | '{meters}',
303 | 'iter: {time}s',
304 | 'data: {data}s'
305 | ]
306 | log_msg = self.delimiter.join(log_msg)
307 |
308 | if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
309 | for i in range(max_iters):
310 | obj = next(itrt)
311 | self.data_time.update(time.time() - end)
312 | yield obj
313 | self.iter_time.update(time.time() - end)
314 | if i in print_iters:
315 | eta_seconds = self.iter_time.global_avg * (max_iters - i)
316 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
317 | print(log_msg.format(
318 | i, max_iters, eta=eta_string,
319 | meters=str(self),
320 | time=str(self.iter_time), data=str(self.data_time)))
321 | end = time.time()
322 | else:
323 | for i, obj in enumerate(itrt):
324 | self.data_time.update(time.time() - end)
325 | yield obj
326 | self.iter_time.update(time.time() - end)
327 | if i in print_iters:
328 | eta_seconds = self.iter_time.global_avg * (max_iters - i)
329 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
330 | print(log_msg.format(
331 | i, max_iters, eta=eta_string,
332 | meters=str(self),
333 | time=str(self.iter_time), data=str(self.data_time)))
334 | end = time.time()
335 |
336 | total_time = time.time() - start_time
337 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
338 | print('{} Total time: {} ({:.3f} s / it)'.format(
339 | header, total_time_str, total_time / max_iters))
340 |
--------------------------------------------------------------------------------
/pretrain/viz_imgs/recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keyu-tian/SparK/a63e386f8e5186bc07ad7fce86e06b08f48a61ea/pretrain/viz_imgs/recon.png
--------------------------------------------------------------------------------
/pretrain/viz_imgs/spconv1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keyu-tian/SparK/a63e386f8e5186bc07ad7fce86e06b08f48a61ea/pretrain/viz_imgs/spconv1.png
--------------------------------------------------------------------------------
/pretrain/viz_imgs/spconv2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keyu-tian/SparK/a63e386f8e5186bc07ad7fce86e06b08f48a61ea/pretrain/viz_imgs/spconv2.png
--------------------------------------------------------------------------------
/pretrain/viz_imgs/spconv3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keyu-tian/SparK/a63e386f8e5186bc07ad7fce86e06b08f48a61ea/pretrain/viz_imgs/spconv3.png
--------------------------------------------------------------------------------