├── .vscode ├── settings.json └── launch.json ├── models ├── layers.py ├── mdconv.py ├── utils.py └── mixnet.py ├── readme.md ├── loader.py ├── .gitignore ├── logger.py ├── main.py ├── runner.py └── ema_runner.py /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "workbench.colorCustomizations": {} 3 | } -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Swish(nn.Module): 6 | def forward(self, x): 7 | return x * torch.sigmoid(x) 8 | 9 | 10 | class Flatten(nn.Module): 11 | def forward(self, x): 12 | return x.view(x.shape[0], -1) 13 | 14 | 15 | class SEModule(nn.Module): 16 | def __init__(self, ch, squeeze_ch): 17 | super().__init__() 18 | self.se = nn.Sequential( 19 | nn.AdaptiveAvgPool2d(1), 20 | nn.Conv2d(ch, squeeze_ch, 1, 1, 0, bias=True), 21 | Swish(), 22 | nn.Conv2d(squeeze_ch, ch, 1, 1, 0, bias=True), 23 | ) 24 | 25 | def forward(self, x): 26 | return x * torch.sigmoid(self.se(x)) 27 | -------------------------------------------------------------------------------- /models/mdconv.py: -------------------------------------------------------------------------------- 1 | # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def _split_channels(total_filters, num_groups): 8 | """ 9 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py#L33 10 | """ 11 | split = [total_filters // num_groups for _ in range(num_groups)] 12 | split[0] += total_filters - sum(split) 13 | return split 14 | 15 | 16 | class MDConv(nn.Module): 17 | def __init__(self, in_channels, kernel_sizes, stride, dilatied=False, bias=False): 18 | super().__init__() 19 | 20 | if not isinstance(kernel_sizes, list): 21 | kernel_sizes = [kernel_sizes] 22 | 23 | self.in_channels = _split_channels(in_channels, len(kernel_sizes)) 24 | 25 | self.convs = nn.ModuleList() 26 | for ch, k in zip(self.in_channels, kernel_sizes): 27 | dilation = 1 28 | if stride[0] == 1 and dilatied: 29 | dilation, stride = (k - 1) // 2, 3 30 | print("Use dilated conv with dilation rate = {}".format(dilation)) 31 | pad = ((stride[0] - 1) + dilation * (k - 1)) // 2 32 | 33 | conv = nn.Conv2d(ch, ch, k, stride, pad, dilation, 34 | groups=ch, bias=bias) 35 | self.convs.append(conv) 36 | 37 | def forward(self, x): 38 | xs = torch.split(x, self.in_channels, 1) 39 | return torch.cat([conv(x) for conv, x in zip(self.convs, xs)], 1) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Mixnet 2 | 3 | A PyTorch implementation of `MixNet: Mixed Depthwise Convolutional Kernels.` 4 | 5 | 6 | ### [[arxiv]](https://arxiv.org/abs/1907.09595) [[Official TF Repo]](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet) 7 | 8 |
9 | 10 | ## Acknowledge 11 | 12 | Now EMA is running on CPU. So It slower than normal runner. 13 | 14 | If you running on GPU, then change these lines [init](ema_runner.py#23), [update_ema](ema_runner.py#96) 15 | 16 |
17 | 18 | ## How to use: 19 | 20 | ``` 21 | python3 main.py -h 22 | usage: main.py [-h] --save_dir SAVE_DIR [--root ROOT] [--gpus GPUS] 23 | [--num_workers NUM_WORKERS] [--model {mixs}] [--epoch EPOCH] 24 | [--batch_size BATCH_SIZE] [--test] [--ema_decay EMA_DECAY] 25 | [--optim {rmsprop,adam}] [--lr LR] [--beta [BETA [BETA ...]]] 26 | [--momentum MOMENTUM] [--eps EPS] [--decay DECAY] 27 | [--scheduler {exp,cosine,none}] 28 | 29 | Pytorch Mixnet 30 | 31 | optional arguments: 32 | -h, --help show this help message and exit 33 | --save_dir SAVE_DIR Directory name to save the model 34 | --root ROOT The Directory of data path. 35 | --gpus GPUS Select GPU Numbers | 0,1,2,3 | 36 | --num_workers NUM_WORKERS 37 | Select CPU Number workers 38 | --model {mixs} The type of mixnet. 39 | --epoch EPOCH The number of epochs 40 | --batch_size BATCH_SIZE 41 | The size of batch 42 | --test Only Test 43 | --ema_decay EMA_DECAY 44 | Exponential Moving Average Term 45 | --optim {rmsprop,adam} 46 | --lr LR Base learning rate when train batch size is 256. 47 | --beta [BETA [BETA ...]] 48 | --momentum MOMENTUM 49 | --eps EPS 50 | --decay DECAY 51 | --scheduler {exp,cosine,none} 52 | Learning rate scheduler type 53 | ``` 54 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import MNIST 2 | from torchvision.datasets import CIFAR10, CIFAR100 3 | from torchvision.datasets import ImageFolder 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | from torchvision import transforms as T 8 | 9 | 10 | def get_dataset(root, dtype="cifar10", resl=224): 11 | tr = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) 12 | if dtype == "mnist": 13 | dset = MNIST 14 | elif dtype == "cifar10": 15 | dset = CIFAR10 16 | elif dtype == "cifar100": 17 | dset = CIFAR100 18 | elif dtype == "imagenet": 19 | return imagenet(root, resl) 20 | 21 | train = dset(root, True, transform=tr, download=True) 22 | valid = dset(root, False, transform=tr) 23 | return train, valid 24 | 25 | 26 | def imagenet(root, resl): 27 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 28 | std=[0.229, 0.224, 0.225]) 29 | train = ImageFolder( 30 | root + "/train", 31 | T.Compose([ 32 | T.Resize([resl, resl]), 33 | T.RandomResizedCrop(resl), 34 | T.RandomHorizontalFlip(), 35 | T.ToTensor(), 36 | normalize, 37 | ]) 38 | ) 39 | 40 | valid = ImageFolder( 41 | root + "/val", 42 | T.Compose([ 43 | T.Resize([resl, resl]), 44 | T.ToTensor(), 45 | normalize, 46 | ]) 47 | ) 48 | 49 | return train, valid 50 | 51 | 52 | def get_loaders(root, batch_size, num_workers=32, dtype="cifar10", resl=224): 53 | train, valid = get_dataset(root, dtype, resl) 54 | 55 | train_loader = DataLoader(train, 56 | batch_size=batch_size, shuffle=True, 57 | num_workers=num_workers, pin_memory=True 58 | ) 59 | 60 | val_loader = DataLoader(valid, 61 | batch_size=batch_size, shuffle=False, 62 | num_workers=num_workers, pin_memory=True 63 | ) 64 | return train_loader, val_loader 65 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File (Integrated Terminal)", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | }, 14 | { 15 | "name": "Python: Remote Attach", 16 | "type": "python", 17 | "request": "attach", 18 | "port": 5678, 19 | "host": "localhost", 20 | "pathMappings": [ 21 | { 22 | "localRoot": "${workspaceFolder}", 23 | "remoteRoot": "." 24 | } 25 | ] 26 | }, 27 | { 28 | "name": "Python: Module", 29 | "type": "python", 30 | "request": "launch", 31 | "module": "enter-your-module-name-here", 32 | "console": "integratedTerminal" 33 | }, 34 | { 35 | "name": "Python: Django", 36 | "type": "python", 37 | "request": "launch", 38 | "program": "${workspaceFolder}/manage.py", 39 | "console": "integratedTerminal", 40 | "args": [ 41 | "runserver", 42 | "--noreload", 43 | "--nothreading" 44 | ], 45 | "django": true 46 | }, 47 | { 48 | "name": "Python: Flask", 49 | "type": "python", 50 | "request": "launch", 51 | "module": "flask", 52 | "env": { 53 | "FLASK_APP": "app.py" 54 | }, 55 | "args": [ 56 | "run", 57 | "--no-debugger", 58 | "--no-reload" 59 | ], 60 | "jinja": true 61 | }, 62 | { 63 | "name": "Python: Current File (External Terminal)", 64 | "type": "python", 65 | "request": "launch", 66 | "program": "${file}", 67 | "console": "externalTerminal" 68 | } 69 | ] 70 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tmp.md 2 | tmp 3 | tmp/* 4 | data 5 | data/* 6 | outs 7 | outs/* 8 | 9 | .DS_Store 10 | *.ipynb 11 | *.npy 12 | *.jpg 13 | *.mat 14 | *.png 15 | *.pyc 16 | *.ubuntu 17 | 18 | ################################################################## 19 | # github python ignore 20 | # https://github.com/github/gitignore/blob/master/Python.gitignore 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | pip-wheel-metadata/ 45 | share/python-wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | MANIFEST 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .nox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # celery beat schedule file 114 | celerybeat-schedule 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | ################################################################ -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from cycler import cycler 4 | from collections import OrderedDict 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | 11 | # x axis of plot 12 | LOG_KEYS = { 13 | "train":"epoch", 14 | "valid":"epoch", 15 | "test": "fname" 16 | } 17 | 18 | # y axis of plot 19 | # save datas like loss, f1-score, PSNR, SSIM .. 20 | # can multiple datas 21 | LOG_VALUES = { 22 | "train":["loss", ], 23 | "valid":["acc","valid_acc"], 24 | "test": ["train_acc", "valid_acc"] 25 | } 26 | 27 | 28 | class Logger: 29 | def __init__(self, save_dir): 30 | self.save_dir = save_dir 31 | self.log_file = save_dir + "/log.txt" 32 | self.buffers = [] 33 | 34 | def will_write(self, line): 35 | print(line) 36 | self.buffers.append(line) 37 | 38 | def flush(self): 39 | with open(self.log_file, "a", encoding="utf-8") as f: 40 | f.write("\n".join(self.buffers)) 41 | f.write("\n") 42 | self.buffers = [] 43 | 44 | def write(self, line): 45 | self.will_write(line) 46 | self.flush() 47 | 48 | def log_write(self, learn_type, **values): 49 | """log write in buffers 50 | 51 | ex ) log_write("train", epoch=1, loss=0.3) 52 | 53 | Parmeters: 54 | learn_type : it must be train, valid or test 55 | values : values keys in LOG_VALUES 56 | """ 57 | for k in values.keys(): 58 | if k not in LOG_VALUES[learn_type] and k != LOG_KEYS[learn_type]: 59 | raise KeyError("%s Log %s keys not in log" % (learn_type, k)) 60 | 61 | log = "[%s] %s" % (learn_type, json.dumps(values)) 62 | self.will_write(log) 63 | if learn_type != "train": 64 | self.flush() 65 | 66 | def log_parse(self, log_key): 67 | log_dict = OrderedDict() 68 | with open(self.log_file, "r", encoding="utf-8") as f: 69 | for line in f.readlines(): 70 | if len(line) == 1 or not line.startswith("[%s]" % (log_key)): 71 | continue 72 | # line : ~~ 73 | line = line[line.find("] ") + 2:] # ~~ 74 | line_log = json.loads(line) 75 | 76 | train_log_key = line_log[LOG_KEYS[log_key]] 77 | line_log.pop(LOG_KEYS[log_key], None) 78 | log_dict[train_log_key] = line_log 79 | return log_dict 80 | 81 | def log_plot(self, log_key, 82 | figsize=(12, 12), title="plot", colors=["C1", "C2"]): 83 | fig = plt.figure(figsize=figsize) 84 | plt.title(title) 85 | plt.legend(LOG_VALUES[log_key], loc="best") 86 | 87 | ax = plt.subplot(111) 88 | colors = plt.cm.nipy_spectral(np.linspace(0.1, 0.9, len(LOG_VALUES[log_key]))) 89 | ax.set_prop_cycle(cycler('color', colors)) 90 | 91 | log_dict = self.log_parse(log_key) 92 | x = log_dict.keys() 93 | for keys in LOG_VALUES[log_key]: 94 | if keys not in list(log_dict.values())[0]: 95 | continue 96 | y = [v[keys] for v in log_dict.values()] 97 | 98 | label = keys + ", max : %f" % (max(y)) 99 | ax.plot(x, y, marker="o", linestyle="solid", label=label) 100 | if max(y) > 1: 101 | ax.set_ylim([min(y) - 1, y[0] + 1]) 102 | ax.legend(fontsize=30) 103 | 104 | plt.show() 105 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from torch.optim.lr_scheduler import StepLR 9 | from torch.optim.lr_scheduler import CosineAnnealingLR 10 | 11 | from models.mixnet import mixnet_s 12 | 13 | from ema_runner import EMARunner 14 | from runner import Runner 15 | 16 | from loader import get_loaders 17 | 18 | from logger import Logger 19 | 20 | 21 | def arg_parse(): 22 | # projects description 23 | desc = "Pytorch Mixnet" 24 | parser = argparse.ArgumentParser(description=desc) 25 | parser.add_argument('--save_dir', type=str, required=True, 26 | help='Directory name to save the model') 27 | 28 | parser.add_argument('--dtype', type=str, default="cifar10", choice=["cifar10", "cifar100", "imagenet"]) 29 | parser.add_argument('--ema', action="store_true", help="Exponential Moving Average") 30 | 31 | parser.add_argument('--root', type=str, default="/data1/imagenet", 32 | help="The Directory of data path.") 33 | parser.add_argument('--gpus', type=str, default="0,1,2,3", 34 | help="Select GPU Numbers | 0,1,2,3 | ") 35 | parser.add_argument('--num_workers', type=int, default=32, 36 | help="Select CPU Number workers") 37 | 38 | parser.add_argument('--model', type=str, default='mixs', help='The type of mixnet.') 39 | 40 | parser.add_argument('--epoch', type=int, default=350, help='The number of epochs') 41 | parser.add_argument('--batch_size', type=int, default=1024, help='The size of batch') 42 | parser.add_argument('--test', action="store_true", help='Only Test') 43 | 44 | parser.add_argument('--optim', type=str, default='adam', choices=["rmsprop", "adam"]) 45 | parser.add_argument('--lr', type=float, default=0.016, help="Base learning rate when train batch size is 256.") 46 | # Adam Optimizer 47 | parser.add_argument('--beta', nargs="*", type=float, default=(0.5, 0.999)) 48 | 49 | parser.add_argument('--momentum', type=float, default=0.9) 50 | parser.add_argument('--eps', type=float, default=0.001) 51 | parser.add_argument('--decay', type=float, default=1e-5) 52 | 53 | parser.add_argument('--scheduler', type=str, default='exp', choices=["exp", "cosine", "none"], 54 | help="Learning rate scheduler type") 55 | 56 | return parser.parse_args() 57 | 58 | 59 | def get_scheduler(optim, sche_type, step_size, t_max): 60 | print("No Scheduler") 61 | return None 62 | 63 | if sche_type == "exp": 64 | return StepLR(optim, step_size, 0.97) 65 | elif sche_type == "cosine": 66 | return CosineAnnealingLR(optim, t_max) 67 | else: 68 | return None 69 | 70 | 71 | if __name__ == "__main__": 72 | arg = arg_parse() 73 | 74 | arg.save_dir = "%s/outs/%s" % (os.getcwd(), arg.save_dir) 75 | if os.path.exists(arg.save_dir) is False: 76 | os.mkdir(arg.save_dir) 77 | 78 | logger = Logger(arg.save_dir) 79 | logger.will_write(str(arg) + "\n") 80 | 81 | os.environ["CUDA_VISIBLE_DEVICES"] = arg.gpus 82 | device = torch.device("cuda") 83 | train_loader, val_loader = get_loaders(arg.root, arg.batch_size, arg.num_workers, 84 | dtype=arg.dtype) 85 | 86 | if arg.model == "mixs": 87 | net = mixnet_s(num_classes=len(train_loader.dataset.classes)) 88 | elif arg.model == "rw": 89 | import sys 90 | sys.path.append("rwightman") 91 | from timm.models.gen_efficientnet import mixnet_s 92 | net = mixnet_s(num_classes=len(train_loader.dataset.classes)) 93 | else: 94 | from torchvision.models import resnet50 95 | net = resnet50(num_classes=len(train_loader.dataset.classes)) 96 | 97 | net = nn.DataParallel(net) 98 | loss = nn.CrossEntropyLoss() 99 | 100 | scaled_lr = arg.lr * arg.batch_size / 256 101 | optim = { 102 | "adam" : lambda : torch.optim.Adam(net.parameters()), 103 | "rmsprop" : lambda : torch.optim.RMSprop(net.parameters(), lr=scaled_lr, momentum=arg.momentum, eps=arg.eps, weight_decay=arg.decay) 104 | }[arg.optim]() 105 | 106 | scheduler = get_scheduler(optim, arg.scheduler, int(2.4 * len(train_loader)), arg.epoch * len(train_loader)) 107 | 108 | if arg.ema: 109 | Runner = EMARunner 110 | 111 | run = Runner(arg.model, arg.save_dir, arg.epoch, 112 | net, optim, device, loss, logger, scheduler) 113 | if arg.test is False: 114 | run.train(train_loader, val_loader) 115 | run.test(train_loader, val_loader) 116 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | from glob import glob 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class Runner(): 14 | def __init__(self, model_type, save_dir, epochs, net, optim, device, loss, logger, scheduler=None): 15 | self.model_type = model_type 16 | self.save_dir = save_dir 17 | self.epochs = epochs 18 | 19 | self.logger = logger 20 | 21 | self.device = device 22 | 23 | self.net = net.to(device) 24 | 25 | self.loss = loss 26 | self.optim = optim 27 | self.scheduler = scheduler 28 | 29 | self.start_epoch = 0 30 | self.best_metric = -1 31 | 32 | self.load() 33 | 34 | def save(self, epoch, filename="train"): 35 | """Save current epoch model 36 | 37 | Save Elements: 38 | model_type : arg.model 39 | start_epoch : current epoch 40 | network : network parameters 41 | optimizer: optimizer parameters 42 | best_metric : current best score 43 | 44 | Parameters: 45 | epoch : current epoch 46 | filename : model save file name 47 | """ 48 | print("Model saved %d epoch" % (epoch)) 49 | return 50 | torch.save({"model_type": self.model_type, 51 | "start_epoch": epoch + 1, 52 | "network": self.net.module.state_dict(), 53 | "optimizer": self.optim.state_dict(), 54 | "best_metric": self.best_metric 55 | }, self.save_dir + "/%s.pth.tar" % (filename)) 56 | print("Model saved %d epoch" % (epoch)) 57 | 58 | def load(self, filename=""): 59 | """ Model load. same with save""" 60 | if filename == "": 61 | # load last epoch model 62 | filenames = sorted(glob(self.save_dir + "/*.pth.tar")) 63 | if len(filenames) == 0: 64 | print("Not Load") 65 | return 66 | else: 67 | filename = os.path.basename(filenames[-1]) 68 | 69 | file_path = self.save_dir + "/" + filename 70 | if os.path.exists(file_path) is True: 71 | print("Load %s to %s File" % (self.save_dir, filename)) 72 | ckpoint = torch.load(file_path) 73 | if ckpoint["model_type"] != self.model_type: 74 | raise ValueError("Ckpoint Model Type is %s" % 75 | (ckpoint["model_type"])) 76 | 77 | self.net.module.load_state_dict(ckpoint['network']) 78 | self.optim.load_state_dict(ckpoint['optimizer']) 79 | self.start_epoch = ckpoint['start_epoch'] 80 | self.best_metric = ckpoint["best_metric"] 81 | print("Load Model Type : %s, epoch : %d acc : %f" % 82 | (ckpoint["model_type"], self.start_epoch, self.best_metric)) 83 | else: 84 | print("Load Failed, not exists file") 85 | 86 | def train(self, train_loader, val_loader=None): 87 | print("\nStart Train len :", len(train_loader.dataset)) 88 | for epoch in range(self.start_epoch, self.epochs): 89 | self.net.train() 90 | for i, (input_, target_) in enumerate(train_loader): 91 | target_ = target_.to(self.device, non_blocking=True) 92 | 93 | if self.scheduler: 94 | self.scheduler.step() 95 | 96 | out = self.net(input_) 97 | loss = self.loss(out, target_) 98 | 99 | self.optim.zero_grad() 100 | loss.backward() 101 | self.optim.step() 102 | 103 | if (i % 50) == 0: 104 | self.logger.log_write("train", epoch=epoch, loss=loss.item()) 105 | 106 | if val_loader is not None: 107 | self.valid(epoch, val_loader) 108 | 109 | @torch.no_grad() 110 | def _get_acc(self, loader): 111 | correct = 0 112 | self.net.eval() 113 | for input_, target_ in loader: 114 | out = self.net(input_) 115 | out = F.softmax(out, dim=1).cpu() 116 | 117 | _, idx = out.max(dim=1) 118 | correct += (target_ == idx).sum().item() 119 | 120 | return correct / len(loader.dataset) 121 | 122 | def valid(self, epoch, val_loader): 123 | acc = self._get_acc(val_loader) 124 | self.logger.log_write("valid", epoch=epoch, acc=acc) 125 | 126 | if acc > self.best_metric: 127 | self.best_metric = acc 128 | self.save(epoch, "epoch[%05d]_acc[%.4f]" % (epoch, acc)) 129 | 130 | def test(self, train_loader, val_loader): 131 | print("\n Start Test") 132 | self.load() 133 | train_acc = self._get_acc(train_loader) 134 | valid_acc = self._get_acc(val_loader) 135 | self.logger.log_write("test", fname="test", train_acc=train_acc, valid_acc=valid_acc) 136 | return train_acc, valid_acc 137 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import namedtuple 3 | 4 | BlockArgs = namedtuple('BlockArgs', [ 5 | 'dw_ksize', 'expand_ksize', 'project_ksize', 'num_repeat', 6 | 'in_channels', 'out_channels', 'expand_ratio', 'id_skip', 7 | 'strides', 'se_ratio', 'swish', 'dilated', 8 | ]) 9 | 10 | 11 | def round_filters(filters, depth_multiplier, depth_divisor, min_depth): 12 | """Round number of filters based on depth depth_multiplier. 13 | TODO : ref link 14 | """ 15 | if not depth_multiplier: 16 | return filters 17 | 18 | filters *= depth_multiplier 19 | min_depth = min_depth or depth_divisor 20 | new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor) 21 | # Make sure that round down does not go down by more than 10%. 22 | if new_filters < 0.9 * filters: 23 | new_filters += depth_divisor 24 | return new_filters 25 | 26 | 27 | class MixnetDecoder: 28 | """A class of Mixnet decoder to get model configuration.""" 29 | 30 | @staticmethod 31 | def _decode_block_string(block_string, depth_multiplier, depth_divisor, min_depth): 32 | """Gets a mixnet block through a string notation of arguments. 33 | 34 | E.g. r2_k3_a1_p1_s2_e1_i32_o16_se0.25_noskip: r - number of repeat blocks, 35 | k - kernel size, s - strides (1-9), e - expansion ratio, i - input filters, 36 | o - output filters, se - squeeze/excitation ratio 37 | 38 | Args: 39 | block_string: a string, a string representation of block arguments. 40 | 41 | Returns: 42 | A BlockArgs instance. 43 | Raises: 44 | ValueError: if the strides option is not correctly specified. 45 | """ 46 | assert isinstance(block_string, str) 47 | 48 | ops = block_string.split('_') 49 | options = {} 50 | for op in ops: 51 | splits = re.split(r'(\d.*)', op) 52 | if len(splits) >= 2: 53 | key, value = splits[:2] 54 | options[key] = value 55 | 56 | if 's' not in options or len(options['s']) != 2: 57 | raise ValueError('Strides options should be a pair of integers.') 58 | 59 | def _parse_ksize(ss): 60 | ks = [int(k) for k in ss.split('.')] 61 | return ks if len(ks) > 1 else ks[0] 62 | 63 | return BlockArgs(num_repeat=int(options['r']), 64 | dw_ksize=_parse_ksize(options['k']), 65 | expand_ksize=_parse_ksize(options['a']), 66 | project_ksize=_parse_ksize(options['p']), 67 | strides=[int(options['s'][0]), int(options['s'][1])], 68 | expand_ratio=int(options['e']), 69 | in_channels=round_filters(int(options['i']), depth_multiplier, depth_divisor, min_depth), 70 | out_channels=round_filters(int(options['o']), depth_multiplier, depth_divisor, min_depth), 71 | id_skip=('noskip' not in block_string), 72 | se_ratio=float(options['se']) if 'se' in options else 0, 73 | swish=('sw' in block_string), 74 | dilated=('dilated' in block_string) 75 | ) 76 | 77 | @staticmethod 78 | def _encode_block_string(block): 79 | """Encodes a Mixnet block to a string.""" 80 | 81 | def _encode_ksize(arr): 82 | return '.'.join([str(k) for k in arr]) 83 | 84 | args = [ 85 | 'r%d' % block.num_repeat, 86 | 'k%s' % _encode_ksize(block.dw_ksize), 87 | 'a%s' % _encode_ksize(block.expand_ksize), 88 | 'p%s' % _encode_ksize(block.project_ksize), 89 | 's%d%d' % (block.strides[0], block.strides[1]), 90 | 'e%s' % block.expand_ratio, 91 | 'i%d' % block.in_channels, 92 | 'o%d' % block.out_channels 93 | ] 94 | 95 | if (block.se_ratio is not None and block.se_ratio > 0 and block.se_ratio <= 1): 96 | args.append('se%s' % block.se_ratio) 97 | if block.id_skip is False: 98 | args.append('noskip') 99 | if block.swish: 100 | args.append('sw') 101 | if block.dilated: 102 | args.append('dilated') 103 | return '_'.join(args) 104 | 105 | @staticmethod 106 | def decode(string_list, depth_multiplier, depth_divisor, min_depth): 107 | """Decodes a list of string notations to specify blocks inside the network. 108 | 109 | Args: 110 | string_list: a list of strings, each string is a notation of Mixnet 111 | block.build_model_base 112 | 113 | Returns: 114 | A list of namedtuples to represent Mixnet blocks arguments. 115 | """ 116 | assert isinstance(string_list, list) 117 | blocks_args = [] 118 | for block_string in string_list: 119 | blocks_args.append(MixnetDecoder._decode_block_string(block_string, depth_multiplier, depth_divisor, min_depth)) 120 | return blocks_args 121 | 122 | @staticmethod 123 | def encode(blocks_args): 124 | """Encodes a list of Mixnet Blocks to a list of strings. 125 | 126 | Args: 127 | blocks_args: A list of namedtuples to represent Mixnet blocks arguments. 128 | Returns: 129 | a list of strings, each string is a notation of Mixnet block. 130 | """ 131 | block_strings = [] 132 | for block in blocks_args: 133 | block_strings.append(MixnetDecoder._encode_block_string(block)) 134 | return block_strings 135 | -------------------------------------------------------------------------------- /ema_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | from glob import glob 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class EMARunner: 14 | def __init__(self, model_type, save_dir, epochs, net, optim, device, loss, logger, scheduler=None): 15 | self.save_dir = save_dir 16 | self.model_type = model_type 17 | self.epochs = epochs 18 | 19 | self.logger = logger 20 | 21 | self.device = device 22 | 23 | self.ema = copy.deepcopy(net.module.cpu()) 24 | self.ema.eval() 25 | for p in self.ema.parameters(): 26 | p.requires_grad_(False) 27 | self.ema_decay = 0.999 28 | 29 | self.net = net.to(device) 30 | self.loss = loss 31 | self.optim = optim 32 | self.scheduler = scheduler 33 | 34 | self.start_epoch = 0 35 | self.best_metric = -1 36 | 37 | self.load() 38 | 39 | def save(self, epoch, filename="train"): 40 | """Save current epoch model 41 | Save Elements: 42 | model_type : arg.model 43 | start_epoch : current epoch 44 | network : network parameters 45 | optimizer: optimizer parameters 46 | best_metric : current best score 47 | Parameters: 48 | epoch : current epoch 49 | filename : model save file name 50 | """ 51 | torch.save({"model_type": self.model_type, 52 | "start_epoch": epoch + 1, 53 | "network": self.net.module.state_dict(), 54 | "ema": self.ema.state_dict(), 55 | "optimizer": self.optim.state_dict(), 56 | "best_metric": self.best_metric 57 | }, self.save_dir + "/%s.pth.tar" % (filename)) 58 | print("Model saved %d epoch" % (epoch)) 59 | 60 | def load(self, filename=""): 61 | """ Model load. same with save""" 62 | if filename == "": 63 | # load last epoch model 64 | filenames = sorted(glob(self.save_dir + "/*.pth.tar")) 65 | if len(filenames) == 0: 66 | print("Not Load") 67 | return 68 | else: 69 | filename = os.path.basename(filenames[-1]) 70 | 71 | file_path = self.save_dir + "/" + filename 72 | if os.path.exists(file_path) is True: 73 | print("Load %s to %s File" % (self.save_dir, filename)) 74 | ckpoint = torch.load(file_path) 75 | if ckpoint["model_type"] != self.model_type: 76 | raise ValueError("Ckpoint Model Type is %s" % 77 | (ckpoint["model_type"])) 78 | 79 | self.net.module.load_state_dict(ckpoint['network']) 80 | self.ema.load_state_dict(ckpoint['ema']) 81 | self.optim.load_state_dict(ckpoint['optimizer']) 82 | self.start_epoch = ckpoint['start_epoch'] 83 | self.best_metric = ckpoint["best_metric"] 84 | print("Load Model Type : %s, epoch : %d acc : %f" % 85 | (ckpoint["model_type"], self.start_epoch, self.best_metric)) 86 | else: 87 | print("Load Failed, not exists file") 88 | 89 | @torch.no_grad() 90 | def update_ema(self): 91 | net_state = self.net.module.state_dict() 92 | ema_state = self.ema.state_dict() 93 | for k, v in ema_state.items(): 94 | net_v = net_state[k].detach().cpu() 95 | v.copy_(v * self.ema_decay + net_v * (1 - self.ema_decay)) 96 | 97 | def train(self, train_loader, val_loader=None): 98 | print("\nStart Train len :", len(train_loader.dataset)) 99 | for epoch in range(self.start_epoch, self.epochs): 100 | self.net.train() 101 | for i, (input_, target_) in enumerate(train_loader): 102 | target_ = target_.to(self.device, non_blocking=True) 103 | input_ = input_.to(self.device) 104 | 105 | if self.scheduler: 106 | self.scheduler.step() 107 | 108 | out = self.net(input_) 109 | loss = self.loss(out, target_) 110 | 111 | self.optim.zero_grad() 112 | loss.backward() 113 | self.optim.step() 114 | self.update_ema() 115 | 116 | if (i % 50) == 0: 117 | self.logger.log_write("train", epoch=epoch, loss=loss.item()) 118 | 119 | if val_loader is not None: 120 | self.valid(epoch, val_loader) 121 | 122 | def _get_acc(self, loader): 123 | correct = 0 124 | with torch.no_grad(): 125 | self.net.eval() 126 | for input_, target_ in loader: 127 | out = self.ema(input_) 128 | out = F.softmax(out, dim=1).cpu() 129 | 130 | _, idx = out.max(dim=1) 131 | correct += (target_ == idx).sum().item() 132 | 133 | return correct / len(loader.dataset) 134 | 135 | def valid(self, epoch, val_loader): 136 | acc = self._get_acc(val_loader) 137 | self.logger.log_write("valid", epoch=epoch, acc=acc) 138 | 139 | if acc > self.best_metric: 140 | self.best_metric = acc 141 | self.save(epoch, "epoch[%05d]_acc[%.4f]" % ( 142 | epoch, acc)) 143 | 144 | def test(self, train_loader, val_loader): 145 | print("\n Start Test") 146 | self.load() 147 | train_acc = self._get_acc(train_loader) 148 | valid_acc = self._get_acc(val_loader) 149 | self.logger.log_write("test", fname="test", train_acc=train_acc, valid_acc=valid_acc) 150 | return train_acc, valid_acc 151 | -------------------------------------------------------------------------------- /models/mixnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.utils import BlockArgs 5 | from models.utils import round_filters 6 | from models.utils import MixnetDecoder 7 | 8 | from models.layers import SEModule 9 | from models.layers import Swish 10 | from models.layers import Flatten 11 | 12 | from models.mdconv import MDConv 13 | 14 | 15 | class MixBlock(nn.Module): 16 | def __init__(self, dw_ksize, expand_ksize, project_ksize, 17 | in_channels, out_channels, expand_ratio, id_skip, 18 | strides, se_ratio, swish, dilated): 19 | super().__init__() 20 | 21 | self.id_skip = id_skip and all(s == 1 for s in strides) and in_channels == out_channels 22 | 23 | act_fn = lambda : Swish() if swish else nn.ReLU(True) 24 | 25 | layers = [] 26 | expaned_ch = in_channels * expand_ratio 27 | if expand_ratio != 1: 28 | expand = nn.Sequential( 29 | nn.Conv2d(in_channels, expaned_ch, expand_ksize, bias=False), 30 | nn.BatchNorm2d(expaned_ch), 31 | act_fn(), 32 | ) 33 | layers.append(expand) 34 | 35 | depthwise = nn.Sequential( 36 | MDConv(expaned_ch, dw_ksize, strides, bias=False), 37 | nn.BatchNorm2d(expaned_ch), 38 | act_fn(), 39 | ) 40 | layers.append(depthwise) 41 | 42 | if se_ratio > 0: 43 | se = SEModule(expaned_ch, int(expaned_ch * se_ratio)) 44 | layers.append(se) 45 | 46 | project = nn.Sequential( 47 | nn.Conv2d(expaned_ch, out_channels, project_ksize, bias=False), 48 | nn.BatchNorm2d(out_channels), 49 | ) 50 | layers.append(project) 51 | 52 | self.layers = nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | out = self.layers(x) 56 | if self.id_skip: 57 | out = out + x 58 | return out 59 | 60 | 61 | class MixModule(nn.Module): 62 | def __init__(self, dw_ksize, expand_ksize, project_ksize, num_repeat, 63 | in_channels, out_channels, expand_ratio, id_skip, 64 | strides, se_ratio, swish, dilated): 65 | super().__init__() 66 | layers = [MixBlock(dw_ksize, expand_ksize, project_ksize, 67 | in_channels, out_channels, expand_ratio, id_skip, 68 | strides, se_ratio, swish, dilated)] 69 | 70 | for _ in range(num_repeat - 1): 71 | layers.append(MixBlock(dw_ksize, expand_ksize, project_ksize, 72 | in_channels, out_channels, expand_ratio, id_skip, 73 | [1, 1], se_ratio, swish, dilated)) 74 | self.layers = nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | return self.layers(x) 78 | 79 | 80 | class MixNet(nn.Module): 81 | def __init__(self, stem, blocks_args, head, dropout_rate, num_classes=1000): 82 | super().__init__() 83 | 84 | self.stem = nn.Sequential( 85 | nn.Conv2d(3, stem, 3, 2, 1, bias=False), 86 | nn.BatchNorm2d(stem), 87 | nn.ReLU(True) 88 | ) 89 | 90 | self.blocks = nn.Sequential(*[MixModule(*args) for args in blocks_args]) 91 | 92 | self.classifier = nn.Sequential( 93 | nn.Conv2d(blocks_args[-1].out_channels, head, 1, bias=False), 94 | nn.BatchNorm2d(head), 95 | nn.ReLU(True), 96 | nn.AdaptiveAvgPool2d(1), 97 | Flatten(), 98 | nn.Dropout(dropout_rate), 99 | nn.Linear(head, num_classes) 100 | ) 101 | 102 | self.init_weights() 103 | 104 | def init_weights(self): 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu') 108 | elif isinstance(m, nn.Linear): 109 | nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') 110 | 111 | def forward(self, x): 112 | # print("Input : ", x.shape) 113 | stem = self.stem(x) 114 | # print("Stem : ", x.shape) 115 | feature = self.blocks(stem) 116 | # print("feature : ", feature.shape) 117 | out = self.classifier(feature) 118 | return out 119 | 120 | 121 | def mixnet_s(depth_multiplier=1, depth_divisor=8, min_depth=None, num_classes=1000): 122 | """ 123 | Creates mixnet-s model. 124 | 125 | Args: 126 | depth_multiplier: depth_multiplier to number of filters per layer. 127 | """ 128 | stem = round_filters(16, depth_multiplier, depth_divisor, min_depth) 129 | head = round_filters(1536, depth_multiplier, depth_divisor, min_depth) 130 | dropout = 0.2 131 | 132 | blocks_args = [ 133 | 'r1_k3_a1_p1_s11_e1_i16_o16', 134 | 'r1_k3_a1.1_p1.1_s22_e6_i16_o24', 135 | 'r1_k3_a1.1_p1.1_s11_e3_i24_o24', 136 | 137 | 'r1_k3.5.7_a1_p1_s22_e6_i24_o40_se0.5_sw', 138 | 'r3_k3.5_a1.1_p1.1_s11_e6_i40_o40_se0.5_sw', 139 | 140 | 'r1_k3.5.7_a1_p1.1_s22_e6_i40_o80_se0.25_sw', 141 | 'r2_k3.5_a1_p1.1_s11_e6_i80_o80_se0.25_sw', 142 | 143 | 'r1_k3.5.7_a1.1_p1.1_s11_e6_i80_o120_se0.5_sw', 144 | 'r2_k3.5.7.9_a1.1_p1.1_s11_e3_i120_o120_se0.5_sw', 145 | 146 | 'r1_k3.5.7.9.11_a1_p1_s22_e6_i120_o200_se0.5_sw', 147 | 'r2_k3.5.7.9_a1_p1.1_s11_e6_i200_o200_se0.5_sw', 148 | ] 149 | 150 | blocks_args = MixnetDecoder.decode(blocks_args, depth_multiplier, depth_divisor, min_depth) 151 | print("-----------") 152 | print("Mixnet S") 153 | for a in blocks_args: 154 | print(a) 155 | print("-----------") 156 | return MixNet(stem, blocks_args, head, dropout, num_classes=num_classes) 157 | 158 | 159 | if __name__ == "__main__": 160 | mixnet_s() 161 | --------------------------------------------------------------------------------