├── .gitignore
├── LICENSE
├── README.md
├── assets
└── models.png
├── configs
├── baseline.json
├── pretrain.json
└── train.json
├── pretrain.py
├── requirements.txt
├── simsiam
├── __init__.py
├── losses.py
├── models.py
└── transforms.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | .mypy_cache/
3 | __pycache__/
4 | runs/
5 | data/
6 | models/
7 | *.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 isaac
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # simsiam-pytorch
2 | Minimal PyTorch Implementation of SimSiam from ["Exploring Simple Siamese Representation Learning" by Chen et al.](https://arxiv.org/abs/2011.10566)
3 |
4 |

5 |
6 |
7 | ### Load and train on a custom dataset
8 |
9 | ```python
10 | from simsiam.models import SimSiam
11 | from simsiam.losses import negative_cosine_similarity
12 |
13 | model = SimSiam(
14 | backbone="resnet50", # encoder network
15 | latent_dim=2048, # predictor network output size
16 | proj_hidden_dim=2048 # projection mlp hidden layer size
17 | pred_hidden_dim=512 # predictor mlp hidden layer size
18 | )
19 | model = model.to("cuda") # use all the parallels
20 | model.train()
21 |
22 | transforms = ...
23 | dataset = ...
24 | dataloader = ...
25 | opt = ...
26 |
27 | for epoch in range(epochs):
28 | for batch, (x, y) in enumerate(dataloader):
29 | opt.zero_grad()
30 |
31 | x1, x2 = transforms(x), transforms(x) # Augment
32 | e1, e2 = model.encode(x1), model.encode(x2) # Encode
33 | z1, z2 = model.project(e1), model.project(e2) # Project
34 | p1, p2 = model.predict(z1), model.predict(z2) # Predict
35 |
36 | # Compute loss
37 | loss1 = negative_cosine_similarity(p1, z1)
38 | loss2 = negative_cosine_similarity(p2, z2)
39 | loss = loss1/2 + loss2/2
40 | loss.backward()
41 | opt.step()
42 |
43 | # Save encoder weights for later
44 | torch.save(model.encoder.state_dict(), "pretrained.pt")
45 | ```
46 |
47 | ### Use pretrained weights in a classifier
48 |
49 | ```python
50 | from simsiam.models import ResNet
51 |
52 | # just a wrapper around encoder + linear classifier networks
53 | model = ResNet(
54 | backbone="resnet50", # Same as during pretraining
55 | num_classes=10, # number of output neurons
56 | pretrained=False, # Whether to load pretrained imagenet weights
57 | freeze=True # Freeze the encoder weights (or not)
58 | )
59 |
60 | # Load the pretrained weights from SimSiam
61 | model.encoder.load_state_dict(torch.load("pretrained.pt"))
62 |
63 | model = model.to("cuda")
64 | model.train()
65 |
66 | transforms = ...
67 | dataset = ...
68 | dataloader = ...
69 | opt = optim.SGD(model.parameters())
70 | loss_func = nn.CrossEntropyLoss()
71 |
72 | # Train on your small labeled train set
73 | for epoch in range(epochs):
74 | for batch, (x, y) in enumerate(dataloader):
75 | opt.zero_grad()
76 | y_pred = model(x)
77 | loss = loss_func(y_pred, y)
78 | loss.backward()
79 | opt.step()
80 |
81 | ```
82 |
83 | ### Install dependencies
84 |
85 | ```bash
86 | pip install -r requirements.txt
87 |
88 | ```
89 |
90 | ### Train on STL-10 dataset
91 |
92 | Modify pretrain.yaml to your liking and run
93 |
94 | ```python
95 | python pretrain.py --cfg configs/pretrain.json
96 |
97 | ```
98 |
99 | ### View logs in tensorboard
100 |
101 | ```python
102 | tensorboard --logdir=logs
103 |
104 | ```
105 |
--------------------------------------------------------------------------------
/assets/models.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isaaccorley/simsiam-pytorch/2570656caa1101e2839abdcff94ed546c61dd7f3/assets/models.png
--------------------------------------------------------------------------------
/configs/baseline.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "name": "baseline",
4 | "backbone": "resnet50",
5 | "freeze": false,
6 | "weights_path": ""
7 | },
8 | "data": {
9 | "path": "data/",
10 | "input_shape": [
11 | 96,
12 | 96
13 | ],
14 | "num_classes": 10
15 | },
16 | "train": {
17 | "epochs": 100,
18 | "batch_size": 128,
19 | "lr": 3E-4,
20 | "momentum": 0.9,
21 | "weight_decay": 0.0001,
22 | "log_interval": 100
23 | },
24 | "device": "cuda"
25 | }
--------------------------------------------------------------------------------
/configs/pretrain.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "name": "pretrained",
4 | "backbone": "resnet50",
5 | "latent_dim": 2048,
6 | "proj_hidden_dim": 2048,
7 | "pred_hidden_dim": 512
8 | },
9 | "data": {
10 | "path": "data/",
11 | "input_shape": [
12 | 96,
13 | 96
14 | ],
15 | "num_classes": 10
16 | },
17 | "train": {
18 | "epochs": 100,
19 | "batch_size": 128,
20 | "lr": 0.05,
21 | "momentum": 0.9,
22 | "weight_decay": 0.0001,
23 | "log_interval": 100
24 | },
25 | "device": "cuda"
26 | }
--------------------------------------------------------------------------------
/configs/train.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "name": "fine-tuned",
4 | "backbone": "resnet50",
5 | "freeze": true,
6 | "weights_path": "models/pretrained.pt"
7 | },
8 | "data": {
9 | "path": "data/",
10 | "input_shape": [
11 | 96,
12 | 96
13 | ],
14 | "num_classes": 10
15 | },
16 | "train": {
17 | "epochs": 90,
18 | "batch_size": 128,
19 | "lr": 3E-4,
20 | "momentum": 0.9,
21 | "weight_decay": 0.0,
22 | "log_interval": 100
23 | },
24 | "device": "cuda"
25 | }
--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from types import SimpleNamespace
5 |
6 | import torch
7 | import torchvision
8 | from tqdm.auto import tqdm
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from simsiam.models import SimSiam
12 | from simsiam.losses import negative_cosine_similarity
13 | from simsiam.transforms import load_transforms, augment_transforms
14 |
15 |
16 | def main(cfg: SimpleNamespace) -> None:
17 |
18 | model = SimSiam(
19 | backbone=cfg.model.backbone,
20 | latent_dim=cfg.model.latent_dim,
21 | proj_hidden_dim=cfg.model.proj_hidden_dim,
22 | pred_hidden_dim=cfg.model.pred_hidden_dim
23 | )
24 | model = model.to(cfg.device)
25 | model.train()
26 |
27 | opt = torch.optim.SGD(
28 | params=model.parameters(),
29 | lr=cfg.train.lr,
30 | momentum=cfg.train.momentum,
31 | weight_decay=cfg.train.weight_decay
32 | )
33 |
34 | dataset = torchvision.datasets.STL10(
35 | root=cfg.data.path,
36 | split="train",
37 | transform=load_transforms(input_shape=cfg.data.input_shape),
38 | download=True
39 | )
40 |
41 | dataloader = torch.utils.data.DataLoader(
42 | dataset=dataset,
43 | batch_size=cfg.train.batch_size,
44 | shuffle=True,
45 | drop_last=True,
46 | pin_memory=True,
47 | num_workers=torch.multiprocessing.cpu_count()
48 | )
49 |
50 | transforms = augment_transforms(
51 | input_shape=cfg.data.input_shape,
52 | device=cfg.device
53 | )
54 |
55 | writer = SummaryWriter()
56 |
57 | n_iter = 0
58 | for epoch in range(cfg.train.epochs):
59 |
60 | pbar = tqdm(enumerate(dataloader), total=len(dataloader), position=0, leave=False)
61 | for batch, (x, y) in pbar:
62 |
63 | opt.zero_grad()
64 |
65 | x = x.to(cfg.device)
66 |
67 | # augment
68 | x1, x2 = transforms(x), transforms(x)
69 |
70 | # encode
71 | e1, e2 = model.encode(x1), model.encode(x2)
72 |
73 | # project
74 | z1, z2 = model.project(e1), model.project(e2)
75 |
76 | # predict
77 | p1, p2 = model.predict(z1), model.predict(z2)
78 |
79 | # compute loss
80 | loss1 = negative_cosine_similarity(p1, z1)
81 | loss2 = negative_cosine_similarity(p2, z2)
82 | loss = loss1/2 + loss2/2
83 | loss.backward()
84 | opt.step()
85 |
86 | pbar.set_description("Epoch {}, Loss: {:.4f}".format(epoch, float(loss)))
87 |
88 | if n_iter % cfg.train.log_interval == 0:
89 | writer.add_scalar(tag="loss", scalar_value=float(loss), global_step=n_iter)
90 |
91 | n_iter += 1
92 |
93 | # save checkpoint
94 | torch.save(model.encoder.state_dict(), os.path.join(writer.log_dir, cfg.model.name + ".pt"))
95 |
96 |
97 | if __name__ == "__main__":
98 |
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument("--cfg", type=str, required=True, help="Path to config json file")
101 | args = parser.parse_args()
102 |
103 | with open(args.cfg, "r") as f:
104 | cfg = json.loads(f.read(), object_hook=lambda d: SimpleNamespace(**d))
105 |
106 | main(cfg)
107 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | Pillow
3 | torch
4 | torchvision
5 | kornia
--------------------------------------------------------------------------------
/simsiam/__init__.py:
--------------------------------------------------------------------------------
1 | from . import losses
2 | from . import models
3 | from . import transforms
4 |
--------------------------------------------------------------------------------
/simsiam/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def negative_cosine_similarity(
6 | p: torch.Tensor,
7 | z: torch.Tensor
8 | ) -> torch.Tensor:
9 | """ D(p, z) = -(p*z).sum(dim=1).mean() """
10 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
11 |
--------------------------------------------------------------------------------
/simsiam/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | class LinearClassifier(nn.Module):
7 |
8 | def __init__(
9 | self,
10 | input_dim: int,
11 | num_classes: int,
12 | ) -> None:
13 | super().__init__()
14 | self.model = nn.Sequential(
15 | nn.Linear(input_dim, num_classes)
16 | )
17 |
18 | def forward(self, x: torch.Tensor) -> torch.Tensor:
19 | return self.model(x)
20 |
21 |
22 | class ProjectionMLP(nn.Module):
23 |
24 | def __init__(
25 | self,
26 | input_dim: int,
27 | hidden_dim: int,
28 | output_dim: int
29 | ) -> None:
30 | super().__init__()
31 | self.model = nn.Sequential(
32 | nn.Linear(input_dim, hidden_dim),
33 | nn.BatchNorm1d(hidden_dim),
34 | nn.ReLU(inplace=True),
35 | nn.Linear(hidden_dim, hidden_dim),
36 | nn.BatchNorm1d(hidden_dim),
37 | nn.ReLU(inplace=True),
38 | nn.Linear(hidden_dim, output_dim),
39 | nn.BatchNorm1d(output_dim),
40 | )
41 |
42 | def forward(self, x: torch.Tensor) -> torch.Tensor:
43 | return self.model(x)
44 |
45 |
46 | class PredictorMLP(nn.Module):
47 |
48 | def __init__(
49 | self,
50 | input_dim: int,
51 | hidden_dim: int,
52 | output_dim: int
53 | ) -> None:
54 | super().__init__()
55 | self.model = nn.Sequential(
56 | nn.Linear(output_dim, hidden_dim),
57 | nn.BatchNorm1d(hidden_dim),
58 | nn.ReLU(inplace=True),
59 | nn.Linear(hidden_dim, output_dim)
60 | )
61 |
62 | def forward(self, x: torch.Tensor) -> torch.Tensor:
63 | return self.model(x)
64 |
65 |
66 | class Encoder(nn.Module):
67 |
68 | def __init__(
69 | self,
70 | backbone: str,
71 | pretrained: bool
72 | ):
73 | super().__init__()
74 | resnet = getattr(torchvision.models, backbone)(pretrained=pretrained)
75 | self.emb_dim = resnet.fc.in_features
76 | self.model = nn.Sequential(*list(resnet.children())[:-1])
77 |
78 | def forward(self, x: torch.Tensor) -> torch.Tensor:
79 | return self.model(x).squeeze()
80 |
81 |
82 | class SimSiam(nn.Module):
83 |
84 | def __init__(
85 | self,
86 | backbone: str,
87 | latent_dim: int,
88 | proj_hidden_dim: int,
89 | pred_hidden_dim: int,
90 | ) -> None:
91 |
92 | super().__init__()
93 |
94 | # Encoder network
95 | self.encoder = Encoder(backbone=backbone, pretrained=False)
96 |
97 | # Projection (mlp) network
98 | self.projection_mlp = ProjectionMLP(
99 | input_dim=self.encoder.emb_dim,
100 | hidden_dim=proj_hidden_dim,
101 | output_dim=latent_dim
102 | )
103 |
104 | # Predictor network (h)
105 | self.predictor_mlp = PredictorMLP(
106 | input_dim=latent_dim,
107 | hidden_dim=pred_hidden_dim,
108 | output_dim=latent_dim
109 | )
110 |
111 | def forward(self, x: torch.Tensor):
112 | return self.encode(x)
113 |
114 | def encode(self, x: torch.Tensor) -> torch.Tensor:
115 | return self.encoder(x)
116 |
117 | def project(self, e: torch.Tensor) -> torch.Tensor:
118 | return self.projection_mlp(e)
119 |
120 | def predict(self, z: torch.Tensor) -> torch.Tensor:
121 | return self.predictor_mlp(z)
122 |
123 |
124 | class ResNet(nn.Module):
125 |
126 | def __init__(
127 | self,
128 | backbone: str,
129 | num_classes: int,
130 | pretrained: bool,
131 | freeze: bool
132 | ) -> None:
133 |
134 | super().__init__()
135 |
136 | # Encoder network
137 | self.encoder = Encoder(backbone=backbone, pretrained=pretrained)
138 |
139 | if freeze:
140 | for param in self.encoder.parameters():
141 | param.requres_grad = False
142 |
143 | # Linear classifier
144 | self.classifier = LinearClassifier(self.encoder.emb_dim, num_classes)
145 |
146 | def forward(self, x: torch.Tensor) -> torch.Tensor:
147 | e = self.encoder(x)
148 | return self.classifier(e)
149 |
--------------------------------------------------------------------------------
/simsiam/transforms.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict
2 |
3 | import kornia
4 | import torch
5 | import torch.nn as nn
6 | import torchvision.transforms as T
7 | from PIL import Image
8 |
9 |
10 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
11 | IMAGENET_STD = (0.229, 0.224, 0.225)
12 |
13 |
14 | class RandomGaussianBlur2D(kornia.augmentation.AugmentationBase2D):
15 |
16 | def __init__(
17 | self,
18 | kernel_size: Tuple[int, int],
19 | sigma: Tuple[float, float],
20 | border_type: str = "reflect",
21 | return_transform: bool = False,
22 | same_on_batch: bool = False,
23 | p: float = 0.1
24 | ) -> None:
25 | super(RandomGaussianBlur2D, self).__init__(
26 | p=p,
27 | return_transform=return_transform,
28 | same_on_batch=same_on_batch
29 | )
30 |
31 | self.kernel_size = kernel_size
32 | self.sigma = sigma
33 | self.border_type = border_type
34 |
35 | def __repr__(self) -> str:
36 | return self.__class__.__name__ + f"({super().__repr__()})"
37 |
38 | def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
39 | return dict()
40 |
41 | def compute_transformation(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor:
42 | return None
43 |
44 | def apply_transform(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor:
45 | return kornia.filters.gaussian_blur2d(
46 | input=input,
47 | kernel_size=self.kernel_size,
48 | sigma=self.sigma,
49 | border_type=self.border_type
50 | )
51 |
52 |
53 | def augment_transforms(
54 | input_shape: Tuple[int, int],
55 | device: str = "cuda" if torch.cuda.is_available() else "cpu"
56 | ) -> nn.Sequential:
57 |
58 | augs = nn.Sequential(
59 | kornia.augmentation.ColorJitter(
60 | brightness=0.4,
61 | contrast=0.4,
62 | saturation=0.4,
63 | hue=0.1,
64 | p=0.8
65 | ),
66 | kornia.augmentation.RandomGrayscale(p=0.2),
67 | RandomGaussianBlur2D(
68 | kernel_size=(3, 3),
69 | sigma=(0.1, 2.0),
70 | p=0.5
71 | ),
72 | kornia.augmentation.RandomResizedCrop(
73 | size=input_shape,
74 | scale=(0.2, 1.0),
75 | ratio=(0.75, 1.33),
76 | interpolation="bilinear",
77 | p=1.0
78 | ),
79 | kornia.augmentation.RandomHorizontalFlip(p=0.5),
80 | kornia.augmentation.Normalize(
81 | mean=torch.tensor(IMAGENET_MEAN),
82 | std=torch.tensor(IMAGENET_STD)
83 | )
84 | )
85 | augs = augs.to(device)
86 | return augs
87 |
88 |
89 | def load_transforms(input_shape: Tuple[int, int]) -> T.Compose:
90 | return T.Compose([
91 | T.Resize(size=input_shape, interpolation=Image.BILINEAR),
92 | T.ToTensor(),
93 | ])
94 |
95 |
96 | def test_transforms(input_shape: Tuple[int, int]) -> T.Compose:
97 | return T.Compose([
98 | T.Resize(size=input_shape, interpolation=Image.BILINEAR),
99 | T.ToTensor(),
100 | T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
101 | ])
102 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from types import SimpleNamespace
5 |
6 | import torch
7 | import torchvision
8 | from tqdm.auto import tqdm
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from simsiam.models import ResNet, LinearClassifier
12 | from simsiam.transforms import load_transforms, augment_transforms
13 |
14 |
15 | def main(cfg: SimpleNamespace) -> None:
16 |
17 | model = ResNet(
18 | backbone=cfg.model.backbone,
19 | num_classes=cfg.data.num_classes,
20 | pretrained=False,
21 | freeze=cfg.model.freeze
22 | )
23 |
24 | if cfg.model.weights_path:
25 | model.encoder.load_state_dict(torch.load(cfg.model.weights_path))
26 |
27 | model = model.to(cfg.device)
28 |
29 | opt = torch.optim.SGD(
30 | params=model.parameters(),
31 | lr=cfg.train.lr,
32 | momentum=cfg.train.momentum,
33 | weight_decay=cfg.train.weight_decay
34 | )
35 | loss_func = torch.nn.CrossEntropyLoss()
36 |
37 | dataset = torchvision.datasets.STL10(
38 | root=cfg.data.path,
39 | split="train",
40 | transform=load_transforms(input_shape=cfg.data.input_shape),
41 | download=True
42 | )
43 |
44 | dataloader = torch.utils.data.DataLoader(
45 | dataset=dataset,
46 | batch_size=cfg.train.batch_size,
47 | shuffle=True,
48 | drop_last=True,
49 | pin_memory=True,
50 | num_workers=torch.multiprocessing.cpu_count()
51 | )
52 |
53 | transforms = augment_transforms(
54 | input_shape=cfg.data.input_shape,
55 | device=cfg.device
56 | )
57 |
58 | writer = SummaryWriter()
59 |
60 | n_iter = 0
61 | for epoch in range(cfg.train.epochs):
62 |
63 | pbar = tqdm(enumerate(dataloader), total=len(dataloader))
64 | for batch, (x, y) in pbar:
65 |
66 | opt.zero_grad()
67 |
68 | x, y = x.to(cfg.device), y.to(cfg.device)
69 | x = transforms(x)
70 | y_pred = model(x)
71 | loss = loss_func(y_pred, y)
72 | loss.backward()
73 | opt.step()
74 |
75 | pbar.set_description("Epoch {}, Loss: {:.4f}".format(epoch, float(loss)))
76 |
77 | if n_iter % cfg.train.log_interval == 0:
78 | writer.add_scalar(tag="loss", scalar_value=float(loss), global_step=n_iter)
79 |
80 | n_iter += 1
81 |
82 | # save checkpoint
83 | torch.save(model.state_dict(), os.path.join(writer.log_dir, cfg.model.name + ".pt"))
84 |
85 |
86 | if __name__ == "__main__":
87 |
88 | parser = argparse.ArgumentParser()
89 | parser.add_argument("--cfg", type=str, required=True, help="Path to config json file")
90 | args = parser.parse_args()
91 |
92 | with open(args.cfg, "r") as f:
93 | cfg = json.loads(f.read(), object_hook=lambda d: SimpleNamespace(**d))
94 |
95 | main(cfg)
96 |
--------------------------------------------------------------------------------