├── .github
└── ISSUE_TEMPLATE.md
├── .gitignore
├── README.md
├── data.py
├── eval_knn.py
├── main_lincls.py
├── main_moco.py
├── resnet.py
├── serve-data.py
└── tox.ini
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | The only goal of this project is to reproduce results in papers.
2 | We do not take feature requests or questions that are unrelated to this goal.
3 |
4 | If you met an unexpected problem when using the code,
5 | please include the following in your issues:
6 |
7 | * What you did: the command you run.
8 |
9 | * What you observed: the full logs and other relevant information.
10 |
11 | * What you expected, if not obvious.
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # tensorpack-specific stuff
2 | fake_train_log
3 | train_log
4 | train_log_*
5 | logs
6 | *.npy
7 | *.npz
8 | *.caffemodel
9 | *.tfmodel
10 | *.meta
11 | *.log*
12 | *.bin
13 | *.png
14 | *.jpg
15 | checkpoint
16 | *.json
17 | *.prototxt
18 | *.txt
19 | *.tgz
20 | *.gz
21 |
22 | # my personal stuff
23 | snippet
24 | examples/private
25 | TODO.md
26 | .gitignore
27 | .vimrc.local
28 |
29 |
30 | # Byte-compiled / optimized / DLL files
31 | __pycache__/
32 | *.py[cod]
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | env/
40 | build/
41 | develop-eggs/
42 | dist/
43 | downloads/
44 | eggs/
45 | .eggs/
46 | lib/
47 | lib64/
48 | parts/
49 | sdist/
50 | var/
51 | *.egg-info/
52 | .installed.cfg
53 | *.egg
54 |
55 | # PyInstaller
56 | # Usually these files are written by a python script from a template
57 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
58 | *.manifest
59 | *.spec
60 |
61 | # Installer logs
62 | pip-log.txt
63 | pip-delete-this-directory.txt
64 |
65 | # Unit test / coverage reports
66 | htmlcov/
67 | .tox/
68 | .coverage
69 | .coverage.*
70 | .cache
71 | nosetests.xml
72 | coverage.xml
73 | *,cover
74 |
75 | # Translations
76 | *.mo
77 | *.pot
78 |
79 | # Django stuff:
80 | *.log
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 | *.dat
88 |
89 | .idea/
90 | *.diff
91 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Implement and __reproduce__ results of the following papers:
3 |
4 | * [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/abs/1911.05722)
5 | * [Improved Baselines with Momentum Contrastive Learning](https://arxiv.org/abs/2003.04297)
6 |
7 | ## Dependencies:
8 |
9 | * TensorFlow 1.14 or 1.15, built with XLA support
10 | * [Tensorpack](https://github.com/tensorpack/tensorpack/) ≥ 0.10.1
11 | * [Horovod](https://github.com/horovod/horovod) ≥ 0.19 built with Gloo & NCCL support
12 | * TensorFlow [zmq_ops](https://github.com/tensorpack/zmq_ops)
13 | * OpenCV
14 | * the `taskset` command (from the `util-linux` package)
15 |
16 | ## Unsupervised Training:
17 |
18 | To run MoCo pre-training on a machine with 8 GPUs, use:
19 | ```
20 | horovodrun -np 8 --output-filename moco.log python main_moco.py --data /path/to/imagenet
21 | ```
22 |
23 | Add `--v2` to train MoCov2,
24 | which uses an extra MLP layer, extra augmentations, and cosine LR schedule.
25 |
26 |
27 | ## Linear Classification:
28 | To train a linear classifier using the pre-trained features, run:
29 | ```
30 | ./main_lincls.py --load /path/to/pretrained/checkpoint --data /path/to/imagenet
31 | ```
32 |
33 | ## KNN Evaluation:
34 | Instead of Linear Classification, a cheap but rough evaluation
35 | is to perform a feature-space kNN using the training set:
36 | ```
37 | horovodrun -np 8 ./eval_knn.py --load /path/to/checkpoint --data /path/to/imagenet --top-k 200
38 | ```
39 |
40 | ## Results:
41 | Training was done in a machine with 8 V100s, >200GB RAM and 80 CPUs.
42 |
43 | Following results are obtained after
44 | 200 epochs of pre-training (~53h)
45 | and 100 epochs of linear classifier tuning (~8h).
46 | KNN evaluation takes 10min per checkpoint.
47 |
48 | | | linear cls.
accuracy | download
(pretrained only) | tensorboard |
49 | | - | :-: | :-: | :-: |
50 | | MoCo v1 | 60.9% | [:arrow_down:](https://github.com/ppwwyyxx/moco.tensorflow/releases/download/v/MoCo_v1.npz) | N/A |
51 | | MoCo v2 | 67.7% | [:arrow_down:](https://github.com/ppwwyyxx/moco.tensorflow/releases/download/v/MoCo_v2.npz) | [pretrain](https://tensorboard.dev/experiment/MBL49FKLTLCbKGr7JMolWQ); [finetune](https://tensorboard.dev/experiment/s3ZOxbjbRCy3hMqgL0TzKQ) |
52 |
53 | ## Notes:
54 |
55 | * Horovod with Gloo is recommended. Horovod with MPI is not tested and may crash due to how we use forking.
56 | * If using TensorFlow without XLA support, you can modify `main_*.py` to replace `xla.compile` by a naive forward.
57 | * Official PyTorch code is at [facebookresearch/moco](https://github.com/facebookresearch/moco).
58 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import cv2
5 | import multiprocessing as mp
6 | import tensorflow as tf
7 |
8 | from tensorpack.dataflow import (
9 | BatchData, MultiProcessMapAndBatchDataZMQ, MultiProcessRunnerZMQ, MultiThreadMapData, dataset,
10 | imgaug)
11 |
12 |
13 | cv2.setNumThreads(0)
14 |
15 |
16 | def get_moco_v1_augmentor():
17 | augmentors = [
18 | imgaug.GoogleNetRandomCropAndResize(crop_area_fraction=(0.2, 1.)),
19 | imgaug.RandomApplyAug(imgaug.Grayscale(rgb=False, keepshape=True), 0.2),
20 | imgaug.ToFloat32(),
21 | imgaug.RandomOrderAug(
22 | [imgaug.BrightnessScale((0.6, 1.4)),
23 | imgaug.Contrast((0.6, 1.4), rgb=False),
24 | imgaug.Saturation(0.4, rgb=False),
25 | # 72 = 180*0.4
26 | imgaug.Hue(range=(-72, 72), rgb=False)
27 | ]),
28 | imgaug.ToUint8(),
29 | imgaug.Flip(horiz=True),
30 | ]
31 | return augmentors
32 |
33 |
34 | def get_moco_v2_augmentor():
35 | augmentors = [
36 | imgaug.GoogleNetRandomCropAndResize(crop_area_fraction=(0.2, 1.)),
37 | imgaug.ToFloat32(),
38 | imgaug.RandomApplyAug(
39 | imgaug.RandomOrderAug(
40 | [imgaug.BrightnessScale((0.6, 1.4)),
41 | imgaug.Contrast((0.6, 1.4), rgb=False),
42 | imgaug.Saturation(0.4, rgb=False),
43 | # 18 = 180*0.1
44 | imgaug.Hue(range=(-18, 18), rgb=False)
45 | ]), 0.8),
46 | imgaug.RandomApplyAug(imgaug.Grayscale(rgb=False, keepshape=True), 0.2),
47 | imgaug.RandomApplyAug(
48 | # 11 = 0.1*224//2
49 | imgaug.GaussianBlur(size_range=(11, 12), sigma_range=[0.1, 2.0]), 0.5),
50 | imgaug.ToUint8(),
51 | imgaug.Flip(horiz=True),
52 | ]
53 | return augmentors
54 |
55 |
56 | class MoCoMapper:
57 | def __init__(self, augs):
58 | self.augs = augs
59 |
60 | def __call__(self, dp):
61 | fname, _ = dp # throw away the label
62 | img = cv2.imread(fname)
63 | img1 = self.augs.augment(img)
64 | img2 = self.augs.augment(img)
65 | return [img1, img2]
66 |
67 |
68 | def get_moco_dataflow(datadir, batch_size, augmentors):
69 | """
70 | Dataflow for training MOCO.
71 | """
72 | augmentors = imgaug.AugmentorList(augmentors)
73 | parallel = min(30, mp.cpu_count()) # tuned on a 40-CPU 80-core machine
74 | ds = dataset.ILSVRC12Files(datadir, 'train', shuffle=True)
75 | ds = MultiProcessMapAndBatchDataZMQ(ds, parallel, MoCoMapper(augmentors), batch_size, buffer_size=5000)
76 | return ds
77 |
78 |
79 | def get_basic_augmentor(isTrain):
80 | interpolation = cv2.INTER_LINEAR
81 | if isTrain:
82 | augmentors = [
83 | imgaug.GoogleNetRandomCropAndResize(),
84 | imgaug.Flip(horiz=True),
85 | ]
86 | else:
87 | augmentors = [
88 | imgaug.ResizeShortestEdge(256, interp=interpolation),
89 | imgaug.CenterCrop((224, 224)),
90 | ]
91 | return augmentors
92 |
93 |
94 | def get_imagenet_dataflow(datadir, name, batch_size, parallel=None):
95 | """
96 | Get a standard imagenet training/evaluation dataflow, for linear classifier tuning.
97 | """
98 | assert name in ['train', 'val']
99 | isTrain = name == 'train'
100 | assert datadir is not None
101 | augmentors = get_basic_augmentor(isTrain)
102 | augmentors = imgaug.AugmentorList(augmentors)
103 | if parallel is None:
104 | parallel = min(50, mp.cpu_count())
105 |
106 | def mapper(dp):
107 | fname, label = dp
108 | img = cv2.imread(fname)
109 | img = augmentors.augment(img)
110 | return img, label
111 |
112 | if isTrain:
113 | ds = dataset.ILSVRC12Files(datadir, name, shuffle=True)
114 | ds = MultiProcessMapAndBatchDataZMQ(ds, parallel, mapper, batch_size, buffer_size=7000)
115 | else:
116 | ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
117 | ds = MultiThreadMapData(ds, parallel, mapper, buffer_size=2000, strict=True)
118 | ds = BatchData(ds, batch_size, remainder=True)
119 | ds = MultiProcessRunnerZMQ(ds, 1)
120 | return ds
121 |
122 |
123 | def tf_preprocess(image): # normalize BGR images
124 | with tf.name_scope('image_preprocess'):
125 | if image.dtype.base_dtype != tf.float32:
126 | image = tf.cast(image, tf.float32)
127 | mean = [0.485, 0.456, 0.406] # rgb
128 | std = [0.229, 0.224, 0.225]
129 | mean = mean[::-1]
130 | std = std[::-1]
131 | image_mean = tf.constant(mean, dtype=tf.float32) * 255.
132 | image_std = tf.constant(std, dtype=tf.float32) * 255.
133 | image = (image - image_mean) / image_std
134 | return image
135 |
136 |
137 | if __name__ == '__main__':
138 | from tensorpack.dataflow import TestDataSpeed
139 | import sys
140 | df = get_imagenet_dataflow(sys.argv[1], 'train', 32)
141 |
142 | TestDataSpeed(df, size=99999999, warmup=300).start()
143 |
--------------------------------------------------------------------------------
/eval_knn.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 | import argparse
4 | import os
5 | import cv2
6 | import tqdm
7 | from collections import Counter
8 |
9 | from tensorpack import tfv1 as tf
10 | from tensorpack.utils.stats import Accuracy
11 | from tensorpack.utils import logger
12 | from tensorpack.tfutils import TowerContext, get_default_sess_config
13 | from tensorpack.tfutils.sessinit import SmartInit
14 | from tensorpack.tfutils.varmanip import get_checkpoint_path, get_all_checkpoints
15 | from tensorpack.dataflow import (
16 | imgaug, DataFromList, BatchData, MultiProcessMapDataZMQ, dataset)
17 |
18 | import horovod.tensorflow as hvd
19 | from resnet import ResNetModel
20 | from data import get_basic_augmentor, get_imagenet_dataflow
21 |
22 |
23 | def build_dataflow(files):
24 | train_ds = DataFromList(files)
25 | aug = imgaug.AugmentorList(get_basic_augmentor(isTrain=False))
26 |
27 | def mapper(dp):
28 | idx, fname, label = dp
29 | img = cv2.imread(fname)
30 | img = aug.augment(img)
31 | return img, idx
32 |
33 | train_ds = MultiProcessMapDataZMQ(train_ds, num_proc=8, map_func=mapper, strict=True)
34 | train_ds = BatchData(train_ds, local_batch_size)
35 | train_ds.reset_state()
36 | return train_ds
37 |
38 |
39 | if __name__ == "__main__":
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument('--data', help='imagenet data dir')
42 | parser.add_argument('--batch', default=512, type=int, help='total batch size')
43 | parser.add_argument('--load', required=True, help='file or directory to evaluate')
44 | parser.add_argument('--top-k', type=int, default=200, help='top-k in KNN')
45 | parser.add_argument('--v2', action='store_true', help='use mocov2')
46 | args = parser.parse_args()
47 |
48 | hvd.init()
49 | local_batch_size = args.batch // hvd.size()
50 |
51 | train_files = dataset.ILSVRC12Files(args.data, 'train', shuffle=True)
52 | train_files.reset_state()
53 | all_train_files = list(train_files)
54 | all_train_files = all_train_files[:len(all_train_files) // args.batch * args.batch] # truncate
55 | num_train_images = len(all_train_files)
56 | logger.info(f"Creating graph for KNN of {num_train_images} training images ...")
57 | local_train_files = [(idx, fname, label) for idx, (fname, label) in
58 | enumerate(all_train_files) if idx % hvd.size() == hvd.rank()]
59 |
60 | image_input = tf.placeholder(tf.uint8, [None, 224, 224, 3], "image")
61 | idx_input = tf.placeholder(tf.int64, [None], "image_idx")
62 |
63 | feat_buffer = tf.get_variable("feature_buffer", shape=[num_train_images, 128], trainable=False)
64 | net = ResNetModel(num_output=(2048, 128) if args.v2 else (128,))
65 | with TowerContext("", is_training=False):
66 | feat = net.forward(image_input)
67 | feat = tf.math.l2_normalize(feat, axis=1) # Nx128
68 | all_feat = hvd.allgather(feat) # GN x 128
69 | all_idx_input = hvd.allgather(idx_input) # GN
70 | update_buffer = tf.scatter_update(feat_buffer, all_idx_input, all_feat)
71 |
72 | dist = tf.matmul(feat, tf.transpose(feat_buffer)) # N x #DS
73 | _, topk_indices = tf.math.top_k(dist, k=args.top_k) # Nxtopk
74 |
75 | train_ds = build_dataflow(local_train_files)
76 |
77 | config = get_default_sess_config()
78 | config.gpu_options.visible_device_list = str(hvd.local_rank())
79 |
80 | def evaluate(checkpoint_file):
81 | result_file = get_checkpoint_path(checkpoint_file) + f".knn{args.top_k}.txt"
82 | if os.path.isfile(result_file):
83 | logger.info(f"Skipping evaluation of {result_file}.")
84 | return
85 | with tf.Session(config=config) as sess:
86 | sess.run(tf.global_variables_initializer())
87 | SmartInit(checkpoint_file).init(sess)
88 | for batch_img, batch_idx in tqdm.tqdm(train_ds, total=len(train_ds)):
89 | sess.run(update_buffer,
90 | feed_dict={image_input: batch_img, idx_input: batch_idx})
91 |
92 | if hvd.rank() == 0:
93 | acc = Accuracy()
94 | val_df = get_imagenet_dataflow(args.data, "val", local_batch_size)
95 | val_df.reset_state()
96 |
97 | for batch_img, batch_label in val_df:
98 | topk_indices_pred = sess.run(topk_indices, feed_dict={image_input: batch_img})
99 | for indices, gt in zip(topk_indices_pred, batch_label):
100 | pred = [all_train_files[k][1] for k in indices]
101 | top_pred = Counter(pred).most_common(1)[0]
102 | acc.feed(top_pred[0] == gt, total=1)
103 | logger.info(f"Accuracy of {checkpoint_file}: {acc.accuracy} out of {acc.total}")
104 | with open(result_file, "w") as f:
105 | f.write(str(acc.accuracy))
106 |
107 | if os.path.isdir(args.load):
108 | for fname, _ in get_all_checkpoints(args.load):
109 | logger.info(f"Evaluating {fname} ...")
110 | evaluate(fname)
111 | else:
112 | evaluate(args.load)
113 |
--------------------------------------------------------------------------------
/main_lincls.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import tensorflow as tf
6 |
7 | from tensorpack.callbacks import (
8 | ClassificationError, DataParallelInferenceRunner, EstimatedTimeLeft, InferenceRunner,
9 | ModelSaver, ScheduledHyperParamSetter, ThroughputTracker)
10 | from tensorpack.dataflow import FakeData
11 | from tensorpack.input_source import QueueInput
12 | from tensorpack.models import BatchNorm, FullyConnected
13 | from tensorpack.tfutils import SaverRestore, argscope, varreplace, SmartInit
14 | from tensorpack.tfutils.summary import add_moving_summary
15 | from tensorpack.train import (
16 | ModelDesc, SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config)
17 | from tensorpack.utils import logger
18 | from tensorpack.utils.gpu import get_num_gpu
19 |
20 | from data import get_imagenet_dataflow
21 | from resnet import ResNetModel
22 |
23 |
24 | class LinearModel(ModelDesc):
25 | def __init__(self):
26 | self.net = ResNetModel(num_output=None)
27 | self.image_shape = 224
28 |
29 | def inputs(self):
30 | return [tf.TensorSpec([None, self.image_shape, self.image_shape, 3], tf.uint8, 'input'),
31 | tf.TensorSpec([None], tf.int32, 'label')]
32 |
33 | def compute_loss_and_error(self, logits, label):
34 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
35 | loss = tf.reduce_mean(loss, name='xentropy-loss')
36 |
37 | def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
38 | with tf.name_scope('prediction_incorrect'):
39 | x = tf.logical_not(tf.nn.in_top_k(logits, label, topk))
40 | return tf.cast(x, tf.float32, name=name)
41 |
42 | wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
43 | add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))
44 |
45 | wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
46 | add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
47 | return loss
48 |
49 | def build_graph(self, input, label):
50 | with argscope(BatchNorm, training=False), \
51 | varreplace.freeze_variables(skip_collection=True):
52 | from tensorflow.python.compiler.xla import xla
53 | feature = xla.compile(lambda: self.net.forward(input))[0]
54 | # feature = self.net.forward(input)
55 | feature = tf.stop_gradient(feature) # double safe
56 | logits = FullyConnected(
57 | 'linear_cls', feature, 1000,
58 | kernel_initializer=tf.random_normal_initializer(stddev=0.01),
59 | bias_initializer=tf.constant_initializer())
60 |
61 | tf.nn.softmax(logits, name='prob')
62 | loss = self.compute_loss_and_error(logits, label)
63 |
64 | # weight decay is 0
65 | add_moving_summary(loss)
66 | return loss
67 |
68 | def optimizer(self):
69 | lr = tf.get_variable('learning_rate', initializer=0.0, trainable=False)
70 | tf.summary.scalar('learning_rate-summary', lr)
71 | opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=False)
72 | return opt
73 |
74 |
75 | def get_config(model):
76 | nr_tower = max(get_num_gpu(), 1)
77 | batch = args.batch // nr_tower
78 |
79 | logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
80 |
81 | callbacks = [ThroughputTracker(args.batch)]
82 | if args.fake:
83 | data = QueueInput(FakeData(
84 | [[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8'))
85 | else:
86 | data = QueueInput(
87 | get_imagenet_dataflow(args.data, 'train', batch),
88 | # use a larger queue
89 | queue=tf.FIFOQueue(200, [tf.uint8, tf.int32], [[batch, 224, 224, 3], [batch]])
90 | )
91 |
92 | BASE_LR = 30
93 | SCALED_LR = BASE_LR * (args.batch / 256.0)
94 | callbacks.extend([
95 | ModelSaver(),
96 | EstimatedTimeLeft(),
97 | ScheduledHyperParamSetter(
98 | 'learning_rate', [
99 | (0, SCALED_LR),
100 | (60, SCALED_LR * 1e-1),
101 | (70, SCALED_LR * 1e-2),
102 | (80, SCALED_LR * 1e-3),
103 | (90, SCALED_LR * 1e-4),
104 | ]),
105 | ])
106 |
107 | dataset_val = get_imagenet_dataflow(args.data, 'val', 64)
108 | infs = [ClassificationError('wrong-top1', 'val-error-top1'),
109 | ClassificationError('wrong-top5', 'val-error-top5')]
110 | if nr_tower == 1:
111 | callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
112 | else:
113 | callbacks.append(DataParallelInferenceRunner(
114 | dataset_val, infs, list(range(nr_tower))))
115 |
116 | if args.load.endswith(".npz"):
117 | # a released model in npz format
118 | init = SmartInit(args.load)
119 | else:
120 | # a pre-trained checkpoint
121 | init = SaverRestore(args.load, ignore=("learning_rate", "global_step"))
122 | return TrainConfig(
123 | model=model,
124 | data=data,
125 | callbacks=callbacks,
126 | steps_per_epoch=100 if args.fake else 1281167 // args.batch,
127 | session_init=init,
128 | max_epoch=100,
129 | )
130 |
131 |
132 | if __name__ == "__main__":
133 | parser = argparse.ArgumentParser()
134 | parser.add_argument('--data', help='imagenet data dir')
135 | parser.add_argument('--load', required=True, help='path to pre-trained model')
136 | parser.add_argument('--fake', help='use FakeData to debug or benchmark this model', action='store_true')
137 | parser.add_argument('--batch', default=256, type=int, help='total batch size')
138 | parser.add_argument('--logdir')
139 | args = parser.parse_args()
140 |
141 | model = LinearModel()
142 |
143 | if args.fake:
144 | logger.set_logger_dir('fake_train_log', 'd')
145 | else:
146 | if args.logdir is None:
147 | args.logdir = './moco_lincls'
148 | logger.set_logger_dir(args.logdir)
149 |
150 | config = get_config(model)
151 | trainer = SyncMultiGPUTrainerReplicated(get_num_gpu())
152 | launch_train_with_config(config, trainer)
153 |
--------------------------------------------------------------------------------
/main_moco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import numpy as np
6 | import os
7 | import subprocess
8 | import tensorflow as tf
9 | from tensorflow.python.compiler.xla import xla
10 |
11 | from tensorpack.callbacks import (
12 | Callback, EstimatedTimeLeft, ModelSaver, ScheduledHyperParamSetter, ThroughputTracker)
13 | from tensorpack.dataflow import FakeData
14 | from tensorpack.input_source import QueueInput, TFDatasetInput, ZMQInput
15 | from tensorpack.models import BatchNorm, l2_regularizer, regularize_cost
16 | from tensorpack.tfutils import argscope, varreplace
17 | from tensorpack.tfutils.summary import add_moving_summary
18 | from tensorpack.train import (
19 | HorovodTrainer, ModelDesc, TrainConfig, launch_train_with_config)
20 | from tensorpack.utils import logger
21 |
22 | import horovod.tensorflow as hvd
23 | from resnet import ResNetModel
24 |
25 | BASE_LR = 0.03
26 |
27 |
28 | def allgather(tensor, name):
29 | tensor = tf.identity(tensor, name=name + "_HVD")
30 | return hvd.allgather(tensor)
31 |
32 |
33 | def batch_shuffle(tensor): # nx...
34 | total, rank = hvd.size(), hvd.rank()
35 | batch_size = tf.shape(tensor)[0]
36 | with tf.device('/cpu:0'):
37 | all_idx = tf.range(total * batch_size)
38 | shuffle_idx = tf.random.shuffle(all_idx)
39 | shuffle_idx = hvd.broadcast(shuffle_idx, 0)
40 | my_idxs = tf.slice(shuffle_idx, [rank * batch_size], [batch_size])
41 |
42 | all_tensor = allgather(tensor, 'batch_shuffle_key') # gn x ...
43 | return tf.gather(all_tensor, my_idxs), shuffle_idx
44 |
45 |
46 | def batch_unshuffle(key_feat, shuffle_idxs):
47 | rank = hvd.rank()
48 | inv_shuffle_idx = tf.argsort(shuffle_idxs)
49 | batch_size = tf.shape(key_feat)[0]
50 | my_idxs = tf.slice(inv_shuffle_idx, [rank * batch_size], [batch_size])
51 | all_key_feat = allgather(key_feat, "batch_unshuffle_feature") # gn x c
52 | return tf.gather(all_key_feat, my_idxs)
53 |
54 |
55 | class MOCOModel(ModelDesc):
56 | def __init__(self, batch_size, feature_dims=(128,), temp=0.07):
57 | self.batch_size = batch_size
58 | self.feature_dim = feature_dims[-1]
59 | # NOTE: implicit assume queue_size % (batch_size * GPU) ==0
60 | self.queue_size = 65536
61 | self.temp = temp
62 |
63 | self.net = ResNetModel(num_output=feature_dims)
64 | self.image_shape = 224
65 |
66 | def inputs(self):
67 | return [tf.TensorSpec([self.batch_size, self.image_shape, self.image_shape, 3], tf.uint8, 'query'),
68 | tf.TensorSpec([self.batch_size, self.image_shape, self.image_shape, 3], tf.uint8, 'key')]
69 |
70 | def build_graph(self, query, key):
71 | # setup queue
72 | queue_init = tf.math.l2_normalize(
73 | tf.random.normal([self.queue_size, self.feature_dim]), axis=1)
74 | queue = tf.get_variable('queue', initializer=queue_init, trainable=False)
75 | queue_ptr = tf.get_variable(
76 | 'queue_ptr',
77 | [], initializer=tf.zeros_initializer(),
78 | dtype=tf.int64, trainable=False)
79 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, queue)
80 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, queue_ptr)
81 |
82 | # query encoder
83 | q_feat = self.net.forward(query) # NxC
84 | q_feat = tf.math.l2_normalize(q_feat, axis=1)
85 |
86 | # key encoder
87 | shuffled_key, shuffle_idxs = batch_shuffle(key)
88 | shuffled_key.set_shape([self.batch_size, None, None, None])
89 | with tf.variable_scope("momentum_encoder"), \
90 | varreplace.freeze_variables(skip_collection=True), \
91 | argscope(BatchNorm, ema_update='skip'): # don't maintain EMA (will not be used at all)
92 | key_feat = xla.compile(lambda: self.net.forward(shuffled_key))[0]
93 | # key_feat = self.net.forward(shuffled_key)
94 | key_feat = tf.math.l2_normalize(key_feat, axis=1) # NxC
95 | key_feat = batch_unshuffle(key_feat, shuffle_idxs)
96 | key_feat = tf.stop_gradient(key_feat)
97 |
98 | # loss
99 | l_pos = tf.reshape(tf.einsum('nc,nc->n', q_feat, key_feat), (-1, 1)) # nx1
100 | l_neg = tf.einsum('nc,kc->nk', q_feat, queue) # nxK
101 | logits = tf.concat([l_pos, l_neg], axis=1) # nx(1+k)
102 | logits = logits * (1 / self.temp)
103 | labels = tf.zeros(self.batch_size, dtype=tf.int64) # n
104 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
105 | loss = tf.reduce_mean(loss, name='xentropy-loss')
106 |
107 | acc = tf.reduce_mean(tf.cast(
108 | tf.equal(tf.math.argmax(logits, axis=1), labels), tf.float32), name='train-acc')
109 |
110 | # update queue (depend on l_neg)
111 | with tf.control_dependencies([l_neg]):
112 | queue_push_op = self.push_queue(queue, queue_ptr, key_feat)
113 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, queue_push_op)
114 |
115 | wd_loss = regularize_cost(".*", l2_regularizer(1e-4), name='l2_regularize_loss')
116 | add_moving_summary(acc, loss, wd_loss)
117 | total_cost = tf.add_n([loss, wd_loss], name='cost')
118 | return total_cost
119 |
120 | def push_queue(self, queue, queue_ptr, item):
121 | # queue: KxC
122 | # item: NxC
123 | item = allgather(item, 'queue_gather') # GN x C
124 | batch_size = tf.shape(item, out_type=tf.int64)[0]
125 | end_queue_ptr = queue_ptr + batch_size
126 |
127 | inds = tf.range(queue_ptr, end_queue_ptr, dtype=tf.int64)
128 | with tf.control_dependencies([inds]):
129 | queue_ptr_update = tf.assign(queue_ptr, end_queue_ptr % self.queue_size)
130 | queue_update = tf.scatter_update(queue, inds, item)
131 | return tf.group(queue_update, queue_ptr_update)
132 |
133 | def optimizer(self):
134 | if args.v2:
135 | # cosine LR in v2
136 | gs = tf.train.get_or_create_global_step()
137 | total_steps = 1281167 // args.batch * 200
138 | lr = BASE_LR * 0.5 * (1 + tf.cos(gs / total_steps * np.pi))
139 | else:
140 | lr = tf.get_variable('learning_rate', initializer=0.0, trainable=False)
141 | tf.summary.scalar('learning_rate-summary', lr)
142 | opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
143 | return opt
144 |
145 |
146 | class UpdateMomentumEncoder(Callback):
147 | _chief_only = False # execute it in every worker
148 | momentum = 0.999
149 |
150 | def _setup_graph(self):
151 | nontrainable_vars = list(set(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)))
152 | all_vars = {v.name: v for v in tf.global_variables() + tf.local_variables()}
153 |
154 | # find variables of encoder & momentum encoder
155 | self._var_mapping = {} # var -> mom var
156 | momentum_prefix = "momentum_encoder/"
157 | for mom_var in nontrainable_vars:
158 | if momentum_prefix in mom_var.name:
159 | q_encoder_name = mom_var.name.replace(momentum_prefix, "")
160 | q_encoder_var = all_vars[q_encoder_name]
161 | assert q_encoder_var not in self._var_mapping
162 | if not q_encoder_var.trainable: # don't need to copy EMA
163 | continue
164 | self._var_mapping[q_encoder_var] = mom_var
165 |
166 | logger.info(f"Found {len(self._var_mapping)} pairs of matched variables.")
167 |
168 | assign_ops = [tf.assign(mom_var, var) for var, mom_var in self._var_mapping.items()]
169 | self.assign_op = tf.group(*assign_ops, name="initialize_momentum_encoder")
170 |
171 | update_ops = [tf.assign_add(mom_var, (var - mom_var) * (1 - self.momentum))
172 | for var, mom_var in self._var_mapping.items()]
173 | self.update_op = tf.group(*update_ops, name="update_momentum_encoder")
174 |
175 | def _before_train(self):
176 | logger.info("Copying encoder to momentum encoder ...")
177 | self.assign_op.run()
178 |
179 | def _trigger_step(self):
180 | self.update_op.run()
181 |
182 |
183 | def get_config(model):
184 | input_sig = model.get_input_signature()
185 | nr_tower = max(hvd.size(), 1)
186 | batch = args.batch // nr_tower
187 | logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
188 |
189 | callbacks = [
190 | ThroughputTracker(args.batch),
191 | UpdateMomentumEncoder()
192 | ]
193 |
194 | if args.fake:
195 | data = QueueInput(FakeData(
196 | [x.shape for x in input_sig], 1000, random=False, dtype='uint8'))
197 | else:
198 | zmq_addr = 'ipc://@imagenet-train-b{}'.format(batch)
199 | data = ZMQInput(zmq_addr, 25, bind=False)
200 |
201 | dataset = data.to_dataset(input_sig).repeat().prefetch(15)
202 | dataset = dataset.apply(tf.data.experimental.prefetch_to_device('/gpu:0'))
203 | data = TFDatasetInput(dataset)
204 |
205 | callbacks.extend([
206 | ModelSaver(),
207 | EstimatedTimeLeft(),
208 | ])
209 |
210 | if not args.v2:
211 | # step-wise LR in v1
212 | SCALED_LR = BASE_LR * (args.batch / 256.0)
213 | callbacks.append(
214 | ScheduledHyperParamSetter(
215 | 'learning_rate', [
216 | (0, min(BASE_LR, SCALED_LR)),
217 | (120, SCALED_LR * 1e-1),
218 | (160, SCALED_LR * 1e-2)
219 | ]))
220 | if SCALED_LR > BASE_LR:
221 | callbacks.append(
222 | ScheduledHyperParamSetter(
223 | 'learning_rate', [(0, BASE_LR), (5, SCALED_LR)], interp='linear'))
224 |
225 | return TrainConfig(
226 | model=model,
227 | data=data,
228 | callbacks=callbacks,
229 | steps_per_epoch=100 if args.fake else 1281167 // args.batch,
230 | max_epoch=200,
231 | )
232 |
233 |
234 | if __name__ == "__main__":
235 | parser = argparse.ArgumentParser()
236 | parser.add_argument('--data', help='imagenet data dir')
237 | parser.add_argument('--fake', help='use FakeData to debug or benchmark this model', action='store_true')
238 | parser.add_argument('--batch', default=256, type=int, help='total batch size')
239 | parser.add_argument('--v2', action='store_true', help='train mocov2')
240 | parser.add_argument('--logdir')
241 | args = parser.parse_args()
242 |
243 | hvd.init()
244 |
245 | local_batch_size = args.batch // hvd.size()
246 | if args.v2:
247 | model = MOCOModel(batch_size=local_batch_size, feature_dims=(2048, 128), temp=0.2)
248 | else:
249 | model = MOCOModel(batch_size=local_batch_size, feature_dims=(128,), temp=0.07)
250 |
251 | if hvd.rank() == 0:
252 | if args.fake:
253 | logger.set_logger_dir('fake_train_log', 'd')
254 | else:
255 | if args.logdir is None:
256 | args.logdir = './moco'
257 | logger.set_logger_dir(args.logdir, 'n')
258 | logger.info("Rank={}, Local Rank={}, Size={}".format(hvd.rank(), hvd.local_rank(), hvd.size()))
259 |
260 | if not args.fake and hvd.local_rank() == 0:
261 | # start data serving process
262 | script = os.path.realpath(os.path.join(os.path.dirname(__file__), "serve-data.py"))
263 | v2_flag = "--v2" if args.v2 else ""
264 | cmd = f"taskset --cpu-list 0-29 {script} --data {args.data} --batch {local_batch_size} {v2_flag}"
265 | log_prefix = os.path.join(args.logdir, "data." + str(hvd.rank()))
266 | logger.info("Launching command: " + cmd)
267 | pid = subprocess.Popen(
268 | cmd,
269 | shell=True,
270 | stdout=open(log_prefix + ".stdout", "w"),
271 | stderr=open(log_prefix + ".stderr", "w"))
272 |
273 | config = get_config(model)
274 | trainer = HorovodTrainer(average=True)
275 | launch_train_with_config(config, trainer)
276 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import tensorflow as tf
4 |
5 | from tensorpack.models import (
6 | BatchNorm, Conv2D, FullyConnected, GlobalAvgPooling, LinearWrap, MaxPooling)
7 | from tensorpack.tfutils import argscope
8 |
9 | from data import tf_preprocess
10 |
11 |
12 | class ResNetModel:
13 | def __init__(self, num_output=None):
14 | """
15 | num_output: int or list[int]: dimension(s) of FC layers in the end
16 | """
17 | self.data_format = "NCHW"
18 | if num_output is not None:
19 | if not isinstance(num_output, (list, tuple)):
20 | num_output = [num_output]
21 | self.num_output = num_output
22 |
23 | def forward(self, image):
24 | # accept [0-255] BGR NHWC images (from dataflow)
25 | image = tf_preprocess(image)
26 | if self.data_format == "NCHW":
27 | image = tf.transpose(image, [0, 3, 1, 2])
28 | return self.get_logits(image)
29 |
30 | def get_logits(self, image):
31 | num_blocks = [3, 4, 6, 3]
32 |
33 | with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format), \
34 | argscope(Conv2D, use_bias=False,
35 | kernel_initializer=tf.variance_scaling_initializer(
36 | scale=2.0, mode='fan_out', distribution='untruncated_normal')), \
37 | argscope(BatchNorm, epsilon=1.001e-5):
38 | logits = (LinearWrap(image)
39 | .tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]])
40 | .Conv2D('conv0', 64, 7, strides=2, padding='VALID')
41 | .apply(self.norm_func, 'conv0')
42 | .tf.nn.relu()
43 | .tf.pad([[0, 0], [0, 0], [1, 1], [1, 1]])
44 | .MaxPooling('pool0', shape=3, stride=2, padding='VALID')
45 | .apply(self.resnet_group, 'group0', 64, num_blocks[0], 1)
46 | .apply(self.resnet_group, 'group1', 128, num_blocks[1], 2)
47 | .apply(self.resnet_group, 'group2', 256, num_blocks[2], 2)
48 | .apply(self.resnet_group, 'group3', 512, num_blocks[3], 2)
49 | .GlobalAvgPooling('gap')())
50 | if self.num_output is not None:
51 | for idx, no in enumerate(self.num_output):
52 | logits = FullyConnected(
53 | 'linear{}_{}'.format(idx, no),
54 | logits, no,
55 | kernel_initializer=tf.random_normal_initializer(stddev=0.01))
56 | if idx != len(self.num_output) - 1:
57 | logits = tf.nn.relu(logits)
58 | return logits
59 |
60 | def norm_func(self, x, name, gamma_initializer=tf.constant_initializer(1.)):
61 | return BatchNorm(name + '_bn', x, gamma_initializer=gamma_initializer)
62 |
63 | def resnet_group(self, l, name, features, count, stride):
64 | with tf.variable_scope(name):
65 | for i in range(0, count):
66 | with tf.variable_scope('block{}'.format(i)):
67 | l = self.bottleneck_block(l, features, stride if i == 0 else 1)
68 | return l
69 |
70 | def bottleneck_block(self, l, ch_out, stride):
71 | shortcut = l
72 | l = Conv2D('conv1', l, ch_out, 1, strides=1)
73 | l = self.norm_func(l, 'conv1')
74 | l = tf.nn.relu(l)
75 |
76 | if stride == 1:
77 | l = Conv2D('conv2', l, ch_out, 3, strides=1)
78 | else:
79 | l = tf.pad(l, [[0, 0], [0, 0], [1, 1], [1, 1]])
80 | l = Conv2D('conv2', l, ch_out, 3, strides=stride, padding='VALID')
81 | l = self.norm_func(l, 'conv2')
82 | l = tf.nn.relu(l)
83 |
84 | l = Conv2D('conv3', l, ch_out * 4, 1)
85 | l = self.norm_func(l, 'conv3') # pt does not use 0init
86 | return tf.nn.relu(
87 | l + self.shortcut(shortcut, ch_out * 4, stride), 'output')
88 |
89 | def shortcut(self, l, n_out, stride):
90 | n_in = l.get_shape().as_list()[1]
91 | if n_in != n_out: # change dimension when channel is not the same
92 | l = Conv2D('convshortcut', l, n_out, 1, strides=stride)
93 | l = self.norm_func(l, 'convshortcut')
94 | return l
95 |
--------------------------------------------------------------------------------
/serve-data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import pprint
6 | import os
7 | import socket
8 | import multiprocessing as mp
9 | import cv2
10 |
11 | from tensorpack.dataflow import FakeData, MapData, TestDataSpeed, send_dataflow_zmq
12 | from tensorpack.utils import logger
13 |
14 | from data import get_moco_dataflow, get_moco_v1_augmentor, get_moco_v2_augmentor
15 | from zmq_ops import dump_arrays
16 |
17 |
18 | cv2.setNumThreads(0)
19 |
20 |
21 | if __name__ == '__main__':
22 | mp.set_start_method('spawn')
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--data', help='ILSVRC dataset dir')
25 | parser.add_argument('--fake', action='store_true')
26 | parser.add_argument('--batch', help='per-GPU batch size',
27 | default=32, type=int)
28 | parser.add_argument('--benchmark', action='store_true')
29 | parser.add_argument('--v2', action='store_true')
30 | parser.add_argument('--no-zmq-ops', action='store_true')
31 | args = parser.parse_args()
32 |
33 | os.environ['CUDA_VISIBLE_DEVICES'] = ''
34 |
35 | if args.fake:
36 | ds = FakeData(
37 | [[args.batch, 224, 224, 3], [args.batch, 224, 224, 3]],
38 | 9999999, random=False, dtype=['uint8', 'uint8'])
39 | else:
40 | aug = get_moco_v2_augmentor() if args.v2 else get_moco_v1_augmentor()
41 | logger.info("Augmentation used: \n" + pprint.pformat(aug))
42 | ds = get_moco_dataflow(args.data, args.batch, aug)
43 |
44 | logger.info("Serving data on {}".format(socket.gethostname()))
45 |
46 | if args.benchmark:
47 | ds = MapData(ds, dump_arrays)
48 | TestDataSpeed(ds, size=99999, warmup=300).start()
49 | else:
50 | format = None if args.no_zmq_ops else 'zmq_ops'
51 | send_dataflow_zmq(
52 | ds, 'ipc://@imagenet-train-b{}'.format(args.batch),
53 | hwm=200, format=format, bind=True)
54 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | # See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
4 | ignore = E265,E741,E742,E743,W504,W605,C408,B007,B008
5 | exclude = .git,
6 | __init__.py,
7 | show-source = true
8 |
9 | [isort]
10 | line_length=100
11 | multi_line_output=4
12 | known_tensorpack=tensorpack
13 | known_standard_library=numpy
14 | known_third_party=matplotlib,tensorflow,cv2,PIL
15 | no_lines_before=STDLIB,THIRDPARTY
16 | sections=FUTURE,STDLIB,THIRDPARTY,tensorpack,FIRSTPARTY,LOCALFOLDER
17 |
--------------------------------------------------------------------------------