├── LICENSE ├── README.md ├── config.py ├── .gitignore ├── model.py ├── preprocess.py └── main.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Myeongjun Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploring Simple Siamese Representation Learning (SimSiam) 2 | 3 | ## Network 4 | 5 | simsiam 6 | 7 | 8 | ## Experiments 9 | | Model | Pre-training Epochs | Batch size | Dim | Linear Evaluation | Acc (%) | 10 | |:-:|:-:|:-:|:-:|:-:|:-:| 11 | | ResNet-18 (Paper) | 800 | 512 | 2048 | O | 91.8 | 12 | | ResNet-18 (Our) | 300 | 512 | 1024 | O | 72.49 | 13 | | ResNet-18 | 800 | 256 | 1024 | O| 83.93 | 14 | | ResNet-18 | | 512 | 2048 | O | wip | 15 | 16 | - plot 17 | 18 | 19 | ## Usage 20 | - Dataset (CIFAR-10) 21 | - [Data Link](https://www.cs.toronto.edu/~kriz/cifar.html) 22 | ``` 23 | data 24 | └── cifar-10-batches-py 25 | ├── batches.meta 26 | ├── data_batch_1 27 | ├── data_batch_2 28 | ├── data_batch_3 29 | ├── data_batch_4 30 | ├── data_batch_5 31 | ├── readme.html 32 | └── test_batch 33 | ``` 34 | 1. Pre-training 35 | ``` 36 | python main.py --pretrain True 37 | ``` 38 | 39 | 2. DownStream Task (Linear) 40 | ``` 41 | python main.py --checkpoints checkpoints/checkpoint_pretrain_model.pth --pretrain False 42 | ``` 43 | 44 | ## Reference 45 | - [Paper Link](https://arxiv.org/abs/2011.10566) 46 | - Author: Xinlei Chen, Kaiming He 47 | - Organization: Facebook AI Research (FAIR) 48 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def load_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # Pre training 8 | parser.add_argument('--base_dir', type=str, default='./data/cifar-10-batches-py') 9 | parser.add_argument('--img_size', type=int, default=32) 10 | parser.add_argument('--batch_size', type=int, default=512) 11 | parser.add_argument('--num_workers', type=int, default=2) 12 | parser.add_argument('--cuda', type=bool, default=True) 13 | parser.add_argument('--epochs', type=int, default=801) 14 | parser.add_argument('--lr', type=float, default=0.03) 15 | parser.add_argument('--momentum', type=float, default=0.9) 16 | parser.add_argument('--weight_decay', type=float, default=5e-4) 17 | parser.add_argument('--checkpoints', type=str, default=None) 18 | parser.add_argument('--pretrain', type=bool, default=True) 19 | parser.add_argument('--device_num', type=int, default=1) 20 | parser.add_argument('--print_intervals', type=int, default=100) 21 | 22 | # Network 23 | parser.add_argument('--proj_hidden', type=int, default=2048) 24 | parser.add_argument('--proj_out', type=int, default=2048) 25 | parser.add_argument('--pred_hidden', type=int, default=512) 26 | parser.add_argument('--pred_out', type=int, default=2048) 27 | 28 | # Down Stream Task 29 | parser.add_argument('--down_lr', type=float, default=0.03) 30 | parser.add_argument('--down_epochs', type=int, default=810) 31 | parser.add_argument('--down_batch_size', type=int, default=256) 32 | 33 | args = parser.parse_args() 34 | 35 | return args 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | class D(nn.Module): 8 | def __init__(self): 9 | super(D, self).__init__() 10 | 11 | def forward(self, p, z): 12 | z = z.detach() 13 | 14 | p = F.normalize(p, p=2, dim=1) 15 | z = F.normalize(z, p=2, dim=1) 16 | return -(p * z).sum(dim=1).mean() 17 | 18 | 19 | class Model(nn.Module): 20 | def __init__(self, args, downstream=False): 21 | super(Model, self).__init__() 22 | resnet18 = models.resnet18(pretrained=False) 23 | proj_hid, proj_out = args.proj_hidden, args.proj_out 24 | pred_hid, pred_out = args.pred_hidden, args.pred_out 25 | 26 | 27 | self.backbone = nn.Sequential(*list(resnet18.children())[:-1]) 28 | backbone_in_channels = resnet18.fc.in_features 29 | 30 | self.projection = nn.Sequential( 31 | nn.Linear(backbone_in_channels, proj_hid), 32 | nn.BatchNorm1d(proj_hid), 33 | nn.ReLU(), 34 | nn.Linear(proj_hid, proj_hid), 35 | nn.BatchNorm1d(proj_hid), 36 | nn.ReLU(), 37 | nn.Linear(proj_hid, proj_out), 38 | nn.BatchNorm1d(proj_out) 39 | ) 40 | 41 | self.prediction = nn.Sequential( 42 | nn.Linear(proj_out, pred_hid), 43 | nn.BatchNorm1d(pred_hid), 44 | nn.ReLU(), 45 | nn.Linear(pred_hid, pred_out), 46 | ) 47 | 48 | self.d = D() 49 | 50 | if args.checkpoints is not None and downstream: 51 | self.load_state_dict(torch.load(args.checkpoints)['model_state_dict']) 52 | 53 | def forward(self, x1, x2): 54 | out1 = self.backbone(x1).squeeze() 55 | z1 = self.projection(out1) 56 | p1 = self.prediction(z1) 57 | 58 | out2 = self.backbone(x2).squeeze() 59 | z2 = self.projection(out2) 60 | p2 = self.prediction(z2) 61 | 62 | d1 = self.d(p1, z2) / 2. 63 | d2 = self.d(p2, z1) / 2. 64 | 65 | return d1, d2 66 | 67 | 68 | class DownStreamModel(nn.Module): 69 | def __init__(self, args, n_classes=10): 70 | super(DownStreamModel, self).__init__() 71 | self.simsiam = Model(args, downstream=True) 72 | hidden = 512 73 | 74 | self.net_backbone = nn.Sequential( 75 | self.simsiam.backbone, 76 | ) 77 | 78 | for name, param in self.net_backbone.named_parameters(): 79 | param.requires_grad = False 80 | 81 | self.net_projection = nn.Sequential( 82 | self.simsiam.projection, 83 | ) 84 | 85 | for name, param in self.net_projection.named_parameters(): 86 | param.requires_grad = False 87 | 88 | self.out = nn.Sequential( 89 | nn.Linear(args.proj_out, hidden), 90 | nn.BatchNorm1d(hidden), 91 | nn.ReLU(), 92 | nn.Linear(hidden, n_classes), 93 | ) 94 | 95 | def forward(self, x): 96 | out = self.net_backbone(x).squeeze() 97 | out = self.net_projection(out) 98 | out = self.out(out) 99 | 100 | return out 101 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torchvision.transforms as transforms 3 | 4 | import os 5 | import pickle 6 | import numpy as np 7 | from PIL import Image 8 | 9 | 10 | # reference 11 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py 12 | class SimSiamDataset(Dataset): 13 | base_folder = 'cifar-10-batches-py' 14 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 15 | filename = "cifar-10-python.tar.gz" 16 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 17 | train_list = [ 18 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 19 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 20 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 21 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 22 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 23 | ] 24 | 25 | test_list = [ 26 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 27 | ] 28 | meta = { 29 | 'filename': 'batches.meta', 30 | 'key': 'label_names', 31 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 32 | } 33 | 34 | def __init__(self, args, mode='train', downstream=False): 35 | if mode == 'train': 36 | data_list = self.train_list 37 | else: 38 | data_list = self.test_list 39 | self.targets = [] 40 | self.data = [] 41 | self.args = args 42 | self.downstream = downstream 43 | 44 | for file_name, checksum in data_list: 45 | file_path = os.path.join(args.base_dir, file_name) 46 | with open(file_path, 'rb') as f: 47 | entry = pickle.load(f, encoding='latin1') 48 | self.data.append(entry['data']) 49 | if 'labels' in entry: 50 | self.targets.extend(entry['labels']) 51 | else: 52 | self.targets.extend(entry['fine_labels']) 53 | 54 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 55 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 56 | 57 | train_transform = transforms.Compose([ 58 | transforms.RandomResizedCrop(self.args.img_size, scale=(0.2, 1.0)), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 61 | transforms.RandomGrayscale(0.2), 62 | # transforms.GaussianBlur(kernel_size=int(self.args.img_size * 0.1), sigma=(0.1, 2.0)), 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 65 | ]) 66 | 67 | test_transform = transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 70 | ]) 71 | 72 | if downstream: 73 | if mode == 'train': 74 | self.transform1 = train_transform 75 | else: 76 | self.transform1 = test_transform 77 | else: 78 | self.transform1 = train_transform 79 | self.transform2 = train_transform 80 | 81 | def __getitem__(self, index: int): 82 | img1, target = self.data[index], self.targets[index] 83 | img2 = img1.copy() 84 | 85 | img1 = Image.fromarray(img1) 86 | img1 = self.transform1(img1) 87 | 88 | if self.downstream: 89 | return img1, target 90 | 91 | img2 = Image.fromarray(img2) 92 | img2 = self.transform2(img2) 93 | 94 | return img1, img2, target 95 | 96 | def __len__(self) -> int: 97 | return len(self.data) 98 | 99 | 100 | def load_data(args): 101 | train_data = SimSiamDataset(args) 102 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 103 | 104 | test_data = SimSiamDataset(args, mode='test') 105 | test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 106 | 107 | down_train_data = SimSiamDataset(args, downstream=True) 108 | down_train_loader = DataLoader(down_train_data, batch_size=args.down_batch_size, shuffle=True, num_workers=args.num_workers) 109 | 110 | down_test_data = SimSiamDataset(args, mode='test', downstream=True) 111 | down_test_loader = DataLoader(down_test_data, batch_size=args.down_batch_size, shuffle=False, num_workers=args.num_workers) 112 | 113 | return train_loader, test_loader, down_train_loader, down_test_loader 114 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from config import load_args 7 | from model import Model, DownStreamModel 8 | from preprocess import load_data 9 | 10 | import os 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | def save_checkpoint(model, optimizer, args, epoch): 15 | print('\nModel Saving...') 16 | if args.device_num > 1: 17 | model_state_dict = model.module.state_dict() 18 | else: 19 | model_state_dict = model.state_dict() 20 | 21 | torch.save({ 22 | 'model_state_dict': model_state_dict, 23 | 'global_epoch': epoch, 24 | 'optimizer_state_dict': optimizer.state_dict(), 25 | }, os.path.join('checkpoints', 'checkpoint_pretrain_model.pth')) 26 | 27 | 28 | def pre_train(epoch, train_loader, model, optimizer, args): 29 | model.train() 30 | 31 | losses, step = 0., 0. 32 | for x1, x2, target in train_loader: 33 | if args.cuda: 34 | x1, x2 = x1.cuda(), x2.cuda() 35 | 36 | d1, d2 = model(x1, x2) 37 | 38 | optimizer.zero_grad() 39 | loss = d1 + d2 40 | loss.backward() 41 | optimizer.step() 42 | losses += loss.item() 43 | 44 | step += 1 45 | 46 | print('[Epoch: {0:4d}, loss: {1:.3f}'.format(epoch, losses / step)) 47 | return losses / step 48 | 49 | 50 | def _train(epoch, train_loader, model, optimizer, criterion, args): 51 | model.train() 52 | 53 | losses, acc, step, total = 0., 0., 0., 0. 54 | for data, target in train_loader: 55 | if args.cuda: 56 | data, target = data.cuda(), target.cuda() 57 | 58 | logits = model(data) 59 | 60 | optimizer.zero_grad() 61 | loss = criterion(logits, target) 62 | loss.backward() 63 | losses += loss.item() 64 | optimizer.step() 65 | 66 | pred = F.softmax(logits, dim=-1).max(-1)[1] 67 | acc += pred.eq(target).sum().item() 68 | 69 | step += 1 70 | total += target.size(0) 71 | 72 | print('[Down Task Train Epoch: {0:4d}], loss: {1:.3f}, acc: {2:.3f}'.format(epoch, losses / step, acc / total * 100.)) 73 | 74 | 75 | def _eval(epoch, test_loader, model, criterion, args): 76 | model.eval() 77 | 78 | losses, acc, step, total = 0., 0., 0., 0. 79 | with torch.no_grad(): 80 | for data, target in test_loader: 81 | if args.cuda: 82 | data, target = data.cuda(), target.cuda() 83 | 84 | logits = model(data) 85 | loss = criterion(logits, target) 86 | losses += loss.item() 87 | pred = F.softmax(logits, dim=-1).max(-1)[1] 88 | acc += pred.eq(target).sum().item() 89 | 90 | step += 1 91 | total += target.size(0) 92 | print('[Down Task Test Epoch: {0:4d}], loss: {1:.3f}, acc: {2:.3f}'.format(epoch, losses / step, acc / total * 100.)) 93 | 94 | 95 | def train_eval_down_task(down_model, down_train_loader, down_test_loader, args): 96 | down_optimizer = optim.SGD(down_model.parameters(), lr=args.down_lr, weight_decay=args.weight_decay, momentum=args.momentum) 97 | down_criterion = nn.CrossEntropyLoss() 98 | down_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(down_optimizer, T_max=args.down_epochs) 99 | for epoch in range(1, args.down_epochs + 1): 100 | _train(epoch, down_train_loader, down_model, down_optimizer, down_criterion, args) 101 | _eval(epoch, down_test_loader, down_model, down_criterion, args) 102 | down_lr_scheduler.step() 103 | 104 | 105 | def main(args): 106 | train_loader, test_loader, down_train_loader, down_test_loader, = load_data(args) 107 | 108 | if not os.path.isdir('checkpoints'): 109 | os.mkdir('checkpoints') 110 | 111 | model = Model(args) 112 | down_model = DownStreamModel(args) 113 | if args.cuda: 114 | model = model.cuda() 115 | down_model = down_model.cuda() 116 | 117 | if args.pretrain: 118 | optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) 119 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=800) 120 | 121 | train_losses, epoch_list = [], [] 122 | for epoch in range(1, args.epochs + 1): 123 | train_loss = pre_train(epoch, train_loader, model, optimizer, args) 124 | if epoch % args.print_intervals == 0: 125 | save_checkpoint(model, optimizer, args, epoch) 126 | args.down_epochs = 1 127 | train_eval_down_task(down_model, down_train_loader, down_test_loader, args) 128 | lr_scheduler.step() 129 | train_losses.append(train_loss) 130 | epoch_list.append(epoch) 131 | print(' Cur lr: {0:.5f}'.format(lr_scheduler.get_last_lr()[0])) 132 | plt.plot(epoch_list, train_losses) 133 | plt.savefig('test.png', dpi=300) 134 | else: 135 | args.down_epochs = 810 136 | train_eval_down_task(down_model, down_train_loader, down_test_loader, args) 137 | 138 | 139 | if __name__ == '__main__': 140 | args = load_args() 141 | main(args) 142 | --------------------------------------------------------------------------------