├── common
├── __init__.py
├── find_mxnet.py
├── util.py
├── modelzoo.py
├── data.py
└── fit.py
├── model
└── PUT_YOUR_MODEL_HERE
├── data
├── gen_rec.sh
└── mx_list.py
├── run.sh
├── .gitignore
├── fine-tune.py
├── README.md
└── sub.py
/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/PUT_YOUR_MODEL_HERE:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/gen_rec.sh:
--------------------------------------------------------------------------------
1 | #get im2rec.py at https://github.com/dmlc/mxnet/tree/master/tools
2 | python -u im2rec.py --resize 512 --quality 95 --num-thread 20 val ./
3 | python -u im2rec.py --resize 512 --quality 95 --num-thread 20 train ./
4 |
--------------------------------------------------------------------------------
/common/find_mxnet.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "1"
3 | try:
4 | import mxnet as mx
5 | except ImportError:
6 | curr_path = os.path.abspath(os.path.dirname(__file__))
7 | sys.path.append(os.path.join(curr_path, "../../../python"))
8 | import mxnet as mx
9 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | export MXNET_CPU_WORKER_NTHREADS=48
2 | export MXNET_CUDNN_AUTOTUNE_DEFAULT=0
3 | python fine-tune.py --pretrained-model model/resnet-152 \
4 | --load-epoch 0 --gpus 0,1,2,3 \
5 | --model-prefix model/iNat-resnet-152 \
6 | --data-nthreads 48 \
7 | --batch-size 48 --num-classes 5089 --num-examples 579184
8 |
--------------------------------------------------------------------------------
/common/util.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | import errno
4 |
5 | def download_file(url, local_fname=None, force_write=False):
6 | # requests is not default installed
7 | import requests
8 | if local_fname is None:
9 | local_fname = url.split('/')[-1]
10 | if not force_write and os.path.exists(local_fname):
11 | return local_fname
12 |
13 | dir_name = os.path.dirname(local_fname)
14 |
15 | if dir_name != "":
16 | if not os.path.exists(dir_name):
17 | try: # try to create the directory if it doesn't exists
18 | os.makedirs(dir_name)
19 | except OSError as exc:
20 | if exc.errno != errno.EEXIST:
21 | raise
22 |
23 |
24 |
25 | r = requests.get(url, stream=True)
26 | assert r.status_code == 200, "failed to open %s" % url
27 | with open(local_fname, 'wb') as f:
28 | for chunk in r.iter_content(chunk_size=1024):
29 | if chunk: # filter out keep-alive new chunks
30 | f.write(chunk)
31 | return local_fname
32 |
33 | def get_gpus():
34 | """
35 | return a list of GPUs
36 | """
37 | try:
38 | re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
39 | except OSError:
40 | return []
41 | return range(len([i for i in re.split('\n') if 'GPU' in i]))
42 |
--------------------------------------------------------------------------------
/data/mx_list.py:
--------------------------------------------------------------------------------
1 | # iNatularist image loader
2 |
3 |
4 | from PIL import Image
5 | import os
6 | import json
7 | import numpy as np
8 |
9 | def default_loader(path):
10 | return Image.open(path).convert('RGB')
11 |
12 | def gen_list(prefix):
13 | ann_file = '%s2017.json'%prefix
14 | train_out = '%s.lst'%prefix
15 | # load annotations
16 | print('Loading annotations from: ' + os.path.basename(ann_file))
17 | with open(ann_file) as data_file:
18 | ann_data = json.load(data_file)
19 |
20 | # set up the filenames and annotations
21 | imgs = [aa['file_name'] for aa in ann_data['images']]
22 | im_ids = [aa['id'] for aa in ann_data['images']]
23 | if 'annotations' in ann_data.keys():
24 | # if we have class labels
25 | classes = [aa['category_id'] for aa in ann_data['annotations']]
26 | else:
27 | # otherwise dont have class info so set to 0
28 | classes = [0]*len(im_ids)
29 |
30 | idx_to_class = {cc['id']: cc['name'] for cc in ann_data['categories']}
31 |
32 | print('\t' + str(len(imgs)) + ' images')
33 | print('\t' + str(len(idx_to_class)) + ' classes')
34 |
35 | for index in range(10):
36 | path = imgs[index]
37 | target = str(classes[index])
38 | im_id = str(im_ids[index]-1)
39 | print(im_id + '\t' + target + '\t' + path)
40 |
41 | import pandas as pd
42 | from sklearn.utils import shuffle
43 |
44 | df = pd.DataFrame(classes)
45 | df[1] = imgs
46 | df = shuffle(df)
47 |
48 | df.to_csv(train_out, sep='\t', header=None, index=False)
49 | df = pd.read_csv(train_out, delimiter='\t', header=None)
50 | df.to_csv(train_out, sep='\t', header=None)
51 |
52 | if __name__ == '__main__':
53 | set_names = ['train', 'val', 'test']
54 | for name in set_names:
55 | gen_list(name)
56 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/common/modelzoo.py:
--------------------------------------------------------------------------------
1 | import os
2 | from util import download_file
3 |
4 | _base_model_url = 'http://data.mxnet.io/models/'
5 | _default_model_info = {
6 | 'imagenet1k-inception-bn': {'symbol':_base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json',
7 | 'params':_base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'},
8 | 'imagenet1k-resnet-18': {'symbol':_base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json',
9 | 'params':_base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'},
10 | 'imagenet1k-resnet-34': {'symbol':_base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json',
11 | 'params':_base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'},
12 | 'imagenet1k-resnet-50': {'symbol':_base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json',
13 | 'params':_base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'},
14 | 'imagenet1k-resnet-101': {'symbol':_base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json',
15 | 'params':_base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'},
16 | 'imagenet1k-resnet-152': {'symbol':_base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json',
17 | 'params':_base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'},
18 | 'imagenet1k-resnext-50': {'symbol':_base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json',
19 | 'params':_base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'},
20 | 'imagenet1k-resnext-101': {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json',
21 | 'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'},
22 | 'imagenet11k-resnet-152': {'symbol':_base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json',
23 | 'params':_base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'},
24 | 'imagenet11k-place365ch-resnet-152': {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json',
25 | 'params':_base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'},
26 | 'imagenet11k-place365ch-resnet-50': {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json',
27 | 'params':_base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'},
28 | }
29 |
30 | def download_model(model_name, dst_dir='./', meta_info=None):
31 | if meta_info is None:
32 | meta_info = _default_model_info
33 | meta_info = dict(meta_info)
34 | if model_name not in meta_info:
35 | return (None, 0)
36 | if not os.path.isdir(dst_dir):
37 | os.mkdir(dst_dir)
38 | meta = dict(meta_info[model_name])
39 | assert 'symbol' in meta, "missing symbol url"
40 | model_name = os.path.join(dst_dir, model_name)
41 | download_file(meta['symbol'], model_name+'-symbol.json')
42 | assert 'params' in meta, "mssing parameter file url"
43 | download_file(meta['params'], model_name+'-0000.params')
44 | return (model_name, 0)
45 |
--------------------------------------------------------------------------------
/fine-tune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | logging.basicConfig(level=logging.DEBUG)
5 | from common import find_mxnet
6 | from common import data, fit, modelzoo
7 | import mxnet as mx
8 |
9 | import os, urllib
10 | def download(url):
11 | filename = url.split("/")[-1]
12 | if not os.path.exists('model/'+filename):
13 | urllib.urlretrieve(url, 'model/'+ filename)
14 |
15 | def get_model(prefix, epoch):
16 | download(prefix+'-symbol.json')
17 | download(prefix+'-%04d.params' % (epoch,))
18 |
19 | def get_fine_tune_model(symbol, arg_params, num_classes, layer_name):
20 | """
21 | symbol: the pre-trained network symbol
22 | arg_params: the argument parameters of the pre-trained model
23 | num_classes: the number of classes for the fine-tune datasets
24 | layer_name: the layer name before the last fully-connected layer
25 | """
26 | all_layers = sym.get_internals()
27 | net = all_layers[layer_name+'_output']
28 | net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc') #, lr_mult=10)
29 | net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
30 | new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})
31 | return (net, new_args)
32 |
33 |
34 | if __name__ == "__main__":
35 | # parse args
36 | parser = argparse.ArgumentParser(description="fine-tune a dataset",
37 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
38 | train = fit.add_fit_args(parser)
39 | data.add_data_args(parser)
40 | aug = data.add_data_aug_args(parser)
41 | parser.add_argument('--pretrained-model', type=str,
42 | help='the pre-trained model')
43 | parser.add_argument('--layer-before-fullc', type=str, default='flatten0',
44 | help='the name of the layer before the last fullc layer')
45 | # use less augmentations for fine-tune
46 | data.set_data_aug_level(parser, 1)
47 | # use a small learning rate and less regularizations
48 | # when training comes to 10th and 20th epoch
49 | # see http://mxnet.io/how_to/finetune.html and Mu's thesis
50 | # http://www.cs.cmu.edu/~muli/file/mu-thesis.pdf
51 | parser.set_defaults(image_shape='3,320,320', num_epochs=30,
52 | lr=.01, lr_step_epochs='10,20', wd=0, mom=0)
53 |
54 | args = parser.parse_args()
55 |
56 | # load pretrained model
57 | dir_path = os.path.dirname(os.path.realpath(__file__))
58 |
59 | # get the pretrained resnet 152 from official MXNet model zoo
60 | # 1k imagenet pretrained
61 | #get_model('http://data.mxnet.io/models/imagenet/resnet/152-layers/resnet-152', 0)
62 | # 11k imagenet resnet 152 has stronger classification power
63 | get_model('http://data.mxnet.io/models/imagenet-11k/resnet-152/resnet-152', 0)
64 | prefix = 'model/resnet-152'
65 | epoch = 0
66 | sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
67 |
68 | # remove the last fullc layer
69 | (new_sym, new_args) = get_fine_tune_model(
70 | sym, arg_params, args.num_classes, args.layer_before_fullc)
71 |
72 |
73 | # train
74 | fit.fit(args = args,
75 | network = new_sym,
76 | data_loader = data.get_rec_iter,
77 | arg_params = new_args,
78 | aux_params = aux_params)
79 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # iNaturalist
2 | MXNet fine-tune baseline script (resnet 152 layers) for iNaturalist Challenge at FGVC 2017, public LB score 0.117 from a single 21st epoch submission without ensemble.
3 |
4 | ## How to use
5 |
6 | ### Install MXNet
7 |
8 | Run `pip install mxnet-cu80` after installing CUDA driver or go to for the latest version from Github.
9 |
10 | Windows users? no CUDA 8.0? no GPU? Please run `pip search mxnet` and find the good package for your platform.
11 |
12 | ### Generate lists
13 |
14 | After downloading and unzipping the train and test set in to `data`, along with the necessary `.json` annotation files, run `python mx_list.py` under `data` and generate `train.lst` `val.lst` `test.lst`
15 |
16 | ### Generate rec files
17 |
18 | A good way to speed up training is maximizing the IO by using `.rec` format, which also provides convenience of data augmentation. In the `data/` directory, `gen_rec.sh` can generate `train.rec` and `val.rec` for the train and validate datasets, and `im2rec.py` can be obtained from MXNet repo . One can adjust `--quality 95` parameter to lower quality for saving disk space, but it may take risk of loosing training precision.
19 |
20 | ### Train
21 |
22 | Run `sh run.sh` which looks like (a 4 GTX 1080 machine for example):
23 |
24 | ```
25 | python fine-tune.py --pretrained-model model/resnet-152 \
26 | --load-epoch 0 --gpus 0,1,2,3 \
27 | --model-prefix model/iNat-resnet-152 \
28 | --data-nthreads 48 \
29 | --batch-size 48 --num-classes 5089 --num-examples 579184
30 | ```
31 |
32 | please adjust `--gpus` and `--batch-size` according to the machine configuration. A sample calculation: `batch-size = 12` can use 8 GB memory on a GTX 1080, so `--batch-size 48` is good for a 4-GPU machine.
33 |
34 | Please have internet connection for the first time run because needs to download the pretrained model from . If the machine has no internet connection, please download the corresponding model files from other machines, and ship to `model/` directory.
35 |
36 | ### Generate submission file
37 |
38 | After a long run of some epochs, e.g. 30 epochs, we can select some epochs for the submission file. Run `sub.py` which two parameters : `num of epoch` and `gpu id` like:
39 |
40 | ```
41 | python sub.py 21 0
42 | ```
43 |
44 | selects the 21st epoch and infer on GPU `#0`. One can merge multiple epoch results on different GPUs and ensemble for a good submission file.
45 |
46 | ## How 'fine-tune' works
47 |
48 | Fine-tune method starts with loading a pretrained ResNet 152 layers (Imagenet 11k classes) from MXNet model zoo, where the model has gained some prediction power, and applies the new data by learning from provided data.
49 |
50 | The key technique is from `lr_step_epochs` where we assign a small learning rate and less regularizations when approach to certain epochs. In this example, we give `lr_step_epochs='10,20'` which means the learning rate changes slower when approach to 10th and 20th epoch, so the fine-tune procedure can converge the network and learn from the provided new samples. A similar thought is applied to the data augmentations where fine tune is given less augmentation. This technique is described in Mu's thesis
51 |
52 | This pipeline is not limited to ResNet-152 pretrained model. Please experiment the fine tune method with other models, like ResNet 101, Inception, from MXNet's model zoo by following this tutorial and this sample code . Please feel free submit issues and/or pull requests and/or discuss on the Kaggle forum if have better results.
53 |
54 | ## Reference
55 |
56 | * MXNet's model zoo
57 | * MXNet fine tune
58 | * Mu Li's thesis
59 | * iNaturalist Challenge at FGVC 2017
--------------------------------------------------------------------------------
/common/data.py:
--------------------------------------------------------------------------------
1 | import mxnet as mx
2 | import random
3 | from mxnet.io import DataBatch, DataIter
4 | import numpy as np
5 |
6 | def add_data_args(parser):
7 | data = parser.add_argument_group('Data', 'the input images')
8 | #data.add_argument('--data-train', type=str, help='the training data')
9 | #data.add_argument('--data-val', type=str, help='the validation data')
10 | data.add_argument('--rgb-mean', type=str, default='123.68,116.779,103.939',
11 | help='a tuple of size 3 for the mean rgb')
12 | data.add_argument('--pad-size', type=int, default=0,
13 | help='padding the input image')
14 | data.add_argument('--image-shape', type=str,
15 | help='the image shape feed into the network, e.g. (3,224,224)')
16 | data.add_argument('--num-classes', type=int, help='the number of classes')
17 | data.add_argument('--num-examples', type=int, help='the number of training examples')
18 | data.add_argument('--data-nthreads', type=int, default=4,
19 | help='number of threads for data decoding')
20 | data.add_argument('--benchmark', type=int, default=0,
21 | help='if 1, then feed the network with synthetic data')
22 | data.add_argument('--dtype', type=str, default='float32',
23 | help='data type: float32 or float16')
24 | return data
25 |
26 | def add_data_aug_args(parser):
27 | aug = parser.add_argument_group(
28 | 'Image augmentations', 'implemented in src/io/image_aug_default.cc')
29 | aug.add_argument('--random-crop', type=int, default=1,
30 | help='if or not randomly crop the image')
31 | aug.add_argument('--random-mirror', type=int, default=1,
32 | help='if or not randomly flip horizontally')
33 | aug.add_argument('--max-random-h', type=int, default=0,
34 | help='max change of hue, whose range is [0, 180]')
35 | aug.add_argument('--max-random-s', type=int, default=0,
36 | help='max change of saturation, whose range is [0, 255]')
37 | aug.add_argument('--max-random-l', type=int, default=0,
38 | help='max change of intensity, whose range is [0, 255]')
39 | aug.add_argument('--max-random-aspect-ratio', type=float, default=0,
40 | help='max change of aspect ratio, whose range is [0, 1]')
41 | aug.add_argument('--max-random-rotate-angle', type=int, default=0,
42 | help='max angle to rotate, whose range is [0, 360]')
43 | aug.add_argument('--max-random-shear-ratio', type=float, default=0,
44 | help='max ratio to shear, whose range is [0, 1]')
45 | aug.add_argument('--max-random-scale', type=float, default=1,
46 | help='max ratio to scale')
47 | aug.add_argument('--min-random-scale', type=float, default=1,
48 | help='min ratio to scale, should >= img_size/input_shape. otherwise use --pad-size')
49 | return aug
50 |
51 | def set_data_aug_level(aug, level):
52 | if level >= 1:
53 | aug.set_defaults(random_crop=1, random_mirror=1)
54 | if level >= 2:
55 | aug.set_defaults(max_random_h=36, max_random_s=50, max_random_l=50)
56 | if level >= 3:
57 | aug.set_defaults(max_random_rotate_angle=10, max_random_shear_ratio=0.1, max_random_aspect_ratio=0.25)
58 |
59 |
60 | class SyntheticDataIter(DataIter):
61 | def __init__(self, num_classes, data_shape, max_iter, dtype):
62 | self.batch_size = data_shape[0]
63 | self.cur_iter = 0
64 | self.max_iter = max_iter
65 | self.dtype = dtype
66 | label = np.random.randint(0, num_classes, [self.batch_size,])
67 | data = np.random.uniform(-1, 1, data_shape)
68 | self.data = mx.nd.array(data, dtype=self.dtype)
69 | self.label = mx.nd.array(label, dtype=self.dtype)
70 | def __iter__(self):
71 | return self
72 | @property
73 | def provide_data(self):
74 | return [mx.io.DataDesc('data', self.data.shape, self.dtype)]
75 | @property
76 | def provide_label(self):
77 | return [mx.io.DataDesc('softmax_label', (self.batch_size,), self.dtype)]
78 | def next(self):
79 | self.cur_iter += 1
80 | if self.cur_iter <= self.max_iter:
81 | return DataBatch(data=(self.data,),
82 | label=(self.label,),
83 | pad=0,
84 | index=None,
85 | provide_data=self.provide_data,
86 | provide_label=self.provide_label)
87 | else:
88 | raise StopIteration
89 | def __next__(self):
90 | return self.next()
91 | def reset(self):
92 | self.cur_iter = 0
93 |
94 | def get_rec_iter(args, kv=None):
95 | image_shape = tuple([int(l) for l in args.image_shape.split(',')])
96 | dtype = np.float32;
97 | if 'dtype' in args:
98 | if args.dtype == 'float16':
99 | dtype = np.float16
100 | if 'benchmark' in args and args.benchmark:
101 | data_shape = (args.batch_size,) + image_shape
102 | train = SyntheticDataIter(args.num_classes, data_shape, 50, dtype)
103 | return (train, None)
104 | if kv:
105 | (rank, nworker) = (kv.rank, kv.num_workers)
106 | else:
107 | (rank, nworker) = (0, 1)
108 | rgb_mean = [float(i) for i in args.rgb_mean.split(',')]
109 | train = mx.img.ImageIter(
110 | label_width = 1,
111 | path_root = 'data/',
112 | #path_imglist = args.data_train,
113 | path_imgrec = 'data/train.rec',
114 | path_imgidx = 'data/train.idx',
115 | data_shape = (3, 320, 320),
116 | batch_size = args.batch_size,
117 | rand_crop = True,
118 | rand_resize = True,
119 | rand_mirror = True,
120 | shuffle = True,
121 | brightness = 0.4,
122 | contrast = 0.4,
123 | saturation = 0.4,
124 | pca_noise = 0.1,
125 | num_parts = nworker,
126 | part_index = rank)
127 | #if args.data_val is None:
128 | # return (train, None)
129 | val = mx.img.ImageIter(
130 | label_width = 1,
131 | path_root = 'data/',
132 | #path_imglist = args.data_val,
133 | path_imgrec = 'data/val.rec',
134 | path_imgidx = 'data/val.idx',
135 | batch_size = args.batch_size,
136 | data_shape = (3, 320, 320),
137 | resize = 360,
138 | rand_crop = False,
139 | rand_resize = False,
140 | rand_mirror = False,
141 | num_parts = nworker,
142 | part_index = rank)
143 | return (train, val)
144 |
--------------------------------------------------------------------------------
/sub.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0'
3 | import sys
4 | import numpy as np
5 | import cv2
6 | import json
7 | from common import find_mxnet
8 | import mxnet as mx
9 |
10 | def ch_dev(arg_params, aux_params, ctx):
11 | new_args = dict()
12 | new_auxs = dict()
13 | for k, v in arg_params.items():
14 | new_args[k] = v.as_in_context(ctx)
15 | for k, v in aux_params.items():
16 | new_auxs[k] = v.as_in_context(ctx)
17 | return new_args, new_auxs
18 |
19 | def oversample(images, crop_dims):
20 |
21 | im_shape = np.array(images.shape)
22 | crop_dims = np.array(crop_dims)
23 | im_center = im_shape[:2] / 2.0
24 |
25 | h_indices = (0, im_shape[0] - crop_dims[0])
26 | w_indices = (0, im_shape[1] - crop_dims[1])
27 | crops_ix = np.empty((5, 4), dtype=int)
28 | curr = 0
29 | for i in h_indices:
30 | for j in w_indices:
31 | crops_ix[curr] = (i, j, i + crop_dims[0], j + crop_dims[1])
32 | curr += 1
33 | crops_ix[4] = np.tile(im_center, (1, 2)) + np.concatenate([
34 | -crop_dims / 2.0,
35 | crop_dims / 2.0
36 | ])
37 | crops_ix = np.tile(crops_ix, (2, 1))
38 |
39 | # print crops_ix
40 |
41 | # Extract crops
42 | crops = np.empty((10, crop_dims[0], crop_dims[1],
43 | im_shape[-1]), dtype=np.float32)
44 | ix = 0
45 | # for im in images:
46 | im = images
47 | # print im.shape
48 | for crop in crops_ix:
49 | # print crop
50 | crops[ix] = im[crop[0]:crop[2], crop[1]:crop[3], :]
51 | # cv2.imshow('crop', im[crop[0]:crop[2], crop[1]:crop[3], :])
52 | # cv2.waitKey()
53 | ix += 1
54 | crops[ix-5:ix] = crops[ix-5:ix, :, ::-1, :]
55 | # cv2.imshow('crop', crops[0,:,:,:])
56 | # cv2.waitKey()
57 | return crops
58 |
59 | prefix = 'model/iNat-resnet-152'
60 | epoch = int(sys.argv[1]) #check point step
61 | gpu_id = int(sys.argv[2]) #GPU ID for infer
62 | ctx = mx.gpu(gpu_id)
63 | sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
64 | arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
65 |
66 |
67 | ann_file = 'data/test2017.json'
68 | print('Loading annotations from: ' + os.path.basename(ann_file))
69 | with open(ann_file) as data_file:
70 | ann_data = json.load(data_file)
71 |
72 | imgs = [aa['file_name'] for aa in ann_data['images']]
73 | im_ids = [aa['id'] for aa in ann_data['images']]
74 | if 'annotations' in ann_data.keys():
75 | # if we have class labels
76 | classes = [aa['category_id'] for aa in ann_data['annotations']]
77 | else:
78 | # otherwise dont have class info so set to 0
79 | classes = [0]*len(im_ids)
80 |
81 | idx_to_class = {cc['id']: cc['name'] for cc in ann_data['categories']}
82 |
83 |
84 |
85 | top1_acc = 0
86 | top5_acc = 0
87 | cnt = 0
88 | img_sz = 360
89 | crop_sz = 320
90 |
91 | preds = []
92 | im_idxs = []
93 | batch_sz = 256
94 | input_blob = np.zeros((batch_sz,3,crop_sz,crop_sz))
95 | idx = 0
96 | num_batches = int(len(imgs) / batch_sz)
97 |
98 | for batch_head in range(0, batch_sz*num_batches, batch_sz):
99 | #print batch_head
100 | for index in range(batch_head, batch_head+batch_sz):
101 | img_name = imgs[index]
102 | label = str(classes[index])
103 | im_id = str(im_ids[index])
104 | im_idxs.append(int(im_id))
105 | cnt += 1
106 | img_full_name = 'data/test2017/' + img_name
107 | img = cv2.cvtColor(cv2.imread(img_full_name), cv2.COLOR_BGR2RGB)
108 | img = np.float32(img)
109 |
110 | rows, cols = img.shape[:2]
111 | if cols < rows:
112 | resize_width = img_sz
113 | resize_height = resize_width * rows / cols;
114 | else:
115 | resize_height = img_sz
116 | resize_width = resize_height * cols / rows;
117 |
118 | img = cv2.resize(img, (resize_width, resize_height), interpolation=cv2.INTER_CUBIC)
119 |
120 | h, w, _ = img.shape
121 |
122 | x0 = int((w - crop_sz) / 2)
123 | y0 = int((h - crop_sz) / 2)
124 | img = img[y0:y0+crop_sz, x0:x0+crop_sz]
125 |
126 | img = np.swapaxes(img, 0, 2)
127 | img = np.swapaxes(img, 1, 2) # change to r,g,b order
128 | input_blob[idx,:,:,:] = img
129 | idx += 1
130 | #print(idx)
131 |
132 | idx = 0
133 |
134 |
135 | arg_params["data"] = mx.nd.array(input_blob, ctx)
136 | arg_params["softmax_label"] = mx.nd.empty((batch_sz,), ctx)
137 | exe = sym.bind(ctx, arg_params ,args_grad=None, grad_req="null", aux_states=aux_params)
138 | exe.forward(is_train=False)
139 | net_out = exe.outputs[0].asnumpy()
140 |
141 | input_blob = np.zeros((batch_sz,3,crop_sz,crop_sz))
142 |
143 | for bz in range(batch_sz):
144 | probs = net_out[bz,:]
145 | score = np.squeeze(probs)
146 |
147 | sort_index = np.argsort(score)[::-1]
148 | top_k = sort_index[0:5]
149 | preds.append(top_k.astype(np.int))
150 | print(preds[-1], batch_head+bz)
151 |
152 |
153 |
154 | for index in range(batch_sz*num_batches, len(imgs)):
155 | img_name = imgs[index]
156 | label = str(classes[index])
157 | im_id = str(im_ids[index])
158 | im_idxs.append(int(im_id))
159 | cnt += 1
160 | img_full_name = 'data/test2017/' + img_name
161 | img = cv2.cvtColor(cv2.imread(img_full_name), cv2.COLOR_BGR2RGB)
162 | img = np.float32(img)
163 |
164 | rows, cols = img.shape[:2]
165 | if cols < rows:
166 | resize_width = img_sz
167 | resize_height = resize_width * rows / cols;
168 | else:
169 | resize_height = img_sz
170 | resize_width = resize_height * cols / rows;
171 |
172 | img = cv2.resize(img, (resize_width, resize_height), interpolation=cv2.INTER_CUBIC)
173 |
174 | #batch = oversample(img, (crop_sz,crop_sz))
175 |
176 | h, w, _ = img.shape
177 |
178 | x0 = int((w - crop_sz) / 2)
179 | y0 = int((h - crop_sz) / 2)
180 | img = img[y0:y0+crop_sz, x0:x0+crop_sz]
181 |
182 | img = np.swapaxes(img, 0, 2)
183 | img = np.swapaxes(img, 1, 2) # change to r,g,b order
184 |
185 | img = img[np.newaxis, :]
186 | arg_params["data"] = mx.nd.array(img, ctx)
187 | #arg_params["data"] = mx.nd.array(input_blob, ctx)
188 | arg_params["softmax_label"] = mx.nd.empty((1,), ctx)
189 | exe = sym.bind(ctx, arg_params ,args_grad=None, grad_req="null", aux_states=aux_params)
190 | exe.forward(is_train=False)
191 | probs = exe.outputs[0].asnumpy()
192 | score = np.squeeze(probs.mean(axis=0))
193 |
194 | sort_index = np.argsort(score)[::-1]
195 | top_k = sort_index[0:5]
196 | #print(top_k)
197 |
198 | preds.append(top_k.astype(np.int))
199 | print(preds[-1], im_idxs[-1])
200 | #print(top_k.astype(np.int), int(im_id))
201 | #print(preds[index], im_idxs[index])
202 |
203 | im_idxs = np.hstack(im_idxs)
204 | preds = np.vstack(preds)
205 |
206 |
207 | with open("submission_epoch_%d.csv"%(epoch), 'w') as opfile:
208 | opfile.write('id,predicted\n')
209 | for ii in range(len(im_idxs)):
210 | opfile.write(str(im_idxs[ii]) + ',' + ' '.join(str(x) for x in preds[ii,:])+'\n')
211 |
212 |
213 |
214 |
215 |
--------------------------------------------------------------------------------
/common/fit.py:
--------------------------------------------------------------------------------
1 | import mxnet as mx
2 | import logging
3 | import os
4 | import time
5 |
6 | def _get_lr_scheduler(args, kv):
7 | if 'lr_factor' not in args or args.lr_factor >= 1:
8 | return (args.lr, None)
9 | epoch_size = args.num_examples / args.batch_size
10 | if 'dist' in args.kv_store:
11 | epoch_size /= kv.num_workers
12 | begin_epoch = args.load_epoch if args.load_epoch else 0
13 | step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
14 | lr = args.lr
15 | for s in step_epochs:
16 | if begin_epoch >= s:
17 | lr *= args.lr_factor
18 | if lr != args.lr:
19 | logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))
20 |
21 | steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
22 | return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))
23 |
24 | def _load_model(args, rank=0):
25 | if 'load_epoch' not in args or args.load_epoch is None:
26 | return (None, None, None)
27 | assert args.model_prefix is not None
28 | model_prefix = args.model_prefix
29 | if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
30 | model_prefix += "-%d" % (rank)
31 | sym, arg_params, aux_params = mx.model.load_checkpoint(
32 | model_prefix, args.load_epoch)
33 | logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
34 | return (sym, arg_params, aux_params)
35 |
36 | def _save_model(args, rank=0):
37 | if args.model_prefix is None:
38 | return None
39 | dst_dir = os.path.dirname(args.model_prefix)
40 | if not os.path.isdir(dst_dir):
41 | os.mkdir(dst_dir)
42 | return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
43 | args.model_prefix, rank))
44 |
45 | def add_fit_args(parser):
46 | """
47 | parser : argparse.ArgumentParser
48 | return a parser added with args required by fit
49 | """
50 | train = parser.add_argument_group('Training', 'model training')
51 | train.add_argument('--network', type=str,
52 | help='the neural network to use')
53 | train.add_argument('--num-layers', type=int,
54 | help='number of layers in the neural network, required by some networks such as resnet')
55 | train.add_argument('--gpus', type=str,
56 | help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
57 | train.add_argument('--kv-store', type=str, default='device',
58 | help='key-value store type')
59 | train.add_argument('--num-epochs', type=int, default=100,
60 | help='max num of epochs')
61 | train.add_argument('--lr', type=float, default=0.1,
62 | help='initial learning rate')
63 | train.add_argument('--lr-factor', type=float, default=0.1,
64 | help='the ratio to reduce lr on each step')
65 | train.add_argument('--lr-step-epochs', type=str,
66 | help='the epochs to reduce the lr, e.g. 30,60')
67 | train.add_argument('--optimizer', type=str, default='sgd',
68 | help='the optimizer type')
69 | train.add_argument('--mom', type=float, default=0.9,
70 | help='momentum for sgd')
71 | train.add_argument('--wd', type=float, default=0.0001,
72 | help='weight decay for sgd')
73 | train.add_argument('--batch-size', type=int, default=128,
74 | help='the batch size')
75 | train.add_argument('--disp-batches', type=int, default=20,
76 | help='show progress for every n batches')
77 | train.add_argument('--model-prefix', type=str,
78 | help='model prefix')
79 | parser.add_argument('--monitor', dest='monitor', type=int, default=0,
80 | help='log network parameters every N iters if larger than 0')
81 | train.add_argument('--load-epoch', type=int,
82 | help='load the model on an epoch using the model-load-prefix')
83 | train.add_argument('--top-k', type=int, default=5,
84 | help='report the top-k accuracy. 0 means no report.')
85 | train.add_argument('--test-io', type=int, default=0,
86 | help='1 means test reading speed without training')
87 | return train
88 |
89 | def fit(args, network, data_loader, **kwargs):
90 | """
91 | train a model
92 | args : argparse returns
93 | network : the symbol definition of the nerual network
94 | data_loader : function that returns the train and val data iterators
95 | """
96 | # kvstore
97 | kv = mx.kvstore.create(args.kv_store)
98 |
99 | # logging
100 | head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
101 | logging.basicConfig(level=logging.DEBUG, format=head)
102 | logging.info('start with arguments %s', args)
103 |
104 | # data iterators
105 | (train, val) = data_loader(args, kv)
106 | if args.test_io:
107 | tic = time.time()
108 | for i, batch in enumerate(train):
109 | for j in batch.data:
110 | j.wait_to_read()
111 | if (i+1) % args.disp_batches == 0:
112 | logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (
113 | i, args.disp_batches*args.batch_size/(time.time()-tic)))
114 | tic = time.time()
115 |
116 | return
117 |
118 |
119 | # load model
120 | if 'arg_params' in kwargs and 'aux_params' in kwargs:
121 | arg_params = kwargs['arg_params']
122 | aux_params = kwargs['aux_params']
123 | else:
124 | sym, arg_params, aux_params = _load_model(args, kv.rank)
125 | if sym is not None:
126 | assert sym.tojson() == network.tojson()
127 |
128 | # save model
129 | checkpoint = _save_model(args, kv.rank)
130 |
131 | # devices for training
132 | devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
133 | mx.gpu(int(i)) for i in args.gpus.split(',')]
134 |
135 | # learning rate
136 | lr, lr_scheduler = _get_lr_scheduler(args, kv)
137 |
138 | # create model
139 | model = mx.mod.Module(
140 | context = devs,
141 | symbol = network
142 | )
143 |
144 | lr_scheduler = lr_scheduler
145 | optimizer_params = {
146 | 'learning_rate': lr,
147 | 'momentum' : args.mom,
148 | 'wd' : args.wd,
149 | 'lr_scheduler': lr_scheduler}
150 |
151 | monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None
152 |
153 | if args.network == 'alexnet':
154 | # AlexNet will not converge using Xavier
155 | initializer = mx.init.Normal()
156 | else:
157 | initializer = mx.init.Xavier(
158 | rnd_type='gaussian', factor_type="in", magnitude=2)
159 | # initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
160 |
161 | # evaluation metrices
162 | eval_metrics = ['accuracy']
163 | if args.top_k > 0:
164 | eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k))
165 |
166 | # callbacks that run after each batch
167 | batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)]
168 | if 'batch_end_callback' in kwargs:
169 | cbs = kwargs['batch_end_callback']
170 | batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]
171 |
172 | # run
173 | model.fit(train,
174 | begin_epoch = args.load_epoch if args.load_epoch else 0,
175 | num_epoch = args.num_epochs,
176 | eval_data = val,
177 | eval_metric = eval_metrics,
178 | kvstore = kv,
179 | optimizer = args.optimizer,
180 | optimizer_params = optimizer_params,
181 | initializer = initializer,
182 | arg_params = arg_params,
183 | aux_params = aux_params,
184 | batch_end_callback = batch_end_callbacks,
185 | epoch_end_callback = checkpoint,
186 | allow_missing = True,
187 | monitor = monitor)
188 |
--------------------------------------------------------------------------------