├── .github └── ISSUE_TEMPLATE.md ├── .gitignore ├── README.md ├── data.py ├── eval_knn.py ├── main_lincls.py ├── main_moco.py ├── resnet.py ├── serve-data.py └── tox.ini /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | The only goal of this project is to reproduce results in papers. 2 | We do not take feature requests or questions that are unrelated to this goal. 3 | 4 | If you met an unexpected problem when using the code, 5 | please include the following in your issues: 6 | 7 | * What you did: the command you run. 8 | 9 | * What you observed: the full logs and other relevant information. 10 | 11 | * What you expected, if not obvious. 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # tensorpack-specific stuff 2 | fake_train_log 3 | train_log 4 | train_log_* 5 | logs 6 | *.npy 7 | *.npz 8 | *.caffemodel 9 | *.tfmodel 10 | *.meta 11 | *.log* 12 | *.bin 13 | *.png 14 | *.jpg 15 | checkpoint 16 | *.json 17 | *.prototxt 18 | *.txt 19 | *.tgz 20 | *.gz 21 | 22 | # my personal stuff 23 | snippet 24 | examples/private 25 | TODO.md 26 | .gitignore 27 | .vimrc.local 28 | 29 | 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | env/ 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *,cover 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | *.dat 88 | 89 | .idea/ 90 | *.diff 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Implement and __reproduce__ results of the following papers: 3 | 4 | * [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/abs/1911.05722) 5 | * [Improved Baselines with Momentum Contrastive Learning](https://arxiv.org/abs/2003.04297) 6 | 7 | ## Dependencies: 8 | 9 | * TensorFlow 1.14 or 1.15, built with XLA support 10 | * [Tensorpack](https://github.com/tensorpack/tensorpack/) ≥ 0.10.1 11 | * [Horovod](https://github.com/horovod/horovod) ≥ 0.19 built with Gloo & NCCL support 12 | * TensorFlow [zmq_ops](https://github.com/tensorpack/zmq_ops) 13 | * OpenCV 14 | * the `taskset` command (from the `util-linux` package) 15 | 16 | ## Unsupervised Training: 17 | 18 | To run MoCo pre-training on a machine with 8 GPUs, use: 19 | ``` 20 | horovodrun -np 8 --output-filename moco.log python main_moco.py --data /path/to/imagenet 21 | ``` 22 | 23 | Add `--v2` to train MoCov2, 24 | which uses an extra MLP layer, extra augmentations, and cosine LR schedule. 25 | 26 | 27 | ## Linear Classification: 28 | To train a linear classifier using the pre-trained features, run: 29 | ``` 30 | ./main_lincls.py --load /path/to/pretrained/checkpoint --data /path/to/imagenet 31 | ``` 32 | 33 | ## KNN Evaluation: 34 | Instead of Linear Classification, a cheap but rough evaluation 35 | is to perform a feature-space kNN using the training set: 36 | ``` 37 | horovodrun -np 8 ./eval_knn.py --load /path/to/checkpoint --data /path/to/imagenet --top-k 200 38 | ``` 39 | 40 | ## Results: 41 | Training was done in a machine with 8 V100s, >200GB RAM and 80 CPUs. 42 | 43 | Following results are obtained after 44 | 200 epochs of pre-training (~53h) 45 | and 100 epochs of linear classifier tuning (~8h). 46 | KNN evaluation takes 10min per checkpoint. 47 | 48 | | | linear cls.
accuracy | download
(pretrained only) | tensorboard | 49 | | - | :-: | :-: | :-: | 50 | | MoCo v1 | 60.9% | [:arrow_down:](https://github.com/ppwwyyxx/moco.tensorflow/releases/download/v/MoCo_v1.npz) | N/A | 51 | | MoCo v2 | 67.7% | [:arrow_down:](https://github.com/ppwwyyxx/moco.tensorflow/releases/download/v/MoCo_v2.npz) | [pretrain](https://tensorboard.dev/experiment/MBL49FKLTLCbKGr7JMolWQ); [finetune](https://tensorboard.dev/experiment/s3ZOxbjbRCy3hMqgL0TzKQ) | 52 | 53 | ## Notes: 54 | 55 | * Horovod with Gloo is recommended. Horovod with MPI is not tested and may crash due to how we use forking. 56 | * If using TensorFlow without XLA support, you can modify `main_*.py` to replace `xla.compile` by a naive forward. 57 | * Official PyTorch code is at [facebookresearch/moco](https://github.com/facebookresearch/moco). 58 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import cv2 5 | import multiprocessing as mp 6 | import tensorflow as tf 7 | 8 | from tensorpack.dataflow import ( 9 | BatchData, MultiProcessMapAndBatchDataZMQ, MultiProcessRunnerZMQ, MultiThreadMapData, dataset, 10 | imgaug) 11 | 12 | 13 | cv2.setNumThreads(0) 14 | 15 | 16 | def get_moco_v1_augmentor(): 17 | augmentors = [ 18 | imgaug.GoogleNetRandomCropAndResize(crop_area_fraction=(0.2, 1.)), 19 | imgaug.RandomApplyAug(imgaug.Grayscale(rgb=False, keepshape=True), 0.2), 20 | imgaug.ToFloat32(), 21 | imgaug.RandomOrderAug( 22 | [imgaug.BrightnessScale((0.6, 1.4)), 23 | imgaug.Contrast((0.6, 1.4), rgb=False), 24 | imgaug.Saturation(0.4, rgb=False), 25 | # 72 = 180*0.4 26 | imgaug.Hue(range=(-72, 72), rgb=False) 27 | ]), 28 | imgaug.ToUint8(), 29 | imgaug.Flip(horiz=True), 30 | ] 31 | return augmentors 32 | 33 | 34 | def get_moco_v2_augmentor(): 35 | augmentors = [ 36 | imgaug.GoogleNetRandomCropAndResize(crop_area_fraction=(0.2, 1.)), 37 | imgaug.ToFloat32(), 38 | imgaug.RandomApplyAug( 39 | imgaug.RandomOrderAug( 40 | [imgaug.BrightnessScale((0.6, 1.4)), 41 | imgaug.Contrast((0.6, 1.4), rgb=False), 42 | imgaug.Saturation(0.4, rgb=False), 43 | # 18 = 180*0.1 44 | imgaug.Hue(range=(-18, 18), rgb=False) 45 | ]), 0.8), 46 | imgaug.RandomApplyAug(imgaug.Grayscale(rgb=False, keepshape=True), 0.2), 47 | imgaug.RandomApplyAug( 48 | # 11 = 0.1*224//2 49 | imgaug.GaussianBlur(size_range=(11, 12), sigma_range=[0.1, 2.0]), 0.5), 50 | imgaug.ToUint8(), 51 | imgaug.Flip(horiz=True), 52 | ] 53 | return augmentors 54 | 55 | 56 | class MoCoMapper: 57 | def __init__(self, augs): 58 | self.augs = augs 59 | 60 | def __call__(self, dp): 61 | fname, _ = dp # throw away the label 62 | img = cv2.imread(fname) 63 | img1 = self.augs.augment(img) 64 | img2 = self.augs.augment(img) 65 | return [img1, img2] 66 | 67 | 68 | def get_moco_dataflow(datadir, batch_size, augmentors): 69 | """ 70 | Dataflow for training MOCO. 71 | """ 72 | augmentors = imgaug.AugmentorList(augmentors) 73 | parallel = min(30, mp.cpu_count()) # tuned on a 40-CPU 80-core machine 74 | ds = dataset.ILSVRC12Files(datadir, 'train', shuffle=True) 75 | ds = MultiProcessMapAndBatchDataZMQ(ds, parallel, MoCoMapper(augmentors), batch_size, buffer_size=5000) 76 | return ds 77 | 78 | 79 | def get_basic_augmentor(isTrain): 80 | interpolation = cv2.INTER_LINEAR 81 | if isTrain: 82 | augmentors = [ 83 | imgaug.GoogleNetRandomCropAndResize(), 84 | imgaug.Flip(horiz=True), 85 | ] 86 | else: 87 | augmentors = [ 88 | imgaug.ResizeShortestEdge(256, interp=interpolation), 89 | imgaug.CenterCrop((224, 224)), 90 | ] 91 | return augmentors 92 | 93 | 94 | def get_imagenet_dataflow(datadir, name, batch_size, parallel=None): 95 | """ 96 | Get a standard imagenet training/evaluation dataflow, for linear classifier tuning. 97 | """ 98 | assert name in ['train', 'val'] 99 | isTrain = name == 'train' 100 | assert datadir is not None 101 | augmentors = get_basic_augmentor(isTrain) 102 | augmentors = imgaug.AugmentorList(augmentors) 103 | if parallel is None: 104 | parallel = min(50, mp.cpu_count()) 105 | 106 | def mapper(dp): 107 | fname, label = dp 108 | img = cv2.imread(fname) 109 | img = augmentors.augment(img) 110 | return img, label 111 | 112 | if isTrain: 113 | ds = dataset.ILSVRC12Files(datadir, name, shuffle=True) 114 | ds = MultiProcessMapAndBatchDataZMQ(ds, parallel, mapper, batch_size, buffer_size=7000) 115 | else: 116 | ds = dataset.ILSVRC12Files(datadir, name, shuffle=False) 117 | ds = MultiThreadMapData(ds, parallel, mapper, buffer_size=2000, strict=True) 118 | ds = BatchData(ds, batch_size, remainder=True) 119 | ds = MultiProcessRunnerZMQ(ds, 1) 120 | return ds 121 | 122 | 123 | def tf_preprocess(image): # normalize BGR images 124 | with tf.name_scope('image_preprocess'): 125 | if image.dtype.base_dtype != tf.float32: 126 | image = tf.cast(image, tf.float32) 127 | mean = [0.485, 0.456, 0.406] # rgb 128 | std = [0.229, 0.224, 0.225] 129 | mean = mean[::-1] 130 | std = std[::-1] 131 | image_mean = tf.constant(mean, dtype=tf.float32) * 255. 132 | image_std = tf.constant(std, dtype=tf.float32) * 255. 133 | image = (image - image_mean) / image_std 134 | return image 135 | 136 | 137 | if __name__ == '__main__': 138 | from tensorpack.dataflow import TestDataSpeed 139 | import sys 140 | df = get_imagenet_dataflow(sys.argv[1], 'train', 32) 141 | 142 | TestDataSpeed(df, size=99999999, warmup=300).start() 143 | -------------------------------------------------------------------------------- /eval_knn.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | import cv2 6 | import tqdm 7 | from collections import Counter 8 | 9 | from tensorpack import tfv1 as tf 10 | from tensorpack.utils.stats import Accuracy 11 | from tensorpack.utils import logger 12 | from tensorpack.tfutils import TowerContext, get_default_sess_config 13 | from tensorpack.tfutils.sessinit import SmartInit 14 | from tensorpack.tfutils.varmanip import get_checkpoint_path, get_all_checkpoints 15 | from tensorpack.dataflow import ( 16 | imgaug, DataFromList, BatchData, MultiProcessMapDataZMQ, dataset) 17 | 18 | import horovod.tensorflow as hvd 19 | from resnet import ResNetModel 20 | from data import get_basic_augmentor, get_imagenet_dataflow 21 | 22 | 23 | def build_dataflow(files): 24 | train_ds = DataFromList(files) 25 | aug = imgaug.AugmentorList(get_basic_augmentor(isTrain=False)) 26 | 27 | def mapper(dp): 28 | idx, fname, label = dp 29 | img = cv2.imread(fname) 30 | img = aug.augment(img) 31 | return img, idx 32 | 33 | train_ds = MultiProcessMapDataZMQ(train_ds, num_proc=8, map_func=mapper, strict=True) 34 | train_ds = BatchData(train_ds, local_batch_size) 35 | train_ds.reset_state() 36 | return train_ds 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--data', help='imagenet data dir') 42 | parser.add_argument('--batch', default=512, type=int, help='total batch size') 43 | parser.add_argument('--load', required=True, help='file or directory to evaluate') 44 | parser.add_argument('--top-k', type=int, default=200, help='top-k in KNN') 45 | parser.add_argument('--v2', action='store_true', help='use mocov2') 46 | args = parser.parse_args() 47 | 48 | hvd.init() 49 | local_batch_size = args.batch // hvd.size() 50 | 51 | train_files = dataset.ILSVRC12Files(args.data, 'train', shuffle=True) 52 | train_files.reset_state() 53 | all_train_files = list(train_files) 54 | all_train_files = all_train_files[:len(all_train_files) // args.batch * args.batch] # truncate 55 | num_train_images = len(all_train_files) 56 | logger.info(f"Creating graph for KNN of {num_train_images} training images ...") 57 | local_train_files = [(idx, fname, label) for idx, (fname, label) in 58 | enumerate(all_train_files) if idx % hvd.size() == hvd.rank()] 59 | 60 | image_input = tf.placeholder(tf.uint8, [None, 224, 224, 3], "image") 61 | idx_input = tf.placeholder(tf.int64, [None], "image_idx") 62 | 63 | feat_buffer = tf.get_variable("feature_buffer", shape=[num_train_images, 128], trainable=False) 64 | net = ResNetModel(num_output=(2048, 128) if args.v2 else (128,)) 65 | with TowerContext("", is_training=False): 66 | feat = net.forward(image_input) 67 | feat = tf.math.l2_normalize(feat, axis=1) # Nx128 68 | all_feat = hvd.allgather(feat) # GN x 128 69 | all_idx_input = hvd.allgather(idx_input) # GN 70 | update_buffer = tf.scatter_update(feat_buffer, all_idx_input, all_feat) 71 | 72 | dist = tf.matmul(feat, tf.transpose(feat_buffer)) # N x #DS 73 | _, topk_indices = tf.math.top_k(dist, k=args.top_k) # Nxtopk 74 | 75 | train_ds = build_dataflow(local_train_files) 76 | 77 | config = get_default_sess_config() 78 | config.gpu_options.visible_device_list = str(hvd.local_rank()) 79 | 80 | def evaluate(checkpoint_file): 81 | result_file = get_checkpoint_path(checkpoint_file) + f".knn{args.top_k}.txt" 82 | if os.path.isfile(result_file): 83 | logger.info(f"Skipping evaluation of {result_file}.") 84 | return 85 | with tf.Session(config=config) as sess: 86 | sess.run(tf.global_variables_initializer()) 87 | SmartInit(checkpoint_file).init(sess) 88 | for batch_img, batch_idx in tqdm.tqdm(train_ds, total=len(train_ds)): 89 | sess.run(update_buffer, 90 | feed_dict={image_input: batch_img, idx_input: batch_idx}) 91 | 92 | if hvd.rank() == 0: 93 | acc = Accuracy() 94 | val_df = get_imagenet_dataflow(args.data, "val", local_batch_size) 95 | val_df.reset_state() 96 | 97 | for batch_img, batch_label in val_df: 98 | topk_indices_pred = sess.run(topk_indices, feed_dict={image_input: batch_img}) 99 | for indices, gt in zip(topk_indices_pred, batch_label): 100 | pred = [all_train_files[k][1] for k in indices] 101 | top_pred = Counter(pred).most_common(1)[0] 102 | acc.feed(top_pred[0] == gt, total=1) 103 | logger.info(f"Accuracy of {checkpoint_file}: {acc.accuracy} out of {acc.total}") 104 | with open(result_file, "w") as f: 105 | f.write(str(acc.accuracy)) 106 | 107 | if os.path.isdir(args.load): 108 | for fname, _ in get_all_checkpoints(args.load): 109 | logger.info(f"Evaluating {fname} ...") 110 | evaluate(fname) 111 | else: 112 | evaluate(args.load) 113 | -------------------------------------------------------------------------------- /main_lincls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import tensorflow as tf 6 | 7 | from tensorpack.callbacks import ( 8 | ClassificationError, DataParallelInferenceRunner, EstimatedTimeLeft, InferenceRunner, 9 | ModelSaver, ScheduledHyperParamSetter, ThroughputTracker) 10 | from tensorpack.dataflow import FakeData 11 | from tensorpack.input_source import QueueInput 12 | from tensorpack.models import BatchNorm, FullyConnected 13 | from tensorpack.tfutils import SaverRestore, argscope, varreplace, SmartInit 14 | from tensorpack.tfutils.summary import add_moving_summary 15 | from tensorpack.train import ( 16 | ModelDesc, SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config) 17 | from tensorpack.utils import logger 18 | from tensorpack.utils.gpu import get_num_gpu 19 | 20 | from data import get_imagenet_dataflow 21 | from resnet import ResNetModel 22 | 23 | 24 | class LinearModel(ModelDesc): 25 | def __init__(self): 26 | self.net = ResNetModel(num_output=None) 27 | self.image_shape = 224 28 | 29 | def inputs(self): 30 | return [tf.TensorSpec([None, self.image_shape, self.image_shape, 3], tf.uint8, 'input'), 31 | tf.TensorSpec([None], tf.int32, 'label')] 32 | 33 | def compute_loss_and_error(self, logits, label): 34 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) 35 | loss = tf.reduce_mean(loss, name='xentropy-loss') 36 | 37 | def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): 38 | with tf.name_scope('prediction_incorrect'): 39 | x = tf.logical_not(tf.nn.in_top_k(logits, label, topk)) 40 | return tf.cast(x, tf.float32, name=name) 41 | 42 | wrong = prediction_incorrect(logits, label, 1, name='wrong-top1') 43 | add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1')) 44 | 45 | wrong = prediction_incorrect(logits, label, 5, name='wrong-top5') 46 | add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5')) 47 | return loss 48 | 49 | def build_graph(self, input, label): 50 | with argscope(BatchNorm, training=False), \ 51 | varreplace.freeze_variables(skip_collection=True): 52 | from tensorflow.python.compiler.xla import xla 53 | feature = xla.compile(lambda: self.net.forward(input))[0] 54 | # feature = self.net.forward(input) 55 | feature = tf.stop_gradient(feature) # double safe 56 | logits = FullyConnected( 57 | 'linear_cls', feature, 1000, 58 | kernel_initializer=tf.random_normal_initializer(stddev=0.01), 59 | bias_initializer=tf.constant_initializer()) 60 | 61 | tf.nn.softmax(logits, name='prob') 62 | loss = self.compute_loss_and_error(logits, label) 63 | 64 | # weight decay is 0 65 | add_moving_summary(loss) 66 | return loss 67 | 68 | def optimizer(self): 69 | lr = tf.get_variable('learning_rate', initializer=0.0, trainable=False) 70 | tf.summary.scalar('learning_rate-summary', lr) 71 | opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=False) 72 | return opt 73 | 74 | 75 | def get_config(model): 76 | nr_tower = max(get_num_gpu(), 1) 77 | batch = args.batch // nr_tower 78 | 79 | logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch)) 80 | 81 | callbacks = [ThroughputTracker(args.batch)] 82 | if args.fake: 83 | data = QueueInput(FakeData( 84 | [[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8')) 85 | else: 86 | data = QueueInput( 87 | get_imagenet_dataflow(args.data, 'train', batch), 88 | # use a larger queue 89 | queue=tf.FIFOQueue(200, [tf.uint8, tf.int32], [[batch, 224, 224, 3], [batch]]) 90 | ) 91 | 92 | BASE_LR = 30 93 | SCALED_LR = BASE_LR * (args.batch / 256.0) 94 | callbacks.extend([ 95 | ModelSaver(), 96 | EstimatedTimeLeft(), 97 | ScheduledHyperParamSetter( 98 | 'learning_rate', [ 99 | (0, SCALED_LR), 100 | (60, SCALED_LR * 1e-1), 101 | (70, SCALED_LR * 1e-2), 102 | (80, SCALED_LR * 1e-3), 103 | (90, SCALED_LR * 1e-4), 104 | ]), 105 | ]) 106 | 107 | dataset_val = get_imagenet_dataflow(args.data, 'val', 64) 108 | infs = [ClassificationError('wrong-top1', 'val-error-top1'), 109 | ClassificationError('wrong-top5', 'val-error-top5')] 110 | if nr_tower == 1: 111 | callbacks.append(InferenceRunner(QueueInput(dataset_val), infs)) 112 | else: 113 | callbacks.append(DataParallelInferenceRunner( 114 | dataset_val, infs, list(range(nr_tower)))) 115 | 116 | if args.load.endswith(".npz"): 117 | # a released model in npz format 118 | init = SmartInit(args.load) 119 | else: 120 | # a pre-trained checkpoint 121 | init = SaverRestore(args.load, ignore=("learning_rate", "global_step")) 122 | return TrainConfig( 123 | model=model, 124 | data=data, 125 | callbacks=callbacks, 126 | steps_per_epoch=100 if args.fake else 1281167 // args.batch, 127 | session_init=init, 128 | max_epoch=100, 129 | ) 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument('--data', help='imagenet data dir') 135 | parser.add_argument('--load', required=True, help='path to pre-trained model') 136 | parser.add_argument('--fake', help='use FakeData to debug or benchmark this model', action='store_true') 137 | parser.add_argument('--batch', default=256, type=int, help='total batch size') 138 | parser.add_argument('--logdir') 139 | args = parser.parse_args() 140 | 141 | model = LinearModel() 142 | 143 | if args.fake: 144 | logger.set_logger_dir('fake_train_log', 'd') 145 | else: 146 | if args.logdir is None: 147 | args.logdir = './moco_lincls' 148 | logger.set_logger_dir(args.logdir) 149 | 150 | config = get_config(model) 151 | trainer = SyncMultiGPUTrainerReplicated(get_num_gpu()) 152 | launch_train_with_config(config, trainer) 153 | -------------------------------------------------------------------------------- /main_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import numpy as np 6 | import os 7 | import subprocess 8 | import tensorflow as tf 9 | from tensorflow.python.compiler.xla import xla 10 | 11 | from tensorpack.callbacks import ( 12 | Callback, EstimatedTimeLeft, ModelSaver, ScheduledHyperParamSetter, ThroughputTracker) 13 | from tensorpack.dataflow import FakeData 14 | from tensorpack.input_source import QueueInput, TFDatasetInput, ZMQInput 15 | from tensorpack.models import BatchNorm, l2_regularizer, regularize_cost 16 | from tensorpack.tfutils import argscope, varreplace 17 | from tensorpack.tfutils.summary import add_moving_summary 18 | from tensorpack.train import ( 19 | HorovodTrainer, ModelDesc, TrainConfig, launch_train_with_config) 20 | from tensorpack.utils import logger 21 | 22 | import horovod.tensorflow as hvd 23 | from resnet import ResNetModel 24 | 25 | BASE_LR = 0.03 26 | 27 | 28 | def allgather(tensor, name): 29 | tensor = tf.identity(tensor, name=name + "_HVD") 30 | return hvd.allgather(tensor) 31 | 32 | 33 | def batch_shuffle(tensor): # nx... 34 | total, rank = hvd.size(), hvd.rank() 35 | batch_size = tf.shape(tensor)[0] 36 | with tf.device('/cpu:0'): 37 | all_idx = tf.range(total * batch_size) 38 | shuffle_idx = tf.random.shuffle(all_idx) 39 | shuffle_idx = hvd.broadcast(shuffle_idx, 0) 40 | my_idxs = tf.slice(shuffle_idx, [rank * batch_size], [batch_size]) 41 | 42 | all_tensor = allgather(tensor, 'batch_shuffle_key') # gn x ... 43 | return tf.gather(all_tensor, my_idxs), shuffle_idx 44 | 45 | 46 | def batch_unshuffle(key_feat, shuffle_idxs): 47 | rank = hvd.rank() 48 | inv_shuffle_idx = tf.argsort(shuffle_idxs) 49 | batch_size = tf.shape(key_feat)[0] 50 | my_idxs = tf.slice(inv_shuffle_idx, [rank * batch_size], [batch_size]) 51 | all_key_feat = allgather(key_feat, "batch_unshuffle_feature") # gn x c 52 | return tf.gather(all_key_feat, my_idxs) 53 | 54 | 55 | class MOCOModel(ModelDesc): 56 | def __init__(self, batch_size, feature_dims=(128,), temp=0.07): 57 | self.batch_size = batch_size 58 | self.feature_dim = feature_dims[-1] 59 | # NOTE: implicit assume queue_size % (batch_size * GPU) ==0 60 | self.queue_size = 65536 61 | self.temp = temp 62 | 63 | self.net = ResNetModel(num_output=feature_dims) 64 | self.image_shape = 224 65 | 66 | def inputs(self): 67 | return [tf.TensorSpec([self.batch_size, self.image_shape, self.image_shape, 3], tf.uint8, 'query'), 68 | tf.TensorSpec([self.batch_size, self.image_shape, self.image_shape, 3], tf.uint8, 'key')] 69 | 70 | def build_graph(self, query, key): 71 | # setup queue 72 | queue_init = tf.math.l2_normalize( 73 | tf.random.normal([self.queue_size, self.feature_dim]), axis=1) 74 | queue = tf.get_variable('queue', initializer=queue_init, trainable=False) 75 | queue_ptr = tf.get_variable( 76 | 'queue_ptr', 77 | [], initializer=tf.zeros_initializer(), 78 | dtype=tf.int64, trainable=False) 79 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, queue) 80 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, queue_ptr) 81 | 82 | # query encoder 83 | q_feat = self.net.forward(query) # NxC 84 | q_feat = tf.math.l2_normalize(q_feat, axis=1) 85 | 86 | # key encoder 87 | shuffled_key, shuffle_idxs = batch_shuffle(key) 88 | shuffled_key.set_shape([self.batch_size, None, None, None]) 89 | with tf.variable_scope("momentum_encoder"), \ 90 | varreplace.freeze_variables(skip_collection=True), \ 91 | argscope(BatchNorm, ema_update='skip'): # don't maintain EMA (will not be used at all) 92 | key_feat = xla.compile(lambda: self.net.forward(shuffled_key))[0] 93 | # key_feat = self.net.forward(shuffled_key) 94 | key_feat = tf.math.l2_normalize(key_feat, axis=1) # NxC 95 | key_feat = batch_unshuffle(key_feat, shuffle_idxs) 96 | key_feat = tf.stop_gradient(key_feat) 97 | 98 | # loss 99 | l_pos = tf.reshape(tf.einsum('nc,nc->n', q_feat, key_feat), (-1, 1)) # nx1 100 | l_neg = tf.einsum('nc,kc->nk', q_feat, queue) # nxK 101 | logits = tf.concat([l_pos, l_neg], axis=1) # nx(1+k) 102 | logits = logits * (1 / self.temp) 103 | labels = tf.zeros(self.batch_size, dtype=tf.int64) # n 104 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 105 | loss = tf.reduce_mean(loss, name='xentropy-loss') 106 | 107 | acc = tf.reduce_mean(tf.cast( 108 | tf.equal(tf.math.argmax(logits, axis=1), labels), tf.float32), name='train-acc') 109 | 110 | # update queue (depend on l_neg) 111 | with tf.control_dependencies([l_neg]): 112 | queue_push_op = self.push_queue(queue, queue_ptr, key_feat) 113 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, queue_push_op) 114 | 115 | wd_loss = regularize_cost(".*", l2_regularizer(1e-4), name='l2_regularize_loss') 116 | add_moving_summary(acc, loss, wd_loss) 117 | total_cost = tf.add_n([loss, wd_loss], name='cost') 118 | return total_cost 119 | 120 | def push_queue(self, queue, queue_ptr, item): 121 | # queue: KxC 122 | # item: NxC 123 | item = allgather(item, 'queue_gather') # GN x C 124 | batch_size = tf.shape(item, out_type=tf.int64)[0] 125 | end_queue_ptr = queue_ptr + batch_size 126 | 127 | inds = tf.range(queue_ptr, end_queue_ptr, dtype=tf.int64) 128 | with tf.control_dependencies([inds]): 129 | queue_ptr_update = tf.assign(queue_ptr, end_queue_ptr % self.queue_size) 130 | queue_update = tf.scatter_update(queue, inds, item) 131 | return tf.group(queue_update, queue_ptr_update) 132 | 133 | def optimizer(self): 134 | if args.v2: 135 | # cosine LR in v2 136 | gs = tf.train.get_or_create_global_step() 137 | total_steps = 1281167 // args.batch * 200 138 | lr = BASE_LR * 0.5 * (1 + tf.cos(gs / total_steps * np.pi)) 139 | else: 140 | lr = tf.get_variable('learning_rate', initializer=0.0, trainable=False) 141 | tf.summary.scalar('learning_rate-summary', lr) 142 | opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) 143 | return opt 144 | 145 | 146 | class UpdateMomentumEncoder(Callback): 147 | _chief_only = False # execute it in every worker 148 | momentum = 0.999 149 | 150 | def _setup_graph(self): 151 | nontrainable_vars = list(set(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))) 152 | all_vars = {v.name: v for v in tf.global_variables() + tf.local_variables()} 153 | 154 | # find variables of encoder & momentum encoder 155 | self._var_mapping = {} # var -> mom var 156 | momentum_prefix = "momentum_encoder/" 157 | for mom_var in nontrainable_vars: 158 | if momentum_prefix in mom_var.name: 159 | q_encoder_name = mom_var.name.replace(momentum_prefix, "") 160 | q_encoder_var = all_vars[q_encoder_name] 161 | assert q_encoder_var not in self._var_mapping 162 | if not q_encoder_var.trainable: # don't need to copy EMA 163 | continue 164 | self._var_mapping[q_encoder_var] = mom_var 165 | 166 | logger.info(f"Found {len(self._var_mapping)} pairs of matched variables.") 167 | 168 | assign_ops = [tf.assign(mom_var, var) for var, mom_var in self._var_mapping.items()] 169 | self.assign_op = tf.group(*assign_ops, name="initialize_momentum_encoder") 170 | 171 | update_ops = [tf.assign_add(mom_var, (var - mom_var) * (1 - self.momentum)) 172 | for var, mom_var in self._var_mapping.items()] 173 | self.update_op = tf.group(*update_ops, name="update_momentum_encoder") 174 | 175 | def _before_train(self): 176 | logger.info("Copying encoder to momentum encoder ...") 177 | self.assign_op.run() 178 | 179 | def _trigger_step(self): 180 | self.update_op.run() 181 | 182 | 183 | def get_config(model): 184 | input_sig = model.get_input_signature() 185 | nr_tower = max(hvd.size(), 1) 186 | batch = args.batch // nr_tower 187 | logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch)) 188 | 189 | callbacks = [ 190 | ThroughputTracker(args.batch), 191 | UpdateMomentumEncoder() 192 | ] 193 | 194 | if args.fake: 195 | data = QueueInput(FakeData( 196 | [x.shape for x in input_sig], 1000, random=False, dtype='uint8')) 197 | else: 198 | zmq_addr = 'ipc://@imagenet-train-b{}'.format(batch) 199 | data = ZMQInput(zmq_addr, 25, bind=False) 200 | 201 | dataset = data.to_dataset(input_sig).repeat().prefetch(15) 202 | dataset = dataset.apply(tf.data.experimental.prefetch_to_device('/gpu:0')) 203 | data = TFDatasetInput(dataset) 204 | 205 | callbacks.extend([ 206 | ModelSaver(), 207 | EstimatedTimeLeft(), 208 | ]) 209 | 210 | if not args.v2: 211 | # step-wise LR in v1 212 | SCALED_LR = BASE_LR * (args.batch / 256.0) 213 | callbacks.append( 214 | ScheduledHyperParamSetter( 215 | 'learning_rate', [ 216 | (0, min(BASE_LR, SCALED_LR)), 217 | (120, SCALED_LR * 1e-1), 218 | (160, SCALED_LR * 1e-2) 219 | ])) 220 | if SCALED_LR > BASE_LR: 221 | callbacks.append( 222 | ScheduledHyperParamSetter( 223 | 'learning_rate', [(0, BASE_LR), (5, SCALED_LR)], interp='linear')) 224 | 225 | return TrainConfig( 226 | model=model, 227 | data=data, 228 | callbacks=callbacks, 229 | steps_per_epoch=100 if args.fake else 1281167 // args.batch, 230 | max_epoch=200, 231 | ) 232 | 233 | 234 | if __name__ == "__main__": 235 | parser = argparse.ArgumentParser() 236 | parser.add_argument('--data', help='imagenet data dir') 237 | parser.add_argument('--fake', help='use FakeData to debug or benchmark this model', action='store_true') 238 | parser.add_argument('--batch', default=256, type=int, help='total batch size') 239 | parser.add_argument('--v2', action='store_true', help='train mocov2') 240 | parser.add_argument('--logdir') 241 | args = parser.parse_args() 242 | 243 | hvd.init() 244 | 245 | local_batch_size = args.batch // hvd.size() 246 | if args.v2: 247 | model = MOCOModel(batch_size=local_batch_size, feature_dims=(2048, 128), temp=0.2) 248 | else: 249 | model = MOCOModel(batch_size=local_batch_size, feature_dims=(128,), temp=0.07) 250 | 251 | if hvd.rank() == 0: 252 | if args.fake: 253 | logger.set_logger_dir('fake_train_log', 'd') 254 | else: 255 | if args.logdir is None: 256 | args.logdir = './moco' 257 | logger.set_logger_dir(args.logdir, 'n') 258 | logger.info("Rank={}, Local Rank={}, Size={}".format(hvd.rank(), hvd.local_rank(), hvd.size())) 259 | 260 | if not args.fake and hvd.local_rank() == 0: 261 | # start data serving process 262 | script = os.path.realpath(os.path.join(os.path.dirname(__file__), "serve-data.py")) 263 | v2_flag = "--v2" if args.v2 else "" 264 | cmd = f"taskset --cpu-list 0-29 {script} --data {args.data} --batch {local_batch_size} {v2_flag}" 265 | log_prefix = os.path.join(args.logdir, "data." + str(hvd.rank())) 266 | logger.info("Launching command: " + cmd) 267 | pid = subprocess.Popen( 268 | cmd, 269 | shell=True, 270 | stdout=open(log_prefix + ".stdout", "w"), 271 | stderr=open(log_prefix + ".stderr", "w")) 272 | 273 | config = get_config(model) 274 | trainer = HorovodTrainer(average=True) 275 | launch_train_with_config(config, trainer) 276 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | from tensorpack.models import ( 6 | BatchNorm, Conv2D, FullyConnected, GlobalAvgPooling, LinearWrap, MaxPooling) 7 | from tensorpack.tfutils import argscope 8 | 9 | from data import tf_preprocess 10 | 11 | 12 | class ResNetModel: 13 | def __init__(self, num_output=None): 14 | """ 15 | num_output: int or list[int]: dimension(s) of FC layers in the end 16 | """ 17 | self.data_format = "NCHW" 18 | if num_output is not None: 19 | if not isinstance(num_output, (list, tuple)): 20 | num_output = [num_output] 21 | self.num_output = num_output 22 | 23 | def forward(self, image): 24 | # accept [0-255] BGR NHWC images (from dataflow) 25 | image = tf_preprocess(image) 26 | if self.data_format == "NCHW": 27 | image = tf.transpose(image, [0, 3, 1, 2]) 28 | return self.get_logits(image) 29 | 30 | def get_logits(self, image): 31 | num_blocks = [3, 4, 6, 3] 32 | 33 | with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format), \ 34 | argscope(Conv2D, use_bias=False, 35 | kernel_initializer=tf.variance_scaling_initializer( 36 | scale=2.0, mode='fan_out', distribution='untruncated_normal')), \ 37 | argscope(BatchNorm, epsilon=1.001e-5): 38 | logits = (LinearWrap(image) 39 | .tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]]) 40 | .Conv2D('conv0', 64, 7, strides=2, padding='VALID') 41 | .apply(self.norm_func, 'conv0') 42 | .tf.nn.relu() 43 | .tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]]) 44 | .MaxPooling('pool0', shape=3, stride=2, padding='VALID') 45 | .apply(self.resnet_group, 'group0', 64, num_blocks[0], 1) 46 | .apply(self.resnet_group, 'group1', 128, num_blocks[1], 2) 47 | .apply(self.resnet_group, 'group2', 256, num_blocks[2], 2) 48 | .apply(self.resnet_group, 'group3', 512, num_blocks[3], 2) 49 | .GlobalAvgPooling('gap')()) 50 | if self.num_output is not None: 51 | for idx, no in enumerate(self.num_output): 52 | logits = FullyConnected( 53 | 'linear{}_{}'.format(idx, no), 54 | logits, no, 55 | kernel_initializer=tf.random_normal_initializer(stddev=0.01)) 56 | if idx != len(self.num_output) - 1: 57 | logits = tf.nn.relu(logits) 58 | return logits 59 | 60 | def norm_func(self, x, name, gamma_initializer=tf.constant_initializer(1.)): 61 | return BatchNorm(name + '_bn', x, gamma_initializer=gamma_initializer) 62 | 63 | def resnet_group(self, l, name, features, count, stride): 64 | with tf.variable_scope(name): 65 | for i in range(0, count): 66 | with tf.variable_scope('block{}'.format(i)): 67 | l = self.bottleneck_block(l, features, stride if i == 0 else 1) 68 | return l 69 | 70 | def bottleneck_block(self, l, ch_out, stride): 71 | shortcut = l 72 | l = Conv2D('conv1', l, ch_out, 1, strides=1) 73 | l = self.norm_func(l, 'conv1') 74 | l = tf.nn.relu(l) 75 | 76 | if stride == 1: 77 | l = Conv2D('conv2', l, ch_out, 3, strides=1) 78 | else: 79 | l = tf.pad(l, [[0, 0], [0, 0], [1, 1], [1, 1]]) 80 | l = Conv2D('conv2', l, ch_out, 3, strides=stride, padding='VALID') 81 | l = self.norm_func(l, 'conv2') 82 | l = tf.nn.relu(l) 83 | 84 | l = Conv2D('conv3', l, ch_out * 4, 1) 85 | l = self.norm_func(l, 'conv3') # pt does not use 0init 86 | return tf.nn.relu( 87 | l + self.shortcut(shortcut, ch_out * 4, stride), 'output') 88 | 89 | def shortcut(self, l, n_out, stride): 90 | n_in = l.get_shape().as_list()[1] 91 | if n_in != n_out: # change dimension when channel is not the same 92 | l = Conv2D('convshortcut', l, n_out, 1, strides=stride) 93 | l = self.norm_func(l, 'convshortcut') 94 | return l 95 | -------------------------------------------------------------------------------- /serve-data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import pprint 6 | import os 7 | import socket 8 | import multiprocessing as mp 9 | import cv2 10 | 11 | from tensorpack.dataflow import FakeData, MapData, TestDataSpeed, send_dataflow_zmq 12 | from tensorpack.utils import logger 13 | 14 | from data import get_moco_dataflow, get_moco_v1_augmentor, get_moco_v2_augmentor 15 | from zmq_ops import dump_arrays 16 | 17 | 18 | cv2.setNumThreads(0) 19 | 20 | 21 | if __name__ == '__main__': 22 | mp.set_start_method('spawn') 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--data', help='ILSVRC dataset dir') 25 | parser.add_argument('--fake', action='store_true') 26 | parser.add_argument('--batch', help='per-GPU batch size', 27 | default=32, type=int) 28 | parser.add_argument('--benchmark', action='store_true') 29 | parser.add_argument('--v2', action='store_true') 30 | parser.add_argument('--no-zmq-ops', action='store_true') 31 | args = parser.parse_args() 32 | 33 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 34 | 35 | if args.fake: 36 | ds = FakeData( 37 | [[args.batch, 224, 224, 3], [args.batch, 224, 224, 3]], 38 | 9999999, random=False, dtype=['uint8', 'uint8']) 39 | else: 40 | aug = get_moco_v2_augmentor() if args.v2 else get_moco_v1_augmentor() 41 | logger.info("Augmentation used: \n" + pprint.pformat(aug)) 42 | ds = get_moco_dataflow(args.data, args.batch, aug) 43 | 44 | logger.info("Serving data on {}".format(socket.gethostname())) 45 | 46 | if args.benchmark: 47 | ds = MapData(ds, dump_arrays) 48 | TestDataSpeed(ds, size=99999, warmup=300).start() 49 | else: 50 | format = None if args.no_zmq_ops else 'zmq_ops' 51 | send_dataflow_zmq( 52 | ds, 'ipc://@imagenet-train-b{}'.format(args.batch), 53 | hwm=200, format=format, bind=True) 54 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | # See https://pep8.readthedocs.io/en/latest/intro.html#error-codes 4 | ignore = E265,E741,E742,E743,W504,W605,C408,B007,B008 5 | exclude = .git, 6 | __init__.py, 7 | show-source = true 8 | 9 | [isort] 10 | line_length=100 11 | multi_line_output=4 12 | known_tensorpack=tensorpack 13 | known_standard_library=numpy 14 | known_third_party=matplotlib,tensorflow,cv2,PIL 15 | no_lines_before=STDLIB,THIRDPARTY 16 | sections=FUTURE,STDLIB,THIRDPARTY,tensorpack,FIRSTPARTY,LOCALFOLDER 17 | --------------------------------------------------------------------------------