├── .gitignore ├── README.md ├── contrast ├── __init__.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── rand_augment.py │ ├── sampler.py │ ├── transform.py │ └── zipreader.py ├── lars.py ├── logger.py ├── lr_scheduler.py ├── models │ ├── InstDisc.py │ ├── MoCo.py │ ├── PIC.py │ ├── SimCLR.py │ ├── __init__.py │ └── base.py ├── option.py ├── resnet.py └── util.py ├── main_linear.py ├── main_pretrain.py └── scripts ├── InstDisc.sh ├── MoCov1.sh ├── MoCov2.sh ├── PIC.sh └── SimCLR.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | output*/ 4 | ckpts/ 5 | *.pth 6 | *.t7 7 | *.png 8 | *.jpg 9 | tmp*.py 10 | # run*.sh 11 | *.pdf 12 | 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .DS_Store 118 | backup/ 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parametric Instance Classification for Unsupervised Visual Feature Learning 2 | 3 | By [Yue Cao](http://yue-cao.me)\*, [Zhenda Xie](https://scholar.google.com/citations?user=0C4cDloAAAAJ)\*, [Bin Liu](https://scholar.google.com/citations?user=-RYlJvYAAAAJ)*, [Yutong Lin](https://scholar.google.com/citations?user=mjUgH44AAAAJ), [Zheng Zhang](https://www.microsoft.com/en-us/research/people/zhez/), [Han Hu](https://ancientmooner.github.io/). 4 | 5 | This repo is an official implementation of ["Parametric Instance Classification for Unsupervised Visual Feature Learning"](https://arxiv.org/abs/2006.14618v1) on PyTorch. It also contains unofficial implementation of several popular Unsupervised Visual Feature Learning methods, including [InstDisc](https://arxiv.org/abs/1805.01978.pdf), [MoCo](https://arxiv.org/abs/1911.05722), [MoCo v2](https://arxiv.org/abs/2003.04297), [SimCLR](https://arxiv.org/abs/2002.05709). 6 | 7 | 8 | 9 | *Update on 2020/09/26* 10 | 11 | Our paper was accepted by NeurIPS 2020! 12 | 13 | ## Introduction 14 | 15 | This paper presents parametric instance classification (PIC) for unsupervised visual feature learning. Unlike the state-of-the-art approaches which do instance discrimination in a dual-branch non-parametric fashion, PIC directly performs a one-branch parametric instance classification, revealing a simple framework similar to supervised classification and without the need to address the information leakage issue. 16 | 17 | We show that the simple PIC framework can be as effective as the state-of-the-art approaches, i.e. SimCLR and MoCo v2, by adapting several common component settings used in the state-of-the-art approaches. 18 | 19 | We also propose two novel techniques to further improve effectiveness and practicality of PIC: 20 | 21 | 1. A sliding-window data scheduler, instead of the previous epoch-based data scheduler, which addresses the extremely infrequent instance visiting issue in PIC and improves the effectiveness; 22 | 2. A negative sampling and weight update correction approach to reduce the training time and GPU memory consumption, which also enables application of PIC to almost unlimited training images. 23 | 24 | We hope that the PIC framework can serve as a simple baseline to facilitate future study. 25 | 26 | ## Citation 27 | 28 | ``` 29 | @article{cao2020PIC, 30 | title={Parametric Instance Classification for Unsupervised Visual Feature Learning}, 31 | author={Cao, Yue and Xie, Zhenda and Liu, Bin and Lin, Yutong and Zhang, Zheng and Hu, Han}, 32 | booktitle={Advances in neural information processing systems}, 33 | year={2020} 34 | } 35 | ``` 36 | 37 | ## Main Results 38 | 39 | | | #aug/iter X #epoch | Top-1 | Top-5 | Model | 40 | | --------- | ------------------------- | ----- | ----- | ------- | 41 | | InstDisc | 1X200 | 60.6 | 82.6 | [download](https://drive.google.com/file/d/1ilTo2Lk0D8MIrLMY2FA9s2QtHnSfy35G/view?usp=sharing) | 42 | | SimCLR | 2X100 | 64.7 | 86.0 | 43 | | MoCo v2 | 2X100 | 64.6 | 85.9 | [download](https://drive.google.com/file/d/1dhOg2AZRhw42SOiXFmXedrRXhMY5gPOh/view?usp=sharing) | 44 | | PIC (ours) | 1X200 | 68.6 | 88.8 | [download](https://drive.google.com/file/d/1eqtLv_RrBCgSEDhte6PueqAFlenaH50k/view?usp=sharing) | 45 | | InstDisc | 1X400 | 62.7 | 84.6 | [download](https://drive.google.com/file/d/1bWHvEZ9vyidtCVBZzkyTxGLmj7Vyla1j/view?usp=sharing) | 46 | | SimCLR | 2X200 | 66.6 | 87.3 | 47 | | MoCo v2 | 2X200 | 67.9 | 88.1 | [download](https://drive.google.com/file/d/1Y-PlmcFSLanIDjYr6Z2fPSYamWt4DO_7/view?usp=sharing) | 48 | | PIC (ours) | 1X400 | 70.3 | 89.8 | [download](https://drive.google.com/file/d/1JdDfPr78BY_0MPeN2r_TX79KObyE-QO4/view?usp=sharing) | 49 | | PIC (ours) | 1X1600 | 70.8 | 90.0 | 50 | 51 | **Notes**: 52 | 53 | * InstDisc and MoCo v2 refer to our re-implementation of InstDisc and MoCo v2. All model checkpoints can be found from [Google Drive](https://drive.google.com/drive/folders/12ihazCK8iogX3pvNA5tRJiLpcxxCiXOc?usp=sharing). 54 | 55 | * To achieve better performance, our PIC adopts a multi-crop strategy, which is proposed in [SwAV](https://arxiv.org/abs/2006.09882). In each iteration, one 160 x 160 crop and three 96 x 96 crops of an image are fed into the model. With similar memory and compute requirements, PIC could achieve better performance than the original PIC model. 56 | 57 | ## Getting started 58 | 59 | ### Tested Environment 60 | 61 | - `Anaconda` with `python >= 3.6` 62 | - `pytorch=1.4, torchvison, cuda>=9.2` 63 | - [Optional] `Apex` for automatic mixed precision: Refer to https://github.com/NVIDIA/apex#quick-start 64 | - Others: ` pip install termcolor opencv-python tensorboard` 65 | 66 | ### Datasets 67 | 68 | We use standard ImageNet dataset to pre-train the model, download it from http://www.image-net.org/ and unzip it. 69 | 70 | * For standard folder dataset, move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) 71 | 72 | * To boost the performance when read images from massive small files is slow, we also support zipped ImageNet, which includes four files: 73 | * `train.zip`, `val.zip`: which stores the zipped folder for train and validate splits. 74 | 75 | * `train_map.txt`, `train_map.txt`: which stores the relative path in the corresponding zip file and ground truth label. 76 | 77 | Make sure the data folder looks like this: 78 | 79 | ``` 80 | $ tree data 81 | data 82 | └── ImageNet-Zip 83 | ├── train_map.txt 84 | ├── train.zip 85 | ├── val_map.txt 86 | └── val.zip 87 | 88 | $ head -n 5 data/ImageNet-Zip/val_map.txt 89 | ILSVRC2012_val_00000001.JPEG 65 90 | ILSVRC2012_val_00000002.JPEG 970 91 | ILSVRC2012_val_00000003.JPEG 230 92 | ILSVRC2012_val_00000004.JPEG 809 93 | ILSVRC2012_val_00000005.JPEG 516 94 | 95 | $ head -n 5 data/ImageNet-Zip/train_map.txt 96 | n01440764/n01440764_10026.JPEG 0 97 | n01440764/n01440764_10027.JPEG 0 98 | n01440764/n01440764_10029.JPEG 0 99 | n01440764/n01440764_10040.JPEG 0 100 | n01440764/n01440764_10042.JPEG 0 101 | ``` 102 | 103 | ### Unsupervised training and linear evaluation 104 | 105 | The implementation only supports **DistributedDataParallel** training with multiple GPU. 106 | 107 | To do PIC pretraining and linear evaluation of a ResNet-50 model on ImageNet in a 4-GPU machine, run: 108 | 109 | ``` 110 | epochs=200 111 | data_dir="./data/ImageNet-Zip" 112 | output_dir="./output/PIC/epochs-${epochs}" 113 | 114 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 115 | main_pretrain.py \ 116 | --data-dir ${data_dir} \ 117 | --aug SimCLR \ 118 | --crop 0.08 \ 119 | --contrast-temperature 0.2 \ 120 | --use-sliding-window-sampler \ 121 | --model PIC \ 122 | --mlp-head \ 123 | --epochs ${epochs} \ 124 | --output-dir ${output_dir} 125 | 126 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 127 | main_linear.py \ 128 | --data-dir ${data_dir} \ 129 | --output-dir ${output_dir}/eval \ 130 | --pretrained-model ${output_dir}/current.pth 131 | ``` 132 | 133 | **Notes**: 134 | 135 | * To use zipped ImageNet instead folder dataset, add `--zip` to the parameters. 136 | 137 | * To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will sharding the dataset into nonoverlapping pieces for different GPU and only load from the corresponding one. 138 | * To enable automatic mixed precision training, add `--amp-opt-level O1`. 139 | 140 | * We have provided the scripts in `./scripts` to help reproduce the results of Our PIC and other methods. For example, to reproduce the results of PIC for 400 epoch, just run 141 | 142 | ``` 143 | bash scripts/PIC.sh 400 144 | ``` 145 | 146 | * For additional options, run `python main_pretrain.py --help` and `python main_pretrain.py —help` to get help message. Or refer to [./contrast/option.py](./contrast/option.py). 147 | 148 | ## Known Issues 149 | 150 | * For longer training (like training for 1600 epochs), the current implementation is unstable, and the loss may become *NaN* during training. 151 | We are now trying to figure out the cause of this phenomenon. 152 | * Recent negative sampling and weight correction have not been implemented in this version. We will add these two techniques in the near future. 153 | 154 | ## References 155 | 156 | Our testbed builds upon several existing publicly available code. Specifically, we have modified and integrated the following code into this project: 157 | 158 | * https://github.com/zhirongw/lemniscate.pytorch 159 | * https://github.com/facebookresearch/moco 160 | * https://github.com/HobbitLong/CMC 161 | -------------------------------------------------------------------------------- /contrast/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bl0/PIC/21fa56aa538689baa20ff4992204cb87fa8276ad/contrast/__init__.py -------------------------------------------------------------------------------- /contrast/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch.distributed as dist 5 | from torch.utils.data import DataLoader, SubsetRandomSampler 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from .transform import get_transform 9 | from .dataset import ImageFolder 10 | from .sampler import SubsetSlidingWindowSampler 11 | 12 | 13 | def get_loader(aug_type, args, two_crop=False, prefix='train'): 14 | transform = get_transform(aug_type, args.crop, args.image_size, args.num_crop, 15 | args.crop2, args.image_size2, args.num_crop2) 16 | 17 | # dataset 18 | if args.zip: 19 | train_ann_file = prefix + "_map.txt" 20 | train_prefix = prefix + ".zip@/" 21 | train_dataset = ImageFolder(args.data_dir, train_ann_file, train_prefix, 22 | transform, two_crop=two_crop, cache_mode=args.cache_mode) 23 | else: 24 | train_folder = os.path.join(args.data_dir, prefix) 25 | train_dataset = ImageFolder(train_folder, transform=transform, two_crop=two_crop) 26 | 27 | # sampler 28 | indices = np.arange(dist.get_rank(), len(train_dataset), dist.get_world_size()) 29 | if args.use_sliding_window_sampler: 30 | sampler = SubsetSlidingWindowSampler(indices, 31 | window_stride=args.window_stride // dist.get_world_size(), 32 | window_size=args.window_size // dist.get_world_size(), 33 | shuffle_per_epoch=args.shuffle_per_epoch) 34 | elif args.zip and args.cache_mode == 'part': 35 | sampler = SubsetRandomSampler(indices) 36 | else: 37 | sampler = DistributedSampler(train_dataset) 38 | 39 | # dataloader 40 | return DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, 41 | num_workers=args.num_workers, pin_memory=True, sampler=sampler, drop_last=True) 42 | -------------------------------------------------------------------------------- /contrast/data/dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | import os 4 | import time 5 | 6 | import torch.distributed as dist 7 | import torch.utils.data as data 8 | from PIL import Image 9 | 10 | from .zipreader import is_zip_path, ZipReader 11 | 12 | 13 | def has_file_allowed_extension(filename, extensions): 14 | """Checks if a file is an allowed extension. 15 | Args: 16 | filename (string): path to a file 17 | Returns: 18 | bool: True if the filename ends with a known image extension 19 | """ 20 | filename_lower = filename.lower() 21 | return any(filename_lower.endswith(ext) for ext in extensions) 22 | 23 | 24 | def find_classes(dir): 25 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 26 | classes.sort() 27 | class_to_idx = {classes[i]: i for i in range(len(classes))} 28 | return classes, class_to_idx 29 | 30 | 31 | def make_dataset(dir, class_to_idx, extensions): 32 | images = [] 33 | dir = os.path.expanduser(dir) 34 | for target in sorted(os.listdir(dir)): 35 | d = os.path.join(dir, target) 36 | if not os.path.isdir(d): 37 | continue 38 | 39 | for root, _, fnames in sorted(os.walk(d)): 40 | for fname in sorted(fnames): 41 | if has_file_allowed_extension(fname, extensions): 42 | path = os.path.join(root, fname) 43 | item = (path, class_to_idx[target]) 44 | images.append(item) 45 | 46 | return images 47 | 48 | 49 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 50 | images = [] 51 | with open(ann_file, "r") as f: 52 | contents = f.readlines() 53 | for line_str in contents: 54 | path_contents = [c for c in line_str.split('\t')] 55 | im_file_name = path_contents[0] 56 | class_index = int(path_contents[1]) 57 | 58 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 59 | item = (os.path.join(img_prefix, im_file_name), class_index) 60 | 61 | images.append(item) 62 | 63 | return images 64 | 65 | 66 | class DatasetFolder(data.Dataset): 67 | """A generic data loader where the samples are arranged in this way: :: 68 | root/class_x/xxx.ext 69 | root/class_x/xxy.ext 70 | root/class_x/xxz.ext 71 | root/class_y/123.ext 72 | root/class_y/nsdf3.ext 73 | root/class_y/asd932_.ext 74 | Args: 75 | root (string): Root directory path. 76 | loader (callable): A function to load a sample given its path. 77 | extensions (list[string]): A list of allowed extensions. 78 | transform (callable, optional): A function/transform that takes in 79 | a sample and returns a transformed version. 80 | E.g, ``transforms.RandomCrop`` for images. 81 | target_transform (callable, optional): A function/transform that takes 82 | in the target and transforms it. 83 | Attributes: 84 | samples (list): List of (sample path, class_index) tuples 85 | """ 86 | 87 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 88 | cache_mode="no"): 89 | # image folder mode 90 | if ann_file == '': 91 | _, class_to_idx = find_classes(root) 92 | samples = make_dataset(root, class_to_idx, extensions) 93 | # zip mode 94 | else: 95 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 96 | os.path.join(root, img_prefix), 97 | extensions) 98 | 99 | if len(samples) == 0: 100 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" 101 | "Supported extensions are: " + ",".join(extensions))) 102 | 103 | self.root = root 104 | self.loader = loader 105 | self.extensions = extensions 106 | 107 | self.samples = samples 108 | self.labels = [y_1k for _, y_1k in samples] 109 | self.classes = list(set(self.labels)) 110 | 111 | self.transform = transform 112 | self.target_transform = target_transform 113 | 114 | self.cache_mode = cache_mode 115 | if self.cache_mode != "no": 116 | self.init_cache() 117 | 118 | def init_cache(self): 119 | assert self.cache_mode in ["part", "full"] 120 | n_sample = len(self.samples) 121 | global_rank = dist.get_rank() 122 | world_size = dist.get_world_size() 123 | 124 | samples_bytes = [None for _ in range(n_sample)] 125 | start_time = time.time() 126 | for index in range(n_sample): 127 | if index % (n_sample//10) == 0: 128 | t = time.time() - start_time 129 | logger = logging.getLogger(__name__) 130 | logger.info(f'cached {index}/{n_sample} takes {t:.2f}s per block') 131 | start_time = time.time() 132 | path, target = self.samples[index] 133 | if self.cache_mode == "full" or index % world_size == global_rank: 134 | samples_bytes[index] = (ZipReader.read(path), target) 135 | else: 136 | samples_bytes[index] = (path, target) 137 | self.samples = samples_bytes 138 | 139 | def __getitem__(self, index): 140 | """ 141 | Args: 142 | index (int): Index 143 | Returns: 144 | tuple: (sample, target) where target is class_index of the target class. 145 | """ 146 | path, target = self.samples[index] 147 | sample = self.loader(path) 148 | if self.transform is not None: 149 | sample = self.transform(sample) 150 | if self.target_transform is not None: 151 | target = self.target_transform(target) 152 | 153 | return sample, target 154 | 155 | def __len__(self): 156 | return len(self.samples) 157 | 158 | def __repr__(self): 159 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 160 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 161 | fmt_str += ' Root Location: {}\n'.format(self.root) 162 | tmp = ' Transforms (if any): ' 163 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 164 | tmp = ' Target Transforms (if any): ' 165 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 166 | return fmt_str 167 | 168 | 169 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 170 | 171 | 172 | def pil_loader(path): 173 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 174 | if isinstance(path, bytes): 175 | img = Image.open(io.BytesIO(path)) 176 | elif is_zip_path(path): 177 | data = ZipReader.read(path) 178 | img = Image.open(io.BytesIO(data)) 179 | else: 180 | with open(path, 'rb') as f: 181 | img = Image.open(f) 182 | return img.convert('RGB') 183 | 184 | 185 | def accimage_loader(path): 186 | import accimage 187 | try: 188 | return accimage.Image(path) 189 | except IOError: 190 | # Potentially a decoding problem, fall back to PIL.Image 191 | return pil_loader(path) 192 | 193 | 194 | def default_img_loader(path): 195 | from torchvision import get_image_backend 196 | if get_image_backend() == 'accimage': 197 | return accimage_loader(path) 198 | else: 199 | return pil_loader(path) 200 | 201 | 202 | class ImageFolder(DatasetFolder): 203 | """A generic data loader where the images are arranged in this way: :: 204 | root/dog/xxx.png 205 | root/dog/xxy.png 206 | root/dog/xxz.png 207 | root/cat/123.png 208 | root/cat/nsdf3.png 209 | root/cat/asd932_.png 210 | Args: 211 | root (string): Root directory path. 212 | transform (callable, optional): A function/transform that takes in an PIL image 213 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 214 | target_transform (callable, optional): A function/transform that takes in the 215 | target and transforms it. 216 | loader (callable, optional): A function to load an image given its path. 217 | Attributes: 218 | imgs (list): List of (image path, class_index) tuples 219 | """ 220 | 221 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 222 | loader=default_img_loader, cache_mode="no", two_crop=False): 223 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 224 | ann_file=ann_file, img_prefix=img_prefix, 225 | transform=transform, target_transform=target_transform, 226 | cache_mode=cache_mode) 227 | self.imgs = self.samples 228 | self.two_crop = two_crop 229 | 230 | def __getitem__(self, index): 231 | """ 232 | Args: 233 | index (int): Index 234 | Returns: 235 | tuple: (image, target) where target is class_index of the target class. 236 | """ 237 | path, target = self.samples[index] 238 | image = self.loader(path) 239 | if self.transform is not None: 240 | if isinstance(self.transform, tuple) and len(self.transform) == 2: 241 | img = self.transform[0](image) 242 | else: 243 | img = self.transform(image) 244 | else: 245 | img = image 246 | if self.target_transform is not None: 247 | target = self.target_transform(target) 248 | 249 | if self.two_crop: 250 | if isinstance(self.transform, tuple) and len(self.transform) == 2: 251 | img2 = self.transform[1](image) 252 | else: 253 | img2 = self.transform(image) 254 | return img, img2, index, target 255 | else: 256 | return img, index, target 257 | -------------------------------------------------------------------------------- /contrast/data/rand_augment.py: -------------------------------------------------------------------------------- 1 | """ AutoAugment and RandAugment 2 | Implementation adapted from: 3 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 4 | Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719 5 | Hacked together by Ross Wightman 6 | """ 7 | import random 8 | import math 9 | import re 10 | from PIL import Image, ImageOps, ImageEnhance 11 | import PIL 12 | import numpy as np 13 | 14 | 15 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) 16 | 17 | _FILL = (128, 128, 128) 18 | 19 | # This signifies the max integer that the controller RNN could predict for the 20 | # augmentation scheme. 21 | _MAX_LEVEL = 10. 22 | 23 | _HPARAMS_DEFAULT = dict( 24 | translate_const=250, 25 | img_mean=_FILL, 26 | ) 27 | 28 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 29 | 30 | 31 | def _interpolation(kwargs): 32 | interpolation = kwargs.pop('resample', Image.BILINEAR) 33 | if isinstance(interpolation, (list, tuple)): 34 | return random.choice(interpolation) 35 | else: 36 | return interpolation 37 | 38 | 39 | def _check_args_tf(kwargs): 40 | if 'fillcolor' in kwargs and _PIL_VER < (5, 0): 41 | kwargs.pop('fillcolor') 42 | kwargs['resample'] = _interpolation(kwargs) 43 | 44 | 45 | def shear_x(img, factor, **kwargs): 46 | _check_args_tf(kwargs) 47 | return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) 48 | 49 | 50 | def shear_y(img, factor, **kwargs): 51 | _check_args_tf(kwargs) 52 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) 53 | 54 | 55 | def translate_x_rel(img, pct, **kwargs): 56 | pixels = pct * img.size[0] 57 | _check_args_tf(kwargs) 58 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) 59 | 60 | 61 | def translate_y_rel(img, pct, **kwargs): 62 | pixels = pct * img.size[1] 63 | _check_args_tf(kwargs) 64 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) 65 | 66 | 67 | def translate_x_abs(img, pixels, **kwargs): 68 | _check_args_tf(kwargs) 69 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) 70 | 71 | 72 | def translate_y_abs(img, pixels, **kwargs): 73 | _check_args_tf(kwargs) 74 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) 75 | 76 | 77 | def rotate(img, degrees, **kwargs): 78 | _check_args_tf(kwargs) 79 | if _PIL_VER >= (5, 2): 80 | return img.rotate(degrees, **kwargs) 81 | elif _PIL_VER >= (5, 0): 82 | w, h = img.size 83 | post_trans = (0, 0) 84 | rotn_center = (w / 2.0, h / 2.0) 85 | angle = -math.radians(degrees) 86 | matrix = [ 87 | round(math.cos(angle), 15), 88 | round(math.sin(angle), 15), 89 | 0.0, 90 | round(-math.sin(angle), 15), 91 | round(math.cos(angle), 15), 92 | 0.0, 93 | ] 94 | 95 | def transform(x, y, matrix): 96 | (a, b, c, d, e, f) = matrix 97 | return a * x + b * y + c, d * x + e * y + f 98 | 99 | matrix[2], matrix[5] = transform( 100 | -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix 101 | ) 102 | matrix[2] += rotn_center[0] 103 | matrix[5] += rotn_center[1] 104 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 105 | else: 106 | return img.rotate(degrees, resample=kwargs['resample']) 107 | 108 | 109 | def auto_contrast(img, **__): 110 | return ImageOps.autocontrast(img) 111 | 112 | 113 | def invert(img, **__): 114 | return ImageOps.invert(img) 115 | 116 | 117 | def identity(img, **__): 118 | return img 119 | 120 | 121 | def equalize(img, **__): 122 | return ImageOps.equalize(img) 123 | 124 | 125 | def solarize(img, thresh, **__): 126 | return ImageOps.solarize(img, thresh) 127 | 128 | 129 | def solarize_add(img, add, thresh=128, **__): 130 | lut = [] 131 | for i in range(256): 132 | if i < thresh: 133 | lut.append(min(255, i + add)) 134 | else: 135 | lut.append(i) 136 | if img.mode in ("L", "RGB"): 137 | if img.mode == "RGB" and len(lut) == 256: 138 | lut = lut + lut + lut 139 | return img.point(lut) 140 | else: 141 | return img 142 | 143 | 144 | def posterize(img, bits_to_keep, **__): 145 | if bits_to_keep >= 8: 146 | return img 147 | return ImageOps.posterize(img, bits_to_keep) 148 | 149 | 150 | def contrast(img, factor, **__): 151 | return ImageEnhance.Contrast(img).enhance(factor) 152 | 153 | 154 | def color(img, factor, **__): 155 | return ImageEnhance.Color(img).enhance(factor) 156 | 157 | 158 | def brightness(img, factor, **__): 159 | return ImageEnhance.Brightness(img).enhance(factor) 160 | 161 | 162 | def sharpness(img, factor, **__): 163 | return ImageEnhance.Sharpness(img).enhance(factor) 164 | 165 | 166 | def _randomly_negate(v): 167 | """With 50% prob, negate the value""" 168 | return -v if random.random() > 0.5 else v 169 | 170 | 171 | def _rotate_level_to_arg(level, _hparams): 172 | # range [-30, 30] 173 | level = (level / _MAX_LEVEL) * 30. 174 | level = _randomly_negate(level) 175 | return level, 176 | 177 | 178 | def _enhance_level_to_arg(level, _hparams): 179 | # range [0.1, 1.9] 180 | return (level / _MAX_LEVEL) * 1.8 + 0.1, 181 | 182 | 183 | def _shear_level_to_arg(level, _hparams): 184 | # range [-0.3, 0.3] 185 | level = (level / _MAX_LEVEL) * 0.3 186 | level = _randomly_negate(level) 187 | return level, 188 | 189 | 190 | def _translate_abs_level_to_arg(level, hparams): 191 | translate_const = hparams['translate_const'] 192 | level = (level / _MAX_LEVEL) * float(translate_const) 193 | level = _randomly_negate(level) 194 | return level, 195 | 196 | 197 | def _translate_rel_level_to_arg(level, _hparams): 198 | # range [-0.45, 0.45] 199 | level = (level / _MAX_LEVEL) * 0.45 200 | level = _randomly_negate(level) 201 | return level, 202 | 203 | 204 | def _posterize_original_level_to_arg(level, _hparams): 205 | # As per original AutoAugment paper description 206 | # range [4, 8], 'keep 4 up to 8 MSB of image' 207 | return int((level / _MAX_LEVEL) * 4) + 4, 208 | 209 | 210 | def _posterize_research_level_to_arg(level, _hparams): 211 | # As per Tensorflow models research and UDA impl 212 | # range [4, 0], 'keep 4 down to 0 MSB of original image' 213 | return 4 - int((level / _MAX_LEVEL) * 4), 214 | 215 | 216 | def _posterize_tpu_level_to_arg(level, _hparams): 217 | # As per Tensorflow TPU EfficientNet impl 218 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 219 | return int((level / _MAX_LEVEL) * 4), 220 | 221 | 222 | def _solarize_level_to_arg(level, _hparams): 223 | # range [0, 256] 224 | return int((level / _MAX_LEVEL) * 256), 225 | 226 | 227 | def _solarize_add_level_to_arg(level, _hparams): 228 | # range [0, 110] 229 | return int((level / _MAX_LEVEL) * 110), 230 | 231 | 232 | LEVEL_TO_ARG = { 233 | 'AutoContrast': None, 234 | 'Equalize': None, 235 | 'Invert': None, 236 | 'Identity': None, 237 | 'Rotate': _rotate_level_to_arg, 238 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 239 | 'PosterizeOriginal': _posterize_original_level_to_arg, 240 | 'PosterizeResearch': _posterize_research_level_to_arg, 241 | 'PosterizeTpu': _posterize_tpu_level_to_arg, 242 | 'Solarize': _solarize_level_to_arg, 243 | 'SolarizeAdd': _solarize_add_level_to_arg, 244 | 'Color': _enhance_level_to_arg, 245 | 'Contrast': _enhance_level_to_arg, 246 | 'Brightness': _enhance_level_to_arg, 247 | 'Sharpness': _enhance_level_to_arg, 248 | 'ShearX': _shear_level_to_arg, 249 | 'ShearY': _shear_level_to_arg, 250 | 'TranslateX': _translate_abs_level_to_arg, 251 | 'TranslateY': _translate_abs_level_to_arg, 252 | 'TranslateXRel': _translate_rel_level_to_arg, 253 | 'TranslateYRel': _translate_rel_level_to_arg, 254 | } 255 | 256 | 257 | NAME_TO_OP = { 258 | 'AutoContrast': auto_contrast, 259 | 'Equalize': equalize, 260 | 'Invert': invert, 261 | 'Identity': identity, 262 | 'Rotate': rotate, 263 | 'PosterizeOriginal': posterize, 264 | 'PosterizeResearch': posterize, 265 | 'PosterizeTpu': posterize, 266 | 'Solarize': solarize, 267 | 'SolarizeAdd': solarize_add, 268 | 'Color': color, 269 | 'Contrast': contrast, 270 | 'Brightness': brightness, 271 | 'Sharpness': sharpness, 272 | 'ShearX': shear_x, 273 | 'ShearY': shear_y, 274 | 'TranslateX': translate_x_abs, 275 | 'TranslateY': translate_y_abs, 276 | 'TranslateXRel': translate_x_rel, 277 | 'TranslateYRel': translate_y_rel, 278 | } 279 | 280 | 281 | class AutoAugmentOp: 282 | 283 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 284 | hparams = hparams or _HPARAMS_DEFAULT 285 | self.aug_fn = NAME_TO_OP[name] 286 | self.level_fn = LEVEL_TO_ARG[name] 287 | self.prob = prob 288 | self.magnitude = magnitude 289 | self.hparams = hparams.copy() 290 | self.kwargs = dict( 291 | fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, 292 | resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, 293 | ) 294 | 295 | # If magnitude_std is > 0, we introduce some randomness 296 | # in the usually fixed policy and sample magnitude from a normal distribution 297 | # with mean `magnitude` and std-dev of `magnitude_std`. 298 | # NOTE This is my own hack, being tested, not in papers or reference impls. 299 | self.magnitude_std = self.hparams.get('magnitude_std', 0) 300 | 301 | def __call__(self, img): 302 | if random.random() > self.prob: 303 | return img 304 | magnitude = self.magnitude 305 | if self.magnitude_std and self.magnitude_std > 0: 306 | magnitude = random.gauss(magnitude, self.magnitude_std) 307 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 308 | level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() 309 | return self.aug_fn(img, *level_args, **self.kwargs) 310 | 311 | 312 | _RAND_TRANSFORMS = [ 313 | 'AutoContrast', 314 | 'Equalize', 315 | 'Invert', 316 | 'Rotate', 317 | 'PosterizeTpu', 318 | 'Solarize', 319 | 'SolarizeAdd', 320 | 'Color', 321 | 'Contrast', 322 | 'Brightness', 323 | 'Sharpness', 324 | 'ShearX', 325 | 'ShearY', 326 | 'TranslateXRel', 327 | 'TranslateYRel', 328 | # 'Cutout' # FIXME I implement this as random erasing separately 329 | ] 330 | 331 | _RAND_TRANSFORMS_CMC = [ 332 | 'AutoContrast', 333 | 'Identity', 334 | 'Rotate', 335 | 'Sharpness', 336 | 'ShearX', 337 | 'ShearY', 338 | 'TranslateXRel', 339 | 'TranslateYRel', 340 | # 'Cutout' # FIXME I implement this as random erasing separately 341 | ] 342 | 343 | 344 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 345 | # They may not result in increased performance, but could likely be tuned to so. 346 | _RAND_CHOICE_WEIGHTS_0 = { 347 | 'Rotate': 0.3, 348 | 'ShearX': 0.2, 349 | 'ShearY': 0.2, 350 | 'TranslateXRel': 0.1, 351 | 'TranslateYRel': 0.1, 352 | 'Color': .025, 353 | 'Sharpness': 0.025, 354 | 'AutoContrast': 0.025, 355 | 'Solarize': .005, 356 | 'SolarizeAdd': .005, 357 | 'Contrast': .005, 358 | 'Brightness': .005, 359 | 'Equalize': .005, 360 | 'PosterizeTpu': 0, 361 | 'Invert': 0, 362 | } 363 | 364 | 365 | def _select_rand_weights(weight_idx=0, transforms=None): 366 | transforms = transforms or _RAND_TRANSFORMS 367 | assert weight_idx == 0 # only one set of weights currently 368 | rand_weights = _RAND_CHOICE_WEIGHTS_0 369 | probs = [rand_weights[k] for k in transforms] 370 | probs /= np.sum(probs) 371 | return probs 372 | 373 | 374 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 375 | """rand augment ops for RGB images""" 376 | hparams = hparams or _HPARAMS_DEFAULT 377 | transforms = transforms or _RAND_TRANSFORMS 378 | return [AutoAugmentOp( 379 | name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] 380 | 381 | 382 | def rand_augment_ops_cmc(magnitude=10, hparams=None, transforms=None): 383 | """rand augment ops for CMC images (removing color ops)""" 384 | hparams = hparams or _HPARAMS_DEFAULT 385 | transforms = transforms or _RAND_TRANSFORMS_CMC 386 | return [AutoAugmentOp( 387 | name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] 388 | 389 | 390 | class RandAugment: 391 | def __init__(self, ops, num_layers=2, choice_weights=None): 392 | self.ops = ops 393 | self.num_layers = num_layers 394 | self.choice_weights = choice_weights 395 | 396 | def __call__(self, img): 397 | # no replacement when using weighted choice 398 | ops = np.random.choice( 399 | self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) 400 | for op in ops: 401 | img = op(img) 402 | return img 403 | 404 | 405 | def rand_augment_transform(config_str, hparams, use_cmc=False): 406 | """ 407 | Create a RandAugment transform 408 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 409 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 410 | sections, not order sepecific determine 411 | 'm' - integer magnitude of rand augment 412 | 'n' - integer num layers (number of transform ops selected per image) 413 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 414 | 'mstd' - float std deviation of magnitude noise applied 415 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 416 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 417 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 418 | :param use_cmc: Flag indicates removing augmentation for coloring ops. 419 | :return: A PyTorch compatible Transform 420 | """ 421 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 422 | num_layers = 2 # default to 2 ops per image 423 | weight_idx = None # default to no probability weights for op choice 424 | config = config_str.split('-') 425 | assert config[0] == 'rand' 426 | config = config[1:] 427 | for c in config: 428 | cs = re.split(r'(\d.*)', c) 429 | if len(cs) < 2: 430 | continue 431 | key, val = cs[:2] 432 | if key == 'mstd': 433 | # noise param injected via hparams for now 434 | hparams.setdefault('magnitude_std', float(val)) 435 | elif key == 'm': 436 | magnitude = int(val) 437 | elif key == 'n': 438 | num_layers = int(val) 439 | elif key == 'w': 440 | weight_idx = int(val) 441 | else: 442 | assert False, 'Unknown RandAugment config section' 443 | if use_cmc: 444 | ra_ops = rand_augment_ops_cmc(magnitude=magnitude, hparams=hparams) 445 | else: 446 | ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) 447 | choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) 448 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 449 | -------------------------------------------------------------------------------- /contrast/data/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Sampler 3 | 4 | 5 | class SubsetSlidingWindowSampler(Sampler): 6 | r"""Samples elements randomly from a given list of indices, without replacement. 7 | 8 | Arguments: 9 | indices (sequence): a sequence of indices 10 | """ 11 | 12 | def __init__(self, indices, window_stride, window_size, shuffle_per_epoch=False): 13 | self.window_stride = window_stride 14 | self.window_size = window_size 15 | self.shuffle_per_epoch = shuffle_per_epoch 16 | self.indices = indices 17 | np.random.shuffle(self.indices) 18 | self.start_index = 0 19 | 20 | def __iter__(self): 21 | # optionally shuffle all indices per epoch 22 | if self.shuffle_per_epoch and self.start_index + self.window_size > len(self): 23 | np.random.shuffle(self.indices) 24 | 25 | # get indices of sampler in the current window 26 | indices = np.mod(np.arange(self.window_size, dtype=np.int) + self.start_index, len(self)) 27 | window_indices = self.indices[indices] 28 | 29 | # shuffle the current window 30 | np.random.shuffle(window_indices) 31 | 32 | # move start index to next window 33 | self.start_index = (self.start_index + self.window_stride) % len(self) 34 | 35 | return iter(window_indices.tolist()) 36 | 37 | def __len__(self): 38 | return self.window_size 39 | 40 | def state_dict(self): 41 | """Returns the state of the scheduler as a :class:`dict`. 42 | It contains an entry for every variable in self.__dict__ which 43 | is not the optimizer. 44 | """ 45 | return {"start_index": self.start_index} 46 | 47 | def load_state_dict(self, state_dict): 48 | """Loads the schedulers state. 49 | Arguments: 50 | state_dict (dict): scheduler state. Should be an object returned 51 | from a call to :meth:`state_dict`. 52 | """ 53 | self.__dict__.update(state_dict) 54 | -------------------------------------------------------------------------------- /contrast/data/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import ImageFilter 4 | from torchvision import transforms 5 | from .rand_augment import rand_augment_transform 6 | 7 | 8 | class GaussianBlur(object): 9 | """Gaussian Blur version 2""" 10 | 11 | def __call__(self, x): 12 | sigma = np.random.uniform(0.1, 2.0) 13 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 14 | return x 15 | 16 | 17 | class MultiTransform(object): 18 | """apply transform to an image for k times""" 19 | def __init__(self, transform, k=3): 20 | self.transform = transform 21 | self.k = k 22 | 23 | def __call__(self, img): 24 | return torch.stack([self.transform(img) for i in range(self.k)]) 25 | 26 | 27 | def get_transform(aug_type, crop, image_size=224, num_crop=1, 28 | crop2=0.14, image_size2=96, num_crop2=3): 29 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 30 | 31 | if aug_type == "InstDisc": # used in InstDisc and MoCo v1 32 | transform = transforms.Compose([ 33 | transforms.RandomResizedCrop(image_size, scale=(crop, 1.)), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 36 | transforms.RandomGrayscale(p=0.2), 37 | transforms.ToTensor(), 38 | normalize, 39 | ]) 40 | elif aug_type == 'MoCov2': # used in MoCov2 41 | transform = transforms.Compose([ 42 | transforms.RandomResizedCrop(image_size, scale=(crop, 1.)), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 45 | transforms.RandomGrayscale(p=0.2), 46 | transforms.RandomApply([GaussianBlur()], p=0.5), 47 | transforms.ToTensor(), 48 | normalize 49 | ]) 50 | elif aug_type == 'SimCLR': # used in SimCLR and PIC 51 | transform = transforms.Compose([ 52 | transforms.RandomResizedCrop(image_size, scale=(crop, 1.)), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), 55 | transforms.RandomGrayscale(p=0.2), 56 | transforms.RandomApply([GaussianBlur()], p=0.5), 57 | transforms.ToTensor(), 58 | normalize, 59 | ]) 60 | elif aug_type == 'MultiCrop': # used in PIC_MultiCrop 61 | assert crop < crop2 62 | transform1 = MultiTransform( 63 | transforms.Compose([ 64 | transforms.RandomResizedCrop(image_size, scale=(crop2, 1.)), 65 | transforms.RandomHorizontalFlip(), 66 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), 67 | transforms.RandomGrayscale(p=0.2), 68 | transforms.RandomApply([GaussianBlur()], p=0.5), 69 | transforms.ToTensor(), 70 | normalize, 71 | ]), 72 | k=num_crop) 73 | transform2 = MultiTransform( 74 | transforms.Compose([ 75 | transforms.RandomResizedCrop(image_size2, scale=(crop, crop2)), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), 78 | transforms.RandomGrayscale(p=0.2), 79 | transforms.RandomApply([GaussianBlur()], p=0.5), 80 | transforms.ToTensor(), 81 | normalize, 82 | ]), 83 | k=num_crop2) 84 | transform = (transform1, transform2) 85 | elif aug_type == 'RandAug': # used in InfoMin 86 | rgb_mean = (0.485, 0.456, 0.406) 87 | ra_params = dict( 88 | translate_const=int(224 * 0.45), 89 | img_mean=tuple([min(255, round(255 * x)) for x in rgb_mean]), 90 | ) 91 | transform = transforms.Compose([ 92 | transforms.RandomResizedCrop(224, scale=(crop, 1.)), 93 | transforms.RandomHorizontalFlip(), 94 | transforms.RandomApply([ 95 | transforms.ColorJitter(0.8, 0.8, 0.8, 0.2) 96 | ], p=0.8), 97 | transforms.RandomApply([GaussianBlur()], p=0.5), 98 | rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10), ra_params), 99 | transforms.RandomGrayscale(p=0.2), 100 | transforms.ToTensor(), 101 | normalize, 102 | ]) 103 | elif aug_type == 'NULL': # used in linear evaluation 104 | transform = transforms.Compose([ 105 | transforms.RandomResizedCrop(image_size, scale=(crop, 1.)), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | normalize, 109 | ]) 110 | elif aug_type == 'val': # used in validate 111 | transform = transforms.Compose([ 112 | transforms.Resize(image_size + 32), 113 | transforms.CenterCrop(image_size), 114 | transforms.ToTensor(), 115 | normalize 116 | ]) 117 | else: 118 | supported = '[InstDisc, MoCov2, SimCLR, RandAug, Null, val]' 119 | raise NotImplementedError(f'aug_type "{aug_type}" not supported. Should in {supported}') 120 | 121 | return transform 122 | -------------------------------------------------------------------------------- /contrast/data/zipreader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | 4 | 5 | def is_zip_path(img_or_path): 6 | """judge if this is a zip path""" 7 | return '.zip@' in img_or_path 8 | 9 | 10 | class ZipReader(object): 11 | """A class to read zipped files""" 12 | zip_bank = dict() 13 | 14 | def __init__(self): 15 | super(ZipReader, self).__init__() 16 | 17 | @staticmethod 18 | def get_zipfile(path): 19 | zip_bank = ZipReader.zip_bank 20 | if path not in zip_bank: 21 | zfile = zipfile.ZipFile(path, 'r') 22 | zip_bank[path] = zfile 23 | return zip_bank[path] 24 | 25 | @staticmethod 26 | def split_zip_style_path(path): 27 | pos_at = path.index('@') 28 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 29 | 30 | zip_path = path[0: pos_at] 31 | folder_path = path[pos_at + 1:] 32 | folder_path = str.strip(folder_path, '/') 33 | return zip_path, folder_path 34 | 35 | @staticmethod 36 | def list_folder(path): 37 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 38 | 39 | zfile = ZipReader.get_zipfile(zip_path) 40 | folder_list = [] 41 | for file_foler_name in zfile.namelist(): 42 | file_foler_name = str.strip(file_foler_name, '/') 43 | if file_foler_name.startswith(folder_path) and \ 44 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 45 | file_foler_name != folder_path: 46 | if len(folder_path) == 0: 47 | folder_list.append(file_foler_name) 48 | else: 49 | folder_list.append(file_foler_name[len(folder_path)+1:]) 50 | 51 | return folder_list 52 | 53 | @staticmethod 54 | def list_files(path, extension=None): 55 | if extension is None: 56 | extension = ['.*'] 57 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 58 | 59 | zfile = ZipReader.get_zipfile(zip_path) 60 | file_lists = [] 61 | for file_foler_name in zfile.namelist(): 62 | file_foler_name = str.strip(file_foler_name, '/') 63 | if file_foler_name.startswith(folder_path) and \ 64 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 65 | if len(folder_path) == 0: 66 | file_lists.append(file_foler_name) 67 | else: 68 | file_lists.append(file_foler_name[len(folder_path)+1:]) 69 | 70 | return file_lists 71 | 72 | @staticmethod 73 | def read(path): 74 | zip_path, path_img = ZipReader.split_zip_style_path(path) 75 | zfile = ZipReader.get_zipfile(zip_path) 76 | data = zfile.read(path_img) 77 | return data 78 | -------------------------------------------------------------------------------- /contrast/lars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.optim.optimizer import Optimizer 4 | 5 | __all__ = ['LARS'] 6 | 7 | 8 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 9 | """Splits param group into weight_decay / non-weight decay. 10 | Tweaked from https://bit.ly/3dzyqod 11 | :param model: the torch.nn model 12 | :param weight_decay: weight decay term 13 | :param skip_list: extra modules (besides BN/bias) to skip 14 | :returns: split param group into weight_decay/not-weight decay 15 | :rtype: list(dict) 16 | """ 17 | decay, no_decay = [], [] 18 | for name, param in model.named_parameters(): 19 | if not param.requires_grad: 20 | continue 21 | 22 | if len(param.shape) == 1 or name in skip_list: 23 | if dist.get_rank() == 0: 24 | print('Skip weight decay, ', name) 25 | no_decay.append(param) 26 | else: 27 | decay.append(param) 28 | return [ 29 | {'params': no_decay, 'weight_decay': 0, 'ignore': True}, 30 | {'params': decay, 'weight_decay': weight_decay, 'ignore': False}] 31 | 32 | class LARS(Optimizer): 33 | """Implements 'LARS (Layer-wise Adaptive Rate Scaling)'__ as Optimizer a 34 | :class:`~torch.optim.Optimizer` wrapper. 35 | 36 | __ : https://arxiv.org/abs/1708.03888 37 | 38 | Wraps an arbitrary optimizer like :class:`torch.optim.SGD` to use LARS. If 39 | you want to the same performance obtained with small-batch training when 40 | you use large-batch training, LARS will be helpful:: 41 | 42 | Args: 43 | optimizer (Optimizer): 44 | optimizer to wrap 45 | eps (float, optional): 46 | epsilon to help with numerical stability while calculating the 47 | adaptive learning rate 48 | trust_coef (float, optional): 49 | trust coefficient for calculating the adaptive learning rate 50 | 51 | Example:: 52 | base_optimizer = optim.SGD(model.parameters(), lr=0.1) 53 | optimizer = LARS(optimizer=base_optimizer) 54 | 55 | output = model(input) 56 | loss = loss_fn(output, target) 57 | loss.backward() 58 | 59 | optimizer.step() 60 | 61 | """ 62 | 63 | def __init__(self, optimizer, eps=1e-8, trust_coef=0.001): 64 | if eps < 0.0: 65 | raise ValueError('invalid epsilon value: , %f' % eps) 66 | 67 | if trust_coef < 0.0: 68 | raise ValueError("invalid trust coefficient: %f" % trust_coef) 69 | 70 | self.optim = optimizer 71 | self.eps = eps 72 | self.trust_coef = trust_coef 73 | 74 | def __getstate__(self): 75 | lars_dict = {} 76 | lars_dict['eps'] = self.eps 77 | lars_dict['trust_coef'] = self.trust_coef 78 | return (self.optim, lars_dict) 79 | 80 | def __setstate__(self, state): 81 | self.optim, lars_dict = state 82 | self.eps = lars_dict['eps'] 83 | self.trust_coef = lars_dict['trust_coef'] 84 | 85 | def __repr__(self): 86 | return '%s(%r)' % (self.__class__.__name__, self.optim) 87 | 88 | @property 89 | def param_groups(self): 90 | return self.optim.param_groups 91 | 92 | @property 93 | def state(self): 94 | return self.optim.state 95 | 96 | def state_dict(self): 97 | return self.optim.state_dict() 98 | 99 | def load_state_dict(self, state_dict): 100 | self.optim.load_state_dict(state_dict) 101 | 102 | def zero_grad(self): 103 | self.optim.zero_grad() 104 | 105 | def add_param_group(self, param_group): 106 | self.optim.add_param_group(param_group) 107 | 108 | def apply_adaptive_lrs(self): 109 | with torch.no_grad(): 110 | for group in self.optim.param_groups: 111 | weight_decay = group['weight_decay'] 112 | ignore = group.get('ignore', None) # NOTE: this is set by add_weight_decay 113 | 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | 118 | # Add weight decay before computing adaptive LR 119 | # Seems to be pretty important in SIMclr style models. 120 | if weight_decay > 0: 121 | p.grad = p.grad.add(p, alpha=weight_decay) 122 | 123 | # Ignore bias / bn terms for LARS update 124 | if ignore is not None and not ignore: 125 | # compute the parameter and gradient norms 126 | param_norm = p.norm() 127 | grad_norm = p.grad.norm() 128 | 129 | # compute our adaptive learning rate 130 | adaptive_lr = 1.0 131 | if param_norm > 0 and grad_norm > 0: 132 | adaptive_lr = self.trust_coef * param_norm / (grad_norm + self.eps) 133 | 134 | # print("applying {} lr scaling to param of shape {}".format(adaptive_lr, p.shape)) 135 | p.grad = p.grad.mul(adaptive_lr) 136 | 137 | def step(self, *args, **kwargs): 138 | self.apply_adaptive_lrs() 139 | 140 | # Zero out weight decay as we do it in LARS 141 | weight_decay_orig = [group['weight_decay'] for group in self.optim.param_groups] 142 | for group in self.optim.param_groups: 143 | group['weight_decay'] = 0 144 | 145 | loss = self.optim.step(*args, **kwargs) # Normal optimizer 146 | 147 | # Restore weight decay 148 | for group, wo in zip(self.optim.param_groups, weight_decay_orig): 149 | group['weight_decay'] = wo 150 | 151 | return loss -------------------------------------------------------------------------------- /contrast/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | from termcolor import colored 7 | 8 | 9 | class _ColorfulFormatter(logging.Formatter): 10 | def __init__(self, *args, **kwargs): 11 | self._root_name = kwargs.pop("root_name") + "." 12 | self._abbrev_name = kwargs.pop("abbrev_name", "") 13 | if len(self._abbrev_name): 14 | self._abbrev_name = self._abbrev_name + "." 15 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 16 | 17 | def formatMessage(self, record): 18 | record.name = record.name.replace(self._root_name, self._abbrev_name) 19 | log = super(_ColorfulFormatter, self).formatMessage(record) 20 | if record.levelno == logging.WARNING: 21 | prefix = colored("WARNING", "red", attrs=["blink"]) 22 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 23 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 24 | else: 25 | return log 26 | return prefix + " " + log 27 | 28 | 29 | # so that calling setup_logger multiple times won't add many handlers 30 | @functools.lru_cache() 31 | def setup_logger( 32 | output=None, distributed_rank=0, *, color=True, name="contrast", abbrev_name=None 33 | ): 34 | """ 35 | Initialize the detectron2 logger and set its verbosity level to "INFO". 36 | 37 | Args: 38 | output (str): a file name or a directory to save log. If None, will not save log file. 39 | If ends with ".txt" or ".log", assumed to be a file name. 40 | Otherwise, logs will be saved to `output/log.txt`. 41 | name (str): the root module name of this logger 42 | 43 | Returns: 44 | logging.Logger: a logger 45 | """ 46 | logger = logging.getLogger(name) 47 | logger.setLevel(logging.DEBUG) 48 | logger.propagate = False 49 | 50 | if abbrev_name is None: 51 | abbrev_name = name 52 | 53 | plain_formatter = logging.Formatter( 54 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 55 | ) 56 | # stdout logging: master only 57 | if distributed_rank == 0: 58 | ch = logging.StreamHandler(stream=sys.stdout) 59 | ch.setLevel(logging.DEBUG) 60 | if color: 61 | formatter = _ColorfulFormatter( 62 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 63 | datefmt="%m/%d %H:%M:%S", 64 | root_name=name, 65 | abbrev_name=str(abbrev_name), 66 | ) 67 | else: 68 | formatter = plain_formatter 69 | ch.setFormatter(formatter) 70 | logger.addHandler(ch) 71 | 72 | # file logging: all workers 73 | if output is not None: 74 | if output.endswith(".txt") or output.endswith(".log"): 75 | filename = output 76 | else: 77 | filename = os.path.join(output, "log.txt") 78 | if distributed_rank > 0: 79 | filename = filename + f".rank{distributed_rank}" 80 | os.makedirs(os.path.dirname(filename), exist_ok=True) 81 | 82 | fh = logging.StreamHandler(_cached_log_stream(filename)) 83 | fh.setLevel(logging.DEBUG) 84 | fh.setFormatter(plain_formatter) 85 | logger.addHandler(fh) 86 | 87 | return logger 88 | 89 | 90 | # cache the opened file object, so that different calls to `setup_logger` 91 | # with the same file name can safely write to the same file. 92 | @functools.lru_cache(maxsize=None) 93 | def _cached_log_stream(filename): 94 | return open(filename, "a") 95 | -------------------------------------------------------------------------------- /contrast/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # noinspection PyProtectedMember 2 | from torch.optim.lr_scheduler import _LRScheduler, MultiStepLR, CosineAnnealingLR 3 | 4 | 5 | # noinspection PyAttributeOutsideInit 6 | class GradualWarmupScheduler(_LRScheduler): 7 | """ Gradually warm-up(increasing) learning rate in optimizer. 8 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: init learning rate = base lr / multiplier 12 | warmup_epoch: target learning rate is reached at warmup_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler, last_epoch=-1): 17 | self.multiplier = multiplier 18 | if self.multiplier <= 1.: 19 | raise ValueError('multiplier should be greater than 1.') 20 | self.warmup_epoch = warmup_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super().__init__(optimizer, last_epoch=last_epoch) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.warmup_epoch: 27 | return self.after_scheduler.get_lr() 28 | else: 29 | return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.) 30 | for base_lr in self.base_lrs] 31 | 32 | def step(self, epoch=None): 33 | if epoch is None: 34 | epoch = self.last_epoch + 1 35 | self.last_epoch = epoch 36 | if epoch > self.warmup_epoch: 37 | self.after_scheduler.step(epoch - self.warmup_epoch) 38 | else: 39 | super(GradualWarmupScheduler, self).step(epoch) 40 | 41 | def state_dict(self): 42 | """Returns the state of the scheduler as a :class:`dict`. 43 | 44 | It contains an entry for every variable in self.__dict__ which 45 | is not the optimizer. 46 | """ 47 | 48 | state = {key: value for key, value in self.__dict__.items() if key != 'optimizer' and key != 'after_scheduler'} 49 | state['after_scheduler'] = self.after_scheduler.state_dict() 50 | return state 51 | 52 | def load_state_dict(self, state_dict): 53 | """Loads the schedulers state. 54 | 55 | Arguments: 56 | state_dict (dict): scheduler state. Should be an object returned 57 | from a call to :meth:`state_dict`. 58 | """ 59 | 60 | after_scheduler_state = state_dict.pop('after_scheduler') 61 | self.__dict__.update(state_dict) 62 | self.after_scheduler.load_state_dict(after_scheduler_state) 63 | 64 | 65 | def get_scheduler(optimizer, n_iter_per_epoch, args): 66 | if "cosine" in args.lr_scheduler: 67 | scheduler = CosineAnnealingLR( 68 | optimizer=optimizer, 69 | eta_min=0.000001, 70 | T_max=(args.epochs - args.warmup_epoch) * n_iter_per_epoch) 71 | elif "step" in args.lr_scheduler: 72 | scheduler = MultiStepLR( 73 | optimizer=optimizer, 74 | gamma=args.lr_decay_rate, 75 | milestones=[(m - args.warmup_epoch)*n_iter_per_epoch for m in args.lr_decay_epochs]) 76 | else: 77 | raise NotImplementedError(f"scheduler {args.lr_scheduler} not supported") 78 | 79 | if args.warmup_epoch > 0: 80 | scheduler = GradualWarmupScheduler( 81 | optimizer, 82 | multiplier=args.warmup_multiplier, 83 | after_scheduler=scheduler, 84 | warmup_epoch=args.warmup_epoch * n_iter_per_epoch) 85 | return scheduler 86 | -------------------------------------------------------------------------------- /contrast/models/InstDisc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import BaseModel 6 | 7 | 8 | class InstDisc(BaseModel): 9 | """ 10 | Build a InstDisc model with: a encoder, a memory bank 11 | """ 12 | 13 | def __init__(self, base_encoder, args): 14 | """ 15 | dim: feature dimension (default: 128) 16 | K: number of negative keys (default: 65536) 17 | m: momentum of updating memory bank (default: 0.999) 18 | T: softmax temperature (default: 0.07) 19 | """ 20 | super(InstDisc, self).__init__(base_encoder, args) 21 | 22 | self.contrast_num_negative = args.contrast_num_negative 23 | self.contrast_momentum = args.contrast_momentum 24 | self.contrast_temperature = args.contrast_temperature 25 | self.num_instances = args.num_instances 26 | 27 | # create the memory 28 | self.register_buffer('memory', torch.randn(args.num_instances, args.feature_dim)) 29 | self.memory = F.normalize(self.memory, dim=0) 30 | 31 | @torch.no_grad() 32 | def _momentum_update_memory(self, feature, y_idx): 33 | memory_pos = torch.index_select(self.memory, 0, y_idx.view(-1)) 34 | memory_pos.mul_(self.contrast_momentum).add_(feature.detach() * (1 - self.contrast_momentum)) 35 | updated_weight = F.normalize(memory_pos) 36 | self.memory.index_copy_(0, y_idx, updated_weight) 37 | 38 | def forward(self, x, y_idx): 39 | """ 40 | Input: 41 | x: a batch of images 42 | y_idx: index of images 43 | Output: 44 | logits, targets 45 | """ 46 | 47 | feature = F.normalize(self.encoder(x), dim=1) 48 | 49 | bs = feature.shape[0] 50 | 51 | # get positive and negative features from memory 52 | with torch.no_grad(): 53 | # random generate indices of negative sample 54 | idx = torch.randint(self.num_instances, size=(bs, self.contrast_num_negative+1)).to(feature.device) 55 | 56 | # let first element to be positive sample 57 | idx[:, 0] = y_idx 58 | 59 | # get weight of positive and negative samples, shape [bs, K+1, dim] 60 | weight = self.memory[idx] 61 | 62 | # logits: (bs, K+1) 63 | logits = torch.einsum('bd,bkd->bk', [feature, weight]) 64 | 65 | # apply temperature 66 | logits /= self.contrast_temperature 67 | 68 | # labels: positive key indicators 69 | labels = torch.zeros(bs, dtype=torch.long).cuda() 70 | 71 | # momentum update memory 72 | self._momentum_update_memory(feature, y_idx) 73 | 74 | return logits, labels 75 | -------------------------------------------------------------------------------- /contrast/models/MoCo.py: -------------------------------------------------------------------------------- 1 | # adopted from https://raw.githubusercontent.com/facebookresearch/moco/master/moco/builder.py 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from contrast.util import dist_collect 8 | from .base import BaseModel 9 | 10 | 11 | class MoCo(BaseModel): 12 | """ 13 | Build a MoCo model with: a query encoder, a key encoder, and a queue 14 | https://arxiv.org/abs/1911.05722 15 | """ 16 | 17 | def __init__(self, base_encoder, args): 18 | """ 19 | dim: feature dimension (default: 128) 20 | K: queue size; number of negative keys (default: 65536) 21 | m: moco momentum of updating key encoder (default: 0.999) 22 | T: softmax temperature (default: 0.07) 23 | """ 24 | super(MoCo, self).__init__(base_encoder, args) 25 | 26 | self.contrast_num_negative = args.contrast_num_negative 27 | self.contrast_momentum = args.contrast_momentum 28 | self.contrast_temperature = args.contrast_temperature 29 | 30 | # create the encoder_k 31 | self.encoder_k = base_encoder(low_dim=args.feature_dim, mlp_head=args.mlp_head) 32 | 33 | for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()): 34 | param_k.data.copy_(param_q.data) # initialize 35 | param_k.requires_grad = False # not update by gradient 36 | 37 | # create the queue 38 | self.register_buffer("queue", torch.randn(args.feature_dim, self.contrast_num_negative)) 39 | self.queue = F.normalize(self.queue, dim=0) 40 | 41 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 42 | 43 | @torch.no_grad() 44 | def _momentum_update_key_encoder(self): 45 | """ 46 | Momentum update of the key encoder 47 | """ 48 | for param_q, param_k in zip(self.encoder.parameters(), self.encoder_k.parameters()): 49 | param_k.data = param_k.data * self.contrast_momentum + param_q.data * (1. - self.contrast_momentum) 50 | 51 | @torch.no_grad() 52 | def _dequeue_and_enqueue(self, keys): 53 | # gather keys before updating queue 54 | keys = dist_collect(keys) 55 | 56 | batch_size = keys.shape[0] 57 | 58 | ptr = int(self.queue_ptr) 59 | assert self.contrast_num_negative % batch_size == 0 # for simplicity 60 | 61 | # replace the keys at ptr (dequeue and enqueue) 62 | self.queue[:, ptr:ptr + batch_size] = keys.T 63 | ptr = (ptr + batch_size) % self.contrast_num_negative # move pointer 64 | 65 | self.queue_ptr[0] = ptr 66 | 67 | @torch.no_grad() 68 | def _batch_shuffle_ddp(self, x): 69 | """ 70 | Batch shuffle, for making use of BatchNorm. 71 | *** Only support DistributedDataParallel (DDP) model. *** 72 | """ 73 | # gather from all gpus 74 | batch_size_this = x.shape[0] 75 | x_gather = dist_collect(x) 76 | batch_size_all = x_gather.shape[0] 77 | 78 | num_gpus = batch_size_all // batch_size_this 79 | 80 | # random shuffle index 81 | idx_shuffle = torch.randperm(batch_size_all).cuda() 82 | 83 | # broadcast to all gpus 84 | torch.distributed.broadcast(idx_shuffle, src=0) 85 | 86 | # index for restoring 87 | idx_unshuffle = torch.argsort(idx_shuffle) 88 | 89 | # shuffled index for this gpu 90 | gpu_idx = torch.distributed.get_rank() 91 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 92 | 93 | return x_gather[idx_this], idx_unshuffle 94 | 95 | @torch.no_grad() 96 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 97 | """ 98 | Undo batch shuffle. 99 | *** Only support DistributedDataParallel (DDP) model. *** 100 | """ 101 | # gather from all gpus 102 | batch_size_this = x.shape[0] 103 | x_gather = dist_collect(x) 104 | batch_size_all = x_gather.shape[0] 105 | 106 | num_gpus = batch_size_all // batch_size_this 107 | 108 | # restored index for this gpu 109 | gpu_idx = torch.distributed.get_rank() 110 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 111 | 112 | return x_gather[idx_this] 113 | 114 | def forward(self, im_q, im_k): 115 | """ 116 | Input: 117 | im_q: a batch of query images 118 | im_k: a batch of key images 119 | Output: 120 | logits, targets 121 | """ 122 | 123 | # compute query features 124 | q = self.encoder(im_q) # queries: NxC 125 | q = F.normalize(q, dim=1) 126 | 127 | # compute key features 128 | with torch.no_grad(): # no gradient to keys 129 | self._momentum_update_key_encoder() # update the key encoder 130 | 131 | # shuffle for making use of BN 132 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 133 | 134 | k = self.encoder_k(im_k) # keys: NxC 135 | k = F.normalize(k, dim=1) 136 | 137 | # undo shuffle 138 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 139 | 140 | # compute logits 141 | # Einstein sum is more intuitive 142 | # positive logits: Nx1 143 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 144 | # negative logits: NxK 145 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 146 | 147 | # logits: Nx(1+K) 148 | logits = torch.cat([l_pos, l_neg], dim=1) 149 | 150 | # apply temperature 151 | logits /= self.contrast_temperature 152 | 153 | # labels: positive key indicators 154 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 155 | 156 | # dequeue and enqueue 157 | self._dequeue_and_enqueue(k) 158 | 159 | return logits, labels 160 | -------------------------------------------------------------------------------- /contrast/models/PIC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base import BaseModel 6 | 7 | 8 | class PIC(BaseModel): 9 | """ 10 | Build a simple PIC model with multi crop 11 | """ 12 | 13 | def __init__(self, base_encoder, args): 14 | """ 15 | dim: feature dimension (default: 128) 16 | m: momentum of updating memory bank (default: 0.999) 17 | T: softmax temperature (default: 0.07) 18 | """ 19 | super(PIC, self).__init__(base_encoder, args) 20 | 21 | self.contrast_temperature = args.contrast_temperature 22 | self.num_instances = args.num_instances 23 | 24 | self.sim_matrix = nn.Parameter(torch.randn(size=(args.num_instances, args.feature_dim)), requires_grad=True) 25 | nn.init.normal_(self.sim_matrix, 0, 0.01) 26 | 27 | def forward(self, x, x_small, y_idx): 28 | """ 29 | Input: 30 | x: a batch of images 31 | x_small: a batch of images(small crops) 32 | y_idx: index of images 33 | Output: 34 | logits, targets 35 | """ 36 | 37 | # large crops 38 | k = x.shape[1] 39 | if k == 1: 40 | x_list = [torch.squeeze(x, 1)] 41 | else: 42 | x_list = [torch.squeeze(x[:, i, ...], 1).contiguous() for i in range(k)] 43 | 44 | feature_list = [F.normalize(self.encoder(x), dim=1) for x in x_list] 45 | 46 | sim_matrix = F.normalize(self.sim_matrix, dim=1) 47 | 48 | logit_list = [torch.einsum('nc,kc->nk', [f, sim_matrix]) for f in feature_list] 49 | 50 | logit_list = [logit / self.contrast_temperature for logit in logit_list] 51 | 52 | # small crops 53 | k_small = x_small.shape[1] 54 | if k_small == 1: 55 | x_small_list = [torch.squeeze(x_small, 1)] 56 | else: 57 | x_small_list = [torch.squeeze(x_small[:, i, ...], 1).contiguous() for i in range(k_small)] 58 | 59 | feature_small_list = [F.normalize(self.encoder(x), dim=1) for x in x_small_list] 60 | 61 | # block gradient of sim_matrix 62 | sim_matrix_detach = sim_matrix.detach() 63 | 64 | logit_small_list = [torch.einsum('nc,kc->nk', [f, sim_matrix_detach]) for f in feature_small_list] 65 | 66 | logit_small_list = [logit / self.contrast_temperature for logit in logit_small_list] 67 | 68 | logit = torch.cat(logit_list + logit_small_list, dim=0) 69 | 70 | return logit, y_idx.repeat(k + k_small) 71 | -------------------------------------------------------------------------------- /contrast/models/SimCLR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from contrast.util import dist_collect 6 | from .base import BaseModel 7 | 8 | LARGE_NUM = 1e9 9 | 10 | 11 | class SimCLR(BaseModel): 12 | """ 13 | Build a SimCLR model with: a encoder and syncBN 14 | """ 15 | 16 | def __init__(self, base_encoder, args): 17 | """ 18 | dim: feature dimension (default: 128) 19 | T: softmax temperature (default: 0.07) 20 | """ 21 | super(SimCLR, self).__init__(base_encoder, args) 22 | 23 | self.contrast_temperature = args.contrast_temperature 24 | 25 | nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder) 26 | 27 | def forward(self, x1, x2): 28 | """ 29 | Input: 30 | x1: a batch of first augmentation images 31 | x2: a batch of second augmentation images 32 | Output: 33 | logit, label 34 | """ 35 | 36 | # compute features 37 | f1 = F.normalize(self.encoder(x1), dim=1) 38 | f2 = F.normalize(self.encoder(x2), dim=1) 39 | 40 | # gather features from all gpus 41 | batch_size_this = f1.size(0) 42 | f1_gather = dist_collect(f1) 43 | f2_gather = dist_collect(f2) 44 | batch_size_all = f1_gather.size(0) 45 | 46 | # compute mask 47 | gpu_index = torch.distributed.get_rank() 48 | label_index = torch.arange(batch_size_this) + gpu_index * batch_size_this 49 | mask = torch.zeros(batch_size_this, batch_size_all) 50 | mask.scatter_(1, label_index.view(-1, 1), 1) 51 | mask = mask.cuda() 52 | 53 | # compute logit 54 | logit_aa = torch.mm(f1, f1_gather.T) / self.contrast_temperature 55 | logit_aa = logit_aa - mask * LARGE_NUM 56 | logit_bb = torch.mm(f2, f2_gather.T) / self.contrast_temperature 57 | logit_bb = logit_bb - mask * LARGE_NUM 58 | logit_ab = torch.mm(f1, f2_gather.T) / self.contrast_temperature 59 | logit_ba = torch.mm(f2, f1_gather.T) / self.contrast_temperature 60 | 61 | logit_a = torch.cat([logit_ab, logit_aa], dim=1) 62 | logit_b = torch.cat([logit_ba, logit_bb], dim=1) 63 | logit = torch.cat([logit_a, logit_b], dim=0) 64 | 65 | # compute label 66 | label = label_index.repeat(2).cuda() 67 | 68 | return logit, label 69 | -------------------------------------------------------------------------------- /contrast/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .InstDisc import InstDisc 2 | from .MoCo import MoCo 3 | from .PIC import PIC 4 | from .SimCLR import SimCLR 5 | 6 | __all__ = ['InstDisc', 'MoCo', 'PIC', 'SimCLR'] 7 | -------------------------------------------------------------------------------- /contrast/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base model with: a encoder 9 | """ 10 | 11 | def __init__(self, base_encoder, args): 12 | super(BaseModel, self).__init__() 13 | 14 | # create the encoders 15 | self.encoder = base_encoder(low_dim=args.feature_dim, mlp_head=args.mlp_head) 16 | 17 | def forward(self, x1, x2): 18 | """ 19 | Input: x1, x2 or x, y_idx 20 | Output: logits, labels 21 | """ 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /contrast/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from contrast import resnet 5 | from contrast.util import MyHelpFormatter 6 | 7 | 8 | model_names = sorted(name for name in resnet.__all__ 9 | if name.islower() and callable(resnet.__dict__[name])) 10 | 11 | 12 | def parse_option(stage='pre-train'): 13 | """ configs for pre-train or linear stage 14 | """ 15 | parser = argparse.ArgumentParser(f'contrast {stage} stage', formatter_class=MyHelpFormatter) 16 | 17 | # dataset 18 | parser.add_argument('--data-dir', type=str, default='./data', help='dataset director') 19 | parser.add_argument('--crop', type=float, default=0.2 if stage == 'pre-train' else 0.08, help='minimum crop') 20 | parser.add_argument('--aug', type=str, default='NULL', choices=['NULL', 'InstDisc', 'MoCov2', 'SimCLR', 'RandAug', 'MultiCrop'], 21 | help='which augmentation to use.') 22 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 23 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 24 | help='no: no cache, ' 25 | 'full: cache all data, ' 26 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 27 | 28 | parser.add_argument('--num-workers', type=int, default=4, help='num of cpu workers per GPU to use') 29 | if stage == 'linear': 30 | parser.add_argument('--total-batch-size', type=int, default=256, help='total train batch size for all GPU') 31 | else: 32 | parser.add_argument('--batch-size', type=int, default=64, help='batch_size for single gpu') 33 | # sliding window sampler 34 | parser.add_argument('--window-size', type=int, default=131072, help='window size in sliding window sampler') 35 | parser.add_argument('--window-stride', type=int, default=16384, help='window stride in sliding window sampler') 36 | parser.add_argument('--use-sliding-window-sampler', action='store_true', 37 | help='whether to use sliding window sampler') 38 | parser.add_argument('--shuffle-per-epoch', action='store_true', 39 | help='shuffle indices in sliding window sampler per epoch') 40 | # multi crop 41 | parser.add_argument('--image-size', type=int, default=224, help='crop size') 42 | parser.add_argument('--image-size2', type=int, default=96, help='small crop size (for MultiCrop)') 43 | parser.add_argument('--crop2', type=float, default=0.14, 44 | help='minimum crop for large crops, maximum crop for small crops') 45 | parser.add_argument('--num-crop', type=int, default=1, help='number of crops') 46 | parser.add_argument('--num-crop2', type=int, default=3, help='number of small crops') 47 | 48 | # model 49 | parser.add_argument('--arch', type=str, default='resnet50', choices=model_names, 50 | help="backbone architecture") 51 | if stage == 'pre-train': 52 | parser.add_argument('--model', type=str, default='PIC', choices=['PIC', 'MoCo', 'SimCLR', 'InstDisc'], 53 | help='which model to use') 54 | parser.add_argument('--contrast-temperature', type=float, default=0.07, help='temperature in instance cls loss') 55 | parser.add_argument('--contrast-momentum', type=float, default=0.999, 56 | help='momentume parameter used in MoCo and InstDisc') 57 | parser.add_argument('--contrast-num-negative', type=int, default=65536, 58 | help='number of negative samples used in MoCo and InstDisc') 59 | parser.add_argument('--feature-dim', type=int, default=128, help='feature dimension') 60 | parser.add_argument('--mlp-head', action='store_true', help='use mlp head') 61 | 62 | # optimization 63 | if stage == 'pre-train': 64 | parser.add_argument('--base-learning-rate', '--base-lr', type=float, default=0.03, 65 | help='base learning when batch size = 256. final lr is determined by linear scale') 66 | parser.add_argument('--optimizer', type=str, default='sgd', choices=['sgd', 'lars'], 67 | help='optimizer in pre-train stage') 68 | else: 69 | parser.add_argument('--learning-rate', type=float, default=30, help='learning rate') 70 | parser.add_argument('--lr-scheduler', type=str, default='cosine', 71 | choices=["step", "cosine"], help="learning rate scheduler") 72 | parser.add_argument('--warmup-epoch', type=int, default=5, help='warmup epoch') 73 | parser.add_argument('--warmup-multiplier', type=int, default=100, help='warmup multiplier') 74 | parser.add_argument('--lr-decay-epochs', type=int, default=[120, 160, 200], nargs='+', 75 | help='for step scheduler. where to decay lr, can be a list') 76 | parser.add_argument('--lr-decay-rate', type=float, default=0.1, 77 | help='for step scheduler. decay rate for learning rate') 78 | parser.add_argument('--weight-decay', type=float, default=1e-4 if stage == 'pre-train' else 0, help='weight decay') 79 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD') 80 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 81 | help='mixed precision opt level, if O0, no amp is used') 82 | parser.add_argument('--start-epoch', type=int, default=1, help='used for resume') 83 | parser.add_argument('--epochs', type=int, default=100, help='number of training epochs') 84 | 85 | # misc 86 | parser.add_argument('--output-dir', type=str, default='./output', help='output director') 87 | parser.add_argument('--auto-resume', action='store_true', help='whether auto resume from current.pth') 88 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 89 | help='path to latest checkpoint') 90 | parser.add_argument('--print-freq', type=int, default=100, help='print message frequency (iteration)') 91 | parser.add_argument('--save-freq', type=int, default=10, help='save checkpoint frequency (epoch)') 92 | parser.add_argument("--local_rank", type=int, required=True, 93 | help='local rank for DistributedDataParallel, required by pytorch DDP') 94 | if stage == 'linear': 95 | parser.add_argument('--pretrained-model', type=str, required=True, help="path to the pretrained model") 96 | parser.add_argument('-e', '--eval', action='store_true', help='only evaluate') 97 | 98 | args = parser.parse_args() 99 | 100 | return args 101 | -------------------------------------------------------------------------------- /contrast/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 6 | 'resnet18_d', 'resnet34_d', 'resnet50_d', 'resnet101_d', 'resnet152_d'] 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | 13 | 14 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 15 | return nn.Sequential( 16 | conv3x3(in_planes, out_planes, stride), 17 | nn.BatchNorm2d(out_planes), 18 | nn.ReLU() 19 | ) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class ResNet(nn.Module): 94 | 95 | def __init__(self, block, layers, in_channel=3, width=1, 96 | mlp_head=False, mid_dim=2048, low_dim=128, 97 | avg_down=False, deep_stem=False): 98 | self.avg_down = avg_down 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | 102 | if deep_stem: 103 | self.conv1 = nn.Sequential( 104 | conv3x3_bn_relu(in_channel, 32, stride=2), 105 | conv3x3_bn_relu(32, 32, stride=1), 106 | conv3x3(32, 64, stride=1) 107 | ) 108 | else: 109 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 110 | 111 | self.bn1 = nn.BatchNorm2d(64) 112 | self.relu = nn.ReLU(inplace=True) 113 | 114 | self.base = int(64 * width) 115 | 116 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 117 | self.layer1 = self._make_layer(block, self.base, layers[0]) 118 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 120 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 121 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 122 | self.mlp_head = mlp_head 123 | 124 | in_dim = self.base * 8 * block.expansion 125 | if self.mlp_head: 126 | self.fc1 = nn.Linear(in_dim, mid_dim) 127 | self.relu2 = nn.ReLU(inplace=True) 128 | self.fc2 = nn.Linear(mid_dim, low_dim) 129 | else: 130 | self.fc = nn.Linear(in_dim, low_dim) 131 | 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 135 | m.weight.data.normal_(0, math.sqrt(2. / n)) 136 | elif isinstance(m, nn.BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | 140 | # zero gamma for batch norm: reference bag of tricks 141 | if block is Bottleneck: 142 | gamma_name = "bn3.weight" 143 | elif block is BasicBlock: 144 | gamma_name = "bn2.weight" 145 | else: 146 | raise RuntimeError(f"block {block} not supported") 147 | for name, value in self.named_parameters(): 148 | if name.endswith(gamma_name): 149 | value.data.zero_() 150 | 151 | def _make_layer(self, block, planes, blocks, stride=1): 152 | downsample = None 153 | if stride != 1 or self.inplanes != planes * block.expansion: 154 | if self.avg_down: 155 | downsample = nn.Sequential( 156 | nn.AvgPool2d(kernel_size=stride, stride=stride), 157 | nn.Conv2d(self.inplanes, planes * block.expansion, 158 | kernel_size=1, stride=1, bias=False), 159 | nn.BatchNorm2d(planes * block.expansion), 160 | ) 161 | else: 162 | downsample = nn.Sequential( 163 | nn.Conv2d(self.inplanes, planes * block.expansion, 164 | kernel_size=1, stride=stride, bias=False), 165 | nn.BatchNorm2d(planes * block.expansion), 166 | ) 167 | 168 | layers = [block(self.inplanes, planes, stride, downsample)] 169 | self.inplanes = planes * block.expansion 170 | for _ in range(1, blocks): 171 | layers.append(block(self.inplanes, planes)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | def forward(self, x): 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | x = self.maxpool(x) 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | x = self.layer4(x) 184 | x = self.avgpool(x) 185 | x = x.view(x.size(0), -1) 186 | 187 | if self.mlp_head: 188 | x = self.fc1(x) 189 | x = self.relu2(x) 190 | x = self.fc2(x) 191 | else: 192 | x = self.fc(x) 193 | 194 | return x 195 | 196 | 197 | def resnet18(**kwargs): 198 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 199 | 200 | 201 | def resnet18_d(**kwargs): 202 | return ResNet(BasicBlock, [2, 2, 2, 2], deep_stem=True, avg_down=True, **kwargs) 203 | 204 | 205 | def resnet34(**kwargs): 206 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 207 | 208 | 209 | def resnet34_d(**kwargs): 210 | return ResNet(BasicBlock, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs) 211 | 212 | 213 | def resnet50(**kwargs): 214 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 215 | 216 | 217 | def resnet50_d(**kwargs): 218 | return ResNet(Bottleneck, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs) 219 | 220 | 221 | def resnet101(**kwargs): 222 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 223 | 224 | 225 | def resnet101_d(**kwargs): 226 | return ResNet(Bottleneck, [3, 4, 23, 3], deep_stem=True, avg_down=True, **kwargs) 227 | 228 | 229 | def resnet152(**kwargs): 230 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 231 | 232 | 233 | def resnet152_d(**kwargs): 234 | return ResNet(Bottleneck, [3, 8, 36, 3], deep_stem=True, avg_down=True, **kwargs) 235 | -------------------------------------------------------------------------------- /contrast/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | 10 | def __init__(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | """Computes the accuracy over the k top predictions for the specified values of k""" 32 | with torch.no_grad(): 33 | maxk = max(topk) 34 | batch_size = target.size(0) 35 | 36 | _, pred = output.topk(maxk, 1, True, True) 37 | pred = pred.t() 38 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 39 | 40 | res = [] 41 | for k in topk: 42 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 43 | res.append(correct_k.mul_(100.0 / batch_size)) 44 | return res 45 | 46 | 47 | def dist_collect(x): 48 | """ collect all tensor from all GPUs 49 | args: 50 | x: shape (mini_batch, ...) 51 | returns: 52 | shape (mini_batch * num_gpu, ...) 53 | """ 54 | x = x.contiguous() 55 | out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype) 56 | for _ in range(dist.get_world_size())] 57 | dist.all_gather(out_list, x) 58 | return torch.cat(out_list, dim=0) 59 | 60 | 61 | def reduce_tensor(tensor): 62 | rt = tensor.clone() 63 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 64 | rt /= dist.get_world_size() 65 | return rt 66 | 67 | 68 | class MyHelpFormatter(argparse.MetavarTypeHelpFormatter, argparse.ArgumentDefaultsHelpFormatter): 69 | pass 70 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | import torch.nn.functional as F 9 | from torch.nn.parallel import DistributedDataParallel 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from contrast import resnet 14 | from contrast.data import get_loader 15 | from contrast.logger import setup_logger 16 | from contrast.lr_scheduler import get_scheduler 17 | from contrast.option import parse_option 18 | from contrast.util import AverageMeter, accuracy, reduce_tensor 19 | 20 | try: 21 | # noinspection PyUnresolvedReferences 22 | from apex import amp 23 | except ImportError: 24 | amp = None 25 | 26 | 27 | def build_model(args, num_class): 28 | # create model 29 | model = resnet.__dict__[args.arch](low_dim=num_class).cuda() 30 | 31 | # set requires_grad of parameters except last fc layer to False 32 | for name, p in model.named_parameters(): 33 | if 'fc' not in name: 34 | p.requires_grad = False 35 | 36 | optimizer = torch.optim.SGD(model.fc.parameters(), 37 | lr=args.learning_rate, 38 | momentum=args.momentum, 39 | weight_decay=args.weight_decay) 40 | 41 | if args.amp_opt_level != "O0": 42 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) 43 | 44 | model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False) 45 | 46 | return model, optimizer 47 | 48 | 49 | def load_pretrained(model, pretrained_model): 50 | ckpt = torch.load(pretrained_model, map_location='cpu') 51 | model_dict = model.state_dict() 52 | state_dict = {k.replace("module.encoder.", "module."): v 53 | for k, v in ckpt['model'].items() 54 | if k.startswith('module.encoder.')} 55 | state_dict = {k: v for k, v in state_dict.items() 56 | if k in model_dict and v.size() == model_dict[k].size()} 57 | 58 | model_dict.update(state_dict) 59 | model.load_state_dict(model_dict) 60 | logger.info(f"==> loaded checkpoint '{pretrained_model}' (epoch {ckpt['epoch']})") 61 | 62 | 63 | def load_checkpoint(args, model, optimizer, scheduler): 64 | logger.info("=> loading checkpoint '{args.resume'") 65 | 66 | checkpoint = torch.load(args.resume, map_location='cpu') 67 | 68 | global best_acc1 69 | best_acc1 = checkpoint['best_acc1'] 70 | args.start_epoch = checkpoint['epoch'] + 1 71 | model.load_state_dict(checkpoint['model']) 72 | optimizer.load_state_dict(checkpoint['optimizer']) 73 | scheduler.load_state_dict(checkpoint['scheduler']) 74 | if args.amp_opt_level != "O0" and checkpoint['opt'].amp_opt_level != "O0": 75 | amp.load_state_dict(checkpoint['amp']) 76 | 77 | logger.info(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 78 | 79 | 80 | def save_checkpoint(args, epoch, model, test_acc, optimizer, scheduler): 81 | state = { 82 | 'args': args, 83 | 'epoch': epoch, 84 | 'model': model.state_dict(), 85 | 'best_acc1': test_acc, 86 | 'optimizer': optimizer.state_dict(), 87 | 'scheduler': scheduler.state_dict(), 88 | } 89 | if args.amp_opt_level != "O0": 90 | state['amp'] = amp.state_dict() 91 | torch.save(state, os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth')) 92 | torch.save(state, os.path.join(args.output_dir, f'current.pth')) 93 | 94 | 95 | def main(args): 96 | global best_acc1 97 | 98 | args.batch_size = args.total_batch_size // dist.get_world_size() 99 | train_loader = get_loader(args.aug, args, prefix='train') 100 | val_loader = get_loader('val', args, prefix='val') 101 | logger.info(f"length of training dataset: {len(train_loader.dataset)}") 102 | 103 | model, optimizer = build_model(args, num_class=len(train_loader.dataset.classes)) 104 | scheduler = get_scheduler(optimizer, len(train_loader), args) 105 | 106 | # load pre-trained model 107 | load_pretrained(model, args.pretrained_model) 108 | 109 | # optionally resume from a checkpoint 110 | if args.auto_resume: 111 | resume_file = os.path.join(args.output_dir, "current.pth") 112 | if os.path.exists(resume_file): 113 | logger.info(f'auto resume from {resume_file}') 114 | args.resume = resume_file 115 | else: 116 | logger.info(f'no checkpoint found in {args.output_dir}, ignoring auto resume') 117 | if args.resume: 118 | assert os.path.isfile(args.resume), f"no checkpoint found at '{args.resume}'" 119 | load_checkpoint(args, model, optimizer, scheduler) 120 | 121 | if args.eval: 122 | logger.info("==> testing...") 123 | validate(val_loader, model, args) 124 | return 125 | 126 | # tensorboard 127 | if dist.get_rank() == 0: 128 | summary_writer = SummaryWriter(log_dir=args.output_dir) 129 | else: 130 | summary_writer = None 131 | 132 | # routine 133 | for epoch in range(args.start_epoch, args.epochs + 1): 134 | if isinstance(train_loader.sampler, DistributedSampler): 135 | train_loader.sampler.set_epoch(epoch) 136 | 137 | tic = time.time() 138 | train(epoch, train_loader, model, optimizer, scheduler, args) 139 | logger.info(f'epoch {epoch}, total time {time.time() - tic:.2f}') 140 | 141 | logger.info("==> testing...") 142 | test_acc, test_acc5, test_loss = validate(val_loader, model, args) 143 | if summary_writer is not None: 144 | summary_writer.add_scalar('test_acc', test_acc, epoch) 145 | summary_writer.add_scalar('test_acc5', test_acc5, epoch) 146 | summary_writer.add_scalar('test_loss', test_loss, epoch) 147 | 148 | # save model 149 | if dist.get_rank() == 0 and epoch % args.save_freq == 0: 150 | logger.info('==> Saving...') 151 | save_checkpoint(args, epoch, model, test_acc, optimizer, scheduler) 152 | 153 | 154 | def train(epoch, train_loader, model, optimizer, scheduler, args): 155 | """ 156 | one epoch training 157 | """ 158 | 159 | model.train() 160 | 161 | batch_time = AverageMeter() 162 | data_time = AverageMeter() 163 | loss_meter = AverageMeter() 164 | acc1_meter = AverageMeter() 165 | acc5_meter = AverageMeter() 166 | 167 | end = time.time() 168 | for idx, (x, _, y) in enumerate(train_loader): 169 | x = x.cuda(non_blocking=True) 170 | y = y.cuda(non_blocking=True) 171 | 172 | # measure data loading time 173 | data_time.update(time.time() - end) 174 | 175 | # forward 176 | output = model(x) 177 | loss = F.cross_entropy(output, y) 178 | 179 | # backward 180 | optimizer.zero_grad() 181 | if args.amp_opt_level != "O0": 182 | with amp.scale_loss(loss, optimizer) as scaled_loss: 183 | scaled_loss.backward() 184 | else: 185 | loss.backward() 186 | optimizer.step() 187 | scheduler.step() 188 | 189 | # update meters 190 | acc1, acc5 = accuracy(output, y, topk=(1, 5)) 191 | loss_meter.update(loss.item(), x.size(0)) 192 | acc1_meter.update(acc1[0], x.size(0)) 193 | acc5_meter.update(acc5[0], x.size(0)) 194 | batch_time.update(time.time() - end) 195 | end = time.time() 196 | 197 | # print info 198 | if idx % args.print_freq == 0: 199 | logger.info( 200 | f'Epoch: [{epoch}][{idx}/{len(train_loader)}]\t' 201 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 202 | f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 203 | f'Lr {optimizer.param_groups[0]["lr"]:.3f} \t' 204 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 205 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 206 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})') 207 | 208 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 209 | 210 | 211 | def validate(val_loader, model, args): 212 | batch_time = AverageMeter() 213 | loss_meter = AverageMeter() 214 | acc1_meter = AverageMeter() 215 | acc5_meter = AverageMeter() 216 | 217 | # switch to evaluate mode 218 | model.eval() 219 | 220 | with torch.no_grad(): 221 | end = time.time() 222 | for idx, (x, _, y) in enumerate(val_loader): 223 | x = x.cuda(non_blocking=True) 224 | y = y.cuda(non_blocking=True) 225 | 226 | # compute output 227 | output = model(x) 228 | loss = F.cross_entropy(output, y) 229 | 230 | # measure accuracy and record loss 231 | acc1, acc5 = accuracy(output, y, topk=(1, 5)) 232 | 233 | acc1 = reduce_tensor(acc1) 234 | acc5 = reduce_tensor(acc5) 235 | loss = reduce_tensor(loss) 236 | 237 | loss_meter.update(loss.item(), x.size(0)) 238 | acc1_meter.update(acc1[0], x.size(0)) 239 | acc5_meter.update(acc5[0], x.size(0)) 240 | 241 | # measure elapsed time 242 | batch_time.update(time.time() - end) 243 | end = time.time() 244 | 245 | if idx % args.print_freq == 0: 246 | logger.info( 247 | f'Test: [{idx}/{len(val_loader)}]\t' 248 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 249 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 250 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 251 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})') 252 | 253 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 254 | 255 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 256 | 257 | 258 | if __name__ == '__main__': 259 | opt = parse_option(stage='linear') 260 | 261 | if opt.amp_opt_level != "O0": 262 | assert amp is not None, "amp not installed!" 263 | 264 | torch.cuda.set_device(opt.local_rank) 265 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 266 | cudnn.benchmark = True 267 | best_acc1 = 0 268 | 269 | os.makedirs(opt.output_dir, exist_ok=True) 270 | logger = setup_logger(output=opt.output_dir, distributed_rank=dist.get_rank(), name="contrast") 271 | if dist.get_rank() == 0: 272 | path = os.path.join(opt.output_dir, "config.json") 273 | with open(path, "w") as f: 274 | json.dump(vars(opt), f, indent=2) 275 | logger.info("Full config saved to {}".format(path)) 276 | 277 | # print args 278 | logger.info(vars(opt)) 279 | 280 | main(opt) 281 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import time 5 | from shutil import copyfile 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | from torch.backends import cudnn 11 | from torch.nn.parallel import DistributedDataParallel 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from contrast import models 16 | from contrast import resnet 17 | from contrast.data import get_loader 18 | from contrast.logger import setup_logger 19 | from contrast.lr_scheduler import get_scheduler 20 | from contrast.option import parse_option 21 | from contrast.util import AverageMeter 22 | from contrast.lars import add_weight_decay, LARS 23 | 24 | try: 25 | # noinspection PyUnresolvedReferences 26 | from apex import amp 27 | except ImportError: 28 | amp = None 29 | 30 | 31 | def build_model(args): 32 | encoder = resnet.__dict__[args.arch] 33 | model = models.__dict__[args.model](encoder, args).cuda() 34 | 35 | lr = args.batch_size * dist.get_world_size() / 256 * args.base_learning_rate 36 | if args.optimizer == 'sgd': 37 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) 38 | elif args.optimizer == 'lars': 39 | params = add_weight_decay(model, args.weight_decay) 40 | optimizer = torch.optim.SGD(params, lr=lr, momentum=args.momentum) 41 | optimizer = LARS(optimizer) 42 | 43 | if args.amp_opt_level != "O0": 44 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) 45 | 46 | model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False) 47 | 48 | return model, optimizer 49 | 50 | 51 | def load_checkpoint(args, model, optimizer, scheduler, sampler=None): 52 | logger.info(f"=> loading checkpoint '{args.resume}'") 53 | 54 | checkpoint = torch.load(args.resume, map_location='cpu') 55 | args.start_epoch = checkpoint['epoch'] + 1 56 | model.load_state_dict(checkpoint['model']) 57 | optimizer.load_state_dict(checkpoint['optimizer']) 58 | scheduler.load_state_dict(checkpoint['scheduler']) 59 | if args.amp_opt_level != "O0" and checkpoint['opt'].amp_opt_level != "O0": 60 | amp.load_state_dict(checkpoint['amp']) 61 | if args.use_sliding_window_sampler: 62 | sampler.load_state_dict(checkpoint['sampler']) 63 | 64 | logger.info(f"=> loaded successfully '{args.resume}' (epoch {checkpoint['epoch']})") 65 | 66 | del checkpoint 67 | torch.cuda.empty_cache() 68 | 69 | 70 | def save_checkpoint(args, epoch, model, optimizer, scheduler, sampler=None): 71 | logger.info('==> Saving...') 72 | state = { 73 | 'opt': args, 74 | 'model': model.state_dict(), 75 | 'optimizer': optimizer.state_dict(), 76 | 'scheduler': scheduler.state_dict(), 77 | 'epoch': epoch, 78 | } 79 | if args.amp_opt_level != "O0": 80 | state['amp'] = amp.state_dict() 81 | if args.use_sliding_window_sampler: 82 | state['sampler'] = sampler.state_dict() 83 | file_name = os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth') 84 | torch.save(state, file_name) 85 | copyfile(file_name, os.path.join(args.output_dir, 'current.pth')) 86 | 87 | 88 | def main(args): 89 | train_loader = get_loader(args.aug, args, 90 | two_crop=args.model in ['MoCo', 'SimCLR', 'PIC'], 91 | prefix='train') 92 | args.num_instances = len(train_loader.dataset) 93 | logger.info(f"length of training dataset: {args.num_instances}") 94 | 95 | if args.use_sliding_window_sampler: 96 | args.warmup_epoch = math.ceil(args.warmup_epoch * len(train_loader.dataset) / args.window_size) 97 | args.epochs = math.ceil(args.epochs * len(train_loader.dataset) / args.window_size) 98 | 99 | model, optimizer = build_model(args) 100 | scheduler = get_scheduler(optimizer, len(train_loader), args) 101 | 102 | # optionally resume from a checkpoint 103 | if args.auto_resume: 104 | resume_file = os.path.join(args.output_dir, "current.pth") 105 | if os.path.exists(resume_file): 106 | logger.info(f'auto resume from {resume_file}') 107 | args.resume = resume_file 108 | else: 109 | logger.info(f'no checkpoint found in {args.output_dir}, ignoring auto resume') 110 | if args.resume: 111 | assert os.path.isfile(args.resume) 112 | load_checkpoint(args, model, optimizer, scheduler, sampler=train_loader.sampler) 113 | 114 | # tensorboard 115 | if dist.get_rank() == 0: 116 | summary_writer = SummaryWriter(log_dir=args.output_dir) 117 | else: 118 | summary_writer = None 119 | 120 | for epoch in range(args.start_epoch, args.epochs + 1): 121 | if isinstance(train_loader.sampler, DistributedSampler): 122 | train_loader.sampler.set_epoch(epoch) 123 | 124 | train(epoch, train_loader, model, optimizer, scheduler, args, summary_writer) 125 | 126 | if dist.get_rank() == 0 and (epoch % args.save_freq == 0 or epoch == args.epochs): 127 | save_checkpoint(args, epoch, model, optimizer, scheduler, sampler=train_loader.sampler) 128 | 129 | 130 | def train(epoch, train_loader, model, optimizer, scheduler, args, summary_writer): 131 | """ 132 | one epoch training 133 | """ 134 | model.train() 135 | 136 | batch_time = AverageMeter() 137 | loss_meter = AverageMeter() 138 | 139 | end = time.time() 140 | for idx, data in enumerate(train_loader): 141 | data = [item.cuda(non_blocking=True) for item in data] 142 | 143 | if args.model == 'PIC': 144 | # for PIC with multi crop, data[0], data[1], data[2] are x_large, x_small, y_idx 145 | logit, label = model(data[0], data[1], data[2]) 146 | else: 147 | # for moco and SimCLR, two_crop=True, data[0] and data[1] are x1, x2 148 | # for InstDisc, two_crop=False, data[0] and data[1] are x, y_idx 149 | logit, label = model(data[0], data[1]) 150 | 151 | loss = F.cross_entropy(logit, label) 152 | 153 | # backward 154 | optimizer.zero_grad() 155 | if args.amp_opt_level != "O0": 156 | with amp.scale_loss(loss, optimizer) as scaled_loss: 157 | scaled_loss.backward() 158 | else: 159 | loss.backward() 160 | optimizer.step() 161 | scheduler.step() 162 | 163 | # update meters and print info 164 | loss_meter.update(loss.item(), data[0].size(0)) 165 | batch_time.update(time.time() - end) 166 | end = time.time() 167 | 168 | train_len = len(train_loader) 169 | if idx % args.print_freq == 0: 170 | lr = optimizer.param_groups[0]['lr'] 171 | logger.info( 172 | f'Train: [{epoch}/{args.epochs}][{idx}/{train_len}] ' 173 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 174 | f'lr {lr:.3f} ' 175 | f'loss {loss_meter.val:.3f} ({loss_meter.avg:.3f})') 176 | 177 | # tensorboard logger 178 | if summary_writer is not None: 179 | step = (epoch - 1) * len(train_loader) + idx 180 | summary_writer.add_scalar('lr', lr, step) 181 | summary_writer.add_scalar('loss', loss_meter.val, step) 182 | 183 | 184 | if __name__ == '__main__': 185 | opt = parse_option(stage='pre-train') 186 | 187 | if opt.amp_opt_level != "O0": 188 | assert amp is not None, "amp not installed!" 189 | 190 | torch.cuda.set_device(opt.local_rank) 191 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 192 | cudnn.benchmark = True 193 | 194 | # setup logger 195 | os.makedirs(opt.output_dir, exist_ok=True) 196 | logger = setup_logger(output=opt.output_dir, distributed_rank=dist.get_rank(), name="contrast") 197 | if dist.get_rank() == 0: 198 | path = os.path.join(opt.output_dir, "config.json") 199 | with open(path, 'w') as f: 200 | json.dump(vars(opt), f, indent=2) 201 | logger.info("Full config saved to {}".format(path)) 202 | 203 | # print args 204 | logger.info(vars(opt)) 205 | 206 | main(opt) 207 | -------------------------------------------------------------------------------- /scripts/InstDisc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | epochs=${1:-200} 7 | warmup=$(( epochs / 40 )) 8 | 9 | data_dir="./data/ImageNet-Zip" 10 | output_dir="./output/InstDisc/epoch-${epochs}" 11 | 12 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 13 | main_pretrain.py \ 14 | --data-dir ${data_dir} \ 15 | --crop 0.2 \ 16 | --aug InstDisc \ 17 | --zip --cache-mode part \ 18 | --model InstDisc \ 19 | --contrast-temperature 0.07 \ 20 | --contrast-momentum 0.5 \ 21 | --mlp-head \ 22 | --warmup-epoch ${warmup} \ 23 | --epochs ${epochs} \ 24 | --output-dir "${output_dir}" \ 25 | --save-freq 10 26 | 27 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 28 | main_linear.py \ 29 | --data-dir "${data_dir}" \ 30 | --zip --cache-mode part \ 31 | --output-dir "${output_dir}/eval" \ 32 | --pretrained-model "${output_dir}/current.pth" \ 33 | -------------------------------------------------------------------------------- /scripts/MoCov1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | epochs=${1:-200} 7 | warmup=$(( epochs / 40 )) 8 | 9 | data_dir="./data/ImageNet-Zip" 10 | output_dir="./output/MoCov1/epoch-${epochs}" 11 | 12 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 13 | main_pretrain.py \ 14 | --data-dir ${data_dir} \ 15 | --crop 0.2 \ 16 | --aug InstDisc \ 17 | --zip --cache-mode part \ 18 | --model MoCo \ 19 | --contrast-temperature 0.07 \ 20 | --lr-scheduler step \ 21 | --warmup-epoch ${warmup} \ 22 | --epochs ${epochs} \ 23 | --output-dir "${output_dir}" \ 24 | --save-freq 10 25 | 26 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 27 | main_linear.py \ 28 | --data-dir "${data_dir}" \ 29 | --zip --cache-mode part \ 30 | --output-dir "${output_dir}/eval" \ 31 | --pretrained-model "${output_dir}/current.pth" \ 32 | -------------------------------------------------------------------------------- /scripts/MoCov2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | epochs=${1:-200} 7 | warmup=$(( epochs / 40 )) 8 | 9 | data_dir="./data/ImageNet-Zip" 10 | output_dir="./output/MoCo/epoch-${epochs}" 11 | 12 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 13 | main_pretrain.py \ 14 | --data-dir ${data_dir} \ 15 | --crop 0.2 \ 16 | --aug MoCov2 \ 17 | --zip --cache-mode part \ 18 | --model MoCo \ 19 | --contrast-temperature 0.2 \ 20 | --mlp-head \ 21 | --warmup-epoch ${warmup} \ 22 | --epochs ${epochs} \ 23 | --output-dir ${output_dir} \ 24 | --save-freq 10 25 | 26 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 27 | main_linear.py \ 28 | --data-dir "${data_dir}" \ 29 | --zip --cache-mode part \ 30 | --output-dir "${output_dir}/eval" \ 31 | --pretrained-model "${output_dir}/current.pth" \ 32 | -------------------------------------------------------------------------------- /scripts/PIC.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # script for PIC 4 | 5 | set -e 6 | set -x 7 | 8 | epochs=${1:-400} 9 | warmup=$(( epochs / 40 )) 10 | 11 | data_dir="./data/ImageNet-Zip" 12 | output_dir="./output/PIC/epoch-${epochs}" 13 | 14 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 15 | main_pretrain.py \ 16 | --data-dir ${data_dir} \ 17 | --crop 0.08 \ 18 | --aug MultiCrop \ 19 | --zip --cache-mode part \ 20 | --arch resnet50 \ 21 | --model PIC \ 22 | --contrast-temperature 0.2 \ 23 | --mlp-head \ 24 | --warmup-epoch ${warmup} \ 25 | --epochs ${epochs} \ 26 | --output-dir "${output_dir}" \ 27 | \ 28 | --window-size 131072 \ 29 | --window-stride 16384 \ 30 | --use-sliding-window-sampler \ 31 | --shuffle-per-epoch \ 32 | \ 33 | --crop2 0.14 \ 34 | --image-size 160 \ 35 | --image-size2 96 \ 36 | --num-crop 1 \ 37 | --num-crop2 3 \ 38 | 39 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 40 | main_linear.py \ 41 | --data-dir "${data_dir}" \ 42 | --zip --cache-mode part \ 43 | --arch resnet50 \ 44 | --output-dir "${output_dir}/eval" \ 45 | --pretrained-model "${output_dir}/current.pth" \ 46 | -------------------------------------------------------------------------------- /scripts/SimCLR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | epochs=${1:-200} 7 | warmup=$(( epochs / 40 )) 8 | 9 | data_dir="./data/ImageNet-Zip" 10 | output_dir="./output/SimCLR" 11 | 12 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 13 | main_pretrain.py \ 14 | --data-dir ${data_dir} \ 15 | --crop 0.08 \ 16 | --aug SimCLR \ 17 | --zip --cache-mode part \ 18 | --model SimCLR \ 19 | --optimizer "lars" \ 20 | --contrast-temperature 0.1 \ 21 | --mlp-head \ 22 | --base-lr 0.3 \ 23 | --warmup-epoch ${warmup} \ 24 | --weight-decay 1e-6 \ 25 | --epochs ${epochs} \ 26 | --output-dir "${output_dir}" \ 27 | --save-freq 10 \ 28 | 29 | 30 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 31 | main_linear.py \ 32 | --data-dir "${data_dir}" \ 33 | --zip --cache-mode part \ 34 | --learning-rate 0.1 \ 35 | --weight-decay 1e-6 \ 36 | --output-dir "${output_dir}/eval" \ 37 | --pretrained-model "${output_dir}/current.pth" \ 38 | --------------------------------------------------------------------------------