├── README.md
├── config.conf
├── media
└── CosineAnnealingLR.png
├── requirements.txt
└── src
├── datasets.py
├── kill_zombie.sh
├── main.py
├── model
├── __init__.py
├── moco.py
└── network.py
├── optimisers.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Unofficial Pytorch Implementation of MocCoV2
2 |
3 | Unofficial Pytorch implemnentation of [MoCoV2](https://arxiv.org/abs/2003.04297): "Improved Baselines with Momentum Contrastive Learning."
4 |
5 | This repo uses elements from [CMC](https://github.com/HobbitLong/CMC) and [MoCo Code](https://github.com/facebookresearch/moco), in which the MoCoV2 model is implementated into my existing pytorch Boilerplate and workflow. Additionally, this repo aims to align nicely with my implementation of [SimCLR](https://arxiv.org/pdf/2002.05709.pdf) found [here:](https://github.com/AidenDurrant/SimCLR-Pytorch/).
6 |
7 | Work in progress, replicating results on ImageNet, TinyImageNet, CIFAR10, CIFAR100, STL10.
8 |
9 | * **Author**: Aiden Durrant
10 | * **Email**: adurrant@lincoln.ac.uk
11 |
12 | ### Results:
13 |
14 | Top-1 Acc / Error of linear evaluation on CIFAR10:
15 |
16 | Testing is performed on the CIFAR10 Val set, whilst the Train set is split into Train and Val for tuning.
17 |
18 | | Method | Batch Size | ResNet | Projection Head Dim. | Pre-train Epochs | Optimizer | Eval Epochs | Acc(%) |
19 | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
20 | | MoCoV2 + Linear eval. | 128 | ResNet18 | 128 | 1000 | Adam | 100 | - |
21 | | MoCoV2 + Linear eval. | 128 | ResNet18 | 128 | 1000 | SGD | 100 | - |
22 | | MoCoV2 + Linear eval. | 128 | ResNet34 | 128 | 1000 | Adam | 100 | - |
23 | | MoCoV2 + Linear eval. | 128 | ResNet34 | 128 | 1000 | SGD | 100 | - |
24 | | MoCoV2 + Linear eval. | 128 | ResNet50 | 128 | 1000 | Adam | 100 | - |
25 | | MoCoV2 + Linear eval. | 128 | ResNet50 | 128 | 1000 | SGD | 100 | - |
26 | | Supervised + Linear eval.| 128 | ResNet18 | 128 | 1000 | SGD | 100 | - |
27 | | Random Init + Linear eval.| 128 | ResNet18 | 128 | 1000 | SGD | 100 | - |
28 |
29 | **Note**: For Linear Evaluation the ResNet is frozen (all layers), training is only perfomed on the supervised Linear Evaluation layer.
30 |
31 | ### Plots:
32 |
33 | **ResNet-18**
34 |
35 |
36 |
37 |
38 | **ResNet-50**
39 |
40 |
41 |
42 | ## Usage / Run
43 |
44 | ### Contrastive Training and Linear Evaluation
45 | Launch the script from `src/main.py`:
46 |
47 | By default the CIFAR-10 dataset is used, use `--dataset` to select from: cifar10, cifar100, stl10, imagenet, tinyimagenet. For ImageNet and TinyImageNet please define a path to the dataset.
48 |
49 | Training uses CosineAnnealingLR decay and linear warmup to replicate the training settings in https://arxiv.org/pdf/2002.05709.pdf. The learning_rate is plotted below:
50 |
51 |
52 |
53 | #### DistributedDataParallel
54 |
55 | To train with **Distributed** for a slight computational speedup with multiple GPUs, use:
56 |
57 | `python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py`
58 |
59 |
60 | This will train on a single machine (`nnodes=1`), assigning 1 process per GPU where `nproc_per_node=2` refers to training on 2 GPUs. To train on `N` GPUs simply launch `N` processes by setting `nproc_per_node=N`.
61 |
62 | The number of CPU threads to use per process is hard coded to `torch.set_num_threads(1)` for safety, and can be changed to `your # cpu threads / nproc_per_node` for better performance. ([fabio-deep](https://github.com/fabio-deep/Distributed-Pytorch-Boilerplate))
63 |
64 | For more info on **multi-node** and **multi-gpu** distributed training refer to https://github.com/hgrover/pytorchdistr/blob/master/README.md
65 |
66 | #### DataParallel
67 |
68 | To train with traditional **nn.DataParallel** with multiple GPUs, use:
69 |
70 | `python main.py --no_distributed`
71 |
72 | **Note:** The default config selects to use `--no_distributed`, therefore runnning `python main.py` runs the default hyperparameters without DistributedDataParallel.
73 |
74 | ### Linear Evaluation of a Pre-Trained Model
75 |
76 | To evaluate the performace of a pre-trained model in a linear classification task just include the flag `--finetune` and provide a path to the pretrained model to `--load_checkpoint_dir`.
77 |
78 | Example:
79 |
80 | `python main.py --no_distributed --finetune --load_checkpoint_dir ~/Documents/MoCo-Pytorch/experiments/yyyy-mm-dd_hh-mm-ss/checkpoint.pt`
81 |
82 | ### Hyperparameters
83 |
84 | The configuration / choice of hyperparameters for the script is handled either by command line arguments or config files.
85 |
86 | An example config file is given at `MoCo-Pytorch/config.conf`. Additionally, `.txt` or `.conf` files can be passed if you prefer, this is achieved using the flag `--c `.
87 |
88 | A list of arguments/options can be found below:
89 |
90 | ```
91 | usage: main.py [-h] [-c MY_CONFIG] [--dataset DATASET]
92 | [--dataset_path DATASET_PATH] [--model MODEL]
93 | [--n_epochs N_EPOCHS] [--finetune_epochs FINETUNE_EPOCHS]
94 | [--warmup_epochs WARMUP_EPOCHS] [--batch_size BATCH_SIZE]
95 | [--learning_rate LEARNING_RATE] [--base_lr BASE_LR]
96 | [--finetune_learning_rate FINETUNE_LEARNING_RATE]
97 | [--weight_decay WEIGHT_DECAY]
98 | [--finetune_weight_decay FINETUNE_WEIGHT_DECAY]
99 | [--optimiser OPTIMISER] [--patience PATIENCE]
100 | [--queue_size QUEUE_SIZE] [--queue_momentum QUEUE_MOMENTUM]
101 | [--temperature TEMPERATURE] [--jitter_d JITTER_D]
102 | [--jitter_p JITTER_P] [--blur_sigma BLUR_SIGMA BLUR_SIGMA]
103 | [--blur_p BLUR_P] [--grey_p GREY_P] [--no_twocrop]
104 | [--load_checkpoint_dir LOAD_CHECKPOINT_DIR] [--no_distributed]
105 | [--finetune] [--supervised]
106 |
107 | Pytorch MocoV2 Args that start with '--' (eg. --dataset) can also be set in a
108 | config file (/MoCo-Pytorch/config.conf or specified via -c). Config
109 | file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see
110 | syntax at https://goo.gl/R74nmi). If an arg is specified in more than one
111 | place, then commandline values override config file values which override
112 | defaults.
113 |
114 | optional arguments:
115 | -h, --help show this help message and exit
116 | -c MY_CONFIG, --my-config MY_CONFIG
117 | config file path
118 | --dataset DATASET Dataset, (Options: cifar10, cifar100, stl10, imagenet,
119 | tinyimagenet).
120 | --dataset_path DATASET_PATH
121 | Path to dataset, Not needed for TorchVision Datasets.
122 | --model MODEL Model, (Options: resnet18, resnet34, resnet50,
123 | resnet101, resnet152).
124 | --n_epochs N_EPOCHS Number of Epochs in Contrastive Training.
125 | --finetune_epochs FINETUNE_EPOCHS
126 | Number of Epochs in Linear Classification Training.
127 | --warmup_epochs WARMUP_EPOCHS
128 | Number of Warmup Epochs During Contrastive Training.
129 | --batch_size BATCH_SIZE
130 | Number of Samples Per Batch.
131 | --learning_rate LEARNING_RATE
132 | Starting Learing Rate for Contrastive Training.
133 | --base_lr BASE_LR Base / Minimum Learing Rate to Begin Linear Warmup.
134 | --finetune_learning_rate FINETUNE_LEARNING_RATE
135 | Starting Learing Rate for Linear Classification
136 | Training.
137 | --weight_decay WEIGHT_DECAY
138 | Contrastive Learning Weight Decay Regularisation
139 | Factor.
140 | --finetune_weight_decay FINETUNE_WEIGHT_DECAY
141 | Linear Classification Training Weight Decay
142 | Regularisation Factor.
143 | --optimiser OPTIMISER
144 | Optimiser, (Options: sgd, adam, lars).
145 | --patience PATIENCE Number of Epochs to Wait for Improvement.
146 | --queue_size QUEUE_SIZE
147 | Size of Memory Queue, Must be Divisible by batch_size.
148 | --queue_momentum QUEUE_MOMENTUM
149 | Momentum for the Key Encoder Update.
150 | --temperature TEMPERATURE
151 | InfoNCE Temperature Factor
152 | --jitter_d JITTER_D Distortion Factor for the Random Colour Jitter
153 | Augmentation
154 | --jitter_p JITTER_P Probability to Apply Random Colour Jitter Augmentation
155 | --blur_sigma BLUR_SIGMA BLUR_SIGMA
156 | Radius to Apply Random Colour Jitter Augmentation
157 | --blur_p BLUR_P Probability to Apply Gaussian Blur Augmentation
158 | --grey_p GREY_P Probability to Apply Random Grey Scale
159 | --no_twocrop Whether or Not to Use Two Crop Augmentation, Used to
160 | Create Two Views of the Input for Contrastive
161 | Learning. (Default: True)
162 | --load_checkpoint_dir LOAD_CHECKPOINT_DIR
163 | Path to Load Pre-trained Model From.
164 | --no_distributed Whether or Not to Use Distributed Training. (Default:
165 | True)
166 | --finetune Perform Only Linear Classification Training. (Default:
167 | False)
168 | --supervised Perform Supervised Pre-Training. (Default: False)
169 | ```
170 |
171 | ## Dependencies
172 |
173 | Install dependencies with `requrements.txt`
174 |
175 | `pip install -r requrements.txt`
176 |
177 | ```
178 | torch
179 | torchvision
180 | tensorboard
181 | tqdm
182 | configargparse
183 | ```
184 |
185 | ## References
186 | * K. He, et. al [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/abs/1911.05722)
187 |
188 |
189 | * X. Chen, et. al [Improved Baselines with Momentum Contrastive Learning](https://arxiv.org/abs/2003.04297)
190 |
191 |
192 | * facebookresearch [MoCo Code](https://github.com/facebookresearch/moco)
193 |
194 |
195 | * HobbitLong [CMC](https://github.com/HobbitLong/CMC)
196 |
197 |
198 | * T. Chen, et. al [SimCLR Paper](https://arxiv.org/pdf/2002.05709.pdf)
199 |
200 |
201 | * noahgolmant [pytorch-lars](https://github.com/noahgolmant/pytorch-lars)
202 |
203 |
204 | * pytorch [torchvision ResNet](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py)
205 |
206 |
207 | * fabio-deep [Distributed-Pytorch-Boilerplate](https://github.com/fabio-deep/Distributed-Pytorch-Boilerplate)
208 |
209 | ## TODO
210 | - [ ] Command Line Argument for MoCo V1 or V2
211 | - [ ] Research and Implement BatchNorm Shuffle for DistributedDataParallel
212 | - [ ] Run All Experiment Comparisons
213 |
--------------------------------------------------------------------------------
/config.conf:
--------------------------------------------------------------------------------
1 | # Config File for MoCo
2 |
3 | # Datset
4 | --dataset=cifar10 # Dataset
5 | --dataset_path=None # Path to dataset, Not needed for TorchVision Datasets.
6 |
7 | # Model
8 | --model=resnet18 # Model
9 |
10 | # Epochs
11 | --n_epochs=200 # Number of Epochs in Contrastive Training.
12 | --finetune_epochs=100 # Number of Epochs in Linear Classification Training.
13 | --warmup_epochs=10 # Number of Warmup Epochs During Contrastive Training.
14 |
15 | # Core Training Params
16 | --batch_size=128 # Number of Samples Per Batch.
17 | --learning_rate=0.015 # Starting Learing Rate for Contrastive Training.
18 | --base_lr=0.0001 # Base / Minimum Learing Rate to Begin Linear Warmup.
19 | --finetune_learning_rate=10.0 # Starting Learing Rate for Linear Classification
20 |
21 | # Regularisation
22 | --weight_decay=1e-6 # Contrastive Learning Weight Decay
23 | --finetune_weight_decay=0.0 # Linear Classification Training Weight Decay
24 | --patience=100 # Number of Epochs to Wait for Improvement.
25 |
26 | # Optimiser
27 | --optimiser=sgd # Optimiser
28 |
29 | # MoCo Options
30 | --queue_size=65536 # Size of Memory Queue, Must be Divisible by batch_size.
31 | --queue_momentum=0.99 # Momentum for the Key Encoder Update.
32 | --temperature=0.07 # InfoNCE Temperature Factor
33 |
34 | # Augmentation
35 | --jitter_d=0.5 # Distortion Factor for the Random Colour Jitter
36 | --jitter_p=0.8 # Probability to Apply Random Colour Jitter
37 | --blur_sigma=[0.1,2.0] # Radius to Apply Random Colour Jitter
38 | --blur_p=0.5 # Probability to Apply Gaussian Blur
39 | --grey_p=0.2 # Probability to Apply Random Grey Scale
40 | ; --no_twocrop # Whether or Not to Use Two Crop Augmentation
41 |
42 |
43 | # Distirbuted Options
44 | --no_distributed # Whether or Not to Use Distributed Training
45 |
46 |
47 | # Finetune Options
48 | ; --finetune # Perform Only Linear Classification Training
49 | ; --supervised # Perform Supervised Pre-Training
50 | ; --load_checkpoint_dir= # Path to Load Pre-trained Model
51 |
--------------------------------------------------------------------------------
/media/CosineAnnealingLR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AidenDurrant/MoCo-Pytorch/373b5fafdbf51e4c6c19a984b7e8d41a296cf6d9/media/CosineAnnealingLR.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | tensorboard
4 | tqdm
5 | ConfigArgParse
6 |
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import numpy as np
4 |
5 | import torch
6 | from torch.utils.data import DataLoader, WeightedRandomSampler
7 | from torch.utils.data.distributed import DistributedSampler
8 |
9 | import torchvision
10 | from torchvision import transforms
11 | from torchvision.datasets import CIFAR10, MNIST, STL10, ImageNet, CIFAR100, ImageFolder
12 |
13 | from utils import *
14 |
15 |
16 | def get_dataloaders(args):
17 | '''
18 | Retrives the dataloaders for the dataset of choice.
19 |
20 | Initalise variables that correspond to the dataset of choice.
21 |
22 | args:
23 | args (dict): Program arguments/commandline arguments.
24 |
25 | returns:
26 | dataloaders (dict): pretrain,train,valid,train_valid,test set split dataloaders.
27 |
28 | args (dict): Updated and Additional program/commandline arguments dependent on dataset.
29 |
30 | '''
31 | if args.dataset == 'cifar10':
32 | dataset = 'CIFAR10'
33 |
34 | args.class_names = (
35 | 'plane', 'car', 'bird', 'cat',
36 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
37 | ) # 0,1,2,3,4,5,6,7,8,9 labels
38 |
39 | args.crop_dim = 32
40 | args.n_channels, args.n_classes = 3, 10
41 |
42 | # Get and make dir to download dataset to.
43 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset)
44 |
45 | if not os.path.exists(working_dir):
46 | os.makedirs(working_dir)
47 |
48 | dataset_paths = {'train': os.path.join(working_dir, 'train'),
49 | 'test': os.path.join(working_dir, 'test')}
50 |
51 | dataloaders = cifar_dataloader(args, dataset_paths)
52 |
53 | elif args.dataset == 'cifar100':
54 | dataset = 'CIFAR100'
55 |
56 | args.class_names = None
57 |
58 | args.crop_dim = 32
59 | args.n_channels, args.n_classes = 3, 100
60 |
61 | # Get and make dir to download dataset to.
62 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset)
63 |
64 | if not os.path.exists(working_dir):
65 | os.makedirs(working_dir)
66 |
67 | dataset_paths = {'train': os.path.join(working_dir, 'train'),
68 | 'test': os.path.join(working_dir, 'test')}
69 |
70 | dataloaders = cifar_dataloader(args, dataset_paths)
71 |
72 | elif args.dataset == 'stl10':
73 | dataset = 'STL10'
74 |
75 | args.class_names = None
76 |
77 | args.crop_dim = 96
78 | args.n_channels, args.n_classes = 3, 10
79 |
80 | # Get and make dir to download dataset to.
81 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset)
82 |
83 | if not os.path.exists(working_dir):
84 | os.makedirs(working_dir)
85 |
86 | dataset_paths = {'train': os.path.join(working_dir, 'train'),
87 | 'test': os.path.join(working_dir, 'test'),
88 | 'pretrain': os.path.join(working_dir, 'unlabeled')}
89 |
90 | dataloaders = stl10_dataloader(args, dataset_paths)
91 |
92 | elif args.dataset == 'imagenet':
93 | dataset = 'ImageNet'
94 |
95 | args.class_names = None
96 |
97 | args.crop_dim = 224
98 | args.n_channels, args.n_classes = 3, 1000
99 |
100 | # Get and make dir to download dataset to.
101 | target_dir = args.dataset_path
102 |
103 | if not target_dir is None:
104 | dataset_paths = {'train': os.path.join(target_dir, 'train'),
105 | 'test': os.path.join(target_dir, 'val')}
106 |
107 | dataloaders = imagenet_dataloader(args, dataset_paths)
108 |
109 | else:
110 | NotImplementedError('Please Select a path for the {} Dataset.'.format(args.dataset))
111 |
112 | elif args.dataset == 'tinyimagenet':
113 | dataset = 'TinyImageNet'
114 |
115 | args.class_names = None
116 |
117 | args.crop_dim = 64
118 | args.n_channels, args.n_classes = 3, 200
119 |
120 | # Get and make dir to download dataset to.
121 | target_dir = args.dataset_path
122 |
123 | if not target_dir is None:
124 | dataset_paths = {'train': os.path.join(target_dir, 'train'),
125 | 'test': os.path.join(target_dir, 'val')}
126 |
127 | dataloaders = imagenet_dataloader(args, dataset_paths)
128 |
129 | else:
130 | NotImplementedError('Please Select a path for the {} Dataset.'.format(args.dataset))
131 | else:
132 | NotImplementedError('{} dataset not available.'.format(args.dataset))
133 |
134 | return dataloaders, args
135 |
136 |
137 | def imagenet_dataloader(args, dataset_paths):
138 | '''
139 | Loads the ImageNet or TinyImageNet dataset performing augmentaions.
140 |
141 | Generates splits of the training set to produce a validation set.
142 |
143 | args:
144 | args (dict): Program/commandline arguments.
145 |
146 | dataset_paths (dict): Paths to each datset split.
147 |
148 | Returns:
149 |
150 | dataloaders (): pretrain,train,valid,train_valid,test set split dataloaders.
151 | '''
152 |
153 | # guassian_blur from https://github.com/facebookresearch/moco/
154 | guassian_blur = transforms.RandomApply([GaussianBlur(args.blur_sigma)], p=args.blur_p)
155 |
156 | color_jitter = transforms.ColorJitter(
157 | 0.8*args.jitter_d, 0.8*args.jitter_d, 0.8*args.jitter_d, 0.2*args.jitter_d)
158 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=args.jitter_p)
159 |
160 | rnd_grey = transforms.RandomGrayscale(p=args.grey_p)
161 |
162 | # Base train and test augmentaions
163 | transf = {
164 | 'train': transforms.Compose([
165 | transforms.RandomResizedCrop((args.crop_dim, args.crop_dim)),
166 | rnd_color_jitter,
167 | rnd_grey,
168 | guassian_blur,
169 | transforms.RandomHorizontalFlip(),
170 | transforms.ToTensor(),
171 | transforms.Normalize((0.485, 0.456, 0.406),
172 | (0.229, 0.224, 0.225))]),
173 | 'test': transforms.Compose([
174 | transforms.CenterCrop((args.crop_dim, args.crop_dim)),
175 | transforms.ToTensor(),
176 | transforms.Normalize((0.485, 0.456, 0.406),
177 | (0.229, 0.224, 0.225))])
178 | }
179 |
180 | config = {'train': True, 'test': False}
181 |
182 | datasets = {i: ImageFolder(root=dataset_paths[i]) for i in config.keys()}
183 |
184 | # weighted sampler weights for full(f) training set
185 | f_s_weights = sample_weights(datasets['train'].targets)
186 |
187 | # return data, labels dicts for new train set and class-balanced valid set
188 | # 50 is the num of samples to be split into the val set for each class (1000)
189 | data, labels = random_split_image_folder(data=np.asarray(datasets['train'].samples),
190 | labels=datasets['train'].targets,
191 | n_classes=args.n_classes,
192 | n_samples_per_class=np.repeat(50, args.n_classes).reshape(-1))
193 |
194 | # torch.from_numpy(np.stack(labels)) this takes the list of class ids and turns them to tensor.long
195 |
196 | # original full training set
197 | datasets['train_valid'] = CustomDataset(data=np.asarray(datasets['train'].samples),
198 | labels=torch.from_numpy(np.stack(datasets['train'].targets)), transform=transf['train'], two_crop=args.twocrop)
199 |
200 | # original test set
201 | datasets['test'] = CustomDataset(data=np.asarray(datasets['test'].samples),
202 | labels=torch.from_numpy(np.stack(datasets['test'].targets)), transform=transf['test'], two_crop=False)
203 |
204 | # make new pretraining set without validation samples
205 | datasets['pretrain'] = CustomDataset(data=np.asarray(data['train']),
206 | labels=labels['train'], transform=transf['train'], two_crop=args.twocrop)
207 |
208 | # make new finetuning set without validation samples
209 | datasets['train'] = CustomDataset(data=np.asarray(data['train']),
210 | labels=labels['train'], transform=transf['train'], two_crop=False)
211 |
212 | # make class balanced validation set for finetuning
213 | datasets['valid'] = CustomDataset(data=np.asarray(data['valid']),
214 | labels=labels['valid'], transform=transf['test'], two_crop=False)
215 |
216 | # weighted sampler weights for new training set
217 | s_weights = sample_weights(datasets['pretrain'].labels)
218 |
219 | config = {
220 | 'pretrain': WeightedRandomSampler(s_weights,
221 | num_samples=len(s_weights), replacement=True),
222 | 'train': WeightedRandomSampler(s_weights,
223 | num_samples=len(s_weights), replacement=True),
224 | 'train_valid': WeightedRandomSampler(f_s_weights,
225 | num_samples=len(f_s_weights), replacement=True),
226 | 'valid': None, 'test': None
227 | }
228 |
229 | if args.distributed:
230 | config = {'pretrain': DistributedSampler(datasets['pretrain']),
231 | 'train': DistributedSampler(datasets['train']),
232 | 'train_valid': DistributedSampler(datasets['train_valid']),
233 | 'valid': None, 'test': None}
234 |
235 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i],
236 | num_workers=8, pin_memory=True, drop_last=True,
237 | batch_size=args.batch_size) for i in config.keys()}
238 |
239 | return dataloaders
240 |
241 |
242 | def stl10_dataloader(args, dataset_paths):
243 | '''
244 | Loads the STL10 dataset performing augmentaions.
245 |
246 | Generates splits of the training set to produce a validation set.
247 |
248 | args:
249 | args (dict): Program/commandline arguments.
250 |
251 | dataset_paths (dict): Paths to each datset split.
252 |
253 | Returns:
254 |
255 | dataloaders (): pretrain,train,valid,train_valid,test set split dataloaders.
256 | '''
257 |
258 | # guassian_blur from https://github.com/facebookresearch/moco/
259 | guassian_blur = transforms.RandomApply([GaussianBlur(args.blur_sigma)], p=args.blur_p)
260 |
261 | color_jitter = transforms.ColorJitter(
262 | 0.8*args.jitter_d, 0.8*args.jitter_d, 0.8*args.jitter_d, 0.2*args.jitter_d)
263 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=args.jitter_p)
264 |
265 | rnd_grey = transforms.RandomGrayscale(p=args.grey_p)
266 |
267 | # Base train and test augmentaions
268 | transf = {
269 | 'train': transforms.Compose([
270 | transforms.ToPILImage(),
271 | rnd_color_jitter,
272 | rnd_grey,
273 | guassian_blur,
274 | transforms.RandomResizedCrop((args.crop_dim, args.crop_dim)),
275 | transforms.RandomHorizontalFlip(),
276 | transforms.ToTensor(),
277 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091),
278 | (0.24703223, 0.24348513, 0.26158784))]),
279 | 'valid': transforms.Compose([
280 | transforms.ToPILImage(),
281 | transforms.ToTensor(),
282 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091),
283 | (0.24703223, 0.24348513, 0.26158784))]),
284 | 'test': transforms.Compose([
285 | transforms.ToTensor(),
286 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091),
287 | (0.24703223, 0.24348513, 0.26158784))])
288 | }
289 |
290 | transf['pretrain'] = transf['train']
291 |
292 | config = {'train': 'train', 'test': 'test', 'pretrain': 'unlabeled'}
293 |
294 | datasets = {i: STL10(root=dataset_paths[i], transform=transf[i],
295 | split=config[i], download=True) for i in config.keys()}
296 |
297 | # weighted sampler weights for full(f) training set
298 | f_s_weights = sample_weights(datasets['train'].labels)
299 |
300 | # return data, labels dicts for new train set and class-balanced valid set
301 | # 500 is the num of samples to be split into the val set for each class (10)
302 | data, labels = random_split(data=datasets['train'].data,
303 | labels=datasets['train'].labels,
304 | n_classes=args.n_classes,
305 | n_samples_per_class=np.repeat(50, args.n_classes).reshape(-1))
306 |
307 | # save original full training set
308 | datasets['train_valid'] = datasets['train']
309 |
310 | # make new pretraining set without validation samples
311 | datasets['pretrain'] = CustomDataset(data=datasets['pretrain'].data,
312 | labels=None, transform=transf['pretrain'], two_crop=args.twocrop)
313 |
314 | # make new finetuning set without validation samples
315 | datasets['train'] = CustomDataset(data=data['train'],
316 | labels=labels['train'], transform=transf['train'], two_crop=False)
317 |
318 | # make class balanced validation set for finetuning
319 | datasets['valid'] = CustomDataset(data=data['valid'],
320 | labels=labels['valid'], transform=transf['valid'], two_crop=False)
321 |
322 | # weighted sampler weights for new training set
323 | s_weights = sample_weights(datasets['train'].labels)
324 |
325 | config = {
326 | 'pretrain': None,
327 | 'train': WeightedRandomSampler(s_weights,
328 | num_samples=len(s_weights), replacement=True),
329 | 'train_valid': WeightedRandomSampler(f_s_weights,
330 | num_samples=len(f_s_weights), replacement=True),
331 | 'valid': None, 'test': None
332 | }
333 |
334 | if args.distributed:
335 | config = {'pretrain': DistributedSampler(datasets['pretrain']),
336 | 'train': DistributedSampler(datasets['train']),
337 | 'train_valid': DistributedSampler(datasets['train_valid']),
338 | 'valid': None, 'test': None}
339 |
340 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i],
341 | num_workers=8, pin_memory=True, drop_last=True,
342 | batch_size=args.batch_size) for i in config.keys()}
343 |
344 | return dataloaders
345 |
346 |
347 | def cifar_dataloader(args, dataset_paths):
348 | '''
349 | Loads the CIFAR10 or CIFAR100 dataset performing augmentaions.
350 |
351 | Generates splits of the training set to produce a validation set.
352 |
353 | args:
354 | args (dict): Program/commandline arguments.
355 |
356 | dataset_paths (dict): Paths to each datset split.
357 |
358 | Returns:
359 |
360 | dataloaders (): pretrain,train,valid,train_valid,test set split dataloaders.
361 | '''
362 |
363 | color_jitter = transforms.ColorJitter(
364 | 0.8*args.jitter_d, 0.8*args.jitter_d, 0.8*args.jitter_d, 0.2*args.jitter_d)
365 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=args.jitter_p)
366 |
367 | rnd_grey = transforms.RandomGrayscale(p=args.grey_p)
368 |
369 | # Base train and test augmentaions
370 | transf = {
371 | 'train': transforms.Compose([
372 | transforms.ToPILImage(),
373 | rnd_color_jitter,
374 | rnd_grey,
375 | transforms.RandomResizedCrop((args.crop_dim, args.crop_dim), scale=(0.25, 1.0)),
376 | transforms.RandomHorizontalFlip(),
377 | transforms.ToTensor(),
378 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091),
379 | (0.24703223, 0.24348513, 0.26158784))]),
380 | 'pretrain': transforms.Compose([
381 | transforms.ToPILImage(),
382 | rnd_color_jitter,
383 | rnd_grey,
384 | transforms.RandomResizedCrop((args.crop_dim, args.crop_dim)),
385 | transforms.RandomHorizontalFlip(),
386 | transforms.ToTensor(),
387 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091),
388 | (0.24703223, 0.24348513, 0.26158784))]),
389 | 'test': transforms.Compose([
390 | transforms.ToTensor(),
391 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091),
392 | (0.24703223, 0.24348513, 0.26158784))])
393 | }
394 |
395 | config = {'train': True, 'test': False}
396 |
397 | if args.dataset == 'cifar10':
398 |
399 | datasets = {i: CIFAR10(root=dataset_paths[i], transform=transf[i],
400 | train=config[i], download=True) for i in config.keys()}
401 | val_samples = 500
402 |
403 | elif args.dataset == 'cifar100':
404 |
405 | datasets = {i: CIFAR100(root=dataset_paths[i], transform=transf[i],
406 | train=config[i], download=True) for i in config.keys()}
407 |
408 | val_samples = 100
409 |
410 | # weighted sampler weights for full(f) training set
411 | f_s_weights = sample_weights(datasets['train'].targets)
412 |
413 | # return data, labels dicts for new train set and class-balanced valid set
414 | # 500 is the num of samples to be split into the val set for each class (10)
415 | data, labels = random_split(data=datasets['train'].data,
416 | labels=datasets['train'].targets,
417 | n_classes=args.n_classes,
418 | n_samples_per_class=np.repeat(val_samples, args.n_classes).reshape(-1))
419 |
420 | # save original full training set
421 | datasets['train_valid'] = datasets['train']
422 |
423 | # make new pretraining set without validation samples
424 | datasets['pretrain'] = CustomDataset(data=data['train'],
425 | labels=labels['train'], transform=transf['pretrain'], two_crop=args.twocrop)
426 |
427 | # make new finetuning set without validation samples
428 | datasets['train'] = CustomDataset(data=data['train'],
429 | labels=labels['train'], transform=transf['train'], two_crop=False)
430 |
431 | # make class balanced validation set for finetuning
432 | datasets['valid'] = CustomDataset(data=data['valid'],
433 | labels=labels['valid'], transform=transf['test'], two_crop=False)
434 |
435 | # weighted sampler weights for new training set
436 | s_weights = sample_weights(datasets['pretrain'].labels)
437 |
438 | config = {
439 | 'pretrain': WeightedRandomSampler(s_weights,
440 | num_samples=len(s_weights), replacement=True),
441 | 'train': WeightedRandomSampler(s_weights,
442 | num_samples=len(s_weights), replacement=True),
443 | 'train_valid': WeightedRandomSampler(f_s_weights,
444 | num_samples=len(f_s_weights), replacement=True),
445 | 'valid': None, 'test': None
446 | }
447 |
448 | if args.distributed:
449 | config = {'pretrain': DistributedSampler(datasets['pretrain']),
450 | 'train': DistributedSampler(datasets['train']),
451 | 'train_valid': DistributedSampler(datasets['train_valid']),
452 | 'valid': None, 'test': None}
453 |
454 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i],
455 | num_workers=8, pin_memory=True, drop_last=True,
456 | batch_size=args.batch_size) for i in config.keys()}
457 |
458 | return dataloaders
459 |
--------------------------------------------------------------------------------
/src/kill_zombie.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | kill $(ps aux | grep "main.py" | grep -v grep | awk '{print $2}')
4 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import logging
5 | import random
6 | import configargparse
7 | import warnings
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn.parallel import DistributedDataParallel
13 |
14 | from train import finetune, evaluate, pretrain, supervised
15 | from datasets import get_dataloaders
16 | from utils import *
17 | import model.network as models
18 | from model.moco import MoCo_Model
19 |
20 |
21 | warnings.filterwarnings("ignore")
22 |
23 |
24 | default_config = os.path.join(os.path.split(os.getcwd())[0], 'config.conf')
25 |
26 | parser = configargparse.ArgumentParser(
27 | description='Pytorch MocoV2', default_config_files=[default_config])
28 | parser.add_argument('-c', '--my-config', required=False,
29 | is_config_file=True, help='config file path')
30 | parser.add_argument('--dataset', default='cifar10',
31 | help='Dataset, (Options: cifar10, cifar100, stl10, imagenet, tinyimagenet).')
32 | parser.add_argument('--dataset_path', default=None,
33 | help='Path to dataset, Not needed for TorchVision Datasets.')
34 | parser.add_argument('--model', default='resnet18',
35 | help='Model, (Options: resnet18, resnet34, resnet50, resnet101, resnet152).')
36 | parser.add_argument('--n_epochs', type=int, default=1000,
37 | help='Number of Epochs in Contrastive Training.')
38 | parser.add_argument('--finetune_epochs', type=int, default=100,
39 | help='Number of Epochs in Linear Classification Training.')
40 | parser.add_argument('--warmup_epochs', type=int, default=10,
41 | help='Number of Warmup Epochs During Contrastive Training.')
42 | parser.add_argument('--batch_size', type=int, default=256,
43 | help='Number of Samples Per Batch.')
44 | parser.add_argument('--learning_rate', type=float, default=1.0,
45 | help='Starting Learing Rate for Contrastive Training.')
46 | parser.add_argument('--base_lr', type=float, default=0.0001,
47 | help='Base / Minimum Learing Rate to Begin Linear Warmup.')
48 | parser.add_argument('--finetune_learning_rate', type=float, default=0.1,
49 | help='Starting Learing Rate for Linear Classification Training.')
50 | parser.add_argument('--weight_decay', type=float, default=1e-6,
51 | help='Contrastive Learning Weight Decay Regularisation Factor.')
52 | parser.add_argument('--finetune_weight_decay', type=float, default=0.0,
53 | help='Linear Classification Training Weight Decay Regularisation Factor.')
54 | parser.add_argument('--optimiser', default='sgd',
55 | help='Optimiser, (Options: sgd, adam, lars).')
56 | parser.add_argument('--patience', default=50, type=int,
57 | help='Number of Epochs to Wait for Improvement.')
58 | parser.add_argument('--queue_size', type=int, default=65536,
59 | help='Size of Memory Queue, Must be Divisible by batch_size.')
60 | parser.add_argument('--queue_momentum', type=float, default=0.999,
61 | help='Momentum for the Key Encoder Update.')
62 | parser.add_argument('--temperature', type=float, default=0.07,
63 | help='InfoNCE Temperature Factor')
64 | parser.add_argument('--jitter_d', type=float, default=1.0,
65 | help='Distortion Factor for the Random Colour Jitter Augmentation')
66 | parser.add_argument('--jitter_p', type=float, default=0.8,
67 | help='Probability to Apply Random Colour Jitter Augmentation')
68 | parser.add_argument('--blur_sigma', nargs=2, type=float, default=[0.1, 2.0],
69 | help='Radius to Apply Random Colour Jitter Augmentation')
70 | parser.add_argument('--blur_p', type=float, default=0.5,
71 | help='Probability to Apply Gaussian Blur Augmentation')
72 | parser.add_argument('--grey_p', type=float, default=0.2,
73 | help='Probability to Apply Random Grey Scale')
74 | parser.add_argument('--no_twocrop', dest='twocrop', action='store_false',
75 | help='Whether or Not to Use Two Crop Augmentation, Used to Create Two Views of the Input for Contrastive Learning. (Default: True)')
76 | parser.set_defaults(twocrop=True)
77 | parser.add_argument('--load_checkpoint_dir', default=None,
78 | help='Path to Load Pre-trained Model From.')
79 | parser.add_argument('--no_distributed', dest='distributed', action='store_false',
80 | help='Whether or Not to Use Distributed Training. (Default: True)')
81 | parser.set_defaults(distributed=True)
82 | parser.add_argument('--finetune', dest='finetune', action='store_true',
83 | help='Perform Only Linear Classification Training. (Default: False)')
84 | parser.set_defaults(finetune=False)
85 | parser.add_argument('--supervised', dest='supervised', action='store_true',
86 | help='Perform Supervised Pre-Training. (Default: False)')
87 | parser.set_defaults(supervised=False)
88 |
89 |
90 | def setup(distributed):
91 | """ Sets up for optional distributed training.
92 | For distributed training run as:
93 | python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py
94 | To kill zombie processes use:
95 | kill $(ps aux | grep "main.py" | grep -v grep | awk '{print $2}')
96 | For data parallel training on GPUs or CPU training run as:
97 | python main.py --no_distributed
98 |
99 | Taken from https://github.com/fabio-deep/Distributed-Pytorch-Boilerplate
100 |
101 | args:
102 | distributed (bool): Flag whether or not to perform distributed training.
103 |
104 | returns:
105 | local_rank (int): rank of local machine / host to perform distributed training.
106 |
107 | device (string): Device and rank of device to perform training on.
108 |
109 | """
110 | if distributed:
111 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
112 | local_rank = int(os.environ.get('LOCAL_RANK'))
113 | device = torch.device(f'cuda:{local_rank}') # unique on individual node
114 |
115 | print('World size: {} ; Rank: {} ; LocalRank: {} ; Master: {}:{}'.format(
116 | os.environ.get('WORLD_SIZE'),
117 | os.environ.get('RANK'),
118 | os.environ.get('LOCAL_RANK'),
119 | os.environ.get('MASTER_ADDR'), os.environ.get('MASTER_PORT')))
120 | else:
121 | local_rank = None
122 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
123 |
124 | seed = 44
125 | random.seed(seed)
126 | np.random.seed(seed)
127 | torch.manual_seed(seed)
128 | torch.cuda.manual_seed(seed)
129 | torch.cuda.manual_seed_all(seed)
130 |
131 | torch.backends.cudnn.enabled = True
132 | torch.backends.cudnn.deterministic = True
133 | torch.backends.cudnn.benchmark = False # True
134 |
135 | return device, local_rank
136 |
137 |
138 | def main():
139 | """ Main """
140 |
141 | # Arguments
142 | args = parser.parse_args()
143 |
144 | # Setup Distributed Training
145 | device, local_rank = setup(distributed=args.distributed)
146 |
147 | # Get Dataloaders for Dataset of choice
148 | dataloaders, args = get_dataloaders(args)
149 |
150 | # Setup logging, saving models, summaries
151 | args = experiment_config(parser, args)
152 |
153 | ''' Base Encoder '''
154 |
155 | # Get available models from /model/network.py
156 | model_names = sorted(name for name in models.__dict__
157 | if name.islower() and not name.startswith("__")
158 | and callable(models.__dict__[name]))
159 |
160 | # If model exists
161 | if any(args.model in model_name for model_name in model_names):
162 | # Load model
163 | base_encoder = getattr(models, args.model)(
164 | args, num_classes=args.n_classes) # Encoder
165 |
166 | else:
167 | raise NotImplementedError("Model Not Implemented: {}".format(args.model))
168 |
169 | if not args.supervised:
170 | # freeze all layers but the last fc
171 | for name, param in base_encoder.named_parameters():
172 | if name not in ['fc.weight', 'fc.bias']:
173 | param.requires_grad = False
174 | # init the fc layer
175 | init_weights(base_encoder)
176 |
177 | ''' MoCo Model '''
178 | moco = MoCo_Model(args, queue_size=args.queue_size,
179 | momentum=args.queue_momentum, temperature=args.temperature)
180 |
181 | # Place model onto GPU(s)
182 | if args.distributed:
183 | torch.cuda.set_device(device)
184 | torch.set_num_threads(6) # n cpu threads / n processes per node
185 |
186 | moco = DistributedDataParallel(moco.cuda(),
187 | device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, broadcast_buffers=False)
188 | base_encoder = DistributedDataParallel(base_encoder.cuda(),
189 | device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, broadcast_buffers=False)
190 |
191 | # Only print from process (rank) 0
192 | args.print_progress = True if int(os.environ.get('RANK')) == 0 else False
193 | else:
194 | # If non Distributed use DataParallel
195 | if torch.cuda.device_count() > 1:
196 | moco = nn.DataParallel(moco)
197 | base_encoder = nn.DataParallel(base_encoder)
198 |
199 | print('\nUsing', torch.cuda.device_count(), 'GPU(s).\n')
200 |
201 | moco.to(device)
202 | base_encoder.to(device)
203 |
204 | args.print_progress = True
205 |
206 | # Print Network Structure and Params
207 | if args.print_progress:
208 | print_network(moco, args) # prints out the network architecture etc
209 | logging.info('\npretrain/train: {} - valid: {} - test: {}'.format(
210 | len(dataloaders['train'].dataset), len(dataloaders['valid'].dataset),
211 | len(dataloaders['test'].dataset)))
212 |
213 | # launch model training or inference
214 | if not args.finetune:
215 |
216 | ''' Pretraining / Finetuning / Evaluate '''
217 |
218 | if not args.supervised:
219 | # Pretrain the encoder and projection head
220 | pretrain(moco, dataloaders, args)
221 |
222 | # Load the state_dict from query encoder and load it on finetune net
223 | base_encoder = load_moco(base_encoder, args)
224 |
225 | else:
226 | supervised(base_encoder, dataloaders, args)
227 |
228 | # Load the state_dict from query encoder and load it on finetune net
229 | base_encoder = load_sup(base_encoder, args)
230 |
231 | # Supervised Finetuning of the supervised classification head
232 | finetune(base_encoder, dataloaders, args)
233 |
234 | # Evaluate the pretrained model and trained supervised head
235 | test_loss, test_acc, test_acc_top5 = evaluate(
236 | base_encoder, dataloaders, 'test', args.finetune_epochs, args)
237 |
238 | print('[Test] loss {:.4f} - acc {:.4f} - acc_top5 {:.4f}'.format(
239 | test_loss, test_acc, test_acc_top5))
240 |
241 | if args.distributed: # cleanup
242 | torch.distributed.destroy_process_group()
243 | else:
244 |
245 | ''' Finetuning / Evaluate '''
246 |
247 | # Do not Pretrain, just finetune and inference
248 | # Load the state_dict from query encoder and load it on finetune net
249 | base_encoder = load_moco(base_encoder, args)
250 |
251 | # Supervised Finetuning of the supervised classification head
252 | finetune(base_encoder, dataloaders, args)
253 |
254 | # Evaluate the pretrained model and trained supervised head
255 | test_loss, test_acc, test_acc_top5 = evaluate(
256 | base_encoder, dataloaders, 'test', args.finetune_epochs, args)
257 |
258 | print('[Test] loss {:.4f} - acc {:.4f} - acc_top5 {:.4f}'.format(
259 | test_loss, test_acc, test_acc_top5))
260 |
261 | if args.distributed: # cleanup
262 | torch.distributed.destroy_process_group()
263 |
264 |
265 | if __name__ == '__main__':
266 | main()
267 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AidenDurrant/MoCo-Pytorch/373b5fafdbf51e4c6c19a984b7e8d41a296cf6d9/src/model/__init__.py
--------------------------------------------------------------------------------
/src/model/moco.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import model.network as models
7 |
8 |
9 | class MoCo_Model(nn.Module):
10 | def __init__(self, args, queue_size=65536, momentum=0.999, temperature=0.07):
11 | '''
12 | MoCoV2 model, taken from: https://github.com/facebookresearch/moco.
13 |
14 | Adapted for use in personal Boilerplate for unsupervised/self-supervised contrastive learning.
15 |
16 | Additionally, too inspiration from: https://github.com/HobbitLong/CMC.
17 |
18 | Args:
19 | init:
20 | args (dict): Program arguments/commandline arguments.
21 |
22 | queue_size (int): Length of the queue/memory, number of samples to store in memory. (default: 65536)
23 |
24 | momentum (float): Momentum value for updating the key_encoder. (default: 0.999)
25 |
26 | temperature (float): Temperature used in the InfoNCE / NT_Xent contrastive losses. (default: 0.07)
27 |
28 | forward:
29 | x_q (Tensor): Reprentation of view intended for the query_encoder.
30 |
31 | x_k (Tensor): Reprentation of view intended for the key_encoder.
32 |
33 | returns:
34 |
35 | logit (Tensor): Positve and negative logits computed as by InfoNCE loss. (bsz, queue_size + 1)
36 |
37 | label (Tensor): Labels of the positve and negative logits to be used in softmax cross entropy. (bsz, 1)
38 |
39 | '''
40 | super(MoCo_Model, self).__init__()
41 |
42 | self.queue_size = queue_size
43 | self.momentum = momentum
44 | self.temperature = temperature
45 |
46 | assert self.queue_size % args.batch_size == 0 # for simplicity
47 |
48 | # Load model
49 | self.encoder_q = getattr(models, args.model)(
50 | args, num_classes=128) # Query Encoder
51 |
52 | self.encoder_k = getattr(models, args.model)(
53 | args, num_classes=128) # Key Encoder
54 |
55 | # Add the mlp head
56 | self.encoder_q.fc = models.projection_MLP(args)
57 | self.encoder_k.fc = models.projection_MLP(args)
58 |
59 | # Initialize the key encoder to have the same values as query encoder
60 | # Do not update the key encoder via gradient
61 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
62 | param_k.data.copy_(param_q.data)
63 | param_k.requires_grad = False
64 |
65 | # Create the queue to store negative samples
66 | self.register_buffer("queue", torch.randn(self.queue_size, 128))
67 |
68 | # Create pointer to store current position in the queue when enqueue and dequeue
69 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
70 |
71 | @torch.no_grad()
72 | def momentum_update(self):
73 | '''
74 | Update the key_encoder parameters through the momentum update:
75 |
76 |
77 | key_params = momentum * key_params + (1 - momentum) * query_params
78 |
79 | '''
80 |
81 | # For each of the parameters in each encoder
82 | for p_q, p_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
83 | p_k.data = p_k.data * self.momentum + p_q.detach().data * (1. - self.momentum)
84 |
85 | @torch.no_grad()
86 | def shuffled_idx(self, batch_size):
87 | '''
88 | Generation of the shuffled indexes for the implementation of ShuffleBN.
89 |
90 | https://github.com/HobbitLong/CMC.
91 |
92 | args:
93 | batch_size (Tensor.int()): Number of samples in a batch
94 |
95 | returns:
96 | shuffled_idxs (Tensor.long()): A random permutation index order for the shuffling of the current minibatch
97 |
98 | reverse_idxs (Tensor.long()): A reverse of the random permutation index order for the shuffling of the
99 | current minibatch to get back original sample order
100 |
101 | '''
102 |
103 | # Generate shuffled indexes
104 | shuffled_idxs = torch.randperm(batch_size).long().cuda()
105 |
106 | reverse_idxs = torch.zeros(batch_size).long().cuda()
107 |
108 | value = torch.arange(batch_size).long().cuda()
109 |
110 | reverse_idxs.index_copy_(0, shuffled_idxs, value)
111 |
112 | return shuffled_idxs, reverse_idxs
113 |
114 | @torch.no_grad()
115 | def update_queue(self, feat_k):
116 | '''
117 | Update the memory / queue.
118 |
119 | Add batch to end of most recent sample index and remove the oldest samples in the queue.
120 |
121 | Store location of most recent sample index (ptr).
122 |
123 | Taken from: https://github.com/facebookresearch/moco
124 |
125 | args:
126 | feat_k (Tensor): Feature reprentations of the view x_k computed by the key_encoder.
127 | '''
128 |
129 | batch_size = feat_k.size(0)
130 |
131 | ptr = int(self.queue_ptr)
132 |
133 | # replace the keys at ptr (dequeue and enqueue)
134 | self.queue[ptr:ptr + batch_size, :] = feat_k
135 |
136 | # move pointer along to end of current batch
137 | ptr = (ptr + batch_size) % self.queue_size
138 |
139 | # Store queue pointer as register_buffer
140 | self.queue_ptr[0] = ptr
141 |
142 | def InfoNCE_logits(self, f_q, f_k):
143 | '''
144 | Compute the similarity logits between positive
145 | samples and positve to all negatives in the memory.
146 |
147 | args:
148 | f_q (Tensor): Feature reprentations of the view x_q computed by the query_encoder.
149 |
150 | f_k (Tensor): Feature reprentations of the view x_k computed by the key_encoder.
151 |
152 | returns:
153 | logit (Tensor): Positve and negative logits computed as by InfoNCE loss. (bsz, queue_size + 1)
154 |
155 | label (Tensor): Labels of the positve and negative logits to be used in softmax cross entropy. (bsz, 1)
156 | '''
157 |
158 | f_k = f_k.detach()
159 |
160 | # Get queue from register_buffer
161 | f_mem = self.queue.clone().detach()
162 |
163 | # Normalize the feature representations
164 | f_q = nn.functional.normalize(f_q, dim=1)
165 | f_k = nn.functional.normalize(f_k, dim=1)
166 | f_mem = nn.functional.normalize(f_mem, dim=1)
167 |
168 | # Compute sim between positive views
169 | pos = torch.bmm(f_q.view(f_q.size(0), 1, -1),
170 | f_k.view(f_k.size(0), -1, 1)).squeeze(-1)
171 |
172 | # Compute sim between postive and all negatives in the memory
173 | neg = torch.mm(f_q, f_mem.transpose(1, 0))
174 |
175 | logits = torch.cat((pos, neg), dim=1)
176 |
177 | logits /= self.temperature
178 |
179 | # Create labels, first logit is postive, all others are negative
180 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
181 |
182 | return logits, labels
183 |
184 | def forward(self, x_q, x_k):
185 |
186 | batch_size = x_q.size(0)
187 |
188 | # Feature representations of the query view from the query encoder
189 | feat_q = self.encoder_q(x_q)
190 |
191 | # TODO: shuffle ids with distributed data parallel
192 | # Get shuffled and reversed indexes for the current minibatch
193 | shuffled_idxs, reverse_idxs = self.shuffled_idx(batch_size)
194 |
195 | with torch.no_grad():
196 | # Update the key encoder
197 | self.momentum_update()
198 |
199 | # Shuffle minibatch
200 | x_k = x_k[shuffled_idxs]
201 |
202 | # Feature representations of the shuffled key view from the key encoder
203 | feat_k = self.encoder_k(x_k)
204 |
205 | # reverse the shuffled samples to original position
206 | feat_k = feat_k[reverse_idxs]
207 |
208 | # Compute the logits for the InfoNCE contrastive loss.
209 | logit, label = self.InfoNCE_logits(feat_q, feat_k)
210 |
211 | # Update the queue/memory with the current key_encoder minibatch.
212 | self.update_queue(feat_k)
213 |
214 | return logit, label
215 |
--------------------------------------------------------------------------------
/src/model/network.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 |
5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
7 | 'wide_resnet50_2', 'wide_resnet101_2', 'projection_MLP']
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
19 | }
20 |
21 | '''Resnet Class
22 |
23 | Taken from:
24 |
25 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
26 | '''
27 |
28 |
29 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
30 | """3x3 convolution with padding"""
31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
32 | padding=dilation, groups=groups, bias=False, dilation=dilation)
33 |
34 |
35 | def conv1x1(in_planes, out_planes, stride=1):
36 | """1x1 convolution"""
37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | expansion = 1
42 |
43 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
44 | base_width=64, dilation=1, norm_layer=None):
45 | super(BasicBlock, self).__init__()
46 | if norm_layer is None:
47 | norm_layer = nn.BatchNorm2d
48 | if groups != 1 or base_width != 64:
49 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
50 | if dilation > 1:
51 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
52 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
53 | self.conv1 = conv3x3(inplanes, planes, stride)
54 | self.bn1 = norm_layer(planes)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.conv2 = conv3x3(planes, planes)
57 | self.bn2 = norm_layer(planes)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | identity = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 |
71 | if self.downsample is not None:
72 | identity = self.downsample(x)
73 |
74 | out += identity
75 | out = self.relu(out)
76 |
77 | return out
78 |
79 |
80 | class Bottleneck(nn.Module):
81 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
82 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
83 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
84 | # This variant is also known as ResNet V1.5 and improves accuracy according to
85 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
86 |
87 | expansion = 4
88 |
89 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
90 | base_width=64, dilation=1, norm_layer=None):
91 | super(Bottleneck, self).__init__()
92 | if norm_layer is None:
93 | norm_layer = nn.BatchNorm2d
94 | width = int(planes * (base_width / 64.)) * groups
95 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
96 | self.conv1 = conv1x1(inplanes, width)
97 | self.bn1 = norm_layer(width)
98 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
99 | self.bn2 = norm_layer(width)
100 | self.conv3 = conv1x1(width, planes * self.expansion)
101 | self.bn3 = norm_layer(planes * self.expansion)
102 | self.relu = nn.ReLU(inplace=True)
103 | self.downsample = downsample
104 | self.stride = stride
105 |
106 | def forward(self, x):
107 | identity = x
108 |
109 | out = self.conv1(x)
110 | out = self.bn1(out)
111 | out = self.relu(out)
112 |
113 | out = self.conv2(out)
114 | out = self.bn2(out)
115 | out = self.relu(out)
116 |
117 | out = self.conv3(out)
118 | out = self.bn3(out)
119 |
120 | if self.downsample is not None:
121 | identity = self.downsample(x)
122 |
123 | out += identity
124 | out = self.relu(out)
125 |
126 | return out
127 |
128 |
129 | class ResNet(nn.Module):
130 | def __init__(self, block, layers, args, num_classes=1000, zero_init_residual=False,
131 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
132 | norm_layer=None):
133 | super(ResNet, self).__init__()
134 | if norm_layer is None:
135 | norm_layer = nn.BatchNorm2d
136 | self._norm_layer = norm_layer
137 |
138 | self.inplanes = 64
139 | self.dilation = 1
140 | if replace_stride_with_dilation is None:
141 | # each element in the tuple indicates if we should replace
142 | # the 2x2 stride with a dilated convolution instead
143 | replace_stride_with_dilation = [False, False, False]
144 | if len(replace_stride_with_dilation) != 3:
145 | raise ValueError("replace_stride_with_dilation should be None "
146 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
147 | self.groups = groups
148 | self.base_width = width_per_group
149 |
150 | # Different model for smaller image size
151 | if args.dataset == 'cifar10' or args.dataset == 'cifar100':
152 |
153 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
154 | bias=False) # For CIFAR
155 |
156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=0) # For CIFAR
157 |
158 | # e.g. ImageNet
159 | else:
160 |
161 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
162 | bias=False)
163 |
164 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
165 |
166 | self.bn1 = norm_layer(self.inplanes)
167 | self.relu = nn.ReLU(inplace=True)
168 |
169 | self.layer1 = self._make_layer(block, 64, layers[0])
170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
171 | dilate=replace_stride_with_dilation[0])
172 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
173 | dilate=replace_stride_with_dilation[1])
174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
175 | dilate=replace_stride_with_dilation[2])
176 |
177 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
178 |
179 | self.fc = nn.Linear(512 * block.expansion, num_classes)
180 |
181 | for m in self.modules():
182 | if isinstance(m, nn.Conv2d):
183 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
184 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
185 | nn.init.constant_(m.weight, 1)
186 | nn.init.constant_(m.bias, 0)
187 |
188 | # Zero-initialize the last BN in each residual branch,
189 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
190 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
191 | if zero_init_residual:
192 | for m in self.modules():
193 | if isinstance(m, Bottleneck):
194 | nn.init.constant_(m.bn3.weight, 0)
195 | elif isinstance(m, BasicBlock):
196 | nn.init.constant_(m.bn2.weight, 0)
197 |
198 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
199 | norm_layer = self._norm_layer
200 | downsample = None
201 | previous_dilation = self.dilation
202 | if dilate:
203 | self.dilation *= stride
204 | stride = 1
205 | if stride != 1 or self.inplanes != planes * block.expansion:
206 | downsample = nn.Sequential(
207 | conv1x1(self.inplanes, planes * block.expansion, stride),
208 | norm_layer(planes * block.expansion),
209 | )
210 |
211 | layers = []
212 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
213 | self.base_width, previous_dilation, norm_layer))
214 | self.inplanes = planes * block.expansion
215 | for _ in range(1, blocks):
216 | layers.append(block(self.inplanes, planes, groups=self.groups,
217 | base_width=self.base_width, dilation=self.dilation,
218 | norm_layer=norm_layer))
219 |
220 | return nn.Sequential(*layers)
221 |
222 | def forward(self, x):
223 | x = self.conv1(x)
224 | x = self.bn1(x)
225 | x = self.relu(x)
226 |
227 | x = self.maxpool(x)
228 |
229 | x = self.layer1(x)
230 | x = self.layer2(x)
231 | x = self.layer3(x)
232 | x = self.layer4(x)
233 |
234 | x = self.avgpool(x)
235 |
236 | x = torch.flatten(x, 1)
237 |
238 | x = self.fc(x)
239 |
240 | return x
241 |
242 |
243 | def resnet18(args, **kwargs):
244 | r"""ResNet-18 model from
245 | `"Deep Residual Learning for Image Recognition" `_
246 | Args:
247 | args: arguments
248 | """
249 | return ResNet(BasicBlock, [2, 2, 2, 2], args, **kwargs)
250 |
251 |
252 | def resnet34(args, **kwargs):
253 | r"""ResNet-34 model from
254 | `"Deep Residual Learning for Image Recognition" `_
255 | Args:
256 | args: arguments
257 | """
258 | return ResNet(BasicBlock, [3, 4, 6, 3], args, **kwargs)
259 |
260 |
261 | def resnet50(args, **kwargs):
262 | r"""ResNet-50 model from
263 | `"Deep Residual Learning for Image Recognition" `_
264 | Args:
265 | args: arguments
266 | """
267 | return ResNet(Bottleneck, [3, 4, 6, 3], args, **kwargs)
268 |
269 |
270 | def resnet101(args, **kwargs):
271 | r"""ResNet-101 model from
272 | `"Deep Residual Learning for Image Recognition" `_
273 | Args:
274 | args: arguments
275 | """
276 | return ResNet(Bottleneck, [3, 4, 23, 3], args, **kwargs)
277 |
278 |
279 | def resnet152(args, **kwargs):
280 | r"""ResNet-152 model from
281 | `"Deep Residual Learning for Image Recognition" `_
282 | Args:
283 | args: arguments
284 | """
285 | return ResNet(Bottleneck, [3, 8, 36, 3], args, **kwargs)
286 |
287 |
288 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
289 | r"""ResNeXt-50 32x4d model from
290 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
291 | Args:
292 | pretrained (bool): If True, returns a model pre-trained on ImageNet
293 | progress (bool): If True, displays a progress bar of the download to stderr
294 | """
295 | kwargs['groups'] = 32
296 | kwargs['width_per_group'] = 4
297 | return ResNet(Bottleneck, [3, 4, 6, 3], args, **kwargs)
298 |
299 |
300 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
301 | r"""ResNeXt-101 32x8d model from
302 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
303 | Args:
304 | pretrained (bool): If True, returns a model pre-trained on ImageNet
305 | progress (bool): If True, displays a progress bar of the download to stderr
306 | """
307 | kwargs['groups'] = 32
308 | kwargs['width_per_group'] = 8
309 | return ResNet(Bottleneck, [3, 4, 23, 3], args, **kwargs)
310 |
311 |
312 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
313 | r"""Wide ResNet-50-2 model from
314 | `"Wide Residual Networks" `_
315 | The model is the same as ResNet except for the bottleneck number of channels
316 | which is twice larger in every block. The number of channels in outer 1x1
317 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
318 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
319 | Args:
320 | pretrained (bool): If True, returns a model pre-trained on ImageNet
321 | progress (bool): If True, displays a progress bar of the download to stderr
322 | """
323 | kwargs['width_per_group'] = 64 * 2
324 | return ResNet(Bottleneck, [3, 4, 6, 3], args, **kwargs)
325 |
326 |
327 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
328 | r"""Wide ResNet-101-2 model from
329 | `"Wide Residual Networks" `_
330 | The model is the same as ResNet except for the bottleneck number of channels
331 | which is twice larger in every block. The number of channels in outer 1x1
332 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
333 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
334 | Args:
335 | pretrained (bool): If True, returns a model pre-trained on ImageNet
336 | progress (bool): If True, displays a progress bar of the download to stderr
337 | """
338 | kwargs['width_per_group'] = 64 * 2
339 | return ResNet(Bottleneck, [3, 4, 23, 3], encoder, args, **kwargs)
340 |
341 |
342 | ''' SimCLR Projection Head '''
343 |
344 |
345 | class projection_MLP(nn.Module):
346 | def __init__(self, args):
347 | '''Projection head for the pretraining of the resnet encoder.
348 |
349 | - Uses the dataset and model size to determine encoder output
350 | representation dimension.
351 | - Outputs to a dimension of 128, and uses non-linear activation
352 | as described in SimCLR paper: https://arxiv.org/pdf/2002.05709.pdf
353 | '''
354 | super(projection_MLP, self).__init__()
355 |
356 | if args.model == 'resnet18' or args.model == 'resnet34':
357 | n_channels = 512
358 | elif args.model == 'resnet50' or args.model == 'resnet101' or args.model == 'resnet152':
359 | n_channels = 2048
360 | else:
361 | raise NotImplementedError('model not supported: {}'.format(args.model))
362 |
363 | self.projection_head = nn.Sequential()
364 |
365 | self.projection_head.add_module('W1', nn.Linear(
366 | n_channels, n_channels))
367 | self.projection_head.add_module('ReLU', nn.ReLU())
368 | self.projection_head.add_module('W2', nn.Linear(
369 | n_channels, 128))
370 |
371 | def forward(self, x):
372 | return self.projection_head(x)
373 |
--------------------------------------------------------------------------------
/src/optimisers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.optim as optim
4 | from torch.optim.optimizer import Optimizer, required
5 |
6 |
7 | def get_optimiser(models, mode, args):
8 | '''Get the desired optimiser
9 |
10 | - Selects and initialises an optimiser with model params.
11 |
12 | - if 'LARS' is selected, the 'bn' and 'bias' parameters are removed from
13 | model optimisation, only passing the parameters we want.
14 |
15 | Args:
16 | models (tuple): models which we want to optmise, (e.g. encoder and projection head)
17 |
18 | mode (string): the mode of training, (i.e. 'pretrain', 'finetune')
19 |
20 | args (Dictionary): Program Arguments
21 |
22 | Returns:
23 | optimiser (torch.optim.optimizer):
24 | '''
25 | params_models = []
26 | reduced_params = []
27 |
28 | removed_params = []
29 |
30 | skip_lists = ['bn', 'bias']
31 |
32 | for m in models:
33 |
34 | m_skip = []
35 | m_noskip = []
36 |
37 | params_models += list(m.parameters())
38 |
39 | for name, param in m.named_parameters():
40 | if (any(skip_name in name for skip_name in skip_lists)):
41 | m_skip.append(param)
42 | else:
43 | m_noskip.append(param)
44 | reduced_params += list(m_noskip)
45 | removed_params += list(m_skip)
46 | # Set hyperparams depending on mode
47 | if mode == 'pretrain':
48 | lr = args.learning_rate
49 | wd = args.weight_decay
50 | else:
51 | lr = args.finetune_learning_rate
52 | wd = args.finetune_weight_decay
53 |
54 | # Select Optimiser
55 | if args.optimiser == 'adam':
56 |
57 | optimiser = optim.Adam(params_models, lr=lr,
58 | weight_decay=wd)
59 |
60 | elif args.optimiser == 'sgd':
61 |
62 | optimiser = optim.SGD(params_models, lr=lr,
63 | weight_decay=wd, momentum=0.9)
64 |
65 | elif args.optimiser == 'lars':
66 |
67 | print("reduced_params len: {}".format(len(reduced_params)))
68 | print("removed_params len: {}".format(len(removed_params)))
69 |
70 | optimiser = LARS(reduced_params+removed_params, lr=lr,
71 | weight_decay=wd, eta=0.001, use_nesterov=False, len_reduced=len(reduced_params))
72 | else:
73 |
74 | raise NotImplementedError('{} not setup.'.format(args.optimiser))
75 |
76 | return optimiser
77 |
78 |
79 | class LARS(Optimizer):
80 | """
81 | Layer-wise adaptive rate scaling
82 |
83 | - Converted from Tensorflow to Pytorch from:
84 |
85 | https://github.com/google-research/simclr/blob/master/lars_optimizer.py
86 |
87 | - Based on:
88 |
89 | https://github.com/noahgolmant/pytorch-lars
90 |
91 | params (iterable): iterable of parameters to optimize or dicts defining
92 | parameter groups
93 | lr (float): base learning rate (\gamma_0)
94 |
95 | lr (int): Length / Number of layers we want to apply weight decay, else do not compute
96 |
97 | momentum (float, optional): momentum factor (default: 0.9)
98 |
99 | use_nesterov (bool, optional): flag to use nesterov momentum (default: False)
100 |
101 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
102 | ("\beta")
103 |
104 | eta (float, optional): LARS coefficient (default: 0.001)
105 |
106 | - Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
107 |
108 | - Large Batch Training of Convolutional Networks:
109 | https://arxiv.org/abs/1708.03888
110 |
111 | """
112 |
113 | def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov=False, weight_decay=0.0, classic_momentum=True, eta=0.001):
114 |
115 | self.epoch = 0
116 | defaults = dict(
117 | lr=lr,
118 | momentum=momentum,
119 | use_nesterov=use_nesterov,
120 | weight_decay=weight_decay,
121 | classic_momentum=classic_momentum,
122 | eta=eta,
123 | len_reduced=len_reduced
124 | )
125 |
126 | super(LARS, self).__init__(params, defaults)
127 | self.lr = lr
128 | self.momentum = momentum
129 | self.weight_decay = weight_decay
130 | self.use_nesterov = use_nesterov
131 | self.classic_momentum = classic_momentum
132 | self.eta = eta
133 | self.len_reduced = len_reduced
134 |
135 | def step(self, epoch=None, closure=None):
136 |
137 | loss = None
138 |
139 | if closure is not None:
140 | loss = closure()
141 |
142 | if epoch is None:
143 | epoch = self.epoch
144 | self.epoch += 1
145 |
146 | for group in self.param_groups:
147 | weight_decay = group['weight_decay']
148 | momentum = group['momentum']
149 | eta = group['eta']
150 | learning_rate = group['lr']
151 |
152 | # TODO: Hacky
153 | counter = 0
154 | for p in group['params']:
155 | if p.grad is None:
156 | continue
157 |
158 | param = p.data
159 | grad = p.grad.data
160 |
161 | param_state = self.state[p]
162 |
163 | # TODO: This really hacky way needs to be improved.
164 |
165 | # Note Excluded are passed at the end of the list to are ignored
166 | if counter < self.len_reduced:
167 | grad += self.weight_decay * param
168 |
169 | # Create parameter for the momentum
170 | if "momentum_var" not in param_state:
171 | next_v = param_state["momentum_var"] = torch.zeros_like(
172 | p.data
173 | )
174 | else:
175 | next_v = param_state["momentum_var"]
176 |
177 | if self.classic_momentum:
178 | trust_ratio = 1.0
179 |
180 | # TODO: implementation of layer adaptation
181 | w_norm = torch.norm(param)
182 | g_norm = torch.norm(grad)
183 |
184 | device = g_norm.get_device()
185 |
186 | trust_ratio = torch.where(w_norm.ge(0), torch.where(
187 | g_norm.ge(0), (self.eta * w_norm / g_norm), torch.Tensor([1.0]).to(device)), torch.Tensor([1.0]).to(device)).item()
188 |
189 | scaled_lr = learning_rate * trust_ratio
190 |
191 | next_v.mul_(momentum).add_(scaled_lr, grad)
192 |
193 | if self.use_nesterov:
194 | update = (self.momentum * next_v) + (scaled_lr * grad)
195 | else:
196 | update = next_v
197 |
198 | p.data.add_(-update)
199 |
200 | # Not classic_momentum
201 | else:
202 |
203 | next_v.mul_(momentum).add_(grad)
204 |
205 | if self.use_nesterov:
206 | update = (self.momentum * next_v) + (grad)
207 |
208 | else:
209 | update = next_v
210 |
211 | trust_ratio = 1.0
212 |
213 | # TODO: implementation of layer adaptation
214 | w_norm = torch.norm(param)
215 | v_norm = torch.norm(update)
216 |
217 | device = v_norm.get_device()
218 |
219 | trust_ratio = torch.where(w_norm.ge(0), torch.where(
220 | v_norm.ge(0), (self.eta * w_norm / v_norm), torch.Tensor([1.0]).to(device)), torch.Tensor([1.0]).to(device)).item()
221 |
222 | scaled_lr = learning_rate * trust_ratio
223 |
224 | p.data.add_(-scaled_lr * update)
225 |
226 | counter += 1
227 |
228 | return loss
229 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import gc
4 | import logging
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | import torch.optim.lr_scheduler as lr_scheduler
12 | import torch.nn.functional as F
13 | from torch.utils.tensorboard import SummaryWriter
14 | from optimisers import get_optimiser
15 | from PIL import Image
16 |
17 |
18 | def pretrain(encoder, dataloaders, args):
19 | ''' Pretrain script - MoCo
20 |
21 | Pretrain the encoder and projection head with a Contrastive InfoNCE Loss.
22 | '''
23 | mode = 'pretrain'
24 |
25 | ''' Optimisers '''
26 | optimiser = get_optimiser((encoder,), mode, args)
27 |
28 | ''' Schedulers '''
29 | # Warmup Scheduler
30 | if args.warmup_epochs > 0:
31 | for param_group in optimiser.param_groups:
32 | param_group['lr'] = args.base_lr
33 |
34 | # Cosine LR Decay after the warmup epochs
35 | lr_decay = lr_scheduler.CosineAnnealingLR(
36 | optimiser, (args.n_epochs-args.warmup_epochs), eta_min=0.0001, last_epoch=-1)
37 | else:
38 | # Cosine LR Decay
39 | lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.n_epochs)
40 |
41 | ''' Loss / Criterion '''
42 | criterion = nn.CrossEntropyLoss().cuda()
43 |
44 | # initilize Variables
45 | args.writer = SummaryWriter(args.summaries_dir)
46 | best_valid_loss = np.inf
47 | patience_counter = 0
48 |
49 | ''' Pretrain loop '''
50 | for epoch in range(args.n_epochs):
51 |
52 | # Train models
53 | encoder.train()
54 |
55 | sample_count = 0
56 | run_loss = 0
57 |
58 | # Print setup for distributed only printing on one node.
59 | if args.print_progress:
60 | logging.info('\nEpoch {}/{}:\n'.format(epoch+1, args.n_epochs))
61 | # tqdm for process (rank) 0 only when using distributed training
62 | train_dataloader = tqdm(dataloaders['pretrain'])
63 | else:
64 | train_dataloader = dataloaders['pretrain']
65 |
66 | ''' epoch loop '''
67 | for i, (inputs, _) in enumerate(train_dataloader):
68 |
69 | inputs = inputs.cuda(non_blocking=True)
70 |
71 | # Forward pass
72 | optimiser.zero_grad()
73 |
74 | # retrieve the 2 views
75 | x_i, x_j = torch.split(inputs, [3, 3], dim=1)
76 |
77 | # Get the encoder representation
78 | logit, label = encoder(x_i, x_j)
79 |
80 | loss = criterion(logit, label)
81 |
82 | loss.backward()
83 |
84 | optimiser.step()
85 |
86 | torch.cuda.synchronize()
87 |
88 | sample_count += inputs.size(0)
89 |
90 | run_loss += loss.item()
91 |
92 | epoch_pretrain_loss = run_loss / len(dataloaders['pretrain'])
93 |
94 | ''' Update Schedulers '''
95 | # TODO: Improve / add lr_scheduler for warmup
96 | if args.warmup_epochs > 0 and epoch+1 <= args.warmup_epochs:
97 | wu_lr = (args.learning_rate - args.base_lr) * \
98 | (float(epoch+1) / args.warmup_epochs) + args.base_lr
99 | save_lr = optimiser.param_groups[0]['lr']
100 | optimiser.param_groups[0]['lr'] = wu_lr
101 | else:
102 | # After warmup, decay lr with CosineAnnealingLR
103 | lr_decay.step()
104 |
105 | ''' Printing '''
106 | if args.print_progress: # only validate using process 0
107 | logging.info('\n[Train] loss: {:.4f}'.format(epoch_pretrain_loss))
108 |
109 | args.writer.add_scalars('epoch_loss', {'pretrain': epoch_pretrain_loss}, epoch+1)
110 | args.writer.add_scalars('lr', {'pretrain': optimiser.param_groups[0]['lr']}, epoch+1)
111 |
112 | # For the best performing epoch, reset patience and save model,
113 | # else update patience.
114 | if epoch_pretrain_loss <= best_valid_loss:
115 | patience_counter = 0
116 | best_epoch = epoch + 1
117 | best_valid_loss = epoch_pretrain_loss
118 |
119 | # saving using process (rank) 0 only as all processes are in sync
120 |
121 | state = {
122 | #'args': args,
123 | 'moco': encoder.state_dict(),
124 | 'optimiser': optimiser.state_dict(),
125 | 'epoch': epoch,
126 | }
127 |
128 | torch.save(state, args.checkpoint_dir)
129 | else:
130 | patience_counter += 1
131 | if patience_counter == (args.patience - 10):
132 | logging.info('\nPatience counter {}/{}.'.format(
133 | patience_counter, args.patience))
134 | elif patience_counter == args.patience:
135 | logging.info('\nEarly stopping... no improvement after {} Epochs.'.format(
136 | args.patience))
137 | break
138 |
139 | epoch_pretrain_loss = None # reset loss
140 |
141 | del state
142 |
143 | torch.cuda.empty_cache()
144 |
145 | gc.collect() # release unreferenced memory
146 |
147 |
148 | def supervised(encoder, dataloaders, args):
149 | ''' Supervised Train script - MoCo
150 |
151 | Supervised Training encoder and train the supervised classification head with a Cross Entropy Loss.
152 | '''
153 |
154 | mode = 'pretrain'
155 |
156 | ''' Optimisers '''
157 | # Only optimise the supervised head
158 | optimiser = get_optimiser((encoder,), mode, args)
159 |
160 | ''' Schedulers '''
161 | # Warmup Scheduler
162 | if args.warmup_epochs > 0:
163 | for param_group in optimiser.param_groups:
164 | param_group['lr'] = args.base_lr
165 |
166 | # Cosine LR Decay after the warmup epochs
167 | lr_decay = lr_scheduler.CosineAnnealingLR(
168 | optimiser, (args.n_epochs-args.warmup_epochs), eta_min=0.0001, last_epoch=-1)
169 | else:
170 | # Cosine LR Decay
171 | lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.n_epochs)
172 |
173 | ''' Loss / Criterion '''
174 | criterion = torch.nn.CrossEntropyLoss().cuda()
175 |
176 | # initilize Variables
177 | args.writer = SummaryWriter(args.summaries_dir)
178 | best_valid_loss = np.inf
179 | patience_counter = 0
180 |
181 | ''' Pretrain loop '''
182 | for epoch in range(args.n_epochs):
183 |
184 | # Train models
185 | encoder.train()
186 |
187 | sample_count = 0
188 | run_loss = 0
189 | run_top1 = 0.0
190 | run_top5 = 0.0
191 |
192 | # Print setup for distributed only printing on one node.
193 | if args.print_progress:
194 | logging.info('\nEpoch {}/{}:\n'.format(epoch+1, args.n_epochs))
195 | # tqdm for process (rank) 0 only when using distributed training
196 | train_dataloader = tqdm(dataloaders['train'])
197 | else:
198 | train_dataloader = dataloaders['train']
199 |
200 | ''' epoch loop '''
201 | for i, (inputs, target) in enumerate(train_dataloader):
202 |
203 | inputs = inputs.cuda(non_blocking=True)
204 |
205 | target = target.cuda(non_blocking=True)
206 |
207 | # Forward pass
208 | optimiser.zero_grad()
209 |
210 | output = encoder(inputs)
211 |
212 | loss = criterion(output, target)
213 |
214 | loss.backward()
215 |
216 | optimiser.step()
217 |
218 | torch.cuda.synchronize()
219 |
220 | sample_count += inputs.size(0)
221 |
222 | run_loss += loss.item()
223 |
224 | predicted = output.argmax(1)
225 |
226 | acc = (predicted == target).sum().item() / target.size(0)
227 |
228 | run_top1 += acc
229 |
230 | _, output_topk = output.topk(5, 1, True, True)
231 |
232 | acc_top5 = (output_topk == target.view(-1, 1).expand_as(output_topk)
233 | ).sum().item() / target.size(0) # num corrects
234 |
235 | run_top5 += acc_top5
236 |
237 | epoch_pretrain_loss = run_loss / len(dataloaders['train']) # sample_count
238 |
239 | epoch_pretrain_acc = run_top1 / len(dataloaders['train'])
240 |
241 | epoch_pretrain_acc_top5 = run_top5 / len(dataloaders['train'])
242 |
243 | ''' Update Schedulers '''
244 | # TODO: Improve / add lr_scheduler for warmup
245 | if args.warmup_epochs > 0 and epoch+1 <= args.warmup_epochs:
246 | wu_lr = (args.learning_rate - args.base_lr) * \
247 | (float(epoch+1) / args.warmup_epochs) + args.base_lr
248 | save_lr = optimiser.param_groups[0]['lr']
249 | optimiser.param_groups[0]['lr'] = wu_lr
250 | else:
251 | # After warmup, decay lr with CosineAnnealingLR
252 | lr_decay.step()
253 |
254 | ''' Printing '''
255 | if args.print_progress: # only validate using process 0
256 | logging.info('\n[Train] loss: {:.4f}'.format(epoch_pretrain_loss))
257 |
258 | args.writer.add_scalars('epoch_loss', {
259 | 'pretrain': epoch_pretrain_loss}, epoch+1)
260 | args.writer.add_scalars('supervised_epoch_acc', {
261 | 'pretrain': epoch_pretrain_acc}, epoch+1)
262 | args.writer.add_scalars('supervised_epoch_acc_top5', {
263 | 'pretrain': epoch_pretrain_acc_top5}, epoch+1)
264 | args.writer.add_scalars('epoch_loss', {'pretrain': epoch_pretrain_loss}, epoch+1)
265 | args.writer.add_scalars('lr', {'pretrain': optimiser.param_groups[0]['lr']}, epoch+1)
266 |
267 | # For the best performing epoch, reset patience and save model,
268 | # else update patience.
269 | if epoch_pretrain_loss <= best_valid_loss:
270 | patience_counter = 0
271 | best_epoch = epoch + 1
272 | best_valid_loss = epoch_pretrain_loss
273 |
274 | # saving using process (rank) 0 only as all processes are in sync
275 |
276 | state = {
277 | #'args': args,
278 | 'encoder': encoder.state_dict(),
279 | 'optimiser': optimiser.state_dict(),
280 | 'epoch': epoch,
281 | }
282 |
283 | torch.save(state, args.checkpoint_dir)
284 | else:
285 | patience_counter += 1
286 | if patience_counter == (args.patience - 10):
287 | logging.info('\nPatience counter {}/{}.'.format(
288 | patience_counter, args.patience))
289 | elif patience_counter == args.patience:
290 | logging.info('\nEarly stopping... no improvement after {} Epochs.'.format(
291 | args.patience))
292 | break
293 |
294 | epoch_pretrain_loss = None # reset loss
295 |
296 | del state
297 |
298 | torch.cuda.empty_cache()
299 |
300 | gc.collect() # release unreferenced memory
301 |
302 |
303 | def finetune(encoder, dataloaders, args):
304 | ''' Finetune script - MoCo
305 |
306 | Freeze the encoder and train the supervised Linear Evaluation head with a Cross Entropy Loss.
307 | '''
308 |
309 | mode = 'finetune'
310 |
311 | ''' Optimisers '''
312 | # Only optimise the supervised head
313 | optimiser = get_optimiser((encoder,), mode, args)
314 |
315 | ''' Schedulers '''
316 | # Cosine LR Decay
317 | lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.finetune_epochs)
318 |
319 | ''' Loss / Criterion '''
320 | criterion = torch.nn.CrossEntropyLoss().cuda()
321 |
322 | # initilize Variables
323 | args.writer = SummaryWriter(args.summaries_dir)
324 | best_valid_loss = np.inf
325 | best_valid_acc = 0.0
326 | patience_counter = 0
327 |
328 | ''' Pretrain loop '''
329 | for epoch in range(args.finetune_epochs):
330 |
331 | # Freeze the encoder, train classification head
332 | encoder.eval()
333 |
334 | sample_count = 0
335 | run_loss = 0
336 | run_top1 = 0.0
337 | run_top5 = 0.0
338 |
339 | # Print setup for distributed only printing on one node.
340 | if args.print_progress:
341 | logging.info('\nEpoch {}/{}:\n'.format(epoch+1, args.finetune_epochs))
342 | # tqdm for process (rank) 0 only when using distributed training
343 | train_dataloader = tqdm(dataloaders['train'])
344 | else:
345 | train_dataloader = dataloaders['train']
346 |
347 | ''' epoch loop '''
348 | for i, (inputs, target) in enumerate(train_dataloader):
349 |
350 | inputs = inputs.cuda(non_blocking=True)
351 |
352 | target = target.cuda(non_blocking=True)
353 |
354 | # Forward pass
355 | optimiser.zero_grad()
356 |
357 | # Do not compute the gradients for the frozen encoder
358 | output = encoder(inputs)
359 |
360 | # Take pretrained encoder representations
361 | loss = criterion(output, target)
362 |
363 | loss.backward()
364 |
365 | optimiser.step()
366 |
367 | torch.cuda.synchronize()
368 |
369 | sample_count += inputs.size(0)
370 |
371 | run_loss += loss.item()
372 |
373 | predicted = output.argmax(1)
374 |
375 | acc = (predicted == target).sum().item() / target.size(0)
376 |
377 | run_top1 += acc
378 |
379 | _, output_topk = output.topk(5, 1, True, True)
380 |
381 | acc_top5 = (output_topk == target.view(-1, 1).expand_as(output_topk)
382 | ).sum().item() / target.size(0) # num corrects
383 |
384 | run_top5 += acc_top5
385 |
386 | epoch_finetune_loss = run_loss / len(dataloaders['train']) # sample_count
387 |
388 | epoch_finetune_acc = run_top1 / len(dataloaders['train'])
389 |
390 | epoch_finetune_acc_top5 = run_top5 / len(dataloaders['train'])
391 |
392 | ''' Update Schedulers '''
393 | # Decay lr with CosineAnnealingLR
394 | lr_decay.step()
395 |
396 | ''' Printing '''
397 | if args.print_progress: # only validate using process 0
398 | logging.info('\n[Finetune] loss: {:.4f},\t acc: {:.4f}, \t acc_top5: {:.4f}\n'.format(
399 | epoch_finetune_loss, epoch_finetune_acc, epoch_finetune_acc_top5))
400 |
401 | args.writer.add_scalars('finetune_epoch_loss', {'train': epoch_finetune_loss}, epoch+1)
402 | args.writer.add_scalars('finetune_epoch_acc', {'train': epoch_finetune_acc}, epoch+1)
403 | args.writer.add_scalars('finetune_epoch_acc_top5', {
404 | 'train': epoch_finetune_acc_top5}, epoch+1)
405 | args.writer.add_scalars(
406 | 'finetune_lr', {'train': optimiser.param_groups[0]['lr']}, epoch+1)
407 |
408 | valid_loss, valid_acc, valid_acc_top5 = evaluate(
409 | encoder, dataloaders, 'valid', epoch, args)
410 |
411 | # For the best performing epoch, reset patience and save model,
412 | # else update patience.
413 | if valid_acc >= best_valid_acc:
414 | patience_counter = 0
415 | best_epoch = epoch + 1
416 | best_valid_acc = valid_acc
417 |
418 | # saving using process (rank) 0 only as all processes are in sync
419 |
420 | state = {
421 | #'args': args,
422 | 'base_encoder': encoder.state_dict(),
423 | 'optimiser': optimiser.state_dict(),
424 | 'epoch': epoch
425 | }
426 |
427 | torch.save(state, (args.checkpoint_dir[:-3] + "_finetune.pt"))
428 | else:
429 | patience_counter += 1
430 | if patience_counter == (args.patience - 10):
431 | logging.info('\nPatience counter {}/{}.'.format(
432 | patience_counter, args.patience))
433 | elif patience_counter == args.patience:
434 | logging.info('\nEarly stopping... no improvement after {} Epochs.'.format(
435 | args.patience))
436 | break
437 |
438 | epoch_finetune_loss = None # reset loss
439 | epoch_finetune_acc = None
440 | epoch_finetune_acc_top5 = None
441 |
442 | del state
443 |
444 | torch.cuda.empty_cache()
445 |
446 | gc.collect() # release unreferenced memory
447 |
448 |
449 | def evaluate(encoder, dataloaders, mode, epoch, args):
450 | ''' Evaluate script - MoCo
451 |
452 | Evaluate the encoder and Linear Evaluation head with Cross Entropy loss.
453 | '''
454 |
455 | epoch_valid_loss = None # reset loss
456 | epoch_valid_acc = None # reset acc
457 | epoch_valid_acc_top5 = None
458 |
459 | ''' Loss / Criterion '''
460 | criterion = nn.CrossEntropyLoss().cuda()
461 |
462 | # initilize Variables
463 | args.writer = SummaryWriter(args.summaries_dir)
464 |
465 | # Evaluate both encoder and class head
466 | encoder.eval()
467 |
468 | # initilize Variables
469 | sample_count = 0
470 | run_loss = 0
471 | run_top1 = 0.0
472 | run_top5 = 0.0
473 |
474 | # Print setup for distributed only printing on one node.
475 | if args.print_progress:
476 | # tqdm for process (rank) 0 only when using distributed training
477 | eval_dataloader = tqdm(dataloaders[mode])
478 | else:
479 | eval_dataloader = dataloaders[mode]
480 |
481 | ''' epoch loop '''
482 | for i, (inputs, target) in enumerate(eval_dataloader):
483 |
484 | # Do not compute gradient for encoder and classification head
485 | encoder.zero_grad()
486 |
487 | inputs = inputs.cuda(non_blocking=True)
488 |
489 | target = target.cuda(non_blocking=True)
490 |
491 | # Forward pass
492 |
493 | output = encoder(inputs)
494 |
495 | loss = criterion(output, target)
496 |
497 | torch.cuda.synchronize()
498 |
499 | sample_count += inputs.size(0)
500 |
501 | run_loss += loss.item()
502 |
503 | predicted = output.argmax(-1)
504 |
505 | acc = (predicted == target).sum().item() / target.size(0)
506 |
507 | run_top1 += acc
508 |
509 | _, output_topk = output.topk(5, 1, True, True)
510 |
511 | acc_top5 = (output_topk == target.view(-1, 1).expand_as(output_topk)
512 | ).sum().item() / target.size(0) # num corrects
513 |
514 | run_top5 += acc_top5
515 |
516 | epoch_valid_loss = run_loss / len(dataloaders[mode]) # sample_count
517 |
518 | epoch_valid_acc = run_top1 / len(dataloaders[mode])
519 |
520 | epoch_valid_acc_top5 = run_top5 / len(dataloaders[mode])
521 |
522 | ''' Printing '''
523 | if args.print_progress: # only validate using process 0
524 | logging.info('\n[{}] loss: {:.4f},\t acc: {:.4f},\t acc_top5: {:.4f} \n'.format(
525 | mode, epoch_valid_loss, epoch_valid_acc, epoch_valid_acc_top5))
526 |
527 | if mode != 'test':
528 | args.writer.add_scalars('finetune_epoch_loss', {mode: epoch_valid_loss}, epoch+1)
529 | args.writer.add_scalars('finetune_epoch_acc', {mode: epoch_valid_acc}, epoch+1)
530 | args.writer.add_scalars('finetune_epoch_acc_top5', {
531 | 'train': epoch_valid_acc_top5}, epoch+1)
532 |
533 | torch.cuda.empty_cache()
534 |
535 | gc.collect() # release unreferenced memory
536 |
537 | return epoch_valid_loss, epoch_valid_acc, epoch_valid_acc_top5
538 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import logging
4 | import numpy as np
5 | import time
6 | import random
7 |
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 | import torch.nn as nn
12 |
13 | from PIL import Image, ImageFilter
14 |
15 |
16 | class GaussianBlur(object):
17 | """Gaussian blur augmentation: https://github.com/facebookresearch/moco/"""
18 |
19 | def __init__(self, sigma=[.1, 2.]):
20 | self.sigma = sigma
21 |
22 | def __call__(self, x):
23 | sigma = random.uniform(self.sigma[0], self.sigma[1])
24 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
25 | return x
26 |
27 |
28 | def load_moco(base_encoder, args):
29 | """ Loads the pre-trained MoCo model parameters.
30 |
31 | Applies the loaded pre-trained params to the base encoder used in Linear Evaluation,
32 | freezing all layers except the Linear Evaluation layer/s.
33 |
34 | Args:
35 | base_encoder (model): Randomly Initialised base_encoder.
36 |
37 | args (dict): Program arguments/commandline arguments.
38 | Returns:
39 | base_encoder (model): Initialised base_encoder with parameters from the MoCo query_encoder.
40 | """
41 | print("\n\nLoading the model: {}\n\n".format(args.load_checkpoint_dir))
42 |
43 | # Load the pretrained model
44 | checkpoint = torch.load(args.load_checkpoint_dir, map_location="cpu")
45 |
46 | # rename moco pre-trained keys
47 | state_dict = checkpoint['moco']
48 | for k in list(state_dict.keys()):
49 | # retain only encoder_q up to before the embedding layer
50 | if k.startswith('encoder_q') and not k.startswith('encoder_q.fc'):
51 | # remove prefix
52 | state_dict[k[len("encoder_q."):]] = state_dict[k]
53 | # delete renamed or unused k
54 | del state_dict[k]
55 |
56 | # Load the encoder parameters
57 | base_encoder.load_state_dict(state_dict, strict=False)
58 |
59 | return base_encoder
60 |
61 |
62 | def load_sup(base_encoder, args):
63 | """ Loads the pre-trained supervised model parameters.
64 |
65 | Applies the loaded pre-trained params to the base encoder used in Linear Evaluation,
66 | freezing all layers except the Linear Evaluation layer/s.
67 |
68 | Args:
69 | base_encoder (model): Randomly Initialised base_encoder.
70 |
71 | args (dict): Program arguments/commandline arguments.
72 | Returns:
73 | base_encoder (model): Initialised base_encoder with parameters from the supervised base_encoder.
74 | """
75 | print("\n\nLoading the model: {}\n\n".format(args.load_checkpoint_dir))
76 |
77 | # Load the pretrained model
78 | checkpoint = torch.load(args.load_checkpoint_dir)
79 |
80 | # Load the encoder parameters
81 | base_encoder.load_state_dict(checkpoint['encoder'])
82 |
83 | # freeze all layers but the last fc
84 | for name, param in base_encoder.named_parameters():
85 | if name not in ['fc.weight', 'fc.bias']:
86 | param.requires_grad = False
87 |
88 | # init the fc layer
89 | init_weights(base_encoder)
90 |
91 | return base_encoder
92 |
93 |
94 | def init_weights(m):
95 | '''Initialize weights with zeros
96 | '''
97 |
98 | # init the fc layer
99 | m.fc.weight.data.normal_(mean=0.0, std=0.01)
100 | m.fc.bias.data.zero_()
101 |
102 |
103 | class CustomDataset(Dataset):
104 | """ Creates a custom pytorch dataset.
105 |
106 | - Creates two views of the same input used for unsupervised visual
107 | representational learning. (SimCLR, Moco, MocoV2)
108 |
109 | Args:
110 | data (array): Array / List of datasamples
111 |
112 | labels (array): Array / List of labels corresponding to the datasamples
113 |
114 | transforms (Dictionary, optional): The torchvision transformations
115 | to make to the datasamples. (Default: None)
116 |
117 | target_transform (Dictionary, optional): The torchvision transformations
118 | to make to the labels. (Default: None)
119 |
120 | two_crop (bool, optional): Whether to perform and return two views
121 | of the data input. (Default: False)
122 |
123 | Returns:
124 | img (Tensor): Datasamples to feed to the model.
125 |
126 | labels (Tensor): Corresponding lables to the datasamples.
127 | """
128 |
129 | def __init__(self, data, labels, transform=None, target_transform=None, two_crop=False):
130 |
131 | # shuffle the dataset
132 | idx = np.random.permutation(data.shape[0])
133 |
134 | if isinstance(data, torch.Tensor):
135 | data = data.numpy() # to work with `ToPILImage'
136 |
137 | self.data = data[idx]
138 |
139 | # when STL10 'unlabelled'
140 | if not labels is None:
141 | self.labels = labels[idx]
142 | else:
143 | self.labels = labels
144 |
145 | self.transform = transform
146 | self.target_transform = target_transform
147 | self.two_crop = two_crop
148 |
149 | def __len__(self):
150 | return self.data.shape[0]
151 |
152 | def __getitem__(self, index):
153 |
154 | # If the input data is in form from torchvision.datasets.ImageFolder
155 | if isinstance(self.data[index][0], np.str_):
156 | # Load image from path
157 | image = Image.open(self.data[index][0]).convert('RGB')
158 |
159 | else:
160 | # Get image / numpy pixel values
161 | image = self.data[index]
162 |
163 | if self.transform is not None:
164 |
165 | # Data augmentation and normalisation
166 | img = self.transform(image)
167 |
168 | if self.target_transform is not None:
169 |
170 | # Transforms the target, i.e. object detection, segmentation
171 | target = self.target_transform(target)
172 |
173 | if self.two_crop:
174 |
175 | # Augments the images again to create a second view of the data
176 | img2 = self.transform(image)
177 |
178 | # Combine the views to pass to the model
179 | img = torch.cat([img, img2], dim=0)
180 |
181 | # when STL10 'unlabelled'
182 | if self.labels is None:
183 | return img, torch.Tensor([0])
184 | else:
185 | return img, self.labels[index].long()
186 |
187 |
188 | def random_split_image_folder(data, labels, n_classes, n_samples_per_class):
189 | """ Creates a class-balanced validation set from a training set.
190 |
191 | Specifically for the image folder class
192 | """
193 |
194 | train_x, train_y, valid_x, valid_y = [], [], [], []
195 |
196 | if isinstance(labels, list):
197 | labels = np.array(labels)
198 |
199 | for i in range(n_classes):
200 | # get indices of all class 'c' samples
201 | c_idx = (np.array(labels) == i).nonzero()[0]
202 | # get n unique class 'c' samples
203 | valid_samples = np.random.choice(c_idx, n_samples_per_class[i], replace=False)
204 | # get remaining samples of class 'c'
205 | train_samples = np.setdiff1d(c_idx, valid_samples)
206 | # assign class c samples to validation, and remaining to training
207 | train_x.extend(data[train_samples])
208 | train_y.extend(labels[train_samples])
209 | valid_x.extend(data[valid_samples])
210 | valid_y.extend(labels[valid_samples])
211 |
212 | # torch.from_numpy(np.stack(labels)) this takes the list of class ids and turns them to tensor.long
213 |
214 | return {'train': train_x, 'valid': valid_x}, \
215 | {'train': torch.from_numpy(np.stack(train_y)), 'valid': torch.from_numpy(np.stack(valid_y))}
216 |
217 |
218 | def random_split(data, labels, n_classes, n_samples_per_class):
219 | """ Creates a class-balanced validation set from a training set.
220 | """
221 |
222 | train_x, train_y, valid_x, valid_y = [], [], [], []
223 |
224 | if isinstance(labels, list):
225 | labels = np.array(labels)
226 |
227 | for i in range(n_classes):
228 | # get indices of all class 'c' samples
229 | c_idx = (np.array(labels) == i).nonzero()[0]
230 | # get n unique class 'c' samples
231 | valid_samples = np.random.choice(c_idx, n_samples_per_class[i], replace=False)
232 | # get remaining samples of class 'c'
233 | train_samples = np.setdiff1d(c_idx, valid_samples)
234 | # assign class c samples to validation, and remaining to training
235 | train_x.extend(data[train_samples])
236 | train_y.extend(labels[train_samples])
237 | valid_x.extend(data[valid_samples])
238 | valid_y.extend(labels[valid_samples])
239 |
240 | if isinstance(data, torch.Tensor):
241 | # torch.stack transforms list of tensors to tensor
242 | return {'train': torch.stack(train_x), 'valid': torch.stack(valid_x)}, \
243 | {'train': torch.stack(train_y), 'valid': torch.stack(valid_y)}
244 | # transforms list of np arrays to tensor
245 | return {'train': torch.from_numpy(np.stack(train_x)),
246 | 'valid': torch.from_numpy(np.stack(valid_x))}, \
247 | {'train': torch.from_numpy(np.stack(train_y)),
248 | 'valid': torch.from_numpy(np.stack(valid_y))}
249 |
250 |
251 | def sample_weights(labels):
252 | """ Calculates per sample weights. """
253 | class_sample_count = np.unique(labels, return_counts=True)[1]
254 | class_weights = 1. / torch.Tensor(class_sample_count)
255 | return class_weights[list(map(int, labels))]
256 |
257 |
258 | def experiment_config(parser, args):
259 | """ Handles experiment configuration and creates new dirs for model.
260 | """
261 | # check number of models already saved in 'experiments' dir, add 1 to get new model number
262 | run_dir = os.path.join(os.path.split(os.getcwd())[0], 'experiments')
263 |
264 | os.makedirs(run_dir, exist_ok=True)
265 |
266 | run_name = time.strftime("%Y-%m-%d_%H-%M-%S")
267 |
268 | # create all save dirs
269 | model_dir = os.path.join(run_dir, run_name)
270 |
271 | os.makedirs(model_dir, exist_ok=True)
272 |
273 | args.summaries_dir = os.path.join(model_dir, 'summaries')
274 | args.checkpoint_dir = os.path.join(model_dir, 'checkpoint.pt')
275 |
276 | if not args.finetune:
277 | args.load_checkpoint_dir = args.checkpoint_dir
278 |
279 | os.makedirs(args.summaries_dir, exist_ok=True)
280 |
281 | # save hyperparameters in .txt file
282 | with open(os.path.join(model_dir, 'hyperparams.txt'), 'w') as logs:
283 | for key, value in vars(args).items():
284 | logs.write('--{0}={1} \n'.format(str(key), str(value)))
285 |
286 | # save config file used in .txt file
287 | with open(os.path.join(model_dir, 'config.txt'), 'w') as logs:
288 | # Remove the string from the blur_sigma value list
289 | config = parser.format_values().replace("'", "")
290 | # Remove the first line, path to original config file
291 | config = config[config.find('\n')+1:]
292 | logs.write('{}'.format(config))
293 |
294 | # reset root logger
295 | [logging.root.removeHandler(handler) for handler in logging.root.handlers[:]]
296 | # info logger for saving command line outputs during training
297 | logging.basicConfig(level=logging.INFO, format='%(message)s',
298 | handlers=[logging.FileHandler(os.path.join(model_dir, 'trainlogs.txt')),
299 | logging.StreamHandler()])
300 | return args
301 |
302 |
303 | def print_network(model, args):
304 | """ Utility for printing out a model's architecture.
305 | """
306 | logging.info('-'*70) # print some info on architecture
307 | logging.info('{:>25} {:>27} {:>15}'.format('Layer.Parameter', 'Shape', 'Param#'))
308 | logging.info('-'*70)
309 |
310 | for param in model.state_dict():
311 | p_name = param.split('.')[-2]+'.'+param.split('.')[-1]
312 | # don't print batch norm layers for prettyness
313 | if p_name[:2] != 'BN' and p_name[:2] != 'bn':
314 | logging.info(
315 | '{:>25} {:>27} {:>15}'.format(
316 | p_name,
317 | str(list(model.state_dict()[param].squeeze().size())),
318 | '{0:,}'.format(np.product(list(model.state_dict()[param].size())))
319 | )
320 | )
321 | logging.info('-'*70)
322 |
323 | logging.info('\nTotal params: {:,}\n\nSummaries dir: {}\n'.format(
324 | sum(p.numel() for p in model.parameters()),
325 | args.summaries_dir))
326 |
327 | for key, value in vars(args).items():
328 | if str(key) != 'print_progress':
329 | logging.info('--{0}: {1}'.format(str(key), str(value)))
330 |
--------------------------------------------------------------------------------