├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── __pycache__ ├── arguments.cpython-38.pyc ├── linear_cub_eval.cpython-38.pyc ├── linear_eval.cpython-38.pyc └── linear_stanfordcars_eval.cpython-38.pyc ├── arguments.py ├── augmentations ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── byol_aug.cpython-38.pyc │ ├── eval_aug.cpython-38.pyc │ └── gaussian_blur.cpython-38.pyc ├── byol_aug.py ├── eval_aug.py └── gaussian_blur.py ├── classifier.py ├── configs ├── __init__.py ├── byol_aircrafts ├── byol_aircrafts.yaml ├── byol_aircrafts_eval.yaml ├── byol_cifar.yaml ├── byol_cub200.yaml └── byol_stanfordcars.yaml ├── datasets ├── CUB200.py ├── CUB200_val.py ├── CUB2011.py ├── ImageNet100.py ├── __init__.py ├── __pycache__ │ ├── CUB2011.cpython-38.pyc │ ├── ImageNet100.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── random_dataset.cpython-38.pyc └── random_dataset.py ├── examples └── framework.png ├── hand_detector.py ├── linear_cub_eval.py ├── linear_eval.py ├── linear_imagenet100_eval.py ├── linear_stanfordcars_eval.py ├── loader.py ├── main.py ├── main_lincls.py ├── main_moco.py ├── moco ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── builder.cpython-38.pyc │ └── loader.cpython-38.pyc ├── builder.py └── loader.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── byol.cpython-38.pyc ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── cifar_resnet_1.cpython-38.pyc │ │ ├── cifar_resnet_2.cpython-38.pyc │ │ └── cub_resnet_1.cpython-38.pyc │ ├── cifar_resnet_1.py │ ├── cifar_resnet_2.py │ └── cub_resnet_1.py └── byol.py ├── optimizers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── larc.cpython-38.pyc │ ├── lars.cpython-38.pyc │ └── lr_scheduler.cpython-38.pyc ├── larc.py ├── lars.py ├── lars_simclr.py └── lr_scheduler.py ├── requirements.txt ├── resnet_output.py ├── run_all.sh ├── simsiam-800e90.83acc.svg └── tools ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── accuracy.cpython-38.pyc ├── average_meter.cpython-38.pyc ├── file_exist_fn.cpython-38.pyc ├── knn_monitor.cpython-38.pyc ├── logger.cpython-38.pyc └── plotter.cpython-38.pyc ├── accuracy.py ├── average_meter.py ├── file_exist_fn.py ├── knn_monitor.py ├── logger.py └── plotter.py /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime 2 | RUN apt-get update && apt install -y git 3 | RUN pip install tensorboardX 4 | RUN pip install ttach 5 | RUN pip install pandas 6 | RUN pip install matplotlib 7 | RUN pip install opencv-python 8 | RUN pip install google-cloud-storage -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Common Rationale to Improve Self-Supervised Representation for Fine-Grained Visual Recognition Problems 2 | 3 | This project contains the implementation of learning common rationale to improve self-supervised representation for fine-grained visual recognition, as presented in our paper 4 | 5 | > Learning Common Rationale to Improve Self-Supervised Representation for Fine-Grained Visual Recognition Problems, 6 | > Yangyang Shu, Anton van den Hengel and Lingqiao Liu* 7 | > *CVPR 2023* 8 | 9 | ## Datasets 10 | | Dataset | Download Link | 11 | | -- | -- | 12 | | CUB-200-2011 | https://paperswithcode.com/dataset/cub-200-2011 | 13 | | Stanford Cars | http://ai.stanford.edu/~jkrause/cars/car_dataset.html | 14 | | FGVC Aircraft | http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ | 15 | 16 | 17 | Please download and organize the datasets in this structure: 18 | ``` 19 | LCR 20 | ├── CUB200/ 21 | │ ├── train/ 22 | ├── test/ 23 | ├── StanfordCars/ 24 | │ ├── train/ 25 | ├── test/ 26 | ├── Aircraft/ 27 | │ ├── train/ 28 | ├── test/ 29 | ``` 30 | 31 | # For byol 32 | Install the required packages: 33 | ``` 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | - The running commands for several datasets are shown below. You can also refer to ``run_all.sh``. 38 | ``` 39 | python main.py --data_dir ./CUB200 --log_dir ./logs/ -c configs/byol_cub200.yaml --ckpt_dir ./.cache/ --hide_progress 40 | python main.py --data_dir ./StanfordCars --log_dir ./logs/ -c configs/byol_stanfordcars.yaml --ckpt_dir ./.cache/ --hide_progress 41 | python main.py --data_dir ./Aircraft --log_dir ./logs/ -c configs/byol_aircrafts.yaml --ckpt_dir ./.cache/ --hide_progress 42 | 43 | ``` 44 | 45 | # For moco v2 46 | 47 | - The running commands for pre-training and retrieval 48 | ``` 49 | python main_moco.py --epochs 100 -a resnet50 --lr 0.03 --batch-size 128 --multiprocessing-distributed --world-size 1 --rank 0 Aircraft --mlp --moco-t 0.2 --aug-plus --cos 50 | python main_moco.py --epochs 100 -a resnet50 --lr 0.03 --batch-size 128 --multiprocessing-distributed --world-size 1 --rank 0 StanfordCars --mlp --moco-t 0.2 --aug-plus --cos 51 | python main_moco.py --epochs 100 -a resnet50 --lr 0.03 --batch-size 128 --multiprocessing-distributed --world-size 1 --rank 0 CUB200 --mlp --moco-t 0.2 --aug-plus --cos 52 | ``` 53 | 54 | - The running commands for linear probing 55 | ``` 56 | python main_lincls.py -a resnet50 --lr 30.0 --batch-size 256 --pretrained [your checkpoint path]/checkpoint_****.pth.tar --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 Aircraft --class_num 100 57 | python main_lincls.py -a resnet50 --lr 30.0 --batch-size 256 --pretrained [your checkpoint path]/checkpoint_****.pth.tar --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 StanfordCars --class_num 196 58 | python main_lincls.py -a resnet50 --lr 30.0 --batch-size 256 --pretrained [your checkpoint path]/checkpoint_****.pth.tar --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 CUB200 --class_num 200 59 | ``` 60 | 61 | Citation 62 | If you find this code or idea useful, please cite our work: 63 | ``` 64 | @inproceedings{shu2023learning, 65 | title={Learning Common Rationale to Improve Self-Supervised Representation for Fine-Grained Visual Recognition Problems}, 66 | author={Shu, Yangyang and van den Hengel, Anton and Liu, Lingqiao}, 67 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 68 | pages={11392--11401}, 69 | year={2023} 70 | } 71 | ``` 72 | 73 | 74 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | ## 2 | -------------------------------------------------------------------------------- /__pycache__/arguments.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/__pycache__/arguments.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/linear_cub_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/__pycache__/linear_cub_eval.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/linear_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/__pycache__/linear_eval.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/linear_stanfordcars_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/__pycache__/linear_stanfordcars_eval.cpython-38.pyc -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | import re 10 | import yaml 11 | 12 | import shutil 13 | import warnings 14 | 15 | from datetime import datetime 16 | 17 | 18 | class Namespace(object): 19 | def __init__(self, somedict): 20 | for key, value in somedict.items(): 21 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key) 22 | if isinstance(value, dict): 23 | self.__dict__[key] = Namespace(value) 24 | else: 25 | self.__dict__[key] = value 26 | 27 | def __getattr__(self, attribute): 28 | 29 | raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!") 30 | 31 | 32 | def set_deterministic(seed): 33 | # seed by default is None 34 | if seed is not None: 35 | print(f"Deterministic with seed = {seed}") 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml") 46 | parser.add_argument('--debug', action='store_true') 47 | parser.add_argument('--debug_subset_size', type=int, default=8) 48 | parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web") 49 | parser.add_argument('--data_dir', type=str, default=os.getenv('DATA')) 50 | parser.add_argument('--log_dir', type=str, default=os.getenv('LOG')) 51 | parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT')) 52 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 53 | parser.add_argument('--eval_from', type=str, default=None) 54 | parser.add_argument('--hide_progress', action='store_true') 55 | args = parser.parse_args() 56 | 57 | 58 | with open(args.config_file, 'r') as f: 59 | for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items(): 60 | vars(args)[key] = value 61 | 62 | if args.debug: 63 | if args.train: 64 | args.train.batch_size = 2 65 | args.train.num_epochs = 1 66 | args.train.stop_at_epoch = 1 67 | if args.eval: 68 | args.eval.batch_size = 2 69 | args.eval.num_epochs = 1 # train only one epoch 70 | args.dataset.num_workers = 0 71 | 72 | 73 | assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name] 74 | 75 | args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name) 76 | 77 | os.makedirs(args.log_dir, exist_ok=False) 78 | print(f'creating file {args.log_dir}') 79 | os.makedirs(args.ckpt_dir, exist_ok=True) 80 | 81 | shutil.copy2(args.config_file, args.log_dir) 82 | set_deterministic(args.seed) 83 | 84 | 85 | vars(args)['aug_kwargs'] = { 86 | 'name':args.model.name, 87 | 'image_size': args.dataset.image_size 88 | } 89 | vars(args)['dataset_kwargs'] = { 90 | 'dataset':args.dataset.name, 91 | 'data_dir': args.data_dir, 92 | 'download':args.download, 93 | 'debug_subset_size': args.debug_subset_size if args.debug else None, 94 | } 95 | vars(args)['dataloader_kwargs'] = { 96 | 'drop_last': True, 97 | 'pin_memory': True, 98 | 'num_workers': args.dataset.num_workers, 99 | } 100 | 101 | return args 102 | -------------------------------------------------------------------------------- /augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_aug import Transform_single 2 | from .byol_aug import BYOL_transform 3 | 4 | def get_aug(name='byol', image_size=224, train=True, train_classifier=None): 5 | 6 | if train==True: 7 | if name == 'byol': 8 | augmentation = BYOL_transform(image_size) 9 | else: 10 | raise NotImplementedError 11 | elif train==False: 12 | if train_classifier is None: 13 | raise Exception 14 | augmentation = Transform_single(image_size, train=train_classifier) 15 | else: 16 | raise Exception 17 | 18 | return augmentation 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /augmentations/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/augmentations/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /augmentations/__pycache__/byol_aug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/augmentations/__pycache__/byol_aug.cpython-38.pyc -------------------------------------------------------------------------------- /augmentations/__pycache__/eval_aug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/augmentations/__pycache__/eval_aug.cpython-38.pyc -------------------------------------------------------------------------------- /augmentations/__pycache__/gaussian_blur.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/augmentations/__pycache__/gaussian_blur.cpython-38.pyc -------------------------------------------------------------------------------- /augmentations/byol_aug.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image, ImageOps 3 | try: 4 | from torchvision.transforms import GaussianBlur 5 | except ImportError: 6 | from .gaussian_blur import GaussianBlur 7 | torchvision.transforms.GaussianBlur = GaussianBlur 8 | 9 | imagenet_norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]] 10 | 11 | class BYOL_transform: # Table 6 12 | def __init__(self, image_size, normalize=imagenet_norm): 13 | 14 | self.transform1 = transforms.Compose([ 15 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 16 | transforms.RandomHorizontalFlip(p=0.5), 17 | transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8), 18 | transforms.RandomGrayscale(p=0.2), 19 | transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0)), # simclr paper gives the kernel size. Kernel size has to be odd positive number with torchvision 20 | transforms.ToTensor(), 21 | transforms.Normalize(*normalize) 22 | ]) 23 | self.transform2 = transforms.Compose([ 24 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 25 | transforms.RandomHorizontalFlip(p=0.5), 26 | transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8), 27 | transforms.RandomGrayscale(p=0.2), 28 | # transforms.RandomApply([GaussianBlur(kernel_size=int(0.1 * image_size))], p=0.1), 29 | transforms.RandomApply([transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=0.1), 30 | transforms.RandomApply([Solarization()], p=0.2), 31 | 32 | transforms.ToTensor(), 33 | transforms.Normalize(*normalize) 34 | ]) 35 | 36 | 37 | def __call__(self, x): 38 | x1 = self.transform1(x) 39 | x2 = self.transform2(x) 40 | return x1, x2 41 | 42 | 43 | class Transform_single: 44 | def __init__(self, image_size, train, normalize=imagenet_norm): 45 | self.denormalize = Denormalize(*imagenet_norm) 46 | if train == True: 47 | self.transform = transforms.Compose([ 48 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize(*normalize) 52 | ]) 53 | else: 54 | self.transform = transforms.Compose([ 55 | transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256 56 | transforms.CenterCrop(image_size), 57 | transforms.ToTensor(), 58 | transforms.Normalize(*normalize) 59 | ]) 60 | 61 | def __call__(self, x): 62 | return self.transform(x) 63 | 64 | 65 | 66 | class Solarization(): 67 | # ImageFilter 68 | def __init__(self, threshold=128): 69 | self.threshold = threshold 70 | def __call__(self, image): 71 | return ImageOps.solarize(image, self.threshold) 72 | 73 | 74 | -------------------------------------------------------------------------------- /augmentations/eval_aug.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | imagenet_norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]] 5 | 6 | class Transform_single(): 7 | def __init__(self, image_size, train, normalize=imagenet_norm): 8 | if train == True: 9 | self.transform = transforms.Compose([ 10 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize(*normalize) 14 | ]) 15 | else: 16 | self.transform = transforms.Compose([ 17 | transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256 18 | transforms.CenterCrop(image_size), 19 | transforms.ToTensor(), 20 | transforms.Normalize(*normalize) 21 | ]) 22 | 23 | def __call__(self, x): 24 | return self.transform(x) 25 | -------------------------------------------------------------------------------- /augmentations/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | """ 2 | Only the recent torchvision package has gaussian blur. 3 | So I copy the functions here 4 | 5 | """ 6 | 7 | 8 | import torch 9 | from torch import Tensor 10 | from torchvision.transforms.functional import to_pil_image, to_tensor 11 | from torch.nn.functional import conv2d, pad as torch_pad 12 | from typing import Any, List, Sequence, Optional 13 | import numbers 14 | import numpy as np 15 | import torch 16 | from PIL import Image 17 | from typing import Tuple 18 | 19 | class GaussianBlur(torch.nn.Module): 20 | """Blurs image with randomly chosen Gaussian blur. 21 | The image can be a PIL Image or a Tensor, in which case it is expected 22 | to have [..., C, H, W] shape, where ... means an arbitrary number of leading 23 | dimensions 24 | 25 | Args: 26 | kernel_size (int or sequence): Size of the Gaussian kernel. 27 | sigma (float or tuple of float (min, max)): Standard deviation to be used for 28 | creating kernel to perform blurring. If float, sigma is fixed. If it is tuple 29 | of float (min, max), sigma is chosen uniformly at random to lie in the 30 | given range. 31 | 32 | Returns: 33 | PIL Image or Tensor: Gaussian blurred version of the input image. 34 | 35 | """ 36 | 37 | def __init__(self, kernel_size, sigma=(0.1, 2.0)): 38 | super().__init__() 39 | self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") 40 | for ks in self.kernel_size: 41 | if ks <= 0 or ks % 2 == 0: 42 | raise ValueError("Kernel size value should be an odd and positive number.") 43 | 44 | if isinstance(sigma, numbers.Number): 45 | if sigma <= 0: 46 | raise ValueError("If sigma is a single number, it must be positive.") 47 | sigma = (sigma, sigma) 48 | elif isinstance(sigma, Sequence) and len(sigma) == 2: 49 | if not 0. < sigma[0] <= sigma[1]: 50 | raise ValueError("sigma values should be positive and of the form (min, max).") 51 | else: 52 | raise ValueError("sigma should be a single number or a list/tuple with length 2.") 53 | 54 | self.sigma = sigma 55 | 56 | @staticmethod 57 | def get_params(sigma_min: float, sigma_max: float) -> float: 58 | """Choose sigma for random gaussian blurring. 59 | 60 | Args: 61 | sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. 62 | sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. 63 | 64 | Returns: 65 | float: Standard deviation to be passed to calculate kernel for gaussian blurring. 66 | """ 67 | return torch.empty(1).uniform_(sigma_min, sigma_max).item() 68 | 69 | def forward(self, img: Tensor) -> Tensor: 70 | """ 71 | Args: 72 | img (PIL Image or Tensor): image to be blurred. 73 | 74 | Returns: 75 | PIL Image or Tensor: Gaussian blurred image 76 | """ 77 | sigma = self.get_params(self.sigma[0], self.sigma[1]) 78 | return gaussian_blur(img, self.kernel_size, [sigma, sigma]) 79 | 80 | def __repr__(self): 81 | s = '(kernel_size={}, '.format(self.kernel_size) 82 | s += 'sigma={})'.format(self.sigma) 83 | return self.__class__.__name__ + s 84 | 85 | @torch.jit.unused 86 | def _is_pil_image(img: Any) -> bool: 87 | return isinstance(img, Image.Image) 88 | def _setup_size(size, error_msg): 89 | if isinstance(size, numbers.Number): 90 | return int(size), int(size) 91 | 92 | if isinstance(size, Sequence) and len(size) == 1: 93 | return size[0], size[0] 94 | 95 | if len(size) != 2: 96 | raise ValueError(error_msg) 97 | 98 | return size 99 | def _is_tensor_a_torch_image(x: Tensor) -> bool: 100 | return x.ndim >= 2 101 | def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: 102 | ksize_half = (kernel_size - 1) * 0.5 103 | 104 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) 105 | pdf = torch.exp(-0.5 * (x / sigma).pow(2)) 106 | kernel1d = pdf / pdf.sum() 107 | 108 | return kernel1d 109 | 110 | def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]: 111 | need_squeeze = False 112 | # make image NCHW 113 | if img.ndim < 4: 114 | img = img.unsqueeze(dim=0) 115 | need_squeeze = True 116 | 117 | out_dtype = img.dtype 118 | need_cast = False 119 | if out_dtype != req_dtype: 120 | need_cast = True 121 | img = img.to(req_dtype) 122 | return img, need_cast, need_squeeze, out_dtype 123 | def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype): 124 | if need_squeeze: 125 | img = img.squeeze(dim=0) 126 | 127 | if need_cast: 128 | # it is better to round before cast 129 | img = torch.round(img).to(out_dtype) 130 | 131 | return img 132 | def _get_gaussian_kernel2d( 133 | kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device 134 | ) -> Tensor: 135 | kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) 136 | kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) 137 | kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) 138 | return kernel2d 139 | def _gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: 140 | """PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel. 141 | 142 | .. warning:: 143 | 144 | Module ``transforms.functional_tensor`` is private and should not be used in user application. 145 | Please, consider instead using methods from `transforms.functional` module. 146 | 147 | Args: 148 | img (Tensor): Image to be blurred 149 | kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``. 150 | sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``. 151 | 152 | Returns: 153 | Tensor: An image that is blurred using gaussian kernel of given parameters 154 | """ 155 | if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)): 156 | raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) 157 | 158 | dtype = img.dtype if torch.is_floating_point(img) else torch.float32 159 | kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) 160 | kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) 161 | 162 | img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, kernel.dtype) 163 | 164 | # padding = (left, right, top, bottom) 165 | padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] 166 | img = torch_pad(img, padding, mode="reflect") 167 | img = conv2d(img, kernel, groups=img.shape[-3]) 168 | 169 | img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) 170 | return img 171 | 172 | def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: 173 | """Performs Gaussian blurring on the img by given kernel. 174 | The image can be a PIL Image or a Tensor, in which case it is expected 175 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions 176 | 177 | Args: 178 | img (PIL Image or Tensor): Image to be blurred 179 | kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers 180 | like ``(kx, ky)`` or a single integer for square kernels. 181 | In torchscript mode kernel_size as single int is not supported, use a tuple or 182 | list of length 1: ``[ksize, ]``. 183 | sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a 184 | sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the 185 | same sigma in both X/Y directions. If None, then it is computed using 186 | ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``. 187 | Default, None. In torchscript mode sigma as single float is 188 | not supported, use a tuple or list of length 1: ``[sigma, ]``. 189 | 190 | Returns: 191 | PIL Image or Tensor: Gaussian Blurred version of the image. 192 | """ 193 | if not isinstance(kernel_size, (int, list, tuple)): 194 | raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size))) 195 | if isinstance(kernel_size, int): 196 | kernel_size = [kernel_size, kernel_size] 197 | if len(kernel_size) != 2: 198 | raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size))) 199 | for ksize in kernel_size: 200 | if ksize % 2 == 0 or ksize < 0: 201 | raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size)) 202 | 203 | if sigma is None: 204 | sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] 205 | 206 | if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): 207 | raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma))) 208 | if isinstance(sigma, (int, float)): 209 | sigma = [float(sigma), float(sigma)] 210 | if isinstance(sigma, (list, tuple)) and len(sigma) == 1: 211 | sigma = [sigma[0], sigma[0]] 212 | if len(sigma) != 2: 213 | raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma))) 214 | for s in sigma: 215 | if s <= 0.: 216 | raise ValueError('sigma should have positive values. Got {}'.format(sigma)) 217 | 218 | t_img = img 219 | if not isinstance(img, torch.Tensor): 220 | if not _is_pil_image(img): 221 | raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img))) 222 | 223 | t_img = to_tensor(img) 224 | 225 | output = _gaussian_blur(t_img, kernel_size, sigma) 226 | 227 | if not isinstance(img, torch.Tensor): 228 | output = to_pil_image(output) 229 | return output 230 | 231 | 232 | 233 | 234 | # if __name__ == "__main__": 235 | # gaussian_blur = GaussianBlur(kernel_size=23) -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Classifier(nn.Module): 4 | def __init__(self, inputs, class_num): 5 | super(Classifier, self).__init__() 6 | self.classifier_layer = nn.Linear(inputs, class_num) 7 | self.classifier_layer.weight.data.normal_(0, 0.01) 8 | self.classifier_layer.bias.data.fill_(0.0) 9 | 10 | def forward(self, x): 11 | return self.classifier_layer(x) -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/configs/__init__.py -------------------------------------------------------------------------------- /configs/byol_aircrafts: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/configs/byol_aircrafts -------------------------------------------------------------------------------- /configs/byol_aircrafts.yaml: -------------------------------------------------------------------------------- 1 | name: byol-aircrafts-resnet50-experiment 2 | dataset: 3 | name: aircrafts 4 | image_size: 224 5 | num_workers: 8 6 | 7 | model: 8 | name: byol 9 | backbone: resnet50_aircrafts 10 | 11 | train: 12 | optimizer: 13 | name: lars_simclr 14 | weight_decay: 1.5e-6 15 | momentum: 0.9 16 | warmup_epochs: 10 17 | warmup_lr: 0 18 | base_lr: 0.3 19 | final_lr: 0 20 | num_epochs: 100 # this parameter influence the lr decay 21 | stop_at_epoch: 100 # has to be smaller than num_epochs 22 | batch_size: 128 23 | knn_monitor: False # knn monitor will take more time 24 | knn_interval: 3 25 | knn_k: 200 26 | eval: # linear evaluation, False will turn off automatic evaluation after training 27 | optimizer: 28 | name: sgd 29 | weight_decay: 0 30 | momentum: 0.9 31 | warmup_lr: 0 32 | warmup_epochs: 0 33 | base_lr: 30 34 | final_lr: 0 35 | batch_size: 128 36 | num_epochs: 100 37 | 38 | logger: 39 | tensorboard: True 40 | matplotlib: True 41 | 42 | seed: null # None type for yaml file 43 | # two things might lead to stochastic behavior other than seed: 44 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 45 | # (keep this in mind if you want to achieve 100% deterministic) 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /configs/byol_aircrafts_eval.yaml: -------------------------------------------------------------------------------- 1 | name: byol-aircrafts-resnet50-experiment 2 | dataset: 3 | name: aircrafts 4 | image_size: 224 5 | num_workers: 4 6 | 7 | model: 8 | name: byol 9 | backbone: resnet50_aircrafts 10 | 11 | train: null 12 | eval: # linear evaluation, False will turn off automatic evaluation after training 13 | optimizer: 14 | name: sgd 15 | weight_decay: 0 16 | momentum: 0.9 17 | warmup_lr: 0 18 | warmup_epochs: 0 19 | base_lr: 30 20 | final_lr: 0 21 | batch_size: 32 22 | num_epochs: 30 23 | 24 | logger: 25 | tensorboard: True 26 | matplotlib: True 27 | 28 | seed: null # None type for yaml file 29 | # two things might lead to stochastic behavior other than seed: 30 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 31 | # (keep this in mind if you want to achieve 100% deterministic) 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /configs/byol_cifar.yaml: -------------------------------------------------------------------------------- 1 | name: byol-cifar10-experiment 2 | dataset: 3 | name: cifar10 4 | image_size: 32 5 | num_workers: 4 6 | 7 | model: 8 | name: byol 9 | backbone: resnet18_cifar_variant1 10 | 11 | train: 12 | optimizer: 13 | name: lars_simclr 14 | weight_decay: 1.5e-6 15 | warmup_epochs: 10 16 | warmup_lr: 0 17 | base_lr: 0.3 18 | final_lr: 0 19 | num_epochs: 800 # this parameter influence the lr decay 20 | stop_at_epoch: 100 # has to be smaller than num_epochs 21 | batch_size: 256 22 | knn_monitor: False # knn monitor will take more time 23 | knn_interval: 1 24 | knn_k: 200 25 | eval: # linear evaluation, False will turn off automatic evaluation after training 26 | optimizer: 27 | name: sgd 28 | weight_decay: 0 29 | momentum: 0.9 30 | warmup_lr: 0 31 | warmup_epochs: 0 32 | base_lr: 30 33 | final_lr: 0 34 | batch_size: 256 35 | num_epochs: 30 36 | 37 | logger: 38 | tensorboard: True 39 | matplotlib: True 40 | 41 | seed: null # None type for yaml file 42 | # two things might lead to stochastic behavior other than seed: 43 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 44 | # (keep this in mind if you want to achieve 100% deterministic) 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /configs/byol_cub200.yaml: -------------------------------------------------------------------------------- 1 | name: byol-cub200-experiment-resnet50_cub200_variant1 2 | dataset: 3 | name: cub200 4 | image_size: 224 5 | num_workers: 8 6 | 7 | model: 8 | name: byol 9 | backbone: resnet50_cub200 10 | 11 | train: 12 | optimizer: 13 | name: lars_simclr 14 | weight_decay: 1.5e-6 15 | momentum: 0.9 16 | warmup_epochs: 10 17 | warmup_lr: 0 18 | base_lr: 0.3 19 | final_lr: 0 20 | num_epochs: 800 # this parameter influence the lr decay 21 | stop_at_epoch: 100 # has to be smaller than num_epochs 22 | batch_size: 128 23 | knn_monitor: True # knn monitor will take more time 24 | knn_interval: 1 25 | knn_k: 200 26 | eval: # linear evaluation, False will turn off automatic evaluation after training 27 | optimizer: 28 | name: sgd 29 | weight_decay: 0 30 | momentum: 0.9 31 | warmup_lr: 0 32 | warmup_epochs: 0 33 | base_lr: 30 34 | final_lr: 0 35 | batch_size: 128 36 | num_epochs: 100 37 | 38 | logger: 39 | tensorboard: True 40 | matplotlib: True 41 | 42 | seed: null # None type for yaml file 43 | # two things might lead to stochastic behavior other than seed: 44 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 45 | # (keep this in mind if you want to achieve 100% deterministic) 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /configs/byol_stanfordcars.yaml: -------------------------------------------------------------------------------- 1 | name: byol-stanfordcars-resnet50-experiment 2 | dataset: 3 | name: stanfordcars 4 | image_size: 224 5 | num_workers: 4 6 | 7 | model: 8 | name: byol 9 | backbone: resnet50_stanfordcars 10 | 11 | train: 12 | optimizer: 13 | name: lars_simclr 14 | weight_decay: 1.5e-6 15 | momentum: 0.9 16 | warmup_epochs: 10 17 | warmup_lr: 0 18 | base_lr: 0.3 19 | final_lr: 0 20 | num_epochs: 100 # this parameter influence the lr decay 21 | stop_at_epoch: 100 # has to be smaller than num_epochs 22 | batch_size: 128 23 | knn_monitor: False # knn monitor will take more time 24 | knn_interval: 3 25 | knn_k: 200 26 | eval: # linear evaluation, False will turn off automatic evaluation after training 27 | optimizer: 28 | name: sgd 29 | weight_decay: 0 30 | momentum: 0.9 31 | warmup_lr: 0 32 | warmup_epochs: 0 33 | base_lr: 30 34 | final_lr: 0 35 | batch_size: 128 36 | num_epochs: 100 37 | 38 | logger: 39 | tensorboard: True 40 | matplotlib: True 41 | 42 | seed: null # None type for yaml file 43 | # two things might lead to stochastic behavior other than seed: 44 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 45 | # (keep this in mind if you want to achieve 100% deterministic) 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /datasets/CUB200.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # Reading data 3 | import scipy.misc 4 | from matplotlib.pyplot import imread 5 | import cv2 6 | import os 7 | from PIL import Image 8 | from torchvision import transforms 9 | import torch 10 | 11 | class CUB(): 12 | def __init__(self, root, is_train=True, data_len=None,transform=None, target_transform=None): 13 | self.root = root 14 | self.is_train = is_train 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | img_txt_file = open(os.path.join(self.root, 'images.txt')) 18 | label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) 19 | train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) 20 | train_class_file = open(os.path.join(self.root, 'classes.txt')) 21 | # Picture index 22 | img_name_list = [] 23 | for line in img_txt_file: 24 | # The last character is a newline character 25 | img_name_list.append(line[:-1].split(' ')[-1]) 26 | 27 | # Tag Index , Each corresponding label minus 1, Tag value from 0 Start 28 | label_list = [] 29 | for line in label_txt_file: 30 | label_list.append(int(line[:-1].split(' ')[-1]) - 1) 31 | train_class_list=[] 32 | for line in train_class_file: 33 | train_class_list.append(line[:-1].split('.')[-1] ) 34 | 35 | # Set up training and test sets 36 | train_test_list = [] 37 | for line in train_val_file: 38 | train_test_list.append(int(line[:-1].split(' ')[-1])) 39 | 40 | # zip Compress merge , Associate data with labels ( Training set or test set ) Corresponding compression 41 | # zip() Function to take iteratable objects as parameters , Package the corresponding elements in the object into tuples , 42 | # And then return the objects made up of these tuples , The advantage is that it saves a lot of memory . 43 | # We can use list() Convert to output list 44 | 45 | # If i by 1, Then set it as the training set 46 | # 1 As the training set ,0 For test set 47 | # zip Compress merge , Associate data with labels ( Training set or test set ) Corresponding compression 48 | # If i by 1, Then set it as the training set 49 | train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] 50 | test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] 51 | 52 | train_label_list = [x for i, x in zip(train_test_list, label_list) if i][:data_len] 53 | test_label_list = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] 54 | if self.is_train: 55 | # scipy.misc.imread The picture reads out as array type , namely numpy type 56 | self.train_img = [self.read_pic(os.path.join(self.root, 'images', train_file)) for train_file in 57 | train_file_list[:data_len]] 58 | # Read the training set label 59 | self.train_label = train_label_list 60 | self.targets=train_label_list 61 | self.classes=train_class_list 62 | if not self.is_train: 63 | self.test_img = [self.read_pic(os.path.join(self.root, 'images', test_file)) for test_file in 64 | test_file_list[:data_len]] 65 | self.test_label = test_label_list 66 | self.targets = test_label_list 67 | self.classes = train_class_list 68 | 69 | # Data to enhance 70 | def __getitem__(self,index): 71 | # Training set 72 | if self.is_train: 73 | img, target = self.train_img[index], self.train_label[index] 74 | # Test set 75 | else: 76 | img, target = self.test_img[index], self.test_label[index] 77 | 78 | if len(img.shape) == 2: 79 | # Gray images are converted to three channels 80 | img = np.stack([img]*3,2) 81 | # To RGB type 82 | img = Image.fromarray(img,mode='RGB') 83 | if self.transform is not None: 84 | img = self.transform(img) 85 | 86 | if self.target_transform is not None: 87 | target = self.target_transform(target) 88 | 89 | return img, target 90 | 91 | def __len__(self): 92 | if self.is_train: 93 | return len(self.train_label) 94 | else: 95 | return len(self.test_label) 96 | 97 | def read_pic(self, train_file ): 98 | try: 99 | img= cv2.imread(train_file) 100 | return img 101 | except Exception as e: 102 | print(train_file) 103 | print(e) 104 | 105 | if __name__ == '__main__': 106 | ''' dataset = CUB(root='./CUB_200_2011') for data in dataset: print(data[0].size(),data[1]) ''' 107 | # With pytorch in DataLoader Read the data set in a way 108 | transform_train = transforms.Compose([ 109 | transforms.Resize((224, 224)), 110 | transforms.RandomCrop(224, padding=4), 111 | transforms.RandomHorizontalFlip(), 112 | transforms.ToTensor(), 113 | transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]), 114 | ]) 115 | 116 | dataset = CUB(root='./CUB_200_2011', is_train=False, transform=transform_train,) 117 | print(len(dataset)) 118 | trainloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0, 119 | drop_last=True) 120 | print(len(trainloader)) -------------------------------------------------------------------------------- /datasets/CUB200_val.py: -------------------------------------------------------------------------------- 1 | # Reading data 2 | import os 3 | import pathlib 4 | import shutil 5 | 6 | class CUB(): 7 | def __init__(self, root, is_train=True): 8 | self.root = root 9 | self.is_train = is_train 10 | 11 | def convert_bird(data_root): 12 | images_txt = os.path.join(data_root, 'images.txt') 13 | train_val_txt = os.path.join(data_root, 'train_test_split.txt') 14 | labels_txt = os.path.join(data_root, 'image_class_labels.txt') 15 | image_folder=os.path.join(data_root, 'images') 16 | id_name_dict = {} 17 | id_class_dict = {} 18 | id_train_val = {} 19 | with open(images_txt, 'r', encoding='utf-8') as f: 20 | line = f.readline() 21 | while line: 22 | id, name = line.strip().split() 23 | id_name_dict[id] = name 24 | line = f.readline() 25 | 26 | with open(train_val_txt, 'r', encoding='utf-8') as f: 27 | line = f.readline() 28 | while line: 29 | id, trainval = line.strip().split() 30 | id_train_val[id] = trainval 31 | line = f.readline() 32 | 33 | with open(labels_txt, 'r', encoding='utf-8') as f: 34 | line = f.readline() 35 | while line: 36 | id, class_id = line.strip().split() 37 | id_class_dict[id] = int(class_id) 38 | line = f.readline() 39 | 40 | train_txt = os.path.join(data_root, 'bird_train.txt') 41 | test_txt = os.path.join(data_root, 'bird_test.txt') 42 | 43 | train_folder=os.path.join(data_root, 'train') 44 | test_folder = os.path.join(data_root, 'test') 45 | if os.path.exists(train_txt): 46 | os.remove(train_txt) 47 | if os.path.exists(test_txt): 48 | os.remove(test_txt) 49 | if os.path.exists(train_folder): 50 | os.remove(train_folder) 51 | if os.path.exists(test_folder): 52 | os.remove(test_folder) 53 | 54 | f1 = open(train_txt, 'a', encoding='utf-8') 55 | f2 = open(test_txt, 'a', encoding='utf-8') 56 | 57 | for id, trainval in id_train_val.items(): 58 | if trainval == '1': 59 | src_path=os.path.join(image_folder, id_name_dict[id]) 60 | dst_path=os.path.join(train_folder,str(id_class_dict[id] - 1)) 61 | pathlib.Path(dst_path).mkdir(parents=True, exist_ok=True) 62 | shutil.copy(src_path,os.path.join(dst_path,os.path.basename(src_path))) 63 | f1.write('%s %d\n' % (id_name_dict[id], id_class_dict[id] - 1)) 64 | else: 65 | 66 | src_path=os.path.join(image_folder, id_name_dict[id]) 67 | dst_path=os.path.join(test_folder,str(id_class_dict[id] - 1)) 68 | pathlib.Path(dst_path).mkdir(parents=True, exist_ok=True) 69 | shutil.copy(src_path,os.path.join(dst_path,os.path.basename(src_path))) 70 | f2.write('%s %d\n' % (id_name_dict[id], id_class_dict[id] - 1)) 71 | f1.close() 72 | f2.close() 73 | 74 | convert_bird(self.root) -------------------------------------------------------------------------------- /datasets/CUB2011.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torchvision.datasets.folder import default_loader 4 | from torchvision.datasets.utils import download_url 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class Cub2011(Dataset): 9 | base_folder = 'CUB_200_2011/images' 10 | url = 'https://data.caltech.edu/tindfiles/serve/1239ea37-e132-42ee-8c09-c383bb54e7ff/' #''http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 11 | 12 | 13 | filename = 'CUB_200_2011.tgz' 14 | tgz_md5 = '97eceeb196236b17998738112f37df78' 15 | 16 | def __init__(self, root, train=True, transform=None, loader=default_loader, download=True): 17 | self.root = os.path.expanduser(root) 18 | self.transform = transform 19 | self.loader = default_loader 20 | self.train = train 21 | self.classes=None 22 | self.targets = None 23 | if download: 24 | self._download() 25 | 26 | if not self._check_integrity(): 27 | raise RuntimeError('Dataset not found or corrupted.' + 28 | ' You can use download=True to download it') 29 | 30 | def _load_metadata(self): 31 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 32 | names=['img_id', 'filepath']) 33 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 34 | sep=' ', names=['img_id', 'target']) 35 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 36 | sep=' ', names=['img_id', 'is_training_img']) 37 | image_classes=pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'), 38 | sep=' ', names=['cls_id', 'cls_name']) 39 | 40 | self.classes = image_classes['cls_name'].tolist()#image_classes.merge(classes, on='cls_id') 41 | data = images.merge(image_class_labels, on='img_id') 42 | self.data = data.merge(train_test_split, on='img_id') 43 | 44 | if self.train: 45 | self.data = self.data[self.data.is_training_img == 1] 46 | self.targets = self.data.target.tolist() 47 | self.targets=[x-1 for x in self.targets ] 48 | else: 49 | self.data = self.data[self.data.is_training_img == 0] 50 | self.targets = self.data.target.tolist() 51 | self.targets = [x - 1 for x in self.targets] 52 | 53 | def _check_integrity(self): 54 | try: 55 | self._load_metadata() 56 | except Exception: 57 | return False 58 | 59 | for index, row in self.data.iterrows(): 60 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 61 | if not os.path.isfile(filepath): 62 | print(filepath) 63 | return False 64 | return True 65 | 66 | def _download(self): 67 | import tarfile 68 | 69 | if self._check_integrity(): 70 | print('Files already downloaded and verified') 71 | return 72 | 73 | download_url(self.url, self.root, self.filename, self.tgz_md5) 74 | 75 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 76 | tar.extractall(path=self.root) 77 | 78 | def __len__(self): 79 | return len(self.data) 80 | 81 | def __getitem__(self, idx): 82 | sample = self.data.iloc[idx] 83 | path = os.path.join(self.root, self.base_folder, sample.filepath) 84 | target = sample.target-1 # Targets start at 1 by default, so shift to 0 85 | img = self.loader(path) 86 | 87 | if self.transform is not None: 88 | img = self.transform(img) 89 | 90 | return img, target -------------------------------------------------------------------------------- /datasets/ImageNet100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torchvision.datasets.folder import default_loader 4 | from torchvision.datasets.utils import download_url 5 | from torch.utils.data import Dataset 6 | import torch 7 | import torchvision 8 | #https://blog.csdn.net/Andrew_SJ/article/details/111335319 9 | class ImageNet100(): 10 | def __init__(self, data_dir, train=True, transform=None): 11 | self.root = os.path.expanduser(data_dir) 12 | self.transform = transform 13 | #self.loader = default_loader 14 | self.train = train 15 | # self.classes=None 16 | # self.targets = None 17 | l=[] 18 | l.append(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train.X1'),transform=transform)) 19 | l.append(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train.X2'), transform=transform)) 20 | l.append(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train.X3'), transform=transform)) 21 | l.append(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train.X4'), transform=transform)) 22 | train_datasets=torch.utils.data.ConcatDataset(l) 23 | train_size = int(0.8 * len(train_datasets)) 24 | test_size = len(train_datasets) - train_size 25 | train_dataset, test_dataset = torch.utils.data.random_split(train_datasets, [train_size, test_size]) 26 | self.train_dataset = train_dataset.dataset 27 | self.test_dataset=test_dataset.dataset 28 | 29 | #base_folder = 'CUB_200_2011/images' 30 | #url = 'https://data.caltech.edu/tindfiles/serve/1239ea37-e132-42ee-8c09-c383bb54e7ff/' #''http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 31 | def getdata(self): 32 | if self.train: 33 | return self.train_dataset 34 | else: return self.test_dataset 35 | 36 | #filename = 'CUB_200_2011.tgz' 37 | #tgz_md5 = '97eceeb196236b17998738112f37df78' 38 | 39 | 40 | class train_dataset_transformed(Dataset): 41 | def __init__(self, subset, transform=None): 42 | self.subset = subset 43 | self.transform = transform 44 | 45 | def __getitem__(self, index): 46 | x, y = self.subset[index] 47 | if self.transform: 48 | x = self.transform(x) 49 | return x, y 50 | 51 | def __len__(self): 52 | return len(self.subset) 53 | 54 | #train_dataset = train_dataset_transformed(train_dataset, transform=train_transform) 55 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import torch 4 | import torchvision 5 | from .random_dataset import RandomDataset 6 | from .CUB2011 import Cub2011 as CUB 7 | from .ImageNet100 import ImageNet100 8 | #from .StandfordCars import StandfordCars 9 | 10 | 11 | def get_dataset(dataset, data_dir, transform, train=True, download=False, debug_subset_size=None): 12 | if dataset == 'mnist': 13 | dataset = torchvision.datasets.MNIST(data_dir, train=train, transform=transform, download=download) 14 | elif dataset == 'stl10': 15 | dataset = torchvision.datasets.STL10(data_dir, split='train+unlabeled' if train else 'test', transform=transform, download=download) 16 | elif dataset == 'cifar10': 17 | dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=download) 18 | elif dataset == 'cifar100': 19 | dataset = torchvision.datasets.CIFAR100(data_dir, train=train, transform=transform, download=download) 20 | elif dataset == 'imagenet': 21 | dataset = torchvision.datasets.ImageNet(data_dir, split='train' if train == True else 'val', transform=transform, download=download) 22 | elif dataset == 'imagenet100': 23 | data_builder = ImageNet100(data_dir, train = train,transform=transform)# torchvision.datasets.ImageFolder(data_dir , transform=transform) #+ 'train' if train == True else 'val', 24 | dataset=data_builder.getdata() 25 | elif dataset =='cub200': 26 | # dataset = elif dataset == 'imagenet': 27 | # dataset= torchvision.datasets.ImageFolder(data_dir+ '/train' if train == True else data_dir+ '/test', transform=transform) 28 | # print('from datasets.imagefolders') 29 | #############dataset = CUB(data_dir, train = train,transform=transform,download=download) 30 | dataset = torchvision.datasets.ImageFolder( 31 | os.path.join(data_dir, 'train') if train == True else os.path.join(data_dir, 'test'), transform=transform) 32 | # print('') 33 | # trainloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0, 34 | # drop_last=True) 35 | elif dataset=='stanfordcars': 36 | dataset= torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train') if train == True else os.path.join(data_dir, 'test'), transform=transform)#StandfordCars(data_dir,train=train,transform=transform) 37 | elif dataset == 'aircrafts': 38 | dataset = torchvision.datasets.ImageFolder( 39 | os.path.join(data_dir, 'train') if train == True else os.path.join(data_dir, 'test'), 40 | transform=transform) 41 | elif dataset == 'random': 42 | dataset = RandomDataset() 43 | else: 44 | raise NotImplementedError 45 | #debug_subset_size=50 46 | if debug_subset_size is not None: 47 | dataset = torch.utils.data.Subset(dataset, range(0, debug_subset_size)) # take only one batch 48 | dataset.classes = dataset.dataset.classes 49 | dataset.targets = dataset.dataset.targets 50 | return dataset -------------------------------------------------------------------------------- /datasets/__pycache__/CUB2011.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/datasets/__pycache__/CUB2011.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ImageNet100.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/datasets/__pycache__/ImageNet100.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/random_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/datasets/__pycache__/random_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/random_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RandomDataset(torch.utils.data.Dataset): 4 | def __init__(self, root=None, train=True, transform=None, target_transform=None): 5 | self.transform = transform 6 | self.target_transform = target_transform 7 | 8 | self.size = 1000 9 | def __getitem__(self, idx): 10 | if idx < self.size: 11 | return [torch.randn((3, 224, 224)), torch.randn((3, 224, 224))], [0,0,0] 12 | else: 13 | raise Exception 14 | 15 | def __len__(self): 16 | return self.size 17 | -------------------------------------------------------------------------------- /examples/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/examples/framework.png -------------------------------------------------------------------------------- /hand_detector.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from cvzone.HandTrackingModule import HandDetector 3 | 4 | cap = cv2.VideoCapture(0) 5 | cap.set(3, 1280) 6 | cap.set(4, 720) 7 | detector = HandDetector(detectionCon=0.8) 8 | 9 | while True: 10 | success, img = cap.read() 11 | pass 12 | hands, img = detector.findHands(img) 13 | cv2.imshow('Hands', img) 14 | cv2.waitKey(1) 15 | -------------------------------------------------------------------------------- /linear_cub_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | 14 | def main(args): 15 | 16 | train_loader = torch.utils.data.DataLoader( 17 | dataset=get_dataset( 18 | transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs), 19 | train=True, 20 | **args.dataset_kwargs 21 | ), 22 | batch_size=args.eval.batch_size, 23 | shuffle=True, 24 | **args.dataloader_kwargs 25 | ) 26 | test_loader = torch.utils.data.DataLoader( 27 | dataset=get_dataset( 28 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), 29 | train=False, 30 | **args.dataset_kwargs 31 | ), 32 | batch_size=args.eval.batch_size, 33 | shuffle=False, 34 | **args.dataloader_kwargs 35 | ) 36 | 37 | 38 | #model = get_backbone(args.model.backbone) 39 | model = get_model(args.model) 40 | classifier = nn.Linear(in_features=2048, out_features=200, bias=True).to(args.device) 41 | 42 | assert args.eval_from is not None 43 | save_dict = torch.load(args.eval_from, map_location='cpu') 44 | #msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) 45 | 46 | # print(msg) 47 | model = model.to(args.device) 48 | model = torch.nn.DataParallel(model) 49 | 50 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 51 | classifier = torch.nn.DataParallel(classifier) 52 | # define optimizer 53 | optimizer = get_optimizer( 54 | args.eval.optimizer.name, classifier, 55 | lr=args.eval.base_lr*args.eval.batch_size/256, 56 | momentum=args.eval.optimizer.momentum, 57 | weight_decay=args.eval.optimizer.weight_decay) 58 | 59 | # define lr scheduler 60 | lr_scheduler = LR_Scheduler( 61 | optimizer, 62 | args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256, 63 | args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256, 64 | len(train_loader), 65 | ) 66 | 67 | loss_meter = AverageMeter(name='Loss') 68 | acc_meter = AverageMeter(name='Accuracy') 69 | 70 | # Start training 71 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating') 72 | for epoch in global_progress: 73 | loss_meter.reset() 74 | model.eval() 75 | classifier.train() 76 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True) 77 | 78 | for idx, (images, labels) in enumerate(local_progress): 79 | 80 | classifier.zero_grad() 81 | with torch.no_grad(): 82 | feat, bp_out_feat = model.module.inference(images.to(args.device)) 83 | preds = classifier(bp_out_feat) 84 | 85 | loss = F.cross_entropy(preds, labels.to(args.device)) 86 | 87 | loss.backward() 88 | optimizer.step() 89 | loss_meter.update(loss.item()) 90 | lr = lr_scheduler.step() 91 | local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg}) 92 | 93 | classifier.eval() 94 | correct, total = 0, 0 95 | acc_meter.reset() 96 | for idx, (images, labels) in enumerate(test_loader): 97 | with torch.no_grad(): 98 | feat,bp_out_feat = model.module.inference(images.to(args.device)) 99 | 100 | preds = classifier(bp_out_feat).argmax(dim=1) 101 | correct = (preds == labels.to(args.device)).sum().item() 102 | acc_meter.update(correct/preds.shape[0]) 103 | print(f'Accuracy = {acc_meter.avg*100:.2f}') 104 | 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | main(args=get_args()) 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /linear_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | 14 | def main(args): 15 | 16 | train_loader = torch.utils.data.DataLoader( 17 | dataset=get_dataset( 18 | transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs), 19 | train=True, 20 | **args.dataset_kwargs 21 | ), 22 | batch_size=args.eval.batch_size, 23 | shuffle=True, 24 | **args.dataloader_kwargs 25 | ) 26 | test_loader = torch.utils.data.DataLoader( 27 | dataset=get_dataset( 28 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), 29 | train=False, 30 | **args.dataset_kwargs 31 | ), 32 | batch_size=args.eval.batch_size, 33 | shuffle=False, 34 | **args.dataloader_kwargs 35 | ) 36 | 37 | 38 | model = get_backbone(args.model.backbone) 39 | classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True).to(args.device) 40 | 41 | assert args.eval_from is not None 42 | save_dict = torch.load(args.eval_from, map_location='cpu') 43 | msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) 44 | 45 | # print(msg) 46 | model = model.to(args.device) 47 | model = torch.nn.DataParallel(model) 48 | 49 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 50 | classifier = torch.nn.DataParallel(classifier) 51 | # define optimizer 52 | optimizer = get_optimizer( 53 | args.eval.optimizer.name, classifier, 54 | lr=args.eval.base_lr*args.eval.batch_size/256, 55 | momentum=args.eval.optimizer.momentum, 56 | weight_decay=args.eval.optimizer.weight_decay) 57 | 58 | # define lr scheduler 59 | lr_scheduler = LR_Scheduler( 60 | optimizer, 61 | args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256, 62 | args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256, 63 | len(train_loader), 64 | ) 65 | 66 | loss_meter = AverageMeter(name='Loss') 67 | acc_meter = AverageMeter(name='Accuracy') 68 | 69 | # Start training 70 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating') 71 | for epoch in global_progress: 72 | loss_meter.reset() 73 | model.eval() 74 | classifier.train() 75 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True) 76 | 77 | for idx, (images, labels) in enumerate(local_progress): 78 | 79 | classifier.zero_grad() 80 | with torch.no_grad(): 81 | feature = model(images.to(args.device)) 82 | 83 | preds = classifier(feature) 84 | 85 | loss = F.cross_entropy(preds, labels.to(args.device)) 86 | 87 | loss.backward() 88 | optimizer.step() 89 | loss_meter.update(loss.item()) 90 | lr = lr_scheduler.step() 91 | local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg}) 92 | 93 | classifier.eval() 94 | correct, total = 0, 0 95 | acc_meter.reset() 96 | for idx, (images, labels) in enumerate(test_loader): 97 | with torch.no_grad(): 98 | feature = model(images.to(args.device)) 99 | preds = classifier(feature).argmax(dim=1) 100 | correct = (preds == labels.to(args.device)).sum().item() 101 | acc_meter.update(correct/preds.shape[0]) 102 | print(f'Accuracy = {acc_meter.avg*100:.2f}') 103 | 104 | 105 | 106 | 107 | if __name__ == "__main__": 108 | main(args=get_args()) 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /linear_imagenet100_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | 14 | def main(args): 15 | 16 | train_loader = torch.utils.data.DataLoader( 17 | dataset=get_dataset( 18 | transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs), 19 | train=True, 20 | **args.dataset_kwargs 21 | ), 22 | batch_size=args.eval.batch_size, 23 | shuffle=True, 24 | **args.dataloader_kwargs 25 | ) 26 | test_loader = torch.utils.data.DataLoader( 27 | dataset=get_dataset( 28 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), 29 | train=False, 30 | **args.dataset_kwargs 31 | ), 32 | batch_size=args.eval.batch_size, 33 | shuffle=False, 34 | **args.dataloader_kwargs 35 | ) 36 | 37 | 38 | model = get_backbone(args.model.backbone) 39 | classifier = nn.Linear(in_features=model.output_dim, out_features=100, bias=True).to(args.device) 40 | 41 | assert args.eval_from is not None 42 | save_dict = torch.load(args.eval_from, map_location='cpu') 43 | msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) 44 | 45 | # print(msg) 46 | model = model.to(args.device) 47 | model = torch.nn.DataParallel(model) 48 | 49 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 50 | classifier = torch.nn.DataParallel(classifier) 51 | # define optimizer 52 | optimizer = get_optimizer( 53 | args.eval.optimizer.name, classifier, 54 | lr=args.eval.base_lr*args.eval.batch_size/256, 55 | momentum=args.eval.optimizer.momentum, 56 | weight_decay=args.eval.optimizer.weight_decay) 57 | 58 | # define lr scheduler 59 | lr_scheduler = LR_Scheduler( 60 | optimizer, 61 | args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256, 62 | args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256, 63 | len(train_loader), 64 | ) 65 | 66 | loss_meter = AverageMeter(name='Loss') 67 | acc_meter = AverageMeter(name='Accuracy') 68 | 69 | # Start training 70 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating') 71 | for epoch in global_progress: 72 | loss_meter.reset() 73 | model.eval() 74 | classifier.train() 75 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True) 76 | 77 | for idx, (images, labels) in enumerate(local_progress): 78 | 79 | classifier.zero_grad() 80 | with torch.no_grad(): 81 | feature = model(images.to(args.device)) 82 | 83 | preds = classifier(feature) 84 | 85 | loss = F.cross_entropy(preds, labels.to(args.device)) 86 | 87 | loss.backward() 88 | optimizer.step() 89 | loss_meter.update(loss.item()) 90 | lr = lr_scheduler.step() 91 | local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg}) 92 | 93 | classifier.eval() 94 | correct, total = 0, 0 95 | acc_meter.reset() 96 | for idx, (images, labels) in enumerate(test_loader): 97 | with torch.no_grad(): 98 | feature = model(images.to(args.device)) 99 | preds = classifier(feature).argmax(dim=1) 100 | correct = (preds == labels.to(args.device)).sum().item() 101 | acc_meter.update(correct/preds.shape[0]) 102 | print(f'Accuracy = {acc_meter.avg*100:.2f}') 103 | 104 | 105 | 106 | 107 | if __name__ == "__main__": 108 | main(args=get_args()) 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /linear_stanfordcars_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | 14 | def main(args,class_num): 15 | classes=class_num 16 | train_loader = torch.utils.data.DataLoader( 17 | dataset=get_dataset( 18 | transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs), 19 | train=True, 20 | **args.dataset_kwargs 21 | ), 22 | batch_size=args.eval.batch_size, 23 | shuffle=True, 24 | **args.dataloader_kwargs 25 | ) 26 | test_loader = torch.utils.data.DataLoader( 27 | dataset=get_dataset( 28 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), 29 | train=False, 30 | **args.dataset_kwargs 31 | ), 32 | batch_size=args.eval.batch_size, 33 | shuffle=False, 34 | **args.dataloader_kwargs 35 | ) 36 | 37 | 38 | #model = get_backbone(args.model.backbone) 39 | model = get_model(args.model) 40 | classifier = nn.Linear(in_features=2048, out_features=classes, bias=True).to(args.device) 41 | 42 | assert args.eval_from is not None 43 | save_dict = torch.load(args.eval_from, map_location='cpu') 44 | #msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) 45 | 46 | # print(msg) 47 | model = model.to(args.device) 48 | model = torch.nn.DataParallel(model) 49 | 50 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 51 | classifier = torch.nn.DataParallel(classifier) 52 | # define optimizer 53 | optimizer = get_optimizer( 54 | args.eval.optimizer.name, classifier, 55 | lr=args.eval.base_lr*args.eval.batch_size/256, 56 | momentum=args.eval.optimizer.momentum, 57 | weight_decay=args.eval.optimizer.weight_decay) 58 | 59 | # define lr scheduler 60 | lr_scheduler = LR_Scheduler( 61 | optimizer, 62 | args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256, 63 | args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256, 64 | len(train_loader), 65 | ) 66 | 67 | loss_meter = AverageMeter(name='Loss') 68 | acc_meter = AverageMeter(name='Accuracy') 69 | 70 | # Start training 71 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating') 72 | for epoch in global_progress: 73 | loss_meter.reset() 74 | model.eval() 75 | classifier.train() 76 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True) 77 | 78 | for idx, (images, labels) in enumerate(local_progress): 79 | 80 | classifier.zero_grad() 81 | with torch.no_grad(): 82 | feature, bp_out_feat = model.module.inference(images.to(args.device)) 83 | 84 | preds = classifier(bp_out_feat) 85 | 86 | loss = F.cross_entropy(preds, labels.to(args.device)) 87 | 88 | loss.backward() 89 | optimizer.step() 90 | loss_meter.update(loss.item()) 91 | lr = lr_scheduler.step() 92 | local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg}) 93 | 94 | classifier.eval() 95 | correct, total = 0, 0 96 | acc_meter.reset() 97 | for idx, (images, labels) in enumerate(test_loader): 98 | with torch.no_grad(): 99 | feature, bp_out_feat = model.module.inference(images.to(args.device)) 100 | 101 | 102 | preds = classifier(bp_out_feat).argmax(dim=1) 103 | correct = (preds == labels.to(args.device)).sum().item() 104 | acc_meter.update(correct/preds.shape[0]) 105 | print(f'Accuracy = {acc_meter.avg*100:.2f}') 106 | 107 | 108 | 109 | 110 | if __name__ == "__main__": 111 | main(args=get_args()) 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from PIL import ImageFilter 8 | import random 9 | 10 | 11 | class TwoCropsTransform: 12 | """Take two random crops of one image as the query and key.""" 13 | 14 | def __init__(self, base_transform): 15 | self.base_transform = base_transform 16 | 17 | def __call__(self, x): 18 | q = self.base_transform(x) 19 | k = self.base_transform(x) 20 | return [q, k] 21 | 22 | 23 | class GaussianBlur(object): 24 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 25 | 26 | def __init__(self, sigma=[.1, 2.]): 27 | self.sigma = sigma 28 | 29 | def __call__(self, x): 30 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 31 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 32 | return x 33 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | from tqdm import tqdm 8 | from arguments import get_args 9 | from augmentations import get_aug 10 | from models import get_model 11 | from tools import AverageMeter, knn_monitor, Logger, file_exist_check 12 | from datasets import get_dataset 13 | from optimizers import get_optimizer, LR_Scheduler 14 | #from linear_eval import main as linear_eval 15 | from linear_cub_eval import main as cub_linear_eval 16 | from linear_stanfordcars_eval import main as linear_stanfordcars_eval 17 | from linear_eval import main as linear_eval 18 | from datetime import datetime 19 | 20 | def main(device, args): 21 | 22 | train_loader = torch.utils.data.DataLoader( 23 | dataset=get_dataset( 24 | transform=get_aug(train=True, **args.aug_kwargs), 25 | train=True, 26 | **args.dataset_kwargs), 27 | shuffle=True, 28 | batch_size=args.train.batch_size, 29 | **args.dataloader_kwargs 30 | ) 31 | 32 | memory_loader = torch.utils.data.DataLoader( 33 | dataset=get_dataset( 34 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), 35 | train=True, 36 | **args.dataset_kwargs), 37 | shuffle=False, 38 | batch_size=args.train.batch_size, 39 | **args.dataloader_kwargs 40 | ) 41 | #test_loader=memory_loader 42 | test_loader = torch.utils.data.DataLoader( 43 | dataset=get_dataset( 44 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), 45 | train=False, 46 | **args.dataset_kwargs), 47 | shuffle=False, 48 | batch_size=args.train.batch_size, 49 | **args.dataloader_kwargs 50 | ) 51 | 52 | # define model 53 | model = get_model(args.model).to(device) 54 | model = torch.nn.DataParallel(model) 55 | #model = torch.nn.parallel.DistributedDataParallel(model) 56 | 57 | # define optimizer 58 | optimizer = get_optimizer( 59 | args.train.optimizer.name, model, 60 | lr=args.train.base_lr*args.train.batch_size/256, 61 | momentum=args.train.optimizer.momentum, 62 | weight_decay=args.train.optimizer.weight_decay) 63 | 64 | lr_scheduler = LR_Scheduler( 65 | optimizer, 66 | args.train.warmup_epochs, args.train.warmup_lr*args.train.batch_size/256, 67 | args.train.num_epochs, args.train.base_lr*args.train.batch_size/256, args.train.final_lr*args.train.batch_size/256, 68 | len(train_loader), 69 | constant_predictor_lr=True # see the end of section 4.2 predictor 70 | ) 71 | 72 | logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir) 73 | accuracy = 0 74 | # Start training 75 | global_progress = tqdm(range(0, args.train.stop_at_epoch), desc=f'Training') 76 | for epoch in global_progress: 77 | model.train() 78 | 79 | local_progress=tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress) 80 | for idx, ((images1, images2), labels) in enumerate(local_progress): 81 | 82 | model.zero_grad() 83 | data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True)) 84 | loss = data_dict['loss'].mean() # ddp 85 | 86 | 87 | t = 0.4 88 | grad_cam = data_dict['gradcam'].view(data_dict['gradcam'].size(0), -1) 89 | grad_cam = (grad_cam / t).float() 90 | att_max = data_dict['attmap'].view(data_dict['attmap'].size(0), -1) 91 | att_max = (att_max / t).float() 92 | loss_cam = F.kl_div(att_max.softmax(dim=-1).log(), grad_cam.softmax(dim=-1), 93 | reduction='sum') 94 | 95 | loss = loss + 0.01 * loss_cam 96 | 97 | loss.backward() 98 | optimizer.step() 99 | lr_scheduler.step() 100 | data_dict.update({'lr':lr_scheduler.get_lr()}) 101 | 102 | local_progress.set_postfix(data_dict) 103 | #logger.update_scalers(data_dict) 104 | 105 | if args.train.knn_monitor and epoch % args.train.knn_interval == 0: 106 | accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device, k=min(args.train.knn_k, len(memory_loader.dataset)), hide_progress=args.hide_progress) 107 | 108 | epoch_dict = {"epoch":epoch, "knn_monitor: accuracy":accuracy} 109 | global_progress.set_postfix(epoch_dict) 110 | #logger.update_scalers(epoch_dict) 111 | 112 | # Save checkpoint 113 | model_path = os.path.join(args.ckpt_dir, f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth") # datetime.now().strftime('%Y%m%d_%H%M%S') 114 | torch.save({ 115 | 'epoch': epoch+1, 116 | 'state_dict':model.module.state_dict() 117 | }, model_path) 118 | print(f"Model saved to {model_path}") 119 | with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f: 120 | f.write(f'{model_path}') 121 | 122 | 123 | if args.eval is not False: 124 | args.eval_from = model_path 125 | classes=len(test_loader.dataset.classes) 126 | dataset_name=args.dataset.name 127 | if dataset_name=='cub200': 128 | cub_linear_eval(args) 129 | elif dataset_name=='stanfordcars': 130 | linear_stanfordcars_eval(args,classes) 131 | elif dataset_name=='aircrafts': 132 | linear_stanfordcars_eval(args,classes) 133 | elif dataset_name=='cifar10': 134 | linear_eval(args) 135 | 136 | 137 | 138 | if __name__ == "__main__": 139 | args = get_args() 140 | 141 | main(device=args.device, args=args) 142 | 143 | completed_log_dir = args.log_dir.replace('in-progress', 'debug' if args.debug else 'completed') 144 | 145 | 146 | 147 | os.rename(args.log_dir, completed_log_dir) 148 | print(f'Log file has been saved to {completed_log_dir}') 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /main_lincls.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import builtins 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | 23 | import moco.loader 24 | import moco.builder 25 | from classifier import Classifier 26 | 27 | 28 | model_names = sorted(name for name in models.__dict__ 29 | if name.islower() and not name.startswith("__") 30 | and callable(models.__dict__[name])) 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 33 | parser.add_argument('data', metavar='DIR', 34 | help='path to dataset') 35 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 36 | choices=model_names, 37 | help='model architecture: ' + 38 | ' | '.join(model_names) + 39 | ' (default: resnet50)') 40 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 41 | help='number of data loading workers (default: 32)') 42 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 43 | help='number of total epochs to run') 44 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 45 | help='manual epoch number (useful on restarts)') 46 | parser.add_argument('-b', '--batch-size', default=256, type=int, 47 | metavar='N', 48 | help='mini-batch size (default: 256), this is the total ' 49 | 'batch size of all GPUs on the current node when ' 50 | 'using Data Parallel or Distributed Data Parallel') 51 | parser.add_argument('--lr', '--learning-rate', default=30., type=float, 52 | metavar='LR', help='initial learning rate', dest='lr') 53 | parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int, 54 | help='learning rate schedule (when to drop lr by a ratio)') 55 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 56 | help='momentum') 57 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 58 | metavar='W', help='weight decay (default: 0.)', 59 | dest='weight_decay') 60 | parser.add_argument('-p', '--print-freq', default=10, type=int, 61 | metavar='N', help='print frequency (default: 10)') 62 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 63 | help='path to latest checkpoint (default: none)') 64 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 65 | help='evaluate model on validation set') 66 | parser.add_argument('--world-size', default=-1, type=int, 67 | help='number of nodes for distributed training') 68 | parser.add_argument('--rank', default=-1, type=int, 69 | help='node rank for distributed training') 70 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 71 | help='url used to set up distributed training') 72 | parser.add_argument('--dist-backend', default='nccl', type=str, 73 | help='distributed backend') 74 | parser.add_argument('--seed', default=None, type=int, 75 | help='seed for initializing training. ') 76 | parser.add_argument('--gpu', default=None, type=int, 77 | help='GPU id to use.') 78 | parser.add_argument('--multiprocessing-distributed', action='store_true', 79 | help='Use multi-processing distributed training to launch ' 80 | 'N processes per node, which has N GPUs. This is the ' 81 | 'fastest way to use PyTorch for either single node or ' 82 | 'multi node data parallel training') 83 | 84 | parser.add_argument('--pretrained', default='', type=str, 85 | help='path to moco pretrained checkpoint') 86 | 87 | # moco specific configs: 88 | parser.add_argument('--moco-dim', default=256, type=int, 89 | help='feature dimension (default: 128)') 90 | parser.add_argument('--moco-k', default=65536, type=int, 91 | help='queue size; number of negative keys (default: 65536)') 92 | parser.add_argument('--moco-m', default=0.999, type=float, 93 | help='moco momentum of updating key encoder (default: 0.999)') 94 | parser.add_argument('--moco-t', default=0.07, type=float, 95 | help='softmax temperature (default: 0.07)') 96 | parser.add_argument('--mlp', action='store_true', 97 | help='use mlp head') 98 | parser.add_argument('--class_num', type=int, default=200) 99 | 100 | best_acc1 = 0 101 | 102 | 103 | def main(): 104 | args = parser.parse_args() 105 | 106 | if args.seed is not None: 107 | random.seed(args.seed) 108 | torch.manual_seed(args.seed) 109 | cudnn.deterministic = True 110 | warnings.warn('You have chosen to seed training. ' 111 | 'This will turn on the CUDNN deterministic setting, ' 112 | 'which can slow down your training considerably! ' 113 | 'You may see unexpected behavior when restarting ' 114 | 'from checkpoints.') 115 | 116 | if args.gpu is not None: 117 | warnings.warn('You have chosen a specific GPU. This will completely ' 118 | 'disable data parallelism.') 119 | 120 | if args.dist_url == "env://" and args.world_size == -1: 121 | args.world_size = int(os.environ["WORLD_SIZE"]) 122 | 123 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 124 | 125 | ngpus_per_node = torch.cuda.device_count() 126 | if args.multiprocessing_distributed: 127 | # Since we have ngpus_per_node processes per node, the total world_size 128 | # needs to be adjusted accordingly 129 | args.world_size = ngpus_per_node * args.world_size 130 | # Use torch.multiprocessing.spawn to launch distributed processes: the 131 | # main_worker process function 132 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 133 | else: 134 | # Simply call main_worker function 135 | main_worker(args.gpu, ngpus_per_node, args) 136 | 137 | 138 | def main_worker(gpu, ngpus_per_node, args): 139 | global best_acc1 140 | args.gpu = gpu 141 | 142 | # suppress printing if not master 143 | if args.multiprocessing_distributed and args.gpu != 0: 144 | def print_pass(*args): 145 | pass 146 | 147 | builtins.print = print_pass 148 | 149 | if args.gpu is not None: 150 | print("Use GPU: {} for training".format(args.gpu)) 151 | 152 | if args.distributed: 153 | if args.dist_url == "env://" and args.rank == -1: 154 | args.rank = int(os.environ["RANK"]) 155 | if args.multiprocessing_distributed: 156 | # For multiprocessing distributed training, rank needs to be the 157 | # global rank among all the processes 158 | args.rank = args.rank * ngpus_per_node + gpu 159 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 160 | world_size=args.world_size, rank=args.rank) 161 | # create model 162 | print("=> creating model '{}'".format(args.arch)) 163 | 164 | model = moco.builder.MoCo( 165 | models.__dict__[args.arch], 166 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp) 167 | 168 | classifier = Classifier(2048, args.class_num).cuda(args.gpu) 169 | 170 | # freeze all layers but the last fc 171 | for name, param in model.named_parameters(): 172 | if name not in ['fc.weight', 'fc.bias']: 173 | param.requires_grad = False 174 | # init the fc layer 175 | model.fc.weight.data.normal_(mean=0.0, std=0.01) 176 | model.fc.bias.data.zero_() 177 | 178 | # load from pre-trained, before DistributedDataParallel constructor 179 | if args.pretrained: 180 | if os.path.isfile(args.pretrained): 181 | print("=> loading checkpoint '{}'".format(args.pretrained)) 182 | checkpoint = torch.load(args.pretrained, map_location="cpu") 183 | 184 | # rename moco pre-trained keys 185 | state_dict = checkpoint['state_dict'] 186 | 187 | for k in list(state_dict.keys()): 188 | state_dict[k[len("module."):]] = state_dict[k] 189 | 190 | args.start_epoch = 0 191 | msg = model.load_state_dict(state_dict, strict=False) 192 | 193 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 194 | else: 195 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 196 | 197 | if args.distributed: 198 | # For multiprocessing distributed, DistributedDataParallel constructor 199 | # should always set the single device scope, otherwise, 200 | # DistributedDataParallel will use all available devices. 201 | if args.gpu is not None: 202 | torch.cuda.set_device(args.gpu) 203 | model.cuda(args.gpu) 204 | # When using a single GPU per process and per 205 | # DistributedDataParallel, we need to divide the batch size 206 | # ourselves based on the total number of GPUs we have 207 | args.batch_size = int(args.batch_size / ngpus_per_node) 208 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 209 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 210 | else: 211 | model.cuda() 212 | # DistributedDataParallel will divide and allocate batch_size to all 213 | # available GPUs if device_ids are not set 214 | model = torch.nn.parallel.DistributedDataParallel(model) 215 | elif args.gpu is not None: 216 | torch.cuda.set_device(args.gpu) 217 | model = model.cuda(args.gpu) 218 | else: 219 | # DataParallel will divide and allocate batch_size to all available GPUs 220 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 221 | model.features = torch.nn.DataParallel(model.features) 222 | model.cuda() 223 | else: 224 | model = torch.nn.DataParallel(model).cuda() 225 | 226 | # define loss function (criterion) and optimizer 227 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 228 | 229 | # optimize only the linear classifier 230 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 231 | assert len(parameters) == 2 # fc.weight, fc.bias 232 | 233 | optimizer = torch.optim.SGD(classifier.parameters(), args.lr, momentum=args.momentum, 234 | weight_decay=args.weight_decay) 235 | 236 | 237 | # optionally resume from a checkpoint 238 | if args.resume: 239 | if os.path.isfile(args.resume): 240 | print("=> loading checkpoint '{}'".format(args.resume)) 241 | if args.gpu is None: 242 | checkpoint = torch.load(args.resume) 243 | else: 244 | # Map model to be loaded to specified single gpu. 245 | loc = 'cuda:{}'.format(args.gpu) 246 | checkpoint = torch.load(args.resume, map_location=loc) 247 | args.start_epoch = checkpoint['epoch'] 248 | best_acc1 = checkpoint['best_acc1'] 249 | if args.gpu is not None: 250 | # best_acc1 may be from a checkpoint from a different GPU 251 | best_acc1 = best_acc1.to(args.gpu) 252 | model.load_state_dict(checkpoint['state_dict']) 253 | optimizer.load_state_dict(checkpoint['optimizer']) 254 | print("=> loaded checkpoint '{}' (epoch {})" 255 | .format(args.resume, checkpoint['epoch'])) 256 | else: 257 | print("=> no checkpoint found at '{}'".format(args.resume)) 258 | 259 | cudnn.benchmark = True 260 | 261 | # Data loading code 262 | traindir = os.path.join(args.data, 'train') 263 | valdir = os.path.join(args.data, 'test') 264 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 265 | std=[0.229, 0.224, 0.225]) 266 | 267 | train_dataset = datasets.ImageFolder( 268 | traindir, 269 | transforms.Compose([ 270 | transforms.RandomResizedCrop(224), 271 | transforms.RandomHorizontalFlip(), 272 | transforms.ToTensor(), 273 | normalize, 274 | ])) 275 | 276 | if args.distributed: 277 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 278 | else: 279 | train_sampler = None 280 | 281 | train_loader = torch.utils.data.DataLoader( 282 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 283 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 284 | 285 | val_loader = torch.utils.data.DataLoader( 286 | datasets.ImageFolder(valdir, transforms.Compose([ 287 | transforms.Resize(256), 288 | transforms.CenterCrop(224), 289 | transforms.ToTensor(), 290 | normalize, 291 | ])), 292 | batch_size=args.batch_size, shuffle=False, 293 | num_workers=args.workers, pin_memory=True) 294 | 295 | if args.evaluate: 296 | validate(val_loader, model, classifier, criterion, args) 297 | return 298 | 299 | for epoch in range(args.start_epoch, args.epochs): 300 | if args.distributed: 301 | train_sampler.set_epoch(epoch) 302 | adjust_learning_rate(optimizer, epoch, args) 303 | 304 | # train for one epoch 305 | train(train_loader, model, classifier, criterion, optimizer, epoch, args) 306 | 307 | # evaluate on validation set 308 | acc1 = validate(val_loader, model, classifier, criterion, args) 309 | 310 | # remember best acc@1 and save checkpoint 311 | is_best = acc1 > best_acc1 312 | best_acc1 = max(acc1, best_acc1) 313 | 314 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 315 | and args.rank % ngpus_per_node == 0): 316 | save_checkpoint({ 317 | 'epoch': epoch + 1, 318 | 'arch': args.arch, 319 | 'state_dict': model.state_dict(), 320 | 'best_acc1': best_acc1, 321 | 'optimizer': optimizer.state_dict(), 322 | }, is_best) 323 | # if epoch == args.start_epoch: 324 | # sanity_check(model.state_dict(), args.pretrained) 325 | 326 | 327 | def train(train_loader, model, classifier, criterion, optimizer, epoch, args): 328 | batch_time = AverageMeter('Time', ':6.3f') 329 | data_time = AverageMeter('Data', ':6.3f') 330 | losses = AverageMeter('Loss', ':.4e') 331 | top1 = AverageMeter('Acc@1', ':6.2f') 332 | top5 = AverageMeter('Acc@5', ':6.2f') 333 | progress = ProgressMeter( 334 | len(train_loader), 335 | [batch_time, data_time, losses, top1, top5], 336 | prefix="Epoch: [{}]".format(epoch)) 337 | 338 | """ 339 | Switch to eval mode: 340 | Under the protocol of linear classification on frozen features/models, 341 | it is not legitimate to change any part of the pre-trained model. 342 | BatchNorm in train mode may revise running mean/std (even if it receives 343 | no gradient), which are part of the model parameters too. 344 | """ 345 | model.eval() 346 | classifier.train(True) 347 | optimizer.zero_grad() 348 | 349 | end = time.time() 350 | for i, (images, target) in enumerate(train_loader): 351 | # measure data loading time 352 | data_time.update(time.time() - end) 353 | 354 | if args.gpu is not None: 355 | images = images.cuda(args.gpu, non_blocking=True) 356 | target = target.cuda(args.gpu, non_blocking=True) 357 | 358 | # compute output 359 | _, _, _, bp_out_feat = model.module.inference(images) 360 | 361 | output = classifier(bp_out_feat) 362 | loss = criterion(output, target) 363 | 364 | # measure accuracy and record loss 365 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 366 | losses.update(loss.item(), images.size(0)) 367 | top1.update(acc1[0], images.size(0)) 368 | top5.update(acc5[0], images.size(0)) 369 | 370 | # compute gradient and do SGD step 371 | optimizer.zero_grad() 372 | loss.backward() 373 | optimizer.step() 374 | 375 | # measure elapsed time 376 | batch_time.update(time.time() - end) 377 | end = time.time() 378 | 379 | if i % args.print_freq == 0: 380 | progress.display(i) 381 | 382 | 383 | def validate(val_loader, model, classifier, criterion, args): 384 | batch_time = AverageMeter('Time', ':6.3f') 385 | losses = AverageMeter('Loss', ':.4e') 386 | top1 = AverageMeter('Acc@1', ':6.2f') 387 | top5 = AverageMeter('Acc@5', ':6.2f') 388 | progress = ProgressMeter( 389 | len(val_loader), 390 | [batch_time, losses, top1, top5], 391 | prefix='Test: ') 392 | 393 | # switch to evaluate mode 394 | model.eval() 395 | classifier.eval() 396 | 397 | with torch.no_grad(): 398 | end = time.time() 399 | for i, (images, target) in enumerate(val_loader): 400 | if args.gpu is not None: 401 | images = images.cuda(args.gpu, non_blocking=True) 402 | target = target.cuda(args.gpu, non_blocking=True) 403 | 404 | # compute output 405 | _, _, _, bp_out_feat = model.module.inference(images) 406 | output = classifier(bp_out_feat) 407 | loss = criterion(output, target) 408 | 409 | # measure accuracy and record loss 410 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 411 | losses.update(loss.item(), images.size(0)) 412 | top1.update(acc1[0], images.size(0)) 413 | top5.update(acc5[0], images.size(0)) 414 | 415 | # measure elapsed time 416 | batch_time.update(time.time() - end) 417 | end = time.time() 418 | 419 | if i % args.print_freq == 0: 420 | progress.display(i) 421 | 422 | # TODO: this should also be done with the ProgressMeter 423 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 424 | .format(top1=top1, top5=top5)) 425 | 426 | return top1.avg 427 | 428 | 429 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 430 | torch.save(state, filename) 431 | if is_best: 432 | shutil.copyfile(filename, 'model_best.pth.tar') 433 | 434 | 435 | def sanity_check(state_dict, pretrained_weights): 436 | """ 437 | Linear classifier should not change any weights other than the linear layer. 438 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 439 | """ 440 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 441 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 442 | state_dict_pre = checkpoint['state_dict'] 443 | 444 | for k in list(state_dict.keys()): 445 | # only ignore fc layer 446 | if 'fc.weight' in k or 'fc.bias' in k: 447 | continue 448 | 449 | # name in pretrained model 450 | k_pre = 'module.encoder_q.' + k[len('module.'):] \ 451 | if k.startswith('module.') else 'module.encoder_q.' + k 452 | 453 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 454 | '{} is changed in linear classifier training.'.format(k) 455 | 456 | print("=> sanity check passed.") 457 | 458 | 459 | class AverageMeter(object): 460 | """Computes and stores the average and current value""" 461 | 462 | def __init__(self, name, fmt=':f'): 463 | self.name = name 464 | self.fmt = fmt 465 | self.reset() 466 | 467 | def reset(self): 468 | self.val = 0 469 | self.avg = 0 470 | self.sum = 0 471 | self.count = 0 472 | 473 | def update(self, val, n=1): 474 | self.val = val 475 | self.sum += val * n 476 | self.count += n 477 | self.avg = self.sum / self.count 478 | 479 | def __str__(self): 480 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 481 | return fmtstr.format(**self.__dict__) 482 | 483 | 484 | class ProgressMeter(object): 485 | def __init__(self, num_batches, meters, prefix=""): 486 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 487 | self.meters = meters 488 | self.prefix = prefix 489 | 490 | def display(self, batch): 491 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 492 | entries += [str(meter) for meter in self.meters] 493 | print('\t'.join(entries)) 494 | 495 | def _get_batch_fmtstr(self, num_batches): 496 | num_digits = len(str(num_batches // 1)) 497 | fmt = '{:' + str(num_digits) + 'd}' 498 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 499 | 500 | 501 | def adjust_learning_rate(optimizer, epoch, args): 502 | """Decay the learning rate based on schedule""" 503 | lr = args.lr 504 | for milestone in args.schedule: 505 | lr *= 0.1 if epoch >= milestone else 1. 506 | for param_group in optimizer.param_groups: 507 | param_group['lr'] = lr 508 | 509 | 510 | def accuracy(output, target, topk=(1,)): 511 | """Computes the accuracy over the k top predictions for the specified values of k""" 512 | with torch.no_grad(): 513 | maxk = max(topk) 514 | batch_size = target.size(0) 515 | 516 | _, pred = output.topk(maxk, 1, True, True) 517 | pred = pred.t() 518 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 519 | 520 | res = [] 521 | for k in topk: 522 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 523 | res.append(correct_k.mul_(100.0 / batch_size)) 524 | return res 525 | 526 | 527 | if __name__ == '__main__': 528 | main() 529 | -------------------------------------------------------------------------------- /main_moco.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import builtins 4 | import math 5 | import os 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | import torch.optim 17 | import torch.multiprocessing as mp 18 | import torch.utils.data 19 | import torch.utils.data.distributed 20 | import torchvision.transforms as transforms 21 | import torchvision.datasets as datasets 22 | import torchvision.models as models 23 | 24 | import moco.loader 25 | import moco.builder 26 | 27 | import numpy as np 28 | 29 | 30 | import torch.nn.functional as F 31 | 32 | torch.multiprocessing.set_sharing_strategy('file_system') 33 | 34 | os.environ[ 35 | "TORCH_DISTRIBUTED_DEBUG" 36 | ] = "DETAIL" 37 | 38 | model_names = sorted(name for name in models.__dict__ 39 | if name.islower() and not name.startswith("__") 40 | and callable(models.__dict__[name])) 41 | 42 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 43 | parser.add_argument('data', metavar='DIR', 44 | help='path to dataset') 45 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 46 | help='evaluate model on validation set') 47 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 48 | choices=model_names, 49 | help='model architecture: ' + 50 | ' | '.join(model_names) + 51 | ' (default: resnet50)') 52 | parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', 53 | help='number of data loading workers (default: 32)') 54 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 55 | help='number of total epochs to run') 56 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 57 | help='manual epoch number (useful on restarts)') 58 | parser.add_argument('-b', '--batch-size', default=256, type=int, 59 | metavar='N', 60 | help='mini-batch size (default: 256), this is the total ' 61 | 'batch size of all GPUs on the current node when ' 62 | 'using Data Parallel or Distributed Data Parallel') 63 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 64 | metavar='LR', help='initial learning rate', dest='lr') 65 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, 66 | help='learning rate schedule (when to drop lr by 10x)') 67 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 68 | help='momentum of SGD solver') 69 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 70 | metavar='W', help='weight decay (default: 1e-4)', 71 | dest='weight_decay') 72 | parser.add_argument('-p', '--print-freq', default=10, type=int, 73 | metavar='N', help='print frequency (default: 10)') 74 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 75 | help='path to latest checkpoint (default: none)') 76 | parser.add_argument('--world-size', default=-1, type=int, 77 | help='number of nodes for distributed training') 78 | parser.add_argument('--nu', default=0.01, type=float, 79 | help='weight of loss function') 80 | parser.add_argument('--rank', default=-1, type=int, 81 | help='node rank for distributed training') 82 | parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str, # tcp://224.66.41.62:23456 83 | help='url used to set up distributed training') 84 | parser.add_argument('--dist-backend', default='nccl', type=str, 85 | help='distributed backend') 86 | parser.add_argument('--seed', default=None, type=int, 87 | help='seed for initializing training. ') 88 | parser.add_argument('--gpu', default=None, type=int, 89 | help='GPU id to use.') 90 | parser.add_argument('--multiprocessing-distributed', action='store_true', 91 | help='Use multi-processing distributed training to launch ' 92 | 'N processes per node, which has N GPUs. This is the ' 93 | 'fastest way to use PyTorch for either single node or ' 94 | 'multi node data parallel training') 95 | 96 | # moco specific configs: 97 | parser.add_argument('--moco-dim', default=256, type=int, 98 | help='feature dimension (default: 128)') 99 | parser.add_argument('--moco-k', default=65536, type=int, 100 | help='queue size; number of negative keys (default: 65536)') 101 | parser.add_argument('--moco-m', default=0.999, type=float, 102 | help='moco momentum of updating key encoder (default: 0.999)') 103 | parser.add_argument('--moco-t', default=0.07, type=float, 104 | help='softmax temperature (default: 0.07)') 105 | 106 | # options for moco v2 107 | parser.add_argument('--mlp', action='store_true', 108 | help='use mlp head') 109 | parser.add_argument('--aug-plus', action='store_true', 110 | help='use moco v2 data augmentation') 111 | parser.add_argument('--cos', action='store_true', 112 | help='use cosine lr schedule') 113 | 114 | 115 | best_top1 = 0 116 | best_top5 = 0 117 | best_mAP = 0 118 | 119 | 120 | def main(): 121 | args = parser.parse_args() 122 | 123 | if args.seed is not None: 124 | random.seed(args.seed) 125 | torch.manual_seed(args.seed) 126 | cudnn.deterministic = True 127 | warnings.warn('You have chosen to seed training. ' 128 | 'This will turn on the CUDNN deterministic setting, ' 129 | 'which can slow down your training considerably! ' 130 | 'You may see unexpected behavior when restarting ' 131 | 'from checkpoints.') 132 | 133 | if args.gpu is not None: 134 | warnings.warn('You have chosen a specific GPU. This will completely ' 135 | 'disable data parallelism.') 136 | 137 | if args.dist_url == "env://" and args.world_size == -1: 138 | args.world_size = int(os.environ["WORLD_SIZE"]) 139 | 140 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 141 | 142 | ngpus_per_node = torch.cuda.device_count() 143 | 144 | if args.multiprocessing_distributed: 145 | # Since we have ngpus_per_node processes per node, the total world_size 146 | # needs to be adjusted accordingly 147 | args.world_size = ngpus_per_node * args.world_size 148 | # Use torch.multiprocessing.spawn to launch distributed processes: the 149 | # main_worker process function 150 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 151 | else: 152 | # Simply call main_worker function 153 | main_worker(args.gpu, ngpus_per_node, args) 154 | 155 | 156 | def main_worker(gpu, ngpus_per_node, args): 157 | global best_top1, best_top5, best_mAP 158 | args.gpu = gpu 159 | 160 | # suppress printing if not master 161 | if args.multiprocessing_distributed and args.gpu != 0: 162 | def print_pass(*args): 163 | pass 164 | 165 | builtins.print = print_pass 166 | 167 | if args.gpu is not None: 168 | print("Use GPU: {} for training".format(args.gpu)) 169 | 170 | if args.distributed: 171 | if args.dist_url == "env://" and args.rank == -1: 172 | args.rank = int(os.environ["RANK"]) 173 | if args.multiprocessing_distributed: 174 | # For multiprocessing distributed training, rank needs to be the 175 | # global rank among all the processes 176 | args.rank = args.rank * ngpus_per_node + gpu 177 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 178 | world_size=args.world_size, rank=args.rank) 179 | 180 | # create model 181 | print("=> creating model '{}'".format(args.arch)) 182 | 183 | model = moco.builder.MoCo( 184 | # models.__dict__[args.arch], 185 | args.arch, 186 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp) 187 | 188 | 189 | if args.distributed: 190 | # For multiprocessing distributed, DistributedDataParallel constructor 191 | # should always set the single device scope, otherwise, 192 | # DistributedDataParallel will use all available devices. 193 | if args.gpu is not None: 194 | torch.cuda.set_device(args.gpu) 195 | model.cuda(args.gpu) 196 | # When using a single GPU per process and per 197 | # DistributedDataParallel, we need to divide the batch size 198 | # ourselves based on the total number of GPUs we have 199 | args.batch_size = int(args.batch_size / ngpus_per_node) 200 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 201 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 202 | else: 203 | model.cuda() 204 | # DistributedDataParallel will divide and allocate batch_size to all 205 | # available GPUs if device_ids are not set 206 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) 207 | elif args.gpu is not None: 208 | torch.cuda.set_device(args.gpu) 209 | model = model.cuda(args.gpu) 210 | # comment out the following line for debugging 211 | raise NotImplementedError("Only DistributedDataParallel is supported.") 212 | else: 213 | # AllGather implementation (batch shuffle, queue update, etc.) in 214 | # this code only supports DistributedDataParallel. 215 | raise NotImplementedError("Only DistributedDataParallel is supported.") 216 | 217 | # define loss function (criterion) and optimizer 218 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 219 | 220 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 221 | momentum=args.momentum, 222 | weight_decay=args.weight_decay) 223 | 224 | # optionally resume from a checkpoint 225 | if args.resume: 226 | if os.path.isfile(args.resume): 227 | print("=> loading checkpoint '{}'".format(args.resume)) 228 | if args.gpu is None: 229 | checkpoint = torch.load(args.resume) 230 | else: 231 | # Map model to be loaded to specified single gpu. 232 | loc = 'cuda:{}'.format(args.gpu) 233 | checkpoint = torch.load(args.resume, map_location=loc) 234 | args.start_epoch = checkpoint['epoch'] 235 | model.load_state_dict(checkpoint['state_dict']) 236 | optimizer.load_state_dict(checkpoint['optimizer']) 237 | print("=> loaded checkpoint '{}' (epoch {})" 238 | .format(args.resume, checkpoint['epoch'])) 239 | else: 240 | print("=> no checkpoint found at '{}'".format(args.resume)) 241 | 242 | cudnn.benchmark = True 243 | 244 | # Data loading code 245 | traindir = os.path.join(args.data, 'train') 246 | valdir = os.path.join(args.data, 'test') 247 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 248 | std=[0.229, 0.224, 0.225]) 249 | if args.aug_plus: 250 | # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 251 | augmentation = [ 252 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 253 | transforms.RandomApply([ 254 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 255 | ], p=0.8), 256 | transforms.RandomGrayscale(p=0.2), 257 | transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5), 258 | transforms.RandomHorizontalFlip(), 259 | transforms.ToTensor(), 260 | normalize 261 | ] 262 | else: 263 | # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978 264 | augmentation = [ 265 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 266 | transforms.RandomGrayscale(p=0.2), 267 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 268 | transforms.RandomHorizontalFlip(), 269 | transforms.ToTensor(), 270 | normalize 271 | ] 272 | 273 | train_dataset = datasets.ImageFolder( 274 | traindir, 275 | moco.loader.TwoCropsTransform(transforms.Compose(augmentation))) 276 | 277 | if args.distributed: 278 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 279 | else: 280 | train_sampler = None 281 | 282 | train_loader = torch.utils.data.DataLoader( 283 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 284 | num_workers=args.workers, pin_memory=False, sampler=train_sampler, drop_last=True) 285 | 286 | val_loader = torch.utils.data.DataLoader( 287 | datasets.ImageFolder(valdir, transforms.Compose([ 288 | transforms.Resize(256), 289 | transforms.CenterCrop(224), 290 | transforms.ToTensor(), 291 | normalize, 292 | ])), 293 | batch_size=args.batch_size, shuffle=False, 294 | num_workers=args.workers, pin_memory=False) 295 | 296 | if args.evaluate: 297 | validate(val_loader, model, criterion, args) 298 | return 299 | 300 | indx = torch.arange(256) 301 | indx_train = indx 302 | for epoch in range(args.start_epoch, args.epochs): 303 | if args.distributed: 304 | train_sampler.set_epoch(epoch) 305 | adjust_learning_rate(optimizer, epoch, args) 306 | 307 | # train for one epoch 308 | 309 | backbone = train(train_loader, model, criterion, optimizer, epoch, args, indx_train) 310 | # evaluate on validation set 311 | top1, top5, mAP = validate(val_loader, model, criterion, args, epoch) 312 | 313 | best_top1 = max(top1, best_top1) 314 | 315 | best_top5 = max(top5, best_top5) 316 | 317 | best_mAP = max(mAP, best_mAP) 318 | 319 | 320 | print("best Top@1: %.4f" % (best_top1)) 321 | print("best Top@5: %.4f" % (best_top5)) 322 | print("best mAP: %.4f" % (best_mAP)) 323 | 324 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 325 | and args.rank % ngpus_per_node == 0): 326 | save_checkpoint({ 327 | 'epoch': epoch + 1, 328 | 'arch': args.arch, 329 | 'state_dict': model.state_dict(), 330 | 'optimizer': optimizer.state_dict(), 331 | }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch)) 332 | 333 | 334 | def train(train_loader, model, criterion, optimizer, epoch, args, indx): 335 | batch_time = AverageMeter('Time', ':6.3f') 336 | data_time = AverageMeter('Data', ':6.3f') 337 | losses = AverageMeter('Loss', ':.4e') 338 | top1 = AverageMeter('Acc@1', ':6.2f') 339 | top5 = AverageMeter('Acc@5', ':6.2f') 340 | progress = ProgressMeter( 341 | len(train_loader), 342 | [batch_time, data_time, losses, top1, top5], 343 | prefix="Epoch: [{}]".format(epoch)) 344 | # switch to train mode 345 | model.train() 346 | 347 | end = time.time() 348 | for i, (images, _) in enumerate(train_loader): 349 | # measure data loading time 350 | data_time.update(time.time() - end) 351 | 352 | if args.gpu is not None: 353 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 354 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 355 | 356 | output, target, output_max, target_max, backbone, featcov16, featmap, att_max, gradcam = model(im_q=images[0], im_k=images[1], epoch=epoch, iter=i, indx=indx) 357 | 358 | 359 | 360 | loss = criterion(output, target) 361 | loss_cam = F.kl_div(att_max.softmax(dim=-1).log(), gradcam.softmax(dim=-1), 362 | reduction='sum') 363 | 364 | total_loss = loss + args.nu * loss_cam 365 | 366 | # acc1/acc5 # (K+1)-way contrast classifier accuracy 367 | # measure accuracy and record loss 368 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 369 | losses.update(loss.item(), images[0].size(0)) 370 | top1.update(acc1[0], images[0].size(0)) 371 | top5.update(acc5[0], images[0].size(0)) 372 | 373 | # compute gradient and do SGD step 374 | optimizer.zero_grad() 375 | total_loss.backward() 376 | optimizer.step() 377 | 378 | # measure elapsed time 379 | batch_time.update(time.time() - end) 380 | end = time.time() 381 | 382 | if i % args.print_freq == 0: 383 | progress.display(i) 384 | return backbone 385 | 386 | 387 | 388 | 389 | def validate(val_loader, model, criterion, args, epoch): 390 | batch_time = AverageMeter('Time', ':6.3f') 391 | losses = AverageMeter('Loss', ':.4e') 392 | top1 = AverageMeter('top1', ':6.2f') 393 | top5 = AverageMeter('top5', ':6.2f') 394 | progress = ProgressMeter( 395 | len(val_loader), 396 | [batch_time, losses, top1, top5], 397 | prefix='Test: ') 398 | 399 | # switch to evaluate mode 400 | model.eval() 401 | 402 | with torch.no_grad(): 403 | end = time.time() 404 | data = torch.zeros(1, 2048) 405 | label = torch.zeros(1) 406 | att_map = torch.zeros(1, 32, 7, 7) 407 | for i, (images, target) in enumerate(val_loader): 408 | if args.gpu is not None: 409 | images = images.cuda(args.gpu, non_blocking=True) 410 | target = target.cuda(args.gpu, non_blocking=True) 411 | 412 | # compute output 413 | _, _, featcov16, output = model.module.inference(images) 414 | 415 | data = torch.cat((data.cuda(args.gpu), output.cuda(args.gpu)), 0) 416 | 417 | label = torch.cat((label.cuda(args.gpu), target), 0) 418 | 419 | att_map = torch.cat((att_map.cuda(args.gpu), featcov16.cuda(args.gpu)), 0) 420 | 421 | 422 | data = data[torch.arange(data.size(0)) != 0] 423 | label = label[torch.arange(label.size(0)) != 0] 424 | 425 | 426 | att_map = att_map[torch.arange(att_map.size(0)) != 0] 427 | max_pre = torch.var(att_map, dim=(2, 3), keepdim=False) 428 | max_pre = torch.mean(max_pre, dim=0, keepdim=False) 429 | a, idx1 = torch.sort(max_pre, descending=True) 430 | 431 | 432 | topN1 = [] 433 | topN5 = [] 434 | MAP = [] 435 | for j in range(data.size(0)): 436 | query_feat = data[j, :] 437 | query_label = label[j].item() 438 | 439 | dict = data[torch.arange(data.size(0)) != j] 440 | sim_label = label[torch.arange(label.size(0)) != j] 441 | 442 | similarity = torch.mv(dict, query_feat) 443 | 444 | table = torch.zeros(similarity.size(0), 2) 445 | table[:, 0] = similarity 446 | table[:, 1] = sim_label 447 | table = table.cpu().detach().numpy() 448 | 449 | index = np.argsort(table[:, 0])[::-1] 450 | 451 | T = table[index] 452 | #top-1 453 | if T[0,1] == query_label: 454 | topN1.append(1) 455 | else: 456 | topN1.append(0) 457 | #top-5 458 | if np.sum(T[:5, -1] == query_label) > 0: 459 | topN5.append(1) 460 | else: 461 | topN5.append(0) 462 | 463 | #mAP 464 | check = np.where(T[:, 1] == query_label) 465 | check = check[0] 466 | AP = 0 467 | for k in range(len(check)): 468 | temp = (k+1)/(check[k]+1) 469 | AP = AP + temp 470 | AP = AP/(len(check)) 471 | MAP.append(AP) 472 | 473 | top1 = np.mean(topN1) 474 | top5 = np.mean(topN5) 475 | mAP = np.mean(MAP) 476 | 477 | # TODO: this should also be done with the ProgressMeter 478 | print(' * Top@1 {top1:.3f} Top@5 {top5:.3f} mAP {mAP:.3f}' 479 | .format(top1=top1, top5=top5, mAP=mAP)) 480 | 481 | return top1, top5, mAP 482 | 483 | 484 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 485 | torch.save(state, filename) 486 | if is_best: 487 | shutil.copyfile(filename, 'model_best.pth.tar') 488 | 489 | 490 | class AverageMeter(object): 491 | """Computes and stores the average and current value""" 492 | 493 | def __init__(self, name, fmt=':f'): 494 | self.name = name 495 | self.fmt = fmt 496 | self.reset() 497 | 498 | def reset(self): 499 | self.val = 0 500 | self.avg = 0 501 | self.sum = 0 502 | self.count = 0 503 | 504 | def update(self, val, n=1): 505 | self.val = val 506 | self.sum += val * n 507 | self.count += n 508 | self.avg = self.sum / self.count 509 | 510 | def __str__(self): 511 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 512 | return fmtstr.format(**self.__dict__) 513 | 514 | 515 | class ProgressMeter(object): 516 | def __init__(self, num_batches, meters, prefix=""): 517 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 518 | self.meters = meters 519 | self.prefix = prefix 520 | 521 | def display(self, batch): 522 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 523 | entries += [str(meter) for meter in self.meters] 524 | print('\t'.join(entries)) 525 | 526 | def _get_batch_fmtstr(self, num_batches): 527 | num_digits = len(str(num_batches // 1)) 528 | fmt = '{:' + str(num_digits) + 'd}' 529 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 530 | 531 | 532 | def adjust_learning_rate(optimizer, epoch, args): 533 | """Decay the learning rate based on schedule""" 534 | lr = args.lr 535 | if args.cos: # cosine lr schedule 536 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 537 | else: # stepwise lr schedule 538 | for milestone in args.schedule: 539 | lr *= 0.1 if epoch >= milestone else 1. 540 | for param_group in optimizer.param_groups: 541 | param_group['lr'] = lr 542 | 543 | 544 | def accuracy(output, target, topk=(1,)): 545 | """Computes the accuracy over the k top predictions for the specified values of k""" 546 | with torch.no_grad(): 547 | maxk = max(topk) 548 | batch_size = target.size(0) 549 | 550 | _, pred = output.topk(maxk, 1, True, True) 551 | pred = pred.t() 552 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 553 | 554 | res = [] 555 | for k in topk: 556 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 557 | res.append(correct_k.mul_(100.0 / batch_size)) 558 | return res 559 | 560 | 561 | if __name__ == '__main__': 562 | main() 563 | -------------------------------------------------------------------------------- /moco/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /moco/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/moco/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /moco/__pycache__/builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/moco/__pycache__/builder.cpython-38.pyc -------------------------------------------------------------------------------- /moco/__pycache__/loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/moco/__pycache__/loader.cpython-38.pyc -------------------------------------------------------------------------------- /moco/builder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from resnet_output import resnet50 5 | import numpy as np 6 | 7 | 8 | class MoCo(nn.Module): 9 | """ 10 | Build a MoCo model with: a query encoder, a key encoder, and a queue 11 | https://arxiv.org/abs/1911.05722 12 | """ 13 | 14 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): 15 | """ 16 | dim: feature dimension (default: 128) 17 | K: queue size; number of negative keys (default: 65536) 18 | m: moco momentum of updating key encoder (default: 0.999) 19 | T: softmax temperature (default: 0.07) 20 | """ 21 | super(MoCo, self).__init__() 22 | 23 | self.K = K 24 | self.m = m 25 | self.T = T 26 | 27 | # create the encoders 28 | self.encoder_q = resnet50(pretrained=True, num_classes=1000) 29 | # self.encoder_q = base_encoder(num_classes=dim) 30 | self.encoder_k = resnet50(pretrained=True, num_classes=1000) 31 | # self.encoder_k = base_encoder(num_classes=dim) 32 | 33 | self.encoder_q.fc = nn.Linear(2048, dim) 34 | self.encoder_k.fc = nn.Linear(2048, dim) 35 | 36 | if mlp: # hack: brute-force replacement 37 | dim_mlp = self.encoder_q.fc.weight.shape[1] 38 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 39 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 40 | 41 | 42 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 43 | param_k.data.copy_(param_q.data) # initialize 44 | param_k.requires_grad = False # not update by gradient 45 | 46 | # create the queue 47 | self.register_buffer("queue", torch.randn(dim, 48 | K)) 49 | self.queue = nn.functional.normalize(self.queue, dim=0) 50 | 51 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 52 | 53 | # create the queue_max 54 | self.register_buffer("queue_max", torch.randn(dim, 55 | K)) 56 | self.queue_max = nn.functional.normalize(self.queue_max, dim=0) 57 | 58 | self.register_buffer("queue_ptr_max", torch.zeros(1, dtype=torch.long)) 59 | # add projection 60 | 61 | self.bilinear = 32 62 | 63 | self.conv16 = nn.Conv2d(2048, self.bilinear, kernel_size=1, stride=1, padding=0, 64 | bias=False) 65 | 66 | self.bn16 = nn.BatchNorm2d(self.bilinear) 67 | 68 | self.conv16_2 = nn.Conv2d(self.bilinear, 1, kernel_size=1, stride=1, padding=0, 69 | bias=False) 70 | self.bn16_2 = nn.BatchNorm2d(1) 71 | 72 | 73 | self.relu = nn.ReLU(inplace=True) 74 | 75 | self.avgpool = nn.AvgPool2d(7, stride=1) 76 | 77 | self.fc = nn.Linear(2048, dim) 78 | self.qmax = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), self.fc) 79 | self.kmax = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), self.fc) 80 | 81 | 82 | 83 | @torch.no_grad() 84 | def _momentum_update_key_encoder(self): 85 | """ 86 | Momentum update of the key encoder 87 | """ 88 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 89 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 90 | 91 | for param_q, param_k in zip(self.qmax.parameters(), self.kmax.parameters()): 92 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 93 | 94 | @torch.no_grad() 95 | def _dequeue_and_enqueue(self, keys): 96 | # gather keys before updating queue 97 | keys = concat_all_gather(keys) 98 | 99 | batch_size = keys.shape[0] 100 | 101 | ptr = int(self.queue_ptr) 102 | assert self.K % batch_size == 0 # for simplicity 103 | 104 | # replace the keys at ptr (dequeue and enqueue) 105 | self.queue[:, ptr:ptr + batch_size] = keys.T 106 | ptr = (ptr + batch_size) % self.K # move pointer 107 | 108 | self.queue_ptr[0] = ptr 109 | 110 | def _dequeue_and_enqueue_max(self, keys): 111 | # gather keys before updating queue 112 | keys = concat_all_gather(keys) 113 | 114 | batch_size = keys.shape[0] 115 | 116 | ptr = int(self.queue_ptr_max) 117 | assert self.K % batch_size == 0 # for simplicity 118 | 119 | # replace the keys at ptr (dequeue and enqueue) 120 | self.queue_max[:, ptr:ptr + batch_size] = keys.T 121 | ptr = (ptr + batch_size) % self.K # move pointer 122 | 123 | self.queue_ptr_max[0] = ptr 124 | 125 | def max_mask(self, featmap, indx): 126 | featcov16 = self.conv16(featmap) 127 | featcov16 = self.bn16(featcov16) 128 | featcov16 = self.relu(featcov16) 129 | 130 | 131 | 132 | img, _ = torch.max(featcov16, axis=1) 133 | img = img - torch.min(img) 134 | att_max = img / (1e-7 + torch.max(img)) 135 | 136 | img = att_max[:, None, :, :] 137 | img = img.repeat(1, 2048, 1, 1) 138 | 139 | 140 | 141 | PFM = featmap.cuda() * img.cuda() 142 | aa = self.avgpool(PFM) 143 | bp_out_feat = aa.view(aa.size(0), -1) 144 | bp_out_feat_max = nn.functional.normalize(bp_out_feat, dim=1) 145 | 146 | 147 | 148 | return bp_out_feat_max, att_max 149 | 150 | def feat_bilinear(self, featmap): 151 | featcov16 = self.conv16(featmap) 152 | featcov16 = self.bn16(featcov16) 153 | featcov16 = self.relu(featcov16) 154 | 155 | feat_matrix = torch.zeros(featcov16.size(0), self.bilinear, 2048) 156 | for i in range(self.bilinear): 157 | matrix = featcov16[:, i, :, :] 158 | matrix = matrix[:, None, :, :] 159 | matrix = matrix.repeat(1, 2048, 1, 1) 160 | PFM = featmap * matrix 161 | aa = self.avgpool(PFM) 162 | 163 | feat_matrix[:, i, :] = aa.view(aa.size(0), -1) 164 | 165 | bp_out_feat = feat_matrix.view(feat_matrix.size(0), -1) 166 | 167 | return bp_out_feat 168 | 169 | @torch.no_grad() 170 | def _batch_shuffle_ddp(self, x): 171 | """ 172 | Batch shuffle, for making use of BatchNorm. 173 | *** Only support DistributedDataParallel (DDP) model. *** 174 | """ 175 | # gather from all gpus 176 | batch_size_this = x.shape[0] 177 | x_gather = concat_all_gather(x) 178 | batch_size_all = x_gather.shape[0] 179 | 180 | num_gpus = batch_size_all // batch_size_this 181 | 182 | # random shuffle index 183 | idx_shuffle = torch.randperm(batch_size_all).cuda() 184 | 185 | # broadcast to all gpus 186 | torch.distributed.broadcast(idx_shuffle, src=0) 187 | 188 | # index for restoring 189 | idx_unshuffle = torch.argsort(idx_shuffle) 190 | 191 | # shuffled index for this gpu 192 | gpu_idx = torch.distributed.get_rank() 193 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 194 | 195 | return x_gather[idx_this], idx_unshuffle 196 | 197 | @torch.no_grad() 198 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 199 | """ 200 | Undo batch shuffle. 201 | *** Only support DistributedDataParallel (DDP) model. *** 202 | """ 203 | # gather from all gpus 204 | batch_size_this = x.shape[0] 205 | x_gather = concat_all_gather(x) 206 | batch_size_all = x_gather.shape[0] 207 | 208 | num_gpus = batch_size_all // batch_size_this 209 | 210 | # restored index for this gpu 211 | gpu_idx = torch.distributed.get_rank() 212 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 213 | 214 | return x_gather[idx_this] 215 | 216 | def forward(self, im_q, im_k, epoch, iter, indx): 217 | """ 218 | Input: 219 | im_q: a batch of query images 220 | im_k: a batch of key images 221 | Output: 222 | logits, targets 223 | """ 224 | 225 | # compute query features 226 | q, _, featmap = self.encoder_q(im_q) # queries: NxC 227 | q = nn.functional.normalize(q, dim=1) 228 | 229 | 230 | # max bilinear q 231 | q_max, att_max = self.max_mask(featmap, indx) 232 | embedding_q = self.qmax(q_max.cuda()) 233 | q_max_proj = nn.functional.normalize(embedding_q, dim=1) 234 | 235 | # compute key features 236 | with torch.no_grad(): # no gradient to keys 237 | self._momentum_update_key_encoder() # update the key encoder 238 | 239 | # shuffle for making use of BN 240 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 241 | 242 | k, _, featmap_k = self.encoder_k(im_k) # keys: NxC 243 | k = nn.functional.normalize(k, dim=1) 244 | 245 | # max bilinear k 246 | k_max, _ = self.max_mask(featmap_k, indx) 247 | embedding_k = self.kmax(k_max.cuda()) 248 | k_max_proj = nn.functional.normalize(embedding_k, dim=1) 249 | 250 | # undo shuffle 251 | k = self._batch_unshuffle_ddp(k, idx_unshuffle).detach() 252 | 253 | k_max_proj = self._batch_unshuffle_ddp(k_max_proj, idx_unshuffle).detach() 254 | 255 | # compute logits 256 | # Einstein sum is more intuitive 257 | # positive logits: Nx1 258 | 259 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 260 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 261 | 262 | logits = torch.cat([l_pos, l_neg], dim=1) / self.T 263 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 264 | 265 | #max 266 | l_pos_max = torch.einsum('nc,nc->n', [q_max_proj, k_max_proj]).unsqueeze(-1) 267 | l_neg_max = torch.einsum('nc,ck->nk', [q_max_proj, self.queue_max.clone().detach()]) 268 | 269 | logits_max = torch.cat([l_pos_max, l_neg_max], dim=1) / self.T 270 | labels_max = torch.zeros(logits_max.shape[0], dtype=torch.long).cuda() 271 | 272 | # dequeue and enqueue 273 | self._dequeue_and_enqueue(k) 274 | 275 | self._dequeue_and_enqueue_max(k_max_proj) 276 | 277 | featcov16 = self.conv16(featmap) 278 | featcov16 = self.bn16(featcov16) 279 | featcov16 = self.relu(featcov16) 280 | 281 | criterion = nn.CrossEntropyLoss() 282 | CE = criterion(logits, labels) 283 | 284 | grad_wrt_act1 = torch.autograd.grad(outputs=CE, inputs=featmap, 285 | grad_outputs=torch.ones_like(CE), retain_graph=True, 286 | allow_unused=True)[0] 287 | 288 | gradcam = torch.relu((featmap * grad_wrt_act1).sum(dim=1)) 289 | 290 | 291 | return logits, labels, logits_max, labels_max, self.encoder_q, featcov16, featmap, att_max, gradcam 292 | 293 | def inference(self, img): 294 | projfeat, feat, featmap = self.encoder_q(img) 295 | 296 | featcov16 = self.conv16(featmap) 297 | featcov16 = self.bn16(featcov16) 298 | featcov16 = self.relu(featcov16) 299 | 300 | 301 | img = featcov16.cpu().detach().numpy() 302 | img = np.max(img, axis=1) 303 | img = img - np.min(img) 304 | img = img / (1e-7 + np.max(img)) 305 | img = torch.from_numpy(img) 306 | 307 | 308 | 309 | img = img[:, None, :, :] 310 | img = img.repeat(1, 2048, 1, 1) 311 | PFM = featmap.cuda() * img.cuda() 312 | aa = self.avgpool(PFM) 313 | bp_out_feat = aa.view(aa.size(0), -1) 314 | bp_out_feat = nn.functional.normalize(bp_out_feat, dim=1) 315 | 316 | feat = nn.functional.normalize(feat, dim=1) 317 | 318 | return projfeat, feat, featcov16, bp_out_feat 319 | 320 | 321 | # utils 322 | @torch.no_grad() 323 | def concat_all_gather(tensor): 324 | """ 325 | Performs all_gather operation on the provided tensors. 326 | *** Warning ***: torch.distributed.all_gather has no gradient. 327 | """ 328 | tensors_gather = [torch.ones_like(tensor) 329 | for _ in range(torch.distributed.get_world_size())] 330 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 331 | 332 | output = torch.cat(tensors_gather, dim=0) 333 | return output 334 | -------------------------------------------------------------------------------- /moco/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from PIL import ImageFilter 3 | import random 4 | 5 | 6 | class TwoCropsTransform: 7 | """Take two random crops of one image as the query and key.""" 8 | 9 | def __init__(self, base_transform): 10 | self.base_transform = base_transform 11 | 12 | def __call__(self, x): 13 | q = self.base_transform(x) 14 | k = self.base_transform(x) 15 | return [q, k] 16 | 17 | 18 | class GaussianBlur(object): 19 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 20 | 21 | def __init__(self, sigma=[.1, 2.]): 22 | self.sigma = sigma 23 | 24 | def __call__(self, x): 25 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 26 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 27 | return x 28 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .byol import BYOL 2 | from torchvision.models import resnet50, resnet18 3 | import torch 4 | from .backbones import resnet50_cub200,resnet50_stanfordcars, resnet18_cifar_variant2,resnet18_cifar_variant1,resnet50_aircrafts 5 | 6 | def get_backbone(backbone, castrate=True): 7 | backbone = eval(f"{backbone}(pretrained=True)") 8 | 9 | if castrate: 10 | backbone.output_dim = backbone.fc.in_features 11 | backbone.fc = torch.nn.Identity() 12 | 13 | return backbone 14 | 15 | 16 | def get_model(model_cfg): 17 | if model_cfg.name == 'byol': 18 | model = BYOL(get_backbone(model_cfg.backbone)) 19 | else: 20 | raise NotImplementedError 21 | return model -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/byol.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/__pycache__/byol.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar_resnet_1 import resnet18 as resnet18_cifar_variant1 2 | from .cifar_resnet_2 import ResNet18 as resnet18_cifar_variant2 3 | 4 | from .cub_resnet_1 import resnet50 as resnet50_cub200 5 | from .cub_resnet_1 import resnet50 as resnet50_stanfordcars 6 | from .cub_resnet_1 import resnet50 as resnet50_aircrafts 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /models/backbones/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/backbones/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbones/__pycache__/cifar_resnet_1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/backbones/__pycache__/cifar_resnet_1.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbones/__pycache__/cifar_resnet_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/backbones/__pycache__/cifar_resnet_2.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbones/__pycache__/cub_resnet_1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/backbones/__pycache__/cub_resnet_1.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbones/cifar_resnet_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import torch.utils.model_zoo as model_zoo 5 | # https://raw.githubusercontent.com/huyvnphan/PyTorch_CIFAR10/master/cifar10_models/resnet.py 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=dilation, groups=groups, bias=False, dilation=dilation) 21 | 22 | 23 | def conv1x1(in_planes, out_planes, stride=1): 24 | """1x1 convolution""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 32 | base_width=64, dilation=1, norm_layer=None): 33 | super(BasicBlock, self).__init__() 34 | if norm_layer is None: 35 | norm_layer = nn.BatchNorm2d 36 | if groups != 1 or base_width != 64: 37 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 38 | if dilation > 1: 39 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 40 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = norm_layer(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = norm_layer(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | identity = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | identity = self.downsample(x) 61 | 62 | out += identity 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 72 | base_width=64, dilation=1, norm_layer=None): 73 | super(Bottleneck, self).__init__() 74 | if norm_layer is None: 75 | norm_layer = nn.BatchNorm2d 76 | width = int(planes * (base_width / 64.)) * groups 77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 78 | self.conv1 = conv1x1(inplanes, width) 79 | self.bn1 = norm_layer(width) 80 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 81 | self.bn2 = norm_layer(width) 82 | self.conv3 = conv1x1(width, planes * self.expansion) 83 | self.bn3 = norm_layer(planes * self.expansion) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv3(out) 100 | out = self.bn3(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out += identity 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | 111 | class ResNet(nn.Module): 112 | 113 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 114 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 115 | norm_layer=None): 116 | super(ResNet, self).__init__() 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | self._norm_layer = norm_layer 120 | 121 | self.inplanes = 64 122 | self.dilation = 1 123 | if replace_stride_with_dilation is None: 124 | # each element in the tuple indicates if we should replace 125 | # the 2x2 stride with a dilated convolution instead 126 | replace_stride_with_dilation = [False, False, False] 127 | if len(replace_stride_with_dilation) != 3: 128 | raise ValueError("replace_stride_with_dilation should be None " 129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 130 | self.groups = groups 131 | self.base_width = width_per_group 132 | 133 | ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 135 | ## END 136 | 137 | self.bn1 = norm_layer(self.inplanes) 138 | self.relu = nn.ReLU(inplace=True) 139 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 140 | self.layer1 = self._make_layer(block, 64, layers[0]) 141 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 142 | dilate=replace_stride_with_dilation[0]) 143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 144 | dilate=replace_stride_with_dilation[1]) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 146 | dilate=replace_stride_with_dilation[2]) 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | self.fc = nn.Linear(512 * block.expansion, num_classes) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 160 | if zero_init_residual: 161 | for m in self.modules(): 162 | if isinstance(m, Bottleneck): 163 | nn.init.constant_(m.bn3.weight, 0) 164 | elif isinstance(m, BasicBlock): 165 | nn.init.constant_(m.bn2.weight, 0) 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 168 | norm_layer = self._norm_layer 169 | downsample = None 170 | previous_dilation = self.dilation 171 | if dilate: 172 | self.dilation *= stride 173 | stride = 1 174 | if stride != 1 or self.inplanes != planes * block.expansion: 175 | downsample = nn.Sequential( 176 | conv1x1(self.inplanes, planes * block.expansion, stride), 177 | norm_layer(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 182 | self.base_width, previous_dilation, norm_layer)) 183 | self.inplanes = planes * block.expansion 184 | for _ in range(1, blocks): 185 | layers.append(block(self.inplanes, planes, groups=self.groups, 186 | base_width=self.base_width, dilation=self.dilation, 187 | norm_layer=norm_layer)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | # x = self.maxpool(x) 196 | 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | 202 | x = self.avgpool(x) 203 | x = x.reshape(x.size(0), -1) 204 | x = self.fc(x) 205 | 206 | return x 207 | 208 | 209 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 210 | model = ResNet(block, layers, **kwargs) 211 | 212 | if pretrained: 213 | print("pretrain resnet arch:") 214 | print(arch) 215 | model.load_state_dict(model_zoo.load_url(model_urls[arch])) 216 | # script_dir = os.path.dirname(__file__) 217 | # state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device) 218 | # model.load_state_dict(state_dict) 219 | return model 220 | 221 | 222 | def resnet18(pretrained=False, progress=True, device='cpu', **kwargs): 223 | """Constructs a ResNet-18 model. 224 | 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | progress (bool): If True, displays a progress bar of the download to stderr 228 | """ 229 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device, 230 | **kwargs) 231 | 232 | 233 | def resnet34(pretrained=False, progress=True, device='cpu', **kwargs): 234 | """Constructs a ResNet-34 model. 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device, 241 | **kwargs) 242 | 243 | 244 | def resnet50(pretrained=False, progress=True, device='cpu', **kwargs): 245 | """Constructs a ResNet-50 model. 246 | 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | progress (bool): If True, displays a progress bar of the download to stderr 250 | """ 251 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device, 252 | **kwargs) 253 | 254 | 255 | def resnet101(pretrained=False, progress=True, device='cpu', **kwargs): 256 | """Constructs a ResNet-101 model. 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device, 263 | **kwargs) 264 | 265 | 266 | def resnet152(pretrained=False, progress=True, device='cpu', **kwargs): 267 | """Constructs a ResNet-152 model. 268 | 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | progress (bool): If True, displays a progress bar of the download to stderr 272 | """ 273 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device, 274 | **kwargs) 275 | 276 | 277 | def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs): 278 | """Constructs a ResNeXt-50 32x4d model. 279 | 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | kwargs['groups'] = 32 285 | kwargs['width_per_group'] = 4 286 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 287 | pretrained, progress, device, **kwargs) 288 | 289 | 290 | def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs): 291 | """Constructs a ResNeXt-101 32x8d model. 292 | 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | progress (bool): If True, displays a progress bar of the download to stderr 296 | """ 297 | kwargs['groups'] = 32 298 | kwargs['width_per_group'] = 8 299 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 300 | pretrained, progress, device, **kwargs) 301 | 302 | 303 | if __name__ == "__main__": 304 | model = resnet18() 305 | print(sum(p.numel() for p in model.parameters() if p.requires_grad)) # 11173962 -------------------------------------------------------------------------------- /models/backbones/cifar_resnet_2.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 2 | '''ResNet in PyTorch. 3 | 4 | For Pre-activation ResNet, see 'preact_resnet.py'. 5 | 6 | Reference: 7 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 8 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = nn.Conv2d( 21 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 24 | stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, 31 | kernel_size=1, stride=stride, bias=False), 32 | nn.BatchNorm2d(self.expansion*planes) 33 | ) 34 | 35 | def forward(self, x): 36 | out = F.relu(self.bn1(self.conv1(x))) 37 | out = self.bn2(self.conv2(out)) 38 | out += self.shortcut(x) 39 | out = F.relu(out) 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 51 | stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, self.expansion * 54 | planes, kernel_size=1, bias=False) 55 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 56 | 57 | self.shortcut = nn.Sequential() 58 | if stride != 1 or in_planes != self.expansion*planes: 59 | self.shortcut = nn.Sequential( 60 | nn.Conv2d(in_planes, self.expansion*planes, 61 | kernel_size=1, stride=stride, bias=False), 62 | nn.BatchNorm2d(self.expansion*planes) 63 | ) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = F.relu(self.bn2(self.conv2(out))) 68 | out = self.bn3(self.conv3(out)) 69 | out += self.shortcut(x) 70 | out = F.relu(out) 71 | return out 72 | 73 | 74 | class ResNet(nn.Module): 75 | def __init__(self, block, num_blocks, num_classes=10): 76 | super(ResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 80 | stride=1, padding=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(64) 82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 86 | self.fc = nn.Linear(512*block.expansion, num_classes) 87 | 88 | def _make_layer(self, block, planes, num_blocks, stride): 89 | strides = [stride] + [1]*(num_blocks-1) 90 | layers = [] 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, stride)) 93 | self.in_planes = planes * block.expansion 94 | return nn.Sequential(*layers) 95 | 96 | def forward(self, x): 97 | out = F.relu(self.bn1(self.conv1(x))) 98 | out = self.layer1(out) 99 | out = self.layer2(out) 100 | out = self.layer3(out) 101 | out = self.layer4(out) 102 | out = F.avg_pool2d(out, 4) 103 | out = out.view(out.size(0), -1) 104 | out = self.fc(out) 105 | return out 106 | 107 | 108 | def ResNet18(): 109 | return ResNet(BasicBlock, [2, 2, 2, 2]) 110 | 111 | 112 | def ResNet34(): 113 | return ResNet(BasicBlock, [3, 4, 6, 3]) 114 | 115 | 116 | def ResNet50(): 117 | return ResNet(Bottleneck, [3, 4, 6, 3]) 118 | 119 | 120 | def ResNet101(): 121 | return ResNet(Bottleneck, [3, 4, 23, 3]) 122 | 123 | 124 | def ResNet152(): 125 | return ResNet(Bottleneck, [3, 8, 36, 3]) 126 | 127 | 128 | def test(): 129 | net = ResNet18() 130 | print(sum(p.numel() for p in net.parameters() if p.requires_grad)) 131 | import torchvision 132 | net2 = torchvision.models.resnet18() 133 | print(sum(p.numel() for p in net2.parameters() if p.requires_grad)) 134 | # y = net(torch.randn(1, 3, 32, 32)) 135 | # print(y.size()) 136 | # 11173962 137 | # 11689512 138 | # test() -------------------------------------------------------------------------------- /models/backbones/cub_resnet_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import torch.utils.model_zoo as model_zoo 5 | # https://raw.githubusercontent.com/huyvnphan/PyTorch_CIFAR10/master/cifar10_models/resnet.py 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=dilation, groups=groups, bias=False, dilation=dilation) 21 | 22 | 23 | def conv1x1(in_planes, out_planes, stride=1): 24 | """1x1 convolution""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 32 | base_width=64, dilation=1, norm_layer=None): 33 | super(BasicBlock, self).__init__() 34 | if norm_layer is None: 35 | norm_layer = nn.BatchNorm2d 36 | if groups != 1 or base_width != 64: 37 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 38 | if dilation > 1: 39 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 40 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = norm_layer(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = norm_layer(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | identity = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | identity = self.downsample(x) 61 | 62 | out += identity 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 72 | base_width=64, dilation=1, norm_layer=None): 73 | super(Bottleneck, self).__init__() 74 | if norm_layer is None: 75 | norm_layer = nn.BatchNorm2d 76 | width = int(planes * (base_width / 64.)) * groups 77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 78 | self.conv1 = conv1x1(inplanes, width) 79 | self.bn1 = norm_layer(width) 80 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 81 | self.bn2 = norm_layer(width) 82 | self.conv3 = conv1x1(width, planes * self.expansion) 83 | self.bn3 = norm_layer(planes * self.expansion) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv3(out) 100 | out = self.bn3(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out += identity 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | 111 | class ResNet(nn.Module): 112 | 113 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 114 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 115 | norm_layer=None): 116 | super(ResNet, self).__init__() 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | self._norm_layer = norm_layer 120 | 121 | self.inplanes = 64 122 | self.dilation = 1 123 | if replace_stride_with_dilation is None: 124 | # each element in the tuple indicates if we should replace 125 | # the 2x2 stride with a dilated convolution instead 126 | replace_stride_with_dilation = [False, False, False] 127 | if len(replace_stride_with_dilation) != 3: 128 | raise ValueError("replace_stride_with_dilation should be None " 129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 130 | self.groups = groups 131 | self.base_width = width_per_group 132 | 133 | ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 135 | ## END 136 | 137 | self.bn1 = norm_layer(self.inplanes) 138 | self.relu = nn.ReLU(inplace=True) 139 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 140 | self.layer1 = self._make_layer(block, 64, layers[0]) 141 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 142 | dilate=replace_stride_with_dilation[0]) 143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 144 | dilate=replace_stride_with_dilation[1]) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 146 | dilate=replace_stride_with_dilation[2]) 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | self.fc = nn.Linear(512 * block.expansion, num_classes) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 160 | if zero_init_residual: 161 | for m in self.modules(): 162 | if isinstance(m, Bottleneck): 163 | nn.init.constant_(m.bn3.weight, 0) 164 | elif isinstance(m, BasicBlock): 165 | nn.init.constant_(m.bn2.weight, 0) 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 168 | norm_layer = self._norm_layer 169 | downsample = None 170 | previous_dilation = self.dilation 171 | if dilate: 172 | self.dilation *= stride 173 | stride = 1 174 | if stride != 1 or self.inplanes != planes * block.expansion: 175 | downsample = nn.Sequential( 176 | conv1x1(self.inplanes, planes * block.expansion, stride), 177 | norm_layer(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 182 | self.base_width, previous_dilation, norm_layer)) 183 | self.inplanes = planes * block.expansion 184 | for _ in range(1, blocks): 185 | layers.append(block(self.inplanes, planes, groups=self.groups, 186 | base_width=self.base_width, dilation=self.dilation, 187 | norm_layer=norm_layer)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | #x = self.maxpool(x) 196 | 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | feat = self.layer4(x) 201 | 202 | x = self.avgpool(feat) 203 | x = x.reshape(x.size(0), -1) 204 | x = self.fc(x) 205 | 206 | return x 207 | 208 | 209 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 210 | model = ResNet(block, layers, **kwargs) 211 | 212 | if pretrained: 213 | print("pretrain resnet arch:") 214 | print(arch) 215 | model.load_state_dict(model_zoo.load_url(model_urls[arch])) 216 | # script_dir = os.path.dirname(__file__) 217 | # state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device) 218 | # model.load_state_dict(state_dict) 219 | return model 220 | 221 | 222 | def resnet18(pretrained=False, progress=True, device='cpu', **kwargs): 223 | """Constructs a ResNet-18 model. 224 | 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | progress (bool): If True, displays a progress bar of the download to stderr 228 | """ 229 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device, 230 | **kwargs) 231 | 232 | 233 | def resnet34(pretrained=False, progress=True, device='cpu', **kwargs): 234 | """Constructs a ResNet-34 model. 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device, 241 | **kwargs) 242 | 243 | 244 | def resnet50(pretrained=False, progress=True, device='cpu', **kwargs): 245 | """Constructs a ResNet-50 model. 246 | 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | progress (bool): If True, displays a progress bar of the download to stderr 250 | """ 251 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device, 252 | **kwargs) 253 | 254 | 255 | def resnet101(pretrained=False, progress=True, device='cpu', **kwargs): 256 | """Constructs a ResNet-101 model. 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device, 263 | **kwargs) 264 | 265 | 266 | def resnet152(pretrained=False, progress=True, device='cpu', **kwargs): 267 | """Constructs a ResNet-152 model. 268 | 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | progress (bool): If True, displays a progress bar of the download to stderr 272 | """ 273 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device, 274 | **kwargs) 275 | 276 | 277 | def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs): 278 | """Constructs a ResNeXt-50 32x4d model. 279 | 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | kwargs['groups'] = 32 285 | kwargs['width_per_group'] = 4 286 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 287 | pretrained, progress, device, **kwargs) 288 | 289 | 290 | def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs): 291 | """Constructs a ResNeXt-101 32x8d model. 292 | 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | progress (bool): If True, displays a progress bar of the download to stderr 296 | """ 297 | kwargs['groups'] = 32 298 | kwargs['width_per_group'] = 8 299 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 300 | pretrained, progress, device, **kwargs) 301 | 302 | 303 | if __name__ == "__main__": 304 | model = resnet18() 305 | print(sum(p.numel() for p in model.parameters() if p.requires_grad)) # 11173962 306 | -------------------------------------------------------------------------------- /models/byol.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torchvision import transforms 7 | from math import pi, cos 8 | from torchvision.models import resnet50 9 | from collections import OrderedDict 10 | HPS = dict( 11 | max_steps=int(1000. * 1281167 / 4096), # 1000 epochs * 1281167 samples / batch size = 100 epochs * N of step/epoch 12 | # = total_epochs * len(dataloader) 13 | mlp_hidden_size=4096, 14 | projection_size=256, 15 | base_target_ema=4e-3, 16 | optimizer_config=dict( 17 | optimizer_name='lars', 18 | beta=0.9, 19 | trust_coef=1e-3, 20 | weight_decay=1.5e-6, 21 | exclude_bias_from_adaption=True), 22 | learning_rate_schedule=dict( 23 | base_learning_rate=0.2, 24 | warmup_steps=int(10.0 * 1281167 / 4096), # 10 epochs * N of steps/epoch = 10 epochs * len(dataloader) 25 | anneal_schedule='cosine'), 26 | batchnorm_kwargs=dict( 27 | decay_rate=0.9, 28 | eps=1e-5), 29 | seed=1337, 30 | ) 31 | 32 | 33 | 34 | def D(p, z, version='simplified'): # negative cosine similarity 35 | if version == 'original': 36 | z = z.detach() # stop gradient 37 | p = F.normalize(p, dim=1) # l2-normalize 38 | z = F.normalize(z, dim=1) # l2-normalize 39 | return -(p*z).sum(dim=1).mean() 40 | 41 | elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__ 42 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean() 43 | else: 44 | raise Exception 45 | 46 | 47 | class MLP(nn.Module): 48 | def __init__(self, in_dim): 49 | super().__init__() 50 | 51 | self.layer1 = nn.Sequential( 52 | nn.Linear(in_dim, HPS['mlp_hidden_size']), 53 | nn.BatchNorm1d(HPS['mlp_hidden_size'], eps=HPS['batchnorm_kwargs']['eps'], momentum=1-HPS['batchnorm_kwargs']['decay_rate']), 54 | nn.ReLU(inplace=True) 55 | ) 56 | self.layer2 = nn.Linear(HPS['mlp_hidden_size'], HPS['projection_size']) 57 | 58 | def forward(self, x): 59 | x = self.layer1(x) 60 | x = self.layer2(x) 61 | return x 62 | 63 | class BYOL(nn.Module): 64 | def __init__(self, backbone): 65 | super().__init__() 66 | 67 | self.backbone = backbone 68 | self.projector = MLP(backbone.output_dim) 69 | self.online_encoder = nn.Sequential( 70 | self.backbone, 71 | self.projector 72 | ) 73 | 74 | self.target_encoder = copy.deepcopy(self.online_encoder) 75 | self.online_predictor = MLP(HPS['projection_size']) 76 | #raise NotImplementedError('Please put update_moving_average to training') 77 | 78 | 79 | self.bilinear = 32 80 | self.conv16 = nn.Conv2d(2048, self.bilinear, kernel_size=1, stride=1, padding=0, bias=False) 81 | self.bn16 = nn.BatchNorm2d(self.bilinear) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.avgpool = nn.AvgPool2d(7, stride=1) 84 | self.avgpool14 = nn.AvgPool2d(14, stride=1) 85 | 86 | 87 | 88 | def target_ema(self, k, K, base_ema=HPS['base_target_ema']): 89 | # tau_base = 0.996 90 | # base_ema = 1 - tau_base = 0.996 91 | return 1 - base_ema * (cos(pi*k/K)+1)/2 92 | # return 1 - (1-self.tau_base) * (cos(pi*k/K)+1)/2 93 | 94 | @torch.no_grad() 95 | def update_moving_average(self): #, global_step, max_steps 96 | #tau = self.target_ema(global_step, max_steps) 97 | tau = 0.996 98 | for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): 99 | target.data = tau * target.data + (1 - tau) * online.data 100 | 101 | def forward(self, x1, x2): 102 | f_o, h_o = self.online_encoder, self.online_predictor 103 | f_t = self.target_encoder 104 | 105 | 106 | x = self.online_encoder[0].conv1(x1) 107 | x = self.online_encoder[0].bn1(x) 108 | x = self.online_encoder[0].relu(x) 109 | x = self.online_encoder[0].layer1(x) 110 | x = self.online_encoder[0].layer2(x) 111 | x = self.online_encoder[0].layer3(x) 112 | feat_map1 = self.online_encoder[0].layer4(x) 113 | x = self.online_encoder[0].avgpool(feat_map1) 114 | x = x.reshape(x.size(0), -1) 115 | z1_o = self.online_encoder[1](x) 116 | 117 | x = self.online_encoder[0].conv1(x2) 118 | x = self.online_encoder[0].bn1(x) 119 | x = self.online_encoder[0].relu(x) 120 | x = self.online_encoder[0].layer1(x) 121 | x = self.online_encoder[0].layer2(x) 122 | x = self.online_encoder[0].layer3(x) 123 | feat_map2 = self.online_encoder[0].layer4(x) 124 | x = self.online_encoder[0].avgpool(feat_map2) 125 | x = x.reshape(x.size(0), -1) 126 | z2_o = self.online_encoder[1](x) 127 | 128 | 129 | 130 | p1_o = h_o(z1_o) 131 | p2_o = h_o(z2_o) 132 | 133 | with torch.no_grad(): 134 | self.update_moving_average() 135 | z1_t = f_t(x1) 136 | z2_t = f_t(x2) 137 | 138 | L = D(p1_o, z2_t) / 2 + D(p2_o, z1_t) / 2 139 | 140 | 141 | grad_wrt_act1 = torch.autograd.grad(outputs=L, inputs=feat_map1, 142 | grad_outputs=torch.ones_like(L), retain_graph=True, 143 | allow_unused=True)[0] 144 | 145 | gradcam = torch.relu((feat_map1 * grad_wrt_act1).sum(dim=1)) 146 | 147 | featcov16 = self.conv16(feat_map1) 148 | featcov16 = self.bn16(featcov16) 149 | featcov16 = self.relu(featcov16) 150 | img, _ = torch.max(featcov16, axis=1) 151 | img = img - torch.min(img) 152 | att_max = img / (1e-7 + torch.max(img)) #batch*7*7 153 | 154 | return {'loss': L, 'attmap': att_max, 'gradcam': gradcam} 155 | 156 | 157 | 158 | def inference(self, img): 159 | x = self.online_encoder[0].conv1(img) 160 | x = self.online_encoder[0].bn1(x) 161 | x = self.online_encoder[0].relu(x) 162 | x = self.online_encoder[0].layer1(x) 163 | x = self.online_encoder[0].layer2(x) 164 | x = self.online_encoder[0].layer3(x) 165 | feat_map1 = self.online_encoder[0].layer4(x) 166 | x = self.online_encoder[0].avgpool(feat_map1) 167 | x = x.reshape(x.size(0), -1) 168 | 169 | featcov16 = self.conv16(feat_map1) 170 | featcov16 = self.bn16(featcov16) 171 | featcov16 = self.relu(featcov16) 172 | 173 | 174 | 175 | img, _ = torch.max(featcov16, axis=1) 176 | img = img - torch.min(img) 177 | att_max = img / (1e-7 + torch.max(img)) 178 | 179 | img = att_max[:, None, :, :] 180 | img = img.repeat(1, 2048, 1, 1) 181 | PFM = feat_map1.cuda() * img.cuda() 182 | aa = self.avgpool14(PFM) 183 | bp_out_feat = aa.view(aa.size(0), -1) 184 | bp_out_feat = nn.functional.normalize(bp_out_feat, dim=1) 185 | 186 | 187 | return x, bp_out_feat 188 | 189 | 190 | 191 | if __name__ == "__main__": 192 | pass 193 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lars import LARS 2 | from .lars_simclr import LARS_simclr 3 | from .larc import LARC 4 | import torch 5 | from .lr_scheduler import LR_Scheduler 6 | 7 | 8 | def get_optimizer(name, model, lr, momentum, weight_decay): 9 | 10 | predictor_prefix = ('module.predictor', 'predictor') 11 | parameters = [{ 12 | 'name': 'base', 13 | 'params': [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)], 14 | 'lr': lr 15 | },{ 16 | 'name': 'predictor', 17 | 'params': [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)], 18 | 'lr': lr 19 | }] 20 | if name == 'lars': 21 | optimizer = LARS(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 22 | elif name == 'sgd': 23 | optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 24 | elif name == 'lars_simclr': # Careful 25 | optimizer = LARS_simclr(model.named_modules(), lr=lr, momentum=momentum, weight_decay=weight_decay) 26 | elif name == 'larc': 27 | optimizer = LARC( 28 | torch.optim.SGD( 29 | parameters, 30 | lr=lr, 31 | momentum=momentum, 32 | weight_decay=weight_decay 33 | ), 34 | trust_coefficient=0.001, 35 | clip=False 36 | ) 37 | else: 38 | raise NotImplementedError 39 | return optimizer 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /optimizers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/optimizers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/larc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/optimizers/__pycache__/larc.cpython-38.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/lars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/optimizers/__pycache__/lars.cpython-38.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/optimizers/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /optimizers/larc.py: -------------------------------------------------------------------------------- 1 | """SwAV use larc instead of lars optimizer""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.parameter import Parameter 6 | from torch.optim.optimizer import Optimizer 7 | 8 | def main(): # Example 9 | import torchvision 10 | model = torchvision.models.resnet18(pretrained=False) 11 | # optim = torch.optim.Adam(model.parameters(), lr=0.0001) 12 | optim = torch.optim.SGD(model.parameters(),lr=0.2, momentum=0.9, weight_decay=1.5e-6) 13 | optim = LARC(optim) 14 | 15 | class LARC(Optimizer): 16 | """ 17 | :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, 18 | in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive 19 | local learning rate for each individual parameter. The algorithm is designed to improve 20 | convergence of large batch training. 21 | 22 | See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. 23 | In practice it modifies the gradients of parameters as a proxy for modifying the learning rate 24 | of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. 25 | ``` 26 | model = ... 27 | optim = torch.optim.Adam(model.parameters(), lr=...) 28 | optim = LARC(optim) 29 | ``` 30 | It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. 31 | ``` 32 | model = ... 33 | optim = torch.optim.Adam(model.parameters(), lr=...) 34 | optim = LARC(optim) 35 | optim = apex.fp16_utils.FP16_Optimizer(optim) 36 | ``` 37 | Args: 38 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 39 | trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 40 | clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. 41 | eps: epsilon kludge to help with numerical stability while calculating adaptive_lr 42 | """ 43 | 44 | def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): 45 | self.optim = optimizer 46 | self.trust_coefficient = trust_coefficient 47 | self.eps = eps 48 | self.clip = clip 49 | 50 | def __getstate__(self): 51 | return self.optim.__getstate__() 52 | 53 | def __setstate__(self, state): 54 | self.optim.__setstate__(state) 55 | 56 | @property 57 | def state(self): 58 | return self.optim.state 59 | 60 | def __repr__(self): 61 | return self.optim.__repr__() 62 | 63 | @property 64 | def param_groups(self): 65 | return self.optim.param_groups 66 | 67 | @param_groups.setter 68 | def param_groups(self, value): 69 | self.optim.param_groups = value 70 | 71 | def state_dict(self): 72 | return self.optim.state_dict() 73 | 74 | def load_state_dict(self, state_dict): 75 | self.optim.load_state_dict(state_dict) 76 | 77 | def zero_grad(self): 78 | self.optim.zero_grad() 79 | 80 | def add_param_group(self, param_group): 81 | self.optim.add_param_group( param_group) 82 | 83 | def step(self): 84 | with torch.no_grad(): 85 | weight_decays = [] 86 | for group in self.optim.param_groups: 87 | # absorb weight decay control from optimizer 88 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 89 | weight_decays.append(weight_decay) 90 | group['weight_decay'] = 0 91 | for p in group['params']: 92 | if p.grad is None: 93 | continue 94 | param_norm = torch.norm(p.data) 95 | grad_norm = torch.norm(p.grad.data) 96 | 97 | if param_norm != 0 and grad_norm != 0: 98 | # calculate adaptive lr + weight decay 99 | adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) 100 | 101 | # clip learning rate for LARC 102 | if self.clip: 103 | # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` 104 | adaptive_lr = min(adaptive_lr/group['lr'], 1) 105 | 106 | p.grad.data += weight_decay * p.data 107 | p.grad.data *= adaptive_lr 108 | 109 | self.optim.step() 110 | # return weight decay control to optimizer 111 | for i, group in enumerate(self.optim.param_groups): 112 | group['weight_decay'] = weight_decays[i] 113 | 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /optimizers/lars.py: -------------------------------------------------------------------------------- 1 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """ 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class LARS(Optimizer): 6 | r"""Implements layer-wise adaptive rate scaling for SGD. 7 | 8 | Args: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float): base learning rate (\gamma_0) 12 | momentum (float, optional): momentum factor (default: 0) ("m") 13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 14 | ("\beta") 15 | eta (float, optional): LARS coefficient 16 | max_epoch: maximum training epoch to determine polynomial LR decay. 17 | 18 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 19 | Large Batch Training of Convolutional Networks: 20 | https://arxiv.org/abs/1708.03888 21 | 22 | Example: 23 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) 24 | >>> optimizer.zero_grad() 25 | >>> loss_fn(model(input), target).backward() 26 | >>> optimizer.step() 27 | """ 28 | def __init__(self, params, lr=required, momentum=.9, 29 | weight_decay=.0005, eta=0.001, max_epoch=200): 30 | if lr is not required and lr < 0.0: 31 | raise ValueError("Invalid learning rate: {}".format(lr)) 32 | if momentum < 0.0: 33 | raise ValueError("Invalid momentum value: {}".format(momentum)) 34 | if weight_decay < 0.0: 35 | raise ValueError("Invalid weight_decay value: {}" 36 | .format(weight_decay)) 37 | if eta < 0.0: 38 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 39 | 40 | self.epoch = 0 41 | defaults = dict(lr=lr, momentum=momentum, 42 | weight_decay=weight_decay, 43 | eta=eta, max_epoch=max_epoch) 44 | super(LARS, self).__init__(params, defaults) 45 | 46 | def step(self, epoch=None, closure=None): 47 | """Performs a single optimization step. 48 | 49 | Arguments: 50 | closure (callable, optional): A closure that reevaluates the model 51 | and returns the loss. 52 | epoch: current epoch to calculate polynomial LR decay schedule. 53 | if None, uses self.epoch and increments it. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | if epoch is None: 60 | epoch = self.epoch 61 | self.epoch += 1 62 | 63 | for group in self.param_groups: 64 | weight_decay = group['weight_decay'] 65 | momentum = group['momentum'] 66 | eta = group['eta'] 67 | lr = group['lr'] 68 | max_epoch = group['max_epoch'] 69 | 70 | for p in group['params']: 71 | if p.grad is None: 72 | continue 73 | 74 | param_state = self.state[p] 75 | d_p = p.grad.data 76 | 77 | weight_norm = torch.norm(p.data) 78 | grad_norm = torch.norm(d_p) 79 | 80 | # Global LR computed on polynomial decay schedule 81 | decay = (1 - float(epoch) / max_epoch) ** 2 82 | global_lr = lr * decay 83 | 84 | # Compute local learning rate for this layer 85 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm) 86 | 87 | # Update the momentum term 88 | actual_lr = local_lr * global_lr 89 | 90 | if 'momentum_buffer' not in param_state: 91 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 92 | else: 93 | buf = param_state['momentum_buffer'] 94 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr) 95 | p.data.add_(-buf) 96 | 97 | return loss -------------------------------------------------------------------------------- /optimizers/lars_simclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.optim.optimizer import Optimizer 4 | import torch.nn as nn 5 | 6 | class LARS_simclr(Optimizer): 7 | def __init__(self, 8 | named_modules, 9 | lr, 10 | momentum=0.9, # beta? YES 11 | trust_coef=1e-3, 12 | weight_decay=1.5e-6, 13 | exclude_bias_from_adaption=True): 14 | '''byol: As in SimCLR and official implementation of LARS, we exclude bias # and batchnorm weight from the Lars adaptation and weightdecay''' 15 | defaults = dict(momentum=momentum, 16 | lr=lr, 17 | weight_decay=weight_decay, 18 | trust_coef=trust_coef) 19 | parameters = self.exclude_from_model(named_modules, exclude_bias_from_adaption) 20 | super(LARS_simclr, self).__init__(parameters, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for group in self.param_groups: # only 1 group in most cases 25 | weight_decay = group['weight_decay'] 26 | momentum = group['momentum'] 27 | lr = group['lr'] 28 | 29 | trust_coef = group['trust_coef'] 30 | # print(group['name']) 31 | # eps = group['eps'] 32 | for p in group['params']: 33 | # breakpoint() 34 | if p.grad is None: 35 | continue 36 | global_lr = lr 37 | velocity = self.state[p].get('velocity', 0) 38 | # if name in self.exclude_from_layer_adaptation: 39 | if self._use_weight_decay(group): 40 | p.grad.data += weight_decay * p.data 41 | 42 | trust_ratio = 1.0 43 | if self._do_layer_adaptation(group): 44 | w_norm = torch.norm(p.data, p=2) 45 | g_norm = torch.norm(p.grad.data, p=2) 46 | trust_ratio = trust_coef * w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1.0 47 | scaled_lr = global_lr * trust_ratio # trust_ratio is the local_lr 48 | next_v = momentum * velocity + scaled_lr * p.grad.data 49 | update = next_v 50 | p.data = p.data - update 51 | 52 | 53 | def _use_weight_decay(self, group): 54 | return False if group['name'] == 'exclude' else True 55 | def _do_layer_adaptation(self, group): 56 | return False if group['name'] == 'exclude' else True 57 | 58 | def exclude_from_model(self, named_modules, exclude_bias_from_adaption=True): 59 | base = [] 60 | exclude = [] 61 | for name, module in named_modules: 62 | if type(module) in [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]: 63 | # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) 64 | for name2, param in module.named_parameters(): 65 | exclude.append(param) 66 | else: 67 | for name2, param in module.named_parameters(): 68 | if name2 == 'bias': 69 | exclude.append(param) 70 | elif name2 == 'weight': 71 | base.append(param) 72 | else: 73 | pass # non leaf modules 74 | return [{ 75 | 'name': 'base', 76 | 'params': base 77 | },{ 78 | 'name': 'exclude', 79 | 'params': exclude 80 | }] if exclude_bias_from_adaption == True else [{ 81 | 'name': 'base', 82 | 'params': base+exclude 83 | }] 84 | 85 | if __name__ == "__main__": 86 | 87 | resnet = torchvision.models.resnet18(pretrained=False) 88 | model = resnet 89 | 90 | optimizer = LARS_simclr(model.named_modules(), lr=0.1) 91 | # print() 92 | # out = optimizer.exclude_from_model(model.named_modules(),exclude_bias_from_adaption=False) 93 | # print(len(out[0]['params'])) 94 | # exit() 95 | 96 | criterion = torch.nn.CrossEntropyLoss() 97 | for i in range(100): 98 | model.zero_grad() 99 | pred = model(torch.randn((2,3,32,32))) 100 | loss = pred.mean() 101 | loss.backward() 102 | optimizer.step() 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class LR_Scheduler(object): 7 | def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): 8 | self.base_lr = base_lr 9 | self.constant_predictor_lr = constant_predictor_lr 10 | warmup_iter = iter_per_epoch * warmup_epochs 11 | warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) 12 | decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) 13 | cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) 14 | 15 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) 16 | self.optimizer = optimizer 17 | self.iter = 0 18 | self.current_lr = 0 19 | def step(self): 20 | for param_group in self.optimizer.param_groups: 21 | 22 | if self.constant_predictor_lr and param_group['name'] == 'predictor': 23 | param_group['lr'] = self.base_lr 24 | else: 25 | lr = param_group['lr'] = self.lr_schedule[self.iter] 26 | 27 | self.iter += 1 28 | self.current_lr = lr 29 | return lr 30 | def get_lr(self): 31 | return self.current_lr 32 | 33 | if __name__ == "__main__": 34 | import torchvision 35 | model = torchvision.models.resnet50() 36 | optimizer = torch.optim.SGD(model.parameters(), lr=999) 37 | epochs = 100 38 | n_iter = 1000 39 | scheduler = LR_Scheduler(optimizer, 10, 1, epochs, 3, 0, n_iter) 40 | import matplotlib.pyplot as plt 41 | lrs = [] 42 | for epoch in range(epochs): 43 | for it in range(n_iter): 44 | lr = scheduler.step() 45 | lrs.append(lr) 46 | plt.plot(lrs) 47 | plt.show() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tqdm 4 | Pillow 5 | numpy 6 | matplotlib 7 | yaml==5.3.1 8 | tensorboardx==2.1 9 | 10 | -------------------------------------------------------------------------------- /resnet_output.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | #from torchvision._internally_replaced_utils import load_state_dict_from_url 5 | from torch.hub import load_state_dict_from_url 6 | from typing import Type, Any, Callable, Union, List, Optional 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 11 | 'wide_resnet50_2', 'wide_resnet101_2'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion: int = 1 40 | 41 | def __init__( 42 | self, 43 | inplanes: int, 44 | planes: int, 45 | stride: int = 1, 46 | downsample: Optional[nn.Module] = None, 47 | groups: int = 1, 48 | base_width: int = 64, 49 | dilation: int = 1, 50 | norm_layer: Optional[Callable[..., nn.Module]] = None 51 | ) -> None: 52 | super(BasicBlock, self).__init__() 53 | if norm_layer is None: 54 | norm_layer = nn.BatchNorm2d 55 | if groups != 1 or base_width != 64: 56 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 57 | if dilation > 1: 58 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 59 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 60 | self.conv1 = conv3x3(inplanes, planes, stride) 61 | self.bn1 = norm_layer(planes) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.conv2 = conv3x3(planes, planes) 64 | self.bn2 = norm_layer(planes) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | if self.downsample is not None: 79 | identity = self.downsample(x) 80 | 81 | out += identity 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class Bottleneck(nn.Module): 88 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 89 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 90 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 91 | # This variant is also known as ResNet V1.5 and improves accuracy according to 92 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 93 | 94 | expansion: int = 4 95 | 96 | def __init__( 97 | self, 98 | inplanes: int, 99 | planes: int, 100 | stride: int = 1, 101 | downsample: Optional[nn.Module] = None, 102 | groups: int = 1, 103 | base_width: int = 64, 104 | dilation: int = 1, 105 | norm_layer: Optional[Callable[..., nn.Module]] = None 106 | ) -> None: 107 | super(Bottleneck, self).__init__() 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm2d 110 | width = int(planes * (base_width / 64.)) * groups 111 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 112 | self.conv1 = conv1x1(inplanes, width) 113 | self.bn1 = norm_layer(width) 114 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 115 | self.bn2 = norm_layer(width) 116 | self.conv3 = conv1x1(width, planes * self.expansion) 117 | self.bn3 = norm_layer(planes * self.expansion) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | self.stride = stride 121 | 122 | def forward(self, x: Tensor) -> Tensor: 123 | identity = x 124 | 125 | out = self.conv1(x) 126 | out = self.bn1(out) 127 | out = self.relu(out) 128 | 129 | out = self.conv2(out) 130 | out = self.bn2(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv3(out) 134 | out = self.bn3(out) 135 | 136 | if self.downsample is not None: 137 | identity = self.downsample(x) 138 | 139 | out += identity 140 | out = self.relu(out) 141 | 142 | return out 143 | 144 | 145 | class ResNet(nn.Module): 146 | 147 | def __init__( 148 | self, 149 | block: Type[Union[BasicBlock, Bottleneck]], 150 | layers: List[int], 151 | num_classes: int = 1000, 152 | zero_init_residual: bool = False, 153 | groups: int = 1, 154 | width_per_group: int = 64, 155 | replace_stride_with_dilation: Optional[List[bool]] = None, 156 | norm_layer: Optional[Callable[..., nn.Module]] = None 157 | ) -> None: 158 | super(ResNet, self).__init__() 159 | if norm_layer is None: 160 | norm_layer = nn.BatchNorm2d 161 | self._norm_layer = norm_layer 162 | 163 | self.inplanes = 64 164 | self.dilation = 1 165 | if replace_stride_with_dilation is None: 166 | # each element in the tuple indicates if we should replace 167 | # the 2x2 stride with a dilated convolution instead 168 | replace_stride_with_dilation = [False, False, False] 169 | if len(replace_stride_with_dilation) != 3: 170 | raise ValueError("replace_stride_with_dilation should be None " 171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 172 | self.groups = groups 173 | self.base_width = width_per_group 174 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 175 | bias=False) 176 | self.bn1 = norm_layer(self.inplanes) 177 | self.relu = nn.ReLU(inplace=True) 178 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 179 | self.layer1 = self._make_layer(block, 64, layers[0]) 180 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 181 | dilate=replace_stride_with_dilation[0]) 182 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 183 | dilate=replace_stride_with_dilation[1]) 184 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 185 | dilate=replace_stride_with_dilation[2]) 186 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 187 | self.fc = nn.Linear(512 * block.expansion, num_classes) 188 | 189 | for m in self.modules(): 190 | if isinstance(m, nn.Conv2d): 191 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 192 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 193 | nn.init.constant_(m.weight, 1) 194 | nn.init.constant_(m.bias, 0) 195 | 196 | # Zero-initialize the last BN in each residual branch, 197 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 198 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 199 | if zero_init_residual: 200 | for m in self.modules(): 201 | if isinstance(m, Bottleneck): 202 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 203 | elif isinstance(m, BasicBlock): 204 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 205 | 206 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 207 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 208 | norm_layer = self._norm_layer 209 | downsample = None 210 | previous_dilation = self.dilation 211 | if dilate: 212 | self.dilation *= stride 213 | stride = 1 214 | if stride != 1 or self.inplanes != planes * block.expansion: 215 | downsample = nn.Sequential( 216 | conv1x1(self.inplanes, planes * block.expansion, stride), 217 | norm_layer(planes * block.expansion), 218 | ) 219 | 220 | layers = [] 221 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 222 | self.base_width, previous_dilation, norm_layer)) 223 | self.inplanes = planes * block.expansion 224 | for _ in range(1, blocks): 225 | layers.append(block(self.inplanes, planes, groups=self.groups, 226 | base_width=self.base_width, dilation=self.dilation, 227 | norm_layer=norm_layer)) 228 | 229 | return nn.Sequential(*layers) 230 | 231 | def _forward_impl(self, x: Tensor) -> Tensor: 232 | # See note [TorchScript super()] 233 | x = self.conv1(x) 234 | x = self.bn1(x) 235 | x = self.relu(x) 236 | x = self.maxpool(x) 237 | 238 | x = self.layer1(x) 239 | x = self.layer2(x) 240 | x = self.layer3(x) 241 | featmap = self.layer4(x) 242 | 243 | x = self.avgpool(featmap) 244 | x = torch.flatten(x, 1) 245 | y = self.fc(x) 246 | 247 | return y, x, featmap 248 | 249 | def forward(self, x: Tensor) -> Tensor: 250 | return self._forward_impl(x) 251 | 252 | 253 | def _resnet( 254 | arch: str, 255 | block: Type[Union[BasicBlock, Bottleneck]], 256 | layers: List[int], 257 | pretrained: bool, 258 | progress: bool, 259 | **kwargs: Any 260 | ) -> ResNet: 261 | model = ResNet(block, layers, **kwargs) 262 | if pretrained: 263 | state_dict = load_state_dict_from_url(model_urls[arch], 264 | progress=progress) 265 | model.load_state_dict(state_dict) 266 | return model 267 | 268 | 269 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 270 | r"""ResNet-18 model from 271 | `"Deep Residual Learning for Image Recognition" `_. 272 | 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 282 | r"""ResNet-34 model from 283 | `"Deep Residual Learning for Image Recognition" `_. 284 | 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | progress (bool): If True, displays a progress bar of the download to stderr 288 | """ 289 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 290 | **kwargs) 291 | 292 | 293 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 294 | r"""ResNet-50 model from 295 | `"Deep Residual Learning for Image Recognition" `_. 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 302 | **kwargs) 303 | 304 | 305 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 306 | r"""ResNet-101 model from 307 | `"Deep Residual Learning for Image Recognition" `_. 308 | 309 | Args: 310 | pretrained (bool): If True, returns a model pre-trained on ImageNet 311 | progress (bool): If True, displays a progress bar of the download to stderr 312 | """ 313 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 314 | **kwargs) 315 | 316 | 317 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 318 | r"""ResNet-152 model from 319 | `"Deep Residual Learning for Image Recognition" `_. 320 | 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | progress (bool): If True, displays a progress bar of the download to stderr 324 | """ 325 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 326 | **kwargs) 327 | 328 | 329 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 330 | r"""ResNeXt-50 32x4d model from 331 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 332 | 333 | Args: 334 | pretrained (bool): If True, returns a model pre-trained on ImageNet 335 | progress (bool): If True, displays a progress bar of the download to stderr 336 | """ 337 | kwargs['groups'] = 32 338 | kwargs['width_per_group'] = 4 339 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 340 | pretrained, progress, **kwargs) 341 | 342 | 343 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 344 | r"""ResNeXt-101 32x8d model from 345 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 346 | 347 | Args: 348 | pretrained (bool): If True, returns a model pre-trained on ImageNet 349 | progress (bool): If True, displays a progress bar of the download to stderr 350 | """ 351 | kwargs['groups'] = 32 352 | kwargs['width_per_group'] = 8 353 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 354 | pretrained, progress, **kwargs) 355 | 356 | 357 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 358 | r"""Wide ResNet-50-2 model from 359 | `"Wide Residual Networks" `_. 360 | 361 | The model is the same as ResNet except for the bottleneck number of channels 362 | which is twice larger in every block. The number of channels in outer 1x1 363 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 364 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 365 | 366 | Args: 367 | pretrained (bool): If True, returns a model pre-trained on ImageNet 368 | progress (bool): If True, displays a progress bar of the download to stderr 369 | """ 370 | kwargs['width_per_group'] = 64 * 2 371 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 372 | pretrained, progress, **kwargs) 373 | 374 | 375 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 376 | r"""Wide ResNet-101-2 model from 377 | `"Wide Residual Networks" `_. 378 | 379 | The model is the same as ResNet except for the bottleneck number of channels 380 | which is twice larger in every block. The number of channels in outer 1x1 381 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 382 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 383 | 384 | Args: 385 | pretrained (bool): If True, returns a model pre-trained on ImageNet 386 | progress (bool): If True, displays a progress bar of the download to stderr 387 | """ 388 | kwargs['width_per_group'] = 64 * 2 389 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 390 | pretrained, progress, **kwargs) 391 | -------------------------------------------------------------------------------- /run_all.sh: -------------------------------------------------------------------------------- 1 | #export PATH=/home/ubuntu/miniconda3/envs/dl/bin:$PATH 2 | # 3 | 4 | python main.py --data_dir ./CUB200 --log_dir ./logs/ -c configs/byol_cub200.yaml --ckpt_dir ./.cache/ --hide_progress 5 | python main.py --data_dir ./StanfordCars --log_dir ./logs/ -c configs/byol_stanfordcars.yaml --ckpt_dir ./.cache/ --hide_progress 6 | python main.py --data_dir ./Aircraft --log_dir ./logs/ -c configs/byol_aircrafts.yaml --ckpt_dir ./.cache/ --hide_progress 7 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .average_meter import AverageMeter 2 | from .accuracy import accuracy 3 | from .knn_monitor import knn_monitor 4 | from .logger import Logger 5 | from .file_exist_fn import file_exist_check -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/accuracy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/accuracy.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/average_meter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/average_meter.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/file_exist_fn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/file_exist_fn.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/knn_monitor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/knn_monitor.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/plotter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/plotter.cpython-38.pyc -------------------------------------------------------------------------------- /tools/accuracy.py: -------------------------------------------------------------------------------- 1 | def accuracy(output, target, topk=(1,)): 2 | """Computes the accuracy over the k top predictions for the specified values of k""" 3 | with torch.no_grad(): 4 | maxk = max(topk) 5 | batch_size = target.size(0) 6 | 7 | _, pred = output.topk(maxk, 1, True, True) 8 | pred = pred.t() 9 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 10 | 11 | res = [] 12 | for k in topk: 13 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 14 | res.append(correct_k.mul_(100.0 / batch_size)) 15 | return res 16 | -------------------------------------------------------------------------------- /tools/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, name, fmt=':f'): 4 | self.name = name 5 | self.fmt = fmt 6 | self.log = [] 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def reset(self): 13 | self.log.append(self.avg) 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def __str__(self): 26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 27 | return fmtstr.format(**self.__dict__) 28 | 29 | if __name__ == "__main__": 30 | meter = AverageMeter('sldk') 31 | print(meter.log) 32 | 33 | -------------------------------------------------------------------------------- /tools/file_exist_fn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | 5 | def file_exist_check(file_dir): 6 | 7 | if os.path.isdir(file_dir): 8 | for i in range(2, 1000): 9 | if not os.path.isdir(file_dir + f'({i})'): 10 | file_dir += f'({i})' 11 | break 12 | return file_dir 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tools/knn_monitor.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch.nn.functional as F 3 | import torch 4 | # code copied from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=RI1Y8bSImD7N 5 | # test using a knn monitor 6 | def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False): 7 | net.eval() 8 | classes = len(memory_data_loader.dataset.classes) 9 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 10 | with torch.no_grad(): 11 | # generate feature bank 12 | for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress): 13 | feature = net(data.cuda(non_blocking=True)) 14 | feature = F.normalize(feature, dim=1) 15 | feature_bank.append(feature) 16 | # [D, N] 17 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 18 | # [N] 19 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device) 20 | # loop test data to predict the label by weighted knn search 21 | test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress) 22 | for data, target in test_bar: 23 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 24 | feature = net(data) 25 | feature = F.normalize(feature, dim=1) 26 | 27 | pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t) 28 | 29 | total_num += data.size(0) 30 | total_top1 += (pred_labels[:, 0] == target).float().sum().item() 31 | test_bar.set_postfix({'Accuracy':total_top1 / total_num * 100}) 32 | return total_top1 / total_num * 100 33 | 34 | # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 35 | # implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR 36 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): 37 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 38 | sim_matrix = torch.mm(feature, feature_bank) 39 | # [B, K] 40 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) 41 | # [B, K] 42 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices) 43 | sim_weight = (sim_weight / knn_t).exp() 44 | 45 | # counts for each class 46 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device) 47 | # [B*K, C] 48 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) 49 | # weighted score ---> [B, C] 50 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) 51 | 52 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 53 | return pred_labels 54 | -------------------------------------------------------------------------------- /tools/logger.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # try: 3 | # from torch.utils.tensorboard import SummaryWriter 4 | # except ImportError: 5 | from tensorboardX import SummaryWriter 6 | 7 | from torch import Tensor 8 | from collections import OrderedDict 9 | import os 10 | from .plotter import Plotter 11 | 12 | 13 | class Logger(object): 14 | def __init__(self, log_dir, tensorboard=True, matplotlib=True): 15 | 16 | self.reset(log_dir, tensorboard, matplotlib) 17 | 18 | def reset(self, log_dir=None, tensorboard=True, matplotlib=True): 19 | 20 | if log_dir is not None: self.log_dir=log_dir 21 | self.writer = SummaryWriter(log_dir=self.log_dir) if tensorboard else None 22 | self.plotter = Plotter() if matplotlib else None 23 | self.counter = OrderedDict() 24 | 25 | def update_scalers(self, ordered_dict): 26 | 27 | for key, value in ordered_dict.items(): 28 | if isinstance(value, Tensor): 29 | ordered_dict[key] = value.item() 30 | if self.counter.get(key) is None: 31 | self.counter[key] = 1 32 | else: 33 | self.counter[key] += 1 34 | 35 | if self.writer: 36 | self.writer.add_scalar(key, value, self.counter[key]) 37 | 38 | 39 | if self.plotter: 40 | self.plotter.update(ordered_dict) 41 | self.plotter.save(os.path.join(self.log_dir, 'plotter.svg')) 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /tools/plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') #https://stackoverflow.com/questions/49921721/runtimeerror-main-thread-is-not-in-main-loop-with-matplotlib-and-flask 3 | import matplotlib.pyplot as plt 4 | from collections import OrderedDict 5 | from torch import Tensor 6 | 7 | class Plotter(object): 8 | def __init__(self): 9 | self.logger = OrderedDict() 10 | def update(self, ordered_dict): 11 | for key, value in ordered_dict.items(): 12 | if isinstance(value, Tensor): 13 | ordered_dict[key] = value.item() 14 | if self.logger.get(key) is None: 15 | self.logger[key] = [value] 16 | else: 17 | self.logger[key].append(value) 18 | 19 | def save(self, file, **kwargs): 20 | fig, axes = plt.subplots(nrows=len(self.logger), ncols=1, figsize=(8,2*len(self.logger))) 21 | fig.tight_layout() 22 | for ax, (key, value) in zip(axes, self.logger.items()): 23 | ax.plot(value) 24 | ax.set_title(key) 25 | 26 | plt.savefig(file, **kwargs) 27 | plt.close() 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | --------------------------------------------------------------------------------