├── .gitignore ├── LICENSE ├── README.md ├── network.py ├── requirements.txt ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 peisuke 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Momentum Contrast 2 | 3 | This repository is reproduced code of "Momentum Contrast for Unsupervised Visual Representation Learning" 4 | 5 | The paper link is [here](https://arxiv.org/abs/1911.05722) 6 | 7 | # How to use 8 | 9 | Run the train script. 10 | 11 | ``` 12 | python train.py 13 | ``` 14 | 15 | # Visualizing the features with t-SNE 16 | 17 | The below figure visualizes the extracted features of MNIST with t-SNE. 18 | 19 | ``` 20 | python test.py 21 | ``` 22 | 23 | ![image](https://user-images.githubusercontent.com/14243883/71716313-a1372680-2e57-11ea-8bbd-a284be180bd6.png) 24 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Net(nn.Module): 6 | def __init__(self): 7 | super(Net, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 9 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 10 | self.fc1 = nn.Linear(9216, 128) 11 | 12 | def forward(self, x): 13 | x = self.conv1(x) 14 | x = F.relu(x) 15 | x = self.conv2(x) 16 | x = F.max_pool2d(x, 2) 17 | x = torch.flatten(x, 1) 18 | x = self.fc1(x) 19 | x = F.normalize(x) 20 | return x 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.1 2 | torchvision==0.4.2 3 | Pillow 4 | numpy 5 | tqdm 6 | scikit-learn 7 | matplotlib 8 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | import tqdm 8 | from sklearn.manifold import TSNE 9 | import matplotlib.pyplot as plt 10 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox 11 | from network import Net 12 | 13 | def show(mnist, targets, ret): 14 | target_ids = range(len(set(targets))) 15 | 16 | colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'violet', 'orange', 'purple'] 17 | 18 | plt.figure(figsize=(12, 10)) 19 | 20 | ax = plt.subplot(aspect='equal') 21 | for label in set(targets): 22 | idx = np.where(np.array(targets) == label)[0] 23 | plt.scatter(ret[idx, 0], ret[idx, 1], c=colors[label], label=label) 24 | 25 | for i in range(0, len(targets), 250): 26 | img = (mnist[i][0] * 0.3081 + 0.1307).numpy()[0] 27 | img = OffsetImage(img, cmap=plt.cm.gray_r, zoom=0.5) 28 | ax.add_artist(AnnotationBbox(img, ret[i])) 29 | 30 | plt.legend() 31 | plt.show() 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser(description='MoCo example: MNIST') 35 | parser.add_argument('--model', '-m', default='result/model.pth', 36 | help='Model file') 37 | args = parser.parse_args() 38 | model_path = args.model 39 | 40 | transform = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.1307,), (0.3081,))]) 43 | 44 | mnist = datasets.MNIST('./', train=False, download=True, transform=transform) 45 | 46 | model = Net() 47 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) 48 | 49 | data = [] 50 | targets = [] 51 | for m in tqdm.tqdm(mnist): 52 | target = m[1] 53 | targets.append(target) 54 | x = m[0] 55 | x = x.view(1, *x.shape) 56 | feat = model(x) 57 | data.append(feat.data.numpy()[0]) 58 | 59 | ret = TSNE(n_components=2, random_state=0).fit_transform(data) 60 | 61 | show(mnist, targets, ret) 62 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from PIL import Image 9 | import numpy as np 10 | import copy 11 | 12 | from network import Net 13 | 14 | class DuplicatedCompose(object): 15 | def __init__(self, transforms): 16 | self.transforms = transforms 17 | 18 | def __call__(self, img): 19 | img1 = img.copy() 20 | img2 = img.copy() 21 | for t in self.transforms: 22 | img1 = t(img1) 23 | img2 = t(img2) 24 | return img1, img2 25 | 26 | def momentum_update(model_q, model_k, beta = 0.999): 27 | param_k = model_k.state_dict() 28 | param_q = model_q.named_parameters() 29 | for n, q in param_q: 30 | if n in param_k: 31 | param_k[n].data.copy_(beta*param_k[n].data + (1-beta)*q.data) 32 | model_k.load_state_dict(param_k) 33 | 34 | def queue_data(data, k): 35 | return torch.cat([data, k], dim=0) 36 | 37 | def dequeue_data(data, K=4096): 38 | if len(data) > K: 39 | return data[-K:] 40 | else: 41 | return data 42 | 43 | def initialize_queue(model_k, device, train_loader): 44 | queue = torch.zeros((0, 128), dtype=torch.float) 45 | queue = queue.to(device) 46 | 47 | for batch_idx, (data, target) in enumerate(train_loader): 48 | x_k = data[1] 49 | x_k = x_k.to(device) 50 | k = model_k(x_k) 51 | k = k.detach() 52 | queue = queue_data(queue, k) 53 | queue = dequeue_data(queue, K = 10) 54 | break 55 | return queue 56 | 57 | def train(model_q, model_k, device, train_loader, queue, optimizer, epoch, temp=0.07): 58 | model_q.train() 59 | total_loss = 0 60 | 61 | for batch_idx, (data, target) in enumerate(train_loader): 62 | x_q = data[0] 63 | x_k = data[1] 64 | 65 | x_q, x_k = x_q.to(device), x_k.to(device) 66 | q = model_q(x_q) 67 | k = model_k(x_k) 68 | k = k.detach() 69 | 70 | N = data[0].shape[0] 71 | K = queue.shape[0] 72 | l_pos = torch.bmm(q.view(N,1,-1), k.view(N,-1,1)) 73 | l_neg = torch.mm(q.view(N,-1), queue.T.view(-1,K)) 74 | 75 | logits = torch.cat([l_pos.view(N, 1), l_neg], dim=1) 76 | 77 | labels = torch.zeros(N, dtype=torch.long) 78 | labels = labels.to(device) 79 | 80 | cross_entropy_loss = nn.CrossEntropyLoss() 81 | loss = cross_entropy_loss(logits/temp, labels) 82 | 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | 87 | total_loss += loss.item() 88 | 89 | momentum_update(model_q, model_k) 90 | 91 | queue = queue_data(queue, k) 92 | queue = dequeue_data(queue) 93 | 94 | total_loss /= len(train_loader.dataset) 95 | 96 | print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, total_loss)) 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser(description='MoCo example: MNIST') 100 | parser.add_argument('--batchsize', '-b', type=int, default=100, 101 | help='Number of images in each mini-batch') 102 | parser.add_argument('--epochs', '-e', type=int, default=50, 103 | help='Number of sweeps over the dataset to train') 104 | parser.add_argument('--out', '-o', default='result', 105 | help='Directory to output the result') 106 | parser.add_argument('--no-cuda', action='store_true', default=False, 107 | help='disables CUDA training') 108 | args = parser.parse_args() 109 | 110 | batchsize = args.batchsize 111 | epochs = args.epochs 112 | out_dir = args.out 113 | 114 | use_cuda = not args.no_cuda and torch.cuda.is_available() 115 | device = torch.device("cuda" if use_cuda else "cpu") 116 | 117 | kwargs = {'num_workers': 4, 'pin_memory': True} 118 | 119 | transform = DuplicatedCompose([ 120 | transforms.RandomRotation(20), 121 | transforms.RandomResizedCrop(28, scale=(0.9, 1.1), ratio=(0.9, 1.1), interpolation=2), 122 | transforms.ToTensor(), 123 | transforms.Normalize((0.1307,), (0.3081,))]) 124 | 125 | train_mnist = datasets.MNIST('./', train=True, download=True, transform=transform) 126 | test_mnist = datasets.MNIST('./', train=False, download=True, transform=transform) 127 | 128 | train_loader = torch.utils.data.DataLoader(train_mnist, batch_size=batchsize, shuffle=True, **kwargs) 129 | test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=batchsize, shuffle=True, **kwargs) 130 | 131 | model_q = Net().to(device) 132 | model_k = copy.deepcopy(model_q) 133 | optimizer = optim.SGD(model_q.parameters(), lr=0.01, weight_decay=0.0001) 134 | 135 | queue = initialize_queue(model_k, device, train_loader) 136 | 137 | for epoch in range(1, epochs + 1): 138 | train(model_q, model_k, device, train_loader, queue, optimizer, epoch) 139 | 140 | os.makedirs(out_dir, exist_ok=True) 141 | torch.save(model_q.state_dict(), os.path.join(out_dir, 'model.pth')) 142 | --------------------------------------------------------------------------------