├── .gitignore ├── LICENSE ├── README.md ├── assets └── models.png ├── configs ├── baseline.json ├── pretrain.json └── train.json ├── pretrain.py ├── requirements.txt ├── simsiam ├── __init__.py ├── losses.py ├── models.py └── transforms.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .mypy_cache/ 3 | __pycache__/ 4 | runs/ 5 | data/ 6 | models/ 7 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 isaac 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simsiam-pytorch 2 | Minimal PyTorch Implementation of SimSiam from ["Exploring Simple Siamese Representation Learning" by Chen et al.](https://arxiv.org/abs/2011.10566) 3 | 4 |

5 | 6 | 7 | ### Load and train on a custom dataset 8 | 9 | ```python 10 | from simsiam.models import SimSiam 11 | from simsiam.losses import negative_cosine_similarity 12 | 13 | model = SimSiam( 14 | backbone="resnet50", # encoder network 15 | latent_dim=2048, # predictor network output size 16 | proj_hidden_dim=2048 # projection mlp hidden layer size 17 | pred_hidden_dim=512 # predictor mlp hidden layer size 18 | ) 19 | model = model.to("cuda") # use all the parallels 20 | model.train() 21 | 22 | transforms = ... 23 | dataset = ... 24 | dataloader = ... 25 | opt = ... 26 | 27 | for epoch in range(epochs): 28 | for batch, (x, y) in enumerate(dataloader): 29 | opt.zero_grad() 30 | 31 | x1, x2 = transforms(x), transforms(x) # Augment 32 | e1, e2 = model.encode(x1), model.encode(x2) # Encode 33 | z1, z2 = model.project(e1), model.project(e2) # Project 34 | p1, p2 = model.predict(z1), model.predict(z2) # Predict 35 | 36 | # Compute loss 37 | loss1 = negative_cosine_similarity(p1, z1) 38 | loss2 = negative_cosine_similarity(p2, z2) 39 | loss = loss1/2 + loss2/2 40 | loss.backward() 41 | opt.step() 42 | 43 | # Save encoder weights for later 44 | torch.save(model.encoder.state_dict(), "pretrained.pt") 45 | ``` 46 | 47 | ### Use pretrained weights in a classifier 48 | 49 | ```python 50 | from simsiam.models import ResNet 51 | 52 | # just a wrapper around encoder + linear classifier networks 53 | model = ResNet( 54 | backbone="resnet50", # Same as during pretraining 55 | num_classes=10, # number of output neurons 56 | pretrained=False, # Whether to load pretrained imagenet weights 57 | freeze=True # Freeze the encoder weights (or not) 58 | ) 59 | 60 | # Load the pretrained weights from SimSiam 61 | model.encoder.load_state_dict(torch.load("pretrained.pt")) 62 | 63 | model = model.to("cuda") 64 | model.train() 65 | 66 | transforms = ... 67 | dataset = ... 68 | dataloader = ... 69 | opt = optim.SGD(model.parameters()) 70 | loss_func = nn.CrossEntropyLoss() 71 | 72 | # Train on your small labeled train set 73 | for epoch in range(epochs): 74 | for batch, (x, y) in enumerate(dataloader): 75 | opt.zero_grad() 76 | y_pred = model(x) 77 | loss = loss_func(y_pred, y) 78 | loss.backward() 79 | opt.step() 80 | 81 | ``` 82 | 83 | ### Install dependencies 84 | 85 | ```bash 86 | pip install -r requirements.txt 87 | 88 | ``` 89 | 90 | ### Train on STL-10 dataset 91 | 92 | Modify pretrain.yaml to your liking and run 93 | 94 | ```python 95 | python pretrain.py --cfg configs/pretrain.json 96 | 97 | ``` 98 | 99 | ### View logs in tensorboard 100 | 101 | ```python 102 | tensorboard --logdir=logs 103 | 104 | ``` 105 | -------------------------------------------------------------------------------- /assets/models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/simsiam-pytorch/2570656caa1101e2839abdcff94ed546c61dd7f3/assets/models.png -------------------------------------------------------------------------------- /configs/baseline.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "baseline", 4 | "backbone": "resnet50", 5 | "freeze": false, 6 | "weights_path": "" 7 | }, 8 | "data": { 9 | "path": "data/", 10 | "input_shape": [ 11 | 96, 12 | 96 13 | ], 14 | "num_classes": 10 15 | }, 16 | "train": { 17 | "epochs": 100, 18 | "batch_size": 128, 19 | "lr": 3E-4, 20 | "momentum": 0.9, 21 | "weight_decay": 0.0001, 22 | "log_interval": 100 23 | }, 24 | "device": "cuda" 25 | } -------------------------------------------------------------------------------- /configs/pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "pretrained", 4 | "backbone": "resnet50", 5 | "latent_dim": 2048, 6 | "proj_hidden_dim": 2048, 7 | "pred_hidden_dim": 512 8 | }, 9 | "data": { 10 | "path": "data/", 11 | "input_shape": [ 12 | 96, 13 | 96 14 | ], 15 | "num_classes": 10 16 | }, 17 | "train": { 18 | "epochs": 100, 19 | "batch_size": 128, 20 | "lr": 0.05, 21 | "momentum": 0.9, 22 | "weight_decay": 0.0001, 23 | "log_interval": 100 24 | }, 25 | "device": "cuda" 26 | } -------------------------------------------------------------------------------- /configs/train.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "fine-tuned", 4 | "backbone": "resnet50", 5 | "freeze": true, 6 | "weights_path": "models/pretrained.pt" 7 | }, 8 | "data": { 9 | "path": "data/", 10 | "input_shape": [ 11 | 96, 12 | 96 13 | ], 14 | "num_classes": 10 15 | }, 16 | "train": { 17 | "epochs": 90, 18 | "batch_size": 128, 19 | "lr": 3E-4, 20 | "momentum": 0.9, 21 | "weight_decay": 0.0, 22 | "log_interval": 100 23 | }, 24 | "device": "cuda" 25 | } -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from types import SimpleNamespace 5 | 6 | import torch 7 | import torchvision 8 | from tqdm.auto import tqdm 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from simsiam.models import SimSiam 12 | from simsiam.losses import negative_cosine_similarity 13 | from simsiam.transforms import load_transforms, augment_transforms 14 | 15 | 16 | def main(cfg: SimpleNamespace) -> None: 17 | 18 | model = SimSiam( 19 | backbone=cfg.model.backbone, 20 | latent_dim=cfg.model.latent_dim, 21 | proj_hidden_dim=cfg.model.proj_hidden_dim, 22 | pred_hidden_dim=cfg.model.pred_hidden_dim 23 | ) 24 | model = model.to(cfg.device) 25 | model.train() 26 | 27 | opt = torch.optim.SGD( 28 | params=model.parameters(), 29 | lr=cfg.train.lr, 30 | momentum=cfg.train.momentum, 31 | weight_decay=cfg.train.weight_decay 32 | ) 33 | 34 | dataset = torchvision.datasets.STL10( 35 | root=cfg.data.path, 36 | split="train", 37 | transform=load_transforms(input_shape=cfg.data.input_shape), 38 | download=True 39 | ) 40 | 41 | dataloader = torch.utils.data.DataLoader( 42 | dataset=dataset, 43 | batch_size=cfg.train.batch_size, 44 | shuffle=True, 45 | drop_last=True, 46 | pin_memory=True, 47 | num_workers=torch.multiprocessing.cpu_count() 48 | ) 49 | 50 | transforms = augment_transforms( 51 | input_shape=cfg.data.input_shape, 52 | device=cfg.device 53 | ) 54 | 55 | writer = SummaryWriter() 56 | 57 | n_iter = 0 58 | for epoch in range(cfg.train.epochs): 59 | 60 | pbar = tqdm(enumerate(dataloader), total=len(dataloader), position=0, leave=False) 61 | for batch, (x, y) in pbar: 62 | 63 | opt.zero_grad() 64 | 65 | x = x.to(cfg.device) 66 | 67 | # augment 68 | x1, x2 = transforms(x), transforms(x) 69 | 70 | # encode 71 | e1, e2 = model.encode(x1), model.encode(x2) 72 | 73 | # project 74 | z1, z2 = model.project(e1), model.project(e2) 75 | 76 | # predict 77 | p1, p2 = model.predict(z1), model.predict(z2) 78 | 79 | # compute loss 80 | loss1 = negative_cosine_similarity(p1, z1) 81 | loss2 = negative_cosine_similarity(p2, z2) 82 | loss = loss1/2 + loss2/2 83 | loss.backward() 84 | opt.step() 85 | 86 | pbar.set_description("Epoch {}, Loss: {:.4f}".format(epoch, float(loss))) 87 | 88 | if n_iter % cfg.train.log_interval == 0: 89 | writer.add_scalar(tag="loss", scalar_value=float(loss), global_step=n_iter) 90 | 91 | n_iter += 1 92 | 93 | # save checkpoint 94 | torch.save(model.encoder.state_dict(), os.path.join(writer.log_dir, cfg.model.name + ".pt")) 95 | 96 | 97 | if __name__ == "__main__": 98 | 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument("--cfg", type=str, required=True, help="Path to config json file") 101 | args = parser.parse_args() 102 | 103 | with open(args.cfg, "r") as f: 104 | cfg = json.loads(f.read(), object_hook=lambda d: SimpleNamespace(**d)) 105 | 106 | main(cfg) 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | Pillow 3 | torch 4 | torchvision 5 | kornia -------------------------------------------------------------------------------- /simsiam/__init__.py: -------------------------------------------------------------------------------- 1 | from . import losses 2 | from . import models 3 | from . import transforms 4 | -------------------------------------------------------------------------------- /simsiam/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def negative_cosine_similarity( 6 | p: torch.Tensor, 7 | z: torch.Tensor 8 | ) -> torch.Tensor: 9 | """ D(p, z) = -(p*z).sum(dim=1).mean() """ 10 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean() 11 | -------------------------------------------------------------------------------- /simsiam/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class LinearClassifier(nn.Module): 7 | 8 | def __init__( 9 | self, 10 | input_dim: int, 11 | num_classes: int, 12 | ) -> None: 13 | super().__init__() 14 | self.model = nn.Sequential( 15 | nn.Linear(input_dim, num_classes) 16 | ) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | return self.model(x) 20 | 21 | 22 | class ProjectionMLP(nn.Module): 23 | 24 | def __init__( 25 | self, 26 | input_dim: int, 27 | hidden_dim: int, 28 | output_dim: int 29 | ) -> None: 30 | super().__init__() 31 | self.model = nn.Sequential( 32 | nn.Linear(input_dim, hidden_dim), 33 | nn.BatchNorm1d(hidden_dim), 34 | nn.ReLU(inplace=True), 35 | nn.Linear(hidden_dim, hidden_dim), 36 | nn.BatchNorm1d(hidden_dim), 37 | nn.ReLU(inplace=True), 38 | nn.Linear(hidden_dim, output_dim), 39 | nn.BatchNorm1d(output_dim), 40 | ) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | return self.model(x) 44 | 45 | 46 | class PredictorMLP(nn.Module): 47 | 48 | def __init__( 49 | self, 50 | input_dim: int, 51 | hidden_dim: int, 52 | output_dim: int 53 | ) -> None: 54 | super().__init__() 55 | self.model = nn.Sequential( 56 | nn.Linear(output_dim, hidden_dim), 57 | nn.BatchNorm1d(hidden_dim), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(hidden_dim, output_dim) 60 | ) 61 | 62 | def forward(self, x: torch.Tensor) -> torch.Tensor: 63 | return self.model(x) 64 | 65 | 66 | class Encoder(nn.Module): 67 | 68 | def __init__( 69 | self, 70 | backbone: str, 71 | pretrained: bool 72 | ): 73 | super().__init__() 74 | resnet = getattr(torchvision.models, backbone)(pretrained=pretrained) 75 | self.emb_dim = resnet.fc.in_features 76 | self.model = nn.Sequential(*list(resnet.children())[:-1]) 77 | 78 | def forward(self, x: torch.Tensor) -> torch.Tensor: 79 | return self.model(x).squeeze() 80 | 81 | 82 | class SimSiam(nn.Module): 83 | 84 | def __init__( 85 | self, 86 | backbone: str, 87 | latent_dim: int, 88 | proj_hidden_dim: int, 89 | pred_hidden_dim: int, 90 | ) -> None: 91 | 92 | super().__init__() 93 | 94 | # Encoder network 95 | self.encoder = Encoder(backbone=backbone, pretrained=False) 96 | 97 | # Projection (mlp) network 98 | self.projection_mlp = ProjectionMLP( 99 | input_dim=self.encoder.emb_dim, 100 | hidden_dim=proj_hidden_dim, 101 | output_dim=latent_dim 102 | ) 103 | 104 | # Predictor network (h) 105 | self.predictor_mlp = PredictorMLP( 106 | input_dim=latent_dim, 107 | hidden_dim=pred_hidden_dim, 108 | output_dim=latent_dim 109 | ) 110 | 111 | def forward(self, x: torch.Tensor): 112 | return self.encode(x) 113 | 114 | def encode(self, x: torch.Tensor) -> torch.Tensor: 115 | return self.encoder(x) 116 | 117 | def project(self, e: torch.Tensor) -> torch.Tensor: 118 | return self.projection_mlp(e) 119 | 120 | def predict(self, z: torch.Tensor) -> torch.Tensor: 121 | return self.predictor_mlp(z) 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__( 127 | self, 128 | backbone: str, 129 | num_classes: int, 130 | pretrained: bool, 131 | freeze: bool 132 | ) -> None: 133 | 134 | super().__init__() 135 | 136 | # Encoder network 137 | self.encoder = Encoder(backbone=backbone, pretrained=pretrained) 138 | 139 | if freeze: 140 | for param in self.encoder.parameters(): 141 | param.requres_grad = False 142 | 143 | # Linear classifier 144 | self.classifier = LinearClassifier(self.encoder.emb_dim, num_classes) 145 | 146 | def forward(self, x: torch.Tensor) -> torch.Tensor: 147 | e = self.encoder(x) 148 | return self.classifier(e) 149 | -------------------------------------------------------------------------------- /simsiam/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict 2 | 3 | import kornia 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms as T 7 | from PIL import Image 8 | 9 | 10 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 11 | IMAGENET_STD = (0.229, 0.224, 0.225) 12 | 13 | 14 | class RandomGaussianBlur2D(kornia.augmentation.AugmentationBase2D): 15 | 16 | def __init__( 17 | self, 18 | kernel_size: Tuple[int, int], 19 | sigma: Tuple[float, float], 20 | border_type: str = "reflect", 21 | return_transform: bool = False, 22 | same_on_batch: bool = False, 23 | p: float = 0.1 24 | ) -> None: 25 | super(RandomGaussianBlur2D, self).__init__( 26 | p=p, 27 | return_transform=return_transform, 28 | same_on_batch=same_on_batch 29 | ) 30 | 31 | self.kernel_size = kernel_size 32 | self.sigma = sigma 33 | self.border_type = border_type 34 | 35 | def __repr__(self) -> str: 36 | return self.__class__.__name__ + f"({super().__repr__()})" 37 | 38 | def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: 39 | return dict() 40 | 41 | def compute_transformation(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: 42 | return None 43 | 44 | def apply_transform(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: 45 | return kornia.filters.gaussian_blur2d( 46 | input=input, 47 | kernel_size=self.kernel_size, 48 | sigma=self.sigma, 49 | border_type=self.border_type 50 | ) 51 | 52 | 53 | def augment_transforms( 54 | input_shape: Tuple[int, int], 55 | device: str = "cuda" if torch.cuda.is_available() else "cpu" 56 | ) -> nn.Sequential: 57 | 58 | augs = nn.Sequential( 59 | kornia.augmentation.ColorJitter( 60 | brightness=0.4, 61 | contrast=0.4, 62 | saturation=0.4, 63 | hue=0.1, 64 | p=0.8 65 | ), 66 | kornia.augmentation.RandomGrayscale(p=0.2), 67 | RandomGaussianBlur2D( 68 | kernel_size=(3, 3), 69 | sigma=(0.1, 2.0), 70 | p=0.5 71 | ), 72 | kornia.augmentation.RandomResizedCrop( 73 | size=input_shape, 74 | scale=(0.2, 1.0), 75 | ratio=(0.75, 1.33), 76 | interpolation="bilinear", 77 | p=1.0 78 | ), 79 | kornia.augmentation.RandomHorizontalFlip(p=0.5), 80 | kornia.augmentation.Normalize( 81 | mean=torch.tensor(IMAGENET_MEAN), 82 | std=torch.tensor(IMAGENET_STD) 83 | ) 84 | ) 85 | augs = augs.to(device) 86 | return augs 87 | 88 | 89 | def load_transforms(input_shape: Tuple[int, int]) -> T.Compose: 90 | return T.Compose([ 91 | T.Resize(size=input_shape, interpolation=Image.BILINEAR), 92 | T.ToTensor(), 93 | ]) 94 | 95 | 96 | def test_transforms(input_shape: Tuple[int, int]) -> T.Compose: 97 | return T.Compose([ 98 | T.Resize(size=input_shape, interpolation=Image.BILINEAR), 99 | T.ToTensor(), 100 | T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 101 | ]) 102 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from types import SimpleNamespace 5 | 6 | import torch 7 | import torchvision 8 | from tqdm.auto import tqdm 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from simsiam.models import ResNet, LinearClassifier 12 | from simsiam.transforms import load_transforms, augment_transforms 13 | 14 | 15 | def main(cfg: SimpleNamespace) -> None: 16 | 17 | model = ResNet( 18 | backbone=cfg.model.backbone, 19 | num_classes=cfg.data.num_classes, 20 | pretrained=False, 21 | freeze=cfg.model.freeze 22 | ) 23 | 24 | if cfg.model.weights_path: 25 | model.encoder.load_state_dict(torch.load(cfg.model.weights_path)) 26 | 27 | model = model.to(cfg.device) 28 | 29 | opt = torch.optim.SGD( 30 | params=model.parameters(), 31 | lr=cfg.train.lr, 32 | momentum=cfg.train.momentum, 33 | weight_decay=cfg.train.weight_decay 34 | ) 35 | loss_func = torch.nn.CrossEntropyLoss() 36 | 37 | dataset = torchvision.datasets.STL10( 38 | root=cfg.data.path, 39 | split="train", 40 | transform=load_transforms(input_shape=cfg.data.input_shape), 41 | download=True 42 | ) 43 | 44 | dataloader = torch.utils.data.DataLoader( 45 | dataset=dataset, 46 | batch_size=cfg.train.batch_size, 47 | shuffle=True, 48 | drop_last=True, 49 | pin_memory=True, 50 | num_workers=torch.multiprocessing.cpu_count() 51 | ) 52 | 53 | transforms = augment_transforms( 54 | input_shape=cfg.data.input_shape, 55 | device=cfg.device 56 | ) 57 | 58 | writer = SummaryWriter() 59 | 60 | n_iter = 0 61 | for epoch in range(cfg.train.epochs): 62 | 63 | pbar = tqdm(enumerate(dataloader), total=len(dataloader)) 64 | for batch, (x, y) in pbar: 65 | 66 | opt.zero_grad() 67 | 68 | x, y = x.to(cfg.device), y.to(cfg.device) 69 | x = transforms(x) 70 | y_pred = model(x) 71 | loss = loss_func(y_pred, y) 72 | loss.backward() 73 | opt.step() 74 | 75 | pbar.set_description("Epoch {}, Loss: {:.4f}".format(epoch, float(loss))) 76 | 77 | if n_iter % cfg.train.log_interval == 0: 78 | writer.add_scalar(tag="loss", scalar_value=float(loss), global_step=n_iter) 79 | 80 | n_iter += 1 81 | 82 | # save checkpoint 83 | torch.save(model.state_dict(), os.path.join(writer.log_dir, cfg.model.name + ".pt")) 84 | 85 | 86 | if __name__ == "__main__": 87 | 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--cfg", type=str, required=True, help="Path to config json file") 90 | args = parser.parse_args() 91 | 92 | with open(args.cfg, "r") as f: 93 | cfg = json.loads(f.read(), object_hook=lambda d: SimpleNamespace(**d)) 94 | 95 | main(cfg) 96 | --------------------------------------------------------------------------------