├── README.md ├── config.conf ├── media └── CosineAnnealingLR.png ├── requirements.txt └── src ├── datasets.py ├── kill_zombie.sh ├── main.py ├── model ├── __init__.py ├── moco.py └── network.py ├── optimisers.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Unofficial Pytorch Implementation of MocCoV2 2 | 3 | Unofficial Pytorch implemnentation of [MoCoV2](https://arxiv.org/abs/2003.04297): "Improved Baselines with Momentum Contrastive Learning." 4 | 5 | This repo uses elements from [CMC](https://github.com/HobbitLong/CMC) and [MoCo Code](https://github.com/facebookresearch/moco), in which the MoCoV2 model is implementated into my existing pytorch Boilerplate and workflow. Additionally, this repo aims to align nicely with my implementation of [SimCLR](https://arxiv.org/pdf/2002.05709.pdf) found [here:](https://github.com/AidenDurrant/SimCLR-Pytorch/). 6 | 7 | Work in progress, replicating results on ImageNet, TinyImageNet, CIFAR10, CIFAR100, STL10. 8 | 9 | * **Author**: Aiden Durrant 10 | * **Email**: adurrant@lincoln.ac.uk 11 | 12 | ### Results: 13 | 14 | Top-1 Acc / Error of linear evaluation on CIFAR10: 15 | 16 | Testing is performed on the CIFAR10 Val set, whilst the Train set is split into Train and Val for tuning. 17 | 18 | | Method | Batch Size | ResNet | Projection Head Dim. | Pre-train Epochs | Optimizer | Eval Epochs | Acc(%) | 19 | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | 20 | | MoCoV2 + Linear eval. | 128 | ResNet18 | 128 | 1000 | Adam | 100 | - | 21 | | MoCoV2 + Linear eval. | 128 | ResNet18 | 128 | 1000 | SGD | 100 | - | 22 | | MoCoV2 + Linear eval. | 128 | ResNet34 | 128 | 1000 | Adam | 100 | - | 23 | | MoCoV2 + Linear eval. | 128 | ResNet34 | 128 | 1000 | SGD | 100 | - | 24 | | MoCoV2 + Linear eval. | 128 | ResNet50 | 128 | 1000 | Adam | 100 | - | 25 | | MoCoV2 + Linear eval. | 128 | ResNet50 | 128 | 1000 | SGD | 100 | - | 26 | | Supervised + Linear eval.| 128 | ResNet18 | 128 | 1000 | SGD | 100 | - | 27 | | Random Init + Linear eval.| 128 | ResNet18 | 128 | 1000 | SGD | 100 | - | 28 | 29 | **Note**: For Linear Evaluation the ResNet is frozen (all layers), training is only perfomed on the supervised Linear Evaluation layer. 30 | 31 | ### Plots: 32 | 33 | **ResNet-18** 34 | 35 | 36 | 37 | 38 | **ResNet-50** 39 | 40 | 41 | 42 | ## Usage / Run 43 | 44 | ### Contrastive Training and Linear Evaluation 45 | Launch the script from `src/main.py`: 46 | 47 | By default the CIFAR-10 dataset is used, use `--dataset` to select from: cifar10, cifar100, stl10, imagenet, tinyimagenet. For ImageNet and TinyImageNet please define a path to the dataset. 48 | 49 | Training uses CosineAnnealingLR decay and linear warmup to replicate the training settings in https://arxiv.org/pdf/2002.05709.pdf. The learning_rate is plotted below: 50 | 51 | 52 | 53 | #### DistributedDataParallel 54 | 55 | To train with **Distributed** for a slight computational speedup with multiple GPUs, use: 56 | 57 | `python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py` 58 | 59 | 60 | This will train on a single machine (`nnodes=1`), assigning 1 process per GPU where `nproc_per_node=2` refers to training on 2 GPUs. To train on `N` GPUs simply launch `N` processes by setting `nproc_per_node=N`. 61 | 62 | The number of CPU threads to use per process is hard coded to `torch.set_num_threads(1)` for safety, and can be changed to `your # cpu threads / nproc_per_node` for better performance. ([fabio-deep](https://github.com/fabio-deep/Distributed-Pytorch-Boilerplate)) 63 | 64 | For more info on **multi-node** and **multi-gpu** distributed training refer to https://github.com/hgrover/pytorchdistr/blob/master/README.md 65 | 66 | #### DataParallel 67 | 68 | To train with traditional **nn.DataParallel** with multiple GPUs, use: 69 | 70 | `python main.py --no_distributed` 71 | 72 | **Note:** The default config selects to use `--no_distributed`, therefore runnning `python main.py` runs the default hyperparameters without DistributedDataParallel. 73 | 74 | ### Linear Evaluation of a Pre-Trained Model 75 | 76 | To evaluate the performace of a pre-trained model in a linear classification task just include the flag `--finetune` and provide a path to the pretrained model to `--load_checkpoint_dir`. 77 | 78 | Example: 79 | 80 | `python main.py --no_distributed --finetune --load_checkpoint_dir ~/Documents/MoCo-Pytorch/experiments/yyyy-mm-dd_hh-mm-ss/checkpoint.pt` 81 | 82 | ### Hyperparameters 83 | 84 | The configuration / choice of hyperparameters for the script is handled either by command line arguments or config files. 85 | 86 | An example config file is given at `MoCo-Pytorch/config.conf`. Additionally, `.txt` or `.conf` files can be passed if you prefer, this is achieved using the flag `--c `. 87 | 88 | A list of arguments/options can be found below: 89 | 90 | ``` 91 | usage: main.py [-h] [-c MY_CONFIG] [--dataset DATASET] 92 | [--dataset_path DATASET_PATH] [--model MODEL] 93 | [--n_epochs N_EPOCHS] [--finetune_epochs FINETUNE_EPOCHS] 94 | [--warmup_epochs WARMUP_EPOCHS] [--batch_size BATCH_SIZE] 95 | [--learning_rate LEARNING_RATE] [--base_lr BASE_LR] 96 | [--finetune_learning_rate FINETUNE_LEARNING_RATE] 97 | [--weight_decay WEIGHT_DECAY] 98 | [--finetune_weight_decay FINETUNE_WEIGHT_DECAY] 99 | [--optimiser OPTIMISER] [--patience PATIENCE] 100 | [--queue_size QUEUE_SIZE] [--queue_momentum QUEUE_MOMENTUM] 101 | [--temperature TEMPERATURE] [--jitter_d JITTER_D] 102 | [--jitter_p JITTER_P] [--blur_sigma BLUR_SIGMA BLUR_SIGMA] 103 | [--blur_p BLUR_P] [--grey_p GREY_P] [--no_twocrop] 104 | [--load_checkpoint_dir LOAD_CHECKPOINT_DIR] [--no_distributed] 105 | [--finetune] [--supervised] 106 | 107 | Pytorch MocoV2 Args that start with '--' (eg. --dataset) can also be set in a 108 | config file (/MoCo-Pytorch/config.conf or specified via -c). Config 109 | file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see 110 | syntax at https://goo.gl/R74nmi). If an arg is specified in more than one 111 | place, then commandline values override config file values which override 112 | defaults. 113 | 114 | optional arguments: 115 | -h, --help show this help message and exit 116 | -c MY_CONFIG, --my-config MY_CONFIG 117 | config file path 118 | --dataset DATASET Dataset, (Options: cifar10, cifar100, stl10, imagenet, 119 | tinyimagenet). 120 | --dataset_path DATASET_PATH 121 | Path to dataset, Not needed for TorchVision Datasets. 122 | --model MODEL Model, (Options: resnet18, resnet34, resnet50, 123 | resnet101, resnet152). 124 | --n_epochs N_EPOCHS Number of Epochs in Contrastive Training. 125 | --finetune_epochs FINETUNE_EPOCHS 126 | Number of Epochs in Linear Classification Training. 127 | --warmup_epochs WARMUP_EPOCHS 128 | Number of Warmup Epochs During Contrastive Training. 129 | --batch_size BATCH_SIZE 130 | Number of Samples Per Batch. 131 | --learning_rate LEARNING_RATE 132 | Starting Learing Rate for Contrastive Training. 133 | --base_lr BASE_LR Base / Minimum Learing Rate to Begin Linear Warmup. 134 | --finetune_learning_rate FINETUNE_LEARNING_RATE 135 | Starting Learing Rate for Linear Classification 136 | Training. 137 | --weight_decay WEIGHT_DECAY 138 | Contrastive Learning Weight Decay Regularisation 139 | Factor. 140 | --finetune_weight_decay FINETUNE_WEIGHT_DECAY 141 | Linear Classification Training Weight Decay 142 | Regularisation Factor. 143 | --optimiser OPTIMISER 144 | Optimiser, (Options: sgd, adam, lars). 145 | --patience PATIENCE Number of Epochs to Wait for Improvement. 146 | --queue_size QUEUE_SIZE 147 | Size of Memory Queue, Must be Divisible by batch_size. 148 | --queue_momentum QUEUE_MOMENTUM 149 | Momentum for the Key Encoder Update. 150 | --temperature TEMPERATURE 151 | InfoNCE Temperature Factor 152 | --jitter_d JITTER_D Distortion Factor for the Random Colour Jitter 153 | Augmentation 154 | --jitter_p JITTER_P Probability to Apply Random Colour Jitter Augmentation 155 | --blur_sigma BLUR_SIGMA BLUR_SIGMA 156 | Radius to Apply Random Colour Jitter Augmentation 157 | --blur_p BLUR_P Probability to Apply Gaussian Blur Augmentation 158 | --grey_p GREY_P Probability to Apply Random Grey Scale 159 | --no_twocrop Whether or Not to Use Two Crop Augmentation, Used to 160 | Create Two Views of the Input for Contrastive 161 | Learning. (Default: True) 162 | --load_checkpoint_dir LOAD_CHECKPOINT_DIR 163 | Path to Load Pre-trained Model From. 164 | --no_distributed Whether or Not to Use Distributed Training. (Default: 165 | True) 166 | --finetune Perform Only Linear Classification Training. (Default: 167 | False) 168 | --supervised Perform Supervised Pre-Training. (Default: False) 169 | ``` 170 | 171 | ## Dependencies 172 | 173 | Install dependencies with `requrements.txt` 174 | 175 | `pip install -r requrements.txt` 176 | 177 | ``` 178 | torch 179 | torchvision 180 | tensorboard 181 | tqdm 182 | configargparse 183 | ``` 184 | 185 | ## References 186 | * K. He, et. al [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/abs/1911.05722) 187 | 188 | 189 | * X. Chen, et. al [Improved Baselines with Momentum Contrastive Learning](https://arxiv.org/abs/2003.04297) 190 | 191 | 192 | * facebookresearch [MoCo Code](https://github.com/facebookresearch/moco) 193 | 194 | 195 | * HobbitLong [CMC](https://github.com/HobbitLong/CMC) 196 | 197 | 198 | * T. Chen, et. al [SimCLR Paper](https://arxiv.org/pdf/2002.05709.pdf) 199 | 200 | 201 | * noahgolmant [pytorch-lars](https://github.com/noahgolmant/pytorch-lars) 202 | 203 | 204 | * pytorch [torchvision ResNet](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) 205 | 206 | 207 | * fabio-deep [Distributed-Pytorch-Boilerplate](https://github.com/fabio-deep/Distributed-Pytorch-Boilerplate) 208 | 209 | ## TODO 210 | - [ ] Command Line Argument for MoCo V1 or V2 211 | - [ ] Research and Implement BatchNorm Shuffle for DistributedDataParallel 212 | - [ ] Run All Experiment Comparisons 213 | -------------------------------------------------------------------------------- /config.conf: -------------------------------------------------------------------------------- 1 | # Config File for MoCo 2 | 3 | # Datset 4 | --dataset=cifar10 # Dataset 5 | --dataset_path=None # Path to dataset, Not needed for TorchVision Datasets. 6 | 7 | # Model 8 | --model=resnet18 # Model 9 | 10 | # Epochs 11 | --n_epochs=200 # Number of Epochs in Contrastive Training. 12 | --finetune_epochs=100 # Number of Epochs in Linear Classification Training. 13 | --warmup_epochs=10 # Number of Warmup Epochs During Contrastive Training. 14 | 15 | # Core Training Params 16 | --batch_size=128 # Number of Samples Per Batch. 17 | --learning_rate=0.015 # Starting Learing Rate for Contrastive Training. 18 | --base_lr=0.0001 # Base / Minimum Learing Rate to Begin Linear Warmup. 19 | --finetune_learning_rate=10.0 # Starting Learing Rate for Linear Classification 20 | 21 | # Regularisation 22 | --weight_decay=1e-6 # Contrastive Learning Weight Decay 23 | --finetune_weight_decay=0.0 # Linear Classification Training Weight Decay 24 | --patience=100 # Number of Epochs to Wait for Improvement. 25 | 26 | # Optimiser 27 | --optimiser=sgd # Optimiser 28 | 29 | # MoCo Options 30 | --queue_size=65536 # Size of Memory Queue, Must be Divisible by batch_size. 31 | --queue_momentum=0.99 # Momentum for the Key Encoder Update. 32 | --temperature=0.07 # InfoNCE Temperature Factor 33 | 34 | # Augmentation 35 | --jitter_d=0.5 # Distortion Factor for the Random Colour Jitter 36 | --jitter_p=0.8 # Probability to Apply Random Colour Jitter 37 | --blur_sigma=[0.1,2.0] # Radius to Apply Random Colour Jitter 38 | --blur_p=0.5 # Probability to Apply Gaussian Blur 39 | --grey_p=0.2 # Probability to Apply Random Grey Scale 40 | ; --no_twocrop # Whether or Not to Use Two Crop Augmentation 41 | 42 | 43 | # Distirbuted Options 44 | --no_distributed # Whether or Not to Use Distributed Training 45 | 46 | 47 | # Finetune Options 48 | ; --finetune # Perform Only Linear Classification Training 49 | ; --supervised # Perform Supervised Pre-Training 50 | ; --load_checkpoint_dir= # Path to Load Pre-trained Model 51 | -------------------------------------------------------------------------------- /media/CosineAnnealingLR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AidenDurrant/MoCo-Pytorch/373b5fafdbf51e4c6c19a984b7e8d41a296cf6d9/media/CosineAnnealingLR.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tensorboard 4 | tqdm 5 | ConfigArgParse 6 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, WeightedRandomSampler 7 | from torch.utils.data.distributed import DistributedSampler 8 | 9 | import torchvision 10 | from torchvision import transforms 11 | from torchvision.datasets import CIFAR10, MNIST, STL10, ImageNet, CIFAR100, ImageFolder 12 | 13 | from utils import * 14 | 15 | 16 | def get_dataloaders(args): 17 | ''' 18 | Retrives the dataloaders for the dataset of choice. 19 | 20 | Initalise variables that correspond to the dataset of choice. 21 | 22 | args: 23 | args (dict): Program arguments/commandline arguments. 24 | 25 | returns: 26 | dataloaders (dict): pretrain,train,valid,train_valid,test set split dataloaders. 27 | 28 | args (dict): Updated and Additional program/commandline arguments dependent on dataset. 29 | 30 | ''' 31 | if args.dataset == 'cifar10': 32 | dataset = 'CIFAR10' 33 | 34 | args.class_names = ( 35 | 'plane', 'car', 'bird', 'cat', 36 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' 37 | ) # 0,1,2,3,4,5,6,7,8,9 labels 38 | 39 | args.crop_dim = 32 40 | args.n_channels, args.n_classes = 3, 10 41 | 42 | # Get and make dir to download dataset to. 43 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset) 44 | 45 | if not os.path.exists(working_dir): 46 | os.makedirs(working_dir) 47 | 48 | dataset_paths = {'train': os.path.join(working_dir, 'train'), 49 | 'test': os.path.join(working_dir, 'test')} 50 | 51 | dataloaders = cifar_dataloader(args, dataset_paths) 52 | 53 | elif args.dataset == 'cifar100': 54 | dataset = 'CIFAR100' 55 | 56 | args.class_names = None 57 | 58 | args.crop_dim = 32 59 | args.n_channels, args.n_classes = 3, 100 60 | 61 | # Get and make dir to download dataset to. 62 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset) 63 | 64 | if not os.path.exists(working_dir): 65 | os.makedirs(working_dir) 66 | 67 | dataset_paths = {'train': os.path.join(working_dir, 'train'), 68 | 'test': os.path.join(working_dir, 'test')} 69 | 70 | dataloaders = cifar_dataloader(args, dataset_paths) 71 | 72 | elif args.dataset == 'stl10': 73 | dataset = 'STL10' 74 | 75 | args.class_names = None 76 | 77 | args.crop_dim = 96 78 | args.n_channels, args.n_classes = 3, 10 79 | 80 | # Get and make dir to download dataset to. 81 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset) 82 | 83 | if not os.path.exists(working_dir): 84 | os.makedirs(working_dir) 85 | 86 | dataset_paths = {'train': os.path.join(working_dir, 'train'), 87 | 'test': os.path.join(working_dir, 'test'), 88 | 'pretrain': os.path.join(working_dir, 'unlabeled')} 89 | 90 | dataloaders = stl10_dataloader(args, dataset_paths) 91 | 92 | elif args.dataset == 'imagenet': 93 | dataset = 'ImageNet' 94 | 95 | args.class_names = None 96 | 97 | args.crop_dim = 224 98 | args.n_channels, args.n_classes = 3, 1000 99 | 100 | # Get and make dir to download dataset to. 101 | target_dir = args.dataset_path 102 | 103 | if not target_dir is None: 104 | dataset_paths = {'train': os.path.join(target_dir, 'train'), 105 | 'test': os.path.join(target_dir, 'val')} 106 | 107 | dataloaders = imagenet_dataloader(args, dataset_paths) 108 | 109 | else: 110 | NotImplementedError('Please Select a path for the {} Dataset.'.format(args.dataset)) 111 | 112 | elif args.dataset == 'tinyimagenet': 113 | dataset = 'TinyImageNet' 114 | 115 | args.class_names = None 116 | 117 | args.crop_dim = 64 118 | args.n_channels, args.n_classes = 3, 200 119 | 120 | # Get and make dir to download dataset to. 121 | target_dir = args.dataset_path 122 | 123 | if not target_dir is None: 124 | dataset_paths = {'train': os.path.join(target_dir, 'train'), 125 | 'test': os.path.join(target_dir, 'val')} 126 | 127 | dataloaders = imagenet_dataloader(args, dataset_paths) 128 | 129 | else: 130 | NotImplementedError('Please Select a path for the {} Dataset.'.format(args.dataset)) 131 | else: 132 | NotImplementedError('{} dataset not available.'.format(args.dataset)) 133 | 134 | return dataloaders, args 135 | 136 | 137 | def imagenet_dataloader(args, dataset_paths): 138 | ''' 139 | Loads the ImageNet or TinyImageNet dataset performing augmentaions. 140 | 141 | Generates splits of the training set to produce a validation set. 142 | 143 | args: 144 | args (dict): Program/commandline arguments. 145 | 146 | dataset_paths (dict): Paths to each datset split. 147 | 148 | Returns: 149 | 150 | dataloaders (): pretrain,train,valid,train_valid,test set split dataloaders. 151 | ''' 152 | 153 | # guassian_blur from https://github.com/facebookresearch/moco/ 154 | guassian_blur = transforms.RandomApply([GaussianBlur(args.blur_sigma)], p=args.blur_p) 155 | 156 | color_jitter = transforms.ColorJitter( 157 | 0.8*args.jitter_d, 0.8*args.jitter_d, 0.8*args.jitter_d, 0.2*args.jitter_d) 158 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=args.jitter_p) 159 | 160 | rnd_grey = transforms.RandomGrayscale(p=args.grey_p) 161 | 162 | # Base train and test augmentaions 163 | transf = { 164 | 'train': transforms.Compose([ 165 | transforms.RandomResizedCrop((args.crop_dim, args.crop_dim)), 166 | rnd_color_jitter, 167 | rnd_grey, 168 | guassian_blur, 169 | transforms.RandomHorizontalFlip(), 170 | transforms.ToTensor(), 171 | transforms.Normalize((0.485, 0.456, 0.406), 172 | (0.229, 0.224, 0.225))]), 173 | 'test': transforms.Compose([ 174 | transforms.CenterCrop((args.crop_dim, args.crop_dim)), 175 | transforms.ToTensor(), 176 | transforms.Normalize((0.485, 0.456, 0.406), 177 | (0.229, 0.224, 0.225))]) 178 | } 179 | 180 | config = {'train': True, 'test': False} 181 | 182 | datasets = {i: ImageFolder(root=dataset_paths[i]) for i in config.keys()} 183 | 184 | # weighted sampler weights for full(f) training set 185 | f_s_weights = sample_weights(datasets['train'].targets) 186 | 187 | # return data, labels dicts for new train set and class-balanced valid set 188 | # 50 is the num of samples to be split into the val set for each class (1000) 189 | data, labels = random_split_image_folder(data=np.asarray(datasets['train'].samples), 190 | labels=datasets['train'].targets, 191 | n_classes=args.n_classes, 192 | n_samples_per_class=np.repeat(50, args.n_classes).reshape(-1)) 193 | 194 | # torch.from_numpy(np.stack(labels)) this takes the list of class ids and turns them to tensor.long 195 | 196 | # original full training set 197 | datasets['train_valid'] = CustomDataset(data=np.asarray(datasets['train'].samples), 198 | labels=torch.from_numpy(np.stack(datasets['train'].targets)), transform=transf['train'], two_crop=args.twocrop) 199 | 200 | # original test set 201 | datasets['test'] = CustomDataset(data=np.asarray(datasets['test'].samples), 202 | labels=torch.from_numpy(np.stack(datasets['test'].targets)), transform=transf['test'], two_crop=False) 203 | 204 | # make new pretraining set without validation samples 205 | datasets['pretrain'] = CustomDataset(data=np.asarray(data['train']), 206 | labels=labels['train'], transform=transf['train'], two_crop=args.twocrop) 207 | 208 | # make new finetuning set without validation samples 209 | datasets['train'] = CustomDataset(data=np.asarray(data['train']), 210 | labels=labels['train'], transform=transf['train'], two_crop=False) 211 | 212 | # make class balanced validation set for finetuning 213 | datasets['valid'] = CustomDataset(data=np.asarray(data['valid']), 214 | labels=labels['valid'], transform=transf['test'], two_crop=False) 215 | 216 | # weighted sampler weights for new training set 217 | s_weights = sample_weights(datasets['pretrain'].labels) 218 | 219 | config = { 220 | 'pretrain': WeightedRandomSampler(s_weights, 221 | num_samples=len(s_weights), replacement=True), 222 | 'train': WeightedRandomSampler(s_weights, 223 | num_samples=len(s_weights), replacement=True), 224 | 'train_valid': WeightedRandomSampler(f_s_weights, 225 | num_samples=len(f_s_weights), replacement=True), 226 | 'valid': None, 'test': None 227 | } 228 | 229 | if args.distributed: 230 | config = {'pretrain': DistributedSampler(datasets['pretrain']), 231 | 'train': DistributedSampler(datasets['train']), 232 | 'train_valid': DistributedSampler(datasets['train_valid']), 233 | 'valid': None, 'test': None} 234 | 235 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i], 236 | num_workers=8, pin_memory=True, drop_last=True, 237 | batch_size=args.batch_size) for i in config.keys()} 238 | 239 | return dataloaders 240 | 241 | 242 | def stl10_dataloader(args, dataset_paths): 243 | ''' 244 | Loads the STL10 dataset performing augmentaions. 245 | 246 | Generates splits of the training set to produce a validation set. 247 | 248 | args: 249 | args (dict): Program/commandline arguments. 250 | 251 | dataset_paths (dict): Paths to each datset split. 252 | 253 | Returns: 254 | 255 | dataloaders (): pretrain,train,valid,train_valid,test set split dataloaders. 256 | ''' 257 | 258 | # guassian_blur from https://github.com/facebookresearch/moco/ 259 | guassian_blur = transforms.RandomApply([GaussianBlur(args.blur_sigma)], p=args.blur_p) 260 | 261 | color_jitter = transforms.ColorJitter( 262 | 0.8*args.jitter_d, 0.8*args.jitter_d, 0.8*args.jitter_d, 0.2*args.jitter_d) 263 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=args.jitter_p) 264 | 265 | rnd_grey = transforms.RandomGrayscale(p=args.grey_p) 266 | 267 | # Base train and test augmentaions 268 | transf = { 269 | 'train': transforms.Compose([ 270 | transforms.ToPILImage(), 271 | rnd_color_jitter, 272 | rnd_grey, 273 | guassian_blur, 274 | transforms.RandomResizedCrop((args.crop_dim, args.crop_dim)), 275 | transforms.RandomHorizontalFlip(), 276 | transforms.ToTensor(), 277 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 278 | (0.24703223, 0.24348513, 0.26158784))]), 279 | 'valid': transforms.Compose([ 280 | transforms.ToPILImage(), 281 | transforms.ToTensor(), 282 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 283 | (0.24703223, 0.24348513, 0.26158784))]), 284 | 'test': transforms.Compose([ 285 | transforms.ToTensor(), 286 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 287 | (0.24703223, 0.24348513, 0.26158784))]) 288 | } 289 | 290 | transf['pretrain'] = transf['train'] 291 | 292 | config = {'train': 'train', 'test': 'test', 'pretrain': 'unlabeled'} 293 | 294 | datasets = {i: STL10(root=dataset_paths[i], transform=transf[i], 295 | split=config[i], download=True) for i in config.keys()} 296 | 297 | # weighted sampler weights for full(f) training set 298 | f_s_weights = sample_weights(datasets['train'].labels) 299 | 300 | # return data, labels dicts for new train set and class-balanced valid set 301 | # 500 is the num of samples to be split into the val set for each class (10) 302 | data, labels = random_split(data=datasets['train'].data, 303 | labels=datasets['train'].labels, 304 | n_classes=args.n_classes, 305 | n_samples_per_class=np.repeat(50, args.n_classes).reshape(-1)) 306 | 307 | # save original full training set 308 | datasets['train_valid'] = datasets['train'] 309 | 310 | # make new pretraining set without validation samples 311 | datasets['pretrain'] = CustomDataset(data=datasets['pretrain'].data, 312 | labels=None, transform=transf['pretrain'], two_crop=args.twocrop) 313 | 314 | # make new finetuning set without validation samples 315 | datasets['train'] = CustomDataset(data=data['train'], 316 | labels=labels['train'], transform=transf['train'], two_crop=False) 317 | 318 | # make class balanced validation set for finetuning 319 | datasets['valid'] = CustomDataset(data=data['valid'], 320 | labels=labels['valid'], transform=transf['valid'], two_crop=False) 321 | 322 | # weighted sampler weights for new training set 323 | s_weights = sample_weights(datasets['train'].labels) 324 | 325 | config = { 326 | 'pretrain': None, 327 | 'train': WeightedRandomSampler(s_weights, 328 | num_samples=len(s_weights), replacement=True), 329 | 'train_valid': WeightedRandomSampler(f_s_weights, 330 | num_samples=len(f_s_weights), replacement=True), 331 | 'valid': None, 'test': None 332 | } 333 | 334 | if args.distributed: 335 | config = {'pretrain': DistributedSampler(datasets['pretrain']), 336 | 'train': DistributedSampler(datasets['train']), 337 | 'train_valid': DistributedSampler(datasets['train_valid']), 338 | 'valid': None, 'test': None} 339 | 340 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i], 341 | num_workers=8, pin_memory=True, drop_last=True, 342 | batch_size=args.batch_size) for i in config.keys()} 343 | 344 | return dataloaders 345 | 346 | 347 | def cifar_dataloader(args, dataset_paths): 348 | ''' 349 | Loads the CIFAR10 or CIFAR100 dataset performing augmentaions. 350 | 351 | Generates splits of the training set to produce a validation set. 352 | 353 | args: 354 | args (dict): Program/commandline arguments. 355 | 356 | dataset_paths (dict): Paths to each datset split. 357 | 358 | Returns: 359 | 360 | dataloaders (): pretrain,train,valid,train_valid,test set split dataloaders. 361 | ''' 362 | 363 | color_jitter = transforms.ColorJitter( 364 | 0.8*args.jitter_d, 0.8*args.jitter_d, 0.8*args.jitter_d, 0.2*args.jitter_d) 365 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=args.jitter_p) 366 | 367 | rnd_grey = transforms.RandomGrayscale(p=args.grey_p) 368 | 369 | # Base train and test augmentaions 370 | transf = { 371 | 'train': transforms.Compose([ 372 | transforms.ToPILImage(), 373 | rnd_color_jitter, 374 | rnd_grey, 375 | transforms.RandomResizedCrop((args.crop_dim, args.crop_dim), scale=(0.25, 1.0)), 376 | transforms.RandomHorizontalFlip(), 377 | transforms.ToTensor(), 378 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 379 | (0.24703223, 0.24348513, 0.26158784))]), 380 | 'pretrain': transforms.Compose([ 381 | transforms.ToPILImage(), 382 | rnd_color_jitter, 383 | rnd_grey, 384 | transforms.RandomResizedCrop((args.crop_dim, args.crop_dim)), 385 | transforms.RandomHorizontalFlip(), 386 | transforms.ToTensor(), 387 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 388 | (0.24703223, 0.24348513, 0.26158784))]), 389 | 'test': transforms.Compose([ 390 | transforms.ToTensor(), 391 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 392 | (0.24703223, 0.24348513, 0.26158784))]) 393 | } 394 | 395 | config = {'train': True, 'test': False} 396 | 397 | if args.dataset == 'cifar10': 398 | 399 | datasets = {i: CIFAR10(root=dataset_paths[i], transform=transf[i], 400 | train=config[i], download=True) for i in config.keys()} 401 | val_samples = 500 402 | 403 | elif args.dataset == 'cifar100': 404 | 405 | datasets = {i: CIFAR100(root=dataset_paths[i], transform=transf[i], 406 | train=config[i], download=True) for i in config.keys()} 407 | 408 | val_samples = 100 409 | 410 | # weighted sampler weights for full(f) training set 411 | f_s_weights = sample_weights(datasets['train'].targets) 412 | 413 | # return data, labels dicts for new train set and class-balanced valid set 414 | # 500 is the num of samples to be split into the val set for each class (10) 415 | data, labels = random_split(data=datasets['train'].data, 416 | labels=datasets['train'].targets, 417 | n_classes=args.n_classes, 418 | n_samples_per_class=np.repeat(val_samples, args.n_classes).reshape(-1)) 419 | 420 | # save original full training set 421 | datasets['train_valid'] = datasets['train'] 422 | 423 | # make new pretraining set without validation samples 424 | datasets['pretrain'] = CustomDataset(data=data['train'], 425 | labels=labels['train'], transform=transf['pretrain'], two_crop=args.twocrop) 426 | 427 | # make new finetuning set without validation samples 428 | datasets['train'] = CustomDataset(data=data['train'], 429 | labels=labels['train'], transform=transf['train'], two_crop=False) 430 | 431 | # make class balanced validation set for finetuning 432 | datasets['valid'] = CustomDataset(data=data['valid'], 433 | labels=labels['valid'], transform=transf['test'], two_crop=False) 434 | 435 | # weighted sampler weights for new training set 436 | s_weights = sample_weights(datasets['pretrain'].labels) 437 | 438 | config = { 439 | 'pretrain': WeightedRandomSampler(s_weights, 440 | num_samples=len(s_weights), replacement=True), 441 | 'train': WeightedRandomSampler(s_weights, 442 | num_samples=len(s_weights), replacement=True), 443 | 'train_valid': WeightedRandomSampler(f_s_weights, 444 | num_samples=len(f_s_weights), replacement=True), 445 | 'valid': None, 'test': None 446 | } 447 | 448 | if args.distributed: 449 | config = {'pretrain': DistributedSampler(datasets['pretrain']), 450 | 'train': DistributedSampler(datasets['train']), 451 | 'train_valid': DistributedSampler(datasets['train_valid']), 452 | 'valid': None, 'test': None} 453 | 454 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i], 455 | num_workers=8, pin_memory=True, drop_last=True, 456 | batch_size=args.batch_size) for i in config.keys()} 457 | 458 | return dataloaders 459 | -------------------------------------------------------------------------------- /src/kill_zombie.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | kill $(ps aux | grep "main.py" | grep -v grep | awk '{print $2}') 4 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import logging 5 | import random 6 | import configargparse 7 | import warnings 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.parallel import DistributedDataParallel 13 | 14 | from train import finetune, evaluate, pretrain, supervised 15 | from datasets import get_dataloaders 16 | from utils import * 17 | import model.network as models 18 | from model.moco import MoCo_Model 19 | 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | default_config = os.path.join(os.path.split(os.getcwd())[0], 'config.conf') 25 | 26 | parser = configargparse.ArgumentParser( 27 | description='Pytorch MocoV2', default_config_files=[default_config]) 28 | parser.add_argument('-c', '--my-config', required=False, 29 | is_config_file=True, help='config file path') 30 | parser.add_argument('--dataset', default='cifar10', 31 | help='Dataset, (Options: cifar10, cifar100, stl10, imagenet, tinyimagenet).') 32 | parser.add_argument('--dataset_path', default=None, 33 | help='Path to dataset, Not needed for TorchVision Datasets.') 34 | parser.add_argument('--model', default='resnet18', 35 | help='Model, (Options: resnet18, resnet34, resnet50, resnet101, resnet152).') 36 | parser.add_argument('--n_epochs', type=int, default=1000, 37 | help='Number of Epochs in Contrastive Training.') 38 | parser.add_argument('--finetune_epochs', type=int, default=100, 39 | help='Number of Epochs in Linear Classification Training.') 40 | parser.add_argument('--warmup_epochs', type=int, default=10, 41 | help='Number of Warmup Epochs During Contrastive Training.') 42 | parser.add_argument('--batch_size', type=int, default=256, 43 | help='Number of Samples Per Batch.') 44 | parser.add_argument('--learning_rate', type=float, default=1.0, 45 | help='Starting Learing Rate for Contrastive Training.') 46 | parser.add_argument('--base_lr', type=float, default=0.0001, 47 | help='Base / Minimum Learing Rate to Begin Linear Warmup.') 48 | parser.add_argument('--finetune_learning_rate', type=float, default=0.1, 49 | help='Starting Learing Rate for Linear Classification Training.') 50 | parser.add_argument('--weight_decay', type=float, default=1e-6, 51 | help='Contrastive Learning Weight Decay Regularisation Factor.') 52 | parser.add_argument('--finetune_weight_decay', type=float, default=0.0, 53 | help='Linear Classification Training Weight Decay Regularisation Factor.') 54 | parser.add_argument('--optimiser', default='sgd', 55 | help='Optimiser, (Options: sgd, adam, lars).') 56 | parser.add_argument('--patience', default=50, type=int, 57 | help='Number of Epochs to Wait for Improvement.') 58 | parser.add_argument('--queue_size', type=int, default=65536, 59 | help='Size of Memory Queue, Must be Divisible by batch_size.') 60 | parser.add_argument('--queue_momentum', type=float, default=0.999, 61 | help='Momentum for the Key Encoder Update.') 62 | parser.add_argument('--temperature', type=float, default=0.07, 63 | help='InfoNCE Temperature Factor') 64 | parser.add_argument('--jitter_d', type=float, default=1.0, 65 | help='Distortion Factor for the Random Colour Jitter Augmentation') 66 | parser.add_argument('--jitter_p', type=float, default=0.8, 67 | help='Probability to Apply Random Colour Jitter Augmentation') 68 | parser.add_argument('--blur_sigma', nargs=2, type=float, default=[0.1, 2.0], 69 | help='Radius to Apply Random Colour Jitter Augmentation') 70 | parser.add_argument('--blur_p', type=float, default=0.5, 71 | help='Probability to Apply Gaussian Blur Augmentation') 72 | parser.add_argument('--grey_p', type=float, default=0.2, 73 | help='Probability to Apply Random Grey Scale') 74 | parser.add_argument('--no_twocrop', dest='twocrop', action='store_false', 75 | help='Whether or Not to Use Two Crop Augmentation, Used to Create Two Views of the Input for Contrastive Learning. (Default: True)') 76 | parser.set_defaults(twocrop=True) 77 | parser.add_argument('--load_checkpoint_dir', default=None, 78 | help='Path to Load Pre-trained Model From.') 79 | parser.add_argument('--no_distributed', dest='distributed', action='store_false', 80 | help='Whether or Not to Use Distributed Training. (Default: True)') 81 | parser.set_defaults(distributed=True) 82 | parser.add_argument('--finetune', dest='finetune', action='store_true', 83 | help='Perform Only Linear Classification Training. (Default: False)') 84 | parser.set_defaults(finetune=False) 85 | parser.add_argument('--supervised', dest='supervised', action='store_true', 86 | help='Perform Supervised Pre-Training. (Default: False)') 87 | parser.set_defaults(supervised=False) 88 | 89 | 90 | def setup(distributed): 91 | """ Sets up for optional distributed training. 92 | For distributed training run as: 93 | python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py 94 | To kill zombie processes use: 95 | kill $(ps aux | grep "main.py" | grep -v grep | awk '{print $2}') 96 | For data parallel training on GPUs or CPU training run as: 97 | python main.py --no_distributed 98 | 99 | Taken from https://github.com/fabio-deep/Distributed-Pytorch-Boilerplate 100 | 101 | args: 102 | distributed (bool): Flag whether or not to perform distributed training. 103 | 104 | returns: 105 | local_rank (int): rank of local machine / host to perform distributed training. 106 | 107 | device (string): Device and rank of device to perform training on. 108 | 109 | """ 110 | if distributed: 111 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 112 | local_rank = int(os.environ.get('LOCAL_RANK')) 113 | device = torch.device(f'cuda:{local_rank}') # unique on individual node 114 | 115 | print('World size: {} ; Rank: {} ; LocalRank: {} ; Master: {}:{}'.format( 116 | os.environ.get('WORLD_SIZE'), 117 | os.environ.get('RANK'), 118 | os.environ.get('LOCAL_RANK'), 119 | os.environ.get('MASTER_ADDR'), os.environ.get('MASTER_PORT'))) 120 | else: 121 | local_rank = None 122 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 123 | 124 | seed = 44 125 | random.seed(seed) 126 | np.random.seed(seed) 127 | torch.manual_seed(seed) 128 | torch.cuda.manual_seed(seed) 129 | torch.cuda.manual_seed_all(seed) 130 | 131 | torch.backends.cudnn.enabled = True 132 | torch.backends.cudnn.deterministic = True 133 | torch.backends.cudnn.benchmark = False # True 134 | 135 | return device, local_rank 136 | 137 | 138 | def main(): 139 | """ Main """ 140 | 141 | # Arguments 142 | args = parser.parse_args() 143 | 144 | # Setup Distributed Training 145 | device, local_rank = setup(distributed=args.distributed) 146 | 147 | # Get Dataloaders for Dataset of choice 148 | dataloaders, args = get_dataloaders(args) 149 | 150 | # Setup logging, saving models, summaries 151 | args = experiment_config(parser, args) 152 | 153 | ''' Base Encoder ''' 154 | 155 | # Get available models from /model/network.py 156 | model_names = sorted(name for name in models.__dict__ 157 | if name.islower() and not name.startswith("__") 158 | and callable(models.__dict__[name])) 159 | 160 | # If model exists 161 | if any(args.model in model_name for model_name in model_names): 162 | # Load model 163 | base_encoder = getattr(models, args.model)( 164 | args, num_classes=args.n_classes) # Encoder 165 | 166 | else: 167 | raise NotImplementedError("Model Not Implemented: {}".format(args.model)) 168 | 169 | if not args.supervised: 170 | # freeze all layers but the last fc 171 | for name, param in base_encoder.named_parameters(): 172 | if name not in ['fc.weight', 'fc.bias']: 173 | param.requires_grad = False 174 | # init the fc layer 175 | init_weights(base_encoder) 176 | 177 | ''' MoCo Model ''' 178 | moco = MoCo_Model(args, queue_size=args.queue_size, 179 | momentum=args.queue_momentum, temperature=args.temperature) 180 | 181 | # Place model onto GPU(s) 182 | if args.distributed: 183 | torch.cuda.set_device(device) 184 | torch.set_num_threads(6) # n cpu threads / n processes per node 185 | 186 | moco = DistributedDataParallel(moco.cuda(), 187 | device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, broadcast_buffers=False) 188 | base_encoder = DistributedDataParallel(base_encoder.cuda(), 189 | device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, broadcast_buffers=False) 190 | 191 | # Only print from process (rank) 0 192 | args.print_progress = True if int(os.environ.get('RANK')) == 0 else False 193 | else: 194 | # If non Distributed use DataParallel 195 | if torch.cuda.device_count() > 1: 196 | moco = nn.DataParallel(moco) 197 | base_encoder = nn.DataParallel(base_encoder) 198 | 199 | print('\nUsing', torch.cuda.device_count(), 'GPU(s).\n') 200 | 201 | moco.to(device) 202 | base_encoder.to(device) 203 | 204 | args.print_progress = True 205 | 206 | # Print Network Structure and Params 207 | if args.print_progress: 208 | print_network(moco, args) # prints out the network architecture etc 209 | logging.info('\npretrain/train: {} - valid: {} - test: {}'.format( 210 | len(dataloaders['train'].dataset), len(dataloaders['valid'].dataset), 211 | len(dataloaders['test'].dataset))) 212 | 213 | # launch model training or inference 214 | if not args.finetune: 215 | 216 | ''' Pretraining / Finetuning / Evaluate ''' 217 | 218 | if not args.supervised: 219 | # Pretrain the encoder and projection head 220 | pretrain(moco, dataloaders, args) 221 | 222 | # Load the state_dict from query encoder and load it on finetune net 223 | base_encoder = load_moco(base_encoder, args) 224 | 225 | else: 226 | supervised(base_encoder, dataloaders, args) 227 | 228 | # Load the state_dict from query encoder and load it on finetune net 229 | base_encoder = load_sup(base_encoder, args) 230 | 231 | # Supervised Finetuning of the supervised classification head 232 | finetune(base_encoder, dataloaders, args) 233 | 234 | # Evaluate the pretrained model and trained supervised head 235 | test_loss, test_acc, test_acc_top5 = evaluate( 236 | base_encoder, dataloaders, 'test', args.finetune_epochs, args) 237 | 238 | print('[Test] loss {:.4f} - acc {:.4f} - acc_top5 {:.4f}'.format( 239 | test_loss, test_acc, test_acc_top5)) 240 | 241 | if args.distributed: # cleanup 242 | torch.distributed.destroy_process_group() 243 | else: 244 | 245 | ''' Finetuning / Evaluate ''' 246 | 247 | # Do not Pretrain, just finetune and inference 248 | # Load the state_dict from query encoder and load it on finetune net 249 | base_encoder = load_moco(base_encoder, args) 250 | 251 | # Supervised Finetuning of the supervised classification head 252 | finetune(base_encoder, dataloaders, args) 253 | 254 | # Evaluate the pretrained model and trained supervised head 255 | test_loss, test_acc, test_acc_top5 = evaluate( 256 | base_encoder, dataloaders, 'test', args.finetune_epochs, args) 257 | 258 | print('[Test] loss {:.4f} - acc {:.4f} - acc_top5 {:.4f}'.format( 259 | test_loss, test_acc, test_acc_top5)) 260 | 261 | if args.distributed: # cleanup 262 | torch.distributed.destroy_process_group() 263 | 264 | 265 | if __name__ == '__main__': 266 | main() 267 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AidenDurrant/MoCo-Pytorch/373b5fafdbf51e4c6c19a984b7e8d41a296cf6d9/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/moco.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import model.network as models 7 | 8 | 9 | class MoCo_Model(nn.Module): 10 | def __init__(self, args, queue_size=65536, momentum=0.999, temperature=0.07): 11 | ''' 12 | MoCoV2 model, taken from: https://github.com/facebookresearch/moco. 13 | 14 | Adapted for use in personal Boilerplate for unsupervised/self-supervised contrastive learning. 15 | 16 | Additionally, too inspiration from: https://github.com/HobbitLong/CMC. 17 | 18 | Args: 19 | init: 20 | args (dict): Program arguments/commandline arguments. 21 | 22 | queue_size (int): Length of the queue/memory, number of samples to store in memory. (default: 65536) 23 | 24 | momentum (float): Momentum value for updating the key_encoder. (default: 0.999) 25 | 26 | temperature (float): Temperature used in the InfoNCE / NT_Xent contrastive losses. (default: 0.07) 27 | 28 | forward: 29 | x_q (Tensor): Reprentation of view intended for the query_encoder. 30 | 31 | x_k (Tensor): Reprentation of view intended for the key_encoder. 32 | 33 | returns: 34 | 35 | logit (Tensor): Positve and negative logits computed as by InfoNCE loss. (bsz, queue_size + 1) 36 | 37 | label (Tensor): Labels of the positve and negative logits to be used in softmax cross entropy. (bsz, 1) 38 | 39 | ''' 40 | super(MoCo_Model, self).__init__() 41 | 42 | self.queue_size = queue_size 43 | self.momentum = momentum 44 | self.temperature = temperature 45 | 46 | assert self.queue_size % args.batch_size == 0 # for simplicity 47 | 48 | # Load model 49 | self.encoder_q = getattr(models, args.model)( 50 | args, num_classes=128) # Query Encoder 51 | 52 | self.encoder_k = getattr(models, args.model)( 53 | args, num_classes=128) # Key Encoder 54 | 55 | # Add the mlp head 56 | self.encoder_q.fc = models.projection_MLP(args) 57 | self.encoder_k.fc = models.projection_MLP(args) 58 | 59 | # Initialize the key encoder to have the same values as query encoder 60 | # Do not update the key encoder via gradient 61 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 62 | param_k.data.copy_(param_q.data) 63 | param_k.requires_grad = False 64 | 65 | # Create the queue to store negative samples 66 | self.register_buffer("queue", torch.randn(self.queue_size, 128)) 67 | 68 | # Create pointer to store current position in the queue when enqueue and dequeue 69 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 70 | 71 | @torch.no_grad() 72 | def momentum_update(self): 73 | ''' 74 | Update the key_encoder parameters through the momentum update: 75 | 76 | 77 | key_params = momentum * key_params + (1 - momentum) * query_params 78 | 79 | ''' 80 | 81 | # For each of the parameters in each encoder 82 | for p_q, p_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 83 | p_k.data = p_k.data * self.momentum + p_q.detach().data * (1. - self.momentum) 84 | 85 | @torch.no_grad() 86 | def shuffled_idx(self, batch_size): 87 | ''' 88 | Generation of the shuffled indexes for the implementation of ShuffleBN. 89 | 90 | https://github.com/HobbitLong/CMC. 91 | 92 | args: 93 | batch_size (Tensor.int()): Number of samples in a batch 94 | 95 | returns: 96 | shuffled_idxs (Tensor.long()): A random permutation index order for the shuffling of the current minibatch 97 | 98 | reverse_idxs (Tensor.long()): A reverse of the random permutation index order for the shuffling of the 99 | current minibatch to get back original sample order 100 | 101 | ''' 102 | 103 | # Generate shuffled indexes 104 | shuffled_idxs = torch.randperm(batch_size).long().cuda() 105 | 106 | reverse_idxs = torch.zeros(batch_size).long().cuda() 107 | 108 | value = torch.arange(batch_size).long().cuda() 109 | 110 | reverse_idxs.index_copy_(0, shuffled_idxs, value) 111 | 112 | return shuffled_idxs, reverse_idxs 113 | 114 | @torch.no_grad() 115 | def update_queue(self, feat_k): 116 | ''' 117 | Update the memory / queue. 118 | 119 | Add batch to end of most recent sample index and remove the oldest samples in the queue. 120 | 121 | Store location of most recent sample index (ptr). 122 | 123 | Taken from: https://github.com/facebookresearch/moco 124 | 125 | args: 126 | feat_k (Tensor): Feature reprentations of the view x_k computed by the key_encoder. 127 | ''' 128 | 129 | batch_size = feat_k.size(0) 130 | 131 | ptr = int(self.queue_ptr) 132 | 133 | # replace the keys at ptr (dequeue and enqueue) 134 | self.queue[ptr:ptr + batch_size, :] = feat_k 135 | 136 | # move pointer along to end of current batch 137 | ptr = (ptr + batch_size) % self.queue_size 138 | 139 | # Store queue pointer as register_buffer 140 | self.queue_ptr[0] = ptr 141 | 142 | def InfoNCE_logits(self, f_q, f_k): 143 | ''' 144 | Compute the similarity logits between positive 145 | samples and positve to all negatives in the memory. 146 | 147 | args: 148 | f_q (Tensor): Feature reprentations of the view x_q computed by the query_encoder. 149 | 150 | f_k (Tensor): Feature reprentations of the view x_k computed by the key_encoder. 151 | 152 | returns: 153 | logit (Tensor): Positve and negative logits computed as by InfoNCE loss. (bsz, queue_size + 1) 154 | 155 | label (Tensor): Labels of the positve and negative logits to be used in softmax cross entropy. (bsz, 1) 156 | ''' 157 | 158 | f_k = f_k.detach() 159 | 160 | # Get queue from register_buffer 161 | f_mem = self.queue.clone().detach() 162 | 163 | # Normalize the feature representations 164 | f_q = nn.functional.normalize(f_q, dim=1) 165 | f_k = nn.functional.normalize(f_k, dim=1) 166 | f_mem = nn.functional.normalize(f_mem, dim=1) 167 | 168 | # Compute sim between positive views 169 | pos = torch.bmm(f_q.view(f_q.size(0), 1, -1), 170 | f_k.view(f_k.size(0), -1, 1)).squeeze(-1) 171 | 172 | # Compute sim between postive and all negatives in the memory 173 | neg = torch.mm(f_q, f_mem.transpose(1, 0)) 174 | 175 | logits = torch.cat((pos, neg), dim=1) 176 | 177 | logits /= self.temperature 178 | 179 | # Create labels, first logit is postive, all others are negative 180 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 181 | 182 | return logits, labels 183 | 184 | def forward(self, x_q, x_k): 185 | 186 | batch_size = x_q.size(0) 187 | 188 | # Feature representations of the query view from the query encoder 189 | feat_q = self.encoder_q(x_q) 190 | 191 | # TODO: shuffle ids with distributed data parallel 192 | # Get shuffled and reversed indexes for the current minibatch 193 | shuffled_idxs, reverse_idxs = self.shuffled_idx(batch_size) 194 | 195 | with torch.no_grad(): 196 | # Update the key encoder 197 | self.momentum_update() 198 | 199 | # Shuffle minibatch 200 | x_k = x_k[shuffled_idxs] 201 | 202 | # Feature representations of the shuffled key view from the key encoder 203 | feat_k = self.encoder_k(x_k) 204 | 205 | # reverse the shuffled samples to original position 206 | feat_k = feat_k[reverse_idxs] 207 | 208 | # Compute the logits for the InfoNCE contrastive loss. 209 | logit, label = self.InfoNCE_logits(feat_q, feat_k) 210 | 211 | # Update the queue/memory with the current key_encoder minibatch. 212 | self.update_queue(feat_k) 213 | 214 | return logit, label 215 | -------------------------------------------------------------------------------- /src/model/network.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2', 'projection_MLP'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 19 | } 20 | 21 | '''Resnet Class 22 | 23 | Taken from: 24 | 25 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 26 | ''' 27 | 28 | 29 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 30 | """3x3 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=dilation, groups=groups, bias=False, dilation=dilation) 33 | 34 | 35 | def conv1x1(in_planes, out_planes, stride=1): 36 | """1x1 convolution""" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 44 | base_width=64, dilation=1, norm_layer=None): 45 | super(BasicBlock, self).__init__() 46 | if norm_layer is None: 47 | norm_layer = nn.BatchNorm2d 48 | if groups != 1 or base_width != 64: 49 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 50 | if dilation > 1: 51 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 52 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 53 | self.conv1 = conv3x3(inplanes, planes, stride) 54 | self.bn1 = norm_layer(planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.conv2 = conv3x3(planes, planes) 57 | self.bn2 = norm_layer(planes) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | identity = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | 71 | if self.downsample is not None: 72 | identity = self.downsample(x) 73 | 74 | out += identity 75 | out = self.relu(out) 76 | 77 | return out 78 | 79 | 80 | class Bottleneck(nn.Module): 81 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 82 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 83 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 84 | # This variant is also known as ResNet V1.5 and improves accuracy according to 85 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 86 | 87 | expansion = 4 88 | 89 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 90 | base_width=64, dilation=1, norm_layer=None): 91 | super(Bottleneck, self).__init__() 92 | if norm_layer is None: 93 | norm_layer = nn.BatchNorm2d 94 | width = int(planes * (base_width / 64.)) * groups 95 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 96 | self.conv1 = conv1x1(inplanes, width) 97 | self.bn1 = norm_layer(width) 98 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 99 | self.bn2 = norm_layer(width) 100 | self.conv3 = conv1x1(width, planes * self.expansion) 101 | self.bn3 = norm_layer(planes * self.expansion) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x): 107 | identity = x 108 | 109 | out = self.conv1(x) 110 | out = self.bn1(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv2(out) 114 | out = self.bn2(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv3(out) 118 | out = self.bn3(out) 119 | 120 | if self.downsample is not None: 121 | identity = self.downsample(x) 122 | 123 | out += identity 124 | out = self.relu(out) 125 | 126 | return out 127 | 128 | 129 | class ResNet(nn.Module): 130 | def __init__(self, block, layers, args, num_classes=1000, zero_init_residual=False, 131 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 132 | norm_layer=None): 133 | super(ResNet, self).__init__() 134 | if norm_layer is None: 135 | norm_layer = nn.BatchNorm2d 136 | self._norm_layer = norm_layer 137 | 138 | self.inplanes = 64 139 | self.dilation = 1 140 | if replace_stride_with_dilation is None: 141 | # each element in the tuple indicates if we should replace 142 | # the 2x2 stride with a dilated convolution instead 143 | replace_stride_with_dilation = [False, False, False] 144 | if len(replace_stride_with_dilation) != 3: 145 | raise ValueError("replace_stride_with_dilation should be None " 146 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 147 | self.groups = groups 148 | self.base_width = width_per_group 149 | 150 | # Different model for smaller image size 151 | if args.dataset == 'cifar10' or args.dataset == 'cifar100': 152 | 153 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 154 | bias=False) # For CIFAR 155 | 156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=0) # For CIFAR 157 | 158 | # e.g. ImageNet 159 | else: 160 | 161 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 162 | bias=False) 163 | 164 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 165 | 166 | self.bn1 = norm_layer(self.inplanes) 167 | self.relu = nn.ReLU(inplace=True) 168 | 169 | self.layer1 = self._make_layer(block, 64, layers[0]) 170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 171 | dilate=replace_stride_with_dilation[0]) 172 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 173 | dilate=replace_stride_with_dilation[1]) 174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 175 | dilate=replace_stride_with_dilation[2]) 176 | 177 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 178 | 179 | self.fc = nn.Linear(512 * block.expansion, num_classes) 180 | 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 184 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 185 | nn.init.constant_(m.weight, 1) 186 | nn.init.constant_(m.bias, 0) 187 | 188 | # Zero-initialize the last BN in each residual branch, 189 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 190 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 191 | if zero_init_residual: 192 | for m in self.modules(): 193 | if isinstance(m, Bottleneck): 194 | nn.init.constant_(m.bn3.weight, 0) 195 | elif isinstance(m, BasicBlock): 196 | nn.init.constant_(m.bn2.weight, 0) 197 | 198 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 199 | norm_layer = self._norm_layer 200 | downsample = None 201 | previous_dilation = self.dilation 202 | if dilate: 203 | self.dilation *= stride 204 | stride = 1 205 | if stride != 1 or self.inplanes != planes * block.expansion: 206 | downsample = nn.Sequential( 207 | conv1x1(self.inplanes, planes * block.expansion, stride), 208 | norm_layer(planes * block.expansion), 209 | ) 210 | 211 | layers = [] 212 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 213 | self.base_width, previous_dilation, norm_layer)) 214 | self.inplanes = planes * block.expansion 215 | for _ in range(1, blocks): 216 | layers.append(block(self.inplanes, planes, groups=self.groups, 217 | base_width=self.base_width, dilation=self.dilation, 218 | norm_layer=norm_layer)) 219 | 220 | return nn.Sequential(*layers) 221 | 222 | def forward(self, x): 223 | x = self.conv1(x) 224 | x = self.bn1(x) 225 | x = self.relu(x) 226 | 227 | x = self.maxpool(x) 228 | 229 | x = self.layer1(x) 230 | x = self.layer2(x) 231 | x = self.layer3(x) 232 | x = self.layer4(x) 233 | 234 | x = self.avgpool(x) 235 | 236 | x = torch.flatten(x, 1) 237 | 238 | x = self.fc(x) 239 | 240 | return x 241 | 242 | 243 | def resnet18(args, **kwargs): 244 | r"""ResNet-18 model from 245 | `"Deep Residual Learning for Image Recognition" `_ 246 | Args: 247 | args: arguments 248 | """ 249 | return ResNet(BasicBlock, [2, 2, 2, 2], args, **kwargs) 250 | 251 | 252 | def resnet34(args, **kwargs): 253 | r"""ResNet-34 model from 254 | `"Deep Residual Learning for Image Recognition" `_ 255 | Args: 256 | args: arguments 257 | """ 258 | return ResNet(BasicBlock, [3, 4, 6, 3], args, **kwargs) 259 | 260 | 261 | def resnet50(args, **kwargs): 262 | r"""ResNet-50 model from 263 | `"Deep Residual Learning for Image Recognition" `_ 264 | Args: 265 | args: arguments 266 | """ 267 | return ResNet(Bottleneck, [3, 4, 6, 3], args, **kwargs) 268 | 269 | 270 | def resnet101(args, **kwargs): 271 | r"""ResNet-101 model from 272 | `"Deep Residual Learning for Image Recognition" `_ 273 | Args: 274 | args: arguments 275 | """ 276 | return ResNet(Bottleneck, [3, 4, 23, 3], args, **kwargs) 277 | 278 | 279 | def resnet152(args, **kwargs): 280 | r"""ResNet-152 model from 281 | `"Deep Residual Learning for Image Recognition" `_ 282 | Args: 283 | args: arguments 284 | """ 285 | return ResNet(Bottleneck, [3, 8, 36, 3], args, **kwargs) 286 | 287 | 288 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 289 | r"""ResNeXt-50 32x4d model from 290 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | kwargs['groups'] = 32 296 | kwargs['width_per_group'] = 4 297 | return ResNet(Bottleneck, [3, 4, 6, 3], args, **kwargs) 298 | 299 | 300 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 301 | r"""ResNeXt-101 32x8d model from 302 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | progress (bool): If True, displays a progress bar of the download to stderr 306 | """ 307 | kwargs['groups'] = 32 308 | kwargs['width_per_group'] = 8 309 | return ResNet(Bottleneck, [3, 4, 23, 3], args, **kwargs) 310 | 311 | 312 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 313 | r"""Wide ResNet-50-2 model from 314 | `"Wide Residual Networks" `_ 315 | The model is the same as ResNet except for the bottleneck number of channels 316 | which is twice larger in every block. The number of channels in outer 1x1 317 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 318 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | progress (bool): If True, displays a progress bar of the download to stderr 322 | """ 323 | kwargs['width_per_group'] = 64 * 2 324 | return ResNet(Bottleneck, [3, 4, 6, 3], args, **kwargs) 325 | 326 | 327 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 328 | r"""Wide ResNet-101-2 model from 329 | `"Wide Residual Networks" `_ 330 | The model is the same as ResNet except for the bottleneck number of channels 331 | which is twice larger in every block. The number of channels in outer 1x1 332 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 333 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 334 | Args: 335 | pretrained (bool): If True, returns a model pre-trained on ImageNet 336 | progress (bool): If True, displays a progress bar of the download to stderr 337 | """ 338 | kwargs['width_per_group'] = 64 * 2 339 | return ResNet(Bottleneck, [3, 4, 23, 3], encoder, args, **kwargs) 340 | 341 | 342 | ''' SimCLR Projection Head ''' 343 | 344 | 345 | class projection_MLP(nn.Module): 346 | def __init__(self, args): 347 | '''Projection head for the pretraining of the resnet encoder. 348 | 349 | - Uses the dataset and model size to determine encoder output 350 | representation dimension. 351 | - Outputs to a dimension of 128, and uses non-linear activation 352 | as described in SimCLR paper: https://arxiv.org/pdf/2002.05709.pdf 353 | ''' 354 | super(projection_MLP, self).__init__() 355 | 356 | if args.model == 'resnet18' or args.model == 'resnet34': 357 | n_channels = 512 358 | elif args.model == 'resnet50' or args.model == 'resnet101' or args.model == 'resnet152': 359 | n_channels = 2048 360 | else: 361 | raise NotImplementedError('model not supported: {}'.format(args.model)) 362 | 363 | self.projection_head = nn.Sequential() 364 | 365 | self.projection_head.add_module('W1', nn.Linear( 366 | n_channels, n_channels)) 367 | self.projection_head.add_module('ReLU', nn.ReLU()) 368 | self.projection_head.add_module('W2', nn.Linear( 369 | n_channels, 128)) 370 | 371 | def forward(self, x): 372 | return self.projection_head(x) 373 | -------------------------------------------------------------------------------- /src/optimisers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.optim as optim 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | 7 | def get_optimiser(models, mode, args): 8 | '''Get the desired optimiser 9 | 10 | - Selects and initialises an optimiser with model params. 11 | 12 | - if 'LARS' is selected, the 'bn' and 'bias' parameters are removed from 13 | model optimisation, only passing the parameters we want. 14 | 15 | Args: 16 | models (tuple): models which we want to optmise, (e.g. encoder and projection head) 17 | 18 | mode (string): the mode of training, (i.e. 'pretrain', 'finetune') 19 | 20 | args (Dictionary): Program Arguments 21 | 22 | Returns: 23 | optimiser (torch.optim.optimizer): 24 | ''' 25 | params_models = [] 26 | reduced_params = [] 27 | 28 | removed_params = [] 29 | 30 | skip_lists = ['bn', 'bias'] 31 | 32 | for m in models: 33 | 34 | m_skip = [] 35 | m_noskip = [] 36 | 37 | params_models += list(m.parameters()) 38 | 39 | for name, param in m.named_parameters(): 40 | if (any(skip_name in name for skip_name in skip_lists)): 41 | m_skip.append(param) 42 | else: 43 | m_noskip.append(param) 44 | reduced_params += list(m_noskip) 45 | removed_params += list(m_skip) 46 | # Set hyperparams depending on mode 47 | if mode == 'pretrain': 48 | lr = args.learning_rate 49 | wd = args.weight_decay 50 | else: 51 | lr = args.finetune_learning_rate 52 | wd = args.finetune_weight_decay 53 | 54 | # Select Optimiser 55 | if args.optimiser == 'adam': 56 | 57 | optimiser = optim.Adam(params_models, lr=lr, 58 | weight_decay=wd) 59 | 60 | elif args.optimiser == 'sgd': 61 | 62 | optimiser = optim.SGD(params_models, lr=lr, 63 | weight_decay=wd, momentum=0.9) 64 | 65 | elif args.optimiser == 'lars': 66 | 67 | print("reduced_params len: {}".format(len(reduced_params))) 68 | print("removed_params len: {}".format(len(removed_params))) 69 | 70 | optimiser = LARS(reduced_params+removed_params, lr=lr, 71 | weight_decay=wd, eta=0.001, use_nesterov=False, len_reduced=len(reduced_params)) 72 | else: 73 | 74 | raise NotImplementedError('{} not setup.'.format(args.optimiser)) 75 | 76 | return optimiser 77 | 78 | 79 | class LARS(Optimizer): 80 | """ 81 | Layer-wise adaptive rate scaling 82 | 83 | - Converted from Tensorflow to Pytorch from: 84 | 85 | https://github.com/google-research/simclr/blob/master/lars_optimizer.py 86 | 87 | - Based on: 88 | 89 | https://github.com/noahgolmant/pytorch-lars 90 | 91 | params (iterable): iterable of parameters to optimize or dicts defining 92 | parameter groups 93 | lr (float): base learning rate (\gamma_0) 94 | 95 | lr (int): Length / Number of layers we want to apply weight decay, else do not compute 96 | 97 | momentum (float, optional): momentum factor (default: 0.9) 98 | 99 | use_nesterov (bool, optional): flag to use nesterov momentum (default: False) 100 | 101 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) 102 | ("\beta") 103 | 104 | eta (float, optional): LARS coefficient (default: 0.001) 105 | 106 | - Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 107 | 108 | - Large Batch Training of Convolutional Networks: 109 | https://arxiv.org/abs/1708.03888 110 | 111 | """ 112 | 113 | def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov=False, weight_decay=0.0, classic_momentum=True, eta=0.001): 114 | 115 | self.epoch = 0 116 | defaults = dict( 117 | lr=lr, 118 | momentum=momentum, 119 | use_nesterov=use_nesterov, 120 | weight_decay=weight_decay, 121 | classic_momentum=classic_momentum, 122 | eta=eta, 123 | len_reduced=len_reduced 124 | ) 125 | 126 | super(LARS, self).__init__(params, defaults) 127 | self.lr = lr 128 | self.momentum = momentum 129 | self.weight_decay = weight_decay 130 | self.use_nesterov = use_nesterov 131 | self.classic_momentum = classic_momentum 132 | self.eta = eta 133 | self.len_reduced = len_reduced 134 | 135 | def step(self, epoch=None, closure=None): 136 | 137 | loss = None 138 | 139 | if closure is not None: 140 | loss = closure() 141 | 142 | if epoch is None: 143 | epoch = self.epoch 144 | self.epoch += 1 145 | 146 | for group in self.param_groups: 147 | weight_decay = group['weight_decay'] 148 | momentum = group['momentum'] 149 | eta = group['eta'] 150 | learning_rate = group['lr'] 151 | 152 | # TODO: Hacky 153 | counter = 0 154 | for p in group['params']: 155 | if p.grad is None: 156 | continue 157 | 158 | param = p.data 159 | grad = p.grad.data 160 | 161 | param_state = self.state[p] 162 | 163 | # TODO: This really hacky way needs to be improved. 164 | 165 | # Note Excluded are passed at the end of the list to are ignored 166 | if counter < self.len_reduced: 167 | grad += self.weight_decay * param 168 | 169 | # Create parameter for the momentum 170 | if "momentum_var" not in param_state: 171 | next_v = param_state["momentum_var"] = torch.zeros_like( 172 | p.data 173 | ) 174 | else: 175 | next_v = param_state["momentum_var"] 176 | 177 | if self.classic_momentum: 178 | trust_ratio = 1.0 179 | 180 | # TODO: implementation of layer adaptation 181 | w_norm = torch.norm(param) 182 | g_norm = torch.norm(grad) 183 | 184 | device = g_norm.get_device() 185 | 186 | trust_ratio = torch.where(w_norm.ge(0), torch.where( 187 | g_norm.ge(0), (self.eta * w_norm / g_norm), torch.Tensor([1.0]).to(device)), torch.Tensor([1.0]).to(device)).item() 188 | 189 | scaled_lr = learning_rate * trust_ratio 190 | 191 | next_v.mul_(momentum).add_(scaled_lr, grad) 192 | 193 | if self.use_nesterov: 194 | update = (self.momentum * next_v) + (scaled_lr * grad) 195 | else: 196 | update = next_v 197 | 198 | p.data.add_(-update) 199 | 200 | # Not classic_momentum 201 | else: 202 | 203 | next_v.mul_(momentum).add_(grad) 204 | 205 | if self.use_nesterov: 206 | update = (self.momentum * next_v) + (grad) 207 | 208 | else: 209 | update = next_v 210 | 211 | trust_ratio = 1.0 212 | 213 | # TODO: implementation of layer adaptation 214 | w_norm = torch.norm(param) 215 | v_norm = torch.norm(update) 216 | 217 | device = v_norm.get_device() 218 | 219 | trust_ratio = torch.where(w_norm.ge(0), torch.where( 220 | v_norm.ge(0), (self.eta * w_norm / v_norm), torch.Tensor([1.0]).to(device)), torch.Tensor([1.0]).to(device)).item() 221 | 222 | scaled_lr = learning_rate * trust_ratio 223 | 224 | p.data.add_(-scaled_lr * update) 225 | 226 | counter += 1 227 | 228 | return loss 229 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import gc 4 | import logging 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.optim.lr_scheduler as lr_scheduler 12 | import torch.nn.functional as F 13 | from torch.utils.tensorboard import SummaryWriter 14 | from optimisers import get_optimiser 15 | from PIL import Image 16 | 17 | 18 | def pretrain(encoder, dataloaders, args): 19 | ''' Pretrain script - MoCo 20 | 21 | Pretrain the encoder and projection head with a Contrastive InfoNCE Loss. 22 | ''' 23 | mode = 'pretrain' 24 | 25 | ''' Optimisers ''' 26 | optimiser = get_optimiser((encoder,), mode, args) 27 | 28 | ''' Schedulers ''' 29 | # Warmup Scheduler 30 | if args.warmup_epochs > 0: 31 | for param_group in optimiser.param_groups: 32 | param_group['lr'] = args.base_lr 33 | 34 | # Cosine LR Decay after the warmup epochs 35 | lr_decay = lr_scheduler.CosineAnnealingLR( 36 | optimiser, (args.n_epochs-args.warmup_epochs), eta_min=0.0001, last_epoch=-1) 37 | else: 38 | # Cosine LR Decay 39 | lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.n_epochs) 40 | 41 | ''' Loss / Criterion ''' 42 | criterion = nn.CrossEntropyLoss().cuda() 43 | 44 | # initilize Variables 45 | args.writer = SummaryWriter(args.summaries_dir) 46 | best_valid_loss = np.inf 47 | patience_counter = 0 48 | 49 | ''' Pretrain loop ''' 50 | for epoch in range(args.n_epochs): 51 | 52 | # Train models 53 | encoder.train() 54 | 55 | sample_count = 0 56 | run_loss = 0 57 | 58 | # Print setup for distributed only printing on one node. 59 | if args.print_progress: 60 | logging.info('\nEpoch {}/{}:\n'.format(epoch+1, args.n_epochs)) 61 | # tqdm for process (rank) 0 only when using distributed training 62 | train_dataloader = tqdm(dataloaders['pretrain']) 63 | else: 64 | train_dataloader = dataloaders['pretrain'] 65 | 66 | ''' epoch loop ''' 67 | for i, (inputs, _) in enumerate(train_dataloader): 68 | 69 | inputs = inputs.cuda(non_blocking=True) 70 | 71 | # Forward pass 72 | optimiser.zero_grad() 73 | 74 | # retrieve the 2 views 75 | x_i, x_j = torch.split(inputs, [3, 3], dim=1) 76 | 77 | # Get the encoder representation 78 | logit, label = encoder(x_i, x_j) 79 | 80 | loss = criterion(logit, label) 81 | 82 | loss.backward() 83 | 84 | optimiser.step() 85 | 86 | torch.cuda.synchronize() 87 | 88 | sample_count += inputs.size(0) 89 | 90 | run_loss += loss.item() 91 | 92 | epoch_pretrain_loss = run_loss / len(dataloaders['pretrain']) 93 | 94 | ''' Update Schedulers ''' 95 | # TODO: Improve / add lr_scheduler for warmup 96 | if args.warmup_epochs > 0 and epoch+1 <= args.warmup_epochs: 97 | wu_lr = (args.learning_rate - args.base_lr) * \ 98 | (float(epoch+1) / args.warmup_epochs) + args.base_lr 99 | save_lr = optimiser.param_groups[0]['lr'] 100 | optimiser.param_groups[0]['lr'] = wu_lr 101 | else: 102 | # After warmup, decay lr with CosineAnnealingLR 103 | lr_decay.step() 104 | 105 | ''' Printing ''' 106 | if args.print_progress: # only validate using process 0 107 | logging.info('\n[Train] loss: {:.4f}'.format(epoch_pretrain_loss)) 108 | 109 | args.writer.add_scalars('epoch_loss', {'pretrain': epoch_pretrain_loss}, epoch+1) 110 | args.writer.add_scalars('lr', {'pretrain': optimiser.param_groups[0]['lr']}, epoch+1) 111 | 112 | # For the best performing epoch, reset patience and save model, 113 | # else update patience. 114 | if epoch_pretrain_loss <= best_valid_loss: 115 | patience_counter = 0 116 | best_epoch = epoch + 1 117 | best_valid_loss = epoch_pretrain_loss 118 | 119 | # saving using process (rank) 0 only as all processes are in sync 120 | 121 | state = { 122 | #'args': args, 123 | 'moco': encoder.state_dict(), 124 | 'optimiser': optimiser.state_dict(), 125 | 'epoch': epoch, 126 | } 127 | 128 | torch.save(state, args.checkpoint_dir) 129 | else: 130 | patience_counter += 1 131 | if patience_counter == (args.patience - 10): 132 | logging.info('\nPatience counter {}/{}.'.format( 133 | patience_counter, args.patience)) 134 | elif patience_counter == args.patience: 135 | logging.info('\nEarly stopping... no improvement after {} Epochs.'.format( 136 | args.patience)) 137 | break 138 | 139 | epoch_pretrain_loss = None # reset loss 140 | 141 | del state 142 | 143 | torch.cuda.empty_cache() 144 | 145 | gc.collect() # release unreferenced memory 146 | 147 | 148 | def supervised(encoder, dataloaders, args): 149 | ''' Supervised Train script - MoCo 150 | 151 | Supervised Training encoder and train the supervised classification head with a Cross Entropy Loss. 152 | ''' 153 | 154 | mode = 'pretrain' 155 | 156 | ''' Optimisers ''' 157 | # Only optimise the supervised head 158 | optimiser = get_optimiser((encoder,), mode, args) 159 | 160 | ''' Schedulers ''' 161 | # Warmup Scheduler 162 | if args.warmup_epochs > 0: 163 | for param_group in optimiser.param_groups: 164 | param_group['lr'] = args.base_lr 165 | 166 | # Cosine LR Decay after the warmup epochs 167 | lr_decay = lr_scheduler.CosineAnnealingLR( 168 | optimiser, (args.n_epochs-args.warmup_epochs), eta_min=0.0001, last_epoch=-1) 169 | else: 170 | # Cosine LR Decay 171 | lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.n_epochs) 172 | 173 | ''' Loss / Criterion ''' 174 | criterion = torch.nn.CrossEntropyLoss().cuda() 175 | 176 | # initilize Variables 177 | args.writer = SummaryWriter(args.summaries_dir) 178 | best_valid_loss = np.inf 179 | patience_counter = 0 180 | 181 | ''' Pretrain loop ''' 182 | for epoch in range(args.n_epochs): 183 | 184 | # Train models 185 | encoder.train() 186 | 187 | sample_count = 0 188 | run_loss = 0 189 | run_top1 = 0.0 190 | run_top5 = 0.0 191 | 192 | # Print setup for distributed only printing on one node. 193 | if args.print_progress: 194 | logging.info('\nEpoch {}/{}:\n'.format(epoch+1, args.n_epochs)) 195 | # tqdm for process (rank) 0 only when using distributed training 196 | train_dataloader = tqdm(dataloaders['train']) 197 | else: 198 | train_dataloader = dataloaders['train'] 199 | 200 | ''' epoch loop ''' 201 | for i, (inputs, target) in enumerate(train_dataloader): 202 | 203 | inputs = inputs.cuda(non_blocking=True) 204 | 205 | target = target.cuda(non_blocking=True) 206 | 207 | # Forward pass 208 | optimiser.zero_grad() 209 | 210 | output = encoder(inputs) 211 | 212 | loss = criterion(output, target) 213 | 214 | loss.backward() 215 | 216 | optimiser.step() 217 | 218 | torch.cuda.synchronize() 219 | 220 | sample_count += inputs.size(0) 221 | 222 | run_loss += loss.item() 223 | 224 | predicted = output.argmax(1) 225 | 226 | acc = (predicted == target).sum().item() / target.size(0) 227 | 228 | run_top1 += acc 229 | 230 | _, output_topk = output.topk(5, 1, True, True) 231 | 232 | acc_top5 = (output_topk == target.view(-1, 1).expand_as(output_topk) 233 | ).sum().item() / target.size(0) # num corrects 234 | 235 | run_top5 += acc_top5 236 | 237 | epoch_pretrain_loss = run_loss / len(dataloaders['train']) # sample_count 238 | 239 | epoch_pretrain_acc = run_top1 / len(dataloaders['train']) 240 | 241 | epoch_pretrain_acc_top5 = run_top5 / len(dataloaders['train']) 242 | 243 | ''' Update Schedulers ''' 244 | # TODO: Improve / add lr_scheduler for warmup 245 | if args.warmup_epochs > 0 and epoch+1 <= args.warmup_epochs: 246 | wu_lr = (args.learning_rate - args.base_lr) * \ 247 | (float(epoch+1) / args.warmup_epochs) + args.base_lr 248 | save_lr = optimiser.param_groups[0]['lr'] 249 | optimiser.param_groups[0]['lr'] = wu_lr 250 | else: 251 | # After warmup, decay lr with CosineAnnealingLR 252 | lr_decay.step() 253 | 254 | ''' Printing ''' 255 | if args.print_progress: # only validate using process 0 256 | logging.info('\n[Train] loss: {:.4f}'.format(epoch_pretrain_loss)) 257 | 258 | args.writer.add_scalars('epoch_loss', { 259 | 'pretrain': epoch_pretrain_loss}, epoch+1) 260 | args.writer.add_scalars('supervised_epoch_acc', { 261 | 'pretrain': epoch_pretrain_acc}, epoch+1) 262 | args.writer.add_scalars('supervised_epoch_acc_top5', { 263 | 'pretrain': epoch_pretrain_acc_top5}, epoch+1) 264 | args.writer.add_scalars('epoch_loss', {'pretrain': epoch_pretrain_loss}, epoch+1) 265 | args.writer.add_scalars('lr', {'pretrain': optimiser.param_groups[0]['lr']}, epoch+1) 266 | 267 | # For the best performing epoch, reset patience and save model, 268 | # else update patience. 269 | if epoch_pretrain_loss <= best_valid_loss: 270 | patience_counter = 0 271 | best_epoch = epoch + 1 272 | best_valid_loss = epoch_pretrain_loss 273 | 274 | # saving using process (rank) 0 only as all processes are in sync 275 | 276 | state = { 277 | #'args': args, 278 | 'encoder': encoder.state_dict(), 279 | 'optimiser': optimiser.state_dict(), 280 | 'epoch': epoch, 281 | } 282 | 283 | torch.save(state, args.checkpoint_dir) 284 | else: 285 | patience_counter += 1 286 | if patience_counter == (args.patience - 10): 287 | logging.info('\nPatience counter {}/{}.'.format( 288 | patience_counter, args.patience)) 289 | elif patience_counter == args.patience: 290 | logging.info('\nEarly stopping... no improvement after {} Epochs.'.format( 291 | args.patience)) 292 | break 293 | 294 | epoch_pretrain_loss = None # reset loss 295 | 296 | del state 297 | 298 | torch.cuda.empty_cache() 299 | 300 | gc.collect() # release unreferenced memory 301 | 302 | 303 | def finetune(encoder, dataloaders, args): 304 | ''' Finetune script - MoCo 305 | 306 | Freeze the encoder and train the supervised Linear Evaluation head with a Cross Entropy Loss. 307 | ''' 308 | 309 | mode = 'finetune' 310 | 311 | ''' Optimisers ''' 312 | # Only optimise the supervised head 313 | optimiser = get_optimiser((encoder,), mode, args) 314 | 315 | ''' Schedulers ''' 316 | # Cosine LR Decay 317 | lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.finetune_epochs) 318 | 319 | ''' Loss / Criterion ''' 320 | criterion = torch.nn.CrossEntropyLoss().cuda() 321 | 322 | # initilize Variables 323 | args.writer = SummaryWriter(args.summaries_dir) 324 | best_valid_loss = np.inf 325 | best_valid_acc = 0.0 326 | patience_counter = 0 327 | 328 | ''' Pretrain loop ''' 329 | for epoch in range(args.finetune_epochs): 330 | 331 | # Freeze the encoder, train classification head 332 | encoder.eval() 333 | 334 | sample_count = 0 335 | run_loss = 0 336 | run_top1 = 0.0 337 | run_top5 = 0.0 338 | 339 | # Print setup for distributed only printing on one node. 340 | if args.print_progress: 341 | logging.info('\nEpoch {}/{}:\n'.format(epoch+1, args.finetune_epochs)) 342 | # tqdm for process (rank) 0 only when using distributed training 343 | train_dataloader = tqdm(dataloaders['train']) 344 | else: 345 | train_dataloader = dataloaders['train'] 346 | 347 | ''' epoch loop ''' 348 | for i, (inputs, target) in enumerate(train_dataloader): 349 | 350 | inputs = inputs.cuda(non_blocking=True) 351 | 352 | target = target.cuda(non_blocking=True) 353 | 354 | # Forward pass 355 | optimiser.zero_grad() 356 | 357 | # Do not compute the gradients for the frozen encoder 358 | output = encoder(inputs) 359 | 360 | # Take pretrained encoder representations 361 | loss = criterion(output, target) 362 | 363 | loss.backward() 364 | 365 | optimiser.step() 366 | 367 | torch.cuda.synchronize() 368 | 369 | sample_count += inputs.size(0) 370 | 371 | run_loss += loss.item() 372 | 373 | predicted = output.argmax(1) 374 | 375 | acc = (predicted == target).sum().item() / target.size(0) 376 | 377 | run_top1 += acc 378 | 379 | _, output_topk = output.topk(5, 1, True, True) 380 | 381 | acc_top5 = (output_topk == target.view(-1, 1).expand_as(output_topk) 382 | ).sum().item() / target.size(0) # num corrects 383 | 384 | run_top5 += acc_top5 385 | 386 | epoch_finetune_loss = run_loss / len(dataloaders['train']) # sample_count 387 | 388 | epoch_finetune_acc = run_top1 / len(dataloaders['train']) 389 | 390 | epoch_finetune_acc_top5 = run_top5 / len(dataloaders['train']) 391 | 392 | ''' Update Schedulers ''' 393 | # Decay lr with CosineAnnealingLR 394 | lr_decay.step() 395 | 396 | ''' Printing ''' 397 | if args.print_progress: # only validate using process 0 398 | logging.info('\n[Finetune] loss: {:.4f},\t acc: {:.4f}, \t acc_top5: {:.4f}\n'.format( 399 | epoch_finetune_loss, epoch_finetune_acc, epoch_finetune_acc_top5)) 400 | 401 | args.writer.add_scalars('finetune_epoch_loss', {'train': epoch_finetune_loss}, epoch+1) 402 | args.writer.add_scalars('finetune_epoch_acc', {'train': epoch_finetune_acc}, epoch+1) 403 | args.writer.add_scalars('finetune_epoch_acc_top5', { 404 | 'train': epoch_finetune_acc_top5}, epoch+1) 405 | args.writer.add_scalars( 406 | 'finetune_lr', {'train': optimiser.param_groups[0]['lr']}, epoch+1) 407 | 408 | valid_loss, valid_acc, valid_acc_top5 = evaluate( 409 | encoder, dataloaders, 'valid', epoch, args) 410 | 411 | # For the best performing epoch, reset patience and save model, 412 | # else update patience. 413 | if valid_acc >= best_valid_acc: 414 | patience_counter = 0 415 | best_epoch = epoch + 1 416 | best_valid_acc = valid_acc 417 | 418 | # saving using process (rank) 0 only as all processes are in sync 419 | 420 | state = { 421 | #'args': args, 422 | 'base_encoder': encoder.state_dict(), 423 | 'optimiser': optimiser.state_dict(), 424 | 'epoch': epoch 425 | } 426 | 427 | torch.save(state, (args.checkpoint_dir[:-3] + "_finetune.pt")) 428 | else: 429 | patience_counter += 1 430 | if patience_counter == (args.patience - 10): 431 | logging.info('\nPatience counter {}/{}.'.format( 432 | patience_counter, args.patience)) 433 | elif patience_counter == args.patience: 434 | logging.info('\nEarly stopping... no improvement after {} Epochs.'.format( 435 | args.patience)) 436 | break 437 | 438 | epoch_finetune_loss = None # reset loss 439 | epoch_finetune_acc = None 440 | epoch_finetune_acc_top5 = None 441 | 442 | del state 443 | 444 | torch.cuda.empty_cache() 445 | 446 | gc.collect() # release unreferenced memory 447 | 448 | 449 | def evaluate(encoder, dataloaders, mode, epoch, args): 450 | ''' Evaluate script - MoCo 451 | 452 | Evaluate the encoder and Linear Evaluation head with Cross Entropy loss. 453 | ''' 454 | 455 | epoch_valid_loss = None # reset loss 456 | epoch_valid_acc = None # reset acc 457 | epoch_valid_acc_top5 = None 458 | 459 | ''' Loss / Criterion ''' 460 | criterion = nn.CrossEntropyLoss().cuda() 461 | 462 | # initilize Variables 463 | args.writer = SummaryWriter(args.summaries_dir) 464 | 465 | # Evaluate both encoder and class head 466 | encoder.eval() 467 | 468 | # initilize Variables 469 | sample_count = 0 470 | run_loss = 0 471 | run_top1 = 0.0 472 | run_top5 = 0.0 473 | 474 | # Print setup for distributed only printing on one node. 475 | if args.print_progress: 476 | # tqdm for process (rank) 0 only when using distributed training 477 | eval_dataloader = tqdm(dataloaders[mode]) 478 | else: 479 | eval_dataloader = dataloaders[mode] 480 | 481 | ''' epoch loop ''' 482 | for i, (inputs, target) in enumerate(eval_dataloader): 483 | 484 | # Do not compute gradient for encoder and classification head 485 | encoder.zero_grad() 486 | 487 | inputs = inputs.cuda(non_blocking=True) 488 | 489 | target = target.cuda(non_blocking=True) 490 | 491 | # Forward pass 492 | 493 | output = encoder(inputs) 494 | 495 | loss = criterion(output, target) 496 | 497 | torch.cuda.synchronize() 498 | 499 | sample_count += inputs.size(0) 500 | 501 | run_loss += loss.item() 502 | 503 | predicted = output.argmax(-1) 504 | 505 | acc = (predicted == target).sum().item() / target.size(0) 506 | 507 | run_top1 += acc 508 | 509 | _, output_topk = output.topk(5, 1, True, True) 510 | 511 | acc_top5 = (output_topk == target.view(-1, 1).expand_as(output_topk) 512 | ).sum().item() / target.size(0) # num corrects 513 | 514 | run_top5 += acc_top5 515 | 516 | epoch_valid_loss = run_loss / len(dataloaders[mode]) # sample_count 517 | 518 | epoch_valid_acc = run_top1 / len(dataloaders[mode]) 519 | 520 | epoch_valid_acc_top5 = run_top5 / len(dataloaders[mode]) 521 | 522 | ''' Printing ''' 523 | if args.print_progress: # only validate using process 0 524 | logging.info('\n[{}] loss: {:.4f},\t acc: {:.4f},\t acc_top5: {:.4f} \n'.format( 525 | mode, epoch_valid_loss, epoch_valid_acc, epoch_valid_acc_top5)) 526 | 527 | if mode != 'test': 528 | args.writer.add_scalars('finetune_epoch_loss', {mode: epoch_valid_loss}, epoch+1) 529 | args.writer.add_scalars('finetune_epoch_acc', {mode: epoch_valid_acc}, epoch+1) 530 | args.writer.add_scalars('finetune_epoch_acc_top5', { 531 | 'train': epoch_valid_acc_top5}, epoch+1) 532 | 533 | torch.cuda.empty_cache() 534 | 535 | gc.collect() # release unreferenced memory 536 | 537 | return epoch_valid_loss, epoch_valid_acc, epoch_valid_acc_top5 538 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import logging 4 | import numpy as np 5 | import time 6 | import random 7 | 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | import torch.nn as nn 12 | 13 | from PIL import Image, ImageFilter 14 | 15 | 16 | class GaussianBlur(object): 17 | """Gaussian blur augmentation: https://github.com/facebookresearch/moco/""" 18 | 19 | def __init__(self, sigma=[.1, 2.]): 20 | self.sigma = sigma 21 | 22 | def __call__(self, x): 23 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 24 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 25 | return x 26 | 27 | 28 | def load_moco(base_encoder, args): 29 | """ Loads the pre-trained MoCo model parameters. 30 | 31 | Applies the loaded pre-trained params to the base encoder used in Linear Evaluation, 32 | freezing all layers except the Linear Evaluation layer/s. 33 | 34 | Args: 35 | base_encoder (model): Randomly Initialised base_encoder. 36 | 37 | args (dict): Program arguments/commandline arguments. 38 | Returns: 39 | base_encoder (model): Initialised base_encoder with parameters from the MoCo query_encoder. 40 | """ 41 | print("\n\nLoading the model: {}\n\n".format(args.load_checkpoint_dir)) 42 | 43 | # Load the pretrained model 44 | checkpoint = torch.load(args.load_checkpoint_dir, map_location="cpu") 45 | 46 | # rename moco pre-trained keys 47 | state_dict = checkpoint['moco'] 48 | for k in list(state_dict.keys()): 49 | # retain only encoder_q up to before the embedding layer 50 | if k.startswith('encoder_q') and not k.startswith('encoder_q.fc'): 51 | # remove prefix 52 | state_dict[k[len("encoder_q."):]] = state_dict[k] 53 | # delete renamed or unused k 54 | del state_dict[k] 55 | 56 | # Load the encoder parameters 57 | base_encoder.load_state_dict(state_dict, strict=False) 58 | 59 | return base_encoder 60 | 61 | 62 | def load_sup(base_encoder, args): 63 | """ Loads the pre-trained supervised model parameters. 64 | 65 | Applies the loaded pre-trained params to the base encoder used in Linear Evaluation, 66 | freezing all layers except the Linear Evaluation layer/s. 67 | 68 | Args: 69 | base_encoder (model): Randomly Initialised base_encoder. 70 | 71 | args (dict): Program arguments/commandline arguments. 72 | Returns: 73 | base_encoder (model): Initialised base_encoder with parameters from the supervised base_encoder. 74 | """ 75 | print("\n\nLoading the model: {}\n\n".format(args.load_checkpoint_dir)) 76 | 77 | # Load the pretrained model 78 | checkpoint = torch.load(args.load_checkpoint_dir) 79 | 80 | # Load the encoder parameters 81 | base_encoder.load_state_dict(checkpoint['encoder']) 82 | 83 | # freeze all layers but the last fc 84 | for name, param in base_encoder.named_parameters(): 85 | if name not in ['fc.weight', 'fc.bias']: 86 | param.requires_grad = False 87 | 88 | # init the fc layer 89 | init_weights(base_encoder) 90 | 91 | return base_encoder 92 | 93 | 94 | def init_weights(m): 95 | '''Initialize weights with zeros 96 | ''' 97 | 98 | # init the fc layer 99 | m.fc.weight.data.normal_(mean=0.0, std=0.01) 100 | m.fc.bias.data.zero_() 101 | 102 | 103 | class CustomDataset(Dataset): 104 | """ Creates a custom pytorch dataset. 105 | 106 | - Creates two views of the same input used for unsupervised visual 107 | representational learning. (SimCLR, Moco, MocoV2) 108 | 109 | Args: 110 | data (array): Array / List of datasamples 111 | 112 | labels (array): Array / List of labels corresponding to the datasamples 113 | 114 | transforms (Dictionary, optional): The torchvision transformations 115 | to make to the datasamples. (Default: None) 116 | 117 | target_transform (Dictionary, optional): The torchvision transformations 118 | to make to the labels. (Default: None) 119 | 120 | two_crop (bool, optional): Whether to perform and return two views 121 | of the data input. (Default: False) 122 | 123 | Returns: 124 | img (Tensor): Datasamples to feed to the model. 125 | 126 | labels (Tensor): Corresponding lables to the datasamples. 127 | """ 128 | 129 | def __init__(self, data, labels, transform=None, target_transform=None, two_crop=False): 130 | 131 | # shuffle the dataset 132 | idx = np.random.permutation(data.shape[0]) 133 | 134 | if isinstance(data, torch.Tensor): 135 | data = data.numpy() # to work with `ToPILImage' 136 | 137 | self.data = data[idx] 138 | 139 | # when STL10 'unlabelled' 140 | if not labels is None: 141 | self.labels = labels[idx] 142 | else: 143 | self.labels = labels 144 | 145 | self.transform = transform 146 | self.target_transform = target_transform 147 | self.two_crop = two_crop 148 | 149 | def __len__(self): 150 | return self.data.shape[0] 151 | 152 | def __getitem__(self, index): 153 | 154 | # If the input data is in form from torchvision.datasets.ImageFolder 155 | if isinstance(self.data[index][0], np.str_): 156 | # Load image from path 157 | image = Image.open(self.data[index][0]).convert('RGB') 158 | 159 | else: 160 | # Get image / numpy pixel values 161 | image = self.data[index] 162 | 163 | if self.transform is not None: 164 | 165 | # Data augmentation and normalisation 166 | img = self.transform(image) 167 | 168 | if self.target_transform is not None: 169 | 170 | # Transforms the target, i.e. object detection, segmentation 171 | target = self.target_transform(target) 172 | 173 | if self.two_crop: 174 | 175 | # Augments the images again to create a second view of the data 176 | img2 = self.transform(image) 177 | 178 | # Combine the views to pass to the model 179 | img = torch.cat([img, img2], dim=0) 180 | 181 | # when STL10 'unlabelled' 182 | if self.labels is None: 183 | return img, torch.Tensor([0]) 184 | else: 185 | return img, self.labels[index].long() 186 | 187 | 188 | def random_split_image_folder(data, labels, n_classes, n_samples_per_class): 189 | """ Creates a class-balanced validation set from a training set. 190 | 191 | Specifically for the image folder class 192 | """ 193 | 194 | train_x, train_y, valid_x, valid_y = [], [], [], [] 195 | 196 | if isinstance(labels, list): 197 | labels = np.array(labels) 198 | 199 | for i in range(n_classes): 200 | # get indices of all class 'c' samples 201 | c_idx = (np.array(labels) == i).nonzero()[0] 202 | # get n unique class 'c' samples 203 | valid_samples = np.random.choice(c_idx, n_samples_per_class[i], replace=False) 204 | # get remaining samples of class 'c' 205 | train_samples = np.setdiff1d(c_idx, valid_samples) 206 | # assign class c samples to validation, and remaining to training 207 | train_x.extend(data[train_samples]) 208 | train_y.extend(labels[train_samples]) 209 | valid_x.extend(data[valid_samples]) 210 | valid_y.extend(labels[valid_samples]) 211 | 212 | # torch.from_numpy(np.stack(labels)) this takes the list of class ids and turns them to tensor.long 213 | 214 | return {'train': train_x, 'valid': valid_x}, \ 215 | {'train': torch.from_numpy(np.stack(train_y)), 'valid': torch.from_numpy(np.stack(valid_y))} 216 | 217 | 218 | def random_split(data, labels, n_classes, n_samples_per_class): 219 | """ Creates a class-balanced validation set from a training set. 220 | """ 221 | 222 | train_x, train_y, valid_x, valid_y = [], [], [], [] 223 | 224 | if isinstance(labels, list): 225 | labels = np.array(labels) 226 | 227 | for i in range(n_classes): 228 | # get indices of all class 'c' samples 229 | c_idx = (np.array(labels) == i).nonzero()[0] 230 | # get n unique class 'c' samples 231 | valid_samples = np.random.choice(c_idx, n_samples_per_class[i], replace=False) 232 | # get remaining samples of class 'c' 233 | train_samples = np.setdiff1d(c_idx, valid_samples) 234 | # assign class c samples to validation, and remaining to training 235 | train_x.extend(data[train_samples]) 236 | train_y.extend(labels[train_samples]) 237 | valid_x.extend(data[valid_samples]) 238 | valid_y.extend(labels[valid_samples]) 239 | 240 | if isinstance(data, torch.Tensor): 241 | # torch.stack transforms list of tensors to tensor 242 | return {'train': torch.stack(train_x), 'valid': torch.stack(valid_x)}, \ 243 | {'train': torch.stack(train_y), 'valid': torch.stack(valid_y)} 244 | # transforms list of np arrays to tensor 245 | return {'train': torch.from_numpy(np.stack(train_x)), 246 | 'valid': torch.from_numpy(np.stack(valid_x))}, \ 247 | {'train': torch.from_numpy(np.stack(train_y)), 248 | 'valid': torch.from_numpy(np.stack(valid_y))} 249 | 250 | 251 | def sample_weights(labels): 252 | """ Calculates per sample weights. """ 253 | class_sample_count = np.unique(labels, return_counts=True)[1] 254 | class_weights = 1. / torch.Tensor(class_sample_count) 255 | return class_weights[list(map(int, labels))] 256 | 257 | 258 | def experiment_config(parser, args): 259 | """ Handles experiment configuration and creates new dirs for model. 260 | """ 261 | # check number of models already saved in 'experiments' dir, add 1 to get new model number 262 | run_dir = os.path.join(os.path.split(os.getcwd())[0], 'experiments') 263 | 264 | os.makedirs(run_dir, exist_ok=True) 265 | 266 | run_name = time.strftime("%Y-%m-%d_%H-%M-%S") 267 | 268 | # create all save dirs 269 | model_dir = os.path.join(run_dir, run_name) 270 | 271 | os.makedirs(model_dir, exist_ok=True) 272 | 273 | args.summaries_dir = os.path.join(model_dir, 'summaries') 274 | args.checkpoint_dir = os.path.join(model_dir, 'checkpoint.pt') 275 | 276 | if not args.finetune: 277 | args.load_checkpoint_dir = args.checkpoint_dir 278 | 279 | os.makedirs(args.summaries_dir, exist_ok=True) 280 | 281 | # save hyperparameters in .txt file 282 | with open(os.path.join(model_dir, 'hyperparams.txt'), 'w') as logs: 283 | for key, value in vars(args).items(): 284 | logs.write('--{0}={1} \n'.format(str(key), str(value))) 285 | 286 | # save config file used in .txt file 287 | with open(os.path.join(model_dir, 'config.txt'), 'w') as logs: 288 | # Remove the string from the blur_sigma value list 289 | config = parser.format_values().replace("'", "") 290 | # Remove the first line, path to original config file 291 | config = config[config.find('\n')+1:] 292 | logs.write('{}'.format(config)) 293 | 294 | # reset root logger 295 | [logging.root.removeHandler(handler) for handler in logging.root.handlers[:]] 296 | # info logger for saving command line outputs during training 297 | logging.basicConfig(level=logging.INFO, format='%(message)s', 298 | handlers=[logging.FileHandler(os.path.join(model_dir, 'trainlogs.txt')), 299 | logging.StreamHandler()]) 300 | return args 301 | 302 | 303 | def print_network(model, args): 304 | """ Utility for printing out a model's architecture. 305 | """ 306 | logging.info('-'*70) # print some info on architecture 307 | logging.info('{:>25} {:>27} {:>15}'.format('Layer.Parameter', 'Shape', 'Param#')) 308 | logging.info('-'*70) 309 | 310 | for param in model.state_dict(): 311 | p_name = param.split('.')[-2]+'.'+param.split('.')[-1] 312 | # don't print batch norm layers for prettyness 313 | if p_name[:2] != 'BN' and p_name[:2] != 'bn': 314 | logging.info( 315 | '{:>25} {:>27} {:>15}'.format( 316 | p_name, 317 | str(list(model.state_dict()[param].squeeze().size())), 318 | '{0:,}'.format(np.product(list(model.state_dict()[param].size()))) 319 | ) 320 | ) 321 | logging.info('-'*70) 322 | 323 | logging.info('\nTotal params: {:,}\n\nSummaries dir: {}\n'.format( 324 | sum(p.numel() for p in model.parameters()), 325 | args.summaries_dir)) 326 | 327 | for key, value in vars(args).items(): 328 | if str(key) != 'print_progress': 329 | logging.info('--{0}: {1}'.format(str(key), str(value))) 330 | --------------------------------------------------------------------------------