├── LICENSE
├── README.md
├── config.py
├── .gitignore
├── model.py
├── preprocess.py
└── main.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Myeongjun Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Exploring Simple Siamese Representation Learning (SimSiam)
2 |
3 | ## Network
4 |
5 |
6 |
7 |
8 | ## Experiments
9 | | Model | Pre-training Epochs | Batch size | Dim | Linear Evaluation | Acc (%) |
10 | |:-:|:-:|:-:|:-:|:-:|:-:|
11 | | ResNet-18 (Paper) | 800 | 512 | 2048 | O | 91.8 |
12 | | ResNet-18 (Our) | 300 | 512 | 1024 | O | 72.49 |
13 | | ResNet-18 | 800 | 256 | 1024 | O| 83.93 |
14 | | ResNet-18 | | 512 | 2048 | O | wip |
15 |
16 | - plot
17 |
18 |
19 | ## Usage
20 | - Dataset (CIFAR-10)
21 | - [Data Link](https://www.cs.toronto.edu/~kriz/cifar.html)
22 | ```
23 | data
24 | └── cifar-10-batches-py
25 | ├── batches.meta
26 | ├── data_batch_1
27 | ├── data_batch_2
28 | ├── data_batch_3
29 | ├── data_batch_4
30 | ├── data_batch_5
31 | ├── readme.html
32 | └── test_batch
33 | ```
34 | 1. Pre-training
35 | ```
36 | python main.py --pretrain True
37 | ```
38 |
39 | 2. DownStream Task (Linear)
40 | ```
41 | python main.py --checkpoints checkpoints/checkpoint_pretrain_model.pth --pretrain False
42 | ```
43 |
44 | ## Reference
45 | - [Paper Link](https://arxiv.org/abs/2011.10566)
46 | - Author: Xinlei Chen, Kaiming He
47 | - Organization: Facebook AI Research (FAIR)
48 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def load_args():
5 | parser = argparse.ArgumentParser()
6 |
7 | # Pre training
8 | parser.add_argument('--base_dir', type=str, default='./data/cifar-10-batches-py')
9 | parser.add_argument('--img_size', type=int, default=32)
10 | parser.add_argument('--batch_size', type=int, default=512)
11 | parser.add_argument('--num_workers', type=int, default=2)
12 | parser.add_argument('--cuda', type=bool, default=True)
13 | parser.add_argument('--epochs', type=int, default=801)
14 | parser.add_argument('--lr', type=float, default=0.03)
15 | parser.add_argument('--momentum', type=float, default=0.9)
16 | parser.add_argument('--weight_decay', type=float, default=5e-4)
17 | parser.add_argument('--checkpoints', type=str, default=None)
18 | parser.add_argument('--pretrain', type=bool, default=True)
19 | parser.add_argument('--device_num', type=int, default=1)
20 | parser.add_argument('--print_intervals', type=int, default=100)
21 |
22 | # Network
23 | parser.add_argument('--proj_hidden', type=int, default=2048)
24 | parser.add_argument('--proj_out', type=int, default=2048)
25 | parser.add_argument('--pred_hidden', type=int, default=512)
26 | parser.add_argument('--pred_out', type=int, default=2048)
27 |
28 | # Down Stream Task
29 | parser.add_argument('--down_lr', type=float, default=0.03)
30 | parser.add_argument('--down_epochs', type=int, default=810)
31 | parser.add_argument('--down_batch_size', type=int, default=256)
32 |
33 | args = parser.parse_args()
34 |
35 | return args
36 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 |
6 |
7 | class D(nn.Module):
8 | def __init__(self):
9 | super(D, self).__init__()
10 |
11 | def forward(self, p, z):
12 | z = z.detach()
13 |
14 | p = F.normalize(p, p=2, dim=1)
15 | z = F.normalize(z, p=2, dim=1)
16 | return -(p * z).sum(dim=1).mean()
17 |
18 |
19 | class Model(nn.Module):
20 | def __init__(self, args, downstream=False):
21 | super(Model, self).__init__()
22 | resnet18 = models.resnet18(pretrained=False)
23 | proj_hid, proj_out = args.proj_hidden, args.proj_out
24 | pred_hid, pred_out = args.pred_hidden, args.pred_out
25 |
26 |
27 | self.backbone = nn.Sequential(*list(resnet18.children())[:-1])
28 | backbone_in_channels = resnet18.fc.in_features
29 |
30 | self.projection = nn.Sequential(
31 | nn.Linear(backbone_in_channels, proj_hid),
32 | nn.BatchNorm1d(proj_hid),
33 | nn.ReLU(),
34 | nn.Linear(proj_hid, proj_hid),
35 | nn.BatchNorm1d(proj_hid),
36 | nn.ReLU(),
37 | nn.Linear(proj_hid, proj_out),
38 | nn.BatchNorm1d(proj_out)
39 | )
40 |
41 | self.prediction = nn.Sequential(
42 | nn.Linear(proj_out, pred_hid),
43 | nn.BatchNorm1d(pred_hid),
44 | nn.ReLU(),
45 | nn.Linear(pred_hid, pred_out),
46 | )
47 |
48 | self.d = D()
49 |
50 | if args.checkpoints is not None and downstream:
51 | self.load_state_dict(torch.load(args.checkpoints)['model_state_dict'])
52 |
53 | def forward(self, x1, x2):
54 | out1 = self.backbone(x1).squeeze()
55 | z1 = self.projection(out1)
56 | p1 = self.prediction(z1)
57 |
58 | out2 = self.backbone(x2).squeeze()
59 | z2 = self.projection(out2)
60 | p2 = self.prediction(z2)
61 |
62 | d1 = self.d(p1, z2) / 2.
63 | d2 = self.d(p2, z1) / 2.
64 |
65 | return d1, d2
66 |
67 |
68 | class DownStreamModel(nn.Module):
69 | def __init__(self, args, n_classes=10):
70 | super(DownStreamModel, self).__init__()
71 | self.simsiam = Model(args, downstream=True)
72 | hidden = 512
73 |
74 | self.net_backbone = nn.Sequential(
75 | self.simsiam.backbone,
76 | )
77 |
78 | for name, param in self.net_backbone.named_parameters():
79 | param.requires_grad = False
80 |
81 | self.net_projection = nn.Sequential(
82 | self.simsiam.projection,
83 | )
84 |
85 | for name, param in self.net_projection.named_parameters():
86 | param.requires_grad = False
87 |
88 | self.out = nn.Sequential(
89 | nn.Linear(args.proj_out, hidden),
90 | nn.BatchNorm1d(hidden),
91 | nn.ReLU(),
92 | nn.Linear(hidden, n_classes),
93 | )
94 |
95 | def forward(self, x):
96 | out = self.net_backbone(x).squeeze()
97 | out = self.net_projection(out)
98 | out = self.out(out)
99 |
100 | return out
101 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import torchvision.transforms as transforms
3 |
4 | import os
5 | import pickle
6 | import numpy as np
7 | from PIL import Image
8 |
9 |
10 | # reference
11 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py
12 | class SimSiamDataset(Dataset):
13 | base_folder = 'cifar-10-batches-py'
14 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
15 | filename = "cifar-10-python.tar.gz"
16 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
17 | train_list = [
18 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
19 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
20 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
21 | ['data_batch_4', '634d18415352ddfa80567beed471001a'],
22 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
23 | ]
24 |
25 | test_list = [
26 | ['test_batch', '40351d587109b95175f43aff81a1287e'],
27 | ]
28 | meta = {
29 | 'filename': 'batches.meta',
30 | 'key': 'label_names',
31 | 'md5': '5ff9c542aee3614f3951f8cda6e48888',
32 | }
33 |
34 | def __init__(self, args, mode='train', downstream=False):
35 | if mode == 'train':
36 | data_list = self.train_list
37 | else:
38 | data_list = self.test_list
39 | self.targets = []
40 | self.data = []
41 | self.args = args
42 | self.downstream = downstream
43 |
44 | for file_name, checksum in data_list:
45 | file_path = os.path.join(args.base_dir, file_name)
46 | with open(file_path, 'rb') as f:
47 | entry = pickle.load(f, encoding='latin1')
48 | self.data.append(entry['data'])
49 | if 'labels' in entry:
50 | self.targets.extend(entry['labels'])
51 | else:
52 | self.targets.extend(entry['fine_labels'])
53 |
54 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
55 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
56 |
57 | train_transform = transforms.Compose([
58 | transforms.RandomResizedCrop(self.args.img_size, scale=(0.2, 1.0)),
59 | transforms.RandomHorizontalFlip(),
60 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
61 | transforms.RandomGrayscale(0.2),
62 | # transforms.GaussianBlur(kernel_size=int(self.args.img_size * 0.1), sigma=(0.1, 2.0)),
63 | transforms.ToTensor(),
64 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
65 | ])
66 |
67 | test_transform = transforms.Compose([
68 | transforms.ToTensor(),
69 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
70 | ])
71 |
72 | if downstream:
73 | if mode == 'train':
74 | self.transform1 = train_transform
75 | else:
76 | self.transform1 = test_transform
77 | else:
78 | self.transform1 = train_transform
79 | self.transform2 = train_transform
80 |
81 | def __getitem__(self, index: int):
82 | img1, target = self.data[index], self.targets[index]
83 | img2 = img1.copy()
84 |
85 | img1 = Image.fromarray(img1)
86 | img1 = self.transform1(img1)
87 |
88 | if self.downstream:
89 | return img1, target
90 |
91 | img2 = Image.fromarray(img2)
92 | img2 = self.transform2(img2)
93 |
94 | return img1, img2, target
95 |
96 | def __len__(self) -> int:
97 | return len(self.data)
98 |
99 |
100 | def load_data(args):
101 | train_data = SimSiamDataset(args)
102 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
103 |
104 | test_data = SimSiamDataset(args, mode='test')
105 | test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
106 |
107 | down_train_data = SimSiamDataset(args, downstream=True)
108 | down_train_loader = DataLoader(down_train_data, batch_size=args.down_batch_size, shuffle=True, num_workers=args.num_workers)
109 |
110 | down_test_data = SimSiamDataset(args, mode='test', downstream=True)
111 | down_test_loader = DataLoader(down_test_data, batch_size=args.down_batch_size, shuffle=False, num_workers=args.num_workers)
112 |
113 | return train_loader, test_loader, down_train_loader, down_test_loader
114 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 |
6 | from config import load_args
7 | from model import Model, DownStreamModel
8 | from preprocess import load_data
9 |
10 | import os
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | def save_checkpoint(model, optimizer, args, epoch):
15 | print('\nModel Saving...')
16 | if args.device_num > 1:
17 | model_state_dict = model.module.state_dict()
18 | else:
19 | model_state_dict = model.state_dict()
20 |
21 | torch.save({
22 | 'model_state_dict': model_state_dict,
23 | 'global_epoch': epoch,
24 | 'optimizer_state_dict': optimizer.state_dict(),
25 | }, os.path.join('checkpoints', 'checkpoint_pretrain_model.pth'))
26 |
27 |
28 | def pre_train(epoch, train_loader, model, optimizer, args):
29 | model.train()
30 |
31 | losses, step = 0., 0.
32 | for x1, x2, target in train_loader:
33 | if args.cuda:
34 | x1, x2 = x1.cuda(), x2.cuda()
35 |
36 | d1, d2 = model(x1, x2)
37 |
38 | optimizer.zero_grad()
39 | loss = d1 + d2
40 | loss.backward()
41 | optimizer.step()
42 | losses += loss.item()
43 |
44 | step += 1
45 |
46 | print('[Epoch: {0:4d}, loss: {1:.3f}'.format(epoch, losses / step))
47 | return losses / step
48 |
49 |
50 | def _train(epoch, train_loader, model, optimizer, criterion, args):
51 | model.train()
52 |
53 | losses, acc, step, total = 0., 0., 0., 0.
54 | for data, target in train_loader:
55 | if args.cuda:
56 | data, target = data.cuda(), target.cuda()
57 |
58 | logits = model(data)
59 |
60 | optimizer.zero_grad()
61 | loss = criterion(logits, target)
62 | loss.backward()
63 | losses += loss.item()
64 | optimizer.step()
65 |
66 | pred = F.softmax(logits, dim=-1).max(-1)[1]
67 | acc += pred.eq(target).sum().item()
68 |
69 | step += 1
70 | total += target.size(0)
71 |
72 | print('[Down Task Train Epoch: {0:4d}], loss: {1:.3f}, acc: {2:.3f}'.format(epoch, losses / step, acc / total * 100.))
73 |
74 |
75 | def _eval(epoch, test_loader, model, criterion, args):
76 | model.eval()
77 |
78 | losses, acc, step, total = 0., 0., 0., 0.
79 | with torch.no_grad():
80 | for data, target in test_loader:
81 | if args.cuda:
82 | data, target = data.cuda(), target.cuda()
83 |
84 | logits = model(data)
85 | loss = criterion(logits, target)
86 | losses += loss.item()
87 | pred = F.softmax(logits, dim=-1).max(-1)[1]
88 | acc += pred.eq(target).sum().item()
89 |
90 | step += 1
91 | total += target.size(0)
92 | print('[Down Task Test Epoch: {0:4d}], loss: {1:.3f}, acc: {2:.3f}'.format(epoch, losses / step, acc / total * 100.))
93 |
94 |
95 | def train_eval_down_task(down_model, down_train_loader, down_test_loader, args):
96 | down_optimizer = optim.SGD(down_model.parameters(), lr=args.down_lr, weight_decay=args.weight_decay, momentum=args.momentum)
97 | down_criterion = nn.CrossEntropyLoss()
98 | down_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(down_optimizer, T_max=args.down_epochs)
99 | for epoch in range(1, args.down_epochs + 1):
100 | _train(epoch, down_train_loader, down_model, down_optimizer, down_criterion, args)
101 | _eval(epoch, down_test_loader, down_model, down_criterion, args)
102 | down_lr_scheduler.step()
103 |
104 |
105 | def main(args):
106 | train_loader, test_loader, down_train_loader, down_test_loader, = load_data(args)
107 |
108 | if not os.path.isdir('checkpoints'):
109 | os.mkdir('checkpoints')
110 |
111 | model = Model(args)
112 | down_model = DownStreamModel(args)
113 | if args.cuda:
114 | model = model.cuda()
115 | down_model = down_model.cuda()
116 |
117 | if args.pretrain:
118 | optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
119 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=800)
120 |
121 | train_losses, epoch_list = [], []
122 | for epoch in range(1, args.epochs + 1):
123 | train_loss = pre_train(epoch, train_loader, model, optimizer, args)
124 | if epoch % args.print_intervals == 0:
125 | save_checkpoint(model, optimizer, args, epoch)
126 | args.down_epochs = 1
127 | train_eval_down_task(down_model, down_train_loader, down_test_loader, args)
128 | lr_scheduler.step()
129 | train_losses.append(train_loss)
130 | epoch_list.append(epoch)
131 | print(' Cur lr: {0:.5f}'.format(lr_scheduler.get_last_lr()[0]))
132 | plt.plot(epoch_list, train_losses)
133 | plt.savefig('test.png', dpi=300)
134 | else:
135 | args.down_epochs = 810
136 | train_eval_down_task(down_model, down_train_loader, down_test_loader, args)
137 |
138 |
139 | if __name__ == '__main__':
140 | args = load_args()
141 | main(args)
142 |
--------------------------------------------------------------------------------