├── .gitignore ├── LICENSE ├── README.md ├── cal_inference_time.py ├── datasets ├── __init__.py ├── mydataset.py ├── split_data.py └── threeaugment.py ├── environment.yml ├── estimate_model.py ├── models ├── __init__.py ├── blocks.py ├── build_mobilenet_v4.py ├── extra_attention_block.py └── model_utils.py ├── onnx_export.py ├── onnx_optimise.py ├── onnx_validate.py ├── optim_AUC.py ├── prediction_probs.png ├── sample_png └── mobilenetV4.jpg ├── train_gpu.py ├── util ├── __init__.py ├── engine.py ├── losses.py ├── optimizer.py ├── samplers.py └── utils.py ├── visualize.py └── weight_converter.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | 7 | # ckpt files 8 | *.safetensors 9 | *.bin 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 奔波儿灞 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

MobileNetV4

2 | 3 | # [MobileNetV4 -- Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) 4 | ## This project is implemented in PyTorch, can be used to train your image-datasets for vision tasks. 5 | ## [official source code](https://github.com/tensorflow/models/blob/master/official/vision/modeling/backbones/mobilenet.py) 6 | ## For segmentation tasks, please refer this [github warehouse](https://github.com/jiaowoguanren0615/Segmentation_Factory/blob/main/models/backbones/mobilenetv4.py) 7 | ## For detection tasks(___Based on DETR Detector architecture___), please refer this [github warehouse](https://github.com/jiaowoguanren0615/Detection-Factory/blob/main/configs/salience_detr_mobilenetv4_medium_800_1333.py) 8 | ![image](https://github.com/jiaowoguanren0615/MobileNetV4/blob/main/sample_png/mobilenetV4.jpg) 9 | 10 | 11 | 12 | ## Preparation 13 | 14 | ### Create conda virtual-environment 15 | ```bash 16 | conda env create -f environment.yml 17 | ``` 18 | 19 | ### Download the dataset: 20 | [flower_dataset](https://www.kaggle.com/datasets/alxmamaev/flowers-recognition). 21 | 22 | ## Project Structure 23 | ``` 24 | ├── datasets: Load datasets 25 | ├── my_dataset.py: Customize reading data sets and define transforms data enhancement methods 26 | ├── split_data.py: Define the function to read the image dataset and divide the training-set and test-set 27 | ├── threeaugment.py: Additional data augmentation methods 28 | ├── models: MobileNetV4 Model 29 | ├── build_mobilenet_v4.py: Construct MobileNetV4 models 30 | ├── extra_attention_block.py: MultiScaleAttentionGate module 31 | ├── util: 32 | ├── engine.py: Function code for a training/validation process 33 | ├── losses.py: Knowledge distillation loss, combined with teacher model (if any) 34 | ├── optimizer.py: Define Sophia/MARS optimizer 35 | ├── samplers.py: Define the parameter of "sampler" in DataLoader 36 | ├── utils.py: Record various indicator information and output and distributed environment 37 | ├── estimate_model.py: Visualized evaluation indicators ROC curve, confusion matrix, classification report, etc. 38 | └── train_gpu.py: Training model startup file (including infer process) 39 | ``` 40 | 41 | ## Precautions 42 | Before you use the code to train your own data set, please first enter the ___train_gpu.py___ file and modify the ___data_root___, ___batch_size___, ___num_workers___ and ___nb_classes___ parameters. If you want to draw the confusion matrix and ROC curve, you only need to set the ___predict___ parameter to __True__. 43 | If you want to add an extra MSAG(MultiScaleAttentionGate) module, set the __extra_attention_block__ parameter to True. 44 | Moreover, you can set the ___opt_auc___ parameter to True if you want to optimize your model for a better performance(maybe~). 45 | 46 | ## Use Sophia Optimizer (in util/optimizer.py) 47 | You can use anther optimizer sophia, just need to change the optimizer in ___train_gpu.py___, for this training sample, can achieve better results 48 | ``` 49 | # optimizer = create_optimizer(args, model_without_ddp) 50 | optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=args.weight_decay) 51 | ``` 52 | 53 | ## Train this model 54 | 55 | ### Parameters Meaning: 56 | ``` 57 | 1. nproc_per_node: 58 | 2. CUDA_VISIBLE_DEVICES: 59 | 3. nnodes: 60 | 4. node_rank: 61 | 5. master_addr: 62 | 6. master_port: 63 | ``` 64 | ### Transfer Learning: 65 | Step 1: Download the [pretrained-weights](https://huggingface.co/timm/mobilenetv4_conv_large.e500_r256_in1k#model-comparison) 66 | Step 2: Write the ___pre-training weight path___ into the ___args.finetune___ in string format. Adjust ___args.input_size___ parameter based on the model pre-trained on images of different sizes. 67 | Step 3: Modify the ___args.freeze_layers___ according to your own GPU memory. If you don't have enough memory, you can set this to True to freeze the weights of the remaining layers except the last layer of classification-head without updating the parameters. If you have enough memory, you can set this to False and not freeze the model weights. 68 | 69 | #### Here is an example for setting parameters: 70 | ![image](https://github.com/jiaowoguanren0615/VisionTransformer/blob/main/sample_png/transfer_learning.jpg) 71 | 72 | ### Note: 73 | If you want to use multiple GPU for training, whether it is a single machine with multiple GPUs or multiple machines with multiple GPUs, each GPU will divide the batch_size equally. For example, batch_size=4 in my train_gpu.py. If I want to use 2 GPUs for training, it means that the batch_size on each GPU is 4. ___Do not let batch_size=1 on each GPU___, otherwise BN layer maybe report an error. 74 | 75 | ### train model with single-machine single-GPU: 76 | ``` 77 | python train_gpu.py 78 | ``` 79 | 80 | ### train model with single-machine multi-GPU: 81 | ``` 82 | python -m torch.distributed.run --nproc_per_node=8 train_gpu.py 83 | ``` 84 | 85 | ### train model with single-machine multi-GPU: 86 | (using a specified part of the GPUs: for example, I want to use the second and fourth GPUs) 87 | ``` 88 | CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.run --nproc_per_node=2 train_gpu.py 89 | ``` 90 | 91 | ### train model with multi-machine multi-GPU: 92 | (For the specific number of GPUs on each machine, modify the value of --nproc_per_node. If you want to specify a certain GPU, just add CUDA_VISIBLE_DEVICES= to specify the index number of the GPU before each command. The principle is the same as single-machine multi-GPU training) 93 | ``` 94 | On the first machine: python -m torch.distributed.run --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr= --master_port= train_gpu.py 95 | 96 | On the second machine: python -m torch.distributed.run --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr= --master_port= train_gpu.py 97 | ``` 98 | 99 | ## ONNX Deployment 100 | ### step 1: ONNX export (modify the param of ___output___, ___model___ and ___checkpoint___) 101 | ```bash 102 | python onnx_export.py --model=mobilenetv4_small --output=./mobilenetv4_small.onnx --checkpoint=./output/mobilenetv4_small_best_checkpoint.pth 103 | ``` 104 | 105 | ### step2: ONNX optimise 106 | ```bash 107 | python onnx_optimise.py --model=mobilenetv4_small --output=./mobilenetv4_small_optim.onnx' 108 | ``` 109 | 110 | ### step3: ONNX validate (modify the param of ___data_root___ and ___onnx-input___) 111 | ```bash 112 | python onnx_validate.py --data_root=/mnt/d/flower_data --onnx-input=./mobilenetv4_small_optim.onnx 113 | ``` 114 | 115 | 116 | ## Citation 117 | ``` 118 | @article{qin2024mobilenetv4, 119 | title={MobileNetV4-Universal Models for the Mobile Ecosystem}, 120 | author={Qin, Danfeng and Leichner, Chas and Delakis, Manolis and Fornoni, Marco and Luo, Shixin and Yang, Fan and Wang, Weijun and Banbury, Colby and Ye, Chengxi and Akin, Berkin and others}, 121 | journal={arXiv preprint arXiv:2404.10518}, 122 | year={2024} 123 | } 124 | ``` 125 | 126 | ## Star History 127 | 128 | [![Star History Chart](https://api.star-history.com/svg?repos=jiaowoguanren0615/MobileNetV4&type=Date)](https://star-history.com/#jiaowoguanren0615/MobileNetV4&Date) 129 | -------------------------------------------------------------------------------- /cal_inference_time.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | import models 6 | import numpy as np 7 | from timm.models import create_model 8 | 9 | 10 | parser = argparse.ArgumentParser(description='PyTorch MobileNetV4 Inference Speed Test') 11 | # Model params 12 | parser.add_argument('--model', default='mobilenetv4_conv_large', type=str, metavar='MODEL', 13 | choices=['mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large_075', 14 | 'mobilenetv4_conv_large', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_medium', 15 | 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_medium_075', 16 | 'mobilenetv4_conv_small_035', 'mobilenetv4_conv_small_050', 'mobilenetv4_conv_blur_medium'], 17 | help='Name of model to train') 18 | parser.add_argument('--device', default='cuda', type=str) 19 | parser.add_argument('--batch-size', default=32, type=int, help='batch size (default: 32)') 20 | parser.add_argument('--img-size', default=224, type=int, 21 | metavar='N', help='Input image dimension, uses model default if empty') 22 | parser.add_argument('--nb-classes', type=int, default=5, 23 | help='Number classes in datasets') 24 | 25 | 26 | def do_pure_cpu_task(): 27 | x = np.random.randn(1, 3, 512, 512).astype(np.float32) 28 | x = x * 1024 ** 0.5 29 | 30 | 31 | @torch.inference_mode() 32 | def cal_time3(model, x, args): 33 | start_event = torch.cuda.Event(enable_timing=True) 34 | end_event = torch.cuda.Event(enable_timing=True) 35 | time_list = [] 36 | for _ in range(50): 37 | # do_pure_cpu_task() ## cpu warm up, not necessary 38 | start_event.record() 39 | ret = model(x) 40 | end_event.record() 41 | end_event.synchronize() 42 | time_list.append(start_event.elapsed_time(end_event) / 1000) 43 | 44 | print(f"{args.model} inference avg time: {sum(time_list[5:]) / len(time_list[5:]):.5f}") ## warm up, remove start 5 times 45 | 46 | 47 | def main(args): 48 | 49 | device = args.device 50 | model = create_model( 51 | args.model, 52 | num_classes=args.nb_classes 53 | ) 54 | model.eval().to(device) 55 | 56 | x = torch.randn(size=(args.batch_size, 3, args.img_size, args.img_size), device=device) 57 | cal_time3(model, x, args) 58 | 59 | 60 | if __name__ == '__main__': 61 | args = parser.parse_args() 62 | main(args) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mydataset import build_dataset, build_transform, MyDataset 2 | from .split_data import read_split_data 3 | from .threeaugment import new_data_aug_generator -------------------------------------------------------------------------------- /datasets/mydataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torchvision import transforms 4 | from .split_data import read_split_data 5 | from torch.utils.data import Dataset 6 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform 7 | 8 | 9 | class MyDataset(Dataset): 10 | def __init__(self, image_paths, image_labels, transforms=None): 11 | self.image_paths = image_paths 12 | self.image_labels = image_labels 13 | self.transforms = transforms 14 | 15 | def __getitem__(self, item): 16 | image = Image.open(self.image_paths[item]).convert('RGB') 17 | label = self.image_labels[item] 18 | if self.transforms: 19 | image = self.transforms(image) 20 | return image, label 21 | 22 | def __len__(self): 23 | return len(self.image_paths) 24 | 25 | @staticmethod 26 | def collate_fn(batch): 27 | images, labels = tuple(zip(*batch)) 28 | images = torch.stack(images, dim=0) 29 | labels = torch.as_tensor(labels) 30 | return images, labels 31 | 32 | 33 | 34 | def build_transform(is_train, args): 35 | resize_im = args.input_size > 32 36 | if is_train: 37 | # this should always dispatch to transforms_imagenet_train 38 | transform = create_transform( 39 | input_size=args.input_size, 40 | is_training=True, 41 | color_jitter=args.color_jitter, 42 | auto_augment=args.aa, 43 | interpolation=args.train_interpolation, 44 | re_prob=args.reprob, 45 | re_mode=args.remode, 46 | re_count=args.recount, 47 | ) 48 | if not resize_im: 49 | # replace RandomResizedCropAndInterpolation with 50 | # RandomCrop 51 | transform.transforms[0] = transforms.RandomCrop( 52 | args.input_size, padding=4) 53 | return transform 54 | 55 | t = [] 56 | if resize_im: 57 | # size = int((256 / 224) * args.input_size) 58 | size = int((1.0 / 0.96) * args.input_size) 59 | t.append( 60 | # to maintain same ratio w.r.t. 224 images 61 | transforms.Resize(size, interpolation=3), 62 | ) 63 | t.append(transforms.CenterCrop(args.input_size)) 64 | 65 | t.append(transforms.ToTensor()) 66 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 67 | return transforms.Compose(t) 68 | 69 | 70 | def build_dataset(args): 71 | train_image_path, train_image_label, val_image_path, val_image_label, class_indices = read_split_data(args.data_root) 72 | 73 | train_transform = build_transform(True, args) 74 | valid_transform = build_transform(False, args) 75 | 76 | train_set = MyDataset(train_image_path, train_image_label, train_transform) 77 | valid_set = MyDataset(val_image_path, val_image_label, valid_transform) 78 | 79 | return train_set, valid_set 80 | 81 | -------------------------------------------------------------------------------- /datasets/split_data.py: -------------------------------------------------------------------------------- 1 | import os, cv2, json, random 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from sklearn.model_selection import train_test_split 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def read_split_data(root, plot_image=False): 9 | filepaths = [] 10 | labels = [] 11 | bad_images = [] 12 | 13 | random.seed(0) 14 | assert os.path.exists(root), 'Your root does not exists!!!' 15 | 16 | classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] 17 | classes.sort() 18 | class_indices = {k: v for v, k in enumerate(classes)} 19 | 20 | json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4) 21 | 22 | with open('./classes_indices.json', 'w') as json_file: 23 | json_file.write(json_str) 24 | 25 | every_class_num = [] 26 | supported = ['.jpg', '.png', '.jpeg', '.PNG', '.JPG', '.JPEG'] 27 | 28 | for klass in classes: 29 | classpath = os.path.join(root, klass) 30 | images = [os.path.join(root, klass, i) for i in os.listdir(classpath) if os.path.splitext(i)[-1] in supported] 31 | every_class_num.append(len(images)) 32 | flist = sorted(os.listdir(classpath)) 33 | desc = f'{klass:23s}' 34 | for f in tqdm(flist, ncols=110, desc=desc, unit='file', colour='blue'): 35 | fpath = os.path.join(classpath, f) 36 | fl = f.lower() 37 | index = fl.rfind('.') 38 | ext = fl[index:] 39 | if ext in supported: 40 | try: 41 | img = cv2.imread(fpath) 42 | filepaths.append(fpath) 43 | labels.append(klass) 44 | except: 45 | bad_images.append(fpath) 46 | print('defective image file: ', fpath) 47 | else: 48 | bad_images.append(fpath) 49 | 50 | Fseries = pd.Series(filepaths, name='filepaths') 51 | Lseries = pd.Series(labels, name='labels') 52 | df = pd.concat([Fseries, Lseries], axis=1) 53 | 54 | print(f'{len(df.labels.unique())} kind of images were found in the dataset') 55 | train_df, test_df = train_test_split(df, train_size=.8, shuffle=True, random_state=123, stratify=df['labels']) 56 | 57 | train_image_path = train_df['filepaths'].tolist() 58 | val_image_path = test_df['filepaths'].tolist() 59 | 60 | train_image_label = [class_indices[i] for i in train_df['labels'].tolist()] 61 | val_image_label = [class_indices[i] for i in test_df['labels'].tolist()] 62 | 63 | sample_df = train_df.sample(n=50, replace=False) 64 | ht, wt, count = 0, 0, 0 65 | for i in range(len(sample_df)): 66 | fpath = sample_df['filepaths'].iloc[i] 67 | try: 68 | img = cv2.imread(fpath) 69 | h = img.shape[0] 70 | w = img.shape[1] 71 | ht += h 72 | wt += w 73 | count += 1 74 | except: 75 | pass 76 | have = int(ht / count) 77 | wave = int(wt / count) 78 | aspect_ratio = have / wave 79 | print('{} images were found in the dataset.\n{} for training, {} for validation'.format( 80 | sum(every_class_num), len(train_image_path), len(val_image_path) 81 | )) 82 | print('average image height= ', have, ' average image width= ', wave, ' aspect ratio h/w= ', aspect_ratio) 83 | 84 | if plot_image: 85 | plt.bar(range(len(classes)), every_class_num, align='center') 86 | plt.xticks(range(len(classes)), classes) 87 | 88 | for i, v in enumerate(every_class_num): 89 | plt.text(x=i, y=v + 5, s=str(v), ha='center') 90 | 91 | plt.xlabel('image class') 92 | plt.ylabel('number of images') 93 | 94 | plt.title('class distribution') 95 | plt.show() 96 | 97 | return train_image_path, train_image_label, val_image_path, val_image_label, class_indices -------------------------------------------------------------------------------- /datasets/threeaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3Augment implementation from (https://github.com/facebookresearch/deit/blob/main/augment.py) 3 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 4 | and timm DA(https://github.com/rwightman/pytorch-image-models) 5 | Can be called by adding "--ThreeAugment" to the command line 6 | """ 7 | import torch 8 | from timm.data.transforms import str_to_pil_interp, RandomResizedCropAndInterpolation 9 | from torchvision import transforms 10 | import random 11 | 12 | from PIL import ImageFilter, ImageOps 13 | 14 | 15 | class GaussianBlur(object): 16 | """ 17 | Apply Gaussian Blur to the PIL image. 18 | """ 19 | 20 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 21 | self.prob = p 22 | self.radius_min = radius_min 23 | self.radius_max = radius_max 24 | 25 | def __call__(self, img): 26 | do_it = random.random() <= self.prob 27 | if not do_it: 28 | return img 29 | 30 | img = img.filter( 31 | ImageFilter.GaussianBlur( 32 | radius=random.uniform(self.radius_min, self.radius_max) 33 | ) 34 | ) 35 | return img 36 | 37 | 38 | class Solarization(object): 39 | """ 40 | Apply Solarization to the PIL image. 41 | """ 42 | 43 | def __init__(self, p=0.2): 44 | self.p = p 45 | 46 | def __call__(self, img): 47 | if random.random() < self.p: 48 | return ImageOps.solarize(img) 49 | else: 50 | return img 51 | 52 | 53 | class gray_scale(object): 54 | """ 55 | Apply Solarization to the PIL image. 56 | """ 57 | 58 | def __init__(self, p=0.2): 59 | self.p = p 60 | self.transf = transforms.Grayscale(3) 61 | 62 | def __call__(self, img): 63 | if random.random() < self.p: 64 | return self.transf(img) 65 | else: 66 | return img 67 | 68 | 69 | class horizontal_flip(object): 70 | """ 71 | Apply Solarization to the PIL image. 72 | """ 73 | 74 | def __init__(self, p=0.2, activate_pred=False): 75 | self.p = p 76 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 77 | 78 | def __call__(self, img): 79 | if random.random() < self.p: 80 | return self.transf(img) 81 | else: 82 | return img 83 | 84 | 85 | def new_data_aug_generator(args=None): 86 | img_size = args.input_size 87 | remove_random_resized_crop = False 88 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 89 | primary_tfl = [] 90 | scale = (0.08, 1.0) 91 | interpolation = 'bicubic' 92 | if remove_random_resized_crop: 93 | primary_tfl = [ 94 | transforms.Resize(img_size, interpolation=3), 95 | transforms.RandomCrop(img_size, padding=4, padding_mode='reflect'), 96 | transforms.RandomHorizontalFlip() 97 | ] 98 | else: 99 | primary_tfl = [ 100 | RandomResizedCropAndInterpolation( 101 | img_size, scale=scale, interpolation=interpolation), 102 | transforms.RandomHorizontalFlip() 103 | ] 104 | 105 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 106 | Solarization(p=1.0), 107 | GaussianBlur(p=1.0)])] 108 | 109 | if args.color_jitter is not None and not args.color_jitter == 0: 110 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 111 | final_tfl = [ 112 | transforms.ToTensor(), 113 | transforms.Normalize( 114 | mean=torch.tensor(mean), 115 | std=torch.tensor(std)) 116 | ] 117 | return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: CV 2 | dependencies: 3 | - python=3.9 4 | - pip 5 | - pip: 6 | - onnx==1.13 7 | - onnxoptimizer==0.3.13 8 | - onnxruntime==1.18 9 | - matplotlib==3.5.1 10 | - numpy==1.23.0 11 | - opencv-contrib-python==4.7.0.72 12 | - opencv-python==4.7.0.72 13 | - openpyxl==3.1.2 14 | - pandas==1.5.3 15 | - pillow==9.3.0 16 | - terminaltables==3.1.10 17 | - plotly==5.14.1 18 | - scikit-learn==1.3.0 19 | - tensorboardx==2.6.2.2 20 | - timm==1.0.11 21 | - torch==2.0.1+cu118 22 | - torchaudio==2.0.2+cu118 23 | - torchinfo==1.7.2 24 | - torchvision==0.15.2+cu118 25 | - transformers==4.28.1 26 | - seaborn==0.13.2 27 | - safetensors==0.4.5 28 | - terminaltables==3.1.10 29 | -------------------------------------------------------------------------------- /estimate_model.py: -------------------------------------------------------------------------------- 1 | import torch, json, os 2 | import seaborn as sns 3 | from sklearn.metrics import auc, f1_score, roc_curve, classification_report, confusion_matrix, roc_auc_score 4 | from itertools import cycle 5 | from numpy import interp 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | from torchvision import transforms 10 | from typing import Iterable 11 | from optim_AUC import OptimizeAUC 12 | from terminaltables import AsciiTable 13 | 14 | 15 | @torch.inference_mode() 16 | def Plot_ROC(net: torch.nn.Module, val_loader: Iterable, save_name: str, device: torch.device): 17 | """ 18 | Plot ROC Curve 19 | 20 | Save the roc curve as an image file in the current directory 21 | 22 | Args: 23 | net (torch.nn.Module): The model to be evaluated. 24 | val_loader (Iterable): The data loader for the valid data. 25 | save_name (str): The file path of your model weights 26 | device (torch.device): The device used for training (CPU or GPU). 27 | 28 | Returns: 29 | None 30 | """ 31 | 32 | try: 33 | json_file = open('./classes_indices.json', 'r') 34 | class_indict = json.load(json_file) 35 | except Exception as e: 36 | print(e) 37 | exit(-1) 38 | 39 | score_list = [] 40 | label_list = [] 41 | 42 | net.load_state_dict(torch.load(save_name)['model']) 43 | 44 | for i, data in enumerate(val_loader): 45 | images, labels = data 46 | images, labels = images.to(device), labels.to(device) 47 | outputs = torch.softmax(net(images), dim=1) 48 | score_tmp = outputs 49 | score_list.extend(score_tmp.detach().cpu().numpy()) 50 | label_list.extend(labels.cpu().numpy()) 51 | 52 | score_array = np.array(score_list) 53 | # convert label to one-hot form 54 | label_tensor = torch.tensor(label_list) 55 | label_tensor = label_tensor.reshape((label_tensor.shape[0], 1)) 56 | label_onehot = torch.zeros(label_tensor.shape[0], len(class_indict.keys())) 57 | label_onehot.scatter_(dim=1, index=label_tensor, value=1) 58 | label_onehot = np.array(label_onehot) 59 | 60 | print("score_array:", score_array.shape) # (batchsize, classnum) 61 | print("label_onehot:", label_onehot.shape) # torch.Size([batchsize, classnum]) 62 | 63 | # compute tpr and fpr for each label by using sklearn lib 64 | fpr_dict = dict() 65 | tpr_dict = dict() 66 | roc_auc_dict = dict() 67 | for i in range(len(class_indict.keys())): 68 | fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score_array[:, i]) 69 | roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i]) 70 | # micro 71 | fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score_array.ravel()) 72 | roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"]) 73 | 74 | # macro 75 | # First aggregate all false positive rates 76 | all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(len(class_indict.keys()))])) 77 | # Then interpolate all ROC curves at this points 78 | mean_tpr = np.zeros_like(all_fpr) 79 | 80 | for i in range(len(set(label_list))): 81 | mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) 82 | 83 | # Finally average it and compute AUC 84 | mean_tpr /= len(class_indict.keys()) 85 | fpr_dict["macro"] = all_fpr 86 | tpr_dict["macro"] = mean_tpr 87 | roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"]) 88 | 89 | # plot roc curve for each label 90 | plt.figure(figsize=(12, 12)) 91 | lw = 2 92 | 93 | plt.plot(fpr_dict["micro"], tpr_dict["micro"], 94 | label='micro-average ROC curve (area = {0:0.2f})' 95 | ''.format(roc_auc_dict["micro"]), 96 | color='deeppink', linestyle=':', linewidth=4) 97 | 98 | plt.plot(fpr_dict["macro"], tpr_dict["macro"], 99 | label='macro-average ROC curve (area = {0:0.2f})' 100 | ''.format(roc_auc_dict["macro"]), 101 | color='navy', linestyle=':', linewidth=4) 102 | 103 | colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) 104 | for i, color in zip(range(len(class_indict.keys())), colors): 105 | plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw, 106 | label='ROC curve of class {0} (area = {1:0.2f})' 107 | ''.format(class_indict[str(i)], roc_auc_dict[i])) 108 | 109 | plt.plot([0, 1], [0, 1], 'k--', lw=lw, label='Chance', color='red') 110 | plt.xlim([0.0, 1.0]) 111 | plt.ylim([0.0, 1.05]) 112 | plt.xlabel('False Positive Rate') 113 | plt.ylabel('True Positive Rate') 114 | plt.title('Receiver operating characteristic to multi-class') 115 | plt.legend(loc="lower right") 116 | plt.savefig('./multi_classes_roc.png') 117 | # plt.show() 118 | 119 | 120 | @torch.inference_mode() 121 | def predict_single_image(model: torch.nn.Module, device: torch.device, weight_path: str): 122 | """ 123 | Predict Single Image. 124 | 125 | Save the prediction as an image file which including pred label and prob in the current directory 126 | 127 | Args: 128 | model (torch.nn.Module): The model to be evaluated. 129 | device (torch.device): The device used for training (CPU or GPU). 130 | weight_path (str): The model weights file 131 | 132 | Returns: 133 | None 134 | """ 135 | 136 | data_transform = { 137 | 'train': transforms.Compose([transforms.RandomResizedCrop(224), transforms.ToTensor(), 138 | transforms.RandomHorizontalFlip(), 139 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 140 | 141 | 'valid': transforms.Compose([transforms.Resize((224, 224)), transforms.CenterCrop(224), 142 | transforms.ToTensor(), 143 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 144 | } 145 | 146 | img_transform = data_transform['valid'] 147 | 148 | # load image 149 | img_path = "rose.jpg" 150 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 151 | img = Image.open(img_path) 152 | plt.imshow(img) 153 | # [N, C, H, W] 154 | img = img_transform(img) 155 | # expand batch dimension 156 | img = torch.unsqueeze(img, dim=0) 157 | 158 | # read class_indict 159 | json_path = './classes_indices.json' 160 | assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) 161 | 162 | with open(json_path, "r") as f: 163 | class_indict = json.load(f) 164 | 165 | # load model weights 166 | 167 | assert os.path.exists(weight_path), "weight file dose not exist." 168 | model.load_state_dict(torch.load(weight_path, map_location=device)['model']) 169 | 170 | model.eval() 171 | # predict class 172 | output = torch.squeeze(model(img.to(device))).cpu() 173 | predict = torch.softmax(output, dim=0) 174 | predict_cla = torch.argmax(predict).numpy() 175 | 176 | print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], 177 | predict[predict_cla].numpy()) 178 | 179 | plt.title(print_res) 180 | for i in range(len(predict)): 181 | print("class: {:10} prob: {:.3}".format(class_indict[str(i)], 182 | predict[i].numpy())) 183 | plt.savefig(f'./pred_{img_path}') 184 | # plt.show() 185 | 186 | 187 | @torch.inference_mode() 188 | def Predictor(net: torch.nn.Module, test_loader: Iterable, save_name: str, device: torch.device): 189 | """ 190 | Evaluate the performance of the model on the given dataset. 191 | 192 | 1. This function will print the following metrics: 193 | - F1 score 194 | - Confusion matrix 195 | - Classification report 196 | 197 | 2. Save the confusion matrix as an image file in the current directory. 198 | 199 | Args: 200 | net (torch.nn.Module): The model to be evaluated. 201 | test_loader (Iterable): The data loader for the valid data. 202 | save_name (str): The file path of your model weights 203 | device (torch.device): The device used for training (CPU or GPU). 204 | 205 | Returns: 206 | None 207 | """ 208 | 209 | try: 210 | json_file = open('./classes_indices.json', 'r') 211 | class_indict = json.load(json_file) 212 | except Exception as e: 213 | print(e) 214 | exit(-1) 215 | 216 | errors = 0 217 | y_pred, y_true = [], [] 218 | net.load_state_dict(torch.load(save_name)['model']) 219 | 220 | net.eval() 221 | 222 | for data in test_loader: 223 | images, labels = data 224 | images, labels = images.to(device), labels.to(device) 225 | preds = torch.argmax(torch.softmax(net(images), dim=1), dim=1) 226 | for i in range(len(preds)): 227 | y_pred.append(preds[i].cpu()) 228 | y_true.append(labels[i].cpu()) 229 | 230 | tests = len(y_pred) 231 | for i in range(tests): 232 | pred_index = y_pred[i] 233 | true_index = y_true[i] 234 | if pred_index != true_index: 235 | errors += 1 236 | 237 | acc = (1 - errors / tests) * 100 238 | print(f'there were {errors} errors in {tests} tests for an accuracy of {acc:6.2f}%') 239 | 240 | ypred = np.array(y_pred) 241 | ytrue = np.array(y_true) 242 | 243 | f1score = f1_score(ytrue, ypred, average='weighted') * 100 244 | 245 | print(f'The F1-score was {f1score:.3f}') 246 | class_count = len(list(class_indict.values())) 247 | classes = list(class_indict.values()) 248 | 249 | cm = confusion_matrix(ytrue, ypred) 250 | plt.figure(figsize=(16, 8)) 251 | plt.subplot(1, 2, 1) 252 | sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False) 253 | plt.xticks(np.arange(class_count) + .5, classes, rotation=45, fontsize=14) 254 | plt.yticks(np.arange(class_count) + .5, classes, rotation=0, fontsize=14) 255 | plt.xlabel("Predicted", fontsize=14) 256 | plt.ylabel("True", fontsize=14) 257 | plt.title("Confusion Matrix") 258 | 259 | plt.subplot(1, 2, 2) 260 | sns.heatmap(cm / np.sum(cm), annot=True, fmt='.1%') 261 | plt.xticks(np.arange(class_count) + .5, classes, rotation=45, fontsize=14) 262 | plt.yticks(np.arange(class_count) + .5, classes, rotation=0, fontsize=14) 263 | plt.xlabel('Predicted', fontsize=14) 264 | plt.ylabel('True', fontsize=14) 265 | plt.savefig('./confusion_matrix.png') 266 | # plt.show() 267 | 268 | clr = classification_report(y_true, y_pred, target_names=classes, digits=4) 269 | print("Classification Report:\n----------------------\n", clr) 270 | 271 | 272 | @torch.inference_mode() 273 | def OptAUC(net: torch.nn.Module, val_loader: Iterable, save_name: str, device: torch.device): 274 | """ 275 | Optimize model for improving AUC 276 | 277 | Print a table of initial and optimized AUC and F1-score. 278 | 279 | This function takes the initial and optimized AUC and F1-score, and generates 280 | an ASCII table to display the results. The table will have the following format: 281 | 282 | Optimize Results 283 | +----------------------+----------------------+----------------------+----------------------+ 284 | | Initial AUC | Initial F1-Score | Optimize AUC | Optimize F1-Score | 285 | +----------------------+----------------------+----------------------+----------------------+ 286 | | 0.654321 | 0.654321 | 0.876543 | 0.876543 | 287 | +----------------------+----------------------+----------------------+----------------------+ 288 | 289 | The optimized AUC and F1-score are obtained by using the `OptimizeAUC` class (in ./optim_AUC.py), which 290 | performs optimization on the initial metrics. 291 | 292 | Args: 293 | net (torch.nn.Module): The model to be evaluated. 294 | test_loader (Iterable): The data loader for the valid data. 295 | save_name (str): The file path of your model weights 296 | device (torch.device): The device used for training (CPU or GPU). 297 | 298 | Returns: 299 | None 300 | """ 301 | 302 | score_list = [] 303 | label_list = [] 304 | 305 | net.load_state_dict(torch.load(save_name)['model']) 306 | 307 | for i, data in enumerate(val_loader): 308 | images, labels = data 309 | images, labels = images.to(device), labels.to(device) 310 | outputs = torch.softmax(net(images), dim=1) 311 | score_tmp = outputs 312 | score_list.extend(score_tmp.detach().cpu().numpy()) 313 | label_list.extend(labels.detach().cpu().numpy()) 314 | 315 | score_array = np.array(score_list) 316 | label_list = np.array(label_list) 317 | y_preds = np.argmax(score_array, axis=1) 318 | f1score = f1_score(label_list, y_preds, average='weighted') * 100 319 | auc_score = roc_auc_score(label_list, score_array, average='weighted', multi_class='ovo') 320 | 321 | opt_auc = OptimizeAUC() 322 | opt_auc.fit(score_array, label_list) 323 | opt_preds = opt_auc.predict(score_array) 324 | opt_y_preds = np.argmax(opt_preds, axis=1) 325 | opt_f1score = f1_score(label_list, opt_y_preds, average='weighted') * 100 326 | opt_auc_score = roc_auc_score(label_list, opt_preds, average='weighted', multi_class='ovo') 327 | 328 | TITLE = 'Optimize Results' 329 | TABLE_DATA = ( 330 | ('Initial AUC', 'Initial F1-Score', 'Optimize AUC', 'Optimize F1-Score'), 331 | ('{:.6f}'.format(auc_score), 332 | '{:.6f}'.format(f1score), 333 | '{:.6f}'.format(opt_auc_score), 334 | '{:.6f}'.format(opt_f1score) 335 | ), 336 | ) 337 | table_instance = AsciiTable(TABLE_DATA, TITLE) 338 | print(table_instance.table) 339 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .model_utils import * 3 | from .build_mobilenet_v4 import mobilenetv4_hybrid_large, mobilenetv4_hybrid_medium, mobilenetv4_hybrid_large_075, \ 4 | mobilenetv4_conv_large, mobilenetv4_conv_aa_large, mobilenetv4_conv_medium, mobilenetv4_conv_aa_medium, \ 5 | mobilenetv4_conv_small, mobilenetv4_hybrid_medium_075, mobilenetv4_conv_small_035, \ 6 | mobilenetv4_conv_small_050, mobilenetv4_conv_blur_medium -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Optional, Type 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | from timm.layers import create_conv2d, DropPath, create_act_layer, create_aa, to_2tuple, LayerType,\ 8 | ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d 9 | 10 | 11 | 12 | __all__ = [ 13 | 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 14 | 'UniversalInvertedResidual', 'MobileAttention' 15 | ] 16 | 17 | ModuleType = Type[nn.Module] 18 | 19 | 20 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 21 | min_value = min_value or divisor 22 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 23 | # Make sure that round down does not go down by more than 10%. 24 | if new_v < round_limit * v: 25 | new_v += divisor 26 | return new_v 27 | 28 | 29 | def num_groups(group_size: Optional[int], channels: int): 30 | if not group_size: # 0 or None 31 | return 1 # normal conv with 1 group 32 | else: 33 | # NOTE group_size == 1 -> depthwise conv 34 | assert channels % group_size == 0 35 | return channels // group_size 36 | 37 | 38 | class SqueezeExcite(nn.Module): 39 | """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family 40 | 41 | Args: 42 | in_chs (int): input channels to layer 43 | rd_ratio (float): ratio of squeeze reduction 44 | act_layer (nn.Module): activation layer of containing block 45 | gate_layer (Callable): attention gate function 46 | force_act_layer (nn.Module): override block's activation fn if this is set/bound 47 | rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs 48 | """ 49 | 50 | def __init__( 51 | self, 52 | in_chs: int, 53 | rd_ratio: float = 0.25, 54 | rd_channels: Optional[int] = None, 55 | act_layer: LayerType = nn.ReLU, 56 | gate_layer: LayerType = nn.Sigmoid, 57 | force_act_layer: Optional[LayerType] = None, 58 | rd_round_fn: Optional[Callable] = None, 59 | ): 60 | super(SqueezeExcite, self).__init__() 61 | if rd_channels is None: 62 | rd_round_fn = rd_round_fn or round 63 | rd_channels = rd_round_fn(in_chs * rd_ratio) 64 | act_layer = force_act_layer or act_layer 65 | self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True) 66 | self.act1 = create_act_layer(act_layer, inplace=True) 67 | self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True) 68 | self.gate = create_act_layer(gate_layer) 69 | 70 | def forward(self, x): 71 | x_se = x.mean((2, 3), keepdim=True) 72 | x_se = self.conv_reduce(x_se) 73 | x_se = self.act1(x_se) 74 | x_se = self.conv_expand(x_se) 75 | return x * self.gate(x_se) 76 | 77 | 78 | class ConvBnAct(nn.Module): 79 | """ Conv + Norm Layer + Activation w/ optional skip connection 80 | """ 81 | def __init__( 82 | self, 83 | in_chs: int, 84 | out_chs: int, 85 | kernel_size: int, 86 | stride: int = 1, 87 | dilation: int = 1, 88 | group_size: int = 0, 89 | pad_type: str = '', 90 | skip: bool = False, 91 | act_layer: LayerType = nn.ReLU, 92 | norm_layer: LayerType = nn.BatchNorm2d, 93 | aa_layer: Optional[LayerType] = None, 94 | drop_path_rate: float = 0., 95 | ): 96 | super(ConvBnAct, self).__init__() 97 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 98 | groups = num_groups(group_size, in_chs) 99 | self.has_skip = skip and stride == 1 and in_chs == out_chs 100 | use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation 101 | 102 | self.conv = create_conv2d( 103 | in_chs, out_chs, kernel_size, 104 | stride=1 if use_aa else stride, 105 | dilation=dilation, groups=groups, padding=pad_type) 106 | self.bn1 = norm_act_layer(out_chs, inplace=True) 107 | self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa) 108 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() 109 | 110 | def feature_info(self, location): 111 | if location == 'expansion': # output of conv after act, same as block coutput 112 | return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels) 113 | else: # location == 'bottleneck', block output 114 | return dict(module='', num_chs=self.conv.out_channels) 115 | 116 | def forward(self, x): 117 | shortcut = x 118 | x = self.conv(x) 119 | x = self.bn1(x) 120 | x = self.aa(x) 121 | if self.has_skip: 122 | x = self.drop_path(x) + shortcut 123 | return x 124 | 125 | 126 | class DepthwiseSeparableConv(nn.Module): 127 | """ Depthwise-separable block 128 | Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion 129 | (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. 130 | """ 131 | def __init__( 132 | self, 133 | in_chs: int, 134 | out_chs: int, 135 | dw_kernel_size: int = 3, 136 | stride: int = 1, 137 | dilation: int = 1, 138 | group_size: int = 1, 139 | pad_type: str = '', 140 | noskip: bool = False, 141 | pw_kernel_size: int = 1, 142 | pw_act: bool = False, 143 | s2d: int = 0, 144 | act_layer: LayerType = nn.ReLU, 145 | norm_layer: LayerType = nn.BatchNorm2d, 146 | aa_layer: Optional[LayerType] = None, 147 | se_layer: Optional[ModuleType] = None, 148 | drop_path_rate: float = 0., 149 | ): 150 | super(DepthwiseSeparableConv, self).__init__() 151 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 152 | self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip 153 | self.has_pw_act = pw_act # activation after point-wise conv 154 | use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation 155 | 156 | # Space to depth 157 | if s2d == 1: 158 | sd_chs = int(in_chs * 4) 159 | self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') 160 | self.bn_s2d = norm_act_layer(sd_chs, sd_chs) 161 | dw_kernel_size = (dw_kernel_size + 1) // 2 162 | dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type 163 | in_chs = sd_chs 164 | use_aa = False # disable AA 165 | else: 166 | self.conv_s2d = None 167 | self.bn_s2d = None 168 | dw_pad_type = pad_type 169 | 170 | groups = num_groups(group_size, in_chs) 171 | 172 | self.conv_dw = create_conv2d( 173 | in_chs, in_chs, dw_kernel_size, 174 | stride=1 if use_aa else stride, 175 | dilation=dilation, padding=dw_pad_type, groups=groups) 176 | self.bn1 = norm_act_layer(in_chs, inplace=True) 177 | self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa) 178 | 179 | # Squeeze-and-excitation 180 | self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() 181 | 182 | self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) 183 | self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act) 184 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() 185 | 186 | def feature_info(self, location): 187 | if location == 'expansion': # after SE, input to PW 188 | return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) 189 | else: # location == 'bottleneck', block output 190 | return dict(module='', num_chs=self.conv_pw.out_channels) 191 | 192 | def forward(self, x): 193 | shortcut = x 194 | if self.conv_s2d is not None: 195 | x = self.conv_s2d(x) 196 | x = self.bn_s2d(x) 197 | x = self.conv_dw(x) 198 | x = self.bn1(x) 199 | x = self.aa(x) 200 | x = self.se(x) 201 | x = self.conv_pw(x) 202 | x = self.bn2(x) 203 | if self.has_skip: 204 | x = self.drop_path(x) + shortcut 205 | return x 206 | 207 | 208 | class InvertedResidual(nn.Module): 209 | """ Inverted residual block w/ optional SE 210 | 211 | Originally used in MobileNet-V2 - https://arxiv.org/abs/1801.04381v4, this layer is often 212 | referred to as 'MBConv' for (Mobile inverted bottleneck conv) and is also used in 213 | * MNasNet - https://arxiv.org/abs/1807.11626 214 | * EfficientNet - https://arxiv.org/abs/1905.11946 215 | * MobileNet-V3 - https://arxiv.org/abs/1905.02244 216 | """ 217 | 218 | def __init__( 219 | self, 220 | in_chs: int, 221 | out_chs: int, 222 | dw_kernel_size: int = 3, 223 | stride: int = 1, 224 | dilation: int = 1, 225 | group_size: int = 1, 226 | pad_type: str = '', 227 | noskip: bool = False, 228 | exp_ratio: float = 1.0, 229 | exp_kernel_size: int = 1, 230 | pw_kernel_size: int = 1, 231 | s2d: int = 0, 232 | act_layer: LayerType = nn.ReLU, 233 | norm_layer: LayerType = nn.BatchNorm2d, 234 | aa_layer: Optional[LayerType] = None, 235 | se_layer: Optional[ModuleType] = None, 236 | conv_kwargs: Optional[Dict] = None, 237 | drop_path_rate: float = 0., 238 | ): 239 | super(InvertedResidual, self).__init__() 240 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 241 | conv_kwargs = conv_kwargs or {} 242 | self.has_skip = (in_chs == out_chs and stride == 1) and not noskip 243 | use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation 244 | 245 | # Space to depth 246 | if s2d == 1: 247 | sd_chs = int(in_chs * 4) 248 | self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') 249 | self.bn_s2d = norm_act_layer(sd_chs, sd_chs) 250 | dw_kernel_size = (dw_kernel_size + 1) // 2 251 | dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type 252 | in_chs = sd_chs 253 | use_aa = False # disable AA 254 | else: 255 | self.conv_s2d = None 256 | self.bn_s2d = None 257 | dw_pad_type = pad_type 258 | 259 | mid_chs = make_divisible(in_chs * exp_ratio) 260 | groups = num_groups(group_size, mid_chs) 261 | 262 | # Point-wise expansion 263 | self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) 264 | self.bn1 = norm_act_layer(mid_chs, inplace=True) 265 | 266 | # Depth-wise convolution 267 | self.conv_dw = create_conv2d( 268 | mid_chs, mid_chs, dw_kernel_size, 269 | stride=1 if use_aa else stride, 270 | dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs) 271 | self.bn2 = norm_act_layer(mid_chs, inplace=True) 272 | self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa) 273 | 274 | # Squeeze-and-excitation 275 | self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() 276 | 277 | # Point-wise linear projection 278 | self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) 279 | self.bn3 = norm_act_layer(out_chs, apply_act=False) 280 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() 281 | 282 | def feature_info(self, location): 283 | if location == 'expansion': # after SE, input to PWL 284 | return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) 285 | else: # location == 'bottleneck', block output 286 | return dict(module='', num_chs=self.conv_pwl.out_channels) 287 | 288 | def forward(self, x): 289 | shortcut = x 290 | if self.conv_s2d is not None: 291 | x = self.conv_s2d(x) 292 | x = self.bn_s2d(x) 293 | x = self.conv_pw(x) 294 | x = self.bn1(x) 295 | x = self.conv_dw(x) 296 | x = self.bn2(x) 297 | x = self.aa(x) 298 | x = self.se(x) 299 | x = self.conv_pwl(x) 300 | x = self.bn3(x) 301 | if self.has_skip: 302 | x = self.drop_path(x) + shortcut 303 | return x 304 | 305 | 306 | class LayerScale2d(nn.Module): 307 | def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): 308 | super().__init__() 309 | self.inplace = inplace 310 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 311 | 312 | def forward(self, x): 313 | gamma = self.gamma.view(1, -1, 1, 1) 314 | return x.mul_(gamma) if self.inplace else x * gamma 315 | 316 | 317 | class UniversalInvertedResidual(nn.Module): 318 | """ Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB) 319 | 320 | For MobileNetV4 - https://arxiv.org/abs/, referenced from 321 | https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778 322 | """ 323 | 324 | def __init__( 325 | self, 326 | in_chs: int, 327 | out_chs: int, 328 | dw_kernel_size_start: int = 0, 329 | dw_kernel_size_mid: int = 3, 330 | dw_kernel_size_end: int = 0, 331 | stride: int = 1, 332 | dilation: int = 1, 333 | group_size: int = 1, 334 | pad_type: str = '', 335 | noskip: bool = False, 336 | exp_ratio: float = 1.0, 337 | act_layer: LayerType = nn.ReLU, 338 | norm_layer: LayerType = nn.BatchNorm2d, 339 | aa_layer: Optional[LayerType] = None, 340 | se_layer: Optional[ModuleType] = None, 341 | conv_kwargs: Optional[Dict] = None, 342 | drop_path_rate: float = 0., 343 | layer_scale_init_value: Optional[float] = 1e-5, 344 | ): 345 | super(UniversalInvertedResidual, self).__init__() 346 | conv_kwargs = conv_kwargs or {} 347 | self.has_skip = (in_chs == out_chs and stride == 1) and not noskip 348 | if stride > 1: 349 | assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end 350 | 351 | # FIXME dilation isn't right w/ extra ks > 1 convs 352 | if dw_kernel_size_start: 353 | dw_start_stride = stride if not dw_kernel_size_mid else 1 354 | dw_start_groups = num_groups(group_size, in_chs) 355 | self.dw_start = ConvNormAct( 356 | in_chs, in_chs, dw_kernel_size_start, 357 | stride=dw_start_stride, 358 | dilation=dilation, # FIXME 359 | groups=dw_start_groups, 360 | padding=pad_type, 361 | apply_act=False, 362 | act_layer=act_layer, 363 | norm_layer=norm_layer, 364 | aa_layer=aa_layer, 365 | **conv_kwargs, 366 | ) 367 | else: 368 | self.dw_start = nn.Identity() 369 | 370 | # Point-wise expansion 371 | mid_chs = make_divisible(in_chs * exp_ratio) 372 | self.pw_exp = ConvNormAct( 373 | in_chs, mid_chs, 1, 374 | padding=pad_type, 375 | act_layer=act_layer, 376 | norm_layer=norm_layer, 377 | **conv_kwargs, 378 | ) 379 | 380 | # Middle depth-wise convolution 381 | if dw_kernel_size_mid: 382 | groups = num_groups(group_size, mid_chs) 383 | self.dw_mid = ConvNormAct( 384 | mid_chs, mid_chs, dw_kernel_size_mid, 385 | stride=stride, 386 | dilation=dilation, # FIXME 387 | groups=groups, 388 | padding=pad_type, 389 | act_layer=act_layer, 390 | norm_layer=norm_layer, 391 | aa_layer=aa_layer, 392 | **conv_kwargs, 393 | ) 394 | else: 395 | # keeping mid as identity so it can be hooked more easily for features 396 | self.dw_mid = nn.Identity() 397 | 398 | # Squeeze-and-excitation 399 | self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() 400 | 401 | # Point-wise linear projection 402 | self.pw_proj = ConvNormAct( 403 | mid_chs, out_chs, 1, 404 | padding=pad_type, 405 | apply_act=False, 406 | act_layer=act_layer, 407 | norm_layer=norm_layer, 408 | **conv_kwargs, 409 | ) 410 | 411 | if dw_kernel_size_end: 412 | dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1 413 | dw_end_groups = num_groups(group_size, out_chs) 414 | if dw_end_stride > 1: 415 | assert not aa_layer 416 | self.dw_end = ConvNormAct( 417 | out_chs, out_chs, dw_kernel_size_end, 418 | stride=dw_end_stride, 419 | dilation=dilation, 420 | groups=dw_end_groups, 421 | padding=pad_type, 422 | apply_act=False, 423 | act_layer=act_layer, 424 | norm_layer=norm_layer, 425 | **conv_kwargs, 426 | ) 427 | else: 428 | self.dw_end = nn.Identity() 429 | 430 | if layer_scale_init_value is not None: 431 | self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) 432 | else: 433 | self.layer_scale = nn.Identity() 434 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() 435 | 436 | def feature_info(self, location): 437 | if location == 'expansion': # after SE, input to PWL 438 | return dict(module='pw_proj.conv', hook_type='forward_pre', num_chs=self.pw_proj.conv.in_channels) 439 | else: # location == 'bottleneck', block output 440 | return dict(module='', num_chs=self.pw_proj.conv.out_channels) 441 | 442 | def forward(self, x): 443 | shortcut = x 444 | x = self.dw_start(x) 445 | x = self.pw_exp(x) 446 | x = self.dw_mid(x) 447 | x = self.se(x) 448 | x = self.pw_proj(x) 449 | x = self.dw_end(x) 450 | x = self.layer_scale(x) 451 | if self.has_skip: 452 | x = self.drop_path(x) + shortcut 453 | return x 454 | 455 | 456 | class MobileAttention(nn.Module): 457 | """ Mobile Attention Block 458 | 459 | For MobileNetV4 - https://arxiv.org/abs/, referenced from 460 | https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504 461 | """ 462 | def __init__( 463 | self, 464 | in_chs: int, 465 | out_chs: int, 466 | stride: int = 1, 467 | dw_kernel_size: int = 3, 468 | dilation: int = 1, 469 | group_size: int = 1, 470 | pad_type: str = '', 471 | num_heads: int = 8, 472 | key_dim: int = 64, 473 | value_dim: int = 64, 474 | use_multi_query: bool = False, 475 | query_strides: int = (1, 1), 476 | kv_stride: int = 1, 477 | cpe_dw_kernel_size: int = 3, 478 | noskip: bool = False, 479 | act_layer: LayerType = nn.ReLU, 480 | norm_layer: LayerType = nn.BatchNorm2d, 481 | aa_layer: Optional[LayerType] = None, 482 | drop_path_rate: float = 0., 483 | attn_drop: float = 0.0, 484 | proj_drop: float = 0.0, 485 | layer_scale_init_value: Optional[float] = 1e-5, 486 | use_bias: bool = False, 487 | use_cpe: bool = False, 488 | ): 489 | super(MobileAttention, self).__init__() 490 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 491 | self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip 492 | self.query_strides = to_2tuple(query_strides) 493 | self.kv_stride = kv_stride 494 | self.has_query_stride = any([s > 1 for s in self.query_strides]) 495 | 496 | # This CPE is different than the one suggested in the original paper. 497 | # https://arxiv.org/abs/2102.10882 498 | # 1. Rather than adding one CPE before the attention blocks, we add a CPE 499 | # into every attention block. 500 | # 2. We replace the expensive Conv2D by a Seperable DW Conv. 501 | if use_cpe: 502 | self.conv_cpe_dw = create_conv2d( 503 | in_chs, in_chs, 504 | kernel_size=cpe_dw_kernel_size, 505 | dilation=dilation, 506 | depthwise=True, 507 | bias=True, 508 | ) 509 | else: 510 | self.conv_cpe_dw = None 511 | 512 | self.norm = norm_act_layer(in_chs, apply_act=False) 513 | 514 | if num_heads is None: 515 | assert in_chs % key_dim == 0 516 | num_heads = in_chs // key_dim 517 | 518 | if use_multi_query: 519 | self.attn = MultiQueryAttention2d( 520 | in_chs, 521 | dim_out=out_chs, 522 | num_heads=num_heads, 523 | key_dim=key_dim, 524 | value_dim=value_dim, 525 | query_strides=query_strides, 526 | kv_stride=kv_stride, 527 | dilation=dilation, 528 | padding=pad_type, 529 | dw_kernel_size=dw_kernel_size, 530 | attn_drop=attn_drop, 531 | proj_drop=proj_drop, 532 | #bias=use_bias, # why not here if used w/ mhsa? 533 | ) 534 | else: 535 | self.attn = Attention2d( 536 | in_chs, 537 | dim_out=out_chs, 538 | num_heads=num_heads, 539 | attn_drop=attn_drop, 540 | proj_drop=proj_drop, 541 | bias=use_bias, 542 | ) 543 | 544 | if layer_scale_init_value is not None: 545 | self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) 546 | else: 547 | self.layer_scale = nn.Identity() 548 | 549 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() 550 | 551 | def feature_info(self, location): 552 | if location == 'expansion': # after SE, input to PW 553 | return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) 554 | else: # location == 'bottleneck', block output 555 | return dict(module='', num_chs=self.conv_pw.out_channels) 556 | 557 | def forward(self, x): 558 | if self.conv_cpe_dw is not None: 559 | x_cpe = self.conv_cpe_dw(x) 560 | x = x + x_cpe 561 | 562 | shortcut = x 563 | x = self.norm(x) 564 | x = self.attn(x) 565 | x = self.layer_scale(x) 566 | if self.has_skip: 567 | x = self.drop_path(x) + shortcut 568 | 569 | return x 570 | 571 | 572 | class CondConvResidual(InvertedResidual): 573 | """ Inverted residual block w/ CondConv routing""" 574 | 575 | def __init__( 576 | self, 577 | in_chs: int, 578 | out_chs: int, 579 | dw_kernel_size: int = 3, 580 | stride: int = 1, 581 | dilation: int = 1, 582 | group_size: int = 1, 583 | pad_type: str = '', 584 | noskip: bool = False, 585 | exp_ratio: float = 1.0, 586 | exp_kernel_size: int = 1, 587 | pw_kernel_size: int = 1, 588 | act_layer: LayerType = nn.ReLU, 589 | norm_layer: LayerType = nn.BatchNorm2d, 590 | aa_layer: Optional[LayerType] = None, 591 | se_layer: Optional[ModuleType] = None, 592 | num_experts: int = 0, 593 | drop_path_rate: float = 0., 594 | ): 595 | 596 | self.num_experts = num_experts 597 | conv_kwargs = dict(num_experts=self.num_experts) 598 | super(CondConvResidual, self).__init__( 599 | in_chs, 600 | out_chs, 601 | dw_kernel_size=dw_kernel_size, 602 | stride=stride, 603 | dilation=dilation, 604 | group_size=group_size, 605 | pad_type=pad_type, 606 | noskip=noskip, 607 | exp_ratio=exp_ratio, 608 | exp_kernel_size=exp_kernel_size, 609 | pw_kernel_size=pw_kernel_size, 610 | act_layer=act_layer, 611 | norm_layer=norm_layer, 612 | aa_layer=aa_layer, 613 | se_layer=se_layer, 614 | conv_kwargs=conv_kwargs, 615 | drop_path_rate=drop_path_rate, 616 | ) 617 | self.routing_fn = nn.Linear(in_chs, self.num_experts) 618 | 619 | def forward(self, x): 620 | shortcut = x 621 | pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing 622 | routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) 623 | x = self.conv_pw(x, routing_weights) 624 | x = self.bn1(x) 625 | x = self.conv_dw(x, routing_weights) 626 | x = self.bn2(x) 627 | x = self.se(x) 628 | x = self.conv_pwl(x, routing_weights) 629 | x = self.bn3(x) 630 | if self.has_skip: 631 | x = self.drop_path(x) + shortcut 632 | return x 633 | 634 | 635 | class EdgeResidual(nn.Module): 636 | """ Residual block with expansion convolution followed by pointwise-linear w/ stride 637 | 638 | Originally introduced in `EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML` 639 | - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html 640 | 641 | This layer is also called FusedMBConv in the MobileDet, EfficientNet-X, and EfficientNet-V2 papers 642 | * MobileDet - https://arxiv.org/abs/2004.14525 643 | * EfficientNet-X - https://arxiv.org/abs/2102.05610 644 | * EfficientNet-V2 - https://arxiv.org/abs/2104.00298 645 | """ 646 | 647 | def __init__( 648 | self, 649 | in_chs: int, 650 | out_chs: int, 651 | exp_kernel_size: int = 3, 652 | stride: int = 1, 653 | dilation: int = 1, 654 | group_size: int = 0, 655 | pad_type: str = '', 656 | force_in_chs: int = 0, 657 | noskip: bool = False, 658 | exp_ratio: float = 1.0, 659 | pw_kernel_size: int = 1, 660 | act_layer: LayerType = nn.ReLU, 661 | norm_layer: LayerType = nn.BatchNorm2d, 662 | aa_layer: Optional[LayerType] = None, 663 | se_layer: Optional[ModuleType] = None, 664 | drop_path_rate: float = 0., 665 | ): 666 | super(EdgeResidual, self).__init__() 667 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 668 | if force_in_chs > 0: 669 | mid_chs = make_divisible(force_in_chs * exp_ratio) 670 | else: 671 | mid_chs = make_divisible(in_chs * exp_ratio) 672 | groups = num_groups(group_size, mid_chs) # NOTE: Using out_chs of conv_exp for groups calc 673 | self.has_skip = (in_chs == out_chs and stride == 1) and not noskip 674 | use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation 675 | 676 | # Expansion convolution 677 | self.conv_exp = create_conv2d( 678 | in_chs, mid_chs, exp_kernel_size, 679 | stride=1 if use_aa else stride, 680 | dilation=dilation, groups=groups, padding=pad_type) 681 | self.bn1 = norm_act_layer(mid_chs, inplace=True) 682 | 683 | self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa) 684 | 685 | # Squeeze-and-excitation 686 | self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() 687 | 688 | # Point-wise linear projection 689 | self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) 690 | self.bn2 = norm_act_layer(out_chs, apply_act=False) 691 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() 692 | 693 | def feature_info(self, location): 694 | if location == 'expansion': # after SE, before PWL 695 | return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) 696 | else: # location == 'bottleneck', block output 697 | return dict(module='', num_chs=self.conv_pwl.out_channels) 698 | 699 | def forward(self, x): 700 | shortcut = x 701 | x = self.conv_exp(x) 702 | x = self.bn1(x) 703 | x = self.aa(x) 704 | x = self.se(x) 705 | x = self.conv_pwl(x) 706 | x = self.bn2(x) 707 | if self.has_skip: 708 | x = self.drop_path(x) + shortcut 709 | return x -------------------------------------------------------------------------------- /models/extra_attention_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class MultiScaleAttentionGate(nn.Module): 6 | """ 7 | Multi-scale attention gate 8 | """ 9 | def __init__(self, channel): 10 | super(MultiScaleAttentionGate, self).__init__() 11 | self.channel = channel 12 | self.pointwiseConv = nn.Sequential( 13 | nn.Conv2d(self.channel, self.channel, kernel_size=1, padding=0, bias=True), 14 | nn.BatchNorm2d(self.channel), 15 | ) 16 | self.ordinaryConv = nn.Sequential( 17 | nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=1, stride=1, bias=True), 18 | nn.BatchNorm2d(self.channel), 19 | ) 20 | self.dilationConv = nn.Sequential( 21 | nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=2, stride=1, dilation=2, bias=True), 22 | nn.BatchNorm2d(self.channel), 23 | ) 24 | self.voteConv = nn.Sequential( 25 | nn.Conv2d(self.channel * 3, self.channel, kernel_size=(1, 1)), 26 | nn.BatchNorm2d(self.channel), 27 | nn.GELU() 28 | ) 29 | self.relu = nn.ReLU(inplace=True) 30 | 31 | def forward(self, x): 32 | x1 = self.pointwiseConv(x) 33 | x2 = self.ordinaryConv(x) 34 | x3 = self.dilationConv(x) 35 | _x = self.relu(torch.cat((x1, x2, x3), dim=1)) 36 | _x = self.voteConv(_x) 37 | x = x + x * _x 38 | return x 39 | 40 | # if __name__ == '__main__': 41 | # net = MultiScaleAttentionGate(960) 42 | # X = torch.randn(1, 960, 7, 7) 43 | # y = net(X) 44 | # print(y.shape) -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import logging 4 | import math 5 | import re 6 | from copy import deepcopy 7 | from functools import partial 8 | from typing import Any, Dict, List 9 | 10 | import torch.nn as nn 11 | 12 | from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, LayerType 13 | from models.blocks import make_divisible 14 | from models.blocks import * 15 | 16 | 17 | def named_modules( 18 | module: nn.Module, 19 | name: str = '', 20 | depth_first: bool = True, 21 | include_root: bool = False, 22 | ): 23 | if not depth_first and include_root: 24 | yield name, module 25 | for child_name, child_module in module.named_children(): 26 | child_name = '.'.join((name, child_name)) if name else child_name 27 | yield from named_modules( 28 | module=child_module, name=child_name, depth_first=depth_first, include_root=True) 29 | if depth_first and include_root: 30 | yield name, module 31 | 32 | 33 | 34 | __all__ = ["EfficientNetBuilder", "BlockArgs", "decode_arch_def", "efficientnet_init_weights", 35 | 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] 36 | 37 | _logger = logging.getLogger(__name__) 38 | 39 | 40 | _DEBUG_BUILDER = False 41 | 42 | # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per 43 | # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) 44 | # NOTE: momentum varies btw .99 and .9997 depending on source 45 | # .99 in official TF TPU impl 46 | # .9997 (/w .999 in search space) for paper 47 | BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 48 | BN_EPS_TF_DEFAULT = 1e-3 49 | _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) 50 | 51 | BlockArgs = List[List[Dict[str, Any]]] 52 | 53 | 54 | def get_bn_args_tf(): 55 | return _BN_ARGS_TF.copy() 56 | 57 | 58 | def resolve_bn_args(kwargs): 59 | bn_args = {} 60 | bn_momentum = kwargs.pop('bn_momentum', None) 61 | if bn_momentum is not None: 62 | bn_args['momentum'] = bn_momentum 63 | bn_eps = kwargs.pop('bn_eps', None) 64 | if bn_eps is not None: 65 | bn_args['eps'] = bn_eps 66 | return bn_args 67 | 68 | 69 | def resolve_act_layer(kwargs, default='relu'): 70 | return get_act_layer(kwargs.pop('act_layer', default)) 71 | 72 | 73 | def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9): 74 | """Round number of filters based on depth multiplier.""" 75 | if not multiplier: 76 | return channels 77 | return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit) 78 | 79 | 80 | def _log_info_if(msg, condition): 81 | if condition: 82 | _logger.info(msg) 83 | 84 | 85 | def _parse_ksize(ss): 86 | if ss.isdigit(): 87 | return int(ss) 88 | else: 89 | return [int(k) for k in ss.split('.')] 90 | 91 | 92 | def _decode_block_str(block_str): 93 | """ Decode block definition string 94 | 95 | Gets a list of block arg (dicts) through a string notation of arguments. 96 | E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip 97 | 98 | All args can exist in any order with the exception of the leading string which 99 | is assumed to indicate the block type. 100 | 101 | leading string - block type ( 102 | ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) 103 | r - number of repeat blocks, 104 | k - kernel size, 105 | s - strides (1-9), 106 | e - expansion ratio, 107 | c - output channels, 108 | se - squeeze/excitation ratio 109 | n - activation fn ('re', 'r6', 'hs', or 'sw') 110 | Args: 111 | block_str: a string representation of block arguments. 112 | Returns: 113 | A list of block args (dicts) 114 | Raises: 115 | ValueError: if the string def not properly specified (TODO) 116 | """ 117 | assert isinstance(block_str, str) 118 | ops = block_str.split('_') 119 | block_type = ops[0] # take the block type off the front 120 | ops = ops[1:] 121 | options = {} 122 | skip = None 123 | for op in ops: 124 | # string options being checked on individual basis, combine if they grow 125 | if op == 'noskip': 126 | skip = False # force no skip connection 127 | elif op == 'skip': 128 | skip = True # force a skip connection 129 | elif op.startswith('n'): 130 | # activation fn 131 | key = op[0] 132 | v = op[1:] 133 | if v == 're': 134 | value = get_act_layer('relu') 135 | elif v == 'r6': 136 | value = get_act_layer('relu6') 137 | elif v == 'hs': 138 | value = get_act_layer('hard_swish') 139 | elif v == 'sw': 140 | value = get_act_layer('swish') # aka SiLU 141 | elif v == 'mi': 142 | value = get_act_layer('mish') 143 | else: 144 | continue 145 | options[key] = value 146 | else: 147 | # all numeric options 148 | splits = re.split(r'(\d.*)', op) 149 | if len(splits) >= 2: 150 | key, value = splits[:2] 151 | options[key] = value 152 | 153 | # if act_layer is None, the model default (passed to model init) will be used 154 | act_layer = options['n'] if 'n' in options else None 155 | start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 156 | end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 157 | force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def 158 | num_repeat = int(options['r']) 159 | 160 | # each type of block has different valid arguments, fill accordingly 161 | block_args = dict( 162 | block_type=block_type, 163 | out_chs=int(options['c']), 164 | stride=int(options['s']), 165 | act_layer=act_layer, 166 | ) 167 | if block_type == 'ir': 168 | block_args.update(dict( 169 | dw_kernel_size=_parse_ksize(options['k']), 170 | exp_kernel_size=start_kernel_size, 171 | pw_kernel_size=end_kernel_size, 172 | exp_ratio=float(options['e']), 173 | se_ratio=float(options.get('se', 0.)), 174 | noskip=skip is False, 175 | s2d=int(options.get('d', 0)) > 0, 176 | )) 177 | if 'cc' in options: 178 | block_args['num_experts'] = int(options['cc']) 179 | elif block_type == 'ds' or block_type == 'dsa': 180 | block_args.update(dict( 181 | dw_kernel_size=_parse_ksize(options['k']), 182 | pw_kernel_size=end_kernel_size, 183 | se_ratio=float(options.get('se', 0.)), 184 | pw_act=block_type == 'dsa', 185 | noskip=block_type == 'dsa' or skip is False, 186 | s2d=int(options.get('d', 0)) > 0, 187 | )) 188 | elif block_type == 'er': 189 | block_args.update(dict( 190 | exp_kernel_size=_parse_ksize(options['k']), 191 | pw_kernel_size=end_kernel_size, 192 | exp_ratio=float(options['e']), 193 | force_in_chs=force_in_chs, 194 | se_ratio=float(options.get('se', 0.)), 195 | noskip=skip is False, 196 | )) 197 | elif block_type == 'cn': 198 | block_args.update(dict( 199 | kernel_size=int(options['k']), 200 | skip=skip is True, 201 | )) 202 | elif block_type == 'uir': 203 | # override exp / proj kernels for start/end in uir block 204 | start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0 205 | end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0 206 | block_args.update(dict( 207 | dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start 208 | dw_kernel_size_mid=_parse_ksize(options['k']), 209 | dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end 210 | exp_ratio=float(options['e']), 211 | se_ratio=float(options.get('se', 0.)), 212 | noskip=skip is False, 213 | )) 214 | elif block_type == 'mha': 215 | kv_dim = int(options['d']) 216 | block_args.update(dict( 217 | dw_kernel_size=_parse_ksize(options['k']), 218 | num_heads=int(options['h']), 219 | key_dim=kv_dim, 220 | value_dim=kv_dim, 221 | kv_stride=int(options.get('v', 1)), 222 | noskip=skip is False, 223 | )) 224 | elif block_type == 'mqa': 225 | kv_dim = int(options['d']) 226 | block_args.update(dict( 227 | dw_kernel_size=_parse_ksize(options['k']), 228 | num_heads=int(options['h']), 229 | key_dim=kv_dim, 230 | value_dim=kv_dim, 231 | kv_stride=int(options.get('v', 1)), 232 | noskip=skip is False, 233 | )) 234 | else: 235 | assert False, 'Unknown block type (%s)' % block_type 236 | 237 | if 'gs' in options: 238 | block_args['group_size'] = int(options['gs']) 239 | 240 | return block_args, num_repeat 241 | 242 | 243 | def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): 244 | """ Per-stage depth scaling 245 | Scales the block repeats in each stage. This depth scaling impl maintains 246 | compatibility with the EfficientNet scaling method, while allowing sensible 247 | scaling for other models that may have multiple block arg definitions in each stage. 248 | """ 249 | 250 | # We scale the total repeat count for each stage, there may be multiple 251 | # block arg defs per stage so we need to sum. 252 | num_repeat = sum(repeats) 253 | if depth_trunc == 'round': 254 | # Truncating to int by rounding allows stages with few repeats to remain 255 | # proportionally smaller for longer. This is a good choice when stage definitions 256 | # include single repeat stages that we'd prefer to keep that way as long as possible 257 | num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) 258 | else: 259 | # The default for EfficientNet truncates repeats to int via 'ceil'. 260 | # Any multiplier > 1.0 will result in an increased depth for every stage. 261 | num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) 262 | 263 | # Proportionally distribute repeat count scaling to each block definition in the stage. 264 | # Allocation is done in reverse as it results in the first block being less likely to be scaled. 265 | # The first block makes less sense to repeat in most of the arch definitions. 266 | repeats_scaled = [] 267 | for r in repeats[::-1]: 268 | rs = max(1, round((r / num_repeat * num_repeat_scaled))) 269 | repeats_scaled.append(rs) 270 | num_repeat -= r 271 | num_repeat_scaled -= rs 272 | repeats_scaled = repeats_scaled[::-1] 273 | 274 | # Apply the calculated scaling to each block arg in the stage 275 | sa_scaled = [] 276 | for ba, rep in zip(stack_args, repeats_scaled): 277 | sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) 278 | return sa_scaled 279 | 280 | 281 | def decode_arch_def( 282 | arch_def, 283 | depth_multiplier=1.0, 284 | depth_trunc='ceil', 285 | experts_multiplier=1, 286 | fix_first_last=False, 287 | group_size=None, 288 | ): 289 | """ Decode block architecture definition strings -> block kwargs 290 | 291 | Args: 292 | arch_def: architecture definition strings, list of list of strings 293 | depth_multiplier: network depth multiplier 294 | depth_trunc: networ depth truncation mode when applying multiplier 295 | experts_multiplier: CondConv experts multiplier 296 | fix_first_last: fix first and last block depths when multiplier is applied 297 | group_size: group size override for all blocks that weren't explicitly set in arch string 298 | 299 | Returns: 300 | list of list of block kwargs 301 | """ 302 | arch_args = [] 303 | if isinstance(depth_multiplier, tuple): 304 | assert len(depth_multiplier) == len(arch_def) 305 | else: 306 | depth_multiplier = (depth_multiplier,) * len(arch_def) 307 | for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)): 308 | assert isinstance(block_strings, list) 309 | stack_args = [] 310 | repeats = [] 311 | for block_str in block_strings: 312 | assert isinstance(block_str, str) 313 | ba, rep = _decode_block_str(block_str) 314 | if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: 315 | ba['num_experts'] *= experts_multiplier 316 | if group_size is not None: 317 | ba.setdefault('group_size', group_size) 318 | stack_args.append(ba) 319 | repeats.append(rep) 320 | if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): 321 | arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) 322 | else: 323 | arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc)) 324 | return arch_args 325 | 326 | 327 | class EfficientNetBuilder: 328 | """ Build Trunk Blocks 329 | 330 | This ended up being somewhat of a cross between 331 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py 332 | and 333 | https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py 334 | 335 | """ 336 | def __init__( 337 | self, 338 | output_stride: int = 32, 339 | pad_type: str = '', 340 | round_chs_fn: Callable = round_channels, 341 | se_from_exp: bool = False, 342 | act_layer: Optional[LayerType] = None, 343 | norm_layer: Optional[LayerType] = None, 344 | aa_layer: Optional[LayerType] = None, 345 | se_layer: Optional[LayerType] = None, 346 | drop_path_rate: float = 0., 347 | layer_scale_init_value: Optional[float] = None, 348 | feature_location: str = '', 349 | ): 350 | self.output_stride = output_stride 351 | self.pad_type = pad_type 352 | self.round_chs_fn = round_chs_fn 353 | self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs 354 | self.act_layer = act_layer 355 | self.norm_layer = norm_layer 356 | self.aa_layer = aa_layer 357 | self.se_layer = get_attn(se_layer) 358 | try: 359 | self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg 360 | self.se_has_ratio = True 361 | except TypeError: 362 | self.se_has_ratio = False 363 | self.drop_path_rate = drop_path_rate 364 | self.layer_scale_init_value = layer_scale_init_value 365 | if feature_location == 'depthwise': 366 | # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense 367 | _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") 368 | feature_location = 'expansion' 369 | self.feature_location = feature_location 370 | assert feature_location in ('bottleneck', 'expansion', '') 371 | self.verbose = _DEBUG_BUILDER 372 | 373 | # state updated during build, consumed by model 374 | self.in_chs = None 375 | self.features = [] 376 | 377 | def _make_block(self, ba, block_idx, block_count): 378 | drop_path_rate = self.drop_path_rate * block_idx / block_count 379 | bt = ba.pop('block_type') 380 | ba['in_chs'] = self.in_chs 381 | ba['out_chs'] = self.round_chs_fn(ba['out_chs']) 382 | s2d = ba.get('s2d', 0) 383 | if s2d > 0: 384 | # adjust while space2depth active 385 | ba['out_chs'] *= 4 386 | if 'force_in_chs' in ba and ba['force_in_chs']: 387 | # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl 388 | ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs']) 389 | ba['pad_type'] = self.pad_type 390 | # block act fn overrides the model default 391 | ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer 392 | assert ba['act_layer'] is not None 393 | ba['norm_layer'] = self.norm_layer 394 | ba['drop_path_rate'] = drop_path_rate 395 | 396 | if self.aa_layer is not None: 397 | ba['aa_layer'] = self.aa_layer 398 | 399 | se_ratio = ba.pop('se_ratio', None) 400 | if se_ratio and self.se_layer is not None: 401 | if not self.se_from_exp: 402 | # adjust se_ratio by expansion ratio if calculating se channels from block input 403 | se_ratio /= ba.get('exp_ratio', 1.0) 404 | if s2d == 1: 405 | # adjust for start of space2depth 406 | se_ratio /= 4 407 | if self.se_has_ratio: 408 | ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) 409 | else: 410 | ba['se_layer'] = self.se_layer 411 | 412 | if bt == 'ir': 413 | _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) 414 | block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba) 415 | elif bt == 'ds' or bt == 'dsa': 416 | _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose) 417 | block = DepthwiseSeparableConv(**ba) 418 | elif bt == 'er': 419 | _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) 420 | block = EdgeResidual(**ba) 421 | elif bt == 'cn': 422 | _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose) 423 | block = ConvBnAct(**ba) 424 | elif bt == 'uir': 425 | _log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) 426 | block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value) 427 | elif bt == 'mqa': 428 | _log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) 429 | block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value) 430 | elif bt == 'mha': 431 | _log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) 432 | block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value) 433 | else: 434 | assert False, 'Unknown block type (%s) while building model.' % bt 435 | 436 | self.in_chs = ba['out_chs'] # update in_chs for arg of next block 437 | return block 438 | 439 | def __call__(self, in_chs, model_block_args): 440 | """ Build the blocks 441 | Args: 442 | in_chs: Number of input-channels passed to first block 443 | model_block_args: A list of lists, outer list defines stages, inner 444 | list contains strings defining block configuration(s) 445 | Return: 446 | List of block stacks (each stack wrapped in nn.Sequential) 447 | """ 448 | _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose) 449 | self.in_chs = in_chs 450 | total_block_count = sum([len(x) for x in model_block_args]) 451 | total_block_idx = 0 452 | current_stride = 2 453 | current_dilation = 1 454 | stages = [] 455 | if model_block_args[0][0]['stride'] > 1: 456 | # if the first block starts with a stride, we need to extract first level feat from stem 457 | feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride) 458 | self.features.append(feature_info) 459 | 460 | # outer list of block_args defines the stacks 461 | space2depth = 0 462 | for stack_idx, stack_args in enumerate(model_block_args): 463 | last_stack = stack_idx + 1 == len(model_block_args) 464 | _log_info_if('Stack: {}'.format(stack_idx), self.verbose) 465 | assert isinstance(stack_args, list) 466 | 467 | blocks = [] 468 | # each stack (stage of blocks) contains a list of block arguments 469 | for block_idx, block_args in enumerate(stack_args): 470 | last_block = block_idx + 1 == len(stack_args) 471 | _log_info_if(' Block: {}'.format(block_idx), self.verbose) 472 | 473 | assert block_args['stride'] in (1, 2) 474 | if block_idx >= 1: # only the first block in any stack can have a stride > 1 475 | block_args['stride'] = 1 476 | 477 | if not space2depth and block_args.pop('s2d', False): 478 | assert block_args['stride'] == 1 479 | space2depth = 1 480 | 481 | if space2depth > 0: 482 | # FIXME s2d is a WIP 483 | if space2depth == 2 and block_args['stride'] == 2: 484 | block_args['stride'] = 1 485 | # to end s2d region, need to correct expansion and se ratio relative to input 486 | block_args['exp_ratio'] /= 4 487 | space2depth = 0 488 | else: 489 | block_args['s2d'] = space2depth 490 | 491 | extract_features = False 492 | if last_block: 493 | next_stack_idx = stack_idx + 1 494 | extract_features = next_stack_idx >= len(model_block_args) or \ 495 | model_block_args[next_stack_idx][0]['stride'] > 1 496 | 497 | next_dilation = current_dilation 498 | if block_args['stride'] > 1: 499 | next_output_stride = current_stride * block_args['stride'] 500 | if next_output_stride > self.output_stride: 501 | next_dilation = current_dilation * block_args['stride'] 502 | block_args['stride'] = 1 503 | _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format( 504 | self.output_stride), self.verbose) 505 | else: 506 | current_stride = next_output_stride 507 | block_args['dilation'] = current_dilation 508 | if next_dilation != current_dilation: 509 | current_dilation = next_dilation 510 | 511 | # create the block 512 | block = self._make_block(block_args, total_block_idx, total_block_count) 513 | blocks.append(block) 514 | 515 | if space2depth == 1: 516 | space2depth = 2 517 | 518 | # stash feature module name and channel info for model feature extraction 519 | if extract_features: 520 | feature_info = dict( 521 | stage=stack_idx + 1, 522 | reduction=current_stride, 523 | **block.feature_info(self.feature_location), 524 | ) 525 | leaf_name = feature_info.get('module', '') 526 | if leaf_name: 527 | feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name]) 528 | else: 529 | assert last_block 530 | feature_info['module'] = f'blocks.{stack_idx}' 531 | self.features.append(feature_info) 532 | 533 | total_block_idx += 1 # incr global block idx (across all stacks) 534 | stages.append(nn.Sequential(*blocks)) 535 | return stages 536 | 537 | 538 | def _init_weight_goog(m, n='', fix_group_fanout=True): 539 | """ Weight initialization as per Tensorflow official implementations. 540 | 541 | Args: 542 | m (nn.Module): module to init 543 | n (str): module name 544 | fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs 545 | 546 | Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: 547 | * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py 548 | * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py 549 | """ 550 | if isinstance(m, CondConv2d): 551 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 552 | if fix_group_fanout: 553 | fan_out //= m.groups 554 | init_weight_fn = get_condconv_initializer( 555 | lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) 556 | init_weight_fn(m.weight) 557 | if m.bias is not None: 558 | nn.init.zeros_(m.bias) 559 | elif isinstance(m, nn.Conv2d): 560 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 561 | if fix_group_fanout: 562 | fan_out //= m.groups 563 | nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out)) 564 | if m.bias is not None: 565 | nn.init.zeros_(m.bias) 566 | elif isinstance(m, nn.BatchNorm2d): 567 | nn.init.ones_(m.weight) 568 | nn.init.zeros_(m.bias) 569 | elif isinstance(m, nn.Linear): 570 | fan_out = m.weight.size(0) # fan-out 571 | fan_in = 0 572 | if 'routing_fn' in n: 573 | fan_in = m.weight.size(1) 574 | init_range = 1.0 / math.sqrt(fan_in + fan_out) 575 | nn.init.uniform_(m.weight, -init_range, init_range) 576 | nn.init.zeros_(m.bias) 577 | 578 | 579 | def efficientnet_init_weights(model: nn.Module, init_fn=None): 580 | init_fn = init_fn or _init_weight_goog 581 | for n, m in model.named_modules(): 582 | init_fn(m, n) 583 | 584 | # iterate and call any module.init_weights() fn, children first 585 | for n, m in named_modules(model): 586 | if hasattr(m, 'init_weights'): 587 | m.init_weights() -------------------------------------------------------------------------------- /onnx_export.py: -------------------------------------------------------------------------------- 1 | """ 2 | ONNX export script 3 | Export PyTorch models as ONNX graphs. 4 | This export script originally started as an adaptation of code snippets found at 5 | https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html 6 | 7 | The default parameters work with PyTorch 2.0.1 and ONNX 1.13 and produce an optimal ONNX graph 8 | for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible 9 | """ 10 | 11 | import argparse 12 | import torch 13 | import numpy as np 14 | import onnx 15 | import models 16 | from copy import deepcopy 17 | from timm.models import create_model 18 | from typing import Optional, Tuple, List 19 | 20 | 21 | 22 | ## python onnx_export.py --model mobilenetv4_small ./mobilenetv4_small.onnx 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch ONNX Deployment') 25 | parser.add_argument('--output', metavar='ONNX_FILE', default=None, type=str, 26 | help='output model filename') 27 | 28 | # Model & datasets params 29 | parser.add_argument('--model', default='mobilenetv4_conv_large', type=str, metavar='MODEL', 30 | choices=['mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large_075', 31 | 'mobilenetv4_conv_large', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_medium', 32 | 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_medium_075', 33 | 'mobilenetv4_conv_small_035', 'mobilenetv4_conv_small_050', 'mobilenetv4_conv_blur_medium'], 34 | help='Name of model to train') 35 | parser.add_argument('--extra_attention_block', default=False, type=bool, help='Add an extra attention block') 36 | parser.add_argument('--checkpoint', default='./output/mobilenetv4_conv_large_best_checkpoint.pth', type=str, metavar='PATH', 37 | help='path to checkpoint (default: none)') 38 | parser.add_argument('--batch-size', default=1, type=int, 39 | metavar='N', help='mini-batch size (default: 1)') 40 | parser.add_argument('--img-size', default=384, type=int, 41 | metavar='N', help='Input image dimension, uses model default if empty') 42 | parser.add_argument('--nb-classes', type=int, default=5, 43 | help='Number classes in datasets') 44 | 45 | parser.add_argument('--opset', type=int, default=10, 46 | help='ONNX opset to use (default: 10)') 47 | parser.add_argument('--keep-init', action='store_true', default=False, 48 | help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.') 49 | parser.add_argument('--aten-fallback', action='store_true', default=False, 50 | help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.') 51 | parser.add_argument('--dynamic-size', action='store_true', default=False, 52 | help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.') 53 | parser.add_argument('--check-forward', action='store_true', default=False, 54 | help='Do a full check of torch vs onnx forward after export.') 55 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 56 | help='Override mean pixel value of datasets') 57 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 58 | help='Override std deviation of of datasets') 59 | parser.add_argument('--reparam', default=False, action='store_true', 60 | help='Reparameterize model') 61 | parser.add_argument('--training', default=False, action='store_true', 62 | help='Export in training mode (default is eval)') 63 | parser.add_argument('--verbose', default=False, action='store_true', 64 | help='Extra stdout output') 65 | parser.add_argument('--dynamo', default=False, action='store_true', 66 | help='Use torch dynamo export.') 67 | 68 | 69 | 70 | def reparameterize_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module: 71 | if not inplace: 72 | model = deepcopy(model) 73 | 74 | def _fuse(m): 75 | for child_name, child in m.named_children(): 76 | if hasattr(child, 'fuse'): 77 | setattr(m, child_name, child.fuse()) 78 | elif hasattr(child, "reparameterize"): 79 | child.reparameterize() 80 | elif hasattr(child, "switch_to_deploy"): 81 | child.switch_to_deploy() 82 | _fuse(child) 83 | 84 | _fuse(model) 85 | return model 86 | 87 | 88 | def onnx_forward(onnx_file, example_input): 89 | import onnxruntime 90 | 91 | sess_options = onnxruntime.SessionOptions() 92 | session = onnxruntime.InferenceSession(onnx_file, sess_options) 93 | input_name = session.get_inputs()[0].name 94 | output = session.run([], {input_name: example_input.numpy()}) 95 | output = output[0] 96 | return output 97 | 98 | 99 | def onnx_export( 100 | model: torch.nn.Module, 101 | output_file: str, 102 | example_input: Optional[torch.Tensor] = None, 103 | training: bool = False, 104 | verbose: bool = False, 105 | check: bool = True, 106 | check_forward: bool = False, 107 | batch_size: int = 64, 108 | input_size: Tuple[int, int, int] = None, 109 | opset: Optional[int] = None, 110 | dynamic_size: bool = False, 111 | aten_fallback: bool = False, 112 | keep_initializers: Optional[bool] = None, 113 | use_dynamo: bool = False, 114 | input_names: List[str] = None, 115 | output_names: List[str] = None, 116 | ): 117 | import onnx 118 | 119 | if training: 120 | training_mode = torch.onnx.TrainingMode.TRAINING 121 | model.train() 122 | else: 123 | training_mode = torch.onnx.TrainingMode.EVAL 124 | model.eval() 125 | 126 | if example_input is None: 127 | if not input_size: 128 | assert hasattr(model, 'default_cfg') 129 | input_size = model.default_cfg.get('input_size') 130 | example_input = torch.randn((batch_size,) + input_size, requires_grad=training) 131 | 132 | # Run model once before export trace, sets padding for models with Conv2dSameExport. This means 133 | # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for 134 | # the input img_size specified in this script. 135 | 136 | # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to 137 | # issues in the tracing of the dynamic padding or errors attempting to export the model after jit 138 | # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... 139 | with torch.no_grad(): 140 | original_out = model(example_input) 141 | 142 | print("==> Exporting model to ONNX format at '{}'".format(output_file)) 143 | 144 | input_names = input_names or ["input0"] 145 | output_names = output_names or ["output0"] 146 | 147 | dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} 148 | if dynamic_size: 149 | dynamic_axes['input0'][2] = 'height' 150 | dynamic_axes['input0'][3] = 'width' 151 | 152 | if aten_fallback: 153 | export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK 154 | else: 155 | export_type = torch.onnx.OperatorExportTypes.ONNX 156 | 157 | if use_dynamo: 158 | export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size) 159 | export_output = torch.onnx.dynamo_export( 160 | model, 161 | example_input, 162 | export_options=export_options, 163 | ) 164 | export_output.save(output_file) 165 | torch_out = None 166 | else: 167 | #TODO, for torch version >= 2.5, use torch.onnx.export() 168 | torch_out = torch.onnx._export( 169 | model, 170 | example_input, 171 | output_file, 172 | training=training_mode, 173 | export_params=True, 174 | verbose=verbose, 175 | input_names=input_names, 176 | output_names=output_names, 177 | keep_initializers_as_inputs=keep_initializers, 178 | dynamic_axes=dynamic_axes, 179 | opset_version=opset, 180 | operator_export_type=export_type 181 | ) 182 | 183 | if check: 184 | print("==> Loading and checking exported model from '{}'".format(output_file)) 185 | onnx_model = onnx.load(output_file) 186 | onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error 187 | if check_forward and not training: 188 | import numpy as np 189 | onnx_out = onnx_forward(output_file, example_input) 190 | if torch_out is not None: 191 | np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3) 192 | np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5) 193 | else: 194 | np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3) 195 | 196 | 197 | def main(): 198 | args = parser.parse_args() 199 | 200 | if args.output == None: 201 | args.output = f'./{args.model}.onnx' 202 | 203 | print("==> Creating PyTorch {} model".format(args.model)) 204 | 205 | 206 | model = create_model( 207 | args.model, 208 | num_classes=args.nb_classes, 209 | extra_attention_block=args.extra_attention_block, 210 | exportable=True 211 | ) 212 | 213 | model.load_state_dict(torch.load(args.checkpoint)['model']) 214 | model.eval() 215 | 216 | if args.reparam: 217 | model = reparameterize_model(model) 218 | 219 | onnx_export( 220 | model=model, 221 | output_file=args.output, 222 | opset=args.opset, 223 | dynamic_size=args.dynamic_size, 224 | aten_fallback=args.aten_fallback, 225 | keep_initializers=args.keep_init, 226 | check_forward=args.check_forward, 227 | training=args.training, 228 | verbose=args.verbose, 229 | use_dynamo=args.dynamo, 230 | input_size=(3, args.img_size, args.img_size), 231 | batch_size=args.batch_size, 232 | ) 233 | 234 | print("==> Passed") 235 | 236 | 237 | if __name__ == '__main__': 238 | main() 239 | -------------------------------------------------------------------------------- /onnx_optimise.py: -------------------------------------------------------------------------------- 1 | """ ONNX optimization script 2 | 3 | Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. 4 | 5 | NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 2.0.1 and ONNX 1.13), 6 | it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). 7 | 8 | Copyright 2020 Ross Wightman 9 | """ 10 | import argparse 11 | import warnings 12 | 13 | import onnx 14 | import onnxoptimizer as optimizer 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Optimize ONNX model") 18 | 19 | parser.add_argument('--model', default='mobilenetv4_conv_large', type=str, metavar='MODEL', 20 | choices=['mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large_075', 21 | 'mobilenetv4_conv_large', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_medium', 22 | 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_medium_075', 23 | 'mobilenetv4_conv_small_035', 'mobilenetv4_conv_small_050', 'mobilenetv4_conv_blur_medium'], 24 | help='Name of model to train') 25 | parser.add_argument("--output", default=None, help="The optimized model output filename") 26 | 27 | 28 | def traverse_graph(graph, prefix=''): 29 | content = [] 30 | indent = prefix + ' ' 31 | graphs = [] 32 | num_nodes = 0 33 | for node in graph.node: 34 | pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) 35 | assert isinstance(gs, list) 36 | content.append(pn) 37 | graphs.extend(gs) 38 | num_nodes += 1 39 | for g in graphs: 40 | g_count, g_str = traverse_graph(g) 41 | content.append('\n' + g_str) 42 | num_nodes += g_count 43 | return num_nodes, '\n'.join(content) 44 | 45 | 46 | def main(): 47 | args = parser.parse_args() 48 | 49 | if args.output == None: 50 | args.output = f'./{args.model}_optim.onnx' 51 | 52 | args.model = f'./{args.model}.onnx' 53 | 54 | onnx_model = onnx.load(args.model) 55 | num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) 56 | 57 | # Optimizer passes to perform 58 | passes = [ 59 | #'eliminate_deadend', 60 | 'eliminate_identity', 61 | 'eliminate_nop_dropout', 62 | 'eliminate_nop_pad', 63 | 'eliminate_nop_transpose', 64 | 'eliminate_unused_initializer', 65 | 'extract_constant_to_initializer', 66 | 'fuse_add_bias_into_conv', 67 | 'fuse_bn_into_conv', 68 | 'fuse_consecutive_concats', 69 | 'fuse_consecutive_reduce_unsqueeze', 70 | 'fuse_consecutive_squeezes', 71 | 'fuse_consecutive_transposes', 72 | #'fuse_matmul_add_bias_into_gemm', 73 | 'fuse_pad_into_conv', 74 | #'fuse_transpose_into_gemm', 75 | #'lift_lexical_references', 76 | ] 77 | 78 | # Apply the optimization on the original serialized model 79 | # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing 80 | # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401 81 | # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script. 82 | warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." 83 | "Try onnxruntime optimization if this doesn't work.") 84 | optimized_model = optimizer.optimize(onnx_model, passes) 85 | 86 | num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) 87 | print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) 88 | print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) 89 | 90 | # Save the ONNX model 91 | onnx.save(optimized_model, args.output) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() -------------------------------------------------------------------------------- /onnx_validate.py: -------------------------------------------------------------------------------- 1 | """ ONNX-runtime validation script 2 | 3 | This script was created to verify accuracy and performance of exported ONNX 4 | models running with the onnxruntime. It utilizes the PyTorch dataloader/processing 5 | pipeline for a fair comparison against the originals. 6 | 7 | Copyright 2020 Ross Wightman 8 | """ 9 | import argparse 10 | import numpy as np 11 | import torch 12 | import onnxruntime 13 | from util.utils import AverageMeter 14 | import time 15 | from datasets import MyDataset, build_transform, read_split_data 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Pytorch ONNX Validation') 19 | parser.add_argument('--data_root', default='D:/flower_data', type=str, 20 | help='path to datasets') 21 | parser.add_argument('--onnx-input', default='./mobilenetv4_conv_large_optim.onnx', type=str, metavar='PATH', 22 | help='path to onnx model/weights file') 23 | parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH', 24 | help='path to output optimized onnx graph') 25 | parser.add_argument('--profile', action='store_true', default=False, 26 | help='Enable profiler output.') 27 | parser.add_argument('--workers', default=2, type=int, metavar='N', 28 | help='number of data loading workers (default: 2)') 29 | parser.add_argument('--batch-size', default=16, type=int, 30 | metavar='N', help='mini-batch size (default: 16), as same as the train_batch_size in train_gpu.py') 31 | parser.add_argument('--img-size', default=384, type=int, 32 | metavar='N', help='Input image dimension, uses model default if empty') 33 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 34 | help='Override mean pixel value of datasets') 35 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 36 | help='Override std deviation of of datasets') 37 | parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', 38 | help='Override default crop pct of 0.875') 39 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 40 | help='Image resize interpolation type (overrides model)') 41 | parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', 42 | help='use tensorflow mnasnet preporcessing') 43 | parser.add_argument('--print-freq', '-p', default=10, type=int, 44 | metavar='N', help='print frequency (default: 10)') 45 | 46 | 47 | def main(): 48 | args = parser.parse_args() 49 | args.gpu_id = 0 50 | 51 | args.input_size = args.img_size 52 | 53 | # Set graph optimization level 54 | sess_options = onnxruntime.SessionOptions() 55 | sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 56 | if args.profile: 57 | sess_options.enable_profiling = True 58 | if args.onnx_output_opt: 59 | sess_options.optimized_model_filepath = args.onnx_output_opt 60 | 61 | session = onnxruntime.InferenceSession(args.onnx_input, sess_options) 62 | 63 | # data_config = resolve_data_config(None, args) 64 | val_set = build_dataset(args) 65 | 66 | loader = torch.utils.data.DataLoader( 67 | val_set, 68 | batch_size=args.batch_size, 69 | num_workers=args.workers, 70 | drop_last=False 71 | ) 72 | 73 | input_name = session.get_inputs()[0].name 74 | 75 | batch_time = AverageMeter() 76 | top1 = AverageMeter() 77 | top5 = AverageMeter() 78 | end = time.time() 79 | for i, (input, target) in enumerate(loader): 80 | # run the net and return prediction 81 | output = session.run([], {input_name: input.data.numpy()}) 82 | output = output[0] 83 | 84 | # measure accuracy and record loss 85 | prec1, prec5 = accuracy_np(output, target.numpy()) 86 | top1.update(prec1.item(), input.size(0)) 87 | top5.update(prec5.item(), input.size(0)) 88 | 89 | # measure elapsed time 90 | batch_time.update(time.time() - end) 91 | end = time.time() 92 | 93 | if i % args.print_freq == 0: 94 | print('Test: [{0}/{1}]\t' 95 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' 96 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 97 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 98 | i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, 99 | ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) 100 | 101 | print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( 102 | top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) 103 | 104 | 105 | def accuracy_np(output, target): 106 | max_indices = np.argsort(output, axis=1)[:, ::-1] 107 | top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() 108 | top1 = 100 * np.equal(max_indices[:, 0], target).mean() 109 | return top1, top5 110 | 111 | 112 | def build_dataset(args): 113 | train_image_path, train_image_label, val_image_path, val_image_label, class_indices = read_split_data(args.data_root) 114 | 115 | valid_transform = build_transform(False, args) 116 | 117 | valid_set = MyDataset(val_image_path, val_image_label, valid_transform) 118 | return valid_set 119 | 120 | 121 | if __name__ == '__main__': 122 | main() -------------------------------------------------------------------------------- /optim_AUC.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import partial 3 | from scipy.optimize import fmin 4 | from sklearn import metrics 5 | 6 | 7 | def max_voting(preds): 8 | """ 9 | Create mean predictions 10 | :param probas: 2-d array of prediction values 11 | :return: max voted predictions 12 | """ 13 | 14 | ''' 15 | preds: np.array([[0, 2, 2, 2], [1, 1, 0, 1]]) 16 | return : [[2] 17 | [1]] 18 | ''' 19 | idxs = np.argmax(preds, axis=1) 20 | return np.take_along_axis(preds, idxs[:, None], axis=1) 21 | 22 | 23 | class OptimizeAUC: 24 | def __init__(self): 25 | self.coef_ = 0. 26 | 27 | def _auc(self, coef, outputs, labels): 28 | """ 29 | This functions calulates and returns AUC. 30 | :param coef: coef list, of the same length as number of models 31 | :param X: predictions, in this case a 2d array 32 | :param y: targets, in our case binary 1d array 33 | """ 34 | # multiply coefficients with every column of the array 35 | # with predictions. 36 | # this means: element 1 of coef is multiplied by column 1 37 | # of the prediction array, element 2 of coef is multiplied 38 | # by column 2 of the prediction array and so on! 39 | 40 | x_coef = coef * outputs 41 | 42 | # create predictions by taking row wise sum 43 | predictions = x_coef / np.sum(x_coef, axis=1, keepdims=True) 44 | 45 | # calculate auc score 46 | auc_score = metrics.roc_auc_score(labels, predictions, average='weighted', multi_class='ovo') 47 | 48 | # return negative auc 49 | return -1.0 * auc_score 50 | 51 | 52 | def fit(self, X, y): 53 | # remember partial from hyperparameter optimization chapter? 54 | loss_partial = partial(self._auc, outputs=X, labels=y) 55 | 56 | # dirichlet distribution. you can use any distribution you want 57 | # to initialize the coefficients 58 | # we want the coefficients to sum to 1 59 | initial_coef = np.random.dirichlet(np.ones(X.shape[1]), size=1) 60 | # use scipy fmin to minimize the loss function, in our case auc 61 | self.coef_ = fmin(loss_partial, initial_coef, disp=True) 62 | 63 | def predict(self, X): 64 | # this is similar to _auc function 65 | x_coef = X * self.coef_ 66 | predictions = x_coef / np.sum(x_coef, axis=1, keepdims=True) 67 | return predictions -------------------------------------------------------------------------------- /prediction_probs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaowoguanren0615/MobileNetV4/ce474fdb8f500fc6dd477e2538e31539ff1be2f9/prediction_probs.png -------------------------------------------------------------------------------- /sample_png/mobilenetV4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaowoguanren0615/MobileNetV4/ce474fdb8f500fc6dd477e2538e31539ff1be2f9/sample_png/mobilenetV4.jpg -------------------------------------------------------------------------------- /train_gpu.py: -------------------------------------------------------------------------------- 1 | """ ImageNet Training Script 2 | 3 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet 4 | training results with some of the latest networks and training techniques. It favours canonical PyTorch 5 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed 6 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. 7 | 8 | This script was started from an early version of the PyTorch ImageNet example 9 | (https://github.com/pytorch/examples/tree/master/imagenet) 10 | 11 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples 12 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 13 | 14 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 15 | """ 16 | import argparse 17 | import datetime 18 | import numpy as np 19 | import time 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | import json 24 | import os 25 | 26 | 27 | from pathlib import Path 28 | 29 | import timm 30 | from timm.data import Mixup 31 | from timm.models import create_model 32 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 33 | from timm.scheduler import create_scheduler 34 | from timm.optim import create_optimizer 35 | from timm.utils import NativeScaler, get_state_dict, ModelEma 36 | 37 | from models import * 38 | from safetensors.torch import load_file 39 | 40 | from util.samplers import RASampler 41 | from util import utils as utils 42 | from util.optimizer import SophiaG, MARS 43 | from util.engine import train_one_epoch, evaluate 44 | from util.losses import DistillationLoss 45 | 46 | from datasets import build_dataset 47 | from datasets.threeaugment import new_data_aug_generator 48 | 49 | from estimate_model import Predictor, Plot_ROC, OptAUC 50 | 51 | 52 | def get_args_parser(): 53 | parser = argparse.ArgumentParser( 54 | 'MobileNetV4 training and evaluation script', add_help=False) 55 | parser.add_argument('--batch-size', default=16, type=int) 56 | parser.add_argument('--epochs', default=5, type=int) 57 | parser.add_argument('--predict', default=True, type=bool, help='plot ROC curve and confusion matrix') 58 | parser.add_argument('--opt_auc', default=False, type=bool, help='Optimize AUC') 59 | 60 | # Model parameters 61 | parser.add_argument('--model', default='mobilenetv4_conv_large', type=str, metavar='MODEL', 62 | choices=['mobilenetv4_hybrid_large', 'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large_075', 63 | 'mobilenetv4_conv_large', 'mobilenetv4_conv_aa_large', 'mobilenetv4_conv_medium', 64 | 'mobilenetv4_conv_aa_medium', 'mobilenetv4_conv_small', 'mobilenetv4_hybrid_medium_075', 65 | 'mobilenetv4_conv_small_035', 'mobilenetv4_conv_small_050', 'mobilenetv4_conv_blur_medium'], 66 | help='Name of model to train') 67 | parser.add_argument('--extra_attention_block', default=False, type=bool, help='Add an extra attention block') 68 | parser.add_argument('--input-size', default=384, type=int, help='images input size') 69 | parser.add_argument('--model-ema', action='store_true') 70 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 71 | parser.set_defaults(model_ema=True) 72 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 73 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 74 | 75 | # Optimizer parameters 76 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 77 | help='Optimizer (default: "adamw"') 78 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 79 | help='Optimizer Epsilon (default: 1e-8)') 80 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 81 | help='Optimizer Betas (default: None, use opt default)') 82 | parser.add_argument('--clip-grad', type=float, default=0.02, metavar='NORM', 83 | help='Clip gradient norm (default: None, no clipping)') 84 | parser.add_argument('--clip-mode', type=str, default='agc', 85 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 86 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 87 | help='SGD momentum (default: 0.9)') 88 | parser.add_argument('--weight-decay', type=float, default=0.025, 89 | help='weight decay (default: 0.025)') 90 | 91 | # Learning rate schedule parameters 92 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 93 | help='LR scheduler (default: "cosine"') 94 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 95 | help='learning rate (default: 1e-3)') 96 | parser.add_argument('--adamw_lr', type=float, default=3e-3, metavar='AdamWLR', 97 | help='Using MARS optimizer, learning rate for adamw(default: 3e-3)') 98 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 99 | help='learning rate noise on/off epoch percentages') 100 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 101 | help='learning rate noise limit percent (default: 0.67)') 102 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 103 | help='learning rate noise std-dev (default: 1.0)') 104 | parser.add_argument('--warmup-lr', type=float, default=1e-4, metavar='LR', 105 | help='warmup learning rate (default: 1e-4)') 106 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 107 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 108 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 109 | help='epoch interval to decay LR') 110 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 111 | help='epochs to warmup LR, if scheduler supports') 112 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 113 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 114 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 115 | help='patience epochs for Plateau LR scheduler (default: 10') 116 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 117 | help='LR decay rate (default: 0.1)') 118 | 119 | # Augmentation parameters 120 | parser.add_argument('--ThreeAugment', action='store_true') 121 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 122 | help='Color jitter factor (default: 0.4)') 123 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 124 | help='Use AutoAugment policy. "v0" or "original". " + \ 125 | "(default: rand-m9-mstd0.5-inc1)'), 126 | parser.add_argument('--smoothing', type=float, default=0.1, 127 | help='Label smoothing (default: 0.1)') 128 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 129 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 130 | parser.add_argument('--repeated-aug', action='store_true') 131 | parser.add_argument('--no-repeated-aug', 132 | action='store_false', dest='repeated_aug') 133 | parser.set_defaults(repeated_aug=True) 134 | 135 | # Random Erase params 136 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 137 | help='Random erase prob (default: 0.25)') 138 | parser.add_argument('--remode', type=str, default='pixel', 139 | help='Random erase mode (default: "pixel")') 140 | parser.add_argument('--recount', type=int, default=1, 141 | help='Random erase count (default: 1)') 142 | parser.add_argument('--resplit', action='store_true', default=False, 143 | help='Do not random erase first (clean) augmentation split') 144 | 145 | # Mixup params 146 | parser.add_argument('--mixup', type=float, default=0.8, 147 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 148 | parser.add_argument('--cutmix', type=float, default=1.0, 149 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 150 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 151 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 152 | parser.add_argument('--mixup-prob', type=float, default=1.0, 153 | help='Probability of performing mixup or cutmix when either/both is enabled') 154 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 155 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 156 | parser.add_argument('--mixup-mode', type=str, default='batch', 157 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 158 | 159 | # Distillation parameters 160 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 161 | help='Name of teacher model to train (default: "regnety_160"') 162 | parser.add_argument('--teacher-path', type=str, 163 | default='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth') 164 | parser.add_argument('--distillation-type', default='none', 165 | choices=['none', 'soft', 'hard'], type=str, help="") 166 | parser.add_argument('--distillation-alpha', 167 | default=0.5, type=float, help="") 168 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 169 | 170 | # Finetuning params 171 | parser.add_argument('--finetune', default='./models/model.safetensors', 172 | help='finetune from checkpoint') 173 | parser.add_argument('--freeze_layers', type=bool, default=False, help='freeze layers') 174 | parser.add_argument('--set_bn_eval', action='store_true', default=False, 175 | help='set BN layers to eval mode during finetuning.') 176 | 177 | # Dataset parameters 178 | parser.add_argument('--data_root', default='D:/flower_data', type=str, 179 | help='dataset path') 180 | parser.add_argument('--nb_classes', default=5, type=int, 181 | help='number classes of your dataset') 182 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 183 | type=str, help='Image Net dataset path') 184 | parser.add_argument('--inat-category', default='name', 185 | choices=['kingdom', 'phylum', 'class', 'order', 186 | 'supercategory', 'family', 'genus', 'name'], 187 | type=str, help='semantic granularity') 188 | parser.add_argument('--output_dir', default='./output', 189 | help='path where to save, empty for no saving') 190 | parser.add_argument('--writer_output', default='./', 191 | help='path where to save SummaryWriter, empty for no saving') 192 | parser.add_argument('--device', default='cuda', 193 | help='device to use for training / testing') 194 | parser.add_argument('--seed', default=0, type=int) 195 | parser.add_argument('--resume', default='', help='resume from checkpoint') 196 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 197 | help='start epoch') 198 | parser.add_argument('--eval', action='store_true', 199 | help='Perform evaluation only') 200 | parser.add_argument('--dist-eval', action='store_true', 201 | default=False, help='Enabling distributed evaluation') 202 | parser.add_argument('--num_workers', default=0, type=int) 203 | parser.add_argument('--pin-mem', action='store_true', 204 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 205 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 206 | help='') 207 | parser.set_defaults(pin_mem=True) 208 | 209 | # training parameters 210 | parser.add_argument('--world_size', default=1, type=int, 211 | help='number of distributed processes') 212 | parser.add_argument('--local_rank', default=0, type=int) 213 | parser.add_argument('--dist_url', default='env://', 214 | help='url used to set up distributed training') 215 | parser.add_argument('--save_freq', default=1, type=int, 216 | help='frequency of model saving') 217 | return parser 218 | 219 | 220 | 221 | 222 | def main(args): 223 | print(args) 224 | utils.init_distributed_mode(args) 225 | 226 | if args.local_rank == 0: 227 | writer = SummaryWriter(os.path.join(args.writer_output, 'runs')) 228 | 229 | if args.distillation_type != 'none' and args.finetune and not args.eval: 230 | raise NotImplementedError( 231 | "Finetuning with distillation not yet supported") 232 | 233 | device = torch.device(args.device) 234 | 235 | # fix the seed for reproducibility 236 | seed = args.seed + utils.get_rank() 237 | torch.manual_seed(seed) 238 | np.random.seed(seed) 239 | # random.seed(seed) 240 | 241 | cudnn.benchmark = True 242 | 243 | dataset_train, dataset_val = build_dataset(args=args) 244 | 245 | if args.distributed: 246 | num_tasks = utils.get_world_size() 247 | global_rank = utils.get_rank() 248 | if args.repeated_aug: 249 | sampler_train = RASampler( 250 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 251 | ) 252 | else: 253 | sampler_train = torch.utils.data.DistributedSampler( 254 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 255 | ) 256 | if args.dist_eval: 257 | if len(dataset_val) % num_tasks != 0: 258 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 259 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 260 | 'equal num of samples per-process.') 261 | sampler_val = torch.utils.data.DistributedSampler( 262 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 263 | else: 264 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 265 | else: 266 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 267 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 268 | 269 | data_loader_train = torch.utils.data.DataLoader( 270 | dataset_train, sampler=sampler_train, 271 | batch_size=args.batch_size, 272 | num_workers=args.num_workers, 273 | pin_memory=args.pin_mem, 274 | drop_last=True, 275 | ) 276 | 277 | if args.ThreeAugment: 278 | data_loader_train.dataset.transform = new_data_aug_generator(args) 279 | 280 | data_loader_val = torch.utils.data.DataLoader( 281 | dataset_val, sampler=sampler_val, 282 | batch_size=int(1.5 * args.batch_size), 283 | num_workers=args.num_workers, 284 | pin_memory=args.pin_mem, 285 | drop_last=False 286 | ) 287 | 288 | mixup_fn = None 289 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 290 | if mixup_active: 291 | mixup_fn = Mixup( 292 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 293 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 294 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 295 | 296 | print(f"Creating model: {args.model}") 297 | 298 | model = create_model( 299 | args.model, 300 | extra_attention_block=args.extra_attention_block, 301 | args=args 302 | ) 303 | model.reset_classifier(num_classes=args.nb_classes) 304 | 305 | if args.finetune: 306 | if args.finetune.startswith('https'): 307 | checkpoint = torch.hub.load_state_dict_from_url( 308 | args.finetune, map_location='cpu', check_hash=True) 309 | else: 310 | checkpoint = utils.load_model(args.finetune, model) 311 | 312 | checkpoint_model = checkpoint 313 | # state_dict = model.state_dict() 314 | # new_state_dict = utils.map_safetensors(checkpoint_model, state_dict) 315 | 316 | for k in list(checkpoint_model.keys()): 317 | if 'classifier' in k: 318 | print(f"Removing key {k} from pretrained checkpoint") 319 | del checkpoint_model[k] 320 | 321 | msg = model.load_state_dict(checkpoint_model, strict=False) 322 | print(msg) 323 | 324 | if args.freeze_layers: 325 | for name, para in model.named_parameters(): 326 | if 'classifier' not in name: 327 | para.requires_grad_(False) 328 | # else: 329 | # print('training {}'.format(name)) 330 | if args.extra_attention_block: 331 | for name, para in model.extra_attention_block.named_parameters(): 332 | para.requires_grad_(True) 333 | 334 | model.to(device) 335 | 336 | model_ema = None 337 | if args.model_ema: 338 | # Important to create EMA model after cuda(), DP wrapper, and AMP but 339 | # before SyncBN and DDP wrapper 340 | model_ema = ModelEma( 341 | model, 342 | decay=args.model_ema_decay, 343 | device='cpu' if args.model_ema_force_cpu else '', 344 | resume='') 345 | 346 | model_without_ddp = model 347 | if args.distributed: 348 | model = torch.nn.parallel.DistributedDataParallel( 349 | model, device_ids=[args.gpu]) 350 | model_without_ddp = model.module 351 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 352 | print('number of params:', n_parameters) 353 | 354 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 355 | # args.lr = linear_scaled_lr 356 | # 357 | # print('*****************') 358 | # print('Initial LR is ', linear_scaled_lr) 359 | # print('*****************') 360 | 361 | # optimizer = create_optimizer(args, model_without_ddp) 362 | optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=2e-4, weight_decay=args.weight_decay) if args.finetune else create_optimizer(args, model_without_ddp) 363 | 364 | loss_scaler = NativeScaler() 365 | lr_scheduler, _ = create_scheduler(args, optimizer) 366 | 367 | criterion = LabelSmoothingCrossEntropy() 368 | 369 | if args.mixup > 0.: 370 | # smoothing is handled with mixup label transform 371 | criterion = SoftTargetCrossEntropy() 372 | elif args.smoothing: 373 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 374 | else: 375 | criterion = torch.nn.CrossEntropyLoss() 376 | 377 | teacher_model = None 378 | if args.distillation_type != 'none': 379 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 380 | print(f"Creating teacher model: {args.teacher_model}") 381 | teacher_model = create_model( 382 | args.teacher_model, 383 | pretrained=False, 384 | num_classes=args.nb_classes, 385 | global_pool='avg', 386 | ) 387 | if args.teacher_path.startswith('https'): 388 | checkpoint = torch.hub.load_state_dict_from_url( 389 | args.teacher_path, map_location='cpu', check_hash=True) 390 | else: 391 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 392 | teacher_model.load_state_dict(checkpoint['model']) 393 | teacher_model.to(device) 394 | teacher_model.eval() 395 | 396 | # wrap the criterion in our custom DistillationLoss, which 397 | # just dispatches to the original criterion if args.distillation_type is 398 | # 'none' 399 | criterion = DistillationLoss( 400 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 401 | ) 402 | 403 | max_accuracy = 0.0 404 | 405 | output_dir = Path(args.output_dir) 406 | if args.output_dir and utils.is_main_process(): 407 | with (output_dir / "model.txt").open("a") as f: 408 | f.write(str(model)) 409 | if args.output_dir and utils.is_main_process(): 410 | with (output_dir / "args.txt").open("a") as f: 411 | f.write(json.dumps(args.__dict__, indent=2) + "\n") 412 | if args.resume or os.path.exists(f'{args.output_dir}/{args.model}_best_checkpoint.pth'): 413 | args.resume = f'{args.output_dir}/{args.model}_best_checkpoint.pth' 414 | if args.resume.startswith('https'): 415 | checkpoint = torch.hub.load_state_dict_from_url( 416 | args.resume, map_location='cpu', check_hash=True) 417 | else: 418 | print("Loading local checkpoint at {}".format(args.resume)) 419 | checkpoint = torch.load(args.resume, map_location='cpu') 420 | msg = model_without_ddp.load_state_dict(checkpoint['model']) 421 | print(msg) 422 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 423 | 424 | optimizer.load_state_dict(checkpoint['optimizer']) 425 | for state in optimizer.state.values(): # load parameters to cuda 426 | for k, v in state.items(): 427 | if isinstance(v, torch.Tensor): 428 | state[k] = v.cuda() 429 | 430 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 431 | max_accuracy = checkpoint['best_score'] 432 | print(f'Now max accuracy is {max_accuracy}') 433 | args.start_epoch = checkpoint['epoch'] + 1 434 | if args.model_ema: 435 | utils._load_checkpoint_for_ema( 436 | model_ema, checkpoint['model_ema']) 437 | if 'scaler' in checkpoint: 438 | loss_scaler.load_state_dict(checkpoint['scaler']) 439 | if args.eval: 440 | # util.replace_batchnorm(model) # Users may choose whether to merge Conv-BN layers during eval 441 | print(f"Evaluating model: {args.model}") 442 | print(f'No Visualization') 443 | test_stats = evaluate(data_loader_val, model, device, None, None, args, visualization=False) 444 | print( 445 | f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" 446 | ) 447 | # print(model) 448 | print(f"Start training for {args.epochs} epochs") 449 | start_time = time.time() 450 | 451 | for epoch in range(args.start_epoch, args.epochs): 452 | if args.distributed: 453 | data_loader_train.sampler.set_epoch(epoch) 454 | 455 | train_stats = train_one_epoch( 456 | model, criterion, data_loader_train, 457 | optimizer, device, epoch, loss_scaler, 458 | args.clip_grad, args.clip_mode, model_ema, mixup_fn, 459 | # set_training_mode=args.finetune == '', # keep in eval mode during finetuning 460 | set_training_mode=True, 461 | set_bn_eval=args.set_bn_eval, # set bn to eval if finetune 462 | writer=writer, 463 | args=args 464 | ) 465 | 466 | lr_scheduler.step(epoch) 467 | 468 | test_stats = evaluate(data_loader_val, model, device, epoch, writer, args, visualization=True) 469 | print( 470 | f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 471 | 472 | if max_accuracy < test_stats["acc1"]: 473 | max_accuracy = test_stats["acc1"] 474 | if args.output_dir: 475 | ckpt_path = os.path.join(output_dir, f'{args.model}_best_checkpoint.pth') 476 | checkpoint_paths = [ckpt_path] 477 | print("Saving checkpoint to {}".format(ckpt_path)) 478 | for checkpoint_path in checkpoint_paths: 479 | utils.save_on_master({ 480 | 'model': model_without_ddp.state_dict(), 481 | 'optimizer': optimizer.state_dict(), 482 | 'lr_scheduler': lr_scheduler.state_dict(), 483 | 'epoch': epoch, 484 | 'best_score': max_accuracy, 485 | 'model_ema': get_state_dict(model_ema), 486 | 'scaler': loss_scaler.state_dict(), 487 | 'args': args, 488 | }, checkpoint_path) 489 | 490 | print(f'Max accuracy: {max_accuracy:.2f}%') 491 | 492 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 493 | **{f'test_{k}': v for k, v in test_stats.items()}, 494 | 'epoch': epoch, 495 | 'n_parameters': n_parameters} 496 | 497 | if args.output_dir and utils.is_main_process(): 498 | with (output_dir / "log.txt").open("a") as f: 499 | f.write(json.dumps(log_stats) + "\n") 500 | 501 | total_time = time.time() - start_time 502 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 503 | print('Training time {}'.format(total_time_str)) 504 | 505 | # plot ROC curve and confusion matrix 506 | if args.predict and utils.is_main_process(): 507 | model_predict = create_model( 508 | args.model, 509 | extra_attention_block=args.extra_attention_block, 510 | args=args 511 | ) 512 | 513 | model_predict.reset_classifier(num_classes=args.nb_classes) 514 | model_predict.to(device) 515 | print('*******************STARTING PREDICT*******************') 516 | Predictor(model_predict, data_loader_val, f'{args.output_dir}/{args.model}_best_checkpoint.pth', device) 517 | Plot_ROC(model_predict, data_loader_val, f'{args.output_dir}/{args.model}_best_checkpoint.pth', device) 518 | 519 | if args.opt_auc: 520 | OptAUC(model_predict, data_loader_val, f'{args.output_dir}/{args.model}_best_checkpoint.pth', device) 521 | 522 | 523 | if __name__ == '__main__': 524 | parser = argparse.ArgumentParser( 525 | 'MobileNetV4 training and evaluation script', parents=[get_args_parser()]) 526 | args = parser.parse_args() 527 | if args.output_dir: 528 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 529 | main(args) -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import train_one_epoch, evaluate 2 | from .losses import DistillationLoss 3 | from .samplers import RASampler 4 | from .optimizer import SophiaG, MARS 5 | from .utils import * -------------------------------------------------------------------------------- /util/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | """ 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | 8 | import torch 9 | 10 | from timm.data import Mixup 11 | from timm.utils import ModelEma, accuracy 12 | 13 | from .losses import DistillationLoss 14 | from util import utils as utils 15 | 16 | 17 | def set_bn_state(model): 18 | for m in model.modules(): 19 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): 20 | m.eval() 21 | 22 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 23 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 24 | device: torch.device, epoch: int, loss_scaler, 25 | clip_grad: float = 0, 26 | clip_mode: str = 'norm', 27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 28 | set_training_mode=True, 29 | set_bn_eval=False, 30 | writer=None, 31 | args=None): 32 | """ 33 | Train the model for one epoch. 34 | 35 | Args: 36 | model (torch.nn.Module): The model to be trained. 37 | criterion (DistillationLoss): The loss function used for training. 38 | data_loader (Iterable): The data loader for the training data. 39 | optimizer (torch.optim.Optimizer): The optimizer used for training. 40 | device (torch.device): The device used for training (CPU or GPU). 41 | epoch (int): The current training epoch. 42 | loss_scaler: The object used for gradient scaling. 43 | clip_grad (float, optional): The maximum value for gradient clipping. Default is 0, which means no gradient clipping. 44 | clip_mode (str, optional): The mode for gradient clipping, can be 'norm' or 'value'. Default is 'norm'. 45 | model_ema (Optional[ModelEma], optional): The EMA (Exponential Moving Average) model for saving model weights. 46 | mixup_fn (Optional[Mixup], optional): The function used for Mixup data augmentation. 47 | set_training_mode (bool, optional): Whether to set the model to training mode. Default is True. 48 | set_bn_eval (bool, optional): Whether to set the batch normalization layers to evaluation mode. Default is False. 49 | writer (Optional[Any], optional): The object used for writing TensorBoard logs. 50 | args (Optional[Any], optional): Additional arguments. 51 | 52 | Returns: 53 | Dict[str, float]: A dictionary containing the average values of the training metrics. 54 | """ 55 | 56 | 57 | model.train(set_training_mode) 58 | num_steps = len(data_loader) 59 | 60 | if set_bn_eval: 61 | set_bn_state(model) 62 | metric_logger = utils.MetricLogger(delimiter=" ") 63 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 64 | header = 'Epoch: [{}]'.format(epoch) 65 | print_freq = 50 66 | 67 | for idx, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 68 | samples = samples.to(device, non_blocking=True) 69 | targets = targets.to(device, non_blocking=True) 70 | 71 | if mixup_fn is not None: 72 | samples, targets = mixup_fn(samples, targets) 73 | 74 | with torch.cuda.amp.autocast(): 75 | outputs = model(samples) 76 | loss = criterion(samples, outputs, targets) 77 | 78 | loss_value = loss.item() 79 | 80 | if not math.isfinite(loss_value): 81 | print("Loss is {}, stopping training".format(loss_value)) 82 | sys.exit(1) 83 | 84 | optimizer.zero_grad() 85 | 86 | # this attribute is added by timm on one optimizer (adahessian) 87 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 88 | with torch.cuda.amp.autocast(): 89 | loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, 90 | parameters=model.parameters(), create_graph=is_second_order) 91 | 92 | torch.cuda.synchronize() 93 | if model_ema is not None: 94 | model_ema.update(model) 95 | 96 | learning_rate = optimizer.param_groups[0]["lr"] 97 | metric_logger.update(loss=loss_value) 98 | metric_logger.update(lr=learning_rate) 99 | 100 | 101 | if idx % print_freq == 0: 102 | if args.local_rank == 0: 103 | iter_all_count = epoch * num_steps + idx 104 | writer.add_scalar('loss', loss, iter_all_count) 105 | # writer.add_scalar('grad_norm', grad_norm, iter_all_count) 106 | writer.add_scalar('lr', learning_rate, iter_all_count) 107 | 108 | # gather the stats from all processes 109 | metric_logger.synchronize_between_processes() 110 | print("Averaged stats:", metric_logger) 111 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 112 | 113 | 114 | @torch.inference_mode() 115 | def evaluate(data_loader: Iterable, model: torch.nn.Module, 116 | device: torch.device, epoch: int, 117 | writer, args, 118 | visualization=True): 119 | """ 120 | Evaluate the model for one epoch. 121 | 122 | Args: 123 | data_loader (Iterable): The data loader for the valid data. 124 | model (torch.nn.Module): The model to be evaluated. 125 | device (torch.device): The device used for training (CPU or GPU). 126 | epoch (int): The current training epoch. 127 | writer (Optional[Any], optional): The object used for writing TensorBoard logs. 128 | args (Optional[Any], optional): Additional arguments. 129 | visualization (bool, optional): Whether to use TensorBoard visualization. Default is True. 130 | 131 | Returns: 132 | Dict[str, float]: A dictionary containing the average values of the training metrics. 133 | """ 134 | 135 | criterion = torch.nn.CrossEntropyLoss() 136 | 137 | metric_logger = utils.MetricLogger(delimiter=" ") 138 | header = 'Test:' 139 | # switch to evaluation mode 140 | model.eval() 141 | 142 | print_freq = 20 143 | for images, target in metric_logger.log_every(data_loader, print_freq, header): 144 | images = images.to(device, non_blocking=True) 145 | target = target.to(device, non_blocking=True) 146 | 147 | # compute output 148 | with torch.cuda.amp.autocast(): 149 | output = model(images) 150 | loss = criterion(output, target) 151 | 152 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 153 | 154 | batch_size = images.shape[0] 155 | metric_logger.update(loss=loss.item()) 156 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 157 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 158 | 159 | 160 | if visualization and args.local_rank == 0: 161 | writer.add_scalar('Acc@1', acc1.item(), epoch) 162 | writer.add_scalar('Acc@5', acc5.item(), epoch) 163 | 164 | # gather the stats from all processes 165 | metric_logger.synchronize_between_processes() 166 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 167 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 168 | 169 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 170 | -------------------------------------------------------------------------------- /util/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the knowledge distillation loss, proposed in deit 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | class DistillationLoss(torch.nn.Module): 9 | """ 10 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 11 | taking a teacher model prediction and using it as additional supervision. 12 | """ 13 | 14 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 15 | distillation_type: str, alpha: float, tau: float): 16 | super().__init__() 17 | self.base_criterion = base_criterion 18 | self.teacher_model = teacher_model 19 | assert distillation_type in ['none', 'soft', 'hard'] 20 | self.distillation_type = distillation_type 21 | self.alpha = alpha 22 | self.tau = tau 23 | 24 | def forward(self, inputs, outputs, labels): 25 | """ 26 | Args: 27 | inputs: The original inputs that are feed to the teacher model 28 | outputs: the outputs of the model to be trained. It is expected to be 29 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 30 | in the first position and the distillation predictions as the second output 31 | labels: the labels for the base criterion 32 | """ 33 | outputs_kd = None 34 | if not isinstance(outputs, torch.Tensor): 35 | # assume that the model outputs a tuple of [outputs, outputs_kd] 36 | outputs, outputs_kd = outputs 37 | base_loss = self.base_criterion(outputs, labels) 38 | if self.distillation_type == 'none': 39 | return base_loss 40 | 41 | if outputs_kd is None: 42 | raise ValueError("When knowledge distillation is enabled, the model is " 43 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 44 | "class_token and the dist_token") 45 | # don't backprop throught the teacher 46 | with torch.no_grad(): 47 | teacher_outputs = self.teacher_model(inputs) 48 | 49 | if self.distillation_type == 'soft': 50 | T = self.tau 51 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 52 | # with slight modifications 53 | distillation_loss = F.kl_div( 54 | F.log_softmax(outputs_kd / T, dim=1), 55 | F.log_softmax(teacher_outputs / T, dim=1), 56 | reduction='sum', 57 | log_target=True 58 | ) * (T * T) / outputs_kd.numel() 59 | elif self.distillation_type == 'hard': 60 | distillation_loss = F.cross_entropy( 61 | outputs_kd, teacher_outputs.argmax(dim=1)) 62 | 63 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 64 | return loss -------------------------------------------------------------------------------- /util/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.optim.optimizer import Optimizer 4 | from typing import List 5 | import math 6 | 7 | 8 | # optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=1e-1) 9 | # optimizer = MARS(model.parameters(), lr=args.lr, weight_decay = args.weight_decay, lr_1d=args.adamw_lr) 10 | 11 | __all__ = ['SophiaG', 'MARS'] 12 | 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | 18 | def update_fn(p, grad, exp_avg, exp_avg_sq, lr, wd, beta1, beta2, last_grad, eps, amsgrad, max_exp_avg_sq, step, gamma, 19 | mars_type, is_grad_2d, optimize_1d, lr_1d_factor, betas_1d, weight_decay_1d): 20 | # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para 21 | if optimize_1d or is_grad_2d: 22 | c_t = (grad - last_grad).mul(gamma * (beta1 / (1. - beta1))).add(grad) 23 | c_t_norm = torch.norm(c_t) 24 | if c_t_norm > 1.: 25 | c_t = c_t / c_t_norm 26 | exp_avg.mul_(beta1).add_(c_t, alpha=1. - beta1) 27 | if (mars_type == "mars-adamw") or (mars_type == "mars-shampoo" and not is_grad_2d): 28 | exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2) 29 | bias_correction1 = 1.0 - beta1 ** step 30 | bias_correction2 = 1.0 - beta2 ** step 31 | if amsgrad: 32 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 33 | denom = max_exp_avg_sq.sqrt().mul(1 / math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 34 | else: 35 | denom = exp_avg_sq.sqrt().mul(1 / math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 36 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom)) 37 | elif mars_type == "mars-lion": 38 | real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign()) 39 | elif mars_type == "mars-shampoo" and is_grad_2d: 40 | factor = max(1, grad.size(0) / grad.size(1)) ** 0.5 41 | real_update_tmp = NewtonSchulz(exp_avg.mul(1. / (1. - beta1)), eps=eps).mul(factor).add(wd, p.data).mul(-lr) 42 | p.data.add_(real_update_tmp) 43 | else: 44 | beta1_1d, beta2_1d = betas_1d 45 | exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d) 46 | exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d) 47 | bias_correction1 = 1.0 - beta1_1d ** step 48 | bias_correction2 = 1.0 - beta2_1d ** step 49 | if amsgrad: 50 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 51 | denom = max_exp_avg_sq.sqrt().mul(1 / math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 52 | else: 53 | denom = exp_avg_sq.sqrt().mul(1 / math.sqrt(bias_correction2)).add(eps).mul(bias_correction1) 54 | real_update_tmp = -lr * lr_1d_factor * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom)) 55 | p.data.add_(real_update_tmp) 56 | return exp_avg, exp_avg_sq 57 | 58 | 59 | class MARS(Optimizer): 60 | def __init__(self, params, lr=3e-3, betas=(0.95, 0.99), eps=1e-8, weight_decay=0., amsgrad=False, gamma=0.025, 61 | is_approx=True, mars_type="mars-adamw", optimize_1d=False, lr_1d=3e-3, betas_1d=(0.9, 0.95), 62 | weight_decay_1d=0.1): 63 | if not 0.0 <= lr: 64 | raise ValueError("Invalid learning rate: {}".format(lr)) 65 | if not 0.0 <= eps: 66 | raise ValueError("Invalid epsilon value: {}".format(eps)) 67 | if not 0.0 <= betas[0] < 1.0: 68 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 69 | if not 0.0 <= betas[1] < 1.0: 70 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 71 | assert mars_type in ["mars-adamw", "mars-lion", "mars-shampoo"], "MARS type not supported" 72 | defaults = dict(lr=lr, betas=betas, eps=eps, 73 | weight_decay=weight_decay, amsgrad=amsgrad, 74 | mars_type=mars_type, gamma=gamma, 75 | optimize_1d=optimize_1d, weight_decay_1d=weight_decay_1d) 76 | super(MARS, self).__init__(params, defaults) 77 | self.eps = eps 78 | self.update_fn = update_fn 79 | self.lr = lr 80 | self.weight_decay = weight_decay 81 | self.amsgrad = amsgrad 82 | self.step_num = 0 83 | self.is_approx = is_approx 84 | self.gamma = gamma 85 | self.mars_type = mars_type 86 | self.optimize_1d = optimize_1d 87 | self.lr_1d_factor = lr_1d / lr 88 | self.weight_decay_1d = weight_decay_1d 89 | self.betas_1d = betas_1d 90 | 91 | @torch.no_grad() 92 | def update_last_grad(self): 93 | if not self.is_approx: 94 | for group in self.param_groups: 95 | for p in group['params']: 96 | state = self.state[p] 97 | if "last_grad" not in state: 98 | state["last_grad"] = torch.zeros_like(p) 99 | state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0) 100 | 101 | @torch.no_grad() 102 | def update_previous_grad(self): 103 | if not self.is_approx: 104 | for group in self.param_groups: 105 | # print ("para name", len(group['params']), len(group['names']), group['names']) 106 | for p in group['params']: 107 | # import pdb 108 | # pdb.set_trace() 109 | if p.grad is None: 110 | print(p, "grad is none") 111 | continue 112 | state = self.state[p] 113 | if "previous_grad" not in state: 114 | state['previous_grad'] = torch.zeros_like(p) 115 | state['previous_grad'].zero_().add_(p.grad, alpha=1.0) 116 | 117 | def __setstate__(self, state): 118 | super(MARS, self).__setstate__(state) 119 | for group in self.param_groups: 120 | group.setdefault('amsgrad', False) 121 | 122 | @torch.no_grad() 123 | def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): 124 | """Performs a single optimization step. 125 | 126 | Arguments: 127 | closure (callable, optional): A closure that reevaluates the model 128 | and returns the loss. 129 | 130 | If using exact version, the example usage is as follows: 131 | previous_X, previous_Y = None, None 132 | for epoch in range(epochs): 133 | for X, Y in data_loader: 134 | if previous_X: 135 | logits, loss = model(X, Y) 136 | loss.backward() 137 | optimizer.update_previous_grad() 138 | optimizer.zero_grad(set_to_none=True) 139 | logits, loss = model(X, Y) 140 | loss.backward() 141 | optimizer.step(bs=bs) 142 | optimizer.zero_grad(set_to_none=True) 143 | optimizer.update_last_grad() 144 | iter_num += 1 145 | previous_X, previous_Y = X.clone(), Y.clone() 146 | """ 147 | if any(p is not None for p in [grads, output_params, scale, grad_norms]): 148 | raise RuntimeError( 149 | 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.') 150 | 151 | loss = None 152 | if exists(closure): 153 | with torch.enable_grad(): 154 | loss = closure() 155 | real_update = 0 156 | real_update_wo_lr = 0 157 | gamma = self.gamma 158 | # import pdb 159 | # pdb.set_trace() 160 | for group in self.param_groups: 161 | for p in filter(lambda p: exists(p.grad), group['params']): 162 | if p.grad is None: 163 | continue 164 | grad = p.grad.data 165 | if grad.is_sparse: 166 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 167 | amsgrad = group['amsgrad'] 168 | 169 | state = self.state[p] 170 | # ('----- starting a parameter state', state.keys(), 'Length of state', len(state)) 171 | # State initialization 172 | if len(state) <= 1: 173 | state['step'] = 0 174 | # Exponential moving average of gradient values 175 | state['exp_avg'] = torch.zeros_like(p.data) 176 | # Last Gradient 177 | state['last_grad'] = torch.zeros_like(p) 178 | # state['previous_grad'] = torch.zeros_like(p) 179 | # Exponential moving average of squared gradient values 180 | state['exp_avg_sq'] = torch.zeros_like(p.data) 181 | if amsgrad: 182 | # Maintains max of all exp. moving avg. of sq. grad. values 183 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 184 | # import pdb 185 | # pdb.set_trace() 186 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 187 | last_grad = state['last_grad'] 188 | lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas'] 189 | if amsgrad: 190 | max_exp_avg_sq = state['max_exp_avg_sq'] 191 | else: 192 | max_exp_avg_sq = 0 193 | 194 | if 'step' in state: 195 | state['step'] += 1 196 | else: 197 | state['step'] = 1 198 | step = state['step'] 199 | is_grad_2d = (len(grad.shape) == 2) 200 | exp_avg, exp_avg_sq = self.update_fn( 201 | p, 202 | grad, 203 | exp_avg, 204 | exp_avg_sq, 205 | lr, 206 | wd, 207 | beta1, 208 | beta2, 209 | last_grad, 210 | self.eps, 211 | amsgrad, 212 | max_exp_avg_sq, 213 | step, 214 | gamma, 215 | mars_type=self.mars_type, 216 | is_grad_2d=is_grad_2d, 217 | optimize_1d=self.optimize_1d, 218 | lr_1d_factor=self.lr_1d_factor, 219 | betas_1d=self.betas_1d, 220 | weight_decay_1d=self.weight_decay if self.optimize_1d else self.weight_decay_1d 221 | ) 222 | if self.is_approx: 223 | state['last_grad'] = grad 224 | self.step_num = step 225 | 226 | return loss 227 | 228 | 229 | @torch.compile 230 | def NewtonSchulz(M, steps=5, eps=1e-7): 231 | a, b, c = (3.4445, -4.7750, 2.0315) 232 | X = M.bfloat16() / (M.norm() + eps) 233 | if M.size(0) > M.size(1): 234 | X = X.T 235 | for _ in range(steps): 236 | A = X @ X.T 237 | B = A @ X 238 | X = a * X + b * B + c * A @ B 239 | if M.size(0) > M.size(1): 240 | X = X.T 241 | return X.to(M.dtype) 242 | 243 | 244 | class SophiaG(Optimizer): 245 | def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho=0.04, 246 | weight_decay=1e-1, *, maximize: bool = False, 247 | capturable: bool = False): 248 | if not 0.0 <= lr: 249 | raise ValueError("Invalid learning rate: {}".format(lr)) 250 | if not 0.0 <= betas[0] < 1.0: 251 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 252 | if not 0.0 <= betas[1] < 1.0: 253 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 254 | if not 0.0 <= rho: 255 | raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) 256 | if not 0.0 <= weight_decay: 257 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 258 | defaults = dict(lr=lr, betas=betas, rho=rho, 259 | weight_decay=weight_decay, 260 | maximize=maximize, capturable=capturable) 261 | super(SophiaG, self).__init__(params, defaults) 262 | 263 | def __setstate__(self, state): 264 | super().__setstate__(state) 265 | for group in self.param_groups: 266 | group.setdefault('maximize', False) 267 | group.setdefault('capturable', False) 268 | state_values = list(self.state.values()) 269 | step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) 270 | if not step_is_tensor: 271 | for s in state_values: 272 | s['step'] = torch.tensor(float(s['step'])) 273 | 274 | @torch.no_grad() 275 | def update_hessian(self): 276 | for group in self.param_groups: 277 | beta1, beta2 = group['betas'] 278 | for p in group['params']: 279 | if p.grad is None: 280 | continue 281 | state = self.state[p] 282 | 283 | if len(state) == 0: 284 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ 285 | if self.defaults['capturable'] else torch.tensor(0.) 286 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 287 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 288 | 289 | if 'hessian' not in state.keys(): 290 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 291 | 292 | state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) 293 | 294 | @torch.no_grad() 295 | def step(self, closure=None, bs=5120): 296 | loss = None 297 | if closure is not None: 298 | with torch.enable_grad(): 299 | loss = closure() 300 | 301 | for group in self.param_groups: 302 | params_with_grad = [] 303 | grads = [] 304 | exp_avgs = [] 305 | state_steps = [] 306 | hessian = [] 307 | beta1, beta2 = group['betas'] 308 | 309 | for p in group['params']: 310 | if p.grad is None: 311 | continue 312 | params_with_grad.append(p) 313 | 314 | if p.grad.is_sparse: 315 | raise RuntimeError('Hero does not support sparse gradients') 316 | grads.append(p.grad) 317 | state = self.state[p] 318 | # State initialization 319 | if len(state) == 0: 320 | state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ 321 | if self.defaults['capturable'] else torch.tensor(0.) 322 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 323 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 324 | 325 | if 'hessian' not in state.keys(): 326 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 327 | 328 | exp_avgs.append(state['exp_avg']) 329 | state_steps.append(state['step']) 330 | hessian.append(state['hessian']) 331 | 332 | if self.defaults['capturable']: 333 | bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs 334 | 335 | sophiag(params_with_grad, 336 | grads, 337 | exp_avgs, 338 | hessian, 339 | state_steps, 340 | bs=bs, 341 | beta1=beta1, 342 | beta2=beta2, 343 | rho=group['rho'], 344 | lr=group['lr'], 345 | weight_decay=group['weight_decay'], 346 | maximize=group['maximize'], 347 | capturable=group['capturable']) 348 | 349 | return loss 350 | 351 | 352 | def sophiag(params: List[Tensor], 353 | grads: List[Tensor], 354 | exp_avgs: List[Tensor], 355 | hessian: List[Tensor], 356 | state_steps: List[Tensor], 357 | capturable: bool = False, 358 | *, 359 | bs: int, 360 | beta1: float, 361 | beta2: float, 362 | rho: float, 363 | lr: float, 364 | weight_decay: float, 365 | maximize: bool): 366 | if not all(isinstance(t, torch.Tensor) for t in state_steps): 367 | raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") 368 | 369 | func = _single_tensor_sophiag 370 | 371 | func(params, 372 | grads, 373 | exp_avgs, 374 | hessian, 375 | state_steps, 376 | bs=bs, 377 | beta1=beta1, 378 | beta2=beta2, 379 | rho=rho, 380 | lr=lr, 381 | weight_decay=weight_decay, 382 | maximize=maximize, 383 | capturable=capturable) 384 | 385 | 386 | def _single_tensor_sophiag(params: List[Tensor], 387 | grads: List[Tensor], 388 | exp_avgs: List[Tensor], 389 | hessian: List[Tensor], 390 | state_steps: List[Tensor], 391 | *, 392 | bs: int, 393 | beta1: float, 394 | beta2: float, 395 | rho: float, 396 | lr: float, 397 | weight_decay: float, 398 | maximize: bool, 399 | capturable: bool): 400 | for i, param in enumerate(params): 401 | grad = grads[i] if not maximize else -grads[i] 402 | exp_avg = exp_avgs[i] 403 | hess = hessian[i] 404 | step_t = state_steps[i] 405 | 406 | if capturable: 407 | assert param.is_cuda and step_t.is_cuda and bs.is_cuda 408 | 409 | if torch.is_complex(param): 410 | grad = torch.view_as_real(grad) 411 | exp_avg = torch.view_as_real(exp_avg) 412 | hess = torch.view_as_real(hess) 413 | param = torch.view_as_real(param) 414 | 415 | # update step 416 | step_t += 1 417 | 418 | # Perform stepweight decay 419 | param.mul_(1 - lr * weight_decay) 420 | 421 | # Decay the first and second moment running average coefficient 422 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 423 | 424 | if capturable: 425 | step = step_t 426 | step_size = lr 427 | step_size_neg = step_size.neg() 428 | 429 | ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) 430 | param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) 431 | else: 432 | step = step_t.item() 433 | step_size_neg = - lr 434 | 435 | ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) 436 | param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) -------------------------------------------------------------------------------- /util/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.util.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | if num_repeats < 1: 26 | raise ValueError("num_repeats should be greater than 0") 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.num_repeats = num_repeats 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g) 44 | else: 45 | indices = torch.arange(start=0, end=len(self.dataset)) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 49 | padding_size: int = self.total_size - len(indices) 50 | if padding_size > 0: 51 | indices += indices[:padding_size] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | 66 | 67 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers and model loaders 3 | Also include a model loader specified for finetuning EfficientViT 4 | """ 5 | import io 6 | import os 7 | import time 8 | from collections import defaultdict, deque 9 | import datetime 10 | import logging 11 | import torch 12 | import torch.nn as nn 13 | import torch.distributed as dist 14 | import safetensors 15 | 16 | 17 | 18 | logger = logging.getLogger() 19 | 20 | 21 | class AverageMeter: 22 | """Computes and stores the average and current value""" 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | class SmoothedValue(object): 40 | """Track a series of values and provide access to smoothed values over a 41 | window or the global series average. 42 | """ 43 | 44 | def __init__(self, window_size=20, fmt=None): 45 | if fmt is None: 46 | fmt = "{median:.4f} ({global_avg:.4f})" 47 | self.deque = deque(maxlen=window_size) 48 | self.total = 0.0 49 | self.count = 0 50 | self.fmt = fmt 51 | 52 | def update(self, value, n=1): 53 | self.deque.append(value) 54 | self.count += n 55 | self.total += value * n 56 | 57 | def synchronize_between_processes(self): 58 | """ 59 | Warning: does not synchronize the deque! 60 | """ 61 | if not is_dist_avail_and_initialized(): 62 | return 63 | t = torch.tensor([self.count, self.total], 64 | dtype=torch.float64, device='cuda') 65 | dist.barrier() 66 | dist.all_reduce(t) 67 | t = t.tolist() 68 | self.count = int(t[0]) 69 | self.total = t[1] 70 | 71 | @property 72 | def median(self): 73 | d = torch.tensor(list(self.deque)) 74 | return d.median().item() 75 | 76 | @property 77 | def avg(self): 78 | d = torch.tensor(list(self.deque), dtype=torch.float32) 79 | return d.mean().item() 80 | 81 | @property 82 | def global_avg(self): 83 | return self.total / self.count 84 | 85 | @property 86 | def max(self): 87 | return max(self.deque) 88 | 89 | @property 90 | def value(self): 91 | return self.deque[-1] 92 | 93 | def __str__(self): 94 | return self.fmt.format( 95 | median=self.median, 96 | avg=self.avg, 97 | global_avg=self.global_avg, 98 | max=self.max, 99 | value=self.value) 100 | 101 | 102 | class MetricLogger(object): 103 | def __init__(self, delimiter="\t"): 104 | self.meters = defaultdict(SmoothedValue) 105 | self.delimiter = delimiter 106 | 107 | def update(self, **kwargs): 108 | for k, v in kwargs.items(): 109 | if isinstance(v, torch.Tensor): 110 | v = v.item() 111 | assert isinstance(v, (float, int)) 112 | self.meters[k].update(v) 113 | 114 | def __getattr__(self, attr): 115 | if attr in self.meters: 116 | return self.meters[attr] 117 | if attr in self.__dict__: 118 | return self.__dict__[attr] 119 | raise AttributeError("'{}' object has no attribute '{}'".format( 120 | type(self).__name__, attr)) 121 | 122 | def __str__(self): 123 | loss_str = [] 124 | for name, meter in self.meters.items(): 125 | loss_str.append( 126 | "{}: {}".format(name, str(meter)) 127 | ) 128 | return self.delimiter.join(loss_str) 129 | 130 | def synchronize_between_processes(self): 131 | for meter in self.meters.values(): 132 | meter.synchronize_between_processes() 133 | 134 | def add_meter(self, name, meter): 135 | self.meters[name] = meter 136 | 137 | def log_every(self, iterable, print_freq, header=None): 138 | i = 0 139 | if not header: 140 | header = '' 141 | start_time = time.time() 142 | end = time.time() 143 | iter_time = SmoothedValue(fmt='{avg:.4f}') 144 | data_time = SmoothedValue(fmt='{avg:.4f}') 145 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 146 | log_msg = [ 147 | header, 148 | '[{0' + space_fmt + '}/{1}]', 149 | 'eta: {eta}', 150 | '{meters}', 151 | 'time: {time}', 152 | 'data: {data}' 153 | ] 154 | if torch.cuda.is_available(): 155 | log_msg.append('max mem: {memory:.0f}') 156 | log_msg = self.delimiter.join(log_msg) 157 | MB = 1024.0 * 1024.0 158 | for obj in iterable: 159 | data_time.update(time.time() - end) 160 | yield obj 161 | iter_time.update(time.time() - end) 162 | if i % print_freq == 0 or i == len(iterable) - 1: 163 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 164 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 165 | if torch.cuda.is_available(): 166 | print(log_msg.format( 167 | i, len(iterable), eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), data=str(data_time), 170 | memory=torch.cuda.max_memory_allocated() / MB)) 171 | else: 172 | print(log_msg.format( 173 | i, len(iterable), eta=eta_string, 174 | meters=str(self), 175 | time=str(iter_time), data=str(data_time))) 176 | i += 1 177 | end = time.time() 178 | total_time = time.time() - start_time 179 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 180 | print('{} Total time: {} ({:.4f} s / it)'.format( 181 | header, total_time_str, total_time / len(iterable))) 182 | 183 | def _load_checkpoint_for_ema(model_ema, checkpoint): 184 | """ 185 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 186 | """ 187 | mem_file = io.BytesIO() 188 | torch.save(checkpoint, mem_file) 189 | mem_file.seek(0) 190 | model_ema._load_checkpoint(mem_file) 191 | 192 | def setup_for_distributed(is_master): 193 | """ 194 | This function disables printing when not in master process 195 | """ 196 | import builtins as __builtin__ 197 | builtin_print = __builtin__.print 198 | 199 | def print(*args, **kwargs): 200 | force = kwargs.pop('force', False) 201 | if is_master or force: 202 | builtin_print(*args, **kwargs) 203 | 204 | __builtin__.print = print 205 | 206 | def is_dist_avail_and_initialized(): 207 | if not dist.is_available(): 208 | return False 209 | if not dist.is_initialized(): 210 | return False 211 | return True 212 | 213 | def get_world_size(): 214 | if not is_dist_avail_and_initialized(): 215 | return 1 216 | return dist.get_world_size() 217 | 218 | def get_rank(): 219 | if not is_dist_avail_and_initialized(): 220 | return 0 221 | return dist.get_rank() 222 | 223 | def is_main_process(): 224 | return get_rank() == 0 225 | 226 | def save_on_master(*args, **kwargs): 227 | if is_main_process(): 228 | torch.save(*args, **kwargs) 229 | 230 | def init_distributed_mode(args): 231 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 232 | args.rank = int(os.environ["RANK"]) 233 | args.world_size = int(os.environ['WORLD_SIZE']) 234 | args.gpu = int(os.environ['LOCAL_RANK']) 235 | elif 'SLURM_PROCID' in os.environ: 236 | args.rank = int(os.environ['SLURM_PROCID']) 237 | args.gpu = args.rank % torch.cuda.device_count() 238 | else: 239 | print('Not using distributed mode') 240 | args.distributed = False 241 | return 242 | 243 | args.distributed = True 244 | 245 | torch.cuda.set_device(args.gpu) 246 | args.dist_backend = 'nccl' 247 | print('| distributed init (rank {}): {}'.format( 248 | args.rank, args.dist_url), flush=True) 249 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 250 | world_size=args.world_size, rank=args.rank) 251 | torch.distributed.barrier() 252 | setup_for_distributed(args.rank == 0) 253 | 254 | 255 | def replace_batchnorm(net): 256 | for child_name, child in net.named_children(): 257 | if hasattr(child, 'fuse'): 258 | setattr(net, child_name, child.fuse()) 259 | elif isinstance(child, torch.nn.BatchNorm2d): 260 | setattr(net, child_name, torch.nn.Identity()) 261 | else: 262 | replace_batchnorm(child) 263 | 264 | def replace_layernorm(net): 265 | import apex 266 | for child_name, child in net.named_children(): 267 | if isinstance(child, torch.nn.LayerNorm): 268 | setattr(net, child_name, apex.normalization.FusedLayerNorm( 269 | child.weight.size(0))) 270 | else: 271 | replace_layernorm(child) 272 | 273 | def load_model(modelpath, model: nn.Module): 274 | ''' 275 | A function to load model from a checkpoint, which is used 276 | for fine-tuning on a different resolution. 277 | ''' 278 | if 'safetensors' in modelpath: 279 | checkpoint = safetensors.torch.load_file(modelpath) 280 | else: 281 | checkpoint = torch.load(modelpath, map_location='cpu') 282 | return checkpoint 283 | 284 | 285 | def map_safetensors(safetensor_ckpt, model_state_dict): 286 | ''' 287 | A function to load model from a safetensor file, which is used 288 | for fine-tuning on a different resolution. 289 | ''' 290 | safetensors_keys = list(safetensor_ckpt.keys()) 291 | key_mapping = {} 292 | mismatched_keys = [] 293 | 294 | for model_key in model_state_dict.keys(): 295 | # 尝试在 safetensors_keys 中找到与模型键类似的键 296 | for safetensor_key in safetensors_keys: 297 | if model_key.split('.')[-1] == safetensor_key.split('.')[-1] and \ 298 | model_state_dict[model_key].shape == safetensor_ckpt[safetensor_key].shape: 299 | key_mapping[model_key] = safetensor_key 300 | # print(f"Mapping model layer '{model_key}' to safetensors layer '{safetensor_key}'") 301 | break 302 | else: 303 | # 如果没有找到匹配的层,则记录为不匹配 304 | mismatched_keys.append(model_key) 305 | 306 | # 显示所有未匹配的模型键 307 | if mismatched_keys: 308 | print("\nUnmatched model keys:") 309 | for key in mismatched_keys: 310 | print(key) 311 | 312 | # 创建一个新的 state_dict,将 safetensors 文件中的权重映射到模型中 313 | mapped_state_dict = {} 314 | for model_key, safetensor_key in key_mapping.items(): 315 | mapped_state_dict[model_key] = safetensor_ckpt[safetensor_key] 316 | 317 | return mapped_state_dict -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib 3 | import numpy as np 4 | 5 | from torchvision import transforms 6 | from PIL import Image 7 | 8 | import torch 9 | from matplotlib import pyplot as plt 10 | import os 11 | from timm.models import create_model 12 | 13 | import urllib.request 14 | 15 | 16 | device = 'cuda' 17 | 18 | 19 | def download_from_url(url, path=None, root="./"): 20 | if path is None: 21 | _, filename = os.path.split(url) 22 | root = os.path.abspath(root) 23 | path = os.path.join(root, filename) 24 | urllib.request.urlretrieve(url, path) 25 | print(f"Downloaded file to {path}") 26 | 27 | 28 | def load_class_names(json_path): 29 | with open(json_path, "r") as f: 30 | return list(json.load(f).values()) 31 | 32 | def preprocess_image(image_path): 33 | image = Image.open(image_path).convert("RGB") 34 | transform = transforms.Compose([ 35 | transforms.Resize((224, 224), interpolation=3), 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalized 38 | ]) 39 | return transform(image) 40 | 41 | 42 | @torch.inference_mode() 43 | def predict_probs_for_image(model, image_path): 44 | image = preprocess_image(image_path).unsqueeze(0) # add batch dim 45 | model.eval() 46 | outputs = model(image.to(device)) 47 | probs = torch.nn.functional.softmax(outputs, dim=1).cpu() 48 | return (probs[0] * 100).tolist() 49 | 50 | 51 | def plot_probs(texts, probs, fig_ax, lang_type=None, save_path=None): 52 | # reverse the order to plot from top to bottom 53 | sorted_indices = np.argsort(probs) 54 | texts = np.array(texts)[sorted_indices] 55 | probs = np.array(probs)[sorted_indices] 56 | if fig_ax is None: 57 | fig, ax = plt.subplots(figsize=(6, 3)) 58 | else: 59 | fig, ax = fig_ax 60 | 61 | font_prop = matplotlib.font_manager.FontProperties( 62 | fname=lang_type_to_font_path(lang_type) 63 | ) 64 | ax.barh(texts, probs, color="darkslateblue", height=0.3) 65 | ax.barh(texts, 100 - probs, color="silver", height=0.3, left=probs) 66 | for bar, label, val in zip(ax.patches, texts, probs): 67 | ax.text( 68 | 0, 69 | bar.get_y() - bar.get_height(), 70 | label, 71 | color="black", 72 | ha="left", 73 | va="center", 74 | fontproperties=font_prop, 75 | ) 76 | ax.text( 77 | bar.get_x() + bar.get_width() + 1, 78 | bar.get_y() + bar.get_height() / 2, 79 | f"{val:.2f} %", 80 | fontweight="bold", 81 | ha="left", 82 | va="center", 83 | ) 84 | 85 | ax.axis("off") 86 | 87 | if save_path: 88 | plt.savefig(save_path, bbox_inches="tight") # 保存图片并移除多余空白 89 | print(f"Figure saved to {save_path}") 90 | 91 | 92 | def predict_probs_and_plot( 93 | model, image_path, texts, plot_image=True, fig_ax=None, lang_type=None 94 | ): 95 | if plot_image: 96 | fig, (ax_1, ax_2) = plt.subplots(1, 2, figsize=(12, 6)) 97 | image = Image.open(image_path).convert('RGB') 98 | ax_1.imshow(image) 99 | ax_1.axis("off") 100 | probs = predict_probs_for_image(model, image_path) 101 | plot_probs(texts, probs, (fig, ax_2), lang_type=lang_type, save_path='./prediction_probs.png') 102 | 103 | 104 | def lang_type_to_font_path(lang_type): 105 | mapping = { 106 | None: "https://cdn.jsdelivr.net/gh/notofonts/notofonts.github.io/fonts/NotoSans/hinted/ttf/NotoSans-Regular.ttf", 107 | "cjk": "https://github.com/notofonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansCJKsc-Regular.otf", 108 | "devanagari": "https://cdn.jsdelivr.net/gh/notofonts/notofonts.github.io/fonts/NotoSansDevanagari/hinted/ttf/NotoSansDevanagari-Regular.ttf", 109 | "emoji": "https://github.com/MorbZ/OpenSansEmoji/raw/master/OpenSansEmoji.ttf", 110 | } 111 | return download_from_url(mapping[lang_type]) 112 | 113 | if __name__ == '__main__': 114 | model = create_model( 115 | 'mobilenetv4_conv_large' 116 | ) 117 | model.reset_classifier(num_classes=5) 118 | model.load_state_dict(torch.load('./output/mobilenetv4_conv_large_best_checkpoint.pth')['model']) 119 | model.to(device) 120 | 121 | texts = load_class_names('./classes_indices.json') 122 | # image_path = r'D:/flower_data/roses/1666341535_99c6f7509f_n.jpg' 123 | image_path = r'D:/flower_data/sunflowers/44079668_34dfee3da1_n.jpg' 124 | predict_probs_and_plot(model, image_path, texts) -------------------------------------------------------------------------------- /weight_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def transpose_weights(weights): 4 | if len(weights.shape) <= 1: 5 | return weights 6 | if len(weights.shape) == 2: 7 | return weights.T 8 | if len(weights.shape) == 3: 9 | return np.transpose(weights, [2, 1, 0]) 10 | else: 11 | raise ValueError("Unknown weights shape : {}".format(weights.shape)) 12 | 13 | """ Pytorch to Tensorflow convertion """ 14 | 15 | def get_pt_layers(pt_model): 16 | layers = {} 17 | state_dict = pt_model.state_dict() if not isinstance(pt_model, dict) else pt_model 18 | for k, v in state_dict.items(): 19 | layer_name = '.'.join(k.split('.')[:-1]) 20 | if layer_name not in layers: layers[layer_name] = [] 21 | layers[layer_name].append(v.cpu().numpy()) 22 | return layers 23 | 24 | def pt_convert_layer_weights(layer_weights): 25 | new_weights = [] 26 | if len(layer_weights) < 4: 27 | new_weights = layer_weights 28 | elif len(layer_weights) == 4: 29 | new_weights = layer_weights[:2] + [layer_weights[2] + layer_weights[3]] 30 | elif len(layer_weights) == 5: 31 | new_weights = layer_weights[:4] 32 | elif len(layer_weights) == 8: 33 | new_weights = layer_weights[:2] + [layer_weights[2] + layer_weights[3]] 34 | new_weights += layer_weights[4:6] + [layer_weights[6] + layer_weights[7]] 35 | else: 36 | raise ValueError("Unknown weights length : {}\n Shapes : {}".format(len(layer_weights), [tuple(v.shape) for v in layer_weights])) 37 | 38 | return [transpose_weights(w) for w in new_weights] 39 | 40 | def pt_convert_model_weights(pt_model, tf_model, verbose = False): 41 | pt_layers = get_pt_layers(pt_model) 42 | converted_weights = [] 43 | for layer_name, layer_variables in pt_layers.items(): 44 | converted_variables = pt_convert_layer_weights(layer_variables) if 'embedding' not in layer_name else layer_variables 45 | converted_weights += converted_variables 46 | 47 | if verbose: 48 | print("Layer : {} \t {} \t {}".format( 49 | layer_name, 50 | [tuple(v.shape) for v in layer_variables], 51 | [tuple(v.shape) for v in converted_variables], 52 | )) 53 | 54 | partial_transfert_learning(tf_model, converted_weights) 55 | print("Weights converted successfully !") 56 | 57 | 58 | """ Tensorflow to Pytorch converter """ 59 | 60 | def get_tf_layers(tf_model): 61 | layers = {} 62 | variables = tf_model.variables if not isinstance(tf_model, list) else tf_model 63 | for v in variables: 64 | layer_name = '/'.join(v.name.split('/')[:-1]) 65 | if layer_name not in layers: layers[layer_name] = [] 66 | layers[layer_name].append(v.numpy()) 67 | return layers 68 | 69 | def tf_convert_layer_weights(layer_weights): 70 | new_weights = [] 71 | if len(layer_weights) < 3 or len(layer_weights) == 4: 72 | new_weights = layer_weights 73 | elif len(layer_weights) == 3: 74 | new_weights = layer_weights[:2] + [layer_weights[2] / 2., layer_weights[2] / 2.] 75 | else: 76 | raise ValueError("Unknown weights length : {}\n Shapes : {}".format(len(layer_weights), [tuple(v.shape) for v in layer_weights])) 77 | 78 | return [transpose_weights(w) for w in new_weights] 79 | 80 | 81 | def tf_convert_model_weights(tf_model, pt_model, verbose = False): 82 | import torch 83 | 84 | pt_layers = pt_model.state_dict() 85 | tf_layers = get_tf_layers(tf_model) 86 | converted_weights = [] 87 | for layer_name, layer_variables in tf_layers.items(): 88 | converted_variables = tf_convert_layer_weights(layer_variables) if 'embedding' not in layer_name else layer_variables 89 | converted_weights += converted_variables 90 | 91 | if verbose: 92 | print("Layer : {} \t {} \t {}".format( 93 | layer_name, 94 | [tuple(v.shape) for v in layer_variables], 95 | [tuple(v.shape) for v in converted_variables], 96 | )) 97 | 98 | tf_idx = 0 99 | for i, (pt_name, pt_weights) in enumerate(pt_layers.items()): 100 | if len(pt_weights.shape) == 0: continue 101 | 102 | pt_weights.data = torch.from_numpy(converted_weights[tf_idx]) 103 | tf_idx += 1 104 | 105 | pt_model.load_state_dict(pt_layers) 106 | print("Weights converted successfully !") 107 | 108 | """ Partial transfert learning """ 109 | 110 | def partial_transfert_learning(target_model, 111 | pretrained_model, 112 | partial_transfert = True, 113 | partial_initializer = 'normal_conditionned' 114 | ): 115 | """ 116 | Make transfert learning on model with either : 117 | - different number of layers (and same shapes for some layers) 118 | - different shapes (and same number of layers) 119 | 120 | Arguments : 121 | - target_model : tf.keras.Model instance (model where weights will be transfered to) 122 | - pretrained_model : tf.keras.Model or list of weights (pretrained) 123 | - partial_transfert : whether to do partial transfert for layers with different shapes (only relevant if 2 models have same number of layers) 124 | """ 125 | assert partial_initializer in (None, 'zeros', 'ones', 'normal', 'normal_conditionned') 126 | def partial_weight_transfert(target, pretrained_v): 127 | v = target 128 | if partial_initializer == 'zeros': 129 | v = np.zeros_like(target) 130 | elif partial_initializer == 'ones': 131 | v = np.ones_like(target) 132 | elif partial_initializer == 'normal_conditionned': 133 | v = np.random.normal(loc = np.mean(pretrained_v), scale = np.std(pretrained_v), size = target.shape) 134 | elif partial_initializer == 'normal': 135 | v = np.random.normal(size = target.shape) 136 | 137 | 138 | if v.ndim == 1: 139 | max_0 = min(v.shape[0], pretrained_v.shape[0]) 140 | v[:max_0] = pretrained_v[:max_0] 141 | elif v.ndim == 2: 142 | max_0 = min(v.shape[0], pretrained_v.shape[0]) 143 | max_1 = min(v.shape[1], pretrained_v.shape[1]) 144 | v[:max_0, :max_1] = pretrained_v[:max_0, :max_1] 145 | elif v.ndim == 3: 146 | max_0 = min(v.shape[0], pretrained_v.shape[0]) 147 | max_1 = min(v.shape[1], pretrained_v.shape[1]) 148 | max_2 = min(v.shape[2], pretrained_v.shape[2]) 149 | v[:max_0, :max_1, :max_2] = pretrained_v[:max_0, :max_1, :max_2] 150 | elif v.ndim == 4: 151 | max_0 = min(v.shape[0], pretrained_v.shape[0]) 152 | max_1 = min(v.shape[1], pretrained_v.shape[1]) 153 | max_2 = min(v.shape[2], pretrained_v.shape[2]) 154 | max_3 = min(v.shape[3], pretrained_v.shape[3]) 155 | v[:max_0, :max_1, :max_2, :max_3] = pretrained_v[:max_0, :max_1, :max_2, :max_3] 156 | else: 157 | raise ValueError("Variable dims > 4 non géré !") 158 | 159 | return v 160 | 161 | target_variables = target_model.variables 162 | pretrained_variables = pretrained_model.variables if not isinstance(pretrained_model, list) else pretrained_model 163 | 164 | skip_layer = len(target_variables) != len(pretrained_variables) 165 | skip_from_a = None 166 | if skip_layer: 167 | skip_from_a = (len(target_variables) > len(pretrained_variables)) 168 | 169 | new_weights = [] 170 | idx_a, idx_b = 0, 0 171 | while idx_a < len(target_variables) and idx_b < len(pretrained_variables): 172 | v, pretrained_v = target_variables[idx_a], pretrained_variables[idx_b] 173 | v = v.numpy() 174 | if not isinstance(pretrained_v, np.ndarray) : pretrained_v = pretrained_v.numpy() 175 | 176 | if v.shape != pretrained_v.shape and skip_layer: 177 | if skip_from_a: 178 | idx_a += 1 179 | new_weights.append(v) 180 | else: idx_b += 1 181 | continue 182 | 183 | if len(v.shape) != len(pretrained_v.shape): 184 | raise ValueError("Le nombre de dimension des variables {} est différent !\n Target shape : {}\n Pretrained shape : {}".format(idx_a, v.shape, pretrained_v.shape)) 185 | 186 | new_v = None 187 | if v.shape == pretrained_v.shape: 188 | new_v = pretrained_v 189 | elif not partial_transfert: 190 | print("Variables {} shapes mismatch ({} vs {}), skipping it".format(idx_a, v.shape, pretrained_v.shape)) 191 | 192 | new_v = v 193 | else: 194 | print("Variables {} shapes mismatch ({} vs {}), making partial transfert".format(idx_a, v.shape, pretrained_v.shape)) 195 | 196 | new_v = partial_weight_transfert(v, pretrained_v) 197 | 198 | new_weights.append(new_v) 199 | idx_a, idx_b = idx_a + 1, idx_b + 1 200 | 201 | if idx_a != len(target_variables) or idx_b != len(pretrained_variables): 202 | raise ValueError("All variables of a model have not been consummed\n Model A : length : {} - variables consummed : {}\n Model B (pretrained) : length : {} - variables consummed : {}".format(len(target_variables), idx_a, len(pretrained_variables), idx_b)) 203 | 204 | target_model.set_weights(new_weights) 205 | print("Weights transfered successfully !") 206 | --------------------------------------------------------------------------------