├── .gitignore
├── LICENSE
├── README.md
├── cal_inference_time.py
├── datasets
├── __init__.py
├── mydataset.py
├── split_data.py
└── threeaugment.py
├── environment.yml
├── estimate_model.py
├── models
├── __init__.py
├── blocks.py
├── build_mobilenet_v4.py
├── extra_attention_block.py
└── model_utils.py
├── onnx_export.py
├── onnx_optimise.py
├── onnx_validate.py
├── optim_AUC.py
├── prediction_probs.png
├── sample_png
└── mobilenetV4.jpg
├── train_gpu.py
├── util
├── __init__.py
├── engine.py
├── losses.py
├── optimizer.py
├── samplers.py
└── utils.py
├── visualize.py
└── weight_converter.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea/
6 |
7 | # ckpt files
8 | *.safetensors
9 | *.bin
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 奔波儿灞
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
MobileNetV4
2 |
3 | # [MobileNetV4 -- Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518)
4 | ## This project is implemented in PyTorch, can be used to train your image-datasets for vision tasks.
5 | ## [official source code](https://github.com/tensorflow/models/blob/master/official/vision/modeling/backbones/mobilenet.py)
6 | ## For segmentation tasks, please refer this [github warehouse](https://github.com/jiaowoguanren0615/Segmentation_Factory/blob/main/models/backbones/mobilenetv4.py)
7 | ## For detection tasks(___Based on DETR Detector architecture___), please refer this [github warehouse](https://github.com/jiaowoguanren0615/Detection-Factory/blob/main/configs/salience_detr_mobilenetv4_medium_800_1333.py)
8 | 
9 |
10 |
11 |
12 | ## Preparation
13 |
14 | ### Create conda virtual-environment
15 | ```bash
16 | conda env create -f environment.yml
17 | ```
18 |
19 | ### Download the dataset:
20 | [flower_dataset](https://www.kaggle.com/datasets/alxmamaev/flowers-recognition).
21 |
22 | ## Project Structure
23 | ```
24 | ├── datasets: Load datasets
25 | ├── my_dataset.py: Customize reading data sets and define transforms data enhancement methods
26 | ├── split_data.py: Define the function to read the image dataset and divide the training-set and test-set
27 | ├── threeaugment.py: Additional data augmentation methods
28 | ├── models: MobileNetV4 Model
29 | ├── build_mobilenet_v4.py: Construct MobileNetV4 models
30 | ├── extra_attention_block.py: MultiScaleAttentionGate module
31 | ├── util:
32 | ├── engine.py: Function code for a training/validation process
33 | ├── losses.py: Knowledge distillation loss, combined with teacher model (if any)
34 | ├── optimizer.py: Define Sophia/MARS optimizer
35 | ├── samplers.py: Define the parameter of "sampler" in DataLoader
36 | ├── utils.py: Record various indicator information and output and distributed environment
37 | ├── estimate_model.py: Visualized evaluation indicators ROC curve, confusion matrix, classification report, etc.
38 | └── train_gpu.py: Training model startup file (including infer process)
39 | ```
40 |
41 | ## Precautions
42 | Before you use the code to train your own data set, please first enter the ___train_gpu.py___ file and modify the ___data_root___, ___batch_size___, ___num_workers___ and ___nb_classes___ parameters. If you want to draw the confusion matrix and ROC curve, you only need to set the ___predict___ parameter to __True__.
43 | If you want to add an extra MSAG(MultiScaleAttentionGate) module, set the __extra_attention_block__ parameter to True.
44 | Moreover, you can set the ___opt_auc___ parameter to True if you want to optimize your model for a better performance(maybe~).
45 |
46 | ## Use Sophia Optimizer (in util/optimizer.py)
47 | You can use anther optimizer sophia, just need to change the optimizer in ___train_gpu.py___, for this training sample, can achieve better results
48 | ```
49 | # optimizer = create_optimizer(args, model_without_ddp)
50 | optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=args.weight_decay)
51 | ```
52 |
53 | ## Train this model
54 |
55 | ### Parameters Meaning:
56 | ```
57 | 1. nproc_per_node:
58 | 2. CUDA_VISIBLE_DEVICES:
59 | 3. nnodes:
60 | 4. node_rank:
61 | 5. master_addr:
62 | 6. master_port:
63 | ```
64 | ### Transfer Learning:
65 | Step 1: Download the [pretrained-weights](https://huggingface.co/timm/mobilenetv4_conv_large.e500_r256_in1k#model-comparison)
66 | Step 2: Write the ___pre-training weight path___ into the ___args.finetune___ in string format. Adjust ___args.input_size___ parameter based on the model pre-trained on images of different sizes.
67 | Step 3: Modify the ___args.freeze_layers___ according to your own GPU memory. If you don't have enough memory, you can set this to True to freeze the weights of the remaining layers except the last layer of classification-head without updating the parameters. If you have enough memory, you can set this to False and not freeze the model weights.
68 |
69 | #### Here is an example for setting parameters:
70 | 
71 |
72 | ### Note:
73 | If you want to use multiple GPU for training, whether it is a single machine with multiple GPUs or multiple machines with multiple GPUs, each GPU will divide the batch_size equally. For example, batch_size=4 in my train_gpu.py. If I want to use 2 GPUs for training, it means that the batch_size on each GPU is 4. ___Do not let batch_size=1 on each GPU___, otherwise BN layer maybe report an error.
74 |
75 | ### train model with single-machine single-GPU:
76 | ```
77 | python train_gpu.py
78 | ```
79 |
80 | ### train model with single-machine multi-GPU:
81 | ```
82 | python -m torch.distributed.run --nproc_per_node=8 train_gpu.py
83 | ```
84 |
85 | ### train model with single-machine multi-GPU:
86 | (using a specified part of the GPUs: for example, I want to use the second and fourth GPUs)
87 | ```
88 | CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.run --nproc_per_node=2 train_gpu.py
89 | ```
90 |
91 | ### train model with multi-machine multi-GPU:
92 | (For the specific number of GPUs on each machine, modify the value of --nproc_per_node. If you want to specify a certain GPU, just add CUDA_VISIBLE_DEVICES= to specify the index number of the GPU before each command. The principle is the same as single-machine multi-GPU training)
93 | ```
94 | On the first machine: python -m torch.distributed.run --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr= --master_port= train_gpu.py
95 |
96 | On the second machine: python -m torch.distributed.run --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr= --master_port= train_gpu.py
97 | ```
98 |
99 | ## ONNX Deployment
100 | ### step 1: ONNX export (modify the param of ___output___, ___model___ and ___checkpoint___)
101 | ```bash
102 | python onnx_export.py --model=mobilenetv4_small --output=./mobilenetv4_small.onnx --checkpoint=./output/mobilenetv4_small_best_checkpoint.pth
103 | ```
104 |
105 | ### step2: ONNX optimise
106 | ```bash
107 | python onnx_optimise.py --model=mobilenetv4_small --output=./mobilenetv4_small_optim.onnx'
108 | ```
109 |
110 | ### step3: ONNX validate (modify the param of ___data_root___ and ___onnx-input___)
111 | ```bash
112 | python onnx_validate.py --data_root=/mnt/d/flower_data --onnx-input=./mobilenetv4_small_optim.onnx
113 | ```
114 |
115 |
116 | ## Citation
117 | ```
118 | @article{qin2024mobilenetv4,
119 | title={MobileNetV4-Universal Models for the Mobile Ecosystem},
120 | author={Qin, Danfeng and Leichner, Chas and Delakis, Manolis and Fornoni, Marco and Luo, Shixin and Yang, Fan and Wang, Weijun and Banbury, Colby and Ye, Chengxi and Akin, Berkin and others},
121 | journal={arXiv preprint arXiv:2404.10518},
122 | year={2024}
123 | }
124 | ```
125 |
126 | ## Star History
127 |
128 | [](https://star-history.com/#jiaowoguanren0615/MobileNetV4&Date)
129 |
--------------------------------------------------------------------------------
/cal_inference_time.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn as nn
4 | import argparse
5 | import models
6 | import numpy as np
7 | from timm.models import create_model
8 |
9 |
10 | parser = argparse.ArgumentParser(description='PyTorch MobileNetV4 Inference Speed Test')
11 | # Model params
12 | parser.add_argument('--model', default='mobilenetv4_conv_large', type=str, metavar='MODEL',
13 | choices=['mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large_075',
14 | 'mobilenetv4_conv_large', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_medium',
15 | 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_medium_075',
16 | 'mobilenetv4_conv_small_035', 'mobilenetv4_conv_small_050', 'mobilenetv4_conv_blur_medium'],
17 | help='Name of model to train')
18 | parser.add_argument('--device', default='cuda', type=str)
19 | parser.add_argument('--batch-size', default=32, type=int, help='batch size (default: 32)')
20 | parser.add_argument('--img-size', default=224, type=int,
21 | metavar='N', help='Input image dimension, uses model default if empty')
22 | parser.add_argument('--nb-classes', type=int, default=5,
23 | help='Number classes in datasets')
24 |
25 |
26 | def do_pure_cpu_task():
27 | x = np.random.randn(1, 3, 512, 512).astype(np.float32)
28 | x = x * 1024 ** 0.5
29 |
30 |
31 | @torch.inference_mode()
32 | def cal_time3(model, x, args):
33 | start_event = torch.cuda.Event(enable_timing=True)
34 | end_event = torch.cuda.Event(enable_timing=True)
35 | time_list = []
36 | for _ in range(50):
37 | # do_pure_cpu_task() ## cpu warm up, not necessary
38 | start_event.record()
39 | ret = model(x)
40 | end_event.record()
41 | end_event.synchronize()
42 | time_list.append(start_event.elapsed_time(end_event) / 1000)
43 |
44 | print(f"{args.model} inference avg time: {sum(time_list[5:]) / len(time_list[5:]):.5f}") ## warm up, remove start 5 times
45 |
46 |
47 | def main(args):
48 |
49 | device = args.device
50 | model = create_model(
51 | args.model,
52 | num_classes=args.nb_classes
53 | )
54 | model.eval().to(device)
55 |
56 | x = torch.randn(size=(args.batch_size, 3, args.img_size, args.img_size), device=device)
57 | cal_time3(model, x, args)
58 |
59 |
60 | if __name__ == '__main__':
61 | args = parser.parse_args()
62 | main(args)
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .mydataset import build_dataset, build_transform, MyDataset
2 | from .split_data import read_split_data
3 | from .threeaugment import new_data_aug_generator
--------------------------------------------------------------------------------
/datasets/mydataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from torchvision import transforms
4 | from .split_data import read_split_data
5 | from torch.utils.data import Dataset
6 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
7 |
8 |
9 | class MyDataset(Dataset):
10 | def __init__(self, image_paths, image_labels, transforms=None):
11 | self.image_paths = image_paths
12 | self.image_labels = image_labels
13 | self.transforms = transforms
14 |
15 | def __getitem__(self, item):
16 | image = Image.open(self.image_paths[item]).convert('RGB')
17 | label = self.image_labels[item]
18 | if self.transforms:
19 | image = self.transforms(image)
20 | return image, label
21 |
22 | def __len__(self):
23 | return len(self.image_paths)
24 |
25 | @staticmethod
26 | def collate_fn(batch):
27 | images, labels = tuple(zip(*batch))
28 | images = torch.stack(images, dim=0)
29 | labels = torch.as_tensor(labels)
30 | return images, labels
31 |
32 |
33 |
34 | def build_transform(is_train, args):
35 | resize_im = args.input_size > 32
36 | if is_train:
37 | # this should always dispatch to transforms_imagenet_train
38 | transform = create_transform(
39 | input_size=args.input_size,
40 | is_training=True,
41 | color_jitter=args.color_jitter,
42 | auto_augment=args.aa,
43 | interpolation=args.train_interpolation,
44 | re_prob=args.reprob,
45 | re_mode=args.remode,
46 | re_count=args.recount,
47 | )
48 | if not resize_im:
49 | # replace RandomResizedCropAndInterpolation with
50 | # RandomCrop
51 | transform.transforms[0] = transforms.RandomCrop(
52 | args.input_size, padding=4)
53 | return transform
54 |
55 | t = []
56 | if resize_im:
57 | # size = int((256 / 224) * args.input_size)
58 | size = int((1.0 / 0.96) * args.input_size)
59 | t.append(
60 | # to maintain same ratio w.r.t. 224 images
61 | transforms.Resize(size, interpolation=3),
62 | )
63 | t.append(transforms.CenterCrop(args.input_size))
64 |
65 | t.append(transforms.ToTensor())
66 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
67 | return transforms.Compose(t)
68 |
69 |
70 | def build_dataset(args):
71 | train_image_path, train_image_label, val_image_path, val_image_label, class_indices = read_split_data(args.data_root)
72 |
73 | train_transform = build_transform(True, args)
74 | valid_transform = build_transform(False, args)
75 |
76 | train_set = MyDataset(train_image_path, train_image_label, train_transform)
77 | valid_set = MyDataset(val_image_path, val_image_label, valid_transform)
78 |
79 | return train_set, valid_set
80 |
81 |
--------------------------------------------------------------------------------
/datasets/split_data.py:
--------------------------------------------------------------------------------
1 | import os, cv2, json, random
2 | import pandas as pd
3 | from tqdm import tqdm
4 | from sklearn.model_selection import train_test_split
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | def read_split_data(root, plot_image=False):
9 | filepaths = []
10 | labels = []
11 | bad_images = []
12 |
13 | random.seed(0)
14 | assert os.path.exists(root), 'Your root does not exists!!!'
15 |
16 | classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
17 | classes.sort()
18 | class_indices = {k: v for v, k in enumerate(classes)}
19 |
20 | json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4)
21 |
22 | with open('./classes_indices.json', 'w') as json_file:
23 | json_file.write(json_str)
24 |
25 | every_class_num = []
26 | supported = ['.jpg', '.png', '.jpeg', '.PNG', '.JPG', '.JPEG']
27 |
28 | for klass in classes:
29 | classpath = os.path.join(root, klass)
30 | images = [os.path.join(root, klass, i) for i in os.listdir(classpath) if os.path.splitext(i)[-1] in supported]
31 | every_class_num.append(len(images))
32 | flist = sorted(os.listdir(classpath))
33 | desc = f'{klass:23s}'
34 | for f in tqdm(flist, ncols=110, desc=desc, unit='file', colour='blue'):
35 | fpath = os.path.join(classpath, f)
36 | fl = f.lower()
37 | index = fl.rfind('.')
38 | ext = fl[index:]
39 | if ext in supported:
40 | try:
41 | img = cv2.imread(fpath)
42 | filepaths.append(fpath)
43 | labels.append(klass)
44 | except:
45 | bad_images.append(fpath)
46 | print('defective image file: ', fpath)
47 | else:
48 | bad_images.append(fpath)
49 |
50 | Fseries = pd.Series(filepaths, name='filepaths')
51 | Lseries = pd.Series(labels, name='labels')
52 | df = pd.concat([Fseries, Lseries], axis=1)
53 |
54 | print(f'{len(df.labels.unique())} kind of images were found in the dataset')
55 | train_df, test_df = train_test_split(df, train_size=.8, shuffle=True, random_state=123, stratify=df['labels'])
56 |
57 | train_image_path = train_df['filepaths'].tolist()
58 | val_image_path = test_df['filepaths'].tolist()
59 |
60 | train_image_label = [class_indices[i] for i in train_df['labels'].tolist()]
61 | val_image_label = [class_indices[i] for i in test_df['labels'].tolist()]
62 |
63 | sample_df = train_df.sample(n=50, replace=False)
64 | ht, wt, count = 0, 0, 0
65 | for i in range(len(sample_df)):
66 | fpath = sample_df['filepaths'].iloc[i]
67 | try:
68 | img = cv2.imread(fpath)
69 | h = img.shape[0]
70 | w = img.shape[1]
71 | ht += h
72 | wt += w
73 | count += 1
74 | except:
75 | pass
76 | have = int(ht / count)
77 | wave = int(wt / count)
78 | aspect_ratio = have / wave
79 | print('{} images were found in the dataset.\n{} for training, {} for validation'.format(
80 | sum(every_class_num), len(train_image_path), len(val_image_path)
81 | ))
82 | print('average image height= ', have, ' average image width= ', wave, ' aspect ratio h/w= ', aspect_ratio)
83 |
84 | if plot_image:
85 | plt.bar(range(len(classes)), every_class_num, align='center')
86 | plt.xticks(range(len(classes)), classes)
87 |
88 | for i, v in enumerate(every_class_num):
89 | plt.text(x=i, y=v + 5, s=str(v), ha='center')
90 |
91 | plt.xlabel('image class')
92 | plt.ylabel('number of images')
93 |
94 | plt.title('class distribution')
95 | plt.show()
96 |
97 | return train_image_path, train_image_label, val_image_path, val_image_label, class_indices
--------------------------------------------------------------------------------
/datasets/threeaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | 3Augment implementation from (https://github.com/facebookresearch/deit/blob/main/augment.py)
3 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
4 | and timm DA(https://github.com/rwightman/pytorch-image-models)
5 | Can be called by adding "--ThreeAugment" to the command line
6 | """
7 | import torch
8 | from timm.data.transforms import str_to_pil_interp, RandomResizedCropAndInterpolation
9 | from torchvision import transforms
10 | import random
11 |
12 | from PIL import ImageFilter, ImageOps
13 |
14 |
15 | class GaussianBlur(object):
16 | """
17 | Apply Gaussian Blur to the PIL image.
18 | """
19 |
20 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
21 | self.prob = p
22 | self.radius_min = radius_min
23 | self.radius_max = radius_max
24 |
25 | def __call__(self, img):
26 | do_it = random.random() <= self.prob
27 | if not do_it:
28 | return img
29 |
30 | img = img.filter(
31 | ImageFilter.GaussianBlur(
32 | radius=random.uniform(self.radius_min, self.radius_max)
33 | )
34 | )
35 | return img
36 |
37 |
38 | class Solarization(object):
39 | """
40 | Apply Solarization to the PIL image.
41 | """
42 |
43 | def __init__(self, p=0.2):
44 | self.p = p
45 |
46 | def __call__(self, img):
47 | if random.random() < self.p:
48 | return ImageOps.solarize(img)
49 | else:
50 | return img
51 |
52 |
53 | class gray_scale(object):
54 | """
55 | Apply Solarization to the PIL image.
56 | """
57 |
58 | def __init__(self, p=0.2):
59 | self.p = p
60 | self.transf = transforms.Grayscale(3)
61 |
62 | def __call__(self, img):
63 | if random.random() < self.p:
64 | return self.transf(img)
65 | else:
66 | return img
67 |
68 |
69 | class horizontal_flip(object):
70 | """
71 | Apply Solarization to the PIL image.
72 | """
73 |
74 | def __init__(self, p=0.2, activate_pred=False):
75 | self.p = p
76 | self.transf = transforms.RandomHorizontalFlip(p=1.0)
77 |
78 | def __call__(self, img):
79 | if random.random() < self.p:
80 | return self.transf(img)
81 | else:
82 | return img
83 |
84 |
85 | def new_data_aug_generator(args=None):
86 | img_size = args.input_size
87 | remove_random_resized_crop = False
88 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
89 | primary_tfl = []
90 | scale = (0.08, 1.0)
91 | interpolation = 'bicubic'
92 | if remove_random_resized_crop:
93 | primary_tfl = [
94 | transforms.Resize(img_size, interpolation=3),
95 | transforms.RandomCrop(img_size, padding=4, padding_mode='reflect'),
96 | transforms.RandomHorizontalFlip()
97 | ]
98 | else:
99 | primary_tfl = [
100 | RandomResizedCropAndInterpolation(
101 | img_size, scale=scale, interpolation=interpolation),
102 | transforms.RandomHorizontalFlip()
103 | ]
104 |
105 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
106 | Solarization(p=1.0),
107 | GaussianBlur(p=1.0)])]
108 |
109 | if args.color_jitter is not None and not args.color_jitter == 0:
110 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
111 | final_tfl = [
112 | transforms.ToTensor(),
113 | transforms.Normalize(
114 | mean=torch.tensor(mean),
115 | std=torch.tensor(std))
116 | ]
117 | return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: CV
2 | dependencies:
3 | - python=3.9
4 | - pip
5 | - pip:
6 | - onnx==1.13
7 | - onnxoptimizer==0.3.13
8 | - onnxruntime==1.18
9 | - matplotlib==3.5.1
10 | - numpy==1.23.0
11 | - opencv-contrib-python==4.7.0.72
12 | - opencv-python==4.7.0.72
13 | - openpyxl==3.1.2
14 | - pandas==1.5.3
15 | - pillow==9.3.0
16 | - terminaltables==3.1.10
17 | - plotly==5.14.1
18 | - scikit-learn==1.3.0
19 | - tensorboardx==2.6.2.2
20 | - timm==1.0.11
21 | - torch==2.0.1+cu118
22 | - torchaudio==2.0.2+cu118
23 | - torchinfo==1.7.2
24 | - torchvision==0.15.2+cu118
25 | - transformers==4.28.1
26 | - seaborn==0.13.2
27 | - safetensors==0.4.5
28 | - terminaltables==3.1.10
29 |
--------------------------------------------------------------------------------
/estimate_model.py:
--------------------------------------------------------------------------------
1 | import torch, json, os
2 | import seaborn as sns
3 | from sklearn.metrics import auc, f1_score, roc_curve, classification_report, confusion_matrix, roc_auc_score
4 | from itertools import cycle
5 | from numpy import interp
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from PIL import Image
9 | from torchvision import transforms
10 | from typing import Iterable
11 | from optim_AUC import OptimizeAUC
12 | from terminaltables import AsciiTable
13 |
14 |
15 | @torch.inference_mode()
16 | def Plot_ROC(net: torch.nn.Module, val_loader: Iterable, save_name: str, device: torch.device):
17 | """
18 | Plot ROC Curve
19 |
20 | Save the roc curve as an image file in the current directory
21 |
22 | Args:
23 | net (torch.nn.Module): The model to be evaluated.
24 | val_loader (Iterable): The data loader for the valid data.
25 | save_name (str): The file path of your model weights
26 | device (torch.device): The device used for training (CPU or GPU).
27 |
28 | Returns:
29 | None
30 | """
31 |
32 | try:
33 | json_file = open('./classes_indices.json', 'r')
34 | class_indict = json.load(json_file)
35 | except Exception as e:
36 | print(e)
37 | exit(-1)
38 |
39 | score_list = []
40 | label_list = []
41 |
42 | net.load_state_dict(torch.load(save_name)['model'])
43 |
44 | for i, data in enumerate(val_loader):
45 | images, labels = data
46 | images, labels = images.to(device), labels.to(device)
47 | outputs = torch.softmax(net(images), dim=1)
48 | score_tmp = outputs
49 | score_list.extend(score_tmp.detach().cpu().numpy())
50 | label_list.extend(labels.cpu().numpy())
51 |
52 | score_array = np.array(score_list)
53 | # convert label to one-hot form
54 | label_tensor = torch.tensor(label_list)
55 | label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
56 | label_onehot = torch.zeros(label_tensor.shape[0], len(class_indict.keys()))
57 | label_onehot.scatter_(dim=1, index=label_tensor, value=1)
58 | label_onehot = np.array(label_onehot)
59 |
60 | print("score_array:", score_array.shape) # (batchsize, classnum)
61 | print("label_onehot:", label_onehot.shape) # torch.Size([batchsize, classnum])
62 |
63 | # compute tpr and fpr for each label by using sklearn lib
64 | fpr_dict = dict()
65 | tpr_dict = dict()
66 | roc_auc_dict = dict()
67 | for i in range(len(class_indict.keys())):
68 | fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score_array[:, i])
69 | roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])
70 | # micro
71 | fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score_array.ravel())
72 | roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"])
73 |
74 | # macro
75 | # First aggregate all false positive rates
76 | all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(len(class_indict.keys()))]))
77 | # Then interpolate all ROC curves at this points
78 | mean_tpr = np.zeros_like(all_fpr)
79 |
80 | for i in range(len(set(label_list))):
81 | mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
82 |
83 | # Finally average it and compute AUC
84 | mean_tpr /= len(class_indict.keys())
85 | fpr_dict["macro"] = all_fpr
86 | tpr_dict["macro"] = mean_tpr
87 | roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"])
88 |
89 | # plot roc curve for each label
90 | plt.figure(figsize=(12, 12))
91 | lw = 2
92 |
93 | plt.plot(fpr_dict["micro"], tpr_dict["micro"],
94 | label='micro-average ROC curve (area = {0:0.2f})'
95 | ''.format(roc_auc_dict["micro"]),
96 | color='deeppink', linestyle=':', linewidth=4)
97 |
98 | plt.plot(fpr_dict["macro"], tpr_dict["macro"],
99 | label='macro-average ROC curve (area = {0:0.2f})'
100 | ''.format(roc_auc_dict["macro"]),
101 | color='navy', linestyle=':', linewidth=4)
102 |
103 | colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
104 | for i, color in zip(range(len(class_indict.keys())), colors):
105 | plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw,
106 | label='ROC curve of class {0} (area = {1:0.2f})'
107 | ''.format(class_indict[str(i)], roc_auc_dict[i]))
108 |
109 | plt.plot([0, 1], [0, 1], 'k--', lw=lw, label='Chance', color='red')
110 | plt.xlim([0.0, 1.0])
111 | plt.ylim([0.0, 1.05])
112 | plt.xlabel('False Positive Rate')
113 | plt.ylabel('True Positive Rate')
114 | plt.title('Receiver operating characteristic to multi-class')
115 | plt.legend(loc="lower right")
116 | plt.savefig('./multi_classes_roc.png')
117 | # plt.show()
118 |
119 |
120 | @torch.inference_mode()
121 | def predict_single_image(model: torch.nn.Module, device: torch.device, weight_path: str):
122 | """
123 | Predict Single Image.
124 |
125 | Save the prediction as an image file which including pred label and prob in the current directory
126 |
127 | Args:
128 | model (torch.nn.Module): The model to be evaluated.
129 | device (torch.device): The device used for training (CPU or GPU).
130 | weight_path (str): The model weights file
131 |
132 | Returns:
133 | None
134 | """
135 |
136 | data_transform = {
137 | 'train': transforms.Compose([transforms.RandomResizedCrop(224), transforms.ToTensor(),
138 | transforms.RandomHorizontalFlip(),
139 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
140 |
141 | 'valid': transforms.Compose([transforms.Resize((224, 224)), transforms.CenterCrop(224),
142 | transforms.ToTensor(),
143 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
144 | }
145 |
146 | img_transform = data_transform['valid']
147 |
148 | # load image
149 | img_path = "rose.jpg"
150 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
151 | img = Image.open(img_path)
152 | plt.imshow(img)
153 | # [N, C, H, W]
154 | img = img_transform(img)
155 | # expand batch dimension
156 | img = torch.unsqueeze(img, dim=0)
157 |
158 | # read class_indict
159 | json_path = './classes_indices.json'
160 | assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
161 |
162 | with open(json_path, "r") as f:
163 | class_indict = json.load(f)
164 |
165 | # load model weights
166 |
167 | assert os.path.exists(weight_path), "weight file dose not exist."
168 | model.load_state_dict(torch.load(weight_path, map_location=device)['model'])
169 |
170 | model.eval()
171 | # predict class
172 | output = torch.squeeze(model(img.to(device))).cpu()
173 | predict = torch.softmax(output, dim=0)
174 | predict_cla = torch.argmax(predict).numpy()
175 |
176 | print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
177 | predict[predict_cla].numpy())
178 |
179 | plt.title(print_res)
180 | for i in range(len(predict)):
181 | print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
182 | predict[i].numpy()))
183 | plt.savefig(f'./pred_{img_path}')
184 | # plt.show()
185 |
186 |
187 | @torch.inference_mode()
188 | def Predictor(net: torch.nn.Module, test_loader: Iterable, save_name: str, device: torch.device):
189 | """
190 | Evaluate the performance of the model on the given dataset.
191 |
192 | 1. This function will print the following metrics:
193 | - F1 score
194 | - Confusion matrix
195 | - Classification report
196 |
197 | 2. Save the confusion matrix as an image file in the current directory.
198 |
199 | Args:
200 | net (torch.nn.Module): The model to be evaluated.
201 | test_loader (Iterable): The data loader for the valid data.
202 | save_name (str): The file path of your model weights
203 | device (torch.device): The device used for training (CPU or GPU).
204 |
205 | Returns:
206 | None
207 | """
208 |
209 | try:
210 | json_file = open('./classes_indices.json', 'r')
211 | class_indict = json.load(json_file)
212 | except Exception as e:
213 | print(e)
214 | exit(-1)
215 |
216 | errors = 0
217 | y_pred, y_true = [], []
218 | net.load_state_dict(torch.load(save_name)['model'])
219 |
220 | net.eval()
221 |
222 | for data in test_loader:
223 | images, labels = data
224 | images, labels = images.to(device), labels.to(device)
225 | preds = torch.argmax(torch.softmax(net(images), dim=1), dim=1)
226 | for i in range(len(preds)):
227 | y_pred.append(preds[i].cpu())
228 | y_true.append(labels[i].cpu())
229 |
230 | tests = len(y_pred)
231 | for i in range(tests):
232 | pred_index = y_pred[i]
233 | true_index = y_true[i]
234 | if pred_index != true_index:
235 | errors += 1
236 |
237 | acc = (1 - errors / tests) * 100
238 | print(f'there were {errors} errors in {tests} tests for an accuracy of {acc:6.2f}%')
239 |
240 | ypred = np.array(y_pred)
241 | ytrue = np.array(y_true)
242 |
243 | f1score = f1_score(ytrue, ypred, average='weighted') * 100
244 |
245 | print(f'The F1-score was {f1score:.3f}')
246 | class_count = len(list(class_indict.values()))
247 | classes = list(class_indict.values())
248 |
249 | cm = confusion_matrix(ytrue, ypred)
250 | plt.figure(figsize=(16, 8))
251 | plt.subplot(1, 2, 1)
252 | sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)
253 | plt.xticks(np.arange(class_count) + .5, classes, rotation=45, fontsize=14)
254 | plt.yticks(np.arange(class_count) + .5, classes, rotation=0, fontsize=14)
255 | plt.xlabel("Predicted", fontsize=14)
256 | plt.ylabel("True", fontsize=14)
257 | plt.title("Confusion Matrix")
258 |
259 | plt.subplot(1, 2, 2)
260 | sns.heatmap(cm / np.sum(cm), annot=True, fmt='.1%')
261 | plt.xticks(np.arange(class_count) + .5, classes, rotation=45, fontsize=14)
262 | plt.yticks(np.arange(class_count) + .5, classes, rotation=0, fontsize=14)
263 | plt.xlabel('Predicted', fontsize=14)
264 | plt.ylabel('True', fontsize=14)
265 | plt.savefig('./confusion_matrix.png')
266 | # plt.show()
267 |
268 | clr = classification_report(y_true, y_pred, target_names=classes, digits=4)
269 | print("Classification Report:\n----------------------\n", clr)
270 |
271 |
272 | @torch.inference_mode()
273 | def OptAUC(net: torch.nn.Module, val_loader: Iterable, save_name: str, device: torch.device):
274 | """
275 | Optimize model for improving AUC
276 |
277 | Print a table of initial and optimized AUC and F1-score.
278 |
279 | This function takes the initial and optimized AUC and F1-score, and generates
280 | an ASCII table to display the results. The table will have the following format:
281 |
282 | Optimize Results
283 | +----------------------+----------------------+----------------------+----------------------+
284 | | Initial AUC | Initial F1-Score | Optimize AUC | Optimize F1-Score |
285 | +----------------------+----------------------+----------------------+----------------------+
286 | | 0.654321 | 0.654321 | 0.876543 | 0.876543 |
287 | +----------------------+----------------------+----------------------+----------------------+
288 |
289 | The optimized AUC and F1-score are obtained by using the `OptimizeAUC` class (in ./optim_AUC.py), which
290 | performs optimization on the initial metrics.
291 |
292 | Args:
293 | net (torch.nn.Module): The model to be evaluated.
294 | test_loader (Iterable): The data loader for the valid data.
295 | save_name (str): The file path of your model weights
296 | device (torch.device): The device used for training (CPU or GPU).
297 |
298 | Returns:
299 | None
300 | """
301 |
302 | score_list = []
303 | label_list = []
304 |
305 | net.load_state_dict(torch.load(save_name)['model'])
306 |
307 | for i, data in enumerate(val_loader):
308 | images, labels = data
309 | images, labels = images.to(device), labels.to(device)
310 | outputs = torch.softmax(net(images), dim=1)
311 | score_tmp = outputs
312 | score_list.extend(score_tmp.detach().cpu().numpy())
313 | label_list.extend(labels.detach().cpu().numpy())
314 |
315 | score_array = np.array(score_list)
316 | label_list = np.array(label_list)
317 | y_preds = np.argmax(score_array, axis=1)
318 | f1score = f1_score(label_list, y_preds, average='weighted') * 100
319 | auc_score = roc_auc_score(label_list, score_array, average='weighted', multi_class='ovo')
320 |
321 | opt_auc = OptimizeAUC()
322 | opt_auc.fit(score_array, label_list)
323 | opt_preds = opt_auc.predict(score_array)
324 | opt_y_preds = np.argmax(opt_preds, axis=1)
325 | opt_f1score = f1_score(label_list, opt_y_preds, average='weighted') * 100
326 | opt_auc_score = roc_auc_score(label_list, opt_preds, average='weighted', multi_class='ovo')
327 |
328 | TITLE = 'Optimize Results'
329 | TABLE_DATA = (
330 | ('Initial AUC', 'Initial F1-Score', 'Optimize AUC', 'Optimize F1-Score'),
331 | ('{:.6f}'.format(auc_score),
332 | '{:.6f}'.format(f1score),
333 | '{:.6f}'.format(opt_auc_score),
334 | '{:.6f}'.format(opt_f1score)
335 | ),
336 | )
337 | table_instance = AsciiTable(TABLE_DATA, TITLE)
338 | print(table_instance.table)
339 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .blocks import *
2 | from .model_utils import *
3 | from .build_mobilenet_v4 import mobilenetv4_hybrid_large, mobilenetv4_hybrid_medium, mobilenetv4_hybrid_large_075, \
4 | mobilenetv4_conv_large, mobilenetv4_conv_aa_large, mobilenetv4_conv_medium, mobilenetv4_conv_aa_medium, \
5 | mobilenetv4_conv_small, mobilenetv4_hybrid_medium_075, mobilenetv4_conv_small_035, \
6 | mobilenetv4_conv_small_050, mobilenetv4_conv_blur_medium
--------------------------------------------------------------------------------
/models/blocks.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, Optional, Type
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | from timm.layers import create_conv2d, DropPath, create_act_layer, create_aa, to_2tuple, LayerType,\
8 | ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d
9 |
10 |
11 |
12 | __all__ = [
13 | 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
14 | 'UniversalInvertedResidual', 'MobileAttention'
15 | ]
16 |
17 | ModuleType = Type[nn.Module]
18 |
19 |
20 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
21 | min_value = min_value or divisor
22 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
23 | # Make sure that round down does not go down by more than 10%.
24 | if new_v < round_limit * v:
25 | new_v += divisor
26 | return new_v
27 |
28 |
29 | def num_groups(group_size: Optional[int], channels: int):
30 | if not group_size: # 0 or None
31 | return 1 # normal conv with 1 group
32 | else:
33 | # NOTE group_size == 1 -> depthwise conv
34 | assert channels % group_size == 0
35 | return channels // group_size
36 |
37 |
38 | class SqueezeExcite(nn.Module):
39 | """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
40 |
41 | Args:
42 | in_chs (int): input channels to layer
43 | rd_ratio (float): ratio of squeeze reduction
44 | act_layer (nn.Module): activation layer of containing block
45 | gate_layer (Callable): attention gate function
46 | force_act_layer (nn.Module): override block's activation fn if this is set/bound
47 | rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
48 | """
49 |
50 | def __init__(
51 | self,
52 | in_chs: int,
53 | rd_ratio: float = 0.25,
54 | rd_channels: Optional[int] = None,
55 | act_layer: LayerType = nn.ReLU,
56 | gate_layer: LayerType = nn.Sigmoid,
57 | force_act_layer: Optional[LayerType] = None,
58 | rd_round_fn: Optional[Callable] = None,
59 | ):
60 | super(SqueezeExcite, self).__init__()
61 | if rd_channels is None:
62 | rd_round_fn = rd_round_fn or round
63 | rd_channels = rd_round_fn(in_chs * rd_ratio)
64 | act_layer = force_act_layer or act_layer
65 | self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
66 | self.act1 = create_act_layer(act_layer, inplace=True)
67 | self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
68 | self.gate = create_act_layer(gate_layer)
69 |
70 | def forward(self, x):
71 | x_se = x.mean((2, 3), keepdim=True)
72 | x_se = self.conv_reduce(x_se)
73 | x_se = self.act1(x_se)
74 | x_se = self.conv_expand(x_se)
75 | return x * self.gate(x_se)
76 |
77 |
78 | class ConvBnAct(nn.Module):
79 | """ Conv + Norm Layer + Activation w/ optional skip connection
80 | """
81 | def __init__(
82 | self,
83 | in_chs: int,
84 | out_chs: int,
85 | kernel_size: int,
86 | stride: int = 1,
87 | dilation: int = 1,
88 | group_size: int = 0,
89 | pad_type: str = '',
90 | skip: bool = False,
91 | act_layer: LayerType = nn.ReLU,
92 | norm_layer: LayerType = nn.BatchNorm2d,
93 | aa_layer: Optional[LayerType] = None,
94 | drop_path_rate: float = 0.,
95 | ):
96 | super(ConvBnAct, self).__init__()
97 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
98 | groups = num_groups(group_size, in_chs)
99 | self.has_skip = skip and stride == 1 and in_chs == out_chs
100 | use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
101 |
102 | self.conv = create_conv2d(
103 | in_chs, out_chs, kernel_size,
104 | stride=1 if use_aa else stride,
105 | dilation=dilation, groups=groups, padding=pad_type)
106 | self.bn1 = norm_act_layer(out_chs, inplace=True)
107 | self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa)
108 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
109 |
110 | def feature_info(self, location):
111 | if location == 'expansion': # output of conv after act, same as block coutput
112 | return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
113 | else: # location == 'bottleneck', block output
114 | return dict(module='', num_chs=self.conv.out_channels)
115 |
116 | def forward(self, x):
117 | shortcut = x
118 | x = self.conv(x)
119 | x = self.bn1(x)
120 | x = self.aa(x)
121 | if self.has_skip:
122 | x = self.drop_path(x) + shortcut
123 | return x
124 |
125 |
126 | class DepthwiseSeparableConv(nn.Module):
127 | """ Depthwise-separable block
128 | Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
129 | (factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
130 | """
131 | def __init__(
132 | self,
133 | in_chs: int,
134 | out_chs: int,
135 | dw_kernel_size: int = 3,
136 | stride: int = 1,
137 | dilation: int = 1,
138 | group_size: int = 1,
139 | pad_type: str = '',
140 | noskip: bool = False,
141 | pw_kernel_size: int = 1,
142 | pw_act: bool = False,
143 | s2d: int = 0,
144 | act_layer: LayerType = nn.ReLU,
145 | norm_layer: LayerType = nn.BatchNorm2d,
146 | aa_layer: Optional[LayerType] = None,
147 | se_layer: Optional[ModuleType] = None,
148 | drop_path_rate: float = 0.,
149 | ):
150 | super(DepthwiseSeparableConv, self).__init__()
151 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
152 | self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
153 | self.has_pw_act = pw_act # activation after point-wise conv
154 | use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
155 |
156 | # Space to depth
157 | if s2d == 1:
158 | sd_chs = int(in_chs * 4)
159 | self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
160 | self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
161 | dw_kernel_size = (dw_kernel_size + 1) // 2
162 | dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
163 | in_chs = sd_chs
164 | use_aa = False # disable AA
165 | else:
166 | self.conv_s2d = None
167 | self.bn_s2d = None
168 | dw_pad_type = pad_type
169 |
170 | groups = num_groups(group_size, in_chs)
171 |
172 | self.conv_dw = create_conv2d(
173 | in_chs, in_chs, dw_kernel_size,
174 | stride=1 if use_aa else stride,
175 | dilation=dilation, padding=dw_pad_type, groups=groups)
176 | self.bn1 = norm_act_layer(in_chs, inplace=True)
177 | self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa)
178 |
179 | # Squeeze-and-excitation
180 | self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
181 |
182 | self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
183 | self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act)
184 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
185 |
186 | def feature_info(self, location):
187 | if location == 'expansion': # after SE, input to PW
188 | return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
189 | else: # location == 'bottleneck', block output
190 | return dict(module='', num_chs=self.conv_pw.out_channels)
191 |
192 | def forward(self, x):
193 | shortcut = x
194 | if self.conv_s2d is not None:
195 | x = self.conv_s2d(x)
196 | x = self.bn_s2d(x)
197 | x = self.conv_dw(x)
198 | x = self.bn1(x)
199 | x = self.aa(x)
200 | x = self.se(x)
201 | x = self.conv_pw(x)
202 | x = self.bn2(x)
203 | if self.has_skip:
204 | x = self.drop_path(x) + shortcut
205 | return x
206 |
207 |
208 | class InvertedResidual(nn.Module):
209 | """ Inverted residual block w/ optional SE
210 |
211 | Originally used in MobileNet-V2 - https://arxiv.org/abs/1801.04381v4, this layer is often
212 | referred to as 'MBConv' for (Mobile inverted bottleneck conv) and is also used in
213 | * MNasNet - https://arxiv.org/abs/1807.11626
214 | * EfficientNet - https://arxiv.org/abs/1905.11946
215 | * MobileNet-V3 - https://arxiv.org/abs/1905.02244
216 | """
217 |
218 | def __init__(
219 | self,
220 | in_chs: int,
221 | out_chs: int,
222 | dw_kernel_size: int = 3,
223 | stride: int = 1,
224 | dilation: int = 1,
225 | group_size: int = 1,
226 | pad_type: str = '',
227 | noskip: bool = False,
228 | exp_ratio: float = 1.0,
229 | exp_kernel_size: int = 1,
230 | pw_kernel_size: int = 1,
231 | s2d: int = 0,
232 | act_layer: LayerType = nn.ReLU,
233 | norm_layer: LayerType = nn.BatchNorm2d,
234 | aa_layer: Optional[LayerType] = None,
235 | se_layer: Optional[ModuleType] = None,
236 | conv_kwargs: Optional[Dict] = None,
237 | drop_path_rate: float = 0.,
238 | ):
239 | super(InvertedResidual, self).__init__()
240 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
241 | conv_kwargs = conv_kwargs or {}
242 | self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
243 | use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
244 |
245 | # Space to depth
246 | if s2d == 1:
247 | sd_chs = int(in_chs * 4)
248 | self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
249 | self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
250 | dw_kernel_size = (dw_kernel_size + 1) // 2
251 | dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
252 | in_chs = sd_chs
253 | use_aa = False # disable AA
254 | else:
255 | self.conv_s2d = None
256 | self.bn_s2d = None
257 | dw_pad_type = pad_type
258 |
259 | mid_chs = make_divisible(in_chs * exp_ratio)
260 | groups = num_groups(group_size, mid_chs)
261 |
262 | # Point-wise expansion
263 | self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
264 | self.bn1 = norm_act_layer(mid_chs, inplace=True)
265 |
266 | # Depth-wise convolution
267 | self.conv_dw = create_conv2d(
268 | mid_chs, mid_chs, dw_kernel_size,
269 | stride=1 if use_aa else stride,
270 | dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs)
271 | self.bn2 = norm_act_layer(mid_chs, inplace=True)
272 | self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa)
273 |
274 | # Squeeze-and-excitation
275 | self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
276 |
277 | # Point-wise linear projection
278 | self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
279 | self.bn3 = norm_act_layer(out_chs, apply_act=False)
280 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
281 |
282 | def feature_info(self, location):
283 | if location == 'expansion': # after SE, input to PWL
284 | return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
285 | else: # location == 'bottleneck', block output
286 | return dict(module='', num_chs=self.conv_pwl.out_channels)
287 |
288 | def forward(self, x):
289 | shortcut = x
290 | if self.conv_s2d is not None:
291 | x = self.conv_s2d(x)
292 | x = self.bn_s2d(x)
293 | x = self.conv_pw(x)
294 | x = self.bn1(x)
295 | x = self.conv_dw(x)
296 | x = self.bn2(x)
297 | x = self.aa(x)
298 | x = self.se(x)
299 | x = self.conv_pwl(x)
300 | x = self.bn3(x)
301 | if self.has_skip:
302 | x = self.drop_path(x) + shortcut
303 | return x
304 |
305 |
306 | class LayerScale2d(nn.Module):
307 | def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
308 | super().__init__()
309 | self.inplace = inplace
310 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
311 |
312 | def forward(self, x):
313 | gamma = self.gamma.view(1, -1, 1, 1)
314 | return x.mul_(gamma) if self.inplace else x * gamma
315 |
316 |
317 | class UniversalInvertedResidual(nn.Module):
318 | """ Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB)
319 |
320 | For MobileNetV4 - https://arxiv.org/abs/, referenced from
321 | https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778
322 | """
323 |
324 | def __init__(
325 | self,
326 | in_chs: int,
327 | out_chs: int,
328 | dw_kernel_size_start: int = 0,
329 | dw_kernel_size_mid: int = 3,
330 | dw_kernel_size_end: int = 0,
331 | stride: int = 1,
332 | dilation: int = 1,
333 | group_size: int = 1,
334 | pad_type: str = '',
335 | noskip: bool = False,
336 | exp_ratio: float = 1.0,
337 | act_layer: LayerType = nn.ReLU,
338 | norm_layer: LayerType = nn.BatchNorm2d,
339 | aa_layer: Optional[LayerType] = None,
340 | se_layer: Optional[ModuleType] = None,
341 | conv_kwargs: Optional[Dict] = None,
342 | drop_path_rate: float = 0.,
343 | layer_scale_init_value: Optional[float] = 1e-5,
344 | ):
345 | super(UniversalInvertedResidual, self).__init__()
346 | conv_kwargs = conv_kwargs or {}
347 | self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
348 | if stride > 1:
349 | assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
350 |
351 | # FIXME dilation isn't right w/ extra ks > 1 convs
352 | if dw_kernel_size_start:
353 | dw_start_stride = stride if not dw_kernel_size_mid else 1
354 | dw_start_groups = num_groups(group_size, in_chs)
355 | self.dw_start = ConvNormAct(
356 | in_chs, in_chs, dw_kernel_size_start,
357 | stride=dw_start_stride,
358 | dilation=dilation, # FIXME
359 | groups=dw_start_groups,
360 | padding=pad_type,
361 | apply_act=False,
362 | act_layer=act_layer,
363 | norm_layer=norm_layer,
364 | aa_layer=aa_layer,
365 | **conv_kwargs,
366 | )
367 | else:
368 | self.dw_start = nn.Identity()
369 |
370 | # Point-wise expansion
371 | mid_chs = make_divisible(in_chs * exp_ratio)
372 | self.pw_exp = ConvNormAct(
373 | in_chs, mid_chs, 1,
374 | padding=pad_type,
375 | act_layer=act_layer,
376 | norm_layer=norm_layer,
377 | **conv_kwargs,
378 | )
379 |
380 | # Middle depth-wise convolution
381 | if dw_kernel_size_mid:
382 | groups = num_groups(group_size, mid_chs)
383 | self.dw_mid = ConvNormAct(
384 | mid_chs, mid_chs, dw_kernel_size_mid,
385 | stride=stride,
386 | dilation=dilation, # FIXME
387 | groups=groups,
388 | padding=pad_type,
389 | act_layer=act_layer,
390 | norm_layer=norm_layer,
391 | aa_layer=aa_layer,
392 | **conv_kwargs,
393 | )
394 | else:
395 | # keeping mid as identity so it can be hooked more easily for features
396 | self.dw_mid = nn.Identity()
397 |
398 | # Squeeze-and-excitation
399 | self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
400 |
401 | # Point-wise linear projection
402 | self.pw_proj = ConvNormAct(
403 | mid_chs, out_chs, 1,
404 | padding=pad_type,
405 | apply_act=False,
406 | act_layer=act_layer,
407 | norm_layer=norm_layer,
408 | **conv_kwargs,
409 | )
410 |
411 | if dw_kernel_size_end:
412 | dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1
413 | dw_end_groups = num_groups(group_size, out_chs)
414 | if dw_end_stride > 1:
415 | assert not aa_layer
416 | self.dw_end = ConvNormAct(
417 | out_chs, out_chs, dw_kernel_size_end,
418 | stride=dw_end_stride,
419 | dilation=dilation,
420 | groups=dw_end_groups,
421 | padding=pad_type,
422 | apply_act=False,
423 | act_layer=act_layer,
424 | norm_layer=norm_layer,
425 | **conv_kwargs,
426 | )
427 | else:
428 | self.dw_end = nn.Identity()
429 |
430 | if layer_scale_init_value is not None:
431 | self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
432 | else:
433 | self.layer_scale = nn.Identity()
434 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
435 |
436 | def feature_info(self, location):
437 | if location == 'expansion': # after SE, input to PWL
438 | return dict(module='pw_proj.conv', hook_type='forward_pre', num_chs=self.pw_proj.conv.in_channels)
439 | else: # location == 'bottleneck', block output
440 | return dict(module='', num_chs=self.pw_proj.conv.out_channels)
441 |
442 | def forward(self, x):
443 | shortcut = x
444 | x = self.dw_start(x)
445 | x = self.pw_exp(x)
446 | x = self.dw_mid(x)
447 | x = self.se(x)
448 | x = self.pw_proj(x)
449 | x = self.dw_end(x)
450 | x = self.layer_scale(x)
451 | if self.has_skip:
452 | x = self.drop_path(x) + shortcut
453 | return x
454 |
455 |
456 | class MobileAttention(nn.Module):
457 | """ Mobile Attention Block
458 |
459 | For MobileNetV4 - https://arxiv.org/abs/, referenced from
460 | https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504
461 | """
462 | def __init__(
463 | self,
464 | in_chs: int,
465 | out_chs: int,
466 | stride: int = 1,
467 | dw_kernel_size: int = 3,
468 | dilation: int = 1,
469 | group_size: int = 1,
470 | pad_type: str = '',
471 | num_heads: int = 8,
472 | key_dim: int = 64,
473 | value_dim: int = 64,
474 | use_multi_query: bool = False,
475 | query_strides: int = (1, 1),
476 | kv_stride: int = 1,
477 | cpe_dw_kernel_size: int = 3,
478 | noskip: bool = False,
479 | act_layer: LayerType = nn.ReLU,
480 | norm_layer: LayerType = nn.BatchNorm2d,
481 | aa_layer: Optional[LayerType] = None,
482 | drop_path_rate: float = 0.,
483 | attn_drop: float = 0.0,
484 | proj_drop: float = 0.0,
485 | layer_scale_init_value: Optional[float] = 1e-5,
486 | use_bias: bool = False,
487 | use_cpe: bool = False,
488 | ):
489 | super(MobileAttention, self).__init__()
490 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
491 | self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
492 | self.query_strides = to_2tuple(query_strides)
493 | self.kv_stride = kv_stride
494 | self.has_query_stride = any([s > 1 for s in self.query_strides])
495 |
496 | # This CPE is different than the one suggested in the original paper.
497 | # https://arxiv.org/abs/2102.10882
498 | # 1. Rather than adding one CPE before the attention blocks, we add a CPE
499 | # into every attention block.
500 | # 2. We replace the expensive Conv2D by a Seperable DW Conv.
501 | if use_cpe:
502 | self.conv_cpe_dw = create_conv2d(
503 | in_chs, in_chs,
504 | kernel_size=cpe_dw_kernel_size,
505 | dilation=dilation,
506 | depthwise=True,
507 | bias=True,
508 | )
509 | else:
510 | self.conv_cpe_dw = None
511 |
512 | self.norm = norm_act_layer(in_chs, apply_act=False)
513 |
514 | if num_heads is None:
515 | assert in_chs % key_dim == 0
516 | num_heads = in_chs // key_dim
517 |
518 | if use_multi_query:
519 | self.attn = MultiQueryAttention2d(
520 | in_chs,
521 | dim_out=out_chs,
522 | num_heads=num_heads,
523 | key_dim=key_dim,
524 | value_dim=value_dim,
525 | query_strides=query_strides,
526 | kv_stride=kv_stride,
527 | dilation=dilation,
528 | padding=pad_type,
529 | dw_kernel_size=dw_kernel_size,
530 | attn_drop=attn_drop,
531 | proj_drop=proj_drop,
532 | #bias=use_bias, # why not here if used w/ mhsa?
533 | )
534 | else:
535 | self.attn = Attention2d(
536 | in_chs,
537 | dim_out=out_chs,
538 | num_heads=num_heads,
539 | attn_drop=attn_drop,
540 | proj_drop=proj_drop,
541 | bias=use_bias,
542 | )
543 |
544 | if layer_scale_init_value is not None:
545 | self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
546 | else:
547 | self.layer_scale = nn.Identity()
548 |
549 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
550 |
551 | def feature_info(self, location):
552 | if location == 'expansion': # after SE, input to PW
553 | return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
554 | else: # location == 'bottleneck', block output
555 | return dict(module='', num_chs=self.conv_pw.out_channels)
556 |
557 | def forward(self, x):
558 | if self.conv_cpe_dw is not None:
559 | x_cpe = self.conv_cpe_dw(x)
560 | x = x + x_cpe
561 |
562 | shortcut = x
563 | x = self.norm(x)
564 | x = self.attn(x)
565 | x = self.layer_scale(x)
566 | if self.has_skip:
567 | x = self.drop_path(x) + shortcut
568 |
569 | return x
570 |
571 |
572 | class CondConvResidual(InvertedResidual):
573 | """ Inverted residual block w/ CondConv routing"""
574 |
575 | def __init__(
576 | self,
577 | in_chs: int,
578 | out_chs: int,
579 | dw_kernel_size: int = 3,
580 | stride: int = 1,
581 | dilation: int = 1,
582 | group_size: int = 1,
583 | pad_type: str = '',
584 | noskip: bool = False,
585 | exp_ratio: float = 1.0,
586 | exp_kernel_size: int = 1,
587 | pw_kernel_size: int = 1,
588 | act_layer: LayerType = nn.ReLU,
589 | norm_layer: LayerType = nn.BatchNorm2d,
590 | aa_layer: Optional[LayerType] = None,
591 | se_layer: Optional[ModuleType] = None,
592 | num_experts: int = 0,
593 | drop_path_rate: float = 0.,
594 | ):
595 |
596 | self.num_experts = num_experts
597 | conv_kwargs = dict(num_experts=self.num_experts)
598 | super(CondConvResidual, self).__init__(
599 | in_chs,
600 | out_chs,
601 | dw_kernel_size=dw_kernel_size,
602 | stride=stride,
603 | dilation=dilation,
604 | group_size=group_size,
605 | pad_type=pad_type,
606 | noskip=noskip,
607 | exp_ratio=exp_ratio,
608 | exp_kernel_size=exp_kernel_size,
609 | pw_kernel_size=pw_kernel_size,
610 | act_layer=act_layer,
611 | norm_layer=norm_layer,
612 | aa_layer=aa_layer,
613 | se_layer=se_layer,
614 | conv_kwargs=conv_kwargs,
615 | drop_path_rate=drop_path_rate,
616 | )
617 | self.routing_fn = nn.Linear(in_chs, self.num_experts)
618 |
619 | def forward(self, x):
620 | shortcut = x
621 | pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing
622 | routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
623 | x = self.conv_pw(x, routing_weights)
624 | x = self.bn1(x)
625 | x = self.conv_dw(x, routing_weights)
626 | x = self.bn2(x)
627 | x = self.se(x)
628 | x = self.conv_pwl(x, routing_weights)
629 | x = self.bn3(x)
630 | if self.has_skip:
631 | x = self.drop_path(x) + shortcut
632 | return x
633 |
634 |
635 | class EdgeResidual(nn.Module):
636 | """ Residual block with expansion convolution followed by pointwise-linear w/ stride
637 |
638 | Originally introduced in `EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML`
639 | - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html
640 |
641 | This layer is also called FusedMBConv in the MobileDet, EfficientNet-X, and EfficientNet-V2 papers
642 | * MobileDet - https://arxiv.org/abs/2004.14525
643 | * EfficientNet-X - https://arxiv.org/abs/2102.05610
644 | * EfficientNet-V2 - https://arxiv.org/abs/2104.00298
645 | """
646 |
647 | def __init__(
648 | self,
649 | in_chs: int,
650 | out_chs: int,
651 | exp_kernel_size: int = 3,
652 | stride: int = 1,
653 | dilation: int = 1,
654 | group_size: int = 0,
655 | pad_type: str = '',
656 | force_in_chs: int = 0,
657 | noskip: bool = False,
658 | exp_ratio: float = 1.0,
659 | pw_kernel_size: int = 1,
660 | act_layer: LayerType = nn.ReLU,
661 | norm_layer: LayerType = nn.BatchNorm2d,
662 | aa_layer: Optional[LayerType] = None,
663 | se_layer: Optional[ModuleType] = None,
664 | drop_path_rate: float = 0.,
665 | ):
666 | super(EdgeResidual, self).__init__()
667 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
668 | if force_in_chs > 0:
669 | mid_chs = make_divisible(force_in_chs * exp_ratio)
670 | else:
671 | mid_chs = make_divisible(in_chs * exp_ratio)
672 | groups = num_groups(group_size, mid_chs) # NOTE: Using out_chs of conv_exp for groups calc
673 | self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
674 | use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
675 |
676 | # Expansion convolution
677 | self.conv_exp = create_conv2d(
678 | in_chs, mid_chs, exp_kernel_size,
679 | stride=1 if use_aa else stride,
680 | dilation=dilation, groups=groups, padding=pad_type)
681 | self.bn1 = norm_act_layer(mid_chs, inplace=True)
682 |
683 | self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa)
684 |
685 | # Squeeze-and-excitation
686 | self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
687 |
688 | # Point-wise linear projection
689 | self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
690 | self.bn2 = norm_act_layer(out_chs, apply_act=False)
691 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
692 |
693 | def feature_info(self, location):
694 | if location == 'expansion': # after SE, before PWL
695 | return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
696 | else: # location == 'bottleneck', block output
697 | return dict(module='', num_chs=self.conv_pwl.out_channels)
698 |
699 | def forward(self, x):
700 | shortcut = x
701 | x = self.conv_exp(x)
702 | x = self.bn1(x)
703 | x = self.aa(x)
704 | x = self.se(x)
705 | x = self.conv_pwl(x)
706 | x = self.bn2(x)
707 | if self.has_skip:
708 | x = self.drop_path(x) + shortcut
709 | return x
--------------------------------------------------------------------------------
/models/extra_attention_block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class MultiScaleAttentionGate(nn.Module):
6 | """
7 | Multi-scale attention gate
8 | """
9 | def __init__(self, channel):
10 | super(MultiScaleAttentionGate, self).__init__()
11 | self.channel = channel
12 | self.pointwiseConv = nn.Sequential(
13 | nn.Conv2d(self.channel, self.channel, kernel_size=1, padding=0, bias=True),
14 | nn.BatchNorm2d(self.channel),
15 | )
16 | self.ordinaryConv = nn.Sequential(
17 | nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=1, stride=1, bias=True),
18 | nn.BatchNorm2d(self.channel),
19 | )
20 | self.dilationConv = nn.Sequential(
21 | nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=2, stride=1, dilation=2, bias=True),
22 | nn.BatchNorm2d(self.channel),
23 | )
24 | self.voteConv = nn.Sequential(
25 | nn.Conv2d(self.channel * 3, self.channel, kernel_size=(1, 1)),
26 | nn.BatchNorm2d(self.channel),
27 | nn.GELU()
28 | )
29 | self.relu = nn.ReLU(inplace=True)
30 |
31 | def forward(self, x):
32 | x1 = self.pointwiseConv(x)
33 | x2 = self.ordinaryConv(x)
34 | x3 = self.dilationConv(x)
35 | _x = self.relu(torch.cat((x1, x2, x3), dim=1))
36 | _x = self.voteConv(_x)
37 | x = x + x * _x
38 | return x
39 |
40 | # if __name__ == '__main__':
41 | # net = MultiScaleAttentionGate(960)
42 | # X = torch.randn(1, 960, 7, 7)
43 | # y = net(X)
44 | # print(y.shape)
--------------------------------------------------------------------------------
/models/model_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional
2 |
3 | import logging
4 | import math
5 | import re
6 | from copy import deepcopy
7 | from functools import partial
8 | from typing import Any, Dict, List
9 |
10 | import torch.nn as nn
11 |
12 | from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, LayerType
13 | from models.blocks import make_divisible
14 | from models.blocks import *
15 |
16 |
17 | def named_modules(
18 | module: nn.Module,
19 | name: str = '',
20 | depth_first: bool = True,
21 | include_root: bool = False,
22 | ):
23 | if not depth_first and include_root:
24 | yield name, module
25 | for child_name, child_module in module.named_children():
26 | child_name = '.'.join((name, child_name)) if name else child_name
27 | yield from named_modules(
28 | module=child_module, name=child_name, depth_first=depth_first, include_root=True)
29 | if depth_first and include_root:
30 | yield name, module
31 |
32 |
33 |
34 | __all__ = ["EfficientNetBuilder", "BlockArgs", "decode_arch_def", "efficientnet_init_weights",
35 | 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
36 |
37 | _logger = logging.getLogger(__name__)
38 |
39 |
40 | _DEBUG_BUILDER = False
41 |
42 | # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
43 | # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
44 | # NOTE: momentum varies btw .99 and .9997 depending on source
45 | # .99 in official TF TPU impl
46 | # .9997 (/w .999 in search space) for paper
47 | BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
48 | BN_EPS_TF_DEFAULT = 1e-3
49 | _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
50 |
51 | BlockArgs = List[List[Dict[str, Any]]]
52 |
53 |
54 | def get_bn_args_tf():
55 | return _BN_ARGS_TF.copy()
56 |
57 |
58 | def resolve_bn_args(kwargs):
59 | bn_args = {}
60 | bn_momentum = kwargs.pop('bn_momentum', None)
61 | if bn_momentum is not None:
62 | bn_args['momentum'] = bn_momentum
63 | bn_eps = kwargs.pop('bn_eps', None)
64 | if bn_eps is not None:
65 | bn_args['eps'] = bn_eps
66 | return bn_args
67 |
68 |
69 | def resolve_act_layer(kwargs, default='relu'):
70 | return get_act_layer(kwargs.pop('act_layer', default))
71 |
72 |
73 | def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
74 | """Round number of filters based on depth multiplier."""
75 | if not multiplier:
76 | return channels
77 | return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
78 |
79 |
80 | def _log_info_if(msg, condition):
81 | if condition:
82 | _logger.info(msg)
83 |
84 |
85 | def _parse_ksize(ss):
86 | if ss.isdigit():
87 | return int(ss)
88 | else:
89 | return [int(k) for k in ss.split('.')]
90 |
91 |
92 | def _decode_block_str(block_str):
93 | """ Decode block definition string
94 |
95 | Gets a list of block arg (dicts) through a string notation of arguments.
96 | E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
97 |
98 | All args can exist in any order with the exception of the leading string which
99 | is assumed to indicate the block type.
100 |
101 | leading string - block type (
102 | ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
103 | r - number of repeat blocks,
104 | k - kernel size,
105 | s - strides (1-9),
106 | e - expansion ratio,
107 | c - output channels,
108 | se - squeeze/excitation ratio
109 | n - activation fn ('re', 'r6', 'hs', or 'sw')
110 | Args:
111 | block_str: a string representation of block arguments.
112 | Returns:
113 | A list of block args (dicts)
114 | Raises:
115 | ValueError: if the string def not properly specified (TODO)
116 | """
117 | assert isinstance(block_str, str)
118 | ops = block_str.split('_')
119 | block_type = ops[0] # take the block type off the front
120 | ops = ops[1:]
121 | options = {}
122 | skip = None
123 | for op in ops:
124 | # string options being checked on individual basis, combine if they grow
125 | if op == 'noskip':
126 | skip = False # force no skip connection
127 | elif op == 'skip':
128 | skip = True # force a skip connection
129 | elif op.startswith('n'):
130 | # activation fn
131 | key = op[0]
132 | v = op[1:]
133 | if v == 're':
134 | value = get_act_layer('relu')
135 | elif v == 'r6':
136 | value = get_act_layer('relu6')
137 | elif v == 'hs':
138 | value = get_act_layer('hard_swish')
139 | elif v == 'sw':
140 | value = get_act_layer('swish') # aka SiLU
141 | elif v == 'mi':
142 | value = get_act_layer('mish')
143 | else:
144 | continue
145 | options[key] = value
146 | else:
147 | # all numeric options
148 | splits = re.split(r'(\d.*)', op)
149 | if len(splits) >= 2:
150 | key, value = splits[:2]
151 | options[key] = value
152 |
153 | # if act_layer is None, the model default (passed to model init) will be used
154 | act_layer = options['n'] if 'n' in options else None
155 | start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
156 | end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
157 | force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
158 | num_repeat = int(options['r'])
159 |
160 | # each type of block has different valid arguments, fill accordingly
161 | block_args = dict(
162 | block_type=block_type,
163 | out_chs=int(options['c']),
164 | stride=int(options['s']),
165 | act_layer=act_layer,
166 | )
167 | if block_type == 'ir':
168 | block_args.update(dict(
169 | dw_kernel_size=_parse_ksize(options['k']),
170 | exp_kernel_size=start_kernel_size,
171 | pw_kernel_size=end_kernel_size,
172 | exp_ratio=float(options['e']),
173 | se_ratio=float(options.get('se', 0.)),
174 | noskip=skip is False,
175 | s2d=int(options.get('d', 0)) > 0,
176 | ))
177 | if 'cc' in options:
178 | block_args['num_experts'] = int(options['cc'])
179 | elif block_type == 'ds' or block_type == 'dsa':
180 | block_args.update(dict(
181 | dw_kernel_size=_parse_ksize(options['k']),
182 | pw_kernel_size=end_kernel_size,
183 | se_ratio=float(options.get('se', 0.)),
184 | pw_act=block_type == 'dsa',
185 | noskip=block_type == 'dsa' or skip is False,
186 | s2d=int(options.get('d', 0)) > 0,
187 | ))
188 | elif block_type == 'er':
189 | block_args.update(dict(
190 | exp_kernel_size=_parse_ksize(options['k']),
191 | pw_kernel_size=end_kernel_size,
192 | exp_ratio=float(options['e']),
193 | force_in_chs=force_in_chs,
194 | se_ratio=float(options.get('se', 0.)),
195 | noskip=skip is False,
196 | ))
197 | elif block_type == 'cn':
198 | block_args.update(dict(
199 | kernel_size=int(options['k']),
200 | skip=skip is True,
201 | ))
202 | elif block_type == 'uir':
203 | # override exp / proj kernels for start/end in uir block
204 | start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0
205 | end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0
206 | block_args.update(dict(
207 | dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start
208 | dw_kernel_size_mid=_parse_ksize(options['k']),
209 | dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end
210 | exp_ratio=float(options['e']),
211 | se_ratio=float(options.get('se', 0.)),
212 | noskip=skip is False,
213 | ))
214 | elif block_type == 'mha':
215 | kv_dim = int(options['d'])
216 | block_args.update(dict(
217 | dw_kernel_size=_parse_ksize(options['k']),
218 | num_heads=int(options['h']),
219 | key_dim=kv_dim,
220 | value_dim=kv_dim,
221 | kv_stride=int(options.get('v', 1)),
222 | noskip=skip is False,
223 | ))
224 | elif block_type == 'mqa':
225 | kv_dim = int(options['d'])
226 | block_args.update(dict(
227 | dw_kernel_size=_parse_ksize(options['k']),
228 | num_heads=int(options['h']),
229 | key_dim=kv_dim,
230 | value_dim=kv_dim,
231 | kv_stride=int(options.get('v', 1)),
232 | noskip=skip is False,
233 | ))
234 | else:
235 | assert False, 'Unknown block type (%s)' % block_type
236 |
237 | if 'gs' in options:
238 | block_args['group_size'] = int(options['gs'])
239 |
240 | return block_args, num_repeat
241 |
242 |
243 | def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
244 | """ Per-stage depth scaling
245 | Scales the block repeats in each stage. This depth scaling impl maintains
246 | compatibility with the EfficientNet scaling method, while allowing sensible
247 | scaling for other models that may have multiple block arg definitions in each stage.
248 | """
249 |
250 | # We scale the total repeat count for each stage, there may be multiple
251 | # block arg defs per stage so we need to sum.
252 | num_repeat = sum(repeats)
253 | if depth_trunc == 'round':
254 | # Truncating to int by rounding allows stages with few repeats to remain
255 | # proportionally smaller for longer. This is a good choice when stage definitions
256 | # include single repeat stages that we'd prefer to keep that way as long as possible
257 | num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
258 | else:
259 | # The default for EfficientNet truncates repeats to int via 'ceil'.
260 | # Any multiplier > 1.0 will result in an increased depth for every stage.
261 | num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
262 |
263 | # Proportionally distribute repeat count scaling to each block definition in the stage.
264 | # Allocation is done in reverse as it results in the first block being less likely to be scaled.
265 | # The first block makes less sense to repeat in most of the arch definitions.
266 | repeats_scaled = []
267 | for r in repeats[::-1]:
268 | rs = max(1, round((r / num_repeat * num_repeat_scaled)))
269 | repeats_scaled.append(rs)
270 | num_repeat -= r
271 | num_repeat_scaled -= rs
272 | repeats_scaled = repeats_scaled[::-1]
273 |
274 | # Apply the calculated scaling to each block arg in the stage
275 | sa_scaled = []
276 | for ba, rep in zip(stack_args, repeats_scaled):
277 | sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
278 | return sa_scaled
279 |
280 |
281 | def decode_arch_def(
282 | arch_def,
283 | depth_multiplier=1.0,
284 | depth_trunc='ceil',
285 | experts_multiplier=1,
286 | fix_first_last=False,
287 | group_size=None,
288 | ):
289 | """ Decode block architecture definition strings -> block kwargs
290 |
291 | Args:
292 | arch_def: architecture definition strings, list of list of strings
293 | depth_multiplier: network depth multiplier
294 | depth_trunc: networ depth truncation mode when applying multiplier
295 | experts_multiplier: CondConv experts multiplier
296 | fix_first_last: fix first and last block depths when multiplier is applied
297 | group_size: group size override for all blocks that weren't explicitly set in arch string
298 |
299 | Returns:
300 | list of list of block kwargs
301 | """
302 | arch_args = []
303 | if isinstance(depth_multiplier, tuple):
304 | assert len(depth_multiplier) == len(arch_def)
305 | else:
306 | depth_multiplier = (depth_multiplier,) * len(arch_def)
307 | for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
308 | assert isinstance(block_strings, list)
309 | stack_args = []
310 | repeats = []
311 | for block_str in block_strings:
312 | assert isinstance(block_str, str)
313 | ba, rep = _decode_block_str(block_str)
314 | if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
315 | ba['num_experts'] *= experts_multiplier
316 | if group_size is not None:
317 | ba.setdefault('group_size', group_size)
318 | stack_args.append(ba)
319 | repeats.append(rep)
320 | if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
321 | arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
322 | else:
323 | arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
324 | return arch_args
325 |
326 |
327 | class EfficientNetBuilder:
328 | """ Build Trunk Blocks
329 |
330 | This ended up being somewhat of a cross between
331 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
332 | and
333 | https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
334 |
335 | """
336 | def __init__(
337 | self,
338 | output_stride: int = 32,
339 | pad_type: str = '',
340 | round_chs_fn: Callable = round_channels,
341 | se_from_exp: bool = False,
342 | act_layer: Optional[LayerType] = None,
343 | norm_layer: Optional[LayerType] = None,
344 | aa_layer: Optional[LayerType] = None,
345 | se_layer: Optional[LayerType] = None,
346 | drop_path_rate: float = 0.,
347 | layer_scale_init_value: Optional[float] = None,
348 | feature_location: str = '',
349 | ):
350 | self.output_stride = output_stride
351 | self.pad_type = pad_type
352 | self.round_chs_fn = round_chs_fn
353 | self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
354 | self.act_layer = act_layer
355 | self.norm_layer = norm_layer
356 | self.aa_layer = aa_layer
357 | self.se_layer = get_attn(se_layer)
358 | try:
359 | self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
360 | self.se_has_ratio = True
361 | except TypeError:
362 | self.se_has_ratio = False
363 | self.drop_path_rate = drop_path_rate
364 | self.layer_scale_init_value = layer_scale_init_value
365 | if feature_location == 'depthwise':
366 | # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
367 | _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
368 | feature_location = 'expansion'
369 | self.feature_location = feature_location
370 | assert feature_location in ('bottleneck', 'expansion', '')
371 | self.verbose = _DEBUG_BUILDER
372 |
373 | # state updated during build, consumed by model
374 | self.in_chs = None
375 | self.features = []
376 |
377 | def _make_block(self, ba, block_idx, block_count):
378 | drop_path_rate = self.drop_path_rate * block_idx / block_count
379 | bt = ba.pop('block_type')
380 | ba['in_chs'] = self.in_chs
381 | ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
382 | s2d = ba.get('s2d', 0)
383 | if s2d > 0:
384 | # adjust while space2depth active
385 | ba['out_chs'] *= 4
386 | if 'force_in_chs' in ba and ba['force_in_chs']:
387 | # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
388 | ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
389 | ba['pad_type'] = self.pad_type
390 | # block act fn overrides the model default
391 | ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
392 | assert ba['act_layer'] is not None
393 | ba['norm_layer'] = self.norm_layer
394 | ba['drop_path_rate'] = drop_path_rate
395 |
396 | if self.aa_layer is not None:
397 | ba['aa_layer'] = self.aa_layer
398 |
399 | se_ratio = ba.pop('se_ratio', None)
400 | if se_ratio and self.se_layer is not None:
401 | if not self.se_from_exp:
402 | # adjust se_ratio by expansion ratio if calculating se channels from block input
403 | se_ratio /= ba.get('exp_ratio', 1.0)
404 | if s2d == 1:
405 | # adjust for start of space2depth
406 | se_ratio /= 4
407 | if self.se_has_ratio:
408 | ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
409 | else:
410 | ba['se_layer'] = self.se_layer
411 |
412 | if bt == 'ir':
413 | _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
414 | block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
415 | elif bt == 'ds' or bt == 'dsa':
416 | _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
417 | block = DepthwiseSeparableConv(**ba)
418 | elif bt == 'er':
419 | _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
420 | block = EdgeResidual(**ba)
421 | elif bt == 'cn':
422 | _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
423 | block = ConvBnAct(**ba)
424 | elif bt == 'uir':
425 | _log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
426 | block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value)
427 | elif bt == 'mqa':
428 | _log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
429 | block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value)
430 | elif bt == 'mha':
431 | _log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
432 | block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value)
433 | else:
434 | assert False, 'Unknown block type (%s) while building model.' % bt
435 |
436 | self.in_chs = ba['out_chs'] # update in_chs for arg of next block
437 | return block
438 |
439 | def __call__(self, in_chs, model_block_args):
440 | """ Build the blocks
441 | Args:
442 | in_chs: Number of input-channels passed to first block
443 | model_block_args: A list of lists, outer list defines stages, inner
444 | list contains strings defining block configuration(s)
445 | Return:
446 | List of block stacks (each stack wrapped in nn.Sequential)
447 | """
448 | _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
449 | self.in_chs = in_chs
450 | total_block_count = sum([len(x) for x in model_block_args])
451 | total_block_idx = 0
452 | current_stride = 2
453 | current_dilation = 1
454 | stages = []
455 | if model_block_args[0][0]['stride'] > 1:
456 | # if the first block starts with a stride, we need to extract first level feat from stem
457 | feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
458 | self.features.append(feature_info)
459 |
460 | # outer list of block_args defines the stacks
461 | space2depth = 0
462 | for stack_idx, stack_args in enumerate(model_block_args):
463 | last_stack = stack_idx + 1 == len(model_block_args)
464 | _log_info_if('Stack: {}'.format(stack_idx), self.verbose)
465 | assert isinstance(stack_args, list)
466 |
467 | blocks = []
468 | # each stack (stage of blocks) contains a list of block arguments
469 | for block_idx, block_args in enumerate(stack_args):
470 | last_block = block_idx + 1 == len(stack_args)
471 | _log_info_if(' Block: {}'.format(block_idx), self.verbose)
472 |
473 | assert block_args['stride'] in (1, 2)
474 | if block_idx >= 1: # only the first block in any stack can have a stride > 1
475 | block_args['stride'] = 1
476 |
477 | if not space2depth and block_args.pop('s2d', False):
478 | assert block_args['stride'] == 1
479 | space2depth = 1
480 |
481 | if space2depth > 0:
482 | # FIXME s2d is a WIP
483 | if space2depth == 2 and block_args['stride'] == 2:
484 | block_args['stride'] = 1
485 | # to end s2d region, need to correct expansion and se ratio relative to input
486 | block_args['exp_ratio'] /= 4
487 | space2depth = 0
488 | else:
489 | block_args['s2d'] = space2depth
490 |
491 | extract_features = False
492 | if last_block:
493 | next_stack_idx = stack_idx + 1
494 | extract_features = next_stack_idx >= len(model_block_args) or \
495 | model_block_args[next_stack_idx][0]['stride'] > 1
496 |
497 | next_dilation = current_dilation
498 | if block_args['stride'] > 1:
499 | next_output_stride = current_stride * block_args['stride']
500 | if next_output_stride > self.output_stride:
501 | next_dilation = current_dilation * block_args['stride']
502 | block_args['stride'] = 1
503 | _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
504 | self.output_stride), self.verbose)
505 | else:
506 | current_stride = next_output_stride
507 | block_args['dilation'] = current_dilation
508 | if next_dilation != current_dilation:
509 | current_dilation = next_dilation
510 |
511 | # create the block
512 | block = self._make_block(block_args, total_block_idx, total_block_count)
513 | blocks.append(block)
514 |
515 | if space2depth == 1:
516 | space2depth = 2
517 |
518 | # stash feature module name and channel info for model feature extraction
519 | if extract_features:
520 | feature_info = dict(
521 | stage=stack_idx + 1,
522 | reduction=current_stride,
523 | **block.feature_info(self.feature_location),
524 | )
525 | leaf_name = feature_info.get('module', '')
526 | if leaf_name:
527 | feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
528 | else:
529 | assert last_block
530 | feature_info['module'] = f'blocks.{stack_idx}'
531 | self.features.append(feature_info)
532 |
533 | total_block_idx += 1 # incr global block idx (across all stacks)
534 | stages.append(nn.Sequential(*blocks))
535 | return stages
536 |
537 |
538 | def _init_weight_goog(m, n='', fix_group_fanout=True):
539 | """ Weight initialization as per Tensorflow official implementations.
540 |
541 | Args:
542 | m (nn.Module): module to init
543 | n (str): module name
544 | fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
545 |
546 | Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
547 | * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
548 | * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
549 | """
550 | if isinstance(m, CondConv2d):
551 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
552 | if fix_group_fanout:
553 | fan_out //= m.groups
554 | init_weight_fn = get_condconv_initializer(
555 | lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
556 | init_weight_fn(m.weight)
557 | if m.bias is not None:
558 | nn.init.zeros_(m.bias)
559 | elif isinstance(m, nn.Conv2d):
560 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
561 | if fix_group_fanout:
562 | fan_out //= m.groups
563 | nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
564 | if m.bias is not None:
565 | nn.init.zeros_(m.bias)
566 | elif isinstance(m, nn.BatchNorm2d):
567 | nn.init.ones_(m.weight)
568 | nn.init.zeros_(m.bias)
569 | elif isinstance(m, nn.Linear):
570 | fan_out = m.weight.size(0) # fan-out
571 | fan_in = 0
572 | if 'routing_fn' in n:
573 | fan_in = m.weight.size(1)
574 | init_range = 1.0 / math.sqrt(fan_in + fan_out)
575 | nn.init.uniform_(m.weight, -init_range, init_range)
576 | nn.init.zeros_(m.bias)
577 |
578 |
579 | def efficientnet_init_weights(model: nn.Module, init_fn=None):
580 | init_fn = init_fn or _init_weight_goog
581 | for n, m in model.named_modules():
582 | init_fn(m, n)
583 |
584 | # iterate and call any module.init_weights() fn, children first
585 | for n, m in named_modules(model):
586 | if hasattr(m, 'init_weights'):
587 | m.init_weights()
--------------------------------------------------------------------------------
/onnx_export.py:
--------------------------------------------------------------------------------
1 | """
2 | ONNX export script
3 | Export PyTorch models as ONNX graphs.
4 | This export script originally started as an adaptation of code snippets found at
5 | https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
6 |
7 | The default parameters work with PyTorch 2.0.1 and ONNX 1.13 and produce an optimal ONNX graph
8 | for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
9 | """
10 |
11 | import argparse
12 | import torch
13 | import numpy as np
14 | import onnx
15 | import models
16 | from copy import deepcopy
17 | from timm.models import create_model
18 | from typing import Optional, Tuple, List
19 |
20 |
21 |
22 | ## python onnx_export.py --model mobilenetv4_small ./mobilenetv4_small.onnx
23 |
24 | parser = argparse.ArgumentParser(description='PyTorch ONNX Deployment')
25 | parser.add_argument('--output', metavar='ONNX_FILE', default=None, type=str,
26 | help='output model filename')
27 |
28 | # Model & datasets params
29 | parser.add_argument('--model', default='mobilenetv4_conv_large', type=str, metavar='MODEL',
30 | choices=['mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large_075',
31 | 'mobilenetv4_conv_large', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_medium',
32 | 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_medium_075',
33 | 'mobilenetv4_conv_small_035', 'mobilenetv4_conv_small_050', 'mobilenetv4_conv_blur_medium'],
34 | help='Name of model to train')
35 | parser.add_argument('--extra_attention_block', default=False, type=bool, help='Add an extra attention block')
36 | parser.add_argument('--checkpoint', default='./output/mobilenetv4_conv_large_best_checkpoint.pth', type=str, metavar='PATH',
37 | help='path to checkpoint (default: none)')
38 | parser.add_argument('--batch-size', default=1, type=int,
39 | metavar='N', help='mini-batch size (default: 1)')
40 | parser.add_argument('--img-size', default=384, type=int,
41 | metavar='N', help='Input image dimension, uses model default if empty')
42 | parser.add_argument('--nb-classes', type=int, default=5,
43 | help='Number classes in datasets')
44 |
45 | parser.add_argument('--opset', type=int, default=10,
46 | help='ONNX opset to use (default: 10)')
47 | parser.add_argument('--keep-init', action='store_true', default=False,
48 | help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
49 | parser.add_argument('--aten-fallback', action='store_true', default=False,
50 | help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
51 | parser.add_argument('--dynamic-size', action='store_true', default=False,
52 | help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
53 | parser.add_argument('--check-forward', action='store_true', default=False,
54 | help='Do a full check of torch vs onnx forward after export.')
55 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
56 | help='Override mean pixel value of datasets')
57 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
58 | help='Override std deviation of of datasets')
59 | parser.add_argument('--reparam', default=False, action='store_true',
60 | help='Reparameterize model')
61 | parser.add_argument('--training', default=False, action='store_true',
62 | help='Export in training mode (default is eval)')
63 | parser.add_argument('--verbose', default=False, action='store_true',
64 | help='Extra stdout output')
65 | parser.add_argument('--dynamo', default=False, action='store_true',
66 | help='Use torch dynamo export.')
67 |
68 |
69 |
70 | def reparameterize_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
71 | if not inplace:
72 | model = deepcopy(model)
73 |
74 | def _fuse(m):
75 | for child_name, child in m.named_children():
76 | if hasattr(child, 'fuse'):
77 | setattr(m, child_name, child.fuse())
78 | elif hasattr(child, "reparameterize"):
79 | child.reparameterize()
80 | elif hasattr(child, "switch_to_deploy"):
81 | child.switch_to_deploy()
82 | _fuse(child)
83 |
84 | _fuse(model)
85 | return model
86 |
87 |
88 | def onnx_forward(onnx_file, example_input):
89 | import onnxruntime
90 |
91 | sess_options = onnxruntime.SessionOptions()
92 | session = onnxruntime.InferenceSession(onnx_file, sess_options)
93 | input_name = session.get_inputs()[0].name
94 | output = session.run([], {input_name: example_input.numpy()})
95 | output = output[0]
96 | return output
97 |
98 |
99 | def onnx_export(
100 | model: torch.nn.Module,
101 | output_file: str,
102 | example_input: Optional[torch.Tensor] = None,
103 | training: bool = False,
104 | verbose: bool = False,
105 | check: bool = True,
106 | check_forward: bool = False,
107 | batch_size: int = 64,
108 | input_size: Tuple[int, int, int] = None,
109 | opset: Optional[int] = None,
110 | dynamic_size: bool = False,
111 | aten_fallback: bool = False,
112 | keep_initializers: Optional[bool] = None,
113 | use_dynamo: bool = False,
114 | input_names: List[str] = None,
115 | output_names: List[str] = None,
116 | ):
117 | import onnx
118 |
119 | if training:
120 | training_mode = torch.onnx.TrainingMode.TRAINING
121 | model.train()
122 | else:
123 | training_mode = torch.onnx.TrainingMode.EVAL
124 | model.eval()
125 |
126 | if example_input is None:
127 | if not input_size:
128 | assert hasattr(model, 'default_cfg')
129 | input_size = model.default_cfg.get('input_size')
130 | example_input = torch.randn((batch_size,) + input_size, requires_grad=training)
131 |
132 | # Run model once before export trace, sets padding for models with Conv2dSameExport. This means
133 | # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
134 | # the input img_size specified in this script.
135 |
136 | # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
137 | # issues in the tracing of the dynamic padding or errors attempting to export the model after jit
138 | # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
139 | with torch.no_grad():
140 | original_out = model(example_input)
141 |
142 | print("==> Exporting model to ONNX format at '{}'".format(output_file))
143 |
144 | input_names = input_names or ["input0"]
145 | output_names = output_names or ["output0"]
146 |
147 | dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
148 | if dynamic_size:
149 | dynamic_axes['input0'][2] = 'height'
150 | dynamic_axes['input0'][3] = 'width'
151 |
152 | if aten_fallback:
153 | export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
154 | else:
155 | export_type = torch.onnx.OperatorExportTypes.ONNX
156 |
157 | if use_dynamo:
158 | export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size)
159 | export_output = torch.onnx.dynamo_export(
160 | model,
161 | example_input,
162 | export_options=export_options,
163 | )
164 | export_output.save(output_file)
165 | torch_out = None
166 | else:
167 | #TODO, for torch version >= 2.5, use torch.onnx.export()
168 | torch_out = torch.onnx._export(
169 | model,
170 | example_input,
171 | output_file,
172 | training=training_mode,
173 | export_params=True,
174 | verbose=verbose,
175 | input_names=input_names,
176 | output_names=output_names,
177 | keep_initializers_as_inputs=keep_initializers,
178 | dynamic_axes=dynamic_axes,
179 | opset_version=opset,
180 | operator_export_type=export_type
181 | )
182 |
183 | if check:
184 | print("==> Loading and checking exported model from '{}'".format(output_file))
185 | onnx_model = onnx.load(output_file)
186 | onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error
187 | if check_forward and not training:
188 | import numpy as np
189 | onnx_out = onnx_forward(output_file, example_input)
190 | if torch_out is not None:
191 | np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
192 | np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
193 | else:
194 | np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
195 |
196 |
197 | def main():
198 | args = parser.parse_args()
199 |
200 | if args.output == None:
201 | args.output = f'./{args.model}.onnx'
202 |
203 | print("==> Creating PyTorch {} model".format(args.model))
204 |
205 |
206 | model = create_model(
207 | args.model,
208 | num_classes=args.nb_classes,
209 | extra_attention_block=args.extra_attention_block,
210 | exportable=True
211 | )
212 |
213 | model.load_state_dict(torch.load(args.checkpoint)['model'])
214 | model.eval()
215 |
216 | if args.reparam:
217 | model = reparameterize_model(model)
218 |
219 | onnx_export(
220 | model=model,
221 | output_file=args.output,
222 | opset=args.opset,
223 | dynamic_size=args.dynamic_size,
224 | aten_fallback=args.aten_fallback,
225 | keep_initializers=args.keep_init,
226 | check_forward=args.check_forward,
227 | training=args.training,
228 | verbose=args.verbose,
229 | use_dynamo=args.dynamo,
230 | input_size=(3, args.img_size, args.img_size),
231 | batch_size=args.batch_size,
232 | )
233 |
234 | print("==> Passed")
235 |
236 |
237 | if __name__ == '__main__':
238 | main()
239 |
--------------------------------------------------------------------------------
/onnx_optimise.py:
--------------------------------------------------------------------------------
1 | """ ONNX optimization script
2 |
3 | Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc.
4 |
5 | NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 2.0.1 and ONNX 1.13),
6 | it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline).
7 |
8 | Copyright 2020 Ross Wightman
9 | """
10 | import argparse
11 | import warnings
12 |
13 | import onnx
14 | import onnxoptimizer as optimizer
15 |
16 |
17 | parser = argparse.ArgumentParser(description="Optimize ONNX model")
18 |
19 | parser.add_argument('--model', default='mobilenetv4_conv_large', type=str, metavar='MODEL',
20 | choices=['mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large_075',
21 | 'mobilenetv4_conv_large', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_medium',
22 | 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_medium_075',
23 | 'mobilenetv4_conv_small_035', 'mobilenetv4_conv_small_050', 'mobilenetv4_conv_blur_medium'],
24 | help='Name of model to train')
25 | parser.add_argument("--output", default=None, help="The optimized model output filename")
26 |
27 |
28 | def traverse_graph(graph, prefix=''):
29 | content = []
30 | indent = prefix + ' '
31 | graphs = []
32 | num_nodes = 0
33 | for node in graph.node:
34 | pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True)
35 | assert isinstance(gs, list)
36 | content.append(pn)
37 | graphs.extend(gs)
38 | num_nodes += 1
39 | for g in graphs:
40 | g_count, g_str = traverse_graph(g)
41 | content.append('\n' + g_str)
42 | num_nodes += g_count
43 | return num_nodes, '\n'.join(content)
44 |
45 |
46 | def main():
47 | args = parser.parse_args()
48 |
49 | if args.output == None:
50 | args.output = f'./{args.model}_optim.onnx'
51 |
52 | args.model = f'./{args.model}.onnx'
53 |
54 | onnx_model = onnx.load(args.model)
55 | num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph)
56 |
57 | # Optimizer passes to perform
58 | passes = [
59 | #'eliminate_deadend',
60 | 'eliminate_identity',
61 | 'eliminate_nop_dropout',
62 | 'eliminate_nop_pad',
63 | 'eliminate_nop_transpose',
64 | 'eliminate_unused_initializer',
65 | 'extract_constant_to_initializer',
66 | 'fuse_add_bias_into_conv',
67 | 'fuse_bn_into_conv',
68 | 'fuse_consecutive_concats',
69 | 'fuse_consecutive_reduce_unsqueeze',
70 | 'fuse_consecutive_squeezes',
71 | 'fuse_consecutive_transposes',
72 | #'fuse_matmul_add_bias_into_gemm',
73 | 'fuse_pad_into_conv',
74 | #'fuse_transpose_into_gemm',
75 | #'lift_lexical_references',
76 | ]
77 |
78 | # Apply the optimization on the original serialized model
79 | # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing
80 | # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401
81 | # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script.
82 | warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX."
83 | "Try onnxruntime optimization if this doesn't work.")
84 | optimized_model = optimizer.optimize(onnx_model, passes)
85 |
86 | num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph)
87 | print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str))
88 | print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes))
89 |
90 | # Save the ONNX model
91 | onnx.save(optimized_model, args.output)
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
--------------------------------------------------------------------------------
/onnx_validate.py:
--------------------------------------------------------------------------------
1 | """ ONNX-runtime validation script
2 |
3 | This script was created to verify accuracy and performance of exported ONNX
4 | models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
5 | pipeline for a fair comparison against the originals.
6 |
7 | Copyright 2020 Ross Wightman
8 | """
9 | import argparse
10 | import numpy as np
11 | import torch
12 | import onnxruntime
13 | from util.utils import AverageMeter
14 | import time
15 | from datasets import MyDataset, build_transform, read_split_data
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Pytorch ONNX Validation')
19 | parser.add_argument('--data_root', default='D:/flower_data', type=str,
20 | help='path to datasets')
21 | parser.add_argument('--onnx-input', default='./mobilenetv4_conv_large_optim.onnx', type=str, metavar='PATH',
22 | help='path to onnx model/weights file')
23 | parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
24 | help='path to output optimized onnx graph')
25 | parser.add_argument('--profile', action='store_true', default=False,
26 | help='Enable profiler output.')
27 | parser.add_argument('--workers', default=2, type=int, metavar='N',
28 | help='number of data loading workers (default: 2)')
29 | parser.add_argument('--batch-size', default=16, type=int,
30 | metavar='N', help='mini-batch size (default: 16), as same as the train_batch_size in train_gpu.py')
31 | parser.add_argument('--img-size', default=384, type=int,
32 | metavar='N', help='Input image dimension, uses model default if empty')
33 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
34 | help='Override mean pixel value of datasets')
35 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
36 | help='Override std deviation of of datasets')
37 | parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
38 | help='Override default crop pct of 0.875')
39 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
40 | help='Image resize interpolation type (overrides model)')
41 | parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
42 | help='use tensorflow mnasnet preporcessing')
43 | parser.add_argument('--print-freq', '-p', default=10, type=int,
44 | metavar='N', help='print frequency (default: 10)')
45 |
46 |
47 | def main():
48 | args = parser.parse_args()
49 | args.gpu_id = 0
50 |
51 | args.input_size = args.img_size
52 |
53 | # Set graph optimization level
54 | sess_options = onnxruntime.SessionOptions()
55 | sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
56 | if args.profile:
57 | sess_options.enable_profiling = True
58 | if args.onnx_output_opt:
59 | sess_options.optimized_model_filepath = args.onnx_output_opt
60 |
61 | session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
62 |
63 | # data_config = resolve_data_config(None, args)
64 | val_set = build_dataset(args)
65 |
66 | loader = torch.utils.data.DataLoader(
67 | val_set,
68 | batch_size=args.batch_size,
69 | num_workers=args.workers,
70 | drop_last=False
71 | )
72 |
73 | input_name = session.get_inputs()[0].name
74 |
75 | batch_time = AverageMeter()
76 | top1 = AverageMeter()
77 | top5 = AverageMeter()
78 | end = time.time()
79 | for i, (input, target) in enumerate(loader):
80 | # run the net and return prediction
81 | output = session.run([], {input_name: input.data.numpy()})
82 | output = output[0]
83 |
84 | # measure accuracy and record loss
85 | prec1, prec5 = accuracy_np(output, target.numpy())
86 | top1.update(prec1.item(), input.size(0))
87 | top5.update(prec5.item(), input.size(0))
88 |
89 | # measure elapsed time
90 | batch_time.update(time.time() - end)
91 | end = time.time()
92 |
93 | if i % args.print_freq == 0:
94 | print('Test: [{0}/{1}]\t'
95 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
96 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
97 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
98 | i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
99 | ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
100 |
101 | print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
102 | top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
103 |
104 |
105 | def accuracy_np(output, target):
106 | max_indices = np.argsort(output, axis=1)[:, ::-1]
107 | top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
108 | top1 = 100 * np.equal(max_indices[:, 0], target).mean()
109 | return top1, top5
110 |
111 |
112 | def build_dataset(args):
113 | train_image_path, train_image_label, val_image_path, val_image_label, class_indices = read_split_data(args.data_root)
114 |
115 | valid_transform = build_transform(False, args)
116 |
117 | valid_set = MyDataset(val_image_path, val_image_label, valid_transform)
118 | return valid_set
119 |
120 |
121 | if __name__ == '__main__':
122 | main()
--------------------------------------------------------------------------------
/optim_AUC.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from functools import partial
3 | from scipy.optimize import fmin
4 | from sklearn import metrics
5 |
6 |
7 | def max_voting(preds):
8 | """
9 | Create mean predictions
10 | :param probas: 2-d array of prediction values
11 | :return: max voted predictions
12 | """
13 |
14 | '''
15 | preds: np.array([[0, 2, 2, 2], [1, 1, 0, 1]])
16 | return : [[2]
17 | [1]]
18 | '''
19 | idxs = np.argmax(preds, axis=1)
20 | return np.take_along_axis(preds, idxs[:, None], axis=1)
21 |
22 |
23 | class OptimizeAUC:
24 | def __init__(self):
25 | self.coef_ = 0.
26 |
27 | def _auc(self, coef, outputs, labels):
28 | """
29 | This functions calulates and returns AUC.
30 | :param coef: coef list, of the same length as number of models
31 | :param X: predictions, in this case a 2d array
32 | :param y: targets, in our case binary 1d array
33 | """
34 | # multiply coefficients with every column of the array
35 | # with predictions.
36 | # this means: element 1 of coef is multiplied by column 1
37 | # of the prediction array, element 2 of coef is multiplied
38 | # by column 2 of the prediction array and so on!
39 |
40 | x_coef = coef * outputs
41 |
42 | # create predictions by taking row wise sum
43 | predictions = x_coef / np.sum(x_coef, axis=1, keepdims=True)
44 |
45 | # calculate auc score
46 | auc_score = metrics.roc_auc_score(labels, predictions, average='weighted', multi_class='ovo')
47 |
48 | # return negative auc
49 | return -1.0 * auc_score
50 |
51 |
52 | def fit(self, X, y):
53 | # remember partial from hyperparameter optimization chapter?
54 | loss_partial = partial(self._auc, outputs=X, labels=y)
55 |
56 | # dirichlet distribution. you can use any distribution you want
57 | # to initialize the coefficients
58 | # we want the coefficients to sum to 1
59 | initial_coef = np.random.dirichlet(np.ones(X.shape[1]), size=1)
60 | # use scipy fmin to minimize the loss function, in our case auc
61 | self.coef_ = fmin(loss_partial, initial_coef, disp=True)
62 |
63 | def predict(self, X):
64 | # this is similar to _auc function
65 | x_coef = X * self.coef_
66 | predictions = x_coef / np.sum(x_coef, axis=1, keepdims=True)
67 | return predictions
--------------------------------------------------------------------------------
/prediction_probs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiaowoguanren0615/MobileNetV4/ce474fdb8f500fc6dd477e2538e31539ff1be2f9/prediction_probs.png
--------------------------------------------------------------------------------
/sample_png/mobilenetV4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiaowoguanren0615/MobileNetV4/ce474fdb8f500fc6dd477e2538e31539ff1be2f9/sample_png/mobilenetV4.jpg
--------------------------------------------------------------------------------
/train_gpu.py:
--------------------------------------------------------------------------------
1 | """ ImageNet Training Script
2 |
3 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
4 | training results with some of the latest networks and training techniques. It favours canonical PyTorch
5 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
6 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
7 |
8 | This script was started from an early version of the PyTorch ImageNet example
9 | (https://github.com/pytorch/examples/tree/master/imagenet)
10 |
11 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
12 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
13 |
14 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
15 | """
16 | import argparse
17 | import datetime
18 | import numpy as np
19 | import time
20 | import torch
21 | import torch.backends.cudnn as cudnn
22 | from torch.utils.tensorboard import SummaryWriter
23 | import json
24 | import os
25 |
26 |
27 | from pathlib import Path
28 |
29 | import timm
30 | from timm.data import Mixup
31 | from timm.models import create_model
32 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
33 | from timm.scheduler import create_scheduler
34 | from timm.optim import create_optimizer
35 | from timm.utils import NativeScaler, get_state_dict, ModelEma
36 |
37 | from models import *
38 | from safetensors.torch import load_file
39 |
40 | from util.samplers import RASampler
41 | from util import utils as utils
42 | from util.optimizer import SophiaG, MARS
43 | from util.engine import train_one_epoch, evaluate
44 | from util.losses import DistillationLoss
45 |
46 | from datasets import build_dataset
47 | from datasets.threeaugment import new_data_aug_generator
48 |
49 | from estimate_model import Predictor, Plot_ROC, OptAUC
50 |
51 |
52 | def get_args_parser():
53 | parser = argparse.ArgumentParser(
54 | 'MobileNetV4 training and evaluation script', add_help=False)
55 | parser.add_argument('--batch-size', default=16, type=int)
56 | parser.add_argument('--epochs', default=5, type=int)
57 | parser.add_argument('--predict', default=True, type=bool, help='plot ROC curve and confusion matrix')
58 | parser.add_argument('--opt_auc', default=False, type=bool, help='Optimize AUC')
59 |
60 | # Model parameters
61 | parser.add_argument('--model', default='mobilenetv4_conv_large', type=str, metavar='MODEL',
62 | choices=['mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large_075',
63 | 'mobilenetv4_conv_large', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_medium',
64 | 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_medium_075',
65 | 'mobilenetv4_conv_small_035', 'mobilenetv4_conv_small_050', 'mobilenetv4_conv_blur_medium'],
66 | help='Name of model to train')
67 | parser.add_argument('--extra_attention_block', default=False, type=bool, help='Add an extra attention block')
68 | parser.add_argument('--input-size', default=384, type=int, help='images input size')
69 | parser.add_argument('--model-ema', action='store_true')
70 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
71 | parser.set_defaults(model_ema=True)
72 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
73 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
74 |
75 | # Optimizer parameters
76 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
77 | help='Optimizer (default: "adamw"')
78 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
79 | help='Optimizer Epsilon (default: 1e-8)')
80 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
81 | help='Optimizer Betas (default: None, use opt default)')
82 | parser.add_argument('--clip-grad', type=float, default=0.02, metavar='NORM',
83 | help='Clip gradient norm (default: None, no clipping)')
84 | parser.add_argument('--clip-mode', type=str, default='agc',
85 | help='Gradient clipping mode. One of ("norm", "value", "agc")')
86 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
87 | help='SGD momentum (default: 0.9)')
88 | parser.add_argument('--weight-decay', type=float, default=0.025,
89 | help='weight decay (default: 0.025)')
90 |
91 | # Learning rate schedule parameters
92 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
93 | help='LR scheduler (default: "cosine"')
94 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
95 | help='learning rate (default: 1e-3)')
96 | parser.add_argument('--adamw_lr', type=float, default=3e-3, metavar='AdamWLR',
97 | help='Using MARS optimizer, learning rate for adamw(default: 3e-3)')
98 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
99 | help='learning rate noise on/off epoch percentages')
100 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
101 | help='learning rate noise limit percent (default: 0.67)')
102 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
103 | help='learning rate noise std-dev (default: 1.0)')
104 | parser.add_argument('--warmup-lr', type=float, default=1e-4, metavar='LR',
105 | help='warmup learning rate (default: 1e-4)')
106 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
107 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
108 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
109 | help='epoch interval to decay LR')
110 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
111 | help='epochs to warmup LR, if scheduler supports')
112 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
113 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
114 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
115 | help='patience epochs for Plateau LR scheduler (default: 10')
116 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
117 | help='LR decay rate (default: 0.1)')
118 |
119 | # Augmentation parameters
120 | parser.add_argument('--ThreeAugment', action='store_true')
121 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
122 | help='Color jitter factor (default: 0.4)')
123 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
124 | help='Use AutoAugment policy. "v0" or "original". " + \
125 | "(default: rand-m9-mstd0.5-inc1)'),
126 | parser.add_argument('--smoothing', type=float, default=0.1,
127 | help='Label smoothing (default: 0.1)')
128 | parser.add_argument('--train-interpolation', type=str, default='bicubic',
129 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
130 | parser.add_argument('--repeated-aug', action='store_true')
131 | parser.add_argument('--no-repeated-aug',
132 | action='store_false', dest='repeated_aug')
133 | parser.set_defaults(repeated_aug=True)
134 |
135 | # Random Erase params
136 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
137 | help='Random erase prob (default: 0.25)')
138 | parser.add_argument('--remode', type=str, default='pixel',
139 | help='Random erase mode (default: "pixel")')
140 | parser.add_argument('--recount', type=int, default=1,
141 | help='Random erase count (default: 1)')
142 | parser.add_argument('--resplit', action='store_true', default=False,
143 | help='Do not random erase first (clean) augmentation split')
144 |
145 | # Mixup params
146 | parser.add_argument('--mixup', type=float, default=0.8,
147 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
148 | parser.add_argument('--cutmix', type=float, default=1.0,
149 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
150 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
151 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
152 | parser.add_argument('--mixup-prob', type=float, default=1.0,
153 | help='Probability of performing mixup or cutmix when either/both is enabled')
154 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
155 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
156 | parser.add_argument('--mixup-mode', type=str, default='batch',
157 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
158 |
159 | # Distillation parameters
160 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
161 | help='Name of teacher model to train (default: "regnety_160"')
162 | parser.add_argument('--teacher-path', type=str,
163 | default='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth')
164 | parser.add_argument('--distillation-type', default='none',
165 | choices=['none', 'soft', 'hard'], type=str, help="")
166 | parser.add_argument('--distillation-alpha',
167 | default=0.5, type=float, help="")
168 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
169 |
170 | # Finetuning params
171 | parser.add_argument('--finetune', default='./models/model.safetensors',
172 | help='finetune from checkpoint')
173 | parser.add_argument('--freeze_layers', type=bool, default=False, help='freeze layers')
174 | parser.add_argument('--set_bn_eval', action='store_true', default=False,
175 | help='set BN layers to eval mode during finetuning.')
176 |
177 | # Dataset parameters
178 | parser.add_argument('--data_root', default='D:/flower_data', type=str,
179 | help='dataset path')
180 | parser.add_argument('--nb_classes', default=5, type=int,
181 | help='number classes of your dataset')
182 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
183 | type=str, help='Image Net dataset path')
184 | parser.add_argument('--inat-category', default='name',
185 | choices=['kingdom', 'phylum', 'class', 'order',
186 | 'supercategory', 'family', 'genus', 'name'],
187 | type=str, help='semantic granularity')
188 | parser.add_argument('--output_dir', default='./output',
189 | help='path where to save, empty for no saving')
190 | parser.add_argument('--writer_output', default='./',
191 | help='path where to save SummaryWriter, empty for no saving')
192 | parser.add_argument('--device', default='cuda',
193 | help='device to use for training / testing')
194 | parser.add_argument('--seed', default=0, type=int)
195 | parser.add_argument('--resume', default='', help='resume from checkpoint')
196 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
197 | help='start epoch')
198 | parser.add_argument('--eval', action='store_true',
199 | help='Perform evaluation only')
200 | parser.add_argument('--dist-eval', action='store_true',
201 | default=False, help='Enabling distributed evaluation')
202 | parser.add_argument('--num_workers', default=0, type=int)
203 | parser.add_argument('--pin-mem', action='store_true',
204 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
205 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
206 | help='')
207 | parser.set_defaults(pin_mem=True)
208 |
209 | # training parameters
210 | parser.add_argument('--world_size', default=1, type=int,
211 | help='number of distributed processes')
212 | parser.add_argument('--local_rank', default=0, type=int)
213 | parser.add_argument('--dist_url', default='env://',
214 | help='url used to set up distributed training')
215 | parser.add_argument('--save_freq', default=1, type=int,
216 | help='frequency of model saving')
217 | return parser
218 |
219 |
220 |
221 |
222 | def main(args):
223 | print(args)
224 | utils.init_distributed_mode(args)
225 |
226 | if args.local_rank == 0:
227 | writer = SummaryWriter(os.path.join(args.writer_output, 'runs'))
228 |
229 | if args.distillation_type != 'none' and args.finetune and not args.eval:
230 | raise NotImplementedError(
231 | "Finetuning with distillation not yet supported")
232 |
233 | device = torch.device(args.device)
234 |
235 | # fix the seed for reproducibility
236 | seed = args.seed + utils.get_rank()
237 | torch.manual_seed(seed)
238 | np.random.seed(seed)
239 | # random.seed(seed)
240 |
241 | cudnn.benchmark = True
242 |
243 | dataset_train, dataset_val = build_dataset(args=args)
244 |
245 | if args.distributed:
246 | num_tasks = utils.get_world_size()
247 | global_rank = utils.get_rank()
248 | if args.repeated_aug:
249 | sampler_train = RASampler(
250 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
251 | )
252 | else:
253 | sampler_train = torch.utils.data.DistributedSampler(
254 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
255 | )
256 | if args.dist_eval:
257 | if len(dataset_val) % num_tasks != 0:
258 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
259 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
260 | 'equal num of samples per-process.')
261 | sampler_val = torch.utils.data.DistributedSampler(
262 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
263 | else:
264 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
265 | else:
266 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
267 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
268 |
269 | data_loader_train = torch.utils.data.DataLoader(
270 | dataset_train, sampler=sampler_train,
271 | batch_size=args.batch_size,
272 | num_workers=args.num_workers,
273 | pin_memory=args.pin_mem,
274 | drop_last=True,
275 | )
276 |
277 | if args.ThreeAugment:
278 | data_loader_train.dataset.transform = new_data_aug_generator(args)
279 |
280 | data_loader_val = torch.utils.data.DataLoader(
281 | dataset_val, sampler=sampler_val,
282 | batch_size=int(1.5 * args.batch_size),
283 | num_workers=args.num_workers,
284 | pin_memory=args.pin_mem,
285 | drop_last=False
286 | )
287 |
288 | mixup_fn = None
289 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
290 | if mixup_active:
291 | mixup_fn = Mixup(
292 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
293 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
294 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
295 |
296 | print(f"Creating model: {args.model}")
297 |
298 | model = create_model(
299 | args.model,
300 | extra_attention_block=args.extra_attention_block,
301 | args=args
302 | )
303 | model.reset_classifier(num_classes=args.nb_classes)
304 |
305 | if args.finetune:
306 | if args.finetune.startswith('https'):
307 | checkpoint = torch.hub.load_state_dict_from_url(
308 | args.finetune, map_location='cpu', check_hash=True)
309 | else:
310 | checkpoint = utils.load_model(args.finetune, model)
311 |
312 | checkpoint_model = checkpoint
313 | # state_dict = model.state_dict()
314 | # new_state_dict = utils.map_safetensors(checkpoint_model, state_dict)
315 |
316 | for k in list(checkpoint_model.keys()):
317 | if 'classifier' in k:
318 | print(f"Removing key {k} from pretrained checkpoint")
319 | del checkpoint_model[k]
320 |
321 | msg = model.load_state_dict(checkpoint_model, strict=False)
322 | print(msg)
323 |
324 | if args.freeze_layers:
325 | for name, para in model.named_parameters():
326 | if 'classifier' not in name:
327 | para.requires_grad_(False)
328 | # else:
329 | # print('training {}'.format(name))
330 | if args.extra_attention_block:
331 | for name, para in model.extra_attention_block.named_parameters():
332 | para.requires_grad_(True)
333 |
334 | model.to(device)
335 |
336 | model_ema = None
337 | if args.model_ema:
338 | # Important to create EMA model after cuda(), DP wrapper, and AMP but
339 | # before SyncBN and DDP wrapper
340 | model_ema = ModelEma(
341 | model,
342 | decay=args.model_ema_decay,
343 | device='cpu' if args.model_ema_force_cpu else '',
344 | resume='')
345 |
346 | model_without_ddp = model
347 | if args.distributed:
348 | model = torch.nn.parallel.DistributedDataParallel(
349 | model, device_ids=[args.gpu])
350 | model_without_ddp = model.module
351 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
352 | print('number of params:', n_parameters)
353 |
354 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
355 | # args.lr = linear_scaled_lr
356 | #
357 | # print('*****************')
358 | # print('Initial LR is ', linear_scaled_lr)
359 | # print('*****************')
360 |
361 | # optimizer = create_optimizer(args, model_without_ddp)
362 | optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=2e-4, weight_decay=args.weight_decay) if args.finetune else create_optimizer(args, model_without_ddp)
363 |
364 | loss_scaler = NativeScaler()
365 | lr_scheduler, _ = create_scheduler(args, optimizer)
366 |
367 | criterion = LabelSmoothingCrossEntropy()
368 |
369 | if args.mixup > 0.:
370 | # smoothing is handled with mixup label transform
371 | criterion = SoftTargetCrossEntropy()
372 | elif args.smoothing:
373 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
374 | else:
375 | criterion = torch.nn.CrossEntropyLoss()
376 |
377 | teacher_model = None
378 | if args.distillation_type != 'none':
379 | assert args.teacher_path, 'need to specify teacher-path when using distillation'
380 | print(f"Creating teacher model: {args.teacher_model}")
381 | teacher_model = create_model(
382 | args.teacher_model,
383 | pretrained=False,
384 | num_classes=args.nb_classes,
385 | global_pool='avg',
386 | )
387 | if args.teacher_path.startswith('https'):
388 | checkpoint = torch.hub.load_state_dict_from_url(
389 | args.teacher_path, map_location='cpu', check_hash=True)
390 | else:
391 | checkpoint = torch.load(args.teacher_path, map_location='cpu')
392 | teacher_model.load_state_dict(checkpoint['model'])
393 | teacher_model.to(device)
394 | teacher_model.eval()
395 |
396 | # wrap the criterion in our custom DistillationLoss, which
397 | # just dispatches to the original criterion if args.distillation_type is
398 | # 'none'
399 | criterion = DistillationLoss(
400 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
401 | )
402 |
403 | max_accuracy = 0.0
404 |
405 | output_dir = Path(args.output_dir)
406 | if args.output_dir and utils.is_main_process():
407 | with (output_dir / "model.txt").open("a") as f:
408 | f.write(str(model))
409 | if args.output_dir and utils.is_main_process():
410 | with (output_dir / "args.txt").open("a") as f:
411 | f.write(json.dumps(args.__dict__, indent=2) + "\n")
412 | if args.resume or os.path.exists(f'{args.output_dir}/{args.model}_best_checkpoint.pth'):
413 | args.resume = f'{args.output_dir}/{args.model}_best_checkpoint.pth'
414 | if args.resume.startswith('https'):
415 | checkpoint = torch.hub.load_state_dict_from_url(
416 | args.resume, map_location='cpu', check_hash=True)
417 | else:
418 | print("Loading local checkpoint at {}".format(args.resume))
419 | checkpoint = torch.load(args.resume, map_location='cpu')
420 | msg = model_without_ddp.load_state_dict(checkpoint['model'])
421 | print(msg)
422 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
423 |
424 | optimizer.load_state_dict(checkpoint['optimizer'])
425 | for state in optimizer.state.values(): # load parameters to cuda
426 | for k, v in state.items():
427 | if isinstance(v, torch.Tensor):
428 | state[k] = v.cuda()
429 |
430 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
431 | max_accuracy = checkpoint['best_score']
432 | print(f'Now max accuracy is {max_accuracy}')
433 | args.start_epoch = checkpoint['epoch'] + 1
434 | if args.model_ema:
435 | utils._load_checkpoint_for_ema(
436 | model_ema, checkpoint['model_ema'])
437 | if 'scaler' in checkpoint:
438 | loss_scaler.load_state_dict(checkpoint['scaler'])
439 | if args.eval:
440 | # util.replace_batchnorm(model) # Users may choose whether to merge Conv-BN layers during eval
441 | print(f"Evaluating model: {args.model}")
442 | print(f'No Visualization')
443 | test_stats = evaluate(data_loader_val, model, device, None, None, args, visualization=False)
444 | print(
445 | f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
446 | )
447 | # print(model)
448 | print(f"Start training for {args.epochs} epochs")
449 | start_time = time.time()
450 |
451 | for epoch in range(args.start_epoch, args.epochs):
452 | if args.distributed:
453 | data_loader_train.sampler.set_epoch(epoch)
454 |
455 | train_stats = train_one_epoch(
456 | model, criterion, data_loader_train,
457 | optimizer, device, epoch, loss_scaler,
458 | args.clip_grad, args.clip_mode, model_ema, mixup_fn,
459 | # set_training_mode=args.finetune == '', # keep in eval mode during finetuning
460 | set_training_mode=True,
461 | set_bn_eval=args.set_bn_eval, # set bn to eval if finetune
462 | writer=writer,
463 | args=args
464 | )
465 |
466 | lr_scheduler.step(epoch)
467 |
468 | test_stats = evaluate(data_loader_val, model, device, epoch, writer, args, visualization=True)
469 | print(
470 | f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
471 |
472 | if max_accuracy < test_stats["acc1"]:
473 | max_accuracy = test_stats["acc1"]
474 | if args.output_dir:
475 | ckpt_path = os.path.join(output_dir, f'{args.model}_best_checkpoint.pth')
476 | checkpoint_paths = [ckpt_path]
477 | print("Saving checkpoint to {}".format(ckpt_path))
478 | for checkpoint_path in checkpoint_paths:
479 | utils.save_on_master({
480 | 'model': model_without_ddp.state_dict(),
481 | 'optimizer': optimizer.state_dict(),
482 | 'lr_scheduler': lr_scheduler.state_dict(),
483 | 'epoch': epoch,
484 | 'best_score': max_accuracy,
485 | 'model_ema': get_state_dict(model_ema),
486 | 'scaler': loss_scaler.state_dict(),
487 | 'args': args,
488 | }, checkpoint_path)
489 |
490 | print(f'Max accuracy: {max_accuracy:.2f}%')
491 |
492 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
493 | **{f'test_{k}': v for k, v in test_stats.items()},
494 | 'epoch': epoch,
495 | 'n_parameters': n_parameters}
496 |
497 | if args.output_dir and utils.is_main_process():
498 | with (output_dir / "log.txt").open("a") as f:
499 | f.write(json.dumps(log_stats) + "\n")
500 |
501 | total_time = time.time() - start_time
502 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
503 | print('Training time {}'.format(total_time_str))
504 |
505 | # plot ROC curve and confusion matrix
506 | if args.predict and utils.is_main_process():
507 | model_predict = create_model(
508 | args.model,
509 | extra_attention_block=args.extra_attention_block,
510 | args=args
511 | )
512 |
513 | model_predict.reset_classifier(num_classes=args.nb_classes)
514 | model_predict.to(device)
515 | print('*******************STARTING PREDICT*******************')
516 | Predictor(model_predict, data_loader_val, f'{args.output_dir}/{args.model}_best_checkpoint.pth', device)
517 | Plot_ROC(model_predict, data_loader_val, f'{args.output_dir}/{args.model}_best_checkpoint.pth', device)
518 |
519 | if args.opt_auc:
520 | OptAUC(model_predict, data_loader_val, f'{args.output_dir}/{args.model}_best_checkpoint.pth', device)
521 |
522 |
523 | if __name__ == '__main__':
524 | parser = argparse.ArgumentParser(
525 | 'MobileNetV4 training and evaluation script', parents=[get_args_parser()])
526 | args = parser.parse_args()
527 | if args.output_dir:
528 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
529 | main(args)
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | from .engine import train_one_epoch, evaluate
2 | from .losses import DistillationLoss
3 | from .samplers import RASampler
4 | from .optimizer import SophiaG, MARS
5 | from .utils import *
--------------------------------------------------------------------------------
/util/engine.py:
--------------------------------------------------------------------------------
1 | """
2 | Train and eval functions used in main.py
3 | """
4 | import math
5 | import sys
6 | from typing import Iterable, Optional
7 |
8 | import torch
9 |
10 | from timm.data import Mixup
11 | from timm.utils import ModelEma, accuracy
12 |
13 | from .losses import DistillationLoss
14 | from util import utils as utils
15 |
16 |
17 | def set_bn_state(model):
18 | for m in model.modules():
19 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
20 | m.eval()
21 |
22 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
23 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
24 | device: torch.device, epoch: int, loss_scaler,
25 | clip_grad: float = 0,
26 | clip_mode: str = 'norm',
27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
28 | set_training_mode=True,
29 | set_bn_eval=False,
30 | writer=None,
31 | args=None):
32 | """
33 | Train the model for one epoch.
34 |
35 | Args:
36 | model (torch.nn.Module): The model to be trained.
37 | criterion (DistillationLoss): The loss function used for training.
38 | data_loader (Iterable): The data loader for the training data.
39 | optimizer (torch.optim.Optimizer): The optimizer used for training.
40 | device (torch.device): The device used for training (CPU or GPU).
41 | epoch (int): The current training epoch.
42 | loss_scaler: The object used for gradient scaling.
43 | clip_grad (float, optional): The maximum value for gradient clipping. Default is 0, which means no gradient clipping.
44 | clip_mode (str, optional): The mode for gradient clipping, can be 'norm' or 'value'. Default is 'norm'.
45 | model_ema (Optional[ModelEma], optional): The EMA (Exponential Moving Average) model for saving model weights.
46 | mixup_fn (Optional[Mixup], optional): The function used for Mixup data augmentation.
47 | set_training_mode (bool, optional): Whether to set the model to training mode. Default is True.
48 | set_bn_eval (bool, optional): Whether to set the batch normalization layers to evaluation mode. Default is False.
49 | writer (Optional[Any], optional): The object used for writing TensorBoard logs.
50 | args (Optional[Any], optional): Additional arguments.
51 |
52 | Returns:
53 | Dict[str, float]: A dictionary containing the average values of the training metrics.
54 | """
55 |
56 |
57 | model.train(set_training_mode)
58 | num_steps = len(data_loader)
59 |
60 | if set_bn_eval:
61 | set_bn_state(model)
62 | metric_logger = utils.MetricLogger(delimiter=" ")
63 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
64 | header = 'Epoch: [{}]'.format(epoch)
65 | print_freq = 50
66 |
67 | for idx, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
68 | samples = samples.to(device, non_blocking=True)
69 | targets = targets.to(device, non_blocking=True)
70 |
71 | if mixup_fn is not None:
72 | samples, targets = mixup_fn(samples, targets)
73 |
74 | with torch.cuda.amp.autocast():
75 | outputs = model(samples)
76 | loss = criterion(samples, outputs, targets)
77 |
78 | loss_value = loss.item()
79 |
80 | if not math.isfinite(loss_value):
81 | print("Loss is {}, stopping training".format(loss_value))
82 | sys.exit(1)
83 |
84 | optimizer.zero_grad()
85 |
86 | # this attribute is added by timm on one optimizer (adahessian)
87 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
88 | with torch.cuda.amp.autocast():
89 | loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
90 | parameters=model.parameters(), create_graph=is_second_order)
91 |
92 | torch.cuda.synchronize()
93 | if model_ema is not None:
94 | model_ema.update(model)
95 |
96 | learning_rate = optimizer.param_groups[0]["lr"]
97 | metric_logger.update(loss=loss_value)
98 | metric_logger.update(lr=learning_rate)
99 |
100 |
101 | if idx % print_freq == 0:
102 | if args.local_rank == 0:
103 | iter_all_count = epoch * num_steps + idx
104 | writer.add_scalar('loss', loss, iter_all_count)
105 | # writer.add_scalar('grad_norm', grad_norm, iter_all_count)
106 | writer.add_scalar('lr', learning_rate, iter_all_count)
107 |
108 | # gather the stats from all processes
109 | metric_logger.synchronize_between_processes()
110 | print("Averaged stats:", metric_logger)
111 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
112 |
113 |
114 | @torch.inference_mode()
115 | def evaluate(data_loader: Iterable, model: torch.nn.Module,
116 | device: torch.device, epoch: int,
117 | writer, args,
118 | visualization=True):
119 | """
120 | Evaluate the model for one epoch.
121 |
122 | Args:
123 | data_loader (Iterable): The data loader for the valid data.
124 | model (torch.nn.Module): The model to be evaluated.
125 | device (torch.device): The device used for training (CPU or GPU).
126 | epoch (int): The current training epoch.
127 | writer (Optional[Any], optional): The object used for writing TensorBoard logs.
128 | args (Optional[Any], optional): Additional arguments.
129 | visualization (bool, optional): Whether to use TensorBoard visualization. Default is True.
130 |
131 | Returns:
132 | Dict[str, float]: A dictionary containing the average values of the training metrics.
133 | """
134 |
135 | criterion = torch.nn.CrossEntropyLoss()
136 |
137 | metric_logger = utils.MetricLogger(delimiter=" ")
138 | header = 'Test:'
139 | # switch to evaluation mode
140 | model.eval()
141 |
142 | print_freq = 20
143 | for images, target in metric_logger.log_every(data_loader, print_freq, header):
144 | images = images.to(device, non_blocking=True)
145 | target = target.to(device, non_blocking=True)
146 |
147 | # compute output
148 | with torch.cuda.amp.autocast():
149 | output = model(images)
150 | loss = criterion(output, target)
151 |
152 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
153 |
154 | batch_size = images.shape[0]
155 | metric_logger.update(loss=loss.item())
156 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
157 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
158 |
159 |
160 | if visualization and args.local_rank == 0:
161 | writer.add_scalar('Acc@1', acc1.item(), epoch)
162 | writer.add_scalar('Acc@5', acc5.item(), epoch)
163 |
164 | # gather the stats from all processes
165 | metric_logger.synchronize_between_processes()
166 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
167 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
168 |
169 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
170 |
--------------------------------------------------------------------------------
/util/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the knowledge distillation loss, proposed in deit
3 | """
4 | import torch
5 | from torch.nn import functional as F
6 |
7 |
8 | class DistillationLoss(torch.nn.Module):
9 | """
10 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
11 | taking a teacher model prediction and using it as additional supervision.
12 | """
13 |
14 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
15 | distillation_type: str, alpha: float, tau: float):
16 | super().__init__()
17 | self.base_criterion = base_criterion
18 | self.teacher_model = teacher_model
19 | assert distillation_type in ['none', 'soft', 'hard']
20 | self.distillation_type = distillation_type
21 | self.alpha = alpha
22 | self.tau = tau
23 |
24 | def forward(self, inputs, outputs, labels):
25 | """
26 | Args:
27 | inputs: The original inputs that are feed to the teacher model
28 | outputs: the outputs of the model to be trained. It is expected to be
29 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
30 | in the first position and the distillation predictions as the second output
31 | labels: the labels for the base criterion
32 | """
33 | outputs_kd = None
34 | if not isinstance(outputs, torch.Tensor):
35 | # assume that the model outputs a tuple of [outputs, outputs_kd]
36 | outputs, outputs_kd = outputs
37 | base_loss = self.base_criterion(outputs, labels)
38 | if self.distillation_type == 'none':
39 | return base_loss
40 |
41 | if outputs_kd is None:
42 | raise ValueError("When knowledge distillation is enabled, the model is "
43 | "expected to return a Tuple[Tensor, Tensor] with the output of the "
44 | "class_token and the dist_token")
45 | # don't backprop throught the teacher
46 | with torch.no_grad():
47 | teacher_outputs = self.teacher_model(inputs)
48 |
49 | if self.distillation_type == 'soft':
50 | T = self.tau
51 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
52 | # with slight modifications
53 | distillation_loss = F.kl_div(
54 | F.log_softmax(outputs_kd / T, dim=1),
55 | F.log_softmax(teacher_outputs / T, dim=1),
56 | reduction='sum',
57 | log_target=True
58 | ) * (T * T) / outputs_kd.numel()
59 | elif self.distillation_type == 'hard':
60 | distillation_loss = F.cross_entropy(
61 | outputs_kd, teacher_outputs.argmax(dim=1))
62 |
63 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
64 | return loss
--------------------------------------------------------------------------------
/util/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.optim.optimizer import Optimizer
4 | from typing import List
5 | import math
6 |
7 |
8 | # optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=1e-1)
9 | # optimizer = MARS(model.parameters(), lr=args.lr, weight_decay = args.weight_decay, lr_1d=args.adamw_lr)
10 |
11 | __all__ = ['SophiaG', 'MARS']
12 |
13 |
14 | def exists(val):
15 | return val is not None
16 |
17 |
18 | def update_fn(p, grad, exp_avg, exp_avg_sq, lr, wd, beta1, beta2, last_grad, eps, amsgrad, max_exp_avg_sq, step, gamma,
19 | mars_type, is_grad_2d, optimize_1d, lr_1d_factor, betas_1d, weight_decay_1d):
20 | # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
21 | if optimize_1d or is_grad_2d:
22 | c_t = (grad - last_grad).mul(gamma * (beta1 / (1. - beta1))).add(grad)
23 | c_t_norm = torch.norm(c_t)
24 | if c_t_norm > 1.:
25 | c_t = c_t / c_t_norm
26 | exp_avg.mul_(beta1).add_(c_t, alpha=1. - beta1)
27 | if (mars_type == "mars-adamw") or (mars_type == "mars-shampoo" and not is_grad_2d):
28 | exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
29 | bias_correction1 = 1.0 - beta1 ** step
30 | bias_correction2 = 1.0 - beta2 ** step
31 | if amsgrad:
32 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
33 | denom = max_exp_avg_sq.sqrt().mul(1 / math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
34 | else:
35 | denom = exp_avg_sq.sqrt().mul(1 / math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
36 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom))
37 | elif mars_type == "mars-lion":
38 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign())
39 | elif mars_type == "mars-shampoo" and is_grad_2d:
40 | factor = max(1, grad.size(0) / grad.size(1)) ** 0.5
41 | real_update_tmp = NewtonSchulz(exp_avg.mul(1. / (1. - beta1)), eps=eps).mul(factor).add(wd, p.data).mul(-lr)
42 | p.data.add_(real_update_tmp)
43 | else:
44 | beta1_1d, beta2_1d = betas_1d
45 | exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d)
46 | exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d)
47 | bias_correction1 = 1.0 - beta1_1d ** step
48 | bias_correction2 = 1.0 - beta2_1d ** step
49 | if amsgrad:
50 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
51 | denom = max_exp_avg_sq.sqrt().mul(1 / math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
52 | else:
53 | denom = exp_avg_sq.sqrt().mul(1 / math.sqrt(bias_correction2)).add(eps).mul(bias_correction1)
54 | real_update_tmp = -lr * lr_1d_factor * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom))
55 | p.data.add_(real_update_tmp)
56 | return exp_avg, exp_avg_sq
57 |
58 |
59 | class MARS(Optimizer):
60 | def __init__(self, params, lr=3e-3, betas=(0.95, 0.99), eps=1e-8, weight_decay=0., amsgrad=False, gamma=0.025,
61 | is_approx=True, mars_type="mars-adamw", optimize_1d=False, lr_1d=3e-3, betas_1d=(0.9, 0.95),
62 | weight_decay_1d=0.1):
63 | if not 0.0 <= lr:
64 | raise ValueError("Invalid learning rate: {}".format(lr))
65 | if not 0.0 <= eps:
66 | raise ValueError("Invalid epsilon value: {}".format(eps))
67 | if not 0.0 <= betas[0] < 1.0:
68 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
69 | if not 0.0 <= betas[1] < 1.0:
70 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
71 | assert mars_type in ["mars-adamw", "mars-lion", "mars-shampoo"], "MARS type not supported"
72 | defaults = dict(lr=lr, betas=betas, eps=eps,
73 | weight_decay=weight_decay, amsgrad=amsgrad,
74 | mars_type=mars_type, gamma=gamma,
75 | optimize_1d=optimize_1d, weight_decay_1d=weight_decay_1d)
76 | super(MARS, self).__init__(params, defaults)
77 | self.eps = eps
78 | self.update_fn = update_fn
79 | self.lr = lr
80 | self.weight_decay = weight_decay
81 | self.amsgrad = amsgrad
82 | self.step_num = 0
83 | self.is_approx = is_approx
84 | self.gamma = gamma
85 | self.mars_type = mars_type
86 | self.optimize_1d = optimize_1d
87 | self.lr_1d_factor = lr_1d / lr
88 | self.weight_decay_1d = weight_decay_1d
89 | self.betas_1d = betas_1d
90 |
91 | @torch.no_grad()
92 | def update_last_grad(self):
93 | if not self.is_approx:
94 | for group in self.param_groups:
95 | for p in group['params']:
96 | state = self.state[p]
97 | if "last_grad" not in state:
98 | state["last_grad"] = torch.zeros_like(p)
99 | state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0)
100 |
101 | @torch.no_grad()
102 | def update_previous_grad(self):
103 | if not self.is_approx:
104 | for group in self.param_groups:
105 | # print ("para name", len(group['params']), len(group['names']), group['names'])
106 | for p in group['params']:
107 | # import pdb
108 | # pdb.set_trace()
109 | if p.grad is None:
110 | print(p, "grad is none")
111 | continue
112 | state = self.state[p]
113 | if "previous_grad" not in state:
114 | state['previous_grad'] = torch.zeros_like(p)
115 | state['previous_grad'].zero_().add_(p.grad, alpha=1.0)
116 |
117 | def __setstate__(self, state):
118 | super(MARS, self).__setstate__(state)
119 | for group in self.param_groups:
120 | group.setdefault('amsgrad', False)
121 |
122 | @torch.no_grad()
123 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None):
124 | """Performs a single optimization step.
125 |
126 | Arguments:
127 | closure (callable, optional): A closure that reevaluates the model
128 | and returns the loss.
129 |
130 | If using exact version, the example usage is as follows:
131 | previous_X, previous_Y = None, None
132 | for epoch in range(epochs):
133 | for X, Y in data_loader:
134 | if previous_X:
135 | logits, loss = model(X, Y)
136 | loss.backward()
137 | optimizer.update_previous_grad()
138 | optimizer.zero_grad(set_to_none=True)
139 | logits, loss = model(X, Y)
140 | loss.backward()
141 | optimizer.step(bs=bs)
142 | optimizer.zero_grad(set_to_none=True)
143 | optimizer.update_last_grad()
144 | iter_num += 1
145 | previous_X, previous_Y = X.clone(), Y.clone()
146 | """
147 | if any(p is not None for p in [grads, output_params, scale, grad_norms]):
148 | raise RuntimeError(
149 | 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')
150 |
151 | loss = None
152 | if exists(closure):
153 | with torch.enable_grad():
154 | loss = closure()
155 | real_update = 0
156 | real_update_wo_lr = 0
157 | gamma = self.gamma
158 | # import pdb
159 | # pdb.set_trace()
160 | for group in self.param_groups:
161 | for p in filter(lambda p: exists(p.grad), group['params']):
162 | if p.grad is None:
163 | continue
164 | grad = p.grad.data
165 | if grad.is_sparse:
166 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
167 | amsgrad = group['amsgrad']
168 |
169 | state = self.state[p]
170 | # ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
171 | # State initialization
172 | if len(state) <= 1:
173 | state['step'] = 0
174 | # Exponential moving average of gradient values
175 | state['exp_avg'] = torch.zeros_like(p.data)
176 | # Last Gradient
177 | state['last_grad'] = torch.zeros_like(p)
178 | # state['previous_grad'] = torch.zeros_like(p)
179 | # Exponential moving average of squared gradient values
180 | state['exp_avg_sq'] = torch.zeros_like(p.data)
181 | if amsgrad:
182 | # Maintains max of all exp. moving avg. of sq. grad. values
183 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
184 | # import pdb
185 | # pdb.set_trace()
186 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
187 | last_grad = state['last_grad']
188 | lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas']
189 | if amsgrad:
190 | max_exp_avg_sq = state['max_exp_avg_sq']
191 | else:
192 | max_exp_avg_sq = 0
193 |
194 | if 'step' in state:
195 | state['step'] += 1
196 | else:
197 | state['step'] = 1
198 | step = state['step']
199 | is_grad_2d = (len(grad.shape) == 2)
200 | exp_avg, exp_avg_sq = self.update_fn(
201 | p,
202 | grad,
203 | exp_avg,
204 | exp_avg_sq,
205 | lr,
206 | wd,
207 | beta1,
208 | beta2,
209 | last_grad,
210 | self.eps,
211 | amsgrad,
212 | max_exp_avg_sq,
213 | step,
214 | gamma,
215 | mars_type=self.mars_type,
216 | is_grad_2d=is_grad_2d,
217 | optimize_1d=self.optimize_1d,
218 | lr_1d_factor=self.lr_1d_factor,
219 | betas_1d=self.betas_1d,
220 | weight_decay_1d=self.weight_decay if self.optimize_1d else self.weight_decay_1d
221 | )
222 | if self.is_approx:
223 | state['last_grad'] = grad
224 | self.step_num = step
225 |
226 | return loss
227 |
228 |
229 | @torch.compile
230 | def NewtonSchulz(M, steps=5, eps=1e-7):
231 | a, b, c = (3.4445, -4.7750, 2.0315)
232 | X = M.bfloat16() / (M.norm() + eps)
233 | if M.size(0) > M.size(1):
234 | X = X.T
235 | for _ in range(steps):
236 | A = X @ X.T
237 | B = A @ X
238 | X = a * X + b * B + c * A @ B
239 | if M.size(0) > M.size(1):
240 | X = X.T
241 | return X.to(M.dtype)
242 |
243 |
244 | class SophiaG(Optimizer):
245 | def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho=0.04,
246 | weight_decay=1e-1, *, maximize: bool = False,
247 | capturable: bool = False):
248 | if not 0.0 <= lr:
249 | raise ValueError("Invalid learning rate: {}".format(lr))
250 | if not 0.0 <= betas[0] < 1.0:
251 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
252 | if not 0.0 <= betas[1] < 1.0:
253 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
254 | if not 0.0 <= rho:
255 | raise ValueError("Invalid rho parameter at index 1: {}".format(rho))
256 | if not 0.0 <= weight_decay:
257 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
258 | defaults = dict(lr=lr, betas=betas, rho=rho,
259 | weight_decay=weight_decay,
260 | maximize=maximize, capturable=capturable)
261 | super(SophiaG, self).__init__(params, defaults)
262 |
263 | def __setstate__(self, state):
264 | super().__setstate__(state)
265 | for group in self.param_groups:
266 | group.setdefault('maximize', False)
267 | group.setdefault('capturable', False)
268 | state_values = list(self.state.values())
269 | step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
270 | if not step_is_tensor:
271 | for s in state_values:
272 | s['step'] = torch.tensor(float(s['step']))
273 |
274 | @torch.no_grad()
275 | def update_hessian(self):
276 | for group in self.param_groups:
277 | beta1, beta2 = group['betas']
278 | for p in group['params']:
279 | if p.grad is None:
280 | continue
281 | state = self.state[p]
282 |
283 | if len(state) == 0:
284 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
285 | if self.defaults['capturable'] else torch.tensor(0.)
286 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
287 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
288 |
289 | if 'hessian' not in state.keys():
290 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
291 |
292 | state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
293 |
294 | @torch.no_grad()
295 | def step(self, closure=None, bs=5120):
296 | loss = None
297 | if closure is not None:
298 | with torch.enable_grad():
299 | loss = closure()
300 |
301 | for group in self.param_groups:
302 | params_with_grad = []
303 | grads = []
304 | exp_avgs = []
305 | state_steps = []
306 | hessian = []
307 | beta1, beta2 = group['betas']
308 |
309 | for p in group['params']:
310 | if p.grad is None:
311 | continue
312 | params_with_grad.append(p)
313 |
314 | if p.grad.is_sparse:
315 | raise RuntimeError('Hero does not support sparse gradients')
316 | grads.append(p.grad)
317 | state = self.state[p]
318 | # State initialization
319 | if len(state) == 0:
320 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
321 | if self.defaults['capturable'] else torch.tensor(0.)
322 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
323 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
324 |
325 | if 'hessian' not in state.keys():
326 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
327 |
328 | exp_avgs.append(state['exp_avg'])
329 | state_steps.append(state['step'])
330 | hessian.append(state['hessian'])
331 |
332 | if self.defaults['capturable']:
333 | bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs
334 |
335 | sophiag(params_with_grad,
336 | grads,
337 | exp_avgs,
338 | hessian,
339 | state_steps,
340 | bs=bs,
341 | beta1=beta1,
342 | beta2=beta2,
343 | rho=group['rho'],
344 | lr=group['lr'],
345 | weight_decay=group['weight_decay'],
346 | maximize=group['maximize'],
347 | capturable=group['capturable'])
348 |
349 | return loss
350 |
351 |
352 | def sophiag(params: List[Tensor],
353 | grads: List[Tensor],
354 | exp_avgs: List[Tensor],
355 | hessian: List[Tensor],
356 | state_steps: List[Tensor],
357 | capturable: bool = False,
358 | *,
359 | bs: int,
360 | beta1: float,
361 | beta2: float,
362 | rho: float,
363 | lr: float,
364 | weight_decay: float,
365 | maximize: bool):
366 | if not all(isinstance(t, torch.Tensor) for t in state_steps):
367 | raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
368 |
369 | func = _single_tensor_sophiag
370 |
371 | func(params,
372 | grads,
373 | exp_avgs,
374 | hessian,
375 | state_steps,
376 | bs=bs,
377 | beta1=beta1,
378 | beta2=beta2,
379 | rho=rho,
380 | lr=lr,
381 | weight_decay=weight_decay,
382 | maximize=maximize,
383 | capturable=capturable)
384 |
385 |
386 | def _single_tensor_sophiag(params: List[Tensor],
387 | grads: List[Tensor],
388 | exp_avgs: List[Tensor],
389 | hessian: List[Tensor],
390 | state_steps: List[Tensor],
391 | *,
392 | bs: int,
393 | beta1: float,
394 | beta2: float,
395 | rho: float,
396 | lr: float,
397 | weight_decay: float,
398 | maximize: bool,
399 | capturable: bool):
400 | for i, param in enumerate(params):
401 | grad = grads[i] if not maximize else -grads[i]
402 | exp_avg = exp_avgs[i]
403 | hess = hessian[i]
404 | step_t = state_steps[i]
405 |
406 | if capturable:
407 | assert param.is_cuda and step_t.is_cuda and bs.is_cuda
408 |
409 | if torch.is_complex(param):
410 | grad = torch.view_as_real(grad)
411 | exp_avg = torch.view_as_real(exp_avg)
412 | hess = torch.view_as_real(hess)
413 | param = torch.view_as_real(param)
414 |
415 | # update step
416 | step_t += 1
417 |
418 | # Perform stepweight decay
419 | param.mul_(1 - lr * weight_decay)
420 |
421 | # Decay the first and second moment running average coefficient
422 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
423 |
424 | if capturable:
425 | step = step_t
426 | step_size = lr
427 | step_size_neg = step_size.neg()
428 |
429 | ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1)
430 | param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
431 | else:
432 | step = step_t.item()
433 | step_size_neg = - lr
434 |
435 | ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1)
436 | param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
--------------------------------------------------------------------------------
/util/samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import torch
4 | import torch.distributed as dist
5 | import math
6 |
7 |
8 | class RASampler(torch.utils.data.Sampler):
9 | """Sampler that restricts data loading to a subset of the dataset for distributed,
10 | with repeated augmentation.
11 | It ensures that different each augmented version of a sample will be visible to a
12 | different process (GPU)
13 | Heavily based on torch.util.data.DistributedSampler
14 | """
15 |
16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
17 | if num_replicas is None:
18 | if not dist.is_available():
19 | raise RuntimeError("Requires distributed package to be available")
20 | num_replicas = dist.get_world_size()
21 | if rank is None:
22 | if not dist.is_available():
23 | raise RuntimeError("Requires distributed package to be available")
24 | rank = dist.get_rank()
25 | if num_repeats < 1:
26 | raise ValueError("num_repeats should be greater than 0")
27 | self.dataset = dataset
28 | self.num_replicas = num_replicas
29 | self.rank = rank
30 | self.num_repeats = num_repeats
31 | self.epoch = 0
32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
33 | self.total_size = self.num_samples * self.num_replicas
34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
36 | self.shuffle = shuffle
37 |
38 | def __iter__(self):
39 | if self.shuffle:
40 | # deterministically shuffle based on epoch
41 | g = torch.Generator()
42 | g.manual_seed(self.epoch)
43 | indices = torch.randperm(len(self.dataset), generator=g)
44 | else:
45 | indices = torch.arange(start=0, end=len(self.dataset))
46 |
47 | # add extra samples to make it evenly divisible
48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
49 | padding_size: int = self.total_size - len(indices)
50 | if padding_size > 0:
51 | indices += indices[:padding_size]
52 | assert len(indices) == self.total_size
53 |
54 | # subsample
55 | indices = indices[self.rank:self.total_size:self.num_replicas]
56 | assert len(indices) == self.num_samples
57 |
58 | return iter(indices[:self.num_selected_samples])
59 |
60 | def __len__(self):
61 | return self.num_selected_samples
62 |
63 | def set_epoch(self, epoch):
64 | self.epoch = epoch
65 |
66 |
67 |
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc functions, including distributed helpers and model loaders
3 | Also include a model loader specified for finetuning EfficientViT
4 | """
5 | import io
6 | import os
7 | import time
8 | from collections import defaultdict, deque
9 | import datetime
10 | import logging
11 | import torch
12 | import torch.nn as nn
13 | import torch.distributed as dist
14 | import safetensors
15 |
16 |
17 |
18 | logger = logging.getLogger()
19 |
20 |
21 | class AverageMeter:
22 | """Computes and stores the average and current value"""
23 | def __init__(self):
24 | self.reset()
25 |
26 | def reset(self):
27 | self.val = 0
28 | self.avg = 0
29 | self.sum = 0
30 | self.count = 0
31 |
32 | def update(self, val, n=1):
33 | self.val = val
34 | self.sum += val * n
35 | self.count += n
36 | self.avg = self.sum / self.count
37 |
38 |
39 | class SmoothedValue(object):
40 | """Track a series of values and provide access to smoothed values over a
41 | window or the global series average.
42 | """
43 |
44 | def __init__(self, window_size=20, fmt=None):
45 | if fmt is None:
46 | fmt = "{median:.4f} ({global_avg:.4f})"
47 | self.deque = deque(maxlen=window_size)
48 | self.total = 0.0
49 | self.count = 0
50 | self.fmt = fmt
51 |
52 | def update(self, value, n=1):
53 | self.deque.append(value)
54 | self.count += n
55 | self.total += value * n
56 |
57 | def synchronize_between_processes(self):
58 | """
59 | Warning: does not synchronize the deque!
60 | """
61 | if not is_dist_avail_and_initialized():
62 | return
63 | t = torch.tensor([self.count, self.total],
64 | dtype=torch.float64, device='cuda')
65 | dist.barrier()
66 | dist.all_reduce(t)
67 | t = t.tolist()
68 | self.count = int(t[0])
69 | self.total = t[1]
70 |
71 | @property
72 | def median(self):
73 | d = torch.tensor(list(self.deque))
74 | return d.median().item()
75 |
76 | @property
77 | def avg(self):
78 | d = torch.tensor(list(self.deque), dtype=torch.float32)
79 | return d.mean().item()
80 |
81 | @property
82 | def global_avg(self):
83 | return self.total / self.count
84 |
85 | @property
86 | def max(self):
87 | return max(self.deque)
88 |
89 | @property
90 | def value(self):
91 | return self.deque[-1]
92 |
93 | def __str__(self):
94 | return self.fmt.format(
95 | median=self.median,
96 | avg=self.avg,
97 | global_avg=self.global_avg,
98 | max=self.max,
99 | value=self.value)
100 |
101 |
102 | class MetricLogger(object):
103 | def __init__(self, delimiter="\t"):
104 | self.meters = defaultdict(SmoothedValue)
105 | self.delimiter = delimiter
106 |
107 | def update(self, **kwargs):
108 | for k, v in kwargs.items():
109 | if isinstance(v, torch.Tensor):
110 | v = v.item()
111 | assert isinstance(v, (float, int))
112 | self.meters[k].update(v)
113 |
114 | def __getattr__(self, attr):
115 | if attr in self.meters:
116 | return self.meters[attr]
117 | if attr in self.__dict__:
118 | return self.__dict__[attr]
119 | raise AttributeError("'{}' object has no attribute '{}'".format(
120 | type(self).__name__, attr))
121 |
122 | def __str__(self):
123 | loss_str = []
124 | for name, meter in self.meters.items():
125 | loss_str.append(
126 | "{}: {}".format(name, str(meter))
127 | )
128 | return self.delimiter.join(loss_str)
129 |
130 | def synchronize_between_processes(self):
131 | for meter in self.meters.values():
132 | meter.synchronize_between_processes()
133 |
134 | def add_meter(self, name, meter):
135 | self.meters[name] = meter
136 |
137 | def log_every(self, iterable, print_freq, header=None):
138 | i = 0
139 | if not header:
140 | header = ''
141 | start_time = time.time()
142 | end = time.time()
143 | iter_time = SmoothedValue(fmt='{avg:.4f}')
144 | data_time = SmoothedValue(fmt='{avg:.4f}')
145 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
146 | log_msg = [
147 | header,
148 | '[{0' + space_fmt + '}/{1}]',
149 | 'eta: {eta}',
150 | '{meters}',
151 | 'time: {time}',
152 | 'data: {data}'
153 | ]
154 | if torch.cuda.is_available():
155 | log_msg.append('max mem: {memory:.0f}')
156 | log_msg = self.delimiter.join(log_msg)
157 | MB = 1024.0 * 1024.0
158 | for obj in iterable:
159 | data_time.update(time.time() - end)
160 | yield obj
161 | iter_time.update(time.time() - end)
162 | if i % print_freq == 0 or i == len(iterable) - 1:
163 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
164 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
165 | if torch.cuda.is_available():
166 | print(log_msg.format(
167 | i, len(iterable), eta=eta_string,
168 | meters=str(self),
169 | time=str(iter_time), data=str(data_time),
170 | memory=torch.cuda.max_memory_allocated() / MB))
171 | else:
172 | print(log_msg.format(
173 | i, len(iterable), eta=eta_string,
174 | meters=str(self),
175 | time=str(iter_time), data=str(data_time)))
176 | i += 1
177 | end = time.time()
178 | total_time = time.time() - start_time
179 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
180 | print('{} Total time: {} ({:.4f} s / it)'.format(
181 | header, total_time_str, total_time / len(iterable)))
182 |
183 | def _load_checkpoint_for_ema(model_ema, checkpoint):
184 | """
185 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
186 | """
187 | mem_file = io.BytesIO()
188 | torch.save(checkpoint, mem_file)
189 | mem_file.seek(0)
190 | model_ema._load_checkpoint(mem_file)
191 |
192 | def setup_for_distributed(is_master):
193 | """
194 | This function disables printing when not in master process
195 | """
196 | import builtins as __builtin__
197 | builtin_print = __builtin__.print
198 |
199 | def print(*args, **kwargs):
200 | force = kwargs.pop('force', False)
201 | if is_master or force:
202 | builtin_print(*args, **kwargs)
203 |
204 | __builtin__.print = print
205 |
206 | def is_dist_avail_and_initialized():
207 | if not dist.is_available():
208 | return False
209 | if not dist.is_initialized():
210 | return False
211 | return True
212 |
213 | def get_world_size():
214 | if not is_dist_avail_and_initialized():
215 | return 1
216 | return dist.get_world_size()
217 |
218 | def get_rank():
219 | if not is_dist_avail_and_initialized():
220 | return 0
221 | return dist.get_rank()
222 |
223 | def is_main_process():
224 | return get_rank() == 0
225 |
226 | def save_on_master(*args, **kwargs):
227 | if is_main_process():
228 | torch.save(*args, **kwargs)
229 |
230 | def init_distributed_mode(args):
231 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
232 | args.rank = int(os.environ["RANK"])
233 | args.world_size = int(os.environ['WORLD_SIZE'])
234 | args.gpu = int(os.environ['LOCAL_RANK'])
235 | elif 'SLURM_PROCID' in os.environ:
236 | args.rank = int(os.environ['SLURM_PROCID'])
237 | args.gpu = args.rank % torch.cuda.device_count()
238 | else:
239 | print('Not using distributed mode')
240 | args.distributed = False
241 | return
242 |
243 | args.distributed = True
244 |
245 | torch.cuda.set_device(args.gpu)
246 | args.dist_backend = 'nccl'
247 | print('| distributed init (rank {}): {}'.format(
248 | args.rank, args.dist_url), flush=True)
249 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
250 | world_size=args.world_size, rank=args.rank)
251 | torch.distributed.barrier()
252 | setup_for_distributed(args.rank == 0)
253 |
254 |
255 | def replace_batchnorm(net):
256 | for child_name, child in net.named_children():
257 | if hasattr(child, 'fuse'):
258 | setattr(net, child_name, child.fuse())
259 | elif isinstance(child, torch.nn.BatchNorm2d):
260 | setattr(net, child_name, torch.nn.Identity())
261 | else:
262 | replace_batchnorm(child)
263 |
264 | def replace_layernorm(net):
265 | import apex
266 | for child_name, child in net.named_children():
267 | if isinstance(child, torch.nn.LayerNorm):
268 | setattr(net, child_name, apex.normalization.FusedLayerNorm(
269 | child.weight.size(0)))
270 | else:
271 | replace_layernorm(child)
272 |
273 | def load_model(modelpath, model: nn.Module):
274 | '''
275 | A function to load model from a checkpoint, which is used
276 | for fine-tuning on a different resolution.
277 | '''
278 | if 'safetensors' in modelpath:
279 | checkpoint = safetensors.torch.load_file(modelpath)
280 | else:
281 | checkpoint = torch.load(modelpath, map_location='cpu')
282 | return checkpoint
283 |
284 |
285 | def map_safetensors(safetensor_ckpt, model_state_dict):
286 | '''
287 | A function to load model from a safetensor file, which is used
288 | for fine-tuning on a different resolution.
289 | '''
290 | safetensors_keys = list(safetensor_ckpt.keys())
291 | key_mapping = {}
292 | mismatched_keys = []
293 |
294 | for model_key in model_state_dict.keys():
295 | # 尝试在 safetensors_keys 中找到与模型键类似的键
296 | for safetensor_key in safetensors_keys:
297 | if model_key.split('.')[-1] == safetensor_key.split('.')[-1] and \
298 | model_state_dict[model_key].shape == safetensor_ckpt[safetensor_key].shape:
299 | key_mapping[model_key] = safetensor_key
300 | # print(f"Mapping model layer '{model_key}' to safetensors layer '{safetensor_key}'")
301 | break
302 | else:
303 | # 如果没有找到匹配的层,则记录为不匹配
304 | mismatched_keys.append(model_key)
305 |
306 | # 显示所有未匹配的模型键
307 | if mismatched_keys:
308 | print("\nUnmatched model keys:")
309 | for key in mismatched_keys:
310 | print(key)
311 |
312 | # 创建一个新的 state_dict,将 safetensors 文件中的权重映射到模型中
313 | mapped_state_dict = {}
314 | for model_key, safetensor_key in key_mapping.items():
315 | mapped_state_dict[model_key] = safetensor_ckpt[safetensor_key]
316 |
317 | return mapped_state_dict
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import json
2 | import matplotlib
3 | import numpy as np
4 |
5 | from torchvision import transforms
6 | from PIL import Image
7 |
8 | import torch
9 | from matplotlib import pyplot as plt
10 | import os
11 | from timm.models import create_model
12 |
13 | import urllib.request
14 |
15 |
16 | device = 'cuda'
17 |
18 |
19 | def download_from_url(url, path=None, root="./"):
20 | if path is None:
21 | _, filename = os.path.split(url)
22 | root = os.path.abspath(root)
23 | path = os.path.join(root, filename)
24 | urllib.request.urlretrieve(url, path)
25 | print(f"Downloaded file to {path}")
26 |
27 |
28 | def load_class_names(json_path):
29 | with open(json_path, "r") as f:
30 | return list(json.load(f).values())
31 |
32 | def preprocess_image(image_path):
33 | image = Image.open(image_path).convert("RGB")
34 | transform = transforms.Compose([
35 | transforms.Resize((224, 224), interpolation=3),
36 | transforms.ToTensor(),
37 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalized
38 | ])
39 | return transform(image)
40 |
41 |
42 | @torch.inference_mode()
43 | def predict_probs_for_image(model, image_path):
44 | image = preprocess_image(image_path).unsqueeze(0) # add batch dim
45 | model.eval()
46 | outputs = model(image.to(device))
47 | probs = torch.nn.functional.softmax(outputs, dim=1).cpu()
48 | return (probs[0] * 100).tolist()
49 |
50 |
51 | def plot_probs(texts, probs, fig_ax, lang_type=None, save_path=None):
52 | # reverse the order to plot from top to bottom
53 | sorted_indices = np.argsort(probs)
54 | texts = np.array(texts)[sorted_indices]
55 | probs = np.array(probs)[sorted_indices]
56 | if fig_ax is None:
57 | fig, ax = plt.subplots(figsize=(6, 3))
58 | else:
59 | fig, ax = fig_ax
60 |
61 | font_prop = matplotlib.font_manager.FontProperties(
62 | fname=lang_type_to_font_path(lang_type)
63 | )
64 | ax.barh(texts, probs, color="darkslateblue", height=0.3)
65 | ax.barh(texts, 100 - probs, color="silver", height=0.3, left=probs)
66 | for bar, label, val in zip(ax.patches, texts, probs):
67 | ax.text(
68 | 0,
69 | bar.get_y() - bar.get_height(),
70 | label,
71 | color="black",
72 | ha="left",
73 | va="center",
74 | fontproperties=font_prop,
75 | )
76 | ax.text(
77 | bar.get_x() + bar.get_width() + 1,
78 | bar.get_y() + bar.get_height() / 2,
79 | f"{val:.2f} %",
80 | fontweight="bold",
81 | ha="left",
82 | va="center",
83 | )
84 |
85 | ax.axis("off")
86 |
87 | if save_path:
88 | plt.savefig(save_path, bbox_inches="tight") # 保存图片并移除多余空白
89 | print(f"Figure saved to {save_path}")
90 |
91 |
92 | def predict_probs_and_plot(
93 | model, image_path, texts, plot_image=True, fig_ax=None, lang_type=None
94 | ):
95 | if plot_image:
96 | fig, (ax_1, ax_2) = plt.subplots(1, 2, figsize=(12, 6))
97 | image = Image.open(image_path).convert('RGB')
98 | ax_1.imshow(image)
99 | ax_1.axis("off")
100 | probs = predict_probs_for_image(model, image_path)
101 | plot_probs(texts, probs, (fig, ax_2), lang_type=lang_type, save_path='./prediction_probs.png')
102 |
103 |
104 | def lang_type_to_font_path(lang_type):
105 | mapping = {
106 | None: "https://cdn.jsdelivr.net/gh/notofonts/notofonts.github.io/fonts/NotoSans/hinted/ttf/NotoSans-Regular.ttf",
107 | "cjk": "https://github.com/notofonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansCJKsc-Regular.otf",
108 | "devanagari": "https://cdn.jsdelivr.net/gh/notofonts/notofonts.github.io/fonts/NotoSansDevanagari/hinted/ttf/NotoSansDevanagari-Regular.ttf",
109 | "emoji": "https://github.com/MorbZ/OpenSansEmoji/raw/master/OpenSansEmoji.ttf",
110 | }
111 | return download_from_url(mapping[lang_type])
112 |
113 | if __name__ == '__main__':
114 | model = create_model(
115 | 'mobilenetv4_conv_large'
116 | )
117 | model.reset_classifier(num_classes=5)
118 | model.load_state_dict(torch.load('./output/mobilenetv4_conv_large_best_checkpoint.pth')['model'])
119 | model.to(device)
120 |
121 | texts = load_class_names('./classes_indices.json')
122 | # image_path = r'D:/flower_data/roses/1666341535_99c6f7509f_n.jpg'
123 | image_path = r'D:/flower_data/sunflowers/44079668_34dfee3da1_n.jpg'
124 | predict_probs_and_plot(model, image_path, texts)
--------------------------------------------------------------------------------
/weight_converter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def transpose_weights(weights):
4 | if len(weights.shape) <= 1:
5 | return weights
6 | if len(weights.shape) == 2:
7 | return weights.T
8 | if len(weights.shape) == 3:
9 | return np.transpose(weights, [2, 1, 0])
10 | else:
11 | raise ValueError("Unknown weights shape : {}".format(weights.shape))
12 |
13 | """ Pytorch to Tensorflow convertion """
14 |
15 | def get_pt_layers(pt_model):
16 | layers = {}
17 | state_dict = pt_model.state_dict() if not isinstance(pt_model, dict) else pt_model
18 | for k, v in state_dict.items():
19 | layer_name = '.'.join(k.split('.')[:-1])
20 | if layer_name not in layers: layers[layer_name] = []
21 | layers[layer_name].append(v.cpu().numpy())
22 | return layers
23 |
24 | def pt_convert_layer_weights(layer_weights):
25 | new_weights = []
26 | if len(layer_weights) < 4:
27 | new_weights = layer_weights
28 | elif len(layer_weights) == 4:
29 | new_weights = layer_weights[:2] + [layer_weights[2] + layer_weights[3]]
30 | elif len(layer_weights) == 5:
31 | new_weights = layer_weights[:4]
32 | elif len(layer_weights) == 8:
33 | new_weights = layer_weights[:2] + [layer_weights[2] + layer_weights[3]]
34 | new_weights += layer_weights[4:6] + [layer_weights[6] + layer_weights[7]]
35 | else:
36 | raise ValueError("Unknown weights length : {}\n Shapes : {}".format(len(layer_weights), [tuple(v.shape) for v in layer_weights]))
37 |
38 | return [transpose_weights(w) for w in new_weights]
39 |
40 | def pt_convert_model_weights(pt_model, tf_model, verbose = False):
41 | pt_layers = get_pt_layers(pt_model)
42 | converted_weights = []
43 | for layer_name, layer_variables in pt_layers.items():
44 | converted_variables = pt_convert_layer_weights(layer_variables) if 'embedding' not in layer_name else layer_variables
45 | converted_weights += converted_variables
46 |
47 | if verbose:
48 | print("Layer : {} \t {} \t {}".format(
49 | layer_name,
50 | [tuple(v.shape) for v in layer_variables],
51 | [tuple(v.shape) for v in converted_variables],
52 | ))
53 |
54 | partial_transfert_learning(tf_model, converted_weights)
55 | print("Weights converted successfully !")
56 |
57 |
58 | """ Tensorflow to Pytorch converter """
59 |
60 | def get_tf_layers(tf_model):
61 | layers = {}
62 | variables = tf_model.variables if not isinstance(tf_model, list) else tf_model
63 | for v in variables:
64 | layer_name = '/'.join(v.name.split('/')[:-1])
65 | if layer_name not in layers: layers[layer_name] = []
66 | layers[layer_name].append(v.numpy())
67 | return layers
68 |
69 | def tf_convert_layer_weights(layer_weights):
70 | new_weights = []
71 | if len(layer_weights) < 3 or len(layer_weights) == 4:
72 | new_weights = layer_weights
73 | elif len(layer_weights) == 3:
74 | new_weights = layer_weights[:2] + [layer_weights[2] / 2., layer_weights[2] / 2.]
75 | else:
76 | raise ValueError("Unknown weights length : {}\n Shapes : {}".format(len(layer_weights), [tuple(v.shape) for v in layer_weights]))
77 |
78 | return [transpose_weights(w) for w in new_weights]
79 |
80 |
81 | def tf_convert_model_weights(tf_model, pt_model, verbose = False):
82 | import torch
83 |
84 | pt_layers = pt_model.state_dict()
85 | tf_layers = get_tf_layers(tf_model)
86 | converted_weights = []
87 | for layer_name, layer_variables in tf_layers.items():
88 | converted_variables = tf_convert_layer_weights(layer_variables) if 'embedding' not in layer_name else layer_variables
89 | converted_weights += converted_variables
90 |
91 | if verbose:
92 | print("Layer : {} \t {} \t {}".format(
93 | layer_name,
94 | [tuple(v.shape) for v in layer_variables],
95 | [tuple(v.shape) for v in converted_variables],
96 | ))
97 |
98 | tf_idx = 0
99 | for i, (pt_name, pt_weights) in enumerate(pt_layers.items()):
100 | if len(pt_weights.shape) == 0: continue
101 |
102 | pt_weights.data = torch.from_numpy(converted_weights[tf_idx])
103 | tf_idx += 1
104 |
105 | pt_model.load_state_dict(pt_layers)
106 | print("Weights converted successfully !")
107 |
108 | """ Partial transfert learning """
109 |
110 | def partial_transfert_learning(target_model,
111 | pretrained_model,
112 | partial_transfert = True,
113 | partial_initializer = 'normal_conditionned'
114 | ):
115 | """
116 | Make transfert learning on model with either :
117 | - different number of layers (and same shapes for some layers)
118 | - different shapes (and same number of layers)
119 |
120 | Arguments :
121 | - target_model : tf.keras.Model instance (model where weights will be transfered to)
122 | - pretrained_model : tf.keras.Model or list of weights (pretrained)
123 | - partial_transfert : whether to do partial transfert for layers with different shapes (only relevant if 2 models have same number of layers)
124 | """
125 | assert partial_initializer in (None, 'zeros', 'ones', 'normal', 'normal_conditionned')
126 | def partial_weight_transfert(target, pretrained_v):
127 | v = target
128 | if partial_initializer == 'zeros':
129 | v = np.zeros_like(target)
130 | elif partial_initializer == 'ones':
131 | v = np.ones_like(target)
132 | elif partial_initializer == 'normal_conditionned':
133 | v = np.random.normal(loc = np.mean(pretrained_v), scale = np.std(pretrained_v), size = target.shape)
134 | elif partial_initializer == 'normal':
135 | v = np.random.normal(size = target.shape)
136 |
137 |
138 | if v.ndim == 1:
139 | max_0 = min(v.shape[0], pretrained_v.shape[0])
140 | v[:max_0] = pretrained_v[:max_0]
141 | elif v.ndim == 2:
142 | max_0 = min(v.shape[0], pretrained_v.shape[0])
143 | max_1 = min(v.shape[1], pretrained_v.shape[1])
144 | v[:max_0, :max_1] = pretrained_v[:max_0, :max_1]
145 | elif v.ndim == 3:
146 | max_0 = min(v.shape[0], pretrained_v.shape[0])
147 | max_1 = min(v.shape[1], pretrained_v.shape[1])
148 | max_2 = min(v.shape[2], pretrained_v.shape[2])
149 | v[:max_0, :max_1, :max_2] = pretrained_v[:max_0, :max_1, :max_2]
150 | elif v.ndim == 4:
151 | max_0 = min(v.shape[0], pretrained_v.shape[0])
152 | max_1 = min(v.shape[1], pretrained_v.shape[1])
153 | max_2 = min(v.shape[2], pretrained_v.shape[2])
154 | max_3 = min(v.shape[3], pretrained_v.shape[3])
155 | v[:max_0, :max_1, :max_2, :max_3] = pretrained_v[:max_0, :max_1, :max_2, :max_3]
156 | else:
157 | raise ValueError("Variable dims > 4 non géré !")
158 |
159 | return v
160 |
161 | target_variables = target_model.variables
162 | pretrained_variables = pretrained_model.variables if not isinstance(pretrained_model, list) else pretrained_model
163 |
164 | skip_layer = len(target_variables) != len(pretrained_variables)
165 | skip_from_a = None
166 | if skip_layer:
167 | skip_from_a = (len(target_variables) > len(pretrained_variables))
168 |
169 | new_weights = []
170 | idx_a, idx_b = 0, 0
171 | while idx_a < len(target_variables) and idx_b < len(pretrained_variables):
172 | v, pretrained_v = target_variables[idx_a], pretrained_variables[idx_b]
173 | v = v.numpy()
174 | if not isinstance(pretrained_v, np.ndarray) : pretrained_v = pretrained_v.numpy()
175 |
176 | if v.shape != pretrained_v.shape and skip_layer:
177 | if skip_from_a:
178 | idx_a += 1
179 | new_weights.append(v)
180 | else: idx_b += 1
181 | continue
182 |
183 | if len(v.shape) != len(pretrained_v.shape):
184 | raise ValueError("Le nombre de dimension des variables {} est différent !\n Target shape : {}\n Pretrained shape : {}".format(idx_a, v.shape, pretrained_v.shape))
185 |
186 | new_v = None
187 | if v.shape == pretrained_v.shape:
188 | new_v = pretrained_v
189 | elif not partial_transfert:
190 | print("Variables {} shapes mismatch ({} vs {}), skipping it".format(idx_a, v.shape, pretrained_v.shape))
191 |
192 | new_v = v
193 | else:
194 | print("Variables {} shapes mismatch ({} vs {}), making partial transfert".format(idx_a, v.shape, pretrained_v.shape))
195 |
196 | new_v = partial_weight_transfert(v, pretrained_v)
197 |
198 | new_weights.append(new_v)
199 | idx_a, idx_b = idx_a + 1, idx_b + 1
200 |
201 | if idx_a != len(target_variables) or idx_b != len(pretrained_variables):
202 | raise ValueError("All variables of a model have not been consummed\n Model A : length : {} - variables consummed : {}\n Model B (pretrained) : length : {} - variables consummed : {}".format(len(target_variables), idx_a, len(pretrained_variables), idx_b))
203 |
204 | target_model.set_weights(new_weights)
205 | print("Weights transfered successfully !")
206 |
--------------------------------------------------------------------------------