├── .vscode
├── settings.json
└── launch.json
├── models
├── layers.py
├── mdconv.py
├── utils.py
└── mixnet.py
├── readme.md
├── loader.py
├── .gitignore
├── logger.py
├── main.py
├── runner.py
└── ema_runner.py
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "workbench.colorCustomizations": {}
3 | }
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Swish(nn.Module):
6 | def forward(self, x):
7 | return x * torch.sigmoid(x)
8 |
9 |
10 | class Flatten(nn.Module):
11 | def forward(self, x):
12 | return x.view(x.shape[0], -1)
13 |
14 |
15 | class SEModule(nn.Module):
16 | def __init__(self, ch, squeeze_ch):
17 | super().__init__()
18 | self.se = nn.Sequential(
19 | nn.AdaptiveAvgPool2d(1),
20 | nn.Conv2d(ch, squeeze_ch, 1, 1, 0, bias=True),
21 | Swish(),
22 | nn.Conv2d(squeeze_ch, ch, 1, 1, 0, bias=True),
23 | )
24 |
25 | def forward(self, x):
26 | return x * torch.sigmoid(self.se(x))
27 |
--------------------------------------------------------------------------------
/models/mdconv.py:
--------------------------------------------------------------------------------
1 | # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def _split_channels(total_filters, num_groups):
8 | """
9 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py#L33
10 | """
11 | split = [total_filters // num_groups for _ in range(num_groups)]
12 | split[0] += total_filters - sum(split)
13 | return split
14 |
15 |
16 | class MDConv(nn.Module):
17 | def __init__(self, in_channels, kernel_sizes, stride, dilatied=False, bias=False):
18 | super().__init__()
19 |
20 | if not isinstance(kernel_sizes, list):
21 | kernel_sizes = [kernel_sizes]
22 |
23 | self.in_channels = _split_channels(in_channels, len(kernel_sizes))
24 |
25 | self.convs = nn.ModuleList()
26 | for ch, k in zip(self.in_channels, kernel_sizes):
27 | dilation = 1
28 | if stride[0] == 1 and dilatied:
29 | dilation, stride = (k - 1) // 2, 3
30 | print("Use dilated conv with dilation rate = {}".format(dilation))
31 | pad = ((stride[0] - 1) + dilation * (k - 1)) // 2
32 |
33 | conv = nn.Conv2d(ch, ch, k, stride, pad, dilation,
34 | groups=ch, bias=bias)
35 | self.convs.append(conv)
36 |
37 | def forward(self, x):
38 | xs = torch.split(x, self.in_channels, 1)
39 | return torch.cat([conv(x) for conv, x in zip(self.convs, xs)], 1)
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Mixnet
2 |
3 | A PyTorch implementation of `MixNet: Mixed Depthwise Convolutional Kernels.`
4 |
5 |
6 | ### [[arxiv]](https://arxiv.org/abs/1907.09595) [[Official TF Repo]](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet)
7 |
8 |
9 |
10 | ## Acknowledge
11 |
12 | Now EMA is running on CPU. So It slower than normal runner.
13 |
14 | If you running on GPU, then change these lines [init](ema_runner.py#23), [update_ema](ema_runner.py#96)
15 |
16 |
17 |
18 | ## How to use:
19 |
20 | ```
21 | python3 main.py -h
22 | usage: main.py [-h] --save_dir SAVE_DIR [--root ROOT] [--gpus GPUS]
23 | [--num_workers NUM_WORKERS] [--model {mixs}] [--epoch EPOCH]
24 | [--batch_size BATCH_SIZE] [--test] [--ema_decay EMA_DECAY]
25 | [--optim {rmsprop,adam}] [--lr LR] [--beta [BETA [BETA ...]]]
26 | [--momentum MOMENTUM] [--eps EPS] [--decay DECAY]
27 | [--scheduler {exp,cosine,none}]
28 |
29 | Pytorch Mixnet
30 |
31 | optional arguments:
32 | -h, --help show this help message and exit
33 | --save_dir SAVE_DIR Directory name to save the model
34 | --root ROOT The Directory of data path.
35 | --gpus GPUS Select GPU Numbers | 0,1,2,3 |
36 | --num_workers NUM_WORKERS
37 | Select CPU Number workers
38 | --model {mixs} The type of mixnet.
39 | --epoch EPOCH The number of epochs
40 | --batch_size BATCH_SIZE
41 | The size of batch
42 | --test Only Test
43 | --ema_decay EMA_DECAY
44 | Exponential Moving Average Term
45 | --optim {rmsprop,adam}
46 | --lr LR Base learning rate when train batch size is 256.
47 | --beta [BETA [BETA ...]]
48 | --momentum MOMENTUM
49 | --eps EPS
50 | --decay DECAY
51 | --scheduler {exp,cosine,none}
52 | Learning rate scheduler type
53 | ```
54 |
--------------------------------------------------------------------------------
/loader.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import MNIST
2 | from torchvision.datasets import CIFAR10, CIFAR100
3 | from torchvision.datasets import ImageFolder
4 |
5 | from torch.utils.data import DataLoader
6 |
7 | from torchvision import transforms as T
8 |
9 |
10 | def get_dataset(root, dtype="cifar10", resl=224):
11 | tr = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
12 | if dtype == "mnist":
13 | dset = MNIST
14 | elif dtype == "cifar10":
15 | dset = CIFAR10
16 | elif dtype == "cifar100":
17 | dset = CIFAR100
18 | elif dtype == "imagenet":
19 | return imagenet(root, resl)
20 |
21 | train = dset(root, True, transform=tr, download=True)
22 | valid = dset(root, False, transform=tr)
23 | return train, valid
24 |
25 |
26 | def imagenet(root, resl):
27 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
28 | std=[0.229, 0.224, 0.225])
29 | train = ImageFolder(
30 | root + "/train",
31 | T.Compose([
32 | T.Resize([resl, resl]),
33 | T.RandomResizedCrop(resl),
34 | T.RandomHorizontalFlip(),
35 | T.ToTensor(),
36 | normalize,
37 | ])
38 | )
39 |
40 | valid = ImageFolder(
41 | root + "/val",
42 | T.Compose([
43 | T.Resize([resl, resl]),
44 | T.ToTensor(),
45 | normalize,
46 | ])
47 | )
48 |
49 | return train, valid
50 |
51 |
52 | def get_loaders(root, batch_size, num_workers=32, dtype="cifar10", resl=224):
53 | train, valid = get_dataset(root, dtype, resl)
54 |
55 | train_loader = DataLoader(train,
56 | batch_size=batch_size, shuffle=True,
57 | num_workers=num_workers, pin_memory=True
58 | )
59 |
60 | val_loader = DataLoader(valid,
61 | batch_size=batch_size, shuffle=False,
62 | num_workers=num_workers, pin_memory=True
63 | )
64 | return train_loader, val_loader
65 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Current File (Integrated Terminal)",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal",
13 | },
14 | {
15 | "name": "Python: Remote Attach",
16 | "type": "python",
17 | "request": "attach",
18 | "port": 5678,
19 | "host": "localhost",
20 | "pathMappings": [
21 | {
22 | "localRoot": "${workspaceFolder}",
23 | "remoteRoot": "."
24 | }
25 | ]
26 | },
27 | {
28 | "name": "Python: Module",
29 | "type": "python",
30 | "request": "launch",
31 | "module": "enter-your-module-name-here",
32 | "console": "integratedTerminal"
33 | },
34 | {
35 | "name": "Python: Django",
36 | "type": "python",
37 | "request": "launch",
38 | "program": "${workspaceFolder}/manage.py",
39 | "console": "integratedTerminal",
40 | "args": [
41 | "runserver",
42 | "--noreload",
43 | "--nothreading"
44 | ],
45 | "django": true
46 | },
47 | {
48 | "name": "Python: Flask",
49 | "type": "python",
50 | "request": "launch",
51 | "module": "flask",
52 | "env": {
53 | "FLASK_APP": "app.py"
54 | },
55 | "args": [
56 | "run",
57 | "--no-debugger",
58 | "--no-reload"
59 | ],
60 | "jinja": true
61 | },
62 | {
63 | "name": "Python: Current File (External Terminal)",
64 | "type": "python",
65 | "request": "launch",
66 | "program": "${file}",
67 | "console": "externalTerminal"
68 | }
69 | ]
70 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | tmp.md
2 | tmp
3 | tmp/*
4 | data
5 | data/*
6 | outs
7 | outs/*
8 |
9 | .DS_Store
10 | *.ipynb
11 | *.npy
12 | *.jpg
13 | *.mat
14 | *.png
15 | *.pyc
16 | *.ubuntu
17 |
18 | ##################################################################
19 | # github python ignore
20 | # https://github.com/github/gitignore/blob/master/Python.gitignore
21 |
22 | # Byte-compiled / optimized / DLL files
23 | __pycache__/
24 | *.py[cod]
25 | *$py.class
26 |
27 | # C extensions
28 | *.so
29 |
30 | # Distribution / packaging
31 | .Python
32 | build/
33 | develop-eggs/
34 | dist/
35 | downloads/
36 | eggs/
37 | .eggs/
38 | lib/
39 | lib64/
40 | parts/
41 | sdist/
42 | var/
43 | wheels/
44 | pip-wheel-metadata/
45 | share/python-wheels/
46 | *.egg-info/
47 | .installed.cfg
48 | *.egg
49 | MANIFEST
50 |
51 | # PyInstaller
52 | # Usually these files are written by a python script from a template
53 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
54 | *.manifest
55 | *.spec
56 |
57 | # Installer logs
58 | pip-log.txt
59 | pip-delete-this-directory.txt
60 |
61 | # Unit test / coverage reports
62 | htmlcov/
63 | .tox/
64 | .nox/
65 | .coverage
66 | .coverage.*
67 | .cache
68 | nosetests.xml
69 | coverage.xml
70 | *.cover
71 | .hypothesis/
72 | .pytest_cache/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | local_settings.py
81 | db.sqlite3
82 |
83 | # Flask stuff:
84 | instance/
85 | .webassets-cache
86 |
87 | # Scrapy stuff:
88 | .scrapy
89 |
90 | # Sphinx documentation
91 | docs/_build/
92 |
93 | # PyBuilder
94 | target/
95 |
96 | # Jupyter Notebook
97 | .ipynb_checkpoints
98 |
99 | # IPython
100 | profile_default/
101 | ipython_config.py
102 |
103 | # pyenv
104 | .python-version
105 |
106 | # pipenv
107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
109 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
110 | # install all needed dependencies.
111 | #Pipfile.lock
112 |
113 | # celery beat schedule file
114 | celerybeat-schedule
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 | .dmypy.json
141 | dmypy.json
142 |
143 | # Pyre type checker
144 | .pyre/
145 | ################################################################
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from cycler import cycler
4 | from collections import OrderedDict
5 |
6 | import matplotlib.pyplot as plt
7 |
8 | import numpy as np
9 |
10 |
11 | # x axis of plot
12 | LOG_KEYS = {
13 | "train":"epoch",
14 | "valid":"epoch",
15 | "test": "fname"
16 | }
17 |
18 | # y axis of plot
19 | # save datas like loss, f1-score, PSNR, SSIM ..
20 | # can multiple datas
21 | LOG_VALUES = {
22 | "train":["loss", ],
23 | "valid":["acc","valid_acc"],
24 | "test": ["train_acc", "valid_acc"]
25 | }
26 |
27 |
28 | class Logger:
29 | def __init__(self, save_dir):
30 | self.save_dir = save_dir
31 | self.log_file = save_dir + "/log.txt"
32 | self.buffers = []
33 |
34 | def will_write(self, line):
35 | print(line)
36 | self.buffers.append(line)
37 |
38 | def flush(self):
39 | with open(self.log_file, "a", encoding="utf-8") as f:
40 | f.write("\n".join(self.buffers))
41 | f.write("\n")
42 | self.buffers = []
43 |
44 | def write(self, line):
45 | self.will_write(line)
46 | self.flush()
47 |
48 | def log_write(self, learn_type, **values):
49 | """log write in buffers
50 |
51 | ex ) log_write("train", epoch=1, loss=0.3)
52 |
53 | Parmeters:
54 | learn_type : it must be train, valid or test
55 | values : values keys in LOG_VALUES
56 | """
57 | for k in values.keys():
58 | if k not in LOG_VALUES[learn_type] and k != LOG_KEYS[learn_type]:
59 | raise KeyError("%s Log %s keys not in log" % (learn_type, k))
60 |
61 | log = "[%s] %s" % (learn_type, json.dumps(values))
62 | self.will_write(log)
63 | if learn_type != "train":
64 | self.flush()
65 |
66 | def log_parse(self, log_key):
67 | log_dict = OrderedDict()
68 | with open(self.log_file, "r", encoding="utf-8") as f:
69 | for line in f.readlines():
70 | if len(line) == 1 or not line.startswith("[%s]" % (log_key)):
71 | continue
72 | # line : ~~
73 | line = line[line.find("] ") + 2:] # ~~
74 | line_log = json.loads(line)
75 |
76 | train_log_key = line_log[LOG_KEYS[log_key]]
77 | line_log.pop(LOG_KEYS[log_key], None)
78 | log_dict[train_log_key] = line_log
79 | return log_dict
80 |
81 | def log_plot(self, log_key,
82 | figsize=(12, 12), title="plot", colors=["C1", "C2"]):
83 | fig = plt.figure(figsize=figsize)
84 | plt.title(title)
85 | plt.legend(LOG_VALUES[log_key], loc="best")
86 |
87 | ax = plt.subplot(111)
88 | colors = plt.cm.nipy_spectral(np.linspace(0.1, 0.9, len(LOG_VALUES[log_key])))
89 | ax.set_prop_cycle(cycler('color', colors))
90 |
91 | log_dict = self.log_parse(log_key)
92 | x = log_dict.keys()
93 | for keys in LOG_VALUES[log_key]:
94 | if keys not in list(log_dict.values())[0]:
95 | continue
96 | y = [v[keys] for v in log_dict.values()]
97 |
98 | label = keys + ", max : %f" % (max(y))
99 | ax.plot(x, y, marker="o", linestyle="solid", label=label)
100 | if max(y) > 1:
101 | ax.set_ylim([min(y) - 1, y[0] + 1])
102 | ax.legend(fontsize=30)
103 |
104 | plt.show()
105 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import argparse
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from torch.optim.lr_scheduler import StepLR
9 | from torch.optim.lr_scheduler import CosineAnnealingLR
10 |
11 | from models.mixnet import mixnet_s
12 |
13 | from ema_runner import EMARunner
14 | from runner import Runner
15 |
16 | from loader import get_loaders
17 |
18 | from logger import Logger
19 |
20 |
21 | def arg_parse():
22 | # projects description
23 | desc = "Pytorch Mixnet"
24 | parser = argparse.ArgumentParser(description=desc)
25 | parser.add_argument('--save_dir', type=str, required=True,
26 | help='Directory name to save the model')
27 |
28 | parser.add_argument('--dtype', type=str, default="cifar10", choice=["cifar10", "cifar100", "imagenet"])
29 | parser.add_argument('--ema', action="store_true", help="Exponential Moving Average")
30 |
31 | parser.add_argument('--root', type=str, default="/data1/imagenet",
32 | help="The Directory of data path.")
33 | parser.add_argument('--gpus', type=str, default="0,1,2,3",
34 | help="Select GPU Numbers | 0,1,2,3 | ")
35 | parser.add_argument('--num_workers', type=int, default=32,
36 | help="Select CPU Number workers")
37 |
38 | parser.add_argument('--model', type=str, default='mixs', help='The type of mixnet.')
39 |
40 | parser.add_argument('--epoch', type=int, default=350, help='The number of epochs')
41 | parser.add_argument('--batch_size', type=int, default=1024, help='The size of batch')
42 | parser.add_argument('--test', action="store_true", help='Only Test')
43 |
44 | parser.add_argument('--optim', type=str, default='adam', choices=["rmsprop", "adam"])
45 | parser.add_argument('--lr', type=float, default=0.016, help="Base learning rate when train batch size is 256.")
46 | # Adam Optimizer
47 | parser.add_argument('--beta', nargs="*", type=float, default=(0.5, 0.999))
48 |
49 | parser.add_argument('--momentum', type=float, default=0.9)
50 | parser.add_argument('--eps', type=float, default=0.001)
51 | parser.add_argument('--decay', type=float, default=1e-5)
52 |
53 | parser.add_argument('--scheduler', type=str, default='exp', choices=["exp", "cosine", "none"],
54 | help="Learning rate scheduler type")
55 |
56 | return parser.parse_args()
57 |
58 |
59 | def get_scheduler(optim, sche_type, step_size, t_max):
60 | print("No Scheduler")
61 | return None
62 |
63 | if sche_type == "exp":
64 | return StepLR(optim, step_size, 0.97)
65 | elif sche_type == "cosine":
66 | return CosineAnnealingLR(optim, t_max)
67 | else:
68 | return None
69 |
70 |
71 | if __name__ == "__main__":
72 | arg = arg_parse()
73 |
74 | arg.save_dir = "%s/outs/%s" % (os.getcwd(), arg.save_dir)
75 | if os.path.exists(arg.save_dir) is False:
76 | os.mkdir(arg.save_dir)
77 |
78 | logger = Logger(arg.save_dir)
79 | logger.will_write(str(arg) + "\n")
80 |
81 | os.environ["CUDA_VISIBLE_DEVICES"] = arg.gpus
82 | device = torch.device("cuda")
83 | train_loader, val_loader = get_loaders(arg.root, arg.batch_size, arg.num_workers,
84 | dtype=arg.dtype)
85 |
86 | if arg.model == "mixs":
87 | net = mixnet_s(num_classes=len(train_loader.dataset.classes))
88 | elif arg.model == "rw":
89 | import sys
90 | sys.path.append("rwightman")
91 | from timm.models.gen_efficientnet import mixnet_s
92 | net = mixnet_s(num_classes=len(train_loader.dataset.classes))
93 | else:
94 | from torchvision.models import resnet50
95 | net = resnet50(num_classes=len(train_loader.dataset.classes))
96 |
97 | net = nn.DataParallel(net)
98 | loss = nn.CrossEntropyLoss()
99 |
100 | scaled_lr = arg.lr * arg.batch_size / 256
101 | optim = {
102 | "adam" : lambda : torch.optim.Adam(net.parameters()),
103 | "rmsprop" : lambda : torch.optim.RMSprop(net.parameters(), lr=scaled_lr, momentum=arg.momentum, eps=arg.eps, weight_decay=arg.decay)
104 | }[arg.optim]()
105 |
106 | scheduler = get_scheduler(optim, arg.scheduler, int(2.4 * len(train_loader)), arg.epoch * len(train_loader))
107 |
108 | if arg.ema:
109 | Runner = EMARunner
110 |
111 | run = Runner(arg.model, arg.save_dir, arg.epoch,
112 | net, optim, device, loss, logger, scheduler)
113 | if arg.test is False:
114 | run.train(train_loader, val_loader)
115 | run.test(train_loader, val_loader)
116 |
--------------------------------------------------------------------------------
/runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import time
4 | from glob import glob
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class Runner():
14 | def __init__(self, model_type, save_dir, epochs, net, optim, device, loss, logger, scheduler=None):
15 | self.model_type = model_type
16 | self.save_dir = save_dir
17 | self.epochs = epochs
18 |
19 | self.logger = logger
20 |
21 | self.device = device
22 |
23 | self.net = net.to(device)
24 |
25 | self.loss = loss
26 | self.optim = optim
27 | self.scheduler = scheduler
28 |
29 | self.start_epoch = 0
30 | self.best_metric = -1
31 |
32 | self.load()
33 |
34 | def save(self, epoch, filename="train"):
35 | """Save current epoch model
36 |
37 | Save Elements:
38 | model_type : arg.model
39 | start_epoch : current epoch
40 | network : network parameters
41 | optimizer: optimizer parameters
42 | best_metric : current best score
43 |
44 | Parameters:
45 | epoch : current epoch
46 | filename : model save file name
47 | """
48 | print("Model saved %d epoch" % (epoch))
49 | return
50 | torch.save({"model_type": self.model_type,
51 | "start_epoch": epoch + 1,
52 | "network": self.net.module.state_dict(),
53 | "optimizer": self.optim.state_dict(),
54 | "best_metric": self.best_metric
55 | }, self.save_dir + "/%s.pth.tar" % (filename))
56 | print("Model saved %d epoch" % (epoch))
57 |
58 | def load(self, filename=""):
59 | """ Model load. same with save"""
60 | if filename == "":
61 | # load last epoch model
62 | filenames = sorted(glob(self.save_dir + "/*.pth.tar"))
63 | if len(filenames) == 0:
64 | print("Not Load")
65 | return
66 | else:
67 | filename = os.path.basename(filenames[-1])
68 |
69 | file_path = self.save_dir + "/" + filename
70 | if os.path.exists(file_path) is True:
71 | print("Load %s to %s File" % (self.save_dir, filename))
72 | ckpoint = torch.load(file_path)
73 | if ckpoint["model_type"] != self.model_type:
74 | raise ValueError("Ckpoint Model Type is %s" %
75 | (ckpoint["model_type"]))
76 |
77 | self.net.module.load_state_dict(ckpoint['network'])
78 | self.optim.load_state_dict(ckpoint['optimizer'])
79 | self.start_epoch = ckpoint['start_epoch']
80 | self.best_metric = ckpoint["best_metric"]
81 | print("Load Model Type : %s, epoch : %d acc : %f" %
82 | (ckpoint["model_type"], self.start_epoch, self.best_metric))
83 | else:
84 | print("Load Failed, not exists file")
85 |
86 | def train(self, train_loader, val_loader=None):
87 | print("\nStart Train len :", len(train_loader.dataset))
88 | for epoch in range(self.start_epoch, self.epochs):
89 | self.net.train()
90 | for i, (input_, target_) in enumerate(train_loader):
91 | target_ = target_.to(self.device, non_blocking=True)
92 |
93 | if self.scheduler:
94 | self.scheduler.step()
95 |
96 | out = self.net(input_)
97 | loss = self.loss(out, target_)
98 |
99 | self.optim.zero_grad()
100 | loss.backward()
101 | self.optim.step()
102 |
103 | if (i % 50) == 0:
104 | self.logger.log_write("train", epoch=epoch, loss=loss.item())
105 |
106 | if val_loader is not None:
107 | self.valid(epoch, val_loader)
108 |
109 | @torch.no_grad()
110 | def _get_acc(self, loader):
111 | correct = 0
112 | self.net.eval()
113 | for input_, target_ in loader:
114 | out = self.net(input_)
115 | out = F.softmax(out, dim=1).cpu()
116 |
117 | _, idx = out.max(dim=1)
118 | correct += (target_ == idx).sum().item()
119 |
120 | return correct / len(loader.dataset)
121 |
122 | def valid(self, epoch, val_loader):
123 | acc = self._get_acc(val_loader)
124 | self.logger.log_write("valid", epoch=epoch, acc=acc)
125 |
126 | if acc > self.best_metric:
127 | self.best_metric = acc
128 | self.save(epoch, "epoch[%05d]_acc[%.4f]" % (epoch, acc))
129 |
130 | def test(self, train_loader, val_loader):
131 | print("\n Start Test")
132 | self.load()
133 | train_acc = self._get_acc(train_loader)
134 | valid_acc = self._get_acc(val_loader)
135 | self.logger.log_write("test", fname="test", train_acc=train_acc, valid_acc=valid_acc)
136 | return train_acc, valid_acc
137 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from collections import namedtuple
3 |
4 | BlockArgs = namedtuple('BlockArgs', [
5 | 'dw_ksize', 'expand_ksize', 'project_ksize', 'num_repeat',
6 | 'in_channels', 'out_channels', 'expand_ratio', 'id_skip',
7 | 'strides', 'se_ratio', 'swish', 'dilated',
8 | ])
9 |
10 |
11 | def round_filters(filters, depth_multiplier, depth_divisor, min_depth):
12 | """Round number of filters based on depth depth_multiplier.
13 | TODO : ref link
14 | """
15 | if not depth_multiplier:
16 | return filters
17 |
18 | filters *= depth_multiplier
19 | min_depth = min_depth or depth_divisor
20 | new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor)
21 | # Make sure that round down does not go down by more than 10%.
22 | if new_filters < 0.9 * filters:
23 | new_filters += depth_divisor
24 | return new_filters
25 |
26 |
27 | class MixnetDecoder:
28 | """A class of Mixnet decoder to get model configuration."""
29 |
30 | @staticmethod
31 | def _decode_block_string(block_string, depth_multiplier, depth_divisor, min_depth):
32 | """Gets a mixnet block through a string notation of arguments.
33 |
34 | E.g. r2_k3_a1_p1_s2_e1_i32_o16_se0.25_noskip: r - number of repeat blocks,
35 | k - kernel size, s - strides (1-9), e - expansion ratio, i - input filters,
36 | o - output filters, se - squeeze/excitation ratio
37 |
38 | Args:
39 | block_string: a string, a string representation of block arguments.
40 |
41 | Returns:
42 | A BlockArgs instance.
43 | Raises:
44 | ValueError: if the strides option is not correctly specified.
45 | """
46 | assert isinstance(block_string, str)
47 |
48 | ops = block_string.split('_')
49 | options = {}
50 | for op in ops:
51 | splits = re.split(r'(\d.*)', op)
52 | if len(splits) >= 2:
53 | key, value = splits[:2]
54 | options[key] = value
55 |
56 | if 's' not in options or len(options['s']) != 2:
57 | raise ValueError('Strides options should be a pair of integers.')
58 |
59 | def _parse_ksize(ss):
60 | ks = [int(k) for k in ss.split('.')]
61 | return ks if len(ks) > 1 else ks[0]
62 |
63 | return BlockArgs(num_repeat=int(options['r']),
64 | dw_ksize=_parse_ksize(options['k']),
65 | expand_ksize=_parse_ksize(options['a']),
66 | project_ksize=_parse_ksize(options['p']),
67 | strides=[int(options['s'][0]), int(options['s'][1])],
68 | expand_ratio=int(options['e']),
69 | in_channels=round_filters(int(options['i']), depth_multiplier, depth_divisor, min_depth),
70 | out_channels=round_filters(int(options['o']), depth_multiplier, depth_divisor, min_depth),
71 | id_skip=('noskip' not in block_string),
72 | se_ratio=float(options['se']) if 'se' in options else 0,
73 | swish=('sw' in block_string),
74 | dilated=('dilated' in block_string)
75 | )
76 |
77 | @staticmethod
78 | def _encode_block_string(block):
79 | """Encodes a Mixnet block to a string."""
80 |
81 | def _encode_ksize(arr):
82 | return '.'.join([str(k) for k in arr])
83 |
84 | args = [
85 | 'r%d' % block.num_repeat,
86 | 'k%s' % _encode_ksize(block.dw_ksize),
87 | 'a%s' % _encode_ksize(block.expand_ksize),
88 | 'p%s' % _encode_ksize(block.project_ksize),
89 | 's%d%d' % (block.strides[0], block.strides[1]),
90 | 'e%s' % block.expand_ratio,
91 | 'i%d' % block.in_channels,
92 | 'o%d' % block.out_channels
93 | ]
94 |
95 | if (block.se_ratio is not None and block.se_ratio > 0 and block.se_ratio <= 1):
96 | args.append('se%s' % block.se_ratio)
97 | if block.id_skip is False:
98 | args.append('noskip')
99 | if block.swish:
100 | args.append('sw')
101 | if block.dilated:
102 | args.append('dilated')
103 | return '_'.join(args)
104 |
105 | @staticmethod
106 | def decode(string_list, depth_multiplier, depth_divisor, min_depth):
107 | """Decodes a list of string notations to specify blocks inside the network.
108 |
109 | Args:
110 | string_list: a list of strings, each string is a notation of Mixnet
111 | block.build_model_base
112 |
113 | Returns:
114 | A list of namedtuples to represent Mixnet blocks arguments.
115 | """
116 | assert isinstance(string_list, list)
117 | blocks_args = []
118 | for block_string in string_list:
119 | blocks_args.append(MixnetDecoder._decode_block_string(block_string, depth_multiplier, depth_divisor, min_depth))
120 | return blocks_args
121 |
122 | @staticmethod
123 | def encode(blocks_args):
124 | """Encodes a list of Mixnet Blocks to a list of strings.
125 |
126 | Args:
127 | blocks_args: A list of namedtuples to represent Mixnet blocks arguments.
128 | Returns:
129 | a list of strings, each string is a notation of Mixnet block.
130 | """
131 | block_strings = []
132 | for block in blocks_args:
133 | block_strings.append(MixnetDecoder._encode_block_string(block))
134 | return block_strings
135 |
--------------------------------------------------------------------------------
/ema_runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import time
4 | from glob import glob
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class EMARunner:
14 | def __init__(self, model_type, save_dir, epochs, net, optim, device, loss, logger, scheduler=None):
15 | self.save_dir = save_dir
16 | self.model_type = model_type
17 | self.epochs = epochs
18 |
19 | self.logger = logger
20 |
21 | self.device = device
22 |
23 | self.ema = copy.deepcopy(net.module.cpu())
24 | self.ema.eval()
25 | for p in self.ema.parameters():
26 | p.requires_grad_(False)
27 | self.ema_decay = 0.999
28 |
29 | self.net = net.to(device)
30 | self.loss = loss
31 | self.optim = optim
32 | self.scheduler = scheduler
33 |
34 | self.start_epoch = 0
35 | self.best_metric = -1
36 |
37 | self.load()
38 |
39 | def save(self, epoch, filename="train"):
40 | """Save current epoch model
41 | Save Elements:
42 | model_type : arg.model
43 | start_epoch : current epoch
44 | network : network parameters
45 | optimizer: optimizer parameters
46 | best_metric : current best score
47 | Parameters:
48 | epoch : current epoch
49 | filename : model save file name
50 | """
51 | torch.save({"model_type": self.model_type,
52 | "start_epoch": epoch + 1,
53 | "network": self.net.module.state_dict(),
54 | "ema": self.ema.state_dict(),
55 | "optimizer": self.optim.state_dict(),
56 | "best_metric": self.best_metric
57 | }, self.save_dir + "/%s.pth.tar" % (filename))
58 | print("Model saved %d epoch" % (epoch))
59 |
60 | def load(self, filename=""):
61 | """ Model load. same with save"""
62 | if filename == "":
63 | # load last epoch model
64 | filenames = sorted(glob(self.save_dir + "/*.pth.tar"))
65 | if len(filenames) == 0:
66 | print("Not Load")
67 | return
68 | else:
69 | filename = os.path.basename(filenames[-1])
70 |
71 | file_path = self.save_dir + "/" + filename
72 | if os.path.exists(file_path) is True:
73 | print("Load %s to %s File" % (self.save_dir, filename))
74 | ckpoint = torch.load(file_path)
75 | if ckpoint["model_type"] != self.model_type:
76 | raise ValueError("Ckpoint Model Type is %s" %
77 | (ckpoint["model_type"]))
78 |
79 | self.net.module.load_state_dict(ckpoint['network'])
80 | self.ema.load_state_dict(ckpoint['ema'])
81 | self.optim.load_state_dict(ckpoint['optimizer'])
82 | self.start_epoch = ckpoint['start_epoch']
83 | self.best_metric = ckpoint["best_metric"]
84 | print("Load Model Type : %s, epoch : %d acc : %f" %
85 | (ckpoint["model_type"], self.start_epoch, self.best_metric))
86 | else:
87 | print("Load Failed, not exists file")
88 |
89 | @torch.no_grad()
90 | def update_ema(self):
91 | net_state = self.net.module.state_dict()
92 | ema_state = self.ema.state_dict()
93 | for k, v in ema_state.items():
94 | net_v = net_state[k].detach().cpu()
95 | v.copy_(v * self.ema_decay + net_v * (1 - self.ema_decay))
96 |
97 | def train(self, train_loader, val_loader=None):
98 | print("\nStart Train len :", len(train_loader.dataset))
99 | for epoch in range(self.start_epoch, self.epochs):
100 | self.net.train()
101 | for i, (input_, target_) in enumerate(train_loader):
102 | target_ = target_.to(self.device, non_blocking=True)
103 | input_ = input_.to(self.device)
104 |
105 | if self.scheduler:
106 | self.scheduler.step()
107 |
108 | out = self.net(input_)
109 | loss = self.loss(out, target_)
110 |
111 | self.optim.zero_grad()
112 | loss.backward()
113 | self.optim.step()
114 | self.update_ema()
115 |
116 | if (i % 50) == 0:
117 | self.logger.log_write("train", epoch=epoch, loss=loss.item())
118 |
119 | if val_loader is not None:
120 | self.valid(epoch, val_loader)
121 |
122 | def _get_acc(self, loader):
123 | correct = 0
124 | with torch.no_grad():
125 | self.net.eval()
126 | for input_, target_ in loader:
127 | out = self.ema(input_)
128 | out = F.softmax(out, dim=1).cpu()
129 |
130 | _, idx = out.max(dim=1)
131 | correct += (target_ == idx).sum().item()
132 |
133 | return correct / len(loader.dataset)
134 |
135 | def valid(self, epoch, val_loader):
136 | acc = self._get_acc(val_loader)
137 | self.logger.log_write("valid", epoch=epoch, acc=acc)
138 |
139 | if acc > self.best_metric:
140 | self.best_metric = acc
141 | self.save(epoch, "epoch[%05d]_acc[%.4f]" % (
142 | epoch, acc))
143 |
144 | def test(self, train_loader, val_loader):
145 | print("\n Start Test")
146 | self.load()
147 | train_acc = self._get_acc(train_loader)
148 | valid_acc = self._get_acc(val_loader)
149 | self.logger.log_write("test", fname="test", train_acc=train_acc, valid_acc=valid_acc)
150 | return train_acc, valid_acc
151 |
--------------------------------------------------------------------------------
/models/mixnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.utils import BlockArgs
5 | from models.utils import round_filters
6 | from models.utils import MixnetDecoder
7 |
8 | from models.layers import SEModule
9 | from models.layers import Swish
10 | from models.layers import Flatten
11 |
12 | from models.mdconv import MDConv
13 |
14 |
15 | class MixBlock(nn.Module):
16 | def __init__(self, dw_ksize, expand_ksize, project_ksize,
17 | in_channels, out_channels, expand_ratio, id_skip,
18 | strides, se_ratio, swish, dilated):
19 | super().__init__()
20 |
21 | self.id_skip = id_skip and all(s == 1 for s in strides) and in_channels == out_channels
22 |
23 | act_fn = lambda : Swish() if swish else nn.ReLU(True)
24 |
25 | layers = []
26 | expaned_ch = in_channels * expand_ratio
27 | if expand_ratio != 1:
28 | expand = nn.Sequential(
29 | nn.Conv2d(in_channels, expaned_ch, expand_ksize, bias=False),
30 | nn.BatchNorm2d(expaned_ch),
31 | act_fn(),
32 | )
33 | layers.append(expand)
34 |
35 | depthwise = nn.Sequential(
36 | MDConv(expaned_ch, dw_ksize, strides, bias=False),
37 | nn.BatchNorm2d(expaned_ch),
38 | act_fn(),
39 | )
40 | layers.append(depthwise)
41 |
42 | if se_ratio > 0:
43 | se = SEModule(expaned_ch, int(expaned_ch * se_ratio))
44 | layers.append(se)
45 |
46 | project = nn.Sequential(
47 | nn.Conv2d(expaned_ch, out_channels, project_ksize, bias=False),
48 | nn.BatchNorm2d(out_channels),
49 | )
50 | layers.append(project)
51 |
52 | self.layers = nn.Sequential(*layers)
53 |
54 | def forward(self, x):
55 | out = self.layers(x)
56 | if self.id_skip:
57 | out = out + x
58 | return out
59 |
60 |
61 | class MixModule(nn.Module):
62 | def __init__(self, dw_ksize, expand_ksize, project_ksize, num_repeat,
63 | in_channels, out_channels, expand_ratio, id_skip,
64 | strides, se_ratio, swish, dilated):
65 | super().__init__()
66 | layers = [MixBlock(dw_ksize, expand_ksize, project_ksize,
67 | in_channels, out_channels, expand_ratio, id_skip,
68 | strides, se_ratio, swish, dilated)]
69 |
70 | for _ in range(num_repeat - 1):
71 | layers.append(MixBlock(dw_ksize, expand_ksize, project_ksize,
72 | in_channels, out_channels, expand_ratio, id_skip,
73 | [1, 1], se_ratio, swish, dilated))
74 | self.layers = nn.Sequential(*layers)
75 |
76 | def forward(self, x):
77 | return self.layers(x)
78 |
79 |
80 | class MixNet(nn.Module):
81 | def __init__(self, stem, blocks_args, head, dropout_rate, num_classes=1000):
82 | super().__init__()
83 |
84 | self.stem = nn.Sequential(
85 | nn.Conv2d(3, stem, 3, 2, 1, bias=False),
86 | nn.BatchNorm2d(stem),
87 | nn.ReLU(True)
88 | )
89 |
90 | self.blocks = nn.Sequential(*[MixModule(*args) for args in blocks_args])
91 |
92 | self.classifier = nn.Sequential(
93 | nn.Conv2d(blocks_args[-1].out_channels, head, 1, bias=False),
94 | nn.BatchNorm2d(head),
95 | nn.ReLU(True),
96 | nn.AdaptiveAvgPool2d(1),
97 | Flatten(),
98 | nn.Dropout(dropout_rate),
99 | nn.Linear(head, num_classes)
100 | )
101 |
102 | self.init_weights()
103 |
104 | def init_weights(self):
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')
108 | elif isinstance(m, nn.Linear):
109 | nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
110 |
111 | def forward(self, x):
112 | # print("Input : ", x.shape)
113 | stem = self.stem(x)
114 | # print("Stem : ", x.shape)
115 | feature = self.blocks(stem)
116 | # print("feature : ", feature.shape)
117 | out = self.classifier(feature)
118 | return out
119 |
120 |
121 | def mixnet_s(depth_multiplier=1, depth_divisor=8, min_depth=None, num_classes=1000):
122 | """
123 | Creates mixnet-s model.
124 |
125 | Args:
126 | depth_multiplier: depth_multiplier to number of filters per layer.
127 | """
128 | stem = round_filters(16, depth_multiplier, depth_divisor, min_depth)
129 | head = round_filters(1536, depth_multiplier, depth_divisor, min_depth)
130 | dropout = 0.2
131 |
132 | blocks_args = [
133 | 'r1_k3_a1_p1_s11_e1_i16_o16',
134 | 'r1_k3_a1.1_p1.1_s22_e6_i16_o24',
135 | 'r1_k3_a1.1_p1.1_s11_e3_i24_o24',
136 |
137 | 'r1_k3.5.7_a1_p1_s22_e6_i24_o40_se0.5_sw',
138 | 'r3_k3.5_a1.1_p1.1_s11_e6_i40_o40_se0.5_sw',
139 |
140 | 'r1_k3.5.7_a1_p1.1_s22_e6_i40_o80_se0.25_sw',
141 | 'r2_k3.5_a1_p1.1_s11_e6_i80_o80_se0.25_sw',
142 |
143 | 'r1_k3.5.7_a1.1_p1.1_s11_e6_i80_o120_se0.5_sw',
144 | 'r2_k3.5.7.9_a1.1_p1.1_s11_e3_i120_o120_se0.5_sw',
145 |
146 | 'r1_k3.5.7.9.11_a1_p1_s22_e6_i120_o200_se0.5_sw',
147 | 'r2_k3.5.7.9_a1_p1.1_s11_e6_i200_o200_se0.5_sw',
148 | ]
149 |
150 | blocks_args = MixnetDecoder.decode(blocks_args, depth_multiplier, depth_divisor, min_depth)
151 | print("-----------")
152 | print("Mixnet S")
153 | for a in blocks_args:
154 | print(a)
155 | print("-----------")
156 | return MixNet(stem, blocks_args, head, dropout, num_classes=num_classes)
157 |
158 |
159 | if __name__ == "__main__":
160 | mixnet_s()
161 |
--------------------------------------------------------------------------------