├── .gitignore ├── LICENSE ├── README.md ├── augmentation.py ├── configs.py ├── critic.py ├── dataset.py ├── evaluate ├── __init__.py ├── checkpoint.py └── lbfgs.py ├── gradient_linear_clf.py ├── lbfgs_linear_clf.py ├── models ├── __init__.py └── resnet.py ├── requirements.txt ├── scheduler.py └── simclr.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | checkpoint 3 | .idea 4 | clan 5 | Pipfile 6 | dataset-paths.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Adam Foster 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reproducing SimCLR in PyTorch 2 | 3 | ## Introduction 4 | This is an *unofficial* [PyTorch](https://github.com/pytorch/pytorch) implementation of the recent 5 | paper ['A Simple Framework for Contrastive Learning of Visual 6 | Representations'](https://arxiv.org/pdf/2002.05709.pdf). The arXiv version of this paper can be cited as 7 | ``` 8 | @article{chen2020simple, 9 | title={A simple framework for contrastive learning of visual representations}, 10 | author={Chen, Ting and Kornblith, Simon and Norouzi, Mohammad and Hinton, Geoffrey}, 11 | journal={arXiv preprint arXiv:2002.05709}, 12 | year={2020} 13 | } 14 | ``` 15 | The focus of this repository is to accurately reproduce the results in the paper using PyTorch. We use the original 16 | paper and the official [tensorflow repo](https://github.com/google-research/simclr) as our sources. 17 | 18 | 19 | ## SimCLR algorithm 20 | ### Data 21 | For comparison with the original paper, we use the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) and 22 | [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/) datasets. This PyTorch version also supports 23 | [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html) and [STL-10](https://cs.stanford.edu/~acoates/stl10/). 24 | 25 | ### Augmentation 26 | The following augmentations are used on the training set 27 | - Random crop and resize. We use `RandomResizedCrop` in PyTorch with cubic interpolation for the resizing. 28 | CIFAR images are resized to 32×32, STL-10 images to 96×96 and ILSVRC2012 images to 224×224. 29 | - Random horizontal flip 30 | - Colour distortion. We use the code provided in the paper (Appendix A) 31 | - Gaussian blur is not yet included 32 | 33 | ### Encoder 34 | We use a ResNet50 as our base architecture. We use the ResNet50 implementation included in `torchvision` with the 35 | following changes: 36 | - Stem adapted to the dataset, for details see `models/resnet.py`. We adapt the stem for CIFAR in the same way as 37 | the original paper: replacing the first 7×7 Conv of stride 2 with 3×3 Conv of stride 1, and also removing the 38 | first max pooling operation. For STL, we use the 3×3 Conv of stride 1 but include a max pool. 39 | - We do not make special adjustments to sync the batch norm means and variances across GPU nodes. 40 | - Remove the final fully connected layer, giving a representation of dimension 2048. 41 | 42 | ### Projection head 43 | The projection head consists of the following: 44 | - MLP projection with one hidden layer with dimensions 2048 -> 2048 -> 128 45 | - Following the tensorflow code, we also include batch norm in the projection head 46 | 47 | ### Loss 48 | We use the NT-Xent loss of the original paper. Specifically, we calculate the `CosineSimilarity` (using PyTorch's 49 | implementation) between each of the `2N` projected representations `z`. We rescale these similarities by temperature. 50 | We set the diagonal similarities to `-inf` and treat the one remaining positive example as the correct category in a 51 | `2N`-way classification task. The scores are fed directly into `CrossEntropyLoss`. 52 | 53 | ### Optimizer 54 | We use the LARS optimizer with `trust_coef=1e-3` to match the tensorflow code. We set the weight decay to `1e-6`. 55 | The 10 epoch linear ramp and cosine annealing of the original paper are implemented and can be activated using 56 | `--cosine-anneal`, otherwise a constant learning rate is used. 57 | 58 | ### Evaluation 59 | On CIFAR-10, we fitted the downstream classifier using L-BFGS with no augmentation on the training set. This is the 60 | approach used in the original paper for transfer learning (and is substantially faster for small datasets). 61 | For ImageNet, we use SGD with the same random resized crop and random flip as for the original training, but no 62 | colour distortion or other augmentations. This is as in the original paper. 63 | 64 | 65 | 66 | ## Running the code 67 | ### Requirements 68 | See `requirements.txt`. Note we require the [torchlars](https://github.com/kakaobrain/torchlars) package. 69 | 70 | ### Dataset locations 71 | The dataset file locations should be specified in a JSON file of the following form 72 | ``` 73 | dataset-paths.json 74 | { 75 | "cifar10": "/data/cifar10/", 76 | "cifar100": "/data/cifar100/", 77 | "stl10": "/data/stl10/", 78 | "imagenet": "/data/imagenet/2012/" 79 | } 80 | ``` 81 | 82 | ### Running with CIFAR-10 83 | Use the following command to train an encoder from scratch on CIFAR-10 84 | ``` 85 | $ python3 simclr.py --num-epochs 1000 --cosine-anneal --filename output.pth --base-lr 1.5 86 | ``` 87 | To evaluate the trained encoder using L-BFGS across a range of regularization parameters 88 | ``` 89 | $ python3 lbfgs_linear_clf.py --load-from output.pth 90 | ``` 91 | 92 | ### Running with ImageNet 93 | Use the following command to train an encoder from scratch on ILSVRC2012 94 | ``` 95 | $ python3 simclr.py --num-epochs 1000 --cosine-anneal --filename output.pth --test-freq 0 --num-workers 32 --dataset imagenet 96 | ``` 97 | To evaluate the trained encoder, use 98 | ``` 99 | $ python3 gradient_linear_clf.py --load-from output.pth --nesterov --num-workers 32 100 | ``` 101 | 102 | 103 | ## Outstanding differences with the original paper 104 | - We do not synchronize the batch norm between multiple GPUs. (To use PyTorch's `SyncBatchNorm`, we would need to 105 | change from using `DataParallel` to `DistributedDataParallel`.) 106 | - We not use Gaussian blur for any datasets, including ILSVRC2012. 107 | - We are not aware of any other discrepancies with the original work, but any correction is more than welcome and 108 | should be suggested by opening an Issue in this repo. 109 | 110 | 111 | ## Reproduction results 112 | ### CIFAR-10 with ResNet50 (1000 epochs) 113 | Method | Test accuracy 114 | --- | --- 115 | SimCLR quoted | 94.0% 116 | SimCLR reproduced (this repo) | 93.5% 117 | 118 | 119 | ## Acknowledgements 120 | The basis for this repository was [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar). 121 | We make use of [torchlars](https://github.com/kakaobrain/torchlars). 122 | -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFilter 2 | from torchvision import transforms 3 | 4 | 5 | def ColourDistortion(s=1.0): 6 | # s is the strength of color distortion. 7 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 8 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 9 | rnd_gray = transforms.RandomGrayscale(p=0.2) 10 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) 11 | return color_distort 12 | 13 | 14 | def BlurOrSharpen(radius=2.): 15 | blur = GaussianBlur(radius=radius) 16 | full_transform = transforms.RandomApply([blur], p=.5) 17 | return full_transform 18 | 19 | 20 | class ImageFilterTransform(object): 21 | 22 | def __init__(self): 23 | raise NotImplementedError 24 | 25 | def __call__(self, img): 26 | return img.filter(self.filter) 27 | 28 | 29 | class GaussianBlur(ImageFilterTransform): 30 | 31 | def __init__(self, radius=2.): 32 | self.filter = ImageFilter.GaussianBlur(radius=radius) 33 | 34 | 35 | class Sharpen(ImageFilterTransform): 36 | 37 | def __init__(self): 38 | self.filter = ImageFilter.SHARPEN 39 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | from augmentation import ColourDistortion 7 | from dataset import * 8 | from models import * 9 | 10 | 11 | def get_datasets(dataset, augment_clf_train=False, add_indices_to_data=False, num_positive=None): 12 | 13 | CACHED_MEAN_STD = { 14 | 'cifar10': ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 15 | 'cifar100': ((0.5071, 0.4865, 0.4409), (0.2009, 0.1984, 0.2023)), 16 | 'stl10': ((0.4409, 0.4279, 0.3868), (0.2309, 0.2262, 0.2237)), 17 | 'imagenet': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 18 | } 19 | 20 | PATHS = { 21 | 'cifar10': '/data/cifar10/', 22 | 'cifar100': '/data/cifar100/', 23 | 'stl10': '/data/stl10/', 24 | 'imagenet': '/data/imagenet/2012/' 25 | } 26 | try: 27 | with open('dataset-paths.json', 'r') as f: 28 | local_paths = json.load(f) 29 | PATHS.update(local_paths) 30 | except FileNotFoundError: 31 | pass 32 | root = PATHS[dataset] 33 | 34 | # Data 35 | if dataset == 'stl10': 36 | img_size = 96 37 | elif dataset == 'imagenet': 38 | img_size = 224 39 | else: 40 | img_size = 32 41 | 42 | transform_train = transforms.Compose([ 43 | transforms.RandomResizedCrop(img_size, interpolation=Image.BICUBIC), 44 | transforms.RandomHorizontalFlip(), 45 | ColourDistortion(s=0.5), 46 | transforms.ToTensor(), 47 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 48 | ]) 49 | 50 | if dataset == 'imagenet': 51 | transform_test = transforms.Compose([ 52 | transforms.Resize(256), 53 | transforms.CenterCrop(224), 54 | transforms.ToTensor(), 55 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 56 | ]) 57 | else: 58 | transform_test = transforms.Compose([ 59 | transforms.ToTensor(), 60 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 61 | ]) 62 | 63 | if augment_clf_train: 64 | transform_clftrain = transforms.Compose([ 65 | transforms.RandomResizedCrop(img_size, interpolation=Image.BICUBIC), 66 | transforms.RandomHorizontalFlip(), 67 | transforms.ToTensor(), 68 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 69 | ]) 70 | else: 71 | transform_clftrain = transform_test 72 | 73 | if dataset == 'cifar100': 74 | if add_indices_to_data: 75 | dset = add_indices(torchvision.datasets.CIFAR100) 76 | else: 77 | dset = torchvision.datasets.CIFAR100 78 | if num_positive is None: 79 | trainset = CIFAR100Biaugment(root=root, train=True, download=True, transform=transform_train) 80 | else: 81 | trainset = CIFAR100Multiaugment(root=root, train=True, download=True, transform=transform_train, 82 | n_augmentations=num_positive) 83 | testset = dset(root=root, train=False, download=True, transform=transform_test) 84 | clftrainset = dset(root=root, train=True, download=True, transform=transform_clftrain) 85 | num_classes = 100 86 | stem = StemCIFAR 87 | elif dataset == 'cifar10': 88 | if add_indices_to_data: 89 | dset = add_indices(torchvision.datasets.CIFAR10) 90 | else: 91 | dset = torchvision.datasets.CIFAR10 92 | if num_positive is None: 93 | trainset = CIFAR10Biaugment(root=root, train=True, download=True, transform=transform_train) 94 | else: 95 | trainset = CIFAR10Multiaugment(root=root, train=True, download=True, transform=transform_train, 96 | n_augmentations=num_positive) 97 | testset = dset(root=root, train=False, download=True, transform=transform_test) 98 | clftrainset = dset(root=root, train=True, download=True, transform=transform_clftrain) 99 | num_classes = 10 100 | stem = StemCIFAR 101 | elif dataset == 'stl10': 102 | if add_indices_to_data: 103 | dset = add_indices(torchvision.datasets.STL10) 104 | else: 105 | dset = torchvision.datasets.STl10 106 | if num_positive is None: 107 | trainset = STL10Biaugment(root=root, split='unlabeled', download=True, transform=transform_train) 108 | else: 109 | raise NotImplementedError 110 | testset = dset(root=root, split='train', download=True, transform=transform_test) 111 | clftrainset = dset(root=root, split='test', download=True, transform=transform_clftrain) 112 | num_classes = 10 113 | stem = StemSTL 114 | elif dataset == 'imagenet': 115 | if add_indices_to_data: 116 | dset = add_indices(torchvision.datasets.ImageNet) 117 | else: 118 | dset = torchvision.datasets.ImageNet 119 | if num_positive is None: 120 | trainset = ImageNetBiaugment(root=root, split='train', transform=transform_train) 121 | else: 122 | raise NotImplementedError 123 | testset = dset(root=root, split='val', transform=transform_test) 124 | clftrainset = dset(root=root, split='train', transform=transform_clftrain) 125 | num_classes = len(testset.classes) 126 | stem = StemImageNet 127 | else: 128 | raise ValueError("Bad dataset value: {}".format(dataset)) 129 | 130 | return trainset, testset, clftrainset, num_classes, stem 131 | -------------------------------------------------------------------------------- /critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LinearCritic(nn.Module): 6 | 7 | def __init__(self, latent_dim, temperature=1.): 8 | super(LinearCritic, self).__init__() 9 | self.temperature = temperature 10 | self.projection_dim = 128 11 | self.w1 = nn.Linear(latent_dim, latent_dim, bias=False) 12 | self.bn1 = nn.BatchNorm1d(latent_dim) 13 | self.relu = nn.ReLU() 14 | self.w2 = nn.Linear(latent_dim, self.projection_dim, bias=False) 15 | self.bn2 = nn.BatchNorm1d(self.projection_dim, affine=False) 16 | self.cossim = nn.CosineSimilarity(dim=-1) 17 | 18 | def project(self, h): 19 | return self.bn2(self.w2(self.relu(self.bn1(self.w1(h))))) 20 | 21 | def forward(self, h1, h2): 22 | z1, z2 = self.project(h1), self.project(h2) 23 | sim11 = self.cossim(z1.unsqueeze(-2), z1.unsqueeze(-3)) / self.temperature 24 | sim22 = self.cossim(z2.unsqueeze(-2), z2.unsqueeze(-3)) / self.temperature 25 | sim12 = self.cossim(z1.unsqueeze(-2), z2.unsqueeze(-3)) / self.temperature 26 | d = sim12.shape[-1] 27 | sim11[..., range(d), range(d)] = float('-inf') 28 | sim22[..., range(d), range(d)] = float('-inf') 29 | raw_scores1 = torch.cat([sim12, sim11], dim=-1) 30 | raw_scores2 = torch.cat([sim22, sim12.transpose(-1, -2)], dim=-1) 31 | raw_scores = torch.cat([raw_scores1, raw_scores2], dim=-2) 32 | targets = torch.arange(2 * d, dtype=torch.long, device=raw_scores.device) 33 | return raw_scores, targets 34 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from PIL import Image 5 | 6 | 7 | class CIFAR10Biaugment(torchvision.datasets.CIFAR10): 8 | 9 | def __getitem__(self, index): 10 | """ 11 | Args: 12 | index (int): Index 13 | 14 | Returns: 15 | tuple: (image, target) where target is index of the target class. 16 | """ 17 | img, target = self.data[index], self.targets[index] 18 | 19 | # doing this so that it is consistent with all other datasets 20 | # to return a PIL Image 21 | pil_img = Image.fromarray(img) 22 | 23 | if self.transform is not None: 24 | img = self.transform(pil_img) 25 | img2 = self.transform(pil_img) 26 | else: 27 | img2 = img = pil_img 28 | 29 | if self.target_transform is not None: 30 | target = self.target_transform(target) 31 | 32 | return (img, img2), target, index 33 | 34 | 35 | class CIFAR100Biaugment(CIFAR10Biaugment): 36 | """`CIFAR100 `_ Dataset. 37 | 38 | This is a subclass of the `CIFAR10Biaugment` Dataset. 39 | """ 40 | base_folder = 'cifar-100-python' 41 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 42 | filename = "cifar-100-python.tar.gz" 43 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 44 | train_list = [ 45 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 46 | ] 47 | 48 | test_list = [ 49 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 50 | ] 51 | meta = { 52 | 'filename': 'meta', 53 | 'key': 'fine_label_names', 54 | 'md5': '7973b15100ade9c7d40fb424638fde48', 55 | } 56 | 57 | 58 | class STL10Biaugment(torchvision.datasets.STL10): 59 | 60 | def __getitem__(self, index): 61 | """ 62 | Args: 63 | index (int): Index 64 | 65 | Returns: 66 | tuple: (image, target) where target is index of the target class. 67 | """ 68 | if self.labels is not None: 69 | img, target = self.data[index], int(self.labels[index]) 70 | else: 71 | img, target = self.data[index], None 72 | 73 | # doing this so that it is consistent with all other datasets 74 | # to return a PIL Image 75 | pil_img = Image.fromarray(np.transpose(img, (1, 2, 0))) 76 | 77 | if self.transform is not None: 78 | img = self.transform(pil_img) 79 | img2 = self.transform(pil_img) 80 | else: 81 | img2 = img = pil_img 82 | 83 | if self.target_transform is not None: 84 | target = self.target_transform(target) 85 | 86 | return (img, img2), target, index 87 | 88 | 89 | class CIFAR10Multiaugment(torchvision.datasets.CIFAR10): 90 | 91 | def __init__(self, *args, n_augmentations=8, **kwargs): 92 | super(CIFAR10Multiaugment, self).__init__(*args, **kwargs) 93 | self.n_augmentations = n_augmentations 94 | assert self.transforms is not None 95 | 96 | def __getitem__(self, index): 97 | """ 98 | Args: 99 | index (int): Index 100 | 101 | Returns: 102 | tuple: (image, target) where target is index of the target class. 103 | """ 104 | img, target = self.data[index], self.targets[index] 105 | 106 | # doing this so that it is consistent with all other datasets 107 | # to return a PIL Image 108 | pil_img = Image.fromarray(img) 109 | 110 | imgs = [self.transform(pil_img) for _ in range(self.n_augmentations)] 111 | 112 | if self.target_transform is not None: 113 | target = self.target_transform(target) 114 | 115 | return torch.stack(imgs, dim=0), target, index 116 | 117 | 118 | class CIFAR100Multiaugment(CIFAR10Multiaugment): 119 | """`CIFAR100 `_ Dataset. 120 | 121 | This is a subclass of the `CIFAR10Biaugment` Dataset. 122 | """ 123 | base_folder = 'cifar-100-python' 124 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 125 | filename = "cifar-100-python.tar.gz" 126 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 127 | train_list = [ 128 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 129 | ] 130 | 131 | test_list = [ 132 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 133 | ] 134 | meta = { 135 | 'filename': 'meta', 136 | 'key': 'fine_label_names', 137 | 'md5': '7973b15100ade9c7d40fb424638fde48', 138 | } 139 | 140 | 141 | class ImageNetBiaugment(torchvision.datasets.ImageNet): 142 | 143 | def __getitem__(self, index): 144 | """ 145 | Args: 146 | index (int): Index 147 | Returns: 148 | tuple: (sample, target) where target is class_index of the target class. 149 | """ 150 | path, target = self.samples[index] 151 | sample = self.loader(path) 152 | if self.transform is not None: 153 | img = self.transform(sample) 154 | img2 = self.transform(sample) 155 | else: 156 | img2 = img = sample 157 | if self.target_transform is not None: 158 | target = self.target_transform(target) 159 | 160 | return (img, img2), target, index 161 | 162 | 163 | def add_indices(dataset_cls): 164 | class NewClass(dataset_cls): 165 | def __getitem__(self, item): 166 | output = super(NewClass, self).__getitem__(item) 167 | return (*output, item) 168 | 169 | return NewClass 170 | -------------------------------------------------------------------------------- /evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import * 2 | from .lbfgs import * 3 | -------------------------------------------------------------------------------- /evaluate/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | def save_checkpoint(net, clf, critic, epoch, args, script_name): 7 | # Save checkpoint. 8 | print('Saving..') 9 | state = { 10 | 'net': net.state_dict(), 11 | 'clf': clf.state_dict(), 12 | 'critic': critic.state_dict(), 13 | 'epoch': epoch, 14 | 'args': vars(args), 15 | 'script': script_name 16 | } 17 | if not os.path.isdir('checkpoint'): 18 | os.mkdir('checkpoint') 19 | destination = os.path.join('./checkpoint', args.filename) 20 | torch.save(state, destination) -------------------------------------------------------------------------------- /evaluate/lbfgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from tqdm import tqdm 5 | 6 | 7 | def encode_train_set(clftrainloader, device, net): 8 | net.eval() 9 | 10 | store = [] 11 | with torch.no_grad(): 12 | t = tqdm(enumerate(clftrainloader), desc='Encoded: **/** ', total=len(clftrainloader), 13 | bar_format='{desc}{bar}{r_bar}') 14 | for batch_idx, (inputs, targets) in t: 15 | inputs, targets = inputs.to(device), targets.to(device) 16 | representation = net(inputs) 17 | store.append((representation, targets)) 18 | 19 | t.set_description('Encoded %d/%d' % (batch_idx, len(clftrainloader))) 20 | 21 | X, y = zip(*store) 22 | X, y = torch.cat(X, dim=0), torch.cat(y, dim=0) 23 | return X, y 24 | 25 | 26 | def train_clf(X, y, representation_dim, num_classes, device, reg_weight=1e-3): 27 | print('\nL2 Regularization weight: %g' % reg_weight) 28 | 29 | criterion = nn.CrossEntropyLoss() 30 | n_lbfgs_steps = 500 31 | 32 | # Should be reset after each epoch for a completely independent evaluation 33 | clf = nn.Linear(representation_dim, num_classes).to(device) 34 | clf_optimizer = optim.LBFGS(clf.parameters()) 35 | clf.train() 36 | 37 | t = tqdm(range(n_lbfgs_steps), desc='Loss: **** | Train Acc: ****% ', bar_format='{desc}{bar}{r_bar}') 38 | for _ in t: 39 | def closure(): 40 | clf_optimizer.zero_grad() 41 | raw_scores = clf(X) 42 | loss = criterion(raw_scores, y) 43 | loss += reg_weight * clf.weight.pow(2).sum() 44 | loss.backward() 45 | 46 | _, predicted = raw_scores.max(1) 47 | correct = predicted.eq(y).sum().item() 48 | 49 | t.set_description('Loss: %.3f | Train Acc: %.3f%% ' % (loss, 100. * correct / y.shape[0])) 50 | 51 | return loss 52 | 53 | clf_optimizer.step(closure) 54 | 55 | return clf 56 | 57 | 58 | def test(testloader, device, net, clf): 59 | criterion = nn.CrossEntropyLoss() 60 | net.eval() 61 | clf.eval() 62 | test_clf_loss = 0 63 | correct = 0 64 | total = 0 65 | with torch.no_grad(): 66 | t = tqdm(enumerate(testloader), total=len(testloader), desc='Loss: **** | Test Acc: ****% ', 67 | bar_format='{desc}{bar}{r_bar}') 68 | for batch_idx, (inputs, targets) in t: 69 | inputs, targets = inputs.to(device), targets.to(device) 70 | representation = net(inputs) 71 | # test_repr_loss = criterion(representation, targets) 72 | raw_scores = clf(representation) 73 | clf_loss = criterion(raw_scores, targets) 74 | 75 | test_clf_loss += clf_loss.item() 76 | _, predicted = raw_scores.max(1) 77 | total += targets.size(0) 78 | correct += predicted.eq(targets).sum().item() 79 | 80 | t.set_description('Loss: %.3f | Test Acc: %.3f%% ' % (test_clf_loss / (batch_idx + 1), 100. * correct / total)) 81 | 82 | acc = 100. * correct / total 83 | return acc 84 | -------------------------------------------------------------------------------- /gradient_linear_clf.py: -------------------------------------------------------------------------------- 1 | '''This script trains the downstream classifier using gradients (for large datasets).''' 2 | import argparse 3 | import os 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from tqdm import tqdm 10 | 11 | from configs import get_datasets 12 | from evaluate.lbfgs import test 13 | from models import * 14 | 15 | parser = argparse.ArgumentParser(description='Train downstream classifier with gradients.') 16 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 17 | parser.add_argument("--momentum", default=0.9, type=float, help='SGD momentum') 18 | parser.add_argument("--batch-size", type=int, default=512, help='Training batch size') 19 | parser.add_argument("--num-epochs", type=int, default=90, help='Number of training epochs') 20 | parser.add_argument("--num-workers", type=int, default=2, help='Number of threads for data loaders') 21 | parser.add_argument("--weight-decay", type=float, default=1e-6, help='Weight decay on the linear classifier') 22 | parser.add_argument("--nesterov", action="store_true", help="Turn on Nesterov style momentum") 23 | parser.add_argument("--load-from", type=str, default='ckpt.pth', help='File to load from') 24 | args = parser.parse_args() 25 | 26 | # Load checkpoint. 27 | print('==> Loading settings from checkpoint..') 28 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 29 | resume_from = os.path.join('./checkpoint', args.load_from) 30 | checkpoint = torch.load(resume_from) 31 | args.dataset = checkpoint['args']['dataset'] 32 | args.arch = checkpoint['args']['arch'] 33 | 34 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 35 | best_acc = 0 36 | 37 | # Data 38 | print('==> Preparing data..') 39 | _, testset, clftrainset, num_classes, stem = get_datasets(args.dataset, augment_clf_train=True) 40 | 41 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, 42 | num_workers=args.num_workers, pin_memory=True) 43 | clftrainloader = torch.utils.data.DataLoader(clftrainset, batch_size=args.batch_size, shuffle=True, 44 | num_workers=args.num_workers, pin_memory=True) 45 | 46 | # Model 47 | print('==> Building model..') 48 | ############################################################## 49 | # Encoder 50 | ############################################################## 51 | if args.arch == 'resnet18': 52 | net = ResNet18(stem=stem) 53 | elif args.arch == 'resnet34': 54 | net = ResNet34(stem=stem) 55 | elif args.arch == 'resnet50': 56 | net = ResNet50(stem=stem) 57 | else: 58 | raise ValueError("Bad architecture specification") 59 | net = net.to(device) 60 | 61 | ############################################################## 62 | # Classifier 63 | ############################################################## 64 | clf = nn.Linear(net.representation_dim, num_classes).to(device) 65 | 66 | if device == 'cuda': 67 | repr_dim = net.representation_dim 68 | net = torch.nn.DataParallel(net) 69 | net.representation_dim = repr_dim 70 | cudnn.benchmark = True 71 | 72 | print('==> Loading encoder from checkpoint..') 73 | net.load_state_dict(checkpoint['net']) 74 | 75 | criterion = nn.CrossEntropyLoss() 76 | clf_optimizer = optim.SGD(clf.parameters(), lr=args.lr, momentum=args.momentum, nesterov=args.nesterov, 77 | weight_decay=args.weight_decay) 78 | 79 | 80 | def train_clf(epoch): 81 | print('\nEpoch %d' % epoch) 82 | net.eval() 83 | clf.train() 84 | train_loss = 0 85 | t = tqdm(enumerate(clftrainloader), desc='Loss: **** ', total=len(clftrainloader), bar_format='{desc}{bar}{r_bar}') 86 | for batch_idx, (inputs, targets) in t: 87 | clf_optimizer.zero_grad() 88 | inputs, targets = inputs.to(device), targets.to(device) 89 | representation = net(inputs).detach() 90 | predictions = clf(representation) 91 | loss = criterion(predictions, targets) 92 | loss.backward() 93 | clf_optimizer.step() 94 | 95 | train_loss += loss.item() 96 | 97 | t.set_description('Loss: %.3f ' % (train_loss / (batch_idx + 1))) 98 | 99 | 100 | for epoch in range(args.num_epochs): 101 | train_clf(epoch) 102 | acc = test(testloader, device, net, clf) 103 | if acc > best_acc: 104 | best_acc = acc 105 | print("Best test accuracy", best_acc, "%") 106 | -------------------------------------------------------------------------------- /lbfgs_linear_clf.py: -------------------------------------------------------------------------------- 1 | '''This script tunes the L2 reg weight of the final classifier.''' 2 | import argparse 3 | import os 4 | import math 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | 9 | from configs import get_datasets 10 | from evaluate import encode_train_set, train_clf, test 11 | from models import * 12 | 13 | parser = argparse.ArgumentParser(description='Tune regularization coefficient of downstream classifier.') 14 | parser.add_argument("--num-workers", type=int, default=2, help='Number of threads for data loaders') 15 | parser.add_argument("--load-from", type=str, default='ckpt.pth', help='File to load from') 16 | args = parser.parse_args() 17 | 18 | # Load checkpoint. 19 | print('==> Loading settings from checkpoint..') 20 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 21 | resume_from = os.path.join('./checkpoint', args.load_from) 22 | checkpoint = torch.load(resume_from) 23 | args.dataset = checkpoint['args']['dataset'] 24 | args.arch = checkpoint['args']['arch'] 25 | 26 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 27 | 28 | # Data 29 | print('==> Preparing data..') 30 | _, testset, clftrainset, num_classes, stem = get_datasets(args.dataset) 31 | 32 | testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=args.num_workers, 33 | pin_memory=True) 34 | clftrainloader = torch.utils.data.DataLoader(clftrainset, batch_size=1000, shuffle=False, num_workers=args.num_workers, 35 | pin_memory=True) 36 | 37 | # Model 38 | print('==> Building model..') 39 | ############################################################## 40 | # Encoder 41 | ############################################################## 42 | if args.arch == 'resnet18': 43 | net = ResNet18(stem=stem) 44 | elif args.arch == 'resnet34': 45 | net = ResNet34(stem=stem) 46 | elif args.arch == 'resnet50': 47 | net = ResNet50(stem=stem) 48 | else: 49 | raise ValueError("Bad architecture specification") 50 | net = net.to(device) 51 | 52 | if device == 'cuda': 53 | repr_dim = net.representation_dim 54 | net = torch.nn.DataParallel(net) 55 | net.representation_dim = repr_dim 56 | cudnn.benchmark = True 57 | 58 | print('==> Loading encoder from checkpoint..') 59 | net.load_state_dict(checkpoint['net']) 60 | 61 | 62 | best_acc = 0 63 | X, y = encode_train_set(clftrainloader, device, net) 64 | for reg_weight in torch.exp(math.log(10) * torch.linspace(-7, -3, 16, dtype=torch.float, device=device)): 65 | clf = train_clf(X, y, net.representation_dim, num_classes, device, reg_weight=reg_weight) 66 | acc = test(testloader, device, net, clf) 67 | if acc > best_acc: 68 | best_acc = acc 69 | print("Best test accuracy", best_acc, "%") 70 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 13 | 14 | 15 | class StemCIFAR(nn.Module): 16 | def __init__(self): 17 | super(StemCIFAR, self).__init__() 18 | self.in_planes = 64 19 | self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(self.in_planes) 21 | 22 | def forward(self, inputs): 23 | return F.relu(self.bn1(self.conv1(inputs))) 24 | 25 | 26 | class StemSTL(StemCIFAR): 27 | def __init__(self): 28 | super(StemSTL, self).__init__() 29 | self.maxpool = nn.MaxPool2d(kernel_size=3) 30 | 31 | def forward(self, inputs): 32 | out = F.relu(self.bn1(self.conv1(inputs))) 33 | out = self.maxpool(out) 34 | return out 35 | 36 | 37 | class StemImageNet(nn.Module): 38 | def __init__(self): 39 | super(StemImageNet, self).__init__() 40 | self.inplanes = 64 41 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 42 | self.bn1 = nn.BatchNorm2d(self.inplanes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 45 | 46 | def forward(self, x): 47 | x = self.conv1(x) 48 | x = self.bn1(x) 49 | x = self.relu(x) 50 | x = self.maxpool(x) 51 | return x 52 | 53 | 54 | class ResNet(nn.Module): 55 | 56 | def __init__(self, block, layers, num_classes=None, zero_init_residual=False, 57 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 58 | norm_layer=None, stem=StemCIFAR): 59 | super(ResNet, self).__init__() 60 | if norm_layer is None: 61 | norm_layer = nn.BatchNorm2d 62 | self._norm_layer = norm_layer 63 | 64 | self.inplanes = 64 65 | self.dilation = 1 66 | if replace_stride_with_dilation is None: 67 | # each element in the tuple indicates if we should replace 68 | # the 2x2 stride with a dilated convolution instead 69 | replace_stride_with_dilation = [False, False, False] 70 | if len(replace_stride_with_dilation) != 3: 71 | raise ValueError("replace_stride_with_dilation should be None " 72 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 73 | self.groups = groups 74 | self.base_width = width_per_group 75 | self.stem = stem() 76 | self.layer1 = self._make_layer(block, 64, layers[0]) 77 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 78 | dilate=replace_stride_with_dilation[0]) 79 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 80 | dilate=replace_stride_with_dilation[1]) 81 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 82 | dilate=replace_stride_with_dilation[2]) 83 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 84 | if num_classes is not None: 85 | self.fc = nn.Linear(512 * block.expansion, num_classes) 86 | else: 87 | self.fc = None 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | # Zero-initialize the last BN in each residual branch, 97 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 98 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 99 | if zero_init_residual: 100 | for m in self.modules(): 101 | if isinstance(m, Bottleneck): 102 | nn.init.constant_(m.bn3.weight, 0) 103 | elif isinstance(m, BasicBlock): 104 | nn.init.constant_(m.bn2.weight, 0) 105 | 106 | self.representation_dim = 512 * block.expansion 107 | 108 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 109 | norm_layer = self._norm_layer 110 | downsample = None 111 | previous_dilation = self.dilation 112 | if dilate: 113 | self.dilation *= stride 114 | stride = 1 115 | if stride != 1 or self.inplanes != planes * block.expansion: 116 | downsample = nn.Sequential( 117 | conv1x1(self.inplanes, planes * block.expansion, stride), 118 | norm_layer(planes * block.expansion), 119 | ) 120 | 121 | layers = [] 122 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 123 | self.base_width, previous_dilation, norm_layer)) 124 | self.inplanes = planes * block.expansion 125 | for _ in range(1, blocks): 126 | layers.append(block(self.inplanes, planes, groups=self.groups, 127 | base_width=self.base_width, dilation=self.dilation, 128 | norm_layer=norm_layer)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def _forward_impl(self, x): 133 | # See note [TorchScript super()] 134 | x = self.stem(x) 135 | 136 | x = self.layer1(x) 137 | x = self.layer2(x) 138 | x = self.layer3(x) 139 | x = self.layer4(x) 140 | 141 | x = self.avgpool(x) 142 | x = torch.flatten(x, 1) 143 | if self.fc: 144 | x = self.fc(x) 145 | 146 | return x 147 | 148 | def forward(self, x): 149 | return self._forward_impl(x) 150 | 151 | 152 | def ResNet18(**kwargs): 153 | return ResNet(BasicBlock, [2,2,2,2], **kwargs) 154 | 155 | 156 | def ResNet34(**kwargs): 157 | return ResNet(BasicBlock, [3,4,6,3], **kwargs) 158 | 159 | 160 | def ResNet50(**kwargs): 161 | return ResNet(Bottleneck, [3,4,6,3], **kwargs) 162 | 163 | 164 | def ResNet101(**kwargs): 165 | return ResNet(Bottleneck, [3,4,23,3], **kwargs) 166 | 167 | 168 | def ResNet152(**kwargs): 169 | return ResNet(Bottleneck, [3,8,36,3], **kwargs) 170 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | torch>=1.2.0 3 | torchvision>=0.4.0 4 | Pillow>=7.1.1 5 | torchlars==0.1.2 6 | tqdm==4.66.3 7 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class CosineAnnealingWithLinearRampLR(_LRScheduler): 7 | 8 | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, ramp_len=10): 9 | self.T_max = T_max 10 | self.eta_min = eta_min 11 | self.ramp_len = ramp_len 12 | super(CosineAnnealingWithLinearRampLR, self).__init__(optimizer, last_epoch) 13 | 14 | def get_lr(self): 15 | return self._get_closed_form_lr() 16 | 17 | def _get_closed_form_lr(self): 18 | cosine_lr = [self.eta_min + (base_lr - self.eta_min) * 19 | (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 20 | for base_lr in self.base_lrs] 21 | linear_lr = [base_lr * (1 + self.last_epoch) / self.ramp_len for base_lr in self.base_lrs] 22 | return [min(x, y) for x, y in zip(cosine_lr, linear_lr)] 23 | -------------------------------------------------------------------------------- /simclr.py: -------------------------------------------------------------------------------- 1 | '''Train an encoder using Contrastive Learning.''' 2 | import argparse 3 | import os 4 | import subprocess 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torchlars import LARS 11 | from tqdm import tqdm 12 | 13 | from configs import get_datasets 14 | from critic import LinearCritic 15 | from evaluate import save_checkpoint, encode_train_set, train_clf, test 16 | from models import * 17 | from scheduler import CosineAnnealingWithLinearRampLR 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch Contrastive Learning.') 20 | parser.add_argument('--base-lr', default=0.25, type=float, help='base learning rate, rescaled by batch_size/256') 21 | parser.add_argument("--momentum", default=0.9, type=float, help='SGD momentum') 22 | parser.add_argument('--resume', '-r', type=str, default='', help='resume from checkpoint with this filename') 23 | parser.add_argument('--dataset', '-d', type=str, default='cifar10', help='dataset', 24 | choices=['cifar10', 'cifar100', 'stl10', 'imagenet']) 25 | parser.add_argument('--temperature', type=float, default=0.5, help='InfoNCE temperature') 26 | parser.add_argument("--batch-size", type=int, default=512, help='Training batch size') 27 | parser.add_argument("--num-epochs", type=int, default=100, help='Number of training epochs') 28 | parser.add_argument("--cosine-anneal", action='store_true', help="Use cosine annealing on the learning rate") 29 | parser.add_argument("--arch", type=str, default='resnet50', help='Encoder architecture', 30 | choices=['resnet18', 'resnet34', 'resnet50']) 31 | parser.add_argument("--num-workers", type=int, default=2, help='Number of threads for data loaders') 32 | parser.add_argument("--test-freq", type=int, default=10, help='Frequency to fit a linear clf with L-BFGS for testing' 33 | 'Not appropriate for large datasets. Set 0 to avoid ' 34 | 'classifier only training here.') 35 | parser.add_argument("--filename", type=str, default='ckpt.pth', help='Output file name') 36 | args = parser.parse_args() 37 | args.lr = args.base_lr * (args.batch_size / 256) 38 | 39 | args.git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']) 40 | args.git_diff = subprocess.check_output(['git', 'diff']) 41 | 42 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | best_acc = 0 # best test accuracy 44 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 45 | clf = None 46 | 47 | print('==> Preparing data..') 48 | trainset, testset, clftrainset, num_classes, stem = get_datasets(args.dataset) 49 | 50 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 51 | num_workers=args.num_workers, pin_memory=True) 52 | testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=args.num_workers, 53 | pin_memory=True) 54 | clftrainloader = torch.utils.data.DataLoader(clftrainset, batch_size=1000, shuffle=False, num_workers=args.num_workers, 55 | pin_memory=True) 56 | 57 | # Model 58 | print('==> Building model..') 59 | ############################################################## 60 | # Encoder 61 | ############################################################## 62 | if args.arch == 'resnet18': 63 | net = ResNet18(stem=stem) 64 | elif args.arch == 'resnet34': 65 | net = ResNet34(stem=stem) 66 | elif args.arch == 'resnet50': 67 | net = ResNet50(stem=stem) 68 | else: 69 | raise ValueError("Bad architecture specification") 70 | net = net.to(device) 71 | 72 | ############################################################## 73 | # Critic 74 | ############################################################## 75 | critic = LinearCritic(net.representation_dim, temperature=args.temperature).to(device) 76 | 77 | if device == 'cuda': 78 | repr_dim = net.representation_dim 79 | net = torch.nn.DataParallel(net) 80 | net.representation_dim = repr_dim 81 | cudnn.benchmark = True 82 | 83 | if args.resume: 84 | # Load checkpoint. 85 | print('==> Resuming from checkpoint..') 86 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 87 | resume_from = os.path.join('./checkpoint', args.resume) 88 | checkpoint = torch.load(resume_from) 89 | net.load_state_dict(checkpoint['net']) 90 | critic.load_state_dict(checkpoint['critic']) 91 | best_acc = checkpoint['acc'] 92 | start_epoch = checkpoint['epoch'] 93 | 94 | criterion = nn.CrossEntropyLoss() 95 | base_optimizer = optim.SGD(list(net.parameters()) + list(critic.parameters()), lr=args.lr, weight_decay=1e-6, 96 | momentum=args.momentum) 97 | if args.cosine_anneal: 98 | scheduler = CosineAnnealingWithLinearRampLR(base_optimizer, args.num_epochs) 99 | encoder_optimizer = LARS(base_optimizer, trust_coef=1e-3) 100 | 101 | 102 | # Training 103 | def train(epoch): 104 | print('\nEpoch: %d' % epoch) 105 | net.train() 106 | critic.train() 107 | train_loss = 0 108 | t = tqdm(enumerate(trainloader), desc='Loss: **** ', total=len(trainloader), bar_format='{desc}{bar}{r_bar}') 109 | for batch_idx, (inputs, _, _) in t: 110 | x1, x2 = inputs 111 | x1, x2 = x1.to(device), x2.to(device) 112 | encoder_optimizer.zero_grad() 113 | representation1, representation2 = net(x1), net(x2) 114 | raw_scores, pseudotargets = critic(representation1, representation2) 115 | loss = criterion(raw_scores, pseudotargets) 116 | loss.backward() 117 | encoder_optimizer.step() 118 | 119 | train_loss += loss.item() 120 | 121 | t.set_description('Loss: %.3f ' % (train_loss / (batch_idx + 1))) 122 | 123 | 124 | for epoch in range(start_epoch, start_epoch + args.num_epochs): 125 | train(epoch) 126 | if (args.test_freq > 0) and (epoch % args.test_freq == (args.test_freq - 1)): 127 | X, y = encode_train_set(clftrainloader, device, net) 128 | clf = train_clf(X, y, net.representation_dim, num_classes, device, reg_weight=1e-5) 129 | acc = test(testloader, device, net, clf) 130 | if acc > best_acc: 131 | best_acc = acc 132 | save_checkpoint(net, clf, critic, epoch, args, os.path.basename(__file__)) 133 | elif args.test_freq == 0: 134 | save_checkpoint(net, clf, critic, epoch, args, os.path.basename(__file__)) 135 | if args.cosine_anneal: 136 | scheduler.step() 137 | --------------------------------------------------------------------------------