├── .gitignore ├── README.md ├── augmentations ├── __init__.py ├── augmentations.py └── autoaugment.py ├── cifar100.py ├── dataset ├── __init__.py └── dataset.py ├── input └── .gitkeep ├── metrics ├── __init__.py └── metrics.py └── model ├── __init__.py ├── efficientnet.py └── swish.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./**/*.pyc 2 | ./**/*.ipynb 3 | ./**/*.pth 4 | ./**/__pycache__ 5 | .idea 6 | .DS_Store 7 | .secret.json 8 | ./input/cifar-100-python 9 | ./input/cifar-100-python.tar.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientNet 2 | PyTorch implementation for: 3 | 4 | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) 5 | 6 | Note: 7 | - I have not yet got the result like a paper. 8 | - And I publish the experimental results of CIFAR-100, so please take a look [here](https://www.comet.ml/katsura-jp/efficientnet) . 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /augmentations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katsura-jp/efficientnet-pytorch/df0506c0b23e920b69f74b793f8ba1ddc1df54de/augmentations/__init__.py -------------------------------------------------------------------------------- /augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import albumentations as albu 5 | from albumentations.pytorch import ToTensor 6 | 7 | class TrainAugment(object): 8 | def __init__(self): 9 | self.Compose = albu.Compose([ 10 | albu.PadIfNeeded(min_height=40, min_width=40, border_mode=0, value=[0,0,0], always_apply=True), 11 | albu.Cutout(num_holes=3, max_h_size=4, max_w_size=4, p=0.5), 12 | albu.HorizontalFlip(p=0.5), 13 | albu.RandomCrop(height=32, width=32, always_apply=True), 14 | albu.ToFloat(max_value=None, always_apply=True), 15 | ToTensor(normalize={'mean': [0.5071, 0.4867, 0.4408], 'std': [0.2675, 0.2565, 0.2761]}) 16 | ]) 17 | def __call__(self, image): 18 | transformed = self.Compose(image=image) 19 | image = transformed['image'] 20 | return image 21 | 22 | 23 | class TestAugment(object): 24 | def __init__(self): 25 | self.Compose = albu.Compose([ 26 | albu.HorizontalFlip(p=0), 27 | albu.ToFloat(max_value=None, always_apply=True), 28 | ToTensor(normalize={'mean': [0.5071, 0.4867, 0.4408], 'std': [0.2675, 0.2565, 0.2761]}) 29 | ]) 30 | def __call__(self, image): 31 | transformed = self.Compose(image=image) 32 | image = transformed['image'] 33 | return image -------------------------------------------------------------------------------- /augmentations/autoaugment.py: -------------------------------------------------------------------------------- 1 | # Reference 2 | # URL : https://github.com/DeepVoltaire/AutoAugment/raw/master/autoaugment.py 3 | ##### 4 | 5 | from PIL import Image, ImageEnhance, ImageOps 6 | import numpy as np 7 | import random 8 | 9 | 10 | class ImageNetPolicy(object): 11 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 12 | 13 | Example: 14 | >>> policy = ImageNetPolicy() 15 | >>> transformed = policy(image) 16 | 17 | Example as a PyTorch Transform: 18 | >>> transform=transforms.Compose([ 19 | >>> transforms.Resize(256), 20 | >>> ImageNetPolicy(), 21 | >>> transforms.ToTensor()]) 22 | """ 23 | def __init__(self, fillcolor=(128, 128, 128)): 24 | self.policies = [ 25 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 26 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 27 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 28 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 29 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 30 | 31 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 32 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 33 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 34 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 35 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 36 | 37 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 38 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 39 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 40 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 41 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 42 | 43 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 44 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 45 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 46 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 47 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 48 | 49 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 50 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 51 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 52 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 53 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 54 | ] 55 | 56 | 57 | def __call__(self, img): 58 | policy_idx = random.randint(0, len(self.policies) - 1) 59 | return self.policies[policy_idx](img) 60 | 61 | def __repr__(self): 62 | return "AutoAugment ImageNet Policy" 63 | 64 | 65 | class CIFAR10Policy(object): 66 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 67 | 68 | Example: 69 | >>> policy = CIFAR10Policy() 70 | >>> transformed = policy(image) 71 | 72 | Example as a PyTorch Transform: 73 | >>> transform=transforms.Compose([ 74 | >>> transforms.Resize(256), 75 | >>> CIFAR10Policy(), 76 | >>> transforms.ToTensor()]) 77 | """ 78 | def __init__(self, fillcolor=(128, 128, 128)): 79 | self.policies = [ 80 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 81 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 82 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 83 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 84 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 85 | 86 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 87 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 88 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 89 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 90 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 91 | 92 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 93 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 94 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 95 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 96 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 97 | 98 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 99 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 100 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 101 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 102 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 103 | 104 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 105 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 106 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 107 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 108 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 109 | ] 110 | 111 | 112 | def __call__(self, img): 113 | policy_idx = random.randint(0, len(self.policies) - 1) 114 | return self.policies[policy_idx](img) 115 | 116 | def __repr__(self): 117 | return "AutoAugment CIFAR10 Policy" 118 | 119 | 120 | class SVHNPolicy(object): 121 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 122 | 123 | Example: 124 | >>> policy = SVHNPolicy() 125 | >>> transformed = policy(image) 126 | 127 | Example as a PyTorch Transform: 128 | >>> transform=transforms.Compose([ 129 | >>> transforms.Resize(256), 130 | >>> SVHNPolicy(), 131 | >>> transforms.ToTensor()]) 132 | """ 133 | def __init__(self, fillcolor=(128, 128, 128)): 134 | self.policies = [ 135 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 136 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 137 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 138 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 139 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 140 | 141 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 142 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 143 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 144 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 145 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 146 | 147 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 148 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 149 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 150 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 151 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 152 | 153 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 154 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 155 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 156 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 157 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 158 | 159 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 160 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 161 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 162 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 163 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 164 | ] 165 | 166 | 167 | def __call__(self, img): 168 | policy_idx = random.randint(0, len(self.policies) - 1) 169 | return self.policies[policy_idx](img) 170 | 171 | def __repr__(self): 172 | return "AutoAugment SVHN Policy" 173 | 174 | 175 | class SubPolicy(object): 176 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 177 | ranges = { 178 | "shearX": np.linspace(0, 0.3, 10), 179 | "shearY": np.linspace(0, 0.3, 10), 180 | "translateX": np.linspace(0, 150 / 331, 10), 181 | "translateY": np.linspace(0, 150 / 331, 10), 182 | "rotate": np.linspace(0, 30, 10), 183 | "color": np.linspace(0.0, 0.9, 10), 184 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 185 | "solarize": np.linspace(256, 0, 10), 186 | "contrast": np.linspace(0.0, 0.9, 10), 187 | "sharpness": np.linspace(0.0, 0.9, 10), 188 | "brightness": np.linspace(0.0, 0.9, 10), 189 | "autocontrast": [0] * 10, 190 | "equalize": [0] * 10, 191 | "invert": [0] * 10 192 | } 193 | 194 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 195 | def rotate_with_fill(img, magnitude): 196 | rot = img.convert("RGBA").rotate(magnitude) 197 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 198 | 199 | func = { 200 | "shearX": lambda img, magnitude: img.transform( 201 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 202 | Image.BICUBIC, fillcolor=fillcolor), 203 | "shearY": lambda img, magnitude: img.transform( 204 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 205 | Image.BICUBIC, fillcolor=fillcolor), 206 | "translateX": lambda img, magnitude: img.transform( 207 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 208 | fillcolor=fillcolor), 209 | "translateY": lambda img, magnitude: img.transform( 210 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 211 | fillcolor=fillcolor), 212 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 213 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 214 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 215 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 216 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 217 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 218 | 1 + magnitude * random.choice([-1, 1])), 219 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 220 | 1 + magnitude * random.choice([-1, 1])), 221 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 222 | 1 + magnitude * random.choice([-1, 1])), 223 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 224 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 225 | "invert": lambda img, magnitude: ImageOps.invert(img) 226 | } 227 | 228 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 229 | # operation1, ranges[operation1][magnitude_idx1], 230 | # operation2, ranges[operation2][magnitude_idx2]) 231 | self.p1 = p1 232 | self.operation1 = func[operation1] 233 | self.magnitude1 = ranges[operation1][magnitude_idx1] 234 | self.p2 = p2 235 | self.operation2 = func[operation2] 236 | self.magnitude2 = ranges[operation2][magnitude_idx2] 237 | 238 | 239 | def __call__(self, img): 240 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 241 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 242 | return img 243 | -------------------------------------------------------------------------------- /cifar100.py: -------------------------------------------------------------------------------- 1 | from dataset import Cifar100Dataset 2 | from model import * 3 | from metrics.metrics import accuracy 4 | from augmentations.augmentations import TrainAugment, TestAugment 5 | 6 | 7 | # TODO: 8 | # stochastic depth (Huang et al., 2016) with drop connect ratio 0.3. 9 | 10 | 11 | def main(): 12 | 13 | torch.manual_seed(2019) 14 | if torch.cuda.is_available(): 15 | torch.cuda.manual_seed_all(2019) 16 | 17 | device = 'cuda:0' 18 | ################################### 19 | ### Dataset 20 | ################################### 21 | train_dataset = Cifar100Dataset(root='./input/', train=True, download=True, transform=TrainAugment()) 22 | test_dataset = Cifar100Dataset(root='./input/', train=False, download=True, transform=TestAugment()) 23 | 24 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, 25 | shuffle=True, num_workers=8, 26 | pin_memory=False) 27 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=128, 28 | shuffle=False, num_workers=8, 29 | pin_memory=False) 30 | 31 | ######################################### 32 | # model 33 | model = efficientnet_b0(num_classes=100).to(device) 34 | # Optimizer 35 | # torch.optim : https://pytorch.org/docs/stable/optim.html 36 | optimizer = torch.optim.RMSprop(model.parameters(), lr=0.016, momentum=0.9, weight_decay=0.9) 37 | # optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5) 38 | # Scheduler 39 | # torch.optim.lr_scheduler : https://pytorch.org/docs/stable/optim.html?highlight=lr_scheduler#torch.optim.lr_scheduler.LambdaLR 40 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(2.4*len(train_dataloader)), gamma=0.97) 41 | # Loss and Evalaiton function 42 | loss_fn = torch.nn.CrossEntropyLoss().to(device) 43 | eval_fn = accuracy 44 | 45 | epoch = 500 46 | 47 | for i in range(epoch): 48 | train_loss, train_accuracy = train(model, optimizer, train_dataloader, 49 | device, loss_fn, eval_fn, i, scheduler) 50 | test_loss, test_accuracy = test(model, test_dataloader, 51 | device, loss_fn, eval_fn) 52 | print(f'''======== epoch {i:>3} (lr: {scheduler.get_lr()[0]:.5f}) ======== 53 | train loss = {train_loss:.5f} | train err = {1-train_accuracy:.2%} | 54 | test loss = {test_loss:.5f} | test err = {1-test_accuracy:.2%}''') 55 | 56 | print('=== Success ===') 57 | 58 | 59 | def l2_loss(model): 60 | loss = 0.0 61 | for m in model.modules(): 62 | if isinstance(m, (nn.Conv2d, nn.Linear)): 63 | for p in m.parameters(): 64 | loss += (p ** 2).sum() / 2 #p.norm(2) 65 | 66 | return loss 67 | 68 | def train(model, optimizer, dataloader, device, loss_fn, eval_fn, epoch, scheduler=None): 69 | model.train() 70 | avg_loss = 0 71 | avg_accuracy = 0 72 | for step, (inputs, targets) in enumerate(dataloader): 73 | inputs = inputs.to(device) 74 | targets = targets.to(device) 75 | optimizer.zero_grad() 76 | logits = model(inputs) 77 | preds = logits.softmax(dim=1) 78 | loss = loss_fn(logits, targets.argmax(dim=1)) 79 | loss += 1e-5 * l2_loss(model) 80 | loss.backward() 81 | optimizer.step() 82 | avg_loss += loss.item() 83 | avg_accuracy += eval_fn(preds, targets) 84 | if scheduler is not None: 85 | scheduler.step() 86 | avg_loss /= len(dataloader) 87 | avg_accuracy /= len(dataloader) 88 | return avg_loss, avg_accuracy 89 | 90 | 91 | def test(model, dataloader, device, loss_fn, eval_fn): 92 | model.eval() 93 | avg_loss = 0 94 | avg_accuracy = 0 95 | with torch.no_grad(): 96 | for inputs, targets in dataloader: 97 | inputs = inputs.to(device) 98 | targets = targets.to(device) 99 | logits = model(inputs) 100 | preds = logits.softmax(dim=1) 101 | loss = loss_fn(logits, targets.argmax(dim=1)) 102 | loss += 1e-5 * l2_loss(model) 103 | avg_loss += loss.item() 104 | avg_accuracy += eval_fn(preds, targets) 105 | 106 | avg_loss /= len(dataloader) 107 | avg_accuracy /= len(dataloader) 108 | return avg_loss, avg_accuracy 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | 6 | 7 | __all__ = ['Cifar100Dataset'] 8 | 9 | class Cifar100Dataset(Dataset): 10 | def __init__(self, root, download=False, train=True, transform=None): 11 | 12 | data = torchvision.datasets.CIFAR100(root=root,train=train, download=download) 13 | # n x 32 x 32 x 3 (uint8, np.array) 14 | self.images = data.data 15 | # n (list) 16 | self.labels = data.targets 17 | self.transform = transform 18 | self.train = train 19 | 20 | def __getitem__(self, index): 21 | if self.train: 22 | image = self.images[index] 23 | target = np.zeros((100),dtype=np.float32) 24 | target[self.labels[index]] = 1.0 25 | 26 | if self.transform is not None: 27 | image = self.transform(image) 28 | 29 | else: 30 | image = self.images[index] 31 | target = np.zeros((100), dtype=np.float32) 32 | target[self.labels[index]] = 1.0 33 | if self.transform is not None: 34 | image = self.transform(image) 35 | 36 | return image, target 37 | 38 | def __len__(self): 39 | return len(self.images) -------------------------------------------------------------------------------- /input/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katsura-jp/efficientnet-pytorch/df0506c0b23e920b69f74b793f8ba1ddc1df54de/input/.gitkeep -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katsura-jp/efficientnet-pytorch/df0506c0b23e920b69f74b793f8ba1ddc1df54de/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def accuracy(preds,target): 4 | # labels = target.argmax(dim=1) 5 | acc = preds.argmax(dim=1).eq(target.argmax(dim=1)).float().mean() 6 | return acc 7 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .efficientnet import * -------------------------------------------------------------------------------- /model/efficientnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .swish import Swish 6 | 7 | 8 | class SqeezeExcitation(nn.Module): 9 | def __init__(self, inplanes, se_ratio): 10 | super(SqeezeExcitation, self).__init__() 11 | hidden_dim = int(inplanes*se_ratio) 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.fc1 = nn.Linear(inplanes, hidden_dim, bias=False) 14 | self.fc2 = nn.Linear(hidden_dim, inplanes, bias=False) 15 | self.swish = Swish() 16 | self.sigmoid = nn.Sigmoid() 17 | def forward(self, x): 18 | out = self.avg_pool(x).view(x.size(0), -1) 19 | out = self.fc1(out) 20 | out = self.swish(out) 21 | out = self.fc2(out) 22 | out = self.sigmoid(out) 23 | out = out.unsqueeze(2).unsqueeze(3) 24 | out = x * out.expand_as(x) 25 | return out 26 | 27 | 28 | class Bottleneck(nn.Module): 29 | def __init__(self,inplanes, planes, kernel_size, stride, expand, se_ratio, prob=1.0): 30 | super(Bottleneck, self).__init__() 31 | if expand == 1: 32 | self.conv2 = nn.Conv2d(inplanes*expand, inplanes*expand, kernel_size=kernel_size, stride=stride, 33 | padding=kernel_size//2, groups=inplanes*expand, bias=False) 34 | self.bn2 = nn.BatchNorm2d(inplanes*expand, momentum=0.99, eps=1e-3) 35 | self.se = SqeezeExcitation(inplanes*expand, se_ratio) 36 | self.conv3 = nn.Conv2d(inplanes*expand, planes, kernel_size=1, bias=False) 37 | self.bn3 = nn.BatchNorm2d(planes, momentum=0.99, eps=1e-3) 38 | else: 39 | self.conv1 = nn.Conv2d(inplanes, inplanes*expand, kernel_size=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(inplanes*expand, momentum=0.99, eps=1e-3) 41 | self.conv2 = nn.Conv2d(inplanes*expand, inplanes*expand, kernel_size=kernel_size, stride=stride, 42 | padding=kernel_size//2, groups=inplanes*expand, bias=False) 43 | self.bn2 = nn.BatchNorm2d(inplanes*expand, momentum=0.99, eps=1e-3) 44 | self.se = SqeezeExcitation(inplanes*expand, se_ratio) 45 | self.conv3 = nn.Conv2d(inplanes*expand, planes, kernel_size=1, bias=False) 46 | self.bn3 = nn.BatchNorm2d(planes, momentum=0.99, eps=1e-3) 47 | 48 | self.swish = Swish() 49 | self.correct_dim = (stride == 1) and (inplanes == planes) 50 | self.prob = torch.Tensor([prob]) 51 | 52 | def forward(self, x): 53 | if self.training: 54 | if not torch.bernoulli(self.prob): 55 | # drop 56 | return x 57 | 58 | if hasattr(self, 'conv1'): 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.swish(out) 62 | else: 63 | out = x 64 | 65 | out = self.conv2(out) # depth wise conv 66 | out = self.bn2(out) 67 | out = self.swish(out) 68 | 69 | out = self.se(out) 70 | 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.correct_dim: 76 | out += x 77 | 78 | return out 79 | 80 | 81 | class MBConv(nn.Module): 82 | def __init__(self, inplanes, planes, repeat, kernel_size, stride, expand, se_ratio, sum_layer, count_layer=None, pl=0.5): 83 | super(MBConv, self).__init__() 84 | layer = [] 85 | 86 | # not drop(stchastic depth) 87 | layer.append(Bottleneck(inplanes, planes, kernel_size, stride, expand, se_ratio)) 88 | 89 | for l in range(1, repeat): 90 | if count_layer is None: 91 | layer.append(Bottleneck(planes, planes, kernel_size, 1, expand, se_ratio)) 92 | else: 93 | # stochastic depth 94 | prob = 1.0 - (count_layer + l) / sum_layer * (1 - pl) 95 | layer.append(Bottleneck(planes, planes, kernel_size, 1, expand, se_ratio, prob=prob)) 96 | 97 | self.layer = nn.Sequential(*layer) 98 | 99 | def forward(self, x): 100 | out = self.layer(x) 101 | return out 102 | 103 | 104 | class Upsample(nn.Module): 105 | def __init__(self, scale): 106 | super(Upsample, self).__init__() 107 | self.scale = scale 108 | 109 | def forward(self, x): 110 | return F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) 111 | 112 | 113 | class Flatten(nn.Module): 114 | def __init(self): 115 | super(Flatten, self).__init__() 116 | def forward(self, x): 117 | return x.view(x.size(0), -1) 118 | 119 | 120 | class EfficientNet(nn.Module): 121 | def __init__(self, num_classes=1000, width_coef=1., depth_coef=1., scale=1., 122 | dropout_ratio=0.2, se_ratio=0.25, stochastic_depth=False, pl=0.5): 123 | 124 | super(EfficientNet, self).__init__() 125 | channels = [32, 16, 24, 40, 80, 112, 192, 320, 1280] 126 | expands = [1, 6, 6, 6, 6, 6, 6] 127 | repeats = [1, 2, 2, 3, 3, 4, 1] 128 | strides = [1, 2, 2, 2, 1, 2, 1] 129 | kernel_sizes = [3, 3, 5, 3, 5, 5, 3] 130 | depth = depth_coef 131 | width = width_coef 132 | 133 | 134 | channels = [round(x*width) for x in channels] # [int(x*width) for x in channels] 135 | repeats = [round(x*depth) for x in repeats] # [int(x*width) for x in repeats] 136 | 137 | sum_layer = sum(repeats) 138 | 139 | self.upsample = Upsample(scale) 140 | self.swish = Swish() 141 | 142 | self.stage1 = nn.Sequential( 143 | nn.Conv2d(3, channels[0], kernel_size=3, stride=2, padding=1, bias=False), 144 | nn.BatchNorm2d(channels[0], momentum=0.99, eps=1e-3)) 145 | 146 | if stochastic_depth: 147 | # stochastic depth 148 | self.stage2 = MBConv(channels[0], channels[1], repeats[0], kernel_size=kernel_sizes[0], 149 | stride=strides[0], expand=expands[0], se_ratio=se_ratio, sum_layer=sum_layer, 150 | count_layer=sum(repeats[:0]), pl=pl) 151 | self.stage3 = MBConv(channels[1], channels[2], repeats[1], kernel_size=kernel_sizes[1], 152 | stride=strides[1], expand=expands[1], se_ratio=se_ratio, sum_layer=sum_layer, 153 | count_layer=sum(repeats[:1]), pl=pl) 154 | self.stage4 = MBConv(channels[2], channels[3], repeats[2], kernel_size=kernel_sizes[2], 155 | stride=strides[2], expand=expands[2], se_ratio=se_ratio, sum_layer=sum_layer, 156 | count_layer=sum(repeats[:2]), pl=pl) 157 | self.stage5 = MBConv(channels[3], channels[4], repeats[3], kernel_size=kernel_sizes[3], 158 | stride=strides[3], expand=expands[3], se_ratio=se_ratio, sum_layer=sum_layer, 159 | count_layer=sum(repeats[:3]), pl=pl) 160 | self.stage6 = MBConv(channels[4], channels[5], repeats[4], kernel_size=kernel_sizes[4], 161 | stride=strides[4], expand=expands[4], se_ratio=se_ratio, sum_layer=sum_layer, 162 | count_layer=sum(repeats[:4]), pl=pl) 163 | self.stage7 = MBConv(channels[5], channels[6], repeats[5], kernel_size=kernel_sizes[5], 164 | stride=strides[5], expand=expands[5], se_ratio=se_ratio, sum_layer=sum_layer, 165 | count_layer=sum(repeats[:5]), pl=pl) 166 | self.stage8 = MBConv(channels[6], channels[7], repeats[6], kernel_size=kernel_sizes[6], 167 | stride=strides[6], expand=expands[6], se_ratio=se_ratio, sum_layer=sum_layer, 168 | count_layer=sum(repeats[:6]), pl=pl) 169 | else: 170 | self.stage2 = MBConv(channels[0], channels[1], repeats[0], kernel_size=kernel_sizes[0], 171 | stride=strides[0], expand=expands[0], se_ratio=se_ratio, sum_layer=sum_layer) 172 | self.stage3 = MBConv(channels[1], channels[2], repeats[1], kernel_size=kernel_sizes[1], 173 | stride=strides[1], expand=expands[1], se_ratio=se_ratio, sum_layer=sum_layer) 174 | self.stage4 = MBConv(channels[2], channels[3], repeats[2], kernel_size=kernel_sizes[2], 175 | stride=strides[2], expand=expands[2], se_ratio=se_ratio, sum_layer=sum_layer) 176 | self.stage5 = MBConv(channels[3], channels[4], repeats[3], kernel_size=kernel_sizes[3], 177 | stride=strides[3], expand=expands[3], se_ratio=se_ratio, sum_layer=sum_layer) 178 | self.stage6 = MBConv(channels[4], channels[5], repeats[4], kernel_size=kernel_sizes[4], 179 | stride=strides[4], expand=expands[4], se_ratio=se_ratio, sum_layer=sum_layer) 180 | self.stage7 = MBConv(channels[5], channels[6], repeats[5], kernel_size=kernel_sizes[5], 181 | stride=strides[5], expand=expands[5], se_ratio=se_ratio, sum_layer=sum_layer) 182 | self.stage8 = MBConv(channels[6], channels[7], repeats[6], kernel_size=kernel_sizes[6], 183 | stride=strides[6], expand=expands[6], se_ratio=se_ratio, sum_layer=sum_layer) 184 | 185 | self.stage9 = nn.Sequential( 186 | nn.Conv2d(channels[7], channels[8], kernel_size=1, bias=False), 187 | nn.BatchNorm2d(channels[8], momentum=0.99, eps=1e-3), 188 | Swish(), 189 | nn.AdaptiveAvgPool2d((1, 1)), 190 | Flatten(), 191 | nn.Dropout(p=dropout_ratio), 192 | nn.Linear(channels[8], num_classes)) 193 | 194 | for m in self.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 197 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='sigmoid') 198 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 199 | nn.init.constant_(m.weight, 1) 200 | nn.init.constant_(m.bias, 0) 201 | 202 | 203 | def forward(self, x): 204 | x = self.upsample(x) 205 | x = self.swish(self.stage1(x)) 206 | x = self.swish(self.stage2(x)) 207 | x = self.swish(self.stage3(x)) 208 | x = self.swish(self.stage4(x)) 209 | x = self.swish(self.stage5(x)) 210 | x = self.swish(self.stage6(x)) 211 | x = self.swish(self.stage7(x)) 212 | x = self.swish(self.stage8(x)) 213 | logit = self.stage9(x) 214 | 215 | return logit 216 | 217 | 218 | 219 | def efficientnet_b0(num_classes=1000): 220 | return EfficientNet(num_classes=num_classes, width_coef=1.0, depth_coef=1.0, scale=1.0,dropout_ratio=0.2, se_ratio=0.25) 221 | 222 | def efficientnet_b1(num_classes=1000): 223 | return EfficientNet(num_classes=num_classes, width_coef=1.0, depth_coef=1.1, scale=240/224, dropout_ratio=0.2, se_ratio=0.25) 224 | 225 | def efficientnet_b2(num_classes=1000): 226 | return EfficientNet(num_classes=num_classes, width_coef=1.1, depth_coef=1.2, scale=260/224., dropout_ratio=0.3, se_ratio=0.25) 227 | 228 | def efficientnet_b3(num_classes=1000): 229 | return EfficientNet(num_classes=num_classes, width_coef=1.2, depth_coef=1.4, scale=300/224, dropout_ratio=0.3, se_ratio=0.25) 230 | 231 | def efficientnet_b4(num_classes=1000): 232 | return EfficientNet(num_classes=num_classes, width_coef=1.4, depth_coef=1.8, scale=380/224, dropout_ratio=0.4, se_ratio=0.25) 233 | 234 | def efficientnet_b5(num_classes=1000): 235 | return EfficientNet(num_classes=num_classes, width_coef=1.6, depth_coef=2.2, scale=456/224, dropout_ratio=0.4, se_ratio=0.25) 236 | 237 | def efficientnet_b6(num_classes=1000): 238 | return EfficientNet(num_classes=num_classes, width_coef=1.8, depth_coef=2.6, scale=528/224, dropout_ratio=0.5, se_ratio=0.25) 239 | 240 | def efficientnet_b7(num_classes=1000): 241 | return EfficientNet(num_classes=num_classes, width_coef=2.0, depth_coef=3.1, scale=600/224, dropout_ratio=0.5, se_ratio=0.25) 242 | 243 | def test(): 244 | x = torch.FloatTensor(64, 3, 224, 224) 245 | model = EfficientNet(num_classes=100, width_coef=1.0, depth_coef=1.0, scale=1.0,dropout_ratio=0.2, 246 | se_ratio=0.25, stochastic_depth=True) 247 | logit = model(x) 248 | print(logit.size()) 249 | 250 | if __name__ == '__main__': 251 | test() 252 | -------------------------------------------------------------------------------- /model/swish.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | from torch._jit_internal import weak_module, weak_script_method 5 | 6 | 7 | @weak_module 8 | class Swish(nn.Module): 9 | def __init__(self, train_beta=False): 10 | super(Swish, self).__init__() 11 | if train_beta: 12 | self.weight = Parameter(torch.Tensor([1.])) 13 | else: 14 | self.weight = 1.0 15 | 16 | @weak_script_method 17 | def forward(self, input): 18 | return input * torch.sigmoid(self.weight * input) 19 | 20 | 21 | def test(): 22 | x = torch.FloatTensor(16, 128, 16, 16) 23 | swish = Swish(train_beta=True) 24 | print(swish(x).size()) 25 | 26 | if __name__ == '__main__': 27 | test() 28 | --------------------------------------------------------------------------------