├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── evaluate.py ├── fire ├── __init__.py ├── data.py ├── dataaug_user.py ├── datatools.py ├── loss.py ├── metrics.py ├── model.py ├── models │ ├── convnext.py │ ├── mobilenetv3.py │ ├── myefficientnet_pytorch │ │ ├── __init__.py │ │ ├── model.py │ │ └── utils.py │ └── swin │ │ ├── __init__.py │ │ ├── build.py │ │ ├── config.py │ │ ├── configs │ │ ├── swin_base_patch4_window12_384_22kto1k_finetune.yaml │ │ ├── swin_base_patch4_window12_384_finetune.yaml │ │ ├── swin_base_patch4_window7_224.yaml │ │ ├── swin_base_patch4_window7_224_22k.yaml │ │ ├── swin_base_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_large_patch4_window12_384_22kto1k_finetune.yaml │ │ ├── swin_large_patch4_window7_224_22k.yaml │ │ ├── swin_large_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_mlp_base_patch4_window7_224.yaml │ │ ├── swin_mlp_tiny_c12_patch4_window8_256.yaml │ │ ├── swin_mlp_tiny_c24_patch4_window8_256.yaml │ │ ├── swin_mlp_tiny_c6_patch4_window8_256.yaml │ │ ├── swin_small_patch4_window7_224.yaml │ │ ├── swin_tiny_c24_patch4_window8_256.yaml │ │ └── swin_tiny_patch4_window7_224.yaml │ │ ├── swin_mlp.py │ │ └── swin_transformer.py ├── runner.py ├── runnertools.py ├── scheduler.py └── utils.py ├── predict.py ├── requirements.txt ├── scripts ├── cleanData.py ├── convert_onnx.py ├── heatmap.py ├── make_fashionmnist.py ├── nohup_train.sh └── predictTTA.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.jpg 3 | *.png 4 | *.zip 5 | *.rar 6 | *.log 7 | *.onnx 8 | *.bak 9 | *.pth 10 | 11 | /fire/__pycache__ 12 | /data 13 | /pretrained 14 | /output 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Mr.Fire 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FireClassification: Deep Learning Image Classification for lazy humans 2 | 3 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/fire717/Fire/blob/main/LICENSE) 4 | ## 一、前言 5 | FireClassification is a deep learning Framework written in Python and used for Image Classification task, running on top of the machine learning platform Pytorch. 6 | 7 | Read the source code as documentation. 8 | 9 | ## 二、使用示例 10 | 11 | 首先git clone本项目 12 | 13 | ### 2.1 训练 14 | 1. 下载[fashion mnist](https://github.com/zalandoresearch/fashion-mnist)数据集的四个压缩包放到./data目录下,运行`python scripts/make_fashionmnist.py`自动提取图片并划分类别、验证集 15 | 2. 执行python train.py 训练 16 | 3. 执行python evaluate.py 测试(在config设置训练好的模型路径) 17 | 18 | ### 2.2 优化 19 | * 迁移学习,下载对应模型的预训练模型,把路径填入config.py中 20 | * 调整不同的模型、尺寸、优化器等等 21 | 22 | ### 2.3 自定义网络结构 23 | 依次修改fire/model.py相应代码即可。 24 | 25 | ## 三、功能 26 | ### 3.1.数据加载 27 | * 文件夹形式 28 | * csv标签形式 29 | * 其它自定义形式需手动修改代码 30 | 31 | ### 3.2.支持网络 32 | 33 | * Resnet系列,Densenet系列,VGGnet系列等所有[pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch)支持的网络 34 | * [Mobilenetv2](https://pytorch.org/docs/stable/torchvision/models.html?highlight=mobilenet#torchvision.models.mobilenet_v2),[Mbilenetv3](https://github.com/kuan-wang/pytorch-mobilenet-v3),ShuffleNetV2 35 | * [EfficientNet](https://github.com/lukemelas/EfficientNet-PyTorch) 36 | * [Swin Transformer](https://github.com/microsoft/Swin-Transformer) 37 | * [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) 38 | * [TIMM库所有模型](https://github.com/huggingface/pytorch-image-models) 39 | 40 | 41 | 42 | ### 3.3.优化器 43 | * Adam 44 | * SGD 45 | * AdaBelief 46 | * AdamW 47 | 48 | ### 3.4.学习率衰减 49 | * ReduceLROnPlateau 50 | * StepLR 51 | * MultiStepLR 52 | * SGDR 53 | 54 | ### 3.5.损失函数 55 | * 交叉熵 56 | * Focalloss 57 | 58 | ### 3.6.其他 59 | * Metric(acc, F1) 60 | * 训练日志保存 61 | * 交叉验证 62 | * 梯度裁剪 63 | * earlystop 64 | * weightdecay 65 | * 按文件夹设置分类标签、读取csv标签 66 | * 冻结/解冻 除最后的全连接层的特征层 67 | * labelsmooth 68 | 69 | 70 | 71 | ## 四、Update 72 | * 2023.9 [v1.1] 优化代码,删掉一些不用的功能,替换一些依赖库为自己实现,修复bug简化代码,修改存储路径 73 | * 2022.7 [v1.0] (根据这半年打比赛经验,增加一些东西,删除一些几乎不用的东西。) 增加convnext、swin transformer、半精度训练,删除mobileformer,删除日志、tensorboard(习惯用文档记录),优化readme 74 | * 2021.8 [v0.9] 增加micronet和测试结果,增加rk3399测速 75 | * 2021.8 [v0.8] 增加mobileformer,加入fashion mnist数据集使用demo,方便测试各种模型,同时加入部分网络的训练结果 76 | 77 | ## 五、To Do 78 | * 完善Readme 79 | * 增加使用文档 80 | * 彻底分离用户自定义部分的代码 81 | 82 | ## 六、参考资源 83 | 1. [albumentations](https://github.com/albumentations-team/albumentations) 84 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # @https://github.com/fire717/Fire 2 | 3 | cfg = { 4 | ### Global Set 5 | "model_name": "resnet50", 6 | 'GPU_ID': '0', 7 | "class_number": 10, 8 | "class_names": [], #str in list or [] for DIR label 9 | 10 | "random_seed":42, 11 | "cfg_verbose":True, 12 | "num_workers":4, 13 | 14 | 15 | ### Train Setting 16 | 'train_path':"./data/train", 17 | 'val_path':"./data/val", #if '' mean use k_flod 18 | 'pretrained':'', #path or '' 19 | 20 | 21 | 'try_to_train_items': 0, # 0 means all, or run part(200 e.g.) for bug test 22 | 'save_best_only': True, #only save model if better than before 23 | 'save_one_only':True, #only save one best model (will del model before) 24 | "save_dir": "output/", 25 | 'metrics': ['acc'], # default acc, can add F1 ... 26 | "loss": 'CE', 27 | 28 | 'show_heatmap':False, 29 | 'show_data':False, 30 | 31 | 32 | ### Train Hyperparameters 33 | "img_size": [224, 224], # [h, w] 34 | 'learning_rate':0.001, 35 | 'batch_size':64, 36 | 'epochs':100, 37 | 'optimizer':'Adam', #Adam SGD AdaBelief Ranger 38 | 'scheduler':'default-0.1-3', #default SGDR-5-2 step-4-0.8 39 | 40 | 'warmup_epoch':0, # 41 | 'weight_decay' : 0,#0.0001, 42 | "k_flod":5, 43 | 'val_fold':0, 44 | 'early_stop_patient':7, 45 | 46 | 'use_distill':0, 47 | 'label_smooth':0, 48 | # 'checkpoint':None, 49 | 'class_weight': None,#s[1.4, 0.78], # None [1, 1] 50 | 'clip_gradient': 0,#1, # 0 51 | 'freeze_nonlinear_epoch':0, 52 | 53 | 54 | 'mixup':False, 55 | 'cutmix':False, 56 | 'sample_weights':None, 57 | 58 | 59 | ### Test 60 | 'model_path':'output/exp2/best.pt',#test model 61 | 62 | 'eval_path':"./data/test",#test with label,get eval result 63 | 'test_path':"./data/test",#test without label, just show img result 64 | 65 | 'TTA':False, 66 | 'merge':False, 67 | 'test_batch_size': 1, 68 | 69 | 70 | } 71 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os,argparse 2 | import random 3 | 4 | from fire import initFire, FireModel, FireRunner, FireData 5 | 6 | from config import cfg 7 | 8 | 9 | 10 | 11 | def main(cfg): 12 | 13 | 14 | initFire(cfg) 15 | 16 | 17 | model = FireModel(cfg) 18 | 19 | 20 | 21 | data = FireData(cfg) 22 | # data.showTrainData() 23 | # b 24 | 25 | _, val_loader = data.getTrainValDataloader() 26 | 27 | 28 | runner = FireRunner(cfg, model) 29 | 30 | 31 | runner.modelLoad(cfg['model_path']) 32 | 33 | 34 | runner.evaluate(val_loader) 35 | 36 | 37 | 38 | if __name__ == '__main__': 39 | main(cfg) -------------------------------------------------------------------------------- /fire/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from fire.utils import initFire 3 | from fire.model import FireModel 4 | from fire.runner import FireRunner 5 | from fire.data import FireData 6 | 7 | 8 | -------------------------------------------------------------------------------- /fire/data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import numpy as np 5 | 6 | import cv2 7 | from torchvision import transforms 8 | 9 | from fire.utils import firelog 10 | from fire.datatools import getDataLoader, getFileNames 11 | from fire.dataaug_user import TrainDataAug 12 | 13 | 14 | 15 | class FireData(): 16 | def __init__(self, cfg): 17 | self.cfg = cfg 18 | 19 | 20 | def getTrainValDataloader(self): 21 | 22 | class_names = self.cfg['class_names'] 23 | if len(class_names)==0: 24 | class_names = os.listdir(self.cfg['train_path']) 25 | class_names.sort() 26 | firelog("i", class_names) 27 | 28 | train_data = [] 29 | for i,class_name in enumerate(class_names): 30 | sub_dir = os.path.join(self.cfg['train_path'],class_name) 31 | img_path_list = getFileNames(sub_dir) 32 | img_path_list.sort() 33 | train_data += [[p,i] for p in img_path_list] 34 | random.shuffle(train_data) 35 | 36 | if self.cfg['val_path'] != '': 37 | firelog('i',"val_path is not none, not use kflod to split train-val data ...") 38 | 39 | val_data = [] 40 | for i,class_name in enumerate(class_names): 41 | sub_dir = os.path.join(self.cfg['val_path'],class_name) 42 | img_path_list = getFileNames(sub_dir) 43 | img_path_list.sort() 44 | val_data += [[p,i] for p in img_path_list] 45 | 46 | else: 47 | firelog('i',"val_path is none, use kflod to split data: k=%d val_fold=%d" % (self.cfg['k_flod'],self.cfg['val_fold'])) 48 | all_data = train_data 49 | 50 | fold_count = int(len(all_data)/self.cfg['k_flod']) 51 | if self.cfg['val_fold']==self.cfg['k_flod']: 52 | train_data = all_data 53 | val_data = all_data[:10] 54 | else: 55 | val_data = all_data[fold_count*self.cfg['val_fold']:fold_count*(self.cfg['val_fold']+1)] 56 | train_data = all_data[:fold_count*self.cfg['val_fold']]+all_data[fold_count*(self.cfg['val_fold']+1):] 57 | 58 | if self.cfg['try_to_train_items'] > 0: 59 | train_data = train_data[:self.cfg['try_to_train_items']] 60 | val_data = val_data[:self.cfg['try_to_train_items']] 61 | 62 | firelog('i',"Train: %d Val: %d " % (len(train_data),len(val_data))) 63 | input_data = [train_data, val_data] 64 | 65 | train_loader, val_loader = getDataLoader("trainval", 66 | input_data, 67 | self.cfg) 68 | return train_loader, val_loader 69 | 70 | 71 | def getTestDataloader(self): 72 | data_names = getFileNames(self.cfg['test_path']) 73 | print("total ",len(data_names)) 74 | input_data = [data_names] 75 | data_loader = getDataLoader("test", 76 | input_data, 77 | self.cfg) 78 | return data_loader 79 | 80 | 81 | def showTrainData(self, show_num = 200): 82 | #show train data finally to exam 83 | 84 | show_dir = "show_img" 85 | show_path = os.path.join(self.cfg['save_dir'], show_dir) 86 | firelog('i',"Showing traing data in ",show_path) 87 | if not os.path.exists(show_path): 88 | os.makedirs(show_path) 89 | 90 | 91 | img_path_list = getFileNames(self.cfg['train_path'])[:show_num] 92 | transform = transforms.Compose([TrainDataAug(self.cfg['img_size'])]) 93 | 94 | 95 | for i,img_path in enumerate(img_path_list): 96 | #print(i) 97 | img = cv2.imread(img_path) 98 | img = transform(img) 99 | img.save(os.path.join(show_path,os.path.basename(img_path)), quality=100) 100 | 101 | -------------------------------------------------------------------------------- /fire/dataaug_user.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import torch 7 | from torch.utils.data.dataset import Dataset 8 | import torchvision.transforms as transforms 9 | import torchvision.transforms.functional as F 10 | import random 11 | import cv2 12 | import albumentations as A 13 | import json 14 | import platform 15 | 16 | 17 | 18 | 19 | ###### 1.Data aug 20 | class TrainDataAug: 21 | def __init__(self, img_size): 22 | self.h = img_size[0] 23 | self.w = img_size[1] 24 | 25 | def __call__(self, img): 26 | # opencv img, BGR 27 | # new_width, new_height = self.size[0], self.size[1] 28 | img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 29 | # raw_h, raw_w = img.shape[:2] 30 | # min_size = max(img.shape[:2]) 31 | 32 | 33 | 34 | # img = A.OneOf([A.ShiftScaleRotate( 35 | # shift_limit=0.1, 36 | # scale_limit=0.1, 37 | # rotate_limit=30, 38 | # interpolation=cv2.INTER_LINEAR, 39 | # border_mode=cv2.BORDER_CONSTANT, 40 | # value=0, mask_value=0, 41 | # p=0.5), 42 | # A.GridDistortion(num_steps=5, distort_limit=0.2, 43 | # interpolation=1, border_mode=4, p=0.4), 44 | # A.RandomGridShuffle(grid=(3, 3), p=0.3)], 45 | # p=0.5)(image=img)['image'] 46 | 47 | # img = A.HorizontalFlip(p=0.5)(image=img)['image'] 48 | # img = A.VerticalFlip(p=0.4)(image=img)['image'] 49 | 50 | # # img = A.OneOf([A.RandomBrightnessContrast(brightness_limit=0.05, 51 | # # contrast_limit=0.05, p=0.5), 52 | # # A.HueSaturationValue(hue_shift_limit=10, 53 | # # sat_shift_limit=10, val_shift_limit=10, p=0.5)], 54 | # # p=0.4)(image=img)['image'] 55 | 56 | 57 | # # img = A.GaussNoise(var_limit=(5.0, 10.0), mean=0, p=0.2)(image=img)['image'] 58 | 59 | 60 | # img = A.RGBShift(r_shift_limit=5, 61 | # g_shift_limit=5, 62 | # b_shift_limit=5, 63 | # p=0.5)(image=img)['image'] 64 | 65 | 66 | # img = A.Resize(self.h,self.w,cv2.INTER_LANCZOS4,p=1)(image=img)['image'] 67 | # img = A.OneOf([A.GaussianBlur(blur_limit=3, p=0.1), 68 | # A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.5), 69 | # A.GaussNoise(var_limit=(10.0, 50.0), mean=0, p=0.4)], 70 | # p=0.4)(image=img)['image'] 71 | 72 | # img = A.CoarseDropout(max_holes=3, max_height=20, max_width=20, 73 | # p=0.8)(image=img)['image'] 74 | 75 | 76 | 77 | 78 | #img = Image.fromarray(img) 79 | return img 80 | 81 | 82 | class TestDataAug: 83 | def __init__(self, img_size): 84 | self.h = img_size[0] 85 | self.w = img_size[1] 86 | 87 | def __call__(self, img): 88 | # opencv img, BGR 89 | # new_width, new_height = self.size[0], self.size[1] 90 | 91 | img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 92 | 93 | 94 | # img = A.Resize(self.h,self.w,cv2.INTER_LANCZOS4,p=1)(image=img)['image'] 95 | #img = Image.fromarray(img) 96 | return img 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /fire/datatools.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import torch 7 | from torch.utils.data.dataset import Dataset 8 | import torchvision.transforms as transforms 9 | import torchvision.transforms.functional as F 10 | import random 11 | import cv2 12 | import albumentations as A 13 | import json 14 | import platform 15 | 16 | from fire.utils import firelog 17 | from fire.dataaug_user import TrainDataAug, TestDataAug 18 | 19 | 20 | ##### Common 21 | def getFileNames(file_dir, tail_list=['.png','.jpg','.JPG','.PNG']): 22 | L=[] 23 | for root, dirs, files in os.walk(file_dir): 24 | for file in files: 25 | if os.path.splitext(file)[1] in tail_list: 26 | L.append(os.path.join(root, file)) 27 | return L 28 | 29 | 30 | 31 | 32 | ######## dataloader 33 | 34 | class TensorDatasetTrainClassify(Dataset): 35 | def __init__(self, data, cfg, transform=None): 36 | self.data = data 37 | self.cfg = cfg 38 | self.transform = transform 39 | 40 | 41 | def __getitem__(self, index): 42 | 43 | img = cv2.imread(self.data[index][0]) 44 | img = cv2.resize(img, self.cfg['img_size']) 45 | 46 | if self.transform is not None: 47 | img = self.transform(img) 48 | 49 | y = self.data[index][1] 50 | 51 | # y_onehot = [0,0] 52 | # y_onehot[y] = 1 53 | 54 | return img, y, self.data[index] 55 | 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | 60 | class TensorDatasetTestClassify(Dataset): 61 | 62 | def __init__(self, data, cfg, transform=None): 63 | self.data = data 64 | self.cfg = cfg 65 | self.transform = transform 66 | 67 | def __getitem__(self, index): 68 | 69 | img = cv2.imread(self.data[index]) 70 | img = cv2.resize(img, self.cfg['img_size']) 71 | #img = imgPaddingWrap(img) 72 | #b 73 | if self.transform is not None: 74 | img = self.transform(img) 75 | 76 | # path_dir = '/'.join(self.data[index].split('/')[:-1]) 77 | # y = 0 78 | # if 'true' in path_dir: 79 | # y = 1 80 | 81 | return img, self.data[index] 82 | 83 | def __len__(self): 84 | return len(self.data) 85 | 86 | 87 | ###### 3. get data loader 88 | 89 | 90 | def getNormorlize(model_name): 91 | if model_name in ['mobilenetv2','mobilenetv3']: 92 | my_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 93 | elif model_name == 'xception': 94 | my_normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 95 | elif "adv-eff" in model_name: 96 | my_normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0) 97 | elif "resnex" in model_name or 'eff' in model_name or 'RegNet' in model_name: 98 | my_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 99 | #my_normalize = transforms.Normalize([0.4783, 0.4559, 0.4570], [0.2566, 0.2544, 0.2522]) 100 | elif "EN-B" in model_name: 101 | my_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 102 | else: 103 | firelog("i","Not set normalize type, Use defalut imagenet normalization.") 104 | my_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 105 | return my_normalize 106 | 107 | 108 | def getDataLoader(mode, input_data, cfg): 109 | 110 | my_normalize = getNormorlize(cfg['model_name']) 111 | 112 | 113 | 114 | data_aug_train = TrainDataAug(cfg['img_size']) 115 | data_aug_test = TestDataAug(cfg['img_size']) 116 | 117 | 118 | if mode=="test": 119 | my_dataloader = TensorDatasetTestClassify 120 | 121 | test_loader = torch.utils.data.DataLoader( 122 | my_dataloader(input_data[0], 123 | cfg, 124 | transforms.Compose([ 125 | data_aug_test, 126 | transforms.ToTensor(), 127 | my_normalize 128 | ]) 129 | ), batch_size=cfg['test_batch_size'], shuffle=False, 130 | num_workers=cfg['num_workers'], pin_memory=True 131 | ) 132 | 133 | return test_loader 134 | 135 | 136 | elif mode=="trainval": 137 | my_dataloader = TensorDatasetTrainClassify 138 | 139 | train_loader = torch.utils.data.DataLoader( 140 | my_dataloader(input_data[0], 141 | cfg, 142 | transforms.Compose([ 143 | data_aug_train, 144 | #ImageNetPolicy(), #autoaug 145 | #Augmentation(fa_resnet50_rimagenet()), #fastaa 146 | transforms.ToTensor(), 147 | my_normalize, 148 | ])), 149 | batch_size=cfg['batch_size'], 150 | shuffle=True, 151 | num_workers=cfg['num_workers'], 152 | pin_memory=True) 153 | 154 | 155 | val_loader = torch.utils.data.DataLoader( 156 | my_dataloader(input_data[1], 157 | cfg, 158 | transforms.Compose([ 159 | data_aug_test, 160 | transforms.ToTensor(), 161 | my_normalize 162 | ])), 163 | batch_size=cfg['batch_size'], 164 | shuffle=False, 165 | num_workers=cfg['num_workers'], 166 | pin_memory=True) 167 | return train_loader, val_loader 168 | -------------------------------------------------------------------------------- /fire/loss.py: -------------------------------------------------------------------------------- 1 | 2 | # import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | ########################### loss 8 | 9 | def labelSmooth(one_hot, label_smooth): 10 | return one_hot*(1-label_smooth)+label_smooth/one_hot.shape[1] 11 | 12 | 13 | class CrossEntropyLossOneHot(nn.Module): 14 | def __init__(self): 15 | super(CrossEntropyLossOneHot, self).__init__() 16 | self.log_softmax = nn.LogSoftmax(dim=-1) 17 | 18 | def forward(self, preds, labels): 19 | return torch.mean(torch.sum(-labels * self.log_softmax(preds), -1)) 20 | 21 | 22 | class CrossEntropyLossV2(nn.Module): 23 | def __init__(self, label_smooth=0, class_weight=None): 24 | super().__init__() 25 | self.class_weight = class_weight 26 | self.label_smooth = label_smooth 27 | self.epsilon = 1e-7 28 | 29 | def forward(self, x, y, label_smooth=0, gamma=0, sample_weights=None, sample_weight_img_names=None): 30 | 31 | #one_hot_label = F.one_hot(y, x.shape[1]) 32 | one_hot_label = y 33 | if label_smooth: 34 | one_hot_label = labelSmooth(one_hot_label, label_smooth) 35 | 36 | #y_pred = F.log_softmax(x, dim=1) 37 | # equal below two lines 38 | y_softmax = F.softmax(x, 1) 39 | #print(y_softmax) 40 | y_softmax = torch.clamp(y_softmax, self.epsilon, 1.0-self.epsilon)# avoid nan 41 | y_softmaxlog = torch.log(y_softmax) 42 | 43 | # original CE loss 44 | loss = -one_hot_label * y_softmaxlog 45 | 46 | if class_weight: 47 | loss = loss*self.class_weight 48 | 49 | #focal loss gamma 50 | if gamma: 51 | loss = loss*((1-y_softmax)**gamma) 52 | 53 | loss = torch.mean(torch.sum(loss, -1)) 54 | 55 | return loss 56 | 57 | 58 | class CrossEntropyLoss(nn.Module): 59 | def __init__(self, label_smooth=0, class_weight=None, gamma=0): 60 | super().__init__() 61 | self.class_weight = class_weight #means alpha 62 | self.label_smooth = label_smooth 63 | self.gamma = gamma 64 | self.epsilon = 1e-7 65 | 66 | def forward(self, x, y, sample_weights=0, sample_weight_img_names=None): 67 | 68 | one_hot_label = F.one_hot(y, x.shape[1]) 69 | 70 | if self.label_smooth: 71 | one_hot_label = labelSmooth(one_hot_label, self.label_smooth) 72 | 73 | #y_pred = F.log_softmax(x, dim=1) 74 | # equal below two lines 75 | y_softmax = F.softmax(x, 1) 76 | #print(y_softmax) 77 | y_softmax = torch.clamp(y_softmax, self.epsilon, 1.0-self.epsilon)# avoid nan 78 | y_softmaxlog = torch.log(y_softmax) 79 | 80 | # original CE loss 81 | loss = -one_hot_label * y_softmaxlog 82 | 83 | if self.class_weight: 84 | loss = loss*self.class_weight 85 | 86 | if self.gamma: 87 | loss = loss*((1-y_softmax)**self.gamma) 88 | 89 | loss = torch.mean(torch.sum(loss, -1)) 90 | return loss 91 | 92 | 93 | class FocalLoss(nn.Module): 94 | def __init__(self, label_smooth=0, gamma = 0., weight=None): 95 | super().__init__() 96 | self.gamma = gamma 97 | self.weight = weight # means alpha 98 | self.epsilon = 1e-7 99 | self.label_smooth = label_smooth 100 | 101 | 102 | def forward(self, x, y, sample_weights=0, sample_weight_img_names=None): 103 | 104 | if len(y.shape) == 1: 105 | # 106 | one_hot_label = F.one_hot(y, x.shape[1]) 107 | 108 | if self.label_smooth: 109 | one_hot_label = labelSmooth(one_hot_label, self.label_smooth) 110 | 111 | if sample_weights>0 and sample_weights is not None: 112 | #print(sample_weight_img_names) 113 | weigths = [sample_weights if 'yxboard' in img_name else 1 for img_name in sample_weight_img_names] 114 | weigths = torch.DoubleTensor(weigths).reshape((len(weigths),1)).to(x.device) 115 | #print(weigths, weigths.shape) 116 | #print(one_hot_label, one_hot_label.shape) 117 | one_hot_label = one_hot_label*weigths 118 | #print(one_hot_label) 119 | #b 120 | else: 121 | one_hot_label = y 122 | 123 | 124 | #y_pred = F.log_softmax(x, dim=1) 125 | # equal below two lines 126 | y_softmax = F.softmax(x, 1) 127 | #print(y_softmax) 128 | y_softmax = torch.clamp(y_softmax, self.epsilon, 1.0-self.epsilon)# avoid nan 129 | y_softmaxlog = torch.log(y_softmax) 130 | 131 | #print(y_softmaxlog) 132 | # original CE loss 133 | loss = -one_hot_label * y_softmaxlog 134 | #loss = 1 * torch.abs(one_hot_label-y_softmax)#my new CE..ok its L1... 135 | 136 | # print(one_hot_label) 137 | # print(y_softmax) 138 | # print(one_hot_label-y_softmax) 139 | # print(torch.abs(y-y_softmax)) 140 | #print(loss) 141 | 142 | # gamma 143 | loss = loss*((torch.abs(one_hot_label-y_softmax))**self.gamma) 144 | # print(loss) 145 | 146 | # alpha 147 | if self.weight is not None: 148 | loss = self.weight*loss 149 | 150 | loss = torch.mean(torch.sum(loss, -1)) 151 | return loss 152 | 153 | 154 | 155 | 156 | 157 | if __name__ == '__main__': 158 | 159 | 160 | 161 | device = torch.device("cpu") 162 | 163 | #x = torch.randn(2,2) 164 | x = torch.tensor([[0.1,0.7,0.2]]) 165 | y = torch.tensor([1]) 166 | print(x) 167 | 168 | loss_func = torch.nn.CrossEntropyLoss().to(device) 169 | loss = loss_func(x,y) 170 | print("loss1: ",loss) 171 | 172 | # loss_func = Focalloss().to(device) 173 | # loss = loss_func(x,y) 174 | # print("loss2: ",loss) 175 | 176 | 177 | weight_loss = torch.DoubleTensor([1,1,1]).to(device) 178 | loss_func = FocalLoss(gamma=0, weight=weight_loss).to(device) 179 | loss = loss_func(x,y) 180 | print("loss3: ",loss) 181 | 182 | 183 | # weight_loss = torch.DoubleTensor([2,1]).to(device) 184 | # loss_func = Focalloss(gamma=0.2, weight=weight_loss).to(device) 185 | # loss = loss_func(x,y) 186 | # print("loss4: ",loss) -------------------------------------------------------------------------------- /fire/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | 4 | 5 | ### acc 6 | 7 | 8 | ### F1 9 | def getF1(pres, labels): 10 | 11 | count_all = len(labels) 12 | 13 | tp = 0 14 | fp = 0 15 | fn = 0 16 | 17 | for i in range(count_all): 18 | #print(pres[i][0], labels[i]) 19 | if pres[i][0] > 0.5: 20 | if labels[i] == 0: 21 | tp += 1 22 | else: 23 | fp += 1 24 | else: 25 | if labels[i] != 1: 26 | fn += 1 27 | 28 | 29 | # print(pres.shape, labels.shape) 30 | # print(pres[0]) 31 | # print(labels[0]) 32 | precision = tp/(tp+fp+1e-7) 33 | recall = tp/(tp+fn+1e-7) 34 | 35 | f1_score = 2*recall*precision / (recall+precision+1e-7) 36 | return precision, recall, f1_score 37 | 38 | 39 | def getMF1(pres, labels): 40 | count_all,class_num = pres.shape 41 | 42 | tp_list = [0 for _ in range(class_num)] 43 | fp_list = [0 for _ in range(class_num)] 44 | fn_list = [0 for _ in range(class_num)] 45 | 46 | for i in range(count_all): 47 | pre_id = np.argmax(pres[i]) 48 | gt_id = labels[i] 49 | if pre_id == gt_id: 50 | tp_list[gt_id] += 1 51 | else: 52 | fp_list[pre_id] += 1 53 | fn_list[gt_id] += 1 54 | 55 | f1_list = [] 56 | p_list = [] 57 | r_list = [] 58 | for i in range(class_num): 59 | tp = tp_list[i] 60 | fp = fp_list[i] 61 | fn = fn_list[i] 62 | precision = tp/(tp+fp+1e-7) 63 | recall = tp/(tp+fn+1e-7) 64 | 65 | f1_score = 2*recall*precision / (recall+precision+1e-7) 66 | 67 | p_list.append(precision) 68 | r_list.append(recall) 69 | f1_list.append(f1_score) 70 | 71 | precision = np.mean(p_list) 72 | recall = np.mean(r_list) 73 | f1_score = np.mean(f1_list) 74 | 75 | return precision, recall, f1_score 76 | 77 | ### mAP 78 | def vocAP(rec, prec, use_07_metric=False): 79 | """ ap = vocAP(rec, prec, [use_07_metric]) 80 | Compute VOC AP given precision and recall. 81 | If use_07_metric is true, uses the 82 | VOC 07 11 point method (default:False). 83 | """ 84 | if use_07_metric: #VOC在2010之后换了评价方法,所以决定是否用07年的 85 | # 11 point metric 86 | ap = 0. 87 | for t in np.arange(0., 1.1, 0.1): # 07年的采用11个点平分recall来计算 88 | if np.sum(rec >= t) == 0: 89 | p = 0 90 | else: 91 | p = np.max(prec[rec >= t]) # 取一个recall阈值之后最大的precision 92 | ap = ap + p / 11. # 将11个precision加和平均 93 | else: # 这里是用2010年后的方法,取所有不同的recall对应的点处的精度值做平均,不再是固定的11个点 94 | # correct AP calculation 95 | # first append sentinel values at the end 96 | mrec = np.concatenate(([0.], rec, [1.])) #recall和precision前后分别加了一个值,因为recall最后是1,所以 97 | mpre = np.concatenate(([0.], prec, [0.])) # 右边加了1,precision加的是0 98 | 99 | # compute the precision envelope 100 | for i in range(mpre.size - 1, 0, -1): 101 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) #从后往前,排除之前局部增加的precison情况 102 | 103 | # to calculate area under PR curve, look for points 104 | # where X axis (recall) changes value 105 | i = np.where(mrec[1:] != mrec[:-1])[0] # 这里巧妙的错位,返回刚好TP的位置, 106 | # 可以看后面辅助的例子 107 | 108 | # and sum (\Delta recall) * prec 用recall的间隔对精度作加权平均 109 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 110 | return ap 111 | 112 | 113 | # 计算每个类别对应的AP,mAP是所有类别AP的平均值 114 | def vocEval(result_json_path,classname,use_07_metric=False): 115 | 116 | with open(result_json_path,'r') as f: 117 | result_json = json.loads(f.readlines()[0]) 118 | 119 | result_json = sorted(result_json, key=lambda x:x['score'], reverse=True) 120 | #print(result_json[:10]) 121 | 122 | count = len(result_json) 123 | tp = np.zeros(count) # 用于标记每个检测结果是tp还是fp 124 | fp = np.zeros(count) 125 | npos = 0 126 | 127 | for i,item in enumerate(result_json): 128 | #print(item) 129 | 130 | if classname in item['path']: 131 | npos += 1 132 | 133 | if item['category'] == classname: 134 | if classname in item['path']: 135 | tp[i] = 1 136 | else: 137 | fp[i] = 1 138 | 139 | 140 | # compute precision recall 141 | fp = np.cumsum(fp) # 累加函数np.cumsum([1, 2, 3, 4]) -> [1, 3, 6, 10] 142 | tp = np.cumsum(tp) 143 | rec = tp / float(npos) 144 | # avoid divide by zero in case the first detection matches a difficult 145 | # ground truth 146 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 147 | ap = vocAP(rec, prec, use_07_metric) 148 | 149 | return rec, prec, ap 150 | 151 | # 计算每个类别对应的AP,mAP是所有类别AP的平均值 152 | def vocTest(result_json_path,classname, label_json_path): 153 | 154 | with open(result_json_path,'r') as f: 155 | result_json = json.loads(f.readlines()[0]) 156 | 157 | result_json = sorted(result_json, key=lambda x:x['score'], reverse=True) 158 | #print(result_json[:10]) 159 | 160 | with open(label_json_path,'r') as f: 161 | label_json = json.loads(f.readlines()[0]) 162 | label_imgs = label_json[classname]#testA_v3_clean 163 | npos = len(label_imgs) 164 | print("len label:", npos) 165 | 166 | 167 | count = len(result_json) 168 | tp = np.zeros(count) # 用于标记每个检测结果是tp还是fp 169 | fp = np.zeros(count) 170 | 171 | 172 | for i,item in enumerate(result_json): 173 | #print(item) 174 | 175 | if item['category'] == classname: 176 | if os.path.basename(item['image_name']) in label_imgs: 177 | tp[i] = 1 178 | else: 179 | fp[i] = 1 180 | 181 | 182 | # compute precision recall 183 | fp = np.cumsum(fp) # 累加函数np.cumsum([1, 2, 3, 4]) -> [1, 3, 6, 10] 184 | tp = np.cumsum(tp) 185 | rec = tp / float(npos) 186 | # avoid divide by zero in case the first detection matches a difficult 187 | # ground truth 188 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 189 | ap = vocAP(rec, prec) 190 | 191 | return rec, prec, ap 192 | 193 | 194 | 195 | def getTestmAP(result_json_path, classname_list, label_json_path): 196 | AP_list = [] 197 | for classname in classname_list: 198 | 199 | rec, prec, ap = vocTest(result_json_path, classname, label_json_path) 200 | print("AP %s: %f" % (classname,ap)) 201 | AP_list.append(ap) 202 | return np.mean(AP_list) 203 | 204 | 205 | 206 | # 计算每个类别对应的AP,mAP是所有类别AP的平均值 207 | def vocOnline(pres, labels, cate_id): 208 | # print(pres, labels, cate_id) 209 | # b 210 | 211 | count = len(pres) 212 | tp = np.zeros(count) # 用于标记每个检测结果是tp还是fp 213 | fp = np.zeros(count) 214 | npos = 0 215 | 216 | for i,item in enumerate(pres): 217 | #print(item) 218 | if labels[i]==cate_id: 219 | npos += 1 220 | 221 | if item>0.33: 222 | if labels[i]==cate_id: 223 | tp[i] = 1 224 | else: 225 | fp[i] = 1 226 | 227 | #print(npos) 228 | # compute precision recall 229 | fp = np.cumsum(fp) # 累加函数np.cumsum([1, 2, 3, 4]) -> [1, 3, 6, 10] 230 | tp = np.cumsum(tp) 231 | rec = tp / (float(npos)+0.000001) 232 | # avoid divide by zero in case the first detection matches a difficult 233 | # ground truth 234 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 235 | ap = vocAP(rec, prec) 236 | 237 | return rec, prec, ap 238 | 239 | def getValmAP(pres, labels): 240 | 241 | class_name = ['calling', 'normal', 'smoking','smoking_calling'] 242 | 243 | AP_list = [] 244 | print() 245 | for idx in range(len(class_name)): 246 | rec, prec, ap = vocOnline(pres[:,idx], labels, idx) 247 | #print(class_name[idx], ap) 248 | AP_list.append(ap) 249 | return np.mean(AP_list) 250 | -------------------------------------------------------------------------------- /fire/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import pretrainedmodels 5 | 6 | from fire.models.mobilenetv3 import MobileNetV3 7 | 8 | from fire.models.myefficientnet_pytorch import EfficientNet 9 | from fire.models.convnext import convnext_tiny,convnext_small,convnext_base,convnext_large 10 | from fire.models.swin import build_model,get_config 11 | 12 | import timm 13 | import torchvision 14 | 15 | class FireModel(nn.Module): 16 | def __init__(self, cfg): 17 | super(FireModel, self).__init__() 18 | 19 | self.cfg = cfg 20 | 21 | 22 | self.pretrainedModel() 23 | 24 | self.changeModelStructure() 25 | 26 | 27 | 28 | def pretrainedModel(self): 29 | 30 | 31 | ### Create model 32 | if "efficientnetv2" in self.cfg['model_name']: 33 | #model = EfficientNet.from_name(model_name) 34 | if "v2-s" in self.cfg['model_name']: 35 | self.pretrain_model = timm.create_model('tf_efficientnetv2_s.in21k_ft_in1k', pretrained=False) 36 | elif "v2-b0" in self.cfg['model_name']: 37 | self.pretrain_model = timm.create_model('tf_efficientnetv2_b0.in1k', pretrained=False) 38 | 39 | if self.cfg['pretrained']: 40 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=True) 41 | 42 | 43 | elif "eca_nfnet_l0" in self.cfg['model_name']: 44 | self.pretrain_model = timm.create_model('eca_nfnet_l0.ra2_in1k', pretrained=False) 45 | if self.cfg['pretrained']: 46 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=True) 47 | 48 | 49 | elif "convnextv2" in self.cfg['model_name']: 50 | 51 | if "tiny" in self.cfg['model_name']: 52 | self.pretrain_model = timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k_384', pretrained=True) 53 | # if self.cfg['pretrained']: 54 | # self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=True) 55 | 56 | elif "resnest" in self.cfg['model_name']: 57 | if "50d" in self.cfg['model_name']: 58 | self.pretrain_model = timm.create_model('resnest50d', 59 | pretrained=False) 60 | 61 | 62 | elif self.cfg['model_name']=="mobilenetv2": 63 | #model.cpu() 64 | self.pretrain_model = torchvision.models.mobilenet_v2(pretrained=False, progress=True, width_mult=1.0) 65 | 66 | if self.cfg['pretrained']: 67 | state_dict = torch.load(self.cfg['pretrained']) 68 | self.pretrain_model.load_state_dict(state_dict, strict=True) 69 | 70 | 71 | elif self.cfg['model_name']=="mobilenetv3": 72 | #model.cpu() 73 | self.pretrain_model = MobileNetV3() 74 | if self.cfg['pretrained']: 75 | state_dict = torch.load(self.cfg['pretrained']) 76 | self.pretrain_model.load_state_dict(state_dict, strict=True) 77 | 78 | 79 | elif "shufflenetv2" in self.cfg['model_name']: 80 | self.pretrain_model = torchvision.models.shufflenet_v2_x1_0() 81 | if self.cfg['pretrained']: 82 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=True) 83 | 84 | 85 | elif "efficientnet" in self.cfg['model_name']: 86 | #model = EfficientNet.from_name(model_name) 87 | self.pretrain_model = EfficientNet.from_name(self.cfg['model_name'].replace('adv-','')) 88 | if self.cfg['pretrained']: 89 | ckpt = torch.load(self.cfg['pretrained']) 90 | # del ckpt["_fc.weight"] 91 | # del ckpt["_fc.bias"] 92 | self.pretrain_model.load_state_dict(ckpt,strict=True) 93 | 94 | 95 | elif 'resnet' in self.cfg['model_name'] or \ 96 | 'resnext' in self.cfg['model_name'] or \ 97 | 'xception' in self.cfg['model_name']: 98 | #model_name = 'resnext50' # se_resnext50_32x4d xception 99 | self.pretrain_model = pretrainedmodels.__dict__[self.cfg['model_name']](num_classes=1000, pretrained=None) 100 | print(pretrainedmodels.pretrained_settings[self.cfg['model_name']]) 101 | 102 | if self.cfg['pretrained']: 103 | if self.cfg['model_name']=="resnet50": 104 | #model.cpu() 105 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=False) 106 | #fc_features = self.pretrain_model.last_linear.in_features 107 | elif self.cfg['model_name']=="xception": 108 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=False) 109 | #fc_features = self.pretrain_model.last_linear.in_features 110 | elif self.cfg['model_name'] == "se_resnext50_32x4d": 111 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=False) 112 | self.pretrain_model.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 113 | #fc_features = self.pretrain_model.last_linear.in_features 114 | elif self.cfg['model_name'] == "se_resnext101_32x4d": 115 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=False) 116 | self.pretrain_model.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 117 | #fc_features = self.pretrain_model.last_linear.in_features 118 | elif self.cfg['model_name'] == "resnext101_32x8d_wsl": 119 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=False) 120 | #fc_features = self.pretrain_model.fc.in_features 121 | elif self.cfg['model_name'] == "resnext101_32x16d_wsl": 122 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained']),strict=False) 123 | #fc_features = self.pretrain_model.fc.in_features 124 | else: 125 | raise Exception("[ERROR] Not load pretrained model!") 126 | 127 | elif "swin" in self.cfg['model_name']: 128 | if 'base' in self.cfg['model_name']: 129 | cfg = "fire/models/swin/configs/swin_base_patch4_window12_384_finetune.yaml" 130 | config = get_config(cfg) 131 | 132 | self.pretrain_model = build_model(config) 133 | 134 | if self.cfg['pretrained']: 135 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained'])['model'],strict=False) 136 | 137 | elif 'small' in self.cfg['model_name']: 138 | cfg = "fire/models/swin/configs/swin_small_patch4_window7_224.yaml" 139 | config = get_config(cfg) 140 | 141 | self.pretrain_model = build_model(config) 142 | 143 | if self.cfg['pretrained']: 144 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained'])['model'],strict=False) 145 | 146 | elif 'large' in self.cfg['model_name']: 147 | cfg = "fire/models/swin/configs/swin_large_patch4_window12_384_22kto1k_finetune.yaml" 148 | config = get_config(cfg) 149 | 150 | self.pretrain_model = build_model(config) 151 | 152 | if self.cfg['pretrained']: 153 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained'])['model'],strict=False) 154 | 155 | elif "convnext" in self.cfg['model_name']: 156 | if "base" in self.cfg['model_name']: 157 | self.pretrain_model = convnext_base() 158 | if self.cfg['pretrained']: 159 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained'])['model'],strict=False) 160 | # print(self.pretrain_model) 161 | # b 162 | elif "tiny" in self.cfg['model_name']: 163 | self.pretrain_model = convnext_tiny() 164 | if self.cfg['pretrained']: 165 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained'])['model'],strict=False) 166 | # print(self.pretrain_model) 167 | # b 168 | elif "small" in self.cfg['model_name']: 169 | self.pretrain_model = convnext_small() 170 | if self.cfg['pretrained']: 171 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained'])['model'],strict=False) 172 | elif "large" in self.cfg['model_name']: 173 | self.pretrain_model = convnext_large() 174 | if self.cfg['pretrained']: 175 | self.pretrain_model.load_state_dict(torch.load(self.cfg['pretrained'])['model'],strict=False) 176 | 177 | 178 | 179 | 180 | 181 | # [Add new model here] 182 | # elif self.cfg['model_name']=="xxx": 183 | # pass 184 | 185 | 186 | else: 187 | raise Exception("[ERROR] Unknown model_name: ",self.cfg['model_name']) 188 | 189 | 190 | def changeModelStructure(self): 191 | ### Change model 192 | if "efficientnetv2" in self.cfg['model_name']: 193 | self.backbone = nn.Sequential(*list(self.pretrain_model.children())[:-1]) 194 | num_features = self.pretrain_model.classifier.in_features 195 | self.head1 = nn.Linear(num_features,self.cfg['class_number']) 196 | 197 | elif "convnextv2" in self.cfg['model_name']: 198 | self.backbone = self.pretrain_model 199 | num_features = self.backbone.head.fc.in_features 200 | self.backbone.head.fc = nn.Linear(num_features,self.cfg['class_number']) 201 | #print(self.backbone) 202 | 203 | elif "eca_nfnet_l0" in self.cfg['model_name']: 204 | self.backbone = self.pretrain_model 205 | #print(self.backbone) 206 | num_features = self.backbone.head.fc.in_features 207 | self.backbone.head.fc = nn.Linear(num_features,self.cfg['class_number']) 208 | #bb 209 | 210 | elif "resnest" in self.cfg['model_name']: 211 | self.backbone = self.pretrain_model 212 | num_features = self.backbone.fc.in_features 213 | self.backbone.fc = nn.Linear(num_features,self.cfg['class_number']) 214 | 215 | elif 'mobilenetv2' in self.cfg['model_name']: 216 | 217 | in_features = self.pretrain_model.classifier[1].in_features 218 | self.features = self.pretrain_model.features 219 | 220 | self.head1 = nn.Sequential( 221 | nn.Dropout(p=0.2), # refer to paper section 6 222 | nn.Linear(in_features, self.cfg['class_number']), 223 | ) 224 | 225 | elif "mobilenetv3" in self.cfg['model_name']: 226 | # self.backbone = self.pretrain_model 227 | self.backbone = nn.Sequential(*list(self.pretrain_model.children())[:-1]) 228 | # print(self.backbone) 229 | # b 230 | num_features = 1280 231 | self.head1 = nn.Sequential( 232 | # nn.Linear(num_features, 64), 233 | nn.Dropout(0.8), 234 | # nn.AdaptiveAvgPool2d(1), 235 | nn.Linear(num_features, self.cfg['class_number'])) 236 | 237 | elif "shufflenetv2" in self.cfg['model_name']: 238 | # self.backbone = self.pretrain_model 239 | self.backbone = nn.Sequential(*list(self.pretrain_model.children())[:-1]) 240 | # print(self.backbone) 241 | # b 242 | num_features = 1024 243 | # self.head1 = nn.Sequential( 244 | # # nn.Linear(num_features, 64), 245 | # nn.Dropout(0.8), 246 | # # nn.AdaptiveAvgPool2d(1), 247 | # nn.Linear(num_features,4)) 248 | self.avgpool = nn.AdaptiveAvgPool2d(1) 249 | self.head1 = nn.Linear(num_features,self.cfg['class_number']) 250 | 251 | 252 | elif "efficientnet" in self.cfg['model_name']: 253 | #self.pretrain_model._dropout = nn.Dropout(0.5) 254 | self.backbone = self.pretrain_model 255 | num_features = self.backbone._bn1.num_features 256 | self.head1 = nn.Linear(num_features,self.cfg['class_number']) 257 | 258 | 259 | elif "convnext" in self.cfg['model_name']: 260 | 261 | self.backbone = self.pretrain_model 262 | #print(self.backbone) 263 | num_features = 1024 264 | if "large" in self.cfg['model_name']: 265 | num_features = 1536 266 | elif "tiny" in self.cfg['model_name']: 267 | num_features = 768 268 | elif "small" in self.cfg['model_name']: 269 | num_features = 768 270 | 271 | self.head1 = nn.Sequential( 272 | # nn.Dropout(0.5), 273 | nn.Linear(num_features,self.cfg['class_number'])) 274 | 275 | 276 | elif "swin" in self.cfg['model_name']: 277 | self.backbone = self.pretrain_model 278 | #print(self.backbone) 279 | num_features = self.backbone.norm.weight.size()[0] 280 | 281 | self.head1 = nn.Sequential( 282 | # nn.Dropout(0.5), 283 | nn.Linear(num_features,self.cfg['class_number'])) 284 | 285 | 286 | elif 'resnet' in self.cfg['model_name'] or \ 287 | 'resnext' in self.cfg['model_name'] or \ 288 | 'xception' in self.cfg['model_name']: 289 | #self.avgpool = nn.AdaptiveAvgPool2d(1) 290 | #print(self.pretrain_model) 291 | fc_features = self.pretrain_model.last_linear.in_features 292 | 293 | self.pretrain_model = nn.Sequential(*list(self.pretrain_model.children())[:-2]) 294 | # self.dp_linear = nn.Linear(fc_features, 8) 295 | # self.dp = nn.Dropout(0.50) 296 | self.avgpool = nn.AdaptiveAvgPool2d(1) 297 | self.head1 = nn.Linear(fc_features, self.cfg['class_number']) 298 | 299 | else: 300 | raise Exception("[ERROR] Unknown model_name: ",self.cfg['model_name']) 301 | 302 | 303 | def forward(self, img): 304 | 305 | if self.cfg['model_name'] in ['mobilenetv2']: 306 | 307 | out = self.features(img) 308 | 309 | out = nn.functional.adaptive_avg_pool2d(out, 1).reshape(out.shape[0], -1) 310 | #nn.AdaptiveAvgPool2d(1) 311 | out1 = self.head1(out) 312 | out = [out1] 313 | 314 | elif "shuffle" in self.cfg['model_name']: 315 | out = self.backbone(img) 316 | out = self.avgpool(out) 317 | out = out.view(out.size(0), -1) 318 | out1 = self.head1(out) 319 | 320 | out = [out1] 321 | 322 | elif "mobilenetv3" in self.cfg['model_name']: 323 | out = self.backbone(img) 324 | # out = self.avgpool(out) 325 | # out = out.view(out.size(0), -1) 326 | out = out.mean(3).mean(2) 327 | out1 = self.head1(out) 328 | 329 | out = [out1] 330 | 331 | 332 | elif "efficientnet" in self.cfg['model_name']: 333 | out = self.backbone(img) 334 | out = out.view(out.size(0), -1) 335 | out1 = self.head1(out) 336 | 337 | out = [out1] 338 | 339 | elif "swin" in self.cfg['model_name']: 340 | out = self.backbone(img) 341 | out = out.view(out.size(0), -1) 342 | out1 = self.head1(out) 343 | 344 | out = [out1] 345 | 346 | elif "resnest" in self.cfg['model_name']: 347 | out1 = self.backbone(img) 348 | #print(out1.shape) 349 | out = [out1] 350 | 351 | 352 | elif "convnextv2" in self.cfg['model_name'] or "eca_nfnet_l0" in self.cfg['model_name']: 353 | out1 = self.backbone(img) 354 | #print(out1.shape) 355 | out = [out1] 356 | 357 | elif "convnext" in self.cfg['model_name']: 358 | 359 | out = self.backbone(img) 360 | out = out.view(out.size(0), -1) 361 | #print(out.shape) 362 | out1 = self.head1(out) 363 | 364 | out = [out1] 365 | 366 | elif 'resnet' in self.cfg['model_name'] or \ 367 | 'resnext' in self.cfg['model_name'] or \ 368 | 'xception' in self.cfg['model_name']: 369 | out = self.pretrain_model(img) 370 | out = self.avgpool(out) 371 | out = out.view(out.size(0), -1) 372 | out1 = self.head1(out) 373 | out = [out1] 374 | # [Add new model here] 375 | # elif self.cfg['model_name']=="xxx": 376 | # pass 377 | 378 | else: 379 | raise Exception("[ERROR] Unknown model_name: ",self.cfg['model_name']) 380 | 381 | return out 382 | 383 | 384 | -------------------------------------------------------------------------------- /fire/models/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_, DropPath 13 | from timm.models.registry import register_model 14 | 15 | class Block(nn.Module): 16 | r""" ConvNeXt Block. There are two equivalent implementations: 17 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 18 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 19 | We use (2) as we find it slightly faster in PyTorch 20 | 21 | Args: 22 | dim (int): Number of input channels. 23 | drop_path (float): Stochastic depth rate. Default: 0.0 24 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 25 | """ 26 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 27 | super().__init__() 28 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 29 | self.norm = LayerNorm(dim, eps=1e-6) 30 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 31 | self.act = nn.GELU() 32 | self.pwconv2 = nn.Linear(4 * dim, dim) 33 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 34 | requires_grad=True) if layer_scale_init_value > 0 else None 35 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 36 | 37 | def forward(self, x): 38 | input = x 39 | x = self.dwconv(x) 40 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 41 | x = self.norm(x) 42 | x = self.pwconv1(x) 43 | x = self.act(x) 44 | x = self.pwconv2(x) 45 | if self.gamma is not None: 46 | x = self.gamma * x 47 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 48 | 49 | x = input + self.drop_path(x) 50 | return x 51 | 52 | class ConvNeXt(nn.Module): 53 | r""" ConvNeXt 54 | A PyTorch impl of : `A ConvNet for the 2020s` - 55 | https://arxiv.org/pdf/2201.03545.pdf 56 | 57 | Args: 58 | in_chans (int): Number of input image channels. Default: 3 59 | num_classes (int): Number of classes for classification head. Default: 1000 60 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 61 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 62 | drop_path_rate (float): Stochastic depth rate. Default: 0. 63 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 64 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 65 | """ 66 | def __init__(self, in_chans=3, num_classes=1000, 67 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 68 | layer_scale_init_value=1e-6, head_init_scale=1., 69 | ): 70 | super().__init__() 71 | 72 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 73 | stem = nn.Sequential( 74 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 75 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 76 | ) 77 | self.downsample_layers.append(stem) 78 | for i in range(3): 79 | downsample_layer = nn.Sequential( 80 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 81 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 82 | ) 83 | self.downsample_layers.append(downsample_layer) 84 | 85 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 86 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 87 | cur = 0 88 | for i in range(4): 89 | stage = nn.Sequential( 90 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 91 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 92 | ) 93 | self.stages.append(stage) 94 | cur += depths[i] 95 | 96 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 97 | # self.head = nn.Linear(dims[-1], num_classes) 98 | 99 | self.apply(self._init_weights) 100 | # self.head.weight.data.mul_(head_init_scale) 101 | # self.head.bias.data.mul_(head_init_scale) 102 | 103 | def _init_weights(self, m): 104 | if isinstance(m, (nn.Conv2d, nn.Linear)): 105 | trunc_normal_(m.weight, std=.02) 106 | nn.init.constant_(m.bias, 0) 107 | 108 | def forward_features(self, x): 109 | for i in range(4): 110 | x = self.downsample_layers[i](x) 111 | x = self.stages[i](x) 112 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 113 | 114 | def forward(self, x): 115 | x = self.forward_features(x) 116 | # x = self.head(x) 117 | return x 118 | 119 | class LayerNorm(nn.Module): 120 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 121 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 122 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 123 | with shape (batch_size, channels, height, width). 124 | """ 125 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 126 | super().__init__() 127 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 128 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 129 | self.eps = eps 130 | self.data_format = data_format 131 | if self.data_format not in ["channels_last", "channels_first"]: 132 | raise NotImplementedError 133 | self.normalized_shape = (normalized_shape, ) 134 | 135 | def forward(self, x): 136 | if self.data_format == "channels_last": 137 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 138 | elif self.data_format == "channels_first": 139 | u = x.mean(1, keepdim=True) 140 | s = (x - u).pow(2).mean(1, keepdim=True) 141 | x = (x - u) / torch.sqrt(s + self.eps) 142 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 143 | return x 144 | 145 | 146 | model_urls = { 147 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 148 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 149 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 150 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 151 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 152 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 153 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 154 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 155 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 156 | } 157 | 158 | @register_model 159 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 160 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 161 | if pretrained: 162 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 163 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 164 | model.load_state_dict(checkpoint["model"]) 165 | return model 166 | 167 | @register_model 168 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 169 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 170 | if pretrained: 171 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 172 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 173 | model.load_state_dict(checkpoint["model"]) 174 | return model 175 | 176 | @register_model 177 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 178 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 179 | if pretrained: 180 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 181 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 182 | model.load_state_dict(checkpoint["model"]) 183 | return model 184 | 185 | @register_model 186 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 187 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 188 | if pretrained: 189 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 190 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 191 | model.load_state_dict(checkpoint["model"]) 192 | return model 193 | 194 | @register_model 195 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 196 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 197 | if pretrained: 198 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 199 | url = model_urls['convnext_xlarge_22k'] 200 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 201 | model.load_state_dict(checkpoint["model"]) 202 | return model 203 | -------------------------------------------------------------------------------- /fire/models/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __all__ = ['MobileNetV3', 'mobilenetv3'] 7 | 8 | 9 | def conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): 10 | return nn.Sequential( 11 | conv_layer(inp, oup, 3, stride, 1, bias=False), 12 | norm_layer(oup), 13 | nlin_layer(inplace=True) 14 | ) 15 | 16 | 17 | def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): 18 | return nn.Sequential( 19 | conv_layer(inp, oup, 1, 1, 0, bias=False), 20 | norm_layer(oup), 21 | nlin_layer(inplace=True) 22 | ) 23 | 24 | 25 | class Hswish(nn.Module): 26 | def __init__(self, inplace=True): 27 | super(Hswish, self).__init__() 28 | self.inplace = inplace 29 | 30 | def forward(self, x): 31 | return x * F.relu6(x + 3., inplace=self.inplace) / 6. 32 | 33 | 34 | class Hsigmoid(nn.Module): 35 | def __init__(self, inplace=True): 36 | super(Hsigmoid, self).__init__() 37 | self.inplace = inplace 38 | 39 | def forward(self, x): 40 | return F.relu6(x + 3., inplace=self.inplace) / 6. 41 | 42 | 43 | class SEModule(nn.Module): 44 | def __init__(self, channel, reduction=4): 45 | super(SEModule, self).__init__() 46 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 47 | self.fc = nn.Sequential( 48 | nn.Linear(channel, channel // reduction, bias=False), 49 | nn.ReLU(inplace=True), 50 | nn.Linear(channel // reduction, channel, bias=False), 51 | Hsigmoid() 52 | # nn.Sigmoid() 53 | ) 54 | 55 | def forward(self, x): 56 | b, c, _, _ = x.size() 57 | y = self.avg_pool(x).view(b, c) 58 | y = self.fc(y).view(b, c, 1, 1) 59 | return x * y.expand_as(x) 60 | 61 | 62 | class Identity(nn.Module): 63 | def __init__(self, channel): 64 | super(Identity, self).__init__() 65 | 66 | def forward(self, x): 67 | return x 68 | 69 | 70 | def make_divisible(x, divisible_by=8): 71 | import numpy as np 72 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 73 | 74 | 75 | class MobileBottleneck(nn.Module): 76 | def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'): 77 | super(MobileBottleneck, self).__init__() 78 | assert stride in [1, 2] 79 | assert kernel in [3, 5] 80 | padding = (kernel - 1) // 2 81 | self.use_res_connect = stride == 1 and inp == oup 82 | 83 | conv_layer = nn.Conv2d 84 | norm_layer = nn.BatchNorm2d 85 | if nl == 'RE': 86 | nlin_layer = nn.ReLU # or ReLU6 87 | elif nl == 'HS': 88 | nlin_layer = Hswish 89 | else: 90 | raise NotImplementedError 91 | if se: 92 | SELayer = SEModule 93 | else: 94 | SELayer = Identity 95 | 96 | self.conv = nn.Sequential( 97 | # pw 98 | conv_layer(inp, exp, 1, 1, 0, bias=False), 99 | norm_layer(exp), 100 | nlin_layer(inplace=True), 101 | # dw 102 | conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False), 103 | norm_layer(exp), 104 | SELayer(exp), 105 | nlin_layer(inplace=True), 106 | # pw-linear 107 | conv_layer(exp, oup, 1, 1, 0, bias=False), 108 | norm_layer(oup), 109 | ) 110 | 111 | def forward(self, x): 112 | if self.use_res_connect: 113 | return x + self.conv(x) 114 | else: 115 | return self.conv(x) 116 | 117 | 118 | class MobileNetV3(nn.Module): 119 | def __init__(self, n_class=1000, input_size=224, dropout=0.8, mode='small', width_mult=1.0): 120 | super(MobileNetV3, self).__init__() 121 | input_channel = 16 122 | last_channel = 1280 123 | if mode == 'large': 124 | # refer to Table 1 in paper 125 | mobile_setting = [ 126 | # k, exp, c, se, nl, s, 127 | [3, 16, 16, False, 'RE', 1], 128 | [3, 64, 24, False, 'RE', 2], 129 | [3, 72, 24, False, 'RE', 1], 130 | [5, 72, 40, True, 'RE', 2], 131 | [5, 120, 40, True, 'RE', 1], 132 | [5, 120, 40, True, 'RE', 1], 133 | [3, 240, 80, False, 'HS', 2], 134 | [3, 200, 80, False, 'HS', 1], 135 | [3, 184, 80, False, 'HS', 1], 136 | [3, 184, 80, False, 'HS', 1], 137 | [3, 480, 112, True, 'HS', 1], 138 | [3, 672, 112, True, 'HS', 1], 139 | [5, 672, 160, True, 'HS', 2], 140 | [5, 960, 160, True, 'HS', 1], 141 | [5, 960, 160, True, 'HS', 1], 142 | ] 143 | elif mode == 'small': 144 | # refer to Table 2 in paper 145 | mobile_setting = [ 146 | # k, exp, c, se, nl, s, 147 | [3, 16, 16, True, 'RE', 2], 148 | [3, 72, 24, False, 'RE', 2], 149 | [3, 88, 24, False, 'RE', 1], 150 | [5, 96, 40, True, 'HS', 2], 151 | [5, 240, 40, True, 'HS', 1], 152 | [5, 240, 40, True, 'HS', 1], 153 | [5, 120, 48, True, 'HS', 1], 154 | [5, 144, 48, True, 'HS', 1], 155 | [5, 288, 96, True, 'HS', 2], 156 | [5, 576, 96, True, 'HS', 1], 157 | [5, 576, 96, True, 'HS', 1], 158 | ] 159 | else: 160 | raise NotImplementedError 161 | 162 | # building first layer 163 | assert input_size % 32 == 0 164 | last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 165 | self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)] 166 | self.classifier = [] 167 | 168 | # building mobile blocks 169 | for k, exp, c, se, nl, s in mobile_setting: 170 | output_channel = make_divisible(c * width_mult) 171 | exp_channel = make_divisible(exp * width_mult) 172 | self.features.append(MobileBottleneck(input_channel, output_channel, k, s, exp_channel, se, nl)) 173 | input_channel = output_channel 174 | 175 | # building last several layers 176 | if mode == 'large': 177 | last_conv = make_divisible(960 * width_mult) 178 | self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) 179 | self.features.append(nn.AdaptiveAvgPool2d(1)) 180 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) 181 | self.features.append(Hswish(inplace=True)) 182 | elif mode == 'small': 183 | last_conv = make_divisible(576 * width_mult) 184 | self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) 185 | # self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake 186 | self.features.append(nn.AdaptiveAvgPool2d(1)) 187 | self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) 188 | self.features.append(Hswish(inplace=True)) 189 | else: 190 | raise NotImplementedError 191 | 192 | # make it nn.Sequential 193 | self.features = nn.Sequential(*self.features) 194 | 195 | # building classifier 196 | self.classifier = nn.Sequential( 197 | nn.Dropout(p=dropout), # refer to paper section 6 198 | nn.Linear(last_channel, n_class), 199 | ) 200 | 201 | self._initialize_weights() 202 | 203 | def forward(self, x): 204 | x = self.features(x) 205 | x = x.mean(3).mean(2) 206 | x = self.classifier(x) 207 | return x 208 | 209 | def _initialize_weights(self): 210 | # weight initialization 211 | for m in self.modules(): 212 | if isinstance(m, nn.Conv2d): 213 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 214 | if m.bias is not None: 215 | nn.init.zeros_(m.bias) 216 | elif isinstance(m, nn.BatchNorm2d): 217 | nn.init.ones_(m.weight) 218 | nn.init.zeros_(m.bias) 219 | elif isinstance(m, nn.Linear): 220 | nn.init.normal_(m.weight, 0, 0.01) 221 | if m.bias is not None: 222 | nn.init.zeros_(m.bias) 223 | 224 | 225 | def mobilenetv3(pretrained=False, **kwargs): 226 | model = MobileNetV3(**kwargs) 227 | if pretrained: 228 | state_dict = torch.load('mobilenetv3_small_67.4.pth.tar') 229 | model.load_state_dict(state_dict, strict=True) 230 | # raise NotImplementedError 231 | return model 232 | 233 | 234 | if __name__ == '__main__': 235 | net = mobilenetv3() 236 | print('mobilenetv3:\n', net) 237 | print('Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0)) 238 | input_size=(1, 3, 224, 224) 239 | # pip install --upgrade git+https://github.com/kuan-wang/pytorch-OpCounter.git 240 | from thop import profile 241 | flops, params = profile(net, input_size=input_size) 242 | # print(flops) 243 | # print(params) 244 | print('Total params: %.2fM' % (params/1000000.0)) 245 | print('Total flops: %.2fM' % (flops/1000000.0)) 246 | x = torch.randn(input_size) 247 | out = net(x) 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /fire/models/myefficientnet_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.7.1" 2 | from .model import EfficientNet, VALID_MODELS 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) 10 | -------------------------------------------------------------------------------- /fire/models/myefficientnet_pytorch/model.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for EfficientNet. 2 | They are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from .utils import ( 13 | round_filters, 14 | round_repeats, 15 | drop_connect, 16 | get_same_padding_conv2d, 17 | get_model_params, 18 | efficientnet_params, 19 | load_pretrained_weights, 20 | Swish, 21 | MemoryEfficientSwish, 22 | calculate_output_image_size 23 | ) 24 | 25 | 26 | VALID_MODELS = ( 27 | 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 28 | 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', 29 | 'efficientnet-b8', 30 | 31 | # Support the construction of 'efficientnet-l2' without pretrained weights 32 | 'efficientnet-l2' 33 | ) 34 | 35 | 36 | class MBConvBlock(nn.Module): 37 | """Mobile Inverted Residual Bottleneck Block. 38 | 39 | Args: 40 | block_args (namedtuple): BlockArgs, defined in utils.py. 41 | global_params (namedtuple): GlobalParam, defined in utils.py. 42 | image_size (tuple or list): [image_height, image_width]. 43 | 44 | References: 45 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) 46 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) 47 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) 48 | """ 49 | 50 | def __init__(self, block_args, global_params, image_size=None): 51 | super().__init__() 52 | self._block_args = block_args 53 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow 54 | self._bn_eps = global_params.batch_norm_epsilon 55 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 56 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect 57 | 58 | # Expansion phase (Inverted Bottleneck) 59 | inp = self._block_args.input_filters # number of input channels 60 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 61 | if self._block_args.expand_ratio != 1: 62 | Conv2d = get_same_padding_conv2d(image_size=image_size) 63 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 64 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 65 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size 66 | 67 | # Depthwise convolution phase 68 | k = self._block_args.kernel_size 69 | s = self._block_args.stride 70 | Conv2d = get_same_padding_conv2d(image_size=image_size) 71 | self._depthwise_conv = Conv2d( 72 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 73 | kernel_size=k, stride=s, bias=False) 74 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 75 | image_size = calculate_output_image_size(image_size, s) 76 | 77 | # Squeeze and Excitation layer, if desired 78 | if self.has_se: 79 | Conv2d = get_same_padding_conv2d(image_size=(1, 1)) 80 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 81 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 82 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 83 | 84 | # Pointwise convolution phase 85 | final_oup = self._block_args.output_filters 86 | Conv2d = get_same_padding_conv2d(image_size=image_size) 87 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 88 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 89 | self._swish = MemoryEfficientSwish() 90 | 91 | def forward(self, inputs, drop_connect_rate=None): 92 | """MBConvBlock's forward function. 93 | 94 | Args: 95 | inputs (tensor): Input tensor. 96 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). 97 | 98 | Returns: 99 | Output of this block after processing. 100 | """ 101 | 102 | # Expansion and Depthwise Convolution 103 | x = inputs 104 | if self._block_args.expand_ratio != 1: 105 | x = self._expand_conv(inputs) 106 | x = self._bn0(x) 107 | x = self._swish(x) 108 | 109 | x = self._depthwise_conv(x) 110 | x = self._bn1(x) 111 | x = self._swish(x) 112 | 113 | # Squeeze and Excitation 114 | if self.has_se: 115 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 116 | x_squeezed = self._se_reduce(x_squeezed) 117 | x_squeezed = self._swish(x_squeezed) 118 | x_squeezed = self._se_expand(x_squeezed) 119 | x = torch.sigmoid(x_squeezed) * x 120 | 121 | # Pointwise Convolution 122 | x = self._project_conv(x) 123 | x = self._bn2(x) 124 | 125 | # Skip connection and drop connect 126 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 127 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 128 | # The combination of skip connection and drop connect brings about stochastic depth. 129 | if drop_connect_rate: 130 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 131 | x = x + inputs # skip connection 132 | return x 133 | 134 | def set_swish(self, memory_efficient=True): 135 | """Sets swish function as memory efficient (for training) or standard (for export). 136 | 137 | Args: 138 | memory_efficient (bool): Whether to use memory-efficient version of swish. 139 | """ 140 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 141 | 142 | 143 | class EfficientNet(nn.Module): 144 | """EfficientNet model. 145 | Most easily loaded with the .from_name or .from_pretrained methods. 146 | 147 | Args: 148 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. 149 | global_params (namedtuple): A set of GlobalParams shared between blocks. 150 | 151 | References: 152 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet) 153 | 154 | Example: 155 | >>> import torch 156 | >>> from efficientnet.model import EfficientNet 157 | >>> inputs = torch.rand(1, 3, 224, 224) 158 | >>> model = EfficientNet.from_pretrained('efficientnet-b0') 159 | >>> model.eval() 160 | >>> outputs = model(inputs) 161 | """ 162 | 163 | def __init__(self, blocks_args=None, global_params=None): 164 | super().__init__() 165 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 166 | assert len(blocks_args) > 0, 'block args must be greater than 0' 167 | self._global_params = global_params 168 | self._blocks_args = blocks_args 169 | 170 | # Batch norm parameters 171 | bn_mom = 1 - self._global_params.batch_norm_momentum 172 | bn_eps = self._global_params.batch_norm_epsilon 173 | 174 | # Get stem static or dynamic convolution depending on image size 175 | image_size = global_params.image_size 176 | Conv2d = get_same_padding_conv2d(image_size=image_size) 177 | 178 | # Stem 179 | in_channels = 3 # rgb 180 | out_channels = round_filters(32, self._global_params) # number of output channels 181 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 182 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 183 | image_size = calculate_output_image_size(image_size, 2) 184 | 185 | # Build blocks 186 | self._blocks = nn.ModuleList([]) 187 | for block_args in self._blocks_args: 188 | 189 | # Update block input and output filters based on depth multiplier. 190 | block_args = block_args._replace( 191 | input_filters=round_filters(block_args.input_filters, self._global_params), 192 | output_filters=round_filters(block_args.output_filters, self._global_params), 193 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 194 | ) 195 | 196 | # The first block needs to take care of stride and filter size increase. 197 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 198 | image_size = calculate_output_image_size(image_size, block_args.stride) 199 | if block_args.num_repeat > 1: # modify block_args to keep same output size 200 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 201 | for _ in range(block_args.num_repeat - 1): 202 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 203 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 204 | 205 | # Head 206 | in_channels = block_args.output_filters # output of final block 207 | out_channels = round_filters(1280, self._global_params) 208 | Conv2d = get_same_padding_conv2d(image_size=image_size) 209 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 210 | # self._conv_head2 = Conv2d(80, 80, kernel_size=1, bias=False) 211 | # self._conv_head3 = Conv2d(192, 192, kernel_size=1, bias=False) 212 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 213 | # self._bn2 = nn.BatchNorm2d(num_features=80, momentum=bn_mom, eps=bn_eps) 214 | # self._bn3 = nn.BatchNorm2d(num_features=192, momentum=bn_mom, eps=bn_eps) 215 | 216 | # Final linear layer 217 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 218 | self._max_pooling = nn.AdaptiveMaxPool2d(1) 219 | if self._global_params.include_top: 220 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 221 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 222 | 223 | # set activation to memory efficient swish by default 224 | self._swish = MemoryEfficientSwish() 225 | 226 | def set_swish(self, memory_efficient=True): 227 | """Sets swish function as memory efficient (for training) or standard (for export). 228 | 229 | Args: 230 | memory_efficient (bool): Whether to use memory-efficient version of swish. 231 | """ 232 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 233 | for block in self._blocks: 234 | block.set_swish(memory_efficient) 235 | 236 | def extract_endpoints(self, inputs): 237 | """Use convolution layer to extract features 238 | from reduction levels i in [1, 2, 3, 4, 5]. 239 | 240 | Args: 241 | inputs (tensor): Input tensor. 242 | 243 | Returns: 244 | Dictionary of last intermediate features 245 | with reduction levels i in [1, 2, 3, 4, 5]. 246 | Example: 247 | >>> import torch 248 | >>> from efficientnet.model import EfficientNet 249 | >>> inputs = torch.rand(1, 3, 224, 224) 250 | >>> model = EfficientNet.from_pretrained('efficientnet-b0') 251 | >>> endpoints = model.extract_endpoints(inputs) 252 | >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) 253 | >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) 254 | >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) 255 | >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) 256 | >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) 257 | >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) 258 | """ 259 | endpoints = dict() 260 | 261 | # Stem 262 | x = self._swish(self._bn0(self._conv_stem(inputs))) 263 | prev_x = x 264 | 265 | # Blocks 266 | for idx, block in enumerate(self._blocks): 267 | drop_connect_rate = self._global_params.drop_connect_rate 268 | if drop_connect_rate: 269 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 270 | x = block(x, drop_connect_rate=drop_connect_rate) 271 | if prev_x.size(2) > x.size(2): 272 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x 273 | elif idx == len(self._blocks) - 1: 274 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = x 275 | prev_x = x 276 | 277 | # Head 278 | x = self._swish(self._bn1(self._conv_head(x))) 279 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = x 280 | 281 | return endpoints 282 | 283 | def extract_features(self, inputs): 284 | """use convolution layer to extract feature . 285 | 286 | Args: 287 | inputs (tensor): Input tensor. 288 | 289 | Returns: 290 | Output of the final convolution 291 | layer in the efficientnet model. 292 | """ 293 | # Stem 294 | x = self._swish(self._bn0(self._conv_stem(inputs))) 295 | 296 | 297 | # Blocks 298 | for idx, block in enumerate(self._blocks): 299 | drop_connect_rate = self._global_params.drop_connect_rate 300 | if drop_connect_rate: 301 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 302 | x = block(x, drop_connect_rate=drop_connect_rate) 303 | 304 | # if idx==10: 305 | # x2 = x 306 | # if idx==20: 307 | # x3 = x 308 | 309 | # Head 310 | # print(len(self._blocks))#0:16 1:23 2: 311 | # print(x.shape,x2.shape,x3.shape) 312 | # b 313 | x = self._swish(self._bn1(self._conv_head(x))) 314 | # x2 = self._swish(self._bn2(self._conv_head2(x2))) 315 | # x3 = self._swish(self._bn3(self._conv_head3(x3))) 316 | return x#,x2,x3 317 | 318 | def forward(self, inputs): 319 | """EfficientNet's forward function. 320 | Calls extract_features to extract features, applies final linear layer, and returns logits. 321 | 322 | Args: 323 | inputs (tensor): Input tensor. 324 | 325 | Returns: 326 | Output of this model after processing. 327 | """ 328 | # Convolution layers 329 | x = self.extract_features(inputs)#[64, 1280, 7, 7] 330 | # x,x2,x3 = self.extract_features(inputs) 331 | # Pooling and final linear layer 332 | x = self._max_pooling(x) 333 | # x2 = self._avg_pooling(x2) 334 | # x3 = self._max_pooling(x3) 335 | # print(x.shape,x2.shape,x3.shape) 336 | # b 337 | # print(x.shape) 338 | # b 339 | # if self._global_params.include_top: 340 | # x = x.flatten(start_dim=1) 341 | # x = self._dropout(x) 342 | # x = self._fc(x) 343 | return x#,x2,x3 344 | 345 | @classmethod 346 | def from_name(cls, model_name, in_channels=3, **override_params): 347 | """Create an efficientnet model according to name. 348 | 349 | Args: 350 | model_name (str): Name for efficientnet. 351 | in_channels (int): Input data's channel number. 352 | override_params (other key word params): 353 | Params to override model's global_params. 354 | Optional key: 355 | 'width_coefficient', 'depth_coefficient', 356 | 'image_size', 'dropout_rate', 357 | 'num_classes', 'batch_norm_momentum', 358 | 'batch_norm_epsilon', 'drop_connect_rate', 359 | 'depth_divisor', 'min_depth' 360 | 361 | Returns: 362 | An efficientnet model. 363 | """ 364 | cls._check_model_name_is_valid(model_name) 365 | blocks_args, global_params = get_model_params(model_name, override_params) 366 | model = cls(blocks_args, global_params) 367 | model._change_in_channels(in_channels) 368 | return model 369 | 370 | @classmethod 371 | def from_pretrained(cls, model_name, weights_path=None, advprop=False, 372 | in_channels=3, num_classes=1000, **override_params): 373 | """Create an efficientnet model according to name. 374 | 375 | Args: 376 | model_name (str): Name for efficientnet. 377 | weights_path (None or str): 378 | str: path to pretrained weights file on the local disk. 379 | None: use pretrained weights downloaded from the Internet. 380 | advprop (bool): 381 | Whether to load pretrained weights 382 | trained with advprop (valid when weights_path is None). 383 | in_channels (int): Input data's channel number. 384 | num_classes (int): 385 | Number of categories for classification. 386 | It controls the output size for final linear layer. 387 | override_params (other key word params): 388 | Params to override model's global_params. 389 | Optional key: 390 | 'width_coefficient', 'depth_coefficient', 391 | 'image_size', 'dropout_rate', 392 | 'batch_norm_momentum', 393 | 'batch_norm_epsilon', 'drop_connect_rate', 394 | 'depth_divisor', 'min_depth' 395 | 396 | Returns: 397 | A pretrained efficientnet model. 398 | """ 399 | model = cls.from_name(model_name, num_classes=num_classes, **override_params) 400 | load_pretrained_weights(model, model_name, weights_path=weights_path, 401 | load_fc=(num_classes == 1000), advprop=advprop) 402 | model._change_in_channels(in_channels) 403 | return model 404 | 405 | @classmethod 406 | def get_image_size(cls, model_name): 407 | """Get the input image size for a given efficientnet model. 408 | 409 | Args: 410 | model_name (str): Name for efficientnet. 411 | 412 | Returns: 413 | Input image size (resolution). 414 | """ 415 | cls._check_model_name_is_valid(model_name) 416 | _, _, res, _ = efficientnet_params(model_name) 417 | return res 418 | 419 | @classmethod 420 | def _check_model_name_is_valid(cls, model_name): 421 | """Validates model name. 422 | 423 | Args: 424 | model_name (str): Name for efficientnet. 425 | 426 | Returns: 427 | bool: Is a valid name or not. 428 | """ 429 | if model_name not in VALID_MODELS: 430 | raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS)) 431 | 432 | def _change_in_channels(self, in_channels): 433 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. 434 | 435 | Args: 436 | in_channels (int): Input data's channel number. 437 | """ 438 | if in_channels != 3: 439 | Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) 440 | out_channels = round_filters(32, self._global_params) 441 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 442 | -------------------------------------------------------------------------------- /fire/models/myefficientnet_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | """utils.py - Helper functions for building the model and for loading model parameters. 2 | These helper functions are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import re 10 | import math 11 | import collections 12 | from functools import partial 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.utils import model_zoo 17 | 18 | 19 | ################################################################################ 20 | # Help functions for model architecture 21 | ################################################################################ 22 | 23 | # GlobalParams and BlockArgs: Two namedtuples 24 | # Swish and MemoryEfficientSwish: Two implementations of the method 25 | # round_filters and round_repeats: 26 | # Functions to calculate params for scaling model width and depth ! ! ! 27 | # get_width_and_height_from_size and calculate_output_image_size 28 | # drop_connect: A structural design 29 | # get_same_padding_conv2d: 30 | # Conv2dDynamicSamePadding 31 | # Conv2dStaticSamePadding 32 | # get_same_padding_maxPool2d: 33 | # MaxPool2dDynamicSamePadding 34 | # MaxPool2dStaticSamePadding 35 | # It's an additional function, not used in EfficientNet, 36 | # but can be used in other model (such as EfficientDet). 37 | 38 | # Parameters for the entire model (stem, all blocks, and head) 39 | GlobalParams = collections.namedtuple('GlobalParams', [ 40 | 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', 41 | 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', 42 | 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top']) 43 | 44 | # Parameters for an individual model block 45 | BlockArgs = collections.namedtuple('BlockArgs', [ 46 | 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', 47 | 'input_filters', 'output_filters', 'se_ratio', 'id_skip']) 48 | 49 | # Set GlobalParams and BlockArgs's defaults 50 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 51 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 52 | 53 | # Swish activation function 54 | if hasattr(nn, 'SiLU'): 55 | Swish = nn.SiLU 56 | else: 57 | # For compatibility with old PyTorch versions 58 | class Swish(nn.Module): 59 | def forward(self, x): 60 | return x * torch.sigmoid(x) 61 | 62 | 63 | # A memory-efficient implementation of Swish function 64 | class SwishImplementation(torch.autograd.Function): 65 | @staticmethod 66 | def forward(ctx, i): 67 | result = i * torch.sigmoid(i) 68 | ctx.save_for_backward(i) 69 | return result 70 | 71 | @staticmethod 72 | def backward(ctx, grad_output): 73 | i = ctx.saved_tensors[0] 74 | sigmoid_i = torch.sigmoid(i) 75 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 76 | 77 | 78 | class MemoryEfficientSwish(nn.Module): 79 | def forward(self, x): 80 | return SwishImplementation.apply(x) 81 | 82 | 83 | def round_filters(filters, global_params): 84 | """Calculate and round number of filters based on width multiplier. 85 | Use width_coefficient, depth_divisor and min_depth of global_params. 86 | 87 | Args: 88 | filters (int): Filters number to be calculated. 89 | global_params (namedtuple): Global params of the model. 90 | 91 | Returns: 92 | new_filters: New filters number after calculating. 93 | """ 94 | multiplier = global_params.width_coefficient 95 | if not multiplier: 96 | return filters 97 | # TODO: modify the params names. 98 | # maybe the names (width_divisor,min_width) 99 | # are more suitable than (depth_divisor,min_depth). 100 | divisor = global_params.depth_divisor 101 | min_depth = global_params.min_depth 102 | filters *= multiplier 103 | min_depth = min_depth or divisor # pay attention to this line when using min_depth 104 | # follow the formula transferred from official TensorFlow implementation 105 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 106 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 107 | new_filters += divisor 108 | return int(new_filters) 109 | 110 | 111 | def round_repeats(repeats, global_params): 112 | """Calculate module's repeat number of a block based on depth multiplier. 113 | Use depth_coefficient of global_params. 114 | 115 | Args: 116 | repeats (int): num_repeat to be calculated. 117 | global_params (namedtuple): Global params of the model. 118 | 119 | Returns: 120 | new repeat: New repeat number after calculating. 121 | """ 122 | multiplier = global_params.depth_coefficient 123 | if not multiplier: 124 | return repeats 125 | # follow the formula transferred from official TensorFlow implementation 126 | return int(math.ceil(multiplier * repeats)) 127 | 128 | 129 | def drop_connect(inputs, p, training): 130 | """Drop connect. 131 | 132 | Args: 133 | input (tensor: BCWH): Input of this structure. 134 | p (float: 0.0~1.0): Probability of drop connection. 135 | training (bool): The running mode. 136 | 137 | Returns: 138 | output: Output after drop connection. 139 | """ 140 | assert 0 <= p <= 1, 'p must be in range of [0,1]' 141 | 142 | if not training: 143 | return inputs 144 | 145 | batch_size = inputs.shape[0] 146 | keep_prob = 1 - p 147 | 148 | # generate binary_tensor mask according to probability (p for 0, 1-p for 1) 149 | random_tensor = keep_prob 150 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 151 | binary_tensor = torch.floor(random_tensor) 152 | 153 | output = inputs / keep_prob * binary_tensor 154 | return output 155 | 156 | 157 | def get_width_and_height_from_size(x): 158 | """Obtain height and width from x. 159 | 160 | Args: 161 | x (int, tuple or list): Data size. 162 | 163 | Returns: 164 | size: A tuple or list (H,W). 165 | """ 166 | if isinstance(x, int): 167 | return x, x 168 | if isinstance(x, list) or isinstance(x, tuple): 169 | return x 170 | else: 171 | raise TypeError() 172 | 173 | 174 | def calculate_output_image_size(input_image_size, stride): 175 | """Calculates the output image size when using Conv2dSamePadding with a stride. 176 | Necessary for static padding. Thanks to mannatsingh for pointing this out. 177 | 178 | Args: 179 | input_image_size (int, tuple or list): Size of input image. 180 | stride (int, tuple or list): Conv2d operation's stride. 181 | 182 | Returns: 183 | output_image_size: A list [H,W]. 184 | """ 185 | if input_image_size is None: 186 | return None 187 | image_height, image_width = get_width_and_height_from_size(input_image_size) 188 | stride = stride if isinstance(stride, int) else stride[0] 189 | image_height = int(math.ceil(image_height / stride)) 190 | image_width = int(math.ceil(image_width / stride)) 191 | return [image_height, image_width] 192 | 193 | 194 | # Note: 195 | # The following 'SamePadding' functions make output size equal ceil(input size/stride). 196 | # Only when stride equals 1, can the output size be the same as input size. 197 | # Don't be confused by their function names ! ! ! 198 | 199 | def get_same_padding_conv2d(image_size=None): 200 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise. 201 | Static padding is necessary for ONNX exporting of models. 202 | 203 | Args: 204 | image_size (int or tuple): Size of the image. 205 | 206 | Returns: 207 | Conv2dDynamicSamePadding or Conv2dStaticSamePadding. 208 | """ 209 | if image_size is None: 210 | return Conv2dDynamicSamePadding 211 | else: 212 | return partial(Conv2dStaticSamePadding, image_size=image_size) 213 | 214 | 215 | class Conv2dDynamicSamePadding(nn.Conv2d): 216 | """2D Convolutions like TensorFlow, for a dynamic image size. 217 | The padding is operated in forward function by calculating dynamically. 218 | """ 219 | 220 | # Tips for 'SAME' mode padding. 221 | # Given the following: 222 | # i: width or height 223 | # s: stride 224 | # k: kernel size 225 | # d: dilation 226 | # p: padding 227 | # Output after Conv2d: 228 | # o = floor((i+p-((k-1)*d+1))/s+1) 229 | # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), 230 | # => p = (i-1)*s+((k-1)*d+1)-i 231 | 232 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 233 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 234 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 235 | 236 | def forward(self, x): 237 | ih, iw = x.size()[-2:] 238 | kh, kw = self.weight.size()[-2:] 239 | sh, sw = self.stride 240 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! 241 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 242 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 243 | if pad_h > 0 or pad_w > 0: 244 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 245 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 246 | 247 | 248 | class Conv2dStaticSamePadding(nn.Conv2d): 249 | """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. 250 | The padding mudule is calculated in construction function, then used in forward. 251 | """ 252 | 253 | # With the same calculation as Conv2dDynamicSamePadding 254 | 255 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): 256 | super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) 257 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 258 | 259 | # Calculate padding based on image size and save it 260 | assert image_size is not None 261 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size 262 | kh, kw = self.weight.size()[-2:] 263 | sh, sw = self.stride 264 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 265 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 266 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 267 | if pad_h > 0 or pad_w > 0: 268 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, 269 | pad_h // 2, pad_h - pad_h // 2)) 270 | else: 271 | self.static_padding = nn.Identity() 272 | 273 | def forward(self, x): 274 | x = self.static_padding(x) 275 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 276 | return x 277 | 278 | 279 | def get_same_padding_maxPool2d(image_size=None): 280 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise. 281 | Static padding is necessary for ONNX exporting of models. 282 | 283 | Args: 284 | image_size (int or tuple): Size of the image. 285 | 286 | Returns: 287 | MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. 288 | """ 289 | if image_size is None: 290 | return MaxPool2dDynamicSamePadding 291 | else: 292 | return partial(MaxPool2dStaticSamePadding, image_size=image_size) 293 | 294 | 295 | class MaxPool2dDynamicSamePadding(nn.MaxPool2d): 296 | """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. 297 | The padding is operated in forward function by calculating dynamically. 298 | """ 299 | 300 | def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False): 301 | super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) 302 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride 303 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size 304 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation 305 | 306 | def forward(self, x): 307 | ih, iw = x.size()[-2:] 308 | kh, kw = self.kernel_size 309 | sh, sw = self.stride 310 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 311 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 312 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 313 | if pad_h > 0 or pad_w > 0: 314 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 315 | return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, 316 | self.dilation, self.ceil_mode, self.return_indices) 317 | 318 | 319 | class MaxPool2dStaticSamePadding(nn.MaxPool2d): 320 | """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. 321 | The padding mudule is calculated in construction function, then used in forward. 322 | """ 323 | 324 | def __init__(self, kernel_size, stride, image_size=None, **kwargs): 325 | super().__init__(kernel_size, stride, **kwargs) 326 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride 327 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size 328 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation 329 | 330 | # Calculate padding based on image size and save it 331 | assert image_size is not None 332 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size 333 | kh, kw = self.kernel_size 334 | sh, sw = self.stride 335 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 336 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 337 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 338 | if pad_h > 0 or pad_w > 0: 339 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 340 | else: 341 | self.static_padding = nn.Identity() 342 | 343 | def forward(self, x): 344 | x = self.static_padding(x) 345 | x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, 346 | self.dilation, self.ceil_mode, self.return_indices) 347 | return x 348 | 349 | 350 | ################################################################################ 351 | # Helper functions for loading model params 352 | ################################################################################ 353 | 354 | # BlockDecoder: A Class for encoding and decoding BlockArgs 355 | # efficientnet_params: A function to query compound coefficient 356 | # get_model_params and efficientnet: 357 | # Functions to get BlockArgs and GlobalParams for efficientnet 358 | # url_map and url_map_advprop: Dicts of url_map for pretrained weights 359 | # load_pretrained_weights: A function to load pretrained weights 360 | 361 | class BlockDecoder(object): 362 | """Block Decoder for readability, 363 | straight from the official TensorFlow repository. 364 | """ 365 | 366 | @staticmethod 367 | def _decode_block_string(block_string): 368 | """Get a block through a string notation of arguments. 369 | 370 | Args: 371 | block_string (str): A string notation of arguments. 372 | Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. 373 | 374 | Returns: 375 | BlockArgs: The namedtuple defined at the top of this file. 376 | """ 377 | assert isinstance(block_string, str) 378 | 379 | ops = block_string.split('_') 380 | options = {} 381 | for op in ops: 382 | splits = re.split(r'(\d.*)', op) 383 | if len(splits) >= 2: 384 | key, value = splits[:2] 385 | options[key] = value 386 | 387 | # Check stride 388 | assert (('s' in options and len(options['s']) == 1) or 389 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 390 | 391 | return BlockArgs( 392 | num_repeat=int(options['r']), 393 | kernel_size=int(options['k']), 394 | stride=[int(options['s'][0])], 395 | expand_ratio=int(options['e']), 396 | input_filters=int(options['i']), 397 | output_filters=int(options['o']), 398 | se_ratio=float(options['se']) if 'se' in options else None, 399 | id_skip=('noskip' not in block_string)) 400 | 401 | @staticmethod 402 | def _encode_block_string(block): 403 | """Encode a block to a string. 404 | 405 | Args: 406 | block (namedtuple): A BlockArgs type argument. 407 | 408 | Returns: 409 | block_string: A String form of BlockArgs. 410 | """ 411 | args = [ 412 | 'r%d' % block.num_repeat, 413 | 'k%d' % block.kernel_size, 414 | 's%d%d' % (block.strides[0], block.strides[1]), 415 | 'e%s' % block.expand_ratio, 416 | 'i%d' % block.input_filters, 417 | 'o%d' % block.output_filters 418 | ] 419 | if 0 < block.se_ratio <= 1: 420 | args.append('se%s' % block.se_ratio) 421 | if block.id_skip is False: 422 | args.append('noskip') 423 | return '_'.join(args) 424 | 425 | @staticmethod 426 | def decode(string_list): 427 | """Decode a list of string notations to specify blocks inside the network. 428 | 429 | Args: 430 | string_list (list[str]): A list of strings, each string is a notation of block. 431 | 432 | Returns: 433 | blocks_args: A list of BlockArgs namedtuples of block args. 434 | """ 435 | assert isinstance(string_list, list) 436 | blocks_args = [] 437 | for block_string in string_list: 438 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 439 | return blocks_args 440 | 441 | @staticmethod 442 | def encode(blocks_args): 443 | """Encode a list of BlockArgs to a list of strings. 444 | 445 | Args: 446 | blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. 447 | 448 | Returns: 449 | block_strings: A list of strings, each string is a notation of block. 450 | """ 451 | block_strings = [] 452 | for block in blocks_args: 453 | block_strings.append(BlockDecoder._encode_block_string(block)) 454 | return block_strings 455 | 456 | 457 | def efficientnet_params(model_name): 458 | """Map EfficientNet model name to parameter coefficients. 459 | 460 | Args: 461 | model_name (str): Model name to be queried. 462 | 463 | Returns: 464 | params_dict[model_name]: A (width,depth,res,dropout) tuple. 465 | """ 466 | params_dict = { 467 | # Coefficients: width,depth,res,dropout 468 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 469 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 470 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 471 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 472 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 473 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 474 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 475 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 476 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 477 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5), 478 | } 479 | return params_dict[model_name] 480 | 481 | 482 | def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, 483 | dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): 484 | """Create BlockArgs and GlobalParams for efficientnet model. 485 | 486 | Args: 487 | width_coefficient (float) 488 | depth_coefficient (float) 489 | image_size (int) 490 | dropout_rate (float) 491 | drop_connect_rate (float) 492 | num_classes (int) 493 | 494 | Meaning as the name suggests. 495 | 496 | Returns: 497 | blocks_args, global_params. 498 | """ 499 | 500 | # Blocks args for the whole model(efficientnet-b0 by default) 501 | # It will be modified in the construction of EfficientNet Class according to model 502 | blocks_args = [ 503 | 'r1_k3_s11_e1_i32_o16_se0.25', 504 | 'r2_k3_s22_e6_i16_o24_se0.25', 505 | 'r2_k5_s22_e6_i24_o40_se0.25', 506 | 'r3_k3_s22_e6_i40_o80_se0.25', 507 | 'r3_k5_s11_e6_i80_o112_se0.25', 508 | 'r4_k5_s22_e6_i112_o192_se0.25', 509 | 'r1_k3_s11_e6_i192_o320_se0.25', 510 | ] 511 | blocks_args = BlockDecoder.decode(blocks_args) 512 | 513 | global_params = GlobalParams( 514 | width_coefficient=width_coefficient, 515 | depth_coefficient=depth_coefficient, 516 | image_size=image_size, 517 | dropout_rate=dropout_rate, 518 | 519 | num_classes=num_classes, 520 | batch_norm_momentum=0.99, 521 | batch_norm_epsilon=1e-3, 522 | drop_connect_rate=drop_connect_rate, 523 | depth_divisor=8, 524 | min_depth=None, 525 | include_top=include_top, 526 | ) 527 | 528 | return blocks_args, global_params 529 | 530 | 531 | def get_model_params(model_name, override_params): 532 | """Get the block args and global params for a given model name. 533 | 534 | Args: 535 | model_name (str): Model's name. 536 | override_params (dict): A dict to modify global_params. 537 | 538 | Returns: 539 | blocks_args, global_params 540 | """ 541 | if model_name.startswith('efficientnet'): 542 | w, d, s, p = efficientnet_params(model_name) 543 | # note: all models have drop connect rate = 0.2 544 | blocks_args, global_params = efficientnet( 545 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) 546 | else: 547 | raise NotImplementedError('model name is not pre-defined: {}'.format(model_name)) 548 | if override_params: 549 | # ValueError will be raised here if override_params has fields not included in global_params. 550 | global_params = global_params._replace(**override_params) 551 | return blocks_args, global_params 552 | 553 | 554 | # train with Standard methods 555 | # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) 556 | url_map = { 557 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', 558 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', 559 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', 560 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', 561 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', 562 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', 563 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', 564 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', 565 | } 566 | 567 | # train with Adversarial Examples(AdvProp) 568 | # check more details in paper(Adversarial Examples Improve Image Recognition) 569 | url_map_advprop = { 570 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', 571 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', 572 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', 573 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', 574 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', 575 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', 576 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', 577 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', 578 | 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', 579 | } 580 | 581 | # TODO: add the petrained weights url map of 'efficientnet-l2' 582 | 583 | 584 | def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True): 585 | """Loads pretrained weights from weights path or download using url. 586 | 587 | Args: 588 | model (Module): The whole model of efficientnet. 589 | model_name (str): Model name of efficientnet. 590 | weights_path (None or str): 591 | str: path to pretrained weights file on the local disk. 592 | None: use pretrained weights downloaded from the Internet. 593 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 594 | advprop (bool): Whether to load pretrained weights 595 | trained with advprop (valid when weights_path is None). 596 | """ 597 | if isinstance(weights_path, str): 598 | state_dict = torch.load(weights_path) 599 | else: 600 | # AutoAugment or Advprop (different preprocessing) 601 | url_map_ = url_map_advprop if advprop else url_map 602 | state_dict = model_zoo.load_url(url_map_[model_name]) 603 | 604 | if load_fc: 605 | ret = model.load_state_dict(state_dict, strict=False) 606 | assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 607 | else: 608 | state_dict.pop('_fc.weight') 609 | state_dict.pop('_fc.bias') 610 | ret = model.load_state_dict(state_dict, strict=False) 611 | assert set(ret.missing_keys) == set( 612 | ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 613 | assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) 614 | 615 | if verbose: 616 | print('Loaded pretrained weights for {}'.format(model_name)) 617 | -------------------------------------------------------------------------------- /fire/models/swin/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model 2 | from .config import get_config -------------------------------------------------------------------------------- /fire/models/swin/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | from .swin_mlp import SwinMLP 10 | 11 | 12 | def build_model(config): 13 | model_type = config.MODEL.TYPE 14 | if model_type == 'swin': 15 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 16 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 17 | in_chans=config.MODEL.SWIN.IN_CHANS, 18 | num_classes=config.MODEL.NUM_CLASSES, 19 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 20 | depths=config.MODEL.SWIN.DEPTHS, 21 | num_heads=config.MODEL.SWIN.NUM_HEADS, 22 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 23 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 24 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 25 | qk_scale=config.MODEL.SWIN.QK_SCALE, 26 | drop_rate=config.MODEL.DROP_RATE, 27 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 28 | ape=config.MODEL.SWIN.APE, 29 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 30 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 31 | elif model_type == 'swin_mlp': 32 | model = SwinMLP(img_size=config.DATA.IMG_SIZE, 33 | patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE, 34 | in_chans=config.MODEL.SWIN_MLP.IN_CHANS, 35 | num_classes=config.MODEL.NUM_CLASSES, 36 | embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM, 37 | depths=config.MODEL.SWIN_MLP.DEPTHS, 38 | num_heads=config.MODEL.SWIN_MLP.NUM_HEADS, 39 | window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE, 40 | mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO, 41 | drop_rate=config.MODEL.DROP_RATE, 42 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 43 | ape=config.MODEL.SWIN_MLP.APE, 44 | patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM, 45 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 46 | else: 47 | raise NotImplementedError(f"Unkown model: {model_type}") 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /fire/models/swin/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Pretrained weight from checkpoint, could be imagenet22k pretrained weight 50 | # could be overwritten by command line argument 51 | _C.MODEL.PRETRAINED = '' 52 | # Checkpoint to resume, could be overwritten by command line argument 53 | _C.MODEL.RESUME = '' 54 | # Number of classes, overwritten in data preparation 55 | _C.MODEL.NUM_CLASSES = 1000 56 | # Dropout rate 57 | _C.MODEL.DROP_RATE = 0.0 58 | # Drop path rate 59 | _C.MODEL.DROP_PATH_RATE = 0.1 60 | # Label Smoothing 61 | _C.MODEL.LABEL_SMOOTHING = 0.1 62 | 63 | # Swin Transformer parameters 64 | _C.MODEL.SWIN = CN() 65 | _C.MODEL.SWIN.PATCH_SIZE = 4 66 | _C.MODEL.SWIN.IN_CHANS = 3 67 | _C.MODEL.SWIN.EMBED_DIM = 96 68 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 69 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 70 | _C.MODEL.SWIN.WINDOW_SIZE = 7 71 | _C.MODEL.SWIN.MLP_RATIO = 4. 72 | _C.MODEL.SWIN.QKV_BIAS = True 73 | _C.MODEL.SWIN.QK_SCALE = None 74 | _C.MODEL.SWIN.APE = False 75 | _C.MODEL.SWIN.PATCH_NORM = True 76 | 77 | # Swin MLP parameters 78 | _C.MODEL.SWIN_MLP = CN() 79 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 80 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 81 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 82 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 83 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 84 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 85 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 86 | _C.MODEL.SWIN_MLP.APE = False 87 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 88 | 89 | # ----------------------------------------------------------------------------- 90 | # Training settings 91 | # ----------------------------------------------------------------------------- 92 | _C.TRAIN = CN() 93 | _C.TRAIN.START_EPOCH = 0 94 | _C.TRAIN.EPOCHS = 300 95 | _C.TRAIN.WARMUP_EPOCHS = 20 96 | _C.TRAIN.WEIGHT_DECAY = 0.05 97 | _C.TRAIN.BASE_LR = 5e-4 98 | _C.TRAIN.WARMUP_LR = 5e-7 99 | _C.TRAIN.MIN_LR = 5e-6 100 | # Clip gradient norm 101 | _C.TRAIN.CLIP_GRAD = 5.0 102 | # Auto resume from latest checkpoint 103 | _C.TRAIN.AUTO_RESUME = True 104 | # Gradient accumulation steps 105 | # could be overwritten by command line argument 106 | _C.TRAIN.ACCUMULATION_STEPS = 0 107 | # Whether to use gradient checkpointing to save memory 108 | # could be overwritten by command line argument 109 | _C.TRAIN.USE_CHECKPOINT = False 110 | 111 | # LR scheduler 112 | _C.TRAIN.LR_SCHEDULER = CN() 113 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 114 | # Epoch interval to decay LR, used in StepLRScheduler 115 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 116 | # LR decay rate, used in StepLRScheduler 117 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 118 | 119 | # Optimizer 120 | _C.TRAIN.OPTIMIZER = CN() 121 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 122 | # Optimizer Epsilon 123 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 124 | # Optimizer Betas 125 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 126 | # SGD momentum 127 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 128 | 129 | # ----------------------------------------------------------------------------- 130 | # Augmentation settings 131 | # ----------------------------------------------------------------------------- 132 | _C.AUG = CN() 133 | # Color jitter factor 134 | _C.AUG.COLOR_JITTER = 0.4 135 | # Use AutoAugment policy. "v0" or "original" 136 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 137 | # Random erase prob 138 | _C.AUG.REPROB = 0.25 139 | # Random erase mode 140 | _C.AUG.REMODE = 'pixel' 141 | # Random erase count 142 | _C.AUG.RECOUNT = 1 143 | # Mixup alpha, mixup enabled if > 0 144 | _C.AUG.MIXUP = 0.8 145 | # Cutmix alpha, cutmix enabled if > 0 146 | _C.AUG.CUTMIX = 1.0 147 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 148 | _C.AUG.CUTMIX_MINMAX = None 149 | # Probability of performing mixup or cutmix when either/both is enabled 150 | _C.AUG.MIXUP_PROB = 1.0 151 | # Probability of switching to cutmix when both mixup and cutmix enabled 152 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 153 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 154 | _C.AUG.MIXUP_MODE = 'batch' 155 | 156 | # ----------------------------------------------------------------------------- 157 | # Testing settings 158 | # ----------------------------------------------------------------------------- 159 | _C.TEST = CN() 160 | # Whether to use center crop when testing 161 | _C.TEST.CROP = True 162 | # Whether to use SequentialSampler as validation sampler 163 | _C.TEST.SEQUENTIAL = False 164 | 165 | # ----------------------------------------------------------------------------- 166 | # Misc 167 | # ----------------------------------------------------------------------------- 168 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 169 | # overwritten by command line argument 170 | _C.AMP_OPT_LEVEL = '' 171 | # Path to output folder, overwritten by command line argument 172 | _C.OUTPUT = '' 173 | # Tag of experiment, overwritten by command line argument 174 | _C.TAG = 'default' 175 | # Frequency to save checkpoint 176 | _C.SAVE_FREQ = 1 177 | # Frequency to logging info 178 | _C.PRINT_FREQ = 10 179 | # Fixed random seed 180 | _C.SEED = 0 181 | # Perform evaluation only, overwritten by command line argument 182 | _C.EVAL_MODE = False 183 | # Test throughput only, overwritten by command line argument 184 | _C.THROUGHPUT_MODE = False 185 | # local rank for DistributedDataParallel, given by command line argument 186 | _C.LOCAL_RANK = 0 187 | 188 | 189 | def _update_config_from_file(config, cfg_file): 190 | config.defrost() 191 | with open(cfg_file, 'r') as f: 192 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 193 | 194 | for cfg in yaml_cfg.setdefault('BASE', ['']): 195 | if cfg: 196 | _update_config_from_file( 197 | config, os.path.join(os.path.dirname(cfg_file), cfg) 198 | ) 199 | print('=> merge config from {}'.format(cfg_file)) 200 | config.merge_from_file(cfg_file) 201 | config.freeze() 202 | 203 | 204 | def update_config(config, args): 205 | _update_config_from_file(config, args.cfg) 206 | 207 | config.defrost() 208 | if args.opts: 209 | config.merge_from_list(args.opts) 210 | 211 | # merge from specific arguments 212 | if args.batch_size: 213 | config.DATA.BATCH_SIZE = args.batch_size 214 | if args.data_path: 215 | config.DATA.DATA_PATH = args.data_path 216 | if args.zip: 217 | config.DATA.ZIP_MODE = True 218 | if args.cache_mode: 219 | config.DATA.CACHE_MODE = args.cache_mode 220 | if args.pretrained: 221 | config.MODEL.PRETRAINED = args.pretrained 222 | if args.resume: 223 | config.MODEL.RESUME = args.resume 224 | if args.accumulation_steps: 225 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 226 | if args.use_checkpoint: 227 | config.TRAIN.USE_CHECKPOINT = True 228 | if args.amp_opt_level: 229 | config.AMP_OPT_LEVEL = args.amp_opt_level 230 | if args.output: 231 | config.OUTPUT = args.output 232 | if args.tag: 233 | config.TAG = args.tag 234 | if args.eval: 235 | config.EVAL_MODE = True 236 | if args.throughput: 237 | config.THROUGHPUT_MODE = True 238 | 239 | # set local rank for distributed training 240 | config.LOCAL_RANK = args.local_rank 241 | 242 | # output folder 243 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 244 | 245 | config.freeze() 246 | 247 | def update_config_cfg(config, cfg): 248 | _update_config_from_file(config, cfg) 249 | config.freeze() 250 | 251 | # def get_config(args): 252 | # """Get a yacs CfgNode object with default values.""" 253 | # # Return a clone so that the defaults will not be altered 254 | # # This is for the "local variable" use pattern 255 | # config = _C.clone() 256 | # update_config(config, args) 257 | 258 | # return config 259 | def get_config(cfg): 260 | """Get a yacs CfgNode object with default values.""" 261 | # Return a clone so that the defaults will not be altered 262 | # This is for the "local variable" use pattern 263 | config = _C.clone() 264 | update_config_cfg(config, cfg) 265 | 266 | return config -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_base_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window12_384_22kto1k_finetune 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_base_patch4_window12_384_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window12_384_finetune 6 | DROP_PATH_RATE: 0.5 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_base_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_base_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_large_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_large_patch4_window12_384_22kto1k_finetune 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_large_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_large_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_large_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_large_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 6, 12, 24, 48 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_mlp_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin_mlp 3 | NAME: swin_mlp_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN_MLP: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_mlp_tiny_c12_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c12_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 8, 16, 32, 64 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_mlp_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_mlp_tiny_c6_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c6_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 16, 32, 64, 128 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_small_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /fire/models/swin/configs/swin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /fire/models/swin/swin_mlp.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | def window_partition(x, window_size): 35 | """ 36 | Args: 37 | x: (B, H, W, C) 38 | window_size (int): window size 39 | 40 | Returns: 41 | windows: (num_windows*B, window_size, window_size, C) 42 | """ 43 | B, H, W, C = x.shape 44 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 45 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 46 | return windows 47 | 48 | 49 | def window_reverse(windows, window_size, H, W): 50 | """ 51 | Args: 52 | windows: (num_windows*B, window_size, window_size, C) 53 | window_size (int): Window size 54 | H (int): Height of image 55 | W (int): Width of image 56 | 57 | Returns: 58 | x: (B, H, W, C) 59 | """ 60 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 61 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 62 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 63 | return x 64 | 65 | 66 | class SwinMLPBlock(nn.Module): 67 | r""" Swin MLP Block. 68 | 69 | Args: 70 | dim (int): Number of input channels. 71 | input_resolution (tuple[int]): Input resulotion. 72 | num_heads (int): Number of attention heads. 73 | window_size (int): Window size. 74 | shift_size (int): Shift size for SW-MSA. 75 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 76 | drop (float, optional): Dropout rate. Default: 0.0 77 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 78 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 79 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 80 | """ 81 | 82 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 83 | mlp_ratio=4., drop=0., drop_path=0., 84 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 85 | super().__init__() 86 | self.dim = dim 87 | self.input_resolution = input_resolution 88 | self.num_heads = num_heads 89 | self.window_size = window_size 90 | self.shift_size = shift_size 91 | self.mlp_ratio = mlp_ratio 92 | if min(self.input_resolution) <= self.window_size: 93 | # if window size is larger than input resolution, we don't partition windows 94 | self.shift_size = 0 95 | self.window_size = min(self.input_resolution) 96 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 97 | 98 | self.padding = [self.window_size - self.shift_size, self.shift_size, 99 | self.window_size - self.shift_size, self.shift_size] # P_l,P_r,P_t,P_b 100 | 101 | self.norm1 = norm_layer(dim) 102 | # use group convolution to implement multi-head MLP 103 | self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2, 104 | self.num_heads * self.window_size ** 2, 105 | kernel_size=1, 106 | groups=self.num_heads) 107 | 108 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 109 | self.norm2 = norm_layer(dim) 110 | mlp_hidden_dim = int(dim * mlp_ratio) 111 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 112 | 113 | def forward(self, x): 114 | H, W = self.input_resolution 115 | B, L, C = x.shape 116 | assert L == H * W, "input feature has wrong size" 117 | 118 | shortcut = x 119 | x = self.norm1(x) 120 | x = x.view(B, H, W, C) 121 | 122 | # shift 123 | if self.shift_size > 0: 124 | P_l, P_r, P_t, P_b = self.padding 125 | shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], "constant", 0) 126 | else: 127 | shifted_x = x 128 | _, _H, _W, _ = shifted_x.shape 129 | 130 | # partition windows 131 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 132 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 133 | 134 | # Window/Shifted-Window Spatial MLP 135 | x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads) 136 | x_windows_heads = x_windows_heads.transpose(1, 2) # nW*B, nH, window_size*window_size, C//nH 137 | x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size, 138 | C // self.num_heads) 139 | spatial_mlp_windows = self.spatial_mlp(x_windows_heads) # nW*B, nH*window_size*window_size, C//nH 140 | spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size, 141 | C // self.num_heads).transpose(1, 2) 142 | spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C) 143 | 144 | # merge windows 145 | spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C) 146 | shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W) # B H' W' C 147 | 148 | # reverse shift 149 | if self.shift_size > 0: 150 | P_l, P_r, P_t, P_b = self.padding 151 | x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous() 152 | else: 153 | x = shifted_x 154 | x = x.view(B, H * W, C) 155 | 156 | # FFN 157 | x = shortcut + self.drop_path(x) 158 | x = x + self.drop_path(self.mlp(self.norm2(x))) 159 | 160 | return x 161 | 162 | def extra_repr(self) -> str: 163 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 164 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 165 | 166 | def flops(self): 167 | flops = 0 168 | H, W = self.input_resolution 169 | # norm1 170 | flops += self.dim * H * W 171 | 172 | # Window/Shifted-Window Spatial MLP 173 | if self.shift_size > 0: 174 | nW = (H / self.window_size + 1) * (W / self.window_size + 1) 175 | else: 176 | nW = H * W / self.window_size / self.window_size 177 | flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size) 178 | # mlp 179 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 180 | # norm2 181 | flops += self.dim * H * W 182 | return flops 183 | 184 | 185 | class PatchMerging(nn.Module): 186 | r""" Patch Merging Layer. 187 | 188 | Args: 189 | input_resolution (tuple[int]): Resolution of input feature. 190 | dim (int): Number of input channels. 191 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 192 | """ 193 | 194 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 195 | super().__init__() 196 | self.input_resolution = input_resolution 197 | self.dim = dim 198 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 199 | self.norm = norm_layer(4 * dim) 200 | 201 | def forward(self, x): 202 | """ 203 | x: B, H*W, C 204 | """ 205 | H, W = self.input_resolution 206 | B, L, C = x.shape 207 | assert L == H * W, "input feature has wrong size" 208 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 209 | 210 | x = x.view(B, H, W, C) 211 | 212 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 213 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 214 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 215 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 216 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 217 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 218 | 219 | x = self.norm(x) 220 | x = self.reduction(x) 221 | 222 | return x 223 | 224 | def extra_repr(self) -> str: 225 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 226 | 227 | def flops(self): 228 | H, W = self.input_resolution 229 | flops = H * W * self.dim 230 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 231 | return flops 232 | 233 | 234 | class BasicLayer(nn.Module): 235 | """ A basic Swin MLP layer for one stage. 236 | 237 | Args: 238 | dim (int): Number of input channels. 239 | input_resolution (tuple[int]): Input resolution. 240 | depth (int): Number of blocks. 241 | num_heads (int): Number of attention heads. 242 | window_size (int): Local window size. 243 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 244 | drop (float, optional): Dropout rate. Default: 0.0 245 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 246 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 247 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 248 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 249 | """ 250 | 251 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 252 | mlp_ratio=4., drop=0., drop_path=0., 253 | norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 254 | 255 | super().__init__() 256 | self.dim = dim 257 | self.input_resolution = input_resolution 258 | self.depth = depth 259 | self.use_checkpoint = use_checkpoint 260 | 261 | # build blocks 262 | self.blocks = nn.ModuleList([ 263 | SwinMLPBlock(dim=dim, input_resolution=input_resolution, 264 | num_heads=num_heads, window_size=window_size, 265 | shift_size=0 if (i % 2 == 0) else window_size // 2, 266 | mlp_ratio=mlp_ratio, 267 | drop=drop, 268 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 269 | norm_layer=norm_layer) 270 | for i in range(depth)]) 271 | 272 | # patch merging layer 273 | if downsample is not None: 274 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 275 | else: 276 | self.downsample = None 277 | 278 | def forward(self, x): 279 | for blk in self.blocks: 280 | if self.use_checkpoint: 281 | x = checkpoint.checkpoint(blk, x) 282 | else: 283 | x = blk(x) 284 | if self.downsample is not None: 285 | x = self.downsample(x) 286 | return x 287 | 288 | def extra_repr(self) -> str: 289 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 290 | 291 | def flops(self): 292 | flops = 0 293 | for blk in self.blocks: 294 | flops += blk.flops() 295 | if self.downsample is not None: 296 | flops += self.downsample.flops() 297 | return flops 298 | 299 | 300 | class PatchEmbed(nn.Module): 301 | r""" Image to Patch Embedding 302 | 303 | Args: 304 | img_size (int): Image size. Default: 224. 305 | patch_size (int): Patch token size. Default: 4. 306 | in_chans (int): Number of input image channels. Default: 3. 307 | embed_dim (int): Number of linear projection output channels. Default: 96. 308 | norm_layer (nn.Module, optional): Normalization layer. Default: None 309 | """ 310 | 311 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 312 | super().__init__() 313 | img_size = to_2tuple(img_size) 314 | patch_size = to_2tuple(patch_size) 315 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 316 | self.img_size = img_size 317 | self.patch_size = patch_size 318 | self.patches_resolution = patches_resolution 319 | self.num_patches = patches_resolution[0] * patches_resolution[1] 320 | 321 | self.in_chans = in_chans 322 | self.embed_dim = embed_dim 323 | 324 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 325 | if norm_layer is not None: 326 | self.norm = norm_layer(embed_dim) 327 | else: 328 | self.norm = None 329 | 330 | def forward(self, x): 331 | B, C, H, W = x.shape 332 | # FIXME look at relaxing size constraints 333 | assert H == self.img_size[0] and W == self.img_size[1], \ 334 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 335 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 336 | if self.norm is not None: 337 | x = self.norm(x) 338 | return x 339 | 340 | def flops(self): 341 | Ho, Wo = self.patches_resolution 342 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 343 | if self.norm is not None: 344 | flops += Ho * Wo * self.embed_dim 345 | return flops 346 | 347 | 348 | class SwinMLP(nn.Module): 349 | r""" Swin MLP 350 | 351 | Args: 352 | img_size (int | tuple(int)): Input image size. Default 224 353 | patch_size (int | tuple(int)): Patch size. Default: 4 354 | in_chans (int): Number of input image channels. Default: 3 355 | num_classes (int): Number of classes for classification head. Default: 1000 356 | embed_dim (int): Patch embedding dimension. Default: 96 357 | depths (tuple(int)): Depth of each Swin MLP layer. 358 | num_heads (tuple(int)): Number of attention heads in different layers. 359 | window_size (int): Window size. Default: 7 360 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 361 | drop_rate (float): Dropout rate. Default: 0 362 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 363 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 364 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 365 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 366 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 367 | """ 368 | 369 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 370 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 371 | window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1, 372 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 373 | use_checkpoint=False, **kwargs): 374 | super().__init__() 375 | 376 | self.num_classes = num_classes 377 | self.num_layers = len(depths) 378 | self.embed_dim = embed_dim 379 | self.ape = ape 380 | self.patch_norm = patch_norm 381 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 382 | self.mlp_ratio = mlp_ratio 383 | 384 | # split image into non-overlapping patches 385 | self.patch_embed = PatchEmbed( 386 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 387 | norm_layer=norm_layer if self.patch_norm else None) 388 | num_patches = self.patch_embed.num_patches 389 | patches_resolution = self.patch_embed.patches_resolution 390 | self.patches_resolution = patches_resolution 391 | 392 | # absolute position embedding 393 | if self.ape: 394 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 395 | trunc_normal_(self.absolute_pos_embed, std=.02) 396 | 397 | self.pos_drop = nn.Dropout(p=drop_rate) 398 | 399 | # stochastic depth 400 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 401 | 402 | # build layers 403 | self.layers = nn.ModuleList() 404 | for i_layer in range(self.num_layers): 405 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 406 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 407 | patches_resolution[1] // (2 ** i_layer)), 408 | depth=depths[i_layer], 409 | num_heads=num_heads[i_layer], 410 | window_size=window_size, 411 | mlp_ratio=self.mlp_ratio, 412 | drop=drop_rate, 413 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 414 | norm_layer=norm_layer, 415 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 416 | use_checkpoint=use_checkpoint) 417 | self.layers.append(layer) 418 | 419 | self.norm = norm_layer(self.num_features) 420 | self.avgpool = nn.AdaptiveAvgPool1d(1) 421 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 422 | 423 | self.apply(self._init_weights) 424 | 425 | def _init_weights(self, m): 426 | if isinstance(m, (nn.Linear, nn.Conv1d)): 427 | trunc_normal_(m.weight, std=.02) 428 | if m.bias is not None: 429 | nn.init.constant_(m.bias, 0) 430 | elif isinstance(m, nn.LayerNorm): 431 | nn.init.constant_(m.bias, 0) 432 | nn.init.constant_(m.weight, 1.0) 433 | 434 | @torch.jit.ignore 435 | def no_weight_decay(self): 436 | return {'absolute_pos_embed'} 437 | 438 | @torch.jit.ignore 439 | def no_weight_decay_keywords(self): 440 | return {'relative_position_bias_table'} 441 | 442 | def forward_features(self, x): 443 | x = self.patch_embed(x) 444 | if self.ape: 445 | x = x + self.absolute_pos_embed 446 | x = self.pos_drop(x) 447 | 448 | for layer in self.layers: 449 | x = layer(x) 450 | 451 | x = self.norm(x) # B L C 452 | x = self.avgpool(x.transpose(1, 2)) # B C 1 453 | x = torch.flatten(x, 1) 454 | return x 455 | 456 | def forward(self, x): 457 | x = self.forward_features(x) 458 | x = self.head(x) 459 | return x 460 | 461 | def flops(self): 462 | flops = 0 463 | flops += self.patch_embed.flops() 464 | for i, layer in enumerate(self.layers): 465 | flops += layer.flops() 466 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 467 | flops += self.num_features * self.num_classes 468 | return flops 469 | -------------------------------------------------------------------------------- /fire/models/swin/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint as checkpoint 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.act = act_layer() 21 | self.fc2 = nn.Linear(hidden_features, out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | return x 31 | 32 | 33 | def window_partition(x, window_size): 34 | """ 35 | Args: 36 | x: (B, H, W, C) 37 | window_size (int): window size 38 | 39 | Returns: 40 | windows: (num_windows*B, window_size, window_size, C) 41 | """ 42 | B, H, W, C = x.shape 43 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 44 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 45 | return windows 46 | 47 | 48 | def window_reverse(windows, window_size, H, W): 49 | """ 50 | Args: 51 | windows: (num_windows*B, window_size, window_size, C) 52 | window_size (int): Window size 53 | H (int): Height of image 54 | W (int): Width of image 55 | 56 | Returns: 57 | x: (B, H, W, C) 58 | """ 59 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 60 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 61 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 62 | return x 63 | 64 | 65 | class WindowAttention(nn.Module): 66 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 67 | It supports both of shifted and non-shifted window. 68 | 69 | Args: 70 | dim (int): Number of input channels. 71 | window_size (tuple[int]): The height and width of the window. 72 | num_heads (int): Number of attention heads. 73 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 74 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 75 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 76 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 77 | """ 78 | 79 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 80 | 81 | super().__init__() 82 | self.dim = dim 83 | self.window_size = window_size # Wh, Ww 84 | self.num_heads = num_heads 85 | head_dim = dim // num_heads 86 | self.scale = qk_scale or head_dim ** -0.5 87 | 88 | # define a parameter table of relative position bias 89 | self.relative_position_bias_table = nn.Parameter( 90 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 91 | 92 | # get pair-wise relative position index for each token inside the window 93 | coords_h = torch.arange(self.window_size[0]) 94 | coords_w = torch.arange(self.window_size[1]) 95 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 96 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 97 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 98 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 99 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 100 | relative_coords[:, :, 1] += self.window_size[1] - 1 101 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 102 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 103 | self.register_buffer("relative_position_index", relative_position_index) 104 | 105 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 106 | self.attn_drop = nn.Dropout(attn_drop) 107 | self.proj = nn.Linear(dim, dim) 108 | self.proj_drop = nn.Dropout(proj_drop) 109 | 110 | trunc_normal_(self.relative_position_bias_table, std=.02) 111 | self.softmax = nn.Softmax(dim=-1) 112 | 113 | def forward(self, x, mask=None): 114 | """ 115 | Args: 116 | x: input features with shape of (num_windows*B, N, C) 117 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 118 | """ 119 | B_, N, C = x.shape 120 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 121 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 122 | 123 | q = q * self.scale 124 | attn = (q @ k.transpose(-2, -1)) 125 | 126 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 127 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 128 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 129 | attn = attn + relative_position_bias.unsqueeze(0) 130 | 131 | if mask is not None: 132 | nW = mask.shape[0] 133 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 134 | attn = attn.view(-1, self.num_heads, N, N) 135 | attn = self.softmax(attn) 136 | else: 137 | attn = self.softmax(attn) 138 | 139 | attn = self.attn_drop(attn) 140 | 141 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 142 | x = self.proj(x) 143 | x = self.proj_drop(x) 144 | return x 145 | 146 | def extra_repr(self) -> str: 147 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 148 | 149 | def flops(self, N): 150 | # calculate flops for 1 window with token length of N 151 | flops = 0 152 | # qkv = self.qkv(x) 153 | flops += N * self.dim * 3 * self.dim 154 | # attn = (q @ k.transpose(-2, -1)) 155 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 156 | # x = (attn @ v) 157 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 158 | # x = self.proj(x) 159 | flops += N * self.dim * self.dim 160 | return flops 161 | 162 | 163 | class SwinTransformerBlock(nn.Module): 164 | r""" Swin Transformer Block. 165 | 166 | Args: 167 | dim (int): Number of input channels. 168 | input_resolution (tuple[int]): Input resulotion. 169 | num_heads (int): Number of attention heads. 170 | window_size (int): Window size. 171 | shift_size (int): Shift size for SW-MSA. 172 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 173 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 174 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 175 | drop (float, optional): Dropout rate. Default: 0.0 176 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 177 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 178 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 179 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 180 | """ 181 | 182 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 183 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 184 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 185 | super().__init__() 186 | self.dim = dim 187 | self.input_resolution = input_resolution 188 | self.num_heads = num_heads 189 | self.window_size = window_size 190 | self.shift_size = shift_size 191 | self.mlp_ratio = mlp_ratio 192 | if min(self.input_resolution) <= self.window_size: 193 | # if window size is larger than input resolution, we don't partition windows 194 | self.shift_size = 0 195 | self.window_size = min(self.input_resolution) 196 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 197 | 198 | self.norm1 = norm_layer(dim) 199 | self.attn = WindowAttention( 200 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 201 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 202 | 203 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 204 | self.norm2 = norm_layer(dim) 205 | mlp_hidden_dim = int(dim * mlp_ratio) 206 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 207 | 208 | if self.shift_size > 0: 209 | # calculate attention mask for SW-MSA 210 | H, W = self.input_resolution 211 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 212 | h_slices = (slice(0, -self.window_size), 213 | slice(-self.window_size, -self.shift_size), 214 | slice(-self.shift_size, None)) 215 | w_slices = (slice(0, -self.window_size), 216 | slice(-self.window_size, -self.shift_size), 217 | slice(-self.shift_size, None)) 218 | cnt = 0 219 | for h in h_slices: 220 | for w in w_slices: 221 | img_mask[:, h, w, :] = cnt 222 | cnt += 1 223 | 224 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 225 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 226 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 227 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 228 | else: 229 | attn_mask = None 230 | 231 | self.register_buffer("attn_mask", attn_mask) 232 | 233 | def forward(self, x): 234 | H, W = self.input_resolution 235 | B, L, C = x.shape 236 | assert L == H * W, "input feature has wrong size" 237 | 238 | shortcut = x 239 | x = self.norm1(x) 240 | x = x.view(B, H, W, C) 241 | 242 | # cyclic shift 243 | if self.shift_size > 0: 244 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 245 | else: 246 | shifted_x = x 247 | 248 | # partition windows 249 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 250 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 251 | 252 | # W-MSA/SW-MSA 253 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 254 | 255 | # merge windows 256 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 257 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 258 | 259 | # reverse cyclic shift 260 | if self.shift_size > 0: 261 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 262 | else: 263 | x = shifted_x 264 | x = x.view(B, H * W, C) 265 | 266 | # FFN 267 | x = shortcut + self.drop_path(x) 268 | x = x + self.drop_path(self.mlp(self.norm2(x))) 269 | 270 | return x 271 | 272 | def extra_repr(self) -> str: 273 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 274 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 275 | 276 | def flops(self): 277 | flops = 0 278 | H, W = self.input_resolution 279 | # norm1 280 | flops += self.dim * H * W 281 | # W-MSA/SW-MSA 282 | nW = H * W / self.window_size / self.window_size 283 | flops += nW * self.attn.flops(self.window_size * self.window_size) 284 | # mlp 285 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 286 | # norm2 287 | flops += self.dim * H * W 288 | return flops 289 | 290 | 291 | class PatchMerging(nn.Module): 292 | r""" Patch Merging Layer. 293 | 294 | Args: 295 | input_resolution (tuple[int]): Resolution of input feature. 296 | dim (int): Number of input channels. 297 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 298 | """ 299 | 300 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 301 | super().__init__() 302 | self.input_resolution = input_resolution 303 | self.dim = dim 304 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 305 | self.norm = norm_layer(4 * dim) 306 | 307 | def forward(self, x): 308 | """ 309 | x: B, H*W, C 310 | """ 311 | H, W = self.input_resolution 312 | B, L, C = x.shape 313 | assert L == H * W, "input feature has wrong size" 314 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 315 | 316 | x = x.view(B, H, W, C) 317 | 318 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 319 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 320 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 321 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 322 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 323 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 324 | 325 | x = self.norm(x) 326 | x = self.reduction(x) 327 | 328 | return x 329 | 330 | def extra_repr(self) -> str: 331 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 332 | 333 | def flops(self): 334 | H, W = self.input_resolution 335 | flops = H * W * self.dim 336 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 337 | return flops 338 | 339 | 340 | class BasicLayer(nn.Module): 341 | """ A basic Swin Transformer layer for one stage. 342 | 343 | Args: 344 | dim (int): Number of input channels. 345 | input_resolution (tuple[int]): Input resolution. 346 | depth (int): Number of blocks. 347 | num_heads (int): Number of attention heads. 348 | window_size (int): Local window size. 349 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 350 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 351 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 352 | drop (float, optional): Dropout rate. Default: 0.0 353 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 354 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 355 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 356 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 357 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 358 | """ 359 | 360 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 361 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 362 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 363 | 364 | super().__init__() 365 | self.dim = dim 366 | self.input_resolution = input_resolution 367 | self.depth = depth 368 | self.use_checkpoint = use_checkpoint 369 | 370 | # build blocks 371 | self.blocks = nn.ModuleList([ 372 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 373 | num_heads=num_heads, window_size=window_size, 374 | shift_size=0 if (i % 2 == 0) else window_size // 2, 375 | mlp_ratio=mlp_ratio, 376 | qkv_bias=qkv_bias, qk_scale=qk_scale, 377 | drop=drop, attn_drop=attn_drop, 378 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 379 | norm_layer=norm_layer) 380 | for i in range(depth)]) 381 | 382 | # patch merging layer 383 | if downsample is not None: 384 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 385 | else: 386 | self.downsample = None 387 | 388 | def forward(self, x): 389 | for blk in self.blocks: 390 | if self.use_checkpoint: 391 | x = checkpoint.checkpoint(blk, x) 392 | else: 393 | x = blk(x) 394 | if self.downsample is not None: 395 | x = self.downsample(x) 396 | return x 397 | 398 | def extra_repr(self) -> str: 399 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 400 | 401 | def flops(self): 402 | flops = 0 403 | for blk in self.blocks: 404 | flops += blk.flops() 405 | if self.downsample is not None: 406 | flops += self.downsample.flops() 407 | return flops 408 | 409 | 410 | class PatchEmbed(nn.Module): 411 | r""" Image to Patch Embedding 412 | 413 | Args: 414 | img_size (int): Image size. Default: 224. 415 | patch_size (int): Patch token size. Default: 4. 416 | in_chans (int): Number of input image channels. Default: 3. 417 | embed_dim (int): Number of linear projection output channels. Default: 96. 418 | norm_layer (nn.Module, optional): Normalization layer. Default: None 419 | """ 420 | 421 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 422 | super().__init__() 423 | img_size = to_2tuple(img_size) 424 | patch_size = to_2tuple(patch_size) 425 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 426 | self.img_size = img_size 427 | self.patch_size = patch_size 428 | self.patches_resolution = patches_resolution 429 | self.num_patches = patches_resolution[0] * patches_resolution[1] 430 | 431 | self.in_chans = in_chans 432 | self.embed_dim = embed_dim 433 | 434 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 435 | if norm_layer is not None: 436 | self.norm = norm_layer(embed_dim) 437 | else: 438 | self.norm = None 439 | 440 | def forward(self, x): 441 | B, C, H, W = x.shape 442 | # FIXME look at relaxing size constraints 443 | assert H == self.img_size[0] and W == self.img_size[1], \ 444 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 445 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 446 | if self.norm is not None: 447 | x = self.norm(x) 448 | return x 449 | 450 | def flops(self): 451 | Ho, Wo = self.patches_resolution 452 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 453 | if self.norm is not None: 454 | flops += Ho * Wo * self.embed_dim 455 | return flops 456 | 457 | 458 | class SwinTransformer(nn.Module): 459 | r""" Swin Transformer 460 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 461 | https://arxiv.org/pdf/2103.14030 462 | 463 | Args: 464 | img_size (int | tuple(int)): Input image size. Default 224 465 | patch_size (int | tuple(int)): Patch size. Default: 4 466 | in_chans (int): Number of input image channels. Default: 3 467 | num_classes (int): Number of classes for classification head. Default: 1000 468 | embed_dim (int): Patch embedding dimension. Default: 96 469 | depths (tuple(int)): Depth of each Swin Transformer layer. 470 | num_heads (tuple(int)): Number of attention heads in different layers. 471 | window_size (int): Window size. Default: 7 472 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 473 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 474 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 475 | drop_rate (float): Dropout rate. Default: 0 476 | attn_drop_rate (float): Attention dropout rate. Default: 0 477 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 478 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 479 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 480 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 481 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 482 | """ 483 | 484 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 485 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 486 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 487 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 488 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 489 | use_checkpoint=False, **kwargs): 490 | super().__init__() 491 | 492 | self.num_classes = num_classes 493 | self.num_layers = len(depths) 494 | self.embed_dim = embed_dim 495 | self.ape = ape 496 | self.patch_norm = patch_norm 497 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 498 | self.mlp_ratio = mlp_ratio 499 | 500 | # split image into non-overlapping patches 501 | self.patch_embed = PatchEmbed( 502 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 503 | norm_layer=norm_layer if self.patch_norm else None) 504 | num_patches = self.patch_embed.num_patches 505 | patches_resolution = self.patch_embed.patches_resolution 506 | self.patches_resolution = patches_resolution 507 | 508 | # absolute position embedding 509 | if self.ape: 510 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 511 | trunc_normal_(self.absolute_pos_embed, std=.02) 512 | 513 | self.pos_drop = nn.Dropout(p=drop_rate) 514 | 515 | # stochastic depth 516 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 517 | 518 | # build layers 519 | self.layers = nn.ModuleList() 520 | for i_layer in range(self.num_layers): 521 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 522 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 523 | patches_resolution[1] // (2 ** i_layer)), 524 | depth=depths[i_layer], 525 | num_heads=num_heads[i_layer], 526 | window_size=window_size, 527 | mlp_ratio=self.mlp_ratio, 528 | qkv_bias=qkv_bias, qk_scale=qk_scale, 529 | drop=drop_rate, attn_drop=attn_drop_rate, 530 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 531 | norm_layer=norm_layer, 532 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 533 | use_checkpoint=use_checkpoint) 534 | self.layers.append(layer) 535 | 536 | self.norm = norm_layer(self.num_features) 537 | self.avgpool = nn.AdaptiveAvgPool1d(1) 538 | # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 539 | 540 | self.apply(self._init_weights) 541 | 542 | def _init_weights(self, m): 543 | if isinstance(m, nn.Linear): 544 | trunc_normal_(m.weight, std=.02) 545 | if isinstance(m, nn.Linear) and m.bias is not None: 546 | nn.init.constant_(m.bias, 0) 547 | elif isinstance(m, nn.LayerNorm): 548 | nn.init.constant_(m.bias, 0) 549 | nn.init.constant_(m.weight, 1.0) 550 | 551 | @torch.jit.ignore 552 | def no_weight_decay(self): 553 | return {'absolute_pos_embed'} 554 | 555 | @torch.jit.ignore 556 | def no_weight_decay_keywords(self): 557 | return {'relative_position_bias_table'} 558 | 559 | def forward_features(self, x): 560 | x = self.patch_embed(x) 561 | if self.ape: 562 | x = x + self.absolute_pos_embed 563 | x = self.pos_drop(x) 564 | 565 | for layer in self.layers: 566 | x = layer(x) 567 | 568 | x = self.norm(x) # B L C 569 | x = self.avgpool(x.transpose(1, 2)) # B C 1 570 | x = torch.flatten(x, 1) 571 | return x 572 | 573 | def forward(self, x): 574 | x = self.forward_features(x) 575 | # x = self.head(x) 576 | return x 577 | 578 | def flops(self): 579 | flops = 0 580 | flops += self.patch_embed.flops() 581 | for i, layer in enumerate(self.layers): 582 | flops += layer.flops() 583 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 584 | flops += self.num_features * self.num_classes 585 | return flops 586 | -------------------------------------------------------------------------------- /fire/runner.py: -------------------------------------------------------------------------------- 1 | import time 2 | import gc 3 | import os 4 | import datetime 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import cv2 9 | 10 | import torch.nn.functional as F 11 | 12 | from fire.runnertools import getSchedu, getOptimizer, getLossFunc 13 | from fire.runnertools import clipGradient 14 | from fire.metrics import getF1 15 | from fire.scheduler import GradualWarmupScheduler 16 | from fire.utils import printDash,firelog,delete_all_pycache_folders 17 | 18 | 19 | 20 | 21 | class FireRunner(): 22 | def __init__(self, cfg, model): 23 | 24 | self.cfg = cfg 25 | 26 | 27 | if self.cfg['GPU_ID'] != '' : 28 | self.device = torch.device("cuda") 29 | else: 30 | self.device = torch.device("cpu") 31 | 32 | self.model = model.to(self.device) 33 | 34 | 35 | self.scaler = torch.cuda.amp.GradScaler() 36 | ############################################################ 37 | 38 | 39 | # loss 40 | self.loss_func = getLossFunc(self.device, cfg) 41 | 42 | 43 | 44 | # optimizer 45 | self.optimizer = getOptimizer(self.cfg['optimizer'], 46 | self.model, 47 | self.cfg['learning_rate'], 48 | self.cfg['weight_decay']) 49 | 50 | 51 | # scheduler 52 | self.scheduler = getSchedu(self.cfg['scheduler'], self.optimizer) 53 | 54 | if self.cfg['warmup_epoch']: 55 | self.scheduler = GradualWarmupScheduler(optimizer, 56 | multiplier=1, 57 | total_epoch=self.cfg['warmup_epoch'], 58 | after_scheduler=self.scheduler) 59 | 60 | if self.cfg['show_heatmap']: 61 | self.extractor = ModelOutputs(self.model, self.model.features[12], ['0']) 62 | 63 | 64 | def freezeBeforeLinear(self, epoch, freeze_epochs = 2): 65 | if epoch0: 226 | new_id = max(exp_nums)+1 227 | exp_dir = os.path.join(self.cfg['save_dir'], 'exp'+str(new_id)) 228 | 229 | firelog("i", "save to %s" % exp_dir) 230 | #if not os.path.exists(exp_dir): 231 | os.makedirs(exp_dir) 232 | # os.system("cp -r fire %s/" % exp_dir) 233 | # delete_all_pycache_folders(exp_dir) 234 | # os.system("cp config.py %s/" % exp_dir) 235 | return exp_dir 236 | 237 | ################ 238 | 239 | def onTrainStart(self): 240 | 241 | self.last_best_value = 0 242 | self.last_best_dist = 0 243 | self.last_save_path = None 244 | 245 | self.earlystop = False 246 | self.best_epoch = 0 247 | 248 | # log 249 | self.log_time = time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time())) 250 | 251 | self.exp_dir = self.make_save_dir() 252 | 253 | def onTrainStep(self,train_loader, epoch): 254 | 255 | self.model.train() 256 | correct = 0 257 | count = 0 258 | batch_time = 0 259 | total_loss = 0 260 | for batch_idx, (data, target, img_names) in enumerate(train_loader): 261 | 262 | one_batch_time_start = time.time() 263 | 264 | target = target.to(self.device) 265 | 266 | data = data.to(self.device) 267 | 268 | with torch.cuda.amp.autocast(): 269 | output = self.model(data) 270 | #all_linear2_params = torch.cat([x.view(-1) for x in model.model_feature._fc.parameters()]) 271 | #l2_regularization = 0.0003 * torch.norm(all_linear2_params, 2) 272 | loss = self.loss_func(output[0], target, self.cfg['sample_weights'],sample_weight_img_names=img_names)# + l2_regularization.item() 273 | 274 | 275 | total_loss += loss.item() 276 | if self.cfg['clip_gradient']: 277 | clipGradient(self.optimizer, self.cfg['clip_gradient']) 278 | 279 | 280 | 281 | self.optimizer.zero_grad()#把梯度置零 282 | # loss.backward() #计算梯度 283 | # self.optimizer.step() #更新参数 284 | self.scaler.scale(loss).backward() 285 | self.scaler.step(self.optimizer) 286 | self.scaler.update() 287 | 288 | ### train acc 289 | pred_score = nn.Softmax(dim=1)(output[0]) 290 | pred = output[0].max(1, keepdim=True)[1] # get the index of the max log-probability 291 | if len(target.shape)>1: 292 | target = target.max(1, keepdim=True)[1] 293 | correct += pred.eq(target.view_as(pred)).sum().item() 294 | count += len(data) 295 | 296 | train_acc = correct / count 297 | train_loss = total_loss/count 298 | #print(train_acc) 299 | one_batch_time = time.time() - one_batch_time_start 300 | batch_time+=one_batch_time 301 | # print(batch_time/(batch_idx+1), len(train_loader), batch_idx, 302 | # int(one_batch_time*(len(train_loader)-batch_idx))) 303 | eta = int((batch_time/(batch_idx+1))*(len(train_loader)-batch_idx-1)) 304 | 305 | 306 | print_epoch = ''.join([' ']*(4-len(str(epoch+1))))+str(epoch+1) 307 | print_epoch_total = str(self.cfg['epochs'])+''.join([' ']*(4-len(str(self.cfg['epochs'])))) 308 | 309 | log_interval = 10 310 | if batch_idx % log_interval== 0: 311 | print('\r', 312 | '{}/{} [{}/{} ({:.0f}%)] - ETA: {}, loss: {:.4f}, acc: {:.4f} LR: {:f}'.format( 313 | print_epoch, print_epoch_total, batch_idx * len(data), len(train_loader.dataset), 314 | 100. * batch_idx / len(train_loader), 315 | datetime.timedelta(seconds=eta), 316 | train_loss,train_acc, 317 | self.optimizer.param_groups[0]["lr"]), 318 | end="",flush=True) 319 | 320 | 321 | 322 | 323 | def onTrainEnd(self): 324 | save_name = 'last.pt' 325 | self.last_save_path = os.path.join(self.exp_dir, save_name) 326 | self.modelSave(self.last_save_path) 327 | 328 | del self.model 329 | gc.collect() 330 | torch.cuda.empty_cache() 331 | 332 | if self.cfg["cfg_verbose"]: 333 | printDash() 334 | print(self.cfg) 335 | printDash() 336 | 337 | 338 | def onValidation(self, val_loader, epoch): 339 | 340 | self.model.eval() 341 | self.val_loss = 0 342 | self.correct = 0 343 | 344 | 345 | with torch.no_grad(): 346 | pres = [] 347 | labels = [] 348 | for (data, target, img_names) in val_loader: 349 | data, target = data.to(self.device), target.to(self.device) 350 | 351 | with torch.cuda.amp.autocast(): 352 | output = self.model(data) 353 | self.val_loss += self.loss_func(output[0], target).item() # sum up batch loss 354 | 355 | #print(output.shape) 356 | pred_score = nn.Softmax(dim=1)(output[0]) 357 | #print(pred_score.shape) 358 | pred = output[0].max(1, keepdim=True)[1] # get the index of the max log-probability 359 | if self.cfg['use_distill']: 360 | target = target.max(1, keepdim=True)[1] 361 | self.correct += pred.eq(target.view_as(pred)).sum().item() 362 | 363 | 364 | batch_pred_score = pred_score.data.cpu().numpy().tolist() 365 | batch_label_score = target.data.cpu().numpy().tolist() 366 | pres.extend(batch_pred_score) 367 | labels.extend(batch_label_score) 368 | 369 | #print('\n',output[0],img_names[0]) 370 | pres = np.array(pres) 371 | labels = np.array(labels) 372 | #print(pres.shape, labels.shape) 373 | 374 | self.val_loss /= len(val_loader.dataset) 375 | self.val_acc = self.correct / len(val_loader.dataset) 376 | self.best_score = self.val_acc 377 | 378 | if 'F1' in self.cfg['metrics']: 379 | #print(labels) 380 | precision, recall, f1_score = getF1(pres, labels) 381 | print(' \n [VAL] loss: {:.5f}, acc: {:.3f}%, precision: {:.5f}, recall: {:.5f}, f1_score: {:.5f}\n'.format( 382 | self.val_loss, 100. * self.val_acc, precision, recall, f1_score)) 383 | self.best_score = f1_score 384 | 385 | else: 386 | print(' \n [VAL] loss: {:.5f}, acc: {:.3f}% \n'.format( 387 | self.val_loss, 100. * self.val_acc)) 388 | 389 | 390 | if self.cfg['warmup_epoch']: 391 | self.scheduler.step(epoch) 392 | else: 393 | if 'default' in self.cfg['scheduler']: 394 | self.scheduler.step(self.best_score) 395 | else: 396 | self.scheduler.step() 397 | 398 | 399 | self.checkpoint(epoch) 400 | self.earlyStop(epoch) 401 | 402 | 403 | 404 | 405 | def onTest(self): 406 | self.model.eval() 407 | 408 | #predict 409 | res_list = [] 410 | with torch.no_grad(): 411 | #end = time.time() 412 | for i, (inputs, target, img_names) in enumerate(data_loader): 413 | print("\r",str(i)+"/"+str(test_loader.__len__()),end="",flush=True) 414 | 415 | inputs = inputs.cuda() 416 | 417 | output = model(inputs) 418 | output = output.data.cpu().numpy() 419 | 420 | for i in range(output.shape[0]): 421 | 422 | output_one = output[i][np.newaxis, :] 423 | output_one = np.argmax(output_one) 424 | 425 | res_list.append(output_one) 426 | return res_list 427 | 428 | 429 | 430 | def earlyStop(self, epoch): 431 | ### earlystop 432 | if self.best_score>self.last_best_value: 433 | self.last_best_value = self.best_score 434 | self.last_best_dist = 0 435 | 436 | self.last_best_dist+=1 437 | if self.last_best_dist>self.cfg['early_stop_patient']: 438 | self.best_epoch = epoch-self.cfg['early_stop_patient']+1 439 | print("[INFO] Early Stop with patient %d , best is Epoch - %d :%f" % (self.cfg['early_stop_patient'],self.best_epoch,self.last_best_value)) 440 | self.earlystop = True 441 | if epoch+1==self.cfg['epochs']: 442 | self.best_epoch = epoch-self.last_best_dist+2 443 | print("[INFO] Finish trainging , best is Epoch - %d :%f" % (self.best_epoch,self.last_best_value)) 444 | self.earlystop = True 445 | 446 | def checkpoint(self, epoch): 447 | 448 | if self.best_score<=self.last_best_value: 449 | pass 450 | else: 451 | save_name = 'best.pt' 452 | self.last_save_path = os.path.join(self.exp_dir, save_name) 453 | self.modelSave(self.last_save_path) 454 | 455 | 456 | 457 | 458 | def modelLoad(self,model_path, data_parallel = False): 459 | self.model.load_state_dict(torch.load(model_path), strict=True) 460 | 461 | if data_parallel: 462 | self.model = torch.nn.DataParallel(self.model) 463 | 464 | def modelSave(self, save_name): 465 | torch.save(self.model.state_dict(), save_name) 466 | 467 | def toOnnx(self, save_name= "model.onnx"): 468 | dummy_input = torch.randn(1, 3, self.cfg['img_size'][0], self.cfg['img_size'][1]).to(self.device) 469 | 470 | torch.onnx.export(self.model, 471 | dummy_input, 472 | os.path.join(self.cfg['save_dir'],save_name), 473 | verbose=True) 474 | 475 | 476 | -------------------------------------------------------------------------------- /fire/runnertools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.optim as optim 5 | 6 | from fire.loss import FocalLoss, CrossEntropyLoss,CrossEntropyLossV2 7 | 8 | 9 | def getSchedu(schedu, optimizer): 10 | if 'default' in schedu: 11 | factor = float(schedu.strip().split('-')[1]) 12 | patience = int(schedu.strip().split('-')[2]) 13 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 14 | mode='max', factor=factor, patience=patience,min_lr=0.000001) 15 | elif 'step' in schedu: 16 | step_size = int(schedu.strip().split('-')[1]) 17 | gamma = int(schedu.strip().split('-')[2]) 18 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma, last_epoch=-1) 19 | elif 'SGDR' in schedu: 20 | T_0 = int(schedu.strip().split('-')[1]) 21 | T_mult = int(schedu.strip().split('-')[2]) 22 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 23 | T_0=T_0, 24 | T_mult=T_mult) 25 | elif 'multi' in schedu: 26 | milestones = [int(x) for x in schedu.strip().split('-')[1].split(',')] 27 | gamma = float(schedu.strip().split('-')[2]) 28 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma, last_epoch=-1) 29 | else: 30 | raise Exception("Unkown getSchedu: ", schedu) 31 | return scheduler 32 | 33 | 34 | def getOptimizer(optims, model, learning_rate, weight_decay): 35 | if optims=='Adam': 36 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 37 | elif optims=='AdamW': 38 | optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 39 | elif optims=='SGD': 40 | optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) 41 | elif optims=='AdaBelief': 42 | optimizer = AdaBelief(model.parameters(), lr=learning_rate, eps=1e-12, betas=(0.9,0.999)) 43 | # elif optims=='Ranger': 44 | # optimizer = Ranger(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 45 | else: 46 | raise Exception("Unkown getSchedu: ", optims) 47 | return optimizer 48 | 49 | 50 | def getLossFunc(device, cfg): 51 | # loss 52 | 53 | if cfg['class_weight']: 54 | class_weight = torch.DoubleTensor(cfg['class_weight']).to(device) 55 | else: 56 | class_weight = None 57 | 58 | if 'Focalloss' in cfg['loss']: 59 | gamma = float(cfg['loss'].strip().split('-')[1]) 60 | loss_func = FocalLoss(label_smooth=cfg['label_smooth'], 61 | gamma=gamma, 62 | class_weight=class_weight).to(device) 63 | 64 | elif 'CEV2' in cfg['loss']: 65 | gamma = float(cfg['loss'].strip().split('-')[1]) 66 | loss_func = CrossEntropyLossV2(label_smooth=cfg['label_smooth'], 67 | class_weight=class_weight).to(device) 68 | else: 69 | ### origin CE 70 | # loss_func = torch.nn.CrossEntropyLoss(weight=class_weight).to(device) 71 | loss_func = CrossEntropyLoss(label_smooth=cfg['label_smooth'], 72 | class_weight=class_weight).to(device) 73 | #self.loss_func = CrossEntropyLossOneHot().to(self.device) 74 | 75 | return loss_func 76 | 77 | 78 | 79 | 80 | ############### Tools 81 | 82 | def clipGradient(optimizer, grad_clip=1): 83 | """ 84 | Clips gradients computed during backpropagation to avoid explosion of gradients. 85 | 86 | :param optimizer: optimizer with the gradients to be clipped 87 | :param grad_clip: clip value 88 | """ 89 | for group in optimizer.param_groups: 90 | for param in group["params"]: 91 | if param.grad is not None: 92 | param.grad.data.clamp_(-grad_clip, grad_clip) 93 | 94 | 95 | -------------------------------------------------------------------------------- /fire/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler,ReduceLROnPlateau 2 | 3 | 4 | class GradualWarmupScheduler(_LRScheduler): 5 | """ Gradually warm-up(increasing) learning rate in optimizer. 6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 7 | 8 | Args: 9 | optimizer (Optimizer): Wrapped optimizer. 10 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 11 | total_epoch: target learning rate is reached at total_epoch, gradually 12 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 13 | """ 14 | 15 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 16 | self.multiplier = multiplier 17 | if self.multiplier < 1.: 18 | raise ValueError('multiplier should be greater than or equal to 1.') 19 | self.total_epoch = total_epoch 20 | self.after_scheduler = after_scheduler 21 | self.finished = False 22 | super(GradualWarmupScheduler, self).__init__(optimizer) 23 | 24 | def get_lr(self): 25 | if self.last_epoch > self.total_epoch: 26 | if self.after_scheduler: 27 | if not self.finished: 28 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 29 | self.finished = True 30 | #return self.after_scheduler.get_last_lr() 31 | return [group['lr'] for group in self.optimizer.param_groups] 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | #self._last_lr = self.after_scheduler.get_last_lr() 61 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 62 | else: 63 | return super(GradualWarmupScheduler, self).step(epoch) 64 | else: 65 | self.step_ReduceLROnPlateau(metrics, epoch) 66 | -------------------------------------------------------------------------------- /fire/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import random 5 | import numpy as np 6 | 7 | 8 | VERSION = "1.1" 9 | 10 | def setRandomSeed(seed=42): 11 | """Reproducer for pytorch experiment. 12 | 13 | Parameters 14 | ---------- 15 | seed: int, optional (default = 2019) 16 | Radnom seed. 17 | 18 | Example 19 | ------- 20 | setRandomSeed(seed=2019). 21 | """ 22 | random.seed(seed) 23 | os.environ["PYTHONHASHSEED"] = str(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | if torch.cuda.is_available(): 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.enabled = True 32 | 33 | 34 | def printDash(num = 50): 35 | print(''.join(['-']*num)) 36 | 37 | 38 | def initFire(cfg): 39 | 40 | if cfg["cfg_verbose"]: 41 | printDash() 42 | print(cfg) 43 | printDash() 44 | 45 | print("[INFO] Fire verison: "+VERSION) 46 | 47 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg['GPU_ID'] 48 | setRandomSeed(cfg['random_seed']) 49 | 50 | if not os.path.exists(cfg['save_dir']): 51 | os.makedirs(cfg['save_dir']) 52 | 53 | 54 | 55 | def npSoftmax(x): 56 | x_row_max = x.max(axis=-1) 57 | x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) 58 | x = x - x_row_max 59 | x_exp = np.exp(x) 60 | x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) 61 | softmax = x_exp / x_exp_row_sum 62 | return softmax 63 | 64 | 65 | def firelog(mode='i',text=''): 66 | if mode=='i': 67 | print("[INFO] ",text) 68 | 69 | 70 | def delete_all_pycache_folders(dir_path): 71 | for dirpath, dirnames, filenames in os.walk(dir_path): 72 | for dirname in dirnames: 73 | if dirname == "__pycache__": 74 | #os.rmdir(os.path.join(dirpath, dirname)) 75 | os.system("rm -rf %s" % os.path.join(dirpath, dirname)) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os,argparse 2 | import random 3 | 4 | from fire import initFire, FireModel, FireRunner, FireData 5 | 6 | from config import cfg 7 | import pandas as pd 8 | 9 | 10 | def softmax(x): 11 | x_row_max = x.max(axis=-1) 12 | x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) 13 | x = x - x_row_max 14 | x_exp = np.exp(x) 15 | x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) 16 | softmax = x_exp / x_exp_row_sum 17 | return softmax 18 | 19 | 20 | 21 | 22 | def predict(cfg): 23 | 24 | initFire(cfg) 25 | 26 | 27 | model = FireModel(cfg) 28 | 29 | 30 | 31 | data = FireData(cfg) 32 | # data.showTrainData() 33 | # b 34 | 35 | test_loader = data.getTestDataloader() 36 | 37 | 38 | runner = FireRunner(cfg, model) 39 | 40 | #print(model) 41 | runner.modelLoad(cfg['model_path']) 42 | 43 | 44 | res_dict = runner.predict(test_loader) 45 | print(len(res_dict)) 46 | 47 | # to csv 48 | res_df = pd.DataFrame.from_dict(res_dict, orient='index', columns=['label']) 49 | res_df = res_df.reset_index().rename(columns={'index':'image_id'}) 50 | res_df.to_csv(os.path.join(cfg['save_dir'], 'pre.csv'), 51 | index=False,header=True) 52 | 53 | 54 | def predictMerge(cfg): 55 | initFire(cfg) 56 | 57 | 58 | model = FireModel(cfg) 59 | 60 | 61 | 62 | data = FireData(cfg) 63 | # data.showTrainData() 64 | # b 65 | 66 | test_loader = data.getTestDataloader() 67 | runner1 = FireRunner(cfg, model) 68 | runner1.modelLoad('output/efficientnet-b6_e17_fold0_0.93368.pth') 69 | print("load model1, start running.") 70 | res_dict1 = runner1.predictRaw(test_loader) 71 | print(len(res_dict1)) 72 | 73 | test_loader = data.getTestDataloader() 74 | runner2 = FireRunner(cfg, model) 75 | runner2.modelLoad('output/efficientnet-b6_e18_fold1_0.94537.pth') 76 | print("load model2, start running.") 77 | res_dict2 = runner2.predictRaw(test_loader) 78 | 79 | test_loader = data.getTestDataloader() 80 | runner3 = FireRunner(cfg, model) 81 | runner3.modelLoad('output/efficientnet-b6_e14_fold2_0.91967.pth') 82 | print("load model3, start running.") 83 | res_dict3 = runner3.predictRaw(test_loader) 84 | 85 | test_loader = data.getTestDataloader() 86 | runner4 = FireRunner(cfg, model) 87 | runner4.modelLoad('output/efficientnet-b6_e18_fold3_0.92239.pth') 88 | print("load model4, start running.") 89 | res_dict4 = runner4.predictRaw(test_loader) 90 | 91 | # test_loader = data.getTestDataloader() 92 | # runner5 = FireRunner(cfg, model) 93 | # runner5.modelLoad('output/efficientnet-b6_e17_fold0_0.93368.pth') 94 | # print("load model5, start running.") 95 | # res_dict5 = runner5.predictRaw(test_loader) 96 | 97 | 98 | res_dict = {} 99 | for k,v in res_dict1.items(): 100 | #print(k,v) 101 | v1 =np.argmax(v+res_dict2[k]+res_dict3[k]+res_dict4[k]) 102 | res_dict[k] = v1 103 | 104 | res_list = sorted(res_dict.items(), key = lambda kv: int(kv[0].split("_")[-1].split('.')[0])) 105 | print(len(res_list), res_list[0]) 106 | 107 | # to csv 108 | # res_list_final = [] 109 | # for res in res_list: 110 | # res_list_final.append([res[0]]+res[1]) 111 | # #res_df = pd.DataFrame.from_dict(res_dict, orient='index', columns=['type']) 112 | # #res_df = res_df.reset_index().rename(columns={'index':'id'}) 113 | # res_df = DataFrame(res_list_final, columns=['id','type','color','toward']) 114 | 115 | 116 | # res_df.to_csv(os.path.join(cfg['save_dir'], 'result.csv'), 117 | # index=False,header=True) 118 | 119 | with open('result.csv', 'w', encoding='utf-8') as f: 120 | f.write('file,label\n') 121 | for i in range(len(res_list)): 122 | line = [res_list[i][0], str(res_list[i][1])] 123 | line = ','.join(line) 124 | f.write(line+"\n") 125 | 126 | 127 | def predictTTA(cfg): 128 | 129 | pass 130 | 131 | 132 | def predictMergeTTA(cfg): 133 | 134 | pass 135 | 136 | 137 | def main(cfg): 138 | 139 | if cfg["merge"]: 140 | if cfg["TTA"]: 141 | predictMergeTTA(cfg) 142 | else: 143 | predictMerge(cfg) 144 | else: 145 | if cfg["TTA"]: 146 | predictTTA(cfg) 147 | else: 148 | predict(cfg) 149 | 150 | 151 | 152 | 153 | 154 | 155 | if __name__ == '__main__': 156 | main(cfg) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | opencv-python 4 | torch 5 | torchvision 6 | scikit-learn 7 | pretrainedmodels 8 | albumentations 9 | timm 10 | yacs -------------------------------------------------------------------------------- /scripts/cleanData.py: -------------------------------------------------------------------------------- 1 | import os,argparse 2 | import random 3 | 4 | from fire import initFire, FireModel, FireRunner, FireData 5 | 6 | from config import cfg 7 | 8 | 9 | 10 | 11 | def main(cfg): 12 | 13 | 14 | initFire(cfg) 15 | 16 | 17 | model = FireModel(cfg) 18 | 19 | 20 | 21 | data = FireData(cfg) 22 | # data.showTrainData() 23 | # b 24 | 25 | train_loader = data.getTrainDataloader() 26 | 27 | 28 | runner = FireRunner(cfg, model) 29 | 30 | runner.modelLoad(cfg['model_path']) 31 | 32 | move_dir = '../data/dataset/d_trainval/v8/tmp' 33 | #"../data/dataset/a_raw_data/3_our/0_bigimg/add/antispoof20201214/true" 34 | #"../data/dataset/d_trainval/v8/tmp" 35 | target_label = 1 36 | runner.cleanData(train_loader, target_label, move_dir) 37 | 38 | 39 | 40 | if __name__ == '__main__': 41 | main(cfg) -------------------------------------------------------------------------------- /scripts/convert_onnx.py: -------------------------------------------------------------------------------- 1 | import os,argparse 2 | import random 3 | import torch 4 | from fire import initFire, FireModel, FireRunner, FireData 5 | 6 | from config import cfg 7 | import pandas as pd 8 | 9 | 10 | 11 | def main(cfg): 12 | 13 | 14 | initFire(cfg) 15 | 16 | 17 | model = FireModel(cfg) 18 | 19 | 20 | 21 | data = FireData(cfg) 22 | # data.showTrainData() 23 | # b 24 | 25 | test_loader = data.getTestDataloader() 26 | 27 | 28 | runner = FireRunner(cfg, model) 29 | 30 | #print(model) 31 | runner.modelLoad(cfg['model_path']) 32 | 33 | 34 | runner.model.eval() 35 | runner.model.to("cuda") 36 | 37 | #data type nchw 38 | dummy_input1 = torch.randn(1, 3, 224, 224).to("cuda") 39 | input_names = [ "input1"] #自己命名 40 | output_names = [ "output1"] 41 | # torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names) 42 | torch.onnx.export(model, dummy_input1, "output/model.onnx", 43 | verbose=True, input_names=input_names, output_names=output_names, 44 | do_constant_folding=True) 45 | 46 | 47 | 48 | 49 | if __name__ == '__main__': 50 | main(cfg) -------------------------------------------------------------------------------- /scripts/heatmap.py: -------------------------------------------------------------------------------- 1 | import os,argparse 2 | import random 3 | 4 | from fire import initFire, FireModel, FireRunner, FireData 5 | 6 | from config import cfg 7 | import pandas as pd 8 | 9 | 10 | 11 | def main(cfg): 12 | 13 | cfg['test_batch_size'] = 1 14 | initFire(cfg) 15 | 16 | 17 | model = FireModel(cfg) 18 | # print(model) 19 | # b 20 | 21 | data = FireData(cfg) 22 | # data.showTrainData() 23 | # b 24 | 25 | test_loader = data.getTestDataloader() 26 | 27 | 28 | runner = FireRunner(cfg, model) 29 | 30 | #print(model) 31 | runner.modelLoad(cfg['model_path'], data_parallel = False) 32 | 33 | show_count = 1 34 | runner.heatmap(test_loader, cfg["save_dir"], show_count) 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | if __name__ == '__main__': 43 | main(cfg) -------------------------------------------------------------------------------- /scripts/make_fashionmnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import gzip 5 | import numpy as np 6 | import cv2 7 | 8 | 9 | 10 | 11 | data_dir = "./data" 12 | 13 | 14 | train = 'train' 15 | val = 'val' 16 | test = 'test' 17 | 18 | 19 | 20 | def load_mnist(path, kind='train'): 21 | """Load MNIST data from `path`""" 22 | labels_path = os.path.join(path, 23 | '%s-labels-idx1-ubyte.gz' 24 | % kind) 25 | images_path = os.path.join(path, 26 | '%s-images-idx3-ubyte.gz' 27 | % kind) 28 | 29 | with gzip.open(labels_path, 'rb') as lbpath: 30 | labels = np.frombuffer(lbpath.read(), dtype=np.uint8, 31 | offset=8) 32 | 33 | with gzip.open(images_path, 'rb') as imgpath: 34 | images = np.frombuffer(imgpath.read(), dtype=np.uint8, 35 | offset=16).reshape(len(labels), 784) 36 | 37 | return images, labels 38 | 39 | 40 | def save_imgs(data, label, save_dir='train'): 41 | print(data.shape, label.shape)#(60000, 784) (60000,) 42 | data_save_dir = os.path.join(data_dir, save_dir) 43 | if not os.path.exists(data_save_dir): 44 | os.mkdir(data_save_dir) 45 | 46 | #make dir for each categray 47 | for i in range(10): 48 | cate_dir = os.path.join(data_save_dir, str(i)) 49 | if not os.path.exists(cate_dir): 50 | os.mkdir(cate_dir) 51 | 52 | for i in range(len(data)): 53 | img = np.reshape(data[i], (28,28)) 54 | img = cv2.resize(img, (224,224)) 55 | gt = label[i] 56 | img_path = os.path.join(data_save_dir, str(gt), str(i)+".jpg") 57 | cv2.imwrite(img_path, img) 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | if __name__ == '__main__': 67 | 68 | X_train, y_train = load_mnist(data_dir, kind='train') 69 | X_test, y_test = load_mnist(data_dir, kind='t10k') 70 | 71 | X_val, y_val = X_train[59000:], y_train[59000:] 72 | X_train, y_train = X_train[:59000], y_train[:59000] 73 | 74 | save_imgs(X_train, y_train, save_dir=train) 75 | save_imgs(X_val, y_val, save_dir=val) 76 | save_imgs(X_test, y_test, save_dir=test) 77 | -------------------------------------------------------------------------------- /scripts/nohup_train.sh: -------------------------------------------------------------------------------- 1 | nohup python -u train.py > nohup.log 2>&1 & -------------------------------------------------------------------------------- /scripts/predictTTA.py: -------------------------------------------------------------------------------- 1 | import os,argparse 2 | import random 3 | 4 | from fire import initFire, FireModel, FireRunner, FireData 5 | 6 | from config import cfg 7 | import pandas as pd 8 | from pandas.core.frame import DataFrame 9 | import numpy as np 10 | 11 | 12 | CATES1 = ['LongSleeve', 'ShortSleeve', 'NoSleeve'] 13 | CATES2 = ['Solidcolor', 'multicolour', 'lattice'] 14 | CATES3 = ['Short', 'Long', 'middle', 'Bald'] 15 | CATES4 = ['Skirt', 'Trousers', 'Shorts'] 16 | CATES5 = ['Solidcolor', 'multicolour', 'lattice'] 17 | CATES6 = ['Sandals', 'Sneaker', 'LeatherShoes', 'else'] 18 | CATES7 = ['left', 'right', 'back', 'front'] 19 | 20 | 21 | def flipRes(d): 22 | for k,v in d.items(): 23 | # print(d[k][2]) 24 | d[k][6][0],d[k][6][1] = d[k][6][1],d[k][6][0] 25 | # print(d[k][2]) 26 | # b 27 | return d 28 | 29 | 30 | def _colorDecode(color): 31 | for j in range(len(color)): 32 | new_c = int(color[j]*10+0.5) 33 | color[j]=new_c 34 | if sum(color)<10: 35 | #print(color[np.argmax(color)],sum(color),color) 36 | color[np.argmax(color)] += 10-int(sum(color)) 37 | #print(color[np.argmax(color)]) 38 | 39 | color /=10 40 | return color 41 | 42 | 43 | def softmax(x): 44 | x_row_max = x.max(axis=-1) 45 | x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1]) 46 | x = x - x_row_max 47 | x_exp = np.exp(x) 48 | x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1]) 49 | softmax = x_exp / x_exp_row_sum 50 | return softmax 51 | 52 | 53 | def main(cfg): 54 | 55 | 56 | initFire(cfg) 57 | 58 | 59 | model = FireModel(cfg) 60 | data = FireData(cfg) 61 | # data.showTrainData() 62 | # b 63 | 64 | test_loader = data.getTestDataloader() 65 | test_loader2 = data.getTestDataloaderFlip() 66 | runner = FireRunner(cfg, model) 67 | runner.modelLoad('output/convnext_large_22k_1k_384_e5_fold0_0.75868.pth') 68 | print("load model1, start running.") 69 | model1_res_dict1 = runner.predictRaw(test_loader) 70 | model1_res_dict2 = runner.predictRaw(test_loader2) 71 | model1_res_dict2 = flipRes(model1_res_dict2) 72 | 73 | test_loader = data.getTestDataloader() 74 | test_loader2 = data.getTestDataloaderFlip() 75 | runner = FireRunner(cfg, model) 76 | runner.modelLoad('output/convnext_large_22k_1k_384_e14_fold0_0.75514.pth') 77 | print("load model2, start running.") 78 | model2_res_dict1 = runner.predictRaw(test_loader) 79 | model2_res_dict2 = runner.predictRaw(test_loader2) 80 | model2_res_dict2 = flipRes(model2_res_dict2) 81 | 82 | test_loader = data.getTestDataloader() 83 | test_loader2 = data.getTestDataloaderFlip() 84 | runner = FireRunner(cfg, model) 85 | runner.modelLoad('output/convnext_large_22k_1k_384_e15_fold0_0.75059.pth') 86 | print("load model3, start running.") 87 | model3_res_dict1 = runner.predictRaw(test_loader) 88 | model3_res_dict2 = runner.predictRaw(test_loader2) 89 | model3_res_dict2 = flipRes(model3_res_dict2) 90 | 91 | 92 | 93 | 94 | res_dict = {} 95 | for k,v in model1_res_dict1.items(): 96 | #print(k,v) 97 | v1 = CATES1[np.argmax(v[0]+model1_res_dict2[k][0]+model2_res_dict1[k][0]+model2_res_dict2[k][0]+model3_res_dict1[k][0]+model3_res_dict2[k][0])] 98 | v2 = CATES2[np.argmax(v[1]+model1_res_dict2[k][1]+model2_res_dict1[k][1]+model2_res_dict2[k][1]+model3_res_dict1[k][1]+model3_res_dict2[k][1])] 99 | v3 = CATES3[np.argmax(v[2]+model1_res_dict2[k][2]+model2_res_dict1[k][2]+model2_res_dict2[k][2]+model3_res_dict1[k][2]+model3_res_dict2[k][2])] 100 | v4 = CATES4[np.argmax(v[3]+model1_res_dict2[k][3]+model2_res_dict1[k][3]+model2_res_dict2[k][3]+model3_res_dict1[k][3]+model3_res_dict2[k][3])] 101 | v5 = CATES5[np.argmax(v[4]+model1_res_dict2[k][4]+model2_res_dict1[k][4]+model2_res_dict2[k][4]+model3_res_dict1[k][4]+model3_res_dict2[k][4])] 102 | v6 = CATES6[np.argmax(v[5]+model1_res_dict2[k][5]+model2_res_dict1[k][5]+model2_res_dict2[k][5]+model3_res_dict1[k][5]+model3_res_dict2[k][5])] 103 | v7 = CATES7[np.argmax(v[6]+model1_res_dict2[k][6]+model2_res_dict1[k][6]+model2_res_dict2[k][6]+model3_res_dict1[k][6]+model3_res_dict2[k][6])] 104 | v8 = softmax(v[7]+model1_res_dict2[k][7]+model2_res_dict1[k][7]+model2_res_dict2[k][7]+model3_res_dict1[k][7]+model3_res_dict2[k][7]) 105 | v9 = softmax(v[8]+model1_res_dict2[k][8]+model2_res_dict1[k][8]+model2_res_dict2[k][8]+model3_res_dict1[k][8]+model3_res_dict2[k][8]) 106 | 107 | v8 = _colorDecode(v8) 108 | v9 = _colorDecode(v9) 109 | res_dict[k] = [v1,v2,v3, 110 | v4,v5,v6, 111 | v7,*v8,*v9] 112 | 113 | res_list = sorted(res_dict.items(), key = lambda kv: int(kv[0].split("_")[-1].split('.')[0])) 114 | #print(len(res_list), res_list[0]) 115 | # to csv 116 | with open('result.csv', 'w', encoding='utf-8') as f: 117 | f.write('name,upperLength,clothesStyles,hairStyles,lowerLength,lowerStyles,shoesStyles,towards,upperBlack,upperBrown,upperBlue,upperGreen,upperGray,upperOrange,upperPink,upperPurple,upperRed,upperWhite,upperYellow,lowerBlack,lowerBrown,lowerBlue,lowerGreen,lowerGray,lowerOrange,lowerPink,lowerPurple,lowerRed,lowerWhite,lowerYellow\n') 118 | for i in range(len(res_list)): 119 | line = [res_list[i][0],res_list[i][1][0],res_list[i][1][1],res_list[i][1][2], 120 | res_list[i][1][3],res_list[i][1][4],res_list[i][1][5], 121 | res_list[i][1][6], 122 | ','.join([str(x) if x>0 else '' for x in res_list[i][1][7:18]]), 123 | ','.join([str(x) if x>0 else '' for x in res_list[i][1][18:]])] 124 | line = ','.join(line) 125 | line = line.replace('1.0','1') 126 | f.write(line+"\n") 127 | 128 | 129 | 130 | if __name__ == '__main__': 131 | main(cfg) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os,argparse 2 | import random 3 | 4 | from fire import initFire, FireModel, FireRunner, FireData 5 | 6 | from config import cfg 7 | 8 | 9 | 10 | 11 | def main(cfg): 12 | 13 | initFire(cfg) 14 | 15 | model = FireModel(cfg) 16 | 17 | data = FireData(cfg) 18 | 19 | if cfg['show_data']: 20 | data.showTrainData() 21 | else: 22 | train_loader, val_loader = data.getTrainValDataloader() 23 | 24 | runner = FireRunner(cfg, model) 25 | runner.train(train_loader, val_loader) 26 | 27 | 28 | if __name__ == '__main__': 29 | main(cfg) --------------------------------------------------------------------------------