├── common ├── __init__.py ├── fit.py └── data.py ├── model └── README ├── data └── README ├── train_mmd.sh ├── .gitignore ├── README.md ├── fine-tune.py ├── mmd.py └── adabn.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/README: -------------------------------------------------------------------------------- 1 | #put models here 2 | -------------------------------------------------------------------------------- /data/README: -------------------------------------------------------------------------------- 1 | #put data lst here 2 | -------------------------------------------------------------------------------- /train_mmd.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export MXNET_CPU_WORKER_NTHREADS=15 4 | export MXNET_CUDNN_AUTOTUNE_DEFAULT=0 5 | export MXNET_ENGINE_TYPE=ThreadedEnginePerDevice 6 | python fine-tune.py --train-stage 0 --pretrained-model 'model/resnet-152' --pretrained-epoch 0 --model-prefix 'model/mmd' --num-classes 263 --lr 0.01 --lr-step-epochs '10,16' --num-epochs 18 --lr-factor 0.1 --gpus 0,1,2,3 --batch-size 64 7 | sleep 5 8 | python fine-tune.py --train-stage 1 --pretrained-model 'model/mmd' --pretrained-epoch 18 --model-prefix 'model/mmd1' --num-classes 263 --lr 0.0001 --lr-step-epochs '6,8,10,12' --num-epochs 14 --lr-factor 0.5 --gpus 0,1,2,3 --batch-size 64 9 | 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transfer-mxnet 2 | Unsupervised transfer learning for image classification written in mxnet. 3 | 4 | This is a library for unsupervised transfer learning using mxnet. We mainly implemented three algorithms: 5 | 6 | - `mmd` described in paper "Learning Transferable Features with Deep Adaptation Networks". 7 | - `jmmd` described in paper "Deep Transfer Learning with Joint Adaptation Networks". 8 | - `AdaBN` described in paper "REVISITING BATCH NORMALIZATION FOR PRACTICAL DOMAIN ADAPTATION". 9 | 10 | For original caffe implementation of `mmd` and `jmmd`, please refer to [here](https://github.com/thuml/transfer-caffe) 11 | 12 | If you have any problem about this code, feel free to concact us with the following email: 13 | - guojia@gmail.com 14 | 15 | Note that this repo is only for unsupervised image classfication transfer learning. 16 | 17 | Experiments 18 | --------------- 19 | We introduce our experiments on cars dataset: 20 | - `Source dataset` is a high quality cars image dataset fetched from web with accurate annotated labels(car models). 21 | - `Target dataset` is a surveillance image dataset which is public available(compcars-sv). 22 | 23 | During training, we set all labels in `Target dataset` to null-label(9999 by default) then it becomes unsupervised TL problem. 24 | 25 | | Method | Accuracy | 26 | | ---------- | ----- | 27 | | CNN(no TL) | 68.7% | 28 | | AdaBN | 71.6% | 29 | | DAN(mmd) | 73.7% | 30 | | JAN(jmmd) | 78.9% | 31 | 32 | If you want to train your own models with mmd(especically JAN as it is the best approach), please use following steps. 33 | 34 | Data Preparation 35 | --------------- 36 | - Download mxnet resnet-152 imagenet-11k pretrained model to `model/` directory, from [here](http://data.mxnet.io/models/imagenet-11k/resnet-152/). 37 | - Prepare your source domain dataset to `data/source.lst` in mxnet `lst` format. 38 | - Prepare your target domain dataset to `data/target.lst`. Generally, label id in `target.lst` should be null-label(9999 by default). But semi-supervised TL is also allowed, you can choose hundreds of items to be their real valid label id. 39 | - Prepare your validation dataset in target domain to `data/val.lst`. It can be the same with `data/target.lst` but all with valid labels. 40 | 41 | 42 | Training Model 43 | --------------- 44 | Env variables setting. 45 | ``` 46 | export MXNET_CPU_WORKER_NTHREADS=15 47 | ``` 48 | 49 | Training stage 1, do softmax training on source dataset only. 50 | 51 | ``` 52 | python fine-tune.py --train-stage 0 --pretrained-model 'model/resnet-152' --pretrained-epoch 0 --model-prefix 'model/mmd' --num-classes 263 --lr 0.01 --lr-step-epochs '10,16' --num-epochs 18 --lr-factor 0.1 --gpus 0,1,2,3 --batch-size 64 53 | ``` 54 | 55 | Training stage 2, do softmax+mmd joint training to produce final model. 56 | 57 | ``` 58 | python fine-tune.py --train-stage 1 --pretrained-model 'model/mmd' --pretrained-epoch 18 --model-prefix 'model/mmd1' --num-classes 263 --lr 0.0001 --lr-step-epochs '6,8,10,12' --num-epochs 14 --lr-factor 0.5 --gpus 0,1,2,3 --batch-size 64 59 | ``` 60 | 61 | Parameter Tuning 62 | --------------- 63 | TODO, check them in source code now. 64 | 65 | Use AdaBN 66 | --------------- 67 | ``` 68 | python adabn.py --model --epoch --val 'data/val.lst' --gpu 0 69 | ``` 70 | It will firstly calculate BN statistics using target domain dataset then write back to preloaded model. Second, use this modified model to validate the classification accuracy on target dataset. You can change the corresponding BN layers name in source code. 71 | 72 | 73 | -------------------------------------------------------------------------------- /fine-tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import logging 5 | logging.basicConfig(level=logging.DEBUG) 6 | #from common import find_mxnet 7 | from common import data 8 | from common import fit 9 | import mxnet as mx 10 | import mmd 11 | 12 | import os, urllib 13 | def download(url): 14 | filename = url.split("/")[-1] 15 | if not os.path.exists('model/'+filename): 16 | urllib.urlretrieve(url, 'model/'+ filename) 17 | 18 | def get_model(prefix, epoch): 19 | download(prefix+'-symbol.json') 20 | download(prefix+'-%04d.params' % (epoch,)) 21 | 22 | LABEL_WIDTH = 1 23 | 24 | def get_fine_tune_model(symbol, arg_params, args): 25 | """ 26 | symbol: the pre-trained network symbol 27 | arg_params: the argument parameters of the pre-trained model 28 | num_classes: the number of classes for the fine-tune datasets 29 | layer_name: the layer name before the last fully-connected layer 30 | """ 31 | layer_name = args.layer_before_fullc 32 | all_layers = symbol.get_internals() 33 | last_before = all_layers[layer_name+'_output'] 34 | lr_mult = 1 35 | feature = last_before 36 | fc = mx.symbol.FullyConnected(data=feature, num_hidden=args.num_classes, name='fc', lr_mult=lr_mult) #, lr_mult=10) 37 | net = mmd.mmd(feature, fc, args) 38 | if args.train_stage==0: 39 | new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k}) 40 | else: 41 | new_args = arg_params 42 | return (net, new_args) 43 | 44 | 45 | if __name__ == "__main__": 46 | # parse args 47 | parser = argparse.ArgumentParser(description="fine-tune a dataset", 48 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 49 | train = fit.add_fit_args(parser) 50 | data.add_data_args(parser) 51 | aug = data.add_data_aug_args(parser) 52 | parser.add_argument('--pretrained-model', type=str, default='model/resnet-152', 53 | help='the pre-trained model') 54 | parser.add_argument('--pretrained-epoch', type=int, default=0, 55 | help='the pre-trained model epoch to load') 56 | parser.add_argument('--layer-before-fullc', type=str, default='flatten0', 57 | help='the name of the layer before the last fullc layer') 58 | parser.add_argument('--no-checkpoint', action="store_true", default=False, 59 | help='do not save checkpoints') 60 | parser.add_argument('--freeze', action="store_true", default=False, 61 | help='freeze lower layers') 62 | parser.add_argument('--train-stage', type=int, default=0, 63 | help='training stage, train softmax only in training stage0 and use mmd loss in training stage1') 64 | parser.add_argument('--null-label', type=int, default=9999, 65 | help='indicate the label id of invalid label') 66 | parser.add_argument('--use-dan', action="store_true", default=False, 67 | help='use DAN instead of JAN') 68 | # use less augmentations for fine-tune 69 | data.set_data_aug_level(parser, 2) 70 | parser.set_defaults(data_dir="./data", top_k=0, kv_store='local', data_nthreads=15) 71 | #parser.set_defaults(model_prefix="", data_nthreads=15, batch_size=64, num_classes=263, gpus='0,1,2,3') 72 | #parser.set_defaults(image_shape='3,320,320', num_epochs=32, 73 | # lr=.0001, lr_step_epochs='12,20,24,28', wd=0, mom=0.9, lr_factor=0.5) 74 | parser.set_defaults(image_shape='3,320,320', wd=0, mom=0.9) 75 | 76 | args = parser.parse_args() 77 | args.label_width = LABEL_WIDTH 78 | args.gpu_num = len(args.gpus.split(',')) 79 | args.batch_per_gpu = args.batch_size/args.gpu_num 80 | with open(args.data_dir+'/source.lst') as f: 81 | args.num_examples = sum(1 for _ in f) 82 | if args.train_stage==1: 83 | with open(args.data_dir+'/target.lst') as f: 84 | target_num_examples = sum(1 for _ in f) 85 | args.num_examples = min(args.num_examples, target_num_examples) 86 | print('num_examples', args.num_examples) 87 | print('gpu_num', args.gpu_num) 88 | 89 | # load pretrained model 90 | dir_path = os.path.dirname(os.path.realpath(__file__)) 91 | 92 | prefix = args.pretrained_model 93 | epoch = args.pretrained_epoch 94 | sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) 95 | 96 | fixed_params = None 97 | 98 | if args.freeze: 99 | fixed_params = [] 100 | active_list = ['bn1', 'fc', 'stage3', 'stage4'] 101 | for k in arg_params: 102 | is_active = False 103 | for a in active_list: 104 | if k.startswith(a): 105 | is_active = True 106 | break 107 | if not is_active: 108 | fixed_params.append(k) 109 | print(fixed_params) 110 | 111 | # remove the last fullc layer 112 | (new_sym, new_args) = get_fine_tune_model( 113 | sym, arg_params, args) 114 | 115 | 116 | # train 117 | fit.fit(args = args, 118 | network = new_sym, 119 | data_loader = data.get_rec_iter, 120 | arg_params = new_args, 121 | aux_params = aux_params, 122 | fixed_param_names = fixed_params) 123 | 124 | -------------------------------------------------------------------------------- /mmd.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' 3 | import sys 4 | import numpy as np 5 | import cv2 6 | import random 7 | import json 8 | import mxnet as mx 9 | import math 10 | 11 | 12 | 13 | 14 | def calculate_distance2(data1, data2): 15 | #total_num = batch_size 16 | _data1 = mx.symbol.expand_dims(data1, axis=1) # B,1,C 17 | spread_distance2 = mx.symbol.broadcast_sub(_data1, data2) #B,B,C 18 | spread_distance2 = mx.symbol.reshape(spread_distance2, shape=(-3,-2)) #B*B,C 19 | spread_distance2 = mx.symbol.square(spread_distance2) 20 | distance2 = mx.symbol.sum(spread_distance2, axis=1) # B*B, 21 | return distance2 22 | 23 | def calculate_distance1(data1, data2): 24 | distance = data1 - data2 25 | #distance = mx.symbol.reshape(distance, shape=(1,-2)) 26 | distance2 = mx.symbol.square(distance) #1,C 27 | distance2 = mx.symbol.sum(distance2, axis=1) #1, 28 | return distance2 29 | 30 | def multi_kernel_distance(data1, data2, batch_size, kernel_num, base_gamma): 31 | assert batch_size==1 32 | #distance2 = calculate_distance2(data1, data2) 33 | distance2 = calculate_distance1(data1, data2) 34 | kernel_mul = 2.0 35 | coef = kernel_num*-0.5 36 | #times = math.pow(kernel_mul, kernel_num/2.0) 37 | ks = [] 38 | for i in xrange(kernel_num): 39 | kernel_gamma = base_gamma * math.pow(kernel_mul, coef) 40 | ks.append(mx.symbol.exp(distance2*-1.0/kernel_gamma)) 41 | coef += 1.0 42 | ret = mx.symbol.add_n(*ks) #(b*b,) 43 | return ret 44 | 45 | 46 | def multi_kernel_distance2(distance2, kernel_num, base_gamma): 47 | kernel_mul = 2.0 48 | coef = kernel_num*-0.5 49 | ks = [] 50 | for i in xrange(kernel_num): 51 | kernel_gamma = base_gamma * math.pow(kernel_mul, coef) 52 | ks.append(mx.symbol.exp(distance2*-1.0/kernel_gamma)) 53 | coef += 1.0 54 | ret = mx.symbol.add_n(*ks) #(b*b,) 55 | return ret 56 | 57 | 58 | def mmd(data, fc, args): 59 | batch_size = args.batch_per_gpu 60 | print('mmd batch_size', batch_size) 61 | assert batch_size%4==0 62 | 63 | #static params 64 | source_num = batch_size/2 65 | target_num = source_num 66 | total_num = batch_size 67 | data_kernel_num = 4 68 | label_kernel_num = 1 69 | group_num = source_num/2 70 | #data_gamma = 0.0 71 | label_gamma = 1.3 72 | 73 | 74 | softmax = mx.symbol.softmax(data=fc) 75 | gt_label = mx.symbol.Variable('softmax_label') 76 | #max_label = mx.symbol.max(gt_label) 77 | #source_data = mx.symbol.slice_axis(data, axis=0, begin=0, end=source_num) 78 | #target_data = mx.symbol.slice_axis(data, axis=0, begin=source_num, end=total_num) 79 | #source_softmax = mx.symbol.slice_axis(softmax, axis=0, begin=0, end=source_num) 80 | #target_softmax = mx.symbol.slice_axis(softmax, axis=0, begin=source_num, end=total_num) 81 | distance2 = calculate_distance2(data, data) 82 | bandwidth = mx.symbol.sum(distance2) 83 | #data_gamma = (total_num * total_num - total_num)/bandwidth 84 | data_gamma = bandwidth/(total_num * total_num - total_num) 85 | k_list = [] 86 | #unbiased mmd 87 | for i in xrange(group_num): 88 | xs1 = mx.symbol.slice_axis(data, axis=0, begin=i*2, end=i*2+1) 89 | xs2 = mx.symbol.slice_axis(data, axis=0, begin=i*2+1, end=i*2+2) 90 | xt1 = mx.symbol.slice_axis(data, axis=0, begin=source_num+i*2, end=source_num+i*2+1) 91 | xt2 = mx.symbol.slice_axis(data, axis=0, begin=source_num+i*2+1, end=source_num+i*2+2) 92 | ys1 = mx.symbol.slice_axis(softmax, axis=0, begin=i*2, end=i*2+1) 93 | ys2 = mx.symbol.slice_axis(softmax, axis=0, begin=i*2+1, end=i*2+2) 94 | yt1 = mx.symbol.slice_axis(softmax, axis=0, begin=source_num+i*2, end=source_num+i*2+1) 95 | yt2 = mx.symbol.slice_axis(softmax, axis=0, begin=source_num+i*2+1, end=source_num+i*2+2) 96 | 97 | k_x = multi_kernel_distance(xs1, xs2, 1, data_kernel_num, data_gamma) 98 | k_y = multi_kernel_distance(ys1, ys2, 1, label_kernel_num, label_gamma) 99 | k = k_x*k_y 100 | if args.use_dan: 101 | k_list.append(k_x) 102 | else: 103 | k_list.append(k) 104 | 105 | k_x = multi_kernel_distance(xt1, xt2, 1, data_kernel_num, data_gamma) 106 | k_y = multi_kernel_distance(yt1, yt2, 1, label_kernel_num, label_gamma) 107 | k = k_x*k_y 108 | if args.use_dan: 109 | k_list.append(k_x) 110 | else: 111 | k_list.append(k) 112 | 113 | k_x = multi_kernel_distance(xs1, xt2, 1, data_kernel_num, data_gamma)*-1.0 114 | k_y = multi_kernel_distance(ys1, yt2, 1, label_kernel_num, label_gamma) 115 | k = k_x*k_y 116 | if args.use_dan: 117 | k_list.append(k_x) 118 | else: 119 | k_list.append(k) 120 | 121 | k_x = multi_kernel_distance(xt1, xs2, 1, data_kernel_num, data_gamma)*-1.0 122 | k_y = multi_kernel_distance(yt1, ys2, 1, label_kernel_num, label_gamma) 123 | k = k_x*k_y 124 | if args.use_dan: 125 | k_list.append(k_x) 126 | else: 127 | k_list.append(k) 128 | 129 | mmd_loss = mx.symbol.add_n(*k_list)/group_num 130 | net = mx.symbol.SoftmaxOutput(data=fc, label = gt_label, use_ignore=True, ignore_label=args.null_label, name='softmax') 131 | if args.train_stage>0: 132 | grad_scale = 1.0 if args.use_dan else 2.0 133 | mmd = mx.symbol.MakeLoss(mmd_loss, grad_scale=grad_scale) 134 | net = mx.symbol.Group([net,mmd]) 135 | return net 136 | 137 | 138 | -------------------------------------------------------------------------------- /adabn.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' 3 | import sys 4 | import argparse 5 | import numpy as np 6 | import cv2 7 | import random 8 | import json 9 | import mxnet as mx 10 | 11 | def get_bn_input_symbol(bn_symbol): 12 | for sym in bn_symbol.get_children(): 13 | #print(sym.name) 14 | if sym.name.startswith('_plus') or sym.name.find('conv')>=0: 15 | return sym 16 | return None 17 | 18 | def get_adabn_params(symbol, arg_params, aux_params, bn_layer_names = ['bn1','stage4_unit3_bn3']): 19 | ret = {} 20 | for _bn in bn_layer_names: 21 | ret[_bn] = None 22 | #print(aux_params[_bn]) 23 | bn_mean = aux_params[_bn+"_moving_mean"].asnumpy() 24 | bn_var = aux_params[_bn+"_moving_var"].asnumpy() 25 | bn_gamma = arg_params[_bn+"_gamma"].asnumpy() 26 | bn_beta = arg_params[_bn+"_beta"].asnumpy() 27 | print(bn_mean.shape) 28 | print(bn_var.shape) 29 | print(bn_gamma.shape) 30 | print(bn_beta.shape) 31 | #bn_layer_name = "bn1" 32 | #print(arg_params.__class__) 33 | #print(aux_params.__class__) 34 | 35 | all_layers = sym.get_internals() 36 | for layer in all_layers: 37 | #print(layer.name) 38 | if layer.name in ret: 39 | bn_input = get_bn_input_symbol(layer) 40 | assert bn_input is not None 41 | ret[layer.name] = (bn_input, layer) 42 | #for _sym in layer.get_children(): 43 | # if _sym.name in arg_params: 44 | # print(_sym.name, arg_params[_sym.name]) 45 | # else: 46 | # print(_sym.name) 47 | return ret 48 | 49 | 50 | def ch_dev(arg_params, aux_params, ctx): 51 | new_args = dict() 52 | new_auxs = dict() 53 | for k, v in arg_params.items(): 54 | new_args[k] = v.as_in_context(ctx) 55 | for k, v in aux_params.items(): 56 | new_auxs[k] = v.as_in_context(ctx) 57 | return new_args, new_auxs 58 | 59 | 60 | img_sz = 360 61 | crop_sz = 320 62 | batch_sz = 64 63 | 64 | def image_preprocess(img_full_path, loop): 65 | img = cv2.cvtColor(cv2.imread(img_full_path), cv2.COLOR_BGR2RGB) 66 | img = np.float32(img) 67 | ori_shape = img.shape 68 | assert img.shape[2]==3 69 | 70 | img = cv2.resize(img, (img_sz, img_sz), interpolation=cv2.INTER_CUBIC) 71 | h, w, _ = img.shape 72 | x0 = int((w - crop_sz) / 2) 73 | y0 = int((h - crop_sz) / 2) 74 | 75 | img = img[y0:y0+crop_sz, x0:x0+crop_sz] 76 | 77 | if loop%2==1: 78 | img = np.fliplr(img) 79 | 80 | img = np.swapaxes(img, 0, 2) 81 | img = np.swapaxes(img, 1, 2) # change to r,g,b order 82 | return img 83 | 84 | 85 | def apply_adabn(ctx, sym, arg_params, aux_params, imgs): 86 | batch_head = 0 87 | batch_num = 0 88 | adabn = get_adabn_params(sym, arg_params, aux_params) 89 | adabn_list = [] 90 | sym_list = [] 91 | X_list = [] 92 | for k,v in adabn.iteritems(): 93 | adabn_list.append( (k,v) ) 94 | sym_list.append(v[0]) 95 | X_list.append([]) 96 | #_adabn = (k,v) 97 | num_bn = len(X_list) 98 | req = mx.symbol.Group(sym_list) 99 | new_auxs = dict() 100 | for k, v in aux_params.items(): 101 | #print(v.__class__) 102 | new_auxs[k] = v.as_in_context(ctx) 103 | while batch_head= 1: 9 | return (args.lr, None) 10 | epoch_size = args.num_examples / args.batch_size 11 | #epoch_size = 400 12 | if 'dist' in args.kv_store: 13 | epoch_size /= kv.num_workers 14 | #return (args.lr, mx.lr_scheduler.FactorScheduler(step=epoch_size*2, factor=args.lr_factor)) 15 | begin_epoch = args.load_epoch if args.load_epoch else 0 16 | step_epochs = [int(l) for l in args.lr_step_epochs.split(',')] 17 | lr = args.lr 18 | for s in step_epochs: 19 | if begin_epoch >= s: 20 | lr *= args.lr_factor 21 | if lr != args.lr: 22 | logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch)) 23 | 24 | steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0] 25 | return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor)) 26 | 27 | def _load_model(args, rank=0): 28 | if 'load_epoch' not in args or args.load_epoch is None: 29 | return (None, None, None) 30 | assert args.model_prefix is not None 31 | model_prefix = args.model_prefix 32 | if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)): 33 | model_prefix += "-%d" % (rank) 34 | sym, arg_params, aux_params = mx.model.load_checkpoint( 35 | model_prefix, args.load_epoch) 36 | logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) 37 | return (sym, arg_params, aux_params) 38 | 39 | def _save_model(args, rank=0): 40 | if args.model_prefix is None: 41 | return None 42 | #args.epoch_id += 1 43 | #print("save model stat %d,%d" % (args.epoch_id, args.num_epochs)) 44 | #if args.epoch_id>=args.num_epochs: 45 | dst_dir = os.path.dirname(args.model_prefix) 46 | if not os.path.isdir(dst_dir): 47 | os.mkdir(dst_dir) 48 | return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % ( 49 | args.model_prefix, rank)) 50 | 51 | def add_fit_args(parser): 52 | """ 53 | parser : argparse.ArgumentParser 54 | return a parser added with args required by fit 55 | """ 56 | train = parser.add_argument_group('Training', 'model training') 57 | train.add_argument('--network', type=str, 58 | help='the neural network to use') 59 | train.add_argument('--num-layers', type=int, 60 | help='number of layers in the neural network, required by some networks such as resnet') 61 | train.add_argument('--gpus', type=str, 62 | help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu') 63 | train.add_argument('--kv-store', type=str, default='device', 64 | help='key-value store type') 65 | train.add_argument('--num-epochs', type=int, default=100, 66 | help='max num of epochs') 67 | train.add_argument('--lr', type=float, default=0.1, 68 | help='initial learning rate') 69 | train.add_argument('--lr-factor', type=float, default=0.1, 70 | help='the ratio to reduce lr on each step') 71 | train.add_argument('--lr-step-epochs', type=str, 72 | help='the epochs to reduce the lr, e.g. 30,60') 73 | train.add_argument('--optimizer', type=str, default='sgd', 74 | help='the optimizer type') 75 | train.add_argument('--mom', type=float, default=0.9, 76 | help='momentum for sgd') 77 | train.add_argument('--wd', type=float, default=0.0001, 78 | help='weight decay for sgd') 79 | train.add_argument('--batch-size', type=int, default=128, 80 | help='the batch size') 81 | train.add_argument('--disp-batches', type=int, default=10, 82 | help='show progress for every n batches') 83 | train.add_argument('--model-prefix', type=str, 84 | help='model prefix') 85 | parser.add_argument('--monitor', dest='monitor', type=int, default=0, 86 | help='log network parameters every N iters if larger than 0') 87 | train.add_argument('--load-epoch', type=int, 88 | help='load the model on an epoch using the model-load-prefix') 89 | train.add_argument('--top-k', type=int, default=5, 90 | help='report the top-k accuracy. 0 means no report.') 91 | train.add_argument('--test-io', type=int, default=0, 92 | help='1 means test reading speed without training') 93 | return train 94 | 95 | 96 | class JAccuracy(mx.metric.EvalMetric): 97 | def __init__(self): 98 | self.axis = 1 99 | super(JAccuracy, self).__init__( 100 | 'jaccuracy', axis=self.axis, 101 | output_names=None, label_names=None) 102 | self.loss = [] 103 | 104 | def update(self, labels, preds): 105 | #print('labels', len(labels), labels[0].shape) 106 | #print('preds', len(preds), preds[0].shape, preds[1].shape) 107 | #print('label', labels[0].asnumpy()) 108 | if len(preds)>1: 109 | self.loss.append(preds[1].asnumpy()[0]) 110 | if len(self.loss)==100: 111 | mmd_loss = sum(self.loss)/len(self.loss) 112 | print('mmd_loss', mmd_loss) 113 | self.loss = [] 114 | #print('preds',len(preds)) 115 | sys.stdout.flush() 116 | for label, pred_label in zip(labels, preds): 117 | #debug = mx.nd.slice_axis(pred_label, axis=0, begin=0, end=1) 118 | #debug = debug.asnumpy().flatten() 119 | #print(debug[0:20]) 120 | if pred_label.shape != label.shape: 121 | pred_label = mx.ndarray.argmax(pred_label, axis=self.axis) 122 | pred_label = pred_label.asnumpy().astype('int32') 123 | label = label.asnumpy().astype('int32') 124 | assert label.shape[0]==pred_label.shape[0] 125 | if len(label.shape)>1: 126 | label = label[:,0].flatten() 127 | pred_label = pred_label.flatten() 128 | label = label.flatten() 129 | stat = [0,0] 130 | for i in xrange(len(label)): 131 | if label[i]==9999: 132 | continue 133 | stat[1]+=1 134 | if label[i]==pred_label[i]: 135 | stat[0]+=1 136 | #print(label) 137 | #print(pred_label) 138 | #print(label.shape) 139 | 140 | #check_label_shapes(label, pred_label) 141 | 142 | #print('eval_stat', stat) 143 | self.sum_metric += stat[0] 144 | self.num_inst += stat[1] 145 | 146 | def fit(args, network, data_loader, **kwargs): 147 | """ 148 | train a model 149 | args : argparse returns 150 | network : the symbol definition of the nerual network 151 | data_loader : function that returns the train and val data iterators 152 | """ 153 | # kvstore 154 | kv = mx.kvstore.create(args.kv_store) 155 | 156 | # logging 157 | head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' 158 | logging.basicConfig(level=logging.DEBUG, format=head) 159 | logging.info('start with arguments %s', args) 160 | 161 | # data iterators 162 | (train, val) = data_loader(args, kv) 163 | #print train.provide_label 164 | if args.test_io: 165 | tic = time.time() 166 | for i, batch in enumerate(train): 167 | for j in batch.data: 168 | j.wait_to_read() 169 | if (i+1) % args.disp_batches == 0: 170 | logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % ( 171 | i, args.disp_batches*args.batch_size/(time.time()-tic))) 172 | tic = time.time() 173 | 174 | return 175 | 176 | 177 | # load model 178 | if 'arg_params' in kwargs and 'aux_params' in kwargs: 179 | arg_params = kwargs['arg_params'] 180 | aux_params = kwargs['aux_params'] 181 | else: 182 | sym, arg_params, aux_params = _load_model(args, kv.rank) 183 | if sym is not None: 184 | assert sym.tojson() == network.tojson() 185 | 186 | args.epoch_id = -1 187 | 188 | # save model 189 | checkpoint = _save_model(args, kv.rank) 190 | if args.no_checkpoint: 191 | checkpoint = None 192 | 193 | # devices for training 194 | devs = mx.cpu() if args.gpus is None or args.gpus is '' else [ 195 | mx.gpu(int(i)) for i in args.gpus.split(',')] 196 | 197 | # learning rate 198 | lr, lr_scheduler = _get_lr_scheduler(args, kv) 199 | fixed_param_names = None 200 | if 'fixed_param_names' in kwargs: 201 | fixed_param_names = kwargs['fixed_param_names'] 202 | 203 | if fixed_param_names is not None: 204 | print('fixed_param_names', len(fixed_param_names)) 205 | 206 | # create model 207 | model = mx.mod.Module( 208 | context = devs, 209 | symbol = network, 210 | fixed_param_names = fixed_param_names, 211 | #label_names = ['softmax_label', 'softmax2_label'] 212 | ) 213 | 214 | lr_scheduler = lr_scheduler 215 | optimizer_params = { 216 | 'learning_rate': lr, 217 | 'momentum' : args.mom, 218 | 'wd' : args.wd, 219 | 'lr_scheduler': lr_scheduler} 220 | 221 | monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None 222 | 223 | if args.network == 'alexnet': 224 | # AlexNet will not converge using Xavier 225 | initializer = mx.init.Normal() 226 | else: 227 | initializer = mx.init.Xavier( 228 | rnd_type='gaussian', factor_type="in", magnitude=2) 229 | # initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), 230 | 231 | # evaluation metrices 232 | #eval_metrics = ['accuracy'] 233 | metric = JAccuracy() 234 | eval_metrics = [mx.metric.create(metric)] 235 | 236 | # callbacks that run after each batch 237 | batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)] 238 | if 'batch_end_callback' in kwargs: 239 | cbs = kwargs['batch_end_callback'] 240 | batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs] 241 | 242 | # run 243 | model.fit(train, 244 | begin_epoch = args.load_epoch if args.load_epoch else 0, 245 | num_epoch = args.num_epochs, 246 | eval_data = val, 247 | eval_metric = eval_metrics, 248 | kvstore = kv, 249 | optimizer = args.optimizer, 250 | optimizer_params = optimizer_params, 251 | initializer = initializer, 252 | arg_params = arg_params, 253 | aux_params = aux_params, 254 | batch_end_callback = batch_end_callbacks, 255 | epoch_end_callback = checkpoint, 256 | allow_missing = True, 257 | monitor = monitor) 258 | 259 | -------------------------------------------------------------------------------- /common/data.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import random 3 | import sys 4 | from mxnet.io import DataBatch, DataIter 5 | import numpy as np 6 | 7 | def add_data_args(parser): 8 | data = parser.add_argument_group('Data', 'the input images') 9 | data.add_argument('--data-dir', type=str, default='./data', 10 | help='the data dir') 11 | #data.add_argument('--data-train', type=str, help='the training data') 12 | #data.add_argument('--data-val', type=str, help='the validation data') 13 | data.add_argument('--rgb-mean', type=str, default='123.68,116.779,103.939', 14 | help='a tuple of size 3 for the mean rgb') 15 | data.add_argument('--pad-size', type=int, default=0, 16 | help='padding the input image') 17 | data.add_argument('--image-shape', type=str, 18 | help='the image shape feed into the network, e.g. (3,224,224)') 19 | data.add_argument('--num-classes', type=int, help='the number of classes') 20 | #data.add_argument('--num-examples', type=int, help='the number of training examples') 21 | data.add_argument('--data-nthreads', type=int, default=4, 22 | help='number of threads for data decoding') 23 | data.add_argument('--benchmark', type=int, default=0, 24 | help='if 1, then feed the network with synthetic data') 25 | data.add_argument('--dtype', type=str, default='float32', 26 | help='data type: float32 or float16') 27 | return data 28 | 29 | def add_data_aug_args(parser): 30 | aug = parser.add_argument_group( 31 | 'Image augmentations', 'implemented in src/io/image_aug_default.cc') 32 | aug.add_argument('--random-crop', type=int, default=1, 33 | help='if or not randomly crop the image') 34 | aug.add_argument('--random-mirror', type=int, default=1, 35 | help='if or not randomly flip horizontally') 36 | aug.add_argument('--max-random-h', type=int, default=0, 37 | help='max change of hue, whose range is [0, 180]') 38 | aug.add_argument('--max-random-s', type=int, default=0, 39 | help='max change of saturation, whose range is [0, 255]') 40 | aug.add_argument('--max-random-l', type=int, default=0, 41 | help='max change of intensity, whose range is [0, 255]') 42 | aug.add_argument('--max-random-aspect-ratio', type=float, default=0, 43 | help='max change of aspect ratio, whose range is [0, 1]') 44 | aug.add_argument('--max-random-rotate-angle', type=int, default=0, 45 | help='max angle to rotate, whose range is [0, 360]') 46 | aug.add_argument('--max-random-shear-ratio', type=float, default=0, 47 | help='max ratio to shear, whose range is [0, 1]') 48 | aug.add_argument('--max-random-scale', type=float, default=1, 49 | help='max ratio to scale') 50 | aug.add_argument('--min-random-scale', type=float, default=1, 51 | help='min ratio to scale, should >= img_size/input_shape. otherwise use --pad-size') 52 | return aug 53 | 54 | def set_data_aug_level(aug, level): 55 | if level >= 1: 56 | aug.set_defaults(random_crop=1, random_mirror=1) 57 | if level >= 2: 58 | aug.set_defaults(max_random_h=36, max_random_s=50, max_random_l=50) 59 | if level >= 3: 60 | aug.set_defaults(max_random_rotate_angle=10, max_random_shear_ratio=0.1, max_random_aspect_ratio=0.25) 61 | 62 | 63 | 64 | class StIter(mx.io.DataIter): 65 | def __init__(self, source_iter, target_iter, data_shape, batch_size, gpu_num, data_name='data', label_name='softmax_label'): 66 | #self.iters = iter_list 67 | self.source_iter = source_iter 68 | self.target_iter = target_iter 69 | self.data_name = data_name 70 | self.label_name = label_name 71 | self.data_shape = data_shape 72 | self.batch_size = batch_size 73 | self.gpu_num = gpu_num 74 | #print(source_iter.provide_data) 75 | self.provide_data = [(self.data_name, (batch_size,) + self.data_shape)] 76 | self.provide_label = [(self.label_name, (batch_size,) )] 77 | self.bs_in_each_part = self.batch_size/2/gpu_num 78 | 79 | 80 | def part_slice(self, d, i): 81 | _d = mx.ndarray.slice_axis(d, axis=0, begin=i*self.bs_in_each_part, end=(i+1)*self.bs_in_each_part) 82 | return _d 83 | 84 | def next(self): 85 | source_batch = self.source_iter.next() 86 | target_batch = self.target_iter.next() 87 | source_data = source_batch.data[0] 88 | target_data = target_batch.data[0] 89 | source_label = source_batch.label[0] 90 | target_label = target_batch.label[0] 91 | data_parts = [] 92 | label_parts = [] 93 | for i in xrange(self.gpu_num): 94 | _source_data = self.part_slice(source_data, i) 95 | _target_data = self.part_slice(target_data, i) 96 | _source_label = self.part_slice(source_label, i) 97 | _target_label = self.part_slice(target_label, i) 98 | #_target_label = mx.ndarray.ones_like(_target_label)*9999 99 | data_parts.append(_source_data) 100 | data_parts.append(_target_data) 101 | label_parts.append(_source_label) 102 | label_parts.append(_target_label) 103 | self.data = mx.ndarray.concat(*data_parts, dim=0) 104 | self.label = mx.ndarray.concat(*label_parts, dim=0) 105 | #self.data = mx.ndarray.concat( source_data, target_data, dim=0) 106 | #print(self.data.shape) 107 | #self.label = mx.ndarray.concat(source_batch.label[0], target_batch.label[0], dim=0) 108 | #provide_data = [(self.data_name, self.data.shape)] 109 | #provide_label = [(self.label_name, self.label.shape)] 110 | #print(self.label.shape) 111 | #return {self.data_name : [self.data], 112 | # self.label_name : [self.label]} 113 | return DataBatch(data=(self.data,), 114 | label=(self.label,)) 115 | 116 | def reset(self): 117 | self.source_iter.reset() 118 | self.target_iter.reset() 119 | #@property 120 | #def provide_data(self): 121 | # #return [(k, self.data.shape)] 122 | # return [mx.io.DataDesc(self.data_name, self.data.shape, self.data.dtype)] 123 | #@property 124 | #def provide_label(self): 125 | # #return [(k, self.label.shape)] 126 | # return [mx.io.DataDesc(self.label_name, self.label.shape, self.label.dtype)] 127 | 128 | def read_lst(file): 129 | ret = [] 130 | with open(file, 'r') as f: 131 | for line in f: 132 | vec = line.strip().split("\t") 133 | label = int(vec[1]) 134 | img_path = vec[2] 135 | ret.append( [label, img_path] ) 136 | return ret 137 | 138 | def get_rec_iter(args, kv=None): 139 | image_shape = tuple([int(l) for l in args.image_shape.split(',')]) 140 | dtype = np.float32; 141 | if 'dtype' in args: 142 | if args.dtype == 'float16': 143 | dtype = np.float16 144 | if kv: 145 | (rank, nworker) = (kv.rank, kv.num_workers) 146 | else: 147 | (rank, nworker) = (0, 1) 148 | rgb_mean = [float(i) for i in args.rgb_mean.split(',')] 149 | #print(rank, nworker, args.batch_size) 150 | train_resize = int(image_shape[1]*1.6) 151 | if args.train_stage==0: 152 | imglist = read_lst(args.data_dir+"/source.lst") 153 | print(len(imglist)) 154 | assert len(imglist)>0 155 | imglist2 = read_lst(args.data_dir+"/target.lst") 156 | assert len(imglist2)>0 157 | imglist2 = [x for x in imglist2 if x[0]!=args.null_label] 158 | imglist += imglist2 159 | print(len(imglist)) 160 | train = mx.img.ImageIter( 161 | label_width = args.label_width, 162 | path_root = '', 163 | #path_imglist = args.data_dir+"/train_part.lst", 164 | imglist = imglist, 165 | data_shape = image_shape, 166 | batch_size = args.batch_size, 167 | resize = train_resize, 168 | rand_crop = True, 169 | rand_resize = True, 170 | rand_mirror = True, 171 | shuffle = True, 172 | brightness = 0.4, 173 | contrast = 0.4, 174 | saturation = 0.4, 175 | pca_noise = 0.1, 176 | num_parts = nworker, 177 | part_index = rank) 178 | else: 179 | source = mx.img.ImageIter( 180 | label_width = args.label_width, 181 | path_root = '', 182 | path_imglist = args.data_dir+"/source.lst", 183 | data_shape = image_shape, 184 | batch_size = args.batch_size/2, 185 | resize = train_resize, 186 | rand_crop = True, 187 | rand_resize = True, 188 | rand_mirror = True, 189 | shuffle = True, 190 | brightness = 0.4, 191 | contrast = 0.4, 192 | saturation = 0.4, 193 | pca_noise = 0.1, 194 | data_name = 'data_source', 195 | label_name = 'label_source', 196 | num_parts = nworker, 197 | part_index = rank) 198 | target_file = args.data_dir+"/target.lst" 199 | target = mx.img.ImageIter( 200 | label_width = args.label_width, 201 | path_root = '', 202 | path_imglist = target_file, 203 | data_shape = image_shape, 204 | batch_size = args.batch_size/2, 205 | resize = train_resize, 206 | rand_crop = True, 207 | rand_resize = True, 208 | rand_mirror = True, 209 | shuffle = True, 210 | brightness = 0.4, 211 | contrast = 0.4, 212 | saturation = 0.4, 213 | pca_noise = 0.1, 214 | data_name = 'data_target', 215 | label_name = 'label_target', 216 | num_parts = nworker, 217 | part_index = rank) 218 | train = StIter(source, target, image_shape, args.batch_size, args.gpu_num) 219 | val = mx.img.ImageIter( 220 | label_width = args.label_width, 221 | path_root = '', 222 | path_imglist = args.data_dir+"/val.lst", 223 | batch_size = args.batch_size, 224 | data_shape = image_shape, 225 | resize = int(image_shape[1]*1.125), 226 | rand_crop = False, 227 | rand_resize = False, 228 | rand_mirror = False, 229 | num_parts = nworker, 230 | part_index = rank) 231 | return (train, val) 232 | 233 | def test_st(): 234 | image_shape = (3,320,320) 235 | source = mx.img.ImageIter( 236 | label_width = 1, 237 | #path_root = 'data/', 238 | #path_imglist = args.data_train, 239 | path_imgrec = './data-asv/train.rec', 240 | path_imgidx = './data-asv/train.idx', 241 | #data_shape = (3, 320, 320), 242 | data_shape = image_shape, 243 | batch_size = 16/2, 244 | rand_crop = True, 245 | rand_resize = True, 246 | rand_mirror = True, 247 | shuffle = True, 248 | brightness = 0.4, 249 | contrast = 0.4, 250 | saturation = 0.4, 251 | pca_noise = 0.1, 252 | data_name = 'data_source', 253 | label_name = 'label_source', 254 | num_parts = 1, 255 | part_index = 0) 256 | target = mx.img.ImageIter( 257 | label_width = 1, 258 | #path_root = 'data/', 259 | #path_imglist = args.data_train, 260 | path_imgrec = './data-asv/val.rec', 261 | path_imgidx = './data-asv/val.idx', 262 | #data_shape = (3, 320, 320), 263 | data_shape = image_shape, 264 | batch_size = 16/2, 265 | rand_crop = True, 266 | rand_resize = True, 267 | rand_mirror = True, 268 | shuffle = True, 269 | brightness = 0.4, 270 | contrast = 0.4, 271 | saturation = 0.4, 272 | pca_noise = 0.1, 273 | data_name = 'data_target', 274 | label_name = 'label_target', 275 | num_parts = 1, 276 | part_index = 0) 277 | train = StIter(source, target, image_shape, 64, 4) 278 | for batch in train: 279 | data = batch.data 280 | label = batch.label 281 | print(data) 282 | print(label) 283 | #print(label.asnumpy()) 284 | 285 | 286 | if __name__ == '__main__': 287 | test_st() 288 | 289 | 290 | --------------------------------------------------------------------------------