├── LICENSE
├── README.md
├── VOC_CLF
├── dataset.py
├── main.py
├── test.py
├── train.py
└── utils.py
├── cmd
├── __init__.py
├── run_multi.sh
└── run_single.sh
├── data_processing
├── Image_ops.py
├── Multi_FixTransform.py
├── RandAugment.py
└── __init__.py
├── detection
├── configs
│ ├── Base-RCNN-C4-BN.yaml
│ ├── coco_R_50_C4_2x.yaml
│ ├── coco_R_50_C4_2x_clsa.yaml
│ ├── pascal_voc_R_50_C4_24k.yaml
│ └── pascal_voc_R_50_C4_24k_CLSA.yaml
├── convert-pretrain-to-detectron2.py
└── train_net.py
├── lincls.py
├── main_clsa.py
├── model
├── CLSA.py
└── __init__.py
├── ops
├── Config_Envrionment.py
├── __init__.py
├── argparser.py
└── os_operation.py
├── requirements.txt
└── training
├── __init__.py
├── main_worker.py
├── train.py
└── train_utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Lab for MAchine Perception and LEarning (MAPLE)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CLSA
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | CLSA is a self-supervised learning methods which focused on the pattern learning from strong augmentations.
12 |
13 | Copyright (C) 2020 Xiao Wang, Guo-Jun Qi
14 |
15 | License: MIT for academic use.
16 |
17 | Contact: Guo-Jun Qi (guojunq@gmail.com)
18 |
19 |
20 | ## Introduction
21 | Representation learning has been greatly improved with the advance of contrastive learning methods. Those methods have greatly benefited from various data augmentations that are carefully designated to maintain their identities so that the images transformed from the same instance can still be retrieved. However, those carefully designed transformations limited us to further explore the novel patterns carried by other transformations. To pave this gap, we propose a general framework called Contrastive Learning with Stronger Augmentations(CLSA) to complement current contrastive learning approaches. As found in our experiments, the distortions induced from the stronger make the transformed images can not be viewed as the same instance any more. Thus, we propose to minimize the distribution divergence between the weakly and strongly augmented images over the representation bank to supervise the retrieval of strongly augmented queries from a pool of candidates. Experiments on ImageNet dataset and downstream datasets showed the information from the strongly augmented images can greatly boost the performance. For example, CLSA achieves top-1 accuracy of 76.2% on ImageNet with a standard ResNet-50 architecture with a single-layer classifier fine-tuned, which is almost the same level as 76.5% of supervised results.
22 |
23 | ## Installation
24 | CUDA version should be 10.1 or higher.
25 | ### 1. [`Install git`](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git)
26 | ### 2. Clone the repository in your computer
27 | ```
28 | git clone git@github.com:maple-research-lab/CLSA.git && cd CLSA
29 | ```
30 |
31 | ### 3. Build dependencies.
32 | You have two options to install dependency on your computer:
33 | #### 3.1 Install with pip and python(Ver 3.6.9).
34 | ##### 3.1.1[`install pip`](https://pip.pypa.io/en/stable/installing/).
35 | ##### 3.1.2 Install dependency in command line.
36 | ```
37 | pip install -r requirements.txt --user
38 | ```
39 | If you encounter any errors, you can install each library one by one:
40 | ```
41 | pip install torch==1.7.1
42 | pip install torchvision==0.8.2
43 | pip install numpy==1.19.5
44 | pip install Pillow==5.1.0
45 | pip install tensorboard==1.14.0
46 | pip install tensorboardX==1.7
47 | ```
48 |
49 | #### 3.2 Install with anaconda
50 | ##### 3.2.1 [`install conda`](https://docs.conda.io/projects/conda/en/latest/user-guide/install/macos.html).
51 | ##### 3.2.2 Install dependency in command line
52 | ```
53 | conda create -n CLSA python=3.6.9
54 | conda activate CLSA
55 | pip install -r requirements.txt
56 | ```
57 | Each time when you want to run my code, simply activate the environment by
58 | ```
59 | conda activate CLSA
60 | conda deactivate(If you want to exit)
61 | ```
62 | #### 4 Prepare the ImageNet dataset
63 | ##### 4.1 Download the [ImageNet2012 Dataset](http://image-net.org/challenges/LSVRC/2012/) under "./datasets/imagenet2012".
64 | ##### 4.2 Go to path "./datasets/imagenet2012/val"
65 | ##### 4.3 move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh)
66 |
67 | ## Usage
68 |
69 | ### Unsupervised Training
70 | This implementation only supports multi-gpu, DistributedDataParallel training, which is faster and simpler; single-gpu or DataParallel training is not supported.
71 | #### Single Crop
72 | ##### 1 Without symmetrical loss
73 | ```
74 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 --size_crops 224 96 --min_scale_crops 0.2 0.086 --max_scale_crops 1.0 0.429 --pick_strong 1 --pick_weak 0 --clsa_t 0.2 --sym 0
75 | ```
76 | Here the [data_path] should be the root directory of imagenet dataset.
77 |
78 | ##### 2 With symmetrical loss (Not verified)
79 | ```
80 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 --size_crops 224 96 --min_scale_crops 0.2 0.086 --max_scale_crops 1.0 0.429 --pick_strong 1 --pick_weak 0 --clsa_t 0.2 --sym 1
81 | ```
82 | Here the [data_path] should be the root directory of imagenet dataset.
83 |
84 | #### Multi Crop
85 |
86 | ##### 1 Without symmetrical loss
87 | ```
88 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 1 1 1 --size_crops 224 192 160 128 96 --min_scale_crops 0.2 0.172 0.143 0.114 0.086 --max_scale_crops 1.0 0.86 0.715 0.571 0.429 --pick_strong 0 1 2 3 4 --pick_weak 0 1 2 3 4 --clsa_t 0.2 --sym 0
89 | ```
90 | Here the [data_path] should be the root directory of imagenet dataset.
91 |
92 | ##### 2 With symmetrical loss (Not verified)
93 | ```
94 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 1 1 1 --size_crops 224 192 160 128 96 --min_scale_crops 0.2 0.172 0.143 0.114 0.086 --max_scale_crops 1.0 0.86 0.715 0.571 0.429 --pick_strong 0 1 2 3 4 --pick_weak 0 1 2 3 4 --clsa_t 0.2 --sym 1
95 | ```
96 | Here the [data_path] should be the root directory of imagenet dataset.
97 |
98 | ### Linear Classification
99 | With a pre-trained model, we can easily evaluate its performance on ImageNet with:
100 | ```
101 | python3 lincls.py --data=./datasets/imagenet2012 --dist-url=tcp://localhost:10001 --pretrained=[pretrained_model_path]
102 | ```
103 | [pretrained_model_path] should be the Imagenet pretrained model path.
104 |
105 | Performance:
106 |
107 |
108 |
109 | pre-train network |
110 | pre-train epochs |
111 | Crop |
112 | CLSA top-1 acc. |
113 | Model Link |
114 |
115 | ResNet-50 |
116 | 200 |
117 | Single |
118 | 69.4 |
119 | model |
120 |
121 | ResNet-50 |
122 | 200 |
123 | Multi |
124 | 73.3 |
125 | model |
126 |
127 | ResNet-50 |
128 | 800 |
129 | Single |
130 | 72.2 |
131 | model |
132 |
133 | ResNet-50 |
134 | 800 |
135 | Multi |
136 | 76.2 |
137 | None |
138 |
139 |
140 |
141 | Really sorry that we can't provide CLSA* 800 epochs' model, which is because that we train it with 32 internal GPUs and we can't download it because of company regulations. For downstream tasks, we found multi-200epoch model also had similar performance. Thus, we suggested you to use this [model](https://purdue0-my.sharepoint.com/:u:/g/personal/wang3702_purdue_edu/Ed8IVMBAvp1GmqABFMskEbYBz6B1vq65kp2IQlukFiS6mw?e=K0G5H6) for downstream purposes.
142 |
143 | ### Transfering to VOC07 Classification
144 | #### 1 Download [Dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar) under "./datasets/voc"
145 | #### 2 Linear Evaluation:
146 | ```
147 | cd VOC_CLF
148 | python3 main.py --data=[VOC_dataset_dir] --pretrained=[pretrained_model_path]
149 | ```
150 | Here VOC directory should be the directory includes "vockit" directory; [VOC_dataset_dir] is the VOC dataset path; [pretrained_model_path] is the imagenet pretrained model path.
151 |
152 | ### Transfer to Object Detection
153 | #### 1. Install [detectron2](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md).
154 |
155 | #### 2. Convert a pre-trained CLSA model to detectron2's format:
156 | ```
157 | # in detection folder
158 | python3 convert-pretrain-to-detectron2.py input.pth.tar output.pkl
159 | ```
160 |
161 | #### 3. download [VOC Dataset](http://places.csail.mit.edu/user/index.php) and [COCO Dataset](https://cocodataset.org/#download) under "./detection/datasets" directory,
162 | following the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) requried by detectron2.
163 |
164 | #### 4. Run training:
165 | ##### 4.1 Pascal detection
166 | ```
167 | cd detection
168 | python train_net.py --config-file configs/pascal_voc_R_50_C4_24k_CLSA.yaml --num-gpus 8 MODEL.WEIGHTS ./output.pkl
169 | ```
170 | ##### 4.2 COCO detection
171 | ```
172 | cd detection
173 | python train_net.py --config-file configs/coco_R_50_C4_2x_clsa.yaml --num-gpus 8 MODEL.WEIGHTS ./output.pkl
174 | ```
175 |
176 |
177 | ## Citation:
178 | [Contrastive Learning with Stronger Augmentations](https://arxiv.org/abs/2104.07713)
179 | ```
180 | @article{wang2021contrastive,
181 | title={Contrastive learning with stronger augmentations},
182 | author={Wang, Xiao and Qi, Guo-Jun},
183 | journal={arXiv preprint arXiv:2104.07713},
184 | year={2021}
185 | }
186 | ```
187 |
188 |
189 |
--------------------------------------------------------------------------------
/VOC_CLF/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Tue Mar 12 23:23:51 2019
4 |
5 | @author: Keshik
6 | """
7 | import torchvision.datasets.voc as voc
8 |
9 | class PascalVOC_Dataset(voc.VOCDetection):
10 | """`Pascal VOC `_ Detection Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the VOC Dataset.
14 | year (string, optional): The dataset year, supports years 2007 to 2012.
15 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
16 | download (bool, optional): If true, downloads the dataset from the internet and
17 | puts it in root directory. If dataset is already downloaded, it is not
18 | downloaded again.
19 | (default: alphabetic indexing of VOC's 20 classes).
20 | transform (callable, optional): A function/transform that takes in an PIL image
21 | and returns a transformed version. E.g, ``transforms.RandomCrop``
22 | target_transform (callable, required): A function/transform that takes in the
23 | target and transforms it.
24 | """
25 | def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None):
26 |
27 | super().__init__(
28 | root,
29 | year=year,
30 | image_set=image_set,
31 | download=download,
32 | transform=transform,
33 | target_transform=target_transform)
34 |
35 |
36 | def __getitem__(self, index):
37 | """
38 | Args:
39 | index (int): Index
40 |
41 | Returns:
42 | tuple: (image, target) where target is the image segmentation.
43 | """
44 | return super().__getitem__(index)
45 |
46 |
47 | def __len__(self):
48 | """
49 | Returns:
50 | size of the dataset
51 | """
52 | return len(self.images)
53 |
--------------------------------------------------------------------------------
/VOC_CLF/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | adopted from https://github.com/keshik6/pascal-voc-classification
4 | Created on Wed Mar 13 10:50:25 2019
5 |
6 | @author: Keshik
7 | """
8 |
9 | import torch
10 | import numpy as np
11 | from torchvision import transforms
12 | import torchvision.models as models
13 | from torch.utils.data import DataLoader
14 | from dataset import PascalVOC_Dataset
15 | import torch.optim as optim
16 | from train import train_model, test
17 | from utils import encode_labels, plot_history
18 | import os
19 | import utils
20 |
21 | def main(args,data_dir, model_name, num, lr, epochs, batch_size = 16, download_data = False):
22 | """
23 | Main function
24 |
25 | Args:
26 | data_dir: directory to download Pascal VOC data
27 | model_name: resnet18, resnet34 or resnet50
28 | num: model_num for file management purposes (can be any postive integer. Your results stored will have this number as suffix)
29 | lr: initial learning rate list [lr for resnet_backbone, lr for resnet_fc]
30 | epochs: number of training epochs
31 | batch_size: batch size. Default=16
32 | download_data: Boolean. If true will download the entire 2012 pascal VOC data as tar to the specified data_dir.
33 | Set this to True only the first time you run it, and then set to False. Default False
34 | save_results: Store results (boolean). Default False
35 |
36 | Returns:
37 | test-time loss and average precision
38 |
39 | Example way of running this function:
40 | if __name__ == '__main__':
41 | main('../data/', "resnet34", num=1, lr = [1.5e-4, 5e-2], epochs = 15, batch_size=16, download_data=False, save_results=True)
42 | """
43 |
44 |
45 |
46 | # Initialize cuda parameters
47 | use_cuda = torch.cuda.is_available()
48 | np.random.seed(2019)
49 | torch.manual_seed(2019)
50 | device = torch.device("cuda" if use_cuda else "cpu")
51 |
52 | print("Available device = ", device)
53 | model = models.__dict__[args.arch]()
54 | for name, param in model.named_parameters():
55 | if name not in ['fc.weight', 'fc.bias']:
56 | param.requires_grad = False
57 | #model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
58 |
59 | #model.load_state_dict(model_zoo.load_url(model_urls[model_name]))
60 | checkpoint = torch.load(args.pretrained, map_location="cpu")
61 | state_dict = checkpoint['state_dict']
62 | for k in list(state_dict.keys()):
63 | # retain only encoder_q up to before the embedding layer
64 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
65 | # remove prefix
66 | state_dict[k[len("module.encoder_q."):]] = state_dict[k]
67 | # delete renamed or unused k
68 | del state_dict[k]
69 | msg = model.load_state_dict(state_dict, strict=False)
70 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
71 |
72 | print("=> loaded pre-trained model '{}'".format(args.pretrained))
73 | num_ftrs = model.fc.in_features
74 | model.fc = torch.nn.Linear(num_ftrs, 20)
75 | model.fc.weight.data.normal_(mean=0.0, std=0.01)
76 | model.fc.bias.data.zero_()
77 |
78 | model.to(device)
79 | parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
80 | print("optimized parameters",parameters)
81 | optimizer = optim.SGD([
82 | {'params': parameters, 'lr': lr, 'momentum': 0.9}
83 | ])
84 |
85 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 12, eta_min=0, last_epoch=-1)
86 |
87 | # Imagnet values
88 | mean=[0.457342265910642, 0.4387686270106377, 0.4073427106250871]
89 | std=[0.26753769276329037, 0.2638145880487105, 0.2776826934044154]
90 |
91 | # mean=[0.485, 0.456, 0.406]
92 | # std=[0.229, 0.224, 0.225]
93 |
94 | transformations = transforms.Compose([transforms.Resize((300, 300)),
95 | # transforms.RandomChoice([
96 | # transforms.CenterCrop(300),
97 | # transforms.RandomResizedCrop(300, scale=(0.80, 1.0)),
98 | # ]),
99 | transforms.RandomChoice([
100 | transforms.ColorJitter(brightness=(0.80, 1.20)),
101 | transforms.RandomGrayscale(p = 0.25)
102 | ]),
103 | transforms.RandomHorizontalFlip(p = 0.25),
104 | transforms.RandomRotation(25),
105 | transforms.ToTensor(),
106 | transforms.Normalize(mean = mean, std = std),
107 | ])
108 |
109 | transformations_valid = transforms.Compose([transforms.Resize(330),
110 | transforms.CenterCrop(300),
111 | transforms.ToTensor(),
112 | transforms.Normalize(mean = mean, std = std),
113 | ])
114 |
115 | # Create train dataloader
116 | dataset_train = PascalVOC_Dataset(data_dir,
117 | year='2007',
118 | image_set='train',
119 | download=download_data,
120 | transform=transformations,
121 | target_transform=encode_labels)
122 |
123 | train_loader = DataLoader(dataset_train, batch_size=batch_size, num_workers=4, shuffle=True)
124 |
125 | # Create validation dataloader
126 | dataset_valid = PascalVOC_Dataset(data_dir,
127 | year='2007',
128 | image_set='val',
129 | download=download_data,
130 | transform=transformations_valid,
131 | target_transform=encode_labels)
132 |
133 | valid_loader = DataLoader(dataset_valid, batch_size=batch_size, num_workers=4)
134 |
135 | # Load the best weights before testing
136 | if not os.path.exists(args.log):
137 | os.mkdir(args.log)
138 |
139 | log_file = open(os.path.join(args.log, "log-{}.txt".format(num)), "w+")
140 | model_dir=os.path.join(args.log,"model")
141 | if not os.path.exists(model_dir):
142 | os.mkdir(model_dir)
143 | log_file.write("----------Experiment {} - {}-----------\n".format(num, model_name))
144 | log_file.write("transformations == {}\n".format(transformations.__str__()))
145 | trn_hist, val_hist = train_model(model, device, optimizer, scheduler, train_loader, valid_loader, model_dir, num, epochs, log_file)
146 | torch.cuda.empty_cache()
147 |
148 | plot_history(trn_hist[0], val_hist[0], "Loss", os.path.join(model_dir, "loss-{}".format(num)))
149 | plot_history(trn_hist[1], val_hist[1], "Accuracy", os.path.join(model_dir, "accuracy-{}".format(num)))
150 | log_file.close()
151 |
152 | #---------------Test your model here---------------------------------------
153 | # Load the best weights before testing
154 | print("Evaluating model on test set")
155 | print("Loading best weights")
156 | weights_file_path = os.path.join(model_dir, "model-{}.pth".format(num))
157 | assert os.path.isfile(weights_file_path)
158 | print("Loading best weights")
159 |
160 | model.load_state_dict(torch.load(weights_file_path))
161 | transformations_test = transforms.Compose([transforms.Resize(330),
162 | transforms.FiveCrop(300),
163 | transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
164 | transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean = mean, std = std)(crop) for crop in crops])),
165 | ])
166 |
167 |
168 | dataset_test = PascalVOC_Dataset(data_dir,
169 | year='2007',
170 | image_set='test',
171 | download=download_data,
172 | transform=transformations_test,
173 | target_transform=encode_labels)
174 |
175 |
176 | test_loader = DataLoader(dataset_test, batch_size=batch_size, num_workers=0, shuffle=False)
177 |
178 | loss, ap, scores, gt = test(model, device, test_loader, returnAllScores=True)
179 |
180 | gt_path, scores_path, scores_with_gt_path = os.path.join(model_dir, "gt-{}.csv".format(num)), os.path.join(model_dir, "scores-{}.csv".format(num)), os.path.join(model_dir, "scores_wth_gt-{}.csv".format(num))
181 |
182 | utils.save_results(test_loader.dataset.images, gt, utils.object_categories, gt_path)
183 | utils.save_results(test_loader.dataset.images, scores, utils.object_categories, scores_path)
184 | utils.append_gt(gt_path, scores_path, scores_with_gt_path)
185 |
186 | utils.get_classification_accuracy(gt_path, scores_path, os.path.join(model_dir, "clf_vs_threshold-{}.png".format(num)))
187 |
188 | return loss, ap
189 |
190 |
191 | model_names = sorted(name for name in models.__dict__
192 | if name.islower() and not name.startswith("__")
193 | and callable(models.__dict__[name]))
194 | # Execute main function here
195 | import argparse
196 | if __name__ == '__main__':
197 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
198 | parser.add_argument('--data', type=str, metavar='DIR',
199 | help='path to dataset')
200 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
201 | choices=model_names,
202 | help='model architecture: ' +
203 | ' | '.join(model_names) +
204 | ' (default: resnet50)')
205 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
206 | help='number of data loading workers (default: 32)')
207 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
208 | help='number of total epochs to run')
209 | parser.add_argument('--batch-size', default=16, type=int,
210 | metavar='N',
211 | help='mini-batch size (default: 256), this is the total '
212 | 'batch size of all GPUs on the current node when '
213 | 'using Data Parallel or Distributed Data Parallel')
214 | parser.add_argument('--lr', '--learning-rate', default=0.05, type=float,
215 | metavar='LR', help='initial learning rate', dest='lr')
216 | parser.add_argument('--gpu', default=None, type=str,
217 | help='GPU id to use.')
218 | parser.add_argument('--pretrained', default='', type=str,
219 | help='path to moco pretrained checkpoint')
220 | parser.add_argument("--log",default="train_log",type=str,help="log path for training")
221 | parser.add_argument("--run_num",default=1, type=int, help="specify the training saving path")
222 | args = parser.parse_args()
223 | choose = args.gpu
224 | if choose is not None:
225 | os.environ['CUDA_VISIBLE_DEVICES'] = choose
226 | main(args,args.data, args.arch, num=args.run_num, lr=args.lr, epochs=args.epochs, batch_size=args.batch_size)
227 |
228 | #if __name__ == '__main__':
229 | # main('../data/', "resnet34", num=1, lr = [1.5e-4, 5e-2], epochs = 1, batch_size=16, download_data=False, save_results=True)
--------------------------------------------------------------------------------
/VOC_CLF/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Wed Mar 13 11:22:28 2019
4 |
5 | @author: Keshik
6 | """
7 | import torch
8 | from tqdm import tqdm
9 | import gc
10 | from sklearn.metrics import average_precision_score
11 |
--------------------------------------------------------------------------------
/VOC_CLF/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Wed Mar 13 10:37:39 2019
4 |
5 | @author: Keshik
6 | """
7 |
8 | from tqdm import tqdm
9 | import torch
10 | import gc
11 | import os
12 | from utils import get_ap_score
13 | import numpy as np
14 |
15 | def train_model(model, device, optimizer, scheduler, train_loader, valid_loader, save_dir, model_num, epochs, log_file):
16 | """
17 | Train a deep neural network model
18 |
19 | Args:
20 | model: pytorch model object
21 | device: cuda or cpu
22 | optimizer: pytorch optimizer object
23 | scheduler: learning rate scheduler object that wraps the optimizer
24 | train_dataloader: training images dataloader
25 | valid_dataloader: validation images dataloader
26 | save_dir: Location to save model weights, plots and log_file
27 | epochs: number of training epochs
28 | log_file: text file instance to record training and validation history
29 |
30 | Returns:
31 | Training history and Validation history (loss and average precision)
32 | """
33 |
34 | tr_loss, tr_map = [], []
35 | val_loss, val_map = [], []
36 | best_val_map = 0.0
37 |
38 | # Each epoch has a training and validation phase
39 | for epoch in range(epochs):
40 | print("-------Epoch {}----------".format(epoch+1))
41 | log_file.write("Epoch {} >>".format(epoch+1))
42 | scheduler.step()
43 |
44 | for phase in ['train', 'valid']:
45 | running_loss = 0.0
46 | running_ap = 0.0
47 |
48 | criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
49 | m = torch.nn.Sigmoid()
50 |
51 | if phase == 'train':
52 | model.train(True) # Set model to training mode
53 |
54 | for data, target in tqdm(train_loader):
55 | #print(data)
56 | target = target.float()
57 | data, target = data.to(device), target.to(device)
58 |
59 | # zero the parameter gradients
60 | optimizer.zero_grad()
61 |
62 | output = model(data)
63 |
64 | loss = criterion(output, target)
65 |
66 | # Get metrics here
67 | running_loss += loss # sum up batch loss
68 | running_ap += get_ap_score(torch.Tensor.cpu(target).detach().numpy(), torch.Tensor.cpu(m(output)).detach().numpy())
69 |
70 | # Backpropagate the system the determine the gradients
71 | loss.backward()
72 |
73 | # Update the paramteres of the model
74 | optimizer.step()
75 |
76 | # clear variables
77 | del data, target, output
78 | gc.collect()
79 | torch.cuda.empty_cache()
80 |
81 | #print("loss = ", running_loss)
82 |
83 | num_samples = float(len(train_loader.dataset))
84 | tr_loss_ = running_loss.item()/num_samples
85 | tr_map_ = running_ap/num_samples
86 |
87 | print('train_loss: {:.4f}, train_avg_precision:{:.3f}'.format(
88 | tr_loss_, tr_map_))
89 |
90 | log_file.write('train_loss: {:.4f}, train_avg_precision:{:.3f}, '.format(
91 | tr_loss_, tr_map_))
92 |
93 | # Append the values to global arrays
94 | tr_loss.append(tr_loss_), tr_map.append(tr_map_)
95 |
96 |
97 | else:
98 | model.train(False) # Set model to evaluate mode
99 |
100 | # torch.no_grad is for memory savings
101 | with torch.no_grad():
102 | for data, target in tqdm(valid_loader):
103 | target = target.float()
104 | data, target = data.to(device), target.to(device)
105 | output = model(data)
106 |
107 | loss = criterion(output, target)
108 |
109 | running_loss += loss # sum up batch loss
110 | running_ap += get_ap_score(torch.Tensor.cpu(target).detach().numpy(), torch.Tensor.cpu(m(output)).detach().numpy())
111 |
112 | del data, target, output
113 | gc.collect()
114 | torch.cuda.empty_cache()
115 |
116 | num_samples = float(len(valid_loader.dataset))
117 | val_loss_ = running_loss.item()/num_samples
118 | val_map_ = running_ap/num_samples
119 |
120 | # Append the values to global arrays
121 | val_loss.append(val_loss_), val_map.append(val_map_)
122 |
123 | print('val_loss: {:.4f}, val_avg_precision:{:.3f}'.format(
124 | val_loss_, val_map_))
125 |
126 | log_file.write('val_loss: {:.4f}, val_avg_precision:{:.3f}\n'.format(
127 | val_loss_, val_map_))
128 |
129 | # Save model using val_acc
130 | if val_map_ >= best_val_map:
131 | best_val_map = val_map_
132 | log_file.write("saving best weights...\n")
133 | torch.save(model.state_dict(), os.path.join(save_dir,"model-{}.pth".format(model_num)))
134 |
135 | return ([tr_loss, tr_map], [val_loss, val_map])
136 |
137 |
138 |
139 | def test(model, device, test_loader, returnAllScores=False):
140 | """
141 | Evaluate a deep neural network model
142 |
143 | Args:
144 | model: pytorch model object
145 | device: cuda or cpu
146 | test_dataloader: test images dataloader
147 | returnAllScores: If true addtionally return all confidence scores and ground truth
148 |
149 | Returns:
150 | test loss and average precision. If returnAllScores = True, check Args
151 | """
152 | model.train(False)
153 |
154 | running_loss = 0
155 | running_ap = 0
156 |
157 | criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
158 | m = torch.nn.Sigmoid()
159 |
160 | if returnAllScores == True:
161 | all_scores = np.empty((0, 20), float)
162 | ground_scores = np.empty((0, 20), float)
163 |
164 | with torch.no_grad():
165 | for data, target in tqdm(test_loader):
166 | #print(data.size(), target.size())
167 | target = target.float()
168 | data, target = data.to(device), target.to(device)
169 | bs, ncrops, c, h, w = data.size()
170 |
171 | output = model(data.view(-1, c, h, w))
172 | output = output.view(bs, ncrops, -1).mean(1)
173 |
174 | loss = criterion(output, target)
175 |
176 | running_loss += loss # sum up batch loss
177 | running_ap += get_ap_score(torch.Tensor.cpu(target).detach().numpy(), torch.Tensor.cpu(m(output)).detach().numpy())
178 |
179 | if returnAllScores == True:
180 | all_scores = np.append(all_scores, torch.Tensor.cpu(m(output)).detach().numpy() , axis=0)
181 | ground_scores = np.append(ground_scores, torch.Tensor.cpu(target).detach().numpy() , axis=0)
182 |
183 | del data, target, output
184 | gc.collect()
185 | torch.cuda.empty_cache()
186 |
187 | num_samples = float(len(test_loader.dataset))
188 | avg_test_loss = running_loss.item()/num_samples
189 | test_map = running_ap/num_samples
190 |
191 | print('test_loss: {:.4f}, test_avg_precision:{:.3f}'.format(
192 | avg_test_loss, test_map))
193 |
194 |
195 | if returnAllScores == False:
196 | return avg_test_loss, running_ap
197 |
198 | return avg_test_loss, running_ap, all_scores, ground_scores
199 |
200 |
201 |
--------------------------------------------------------------------------------
/VOC_CLF/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Tue Mar 12 20:52:33 2019
4 |
5 | @author: Keshik
6 | """
7 | import os
8 | import math
9 | from tqdm import tqdm
10 | import torch
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | from sklearn.metrics import average_precision_score, accuracy_score
14 | import pandas as pd
15 |
16 | object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
17 | 'bottle', 'bus', 'car', 'cat', 'chair',
18 | 'cow', 'diningtable', 'dog', 'horse',
19 | 'motorbike', 'person', 'pottedplant',
20 | 'sheep', 'sofa', 'train', 'tvmonitor']
21 |
22 |
23 | def get_categories(labels_dir):
24 | """
25 | Get the object categories
26 |
27 | Args:
28 | label_dir: Directory that contains object specific label as .txt files
29 | Raises:
30 | FileNotFoundError: If the label directory does not exist
31 | Returns:
32 | Object categories as a list
33 | """
34 |
35 | if not os.path.isdir(labels_dir):
36 | raise FileNotFoundError
37 |
38 | else:
39 | categories = []
40 |
41 | for file in os.listdir(labels_dir):
42 | if file.endswith("_train.txt"):
43 | categories.append(file.split("_")[0])
44 |
45 | return categories
46 |
47 |
48 | def encode_labels(target):
49 | """
50 | Encode multiple labels using 1/0 encoding
51 |
52 | Args:
53 | target: xml tree file
54 | Returns:
55 | torch tensor encoding labels as 1/0 vector
56 | """
57 |
58 | ls = target['annotation']['object']
59 |
60 | j = []
61 | if type(ls) == dict:
62 | if int(ls['difficult']) == 0:
63 | j.append(object_categories.index(ls['name']))
64 |
65 | else:
66 | for i in range(len(ls)):
67 | if int(ls[i]['difficult']) == 0:
68 | j.append(object_categories.index(ls[i]['name']))
69 |
70 | k = np.zeros(len(object_categories))
71 | k[j] = 1
72 |
73 | return torch.from_numpy(k)
74 |
75 |
76 | def get_nrows(file_name):
77 | """
78 | Get the number of rows of a csv file
79 |
80 | Args:
81 | file_path: path of the csv file
82 | Raises:
83 | FileNotFoundError: If the csv file does not exist
84 | Returns:
85 | number of rows
86 | """
87 |
88 | if not os.path.isfile(file_name):
89 | raise FileNotFoundError
90 |
91 | s = 0
92 | with open(file_name) as f:
93 | s = sum(1 for line in f)
94 | return s
95 |
96 |
97 | def get_mean_and_std(dataloader):
98 | """
99 | Get the mean and std of a 3-channel image dataset
100 |
101 | Args:
102 | dataloader: pytorch dataloader
103 | Returns:
104 | mean and std of the dataset
105 | """
106 | mean = []
107 | std = []
108 |
109 | total = 0
110 | r_running, g_running, b_running = 0, 0, 0
111 | r2_running, g2_running, b2_running = 0, 0, 0
112 |
113 | with torch.no_grad():
114 | for data, target in tqdm(dataloader):
115 | r, g, b = data[:,0 ,:, :], data[:, 1, :, :], data[:, 2, :, :]
116 | r2, g2, b2 = r**2, g**2, b**2
117 |
118 | # Sum up values to find mean
119 | r_running += r.sum().item()
120 | g_running += g.sum().item()
121 | b_running += b.sum().item()
122 |
123 | # Sum up squared values to find standard deviation
124 | r2_running += r2.sum().item()
125 | g2_running += g2.sum().item()
126 | b2_running += b2.sum().item()
127 |
128 | total += data.size(0)*data.size(2)*data.size(3)
129 |
130 | # Append the mean values
131 | mean.extend([r_running/total,
132 | g_running/total,
133 | b_running/total])
134 |
135 | # Calculate standard deviation and append
136 | std.extend([
137 | math.sqrt((r2_running/total) - mean[0]**2),
138 | math.sqrt((g2_running/total) - mean[1]**2),
139 | math.sqrt((b2_running/total) - mean[2]**2)
140 | ])
141 |
142 | return mean, std
143 |
144 |
145 | def plot_history(train_hist, val_hist, y_label, filename, labels=["train", "validation"]):
146 | """
147 | Plot training and validation history
148 |
149 | Args:
150 | train_hist: numpy array consisting of train history values (loss/ accuracy metrics)
151 | valid_hist: numpy array consisting of validation history values (loss/ accuracy metrics)
152 | y_label: label for y_axis
153 | filename: filename to store the resulting plot
154 | labels: legend for the plot
155 |
156 | Returns:
157 | None
158 | """
159 | # Plot loss and accuracy
160 | xi = [i for i in range(0, len(train_hist), 2)]
161 | plt.plot(train_hist, label = labels[0])
162 | plt.plot(val_hist, label = labels[1])
163 | plt.xticks(xi)
164 | plt.legend()
165 | plt.xlabel("Epoch")
166 | plt.ylabel(y_label)
167 | plt.savefig(filename)
168 | plt.show()
169 |
170 |
171 | def get_ap_score(y_true, y_scores):
172 | """
173 | Get average precision score between 2 1-d numpy arrays
174 |
175 | Args:
176 | y_true: batch of true labels
177 | y_scores: batch of confidence scores
178 | =
179 | Returns:
180 | sum of batch average precision
181 | """
182 | scores = 0.0
183 |
184 | for i in range(y_true.shape[0]):
185 | scores += average_precision_score(y_true = y_true[i], y_score = y_scores[i])
186 |
187 | return scores
188 |
189 | def save_results(images, scores, columns, filename):
190 | """
191 | Save inference results as csv
192 |
193 | Args:
194 | images: inferred image list
195 | scores: confidence score for inferred images
196 | columns: object categories
197 | filename: name and location to save resulting csv
198 | """
199 | df_scores = pd.DataFrame(scores, columns=columns)
200 | df_scores['image'] = images
201 | df_scores.set_index('image', inplace=True)
202 | df_scores.to_csv(filename)
203 |
204 |
205 | def append_gt(gt_csv_path, scores_csv_path, store_filename):
206 | """
207 | Append ground truth to confidence score csv
208 |
209 | Args:
210 | gt_csv_path: Ground truth csv location
211 | scores_csv_path: Confidence scores csv path
212 | store_filename: name and location to save resulting csv
213 | """
214 | gt_df = pd.read_csv(gt_csv_path)
215 | scores_df = pd.read_csv(scores_csv_path)
216 |
217 | gt_label_list = []
218 | for index, row in gt_df.iterrows():
219 | arr = np.array(gt_df.iloc[index,1:], dtype=int)
220 | target_idx = np.ravel(np.where(arr == 1))
221 | j = [object_categories[i] for i in target_idx]
222 | gt_label_list.append(j)
223 |
224 | scores_df.insert(1, "gt", gt_label_list)
225 | scores_df.to_csv(store_filename, index=False)
226 |
227 |
228 |
229 | def get_classification_accuracy(gt_csv_path, scores_csv_path, store_filename):
230 | """
231 | Plot mean tail accuracy across all classes for threshold values
232 |
233 | Args:
234 | gt_csv_path: Ground truth csv location
235 | scores_csv_path: Confidence scores csv path
236 | store_filename: name and location to save resulting plot
237 | """
238 | gt_df = pd.read_csv(gt_csv_path)
239 | scores_df = pd.read_csv(scores_csv_path)
240 |
241 | # Get the top-50 images
242 | top_num = 2800
243 | image_num = 2
244 | num_threshold = 10
245 | results = []
246 |
247 | for image_num in range(1, 21):
248 | clf = np.sort(np.array(scores_df.iloc[:,image_num], dtype=float))[-top_num:]
249 | ls = np.linspace(0.0, 1.0, num=num_threshold)
250 |
251 | class_results = []
252 | for i in ls:
253 | clf = np.sort(np.array(scores_df.iloc[:,image_num], dtype=float))[-top_num:]
254 | clf_ind = np.argsort(np.array(scores_df.iloc[:,image_num], dtype=float))[-top_num:]
255 |
256 | # Read ground truth
257 | gt = np.sort(np.array(gt_df.iloc[:,image_num], dtype=int))
258 |
259 | # Now get the ground truth corresponding to top-50 scores
260 | gt = gt[clf_ind]
261 | clf[clf >= i] = 1
262 | clf[clf < i] = 0
263 |
264 | score = accuracy_score(y_true=gt, y_pred=clf, normalize=False)/clf.shape[0]
265 | class_results.append(score)
266 |
267 | results.append(class_results)
268 |
269 | results = np.asarray(results)
270 |
271 | ls = np.linspace(0.0, 1.0, num=num_threshold)
272 | plt.plot(ls, results.mean(0))
273 | plt.title("Mean Tail Accuracy vs Threshold")
274 | plt.xlabel("Threshold")
275 | plt.ylabel("Mean Tail Accuracy")
276 | plt.savefig(store_filename)
277 | plt.show()
278 |
279 |
280 | #get_classification_accuracy("../models/resnet18/results.csv", "../models/resnet18/gt.csv", "roc-curve.png")
281 |
--------------------------------------------------------------------------------
/cmd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/cmd/__init__.py
--------------------------------------------------------------------------------
/cmd/run_multi.sh:
--------------------------------------------------------------------------------
1 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 1 1 1 --size_crops 224 192 160 128 96 --min_scale_crops 0.2 0.172 0.143 0.114 0.086 --max_scale_crops 1.0 0.86 0.715 0.571 0.429 --pick_strong 0 1 2 3 4 --pick_weak 0 1 2 3 4 --clsa_t 0.2 --sym 0
2 |
--------------------------------------------------------------------------------
/cmd/run_single.sh:
--------------------------------------------------------------------------------
1 | python3 main_clsa.py --data=/data/imagenet --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 --size_crops 224 96 --min_scale_crops 0.2 0.086 --max_scale_crops 1.0 0.429 --pick_strong 1 --pick_weak 0 --clsa_t 0.2 --sym 0
2 |
--------------------------------------------------------------------------------
/data_processing/Image_ops.py:
--------------------------------------------------------------------------------
1 | from PIL import ImageFilter
2 | import random
3 | class TwoCropsTransform:
4 | """Take two random crops of one image as the query and key."""
5 |
6 | def __init__(self, base_transform):
7 | self.base_transform = base_transform
8 |
9 | def __call__(self, x):
10 | q = self.base_transform(x)
11 | k = self.base_transform(x)
12 | return [q, k]
13 |
14 |
15 | class GaussianBlur(object):
16 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
17 |
18 | def __init__(self, sigma=[.1, 2.]):
19 | self.sigma = sigma
20 |
21 | def __call__(self, x):
22 | sigma = random.uniform(self.sigma[0], self.sigma[1])
23 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
24 | return x
--------------------------------------------------------------------------------
/data_processing/Multi_FixTransform.py:
--------------------------------------------------------------------------------
1 | #modified from https://github.com/facebookresearch/swav/blob/master/src/multicropdataset.py
2 | from torchvision import transforms
3 | from data_processing.RandAugment import RandAugment
4 | from data_processing.Image_ops import GaussianBlur
5 | class Multi_Fixtransform(object):
6 | def __init__(self,
7 | size_crops,
8 | nmb_crops,
9 | min_scale_crops,
10 | max_scale_crops,normalize,
11 | aug_times,init_size=224):
12 | """
13 | :param size_crops: list of crops with crop output img size
14 | :param nmb_crops: number of output cropped image
15 | :param min_scale_crops: minimum scale for corresponding crop
16 | :param max_scale_crops: maximum scale for corresponding crop
17 | :param normalize: normalize operation
18 | :param aug_times: strong augmentation times
19 | :param init_size: key image size
20 | """
21 | assert len(size_crops) == len(nmb_crops)
22 | assert len(min_scale_crops) == len(nmb_crops)
23 | assert len(max_scale_crops) == len(nmb_crops)
24 | trans=[]
25 | #key image transform
26 | self.weak = transforms.Compose([
27 | transforms.RandomResizedCrop(init_size, scale=(0.2, 1.)),
28 | transforms.RandomApply([
29 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
30 | ], p=0.8),
31 | transforms.RandomGrayscale(p=0.2),
32 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
33 | transforms.RandomHorizontalFlip(),
34 | transforms.ToTensor(),
35 | normalize
36 | ])
37 | trans.append(self.weak)
38 | self.aug_times=aug_times
39 | trans_weak=[]
40 | trans_strong=[]
41 | for i in range(len(size_crops)):
42 | randomresizedcrop = transforms.RandomResizedCrop(
43 | size_crops[i],
44 | scale=(min_scale_crops[i], max_scale_crops[i]),
45 | )
46 |
47 | strong = transforms.Compose([
48 | randomresizedcrop,
49 | transforms.RandomApply([
50 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
51 | ], p=0.8),
52 | transforms.RandomGrayscale(p=0.2),
53 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
54 | transforms.RandomHorizontalFlip(),
55 | RandAugment(n=self.aug_times, m=10),
56 | transforms.ToTensor(),
57 | normalize
58 | ])
59 | weak=transforms.Compose([
60 | randomresizedcrop,
61 | transforms.RandomApply([
62 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
63 | ], p=0.8),
64 | transforms.RandomGrayscale(p=0.2),
65 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
66 | transforms.RandomHorizontalFlip(),
67 | transforms.ToTensor(),
68 | normalize
69 | ])
70 | trans_weak.extend([weak]*nmb_crops[i])
71 | trans_strong.extend([strong]*nmb_crops[i])
72 | trans.extend(trans_weak)
73 | trans.extend(trans_strong)
74 | self.trans=trans
75 | def __call__(self, x):
76 | multi_crops = list(map(lambda trans: trans(x), self.trans))
77 | return multi_crops
78 |
--------------------------------------------------------------------------------
/data_processing/RandAugment.py:
--------------------------------------------------------------------------------
1 |
2 | # code in this file is adpated from
3 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
4 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py
5 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py
6 | import logging
7 | import random
8 |
9 | import numpy as np
10 | import PIL
11 | import PIL.ImageOps
12 | import PIL.ImageEnhance
13 | import PIL.ImageDraw
14 | from PIL import Image
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | PARAMETER_MAX = 10
19 |
20 |
21 | def AutoContrast(img, **kwarg):
22 | return PIL.ImageOps.autocontrast(img)
23 |
24 |
25 | def Brightness(img, v, max_v, bias=0):
26 | v = _float_parameter(v, max_v) + bias
27 | return PIL.ImageEnhance.Brightness(img).enhance(v)
28 |
29 |
30 | def Color(img, v, max_v, bias=0):
31 | v = _float_parameter(v, max_v) + bias
32 | return PIL.ImageEnhance.Color(img).enhance(v)
33 |
34 |
35 | def Contrast(img, v, max_v, bias=0):
36 | v = _float_parameter(v, max_v) + bias
37 | return PIL.ImageEnhance.Contrast(img).enhance(v)
38 |
39 |
40 | def Cutout(img, v, max_v, bias=0):
41 | if v == 0:
42 | return img
43 | v = _float_parameter(v, max_v) + bias
44 | v = int(v * min(img.size))
45 | return CutoutAbs(img, v)
46 |
47 |
48 | def CutoutAbs(img, v, **kwarg):
49 | w, h = img.size
50 | x0 = np.random.uniform(0, w)
51 | y0 = np.random.uniform(0, h)
52 | x0 = int(max(0, x0 - v / 2.))
53 | y0 = int(max(0, y0 - v / 2.))
54 | x1 = int(min(w, x0 + v))
55 | y1 = int(min(h, y0 + v))
56 | xy = (x0, y0, x1, y1)
57 | # gray
58 | color = (127, 127, 127)
59 | img = img.copy()
60 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
61 | return img
62 |
63 |
64 | def Equalize(img, **kwarg):
65 | return PIL.ImageOps.equalize(img)
66 |
67 |
68 | def Identity(img, **kwarg):
69 | return img
70 |
71 |
72 | def Invert(img, **kwarg):
73 | return PIL.ImageOps.invert(img)
74 |
75 |
76 | def Posterize(img, v, max_v, bias=0):
77 | v = _int_parameter(v, max_v) + bias
78 | return PIL.ImageOps.posterize(img, v)
79 |
80 |
81 | def Rotate(img, v, max_v, bias=0):
82 | v = _int_parameter(v, max_v) + bias
83 | if random.random() < 0.5:
84 | v = -v
85 | return img.rotate(v)
86 |
87 |
88 | def Sharpness(img, v, max_v, bias=0):
89 | v = _float_parameter(v, max_v) + bias
90 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
91 |
92 |
93 | def ShearX(img, v, max_v, bias=0):
94 | v = _float_parameter(v, max_v) + bias
95 | if random.random() < 0.5:
96 | v = -v
97 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
98 |
99 |
100 | def ShearY(img, v, max_v, bias=0):
101 | v = _float_parameter(v, max_v) + bias
102 | if random.random() < 0.5:
103 | v = -v
104 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
105 |
106 |
107 | def Solarize(img, v, max_v, bias=0):
108 | v = _int_parameter(v, max_v) + bias
109 | return PIL.ImageOps.solarize(img, 256 - v)
110 |
111 |
112 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
113 | v = _int_parameter(v, max_v) + bias
114 | if random.random() < 0.5:
115 | v = -v
116 | img_np = np.array(img).astype(np.int)
117 | img_np = img_np + v
118 | img_np = np.clip(img_np, 0, 255)
119 | img_np = img_np.astype(np.uint8)
120 | img = Image.fromarray(img_np)
121 | return PIL.ImageOps.solarize(img, threshold)
122 |
123 |
124 | def TranslateX(img, v, max_v, bias=0):
125 | v = _float_parameter(v, max_v) + bias
126 | if random.random() < 0.5:
127 | v = -v
128 | v = int(v * img.size[0])
129 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
130 |
131 |
132 | def TranslateY(img, v, max_v, bias=0):
133 | v = _float_parameter(v, max_v) + bias
134 | if random.random() < 0.5:
135 | v = -v
136 | v = int(v * img.size[1])
137 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
138 |
139 |
140 | def _float_parameter(v, max_v):
141 | return float(v) * max_v / PARAMETER_MAX
142 |
143 |
144 | def _int_parameter(v, max_v):
145 | return int(v * max_v / PARAMETER_MAX)
146 |
147 |
148 | def fixmatch_augment_pool():
149 | # FixMatch paper
150 | augs = [(AutoContrast, None, None),
151 | (Brightness, 0.9, 0.05),
152 | (Color, 0.9, 0.05),
153 | (Contrast, 0.9, 0.05),
154 | (Equalize, None, None),
155 | (Identity, None, None),
156 | (Posterize, 4, 4),
157 | (Rotate, 30, 0),
158 | (Sharpness, 0.9, 0.05),
159 | (ShearX, 0.3, 0),
160 | (ShearY, 0.3, 0),
161 | (Solarize, 256, 0),
162 | (TranslateX, 0.3, 0),
163 | (TranslateY, 0.3, 0)]
164 | return augs
165 |
166 |
167 | def my_augment_pool():
168 | # Test
169 | augs = [(AutoContrast, None, None),
170 | (Brightness, 1.8, 0.1),
171 | (Color, 1.8, 0.1),
172 | (Contrast, 1.8, 0.1),
173 | (Cutout, 0.2, 0),
174 | (Equalize, None, None),
175 | (Invert, None, None),
176 | (Posterize, 4, 4),
177 | (Rotate, 30, 0),
178 | (Sharpness, 1.8, 0.1),
179 | (ShearX, 0.3, 0),
180 | (ShearY, 0.3, 0),
181 | (Solarize, 256, 0),
182 | (SolarizeAdd, 110, 0),
183 | (TranslateX, 0.45, 0),
184 | (TranslateY, 0.45, 0)]
185 | return augs
186 |
187 |
188 | class RandAugmentPC(object):
189 | def __init__(self, n, m):
190 | assert n >= 1
191 | assert 1 <= m <= 10
192 | self.n = n
193 | self.m = m
194 | self.augment_pool = my_augment_pool()
195 |
196 | def __call__(self, img):
197 | ops = random.choices(self.augment_pool, k=self.n)
198 | for op, max_v, bias in ops:
199 | prob = np.random.uniform(0.2, 0.8)
200 | if random.random() + prob >= 1:
201 | img = op(img, v=self.m, max_v=max_v, bias=bias)
202 | img = CutoutAbs(img, 16)
203 | return img
204 |
205 |
206 | class RandAugment(object):
207 | def __init__(self, n, m):
208 | assert n >= 0
209 | assert 1 <= m <= 10
210 | self.n = n
211 | self.m = m
212 | self.augment_pool = fixmatch_augment_pool()
213 | def __call__(self, img):
214 | ops = random.choices(self.augment_pool, k=self.n)
215 | for op, max_v, bias in ops:
216 | v = np.random.randint(1, self.m)
217 | if random.random() < 0.5:
218 | img = op(img, v=v, max_v=max_v, bias=bias)
219 | return img
220 |
221 |
--------------------------------------------------------------------------------
/data_processing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/data_processing/__init__.py
--------------------------------------------------------------------------------
/detection/configs/Base-RCNN-C4-BN.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN"
3 | RPN:
4 | PRE_NMS_TOPK_TEST: 6000
5 | POST_NMS_TOPK_TEST: 1000
6 | ROI_HEADS:
7 | NAME: "Res5ROIHeadsExtraNorm"
8 | BACKBONE:
9 | FREEZE_AT: 0
10 | RESNETS:
11 | NORM: "SyncBN"
12 | TEST:
13 | PRECISE_BN:
14 | ENABLED: True
15 | SOLVER:
16 | IMS_PER_BATCH: 16
17 | BASE_LR: 0.02
18 |
--------------------------------------------------------------------------------
/detection/configs/coco_R_50_C4_2x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-C4-BN.yaml"
2 | MODEL:
3 | MASK_ON: True
4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
5 | INPUT:
6 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
7 | MIN_SIZE_TEST: 800
8 | DATASETS:
9 | TRAIN: ("coco_2017_train",)
10 | TEST: ("coco_2017_val",)
11 | SOLVER:
12 | STEPS: (120000, 160000)
13 | MAX_ITER: 180000
14 |
--------------------------------------------------------------------------------
/detection/configs/coco_R_50_C4_2x_clsa.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "coco_R_50_C4_2x.yaml"
2 | MODEL:
3 | PIXEL_MEAN: [123.675, 116.280, 103.530]
4 | PIXEL_STD: [58.395, 57.120, 57.375]
5 | WEIGHTS: "See Instructions"
6 | RESNETS:
7 | STRIDE_IN_1X1: False
8 | INPUT:
9 | FORMAT: "RGB"
10 |
--------------------------------------------------------------------------------
/detection/configs/pascal_voc_R_50_C4_24k.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-C4-BN.yaml"
2 | MODEL:
3 | MASK_ON: False
4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
5 | ROI_HEADS:
6 | NUM_CLASSES: 20
7 | INPUT:
8 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
9 | MIN_SIZE_TEST: 800
10 | DATASETS:
11 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
12 | TEST: ('voc_2007_test',)
13 | SOLVER:
14 | STEPS: (18000, 22000)
15 | MAX_ITER: 24000
16 | WARMUP_ITERS: 100
17 |
--------------------------------------------------------------------------------
/detection/configs/pascal_voc_R_50_C4_24k_CLSA.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "pascal_voc_R_50_C4_24k.yaml"
2 | MODEL:
3 | PIXEL_MEAN: [123.675, 116.280, 103.530]
4 | PIXEL_STD: [58.395, 57.120, 57.375]
5 | WEIGHTS: "See Instructions"
6 | RESNETS:
7 | STRIDE_IN_1X1: False
8 | INPUT:
9 | FORMAT: "RGB"
10 |
--------------------------------------------------------------------------------
/detection/convert-pretrain-to-detectron2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | adopted from https://github.com/facebookresearch/moco/tree/master/detection
4 | Created on Thur Mar 4 14:37:25 2021
5 |
6 | @author: Facebook
7 | """
8 | import pickle as pkl
9 | import sys
10 | import torch
11 |
12 | if __name__ == "__main__":
13 | input = sys.argv[1]
14 |
15 | obj = torch.load(input, map_location="cpu")
16 | obj = obj["state_dict"]
17 |
18 | newmodel = {}
19 | for k, v in obj.items():
20 | if not k.startswith("module.encoder_q."):
21 | continue
22 | old_k = k
23 | k = k.replace("module.encoder_q.", "")
24 | if "layer" not in k:
25 | k = "stem." + k
26 | for t in [1, 2, 3, 4]:
27 | k = k.replace("layer{}".format(t), "res{}".format(t + 1))
28 | for t in [1, 2, 3]:
29 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
30 | k = k.replace("downsample.0", "shortcut")
31 | k = k.replace("downsample.1", "shortcut.norm")
32 | print(old_k, "->", k)
33 | newmodel[k] = v.numpy()
34 |
35 | res = {"model": newmodel, "__author__": "MOCO", "matching_heuristics": True}
36 |
37 | with open(sys.argv[2], "wb") as f:
38 | pkl.dump(res, f)
39 |
--------------------------------------------------------------------------------
/detection/train_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | adopted from https://github.com/facebookresearch/moco/tree/master/detection
4 | Created on Thur Mar 4 14:37:25 2021
5 |
6 | @author: Facebook
7 | """
8 |
9 | import os
10 |
11 | from detectron2.checkpoint import DetectionCheckpointer
12 | from detectron2.config import get_cfg
13 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
14 | from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator
15 | from detectron2.layers import get_norm
16 | from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
17 |
18 |
19 | @ROI_HEADS_REGISTRY.register()
20 | class Res5ROIHeadsExtraNorm(Res5ROIHeads):
21 | """
22 | As described in the MOCO paper, there is an extra BN layer
23 | following the res5 stage.
24 | """
25 | def _build_res5_block(self, cfg):
26 | seq, out_channels = super()._build_res5_block(cfg)
27 | norm = cfg.MODEL.RESNETS.NORM
28 | norm = get_norm(norm, out_channels)
29 | seq.add_module("norm", norm)
30 | return seq, out_channels
31 |
32 |
33 | class Trainer(DefaultTrainer):
34 | @classmethod
35 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
36 | if output_folder is None:
37 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
38 | if "coco" in dataset_name:
39 | return COCOEvaluator(dataset_name, cfg, True, output_folder)
40 | else:
41 | assert "voc" in dataset_name
42 | return PascalVOCDetectionEvaluator(dataset_name)
43 |
44 |
45 | def setup(args):
46 | cfg = get_cfg()
47 | cfg.merge_from_file(args.config_file)
48 | cfg.merge_from_list(args.opts)
49 | cfg.freeze()
50 | default_setup(cfg, args)
51 | return cfg
52 |
53 |
54 | def main(args):
55 | cfg = setup(args)
56 |
57 | if args.eval_only:
58 | model = Trainer.build_model(cfg)
59 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
60 | cfg.MODEL.WEIGHTS, resume=args.resume
61 | )
62 | res = Trainer.test(cfg, model)
63 | return res
64 |
65 | trainer = Trainer(cfg)
66 | trainer.resume_or_load(resume=args.resume)
67 | return trainer.train()
68 |
69 |
70 | if __name__ == "__main__":
71 | args = default_argument_parser().parse_args()
72 | print("Command Line Args:", args)
73 | launch(
74 | main,
75 | args.num_gpus,
76 | num_machines=args.num_machines,
77 | machine_rank=args.machine_rank,
78 | dist_url=args.dist_url,
79 | args=(args,),
80 | )
81 |
--------------------------------------------------------------------------------
/lincls.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | import warnings
4 |
5 | warnings.filterwarnings('ignore')
6 | import argparse
7 | import builtins
8 | import os
9 | import random
10 | import shutil
11 | import time
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.parallel
16 | import torch.backends.cudnn as cudnn
17 | import torch.distributed as dist
18 | import torch.optim
19 | import torch.multiprocessing as mp
20 | import torch.utils.data
21 | import torch.utils.data.distributed
22 | import torchvision.transforms as transforms
23 | import torchvision.datasets as datasets
24 | import torchvision.models as models
25 |
26 | from data_processing.loader import GaussianBlur
27 | from ops.os_operation import mkdir
28 | from training.train_utils import accuracy
29 |
30 | model_names = sorted(name for name in models.__dict__
31 | if name.islower() and not name.startswith("__")
32 | and callable(models.__dict__[name]))
33 |
34 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
35 | parser.add_argument('--data', type=str, metavar='DIR',
36 | help='path to dataset')
37 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
38 | choices=model_names,
39 | help='model architecture: ' +
40 | ' | '.join(model_names) +
41 | ' (default: resnet50)')
42 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
43 | help='number of data loading workers (default: 32)')
44 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
45 | help='number of total epochs to run')
46 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
47 | help='manual epoch number (useful on restarts)')
48 | parser.add_argument('--batch-size', default=256, type=int,
49 | metavar='N',
50 | help='mini-batch size (default: 256), this is the total '
51 | 'batch size of all GPUs on the current node when '
52 | 'using Data Parallel or Distributed Data Parallel')
53 | parser.add_argument('--lr', '--learning-rate', default=10., type=float,
54 | metavar='LR', help='initial learning rate', dest='lr')
55 | parser.add_argument('--schedule', default=[15, 25, 30], nargs='*', type=int,
56 | help='learning rate schedule (when to drop lr by a ratio)') # default is for places205
57 | parser.add_argument('--cos', type=int, default=1,
58 | help='use cosine lr schedule')
59 |
60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
61 | help='momentum')
62 | parser.add_argument('--wd', '--weight-decay', default=0., type=float,
63 | metavar='W', help='weight decay (default: 0.)',
64 | dest='weight_decay')
65 | parser.add_argument('-p', '--print-freq', default=10, type=int,
66 | metavar='N', help='print frequency (default: 10)')
67 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
68 | help='path to latest checkpoint (default: none)')
69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
70 | help='evaluate model on validation set')
71 | parser.add_argument('--world-size', default=1, type=int,
72 | help='number of nodes for distributed training')
73 | parser.add_argument('--rank', default=0, type=int,
74 | help='node rank for distributed training')
75 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
76 | help='url used to set up distributed training')
77 | parser.add_argument('--dist-backend', default='nccl', type=str,
78 | help='distributed backend')
79 | parser.add_argument('--seed', default=None, type=int,
80 | help='seed for initializing training. ')
81 | parser.add_argument('--gpu', default=None, type=int,
82 | help='GPU id to use.')
83 | parser.add_argument('--multiprocessing-distributed', type=int, default=1,
84 | help='Use multi-processing distributed training to launch '
85 | 'N processes per node, which has N GPUs. This is the '
86 | 'fastest way to use PyTorch for either single node or '
87 | 'multi node data parallel training')
88 |
89 | parser.add_argument('--pretrained', default='', type=str,
90 | help='path to moco pretrained checkpoint')
91 | parser.add_argument('--choose', type=str, default=None, help="choose gpu for training")
92 | parser.add_argument("--dataset", type=str, default="ImageNet", help="which dataset is used to finetune")
93 | parser.add_argument("--aug", type=int, default=0, help="use augmentation or not during fine tuning")
94 | parser.add_argument("--size_crops", type=int, default=[224, 192, 160, 128, 96], nargs="+",
95 | help="crops resolutions (example: [224, 96])")
96 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.172, 0.143, 0.114, 0.086], nargs="+",
97 | help="argument in RandomResizedCrop (example: [0.14, 0.05])")
98 | parser.add_argument("--max_scale_crops", type=float, default=[1.0, 0.86, 0.715, 0.571, 0.429], nargs="+",
99 | help="argument in RandomResizedCrop (example: [1., 0.14])")
100 | parser.add_argument("--add_crop", type=int, default=0, help="use crop or not in our training dataset")
101 | parser.add_argument("--strong", type=int, default=0, help="use strong augmentation or not")
102 | parser.add_argument("--final_lr", type=float, default=0.01, help="ending learning rate for training")
103 | parser.add_argument("--aug_type", type=int, default=0, help="augmentation type for our condition")
104 | parser.add_argument('--save_path', default="", type=str, help="model and record save path")
105 | parser.add_argument('--log_path', type=str, default="train_log", help="log path for saving models")
106 | parser.add_argument("--nodes_num", type=int, default=1, help="number of nodes to use")
107 | parser.add_argument("--ngpu", type=int, default=8, help="number of gpus per node")
108 | parser.add_argument("--master_addr", type=str, default="127.0.0.1", help="addr for master node")
109 | parser.add_argument("--master_port", type=str, default="1234", help="port for master node")
110 | parser.add_argument('--node_rank', type=int, default=0, help='rank of machine, 0 to nodes_num-1')
111 | parser.add_argument("--final", default=0, type=int, help="use the final specified augment or not")
112 | parser.add_argument("--avg_pool", default=1, type=int, help="average pool output size")
113 | parser.add_argument("--crop_scale", type=float, default=[0.2, 1.0], nargs="+",
114 | help="argument in RandomResizedCrop (example: [1., 0.14])")
115 | parser.add_argument("--train_strong", type=int, default=0, help="training use stronger augmentation or not")
116 | parser.add_argument("--sgdr", type=int, default=0, help="training with warm up (1) or restart warm up (2)")
117 | parser.add_argument("--sgdr_t0", type=int, default=10, help="sgdr t0")
118 | parser.add_argument("--sgdr_t_mult", type=int, default=1, help="sgdr t mult")
119 | parser.add_argument("--dropout", type=float, default=0.0, help="dropout layer settings")
120 | parser.add_argument("--randcrop", type=int, default=0, help="use random crop or not")
121 | best_acc1 = 0
122 |
123 |
124 | def main():
125 | args = parser.parse_args()
126 | choose = args.choose
127 | if choose is not None:
128 | os.environ['CUDA_VISIBLE_DEVICES'] = choose
129 | if args.seed is not None:
130 | random.seed(args.seed)
131 | torch.manual_seed(args.seed)
132 | cudnn.deterministic = True
133 | warnings.warn('You have chosen to seed training. '
134 | 'This will turn on the CUDNN deterministic setting, '
135 | 'which can slow down your training considerably! '
136 | 'You may see unexpected behavior when restarting '
137 | 'from checkpoints.')
138 |
139 | if args.gpu is not None:
140 | warnings.warn('You have chosen a specific GPU. This will completely '
141 | 'disable data parallelism.')
142 |
143 | if args.dist_url == "env://" and args.world_size == -1:
144 | args.world_size = int(os.environ["WORLD_SIZE"])
145 |
146 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
147 | params = vars(args)
148 | data_path = args.data # the path stored
149 | args.data = data_path
150 | ngpus_per_node = torch.cuda.device_count()
151 | if args.multiprocessing_distributed:
152 | # Since we have ngpus_per_node processes per node, the total world_size
153 | # needs to be adjusted accordingly
154 | args.world_size = ngpus_per_node * args.world_size
155 | # Use torch.multiprocessing.spawn to launch distributed processes: the
156 | # main_worker process function
157 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
158 | else:
159 | # Simply call main_worker function
160 | main_worker(args.gpu, ngpus_per_node, args)
161 |
162 |
163 | def main_worker(gpu, ngpus_per_node, args):
164 | global best_acc1
165 | args.gpu = gpu
166 | params = vars(args)
167 | # suppress printing if not master
168 | if args.multiprocessing_distributed and args.gpu != 0:
169 | def print_pass(*args):
170 | pass
171 |
172 | builtins.print = print_pass
173 |
174 | if args.gpu is not None:
175 | print("Use GPU: {} for training".format(args.gpu))
176 |
177 | if args.distributed:
178 | if args.dist_url == "env://" and args.rank == -1:
179 | args.rank = int(os.environ["RANK"])
180 | if args.multiprocessing_distributed:
181 | # For multiprocessing distributed training, rank needs to be the
182 | # global rank among all the processes
183 | args.rank = args.rank * ngpus_per_node + gpu
184 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
185 | world_size=args.world_size, rank=args.rank)
186 | # create model
187 | print("=> creating model '{}'".format(args.arch))
188 | if args.dataset == "Place205":
189 | num_classes = 205
190 | else:
191 | num_classes = 1000
192 |
193 | model = models.__dict__[args.arch](num_classes=num_classes)
194 |
195 | # freeze all layers but the last fc
196 | for name, param in model.named_parameters():
197 | if name not in ['fc.weight', 'fc.bias']:
198 | param.requires_grad = False
199 |
200 | # init the fc layer
201 | model.fc.weight.data.normal_(mean=0.0, std=0.01)
202 | model.fc.bias.data.zero_()
203 |
204 | # load from pre-trained, before DistributedDataParallel constructor
205 | if args.pretrained:
206 |
207 | if os.path.isfile(args.pretrained):
208 | print("=> loading checkpoint '{}'".format(args.pretrained))
209 |
210 | checkpoint = torch.load(args.pretrained, map_location="cpu")
211 | state_dict = checkpoint['state_dict']
212 | for k in list(state_dict.keys()):
213 | # retain only encoder_q up to before the embedding layer
214 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
215 | # remove prefix
216 | state_dict[k[len("module.encoder_q."):]] = state_dict[k]
217 | # delete renamed or unused k
218 | del state_dict[k]
219 |
220 | args.start_epoch = 0
221 | msg = model.load_state_dict(state_dict, strict=False)
222 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
223 |
224 | print("=> loaded pre-trained model '{}'".format(args.pretrained))
225 | else:
226 | print("=> no checkpoint found at '{}'".format(args.pretrained))
227 |
228 | if args.dropout != 0.0:
229 | model.fc = nn.Sequential(nn.Dropout(args.dropout), model.fc)
230 | if args.distributed:
231 | # For multiprocessing distributed, DistributedDataParallel constructor
232 | # should always set the single device scope, otherwise,
233 | # DistributedDataParallel will use all available devices.
234 | if args.gpu is not None:
235 | torch.cuda.set_device(args.gpu)
236 | model.cuda(args.gpu)
237 | # When using a single GPU per process and per
238 | # DistributedDataParallel, we need to divide the batch size
239 | # ourselves based on the total number of GPUs we have
240 | args.batch_size = int(args.batch_size / ngpus_per_node)
241 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
242 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
243 | else:
244 | model.cuda()
245 | # DistributedDataParallel will divide and allocate batch_size to all
246 | # available GPUs if device_ids are not set
247 | model = torch.nn.parallel.DistributedDataParallel(model)
248 | elif args.gpu is not None:
249 | torch.cuda.set_device(args.gpu)
250 | model = model.cuda(args.gpu)
251 | else:
252 | # DataParallel will divide and allocate batch_size to all available GPUs
253 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
254 | model.features = torch.nn.DataParallel(model.features)
255 | model.cuda()
256 | else:
257 | model = torch.nn.DataParallel(model).cuda()
258 |
259 | # define loss function (criterion) and optimizer
260 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
261 |
262 | # optimize only the linear classifier
263 | parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
264 | assert len(parameters) == 2 # fc.weight, fc.bias
265 | optimizer = torch.optim.SGD(parameters, args.lr,
266 | momentum=args.momentum,
267 | weight_decay=args.weight_decay)
268 |
269 | # optionally resume from a checkpoint
270 | if args.resume:
271 | if os.path.isfile(args.resume):
272 | print("=> loading checkpoint '{}'".format(args.resume))
273 | if args.gpu is None:
274 | checkpoint = torch.load(args.resume)
275 | else:
276 | # Map model to be loaded to specified single gpu.
277 | loc = 'cuda:{}'.format(args.gpu)
278 | checkpoint = torch.load(args.resume, map_location=loc)
279 |
280 | args.start_epoch = checkpoint['epoch']
281 | best_acc1 = torch.tensor(checkpoint['best_acc1'])
282 | if args.gpu is not None:
283 | # best_acc1 may be from a checkpoint from a different GPU
284 | best_acc1 = best_acc1.to(args.gpu)
285 | model.load_state_dict(checkpoint['state_dict'])
286 | optimizer.load_state_dict(checkpoint['optimizer'])
287 | print("=> loaded checkpoint '{}' (epoch {})"
288 | .format(args.resume, checkpoint['epoch']))
289 | else:
290 | print("=> no checkpoint found at '{}'".format(args.resume))
291 |
292 | cudnn.benchmark = True
293 |
294 | # Data loading code
295 | if args.dataset == "ImageNet":
296 | data_path = args.data
297 | traindir = os.path.join(data_path, 'train')
298 | valdir = os.path.join(data_path, 'val')
299 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
300 | std=[0.229, 0.224, 0.225])
301 | if args.train_strong:
302 | transform_train = transforms.Compose([
303 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
304 | transforms.RandomApply([
305 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
306 | ], p=0.8),
307 | transforms.RandomGrayscale(p=0.2),
308 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
309 | transforms.RandomHorizontalFlip(),
310 | transforms.ToTensor(),
311 | normalize
312 | ])
313 | elif args.randcrop:
314 | transform_train = transforms.Compose([
315 | transforms.RandomCrop(224, pad_if_needed=True),
316 | transforms.RandomHorizontalFlip(),
317 | transforms.ToTensor(),
318 | normalize, ])
319 |
320 | else:
321 | transform_train = transforms.Compose([
322 | transforms.RandomResizedCrop(224),
323 | transforms.RandomHorizontalFlip(),
324 | transforms.ToTensor(),
325 | normalize, ])
326 | transform_test = transforms.Compose([
327 | transforms.Resize(256),
328 | transforms.CenterCrop(224),
329 | transforms.ToTensor(),
330 | normalize,
331 | ])
332 | train_dataset = datasets.ImageFolder(traindir, transform_train)
333 | val_dataset = datasets.ImageFolder(valdir, transform_test)
334 |
335 | if args.distributed:
336 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
337 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset,
338 | shuffle=True) # different gpu forward individual based on its own statistics
339 | # val_sampler=None
340 | else:
341 | train_sampler = None
342 | val_sampler = None
343 |
344 | train_loader = torch.utils.data.DataLoader(
345 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
346 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
347 |
348 | val_loader = torch.utils.data.DataLoader(
349 | val_dataset, sampler=val_sampler,
350 | batch_size=args.batch_size, shuffle=(val_sampler is None),
351 | # different gpu forward is different, thus it's necessary
352 | num_workers=args.workers, pin_memory=True)
353 |
354 |
355 | elif args.dataset == "Place205":
356 | from data_processing.Place205_Dataset import Places205
357 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
358 | std=[0.229, 0.224, 0.225])
359 | if args.train_strong:
360 | if args.randcrop:
361 | transform_train = transforms.Compose([
362 | transforms.RandomCrop(224),
363 | transforms.RandomApply([
364 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
365 | ], p=0.8),
366 | transforms.RandomGrayscale(p=0.2),
367 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
368 | transforms.RandomHorizontalFlip(),
369 | transforms.ToTensor(),
370 | normalize
371 | ])
372 | else:
373 | transform_train = transforms.Compose([
374 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
375 | transforms.RandomApply([
376 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
377 | ], p=0.8),
378 | transforms.RandomGrayscale(p=0.2),
379 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
380 | transforms.RandomHorizontalFlip(),
381 | transforms.ToTensor(),
382 | normalize
383 | ])
384 | else:
385 | if args.randcrop:
386 | transform_train = transforms.Compose([
387 | transforms.RandomCrop(224),
388 | transforms.RandomHorizontalFlip(),
389 | transforms.ToTensor(),
390 | normalize, ])
391 |
392 | else:
393 | transform_train = transforms.Compose([
394 | transforms.RandomResizedCrop(224),
395 | transforms.RandomHorizontalFlip(),
396 | transforms.ToTensor(),
397 | normalize, ])
398 | # waiting to add 10 crop
399 | transform_valid = transforms.Compose([
400 | transforms.Resize([256, 256]),
401 | transforms.CenterCrop(224),
402 | transforms.ToTensor(),
403 | normalize,
404 | ])
405 |
406 | train_dataset = Places205(args.data, 'train', transform_train)
407 | valid_dataset = Places205(args.data, 'val', transform_valid)
408 | if args.distributed:
409 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
410 | val_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset, shuffle=False)
411 | # val_sampler = None
412 | else:
413 | train_sampler = None
414 | val_sampler = None
415 |
416 | train_loader = torch.utils.data.DataLoader(
417 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
418 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
419 |
420 | val_loader = torch.utils.data.DataLoader(
421 | valid_dataset, sampler=val_sampler,
422 | batch_size=args.batch_size,
423 | num_workers=args.workers, pin_memory=True)
424 |
425 | else:
426 | print("your dataset %s is not supported for finetuning now" % args.dataset)
427 | exit()
428 |
429 | if args.evaluate:
430 | validate(val_loader, model, criterion, args)
431 | return
432 | import datetime
433 | today = datetime.date.today()
434 | formatted_today = today.strftime('%y%m%d')
435 | now = time.strftime("%H:%M:%S")
436 |
437 | save_path = os.path.join(args.save_path, args.log_path)
438 | log_path = os.path.join(save_path, 'Finetune_log')
439 | mkdir(log_path)
440 | log_path = os.path.join(log_path, formatted_today + now)
441 | mkdir(log_path)
442 | # model_path=os.path.join(log_path,'checkpoint.pth.tar')
443 | lr_scheduler = None
444 | if args.sgdr == 1:
445 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 12)
446 | elif args.sgdr == 2:
447 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.sgdr_t0, args.sgdr_t_mult)
448 | for epoch in range(args.start_epoch, args.epochs):
449 | if args.distributed:
450 | train_sampler.set_epoch(epoch)
451 | if args.sgdr == 0:
452 | adjust_learning_rate(optimizer, epoch, args)
453 | train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler)
454 | # evaluate on validation set
455 | acc1 = validate(val_loader, model, criterion, args)
456 | # remember best acc@1 and save checkpoint
457 | is_best = acc1 > best_acc1
458 | best_acc1 = max(acc1, best_acc1)
459 |
460 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
461 | and args.rank % ngpus_per_node == 0):
462 | # add timestamp
463 | tmp_save_path = os.path.join(log_path, 'checkpoint.pth.tar')
464 | save_checkpoint({
465 | 'epoch': epoch + 1,
466 | 'arch': args.arch,
467 | 'state_dict': model.state_dict(),
468 | 'best_acc1': best_acc1,
469 | 'optimizer': optimizer.state_dict(),
470 | }, is_best, filename=tmp_save_path)
471 |
472 | if abs(args.epochs - epoch) <= 20:
473 | tmp_save_path = os.path.join(log_path, 'model_%d.pth.tar' % epoch)
474 | save_checkpoint({
475 | 'epoch': epoch + 1,
476 | 'arch': args.arch,
477 | 'state_dict': model.state_dict(),
478 | 'best_acc1': best_acc1,
479 | 'optimizer': optimizer.state_dict(),
480 | }, False, filename=tmp_save_path)
481 |
482 |
483 |
484 | def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler):
485 | batch_time = AverageMeter('Time', ':6.3f')
486 | data_time = AverageMeter('Data', ':6.3f')
487 | losses = AverageMeter('Loss', ':.4e')
488 | top1 = AverageMeter('Acc@1', ':6.2f')
489 | top5 = AverageMeter('Acc@5', ':6.2f')
490 | mAP = AverageMeter("mAP", ":6.2f")
491 | progress = ProgressMeter(
492 | len(train_loader),
493 | [batch_time, data_time, losses, top1, top5, mAP],
494 | prefix="Epoch: [{}]".format(epoch))
495 |
496 | """
497 | Switch to eval mode:
498 | Under the protocol of linear classification on frozen features/models,
499 | it is not legitimate to change any part of the pre-trained model.
500 | BatchNorm in train mode may revise running mean/std (even if it receives
501 | no gradient), which are part of the model parameters too.
502 | """
503 | model.eval()
504 | batch_total = len(train_loader)
505 | end = time.time()
506 | for i, (images, target) in enumerate(train_loader):
507 | # measure data loading time
508 | data_time.update(time.time() - end)
509 | # adjust_batch_learning_rate(optimizer, epoch, i, batch_total, args)
510 |
511 | if args.gpu is not None:
512 | images = images.cuda(args.gpu, non_blocking=True)
513 |
514 | target = target.cuda(args.gpu, non_blocking=True)
515 |
516 | # compute output
517 | output = model(images)
518 | loss = criterion(output, target)
519 |
520 | # measure accuracy and record loss
521 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
522 | losses.update(loss.item(), images.size(0))
523 | top1.update(acc1.item(), images.size(0))
524 | top5.update(acc5.item(), images.size(0))
525 |
526 | # compute gradient and do SGD step
527 | optimizer.zero_grad()
528 | loss.backward()
529 | optimizer.step()
530 |
531 | if args.sgdr != 0:
532 | lr_scheduler.step(epoch + i / batch_total)
533 |
534 | # measure elapsed time
535 | batch_time.update(time.time() - end)
536 | end = time.time()
537 |
538 | if i % args.print_freq == 0:
539 | progress.display(i)
540 |
541 |
542 | def train2(train_loader, model, criterion, optimizer, epoch, args):
543 | batch_time = AverageMeter('Time', ':6.3f')
544 | data_time = AverageMeter('Data', ':6.3f')
545 | losses = AverageMeter('Loss', ':.4e')
546 | top1 = AverageMeter('Acc@1', ':6.2f')
547 | top5 = AverageMeter('Acc@5', ':6.2f')
548 | mAP = AverageMeter("mAP", ":6.2f")
549 | progress = ProgressMeter(
550 | len(train_loader),
551 | [batch_time, data_time, losses, top1, top5, mAP],
552 | prefix="Epoch: [{}]".format(epoch))
553 |
554 | """
555 | Switch to eval mode:
556 | Under the protocol of linear classification on frozen features/models,
557 | it is not legitimate to change any part of the pre-trained model.
558 | BatchNorm in train mode may revise running mean/std (even if it receives
559 | no gradient), which are part of the model parameters too.
560 | """
561 | model.eval()
562 |
563 | end = time.time()
564 | for i, (images, target) in enumerate(train_loader):
565 | # measure data loading time
566 | data_time.update(time.time() - end)
567 |
568 | if args.gpu is not None:
569 | len_images = len(images)
570 | for k in range(len(images)):
571 | images[k] = images[k].cuda(args.gpu, non_blocking=True)
572 |
573 | target = target.cuda(args.gpu, non_blocking=True)
574 | len_images = len(images)
575 |
576 | first_output = -1
577 | for k in range(len_images):
578 | # compute gradient and do SGD step
579 | optimizer.zero_grad()
580 | output = model(images[k])
581 | loss = criterion(output, target)
582 | loss.backward()
583 | optimizer.step()
584 | losses.update(loss.item(), images[k].size(0))
585 | if k == 0:
586 | first_output = output
587 |
588 | images = images[0]
589 | output = first_output
590 |
591 | # measure accuracy and record loss
592 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
593 |
594 | top1.update(acc1.item(), images.size(0))
595 | top5.update(acc5.item(), images.size(0))
596 |
597 | # measure elapsed time
598 | batch_time.update(time.time() - end)
599 | end = time.time()
600 |
601 | if i % args.print_freq == 0:
602 | progress.display(i)
603 |
604 |
605 | def validate(val_loader, model, criterion, args):
606 | batch_time = AverageMeter('Time', ':6.3f')
607 | losses = AverageMeter('Loss', ':.4e')
608 | top1 = AverageMeter('Acc@1', ':6.2f')
609 | top5 = AverageMeter('Acc@5', ':6.2f')
610 | mAP = AverageMeter("mAP", ":6.2f")
611 | progress = ProgressMeter(
612 | len(val_loader),
613 | [batch_time, losses, top1, top5, mAP],
614 | prefix='Test: ')
615 |
616 | # switch to evaluate mode
617 | model.eval()
618 | with torch.no_grad():
619 | end = time.time()
620 | for i, (images, target) in enumerate(val_loader):
621 | target = target.cuda(args.gpu, non_blocking=True)
622 | output = model(images)
623 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
624 | acc1 = torch.mean(concat_all_gather(acc1.unsqueeze(0)), dim=0, keepdim=True)
625 | acc5 = torch.mean(concat_all_gather(acc5.unsqueeze(0)), dim=0, keepdim=True)
626 | top1.update(acc1.item(), images.size(0))
627 | top5.update(acc5.item(), images.size(0))
628 | loss = criterion(output, target)
629 | losses.update(loss.item(), images.size(0))
630 | # measure elapsed time
631 | batch_time.update(time.time() - end)
632 | end = time.time()
633 |
634 | if i % args.print_freq == 0:
635 | progress.display(i)
636 |
637 | # TODO: this should also be done with the ProgressMeter
638 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} '
639 | .format(top1=top1, top5=top5, mAP=mAP))
640 |
641 | return top1.avg
642 |
643 |
644 | def testing(val_loader, model, criterion, args):
645 | batch_time = AverageMeter('Time', ':6.3f')
646 | losses = AverageMeter('Loss', ':.4e')
647 | top1 = AverageMeter('Acc@1', ':6.2f')
648 | top5 = AverageMeter('Acc@5', ':6.2f')
649 | mAP = AverageMeter("mAP", ":6.2f")
650 | progress = ProgressMeter(
651 | len(val_loader),
652 | [batch_time, losses, top1, top5, mAP],
653 | prefix='Test: ')
654 |
655 | # switch to evaluate mode
656 | model.eval()
657 | correct_count = 0
658 | count_all = 0
659 | # implement our own random crop
660 | with torch.no_grad():
661 | end = time.time()
662 | for i, (images, target) in enumerate(val_loader):
663 | target = target.cuda(args.gpu, non_blocking=True)
664 | output_list = []
665 | for image in images:
666 | output = model(image)
667 | output = torch.softmax(output, dim=1)
668 | output_list.append(output)
669 | output_list = torch.stack(output_list, dim=0)
670 | output_list, max_index = torch.max(output_list, dim=0)
671 | output = output_list
672 | images = images[0]
673 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
674 | acc1 = torch.mean(concat_all_gather(acc1.unsqueeze(0)), dim=0, keepdim=True)
675 | acc5 = torch.mean(concat_all_gather(acc5.unsqueeze(0)), dim=0, keepdim=True)
676 | correct_count += float(acc1[0]) * images.size(0)
677 | count_all += images.size(0)
678 | top1.update(acc1.item(), images.size(0))
679 | top5.update(acc5.item(), images.size(0))
680 | loss = criterion(output, target)
681 | losses.update(loss.item(), images.size(0))
682 | # measure elapsed time
683 | batch_time.update(time.time() - end)
684 | end = time.time()
685 |
686 | if i % args.print_freq == 0:
687 | progress.display(i)
688 |
689 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} '
690 | .format(top1=top1, top5=top5, mAP=mAP))
691 | final_accu = correct_count / count_all
692 | print("$$our final calculated accuracy %.7f" % final_accu)
693 | return top1.avg
694 |
695 |
696 | def testing2(val_loader, model, criterion, args):
697 | batch_time = AverageMeter('Time', ':6.3f')
698 | losses = AverageMeter('Loss', ':.4e')
699 | top1 = AverageMeter('Acc@1', ':6.2f')
700 | top5 = AverageMeter('Acc@5', ':6.2f')
701 | mAP = AverageMeter("mAP", ":6.2f")
702 | progress = ProgressMeter(
703 | len(val_loader),
704 | [batch_time, losses, top1, top5, mAP],
705 | prefix='Test: ')
706 |
707 | # switch to evaluate mode
708 | model.eval()
709 | correct_count = 0
710 | count_all = 0
711 | # implement our own random crop
712 | with torch.no_grad():
713 | end = time.time()
714 | for i, (images, target) in enumerate(val_loader):
715 | target = target.cuda(args.gpu, non_blocking=True)
716 | output_list = []
717 | for image in images:
718 | output = model(image)
719 | output = torch.softmax(output, dim=1)
720 | output_list.append(output)
721 | output_list = torch.stack(output_list, dim=0)
722 | output_list = torch.mean(output_list, dim=0)
723 | output = output_list
724 | images = images[0]
725 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
726 | acc1 = torch.mean(concat_all_gather(acc1), dim=0, keepdim=True)
727 | acc5 = torch.mean(concat_all_gather(acc5), dim=0, keepdim=True)
728 | correct_count += float(acc1[0]) * images.size(0)
729 | count_all += images.size(0)
730 | top1.update(acc1.item(), images.size(0))
731 | top5.update(acc5.item(), images.size(0))
732 | loss = criterion(output, target)
733 | losses.update(loss.item(), images.size(0))
734 | # measure elapsed time
735 | batch_time.update(time.time() - end)
736 | end = time.time()
737 |
738 | if i % args.print_freq == 0:
739 | progress.display(i)
740 |
741 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} '
742 | .format(top1=top1, top5=top5, mAP=mAP))
743 | final_accu = correct_count / count_all
744 | print("$$our final average accuracy %.7f" % final_accu)
745 | return top1.avg
746 |
747 |
748 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
749 | torch.save(state, filename)
750 | if is_best:
751 | root_path = os.path.split(filename)[0]
752 | best_path = os.path.join(root_path, "model_best.pth.tar")
753 | shutil.copyfile(filename, best_path)
754 |
755 |
756 | def sanity_check(state_dict, pretrained_weights):
757 | """
758 | Linear classifier should not change any weights other than the linear layer.
759 | This sanity check asserts nothing wrong happens (e.g., BN stats updated).
760 | """
761 | print("=> loading '{}' for sanity check".format(pretrained_weights))
762 | checkpoint = torch.load(pretrained_weights, map_location="cpu")
763 | state_dict_pre = checkpoint['state_dict']
764 |
765 | for k in list(state_dict.keys()):
766 | # only ignore fc layer
767 | if 'fc.weight' in k or 'fc.bias' in k:
768 | continue
769 |
770 | # name in pretrained model
771 | k_pre = 'module.encoder_q.' + k[len('module.'):] \
772 | if k.startswith('module.') else 'module.encoder_q.' + k
773 |
774 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
775 | '{} is changed in linear classifier training.'.format(k)
776 |
777 | print("=> sanity check passed.")
778 |
779 |
780 | class AverageMeter(object):
781 | """Computes and stores the average and current value"""
782 |
783 | def __init__(self, name, fmt=':f'):
784 | self.name = name
785 | self.fmt = fmt
786 | self.reset()
787 |
788 | def reset(self):
789 | self.val = 0
790 | self.avg = 0
791 | self.sum = 0
792 | self.count = 0
793 |
794 | def update(self, val, n=1):
795 | self.val = val
796 | self.sum += val * n
797 | self.count += n
798 | self.avg = self.sum / self.count
799 |
800 | def __str__(self):
801 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
802 | return fmtstr.format(**self.__dict__)
803 |
804 |
805 | class ProgressMeter(object):
806 | def __init__(self, num_batches, meters, prefix=""):
807 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
808 | self.meters = meters
809 | self.prefix = prefix
810 |
811 | def display(self, batch):
812 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
813 | entries += [str(meter) for meter in self.meters]
814 | print('\t'.join(entries))
815 |
816 | def _get_batch_fmtstr(self, num_batches):
817 | num_digits = len(str(num_batches // 1))
818 | fmt = '{:' + str(num_digits) + 'd}'
819 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
820 |
821 |
822 | import math
823 |
824 |
825 | def adjust_learning_rate(optimizer, epoch, args):
826 | """Decay the learning rate based on schedule"""
827 | lr = args.lr
828 | end_lr = args.final_lr
829 | # update on cos scheduler
830 | # this scheduler is not proper enough
831 | if args.cos:
832 | lr = 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (lr - end_lr) + end_lr
833 | else:
834 | for milestone in args.schedule:
835 | lr *= 0.1 if epoch >= milestone else 1.
836 | for param_group in optimizer.param_groups:
837 | param_group['lr'] = lr
838 |
839 |
840 | def adjust_batch_learning_rate(optimizer, cur_epoch, cur_batch, batch_total, args):
841 | """Decay the learning rate based on schedule"""
842 | init_lr = args.lr
843 | # end_lr=args.final_lr
844 | # update on cos scheduler
845 | # this scheduler is not proper enough
846 | current_schdule = 0
847 | # use_epoch=cur_epoch
848 | last_milestone = 0
849 | for milestone in args.schedule:
850 | if cur_epoch > milestone:
851 | current_schdule += 1
852 | init_lr *= 0.1
853 | last_milestone = milestone
854 | else:
855 | cur_epoch -= last_milestone
856 | break
857 | if current_schdule < len(args.schedule):
858 | all_epochs = args.schedule[current_schdule]
859 | else:
860 | all_epochs = args.epochs
861 | end_lr = init_lr * 0.1
862 | lr = math.cos(
863 | 0.5 * math.pi * (cur_batch + cur_epoch * batch_total) / ((all_epochs - last_milestone) * batch_total)) * (
864 | init_lr - end_lr) + end_lr
865 | if cur_batch % 50 == 0:
866 | print("[%d] %d/%d learing rate %.9f" % (cur_epoch, cur_batch, batch_total, lr))
867 | for param_group in optimizer.param_groups:
868 | param_group['lr'] = lr
869 |
870 |
871 |
872 |
873 |
874 | # utils
875 | @torch.no_grad()
876 | def concat_all_gather(tensor):
877 | """
878 | Performs all_gather operation on the provided tensors.
879 | *** Warning ***: torch.distributed.all_gather has no gradient.
880 | """
881 | tensors_gather = [torch.ones_like(tensor)
882 | for _ in range(torch.distributed.get_world_size())]
883 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
884 |
885 | output = torch.cat(tensors_gather, dim=0)
886 | return output
887 |
888 |
889 | if __name__ == '__main__':
890 | main()
891 |
--------------------------------------------------------------------------------
/main_clsa.py:
--------------------------------------------------------------------------------
1 | #Copyright (C) 2020 Xiao Wang
2 | #License: MIT for academic use.
3 | #Contact: Xiao Wang (wang3702@purdue.edu, xiaowang20140001@gmail.com)
4 |
5 | #Some codes adopted from https://github.com/facebookresearch/moco
6 |
7 | from ops.argparser import argparser
8 | from ops.Config_Envrionment import Config_Environment
9 | import torch.multiprocessing as mp
10 | from training.main_worker import main_worker
11 | def main(args):
12 | #config environment
13 | ngpus_per_node=Config_Environment(args)
14 |
15 | # call training main control function
16 | if args.multiprocessing_distributed==1:
17 | # Since we have ngpus_per_node processes per node, the total world_size
18 | # needs to be adjusted accordingly
19 | args.world_size = ngpus_per_node * args.world_size
20 | # Use torch.multiprocessing.spawn to launch distributed processes: the
21 | # main_worker process function
22 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
23 | else:
24 | # Simply call main_worker function
25 | main_worker(args.gpu, ngpus_per_node, args)
26 |
27 |
28 | if __name__ == '__main__':
29 | #use_cuda = torch.cuda.is_available()
30 | #print("starting check cuda status",use_cuda)
31 | #if use_cuda:
32 | args,params=argparser()
33 | main(args)
--------------------------------------------------------------------------------
/model/CLSA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class CLSA(nn.Module):
5 |
6 | def __init__(self, base_encoder, args, dim=128, K=65536, m=0.999, T=0.2, mlp=True):
7 | """
8 | :param base_encoder: encoder model
9 | :param args: config parameters
10 | :param dim: feature dimension (default: 128)
11 | :param K: queue size; number of negative keys (default: 65536)
12 | :param m: momentum of updating key encoder (default: 0.999)
13 | :param T: softmax temperature (default: 0.2)
14 | :param mlp: use MLP layer to process encoder output or not (default: True)
15 | """
16 | super(CLSA, self).__init__()
17 | self.args = args
18 | self.K = K
19 | self.m = m
20 | self.T = T
21 | self.T2 = self.args.clsa_t
22 |
23 | # create the encoders
24 | # num_classes is the output fc dimension
25 | self.encoder_q = base_encoder(num_classes=dim)
26 | self.encoder_k = base_encoder(num_classes=dim)
27 |
28 | if mlp: # hack: brute-force replacement
29 | dim_mlp = self.encoder_q.fc.weight.shape[1]
30 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
31 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
32 |
33 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
34 | param_k.data.copy_(param_q.data) # initialize
35 | param_k.requires_grad = False # not update by gradient
36 | self.register_buffer("queue", torch.randn(dim, K))
37 | self.queue = nn.functional.normalize(self.queue, dim=0) # normalize across queue instead of each example
38 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
39 | # config parameters for CLSA stronger augmentation and multi-crop
40 | self.weak_pick = args.pick_weak
41 | self.strong_pick = args.pick_strong
42 | self.weak_pick = set(self.weak_pick)
43 | self.strong_pick = set(self.strong_pick)
44 | self.gpu = args.gpu
45 | self.sym = self.args.sym
46 |
47 | @torch.no_grad()
48 | def _momentum_update_key_encoder(self):
49 | """
50 | Momentum update of the key encoder
51 | """
52 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
53 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
54 |
55 | @torch.no_grad()
56 | def _dequeue_and_enqueue(self, queue, queue_ptr, keys):
57 | # gather keys before updating queue
58 | #keys = concat_all_gather(keys) #already concatenated before
59 |
60 | batch_size = keys.shape[0]
61 |
62 | ptr = int(queue_ptr)
63 | assert self.K % batch_size == 0 # for simplicity
64 |
65 | # replace the keys at ptr (dequeue and enqueue)
66 | queue[:, ptr:ptr + batch_size] = keys.T
67 | ptr = (ptr + batch_size) % self.K # move pointer
68 |
69 | queue_ptr[0] = ptr
70 |
71 | @torch.no_grad()
72 | def _batch_shuffle_ddp(self, x):
73 | """
74 | Batch shuffle, for making use of BatchNorm.
75 | *** Only support DistributedDataParallel (DDP) model. ***
76 | """
77 | # gather from all gpus
78 | batch_size_this = x.shape[0]
79 | x_gather = concat_all_gather(x)
80 | batch_size_all = x_gather.shape[0]
81 |
82 | num_gpus = batch_size_all // batch_size_this
83 |
84 | # random shuffle index
85 | idx_shuffle = torch.randperm(batch_size_all).cuda()
86 |
87 | # broadcast to all gpus
88 | torch.distributed.broadcast(idx_shuffle, src=0)
89 |
90 | # index for restoring
91 | idx_unshuffle = torch.argsort(idx_shuffle)
92 |
93 | # shuffled index for this gpu
94 | gpu_idx = torch.distributed.get_rank()
95 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
96 |
97 | return x_gather[idx_this], idx_unshuffle
98 |
99 | @torch.no_grad()
100 | def _batch_unshuffle_ddp(self, x, idx_unshuffle):
101 | """
102 | Undo batch shuffle.
103 | *** Only support DistributedDataParallel (DDP) model. ***
104 | """
105 | # gather from all gpus
106 | batch_size_this = x.shape[0]
107 | x_gather = concat_all_gather(x)
108 | batch_size_all = x_gather.shape[0]
109 |
110 | num_gpus = batch_size_all // batch_size_this
111 |
112 | # restored index for this gpu
113 | gpu_idx = torch.distributed.get_rank()
114 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
115 |
116 | return x_gather[idx_this]
117 | def forward(self, im_q_list, im_k,im_strong_list):
118 | """
119 | :param im_q_list: query image list
120 | :param im_k: key image
121 | :param im_strong_list: query strong image list
122 | :return:
123 | weak: logit_list, label_list
124 | strong: logit_list, label_list
125 | """
126 | if self.sym:
127 | q_list = []
128 | for k, im_q in enumerate(im_q_list): # weak forward
129 | if k not in self.weak_pick:
130 | continue
131 | # can't shuffle because it will stop gradient only can be applied for k
132 | # im_q, idx_unshuffle = self._batch_shuffle_ddp(im_q)
133 | q = self.encoder_q(im_q) # queries: NxC
134 | q = nn.functional.normalize(q, dim=1)
135 | # q = self._batch_unshuffle_ddp(q, idx_unshuffle)
136 | q_list.append(q)
137 | # add the encoding of im_k as one of weakly supervised
138 | q = self.encoder_q(im_k)
139 | q = nn.functional.normalize(q, dim=1)
140 | q_list.append(q)
141 |
142 | q_strong_list = []
143 | for k, im_strong in enumerate(im_strong_list):
144 | # im_strong, idx_unshuffle = self._batch_shuffle_ddp(im_strong)
145 | if k not in self.strong_pick:
146 | continue
147 | q_strong = self.encoder_q(im_strong) # queries: NxC
148 | q_strong = nn.functional.normalize(q_strong, dim=1)
149 | # q_strong = self._batch_unshuffle_ddp(q_strong, idx_unshuffle)
150 | q_strong_list.append(q_strong)
151 | with torch.no_grad(): # no gradient to keys
152 | # if update_key_encoder:
153 | self._momentum_update_key_encoder() # update the key encoder
154 |
155 | # shuffle for making use of BN
156 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
157 |
158 | k = self.encoder_k(im_k) # keys: NxC
159 | k = nn.functional.normalize(k, dim=1)
160 | # undo shuffle
161 | k = self._batch_unshuffle_ddp(k, idx_unshuffle)
162 | k = k.detach()
163 | k = concat_all_gather(k)
164 |
165 | k2 = self.encoder_k(im_q_list[0]) # keys: NxC
166 | k2 = nn.functional.normalize(k2, dim=1)
167 | # undo shuffle
168 | k2 = self._batch_unshuffle_ddp(k2, idx_unshuffle)
169 | k2 = k2.detach()
170 | k2 = concat_all_gather(k2)
171 | logits0_list = []
172 | labels0_list = []
173 | logits1_list = []
174 | labels1_list = []
175 | # first iter the 1st k supervised
176 | for choose_idx in range(len(q_list) - 1):
177 | q = q_list[choose_idx]
178 | # positive logits: NxN
179 | l_pos = torch.einsum('nc,ck->nk', [q, k.T])
180 | # negative logits: NxK
181 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
182 | # logits: Nx(1+K)
183 | logits = torch.cat([l_pos, l_neg], dim=1)
184 |
185 | # apply temperature
186 | logits /= self.T
187 |
188 | # labels: positive key indicators
189 |
190 | cur_batch_size = logits.shape[0]
191 | cur_gpu = self.gpu
192 | choose_match = cur_gpu * cur_batch_size
193 | labels = torch.arange(choose_match, choose_match + cur_batch_size, dtype=torch.long).cuda()
194 |
195 | logits0_list.append(logits)
196 | labels0_list.append(labels)
197 |
198 | labels0 = logits.clone().detach() # use previous q as supervision
199 | labels0 = labels0 * self.T / self.T2
200 | labels0 = torch.softmax(labels0, dim=1)
201 | labels0 = labels0.detach()
202 | for choose_idx2 in range(len(q_strong_list)):
203 | q_strong = q_strong_list[choose_idx2]
204 | # weak strong loss
205 |
206 | l_pos = torch.einsum('nc,ck->nk', [q_strong, k.T])
207 | # negative logits: NxK
208 | l_neg = torch.einsum('nc,ck->nk', [q_strong, self.queue.clone().detach()])
209 |
210 | # logits: Nx(1+K)
211 | logits0 = torch.cat([l_pos, l_neg], dim=1) # N*(K+1)
212 |
213 | # apply temperature
214 | logits0 /= self.T2
215 | logits0 = torch.softmax(logits0, dim=1)
216 |
217 | logits1_list.append(logits0)
218 | labels1_list.append(labels0)
219 | # iter another part, symmetrized
220 | k = k2
221 | for choose_idx in range(1, len(q_list)):
222 | q = q_list[choose_idx]
223 | # positive logits: NxN
224 | l_pos = torch.einsum('nc,ck->nk', [q, k.T])
225 | # negative logits: NxK
226 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
227 | # logits: Nx(1+K)
228 | logits = torch.cat([l_pos, l_neg], dim=1)
229 |
230 | # apply temperature
231 | logits /= self.T
232 |
233 | # labels: positive key indicators
234 |
235 | cur_batch_size = logits.shape[0]
236 | cur_gpu = self.gpu
237 | choose_match = cur_gpu * cur_batch_size
238 | labels = torch.arange(choose_match, choose_match + cur_batch_size, dtype=torch.long).cuda()
239 |
240 | logits0_list.append(logits)
241 | labels0_list.append(labels)
242 |
243 | labels0 = logits.clone().detach() # use previous q as supervision
244 | labels0 = labels0 * self.T / self.T2
245 | labels0 = torch.softmax(labels0, dim=1)
246 | labels0 = labels0.detach()
247 | for choose_idx2 in range(len(q_strong_list)):
248 | q_strong = q_strong_list[choose_idx2]
249 | # weak strong loss
250 |
251 | l_pos = torch.einsum('nc,ck->nk', [q_strong, k.T])
252 | # negative logits: NxK
253 | l_neg = torch.einsum('nc,ck->nk', [q_strong, self.queue.clone().detach()])
254 |
255 | # logits: Nx(1+K)
256 | logits0 = torch.cat([l_pos, l_neg], dim=1) # N*(K+1)
257 |
258 | # apply temperature
259 | logits0 /= self.T2
260 | logits0 = torch.softmax(logits0, dim=1)
261 |
262 | logits1_list.append(logits0)
263 | labels1_list.append(labels0)
264 |
265 | # dequeue and enqueue
266 | # if update_key_encoder==False:
267 | self._dequeue_and_enqueue(self.queue, self.queue_ptr, k)
268 |
269 | return logits0_list, labels0_list, logits1_list, labels1_list
270 | else:
271 | q_list = []
272 | for k, im_q in enumerate(im_q_list): # weak forward
273 | if k not in self.weak_pick:
274 | continue
275 | # can't shuffle because it will stop gradient only can be applied for k
276 | # im_q, idx_unshuffle = self._batch_shuffle_ddp(im_q)
277 | q = self.encoder_q(im_q) # queries: NxC
278 | q = nn.functional.normalize(q, dim=1)
279 | # q = self._batch_unshuffle_ddp(q, idx_unshuffle)
280 | q_list.append(q)
281 |
282 | q_strong_list = []
283 | for k, im_strong in enumerate(im_strong_list):
284 | # im_strong, idx_unshuffle = self._batch_shuffle_ddp(im_strong)
285 | if k not in self.strong_pick:
286 | continue
287 | q_strong = self.encoder_q(im_strong) # queries: NxC
288 | q_strong = nn.functional.normalize(q_strong, dim=1)
289 | # q_strong = self._batch_unshuffle_ddp(q_strong, idx_unshuffle)
290 | q_strong_list.append(q_strong)
291 |
292 | # compute key features
293 | with torch.no_grad(): # no gradient to keys
294 | # if update_key_encoder:
295 | self._momentum_update_key_encoder() # update the key encoder
296 |
297 | # shuffle for making use of BN
298 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
299 |
300 | k = self.encoder_k(im_k) # keys: NxC
301 | k = nn.functional.normalize(k, dim=1)
302 |
303 | # undo shuffle
304 | k = self._batch_unshuffle_ddp(k, idx_unshuffle)
305 | k = k.detach()
306 | k = concat_all_gather(k)
307 |
308 | # compute logits
309 | # Einstein sum is more intuitive
310 |
311 | logits0_list = []
312 | labels0_list = []
313 | logits1_list = []
314 | labels1_list = []
315 | for choose_idx in range(len(q_list)):
316 | q = q_list[choose_idx]
317 |
318 | # positive logits: Nx1
319 | l_pos = torch.einsum('nc,ck->nk', [q, k.T])
320 | # negative logits: NxK
321 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
322 |
323 | # logits: Nx(1+K)
324 | logits = torch.cat([l_pos, l_neg], dim=1)
325 |
326 | # apply temperature
327 | logits /= self.T
328 |
329 | # labels: positive key indicators
330 | cur_batch_size = logits.shape[0]
331 | cur_gpu = self.gpu
332 | choose_match = cur_gpu * cur_batch_size
333 | labels = torch.arange(choose_match, choose_match + cur_batch_size, dtype=torch.long).cuda()
334 |
335 | logits0_list.append(logits)
336 | labels0_list.append(labels)
337 |
338 | labels0 = logits.clone().detach() # use previous q as supervision
339 | labels0 = labels0*self.T/self.T2
340 | labels0 = torch.softmax(labels0, dim=1)
341 | labels0 = labels0.detach()
342 | for choose_idx2 in range(len(q_strong_list)):
343 | q_strong = q_strong_list[choose_idx2]
344 | # weak strong loss
345 |
346 | l_pos = torch.einsum('nc,ck->nk', [q_strong, k.T])
347 | # negative logits: NxK
348 | l_neg = torch.einsum('nc,ck->nk', [q_strong, self.queue.clone().detach()])
349 |
350 | # logits: Nx(1+K)
351 | logits0 = torch.cat([l_pos, l_neg], dim=1) # N*(K+1)
352 |
353 | # apply temperature
354 | logits0 /= self.T2
355 | logits0 = torch.softmax(logits0, dim=1)
356 |
357 | logits1_list.append(logits0)
358 | labels1_list.append(labels0)
359 |
360 | # dequeue and enqueue
361 | # if update_key_encoder==False:
362 | self._dequeue_and_enqueue(self.queue, self.queue_ptr, k)
363 |
364 | return logits0_list, labels0_list, logits1_list, labels1_list
365 |
366 |
367 |
368 |
369 |
370 | @torch.no_grad()
371 | def concat_all_gather(tensor):
372 | """
373 | Performs all_gather operation on the provided tensors.
374 | *** Warning ***: torch.distributed.all_gather has no gradient.
375 | """
376 | tensors_gather = [torch.ones_like(tensor)
377 | for _ in range(torch.distributed.get_world_size())]
378 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
379 |
380 | output = torch.cat(tensors_gather, dim=0)
381 | return output
382 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/model/__init__.py
--------------------------------------------------------------------------------
/ops/Config_Envrionment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import resource
3 | import torch
4 | import warnings
5 | import random
6 | import torch.backends.cudnn as cudnn
7 |
8 | def Config_Environment(args):
9 | # increase the limit of resources to make sure it can run under any conditions
10 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
11 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
12 |
13 | # config gpu settings
14 | choose = args.choose
15 | if choose is not None and args.nodes_num == 1:
16 | os.environ['CUDA_VISIBLE_DEVICES'] = choose
17 | print("Current we choose gpu:%s" % choose)
18 | use_cuda = torch.cuda.is_available()
19 | print("Cuda status ", use_cuda)
20 | ngpus_per_node = torch.cuda.device_count()
21 | print("in total we have ", ngpus_per_node, " gpu")
22 | if ngpus_per_node <= 0:
23 | print("We do not have gpu supporting, exit!!!")
24 | exit()
25 | if args.gpu is not None:
26 | warnings.warn('You have chosen a specific GPU. This will completely '
27 | 'disable data parallelism.')
28 |
29 | if args.dist_url == "env://" and args.world_size == -1:
30 | args.world_size = int(os.environ["WORLD_SIZE"])
31 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
32 |
33 | #init random seed
34 | if args.seed is not None:
35 | random.seed(args.seed)
36 | torch.manual_seed(args.seed)
37 | cudnn.deterministic = True
38 | warnings.warn('You have chosen to seed training. '
39 | 'This will turn on the CUDNN deterministic setting, '
40 | 'which can slow down your training considerably! '
41 | 'You may see unexpected behavior when restarting '
42 | 'from checkpoints.')
43 | return ngpus_per_node
--------------------------------------------------------------------------------
/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/ops/__init__.py
--------------------------------------------------------------------------------
/ops/argparser.py:
--------------------------------------------------------------------------------
1 | import parser
2 | import argparse
3 |
4 | def argparser():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--data', default="data", type=str, metavar='DIR',
7 | help='path to dataset')
8 | parser.add_argument('--log_path', type=str, default="train_log", help="log path for saving models and logs")
9 | parser.add_argument('--arch', metavar='ARCH', default='resnet50',
10 | type=str,
11 | help='model architecture: (default: resnet50)')
12 | parser.add_argument('--workers', default=32, type=int, metavar='N',
13 | help='number of data loading workers (default: 32)')
14 | parser.add_argument('--epochs', default=200, type=int, metavar='N',
15 | help='number of total epochs to run')
16 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
17 | help='manual epoch number (useful on restarts)')
18 | parser.add_argument('-b', '--batch_size', default=256, type=int,
19 | metavar='N',
20 | help='mini-batch size (default: 256), this is the total '
21 | 'batch size of all GPUs on the current node when '
22 | 'using Data Parallel or Distributed Data Parallel')
23 | parser.add_argument('--lr', '--learning_rate', default=0.03, type=float,
24 | metavar='LR', help='initial learning rate', dest='lr')
25 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
26 | help='momentum of SGD solver')
27 | parser.add_argument('--weight_decay', default=1e-4, type=float,
28 | help='weight decay (default: 1e-4)')
29 | parser.add_argument('--print_freq', default=10, type=int,
30 | metavar='N', help='print frequency (default: 10)')
31 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
32 | help='path to latest checkpoint (default: none)')
33 | parser.add_argument('--world_size', default=-1, type=int,
34 | help='number of nodes for distributed training')
35 | parser.add_argument('--rank', default=-1, type=int,
36 | help='node rank for distributed training,rank of total threads, 0 to args.world_size-1')
37 | parser.add_argument('--dist_url', default='tcp://localhost:10001', type=str,
38 | help='url used to set up distributed training')
39 | parser.add_argument('--dist_backend', default='nccl', type=str,
40 | help='distributed backend')
41 | parser.add_argument('--seed', default=None, type=int,
42 | help='seed for initializing training. ')
43 | parser.add_argument('--gpu', default=None, type=int,
44 | help='GPU id to use.')
45 | parser.add_argument('--multiprocessing_distributed', type=int, default=1,
46 | help='Use multi-processing distributed training to launch '
47 | 'N processes per node, which has N GPUs. This is the '
48 | 'fastest way to use PyTorch for either single node or '
49 | 'multi node data parallel training')
50 | parser.add_argument("--nodes_num", type=int, default=1, help="number of nodes to use")
51 | parser.add_argument('--dataset', type=str, default="ImageNet", help="Specify dataset: default: ImageNet")
52 |
53 | # Baseline: moco specific configs:
54 | parser.add_argument('--moco_dim', default=128, type=int,
55 | help='feature dimension (default: 128)')
56 | parser.add_argument('--moco_k', default=65536, type=int,
57 | help='queue size; number of negative keys (default: 65536)')
58 | parser.add_argument('--moco_m', default=0.999, type=float,
59 | help='moco momentum of updating key encoder (default: 0.999)')
60 | parser.add_argument('--moco_t', default=0.2, type=float,
61 | help='softmax temperature (default: 0.2)')
62 | parser.add_argument('--mlp', type=int, default=1,
63 | help='use mlp head')
64 | parser.add_argument('--cos', type=int, default=1,
65 | help='use cosine lr schedule')
66 | parser.add_argument('--choose', type=str, default=None,
67 | help="choose gpu for training, default:None(Use all available GPUs)")
68 |
69 | #clsa parameter configuration
70 | parser.add_argument('--alpha', type=float, default=1,
71 | help="coefficients for DDM loss")
72 | parser.add_argument('--aug_times', type=int, default=5,
73 | help="random augmentation times in strong augmentation")
74 | # idea from swav#adds crops for it
75 | parser.add_argument("--nmb_crops", type=int, default=[1, 1, 1, 1, 1], nargs="+",
76 | help="list of number of crops (example: [2, 6])") # when use 0 denotes the multi crop is not applied
77 | parser.add_argument("--size_crops", type=int, default=[224, 192, 160, 128, 96], nargs="+",
78 | help="crops resolutions (example: [224, 96])")
79 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.172, 0.143, 0.114, 0.086], nargs="+",
80 | help="min scale crop argument in RandomResizedCrop ")
81 | parser.add_argument("--max_scale_crops", type=float, default=[1.0, 0.86, 0.715, 0.571, 0.429], nargs="+",
82 | help="max scale crop argument in RandomResizedCrop ")
83 | parser.add_argument("--pick_strong", type=int, default=[0, 1, 2, 3, 4], nargs="+",
84 | help="specify the strong augmentation that will be used ")
85 | parser.add_argument("--pick_weak", type=int, default=[0, 1, 2, 3, 4], nargs="+",
86 | help="specify the weak augmentation that will be used ")
87 | parser.add_argument("--clsa_t", type=float, default=0.2, help="temperature used for ddm loss")
88 | parser.add_argument("--sym",type=int,default=0,help="symmetrical loss apply or not (default:False)")
89 | args = parser.parse_args()
90 | params = vars(args)
91 | return args,params
--------------------------------------------------------------------------------
/ops/os_operation.py:
--------------------------------------------------------------------------------
1 | # Publication: "Protein Docking Model Evaluation by Graph Neural Networks", Xiao Wang, Sean T Flannery and Daisuke Kihara, (2020)
2 |
3 | #GNN-Dove is a computational tool using graph neural network that can evaluate the quality of docking protein-complexes.
4 |
5 | #Copyright (C) 2020 Xiao Wang, Sean T Flannery, Daisuke Kihara, and Purdue University.
6 |
7 | #License: GPL v3 for academic use. (For commercial use, please contact us for different licensing.)
8 |
9 | #Contact: Daisuke Kihara (dkihara@purdue.edu)
10 |
11 | #
12 |
13 | # This program is free software: you can redistribute it and/or modify
14 |
15 | # it under the terms of the GNU General Public License as published by
16 |
17 | # the Free Software Foundation, version 3.
18 |
19 | #
20 |
21 | # This program is distributed in the hope that it will be useful,
22 |
23 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
24 |
25 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 |
27 | # GNU General Public License V3 for more details.
28 |
29 | #
30 |
31 | # You should have received a copy of the GNU v3.0 General Public License
32 |
33 | # along with this program. If not, see https://www.gnu.org/licenses/gpl-3.0.en.html.
34 |
35 | import os
36 | def mkdir(path):
37 | path=path.strip()
38 | path=path.rstrip("\\")
39 | isExists=os.path.exists(path)
40 | if not isExists:
41 | print (path+" created")
42 | os.makedirs(path)
43 | return True
44 | else:
45 | print (path+' existed')
46 | return False
47 | def execCmd(cmd):
48 | r = os.popen(cmd)
49 | text = r.read()
50 | r.close()
51 | return text
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | torchvision==0.8.2
3 | numpy==1.19.5
4 | Pillow==5.1.0
5 | tensorboard==1.14.0
6 | tensorboardX==1.7
7 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/training/__init__.py
--------------------------------------------------------------------------------
/training/main_worker.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import torch.distributed as dist
3 | import os
4 | import torchvision.models as models
5 | import torch
6 | import torch.nn as nn
7 | import torch.backends.cudnn as cudnn
8 | import torchvision.transforms as transforms
9 | import torchvision.datasets as datasets
10 | import datetime
11 | import time
12 |
13 | from model.CLSA import CLSA
14 | from ops.os_operation import mkdir
15 | from data_processing.Multi_FixTransform import Multi_Fixtransform
16 | from training.train_utils import adjust_learning_rate,save_checkpoint
17 | from training.train import train
18 |
19 |
20 | def init_log_path(args):
21 | """
22 | :param args:
23 | :return:
24 | save model+log path
25 | """
26 | save_path = os.path.join(os.getcwd(), args.log_path)
27 | mkdir(save_path)
28 | save_path = os.path.join(save_path, args.dataset)
29 | mkdir(save_path)
30 | save_path = os.path.join(save_path, "Alpha_" + str(args.alpha))
31 | mkdir(save_path)
32 | save_path = os.path.join(save_path, "Aug_" + str(args.aug_times))
33 | mkdir(save_path)
34 | save_path = os.path.join(save_path, "lr_" + str(args.lr))
35 | mkdir(save_path)
36 | save_path = os.path.join(save_path, "cos_" + str(args.cos))
37 | mkdir(save_path)
38 | today = datetime.date.today()
39 | formatted_today = today.strftime('%y%m%d')
40 | now = time.strftime("%H:%M:%S")
41 | save_path = os.path.join(save_path, formatted_today + now)
42 | mkdir(save_path)
43 | return save_path
44 |
45 |
46 | def main_worker(gpu, ngpus_per_node, args):
47 | """
48 | :param gpu: current gpu id
49 | :param ngpus_per_node: number of gpus in one node
50 | :param args: config parameter
51 | :return:
52 | init training setup and iteratively training
53 | """
54 | params = vars(args)
55 | args.gpu = gpu
56 |
57 | # suppress printing if not master
58 | if args.multiprocessing_distributed and args.gpu != 0:
59 | def print_pass(*args):
60 | pass
61 |
62 | builtins.print = print_pass
63 |
64 | if args.gpu is not None:
65 | print("Use GPU: {} for training".format(args.gpu))
66 | print("=> creating model '{}'".format(args.arch))
67 | if args.distributed:
68 | if args.dist_url == "env://" and args.rank == -1:
69 | args.rank = int(os.environ["RANK"])
70 | if args.multiprocessing_distributed:
71 | # For multiprocessing distributed training, rank needs to be the
72 | # global rank among all the processes
73 | args.rank = args.rank * ngpus_per_node + gpu
74 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
75 | world_size=args.world_size, rank=args.rank)
76 | #init model
77 | model = CLSA(models.__dict__[args.arch], args,
78 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp)
79 | print(model)
80 |
81 |
82 | if args.distributed:
83 | # For multiprocessing distributed, DistributedDataParallel constructor
84 | # should always set the single device scope, otherwise,
85 | # DistributedDataParallel will use all available devices.
86 | if args.gpu is not None:
87 | torch.cuda.set_device(args.gpu)
88 | model.cuda(args.gpu)
89 |
90 | # When using a single GPU per process and per
91 | # DistributedDataParallel, we need to divide the batch size
92 | # ourselves based on the total number of GPUs we have
93 | args.batch_size = int(args.batch_size / ngpus_per_node)
94 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
95 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
96 | else:
97 | model.cuda()
98 | # DistributedDataParallel will divide and allocate batch_size to all
99 | # available GPUs if device_ids are not set
100 | model = torch.nn.parallel.DistributedDataParallel(model)
101 | elif args.gpu is not None:
102 | torch.cuda.set_device(args.gpu)
103 | model = model.cuda(args.gpu)
104 | # comment out the following line for debugging
105 | raise NotImplementedError("Only DistributedDataParallel is supported.")
106 | else:
107 | # AllGather implementation (batch shuffle, queue update, etc.) in
108 | # this code only supports DistributedDataParallel.
109 | raise NotImplementedError("Only DistributedDataParallel is supported.")
110 |
111 | # define loss function (criterion) and optimizer
112 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
113 |
114 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
115 | momentum=args.momentum,
116 | weight_decay=args.weight_decay)
117 |
118 | # optionally resume from a checkpoint
119 | if args.resume:
120 | if os.path.isfile(args.resume):
121 | print("=> loading checkpoint '{}'".format(args.resume))
122 | if args.gpu is None:
123 | checkpoint = torch.load(args.resume)
124 | else:
125 | # Map model to be loaded to specified single gpu.
126 | loc = 'cuda:{}'.format(args.gpu)
127 | checkpoint = torch.load(args.resume, map_location=loc)
128 | args.start_epoch = checkpoint['epoch']
129 | model.load_state_dict(checkpoint['state_dict'])
130 | optimizer.load_state_dict(checkpoint['optimizer'])
131 | print("=> loaded checkpoint '{}' (epoch {})"
132 | .format(args.resume, checkpoint['epoch']))
133 | else:
134 | print("=> no checkpoint found at '{}'".format(args.resume))
135 | exit()
136 |
137 | cudnn.benchmark = True
138 | # config data loader
139 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
140 | std=[0.229, 0.224, 0.225])
141 |
142 | fix_transform = Multi_Fixtransform(args.size_crops,
143 | args.nmb_crops,
144 | args.min_scale_crops,
145 | args.max_scale_crops, normalize, args.aug_times)
146 | traindir = os.path.join(args.data, 'train')
147 | train_dataset = datasets.ImageFolder(
148 | traindir,
149 | fix_transform)
150 | if args.distributed:
151 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
152 | else:
153 | train_sampler = None
154 | train_loader = torch.utils.data.DataLoader(
155 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
156 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
157 | save_path=init_log_path(args) #config model save path and log path
158 | log_path = os.path.join(save_path,"train.log")
159 | best_Acc = 0
160 | for epoch in range(args.start_epoch, args.epochs):
161 | if args.distributed:
162 | train_sampler.set_epoch(epoch)
163 | adjust_learning_rate(optimizer, epoch, args)
164 | acc1 = train(train_loader, model, criterion, optimizer, epoch, args,log_path)
165 | is_best = best_Acc > acc1
166 | best_Acc = max(best_Acc, acc1)
167 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
168 | and args.rank % ngpus_per_node == 0):
169 | save_dict = {
170 | 'epoch': epoch + 1,
171 | 'arch': args.arch,
172 | 'best_acc': best_Acc,
173 | 'state_dict': model.state_dict(),
174 | 'optimizer': optimizer.state_dict(),
175 | }
176 |
177 | if epoch % 10 == 9:
178 | tmp_save_path = os.path.join(save_path, 'checkpoint_{:04d}.pth.tar'.format(epoch))
179 | save_checkpoint(save_dict, is_best=False, filename=tmp_save_path)
180 | tmp_save_path = os.path.join(save_path, 'checkpoint_best.pth.tar')
181 | save_checkpoint(save_dict, is_best=is_best, filename=tmp_save_path)
182 |
183 |
--------------------------------------------------------------------------------
/training/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch.nn as nn
3 | import torch
4 |
5 | from training.train_utils import AverageMeter,ProgressMeter,accuracy
6 |
7 | def train(train_loader, model, criterion, optimizer, epoch, args,log_path):
8 | """
9 | :param train_loader: data loader
10 | :param model: training model
11 | :param criterion: loss function
12 | :param optimizer: SGD optimizer
13 | :param epoch: current epoch
14 | :param args: config parameter
15 | :return:
16 | """
17 | batch_time = AverageMeter('Time', ':6.3f')
18 | data_time = AverageMeter('Data', ':6.3f')
19 | losses = AverageMeter('Loss', ':.4e')
20 | top1 = AverageMeter('Acc@1', ':6.2f')
21 | top5 = AverageMeter('Acc@5', ':6.2f')
22 | progress = ProgressMeter(
23 | len(train_loader),
24 | [batch_time, data_time, losses, top1, top5],
25 | prefix="Epoch: [{}]".format(epoch))
26 |
27 | # switch to train mode
28 | model.train()
29 |
30 | end = time.time()
31 | mse_criterion=nn.MSELoss().cuda(args.gpu)
32 | for i, (images, _) in enumerate(train_loader):
33 | # measure data loading time
34 | data_time.update(time.time() - end)
35 |
36 | if args.gpu is not None:
37 | len_images = len(images)
38 | for k in range(len(images)):
39 | images[k] = images[k].cuda(args.gpu, non_blocking=True)
40 | crop_copy_length = int((len_images - 1) / 2)
41 | image_k = images[0]
42 | image_q = images[1:1 + crop_copy_length]
43 | image_strong = images[1 + crop_copy_length:]
44 |
45 | output, target, output2, target2 = model(image_q, image_k, image_strong)
46 | loss_contrastive = 0
47 | loss_weak_strong = 0
48 | if epoch == 0 and i == 0:
49 | print("-" * 100)
50 | print("contrastive loss count %d" % len(output))
51 | print("weak strong loss count %d" % len(output2))
52 | print("-" * 100)
53 | for k in range(len(output)):
54 | loss1 = criterion(output[k], target[k])
55 | loss_contrastive += loss1
56 | for k in range(len(output2)):
57 | loss2 = -torch.mean(torch.sum(torch.log(output2[k]) * target2[k], dim=1)) # DDM loss
58 | loss_weak_strong += loss2
59 | loss = loss_contrastive + args.alpha * loss_weak_strong
60 | # acc1/acc5 are (K+1)-way contrast classifier accuracy
61 | # measure accuracy and record loss
62 | acc1, acc5 = accuracy(output[0], target[0], topk=(1, 5))
63 | losses.update(loss.item(), images[0].size(0))
64 | top1.update(acc1[0], images[0].size(0))
65 | top5.update(acc5[0], images[0].size(0))
66 |
67 | # compute gradient and do SGD step
68 | optimizer.zero_grad()
69 | loss.backward()
70 | optimizer.step()
71 |
72 | # measure elapsed time
73 | batch_time.update(time.time() - end)
74 | end = time.time()
75 |
76 | if i % args.print_freq == 0:
77 | progress.display(i)
78 | progress.write_record(i,log_path)
79 | return top1.avg
--------------------------------------------------------------------------------
/training/train_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import shutil
4 | import os
5 |
6 | def adjust_learning_rate(optimizer, epoch, args):
7 | """
8 | :param optimizer: SGD optimizer
9 | :param epoch: current epoch
10 | :param args: args
11 | :return:
12 | Decay the learning rate based on schedule
13 | """
14 |
15 | lr = args.lr
16 | if args.cos==1: # cosine lr schedule
17 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
18 | elif args.cos==2:
19 | lr *= math.cos(math.pi * epoch / (args.epochs*2))
20 | else: # stepwise lr schedule
21 | lr = args.lr
22 | for param_group in optimizer.param_groups:
23 | param_group['lr'] = lr
24 |
25 | def save_checkpoint(state, is_best, filename):
26 | torch.save(state, filename)
27 | if is_best:
28 | root_path=os.path.split(filename)[0]
29 | best_model_path=os.path.join(root_path,"model_best.pth.tar")
30 | shutil.copyfile(filename, best_model_path)
31 |
32 | class AverageMeter(object):
33 | """Computes and stores the average and current value"""
34 | def __init__(self, name, fmt=':f'):
35 | self.name = name
36 | self.fmt = fmt
37 | self.reset()
38 |
39 | def reset(self):
40 | self.val = 0
41 | self.avg = 0
42 | self.sum = 0
43 | self.count = 0
44 |
45 | def update(self, val, n=1):
46 | self.val = val
47 | self.sum += val * n
48 | self.count += n
49 | self.avg = self.sum / self.count
50 |
51 | def __str__(self):
52 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
53 | return fmtstr.format(**self.__dict__)
54 |
55 |
56 | class ProgressMeter(object):
57 | def __init__(self, num_batches, meters, prefix=""):
58 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
59 | self.meters = meters
60 | self.prefix = prefix
61 |
62 | def display(self, batch):
63 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
64 | entries += [str(meter) for meter in self.meters]
65 | print('\t'.join(entries))
66 |
67 | def _get_batch_fmtstr(self, num_batches):
68 | num_digits = len(str(num_batches // 1))
69 | fmt = '{:' + str(num_digits) + 'd}'
70 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
71 | def write_record(self,batch,filename):
72 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
73 | entries += [str(meter) for meter in self.meters]
74 | with open(filename,"a+") as file:
75 | file.write('\t'.join(entries)+"\n")
76 |
77 | def accuracy(output, target, topk=(1,)):
78 | """
79 | :param output: predicted prob vectors
80 | :param target: ground truth
81 | :param topk: top k predictions considered
82 | :return:
83 | Computes the accuracy over the k top predictions for the specified values of k
84 | """
85 | with torch.no_grad():
86 | maxk = max(topk)
87 | batch_size = target.size(0)
88 |
89 | _, pred = output.topk(maxk, 1, True, True)
90 | pred = pred.t()
91 | correct = pred.eq(target.view(1, -1).expand_as(pred))
92 |
93 | res = []
94 | for k in topk:
95 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
96 | res.append(correct_k.mul_(100.0 / batch_size))
97 | return res
--------------------------------------------------------------------------------