├── assets ├── runs_train-acc.png ├── runs_valid-acc.png ├── runs_valid-loss.png ├── runs_valid-acc.svg └── runs_train-acc.svg ├── utils.py ├── LICENSE ├── README.md ├── loss.py ├── .gitignore ├── requirements.txt ├── data.py ├── train_classifier.py ├── train_features.py └── models.py /assets/runs_train-acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdiephuis/SimCLR/HEAD/assets/runs_train-acc.png -------------------------------------------------------------------------------- /assets/runs_valid-acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdiephuis/SimCLR/HEAD/assets/runs_valid-acc.png -------------------------------------------------------------------------------- /assets/runs_valid-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdiephuis/SimCLR/HEAD/assets/runs_valid-loss.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def type_tdouble(use_cuda=False): 6 | return torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor 7 | 8 | 9 | def one_hot(labels, n_class, use_cuda=False): 10 | # Ensure labels are [N x 1] 11 | if len(list(labels.size())) == 1: 12 | labels = labels.unsqueeze(1) 13 | mask = type_tdouble(use_cuda)(labels.size(0), n_class).fill_(0) 14 | # scatter dimension, position indices, fill_value 15 | return mask.scatter_(1, labels, 1) 16 | 17 | 18 | def init_weights(module): 19 | for m in module.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 22 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 23 | nn.init.constant_(m.weight, 1) 24 | nn.init.constant_(m.bias, 0) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Maurits Diephuis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimCLR 2 | Pytorch implementation of the paper 3 | [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709) 4 | 5 | * ADAM optimizer 6 | * ExponentialLR schedular. No warmup or other exotics 7 | * Batchsize of 256 via gradient accumulation 8 | 9 | ## Feature model 10 | * [Resnet50](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py), where the first convolutional layer has a filter size of 3 instead of 7. 11 | * h() feature dimensionality: 2048 12 | * z() learning head output dimensionality: 128 13 | 14 | ## Classifier model 15 | * Simple 1 layer Neural network from 2048 to num_classes 16 | 17 | ## Classification Results 18 | 19 | | Epochs | 100 | 200 | 20 | | ------ |-----| ------| 21 | | Paper | 83.9| 89.2 | 22 | | This repo |87.49 | 88.16 | 23 | 24 | ## Run 25 | Train the feature extracting model (resnet). Note CIFAR10C inherits from datasets.CIFAR and provides the augmented image pairs. 26 | 27 | python train_features.py --batch-size=64 --accumulation-steps=4 --tau=0.5 28 | --feature-size=128 --dataset-name=CIFAR10C --data-dir=path/to/your/data 29 | 30 | Train the classifier model. Needs a saved feature model to extract features from images. 31 | 32 | python train_classifier.py --load-model=models/modelname_timestamp.pt --dataset-name=CIFAR10 33 | --data-dir=path/to/your/data 34 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class contrastive_loss(nn.Module): 6 | def __init__(self, tau=1, normalize=False): 7 | super(contrastive_loss, self).__init__() 8 | self.tau = tau 9 | self.normalize = normalize 10 | 11 | def forward(self, xi, xj): 12 | 13 | x = torch.cat((xi, xj), dim=0) 14 | 15 | is_cuda = x.is_cuda 16 | sim_mat = torch.mm(x, x.T) 17 | if self.normalize: 18 | sim_mat_denom = torch.mm(torch.norm(x, dim=1).unsqueeze(1), torch.norm(x, dim=1).unsqueeze(1).T) 19 | sim_mat = sim_mat / sim_mat_denom.clamp(min=1e-16) 20 | 21 | sim_mat = torch.exp(sim_mat / self.tau) 22 | 23 | # no diag because it's not diffrentiable -> sum - exp(1 / tau) 24 | # diag_ind = torch.eye(xi.size(0) * 2).bool() 25 | # diag_ind = diag_ind.cuda() if use_cuda else diag_ind 26 | 27 | # sim_mat = sim_mat.masked_fill_(diag_ind, 0) 28 | 29 | # top 30 | if self.normalize: 31 | sim_mat_denom = torch.norm(xi, dim=1) * torch.norm(xj, dim=1) 32 | sim_match = torch.exp(torch.sum(xi * xj, dim=-1) / sim_mat_denom / self.tau) 33 | else: 34 | sim_match = torch.exp(torch.sum(xi * xj, dim=-1) / self.tau) 35 | 36 | sim_match = torch.cat((sim_match, sim_match), dim=0) 37 | 38 | norm_sum = torch.exp(torch.ones(x.size(0)) / self.tau) 39 | norm_sum = norm_sum.cuda() if is_cuda else norm_sum 40 | loss = torch.mean(-torch.log(sim_match / (torch.sum(sim_mat, dim=-1) - norm_sum))) 41 | 42 | return loss 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Maurits 2 | .ipynb_checkpoints/ 3 | results/ 4 | graphs/ 5 | latent_temp_dir/ 6 | todo.txt 7 | data/ 8 | models/ 9 | output/ 10 | runs/ 11 | logs/ 12 | *.pkl 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # Pycharm 20 | .idea/ 21 | 22 | # tensorboard 23 | runs/ 24 | logs/ 25 | 26 | # C extensions 27 | *.so 28 | 29 | # vscode 30 | .vscode/ 31 | 32 | #shideh 33 | .DS_Store 34 | 35 | 36 | # Distribution / packaging 37 | .Python 38 | env/ 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | .hypothesis/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # pyenv 101 | .python-version 102 | 103 | # celery beat schedule file 104 | celerybeat-schedule 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # dotenv 110 | .env 111 | 112 | # virtualenv 113 | .venv 114 | venv/ 115 | ENV/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | 130 | # custom 131 | .datasets 132 | .models 133 | 134 | 135 | 136 | 137 | # Sohrab 138 | paper/rho_VAE.aux 139 | paper/rho_VAE.bbl 140 | paper/rho_VAE.blg 141 | paper/rho_VAE.log 142 | paper/rho_VAE.out 143 | paper/rho_VAE.synctex.gz 144 | paper/figs/curves/*[!.pdf] 145 | paper/figs/curves/dat 146 | 147 | 148 | notebooks/data/* 149 | notebooks/.ipynb_checkpoints 150 | notebooks/samples/vanilla_VAE/* 151 | notebooks/samples/rho_VAE/* 152 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.0 2 | alabaster==0.7.12 3 | appdirs==1.4.4 4 | astor==0.8.0 5 | Babel==2.8.0 6 | backcall==0.1.0 7 | bleach==1.5.0 8 | certifi==2019.9.11 9 | cffi==1.14.0 10 | chardet==3.0.4 11 | click==7.1.2 12 | cloudpickle==1.3.0 13 | cycler==0.10.0 14 | Cython==0.29.16 15 | decorator==4.4.0 16 | defusedxml==0.6.0 17 | docutils==0.16 18 | entrypoints==0.3 19 | fasteners==0.15 20 | future==0.18.2 21 | gast==0.2.2 22 | glfw==1.12.0 23 | google-pasta==0.1.7 24 | grpcio==1.23.0 25 | h5py==2.10.0 26 | html5lib==0.9999999 27 | idna==2.8 28 | ImageHash==4.1.0 29 | imageio==2.5.0 30 | imagesize==1.2.0 31 | imgaug==0.4.0 32 | importlib-metadata==1.7.0 33 | ipdb==0.13.3 34 | ipykernel==5.1.2 35 | ipython==7.7.0 36 | ipython-genutils==0.2.0 37 | ipywidgets==7.5.1 38 | jedi==0.15.1 39 | Jinja2==2.10.1 40 | joblib==0.13.2 41 | jsonpatch==1.24 42 | jsonpointer==2.0 43 | jsonschema==3.0.2 44 | jupyter==1.0.0 45 | jupyter-client==5.3.1 46 | jupyter-console==6.0.0 47 | jupyter-core==4.5.0 48 | Keras-Applications==1.0.8 49 | Keras-Preprocessing==1.1.0 50 | kiwisolver==1.1.0 51 | lockfile==0.12.2 52 | lxml==4.5.2 53 | Markdown==3.1.1 54 | MarkupSafe==1.1.1 55 | matplotlib==3.1.1 56 | mistune==0.8.4 57 | mock==4.0.2 58 | monotonic==1.5 59 | more-itertools==8.4.0 60 | mpmath==1.1.0 61 | munch==2.5.0 62 | nbconvert==5.6.0 63 | nbformat==4.4.0 64 | networkx==2.3 65 | notebook==6.0.0 66 | numpy==1.18.4 67 | numpydoc==1.1.0 68 | opencv-python==4.1.1.26 69 | opt-einsum==3.1.0 70 | packaging==20.4 71 | pandas==0.25.0 72 | pandocfilters==1.4.2 73 | parso==0.5.1 74 | pexpect==4.7.0 75 | pickleshare==0.7.5 76 | Pillow==6.1.0 77 | pluggy==0.13.1 78 | portalocker==1.7.0 79 | pretrainedmodels==0.7.4 80 | progressbar==2.5 81 | progressbar2==3.47.0 82 | prometheus-client==0.7.1 83 | prompt-toolkit==2.0.9 84 | protobuf==3.9.1 85 | ptflops==0.6 86 | ptyprocess==0.6.0 87 | py==1.9.0 88 | pycparser==2.20 89 | pydot==1.4.1 90 | pyglet==1.5.0 91 | Pygments==2.4.2 92 | PyOpenGL==3.1.5 93 | pyparsing==2.4.2 94 | pyrsistent==0.15.4 95 | pytest==5.4.3 96 | pytest-instafail==0.3.0 97 | python-dateutil==2.8.0 98 | python-utils==2.3.0 99 | pytz==2019.2 100 | pyvacy==0.0.32 101 | PyWavelets==1.0.3 102 | PyYAML==5.1 103 | pyzmq==18.1.0 104 | qtconsole==4.5.3 105 | scikit-image==0.15.0 106 | scikit-learn==0.21.3 107 | scipy==1.3.1 108 | seaborn==0.9.0 109 | Send2Trash==1.5.0 110 | Shapely==1.7.0 111 | six==1.12.0 112 | sklearn==0.0 113 | tensorboard==2.0.0 114 | tensorboardX==2.0 115 | tensorflow==2.0.0 116 | tensorflow-estimator==2.0.0 117 | termcolor==1.1.0 118 | terminado==0.8.2 119 | testpath==0.4.2 120 | torch==1.6.0 121 | torchfile==0.1.0 122 | torchlars==0.1.2 123 | torchsummary==1.5.1 124 | torchvision==0.7.0 125 | tornado==6.0.3 126 | tqdm==4.34.0 127 | traitlets==4.3.2 128 | urllib3==1.25.7 129 | visdom==0.1.8.9 130 | vizdoom==1.1.7 131 | wcwidth==0.1.7 132 | webencodings==0.5.1 133 | websocket-client==0.56.0 134 | Werkzeug==0.15.5 135 | wrapt==1.11.2 136 | yacs==0.1.6 137 | zipp==3.1.0 138 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import datasets, transforms 3 | 4 | from PIL import ImageFilter 5 | from PIL import Image 6 | 7 | 8 | class GaussianSmoothing(object): 9 | def __init__(self, radius): 10 | self.radius = radius 11 | 12 | def __call__(self, image): 13 | return image.filter(ImageFilter.GaussianBlur(self.radius)) 14 | 15 | 16 | def cifar_train_transforms(): 17 | all_transforms = transforms.Compose([ 18 | transforms.RandomResizedCrop(32), 19 | transforms.RandomHorizontalFlip(p=0.5), 20 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 21 | transforms.RandomGrayscale(p=0.2), 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 24 | ]) 25 | return all_transforms 26 | 27 | 28 | def cifar_test_transforms(): 29 | all_transforms = transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 32 | ]) 33 | return all_transforms 34 | 35 | 36 | def mnist_train_transforms(): 37 | # Defining the augmentations 38 | all_transforms = transforms.Compose([ 39 | transforms.RandomAffine(degrees=15, 40 | translate=[0.1, 0.1], 41 | scale=[0.9, 1.1], 42 | shear=15), 43 | transforms.ToTensor() 44 | ]) 45 | return all_transforms 46 | 47 | 48 | def mnist_test_transforms(): 49 | all_transforms = transforms.Compose([ 50 | transforms.ToTensor() 51 | ]) 52 | return all_transforms 53 | 54 | 55 | class CIFAR10C(datasets.CIFAR10): 56 | def __init__(self, *args, **kwargs): 57 | super(CIFAR10C, self).__init__(*args, **kwargs) 58 | 59 | def __getitem__(self, index): 60 | img, target = self.data[index], self.targets[index] 61 | 62 | # return a PIL Image 63 | img = Image.fromarray(img) 64 | 65 | if self.transform is not None: 66 | xi = self.transform(img) 67 | xj = self.transform(img) 68 | 69 | if self.target_transform is not None: 70 | target = self.target_transform(target) 71 | 72 | return xi, xj, target 73 | 74 | 75 | class MNISTC(datasets.MNIST): 76 | def __init__(self, *args, **kwargs): 77 | super(MNISTC, self).__init__(*args, **kwargs) 78 | 79 | def __getitem__(self, index): 80 | img, target = self.data[index], int(self.targets[index]) 81 | 82 | # return a PIL Image 83 | img = Image.fromarray(img.numpy(), mode='L') 84 | 85 | if self.transform is not None: 86 | xi = self.transform(img) 87 | xj = self.transform(img) 88 | 89 | if self.target_transform is not None: 90 | target = self.target_transform(target) 91 | 92 | return xi, xj, target 93 | 94 | 95 | class Loader(object): 96 | def __init__(self, dataset_ident, file_path, download, batch_size, train_transform, test_transform, target_transform, use_cuda): 97 | 98 | kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {} 99 | 100 | loader_map = { 101 | 'CIFAR10C': CIFAR10C, 102 | 'CIFAR10': datasets.CIFAR10, 103 | 'MNIST': datasets.MNIST, 104 | 'MNISTC': MNISTC 105 | } 106 | 107 | num_class = { 108 | 'CIFAR10C': 10, 109 | 'CIFAR10': 10, 110 | 'MNIST': 10, 111 | 'MNISTC': 10 112 | } 113 | 114 | # Get the datasets 115 | train_dataset, test_dataset = self.get_dataset(loader_map[dataset_ident], file_path, download, 116 | train_transform, test_transform, target_transform) 117 | # Set the loaders 118 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs) 119 | self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **kwargs) 120 | 121 | tmp_batch = self.train_loader.__iter__().__next__()[0] 122 | self.img_shape = list(tmp_batch.size())[1:] 123 | self.num_class = num_class[dataset_ident] 124 | 125 | @staticmethod 126 | def get_dataset(dataset, file_path, download, train_transform, test_transform, target_transform): 127 | 128 | # Training and Validation datasets 129 | train_dataset = dataset(file_path, train=True, download=download, 130 | transform=train_transform, 131 | target_transform=target_transform) 132 | 133 | test_dataset = dataset(file_path, train=False, download=download, 134 | transform=test_transform, 135 | target_transform=target_transform) 136 | 137 | return train_dataset, test_dataset 138 | -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | import numpy as np 5 | 6 | from tensorboardX import SummaryWriter 7 | from tqdm import tqdm 8 | import os 9 | 10 | 11 | from models import * 12 | from utils import * 13 | from data import * 14 | from loss import * 15 | 16 | 17 | parser = argparse.ArgumentParser(description='SIMCLR-CLASSI') 18 | 19 | parser.add_argument('--uid', type=str, default='SimCLR-CLASSI', 20 | help='Staging identifier (default: SimCLR-CLASSI)') 21 | parser.add_argument('--load-model', type=str, default=None, 22 | help='Load model for feature extraction (default None)') 23 | 24 | parser.add_argument('--dataset-name', type=str, default='CIFAR10', 25 | help='Name of dataset (default: CIFAR10') 26 | parser.add_argument('--data-dir', type=str, default='data', 27 | help='Path to dataset (default: data') 28 | parser.add_argument('--feature-size', type=int, default=128, 29 | help='Feature output size (default: 128') 30 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 31 | help='input training batch-size') 32 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 33 | help='number of training epochs (default: 150)') 34 | parser.add_argument('--lr', type=float, default=1e-3, 35 | help='learning rate (default: 1e-3') 36 | parser.add_argument("--decay-lr", default=1e-6, action="store", type=float, 37 | help='Learning rate decay (default: 1e-6') 38 | parser.add_argument('--log-dir', type=str, default='runs', 39 | help='logging directory (default: runs)') 40 | parser.add_argument('--no-cuda', action='store_true', default=False, 41 | help='disables cuda (default: False') 42 | parser.add_argument('--device-id', type=int, default=0, 43 | help='GPU device id (default: 0') 44 | 45 | args = parser.parse_args() 46 | 47 | # Set cuda 48 | use_cuda = not args.no_cuda and torch.cuda.is_available() 49 | 50 | if use_cuda: 51 | dtype = torch.cuda.FloatTensor 52 | device = torch.device("cuda") 53 | torch.cuda.set_device(args.device_id) 54 | print('GPU') 55 | else: 56 | dtype = torch.FloatTensor 57 | device = torch.device("cpu") 58 | 59 | # Setup tensorboard 60 | use_tb = args.log_dir is not None 61 | log_dir = args.log_dir 62 | 63 | # Setup asset directories 64 | if not os.path.exists('models'): 65 | os.makedirs('models') 66 | 67 | if not os.path.exists('runs'): 68 | os.makedirs('runs') 69 | 70 | # Logger 71 | if use_tb: 72 | logger = SummaryWriter(comment='_' + args.uid + '_' + args.dataset_name) 73 | 74 | # Datasets 75 | train_transforms = transforms.Compose([ 76 | transforms.ToTensor(), 77 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 78 | ]) 79 | 80 | test_transforms = transforms.Compose([ 81 | transforms.ToTensor(), 82 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 83 | ]) 84 | 85 | 86 | loader = Loader(args.dataset_name, args.data_dir, True, args.batch_size, train_transforms, test_transforms, None, use_cuda) 87 | train_loader = loader.train_loader 88 | test_loader = loader.test_loader 89 | 90 | 91 | # train validate 92 | def train_validate(classifier_model, feature_model, loader, optimizer, is_train, epoch, use_cuda): 93 | 94 | loss_func = nn.CrossEntropyLoss() 95 | 96 | data_loader = loader.train_loader if is_train else loader.test_loader 97 | 98 | classifier_model.train() if is_train else classifier_model.eval() 99 | desc = 'Train' if is_train else 'Validation' 100 | 101 | total_loss = 0 102 | total_acc = 0 103 | 104 | tqdm_bar = tqdm(data_loader) 105 | for batch_idx, (x, y) in enumerate(tqdm_bar): 106 | batch_loss = 0 107 | batch_acc = 0 108 | 109 | x = x.cuda() if use_cuda else x 110 | y = y.cuda() if use_cuda else y 111 | 112 | # Get features 113 | f_x, _ = feature_model(x) 114 | f_x = f_x.detach() 115 | 116 | # Classify features 117 | y_hat = classifier_model(f_x) 118 | 119 | loss = loss_func(y_hat, y) 120 | 121 | if is_train: 122 | classifier_model.zero_grad() 123 | loss.backward() 124 | optimizer.step() 125 | 126 | # Reporting 127 | batch_loss = loss.item() / x.size(0) 128 | total_loss += loss.item() 129 | 130 | pred = y_hat.max(dim=1)[1] 131 | correct = pred.eq(y).sum().item() 132 | correct /= y.size(0) 133 | batch_acc = (correct * 100) 134 | total_acc += batch_acc 135 | 136 | tqdm_bar.set_description('{} Epoch: [{}] Batch Loss: {:.4f} Batch Acc: {:.4f}'.format(desc, epoch, batch_loss, batch_acc)) 137 | 138 | return total_loss / (batch_idx + 1), total_acc / (batch_idx + 1) 139 | 140 | 141 | def execute_graph(classifier_model, feature_model, loader, optimizer, epoch, use_cuda): 142 | t_loss, t_acc = train_validate(classifier_model, feature_model, loader, optimizer, True, epoch, use_cuda) 143 | v_loss, v_acc = train_validate(classifier_model, feature_model, loader, optimizer, False, epoch, use_cuda) 144 | 145 | if use_tb: 146 | logger.add_scalar(log_dir + '/train-loss', t_loss, epoch) 147 | logger.add_scalar(log_dir + '/valid-loss', v_loss, epoch) 148 | 149 | logger.add_scalar(log_dir + '/train-acc', t_acc, epoch) 150 | logger.add_scalar(log_dir + '/valid-acc', v_acc, epoch) 151 | 152 | # print('Epoch: {} Train loss {}'.format(epoch, t_loss)) 153 | # print('Epoch: {} Valid loss {}'.format(epoch, v_loss)) 154 | 155 | return v_loss 156 | 157 | 158 | # 159 | # Load feature extraction model 160 | feature_model = resnet50_cifar(args.feature_size).type(dtype) 161 | feature_model.eval() 162 | 163 | if os.path.isfile(args.load_model): 164 | checkpoint = torch.load(args.load_model) 165 | feature_model.load_state_dict(checkpoint['model']) 166 | epoch = checkpoint['epoch'] 167 | print('Loading model: {}, from epoch: {}'.format(args.load_model, epoch)) 168 | else: 169 | print('Model: {} not found'.format(args.load_model)) 170 | 171 | # 172 | # Define linear classification model 173 | classifier_model = SimpleNet().type(dtype) 174 | optimizer = optim.Adam(classifier_model.parameters(), lr=args.lr, weight_decay=args.decay_lr) 175 | 176 | 177 | # Main training loop 178 | best_loss = np.inf 179 | 180 | for epoch in range(args.epochs): 181 | execute_graph(classifier_model, feature_model, loader, optimizer, epoch, use_cuda) 182 | 183 | # TensorboardX logger 184 | logger.close() 185 | 186 | # save model / restart training 187 | -------------------------------------------------------------------------------- /train_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | from torch.optim.lr_scheduler import ExponentialLR 5 | import numpy as np 6 | 7 | from tensorboardX import SummaryWriter 8 | from tqdm import tqdm 9 | import os 10 | import time 11 | 12 | 13 | from models import * 14 | from utils import * 15 | from data import * 16 | from loss import * 17 | 18 | parser = argparse.ArgumentParser(description='SIMCLR') 19 | 20 | parser.add_argument('--uid', type=str, default='SimCLR', 21 | help='Staging identifier (default: SimCLR)') 22 | parser.add_argument('--dataset-name', type=str, default='CIFAR10C', 23 | help='Name of dataset (default: CIFAR10C') 24 | parser.add_argument('--data-dir', type=str, default='data', 25 | help='Path to dataset (default: data') 26 | parser.add_argument('--feature-size', type=int, default=128, 27 | help='Feature output size (default: 128') 28 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 29 | help='input training batch-size') 30 | parser.add_argument('--accumulation-steps', type=int, default=4, metavar='N', 31 | help='Gradient accumulation steps (default: 4') 32 | parser.add_argument('--epochs', type=int, default=150, metavar='N', 33 | help='number of training epochs (default: 150)') 34 | parser.add_argument('--lr', type=float, default=1e-3, 35 | help='learning rate (default: 1e-3') 36 | parser.add_argument("--decay-lr", default=1e-6, action="store", type=float, 37 | help='Learning rate decay (default: 1e-6') 38 | parser.add_argument('--tau', default=0.5, type=float, 39 | help='Tau temperature smoothing (default 0.5)') 40 | parser.add_argument('--log-dir', type=str, default='runs', 41 | help='logging directory (default: runs)') 42 | parser.add_argument('--no-cuda', action='store_true', default=False, 43 | help='disables cuda (default: False') 44 | parser.add_argument('--load-model', type=str, default=None, 45 | help='Load model to resume training for (default None)') 46 | parser.add_argument('--device-id', type=int, default=0, 47 | help='GPU device id (default: 0') 48 | 49 | args = parser.parse_args() 50 | 51 | # Set cuda 52 | use_cuda = not args.no_cuda and torch.cuda.is_available() 53 | 54 | if use_cuda: 55 | dtype = torch.cuda.FloatTensor 56 | device = torch.device("cuda") 57 | torch.cuda.set_device(args.device_id) 58 | print('GPU') 59 | else: 60 | dtype = torch.FloatTensor 61 | device = torch.device("cpu") 62 | 63 | # Setup tensorboard 64 | use_tb = args.log_dir is not None 65 | log_dir = args.log_dir 66 | 67 | # Setup asset directories 68 | if not os.path.exists('models'): 69 | os.makedirs('models') 70 | 71 | if not os.path.exists('runs'): 72 | os.makedirs('runs') 73 | 74 | # Logger 75 | if use_tb: 76 | logger = SummaryWriter(comment='_' + args.uid + '_' + args.dataset_name) 77 | 78 | if args.dataset_name == 'CIFAR10C': 79 | in_channels = 3 80 | # Get train and test loaders for dataset 81 | train_transforms = cifar_train_transforms() 82 | test_transforms = cifar_test_transforms() 83 | target_transforms = None 84 | 85 | loader = Loader(args.dataset_name, args.data_dir, True, args.batch_size, train_transforms, test_transforms, target_transforms, use_cuda) 86 | train_loader = loader.train_loader 87 | test_loader = loader.test_loader 88 | 89 | 90 | # train validate 91 | def train_validate(model, loader, optimizer, is_train, epoch, use_cuda): 92 | 93 | loss_func = contrastive_loss(tau=args.tau) 94 | 95 | data_loader = loader.train_loader if is_train else loader.test_loader 96 | 97 | if is_train: 98 | model.train() 99 | model.zero_grad() 100 | else: 101 | model.eval() 102 | 103 | desc = 'Train' if is_train else 'Validation' 104 | 105 | total_loss = 0.0 106 | 107 | tqdm_bar = tqdm(data_loader) 108 | for i, (x_i, x_j, _) in enumerate(tqdm_bar): 109 | 110 | x_i = x_i.cuda() if use_cuda else x_i 111 | x_j = x_j.cuda() if use_cuda else x_j 112 | 113 | _, z_i = model(x_i) 114 | _, z_j = model(x_j) 115 | 116 | loss = loss_func(z_i, z_j) 117 | loss /= args.accumulation_steps 118 | 119 | if is_train: 120 | loss.backward() 121 | 122 | if (i + 1) % args.accumulation_steps == 0 and is_train: 123 | optimizer.step() 124 | model.zero_grad() 125 | 126 | total_loss += loss.item() 127 | 128 | tqdm_bar.set_description('{} Epoch: [{}] Loss: {:.4f}'.format(desc, epoch, loss.item())) 129 | 130 | return total_loss / (len(data_loader.dataset)) 131 | 132 | 133 | def execute_graph(model, loader, optimizer, scheduler, epoch, use_cuda): 134 | t_loss = train_validate(model, loader, optimizer, True, epoch, use_cuda) 135 | v_loss = train_validate(model, loader, optimizer, False, epoch, use_cuda) 136 | 137 | scheduler.step(v_loss) 138 | 139 | if use_tb: 140 | logger.add_scalar(log_dir + '/train-loss', t_loss, epoch) 141 | logger.add_scalar(log_dir + '/valid-loss', v_loss, epoch) 142 | 143 | return v_loss 144 | 145 | 146 | model = resnet50_cifar(args.feature_size).type(dtype) 147 | 148 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay_lr) 149 | scheduler = ExponentialLR(optimizer, gamma=args.decay_lr) 150 | 151 | 152 | # Main training loop 153 | best_loss = np.inf 154 | 155 | # Resume training 156 | if args.load_model is not None: 157 | if os.path.isfile(args.load_model): 158 | checkpoint = torch.load(args.load_model) 159 | model.load_state_dict(checkpoint['model']) 160 | optimizer.load_state_dict(checkpoint['optimizer']) 161 | scheduler.load_state_dict(checkpoint['scheduler']) 162 | best_loss = checkpoint['val_loss'] 163 | epoch = checkpoint['epoch'] 164 | print('Loading model: {}. Resuming from epoch: {}'.format(args.load_model, epoch)) 165 | else: 166 | print('Model: {} not found'.format(args.load_model)) 167 | 168 | for epoch in range(args.epochs): 169 | v_loss = execute_graph(model, loader, optimizer, scheduler, epoch, use_cuda) 170 | 171 | if v_loss < best_loss: 172 | best_loss = v_loss 173 | print('Writing model checkpoint') 174 | state = { 175 | 'epoch': epoch, 176 | 'model': model.state_dict(), 177 | 'optimizer': optimizer.state_dict(), 178 | 'scheduler': scheduler.state_dict(), 179 | 'val_loss': v_loss 180 | } 181 | t = time.localtime() 182 | timestamp = time.strftime('%b-%d-%Y_%H%M', t) 183 | file_name = 'models/{}_{}_{}_{:04.4f}.pt'.format(timestamp, args.uid, epoch, v_loss) 184 | 185 | torch.save(state, file_name) 186 | 187 | 188 | # TensorboardX logger 189 | logger.close() 190 | 191 | # save model / restart training 192 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | import torch.nn.functional as F 5 | from torch.distributions import Normal, Independent 6 | 7 | 8 | __all__ = ['resnet50_cifar', 'resnet18_cifar', 'SimpleNet'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=dilation, groups=groups, bias=False, dilation=dilation) 21 | 22 | 23 | def conv1x1(in_planes, out_planes, stride=1): 24 | """1x1 convolution""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | __constants__ = ['downsample'] 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 33 | base_width=64, dilation=1, norm_layer=None): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = norm_layer(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = norm_layer(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | if self.downsample is not None: 61 | identity = self.downsample(x) 62 | 63 | out += identity 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | class Bottleneck(nn.Module): 70 | expansion = 4 71 | __constants__ = ['downsample'] 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 74 | base_width=64, dilation=1, norm_layer=None): 75 | super(Bottleneck, self).__init__() 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | width = int(planes * (base_width / 64.)) * groups 79 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 80 | self.conv1 = conv1x1(inplanes, width) 81 | self.bn1 = norm_layer(width) 82 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 83 | self.bn2 = norm_layer(width) 84 | self.conv3 = conv1x1(width, planes * self.expansion) 85 | self.bn3 = norm_layer(planes * self.expansion) 86 | self.relu = nn.ReLU(inplace=True) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | identity = x 92 | 93 | out = self.conv1(x) 94 | out = self.bn1(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv2(out) 98 | out = self.bn2(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv3(out) 102 | out = self.bn3(out) 103 | 104 | if self.downsample is not None: 105 | identity = self.downsample(x) 106 | 107 | out += identity 108 | out = self.relu(out) 109 | 110 | return out 111 | 112 | 113 | class ResNet(nn.Module): 114 | 115 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 116 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 117 | norm_layer=None): 118 | super(ResNet, self).__init__() 119 | if norm_layer is None: 120 | norm_layer = nn.BatchNorm2d 121 | self._norm_layer = norm_layer 122 | 123 | self.inplanes = 64 124 | self.dilation = 1 125 | if replace_stride_with_dilation is None: 126 | # each element in the tuple indicates if we should replace 127 | # the 2x2 stride with a dilated convolution instead 128 | replace_stride_with_dilation = [False, False, False] 129 | if len(replace_stride_with_dilation) != 3: 130 | raise ValueError("replace_stride_with_dilation should be None " 131 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 132 | self.groups = groups 133 | self.base_width = width_per_group 134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 135 | bias=False) 136 | self.bn1 = norm_layer(self.inplanes) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | self.layer1 = self._make_layer(block, 64, layers[0]) 140 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 141 | dilate=replace_stride_with_dilation[0]) 142 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 143 | dilate=replace_stride_with_dilation[1]) 144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 145 | dilate=replace_stride_with_dilation[2]) 146 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 147 | self.fc = nn.Linear(512 * block.expansion, num_classes) 148 | 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 152 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 153 | nn.init.constant_(m.weight, 1) 154 | nn.init.constant_(m.bias, 0) 155 | 156 | # Zero-initialize the last BN in each residual branch, 157 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 158 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 159 | if zero_init_residual: 160 | for m in self.modules(): 161 | if isinstance(m, Bottleneck): 162 | nn.init.constant_(m.bn3.weight, 0) 163 | elif isinstance(m, BasicBlock): 164 | nn.init.constant_(m.bn2.weight, 0) 165 | 166 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 167 | norm_layer = self._norm_layer 168 | downsample = None 169 | previous_dilation = self.dilation 170 | if dilate: 171 | self.dilation *= stride 172 | stride = 1 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | conv1x1(self.inplanes, planes * block.expansion, stride), 176 | norm_layer(planes * block.expansion), 177 | ) 178 | 179 | layers = [] 180 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 181 | self.base_width, previous_dilation, norm_layer)) 182 | self.inplanes = planes * block.expansion 183 | for _ in range(1, blocks): 184 | layers.append(block(self.inplanes, planes, groups=self.groups, 185 | base_width=self.base_width, dilation=self.dilation, 186 | norm_layer=norm_layer)) 187 | 188 | return nn.Sequential(*layers) 189 | 190 | def _forward_impl(self, x): 191 | # See note [TorchScript super()] 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | x = self.maxpool(x) 196 | 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | 202 | x = self.avgpool(x) 203 | x = torch.flatten(x, 1) 204 | x = self.fc(x) 205 | 206 | return x 207 | 208 | def forward(self, x): 209 | return self._forward_impl(x) 210 | 211 | 212 | class Bottleneck_CIFAR(nn.Module): 213 | expansion = 4 214 | 215 | def __init__(self, in_planes, planes, stride=1): 216 | super(Bottleneck_CIFAR, self).__init__() 217 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 218 | self.bn1 = nn.BatchNorm2d(planes) 219 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 220 | self.bn2 = nn.BatchNorm2d(planes) 221 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 222 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 223 | 224 | self.shortcut = nn.Sequential() 225 | if stride != 1 or in_planes != self.expansion * planes: 226 | self.shortcut = nn.Sequential( 227 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 228 | nn.BatchNorm2d(self.expansion * planes) 229 | ) 230 | 231 | def forward(self, x): 232 | out = F.relu(self.bn1(self.conv1(x))) 233 | out = F.relu(self.bn2(self.conv2(out))) 234 | out = self.bn3(self.conv3(out)) 235 | out += self.shortcut(x) 236 | out = F.relu(out) 237 | return out 238 | 239 | 240 | class BasicBlock_CIFAR(nn.Module): 241 | expansion = 1 242 | 243 | def __init__(self, in_planes, planes, stride=1): 244 | super(BasicBlock_CIFAR, self).__init__() 245 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 246 | self.bn1 = nn.BatchNorm2d(planes) 247 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 248 | self.bn2 = nn.BatchNorm2d(planes) 249 | 250 | self.shortcut = nn.Sequential() 251 | if stride != 1 or in_planes != self.expansion * planes: 252 | self.shortcut = nn.Sequential( 253 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 254 | nn.BatchNorm2d(self.expansion * planes) 255 | ) 256 | 257 | def forward(self, x): 258 | out = F.relu(self.bn1(self.conv1(x))) 259 | out = self.bn2(self.conv2(out)) 260 | out += self.shortcut(x) 261 | out = F.relu(out) 262 | return out 263 | 264 | 265 | class ResNet_CIFAR(nn.Module): 266 | 267 | def __init__(self, block, layers, num_features=128, reparam=False): 268 | super(ResNet_CIFAR, self).__init__() 269 | self.in_planes = 64 270 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 271 | self.bn1 = nn.BatchNorm2d(64) 272 | self.num_features = num_features 273 | self.reparam = reparam 274 | 275 | self.layer_block = nn.Sequential( 276 | self._make_layer(block, 64, layers[0], stride=1), 277 | self._make_layer(block, 128, layers[1], stride=2), 278 | self._make_layer(block, 256, layers[2], stride=2), 279 | self._make_layer(block, 512, layers[2], stride=2) 280 | ) 281 | 282 | self.linear = nn.Linear(512 * block.expansion, num_features) 283 | 284 | self.learning_head = nn.Sequential( 285 | nn.Linear(512 * block.expansion, 512), 286 | nn.BatchNorm1d(512), 287 | nn.ReLU(), 288 | nn.Linear(512, num_features) 289 | ) 290 | 291 | def _make_layer(self, block, planes, num_blocks, stride=1): 292 | strides = [stride] + [1] * (num_blocks - 1) 293 | layers = [] 294 | for stride in strides: 295 | layers.append(block(self.in_planes, planes, stride)) 296 | self.in_planes = planes * block.expansion 297 | return nn.Sequential(*layers) 298 | 299 | def forward(self, x): 300 | x = self.conv1(x) 301 | x = F.relu(self.bn1(x)) 302 | 303 | x = self.layer_block(x) 304 | 305 | x = F.avg_pool2d(x, 4) 306 | 307 | # Feature representation, 2048 308 | h_out = x.view(x.size(0), -1) 309 | 310 | # Learning output 311 | z_out = self.learning_head(h_out) 312 | 313 | if self.reparam: 314 | mu, sigma = z_out[:, :self.num_features // 2], z_out[:, self.num_features // 2:] 315 | sigma = F.softplus(sigma) + 1e-7 316 | return F.normalize(h_out, dim=-1), Independent(Normal(loc=mu, scale=sigma), 1) 317 | else: 318 | return F.normalize(h_out, dim=-1), F.normalize(z_out, dim=-1) 319 | 320 | 321 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 322 | model = ResNet(block, layers, **kwargs) 323 | if pretrained: 324 | state_dict = load_state_dict_from_url(model_urls[arch], 325 | progress=progress) 326 | model.load_state_dict(state_dict) 327 | return model 328 | 329 | 330 | def resnet18(pretrained=False, progress=True, **kwargs): 331 | """ResNet-18 model from 332 | `"Deep Residual Learning for Image Recognition" `_ 333 | Args: 334 | pretrained (bool): If True, returns a model pre-trained on ImageNet 335 | progress (bool): If True, displays a progress bar of the download to stderr 336 | """ 337 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 338 | **kwargs) 339 | 340 | 341 | def resnet50(pretrained=False, progress=True, **kwargs): 342 | """ResNet-50 model from 343 | `"Deep Residual Learning for Image Recognition" `_ 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | progress (bool): If True, displays a progress bar of the download to stderr 347 | """ 348 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 349 | **kwargs) 350 | 351 | 352 | def resnet50_cifar(num_features=128, reparam=False): 353 | return ResNet_CIFAR(Bottleneck_CIFAR, [3, 4, 6, 3], num_features, reparam) 354 | 355 | 356 | def resnet18_cifar(num_features=128, reparam=False): 357 | return ResNet_CIFAR(BasicBlock_CIFAR, [2, 2, 2, 2], num_features, reparam) 358 | 359 | 360 | class SimpleNet(nn.Module): 361 | def __init__(self, input_dim=2048, num_classes=10): 362 | super(SimpleNet, self).__init__() 363 | self.input_dim = input_dim 364 | self.num_classes = num_classes 365 | 366 | self.fc = nn.Linear(self.input_dim, self.num_classes) 367 | 368 | def forward(self, x): 369 | x = self.fc(x) 370 | return x 371 | -------------------------------------------------------------------------------- /assets/runs_valid-acc.svg: -------------------------------------------------------------------------------- 1 | 8383.58484.58585.58686.587-20020406080100120140160180200220240260280 -------------------------------------------------------------------------------- /assets/runs_train-acc.svg: -------------------------------------------------------------------------------- 1 | 8484.58585.58686.58787.58888.58989.590-20020406080100120140160180200220240260280 --------------------------------------------------------------------------------