├── LICENSE ├── README.md ├── VOC_CLF ├── dataset.py ├── main.py ├── test.py ├── train.py └── utils.py ├── cmd ├── __init__.py ├── run_multi.sh └── run_single.sh ├── data_processing ├── Image_ops.py ├── Multi_FixTransform.py ├── RandAugment.py └── __init__.py ├── detection ├── configs │ ├── Base-RCNN-C4-BN.yaml │ ├── coco_R_50_C4_2x.yaml │ ├── coco_R_50_C4_2x_clsa.yaml │ ├── pascal_voc_R_50_C4_24k.yaml │ └── pascal_voc_R_50_C4_24k_CLSA.yaml ├── convert-pretrain-to-detectron2.py └── train_net.py ├── lincls.py ├── main_clsa.py ├── model ├── CLSA.py └── __init__.py ├── ops ├── Config_Envrionment.py ├── __init__.py ├── argparser.py └── os_operation.py ├── requirements.txt └── training ├── __init__.py ├── main_worker.py ├── train.py └── train_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Lab for MAchine Perception and LEarning (MAPLE) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLSA 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | CLSA is a self-supervised learning methods which focused on the pattern learning from strong augmentations. 12 | 13 | Copyright (C) 2020 Xiao Wang, Guo-Jun Qi 14 | 15 | License: MIT for academic use. 16 | 17 | Contact: Guo-Jun Qi (guojunq@gmail.com) 18 | 19 | 20 | ## Introduction 21 | Representation learning has been greatly improved with the advance of contrastive learning methods. Those methods have greatly benefited from various data augmentations that are carefully designated to maintain their identities so that the images transformed from the same instance can still be retrieved. However, those carefully designed transformations limited us to further explore the novel patterns carried by other transformations. To pave this gap, we propose a general framework called Contrastive Learning with Stronger Augmentations(CLSA) to complement current contrastive learning approaches. As found in our experiments, the distortions induced from the stronger make the transformed images can not be viewed as the same instance any more. Thus, we propose to minimize the distribution divergence between the weakly and strongly augmented images over the representation bank to supervise the retrieval of strongly augmented queries from a pool of candidates. Experiments on ImageNet dataset and downstream datasets showed the information from the strongly augmented images can greatly boost the performance. For example, CLSA achieves top-1 accuracy of 76.2% on ImageNet with a standard ResNet-50 architecture with a single-layer classifier fine-tuned, which is almost the same level as 76.5% of supervised results. 22 | 23 | ## Installation 24 | CUDA version should be 10.1 or higher. 25 | ### 1. [`Install git`](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) 26 | ### 2. Clone the repository in your computer 27 | ``` 28 | git clone git@github.com:maple-research-lab/CLSA.git && cd CLSA 29 | ``` 30 | 31 | ### 3. Build dependencies. 32 | You have two options to install dependency on your computer: 33 | #### 3.1 Install with pip and python(Ver 3.6.9). 34 | ##### 3.1.1[`install pip`](https://pip.pypa.io/en/stable/installing/). 35 | ##### 3.1.2 Install dependency in command line. 36 | ``` 37 | pip install -r requirements.txt --user 38 | ``` 39 | If you encounter any errors, you can install each library one by one: 40 | ``` 41 | pip install torch==1.7.1 42 | pip install torchvision==0.8.2 43 | pip install numpy==1.19.5 44 | pip install Pillow==5.1.0 45 | pip install tensorboard==1.14.0 46 | pip install tensorboardX==1.7 47 | ``` 48 | 49 | #### 3.2 Install with anaconda 50 | ##### 3.2.1 [`install conda`](https://docs.conda.io/projects/conda/en/latest/user-guide/install/macos.html). 51 | ##### 3.2.2 Install dependency in command line 52 | ``` 53 | conda create -n CLSA python=3.6.9 54 | conda activate CLSA 55 | pip install -r requirements.txt 56 | ``` 57 | Each time when you want to run my code, simply activate the environment by 58 | ``` 59 | conda activate CLSA 60 | conda deactivate(If you want to exit) 61 | ``` 62 | #### 4 Prepare the ImageNet dataset 63 | ##### 4.1 Download the [ImageNet2012 Dataset](http://image-net.org/challenges/LSVRC/2012/) under "./datasets/imagenet2012". 64 | ##### 4.2 Go to path "./datasets/imagenet2012/val" 65 | ##### 4.3 move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) 66 | 67 | ## Usage 68 | 69 | ### Unsupervised Training 70 | This implementation only supports multi-gpu, DistributedDataParallel training, which is faster and simpler; single-gpu or DataParallel training is not supported. 71 | #### Single Crop 72 | ##### 1 Without symmetrical loss 73 | ``` 74 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 --size_crops 224 96 --min_scale_crops 0.2 0.086 --max_scale_crops 1.0 0.429 --pick_strong 1 --pick_weak 0 --clsa_t 0.2 --sym 0 75 | ``` 76 | Here the [data_path] should be the root directory of imagenet dataset. 77 | 78 | ##### 2 With symmetrical loss (Not verified) 79 | ``` 80 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 --size_crops 224 96 --min_scale_crops 0.2 0.086 --max_scale_crops 1.0 0.429 --pick_strong 1 --pick_weak 0 --clsa_t 0.2 --sym 1 81 | ``` 82 | Here the [data_path] should be the root directory of imagenet dataset. 83 | 84 | #### Multi Crop 85 | 86 | ##### 1 Without symmetrical loss 87 | ``` 88 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 1 1 1 --size_crops 224 192 160 128 96 --min_scale_crops 0.2 0.172 0.143 0.114 0.086 --max_scale_crops 1.0 0.86 0.715 0.571 0.429 --pick_strong 0 1 2 3 4 --pick_weak 0 1 2 3 4 --clsa_t 0.2 --sym 0 89 | ``` 90 | Here the [data_path] should be the root directory of imagenet dataset. 91 | 92 | ##### 2 With symmetrical loss (Not verified) 93 | ``` 94 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 1 1 1 --size_crops 224 192 160 128 96 --min_scale_crops 0.2 0.172 0.143 0.114 0.086 --max_scale_crops 1.0 0.86 0.715 0.571 0.429 --pick_strong 0 1 2 3 4 --pick_weak 0 1 2 3 4 --clsa_t 0.2 --sym 1 95 | ``` 96 | Here the [data_path] should be the root directory of imagenet dataset. 97 | 98 | ### Linear Classification 99 | With a pre-trained model, we can easily evaluate its performance on ImageNet with: 100 | ``` 101 | python3 lincls.py --data=./datasets/imagenet2012 --dist-url=tcp://localhost:10001 --pretrained=[pretrained_model_path] 102 | ``` 103 | [pretrained_model_path] should be the Imagenet pretrained model path. 104 | 105 | Performance: 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 |
pre-train
network
pre-train
epochs
CropCLSA
top-1 acc.
Model
Link
ResNet-50200Single69.4model
ResNet-50200Multi73.3model
ResNet-50800Single72.2model
ResNet-50800Multi76.2None
140 | 141 | Really sorry that we can't provide CLSA* 800 epochs' model, which is because that we train it with 32 internal GPUs and we can't download it because of company regulations. For downstream tasks, we found multi-200epoch model also had similar performance. Thus, we suggested you to use this [model](https://purdue0-my.sharepoint.com/:u:/g/personal/wang3702_purdue_edu/Ed8IVMBAvp1GmqABFMskEbYBz6B1vq65kp2IQlukFiS6mw?e=K0G5H6) for downstream purposes. 142 | 143 | ### Transfering to VOC07 Classification 144 | #### 1 Download [Dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar) under "./datasets/voc" 145 | #### 2 Linear Evaluation: 146 | ``` 147 | cd VOC_CLF 148 | python3 main.py --data=[VOC_dataset_dir] --pretrained=[pretrained_model_path] 149 | ``` 150 | Here VOC directory should be the directory includes "vockit" directory; [VOC_dataset_dir] is the VOC dataset path; [pretrained_model_path] is the imagenet pretrained model path. 151 | 152 | ### Transfer to Object Detection 153 | #### 1. Install [detectron2](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). 154 | 155 | #### 2. Convert a pre-trained CLSA model to detectron2's format: 156 | ``` 157 | # in detection folder 158 | python3 convert-pretrain-to-detectron2.py input.pth.tar output.pkl 159 | ``` 160 | 161 | #### 3. download [VOC Dataset](http://places.csail.mit.edu/user/index.php) and [COCO Dataset](https://cocodataset.org/#download) under "./detection/datasets" directory, 162 | following the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) requried by detectron2. 163 | 164 | #### 4. Run training: 165 | ##### 4.1 Pascal detection 166 | ``` 167 | cd detection 168 | python train_net.py --config-file configs/pascal_voc_R_50_C4_24k_CLSA.yaml --num-gpus 8 MODEL.WEIGHTS ./output.pkl 169 | ``` 170 | ##### 4.2 COCO detection 171 | ``` 172 | cd detection 173 | python train_net.py --config-file configs/coco_R_50_C4_2x_clsa.yaml --num-gpus 8 MODEL.WEIGHTS ./output.pkl 174 | ``` 175 | 176 | 177 | ## Citation: 178 | [Contrastive Learning with Stronger Augmentations](https://arxiv.org/abs/2104.07713) 179 | ``` 180 | @article{wang2021contrastive, 181 | title={Contrastive learning with stronger augmentations}, 182 | author={Wang, Xiao and Qi, Guo-Jun}, 183 | journal={arXiv preprint arXiv:2104.07713}, 184 | year={2021} 185 | } 186 | ``` 187 | 188 | 189 | -------------------------------------------------------------------------------- /VOC_CLF/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 12 23:23:51 2019 4 | 5 | @author: Keshik 6 | """ 7 | import torchvision.datasets.voc as voc 8 | 9 | class PascalVOC_Dataset(voc.VOCDetection): 10 | """`Pascal VOC `_ Detection Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the VOC Dataset. 14 | year (string, optional): The dataset year, supports years 2007 to 2012. 15 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` 16 | download (bool, optional): If true, downloads the dataset from the internet and 17 | puts it in root directory. If dataset is already downloaded, it is not 18 | downloaded again. 19 | (default: alphabetic indexing of VOC's 20 classes). 20 | transform (callable, optional): A function/transform that takes in an PIL image 21 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 22 | target_transform (callable, required): A function/transform that takes in the 23 | target and transforms it. 24 | """ 25 | def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None): 26 | 27 | super().__init__( 28 | root, 29 | year=year, 30 | image_set=image_set, 31 | download=download, 32 | transform=transform, 33 | target_transform=target_transform) 34 | 35 | 36 | def __getitem__(self, index): 37 | """ 38 | Args: 39 | index (int): Index 40 | 41 | Returns: 42 | tuple: (image, target) where target is the image segmentation. 43 | """ 44 | return super().__getitem__(index) 45 | 46 | 47 | def __len__(self): 48 | """ 49 | Returns: 50 | size of the dataset 51 | """ 52 | return len(self.images) 53 | -------------------------------------------------------------------------------- /VOC_CLF/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | adopted from https://github.com/keshik6/pascal-voc-classification 4 | Created on Wed Mar 13 10:50:25 2019 5 | 6 | @author: Keshik 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | from torchvision import transforms 12 | import torchvision.models as models 13 | from torch.utils.data import DataLoader 14 | from dataset import PascalVOC_Dataset 15 | import torch.optim as optim 16 | from train import train_model, test 17 | from utils import encode_labels, plot_history 18 | import os 19 | import utils 20 | 21 | def main(args,data_dir, model_name, num, lr, epochs, batch_size = 16, download_data = False): 22 | """ 23 | Main function 24 | 25 | Args: 26 | data_dir: directory to download Pascal VOC data 27 | model_name: resnet18, resnet34 or resnet50 28 | num: model_num for file management purposes (can be any postive integer. Your results stored will have this number as suffix) 29 | lr: initial learning rate list [lr for resnet_backbone, lr for resnet_fc] 30 | epochs: number of training epochs 31 | batch_size: batch size. Default=16 32 | download_data: Boolean. If true will download the entire 2012 pascal VOC data as tar to the specified data_dir. 33 | Set this to True only the first time you run it, and then set to False. Default False 34 | save_results: Store results (boolean). Default False 35 | 36 | Returns: 37 | test-time loss and average precision 38 | 39 | Example way of running this function: 40 | if __name__ == '__main__': 41 | main('../data/', "resnet34", num=1, lr = [1.5e-4, 5e-2], epochs = 15, batch_size=16, download_data=False, save_results=True) 42 | """ 43 | 44 | 45 | 46 | # Initialize cuda parameters 47 | use_cuda = torch.cuda.is_available() 48 | np.random.seed(2019) 49 | torch.manual_seed(2019) 50 | device = torch.device("cuda" if use_cuda else "cpu") 51 | 52 | print("Available device = ", device) 53 | model = models.__dict__[args.arch]() 54 | for name, param in model.named_parameters(): 55 | if name not in ['fc.weight', 'fc.bias']: 56 | param.requires_grad = False 57 | #model.avgpool = torch.nn.AdaptiveAvgPool2d(1) 58 | 59 | #model.load_state_dict(model_zoo.load_url(model_urls[model_name])) 60 | checkpoint = torch.load(args.pretrained, map_location="cpu") 61 | state_dict = checkpoint['state_dict'] 62 | for k in list(state_dict.keys()): 63 | # retain only encoder_q up to before the embedding layer 64 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 65 | # remove prefix 66 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 67 | # delete renamed or unused k 68 | del state_dict[k] 69 | msg = model.load_state_dict(state_dict, strict=False) 70 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 71 | 72 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 73 | num_ftrs = model.fc.in_features 74 | model.fc = torch.nn.Linear(num_ftrs, 20) 75 | model.fc.weight.data.normal_(mean=0.0, std=0.01) 76 | model.fc.bias.data.zero_() 77 | 78 | model.to(device) 79 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 80 | print("optimized parameters",parameters) 81 | optimizer = optim.SGD([ 82 | {'params': parameters, 'lr': lr, 'momentum': 0.9} 83 | ]) 84 | 85 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 12, eta_min=0, last_epoch=-1) 86 | 87 | # Imagnet values 88 | mean=[0.457342265910642, 0.4387686270106377, 0.4073427106250871] 89 | std=[0.26753769276329037, 0.2638145880487105, 0.2776826934044154] 90 | 91 | # mean=[0.485, 0.456, 0.406] 92 | # std=[0.229, 0.224, 0.225] 93 | 94 | transformations = transforms.Compose([transforms.Resize((300, 300)), 95 | # transforms.RandomChoice([ 96 | # transforms.CenterCrop(300), 97 | # transforms.RandomResizedCrop(300, scale=(0.80, 1.0)), 98 | # ]), 99 | transforms.RandomChoice([ 100 | transforms.ColorJitter(brightness=(0.80, 1.20)), 101 | transforms.RandomGrayscale(p = 0.25) 102 | ]), 103 | transforms.RandomHorizontalFlip(p = 0.25), 104 | transforms.RandomRotation(25), 105 | transforms.ToTensor(), 106 | transforms.Normalize(mean = mean, std = std), 107 | ]) 108 | 109 | transformations_valid = transforms.Compose([transforms.Resize(330), 110 | transforms.CenterCrop(300), 111 | transforms.ToTensor(), 112 | transforms.Normalize(mean = mean, std = std), 113 | ]) 114 | 115 | # Create train dataloader 116 | dataset_train = PascalVOC_Dataset(data_dir, 117 | year='2007', 118 | image_set='train', 119 | download=download_data, 120 | transform=transformations, 121 | target_transform=encode_labels) 122 | 123 | train_loader = DataLoader(dataset_train, batch_size=batch_size, num_workers=4, shuffle=True) 124 | 125 | # Create validation dataloader 126 | dataset_valid = PascalVOC_Dataset(data_dir, 127 | year='2007', 128 | image_set='val', 129 | download=download_data, 130 | transform=transformations_valid, 131 | target_transform=encode_labels) 132 | 133 | valid_loader = DataLoader(dataset_valid, batch_size=batch_size, num_workers=4) 134 | 135 | # Load the best weights before testing 136 | if not os.path.exists(args.log): 137 | os.mkdir(args.log) 138 | 139 | log_file = open(os.path.join(args.log, "log-{}.txt".format(num)), "w+") 140 | model_dir=os.path.join(args.log,"model") 141 | if not os.path.exists(model_dir): 142 | os.mkdir(model_dir) 143 | log_file.write("----------Experiment {} - {}-----------\n".format(num, model_name)) 144 | log_file.write("transformations == {}\n".format(transformations.__str__())) 145 | trn_hist, val_hist = train_model(model, device, optimizer, scheduler, train_loader, valid_loader, model_dir, num, epochs, log_file) 146 | torch.cuda.empty_cache() 147 | 148 | plot_history(trn_hist[0], val_hist[0], "Loss", os.path.join(model_dir, "loss-{}".format(num))) 149 | plot_history(trn_hist[1], val_hist[1], "Accuracy", os.path.join(model_dir, "accuracy-{}".format(num))) 150 | log_file.close() 151 | 152 | #---------------Test your model here--------------------------------------- 153 | # Load the best weights before testing 154 | print("Evaluating model on test set") 155 | print("Loading best weights") 156 | weights_file_path = os.path.join(model_dir, "model-{}.pth".format(num)) 157 | assert os.path.isfile(weights_file_path) 158 | print("Loading best weights") 159 | 160 | model.load_state_dict(torch.load(weights_file_path)) 161 | transformations_test = transforms.Compose([transforms.Resize(330), 162 | transforms.FiveCrop(300), 163 | transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), 164 | transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean = mean, std = std)(crop) for crop in crops])), 165 | ]) 166 | 167 | 168 | dataset_test = PascalVOC_Dataset(data_dir, 169 | year='2007', 170 | image_set='test', 171 | download=download_data, 172 | transform=transformations_test, 173 | target_transform=encode_labels) 174 | 175 | 176 | test_loader = DataLoader(dataset_test, batch_size=batch_size, num_workers=0, shuffle=False) 177 | 178 | loss, ap, scores, gt = test(model, device, test_loader, returnAllScores=True) 179 | 180 | gt_path, scores_path, scores_with_gt_path = os.path.join(model_dir, "gt-{}.csv".format(num)), os.path.join(model_dir, "scores-{}.csv".format(num)), os.path.join(model_dir, "scores_wth_gt-{}.csv".format(num)) 181 | 182 | utils.save_results(test_loader.dataset.images, gt, utils.object_categories, gt_path) 183 | utils.save_results(test_loader.dataset.images, scores, utils.object_categories, scores_path) 184 | utils.append_gt(gt_path, scores_path, scores_with_gt_path) 185 | 186 | utils.get_classification_accuracy(gt_path, scores_path, os.path.join(model_dir, "clf_vs_threshold-{}.png".format(num))) 187 | 188 | return loss, ap 189 | 190 | 191 | model_names = sorted(name for name in models.__dict__ 192 | if name.islower() and not name.startswith("__") 193 | and callable(models.__dict__[name])) 194 | # Execute main function here 195 | import argparse 196 | if __name__ == '__main__': 197 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 198 | parser.add_argument('--data', type=str, metavar='DIR', 199 | help='path to dataset') 200 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 201 | choices=model_names, 202 | help='model architecture: ' + 203 | ' | '.join(model_names) + 204 | ' (default: resnet50)') 205 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 206 | help='number of data loading workers (default: 32)') 207 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 208 | help='number of total epochs to run') 209 | parser.add_argument('--batch-size', default=16, type=int, 210 | metavar='N', 211 | help='mini-batch size (default: 256), this is the total ' 212 | 'batch size of all GPUs on the current node when ' 213 | 'using Data Parallel or Distributed Data Parallel') 214 | parser.add_argument('--lr', '--learning-rate', default=0.05, type=float, 215 | metavar='LR', help='initial learning rate', dest='lr') 216 | parser.add_argument('--gpu', default=None, type=str, 217 | help='GPU id to use.') 218 | parser.add_argument('--pretrained', default='', type=str, 219 | help='path to moco pretrained checkpoint') 220 | parser.add_argument("--log",default="train_log",type=str,help="log path for training") 221 | parser.add_argument("--run_num",default=1, type=int, help="specify the training saving path") 222 | args = parser.parse_args() 223 | choose = args.gpu 224 | if choose is not None: 225 | os.environ['CUDA_VISIBLE_DEVICES'] = choose 226 | main(args,args.data, args.arch, num=args.run_num, lr=args.lr, epochs=args.epochs, batch_size=args.batch_size) 227 | 228 | #if __name__ == '__main__': 229 | # main('../data/', "resnet34", num=1, lr = [1.5e-4, 5e-2], epochs = 1, batch_size=16, download_data=False, save_results=True) -------------------------------------------------------------------------------- /VOC_CLF/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Mar 13 11:22:28 2019 4 | 5 | @author: Keshik 6 | """ 7 | import torch 8 | from tqdm import tqdm 9 | import gc 10 | from sklearn.metrics import average_precision_score 11 | -------------------------------------------------------------------------------- /VOC_CLF/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Mar 13 10:37:39 2019 4 | 5 | @author: Keshik 6 | """ 7 | 8 | from tqdm import tqdm 9 | import torch 10 | import gc 11 | import os 12 | from utils import get_ap_score 13 | import numpy as np 14 | 15 | def train_model(model, device, optimizer, scheduler, train_loader, valid_loader, save_dir, model_num, epochs, log_file): 16 | """ 17 | Train a deep neural network model 18 | 19 | Args: 20 | model: pytorch model object 21 | device: cuda or cpu 22 | optimizer: pytorch optimizer object 23 | scheduler: learning rate scheduler object that wraps the optimizer 24 | train_dataloader: training images dataloader 25 | valid_dataloader: validation images dataloader 26 | save_dir: Location to save model weights, plots and log_file 27 | epochs: number of training epochs 28 | log_file: text file instance to record training and validation history 29 | 30 | Returns: 31 | Training history and Validation history (loss and average precision) 32 | """ 33 | 34 | tr_loss, tr_map = [], [] 35 | val_loss, val_map = [], [] 36 | best_val_map = 0.0 37 | 38 | # Each epoch has a training and validation phase 39 | for epoch in range(epochs): 40 | print("-------Epoch {}----------".format(epoch+1)) 41 | log_file.write("Epoch {} >>".format(epoch+1)) 42 | scheduler.step() 43 | 44 | for phase in ['train', 'valid']: 45 | running_loss = 0.0 46 | running_ap = 0.0 47 | 48 | criterion = torch.nn.BCEWithLogitsLoss(reduction='sum') 49 | m = torch.nn.Sigmoid() 50 | 51 | if phase == 'train': 52 | model.train(True) # Set model to training mode 53 | 54 | for data, target in tqdm(train_loader): 55 | #print(data) 56 | target = target.float() 57 | data, target = data.to(device), target.to(device) 58 | 59 | # zero the parameter gradients 60 | optimizer.zero_grad() 61 | 62 | output = model(data) 63 | 64 | loss = criterion(output, target) 65 | 66 | # Get metrics here 67 | running_loss += loss # sum up batch loss 68 | running_ap += get_ap_score(torch.Tensor.cpu(target).detach().numpy(), torch.Tensor.cpu(m(output)).detach().numpy()) 69 | 70 | # Backpropagate the system the determine the gradients 71 | loss.backward() 72 | 73 | # Update the paramteres of the model 74 | optimizer.step() 75 | 76 | # clear variables 77 | del data, target, output 78 | gc.collect() 79 | torch.cuda.empty_cache() 80 | 81 | #print("loss = ", running_loss) 82 | 83 | num_samples = float(len(train_loader.dataset)) 84 | tr_loss_ = running_loss.item()/num_samples 85 | tr_map_ = running_ap/num_samples 86 | 87 | print('train_loss: {:.4f}, train_avg_precision:{:.3f}'.format( 88 | tr_loss_, tr_map_)) 89 | 90 | log_file.write('train_loss: {:.4f}, train_avg_precision:{:.3f}, '.format( 91 | tr_loss_, tr_map_)) 92 | 93 | # Append the values to global arrays 94 | tr_loss.append(tr_loss_), tr_map.append(tr_map_) 95 | 96 | 97 | else: 98 | model.train(False) # Set model to evaluate mode 99 | 100 | # torch.no_grad is for memory savings 101 | with torch.no_grad(): 102 | for data, target in tqdm(valid_loader): 103 | target = target.float() 104 | data, target = data.to(device), target.to(device) 105 | output = model(data) 106 | 107 | loss = criterion(output, target) 108 | 109 | running_loss += loss # sum up batch loss 110 | running_ap += get_ap_score(torch.Tensor.cpu(target).detach().numpy(), torch.Tensor.cpu(m(output)).detach().numpy()) 111 | 112 | del data, target, output 113 | gc.collect() 114 | torch.cuda.empty_cache() 115 | 116 | num_samples = float(len(valid_loader.dataset)) 117 | val_loss_ = running_loss.item()/num_samples 118 | val_map_ = running_ap/num_samples 119 | 120 | # Append the values to global arrays 121 | val_loss.append(val_loss_), val_map.append(val_map_) 122 | 123 | print('val_loss: {:.4f}, val_avg_precision:{:.3f}'.format( 124 | val_loss_, val_map_)) 125 | 126 | log_file.write('val_loss: {:.4f}, val_avg_precision:{:.3f}\n'.format( 127 | val_loss_, val_map_)) 128 | 129 | # Save model using val_acc 130 | if val_map_ >= best_val_map: 131 | best_val_map = val_map_ 132 | log_file.write("saving best weights...\n") 133 | torch.save(model.state_dict(), os.path.join(save_dir,"model-{}.pth".format(model_num))) 134 | 135 | return ([tr_loss, tr_map], [val_loss, val_map]) 136 | 137 | 138 | 139 | def test(model, device, test_loader, returnAllScores=False): 140 | """ 141 | Evaluate a deep neural network model 142 | 143 | Args: 144 | model: pytorch model object 145 | device: cuda or cpu 146 | test_dataloader: test images dataloader 147 | returnAllScores: If true addtionally return all confidence scores and ground truth 148 | 149 | Returns: 150 | test loss and average precision. If returnAllScores = True, check Args 151 | """ 152 | model.train(False) 153 | 154 | running_loss = 0 155 | running_ap = 0 156 | 157 | criterion = torch.nn.BCEWithLogitsLoss(reduction='sum') 158 | m = torch.nn.Sigmoid() 159 | 160 | if returnAllScores == True: 161 | all_scores = np.empty((0, 20), float) 162 | ground_scores = np.empty((0, 20), float) 163 | 164 | with torch.no_grad(): 165 | for data, target in tqdm(test_loader): 166 | #print(data.size(), target.size()) 167 | target = target.float() 168 | data, target = data.to(device), target.to(device) 169 | bs, ncrops, c, h, w = data.size() 170 | 171 | output = model(data.view(-1, c, h, w)) 172 | output = output.view(bs, ncrops, -1).mean(1) 173 | 174 | loss = criterion(output, target) 175 | 176 | running_loss += loss # sum up batch loss 177 | running_ap += get_ap_score(torch.Tensor.cpu(target).detach().numpy(), torch.Tensor.cpu(m(output)).detach().numpy()) 178 | 179 | if returnAllScores == True: 180 | all_scores = np.append(all_scores, torch.Tensor.cpu(m(output)).detach().numpy() , axis=0) 181 | ground_scores = np.append(ground_scores, torch.Tensor.cpu(target).detach().numpy() , axis=0) 182 | 183 | del data, target, output 184 | gc.collect() 185 | torch.cuda.empty_cache() 186 | 187 | num_samples = float(len(test_loader.dataset)) 188 | avg_test_loss = running_loss.item()/num_samples 189 | test_map = running_ap/num_samples 190 | 191 | print('test_loss: {:.4f}, test_avg_precision:{:.3f}'.format( 192 | avg_test_loss, test_map)) 193 | 194 | 195 | if returnAllScores == False: 196 | return avg_test_loss, running_ap 197 | 198 | return avg_test_loss, running_ap, all_scores, ground_scores 199 | 200 | 201 | -------------------------------------------------------------------------------- /VOC_CLF/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 12 20:52:33 2019 4 | 5 | @author: Keshik 6 | """ 7 | import os 8 | import math 9 | from tqdm import tqdm 10 | import torch 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | from sklearn.metrics import average_precision_score, accuracy_score 14 | import pandas as pd 15 | 16 | object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 17 | 'bottle', 'bus', 'car', 'cat', 'chair', 18 | 'cow', 'diningtable', 'dog', 'horse', 19 | 'motorbike', 'person', 'pottedplant', 20 | 'sheep', 'sofa', 'train', 'tvmonitor'] 21 | 22 | 23 | def get_categories(labels_dir): 24 | """ 25 | Get the object categories 26 | 27 | Args: 28 | label_dir: Directory that contains object specific label as .txt files 29 | Raises: 30 | FileNotFoundError: If the label directory does not exist 31 | Returns: 32 | Object categories as a list 33 | """ 34 | 35 | if not os.path.isdir(labels_dir): 36 | raise FileNotFoundError 37 | 38 | else: 39 | categories = [] 40 | 41 | for file in os.listdir(labels_dir): 42 | if file.endswith("_train.txt"): 43 | categories.append(file.split("_")[0]) 44 | 45 | return categories 46 | 47 | 48 | def encode_labels(target): 49 | """ 50 | Encode multiple labels using 1/0 encoding 51 | 52 | Args: 53 | target: xml tree file 54 | Returns: 55 | torch tensor encoding labels as 1/0 vector 56 | """ 57 | 58 | ls = target['annotation']['object'] 59 | 60 | j = [] 61 | if type(ls) == dict: 62 | if int(ls['difficult']) == 0: 63 | j.append(object_categories.index(ls['name'])) 64 | 65 | else: 66 | for i in range(len(ls)): 67 | if int(ls[i]['difficult']) == 0: 68 | j.append(object_categories.index(ls[i]['name'])) 69 | 70 | k = np.zeros(len(object_categories)) 71 | k[j] = 1 72 | 73 | return torch.from_numpy(k) 74 | 75 | 76 | def get_nrows(file_name): 77 | """ 78 | Get the number of rows of a csv file 79 | 80 | Args: 81 | file_path: path of the csv file 82 | Raises: 83 | FileNotFoundError: If the csv file does not exist 84 | Returns: 85 | number of rows 86 | """ 87 | 88 | if not os.path.isfile(file_name): 89 | raise FileNotFoundError 90 | 91 | s = 0 92 | with open(file_name) as f: 93 | s = sum(1 for line in f) 94 | return s 95 | 96 | 97 | def get_mean_and_std(dataloader): 98 | """ 99 | Get the mean and std of a 3-channel image dataset 100 | 101 | Args: 102 | dataloader: pytorch dataloader 103 | Returns: 104 | mean and std of the dataset 105 | """ 106 | mean = [] 107 | std = [] 108 | 109 | total = 0 110 | r_running, g_running, b_running = 0, 0, 0 111 | r2_running, g2_running, b2_running = 0, 0, 0 112 | 113 | with torch.no_grad(): 114 | for data, target in tqdm(dataloader): 115 | r, g, b = data[:,0 ,:, :], data[:, 1, :, :], data[:, 2, :, :] 116 | r2, g2, b2 = r**2, g**2, b**2 117 | 118 | # Sum up values to find mean 119 | r_running += r.sum().item() 120 | g_running += g.sum().item() 121 | b_running += b.sum().item() 122 | 123 | # Sum up squared values to find standard deviation 124 | r2_running += r2.sum().item() 125 | g2_running += g2.sum().item() 126 | b2_running += b2.sum().item() 127 | 128 | total += data.size(0)*data.size(2)*data.size(3) 129 | 130 | # Append the mean values 131 | mean.extend([r_running/total, 132 | g_running/total, 133 | b_running/total]) 134 | 135 | # Calculate standard deviation and append 136 | std.extend([ 137 | math.sqrt((r2_running/total) - mean[0]**2), 138 | math.sqrt((g2_running/total) - mean[1]**2), 139 | math.sqrt((b2_running/total) - mean[2]**2) 140 | ]) 141 | 142 | return mean, std 143 | 144 | 145 | def plot_history(train_hist, val_hist, y_label, filename, labels=["train", "validation"]): 146 | """ 147 | Plot training and validation history 148 | 149 | Args: 150 | train_hist: numpy array consisting of train history values (loss/ accuracy metrics) 151 | valid_hist: numpy array consisting of validation history values (loss/ accuracy metrics) 152 | y_label: label for y_axis 153 | filename: filename to store the resulting plot 154 | labels: legend for the plot 155 | 156 | Returns: 157 | None 158 | """ 159 | # Plot loss and accuracy 160 | xi = [i for i in range(0, len(train_hist), 2)] 161 | plt.plot(train_hist, label = labels[0]) 162 | plt.plot(val_hist, label = labels[1]) 163 | plt.xticks(xi) 164 | plt.legend() 165 | plt.xlabel("Epoch") 166 | plt.ylabel(y_label) 167 | plt.savefig(filename) 168 | plt.show() 169 | 170 | 171 | def get_ap_score(y_true, y_scores): 172 | """ 173 | Get average precision score between 2 1-d numpy arrays 174 | 175 | Args: 176 | y_true: batch of true labels 177 | y_scores: batch of confidence scores 178 | = 179 | Returns: 180 | sum of batch average precision 181 | """ 182 | scores = 0.0 183 | 184 | for i in range(y_true.shape[0]): 185 | scores += average_precision_score(y_true = y_true[i], y_score = y_scores[i]) 186 | 187 | return scores 188 | 189 | def save_results(images, scores, columns, filename): 190 | """ 191 | Save inference results as csv 192 | 193 | Args: 194 | images: inferred image list 195 | scores: confidence score for inferred images 196 | columns: object categories 197 | filename: name and location to save resulting csv 198 | """ 199 | df_scores = pd.DataFrame(scores, columns=columns) 200 | df_scores['image'] = images 201 | df_scores.set_index('image', inplace=True) 202 | df_scores.to_csv(filename) 203 | 204 | 205 | def append_gt(gt_csv_path, scores_csv_path, store_filename): 206 | """ 207 | Append ground truth to confidence score csv 208 | 209 | Args: 210 | gt_csv_path: Ground truth csv location 211 | scores_csv_path: Confidence scores csv path 212 | store_filename: name and location to save resulting csv 213 | """ 214 | gt_df = pd.read_csv(gt_csv_path) 215 | scores_df = pd.read_csv(scores_csv_path) 216 | 217 | gt_label_list = [] 218 | for index, row in gt_df.iterrows(): 219 | arr = np.array(gt_df.iloc[index,1:], dtype=int) 220 | target_idx = np.ravel(np.where(arr == 1)) 221 | j = [object_categories[i] for i in target_idx] 222 | gt_label_list.append(j) 223 | 224 | scores_df.insert(1, "gt", gt_label_list) 225 | scores_df.to_csv(store_filename, index=False) 226 | 227 | 228 | 229 | def get_classification_accuracy(gt_csv_path, scores_csv_path, store_filename): 230 | """ 231 | Plot mean tail accuracy across all classes for threshold values 232 | 233 | Args: 234 | gt_csv_path: Ground truth csv location 235 | scores_csv_path: Confidence scores csv path 236 | store_filename: name and location to save resulting plot 237 | """ 238 | gt_df = pd.read_csv(gt_csv_path) 239 | scores_df = pd.read_csv(scores_csv_path) 240 | 241 | # Get the top-50 images 242 | top_num = 2800 243 | image_num = 2 244 | num_threshold = 10 245 | results = [] 246 | 247 | for image_num in range(1, 21): 248 | clf = np.sort(np.array(scores_df.iloc[:,image_num], dtype=float))[-top_num:] 249 | ls = np.linspace(0.0, 1.0, num=num_threshold) 250 | 251 | class_results = [] 252 | for i in ls: 253 | clf = np.sort(np.array(scores_df.iloc[:,image_num], dtype=float))[-top_num:] 254 | clf_ind = np.argsort(np.array(scores_df.iloc[:,image_num], dtype=float))[-top_num:] 255 | 256 | # Read ground truth 257 | gt = np.sort(np.array(gt_df.iloc[:,image_num], dtype=int)) 258 | 259 | # Now get the ground truth corresponding to top-50 scores 260 | gt = gt[clf_ind] 261 | clf[clf >= i] = 1 262 | clf[clf < i] = 0 263 | 264 | score = accuracy_score(y_true=gt, y_pred=clf, normalize=False)/clf.shape[0] 265 | class_results.append(score) 266 | 267 | results.append(class_results) 268 | 269 | results = np.asarray(results) 270 | 271 | ls = np.linspace(0.0, 1.0, num=num_threshold) 272 | plt.plot(ls, results.mean(0)) 273 | plt.title("Mean Tail Accuracy vs Threshold") 274 | plt.xlabel("Threshold") 275 | plt.ylabel("Mean Tail Accuracy") 276 | plt.savefig(store_filename) 277 | plt.show() 278 | 279 | 280 | #get_classification_accuracy("../models/resnet18/results.csv", "../models/resnet18/gt.csv", "roc-curve.png") 281 | -------------------------------------------------------------------------------- /cmd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/cmd/__init__.py -------------------------------------------------------------------------------- /cmd/run_multi.sh: -------------------------------------------------------------------------------- 1 | python3 main_clsa.py --data=[data_path] --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 1 1 1 --size_crops 224 192 160 128 96 --min_scale_crops 0.2 0.172 0.143 0.114 0.086 --max_scale_crops 1.0 0.86 0.715 0.571 0.429 --pick_strong 0 1 2 3 4 --pick_weak 0 1 2 3 4 --clsa_t 0.2 --sym 0 2 | -------------------------------------------------------------------------------- /cmd/run_single.sh: -------------------------------------------------------------------------------- 1 | python3 main_clsa.py --data=/data/imagenet --workers=32 --epochs=200 --start_epoch=0 --batch_size=256 --lr=0.03 --weight_decay=1e-4 --print_freq=100 --world_size=1 --rank=0 --dist_url=tcp://localhost:10001 --moco_dim=128 --moco_k=65536 --moco_m=0.999 --moco_t=0.2 --alpha=1 --aug_times=5 --nmb_crops 1 1 --size_crops 224 96 --min_scale_crops 0.2 0.086 --max_scale_crops 1.0 0.429 --pick_strong 1 --pick_weak 0 --clsa_t 0.2 --sym 0 2 | -------------------------------------------------------------------------------- /data_processing/Image_ops.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFilter 2 | import random 3 | class TwoCropsTransform: 4 | """Take two random crops of one image as the query and key.""" 5 | 6 | def __init__(self, base_transform): 7 | self.base_transform = base_transform 8 | 9 | def __call__(self, x): 10 | q = self.base_transform(x) 11 | k = self.base_transform(x) 12 | return [q, k] 13 | 14 | 15 | class GaussianBlur(object): 16 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 17 | 18 | def __init__(self, sigma=[.1, 2.]): 19 | self.sigma = sigma 20 | 21 | def __call__(self, x): 22 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 23 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 24 | return x -------------------------------------------------------------------------------- /data_processing/Multi_FixTransform.py: -------------------------------------------------------------------------------- 1 | #modified from https://github.com/facebookresearch/swav/blob/master/src/multicropdataset.py 2 | from torchvision import transforms 3 | from data_processing.RandAugment import RandAugment 4 | from data_processing.Image_ops import GaussianBlur 5 | class Multi_Fixtransform(object): 6 | def __init__(self, 7 | size_crops, 8 | nmb_crops, 9 | min_scale_crops, 10 | max_scale_crops,normalize, 11 | aug_times,init_size=224): 12 | """ 13 | :param size_crops: list of crops with crop output img size 14 | :param nmb_crops: number of output cropped image 15 | :param min_scale_crops: minimum scale for corresponding crop 16 | :param max_scale_crops: maximum scale for corresponding crop 17 | :param normalize: normalize operation 18 | :param aug_times: strong augmentation times 19 | :param init_size: key image size 20 | """ 21 | assert len(size_crops) == len(nmb_crops) 22 | assert len(min_scale_crops) == len(nmb_crops) 23 | assert len(max_scale_crops) == len(nmb_crops) 24 | trans=[] 25 | #key image transform 26 | self.weak = transforms.Compose([ 27 | transforms.RandomResizedCrop(init_size, scale=(0.2, 1.)), 28 | transforms.RandomApply([ 29 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 30 | ], p=0.8), 31 | transforms.RandomGrayscale(p=0.2), 32 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | normalize 36 | ]) 37 | trans.append(self.weak) 38 | self.aug_times=aug_times 39 | trans_weak=[] 40 | trans_strong=[] 41 | for i in range(len(size_crops)): 42 | randomresizedcrop = transforms.RandomResizedCrop( 43 | size_crops[i], 44 | scale=(min_scale_crops[i], max_scale_crops[i]), 45 | ) 46 | 47 | strong = transforms.Compose([ 48 | randomresizedcrop, 49 | transforms.RandomApply([ 50 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 51 | ], p=0.8), 52 | transforms.RandomGrayscale(p=0.2), 53 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 54 | transforms.RandomHorizontalFlip(), 55 | RandAugment(n=self.aug_times, m=10), 56 | transforms.ToTensor(), 57 | normalize 58 | ]) 59 | weak=transforms.Compose([ 60 | randomresizedcrop, 61 | transforms.RandomApply([ 62 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 63 | ], p=0.8), 64 | transforms.RandomGrayscale(p=0.2), 65 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 66 | transforms.RandomHorizontalFlip(), 67 | transforms.ToTensor(), 68 | normalize 69 | ]) 70 | trans_weak.extend([weak]*nmb_crops[i]) 71 | trans_strong.extend([strong]*nmb_crops[i]) 72 | trans.extend(trans_weak) 73 | trans.extend(trans_strong) 74 | self.trans=trans 75 | def __call__(self, x): 76 | multi_crops = list(map(lambda trans: trans(x), self.trans)) 77 | return multi_crops 78 | -------------------------------------------------------------------------------- /data_processing/RandAugment.py: -------------------------------------------------------------------------------- 1 | 2 | # code in this file is adpated from 3 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 5 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 6 | import logging 7 | import random 8 | 9 | import numpy as np 10 | import PIL 11 | import PIL.ImageOps 12 | import PIL.ImageEnhance 13 | import PIL.ImageDraw 14 | from PIL import Image 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | PARAMETER_MAX = 10 19 | 20 | 21 | def AutoContrast(img, **kwarg): 22 | return PIL.ImageOps.autocontrast(img) 23 | 24 | 25 | def Brightness(img, v, max_v, bias=0): 26 | v = _float_parameter(v, max_v) + bias 27 | return PIL.ImageEnhance.Brightness(img).enhance(v) 28 | 29 | 30 | def Color(img, v, max_v, bias=0): 31 | v = _float_parameter(v, max_v) + bias 32 | return PIL.ImageEnhance.Color(img).enhance(v) 33 | 34 | 35 | def Contrast(img, v, max_v, bias=0): 36 | v = _float_parameter(v, max_v) + bias 37 | return PIL.ImageEnhance.Contrast(img).enhance(v) 38 | 39 | 40 | def Cutout(img, v, max_v, bias=0): 41 | if v == 0: 42 | return img 43 | v = _float_parameter(v, max_v) + bias 44 | v = int(v * min(img.size)) 45 | return CutoutAbs(img, v) 46 | 47 | 48 | def CutoutAbs(img, v, **kwarg): 49 | w, h = img.size 50 | x0 = np.random.uniform(0, w) 51 | y0 = np.random.uniform(0, h) 52 | x0 = int(max(0, x0 - v / 2.)) 53 | y0 = int(max(0, y0 - v / 2.)) 54 | x1 = int(min(w, x0 + v)) 55 | y1 = int(min(h, y0 + v)) 56 | xy = (x0, y0, x1, y1) 57 | # gray 58 | color = (127, 127, 127) 59 | img = img.copy() 60 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 61 | return img 62 | 63 | 64 | def Equalize(img, **kwarg): 65 | return PIL.ImageOps.equalize(img) 66 | 67 | 68 | def Identity(img, **kwarg): 69 | return img 70 | 71 | 72 | def Invert(img, **kwarg): 73 | return PIL.ImageOps.invert(img) 74 | 75 | 76 | def Posterize(img, v, max_v, bias=0): 77 | v = _int_parameter(v, max_v) + bias 78 | return PIL.ImageOps.posterize(img, v) 79 | 80 | 81 | def Rotate(img, v, max_v, bias=0): 82 | v = _int_parameter(v, max_v) + bias 83 | if random.random() < 0.5: 84 | v = -v 85 | return img.rotate(v) 86 | 87 | 88 | def Sharpness(img, v, max_v, bias=0): 89 | v = _float_parameter(v, max_v) + bias 90 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 91 | 92 | 93 | def ShearX(img, v, max_v, bias=0): 94 | v = _float_parameter(v, max_v) + bias 95 | if random.random() < 0.5: 96 | v = -v 97 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 98 | 99 | 100 | def ShearY(img, v, max_v, bias=0): 101 | v = _float_parameter(v, max_v) + bias 102 | if random.random() < 0.5: 103 | v = -v 104 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 105 | 106 | 107 | def Solarize(img, v, max_v, bias=0): 108 | v = _int_parameter(v, max_v) + bias 109 | return PIL.ImageOps.solarize(img, 256 - v) 110 | 111 | 112 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 113 | v = _int_parameter(v, max_v) + bias 114 | if random.random() < 0.5: 115 | v = -v 116 | img_np = np.array(img).astype(np.int) 117 | img_np = img_np + v 118 | img_np = np.clip(img_np, 0, 255) 119 | img_np = img_np.astype(np.uint8) 120 | img = Image.fromarray(img_np) 121 | return PIL.ImageOps.solarize(img, threshold) 122 | 123 | 124 | def TranslateX(img, v, max_v, bias=0): 125 | v = _float_parameter(v, max_v) + bias 126 | if random.random() < 0.5: 127 | v = -v 128 | v = int(v * img.size[0]) 129 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 130 | 131 | 132 | def TranslateY(img, v, max_v, bias=0): 133 | v = _float_parameter(v, max_v) + bias 134 | if random.random() < 0.5: 135 | v = -v 136 | v = int(v * img.size[1]) 137 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 138 | 139 | 140 | def _float_parameter(v, max_v): 141 | return float(v) * max_v / PARAMETER_MAX 142 | 143 | 144 | def _int_parameter(v, max_v): 145 | return int(v * max_v / PARAMETER_MAX) 146 | 147 | 148 | def fixmatch_augment_pool(): 149 | # FixMatch paper 150 | augs = [(AutoContrast, None, None), 151 | (Brightness, 0.9, 0.05), 152 | (Color, 0.9, 0.05), 153 | (Contrast, 0.9, 0.05), 154 | (Equalize, None, None), 155 | (Identity, None, None), 156 | (Posterize, 4, 4), 157 | (Rotate, 30, 0), 158 | (Sharpness, 0.9, 0.05), 159 | (ShearX, 0.3, 0), 160 | (ShearY, 0.3, 0), 161 | (Solarize, 256, 0), 162 | (TranslateX, 0.3, 0), 163 | (TranslateY, 0.3, 0)] 164 | return augs 165 | 166 | 167 | def my_augment_pool(): 168 | # Test 169 | augs = [(AutoContrast, None, None), 170 | (Brightness, 1.8, 0.1), 171 | (Color, 1.8, 0.1), 172 | (Contrast, 1.8, 0.1), 173 | (Cutout, 0.2, 0), 174 | (Equalize, None, None), 175 | (Invert, None, None), 176 | (Posterize, 4, 4), 177 | (Rotate, 30, 0), 178 | (Sharpness, 1.8, 0.1), 179 | (ShearX, 0.3, 0), 180 | (ShearY, 0.3, 0), 181 | (Solarize, 256, 0), 182 | (SolarizeAdd, 110, 0), 183 | (TranslateX, 0.45, 0), 184 | (TranslateY, 0.45, 0)] 185 | return augs 186 | 187 | 188 | class RandAugmentPC(object): 189 | def __init__(self, n, m): 190 | assert n >= 1 191 | assert 1 <= m <= 10 192 | self.n = n 193 | self.m = m 194 | self.augment_pool = my_augment_pool() 195 | 196 | def __call__(self, img): 197 | ops = random.choices(self.augment_pool, k=self.n) 198 | for op, max_v, bias in ops: 199 | prob = np.random.uniform(0.2, 0.8) 200 | if random.random() + prob >= 1: 201 | img = op(img, v=self.m, max_v=max_v, bias=bias) 202 | img = CutoutAbs(img, 16) 203 | return img 204 | 205 | 206 | class RandAugment(object): 207 | def __init__(self, n, m): 208 | assert n >= 0 209 | assert 1 <= m <= 10 210 | self.n = n 211 | self.m = m 212 | self.augment_pool = fixmatch_augment_pool() 213 | def __call__(self, img): 214 | ops = random.choices(self.augment_pool, k=self.n) 215 | for op, max_v, bias in ops: 216 | v = np.random.randint(1, self.m) 217 | if random.random() < 0.5: 218 | img = op(img, v=v, max_v=max_v, bias=bias) 219 | return img 220 | 221 | -------------------------------------------------------------------------------- /data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/data_processing/__init__.py -------------------------------------------------------------------------------- /detection/configs/Base-RCNN-C4-BN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | RPN: 4 | PRE_NMS_TOPK_TEST: 6000 5 | POST_NMS_TOPK_TEST: 1000 6 | ROI_HEADS: 7 | NAME: "Res5ROIHeadsExtraNorm" 8 | BACKBONE: 9 | FREEZE_AT: 0 10 | RESNETS: 11 | NORM: "SyncBN" 12 | TEST: 13 | PRECISE_BN: 14 | ENABLED: True 15 | SOLVER: 16 | IMS_PER_BATCH: 16 17 | BASE_LR: 0.02 18 | -------------------------------------------------------------------------------- /detection/configs/coco_R_50_C4_2x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-C4-BN.yaml" 2 | MODEL: 3 | MASK_ON: True 4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 5 | INPUT: 6 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 7 | MIN_SIZE_TEST: 800 8 | DATASETS: 9 | TRAIN: ("coco_2017_train",) 10 | TEST: ("coco_2017_val",) 11 | SOLVER: 12 | STEPS: (120000, 160000) 13 | MAX_ITER: 180000 14 | -------------------------------------------------------------------------------- /detection/configs/coco_R_50_C4_2x_clsa.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "coco_R_50_C4_2x.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "See Instructions" 6 | RESNETS: 7 | STRIDE_IN_1X1: False 8 | INPUT: 9 | FORMAT: "RGB" 10 | -------------------------------------------------------------------------------- /detection/configs/pascal_voc_R_50_C4_24k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-C4-BN.yaml" 2 | MODEL: 3 | MASK_ON: False 4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 5 | ROI_HEADS: 6 | NUM_CLASSES: 20 7 | INPUT: 8 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 9 | MIN_SIZE_TEST: 800 10 | DATASETS: 11 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 12 | TEST: ('voc_2007_test',) 13 | SOLVER: 14 | STEPS: (18000, 22000) 15 | MAX_ITER: 24000 16 | WARMUP_ITERS: 100 17 | -------------------------------------------------------------------------------- /detection/configs/pascal_voc_R_50_C4_24k_CLSA.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "pascal_voc_R_50_C4_24k.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "See Instructions" 6 | RESNETS: 7 | STRIDE_IN_1X1: False 8 | INPUT: 9 | FORMAT: "RGB" 10 | -------------------------------------------------------------------------------- /detection/convert-pretrain-to-detectron2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | adopted from https://github.com/facebookresearch/moco/tree/master/detection 4 | Created on Thur Mar 4 14:37:25 2021 5 | 6 | @author: Facebook 7 | """ 8 | import pickle as pkl 9 | import sys 10 | import torch 11 | 12 | if __name__ == "__main__": 13 | input = sys.argv[1] 14 | 15 | obj = torch.load(input, map_location="cpu") 16 | obj = obj["state_dict"] 17 | 18 | newmodel = {} 19 | for k, v in obj.items(): 20 | if not k.startswith("module.encoder_q."): 21 | continue 22 | old_k = k 23 | k = k.replace("module.encoder_q.", "") 24 | if "layer" not in k: 25 | k = "stem." + k 26 | for t in [1, 2, 3, 4]: 27 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 28 | for t in [1, 2, 3]: 29 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 30 | k = k.replace("downsample.0", "shortcut") 31 | k = k.replace("downsample.1", "shortcut.norm") 32 | print(old_k, "->", k) 33 | newmodel[k] = v.numpy() 34 | 35 | res = {"model": newmodel, "__author__": "MOCO", "matching_heuristics": True} 36 | 37 | with open(sys.argv[2], "wb") as f: 38 | pkl.dump(res, f) 39 | -------------------------------------------------------------------------------- /detection/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | adopted from https://github.com/facebookresearch/moco/tree/master/detection 4 | Created on Thur Mar 4 14:37:25 2021 5 | 6 | @author: Facebook 7 | """ 8 | 9 | import os 10 | 11 | from detectron2.checkpoint import DetectionCheckpointer 12 | from detectron2.config import get_cfg 13 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch 14 | from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator 15 | from detectron2.layers import get_norm 16 | from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads 17 | 18 | 19 | @ROI_HEADS_REGISTRY.register() 20 | class Res5ROIHeadsExtraNorm(Res5ROIHeads): 21 | """ 22 | As described in the MOCO paper, there is an extra BN layer 23 | following the res5 stage. 24 | """ 25 | def _build_res5_block(self, cfg): 26 | seq, out_channels = super()._build_res5_block(cfg) 27 | norm = cfg.MODEL.RESNETS.NORM 28 | norm = get_norm(norm, out_channels) 29 | seq.add_module("norm", norm) 30 | return seq, out_channels 31 | 32 | 33 | class Trainer(DefaultTrainer): 34 | @classmethod 35 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 36 | if output_folder is None: 37 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 38 | if "coco" in dataset_name: 39 | return COCOEvaluator(dataset_name, cfg, True, output_folder) 40 | else: 41 | assert "voc" in dataset_name 42 | return PascalVOCDetectionEvaluator(dataset_name) 43 | 44 | 45 | def setup(args): 46 | cfg = get_cfg() 47 | cfg.merge_from_file(args.config_file) 48 | cfg.merge_from_list(args.opts) 49 | cfg.freeze() 50 | default_setup(cfg, args) 51 | return cfg 52 | 53 | 54 | def main(args): 55 | cfg = setup(args) 56 | 57 | if args.eval_only: 58 | model = Trainer.build_model(cfg) 59 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 60 | cfg.MODEL.WEIGHTS, resume=args.resume 61 | ) 62 | res = Trainer.test(cfg, model) 63 | return res 64 | 65 | trainer = Trainer(cfg) 66 | trainer.resume_or_load(resume=args.resume) 67 | return trainer.train() 68 | 69 | 70 | if __name__ == "__main__": 71 | args = default_argument_parser().parse_args() 72 | print("Command Line Args:", args) 73 | launch( 74 | main, 75 | args.num_gpus, 76 | num_machines=args.num_machines, 77 | machine_rank=args.machine_rank, 78 | dist_url=args.dist_url, 79 | args=(args,), 80 | ) 81 | -------------------------------------------------------------------------------- /lincls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import warnings 4 | 5 | warnings.filterwarnings('ignore') 6 | import argparse 7 | import builtins 8 | import os 9 | import random 10 | import shutil 11 | import time 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.multiprocessing as mp 20 | import torch.utils.data 21 | import torch.utils.data.distributed 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | import torchvision.models as models 25 | 26 | from data_processing.loader import GaussianBlur 27 | from ops.os_operation import mkdir 28 | from training.train_utils import accuracy 29 | 30 | model_names = sorted(name for name in models.__dict__ 31 | if name.islower() and not name.startswith("__") 32 | and callable(models.__dict__[name])) 33 | 34 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 35 | parser.add_argument('--data', type=str, metavar='DIR', 36 | help='path to dataset') 37 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 38 | choices=model_names, 39 | help='model architecture: ' + 40 | ' | '.join(model_names) + 41 | ' (default: resnet50)') 42 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 43 | help='number of data loading workers (default: 32)') 44 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 45 | help='number of total epochs to run') 46 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 47 | help='manual epoch number (useful on restarts)') 48 | parser.add_argument('--batch-size', default=256, type=int, 49 | metavar='N', 50 | help='mini-batch size (default: 256), this is the total ' 51 | 'batch size of all GPUs on the current node when ' 52 | 'using Data Parallel or Distributed Data Parallel') 53 | parser.add_argument('--lr', '--learning-rate', default=10., type=float, 54 | metavar='LR', help='initial learning rate', dest='lr') 55 | parser.add_argument('--schedule', default=[15, 25, 30], nargs='*', type=int, 56 | help='learning rate schedule (when to drop lr by a ratio)') # default is for places205 57 | parser.add_argument('--cos', type=int, default=1, 58 | help='use cosine lr schedule') 59 | 60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 61 | help='momentum') 62 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 63 | metavar='W', help='weight decay (default: 0.)', 64 | dest='weight_decay') 65 | parser.add_argument('-p', '--print-freq', default=10, type=int, 66 | metavar='N', help='print frequency (default: 10)') 67 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 68 | help='path to latest checkpoint (default: none)') 69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 70 | help='evaluate model on validation set') 71 | parser.add_argument('--world-size', default=1, type=int, 72 | help='number of nodes for distributed training') 73 | parser.add_argument('--rank', default=0, type=int, 74 | help='node rank for distributed training') 75 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 76 | help='url used to set up distributed training') 77 | parser.add_argument('--dist-backend', default='nccl', type=str, 78 | help='distributed backend') 79 | parser.add_argument('--seed', default=None, type=int, 80 | help='seed for initializing training. ') 81 | parser.add_argument('--gpu', default=None, type=int, 82 | help='GPU id to use.') 83 | parser.add_argument('--multiprocessing-distributed', type=int, default=1, 84 | help='Use multi-processing distributed training to launch ' 85 | 'N processes per node, which has N GPUs. This is the ' 86 | 'fastest way to use PyTorch for either single node or ' 87 | 'multi node data parallel training') 88 | 89 | parser.add_argument('--pretrained', default='', type=str, 90 | help='path to moco pretrained checkpoint') 91 | parser.add_argument('--choose', type=str, default=None, help="choose gpu for training") 92 | parser.add_argument("--dataset", type=str, default="ImageNet", help="which dataset is used to finetune") 93 | parser.add_argument("--aug", type=int, default=0, help="use augmentation or not during fine tuning") 94 | parser.add_argument("--size_crops", type=int, default=[224, 192, 160, 128, 96], nargs="+", 95 | help="crops resolutions (example: [224, 96])") 96 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.172, 0.143, 0.114, 0.086], nargs="+", 97 | help="argument in RandomResizedCrop (example: [0.14, 0.05])") 98 | parser.add_argument("--max_scale_crops", type=float, default=[1.0, 0.86, 0.715, 0.571, 0.429], nargs="+", 99 | help="argument in RandomResizedCrop (example: [1., 0.14])") 100 | parser.add_argument("--add_crop", type=int, default=0, help="use crop or not in our training dataset") 101 | parser.add_argument("--strong", type=int, default=0, help="use strong augmentation or not") 102 | parser.add_argument("--final_lr", type=float, default=0.01, help="ending learning rate for training") 103 | parser.add_argument("--aug_type", type=int, default=0, help="augmentation type for our condition") 104 | parser.add_argument('--save_path', default="", type=str, help="model and record save path") 105 | parser.add_argument('--log_path', type=str, default="train_log", help="log path for saving models") 106 | parser.add_argument("--nodes_num", type=int, default=1, help="number of nodes to use") 107 | parser.add_argument("--ngpu", type=int, default=8, help="number of gpus per node") 108 | parser.add_argument("--master_addr", type=str, default="127.0.0.1", help="addr for master node") 109 | parser.add_argument("--master_port", type=str, default="1234", help="port for master node") 110 | parser.add_argument('--node_rank', type=int, default=0, help='rank of machine, 0 to nodes_num-1') 111 | parser.add_argument("--final", default=0, type=int, help="use the final specified augment or not") 112 | parser.add_argument("--avg_pool", default=1, type=int, help="average pool output size") 113 | parser.add_argument("--crop_scale", type=float, default=[0.2, 1.0], nargs="+", 114 | help="argument in RandomResizedCrop (example: [1., 0.14])") 115 | parser.add_argument("--train_strong", type=int, default=0, help="training use stronger augmentation or not") 116 | parser.add_argument("--sgdr", type=int, default=0, help="training with warm up (1) or restart warm up (2)") 117 | parser.add_argument("--sgdr_t0", type=int, default=10, help="sgdr t0") 118 | parser.add_argument("--sgdr_t_mult", type=int, default=1, help="sgdr t mult") 119 | parser.add_argument("--dropout", type=float, default=0.0, help="dropout layer settings") 120 | parser.add_argument("--randcrop", type=int, default=0, help="use random crop or not") 121 | best_acc1 = 0 122 | 123 | 124 | def main(): 125 | args = parser.parse_args() 126 | choose = args.choose 127 | if choose is not None: 128 | os.environ['CUDA_VISIBLE_DEVICES'] = choose 129 | if args.seed is not None: 130 | random.seed(args.seed) 131 | torch.manual_seed(args.seed) 132 | cudnn.deterministic = True 133 | warnings.warn('You have chosen to seed training. ' 134 | 'This will turn on the CUDNN deterministic setting, ' 135 | 'which can slow down your training considerably! ' 136 | 'You may see unexpected behavior when restarting ' 137 | 'from checkpoints.') 138 | 139 | if args.gpu is not None: 140 | warnings.warn('You have chosen a specific GPU. This will completely ' 141 | 'disable data parallelism.') 142 | 143 | if args.dist_url == "env://" and args.world_size == -1: 144 | args.world_size = int(os.environ["WORLD_SIZE"]) 145 | 146 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 147 | params = vars(args) 148 | data_path = args.data # the path stored 149 | args.data = data_path 150 | ngpus_per_node = torch.cuda.device_count() 151 | if args.multiprocessing_distributed: 152 | # Since we have ngpus_per_node processes per node, the total world_size 153 | # needs to be adjusted accordingly 154 | args.world_size = ngpus_per_node * args.world_size 155 | # Use torch.multiprocessing.spawn to launch distributed processes: the 156 | # main_worker process function 157 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 158 | else: 159 | # Simply call main_worker function 160 | main_worker(args.gpu, ngpus_per_node, args) 161 | 162 | 163 | def main_worker(gpu, ngpus_per_node, args): 164 | global best_acc1 165 | args.gpu = gpu 166 | params = vars(args) 167 | # suppress printing if not master 168 | if args.multiprocessing_distributed and args.gpu != 0: 169 | def print_pass(*args): 170 | pass 171 | 172 | builtins.print = print_pass 173 | 174 | if args.gpu is not None: 175 | print("Use GPU: {} for training".format(args.gpu)) 176 | 177 | if args.distributed: 178 | if args.dist_url == "env://" and args.rank == -1: 179 | args.rank = int(os.environ["RANK"]) 180 | if args.multiprocessing_distributed: 181 | # For multiprocessing distributed training, rank needs to be the 182 | # global rank among all the processes 183 | args.rank = args.rank * ngpus_per_node + gpu 184 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 185 | world_size=args.world_size, rank=args.rank) 186 | # create model 187 | print("=> creating model '{}'".format(args.arch)) 188 | if args.dataset == "Place205": 189 | num_classes = 205 190 | else: 191 | num_classes = 1000 192 | 193 | model = models.__dict__[args.arch](num_classes=num_classes) 194 | 195 | # freeze all layers but the last fc 196 | for name, param in model.named_parameters(): 197 | if name not in ['fc.weight', 'fc.bias']: 198 | param.requires_grad = False 199 | 200 | # init the fc layer 201 | model.fc.weight.data.normal_(mean=0.0, std=0.01) 202 | model.fc.bias.data.zero_() 203 | 204 | # load from pre-trained, before DistributedDataParallel constructor 205 | if args.pretrained: 206 | 207 | if os.path.isfile(args.pretrained): 208 | print("=> loading checkpoint '{}'".format(args.pretrained)) 209 | 210 | checkpoint = torch.load(args.pretrained, map_location="cpu") 211 | state_dict = checkpoint['state_dict'] 212 | for k in list(state_dict.keys()): 213 | # retain only encoder_q up to before the embedding layer 214 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 215 | # remove prefix 216 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 217 | # delete renamed or unused k 218 | del state_dict[k] 219 | 220 | args.start_epoch = 0 221 | msg = model.load_state_dict(state_dict, strict=False) 222 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 223 | 224 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 225 | else: 226 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 227 | 228 | if args.dropout != 0.0: 229 | model.fc = nn.Sequential(nn.Dropout(args.dropout), model.fc) 230 | if args.distributed: 231 | # For multiprocessing distributed, DistributedDataParallel constructor 232 | # should always set the single device scope, otherwise, 233 | # DistributedDataParallel will use all available devices. 234 | if args.gpu is not None: 235 | torch.cuda.set_device(args.gpu) 236 | model.cuda(args.gpu) 237 | # When using a single GPU per process and per 238 | # DistributedDataParallel, we need to divide the batch size 239 | # ourselves based on the total number of GPUs we have 240 | args.batch_size = int(args.batch_size / ngpus_per_node) 241 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 242 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 243 | else: 244 | model.cuda() 245 | # DistributedDataParallel will divide and allocate batch_size to all 246 | # available GPUs if device_ids are not set 247 | model = torch.nn.parallel.DistributedDataParallel(model) 248 | elif args.gpu is not None: 249 | torch.cuda.set_device(args.gpu) 250 | model = model.cuda(args.gpu) 251 | else: 252 | # DataParallel will divide and allocate batch_size to all available GPUs 253 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 254 | model.features = torch.nn.DataParallel(model.features) 255 | model.cuda() 256 | else: 257 | model = torch.nn.DataParallel(model).cuda() 258 | 259 | # define loss function (criterion) and optimizer 260 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 261 | 262 | # optimize only the linear classifier 263 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 264 | assert len(parameters) == 2 # fc.weight, fc.bias 265 | optimizer = torch.optim.SGD(parameters, args.lr, 266 | momentum=args.momentum, 267 | weight_decay=args.weight_decay) 268 | 269 | # optionally resume from a checkpoint 270 | if args.resume: 271 | if os.path.isfile(args.resume): 272 | print("=> loading checkpoint '{}'".format(args.resume)) 273 | if args.gpu is None: 274 | checkpoint = torch.load(args.resume) 275 | else: 276 | # Map model to be loaded to specified single gpu. 277 | loc = 'cuda:{}'.format(args.gpu) 278 | checkpoint = torch.load(args.resume, map_location=loc) 279 | 280 | args.start_epoch = checkpoint['epoch'] 281 | best_acc1 = torch.tensor(checkpoint['best_acc1']) 282 | if args.gpu is not None: 283 | # best_acc1 may be from a checkpoint from a different GPU 284 | best_acc1 = best_acc1.to(args.gpu) 285 | model.load_state_dict(checkpoint['state_dict']) 286 | optimizer.load_state_dict(checkpoint['optimizer']) 287 | print("=> loaded checkpoint '{}' (epoch {})" 288 | .format(args.resume, checkpoint['epoch'])) 289 | else: 290 | print("=> no checkpoint found at '{}'".format(args.resume)) 291 | 292 | cudnn.benchmark = True 293 | 294 | # Data loading code 295 | if args.dataset == "ImageNet": 296 | data_path = args.data 297 | traindir = os.path.join(data_path, 'train') 298 | valdir = os.path.join(data_path, 'val') 299 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 300 | std=[0.229, 0.224, 0.225]) 301 | if args.train_strong: 302 | transform_train = transforms.Compose([ 303 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 304 | transforms.RandomApply([ 305 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 306 | ], p=0.8), 307 | transforms.RandomGrayscale(p=0.2), 308 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 309 | transforms.RandomHorizontalFlip(), 310 | transforms.ToTensor(), 311 | normalize 312 | ]) 313 | elif args.randcrop: 314 | transform_train = transforms.Compose([ 315 | transforms.RandomCrop(224, pad_if_needed=True), 316 | transforms.RandomHorizontalFlip(), 317 | transforms.ToTensor(), 318 | normalize, ]) 319 | 320 | else: 321 | transform_train = transforms.Compose([ 322 | transforms.RandomResizedCrop(224), 323 | transforms.RandomHorizontalFlip(), 324 | transforms.ToTensor(), 325 | normalize, ]) 326 | transform_test = transforms.Compose([ 327 | transforms.Resize(256), 328 | transforms.CenterCrop(224), 329 | transforms.ToTensor(), 330 | normalize, 331 | ]) 332 | train_dataset = datasets.ImageFolder(traindir, transform_train) 333 | val_dataset = datasets.ImageFolder(valdir, transform_test) 334 | 335 | if args.distributed: 336 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 337 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, 338 | shuffle=True) # different gpu forward individual based on its own statistics 339 | # val_sampler=None 340 | else: 341 | train_sampler = None 342 | val_sampler = None 343 | 344 | train_loader = torch.utils.data.DataLoader( 345 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 346 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 347 | 348 | val_loader = torch.utils.data.DataLoader( 349 | val_dataset, sampler=val_sampler, 350 | batch_size=args.batch_size, shuffle=(val_sampler is None), 351 | # different gpu forward is different, thus it's necessary 352 | num_workers=args.workers, pin_memory=True) 353 | 354 | 355 | elif args.dataset == "Place205": 356 | from data_processing.Place205_Dataset import Places205 357 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 358 | std=[0.229, 0.224, 0.225]) 359 | if args.train_strong: 360 | if args.randcrop: 361 | transform_train = transforms.Compose([ 362 | transforms.RandomCrop(224), 363 | transforms.RandomApply([ 364 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 365 | ], p=0.8), 366 | transforms.RandomGrayscale(p=0.2), 367 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 368 | transforms.RandomHorizontalFlip(), 369 | transforms.ToTensor(), 370 | normalize 371 | ]) 372 | else: 373 | transform_train = transforms.Compose([ 374 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 375 | transforms.RandomApply([ 376 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 377 | ], p=0.8), 378 | transforms.RandomGrayscale(p=0.2), 379 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 380 | transforms.RandomHorizontalFlip(), 381 | transforms.ToTensor(), 382 | normalize 383 | ]) 384 | else: 385 | if args.randcrop: 386 | transform_train = transforms.Compose([ 387 | transforms.RandomCrop(224), 388 | transforms.RandomHorizontalFlip(), 389 | transforms.ToTensor(), 390 | normalize, ]) 391 | 392 | else: 393 | transform_train = transforms.Compose([ 394 | transforms.RandomResizedCrop(224), 395 | transforms.RandomHorizontalFlip(), 396 | transforms.ToTensor(), 397 | normalize, ]) 398 | # waiting to add 10 crop 399 | transform_valid = transforms.Compose([ 400 | transforms.Resize([256, 256]), 401 | transforms.CenterCrop(224), 402 | transforms.ToTensor(), 403 | normalize, 404 | ]) 405 | 406 | train_dataset = Places205(args.data, 'train', transform_train) 407 | valid_dataset = Places205(args.data, 'val', transform_valid) 408 | if args.distributed: 409 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 410 | val_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset, shuffle=False) 411 | # val_sampler = None 412 | else: 413 | train_sampler = None 414 | val_sampler = None 415 | 416 | train_loader = torch.utils.data.DataLoader( 417 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 418 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 419 | 420 | val_loader = torch.utils.data.DataLoader( 421 | valid_dataset, sampler=val_sampler, 422 | batch_size=args.batch_size, 423 | num_workers=args.workers, pin_memory=True) 424 | 425 | else: 426 | print("your dataset %s is not supported for finetuning now" % args.dataset) 427 | exit() 428 | 429 | if args.evaluate: 430 | validate(val_loader, model, criterion, args) 431 | return 432 | import datetime 433 | today = datetime.date.today() 434 | formatted_today = today.strftime('%y%m%d') 435 | now = time.strftime("%H:%M:%S") 436 | 437 | save_path = os.path.join(args.save_path, args.log_path) 438 | log_path = os.path.join(save_path, 'Finetune_log') 439 | mkdir(log_path) 440 | log_path = os.path.join(log_path, formatted_today + now) 441 | mkdir(log_path) 442 | # model_path=os.path.join(log_path,'checkpoint.pth.tar') 443 | lr_scheduler = None 444 | if args.sgdr == 1: 445 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 12) 446 | elif args.sgdr == 2: 447 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.sgdr_t0, args.sgdr_t_mult) 448 | for epoch in range(args.start_epoch, args.epochs): 449 | if args.distributed: 450 | train_sampler.set_epoch(epoch) 451 | if args.sgdr == 0: 452 | adjust_learning_rate(optimizer, epoch, args) 453 | train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler) 454 | # evaluate on validation set 455 | acc1 = validate(val_loader, model, criterion, args) 456 | # remember best acc@1 and save checkpoint 457 | is_best = acc1 > best_acc1 458 | best_acc1 = max(acc1, best_acc1) 459 | 460 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 461 | and args.rank % ngpus_per_node == 0): 462 | # add timestamp 463 | tmp_save_path = os.path.join(log_path, 'checkpoint.pth.tar') 464 | save_checkpoint({ 465 | 'epoch': epoch + 1, 466 | 'arch': args.arch, 467 | 'state_dict': model.state_dict(), 468 | 'best_acc1': best_acc1, 469 | 'optimizer': optimizer.state_dict(), 470 | }, is_best, filename=tmp_save_path) 471 | 472 | if abs(args.epochs - epoch) <= 20: 473 | tmp_save_path = os.path.join(log_path, 'model_%d.pth.tar' % epoch) 474 | save_checkpoint({ 475 | 'epoch': epoch + 1, 476 | 'arch': args.arch, 477 | 'state_dict': model.state_dict(), 478 | 'best_acc1': best_acc1, 479 | 'optimizer': optimizer.state_dict(), 480 | }, False, filename=tmp_save_path) 481 | 482 | 483 | 484 | def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler): 485 | batch_time = AverageMeter('Time', ':6.3f') 486 | data_time = AverageMeter('Data', ':6.3f') 487 | losses = AverageMeter('Loss', ':.4e') 488 | top1 = AverageMeter('Acc@1', ':6.2f') 489 | top5 = AverageMeter('Acc@5', ':6.2f') 490 | mAP = AverageMeter("mAP", ":6.2f") 491 | progress = ProgressMeter( 492 | len(train_loader), 493 | [batch_time, data_time, losses, top1, top5, mAP], 494 | prefix="Epoch: [{}]".format(epoch)) 495 | 496 | """ 497 | Switch to eval mode: 498 | Under the protocol of linear classification on frozen features/models, 499 | it is not legitimate to change any part of the pre-trained model. 500 | BatchNorm in train mode may revise running mean/std (even if it receives 501 | no gradient), which are part of the model parameters too. 502 | """ 503 | model.eval() 504 | batch_total = len(train_loader) 505 | end = time.time() 506 | for i, (images, target) in enumerate(train_loader): 507 | # measure data loading time 508 | data_time.update(time.time() - end) 509 | # adjust_batch_learning_rate(optimizer, epoch, i, batch_total, args) 510 | 511 | if args.gpu is not None: 512 | images = images.cuda(args.gpu, non_blocking=True) 513 | 514 | target = target.cuda(args.gpu, non_blocking=True) 515 | 516 | # compute output 517 | output = model(images) 518 | loss = criterion(output, target) 519 | 520 | # measure accuracy and record loss 521 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 522 | losses.update(loss.item(), images.size(0)) 523 | top1.update(acc1.item(), images.size(0)) 524 | top5.update(acc5.item(), images.size(0)) 525 | 526 | # compute gradient and do SGD step 527 | optimizer.zero_grad() 528 | loss.backward() 529 | optimizer.step() 530 | 531 | if args.sgdr != 0: 532 | lr_scheduler.step(epoch + i / batch_total) 533 | 534 | # measure elapsed time 535 | batch_time.update(time.time() - end) 536 | end = time.time() 537 | 538 | if i % args.print_freq == 0: 539 | progress.display(i) 540 | 541 | 542 | def train2(train_loader, model, criterion, optimizer, epoch, args): 543 | batch_time = AverageMeter('Time', ':6.3f') 544 | data_time = AverageMeter('Data', ':6.3f') 545 | losses = AverageMeter('Loss', ':.4e') 546 | top1 = AverageMeter('Acc@1', ':6.2f') 547 | top5 = AverageMeter('Acc@5', ':6.2f') 548 | mAP = AverageMeter("mAP", ":6.2f") 549 | progress = ProgressMeter( 550 | len(train_loader), 551 | [batch_time, data_time, losses, top1, top5, mAP], 552 | prefix="Epoch: [{}]".format(epoch)) 553 | 554 | """ 555 | Switch to eval mode: 556 | Under the protocol of linear classification on frozen features/models, 557 | it is not legitimate to change any part of the pre-trained model. 558 | BatchNorm in train mode may revise running mean/std (even if it receives 559 | no gradient), which are part of the model parameters too. 560 | """ 561 | model.eval() 562 | 563 | end = time.time() 564 | for i, (images, target) in enumerate(train_loader): 565 | # measure data loading time 566 | data_time.update(time.time() - end) 567 | 568 | if args.gpu is not None: 569 | len_images = len(images) 570 | for k in range(len(images)): 571 | images[k] = images[k].cuda(args.gpu, non_blocking=True) 572 | 573 | target = target.cuda(args.gpu, non_blocking=True) 574 | len_images = len(images) 575 | 576 | first_output = -1 577 | for k in range(len_images): 578 | # compute gradient and do SGD step 579 | optimizer.zero_grad() 580 | output = model(images[k]) 581 | loss = criterion(output, target) 582 | loss.backward() 583 | optimizer.step() 584 | losses.update(loss.item(), images[k].size(0)) 585 | if k == 0: 586 | first_output = output 587 | 588 | images = images[0] 589 | output = first_output 590 | 591 | # measure accuracy and record loss 592 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 593 | 594 | top1.update(acc1.item(), images.size(0)) 595 | top5.update(acc5.item(), images.size(0)) 596 | 597 | # measure elapsed time 598 | batch_time.update(time.time() - end) 599 | end = time.time() 600 | 601 | if i % args.print_freq == 0: 602 | progress.display(i) 603 | 604 | 605 | def validate(val_loader, model, criterion, args): 606 | batch_time = AverageMeter('Time', ':6.3f') 607 | losses = AverageMeter('Loss', ':.4e') 608 | top1 = AverageMeter('Acc@1', ':6.2f') 609 | top5 = AverageMeter('Acc@5', ':6.2f') 610 | mAP = AverageMeter("mAP", ":6.2f") 611 | progress = ProgressMeter( 612 | len(val_loader), 613 | [batch_time, losses, top1, top5, mAP], 614 | prefix='Test: ') 615 | 616 | # switch to evaluate mode 617 | model.eval() 618 | with torch.no_grad(): 619 | end = time.time() 620 | for i, (images, target) in enumerate(val_loader): 621 | target = target.cuda(args.gpu, non_blocking=True) 622 | output = model(images) 623 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 624 | acc1 = torch.mean(concat_all_gather(acc1.unsqueeze(0)), dim=0, keepdim=True) 625 | acc5 = torch.mean(concat_all_gather(acc5.unsqueeze(0)), dim=0, keepdim=True) 626 | top1.update(acc1.item(), images.size(0)) 627 | top5.update(acc5.item(), images.size(0)) 628 | loss = criterion(output, target) 629 | losses.update(loss.item(), images.size(0)) 630 | # measure elapsed time 631 | batch_time.update(time.time() - end) 632 | end = time.time() 633 | 634 | if i % args.print_freq == 0: 635 | progress.display(i) 636 | 637 | # TODO: this should also be done with the ProgressMeter 638 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} ' 639 | .format(top1=top1, top5=top5, mAP=mAP)) 640 | 641 | return top1.avg 642 | 643 | 644 | def testing(val_loader, model, criterion, args): 645 | batch_time = AverageMeter('Time', ':6.3f') 646 | losses = AverageMeter('Loss', ':.4e') 647 | top1 = AverageMeter('Acc@1', ':6.2f') 648 | top5 = AverageMeter('Acc@5', ':6.2f') 649 | mAP = AverageMeter("mAP", ":6.2f") 650 | progress = ProgressMeter( 651 | len(val_loader), 652 | [batch_time, losses, top1, top5, mAP], 653 | prefix='Test: ') 654 | 655 | # switch to evaluate mode 656 | model.eval() 657 | correct_count = 0 658 | count_all = 0 659 | # implement our own random crop 660 | with torch.no_grad(): 661 | end = time.time() 662 | for i, (images, target) in enumerate(val_loader): 663 | target = target.cuda(args.gpu, non_blocking=True) 664 | output_list = [] 665 | for image in images: 666 | output = model(image) 667 | output = torch.softmax(output, dim=1) 668 | output_list.append(output) 669 | output_list = torch.stack(output_list, dim=0) 670 | output_list, max_index = torch.max(output_list, dim=0) 671 | output = output_list 672 | images = images[0] 673 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 674 | acc1 = torch.mean(concat_all_gather(acc1.unsqueeze(0)), dim=0, keepdim=True) 675 | acc5 = torch.mean(concat_all_gather(acc5.unsqueeze(0)), dim=0, keepdim=True) 676 | correct_count += float(acc1[0]) * images.size(0) 677 | count_all += images.size(0) 678 | top1.update(acc1.item(), images.size(0)) 679 | top5.update(acc5.item(), images.size(0)) 680 | loss = criterion(output, target) 681 | losses.update(loss.item(), images.size(0)) 682 | # measure elapsed time 683 | batch_time.update(time.time() - end) 684 | end = time.time() 685 | 686 | if i % args.print_freq == 0: 687 | progress.display(i) 688 | 689 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} ' 690 | .format(top1=top1, top5=top5, mAP=mAP)) 691 | final_accu = correct_count / count_all 692 | print("$$our final calculated accuracy %.7f" % final_accu) 693 | return top1.avg 694 | 695 | 696 | def testing2(val_loader, model, criterion, args): 697 | batch_time = AverageMeter('Time', ':6.3f') 698 | losses = AverageMeter('Loss', ':.4e') 699 | top1 = AverageMeter('Acc@1', ':6.2f') 700 | top5 = AverageMeter('Acc@5', ':6.2f') 701 | mAP = AverageMeter("mAP", ":6.2f") 702 | progress = ProgressMeter( 703 | len(val_loader), 704 | [batch_time, losses, top1, top5, mAP], 705 | prefix='Test: ') 706 | 707 | # switch to evaluate mode 708 | model.eval() 709 | correct_count = 0 710 | count_all = 0 711 | # implement our own random crop 712 | with torch.no_grad(): 713 | end = time.time() 714 | for i, (images, target) in enumerate(val_loader): 715 | target = target.cuda(args.gpu, non_blocking=True) 716 | output_list = [] 717 | for image in images: 718 | output = model(image) 719 | output = torch.softmax(output, dim=1) 720 | output_list.append(output) 721 | output_list = torch.stack(output_list, dim=0) 722 | output_list = torch.mean(output_list, dim=0) 723 | output = output_list 724 | images = images[0] 725 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 726 | acc1 = torch.mean(concat_all_gather(acc1), dim=0, keepdim=True) 727 | acc5 = torch.mean(concat_all_gather(acc5), dim=0, keepdim=True) 728 | correct_count += float(acc1[0]) * images.size(0) 729 | count_all += images.size(0) 730 | top1.update(acc1.item(), images.size(0)) 731 | top5.update(acc5.item(), images.size(0)) 732 | loss = criterion(output, target) 733 | losses.update(loss.item(), images.size(0)) 734 | # measure elapsed time 735 | batch_time.update(time.time() - end) 736 | end = time.time() 737 | 738 | if i % args.print_freq == 0: 739 | progress.display(i) 740 | 741 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} mAP {mAP.avg:.3f} ' 742 | .format(top1=top1, top5=top5, mAP=mAP)) 743 | final_accu = correct_count / count_all 744 | print("$$our final average accuracy %.7f" % final_accu) 745 | return top1.avg 746 | 747 | 748 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 749 | torch.save(state, filename) 750 | if is_best: 751 | root_path = os.path.split(filename)[0] 752 | best_path = os.path.join(root_path, "model_best.pth.tar") 753 | shutil.copyfile(filename, best_path) 754 | 755 | 756 | def sanity_check(state_dict, pretrained_weights): 757 | """ 758 | Linear classifier should not change any weights other than the linear layer. 759 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 760 | """ 761 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 762 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 763 | state_dict_pre = checkpoint['state_dict'] 764 | 765 | for k in list(state_dict.keys()): 766 | # only ignore fc layer 767 | if 'fc.weight' in k or 'fc.bias' in k: 768 | continue 769 | 770 | # name in pretrained model 771 | k_pre = 'module.encoder_q.' + k[len('module.'):] \ 772 | if k.startswith('module.') else 'module.encoder_q.' + k 773 | 774 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 775 | '{} is changed in linear classifier training.'.format(k) 776 | 777 | print("=> sanity check passed.") 778 | 779 | 780 | class AverageMeter(object): 781 | """Computes and stores the average and current value""" 782 | 783 | def __init__(self, name, fmt=':f'): 784 | self.name = name 785 | self.fmt = fmt 786 | self.reset() 787 | 788 | def reset(self): 789 | self.val = 0 790 | self.avg = 0 791 | self.sum = 0 792 | self.count = 0 793 | 794 | def update(self, val, n=1): 795 | self.val = val 796 | self.sum += val * n 797 | self.count += n 798 | self.avg = self.sum / self.count 799 | 800 | def __str__(self): 801 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 802 | return fmtstr.format(**self.__dict__) 803 | 804 | 805 | class ProgressMeter(object): 806 | def __init__(self, num_batches, meters, prefix=""): 807 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 808 | self.meters = meters 809 | self.prefix = prefix 810 | 811 | def display(self, batch): 812 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 813 | entries += [str(meter) for meter in self.meters] 814 | print('\t'.join(entries)) 815 | 816 | def _get_batch_fmtstr(self, num_batches): 817 | num_digits = len(str(num_batches // 1)) 818 | fmt = '{:' + str(num_digits) + 'd}' 819 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 820 | 821 | 822 | import math 823 | 824 | 825 | def adjust_learning_rate(optimizer, epoch, args): 826 | """Decay the learning rate based on schedule""" 827 | lr = args.lr 828 | end_lr = args.final_lr 829 | # update on cos scheduler 830 | # this scheduler is not proper enough 831 | if args.cos: 832 | lr = 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (lr - end_lr) + end_lr 833 | else: 834 | for milestone in args.schedule: 835 | lr *= 0.1 if epoch >= milestone else 1. 836 | for param_group in optimizer.param_groups: 837 | param_group['lr'] = lr 838 | 839 | 840 | def adjust_batch_learning_rate(optimizer, cur_epoch, cur_batch, batch_total, args): 841 | """Decay the learning rate based on schedule""" 842 | init_lr = args.lr 843 | # end_lr=args.final_lr 844 | # update on cos scheduler 845 | # this scheduler is not proper enough 846 | current_schdule = 0 847 | # use_epoch=cur_epoch 848 | last_milestone = 0 849 | for milestone in args.schedule: 850 | if cur_epoch > milestone: 851 | current_schdule += 1 852 | init_lr *= 0.1 853 | last_milestone = milestone 854 | else: 855 | cur_epoch -= last_milestone 856 | break 857 | if current_schdule < len(args.schedule): 858 | all_epochs = args.schedule[current_schdule] 859 | else: 860 | all_epochs = args.epochs 861 | end_lr = init_lr * 0.1 862 | lr = math.cos( 863 | 0.5 * math.pi * (cur_batch + cur_epoch * batch_total) / ((all_epochs - last_milestone) * batch_total)) * ( 864 | init_lr - end_lr) + end_lr 865 | if cur_batch % 50 == 0: 866 | print("[%d] %d/%d learing rate %.9f" % (cur_epoch, cur_batch, batch_total, lr)) 867 | for param_group in optimizer.param_groups: 868 | param_group['lr'] = lr 869 | 870 | 871 | 872 | 873 | 874 | # utils 875 | @torch.no_grad() 876 | def concat_all_gather(tensor): 877 | """ 878 | Performs all_gather operation on the provided tensors. 879 | *** Warning ***: torch.distributed.all_gather has no gradient. 880 | """ 881 | tensors_gather = [torch.ones_like(tensor) 882 | for _ in range(torch.distributed.get_world_size())] 883 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 884 | 885 | output = torch.cat(tensors_gather, dim=0) 886 | return output 887 | 888 | 889 | if __name__ == '__main__': 890 | main() 891 | -------------------------------------------------------------------------------- /main_clsa.py: -------------------------------------------------------------------------------- 1 | #Copyright (C) 2020 Xiao Wang 2 | #License: MIT for academic use. 3 | #Contact: Xiao Wang (wang3702@purdue.edu, xiaowang20140001@gmail.com) 4 | 5 | #Some codes adopted from https://github.com/facebookresearch/moco 6 | 7 | from ops.argparser import argparser 8 | from ops.Config_Envrionment import Config_Environment 9 | import torch.multiprocessing as mp 10 | from training.main_worker import main_worker 11 | def main(args): 12 | #config environment 13 | ngpus_per_node=Config_Environment(args) 14 | 15 | # call training main control function 16 | if args.multiprocessing_distributed==1: 17 | # Since we have ngpus_per_node processes per node, the total world_size 18 | # needs to be adjusted accordingly 19 | args.world_size = ngpus_per_node * args.world_size 20 | # Use torch.multiprocessing.spawn to launch distributed processes: the 21 | # main_worker process function 22 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 23 | else: 24 | # Simply call main_worker function 25 | main_worker(args.gpu, ngpus_per_node, args) 26 | 27 | 28 | if __name__ == '__main__': 29 | #use_cuda = torch.cuda.is_available() 30 | #print("starting check cuda status",use_cuda) 31 | #if use_cuda: 32 | args,params=argparser() 33 | main(args) -------------------------------------------------------------------------------- /model/CLSA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CLSA(nn.Module): 5 | 6 | def __init__(self, base_encoder, args, dim=128, K=65536, m=0.999, T=0.2, mlp=True): 7 | """ 8 | :param base_encoder: encoder model 9 | :param args: config parameters 10 | :param dim: feature dimension (default: 128) 11 | :param K: queue size; number of negative keys (default: 65536) 12 | :param m: momentum of updating key encoder (default: 0.999) 13 | :param T: softmax temperature (default: 0.2) 14 | :param mlp: use MLP layer to process encoder output or not (default: True) 15 | """ 16 | super(CLSA, self).__init__() 17 | self.args = args 18 | self.K = K 19 | self.m = m 20 | self.T = T 21 | self.T2 = self.args.clsa_t 22 | 23 | # create the encoders 24 | # num_classes is the output fc dimension 25 | self.encoder_q = base_encoder(num_classes=dim) 26 | self.encoder_k = base_encoder(num_classes=dim) 27 | 28 | if mlp: # hack: brute-force replacement 29 | dim_mlp = self.encoder_q.fc.weight.shape[1] 30 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 31 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 32 | 33 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 34 | param_k.data.copy_(param_q.data) # initialize 35 | param_k.requires_grad = False # not update by gradient 36 | self.register_buffer("queue", torch.randn(dim, K)) 37 | self.queue = nn.functional.normalize(self.queue, dim=0) # normalize across queue instead of each example 38 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 39 | # config parameters for CLSA stronger augmentation and multi-crop 40 | self.weak_pick = args.pick_weak 41 | self.strong_pick = args.pick_strong 42 | self.weak_pick = set(self.weak_pick) 43 | self.strong_pick = set(self.strong_pick) 44 | self.gpu = args.gpu 45 | self.sym = self.args.sym 46 | 47 | @torch.no_grad() 48 | def _momentum_update_key_encoder(self): 49 | """ 50 | Momentum update of the key encoder 51 | """ 52 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 53 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 54 | 55 | @torch.no_grad() 56 | def _dequeue_and_enqueue(self, queue, queue_ptr, keys): 57 | # gather keys before updating queue 58 | #keys = concat_all_gather(keys) #already concatenated before 59 | 60 | batch_size = keys.shape[0] 61 | 62 | ptr = int(queue_ptr) 63 | assert self.K % batch_size == 0 # for simplicity 64 | 65 | # replace the keys at ptr (dequeue and enqueue) 66 | queue[:, ptr:ptr + batch_size] = keys.T 67 | ptr = (ptr + batch_size) % self.K # move pointer 68 | 69 | queue_ptr[0] = ptr 70 | 71 | @torch.no_grad() 72 | def _batch_shuffle_ddp(self, x): 73 | """ 74 | Batch shuffle, for making use of BatchNorm. 75 | *** Only support DistributedDataParallel (DDP) model. *** 76 | """ 77 | # gather from all gpus 78 | batch_size_this = x.shape[0] 79 | x_gather = concat_all_gather(x) 80 | batch_size_all = x_gather.shape[0] 81 | 82 | num_gpus = batch_size_all // batch_size_this 83 | 84 | # random shuffle index 85 | idx_shuffle = torch.randperm(batch_size_all).cuda() 86 | 87 | # broadcast to all gpus 88 | torch.distributed.broadcast(idx_shuffle, src=0) 89 | 90 | # index for restoring 91 | idx_unshuffle = torch.argsort(idx_shuffle) 92 | 93 | # shuffled index for this gpu 94 | gpu_idx = torch.distributed.get_rank() 95 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 96 | 97 | return x_gather[idx_this], idx_unshuffle 98 | 99 | @torch.no_grad() 100 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 101 | """ 102 | Undo batch shuffle. 103 | *** Only support DistributedDataParallel (DDP) model. *** 104 | """ 105 | # gather from all gpus 106 | batch_size_this = x.shape[0] 107 | x_gather = concat_all_gather(x) 108 | batch_size_all = x_gather.shape[0] 109 | 110 | num_gpus = batch_size_all // batch_size_this 111 | 112 | # restored index for this gpu 113 | gpu_idx = torch.distributed.get_rank() 114 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 115 | 116 | return x_gather[idx_this] 117 | def forward(self, im_q_list, im_k,im_strong_list): 118 | """ 119 | :param im_q_list: query image list 120 | :param im_k: key image 121 | :param im_strong_list: query strong image list 122 | :return: 123 | weak: logit_list, label_list 124 | strong: logit_list, label_list 125 | """ 126 | if self.sym: 127 | q_list = [] 128 | for k, im_q in enumerate(im_q_list): # weak forward 129 | if k not in self.weak_pick: 130 | continue 131 | # can't shuffle because it will stop gradient only can be applied for k 132 | # im_q, idx_unshuffle = self._batch_shuffle_ddp(im_q) 133 | q = self.encoder_q(im_q) # queries: NxC 134 | q = nn.functional.normalize(q, dim=1) 135 | # q = self._batch_unshuffle_ddp(q, idx_unshuffle) 136 | q_list.append(q) 137 | # add the encoding of im_k as one of weakly supervised 138 | q = self.encoder_q(im_k) 139 | q = nn.functional.normalize(q, dim=1) 140 | q_list.append(q) 141 | 142 | q_strong_list = [] 143 | for k, im_strong in enumerate(im_strong_list): 144 | # im_strong, idx_unshuffle = self._batch_shuffle_ddp(im_strong) 145 | if k not in self.strong_pick: 146 | continue 147 | q_strong = self.encoder_q(im_strong) # queries: NxC 148 | q_strong = nn.functional.normalize(q_strong, dim=1) 149 | # q_strong = self._batch_unshuffle_ddp(q_strong, idx_unshuffle) 150 | q_strong_list.append(q_strong) 151 | with torch.no_grad(): # no gradient to keys 152 | # if update_key_encoder: 153 | self._momentum_update_key_encoder() # update the key encoder 154 | 155 | # shuffle for making use of BN 156 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 157 | 158 | k = self.encoder_k(im_k) # keys: NxC 159 | k = nn.functional.normalize(k, dim=1) 160 | # undo shuffle 161 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 162 | k = k.detach() 163 | k = concat_all_gather(k) 164 | 165 | k2 = self.encoder_k(im_q_list[0]) # keys: NxC 166 | k2 = nn.functional.normalize(k2, dim=1) 167 | # undo shuffle 168 | k2 = self._batch_unshuffle_ddp(k2, idx_unshuffle) 169 | k2 = k2.detach() 170 | k2 = concat_all_gather(k2) 171 | logits0_list = [] 172 | labels0_list = [] 173 | logits1_list = [] 174 | labels1_list = [] 175 | # first iter the 1st k supervised 176 | for choose_idx in range(len(q_list) - 1): 177 | q = q_list[choose_idx] 178 | # positive logits: NxN 179 | l_pos = torch.einsum('nc,ck->nk', [q, k.T]) 180 | # negative logits: NxK 181 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 182 | # logits: Nx(1+K) 183 | logits = torch.cat([l_pos, l_neg], dim=1) 184 | 185 | # apply temperature 186 | logits /= self.T 187 | 188 | # labels: positive key indicators 189 | 190 | cur_batch_size = logits.shape[0] 191 | cur_gpu = self.gpu 192 | choose_match = cur_gpu * cur_batch_size 193 | labels = torch.arange(choose_match, choose_match + cur_batch_size, dtype=torch.long).cuda() 194 | 195 | logits0_list.append(logits) 196 | labels0_list.append(labels) 197 | 198 | labels0 = logits.clone().detach() # use previous q as supervision 199 | labels0 = labels0 * self.T / self.T2 200 | labels0 = torch.softmax(labels0, dim=1) 201 | labels0 = labels0.detach() 202 | for choose_idx2 in range(len(q_strong_list)): 203 | q_strong = q_strong_list[choose_idx2] 204 | # weak strong loss 205 | 206 | l_pos = torch.einsum('nc,ck->nk', [q_strong, k.T]) 207 | # negative logits: NxK 208 | l_neg = torch.einsum('nc,ck->nk', [q_strong, self.queue.clone().detach()]) 209 | 210 | # logits: Nx(1+K) 211 | logits0 = torch.cat([l_pos, l_neg], dim=1) # N*(K+1) 212 | 213 | # apply temperature 214 | logits0 /= self.T2 215 | logits0 = torch.softmax(logits0, dim=1) 216 | 217 | logits1_list.append(logits0) 218 | labels1_list.append(labels0) 219 | # iter another part, symmetrized 220 | k = k2 221 | for choose_idx in range(1, len(q_list)): 222 | q = q_list[choose_idx] 223 | # positive logits: NxN 224 | l_pos = torch.einsum('nc,ck->nk', [q, k.T]) 225 | # negative logits: NxK 226 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 227 | # logits: Nx(1+K) 228 | logits = torch.cat([l_pos, l_neg], dim=1) 229 | 230 | # apply temperature 231 | logits /= self.T 232 | 233 | # labels: positive key indicators 234 | 235 | cur_batch_size = logits.shape[0] 236 | cur_gpu = self.gpu 237 | choose_match = cur_gpu * cur_batch_size 238 | labels = torch.arange(choose_match, choose_match + cur_batch_size, dtype=torch.long).cuda() 239 | 240 | logits0_list.append(logits) 241 | labels0_list.append(labels) 242 | 243 | labels0 = logits.clone().detach() # use previous q as supervision 244 | labels0 = labels0 * self.T / self.T2 245 | labels0 = torch.softmax(labels0, dim=1) 246 | labels0 = labels0.detach() 247 | for choose_idx2 in range(len(q_strong_list)): 248 | q_strong = q_strong_list[choose_idx2] 249 | # weak strong loss 250 | 251 | l_pos = torch.einsum('nc,ck->nk', [q_strong, k.T]) 252 | # negative logits: NxK 253 | l_neg = torch.einsum('nc,ck->nk', [q_strong, self.queue.clone().detach()]) 254 | 255 | # logits: Nx(1+K) 256 | logits0 = torch.cat([l_pos, l_neg], dim=1) # N*(K+1) 257 | 258 | # apply temperature 259 | logits0 /= self.T2 260 | logits0 = torch.softmax(logits0, dim=1) 261 | 262 | logits1_list.append(logits0) 263 | labels1_list.append(labels0) 264 | 265 | # dequeue and enqueue 266 | # if update_key_encoder==False: 267 | self._dequeue_and_enqueue(self.queue, self.queue_ptr, k) 268 | 269 | return logits0_list, labels0_list, logits1_list, labels1_list 270 | else: 271 | q_list = [] 272 | for k, im_q in enumerate(im_q_list): # weak forward 273 | if k not in self.weak_pick: 274 | continue 275 | # can't shuffle because it will stop gradient only can be applied for k 276 | # im_q, idx_unshuffle = self._batch_shuffle_ddp(im_q) 277 | q = self.encoder_q(im_q) # queries: NxC 278 | q = nn.functional.normalize(q, dim=1) 279 | # q = self._batch_unshuffle_ddp(q, idx_unshuffle) 280 | q_list.append(q) 281 | 282 | q_strong_list = [] 283 | for k, im_strong in enumerate(im_strong_list): 284 | # im_strong, idx_unshuffle = self._batch_shuffle_ddp(im_strong) 285 | if k not in self.strong_pick: 286 | continue 287 | q_strong = self.encoder_q(im_strong) # queries: NxC 288 | q_strong = nn.functional.normalize(q_strong, dim=1) 289 | # q_strong = self._batch_unshuffle_ddp(q_strong, idx_unshuffle) 290 | q_strong_list.append(q_strong) 291 | 292 | # compute key features 293 | with torch.no_grad(): # no gradient to keys 294 | # if update_key_encoder: 295 | self._momentum_update_key_encoder() # update the key encoder 296 | 297 | # shuffle for making use of BN 298 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 299 | 300 | k = self.encoder_k(im_k) # keys: NxC 301 | k = nn.functional.normalize(k, dim=1) 302 | 303 | # undo shuffle 304 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 305 | k = k.detach() 306 | k = concat_all_gather(k) 307 | 308 | # compute logits 309 | # Einstein sum is more intuitive 310 | 311 | logits0_list = [] 312 | labels0_list = [] 313 | logits1_list = [] 314 | labels1_list = [] 315 | for choose_idx in range(len(q_list)): 316 | q = q_list[choose_idx] 317 | 318 | # positive logits: Nx1 319 | l_pos = torch.einsum('nc,ck->nk', [q, k.T]) 320 | # negative logits: NxK 321 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 322 | 323 | # logits: Nx(1+K) 324 | logits = torch.cat([l_pos, l_neg], dim=1) 325 | 326 | # apply temperature 327 | logits /= self.T 328 | 329 | # labels: positive key indicators 330 | cur_batch_size = logits.shape[0] 331 | cur_gpu = self.gpu 332 | choose_match = cur_gpu * cur_batch_size 333 | labels = torch.arange(choose_match, choose_match + cur_batch_size, dtype=torch.long).cuda() 334 | 335 | logits0_list.append(logits) 336 | labels0_list.append(labels) 337 | 338 | labels0 = logits.clone().detach() # use previous q as supervision 339 | labels0 = labels0*self.T/self.T2 340 | labels0 = torch.softmax(labels0, dim=1) 341 | labels0 = labels0.detach() 342 | for choose_idx2 in range(len(q_strong_list)): 343 | q_strong = q_strong_list[choose_idx2] 344 | # weak strong loss 345 | 346 | l_pos = torch.einsum('nc,ck->nk', [q_strong, k.T]) 347 | # negative logits: NxK 348 | l_neg = torch.einsum('nc,ck->nk', [q_strong, self.queue.clone().detach()]) 349 | 350 | # logits: Nx(1+K) 351 | logits0 = torch.cat([l_pos, l_neg], dim=1) # N*(K+1) 352 | 353 | # apply temperature 354 | logits0 /= self.T2 355 | logits0 = torch.softmax(logits0, dim=1) 356 | 357 | logits1_list.append(logits0) 358 | labels1_list.append(labels0) 359 | 360 | # dequeue and enqueue 361 | # if update_key_encoder==False: 362 | self._dequeue_and_enqueue(self.queue, self.queue_ptr, k) 363 | 364 | return logits0_list, labels0_list, logits1_list, labels1_list 365 | 366 | 367 | 368 | 369 | 370 | @torch.no_grad() 371 | def concat_all_gather(tensor): 372 | """ 373 | Performs all_gather operation on the provided tensors. 374 | *** Warning ***: torch.distributed.all_gather has no gradient. 375 | """ 376 | tensors_gather = [torch.ones_like(tensor) 377 | for _ in range(torch.distributed.get_world_size())] 378 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 379 | 380 | output = torch.cat(tensors_gather, dim=0) 381 | return output 382 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/model/__init__.py -------------------------------------------------------------------------------- /ops/Config_Envrionment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import resource 3 | import torch 4 | import warnings 5 | import random 6 | import torch.backends.cudnn as cudnn 7 | 8 | def Config_Environment(args): 9 | # increase the limit of resources to make sure it can run under any conditions 10 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 11 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 12 | 13 | # config gpu settings 14 | choose = args.choose 15 | if choose is not None and args.nodes_num == 1: 16 | os.environ['CUDA_VISIBLE_DEVICES'] = choose 17 | print("Current we choose gpu:%s" % choose) 18 | use_cuda = torch.cuda.is_available() 19 | print("Cuda status ", use_cuda) 20 | ngpus_per_node = torch.cuda.device_count() 21 | print("in total we have ", ngpus_per_node, " gpu") 22 | if ngpus_per_node <= 0: 23 | print("We do not have gpu supporting, exit!!!") 24 | exit() 25 | if args.gpu is not None: 26 | warnings.warn('You have chosen a specific GPU. This will completely ' 27 | 'disable data parallelism.') 28 | 29 | if args.dist_url == "env://" and args.world_size == -1: 30 | args.world_size = int(os.environ["WORLD_SIZE"]) 31 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 32 | 33 | #init random seed 34 | if args.seed is not None: 35 | random.seed(args.seed) 36 | torch.manual_seed(args.seed) 37 | cudnn.deterministic = True 38 | warnings.warn('You have chosen to seed training. ' 39 | 'This will turn on the CUDNN deterministic setting, ' 40 | 'which can slow down your training considerably! ' 41 | 'You may see unexpected behavior when restarting ' 42 | 'from checkpoints.') 43 | return ngpus_per_node -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/ops/__init__.py -------------------------------------------------------------------------------- /ops/argparser.py: -------------------------------------------------------------------------------- 1 | import parser 2 | import argparse 3 | 4 | def argparser(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--data', default="data", type=str, metavar='DIR', 7 | help='path to dataset') 8 | parser.add_argument('--log_path', type=str, default="train_log", help="log path for saving models and logs") 9 | parser.add_argument('--arch', metavar='ARCH', default='resnet50', 10 | type=str, 11 | help='model architecture: (default: resnet50)') 12 | parser.add_argument('--workers', default=32, type=int, metavar='N', 13 | help='number of data loading workers (default: 32)') 14 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 15 | help='number of total epochs to run') 16 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 17 | help='manual epoch number (useful on restarts)') 18 | parser.add_argument('-b', '--batch_size', default=256, type=int, 19 | metavar='N', 20 | help='mini-batch size (default: 256), this is the total ' 21 | 'batch size of all GPUs on the current node when ' 22 | 'using Data Parallel or Distributed Data Parallel') 23 | parser.add_argument('--lr', '--learning_rate', default=0.03, type=float, 24 | metavar='LR', help='initial learning rate', dest='lr') 25 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 26 | help='momentum of SGD solver') 27 | parser.add_argument('--weight_decay', default=1e-4, type=float, 28 | help='weight decay (default: 1e-4)') 29 | parser.add_argument('--print_freq', default=10, type=int, 30 | metavar='N', help='print frequency (default: 10)') 31 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 32 | help='path to latest checkpoint (default: none)') 33 | parser.add_argument('--world_size', default=-1, type=int, 34 | help='number of nodes for distributed training') 35 | parser.add_argument('--rank', default=-1, type=int, 36 | help='node rank for distributed training,rank of total threads, 0 to args.world_size-1') 37 | parser.add_argument('--dist_url', default='tcp://localhost:10001', type=str, 38 | help='url used to set up distributed training') 39 | parser.add_argument('--dist_backend', default='nccl', type=str, 40 | help='distributed backend') 41 | parser.add_argument('--seed', default=None, type=int, 42 | help='seed for initializing training. ') 43 | parser.add_argument('--gpu', default=None, type=int, 44 | help='GPU id to use.') 45 | parser.add_argument('--multiprocessing_distributed', type=int, default=1, 46 | help='Use multi-processing distributed training to launch ' 47 | 'N processes per node, which has N GPUs. This is the ' 48 | 'fastest way to use PyTorch for either single node or ' 49 | 'multi node data parallel training') 50 | parser.add_argument("--nodes_num", type=int, default=1, help="number of nodes to use") 51 | parser.add_argument('--dataset', type=str, default="ImageNet", help="Specify dataset: default: ImageNet") 52 | 53 | # Baseline: moco specific configs: 54 | parser.add_argument('--moco_dim', default=128, type=int, 55 | help='feature dimension (default: 128)') 56 | parser.add_argument('--moco_k', default=65536, type=int, 57 | help='queue size; number of negative keys (default: 65536)') 58 | parser.add_argument('--moco_m', default=0.999, type=float, 59 | help='moco momentum of updating key encoder (default: 0.999)') 60 | parser.add_argument('--moco_t', default=0.2, type=float, 61 | help='softmax temperature (default: 0.2)') 62 | parser.add_argument('--mlp', type=int, default=1, 63 | help='use mlp head') 64 | parser.add_argument('--cos', type=int, default=1, 65 | help='use cosine lr schedule') 66 | parser.add_argument('--choose', type=str, default=None, 67 | help="choose gpu for training, default:None(Use all available GPUs)") 68 | 69 | #clsa parameter configuration 70 | parser.add_argument('--alpha', type=float, default=1, 71 | help="coefficients for DDM loss") 72 | parser.add_argument('--aug_times', type=int, default=5, 73 | help="random augmentation times in strong augmentation") 74 | # idea from swav#adds crops for it 75 | parser.add_argument("--nmb_crops", type=int, default=[1, 1, 1, 1, 1], nargs="+", 76 | help="list of number of crops (example: [2, 6])") # when use 0 denotes the multi crop is not applied 77 | parser.add_argument("--size_crops", type=int, default=[224, 192, 160, 128, 96], nargs="+", 78 | help="crops resolutions (example: [224, 96])") 79 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.172, 0.143, 0.114, 0.086], nargs="+", 80 | help="min scale crop argument in RandomResizedCrop ") 81 | parser.add_argument("--max_scale_crops", type=float, default=[1.0, 0.86, 0.715, 0.571, 0.429], nargs="+", 82 | help="max scale crop argument in RandomResizedCrop ") 83 | parser.add_argument("--pick_strong", type=int, default=[0, 1, 2, 3, 4], nargs="+", 84 | help="specify the strong augmentation that will be used ") 85 | parser.add_argument("--pick_weak", type=int, default=[0, 1, 2, 3, 4], nargs="+", 86 | help="specify the weak augmentation that will be used ") 87 | parser.add_argument("--clsa_t", type=float, default=0.2, help="temperature used for ddm loss") 88 | parser.add_argument("--sym",type=int,default=0,help="symmetrical loss apply or not (default:False)") 89 | args = parser.parse_args() 90 | params = vars(args) 91 | return args,params -------------------------------------------------------------------------------- /ops/os_operation.py: -------------------------------------------------------------------------------- 1 | # Publication: "Protein Docking Model Evaluation by Graph Neural Networks", Xiao Wang, Sean T Flannery and Daisuke Kihara, (2020) 2 | 3 | #GNN-Dove is a computational tool using graph neural network that can evaluate the quality of docking protein-complexes. 4 | 5 | #Copyright (C) 2020 Xiao Wang, Sean T Flannery, Daisuke Kihara, and Purdue University. 6 | 7 | #License: GPL v3 for academic use. (For commercial use, please contact us for different licensing.) 8 | 9 | #Contact: Daisuke Kihara (dkihara@purdue.edu) 10 | 11 | # 12 | 13 | # This program is free software: you can redistribute it and/or modify 14 | 15 | # it under the terms of the GNU General Public License as published by 16 | 17 | # the Free Software Foundation, version 3. 18 | 19 | # 20 | 21 | # This program is distributed in the hope that it will be useful, 22 | 23 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 24 | 25 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 26 | 27 | # GNU General Public License V3 for more details. 28 | 29 | # 30 | 31 | # You should have received a copy of the GNU v3.0 General Public License 32 | 33 | # along with this program. If not, see https://www.gnu.org/licenses/gpl-3.0.en.html. 34 | 35 | import os 36 | def mkdir(path): 37 | path=path.strip() 38 | path=path.rstrip("\\") 39 | isExists=os.path.exists(path) 40 | if not isExists: 41 | print (path+" created") 42 | os.makedirs(path) 43 | return True 44 | else: 45 | print (path+' existed') 46 | return False 47 | def execCmd(cmd): 48 | r = os.popen(cmd) 49 | text = r.read() 50 | r.close() 51 | return text -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | numpy==1.19.5 4 | Pillow==5.1.0 5 | tensorboard==1.14.0 6 | tensorboardX==1.7 7 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CLSA/37df76cf5cb032683e57b70a3a4090f0d524c8fd/training/__init__.py -------------------------------------------------------------------------------- /training/main_worker.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import torch.distributed as dist 3 | import os 4 | import torchvision.models as models 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import datetime 11 | import time 12 | 13 | from model.CLSA import CLSA 14 | from ops.os_operation import mkdir 15 | from data_processing.Multi_FixTransform import Multi_Fixtransform 16 | from training.train_utils import adjust_learning_rate,save_checkpoint 17 | from training.train import train 18 | 19 | 20 | def init_log_path(args): 21 | """ 22 | :param args: 23 | :return: 24 | save model+log path 25 | """ 26 | save_path = os.path.join(os.getcwd(), args.log_path) 27 | mkdir(save_path) 28 | save_path = os.path.join(save_path, args.dataset) 29 | mkdir(save_path) 30 | save_path = os.path.join(save_path, "Alpha_" + str(args.alpha)) 31 | mkdir(save_path) 32 | save_path = os.path.join(save_path, "Aug_" + str(args.aug_times)) 33 | mkdir(save_path) 34 | save_path = os.path.join(save_path, "lr_" + str(args.lr)) 35 | mkdir(save_path) 36 | save_path = os.path.join(save_path, "cos_" + str(args.cos)) 37 | mkdir(save_path) 38 | today = datetime.date.today() 39 | formatted_today = today.strftime('%y%m%d') 40 | now = time.strftime("%H:%M:%S") 41 | save_path = os.path.join(save_path, formatted_today + now) 42 | mkdir(save_path) 43 | return save_path 44 | 45 | 46 | def main_worker(gpu, ngpus_per_node, args): 47 | """ 48 | :param gpu: current gpu id 49 | :param ngpus_per_node: number of gpus in one node 50 | :param args: config parameter 51 | :return: 52 | init training setup and iteratively training 53 | """ 54 | params = vars(args) 55 | args.gpu = gpu 56 | 57 | # suppress printing if not master 58 | if args.multiprocessing_distributed and args.gpu != 0: 59 | def print_pass(*args): 60 | pass 61 | 62 | builtins.print = print_pass 63 | 64 | if args.gpu is not None: 65 | print("Use GPU: {} for training".format(args.gpu)) 66 | print("=> creating model '{}'".format(args.arch)) 67 | if args.distributed: 68 | if args.dist_url == "env://" and args.rank == -1: 69 | args.rank = int(os.environ["RANK"]) 70 | if args.multiprocessing_distributed: 71 | # For multiprocessing distributed training, rank needs to be the 72 | # global rank among all the processes 73 | args.rank = args.rank * ngpus_per_node + gpu 74 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 75 | world_size=args.world_size, rank=args.rank) 76 | #init model 77 | model = CLSA(models.__dict__[args.arch], args, 78 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp) 79 | print(model) 80 | 81 | 82 | if args.distributed: 83 | # For multiprocessing distributed, DistributedDataParallel constructor 84 | # should always set the single device scope, otherwise, 85 | # DistributedDataParallel will use all available devices. 86 | if args.gpu is not None: 87 | torch.cuda.set_device(args.gpu) 88 | model.cuda(args.gpu) 89 | 90 | # When using a single GPU per process and per 91 | # DistributedDataParallel, we need to divide the batch size 92 | # ourselves based on the total number of GPUs we have 93 | args.batch_size = int(args.batch_size / ngpus_per_node) 94 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 95 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 96 | else: 97 | model.cuda() 98 | # DistributedDataParallel will divide and allocate batch_size to all 99 | # available GPUs if device_ids are not set 100 | model = torch.nn.parallel.DistributedDataParallel(model) 101 | elif args.gpu is not None: 102 | torch.cuda.set_device(args.gpu) 103 | model = model.cuda(args.gpu) 104 | # comment out the following line for debugging 105 | raise NotImplementedError("Only DistributedDataParallel is supported.") 106 | else: 107 | # AllGather implementation (batch shuffle, queue update, etc.) in 108 | # this code only supports DistributedDataParallel. 109 | raise NotImplementedError("Only DistributedDataParallel is supported.") 110 | 111 | # define loss function (criterion) and optimizer 112 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 113 | 114 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 115 | momentum=args.momentum, 116 | weight_decay=args.weight_decay) 117 | 118 | # optionally resume from a checkpoint 119 | if args.resume: 120 | if os.path.isfile(args.resume): 121 | print("=> loading checkpoint '{}'".format(args.resume)) 122 | if args.gpu is None: 123 | checkpoint = torch.load(args.resume) 124 | else: 125 | # Map model to be loaded to specified single gpu. 126 | loc = 'cuda:{}'.format(args.gpu) 127 | checkpoint = torch.load(args.resume, map_location=loc) 128 | args.start_epoch = checkpoint['epoch'] 129 | model.load_state_dict(checkpoint['state_dict']) 130 | optimizer.load_state_dict(checkpoint['optimizer']) 131 | print("=> loaded checkpoint '{}' (epoch {})" 132 | .format(args.resume, checkpoint['epoch'])) 133 | else: 134 | print("=> no checkpoint found at '{}'".format(args.resume)) 135 | exit() 136 | 137 | cudnn.benchmark = True 138 | # config data loader 139 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 140 | std=[0.229, 0.224, 0.225]) 141 | 142 | fix_transform = Multi_Fixtransform(args.size_crops, 143 | args.nmb_crops, 144 | args.min_scale_crops, 145 | args.max_scale_crops, normalize, args.aug_times) 146 | traindir = os.path.join(args.data, 'train') 147 | train_dataset = datasets.ImageFolder( 148 | traindir, 149 | fix_transform) 150 | if args.distributed: 151 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 152 | else: 153 | train_sampler = None 154 | train_loader = torch.utils.data.DataLoader( 155 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 156 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 157 | save_path=init_log_path(args) #config model save path and log path 158 | log_path = os.path.join(save_path,"train.log") 159 | best_Acc = 0 160 | for epoch in range(args.start_epoch, args.epochs): 161 | if args.distributed: 162 | train_sampler.set_epoch(epoch) 163 | adjust_learning_rate(optimizer, epoch, args) 164 | acc1 = train(train_loader, model, criterion, optimizer, epoch, args,log_path) 165 | is_best = best_Acc > acc1 166 | best_Acc = max(best_Acc, acc1) 167 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 168 | and args.rank % ngpus_per_node == 0): 169 | save_dict = { 170 | 'epoch': epoch + 1, 171 | 'arch': args.arch, 172 | 'best_acc': best_Acc, 173 | 'state_dict': model.state_dict(), 174 | 'optimizer': optimizer.state_dict(), 175 | } 176 | 177 | if epoch % 10 == 9: 178 | tmp_save_path = os.path.join(save_path, 'checkpoint_{:04d}.pth.tar'.format(epoch)) 179 | save_checkpoint(save_dict, is_best=False, filename=tmp_save_path) 180 | tmp_save_path = os.path.join(save_path, 'checkpoint_best.pth.tar') 181 | save_checkpoint(save_dict, is_best=is_best, filename=tmp_save_path) 182 | 183 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch.nn as nn 3 | import torch 4 | 5 | from training.train_utils import AverageMeter,ProgressMeter,accuracy 6 | 7 | def train(train_loader, model, criterion, optimizer, epoch, args,log_path): 8 | """ 9 | :param train_loader: data loader 10 | :param model: training model 11 | :param criterion: loss function 12 | :param optimizer: SGD optimizer 13 | :param epoch: current epoch 14 | :param args: config parameter 15 | :return: 16 | """ 17 | batch_time = AverageMeter('Time', ':6.3f') 18 | data_time = AverageMeter('Data', ':6.3f') 19 | losses = AverageMeter('Loss', ':.4e') 20 | top1 = AverageMeter('Acc@1', ':6.2f') 21 | top5 = AverageMeter('Acc@5', ':6.2f') 22 | progress = ProgressMeter( 23 | len(train_loader), 24 | [batch_time, data_time, losses, top1, top5], 25 | prefix="Epoch: [{}]".format(epoch)) 26 | 27 | # switch to train mode 28 | model.train() 29 | 30 | end = time.time() 31 | mse_criterion=nn.MSELoss().cuda(args.gpu) 32 | for i, (images, _) in enumerate(train_loader): 33 | # measure data loading time 34 | data_time.update(time.time() - end) 35 | 36 | if args.gpu is not None: 37 | len_images = len(images) 38 | for k in range(len(images)): 39 | images[k] = images[k].cuda(args.gpu, non_blocking=True) 40 | crop_copy_length = int((len_images - 1) / 2) 41 | image_k = images[0] 42 | image_q = images[1:1 + crop_copy_length] 43 | image_strong = images[1 + crop_copy_length:] 44 | 45 | output, target, output2, target2 = model(image_q, image_k, image_strong) 46 | loss_contrastive = 0 47 | loss_weak_strong = 0 48 | if epoch == 0 and i == 0: 49 | print("-" * 100) 50 | print("contrastive loss count %d" % len(output)) 51 | print("weak strong loss count %d" % len(output2)) 52 | print("-" * 100) 53 | for k in range(len(output)): 54 | loss1 = criterion(output[k], target[k]) 55 | loss_contrastive += loss1 56 | for k in range(len(output2)): 57 | loss2 = -torch.mean(torch.sum(torch.log(output2[k]) * target2[k], dim=1)) # DDM loss 58 | loss_weak_strong += loss2 59 | loss = loss_contrastive + args.alpha * loss_weak_strong 60 | # acc1/acc5 are (K+1)-way contrast classifier accuracy 61 | # measure accuracy and record loss 62 | acc1, acc5 = accuracy(output[0], target[0], topk=(1, 5)) 63 | losses.update(loss.item(), images[0].size(0)) 64 | top1.update(acc1[0], images[0].size(0)) 65 | top5.update(acc5[0], images[0].size(0)) 66 | 67 | # compute gradient and do SGD step 68 | optimizer.zero_grad() 69 | loss.backward() 70 | optimizer.step() 71 | 72 | # measure elapsed time 73 | batch_time.update(time.time() - end) 74 | end = time.time() 75 | 76 | if i % args.print_freq == 0: 77 | progress.display(i) 78 | progress.write_record(i,log_path) 79 | return top1.avg -------------------------------------------------------------------------------- /training/train_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import shutil 4 | import os 5 | 6 | def adjust_learning_rate(optimizer, epoch, args): 7 | """ 8 | :param optimizer: SGD optimizer 9 | :param epoch: current epoch 10 | :param args: args 11 | :return: 12 | Decay the learning rate based on schedule 13 | """ 14 | 15 | lr = args.lr 16 | if args.cos==1: # cosine lr schedule 17 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 18 | elif args.cos==2: 19 | lr *= math.cos(math.pi * epoch / (args.epochs*2)) 20 | else: # stepwise lr schedule 21 | lr = args.lr 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | 25 | def save_checkpoint(state, is_best, filename): 26 | torch.save(state, filename) 27 | if is_best: 28 | root_path=os.path.split(filename)[0] 29 | best_model_path=os.path.join(root_path,"model_best.pth.tar") 30 | shutil.copyfile(filename, best_model_path) 31 | 32 | class AverageMeter(object): 33 | """Computes and stores the average and current value""" 34 | def __init__(self, name, fmt=':f'): 35 | self.name = name 36 | self.fmt = fmt 37 | self.reset() 38 | 39 | def reset(self): 40 | self.val = 0 41 | self.avg = 0 42 | self.sum = 0 43 | self.count = 0 44 | 45 | def update(self, val, n=1): 46 | self.val = val 47 | self.sum += val * n 48 | self.count += n 49 | self.avg = self.sum / self.count 50 | 51 | def __str__(self): 52 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 53 | return fmtstr.format(**self.__dict__) 54 | 55 | 56 | class ProgressMeter(object): 57 | def __init__(self, num_batches, meters, prefix=""): 58 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 59 | self.meters = meters 60 | self.prefix = prefix 61 | 62 | def display(self, batch): 63 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 64 | entries += [str(meter) for meter in self.meters] 65 | print('\t'.join(entries)) 66 | 67 | def _get_batch_fmtstr(self, num_batches): 68 | num_digits = len(str(num_batches // 1)) 69 | fmt = '{:' + str(num_digits) + 'd}' 70 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 71 | def write_record(self,batch,filename): 72 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 73 | entries += [str(meter) for meter in self.meters] 74 | with open(filename,"a+") as file: 75 | file.write('\t'.join(entries)+"\n") 76 | 77 | def accuracy(output, target, topk=(1,)): 78 | """ 79 | :param output: predicted prob vectors 80 | :param target: ground truth 81 | :param topk: top k predictions considered 82 | :return: 83 | Computes the accuracy over the k top predictions for the specified values of k 84 | """ 85 | with torch.no_grad(): 86 | maxk = max(topk) 87 | batch_size = target.size(0) 88 | 89 | _, pred = output.topk(maxk, 1, True, True) 90 | pred = pred.t() 91 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 92 | 93 | res = [] 94 | for k in topk: 95 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 96 | res.append(correct_k.mul_(100.0 / batch_size)) 97 | return res --------------------------------------------------------------------------------