├── Dockerfile
├── LICENSE
├── README.md
├── __init__.py
├── __pycache__
├── arguments.cpython-38.pyc
├── linear_cub_eval.cpython-38.pyc
├── linear_eval.cpython-38.pyc
└── linear_stanfordcars_eval.cpython-38.pyc
├── arguments.py
├── augmentations
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── byol_aug.cpython-38.pyc
│ ├── eval_aug.cpython-38.pyc
│ └── gaussian_blur.cpython-38.pyc
├── byol_aug.py
├── eval_aug.py
└── gaussian_blur.py
├── classifier.py
├── configs
├── __init__.py
├── byol_aircrafts
├── byol_aircrafts.yaml
├── byol_aircrafts_eval.yaml
├── byol_cifar.yaml
├── byol_cub200.yaml
└── byol_stanfordcars.yaml
├── datasets
├── CUB200.py
├── CUB200_val.py
├── CUB2011.py
├── ImageNet100.py
├── __init__.py
├── __pycache__
│ ├── CUB2011.cpython-38.pyc
│ ├── ImageNet100.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ └── random_dataset.cpython-38.pyc
└── random_dataset.py
├── examples
└── framework.png
├── hand_detector.py
├── linear_cub_eval.py
├── linear_eval.py
├── linear_imagenet100_eval.py
├── linear_stanfordcars_eval.py
├── loader.py
├── main.py
├── main_lincls.py
├── main_moco.py
├── moco
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── builder.cpython-38.pyc
│ └── loader.cpython-38.pyc
├── builder.py
└── loader.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── byol.cpython-38.pyc
├── backbones
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── cifar_resnet_1.cpython-38.pyc
│ │ ├── cifar_resnet_2.cpython-38.pyc
│ │ └── cub_resnet_1.cpython-38.pyc
│ ├── cifar_resnet_1.py
│ ├── cifar_resnet_2.py
│ └── cub_resnet_1.py
└── byol.py
├── optimizers
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── larc.cpython-38.pyc
│ ├── lars.cpython-38.pyc
│ └── lr_scheduler.cpython-38.pyc
├── larc.py
├── lars.py
├── lars_simclr.py
└── lr_scheduler.py
├── requirements.txt
├── resnet_output.py
├── run_all.sh
├── simsiam-800e90.83acc.svg
└── tools
├── __init__.py
├── __pycache__
├── __init__.cpython-38.pyc
├── accuracy.cpython-38.pyc
├── average_meter.cpython-38.pyc
├── file_exist_fn.cpython-38.pyc
├── knn_monitor.cpython-38.pyc
├── logger.cpython-38.pyc
└── plotter.cpython-38.pyc
├── accuracy.py
├── average_meter.py
├── file_exist_fn.py
├── knn_monitor.py
├── logger.py
└── plotter.py
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime
2 | RUN apt-get update && apt install -y git
3 | RUN pip install tensorboardX
4 | RUN pip install ttach
5 | RUN pip install pandas
6 | RUN pip install matplotlib
7 | RUN pip install opencv-python
8 | RUN pip install google-cloud-storage
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Common Rationale to Improve Self-Supervised Representation for Fine-Grained Visual Recognition Problems
2 |
3 | This project contains the implementation of learning common rationale to improve self-supervised representation for fine-grained visual recognition, as presented in our paper
4 |
5 | > Learning Common Rationale to Improve Self-Supervised Representation for Fine-Grained Visual Recognition Problems,
6 | > Yangyang Shu, Anton van den Hengel and Lingqiao Liu*
7 | > *CVPR 2023*
8 |
9 | ## Datasets
10 | | Dataset | Download Link |
11 | | -- | -- |
12 | | CUB-200-2011 | https://paperswithcode.com/dataset/cub-200-2011 |
13 | | Stanford Cars | http://ai.stanford.edu/~jkrause/cars/car_dataset.html |
14 | | FGVC Aircraft | http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ |
15 |
16 |
17 | Please download and organize the datasets in this structure:
18 | ```
19 | LCR
20 | ├── CUB200/
21 | │ ├── train/
22 | ├── test/
23 | ├── StanfordCars/
24 | │ ├── train/
25 | ├── test/
26 | ├── Aircraft/
27 | │ ├── train/
28 | ├── test/
29 | ```
30 |
31 | # For byol
32 | Install the required packages:
33 | ```
34 | pip install -r requirements.txt
35 | ```
36 |
37 | - The running commands for several datasets are shown below. You can also refer to ``run_all.sh``.
38 | ```
39 | python main.py --data_dir ./CUB200 --log_dir ./logs/ -c configs/byol_cub200.yaml --ckpt_dir ./.cache/ --hide_progress
40 | python main.py --data_dir ./StanfordCars --log_dir ./logs/ -c configs/byol_stanfordcars.yaml --ckpt_dir ./.cache/ --hide_progress
41 | python main.py --data_dir ./Aircraft --log_dir ./logs/ -c configs/byol_aircrafts.yaml --ckpt_dir ./.cache/ --hide_progress
42 |
43 | ```
44 |
45 | # For moco v2
46 |
47 | - The running commands for pre-training and retrieval
48 | ```
49 | python main_moco.py --epochs 100 -a resnet50 --lr 0.03 --batch-size 128 --multiprocessing-distributed --world-size 1 --rank 0 Aircraft --mlp --moco-t 0.2 --aug-plus --cos
50 | python main_moco.py --epochs 100 -a resnet50 --lr 0.03 --batch-size 128 --multiprocessing-distributed --world-size 1 --rank 0 StanfordCars --mlp --moco-t 0.2 --aug-plus --cos
51 | python main_moco.py --epochs 100 -a resnet50 --lr 0.03 --batch-size 128 --multiprocessing-distributed --world-size 1 --rank 0 CUB200 --mlp --moco-t 0.2 --aug-plus --cos
52 | ```
53 |
54 | - The running commands for linear probing
55 | ```
56 | python main_lincls.py -a resnet50 --lr 30.0 --batch-size 256 --pretrained [your checkpoint path]/checkpoint_****.pth.tar --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 Aircraft --class_num 100
57 | python main_lincls.py -a resnet50 --lr 30.0 --batch-size 256 --pretrained [your checkpoint path]/checkpoint_****.pth.tar --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 StanfordCars --class_num 196
58 | python main_lincls.py -a resnet50 --lr 30.0 --batch-size 256 --pretrained [your checkpoint path]/checkpoint_****.pth.tar --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 CUB200 --class_num 200
59 | ```
60 |
61 | Citation
62 | If you find this code or idea useful, please cite our work:
63 | ```
64 | @inproceedings{shu2023learning,
65 | title={Learning Common Rationale to Improve Self-Supervised Representation for Fine-Grained Visual Recognition Problems},
66 | author={Shu, Yangyang and van den Hengel, Anton and Liu, Lingqiao},
67 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
68 | pages={11392--11401},
69 | year={2023}
70 | }
71 | ```
72 |
73 |
74 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | ##
2 |
--------------------------------------------------------------------------------
/__pycache__/arguments.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/__pycache__/arguments.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/linear_cub_eval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/__pycache__/linear_cub_eval.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/linear_eval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/__pycache__/linear_eval.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/linear_stanfordcars_eval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/__pycache__/linear_stanfordcars_eval.cpython-38.pyc
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 |
5 | import numpy as np
6 | import torch
7 | import random
8 |
9 | import re
10 | import yaml
11 |
12 | import shutil
13 | import warnings
14 |
15 | from datetime import datetime
16 |
17 |
18 | class Namespace(object):
19 | def __init__(self, somedict):
20 | for key, value in somedict.items():
21 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
22 | if isinstance(value, dict):
23 | self.__dict__[key] = Namespace(value)
24 | else:
25 | self.__dict__[key] = value
26 |
27 | def __getattr__(self, attribute):
28 |
29 | raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!")
30 |
31 |
32 | def set_deterministic(seed):
33 | # seed by default is None
34 | if seed is not None:
35 | print(f"Deterministic with seed = {seed}")
36 | random.seed(seed)
37 | np.random.seed(seed)
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed(seed)
40 | torch.backends.cudnn.deterministic = True
41 | torch.backends.cudnn.benchmark = False
42 |
43 | def get_args():
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml")
46 | parser.add_argument('--debug', action='store_true')
47 | parser.add_argument('--debug_subset_size', type=int, default=8)
48 | parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web")
49 | parser.add_argument('--data_dir', type=str, default=os.getenv('DATA'))
50 | parser.add_argument('--log_dir', type=str, default=os.getenv('LOG'))
51 | parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT'))
52 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
53 | parser.add_argument('--eval_from', type=str, default=None)
54 | parser.add_argument('--hide_progress', action='store_true')
55 | args = parser.parse_args()
56 |
57 |
58 | with open(args.config_file, 'r') as f:
59 | for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
60 | vars(args)[key] = value
61 |
62 | if args.debug:
63 | if args.train:
64 | args.train.batch_size = 2
65 | args.train.num_epochs = 1
66 | args.train.stop_at_epoch = 1
67 | if args.eval:
68 | args.eval.batch_size = 2
69 | args.eval.num_epochs = 1 # train only one epoch
70 | args.dataset.num_workers = 0
71 |
72 |
73 | assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name]
74 |
75 | args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name)
76 |
77 | os.makedirs(args.log_dir, exist_ok=False)
78 | print(f'creating file {args.log_dir}')
79 | os.makedirs(args.ckpt_dir, exist_ok=True)
80 |
81 | shutil.copy2(args.config_file, args.log_dir)
82 | set_deterministic(args.seed)
83 |
84 |
85 | vars(args)['aug_kwargs'] = {
86 | 'name':args.model.name,
87 | 'image_size': args.dataset.image_size
88 | }
89 | vars(args)['dataset_kwargs'] = {
90 | 'dataset':args.dataset.name,
91 | 'data_dir': args.data_dir,
92 | 'download':args.download,
93 | 'debug_subset_size': args.debug_subset_size if args.debug else None,
94 | }
95 | vars(args)['dataloader_kwargs'] = {
96 | 'drop_last': True,
97 | 'pin_memory': True,
98 | 'num_workers': args.dataset.num_workers,
99 | }
100 |
101 | return args
102 |
--------------------------------------------------------------------------------
/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | from .eval_aug import Transform_single
2 | from .byol_aug import BYOL_transform
3 |
4 | def get_aug(name='byol', image_size=224, train=True, train_classifier=None):
5 |
6 | if train==True:
7 | if name == 'byol':
8 | augmentation = BYOL_transform(image_size)
9 | else:
10 | raise NotImplementedError
11 | elif train==False:
12 | if train_classifier is None:
13 | raise Exception
14 | augmentation = Transform_single(image_size, train=train_classifier)
15 | else:
16 | raise Exception
17 |
18 | return augmentation
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/augmentations/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/augmentations/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/augmentations/__pycache__/byol_aug.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/augmentations/__pycache__/byol_aug.cpython-38.pyc
--------------------------------------------------------------------------------
/augmentations/__pycache__/eval_aug.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/augmentations/__pycache__/eval_aug.cpython-38.pyc
--------------------------------------------------------------------------------
/augmentations/__pycache__/gaussian_blur.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/augmentations/__pycache__/gaussian_blur.cpython-38.pyc
--------------------------------------------------------------------------------
/augmentations/byol_aug.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image, ImageOps
3 | try:
4 | from torchvision.transforms import GaussianBlur
5 | except ImportError:
6 | from .gaussian_blur import GaussianBlur
7 | torchvision.transforms.GaussianBlur = GaussianBlur
8 |
9 | imagenet_norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]
10 |
11 | class BYOL_transform: # Table 6
12 | def __init__(self, image_size, normalize=imagenet_norm):
13 |
14 | self.transform1 = transforms.Compose([
15 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
16 | transforms.RandomHorizontalFlip(p=0.5),
17 | transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8),
18 | transforms.RandomGrayscale(p=0.2),
19 | transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0)), # simclr paper gives the kernel size. Kernel size has to be odd positive number with torchvision
20 | transforms.ToTensor(),
21 | transforms.Normalize(*normalize)
22 | ])
23 | self.transform2 = transforms.Compose([
24 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
25 | transforms.RandomHorizontalFlip(p=0.5),
26 | transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8),
27 | transforms.RandomGrayscale(p=0.2),
28 | # transforms.RandomApply([GaussianBlur(kernel_size=int(0.1 * image_size))], p=0.1),
29 | transforms.RandomApply([transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=0.1),
30 | transforms.RandomApply([Solarization()], p=0.2),
31 |
32 | transforms.ToTensor(),
33 | transforms.Normalize(*normalize)
34 | ])
35 |
36 |
37 | def __call__(self, x):
38 | x1 = self.transform1(x)
39 | x2 = self.transform2(x)
40 | return x1, x2
41 |
42 |
43 | class Transform_single:
44 | def __init__(self, image_size, train, normalize=imagenet_norm):
45 | self.denormalize = Denormalize(*imagenet_norm)
46 | if train == True:
47 | self.transform = transforms.Compose([
48 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
49 | transforms.RandomHorizontalFlip(),
50 | transforms.ToTensor(),
51 | transforms.Normalize(*normalize)
52 | ])
53 | else:
54 | self.transform = transforms.Compose([
55 | transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256
56 | transforms.CenterCrop(image_size),
57 | transforms.ToTensor(),
58 | transforms.Normalize(*normalize)
59 | ])
60 |
61 | def __call__(self, x):
62 | return self.transform(x)
63 |
64 |
65 |
66 | class Solarization():
67 | # ImageFilter
68 | def __init__(self, threshold=128):
69 | self.threshold = threshold
70 | def __call__(self, image):
71 | return ImageOps.solarize(image, self.threshold)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/augmentations/eval_aug.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image
3 |
4 | imagenet_norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]
5 |
6 | class Transform_single():
7 | def __init__(self, image_size, train, normalize=imagenet_norm):
8 | if train == True:
9 | self.transform = transforms.Compose([
10 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
11 | transforms.RandomHorizontalFlip(),
12 | transforms.ToTensor(),
13 | transforms.Normalize(*normalize)
14 | ])
15 | else:
16 | self.transform = transforms.Compose([
17 | transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256
18 | transforms.CenterCrop(image_size),
19 | transforms.ToTensor(),
20 | transforms.Normalize(*normalize)
21 | ])
22 |
23 | def __call__(self, x):
24 | return self.transform(x)
25 |
--------------------------------------------------------------------------------
/augmentations/gaussian_blur.py:
--------------------------------------------------------------------------------
1 | """
2 | Only the recent torchvision package has gaussian blur.
3 | So I copy the functions here
4 |
5 | """
6 |
7 |
8 | import torch
9 | from torch import Tensor
10 | from torchvision.transforms.functional import to_pil_image, to_tensor
11 | from torch.nn.functional import conv2d, pad as torch_pad
12 | from typing import Any, List, Sequence, Optional
13 | import numbers
14 | import numpy as np
15 | import torch
16 | from PIL import Image
17 | from typing import Tuple
18 |
19 | class GaussianBlur(torch.nn.Module):
20 | """Blurs image with randomly chosen Gaussian blur.
21 | The image can be a PIL Image or a Tensor, in which case it is expected
22 | to have [..., C, H, W] shape, where ... means an arbitrary number of leading
23 | dimensions
24 |
25 | Args:
26 | kernel_size (int or sequence): Size of the Gaussian kernel.
27 | sigma (float or tuple of float (min, max)): Standard deviation to be used for
28 | creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
29 | of float (min, max), sigma is chosen uniformly at random to lie in the
30 | given range.
31 |
32 | Returns:
33 | PIL Image or Tensor: Gaussian blurred version of the input image.
34 |
35 | """
36 |
37 | def __init__(self, kernel_size, sigma=(0.1, 2.0)):
38 | super().__init__()
39 | self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
40 | for ks in self.kernel_size:
41 | if ks <= 0 or ks % 2 == 0:
42 | raise ValueError("Kernel size value should be an odd and positive number.")
43 |
44 | if isinstance(sigma, numbers.Number):
45 | if sigma <= 0:
46 | raise ValueError("If sigma is a single number, it must be positive.")
47 | sigma = (sigma, sigma)
48 | elif isinstance(sigma, Sequence) and len(sigma) == 2:
49 | if not 0. < sigma[0] <= sigma[1]:
50 | raise ValueError("sigma values should be positive and of the form (min, max).")
51 | else:
52 | raise ValueError("sigma should be a single number or a list/tuple with length 2.")
53 |
54 | self.sigma = sigma
55 |
56 | @staticmethod
57 | def get_params(sigma_min: float, sigma_max: float) -> float:
58 | """Choose sigma for random gaussian blurring.
59 |
60 | Args:
61 | sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
62 | sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
63 |
64 | Returns:
65 | float: Standard deviation to be passed to calculate kernel for gaussian blurring.
66 | """
67 | return torch.empty(1).uniform_(sigma_min, sigma_max).item()
68 |
69 | def forward(self, img: Tensor) -> Tensor:
70 | """
71 | Args:
72 | img (PIL Image or Tensor): image to be blurred.
73 |
74 | Returns:
75 | PIL Image or Tensor: Gaussian blurred image
76 | """
77 | sigma = self.get_params(self.sigma[0], self.sigma[1])
78 | return gaussian_blur(img, self.kernel_size, [sigma, sigma])
79 |
80 | def __repr__(self):
81 | s = '(kernel_size={}, '.format(self.kernel_size)
82 | s += 'sigma={})'.format(self.sigma)
83 | return self.__class__.__name__ + s
84 |
85 | @torch.jit.unused
86 | def _is_pil_image(img: Any) -> bool:
87 | return isinstance(img, Image.Image)
88 | def _setup_size(size, error_msg):
89 | if isinstance(size, numbers.Number):
90 | return int(size), int(size)
91 |
92 | if isinstance(size, Sequence) and len(size) == 1:
93 | return size[0], size[0]
94 |
95 | if len(size) != 2:
96 | raise ValueError(error_msg)
97 |
98 | return size
99 | def _is_tensor_a_torch_image(x: Tensor) -> bool:
100 | return x.ndim >= 2
101 | def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
102 | ksize_half = (kernel_size - 1) * 0.5
103 |
104 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
105 | pdf = torch.exp(-0.5 * (x / sigma).pow(2))
106 | kernel1d = pdf / pdf.sum()
107 |
108 | return kernel1d
109 |
110 | def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]:
111 | need_squeeze = False
112 | # make image NCHW
113 | if img.ndim < 4:
114 | img = img.unsqueeze(dim=0)
115 | need_squeeze = True
116 |
117 | out_dtype = img.dtype
118 | need_cast = False
119 | if out_dtype != req_dtype:
120 | need_cast = True
121 | img = img.to(req_dtype)
122 | return img, need_cast, need_squeeze, out_dtype
123 | def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype):
124 | if need_squeeze:
125 | img = img.squeeze(dim=0)
126 |
127 | if need_cast:
128 | # it is better to round before cast
129 | img = torch.round(img).to(out_dtype)
130 |
131 | return img
132 | def _get_gaussian_kernel2d(
133 | kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
134 | ) -> Tensor:
135 | kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
136 | kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
137 | kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
138 | return kernel2d
139 | def _gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
140 | """PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel.
141 |
142 | .. warning::
143 |
144 | Module ``transforms.functional_tensor`` is private and should not be used in user application.
145 | Please, consider instead using methods from `transforms.functional` module.
146 |
147 | Args:
148 | img (Tensor): Image to be blurred
149 | kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``.
150 | sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``.
151 |
152 | Returns:
153 | Tensor: An image that is blurred using gaussian kernel of given parameters
154 | """
155 | if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
156 | raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
157 |
158 | dtype = img.dtype if torch.is_floating_point(img) else torch.float32
159 | kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
160 | kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
161 |
162 | img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, kernel.dtype)
163 |
164 | # padding = (left, right, top, bottom)
165 | padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
166 | img = torch_pad(img, padding, mode="reflect")
167 | img = conv2d(img, kernel, groups=img.shape[-3])
168 |
169 | img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
170 | return img
171 |
172 | def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
173 | """Performs Gaussian blurring on the img by given kernel.
174 | The image can be a PIL Image or a Tensor, in which case it is expected
175 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
176 |
177 | Args:
178 | img (PIL Image or Tensor): Image to be blurred
179 | kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
180 | like ``(kx, ky)`` or a single integer for square kernels.
181 | In torchscript mode kernel_size as single int is not supported, use a tuple or
182 | list of length 1: ``[ksize, ]``.
183 | sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
184 | sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
185 | same sigma in both X/Y directions. If None, then it is computed using
186 | ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
187 | Default, None. In torchscript mode sigma as single float is
188 | not supported, use a tuple or list of length 1: ``[sigma, ]``.
189 |
190 | Returns:
191 | PIL Image or Tensor: Gaussian Blurred version of the image.
192 | """
193 | if not isinstance(kernel_size, (int, list, tuple)):
194 | raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size)))
195 | if isinstance(kernel_size, int):
196 | kernel_size = [kernel_size, kernel_size]
197 | if len(kernel_size) != 2:
198 | raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size)))
199 | for ksize in kernel_size:
200 | if ksize % 2 == 0 or ksize < 0:
201 | raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size))
202 |
203 | if sigma is None:
204 | sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
205 |
206 | if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
207 | raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma)))
208 | if isinstance(sigma, (int, float)):
209 | sigma = [float(sigma), float(sigma)]
210 | if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
211 | sigma = [sigma[0], sigma[0]]
212 | if len(sigma) != 2:
213 | raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma)))
214 | for s in sigma:
215 | if s <= 0.:
216 | raise ValueError('sigma should have positive values. Got {}'.format(sigma))
217 |
218 | t_img = img
219 | if not isinstance(img, torch.Tensor):
220 | if not _is_pil_image(img):
221 | raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img)))
222 |
223 | t_img = to_tensor(img)
224 |
225 | output = _gaussian_blur(t_img, kernel_size, sigma)
226 |
227 | if not isinstance(img, torch.Tensor):
228 | output = to_pil_image(output)
229 | return output
230 |
231 |
232 |
233 |
234 | # if __name__ == "__main__":
235 | # gaussian_blur = GaussianBlur(kernel_size=23)
--------------------------------------------------------------------------------
/classifier.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class Classifier(nn.Module):
4 | def __init__(self, inputs, class_num):
5 | super(Classifier, self).__init__()
6 | self.classifier_layer = nn.Linear(inputs, class_num)
7 | self.classifier_layer.weight.data.normal_(0, 0.01)
8 | self.classifier_layer.bias.data.fill_(0.0)
9 |
10 | def forward(self, x):
11 | return self.classifier_layer(x)
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/configs/__init__.py
--------------------------------------------------------------------------------
/configs/byol_aircrafts:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/configs/byol_aircrafts
--------------------------------------------------------------------------------
/configs/byol_aircrafts.yaml:
--------------------------------------------------------------------------------
1 | name: byol-aircrafts-resnet50-experiment
2 | dataset:
3 | name: aircrafts
4 | image_size: 224
5 | num_workers: 8
6 |
7 | model:
8 | name: byol
9 | backbone: resnet50_aircrafts
10 |
11 | train:
12 | optimizer:
13 | name: lars_simclr
14 | weight_decay: 1.5e-6
15 | momentum: 0.9
16 | warmup_epochs: 10
17 | warmup_lr: 0
18 | base_lr: 0.3
19 | final_lr: 0
20 | num_epochs: 100 # this parameter influence the lr decay
21 | stop_at_epoch: 100 # has to be smaller than num_epochs
22 | batch_size: 128
23 | knn_monitor: False # knn monitor will take more time
24 | knn_interval: 3
25 | knn_k: 200
26 | eval: # linear evaluation, False will turn off automatic evaluation after training
27 | optimizer:
28 | name: sgd
29 | weight_decay: 0
30 | momentum: 0.9
31 | warmup_lr: 0
32 | warmup_epochs: 0
33 | base_lr: 30
34 | final_lr: 0
35 | batch_size: 128
36 | num_epochs: 100
37 |
38 | logger:
39 | tensorboard: True
40 | matplotlib: True
41 |
42 | seed: null # None type for yaml file
43 | # two things might lead to stochastic behavior other than seed:
44 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
45 | # (keep this in mind if you want to achieve 100% deterministic)
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/configs/byol_aircrafts_eval.yaml:
--------------------------------------------------------------------------------
1 | name: byol-aircrafts-resnet50-experiment
2 | dataset:
3 | name: aircrafts
4 | image_size: 224
5 | num_workers: 4
6 |
7 | model:
8 | name: byol
9 | backbone: resnet50_aircrafts
10 |
11 | train: null
12 | eval: # linear evaluation, False will turn off automatic evaluation after training
13 | optimizer:
14 | name: sgd
15 | weight_decay: 0
16 | momentum: 0.9
17 | warmup_lr: 0
18 | warmup_epochs: 0
19 | base_lr: 30
20 | final_lr: 0
21 | batch_size: 32
22 | num_epochs: 30
23 |
24 | logger:
25 | tensorboard: True
26 | matplotlib: True
27 |
28 | seed: null # None type for yaml file
29 | # two things might lead to stochastic behavior other than seed:
30 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
31 | # (keep this in mind if you want to achieve 100% deterministic)
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/configs/byol_cifar.yaml:
--------------------------------------------------------------------------------
1 | name: byol-cifar10-experiment
2 | dataset:
3 | name: cifar10
4 | image_size: 32
5 | num_workers: 4
6 |
7 | model:
8 | name: byol
9 | backbone: resnet18_cifar_variant1
10 |
11 | train:
12 | optimizer:
13 | name: lars_simclr
14 | weight_decay: 1.5e-6
15 | warmup_epochs: 10
16 | warmup_lr: 0
17 | base_lr: 0.3
18 | final_lr: 0
19 | num_epochs: 800 # this parameter influence the lr decay
20 | stop_at_epoch: 100 # has to be smaller than num_epochs
21 | batch_size: 256
22 | knn_monitor: False # knn monitor will take more time
23 | knn_interval: 1
24 | knn_k: 200
25 | eval: # linear evaluation, False will turn off automatic evaluation after training
26 | optimizer:
27 | name: sgd
28 | weight_decay: 0
29 | momentum: 0.9
30 | warmup_lr: 0
31 | warmup_epochs: 0
32 | base_lr: 30
33 | final_lr: 0
34 | batch_size: 256
35 | num_epochs: 30
36 |
37 | logger:
38 | tensorboard: True
39 | matplotlib: True
40 |
41 | seed: null # None type for yaml file
42 | # two things might lead to stochastic behavior other than seed:
43 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
44 | # (keep this in mind if you want to achieve 100% deterministic)
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/configs/byol_cub200.yaml:
--------------------------------------------------------------------------------
1 | name: byol-cub200-experiment-resnet50_cub200_variant1
2 | dataset:
3 | name: cub200
4 | image_size: 224
5 | num_workers: 8
6 |
7 | model:
8 | name: byol
9 | backbone: resnet50_cub200
10 |
11 | train:
12 | optimizer:
13 | name: lars_simclr
14 | weight_decay: 1.5e-6
15 | momentum: 0.9
16 | warmup_epochs: 10
17 | warmup_lr: 0
18 | base_lr: 0.3
19 | final_lr: 0
20 | num_epochs: 800 # this parameter influence the lr decay
21 | stop_at_epoch: 100 # has to be smaller than num_epochs
22 | batch_size: 128
23 | knn_monitor: True # knn monitor will take more time
24 | knn_interval: 1
25 | knn_k: 200
26 | eval: # linear evaluation, False will turn off automatic evaluation after training
27 | optimizer:
28 | name: sgd
29 | weight_decay: 0
30 | momentum: 0.9
31 | warmup_lr: 0
32 | warmup_epochs: 0
33 | base_lr: 30
34 | final_lr: 0
35 | batch_size: 128
36 | num_epochs: 100
37 |
38 | logger:
39 | tensorboard: True
40 | matplotlib: True
41 |
42 | seed: null # None type for yaml file
43 | # two things might lead to stochastic behavior other than seed:
44 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
45 | # (keep this in mind if you want to achieve 100% deterministic)
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/configs/byol_stanfordcars.yaml:
--------------------------------------------------------------------------------
1 | name: byol-stanfordcars-resnet50-experiment
2 | dataset:
3 | name: stanfordcars
4 | image_size: 224
5 | num_workers: 4
6 |
7 | model:
8 | name: byol
9 | backbone: resnet50_stanfordcars
10 |
11 | train:
12 | optimizer:
13 | name: lars_simclr
14 | weight_decay: 1.5e-6
15 | momentum: 0.9
16 | warmup_epochs: 10
17 | warmup_lr: 0
18 | base_lr: 0.3
19 | final_lr: 0
20 | num_epochs: 100 # this parameter influence the lr decay
21 | stop_at_epoch: 100 # has to be smaller than num_epochs
22 | batch_size: 128
23 | knn_monitor: False # knn monitor will take more time
24 | knn_interval: 3
25 | knn_k: 200
26 | eval: # linear evaluation, False will turn off automatic evaluation after training
27 | optimizer:
28 | name: sgd
29 | weight_decay: 0
30 | momentum: 0.9
31 | warmup_lr: 0
32 | warmup_epochs: 0
33 | base_lr: 30
34 | final_lr: 0
35 | batch_size: 128
36 | num_epochs: 100
37 |
38 | logger:
39 | tensorboard: True
40 | matplotlib: True
41 |
42 | seed: null # None type for yaml file
43 | # two things might lead to stochastic behavior other than seed:
44 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
45 | # (keep this in mind if you want to achieve 100% deterministic)
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/datasets/CUB200.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # Reading data
3 | import scipy.misc
4 | from matplotlib.pyplot import imread
5 | import cv2
6 | import os
7 | from PIL import Image
8 | from torchvision import transforms
9 | import torch
10 |
11 | class CUB():
12 | def __init__(self, root, is_train=True, data_len=None,transform=None, target_transform=None):
13 | self.root = root
14 | self.is_train = is_train
15 | self.transform = transform
16 | self.target_transform = target_transform
17 | img_txt_file = open(os.path.join(self.root, 'images.txt'))
18 | label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
19 | train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
20 | train_class_file = open(os.path.join(self.root, 'classes.txt'))
21 | # Picture index
22 | img_name_list = []
23 | for line in img_txt_file:
24 | # The last character is a newline character
25 | img_name_list.append(line[:-1].split(' ')[-1])
26 |
27 | # Tag Index , Each corresponding label minus 1, Tag value from 0 Start
28 | label_list = []
29 | for line in label_txt_file:
30 | label_list.append(int(line[:-1].split(' ')[-1]) - 1)
31 | train_class_list=[]
32 | for line in train_class_file:
33 | train_class_list.append(line[:-1].split('.')[-1] )
34 |
35 | # Set up training and test sets
36 | train_test_list = []
37 | for line in train_val_file:
38 | train_test_list.append(int(line[:-1].split(' ')[-1]))
39 |
40 | # zip Compress merge , Associate data with labels ( Training set or test set ) Corresponding compression
41 | # zip() Function to take iteratable objects as parameters , Package the corresponding elements in the object into tuples ,
42 | # And then return the objects made up of these tuples , The advantage is that it saves a lot of memory .
43 | # We can use list() Convert to output list
44 |
45 | # If i by 1, Then set it as the training set
46 | # 1 As the training set ,0 For test set
47 | # zip Compress merge , Associate data with labels ( Training set or test set ) Corresponding compression
48 | # If i by 1, Then set it as the training set
49 | train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
50 | test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
51 |
52 | train_label_list = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
53 | test_label_list = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
54 | if self.is_train:
55 | # scipy.misc.imread The picture reads out as array type , namely numpy type
56 | self.train_img = [self.read_pic(os.path.join(self.root, 'images', train_file)) for train_file in
57 | train_file_list[:data_len]]
58 | # Read the training set label
59 | self.train_label = train_label_list
60 | self.targets=train_label_list
61 | self.classes=train_class_list
62 | if not self.is_train:
63 | self.test_img = [self.read_pic(os.path.join(self.root, 'images', test_file)) for test_file in
64 | test_file_list[:data_len]]
65 | self.test_label = test_label_list
66 | self.targets = test_label_list
67 | self.classes = train_class_list
68 |
69 | # Data to enhance
70 | def __getitem__(self,index):
71 | # Training set
72 | if self.is_train:
73 | img, target = self.train_img[index], self.train_label[index]
74 | # Test set
75 | else:
76 | img, target = self.test_img[index], self.test_label[index]
77 |
78 | if len(img.shape) == 2:
79 | # Gray images are converted to three channels
80 | img = np.stack([img]*3,2)
81 | # To RGB type
82 | img = Image.fromarray(img,mode='RGB')
83 | if self.transform is not None:
84 | img = self.transform(img)
85 |
86 | if self.target_transform is not None:
87 | target = self.target_transform(target)
88 |
89 | return img, target
90 |
91 | def __len__(self):
92 | if self.is_train:
93 | return len(self.train_label)
94 | else:
95 | return len(self.test_label)
96 |
97 | def read_pic(self, train_file ):
98 | try:
99 | img= cv2.imread(train_file)
100 | return img
101 | except Exception as e:
102 | print(train_file)
103 | print(e)
104 |
105 | if __name__ == '__main__':
106 | ''' dataset = CUB(root='./CUB_200_2011') for data in dataset: print(data[0].size(),data[1]) '''
107 | # With pytorch in DataLoader Read the data set in a way
108 | transform_train = transforms.Compose([
109 | transforms.Resize((224, 224)),
110 | transforms.RandomCrop(224, padding=4),
111 | transforms.RandomHorizontalFlip(),
112 | transforms.ToTensor(),
113 | transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
114 | ])
115 |
116 | dataset = CUB(root='./CUB_200_2011', is_train=False, transform=transform_train,)
117 | print(len(dataset))
118 | trainloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0,
119 | drop_last=True)
120 | print(len(trainloader))
--------------------------------------------------------------------------------
/datasets/CUB200_val.py:
--------------------------------------------------------------------------------
1 | # Reading data
2 | import os
3 | import pathlib
4 | import shutil
5 |
6 | class CUB():
7 | def __init__(self, root, is_train=True):
8 | self.root = root
9 | self.is_train = is_train
10 |
11 | def convert_bird(data_root):
12 | images_txt = os.path.join(data_root, 'images.txt')
13 | train_val_txt = os.path.join(data_root, 'train_test_split.txt')
14 | labels_txt = os.path.join(data_root, 'image_class_labels.txt')
15 | image_folder=os.path.join(data_root, 'images')
16 | id_name_dict = {}
17 | id_class_dict = {}
18 | id_train_val = {}
19 | with open(images_txt, 'r', encoding='utf-8') as f:
20 | line = f.readline()
21 | while line:
22 | id, name = line.strip().split()
23 | id_name_dict[id] = name
24 | line = f.readline()
25 |
26 | with open(train_val_txt, 'r', encoding='utf-8') as f:
27 | line = f.readline()
28 | while line:
29 | id, trainval = line.strip().split()
30 | id_train_val[id] = trainval
31 | line = f.readline()
32 |
33 | with open(labels_txt, 'r', encoding='utf-8') as f:
34 | line = f.readline()
35 | while line:
36 | id, class_id = line.strip().split()
37 | id_class_dict[id] = int(class_id)
38 | line = f.readline()
39 |
40 | train_txt = os.path.join(data_root, 'bird_train.txt')
41 | test_txt = os.path.join(data_root, 'bird_test.txt')
42 |
43 | train_folder=os.path.join(data_root, 'train')
44 | test_folder = os.path.join(data_root, 'test')
45 | if os.path.exists(train_txt):
46 | os.remove(train_txt)
47 | if os.path.exists(test_txt):
48 | os.remove(test_txt)
49 | if os.path.exists(train_folder):
50 | os.remove(train_folder)
51 | if os.path.exists(test_folder):
52 | os.remove(test_folder)
53 |
54 | f1 = open(train_txt, 'a', encoding='utf-8')
55 | f2 = open(test_txt, 'a', encoding='utf-8')
56 |
57 | for id, trainval in id_train_val.items():
58 | if trainval == '1':
59 | src_path=os.path.join(image_folder, id_name_dict[id])
60 | dst_path=os.path.join(train_folder,str(id_class_dict[id] - 1))
61 | pathlib.Path(dst_path).mkdir(parents=True, exist_ok=True)
62 | shutil.copy(src_path,os.path.join(dst_path,os.path.basename(src_path)))
63 | f1.write('%s %d\n' % (id_name_dict[id], id_class_dict[id] - 1))
64 | else:
65 |
66 | src_path=os.path.join(image_folder, id_name_dict[id])
67 | dst_path=os.path.join(test_folder,str(id_class_dict[id] - 1))
68 | pathlib.Path(dst_path).mkdir(parents=True, exist_ok=True)
69 | shutil.copy(src_path,os.path.join(dst_path,os.path.basename(src_path)))
70 | f2.write('%s %d\n' % (id_name_dict[id], id_class_dict[id] - 1))
71 | f1.close()
72 | f2.close()
73 |
74 | convert_bird(self.root)
--------------------------------------------------------------------------------
/datasets/CUB2011.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from torchvision.datasets.folder import default_loader
4 | from torchvision.datasets.utils import download_url
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class Cub2011(Dataset):
9 | base_folder = 'CUB_200_2011/images'
10 | url = 'https://data.caltech.edu/tindfiles/serve/1239ea37-e132-42ee-8c09-c383bb54e7ff/' #''http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
11 |
12 |
13 | filename = 'CUB_200_2011.tgz'
14 | tgz_md5 = '97eceeb196236b17998738112f37df78'
15 |
16 | def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
17 | self.root = os.path.expanduser(root)
18 | self.transform = transform
19 | self.loader = default_loader
20 | self.train = train
21 | self.classes=None
22 | self.targets = None
23 | if download:
24 | self._download()
25 |
26 | if not self._check_integrity():
27 | raise RuntimeError('Dataset not found or corrupted.' +
28 | ' You can use download=True to download it')
29 |
30 | def _load_metadata(self):
31 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
32 | names=['img_id', 'filepath'])
33 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
34 | sep=' ', names=['img_id', 'target'])
35 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
36 | sep=' ', names=['img_id', 'is_training_img'])
37 | image_classes=pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),
38 | sep=' ', names=['cls_id', 'cls_name'])
39 |
40 | self.classes = image_classes['cls_name'].tolist()#image_classes.merge(classes, on='cls_id')
41 | data = images.merge(image_class_labels, on='img_id')
42 | self.data = data.merge(train_test_split, on='img_id')
43 |
44 | if self.train:
45 | self.data = self.data[self.data.is_training_img == 1]
46 | self.targets = self.data.target.tolist()
47 | self.targets=[x-1 for x in self.targets ]
48 | else:
49 | self.data = self.data[self.data.is_training_img == 0]
50 | self.targets = self.data.target.tolist()
51 | self.targets = [x - 1 for x in self.targets]
52 |
53 | def _check_integrity(self):
54 | try:
55 | self._load_metadata()
56 | except Exception:
57 | return False
58 |
59 | for index, row in self.data.iterrows():
60 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
61 | if not os.path.isfile(filepath):
62 | print(filepath)
63 | return False
64 | return True
65 |
66 | def _download(self):
67 | import tarfile
68 |
69 | if self._check_integrity():
70 | print('Files already downloaded and verified')
71 | return
72 |
73 | download_url(self.url, self.root, self.filename, self.tgz_md5)
74 |
75 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
76 | tar.extractall(path=self.root)
77 |
78 | def __len__(self):
79 | return len(self.data)
80 |
81 | def __getitem__(self, idx):
82 | sample = self.data.iloc[idx]
83 | path = os.path.join(self.root, self.base_folder, sample.filepath)
84 | target = sample.target-1 # Targets start at 1 by default, so shift to 0
85 | img = self.loader(path)
86 |
87 | if self.transform is not None:
88 | img = self.transform(img)
89 |
90 | return img, target
--------------------------------------------------------------------------------
/datasets/ImageNet100.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from torchvision.datasets.folder import default_loader
4 | from torchvision.datasets.utils import download_url
5 | from torch.utils.data import Dataset
6 | import torch
7 | import torchvision
8 | #https://blog.csdn.net/Andrew_SJ/article/details/111335319
9 | class ImageNet100():
10 | def __init__(self, data_dir, train=True, transform=None):
11 | self.root = os.path.expanduser(data_dir)
12 | self.transform = transform
13 | #self.loader = default_loader
14 | self.train = train
15 | # self.classes=None
16 | # self.targets = None
17 | l=[]
18 | l.append(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train.X1'),transform=transform))
19 | l.append(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train.X2'), transform=transform))
20 | l.append(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train.X3'), transform=transform))
21 | l.append(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train.X4'), transform=transform))
22 | train_datasets=torch.utils.data.ConcatDataset(l)
23 | train_size = int(0.8 * len(train_datasets))
24 | test_size = len(train_datasets) - train_size
25 | train_dataset, test_dataset = torch.utils.data.random_split(train_datasets, [train_size, test_size])
26 | self.train_dataset = train_dataset.dataset
27 | self.test_dataset=test_dataset.dataset
28 |
29 | #base_folder = 'CUB_200_2011/images'
30 | #url = 'https://data.caltech.edu/tindfiles/serve/1239ea37-e132-42ee-8c09-c383bb54e7ff/' #''http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
31 | def getdata(self):
32 | if self.train:
33 | return self.train_dataset
34 | else: return self.test_dataset
35 |
36 | #filename = 'CUB_200_2011.tgz'
37 | #tgz_md5 = '97eceeb196236b17998738112f37df78'
38 |
39 |
40 | class train_dataset_transformed(Dataset):
41 | def __init__(self, subset, transform=None):
42 | self.subset = subset
43 | self.transform = transform
44 |
45 | def __getitem__(self, index):
46 | x, y = self.subset[index]
47 | if self.transform:
48 | x = self.transform(x)
49 | return x, y
50 |
51 | def __len__(self):
52 | return len(self.subset)
53 |
54 | #train_dataset = train_dataset_transformed(train_dataset, transform=train_transform)
55 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import torch
4 | import torchvision
5 | from .random_dataset import RandomDataset
6 | from .CUB2011 import Cub2011 as CUB
7 | from .ImageNet100 import ImageNet100
8 | #from .StandfordCars import StandfordCars
9 |
10 |
11 | def get_dataset(dataset, data_dir, transform, train=True, download=False, debug_subset_size=None):
12 | if dataset == 'mnist':
13 | dataset = torchvision.datasets.MNIST(data_dir, train=train, transform=transform, download=download)
14 | elif dataset == 'stl10':
15 | dataset = torchvision.datasets.STL10(data_dir, split='train+unlabeled' if train else 'test', transform=transform, download=download)
16 | elif dataset == 'cifar10':
17 | dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=download)
18 | elif dataset == 'cifar100':
19 | dataset = torchvision.datasets.CIFAR100(data_dir, train=train, transform=transform, download=download)
20 | elif dataset == 'imagenet':
21 | dataset = torchvision.datasets.ImageNet(data_dir, split='train' if train == True else 'val', transform=transform, download=download)
22 | elif dataset == 'imagenet100':
23 | data_builder = ImageNet100(data_dir, train = train,transform=transform)# torchvision.datasets.ImageFolder(data_dir , transform=transform) #+ 'train' if train == True else 'val',
24 | dataset=data_builder.getdata()
25 | elif dataset =='cub200':
26 | # dataset = elif dataset == 'imagenet':
27 | # dataset= torchvision.datasets.ImageFolder(data_dir+ '/train' if train == True else data_dir+ '/test', transform=transform)
28 | # print('from datasets.imagefolders')
29 | #############dataset = CUB(data_dir, train = train,transform=transform,download=download)
30 | dataset = torchvision.datasets.ImageFolder(
31 | os.path.join(data_dir, 'train') if train == True else os.path.join(data_dir, 'test'), transform=transform)
32 | # print('')
33 | # trainloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0,
34 | # drop_last=True)
35 | elif dataset=='stanfordcars':
36 | dataset= torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train') if train == True else os.path.join(data_dir, 'test'), transform=transform)#StandfordCars(data_dir,train=train,transform=transform)
37 | elif dataset == 'aircrafts':
38 | dataset = torchvision.datasets.ImageFolder(
39 | os.path.join(data_dir, 'train') if train == True else os.path.join(data_dir, 'test'),
40 | transform=transform)
41 | elif dataset == 'random':
42 | dataset = RandomDataset()
43 | else:
44 | raise NotImplementedError
45 | #debug_subset_size=50
46 | if debug_subset_size is not None:
47 | dataset = torch.utils.data.Subset(dataset, range(0, debug_subset_size)) # take only one batch
48 | dataset.classes = dataset.dataset.classes
49 | dataset.targets = dataset.dataset.targets
50 | return dataset
--------------------------------------------------------------------------------
/datasets/__pycache__/CUB2011.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/datasets/__pycache__/CUB2011.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/ImageNet100.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/datasets/__pycache__/ImageNet100.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/random_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/datasets/__pycache__/random_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/random_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class RandomDataset(torch.utils.data.Dataset):
4 | def __init__(self, root=None, train=True, transform=None, target_transform=None):
5 | self.transform = transform
6 | self.target_transform = target_transform
7 |
8 | self.size = 1000
9 | def __getitem__(self, idx):
10 | if idx < self.size:
11 | return [torch.randn((3, 224, 224)), torch.randn((3, 224, 224))], [0,0,0]
12 | else:
13 | raise Exception
14 |
15 | def __len__(self):
16 | return self.size
17 |
--------------------------------------------------------------------------------
/examples/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/examples/framework.png
--------------------------------------------------------------------------------
/hand_detector.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from cvzone.HandTrackingModule import HandDetector
3 |
4 | cap = cv2.VideoCapture(0)
5 | cap.set(3, 1280)
6 | cap.set(4, 720)
7 | detector = HandDetector(detectionCon=0.8)
8 |
9 | while True:
10 | success, img = cap.read()
11 | pass
12 | hands, img = detector.findHands(img)
13 | cv2.imshow('Hands', img)
14 | cv2.waitKey(1)
15 |
--------------------------------------------------------------------------------
/linear_cub_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | from tqdm import tqdm
7 | from arguments import get_args
8 | from augmentations import get_aug
9 | from models import get_model, get_backbone
10 | from tools import AverageMeter
11 | from datasets import get_dataset
12 | from optimizers import get_optimizer, LR_Scheduler
13 |
14 | def main(args):
15 |
16 | train_loader = torch.utils.data.DataLoader(
17 | dataset=get_dataset(
18 | transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
19 | train=True,
20 | **args.dataset_kwargs
21 | ),
22 | batch_size=args.eval.batch_size,
23 | shuffle=True,
24 | **args.dataloader_kwargs
25 | )
26 | test_loader = torch.utils.data.DataLoader(
27 | dataset=get_dataset(
28 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
29 | train=False,
30 | **args.dataset_kwargs
31 | ),
32 | batch_size=args.eval.batch_size,
33 | shuffle=False,
34 | **args.dataloader_kwargs
35 | )
36 |
37 |
38 | #model = get_backbone(args.model.backbone)
39 | model = get_model(args.model)
40 | classifier = nn.Linear(in_features=2048, out_features=200, bias=True).to(args.device)
41 |
42 | assert args.eval_from is not None
43 | save_dict = torch.load(args.eval_from, map_location='cpu')
44 | #msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
45 |
46 | # print(msg)
47 | model = model.to(args.device)
48 | model = torch.nn.DataParallel(model)
49 |
50 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
51 | classifier = torch.nn.DataParallel(classifier)
52 | # define optimizer
53 | optimizer = get_optimizer(
54 | args.eval.optimizer.name, classifier,
55 | lr=args.eval.base_lr*args.eval.batch_size/256,
56 | momentum=args.eval.optimizer.momentum,
57 | weight_decay=args.eval.optimizer.weight_decay)
58 |
59 | # define lr scheduler
60 | lr_scheduler = LR_Scheduler(
61 | optimizer,
62 | args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256,
63 | args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256,
64 | len(train_loader),
65 | )
66 |
67 | loss_meter = AverageMeter(name='Loss')
68 | acc_meter = AverageMeter(name='Accuracy')
69 |
70 | # Start training
71 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
72 | for epoch in global_progress:
73 | loss_meter.reset()
74 | model.eval()
75 | classifier.train()
76 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True)
77 |
78 | for idx, (images, labels) in enumerate(local_progress):
79 |
80 | classifier.zero_grad()
81 | with torch.no_grad():
82 | feat, bp_out_feat = model.module.inference(images.to(args.device))
83 | preds = classifier(bp_out_feat)
84 |
85 | loss = F.cross_entropy(preds, labels.to(args.device))
86 |
87 | loss.backward()
88 | optimizer.step()
89 | loss_meter.update(loss.item())
90 | lr = lr_scheduler.step()
91 | local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
92 |
93 | classifier.eval()
94 | correct, total = 0, 0
95 | acc_meter.reset()
96 | for idx, (images, labels) in enumerate(test_loader):
97 | with torch.no_grad():
98 | feat,bp_out_feat = model.module.inference(images.to(args.device))
99 |
100 | preds = classifier(bp_out_feat).argmax(dim=1)
101 | correct = (preds == labels.to(args.device)).sum().item()
102 | acc_meter.update(correct/preds.shape[0])
103 | print(f'Accuracy = {acc_meter.avg*100:.2f}')
104 |
105 |
106 |
107 |
108 | if __name__ == "__main__":
109 | main(args=get_args())
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/linear_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | from tqdm import tqdm
7 | from arguments import get_args
8 | from augmentations import get_aug
9 | from models import get_model, get_backbone
10 | from tools import AverageMeter
11 | from datasets import get_dataset
12 | from optimizers import get_optimizer, LR_Scheduler
13 |
14 | def main(args):
15 |
16 | train_loader = torch.utils.data.DataLoader(
17 | dataset=get_dataset(
18 | transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
19 | train=True,
20 | **args.dataset_kwargs
21 | ),
22 | batch_size=args.eval.batch_size,
23 | shuffle=True,
24 | **args.dataloader_kwargs
25 | )
26 | test_loader = torch.utils.data.DataLoader(
27 | dataset=get_dataset(
28 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
29 | train=False,
30 | **args.dataset_kwargs
31 | ),
32 | batch_size=args.eval.batch_size,
33 | shuffle=False,
34 | **args.dataloader_kwargs
35 | )
36 |
37 |
38 | model = get_backbone(args.model.backbone)
39 | classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True).to(args.device)
40 |
41 | assert args.eval_from is not None
42 | save_dict = torch.load(args.eval_from, map_location='cpu')
43 | msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
44 |
45 | # print(msg)
46 | model = model.to(args.device)
47 | model = torch.nn.DataParallel(model)
48 |
49 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
50 | classifier = torch.nn.DataParallel(classifier)
51 | # define optimizer
52 | optimizer = get_optimizer(
53 | args.eval.optimizer.name, classifier,
54 | lr=args.eval.base_lr*args.eval.batch_size/256,
55 | momentum=args.eval.optimizer.momentum,
56 | weight_decay=args.eval.optimizer.weight_decay)
57 |
58 | # define lr scheduler
59 | lr_scheduler = LR_Scheduler(
60 | optimizer,
61 | args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256,
62 | args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256,
63 | len(train_loader),
64 | )
65 |
66 | loss_meter = AverageMeter(name='Loss')
67 | acc_meter = AverageMeter(name='Accuracy')
68 |
69 | # Start training
70 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
71 | for epoch in global_progress:
72 | loss_meter.reset()
73 | model.eval()
74 | classifier.train()
75 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True)
76 |
77 | for idx, (images, labels) in enumerate(local_progress):
78 |
79 | classifier.zero_grad()
80 | with torch.no_grad():
81 | feature = model(images.to(args.device))
82 |
83 | preds = classifier(feature)
84 |
85 | loss = F.cross_entropy(preds, labels.to(args.device))
86 |
87 | loss.backward()
88 | optimizer.step()
89 | loss_meter.update(loss.item())
90 | lr = lr_scheduler.step()
91 | local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
92 |
93 | classifier.eval()
94 | correct, total = 0, 0
95 | acc_meter.reset()
96 | for idx, (images, labels) in enumerate(test_loader):
97 | with torch.no_grad():
98 | feature = model(images.to(args.device))
99 | preds = classifier(feature).argmax(dim=1)
100 | correct = (preds == labels.to(args.device)).sum().item()
101 | acc_meter.update(correct/preds.shape[0])
102 | print(f'Accuracy = {acc_meter.avg*100:.2f}')
103 |
104 |
105 |
106 |
107 | if __name__ == "__main__":
108 | main(args=get_args())
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/linear_imagenet100_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | from tqdm import tqdm
7 | from arguments import get_args
8 | from augmentations import get_aug
9 | from models import get_model, get_backbone
10 | from tools import AverageMeter
11 | from datasets import get_dataset
12 | from optimizers import get_optimizer, LR_Scheduler
13 |
14 | def main(args):
15 |
16 | train_loader = torch.utils.data.DataLoader(
17 | dataset=get_dataset(
18 | transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
19 | train=True,
20 | **args.dataset_kwargs
21 | ),
22 | batch_size=args.eval.batch_size,
23 | shuffle=True,
24 | **args.dataloader_kwargs
25 | )
26 | test_loader = torch.utils.data.DataLoader(
27 | dataset=get_dataset(
28 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
29 | train=False,
30 | **args.dataset_kwargs
31 | ),
32 | batch_size=args.eval.batch_size,
33 | shuffle=False,
34 | **args.dataloader_kwargs
35 | )
36 |
37 |
38 | model = get_backbone(args.model.backbone)
39 | classifier = nn.Linear(in_features=model.output_dim, out_features=100, bias=True).to(args.device)
40 |
41 | assert args.eval_from is not None
42 | save_dict = torch.load(args.eval_from, map_location='cpu')
43 | msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
44 |
45 | # print(msg)
46 | model = model.to(args.device)
47 | model = torch.nn.DataParallel(model)
48 |
49 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
50 | classifier = torch.nn.DataParallel(classifier)
51 | # define optimizer
52 | optimizer = get_optimizer(
53 | args.eval.optimizer.name, classifier,
54 | lr=args.eval.base_lr*args.eval.batch_size/256,
55 | momentum=args.eval.optimizer.momentum,
56 | weight_decay=args.eval.optimizer.weight_decay)
57 |
58 | # define lr scheduler
59 | lr_scheduler = LR_Scheduler(
60 | optimizer,
61 | args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256,
62 | args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256,
63 | len(train_loader),
64 | )
65 |
66 | loss_meter = AverageMeter(name='Loss')
67 | acc_meter = AverageMeter(name='Accuracy')
68 |
69 | # Start training
70 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
71 | for epoch in global_progress:
72 | loss_meter.reset()
73 | model.eval()
74 | classifier.train()
75 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True)
76 |
77 | for idx, (images, labels) in enumerate(local_progress):
78 |
79 | classifier.zero_grad()
80 | with torch.no_grad():
81 | feature = model(images.to(args.device))
82 |
83 | preds = classifier(feature)
84 |
85 | loss = F.cross_entropy(preds, labels.to(args.device))
86 |
87 | loss.backward()
88 | optimizer.step()
89 | loss_meter.update(loss.item())
90 | lr = lr_scheduler.step()
91 | local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
92 |
93 | classifier.eval()
94 | correct, total = 0, 0
95 | acc_meter.reset()
96 | for idx, (images, labels) in enumerate(test_loader):
97 | with torch.no_grad():
98 | feature = model(images.to(args.device))
99 | preds = classifier(feature).argmax(dim=1)
100 | correct = (preds == labels.to(args.device)).sum().item()
101 | acc_meter.update(correct/preds.shape[0])
102 | print(f'Accuracy = {acc_meter.avg*100:.2f}')
103 |
104 |
105 |
106 |
107 | if __name__ == "__main__":
108 | main(args=get_args())
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/linear_stanfordcars_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | from tqdm import tqdm
7 | from arguments import get_args
8 | from augmentations import get_aug
9 | from models import get_model, get_backbone
10 | from tools import AverageMeter
11 | from datasets import get_dataset
12 | from optimizers import get_optimizer, LR_Scheduler
13 |
14 | def main(args,class_num):
15 | classes=class_num
16 | train_loader = torch.utils.data.DataLoader(
17 | dataset=get_dataset(
18 | transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
19 | train=True,
20 | **args.dataset_kwargs
21 | ),
22 | batch_size=args.eval.batch_size,
23 | shuffle=True,
24 | **args.dataloader_kwargs
25 | )
26 | test_loader = torch.utils.data.DataLoader(
27 | dataset=get_dataset(
28 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
29 | train=False,
30 | **args.dataset_kwargs
31 | ),
32 | batch_size=args.eval.batch_size,
33 | shuffle=False,
34 | **args.dataloader_kwargs
35 | )
36 |
37 |
38 | #model = get_backbone(args.model.backbone)
39 | model = get_model(args.model)
40 | classifier = nn.Linear(in_features=2048, out_features=classes, bias=True).to(args.device)
41 |
42 | assert args.eval_from is not None
43 | save_dict = torch.load(args.eval_from, map_location='cpu')
44 | #msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
45 |
46 | # print(msg)
47 | model = model.to(args.device)
48 | model = torch.nn.DataParallel(model)
49 |
50 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
51 | classifier = torch.nn.DataParallel(classifier)
52 | # define optimizer
53 | optimizer = get_optimizer(
54 | args.eval.optimizer.name, classifier,
55 | lr=args.eval.base_lr*args.eval.batch_size/256,
56 | momentum=args.eval.optimizer.momentum,
57 | weight_decay=args.eval.optimizer.weight_decay)
58 |
59 | # define lr scheduler
60 | lr_scheduler = LR_Scheduler(
61 | optimizer,
62 | args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256,
63 | args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256,
64 | len(train_loader),
65 | )
66 |
67 | loss_meter = AverageMeter(name='Loss')
68 | acc_meter = AverageMeter(name='Accuracy')
69 |
70 | # Start training
71 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
72 | for epoch in global_progress:
73 | loss_meter.reset()
74 | model.eval()
75 | classifier.train()
76 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True)
77 |
78 | for idx, (images, labels) in enumerate(local_progress):
79 |
80 | classifier.zero_grad()
81 | with torch.no_grad():
82 | feature, bp_out_feat = model.module.inference(images.to(args.device))
83 |
84 | preds = classifier(bp_out_feat)
85 |
86 | loss = F.cross_entropy(preds, labels.to(args.device))
87 |
88 | loss.backward()
89 | optimizer.step()
90 | loss_meter.update(loss.item())
91 | lr = lr_scheduler.step()
92 | local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
93 |
94 | classifier.eval()
95 | correct, total = 0, 0
96 | acc_meter.reset()
97 | for idx, (images, labels) in enumerate(test_loader):
98 | with torch.no_grad():
99 | feature, bp_out_feat = model.module.inference(images.to(args.device))
100 |
101 |
102 | preds = classifier(bp_out_feat).argmax(dim=1)
103 | correct = (preds == labels.to(args.device)).sum().item()
104 | acc_meter.update(correct/preds.shape[0])
105 | print(f'Accuracy = {acc_meter.avg*100:.2f}')
106 |
107 |
108 |
109 |
110 | if __name__ == "__main__":
111 | main(args=get_args())
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from PIL import ImageFilter
8 | import random
9 |
10 |
11 | class TwoCropsTransform:
12 | """Take two random crops of one image as the query and key."""
13 |
14 | def __init__(self, base_transform):
15 | self.base_transform = base_transform
16 |
17 | def __call__(self, x):
18 | q = self.base_transform(x)
19 | k = self.base_transform(x)
20 | return [q, k]
21 |
22 |
23 | class GaussianBlur(object):
24 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
25 |
26 | def __init__(self, sigma=[.1, 2.]):
27 | self.sigma = sigma
28 |
29 | def __call__(self, x):
30 | sigma = random.uniform(self.sigma[0], self.sigma[1])
31 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
32 | return x
33 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | import numpy as np
7 | from tqdm import tqdm
8 | from arguments import get_args
9 | from augmentations import get_aug
10 | from models import get_model
11 | from tools import AverageMeter, knn_monitor, Logger, file_exist_check
12 | from datasets import get_dataset
13 | from optimizers import get_optimizer, LR_Scheduler
14 | #from linear_eval import main as linear_eval
15 | from linear_cub_eval import main as cub_linear_eval
16 | from linear_stanfordcars_eval import main as linear_stanfordcars_eval
17 | from linear_eval import main as linear_eval
18 | from datetime import datetime
19 |
20 | def main(device, args):
21 |
22 | train_loader = torch.utils.data.DataLoader(
23 | dataset=get_dataset(
24 | transform=get_aug(train=True, **args.aug_kwargs),
25 | train=True,
26 | **args.dataset_kwargs),
27 | shuffle=True,
28 | batch_size=args.train.batch_size,
29 | **args.dataloader_kwargs
30 | )
31 |
32 | memory_loader = torch.utils.data.DataLoader(
33 | dataset=get_dataset(
34 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
35 | train=True,
36 | **args.dataset_kwargs),
37 | shuffle=False,
38 | batch_size=args.train.batch_size,
39 | **args.dataloader_kwargs
40 | )
41 | #test_loader=memory_loader
42 | test_loader = torch.utils.data.DataLoader(
43 | dataset=get_dataset(
44 | transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
45 | train=False,
46 | **args.dataset_kwargs),
47 | shuffle=False,
48 | batch_size=args.train.batch_size,
49 | **args.dataloader_kwargs
50 | )
51 |
52 | # define model
53 | model = get_model(args.model).to(device)
54 | model = torch.nn.DataParallel(model)
55 | #model = torch.nn.parallel.DistributedDataParallel(model)
56 |
57 | # define optimizer
58 | optimizer = get_optimizer(
59 | args.train.optimizer.name, model,
60 | lr=args.train.base_lr*args.train.batch_size/256,
61 | momentum=args.train.optimizer.momentum,
62 | weight_decay=args.train.optimizer.weight_decay)
63 |
64 | lr_scheduler = LR_Scheduler(
65 | optimizer,
66 | args.train.warmup_epochs, args.train.warmup_lr*args.train.batch_size/256,
67 | args.train.num_epochs, args.train.base_lr*args.train.batch_size/256, args.train.final_lr*args.train.batch_size/256,
68 | len(train_loader),
69 | constant_predictor_lr=True # see the end of section 4.2 predictor
70 | )
71 |
72 | logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir)
73 | accuracy = 0
74 | # Start training
75 | global_progress = tqdm(range(0, args.train.stop_at_epoch), desc=f'Training')
76 | for epoch in global_progress:
77 | model.train()
78 |
79 | local_progress=tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress)
80 | for idx, ((images1, images2), labels) in enumerate(local_progress):
81 |
82 | model.zero_grad()
83 | data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True))
84 | loss = data_dict['loss'].mean() # ddp
85 |
86 |
87 | t = 0.4
88 | grad_cam = data_dict['gradcam'].view(data_dict['gradcam'].size(0), -1)
89 | grad_cam = (grad_cam / t).float()
90 | att_max = data_dict['attmap'].view(data_dict['attmap'].size(0), -1)
91 | att_max = (att_max / t).float()
92 | loss_cam = F.kl_div(att_max.softmax(dim=-1).log(), grad_cam.softmax(dim=-1),
93 | reduction='sum')
94 |
95 | loss = loss + 0.01 * loss_cam
96 |
97 | loss.backward()
98 | optimizer.step()
99 | lr_scheduler.step()
100 | data_dict.update({'lr':lr_scheduler.get_lr()})
101 |
102 | local_progress.set_postfix(data_dict)
103 | #logger.update_scalers(data_dict)
104 |
105 | if args.train.knn_monitor and epoch % args.train.knn_interval == 0:
106 | accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device, k=min(args.train.knn_k, len(memory_loader.dataset)), hide_progress=args.hide_progress)
107 |
108 | epoch_dict = {"epoch":epoch, "knn_monitor: accuracy":accuracy}
109 | global_progress.set_postfix(epoch_dict)
110 | #logger.update_scalers(epoch_dict)
111 |
112 | # Save checkpoint
113 | model_path = os.path.join(args.ckpt_dir, f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth") # datetime.now().strftime('%Y%m%d_%H%M%S')
114 | torch.save({
115 | 'epoch': epoch+1,
116 | 'state_dict':model.module.state_dict()
117 | }, model_path)
118 | print(f"Model saved to {model_path}")
119 | with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
120 | f.write(f'{model_path}')
121 |
122 |
123 | if args.eval is not False:
124 | args.eval_from = model_path
125 | classes=len(test_loader.dataset.classes)
126 | dataset_name=args.dataset.name
127 | if dataset_name=='cub200':
128 | cub_linear_eval(args)
129 | elif dataset_name=='stanfordcars':
130 | linear_stanfordcars_eval(args,classes)
131 | elif dataset_name=='aircrafts':
132 | linear_stanfordcars_eval(args,classes)
133 | elif dataset_name=='cifar10':
134 | linear_eval(args)
135 |
136 |
137 |
138 | if __name__ == "__main__":
139 | args = get_args()
140 |
141 | main(device=args.device, args=args)
142 |
143 | completed_log_dir = args.log_dir.replace('in-progress', 'debug' if args.debug else 'completed')
144 |
145 |
146 |
147 | os.rename(args.log_dir, completed_log_dir)
148 | print(f'Log file has been saved to {completed_log_dir}')
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
--------------------------------------------------------------------------------
/main_lincls.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import builtins
4 | import os
5 | import random
6 | import shutil
7 | import time
8 | import warnings
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.parallel
13 | import torch.backends.cudnn as cudnn
14 | import torch.distributed as dist
15 | import torch.optim
16 | import torch.multiprocessing as mp
17 | import torch.utils.data
18 | import torch.utils.data.distributed
19 | import torchvision.transforms as transforms
20 | import torchvision.datasets as datasets
21 | import torchvision.models as models
22 |
23 | import moco.loader
24 | import moco.builder
25 | from classifier import Classifier
26 |
27 |
28 | model_names = sorted(name for name in models.__dict__
29 | if name.islower() and not name.startswith("__")
30 | and callable(models.__dict__[name]))
31 |
32 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
33 | parser.add_argument('data', metavar='DIR',
34 | help='path to dataset')
35 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
36 | choices=model_names,
37 | help='model architecture: ' +
38 | ' | '.join(model_names) +
39 | ' (default: resnet50)')
40 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
41 | help='number of data loading workers (default: 32)')
42 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
43 | help='number of total epochs to run')
44 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
45 | help='manual epoch number (useful on restarts)')
46 | parser.add_argument('-b', '--batch-size', default=256, type=int,
47 | metavar='N',
48 | help='mini-batch size (default: 256), this is the total '
49 | 'batch size of all GPUs on the current node when '
50 | 'using Data Parallel or Distributed Data Parallel')
51 | parser.add_argument('--lr', '--learning-rate', default=30., type=float,
52 | metavar='LR', help='initial learning rate', dest='lr')
53 | parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
54 | help='learning rate schedule (when to drop lr by a ratio)')
55 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
56 | help='momentum')
57 | parser.add_argument('--wd', '--weight-decay', default=0., type=float,
58 | metavar='W', help='weight decay (default: 0.)',
59 | dest='weight_decay')
60 | parser.add_argument('-p', '--print-freq', default=10, type=int,
61 | metavar='N', help='print frequency (default: 10)')
62 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
63 | help='path to latest checkpoint (default: none)')
64 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
65 | help='evaluate model on validation set')
66 | parser.add_argument('--world-size', default=-1, type=int,
67 | help='number of nodes for distributed training')
68 | parser.add_argument('--rank', default=-1, type=int,
69 | help='node rank for distributed training')
70 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
71 | help='url used to set up distributed training')
72 | parser.add_argument('--dist-backend', default='nccl', type=str,
73 | help='distributed backend')
74 | parser.add_argument('--seed', default=None, type=int,
75 | help='seed for initializing training. ')
76 | parser.add_argument('--gpu', default=None, type=int,
77 | help='GPU id to use.')
78 | parser.add_argument('--multiprocessing-distributed', action='store_true',
79 | help='Use multi-processing distributed training to launch '
80 | 'N processes per node, which has N GPUs. This is the '
81 | 'fastest way to use PyTorch for either single node or '
82 | 'multi node data parallel training')
83 |
84 | parser.add_argument('--pretrained', default='', type=str,
85 | help='path to moco pretrained checkpoint')
86 |
87 | # moco specific configs:
88 | parser.add_argument('--moco-dim', default=256, type=int,
89 | help='feature dimension (default: 128)')
90 | parser.add_argument('--moco-k', default=65536, type=int,
91 | help='queue size; number of negative keys (default: 65536)')
92 | parser.add_argument('--moco-m', default=0.999, type=float,
93 | help='moco momentum of updating key encoder (default: 0.999)')
94 | parser.add_argument('--moco-t', default=0.07, type=float,
95 | help='softmax temperature (default: 0.07)')
96 | parser.add_argument('--mlp', action='store_true',
97 | help='use mlp head')
98 | parser.add_argument('--class_num', type=int, default=200)
99 |
100 | best_acc1 = 0
101 |
102 |
103 | def main():
104 | args = parser.parse_args()
105 |
106 | if args.seed is not None:
107 | random.seed(args.seed)
108 | torch.manual_seed(args.seed)
109 | cudnn.deterministic = True
110 | warnings.warn('You have chosen to seed training. '
111 | 'This will turn on the CUDNN deterministic setting, '
112 | 'which can slow down your training considerably! '
113 | 'You may see unexpected behavior when restarting '
114 | 'from checkpoints.')
115 |
116 | if args.gpu is not None:
117 | warnings.warn('You have chosen a specific GPU. This will completely '
118 | 'disable data parallelism.')
119 |
120 | if args.dist_url == "env://" and args.world_size == -1:
121 | args.world_size = int(os.environ["WORLD_SIZE"])
122 |
123 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
124 |
125 | ngpus_per_node = torch.cuda.device_count()
126 | if args.multiprocessing_distributed:
127 | # Since we have ngpus_per_node processes per node, the total world_size
128 | # needs to be adjusted accordingly
129 | args.world_size = ngpus_per_node * args.world_size
130 | # Use torch.multiprocessing.spawn to launch distributed processes: the
131 | # main_worker process function
132 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
133 | else:
134 | # Simply call main_worker function
135 | main_worker(args.gpu, ngpus_per_node, args)
136 |
137 |
138 | def main_worker(gpu, ngpus_per_node, args):
139 | global best_acc1
140 | args.gpu = gpu
141 |
142 | # suppress printing if not master
143 | if args.multiprocessing_distributed and args.gpu != 0:
144 | def print_pass(*args):
145 | pass
146 |
147 | builtins.print = print_pass
148 |
149 | if args.gpu is not None:
150 | print("Use GPU: {} for training".format(args.gpu))
151 |
152 | if args.distributed:
153 | if args.dist_url == "env://" and args.rank == -1:
154 | args.rank = int(os.environ["RANK"])
155 | if args.multiprocessing_distributed:
156 | # For multiprocessing distributed training, rank needs to be the
157 | # global rank among all the processes
158 | args.rank = args.rank * ngpus_per_node + gpu
159 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
160 | world_size=args.world_size, rank=args.rank)
161 | # create model
162 | print("=> creating model '{}'".format(args.arch))
163 |
164 | model = moco.builder.MoCo(
165 | models.__dict__[args.arch],
166 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp)
167 |
168 | classifier = Classifier(2048, args.class_num).cuda(args.gpu)
169 |
170 | # freeze all layers but the last fc
171 | for name, param in model.named_parameters():
172 | if name not in ['fc.weight', 'fc.bias']:
173 | param.requires_grad = False
174 | # init the fc layer
175 | model.fc.weight.data.normal_(mean=0.0, std=0.01)
176 | model.fc.bias.data.zero_()
177 |
178 | # load from pre-trained, before DistributedDataParallel constructor
179 | if args.pretrained:
180 | if os.path.isfile(args.pretrained):
181 | print("=> loading checkpoint '{}'".format(args.pretrained))
182 | checkpoint = torch.load(args.pretrained, map_location="cpu")
183 |
184 | # rename moco pre-trained keys
185 | state_dict = checkpoint['state_dict']
186 |
187 | for k in list(state_dict.keys()):
188 | state_dict[k[len("module."):]] = state_dict[k]
189 |
190 | args.start_epoch = 0
191 | msg = model.load_state_dict(state_dict, strict=False)
192 |
193 | print("=> loaded pre-trained model '{}'".format(args.pretrained))
194 | else:
195 | print("=> no checkpoint found at '{}'".format(args.pretrained))
196 |
197 | if args.distributed:
198 | # For multiprocessing distributed, DistributedDataParallel constructor
199 | # should always set the single device scope, otherwise,
200 | # DistributedDataParallel will use all available devices.
201 | if args.gpu is not None:
202 | torch.cuda.set_device(args.gpu)
203 | model.cuda(args.gpu)
204 | # When using a single GPU per process and per
205 | # DistributedDataParallel, we need to divide the batch size
206 | # ourselves based on the total number of GPUs we have
207 | args.batch_size = int(args.batch_size / ngpus_per_node)
208 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
209 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
210 | else:
211 | model.cuda()
212 | # DistributedDataParallel will divide and allocate batch_size to all
213 | # available GPUs if device_ids are not set
214 | model = torch.nn.parallel.DistributedDataParallel(model)
215 | elif args.gpu is not None:
216 | torch.cuda.set_device(args.gpu)
217 | model = model.cuda(args.gpu)
218 | else:
219 | # DataParallel will divide and allocate batch_size to all available GPUs
220 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
221 | model.features = torch.nn.DataParallel(model.features)
222 | model.cuda()
223 | else:
224 | model = torch.nn.DataParallel(model).cuda()
225 |
226 | # define loss function (criterion) and optimizer
227 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
228 |
229 | # optimize only the linear classifier
230 | parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
231 | assert len(parameters) == 2 # fc.weight, fc.bias
232 |
233 | optimizer = torch.optim.SGD(classifier.parameters(), args.lr, momentum=args.momentum,
234 | weight_decay=args.weight_decay)
235 |
236 |
237 | # optionally resume from a checkpoint
238 | if args.resume:
239 | if os.path.isfile(args.resume):
240 | print("=> loading checkpoint '{}'".format(args.resume))
241 | if args.gpu is None:
242 | checkpoint = torch.load(args.resume)
243 | else:
244 | # Map model to be loaded to specified single gpu.
245 | loc = 'cuda:{}'.format(args.gpu)
246 | checkpoint = torch.load(args.resume, map_location=loc)
247 | args.start_epoch = checkpoint['epoch']
248 | best_acc1 = checkpoint['best_acc1']
249 | if args.gpu is not None:
250 | # best_acc1 may be from a checkpoint from a different GPU
251 | best_acc1 = best_acc1.to(args.gpu)
252 | model.load_state_dict(checkpoint['state_dict'])
253 | optimizer.load_state_dict(checkpoint['optimizer'])
254 | print("=> loaded checkpoint '{}' (epoch {})"
255 | .format(args.resume, checkpoint['epoch']))
256 | else:
257 | print("=> no checkpoint found at '{}'".format(args.resume))
258 |
259 | cudnn.benchmark = True
260 |
261 | # Data loading code
262 | traindir = os.path.join(args.data, 'train')
263 | valdir = os.path.join(args.data, 'test')
264 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
265 | std=[0.229, 0.224, 0.225])
266 |
267 | train_dataset = datasets.ImageFolder(
268 | traindir,
269 | transforms.Compose([
270 | transforms.RandomResizedCrop(224),
271 | transforms.RandomHorizontalFlip(),
272 | transforms.ToTensor(),
273 | normalize,
274 | ]))
275 |
276 | if args.distributed:
277 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
278 | else:
279 | train_sampler = None
280 |
281 | train_loader = torch.utils.data.DataLoader(
282 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
283 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
284 |
285 | val_loader = torch.utils.data.DataLoader(
286 | datasets.ImageFolder(valdir, transforms.Compose([
287 | transforms.Resize(256),
288 | transforms.CenterCrop(224),
289 | transforms.ToTensor(),
290 | normalize,
291 | ])),
292 | batch_size=args.batch_size, shuffle=False,
293 | num_workers=args.workers, pin_memory=True)
294 |
295 | if args.evaluate:
296 | validate(val_loader, model, classifier, criterion, args)
297 | return
298 |
299 | for epoch in range(args.start_epoch, args.epochs):
300 | if args.distributed:
301 | train_sampler.set_epoch(epoch)
302 | adjust_learning_rate(optimizer, epoch, args)
303 |
304 | # train for one epoch
305 | train(train_loader, model, classifier, criterion, optimizer, epoch, args)
306 |
307 | # evaluate on validation set
308 | acc1 = validate(val_loader, model, classifier, criterion, args)
309 |
310 | # remember best acc@1 and save checkpoint
311 | is_best = acc1 > best_acc1
312 | best_acc1 = max(acc1, best_acc1)
313 |
314 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
315 | and args.rank % ngpus_per_node == 0):
316 | save_checkpoint({
317 | 'epoch': epoch + 1,
318 | 'arch': args.arch,
319 | 'state_dict': model.state_dict(),
320 | 'best_acc1': best_acc1,
321 | 'optimizer': optimizer.state_dict(),
322 | }, is_best)
323 | # if epoch == args.start_epoch:
324 | # sanity_check(model.state_dict(), args.pretrained)
325 |
326 |
327 | def train(train_loader, model, classifier, criterion, optimizer, epoch, args):
328 | batch_time = AverageMeter('Time', ':6.3f')
329 | data_time = AverageMeter('Data', ':6.3f')
330 | losses = AverageMeter('Loss', ':.4e')
331 | top1 = AverageMeter('Acc@1', ':6.2f')
332 | top5 = AverageMeter('Acc@5', ':6.2f')
333 | progress = ProgressMeter(
334 | len(train_loader),
335 | [batch_time, data_time, losses, top1, top5],
336 | prefix="Epoch: [{}]".format(epoch))
337 |
338 | """
339 | Switch to eval mode:
340 | Under the protocol of linear classification on frozen features/models,
341 | it is not legitimate to change any part of the pre-trained model.
342 | BatchNorm in train mode may revise running mean/std (even if it receives
343 | no gradient), which are part of the model parameters too.
344 | """
345 | model.eval()
346 | classifier.train(True)
347 | optimizer.zero_grad()
348 |
349 | end = time.time()
350 | for i, (images, target) in enumerate(train_loader):
351 | # measure data loading time
352 | data_time.update(time.time() - end)
353 |
354 | if args.gpu is not None:
355 | images = images.cuda(args.gpu, non_blocking=True)
356 | target = target.cuda(args.gpu, non_blocking=True)
357 |
358 | # compute output
359 | _, _, _, bp_out_feat = model.module.inference(images)
360 |
361 | output = classifier(bp_out_feat)
362 | loss = criterion(output, target)
363 |
364 | # measure accuracy and record loss
365 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
366 | losses.update(loss.item(), images.size(0))
367 | top1.update(acc1[0], images.size(0))
368 | top5.update(acc5[0], images.size(0))
369 |
370 | # compute gradient and do SGD step
371 | optimizer.zero_grad()
372 | loss.backward()
373 | optimizer.step()
374 |
375 | # measure elapsed time
376 | batch_time.update(time.time() - end)
377 | end = time.time()
378 |
379 | if i % args.print_freq == 0:
380 | progress.display(i)
381 |
382 |
383 | def validate(val_loader, model, classifier, criterion, args):
384 | batch_time = AverageMeter('Time', ':6.3f')
385 | losses = AverageMeter('Loss', ':.4e')
386 | top1 = AverageMeter('Acc@1', ':6.2f')
387 | top5 = AverageMeter('Acc@5', ':6.2f')
388 | progress = ProgressMeter(
389 | len(val_loader),
390 | [batch_time, losses, top1, top5],
391 | prefix='Test: ')
392 |
393 | # switch to evaluate mode
394 | model.eval()
395 | classifier.eval()
396 |
397 | with torch.no_grad():
398 | end = time.time()
399 | for i, (images, target) in enumerate(val_loader):
400 | if args.gpu is not None:
401 | images = images.cuda(args.gpu, non_blocking=True)
402 | target = target.cuda(args.gpu, non_blocking=True)
403 |
404 | # compute output
405 | _, _, _, bp_out_feat = model.module.inference(images)
406 | output = classifier(bp_out_feat)
407 | loss = criterion(output, target)
408 |
409 | # measure accuracy and record loss
410 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
411 | losses.update(loss.item(), images.size(0))
412 | top1.update(acc1[0], images.size(0))
413 | top5.update(acc5[0], images.size(0))
414 |
415 | # measure elapsed time
416 | batch_time.update(time.time() - end)
417 | end = time.time()
418 |
419 | if i % args.print_freq == 0:
420 | progress.display(i)
421 |
422 | # TODO: this should also be done with the ProgressMeter
423 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
424 | .format(top1=top1, top5=top5))
425 |
426 | return top1.avg
427 |
428 |
429 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
430 | torch.save(state, filename)
431 | if is_best:
432 | shutil.copyfile(filename, 'model_best.pth.tar')
433 |
434 |
435 | def sanity_check(state_dict, pretrained_weights):
436 | """
437 | Linear classifier should not change any weights other than the linear layer.
438 | This sanity check asserts nothing wrong happens (e.g., BN stats updated).
439 | """
440 | print("=> loading '{}' for sanity check".format(pretrained_weights))
441 | checkpoint = torch.load(pretrained_weights, map_location="cpu")
442 | state_dict_pre = checkpoint['state_dict']
443 |
444 | for k in list(state_dict.keys()):
445 | # only ignore fc layer
446 | if 'fc.weight' in k or 'fc.bias' in k:
447 | continue
448 |
449 | # name in pretrained model
450 | k_pre = 'module.encoder_q.' + k[len('module.'):] \
451 | if k.startswith('module.') else 'module.encoder_q.' + k
452 |
453 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
454 | '{} is changed in linear classifier training.'.format(k)
455 |
456 | print("=> sanity check passed.")
457 |
458 |
459 | class AverageMeter(object):
460 | """Computes and stores the average and current value"""
461 |
462 | def __init__(self, name, fmt=':f'):
463 | self.name = name
464 | self.fmt = fmt
465 | self.reset()
466 |
467 | def reset(self):
468 | self.val = 0
469 | self.avg = 0
470 | self.sum = 0
471 | self.count = 0
472 |
473 | def update(self, val, n=1):
474 | self.val = val
475 | self.sum += val * n
476 | self.count += n
477 | self.avg = self.sum / self.count
478 |
479 | def __str__(self):
480 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
481 | return fmtstr.format(**self.__dict__)
482 |
483 |
484 | class ProgressMeter(object):
485 | def __init__(self, num_batches, meters, prefix=""):
486 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
487 | self.meters = meters
488 | self.prefix = prefix
489 |
490 | def display(self, batch):
491 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
492 | entries += [str(meter) for meter in self.meters]
493 | print('\t'.join(entries))
494 |
495 | def _get_batch_fmtstr(self, num_batches):
496 | num_digits = len(str(num_batches // 1))
497 | fmt = '{:' + str(num_digits) + 'd}'
498 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
499 |
500 |
501 | def adjust_learning_rate(optimizer, epoch, args):
502 | """Decay the learning rate based on schedule"""
503 | lr = args.lr
504 | for milestone in args.schedule:
505 | lr *= 0.1 if epoch >= milestone else 1.
506 | for param_group in optimizer.param_groups:
507 | param_group['lr'] = lr
508 |
509 |
510 | def accuracy(output, target, topk=(1,)):
511 | """Computes the accuracy over the k top predictions for the specified values of k"""
512 | with torch.no_grad():
513 | maxk = max(topk)
514 | batch_size = target.size(0)
515 |
516 | _, pred = output.topk(maxk, 1, True, True)
517 | pred = pred.t()
518 | correct = pred.eq(target.view(1, -1).expand_as(pred))
519 |
520 | res = []
521 | for k in topk:
522 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
523 | res.append(correct_k.mul_(100.0 / batch_size))
524 | return res
525 |
526 |
527 | if __name__ == '__main__':
528 | main()
529 |
--------------------------------------------------------------------------------
/main_moco.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import builtins
4 | import math
5 | import os
6 | import random
7 | import shutil
8 | import time
9 | import warnings
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | import torch.backends.cudnn as cudnn
15 | import torch.distributed as dist
16 | import torch.optim
17 | import torch.multiprocessing as mp
18 | import torch.utils.data
19 | import torch.utils.data.distributed
20 | import torchvision.transforms as transforms
21 | import torchvision.datasets as datasets
22 | import torchvision.models as models
23 |
24 | import moco.loader
25 | import moco.builder
26 |
27 | import numpy as np
28 |
29 |
30 | import torch.nn.functional as F
31 |
32 | torch.multiprocessing.set_sharing_strategy('file_system')
33 |
34 | os.environ[
35 | "TORCH_DISTRIBUTED_DEBUG"
36 | ] = "DETAIL"
37 |
38 | model_names = sorted(name for name in models.__dict__
39 | if name.islower() and not name.startswith("__")
40 | and callable(models.__dict__[name]))
41 |
42 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
43 | parser.add_argument('data', metavar='DIR',
44 | help='path to dataset')
45 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
46 | help='evaluate model on validation set')
47 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
48 | choices=model_names,
49 | help='model architecture: ' +
50 | ' | '.join(model_names) +
51 | ' (default: resnet50)')
52 | parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
53 | help='number of data loading workers (default: 32)')
54 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
55 | help='number of total epochs to run')
56 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
57 | help='manual epoch number (useful on restarts)')
58 | parser.add_argument('-b', '--batch-size', default=256, type=int,
59 | metavar='N',
60 | help='mini-batch size (default: 256), this is the total '
61 | 'batch size of all GPUs on the current node when '
62 | 'using Data Parallel or Distributed Data Parallel')
63 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
64 | metavar='LR', help='initial learning rate', dest='lr')
65 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
66 | help='learning rate schedule (when to drop lr by 10x)')
67 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
68 | help='momentum of SGD solver')
69 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
70 | metavar='W', help='weight decay (default: 1e-4)',
71 | dest='weight_decay')
72 | parser.add_argument('-p', '--print-freq', default=10, type=int,
73 | metavar='N', help='print frequency (default: 10)')
74 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
75 | help='path to latest checkpoint (default: none)')
76 | parser.add_argument('--world-size', default=-1, type=int,
77 | help='number of nodes for distributed training')
78 | parser.add_argument('--nu', default=0.01, type=float,
79 | help='weight of loss function')
80 | parser.add_argument('--rank', default=-1, type=int,
81 | help='node rank for distributed training')
82 | parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str, # tcp://224.66.41.62:23456
83 | help='url used to set up distributed training')
84 | parser.add_argument('--dist-backend', default='nccl', type=str,
85 | help='distributed backend')
86 | parser.add_argument('--seed', default=None, type=int,
87 | help='seed for initializing training. ')
88 | parser.add_argument('--gpu', default=None, type=int,
89 | help='GPU id to use.')
90 | parser.add_argument('--multiprocessing-distributed', action='store_true',
91 | help='Use multi-processing distributed training to launch '
92 | 'N processes per node, which has N GPUs. This is the '
93 | 'fastest way to use PyTorch for either single node or '
94 | 'multi node data parallel training')
95 |
96 | # moco specific configs:
97 | parser.add_argument('--moco-dim', default=256, type=int,
98 | help='feature dimension (default: 128)')
99 | parser.add_argument('--moco-k', default=65536, type=int,
100 | help='queue size; number of negative keys (default: 65536)')
101 | parser.add_argument('--moco-m', default=0.999, type=float,
102 | help='moco momentum of updating key encoder (default: 0.999)')
103 | parser.add_argument('--moco-t', default=0.07, type=float,
104 | help='softmax temperature (default: 0.07)')
105 |
106 | # options for moco v2
107 | parser.add_argument('--mlp', action='store_true',
108 | help='use mlp head')
109 | parser.add_argument('--aug-plus', action='store_true',
110 | help='use moco v2 data augmentation')
111 | parser.add_argument('--cos', action='store_true',
112 | help='use cosine lr schedule')
113 |
114 |
115 | best_top1 = 0
116 | best_top5 = 0
117 | best_mAP = 0
118 |
119 |
120 | def main():
121 | args = parser.parse_args()
122 |
123 | if args.seed is not None:
124 | random.seed(args.seed)
125 | torch.manual_seed(args.seed)
126 | cudnn.deterministic = True
127 | warnings.warn('You have chosen to seed training. '
128 | 'This will turn on the CUDNN deterministic setting, '
129 | 'which can slow down your training considerably! '
130 | 'You may see unexpected behavior when restarting '
131 | 'from checkpoints.')
132 |
133 | if args.gpu is not None:
134 | warnings.warn('You have chosen a specific GPU. This will completely '
135 | 'disable data parallelism.')
136 |
137 | if args.dist_url == "env://" and args.world_size == -1:
138 | args.world_size = int(os.environ["WORLD_SIZE"])
139 |
140 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
141 |
142 | ngpus_per_node = torch.cuda.device_count()
143 |
144 | if args.multiprocessing_distributed:
145 | # Since we have ngpus_per_node processes per node, the total world_size
146 | # needs to be adjusted accordingly
147 | args.world_size = ngpus_per_node * args.world_size
148 | # Use torch.multiprocessing.spawn to launch distributed processes: the
149 | # main_worker process function
150 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
151 | else:
152 | # Simply call main_worker function
153 | main_worker(args.gpu, ngpus_per_node, args)
154 |
155 |
156 | def main_worker(gpu, ngpus_per_node, args):
157 | global best_top1, best_top5, best_mAP
158 | args.gpu = gpu
159 |
160 | # suppress printing if not master
161 | if args.multiprocessing_distributed and args.gpu != 0:
162 | def print_pass(*args):
163 | pass
164 |
165 | builtins.print = print_pass
166 |
167 | if args.gpu is not None:
168 | print("Use GPU: {} for training".format(args.gpu))
169 |
170 | if args.distributed:
171 | if args.dist_url == "env://" and args.rank == -1:
172 | args.rank = int(os.environ["RANK"])
173 | if args.multiprocessing_distributed:
174 | # For multiprocessing distributed training, rank needs to be the
175 | # global rank among all the processes
176 | args.rank = args.rank * ngpus_per_node + gpu
177 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
178 | world_size=args.world_size, rank=args.rank)
179 |
180 | # create model
181 | print("=> creating model '{}'".format(args.arch))
182 |
183 | model = moco.builder.MoCo(
184 | # models.__dict__[args.arch],
185 | args.arch,
186 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp)
187 |
188 |
189 | if args.distributed:
190 | # For multiprocessing distributed, DistributedDataParallel constructor
191 | # should always set the single device scope, otherwise,
192 | # DistributedDataParallel will use all available devices.
193 | if args.gpu is not None:
194 | torch.cuda.set_device(args.gpu)
195 | model.cuda(args.gpu)
196 | # When using a single GPU per process and per
197 | # DistributedDataParallel, we need to divide the batch size
198 | # ourselves based on the total number of GPUs we have
199 | args.batch_size = int(args.batch_size / ngpus_per_node)
200 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
201 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
202 | else:
203 | model.cuda()
204 | # DistributedDataParallel will divide and allocate batch_size to all
205 | # available GPUs if device_ids are not set
206 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
207 | elif args.gpu is not None:
208 | torch.cuda.set_device(args.gpu)
209 | model = model.cuda(args.gpu)
210 | # comment out the following line for debugging
211 | raise NotImplementedError("Only DistributedDataParallel is supported.")
212 | else:
213 | # AllGather implementation (batch shuffle, queue update, etc.) in
214 | # this code only supports DistributedDataParallel.
215 | raise NotImplementedError("Only DistributedDataParallel is supported.")
216 |
217 | # define loss function (criterion) and optimizer
218 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
219 |
220 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
221 | momentum=args.momentum,
222 | weight_decay=args.weight_decay)
223 |
224 | # optionally resume from a checkpoint
225 | if args.resume:
226 | if os.path.isfile(args.resume):
227 | print("=> loading checkpoint '{}'".format(args.resume))
228 | if args.gpu is None:
229 | checkpoint = torch.load(args.resume)
230 | else:
231 | # Map model to be loaded to specified single gpu.
232 | loc = 'cuda:{}'.format(args.gpu)
233 | checkpoint = torch.load(args.resume, map_location=loc)
234 | args.start_epoch = checkpoint['epoch']
235 | model.load_state_dict(checkpoint['state_dict'])
236 | optimizer.load_state_dict(checkpoint['optimizer'])
237 | print("=> loaded checkpoint '{}' (epoch {})"
238 | .format(args.resume, checkpoint['epoch']))
239 | else:
240 | print("=> no checkpoint found at '{}'".format(args.resume))
241 |
242 | cudnn.benchmark = True
243 |
244 | # Data loading code
245 | traindir = os.path.join(args.data, 'train')
246 | valdir = os.path.join(args.data, 'test')
247 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
248 | std=[0.229, 0.224, 0.225])
249 | if args.aug_plus:
250 | # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
251 | augmentation = [
252 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
253 | transforms.RandomApply([
254 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
255 | ], p=0.8),
256 | transforms.RandomGrayscale(p=0.2),
257 | transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5),
258 | transforms.RandomHorizontalFlip(),
259 | transforms.ToTensor(),
260 | normalize
261 | ]
262 | else:
263 | # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978
264 | augmentation = [
265 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
266 | transforms.RandomGrayscale(p=0.2),
267 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
268 | transforms.RandomHorizontalFlip(),
269 | transforms.ToTensor(),
270 | normalize
271 | ]
272 |
273 | train_dataset = datasets.ImageFolder(
274 | traindir,
275 | moco.loader.TwoCropsTransform(transforms.Compose(augmentation)))
276 |
277 | if args.distributed:
278 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
279 | else:
280 | train_sampler = None
281 |
282 | train_loader = torch.utils.data.DataLoader(
283 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
284 | num_workers=args.workers, pin_memory=False, sampler=train_sampler, drop_last=True)
285 |
286 | val_loader = torch.utils.data.DataLoader(
287 | datasets.ImageFolder(valdir, transforms.Compose([
288 | transforms.Resize(256),
289 | transforms.CenterCrop(224),
290 | transforms.ToTensor(),
291 | normalize,
292 | ])),
293 | batch_size=args.batch_size, shuffle=False,
294 | num_workers=args.workers, pin_memory=False)
295 |
296 | if args.evaluate:
297 | validate(val_loader, model, criterion, args)
298 | return
299 |
300 | indx = torch.arange(256)
301 | indx_train = indx
302 | for epoch in range(args.start_epoch, args.epochs):
303 | if args.distributed:
304 | train_sampler.set_epoch(epoch)
305 | adjust_learning_rate(optimizer, epoch, args)
306 |
307 | # train for one epoch
308 |
309 | backbone = train(train_loader, model, criterion, optimizer, epoch, args, indx_train)
310 | # evaluate on validation set
311 | top1, top5, mAP = validate(val_loader, model, criterion, args, epoch)
312 |
313 | best_top1 = max(top1, best_top1)
314 |
315 | best_top5 = max(top5, best_top5)
316 |
317 | best_mAP = max(mAP, best_mAP)
318 |
319 |
320 | print("best Top@1: %.4f" % (best_top1))
321 | print("best Top@5: %.4f" % (best_top5))
322 | print("best mAP: %.4f" % (best_mAP))
323 |
324 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
325 | and args.rank % ngpus_per_node == 0):
326 | save_checkpoint({
327 | 'epoch': epoch + 1,
328 | 'arch': args.arch,
329 | 'state_dict': model.state_dict(),
330 | 'optimizer': optimizer.state_dict(),
331 | }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch))
332 |
333 |
334 | def train(train_loader, model, criterion, optimizer, epoch, args, indx):
335 | batch_time = AverageMeter('Time', ':6.3f')
336 | data_time = AverageMeter('Data', ':6.3f')
337 | losses = AverageMeter('Loss', ':.4e')
338 | top1 = AverageMeter('Acc@1', ':6.2f')
339 | top5 = AverageMeter('Acc@5', ':6.2f')
340 | progress = ProgressMeter(
341 | len(train_loader),
342 | [batch_time, data_time, losses, top1, top5],
343 | prefix="Epoch: [{}]".format(epoch))
344 | # switch to train mode
345 | model.train()
346 |
347 | end = time.time()
348 | for i, (images, _) in enumerate(train_loader):
349 | # measure data loading time
350 | data_time.update(time.time() - end)
351 |
352 | if args.gpu is not None:
353 | images[0] = images[0].cuda(args.gpu, non_blocking=True)
354 | images[1] = images[1].cuda(args.gpu, non_blocking=True)
355 |
356 | output, target, output_max, target_max, backbone, featcov16, featmap, att_max, gradcam = model(im_q=images[0], im_k=images[1], epoch=epoch, iter=i, indx=indx)
357 |
358 |
359 |
360 | loss = criterion(output, target)
361 | loss_cam = F.kl_div(att_max.softmax(dim=-1).log(), gradcam.softmax(dim=-1),
362 | reduction='sum')
363 |
364 | total_loss = loss + args.nu * loss_cam
365 |
366 | # acc1/acc5 # (K+1)-way contrast classifier accuracy
367 | # measure accuracy and record loss
368 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
369 | losses.update(loss.item(), images[0].size(0))
370 | top1.update(acc1[0], images[0].size(0))
371 | top5.update(acc5[0], images[0].size(0))
372 |
373 | # compute gradient and do SGD step
374 | optimizer.zero_grad()
375 | total_loss.backward()
376 | optimizer.step()
377 |
378 | # measure elapsed time
379 | batch_time.update(time.time() - end)
380 | end = time.time()
381 |
382 | if i % args.print_freq == 0:
383 | progress.display(i)
384 | return backbone
385 |
386 |
387 |
388 |
389 | def validate(val_loader, model, criterion, args, epoch):
390 | batch_time = AverageMeter('Time', ':6.3f')
391 | losses = AverageMeter('Loss', ':.4e')
392 | top1 = AverageMeter('top1', ':6.2f')
393 | top5 = AverageMeter('top5', ':6.2f')
394 | progress = ProgressMeter(
395 | len(val_loader),
396 | [batch_time, losses, top1, top5],
397 | prefix='Test: ')
398 |
399 | # switch to evaluate mode
400 | model.eval()
401 |
402 | with torch.no_grad():
403 | end = time.time()
404 | data = torch.zeros(1, 2048)
405 | label = torch.zeros(1)
406 | att_map = torch.zeros(1, 32, 7, 7)
407 | for i, (images, target) in enumerate(val_loader):
408 | if args.gpu is not None:
409 | images = images.cuda(args.gpu, non_blocking=True)
410 | target = target.cuda(args.gpu, non_blocking=True)
411 |
412 | # compute output
413 | _, _, featcov16, output = model.module.inference(images)
414 |
415 | data = torch.cat((data.cuda(args.gpu), output.cuda(args.gpu)), 0)
416 |
417 | label = torch.cat((label.cuda(args.gpu), target), 0)
418 |
419 | att_map = torch.cat((att_map.cuda(args.gpu), featcov16.cuda(args.gpu)), 0)
420 |
421 |
422 | data = data[torch.arange(data.size(0)) != 0]
423 | label = label[torch.arange(label.size(0)) != 0]
424 |
425 |
426 | att_map = att_map[torch.arange(att_map.size(0)) != 0]
427 | max_pre = torch.var(att_map, dim=(2, 3), keepdim=False)
428 | max_pre = torch.mean(max_pre, dim=0, keepdim=False)
429 | a, idx1 = torch.sort(max_pre, descending=True)
430 |
431 |
432 | topN1 = []
433 | topN5 = []
434 | MAP = []
435 | for j in range(data.size(0)):
436 | query_feat = data[j, :]
437 | query_label = label[j].item()
438 |
439 | dict = data[torch.arange(data.size(0)) != j]
440 | sim_label = label[torch.arange(label.size(0)) != j]
441 |
442 | similarity = torch.mv(dict, query_feat)
443 |
444 | table = torch.zeros(similarity.size(0), 2)
445 | table[:, 0] = similarity
446 | table[:, 1] = sim_label
447 | table = table.cpu().detach().numpy()
448 |
449 | index = np.argsort(table[:, 0])[::-1]
450 |
451 | T = table[index]
452 | #top-1
453 | if T[0,1] == query_label:
454 | topN1.append(1)
455 | else:
456 | topN1.append(0)
457 | #top-5
458 | if np.sum(T[:5, -1] == query_label) > 0:
459 | topN5.append(1)
460 | else:
461 | topN5.append(0)
462 |
463 | #mAP
464 | check = np.where(T[:, 1] == query_label)
465 | check = check[0]
466 | AP = 0
467 | for k in range(len(check)):
468 | temp = (k+1)/(check[k]+1)
469 | AP = AP + temp
470 | AP = AP/(len(check))
471 | MAP.append(AP)
472 |
473 | top1 = np.mean(topN1)
474 | top5 = np.mean(topN5)
475 | mAP = np.mean(MAP)
476 |
477 | # TODO: this should also be done with the ProgressMeter
478 | print(' * Top@1 {top1:.3f} Top@5 {top5:.3f} mAP {mAP:.3f}'
479 | .format(top1=top1, top5=top5, mAP=mAP))
480 |
481 | return top1, top5, mAP
482 |
483 |
484 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
485 | torch.save(state, filename)
486 | if is_best:
487 | shutil.copyfile(filename, 'model_best.pth.tar')
488 |
489 |
490 | class AverageMeter(object):
491 | """Computes and stores the average and current value"""
492 |
493 | def __init__(self, name, fmt=':f'):
494 | self.name = name
495 | self.fmt = fmt
496 | self.reset()
497 |
498 | def reset(self):
499 | self.val = 0
500 | self.avg = 0
501 | self.sum = 0
502 | self.count = 0
503 |
504 | def update(self, val, n=1):
505 | self.val = val
506 | self.sum += val * n
507 | self.count += n
508 | self.avg = self.sum / self.count
509 |
510 | def __str__(self):
511 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
512 | return fmtstr.format(**self.__dict__)
513 |
514 |
515 | class ProgressMeter(object):
516 | def __init__(self, num_batches, meters, prefix=""):
517 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
518 | self.meters = meters
519 | self.prefix = prefix
520 |
521 | def display(self, batch):
522 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
523 | entries += [str(meter) for meter in self.meters]
524 | print('\t'.join(entries))
525 |
526 | def _get_batch_fmtstr(self, num_batches):
527 | num_digits = len(str(num_batches // 1))
528 | fmt = '{:' + str(num_digits) + 'd}'
529 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
530 |
531 |
532 | def adjust_learning_rate(optimizer, epoch, args):
533 | """Decay the learning rate based on schedule"""
534 | lr = args.lr
535 | if args.cos: # cosine lr schedule
536 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
537 | else: # stepwise lr schedule
538 | for milestone in args.schedule:
539 | lr *= 0.1 if epoch >= milestone else 1.
540 | for param_group in optimizer.param_groups:
541 | param_group['lr'] = lr
542 |
543 |
544 | def accuracy(output, target, topk=(1,)):
545 | """Computes the accuracy over the k top predictions for the specified values of k"""
546 | with torch.no_grad():
547 | maxk = max(topk)
548 | batch_size = target.size(0)
549 |
550 | _, pred = output.topk(maxk, 1, True, True)
551 | pred = pred.t()
552 | correct = pred.eq(target.view(1, -1).expand_as(pred))
553 |
554 | res = []
555 | for k in topk:
556 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
557 | res.append(correct_k.mul_(100.0 / batch_size))
558 | return res
559 |
560 |
561 | if __name__ == '__main__':
562 | main()
563 |
--------------------------------------------------------------------------------
/moco/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
--------------------------------------------------------------------------------
/moco/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/moco/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/moco/__pycache__/builder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/moco/__pycache__/builder.cpython-38.pyc
--------------------------------------------------------------------------------
/moco/__pycache__/loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/moco/__pycache__/loader.cpython-38.pyc
--------------------------------------------------------------------------------
/moco/builder.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from resnet_output import resnet50
5 | import numpy as np
6 |
7 |
8 | class MoCo(nn.Module):
9 | """
10 | Build a MoCo model with: a query encoder, a key encoder, and a queue
11 | https://arxiv.org/abs/1911.05722
12 | """
13 |
14 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
15 | """
16 | dim: feature dimension (default: 128)
17 | K: queue size; number of negative keys (default: 65536)
18 | m: moco momentum of updating key encoder (default: 0.999)
19 | T: softmax temperature (default: 0.07)
20 | """
21 | super(MoCo, self).__init__()
22 |
23 | self.K = K
24 | self.m = m
25 | self.T = T
26 |
27 | # create the encoders
28 | self.encoder_q = resnet50(pretrained=True, num_classes=1000)
29 | # self.encoder_q = base_encoder(num_classes=dim)
30 | self.encoder_k = resnet50(pretrained=True, num_classes=1000)
31 | # self.encoder_k = base_encoder(num_classes=dim)
32 |
33 | self.encoder_q.fc = nn.Linear(2048, dim)
34 | self.encoder_k.fc = nn.Linear(2048, dim)
35 |
36 | if mlp: # hack: brute-force replacement
37 | dim_mlp = self.encoder_q.fc.weight.shape[1]
38 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
39 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
40 |
41 |
42 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
43 | param_k.data.copy_(param_q.data) # initialize
44 | param_k.requires_grad = False # not update by gradient
45 |
46 | # create the queue
47 | self.register_buffer("queue", torch.randn(dim,
48 | K))
49 | self.queue = nn.functional.normalize(self.queue, dim=0)
50 |
51 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
52 |
53 | # create the queue_max
54 | self.register_buffer("queue_max", torch.randn(dim,
55 | K))
56 | self.queue_max = nn.functional.normalize(self.queue_max, dim=0)
57 |
58 | self.register_buffer("queue_ptr_max", torch.zeros(1, dtype=torch.long))
59 | # add projection
60 |
61 | self.bilinear = 32
62 |
63 | self.conv16 = nn.Conv2d(2048, self.bilinear, kernel_size=1, stride=1, padding=0,
64 | bias=False)
65 |
66 | self.bn16 = nn.BatchNorm2d(self.bilinear)
67 |
68 | self.conv16_2 = nn.Conv2d(self.bilinear, 1, kernel_size=1, stride=1, padding=0,
69 | bias=False)
70 | self.bn16_2 = nn.BatchNorm2d(1)
71 |
72 |
73 | self.relu = nn.ReLU(inplace=True)
74 |
75 | self.avgpool = nn.AvgPool2d(7, stride=1)
76 |
77 | self.fc = nn.Linear(2048, dim)
78 | self.qmax = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), self.fc)
79 | self.kmax = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), self.fc)
80 |
81 |
82 |
83 | @torch.no_grad()
84 | def _momentum_update_key_encoder(self):
85 | """
86 | Momentum update of the key encoder
87 | """
88 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
89 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
90 |
91 | for param_q, param_k in zip(self.qmax.parameters(), self.kmax.parameters()):
92 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
93 |
94 | @torch.no_grad()
95 | def _dequeue_and_enqueue(self, keys):
96 | # gather keys before updating queue
97 | keys = concat_all_gather(keys)
98 |
99 | batch_size = keys.shape[0]
100 |
101 | ptr = int(self.queue_ptr)
102 | assert self.K % batch_size == 0 # for simplicity
103 |
104 | # replace the keys at ptr (dequeue and enqueue)
105 | self.queue[:, ptr:ptr + batch_size] = keys.T
106 | ptr = (ptr + batch_size) % self.K # move pointer
107 |
108 | self.queue_ptr[0] = ptr
109 |
110 | def _dequeue_and_enqueue_max(self, keys):
111 | # gather keys before updating queue
112 | keys = concat_all_gather(keys)
113 |
114 | batch_size = keys.shape[0]
115 |
116 | ptr = int(self.queue_ptr_max)
117 | assert self.K % batch_size == 0 # for simplicity
118 |
119 | # replace the keys at ptr (dequeue and enqueue)
120 | self.queue_max[:, ptr:ptr + batch_size] = keys.T
121 | ptr = (ptr + batch_size) % self.K # move pointer
122 |
123 | self.queue_ptr_max[0] = ptr
124 |
125 | def max_mask(self, featmap, indx):
126 | featcov16 = self.conv16(featmap)
127 | featcov16 = self.bn16(featcov16)
128 | featcov16 = self.relu(featcov16)
129 |
130 |
131 |
132 | img, _ = torch.max(featcov16, axis=1)
133 | img = img - torch.min(img)
134 | att_max = img / (1e-7 + torch.max(img))
135 |
136 | img = att_max[:, None, :, :]
137 | img = img.repeat(1, 2048, 1, 1)
138 |
139 |
140 |
141 | PFM = featmap.cuda() * img.cuda()
142 | aa = self.avgpool(PFM)
143 | bp_out_feat = aa.view(aa.size(0), -1)
144 | bp_out_feat_max = nn.functional.normalize(bp_out_feat, dim=1)
145 |
146 |
147 |
148 | return bp_out_feat_max, att_max
149 |
150 | def feat_bilinear(self, featmap):
151 | featcov16 = self.conv16(featmap)
152 | featcov16 = self.bn16(featcov16)
153 | featcov16 = self.relu(featcov16)
154 |
155 | feat_matrix = torch.zeros(featcov16.size(0), self.bilinear, 2048)
156 | for i in range(self.bilinear):
157 | matrix = featcov16[:, i, :, :]
158 | matrix = matrix[:, None, :, :]
159 | matrix = matrix.repeat(1, 2048, 1, 1)
160 | PFM = featmap * matrix
161 | aa = self.avgpool(PFM)
162 |
163 | feat_matrix[:, i, :] = aa.view(aa.size(0), -1)
164 |
165 | bp_out_feat = feat_matrix.view(feat_matrix.size(0), -1)
166 |
167 | return bp_out_feat
168 |
169 | @torch.no_grad()
170 | def _batch_shuffle_ddp(self, x):
171 | """
172 | Batch shuffle, for making use of BatchNorm.
173 | *** Only support DistributedDataParallel (DDP) model. ***
174 | """
175 | # gather from all gpus
176 | batch_size_this = x.shape[0]
177 | x_gather = concat_all_gather(x)
178 | batch_size_all = x_gather.shape[0]
179 |
180 | num_gpus = batch_size_all // batch_size_this
181 |
182 | # random shuffle index
183 | idx_shuffle = torch.randperm(batch_size_all).cuda()
184 |
185 | # broadcast to all gpus
186 | torch.distributed.broadcast(idx_shuffle, src=0)
187 |
188 | # index for restoring
189 | idx_unshuffle = torch.argsort(idx_shuffle)
190 |
191 | # shuffled index for this gpu
192 | gpu_idx = torch.distributed.get_rank()
193 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
194 |
195 | return x_gather[idx_this], idx_unshuffle
196 |
197 | @torch.no_grad()
198 | def _batch_unshuffle_ddp(self, x, idx_unshuffle):
199 | """
200 | Undo batch shuffle.
201 | *** Only support DistributedDataParallel (DDP) model. ***
202 | """
203 | # gather from all gpus
204 | batch_size_this = x.shape[0]
205 | x_gather = concat_all_gather(x)
206 | batch_size_all = x_gather.shape[0]
207 |
208 | num_gpus = batch_size_all // batch_size_this
209 |
210 | # restored index for this gpu
211 | gpu_idx = torch.distributed.get_rank()
212 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
213 |
214 | return x_gather[idx_this]
215 |
216 | def forward(self, im_q, im_k, epoch, iter, indx):
217 | """
218 | Input:
219 | im_q: a batch of query images
220 | im_k: a batch of key images
221 | Output:
222 | logits, targets
223 | """
224 |
225 | # compute query features
226 | q, _, featmap = self.encoder_q(im_q) # queries: NxC
227 | q = nn.functional.normalize(q, dim=1)
228 |
229 |
230 | # max bilinear q
231 | q_max, att_max = self.max_mask(featmap, indx)
232 | embedding_q = self.qmax(q_max.cuda())
233 | q_max_proj = nn.functional.normalize(embedding_q, dim=1)
234 |
235 | # compute key features
236 | with torch.no_grad(): # no gradient to keys
237 | self._momentum_update_key_encoder() # update the key encoder
238 |
239 | # shuffle for making use of BN
240 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
241 |
242 | k, _, featmap_k = self.encoder_k(im_k) # keys: NxC
243 | k = nn.functional.normalize(k, dim=1)
244 |
245 | # max bilinear k
246 | k_max, _ = self.max_mask(featmap_k, indx)
247 | embedding_k = self.kmax(k_max.cuda())
248 | k_max_proj = nn.functional.normalize(embedding_k, dim=1)
249 |
250 | # undo shuffle
251 | k = self._batch_unshuffle_ddp(k, idx_unshuffle).detach()
252 |
253 | k_max_proj = self._batch_unshuffle_ddp(k_max_proj, idx_unshuffle).detach()
254 |
255 | # compute logits
256 | # Einstein sum is more intuitive
257 | # positive logits: Nx1
258 |
259 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
260 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
261 |
262 | logits = torch.cat([l_pos, l_neg], dim=1) / self.T
263 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
264 |
265 | #max
266 | l_pos_max = torch.einsum('nc,nc->n', [q_max_proj, k_max_proj]).unsqueeze(-1)
267 | l_neg_max = torch.einsum('nc,ck->nk', [q_max_proj, self.queue_max.clone().detach()])
268 |
269 | logits_max = torch.cat([l_pos_max, l_neg_max], dim=1) / self.T
270 | labels_max = torch.zeros(logits_max.shape[0], dtype=torch.long).cuda()
271 |
272 | # dequeue and enqueue
273 | self._dequeue_and_enqueue(k)
274 |
275 | self._dequeue_and_enqueue_max(k_max_proj)
276 |
277 | featcov16 = self.conv16(featmap)
278 | featcov16 = self.bn16(featcov16)
279 | featcov16 = self.relu(featcov16)
280 |
281 | criterion = nn.CrossEntropyLoss()
282 | CE = criterion(logits, labels)
283 |
284 | grad_wrt_act1 = torch.autograd.grad(outputs=CE, inputs=featmap,
285 | grad_outputs=torch.ones_like(CE), retain_graph=True,
286 | allow_unused=True)[0]
287 |
288 | gradcam = torch.relu((featmap * grad_wrt_act1).sum(dim=1))
289 |
290 |
291 | return logits, labels, logits_max, labels_max, self.encoder_q, featcov16, featmap, att_max, gradcam
292 |
293 | def inference(self, img):
294 | projfeat, feat, featmap = self.encoder_q(img)
295 |
296 | featcov16 = self.conv16(featmap)
297 | featcov16 = self.bn16(featcov16)
298 | featcov16 = self.relu(featcov16)
299 |
300 |
301 | img = featcov16.cpu().detach().numpy()
302 | img = np.max(img, axis=1)
303 | img = img - np.min(img)
304 | img = img / (1e-7 + np.max(img))
305 | img = torch.from_numpy(img)
306 |
307 |
308 |
309 | img = img[:, None, :, :]
310 | img = img.repeat(1, 2048, 1, 1)
311 | PFM = featmap.cuda() * img.cuda()
312 | aa = self.avgpool(PFM)
313 | bp_out_feat = aa.view(aa.size(0), -1)
314 | bp_out_feat = nn.functional.normalize(bp_out_feat, dim=1)
315 |
316 | feat = nn.functional.normalize(feat, dim=1)
317 |
318 | return projfeat, feat, featcov16, bp_out_feat
319 |
320 |
321 | # utils
322 | @torch.no_grad()
323 | def concat_all_gather(tensor):
324 | """
325 | Performs all_gather operation on the provided tensors.
326 | *** Warning ***: torch.distributed.all_gather has no gradient.
327 | """
328 | tensors_gather = [torch.ones_like(tensor)
329 | for _ in range(torch.distributed.get_world_size())]
330 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
331 |
332 | output = torch.cat(tensors_gather, dim=0)
333 | return output
334 |
--------------------------------------------------------------------------------
/moco/loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from PIL import ImageFilter
3 | import random
4 |
5 |
6 | class TwoCropsTransform:
7 | """Take two random crops of one image as the query and key."""
8 |
9 | def __init__(self, base_transform):
10 | self.base_transform = base_transform
11 |
12 | def __call__(self, x):
13 | q = self.base_transform(x)
14 | k = self.base_transform(x)
15 | return [q, k]
16 |
17 |
18 | class GaussianBlur(object):
19 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
20 |
21 | def __init__(self, sigma=[.1, 2.]):
22 | self.sigma = sigma
23 |
24 | def __call__(self, x):
25 | sigma = random.uniform(self.sigma[0], self.sigma[1])
26 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
27 | return x
28 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .byol import BYOL
2 | from torchvision.models import resnet50, resnet18
3 | import torch
4 | from .backbones import resnet50_cub200,resnet50_stanfordcars, resnet18_cifar_variant2,resnet18_cifar_variant1,resnet50_aircrafts
5 |
6 | def get_backbone(backbone, castrate=True):
7 | backbone = eval(f"{backbone}(pretrained=True)")
8 |
9 | if castrate:
10 | backbone.output_dim = backbone.fc.in_features
11 | backbone.fc = torch.nn.Identity()
12 |
13 | return backbone
14 |
15 |
16 | def get_model(model_cfg):
17 | if model_cfg.name == 'byol':
18 | model = BYOL(get_backbone(model_cfg.backbone))
19 | else:
20 | raise NotImplementedError
21 | return model
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/byol.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/__pycache__/byol.cpython-38.pyc
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar_resnet_1 import resnet18 as resnet18_cifar_variant1
2 | from .cifar_resnet_2 import ResNet18 as resnet18_cifar_variant2
3 |
4 | from .cub_resnet_1 import resnet50 as resnet50_cub200
5 | from .cub_resnet_1 import resnet50 as resnet50_stanfordcars
6 | from .cub_resnet_1 import resnet50 as resnet50_aircrafts
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/models/backbones/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/backbones/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/backbones/__pycache__/cifar_resnet_1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/backbones/__pycache__/cifar_resnet_1.cpython-38.pyc
--------------------------------------------------------------------------------
/models/backbones/__pycache__/cifar_resnet_2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/backbones/__pycache__/cifar_resnet_2.cpython-38.pyc
--------------------------------------------------------------------------------
/models/backbones/__pycache__/cub_resnet_1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/models/backbones/__pycache__/cub_resnet_1.cpython-38.pyc
--------------------------------------------------------------------------------
/models/backbones/cifar_resnet_1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import torch.utils.model_zoo as model_zoo
5 | # https://raw.githubusercontent.com/huyvnphan/PyTorch_CIFAR10/master/cifar10_models/resnet.py
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
18 | """3x3 convolution with padding"""
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=dilation, groups=groups, bias=False, dilation=dilation)
21 |
22 |
23 | def conv1x1(in_planes, out_planes, stride=1):
24 | """1x1 convolution"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
32 | base_width=64, dilation=1, norm_layer=None):
33 | super(BasicBlock, self).__init__()
34 | if norm_layer is None:
35 | norm_layer = nn.BatchNorm2d
36 | if groups != 1 or base_width != 64:
37 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
38 | if dilation > 1:
39 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
40 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
41 | self.conv1 = conv3x3(inplanes, planes, stride)
42 | self.bn1 = norm_layer(planes)
43 | self.relu = nn.ReLU(inplace=True)
44 | self.conv2 = conv3x3(planes, planes)
45 | self.bn2 = norm_layer(planes)
46 | self.downsample = downsample
47 | self.stride = stride
48 |
49 | def forward(self, x):
50 | identity = x
51 |
52 | out = self.conv1(x)
53 | out = self.bn1(out)
54 | out = self.relu(out)
55 |
56 | out = self.conv2(out)
57 | out = self.bn2(out)
58 |
59 | if self.downsample is not None:
60 | identity = self.downsample(x)
61 |
62 | out += identity
63 | out = self.relu(out)
64 |
65 | return out
66 |
67 |
68 | class Bottleneck(nn.Module):
69 | expansion = 4
70 |
71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
72 | base_width=64, dilation=1, norm_layer=None):
73 | super(Bottleneck, self).__init__()
74 | if norm_layer is None:
75 | norm_layer = nn.BatchNorm2d
76 | width = int(planes * (base_width / 64.)) * groups
77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
78 | self.conv1 = conv1x1(inplanes, width)
79 | self.bn1 = norm_layer(width)
80 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
81 | self.bn2 = norm_layer(width)
82 | self.conv3 = conv1x1(width, planes * self.expansion)
83 | self.bn3 = norm_layer(planes * self.expansion)
84 | self.relu = nn.ReLU(inplace=True)
85 | self.downsample = downsample
86 | self.stride = stride
87 |
88 | def forward(self, x):
89 | identity = x
90 |
91 | out = self.conv1(x)
92 | out = self.bn1(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv2(out)
96 | out = self.bn2(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv3(out)
100 | out = self.bn3(out)
101 |
102 | if self.downsample is not None:
103 | identity = self.downsample(x)
104 |
105 | out += identity
106 | out = self.relu(out)
107 |
108 | return out
109 |
110 |
111 | class ResNet(nn.Module):
112 |
113 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
114 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
115 | norm_layer=None):
116 | super(ResNet, self).__init__()
117 | if norm_layer is None:
118 | norm_layer = nn.BatchNorm2d
119 | self._norm_layer = norm_layer
120 |
121 | self.inplanes = 64
122 | self.dilation = 1
123 | if replace_stride_with_dilation is None:
124 | # each element in the tuple indicates if we should replace
125 | # the 2x2 stride with a dilated convolution instead
126 | replace_stride_with_dilation = [False, False, False]
127 | if len(replace_stride_with_dilation) != 3:
128 | raise ValueError("replace_stride_with_dilation should be None "
129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
130 | self.groups = groups
131 | self.base_width = width_per_group
132 |
133 | ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
135 | ## END
136 |
137 | self.bn1 = norm_layer(self.inplanes)
138 | self.relu = nn.ReLU(inplace=True)
139 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
140 | self.layer1 = self._make_layer(block, 64, layers[0])
141 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
142 | dilate=replace_stride_with_dilation[0])
143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
144 | dilate=replace_stride_with_dilation[1])
145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
146 | dilate=replace_stride_with_dilation[2])
147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148 | self.fc = nn.Linear(512 * block.expansion, num_classes)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
154 | nn.init.constant_(m.weight, 1)
155 | nn.init.constant_(m.bias, 0)
156 |
157 | # Zero-initialize the last BN in each residual branch,
158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
160 | if zero_init_residual:
161 | for m in self.modules():
162 | if isinstance(m, Bottleneck):
163 | nn.init.constant_(m.bn3.weight, 0)
164 | elif isinstance(m, BasicBlock):
165 | nn.init.constant_(m.bn2.weight, 0)
166 |
167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
168 | norm_layer = self._norm_layer
169 | downsample = None
170 | previous_dilation = self.dilation
171 | if dilate:
172 | self.dilation *= stride
173 | stride = 1
174 | if stride != 1 or self.inplanes != planes * block.expansion:
175 | downsample = nn.Sequential(
176 | conv1x1(self.inplanes, planes * block.expansion, stride),
177 | norm_layer(planes * block.expansion),
178 | )
179 |
180 | layers = []
181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
182 | self.base_width, previous_dilation, norm_layer))
183 | self.inplanes = planes * block.expansion
184 | for _ in range(1, blocks):
185 | layers.append(block(self.inplanes, planes, groups=self.groups,
186 | base_width=self.base_width, dilation=self.dilation,
187 | norm_layer=norm_layer))
188 |
189 | return nn.Sequential(*layers)
190 |
191 | def forward(self, x):
192 | x = self.conv1(x)
193 | x = self.bn1(x)
194 | x = self.relu(x)
195 | # x = self.maxpool(x)
196 |
197 | x = self.layer1(x)
198 | x = self.layer2(x)
199 | x = self.layer3(x)
200 | x = self.layer4(x)
201 |
202 | x = self.avgpool(x)
203 | x = x.reshape(x.size(0), -1)
204 | x = self.fc(x)
205 |
206 | return x
207 |
208 |
209 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):
210 | model = ResNet(block, layers, **kwargs)
211 |
212 | if pretrained:
213 | print("pretrain resnet arch:")
214 | print(arch)
215 | model.load_state_dict(model_zoo.load_url(model_urls[arch]))
216 | # script_dir = os.path.dirname(__file__)
217 | # state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device)
218 | # model.load_state_dict(state_dict)
219 | return model
220 |
221 |
222 | def resnet18(pretrained=False, progress=True, device='cpu', **kwargs):
223 | """Constructs a ResNet-18 model.
224 |
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | progress (bool): If True, displays a progress bar of the download to stderr
228 | """
229 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device,
230 | **kwargs)
231 |
232 |
233 | def resnet34(pretrained=False, progress=True, device='cpu', **kwargs):
234 | """Constructs a ResNet-34 model.
235 |
236 | Args:
237 | pretrained (bool): If True, returns a model pre-trained on ImageNet
238 | progress (bool): If True, displays a progress bar of the download to stderr
239 | """
240 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device,
241 | **kwargs)
242 |
243 |
244 | def resnet50(pretrained=False, progress=True, device='cpu', **kwargs):
245 | """Constructs a ResNet-50 model.
246 |
247 | Args:
248 | pretrained (bool): If True, returns a model pre-trained on ImageNet
249 | progress (bool): If True, displays a progress bar of the download to stderr
250 | """
251 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device,
252 | **kwargs)
253 |
254 |
255 | def resnet101(pretrained=False, progress=True, device='cpu', **kwargs):
256 | """Constructs a ResNet-101 model.
257 |
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | progress (bool): If True, displays a progress bar of the download to stderr
261 | """
262 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device,
263 | **kwargs)
264 |
265 |
266 | def resnet152(pretrained=False, progress=True, device='cpu', **kwargs):
267 | """Constructs a ResNet-152 model.
268 |
269 | Args:
270 | pretrained (bool): If True, returns a model pre-trained on ImageNet
271 | progress (bool): If True, displays a progress bar of the download to stderr
272 | """
273 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device,
274 | **kwargs)
275 |
276 |
277 | def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs):
278 | """Constructs a ResNeXt-50 32x4d model.
279 |
280 | Args:
281 | pretrained (bool): If True, returns a model pre-trained on ImageNet
282 | progress (bool): If True, displays a progress bar of the download to stderr
283 | """
284 | kwargs['groups'] = 32
285 | kwargs['width_per_group'] = 4
286 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
287 | pretrained, progress, device, **kwargs)
288 |
289 |
290 | def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs):
291 | """Constructs a ResNeXt-101 32x8d model.
292 |
293 | Args:
294 | pretrained (bool): If True, returns a model pre-trained on ImageNet
295 | progress (bool): If True, displays a progress bar of the download to stderr
296 | """
297 | kwargs['groups'] = 32
298 | kwargs['width_per_group'] = 8
299 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
300 | pretrained, progress, device, **kwargs)
301 |
302 |
303 | if __name__ == "__main__":
304 | model = resnet18()
305 | print(sum(p.numel() for p in model.parameters() if p.requires_grad)) # 11173962
--------------------------------------------------------------------------------
/models/backbones/cifar_resnet_2.py:
--------------------------------------------------------------------------------
1 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
2 | '''ResNet in PyTorch.
3 |
4 | For Pre-activation ResNet, see 'preact_resnet.py'.
5 |
6 | Reference:
7 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
8 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
9 | '''
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, in_planes, planes, stride=1):
19 | super(BasicBlock, self).__init__()
20 | self.conv1 = nn.Conv2d(
21 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
24 | stride=1, padding=1, bias=False)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 |
27 | self.shortcut = nn.Sequential()
28 | if stride != 1 or in_planes != self.expansion*planes:
29 | self.shortcut = nn.Sequential(
30 | nn.Conv2d(in_planes, self.expansion*planes,
31 | kernel_size=1, stride=stride, bias=False),
32 | nn.BatchNorm2d(self.expansion*planes)
33 | )
34 |
35 | def forward(self, x):
36 | out = F.relu(self.bn1(self.conv1(x)))
37 | out = self.bn2(self.conv2(out))
38 | out += self.shortcut(x)
39 | out = F.relu(out)
40 | return out
41 |
42 |
43 | class Bottleneck(nn.Module):
44 | expansion = 4
45 |
46 | def __init__(self, in_planes, planes, stride=1):
47 | super(Bottleneck, self).__init__()
48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
51 | stride=stride, padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, self.expansion *
54 | planes, kernel_size=1, bias=False)
55 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
56 |
57 | self.shortcut = nn.Sequential()
58 | if stride != 1 or in_planes != self.expansion*planes:
59 | self.shortcut = nn.Sequential(
60 | nn.Conv2d(in_planes, self.expansion*planes,
61 | kernel_size=1, stride=stride, bias=False),
62 | nn.BatchNorm2d(self.expansion*planes)
63 | )
64 |
65 | def forward(self, x):
66 | out = F.relu(self.bn1(self.conv1(x)))
67 | out = F.relu(self.bn2(self.conv2(out)))
68 | out = self.bn3(self.conv3(out))
69 | out += self.shortcut(x)
70 | out = F.relu(out)
71 | return out
72 |
73 |
74 | class ResNet(nn.Module):
75 | def __init__(self, block, num_blocks, num_classes=10):
76 | super(ResNet, self).__init__()
77 | self.in_planes = 64
78 |
79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
80 | stride=1, padding=1, bias=False)
81 | self.bn1 = nn.BatchNorm2d(64)
82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
86 | self.fc = nn.Linear(512*block.expansion, num_classes)
87 |
88 | def _make_layer(self, block, planes, num_blocks, stride):
89 | strides = [stride] + [1]*(num_blocks-1)
90 | layers = []
91 | for stride in strides:
92 | layers.append(block(self.in_planes, planes, stride))
93 | self.in_planes = planes * block.expansion
94 | return nn.Sequential(*layers)
95 |
96 | def forward(self, x):
97 | out = F.relu(self.bn1(self.conv1(x)))
98 | out = self.layer1(out)
99 | out = self.layer2(out)
100 | out = self.layer3(out)
101 | out = self.layer4(out)
102 | out = F.avg_pool2d(out, 4)
103 | out = out.view(out.size(0), -1)
104 | out = self.fc(out)
105 | return out
106 |
107 |
108 | def ResNet18():
109 | return ResNet(BasicBlock, [2, 2, 2, 2])
110 |
111 |
112 | def ResNet34():
113 | return ResNet(BasicBlock, [3, 4, 6, 3])
114 |
115 |
116 | def ResNet50():
117 | return ResNet(Bottleneck, [3, 4, 6, 3])
118 |
119 |
120 | def ResNet101():
121 | return ResNet(Bottleneck, [3, 4, 23, 3])
122 |
123 |
124 | def ResNet152():
125 | return ResNet(Bottleneck, [3, 8, 36, 3])
126 |
127 |
128 | def test():
129 | net = ResNet18()
130 | print(sum(p.numel() for p in net.parameters() if p.requires_grad))
131 | import torchvision
132 | net2 = torchvision.models.resnet18()
133 | print(sum(p.numel() for p in net2.parameters() if p.requires_grad))
134 | # y = net(torch.randn(1, 3, 32, 32))
135 | # print(y.size())
136 | # 11173962
137 | # 11689512
138 | # test()
--------------------------------------------------------------------------------
/models/backbones/cub_resnet_1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import torch.utils.model_zoo as model_zoo
5 | # https://raw.githubusercontent.com/huyvnphan/PyTorch_CIFAR10/master/cifar10_models/resnet.py
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
18 | """3x3 convolution with padding"""
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=dilation, groups=groups, bias=False, dilation=dilation)
21 |
22 |
23 | def conv1x1(in_planes, out_planes, stride=1):
24 | """1x1 convolution"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
32 | base_width=64, dilation=1, norm_layer=None):
33 | super(BasicBlock, self).__init__()
34 | if norm_layer is None:
35 | norm_layer = nn.BatchNorm2d
36 | if groups != 1 or base_width != 64:
37 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
38 | if dilation > 1:
39 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
40 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
41 | self.conv1 = conv3x3(inplanes, planes, stride)
42 | self.bn1 = norm_layer(planes)
43 | self.relu = nn.ReLU(inplace=True)
44 | self.conv2 = conv3x3(planes, planes)
45 | self.bn2 = norm_layer(planes)
46 | self.downsample = downsample
47 | self.stride = stride
48 |
49 | def forward(self, x):
50 | identity = x
51 |
52 | out = self.conv1(x)
53 | out = self.bn1(out)
54 | out = self.relu(out)
55 |
56 | out = self.conv2(out)
57 | out = self.bn2(out)
58 |
59 | if self.downsample is not None:
60 | identity = self.downsample(x)
61 |
62 | out += identity
63 | out = self.relu(out)
64 |
65 | return out
66 |
67 |
68 | class Bottleneck(nn.Module):
69 | expansion = 4
70 |
71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
72 | base_width=64, dilation=1, norm_layer=None):
73 | super(Bottleneck, self).__init__()
74 | if norm_layer is None:
75 | norm_layer = nn.BatchNorm2d
76 | width = int(planes * (base_width / 64.)) * groups
77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
78 | self.conv1 = conv1x1(inplanes, width)
79 | self.bn1 = norm_layer(width)
80 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
81 | self.bn2 = norm_layer(width)
82 | self.conv3 = conv1x1(width, planes * self.expansion)
83 | self.bn3 = norm_layer(planes * self.expansion)
84 | self.relu = nn.ReLU(inplace=True)
85 | self.downsample = downsample
86 | self.stride = stride
87 |
88 | def forward(self, x):
89 | identity = x
90 |
91 | out = self.conv1(x)
92 | out = self.bn1(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv2(out)
96 | out = self.bn2(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv3(out)
100 | out = self.bn3(out)
101 |
102 | if self.downsample is not None:
103 | identity = self.downsample(x)
104 |
105 | out += identity
106 | out = self.relu(out)
107 |
108 | return out
109 |
110 |
111 | class ResNet(nn.Module):
112 |
113 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
114 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
115 | norm_layer=None):
116 | super(ResNet, self).__init__()
117 | if norm_layer is None:
118 | norm_layer = nn.BatchNorm2d
119 | self._norm_layer = norm_layer
120 |
121 | self.inplanes = 64
122 | self.dilation = 1
123 | if replace_stride_with_dilation is None:
124 | # each element in the tuple indicates if we should replace
125 | # the 2x2 stride with a dilated convolution instead
126 | replace_stride_with_dilation = [False, False, False]
127 | if len(replace_stride_with_dilation) != 3:
128 | raise ValueError("replace_stride_with_dilation should be None "
129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
130 | self.groups = groups
131 | self.base_width = width_per_group
132 |
133 | ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
134 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
135 | ## END
136 |
137 | self.bn1 = norm_layer(self.inplanes)
138 | self.relu = nn.ReLU(inplace=True)
139 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
140 | self.layer1 = self._make_layer(block, 64, layers[0])
141 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
142 | dilate=replace_stride_with_dilation[0])
143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
144 | dilate=replace_stride_with_dilation[1])
145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
146 | dilate=replace_stride_with_dilation[2])
147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148 | self.fc = nn.Linear(512 * block.expansion, num_classes)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
154 | nn.init.constant_(m.weight, 1)
155 | nn.init.constant_(m.bias, 0)
156 |
157 | # Zero-initialize the last BN in each residual branch,
158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
160 | if zero_init_residual:
161 | for m in self.modules():
162 | if isinstance(m, Bottleneck):
163 | nn.init.constant_(m.bn3.weight, 0)
164 | elif isinstance(m, BasicBlock):
165 | nn.init.constant_(m.bn2.weight, 0)
166 |
167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
168 | norm_layer = self._norm_layer
169 | downsample = None
170 | previous_dilation = self.dilation
171 | if dilate:
172 | self.dilation *= stride
173 | stride = 1
174 | if stride != 1 or self.inplanes != planes * block.expansion:
175 | downsample = nn.Sequential(
176 | conv1x1(self.inplanes, planes * block.expansion, stride),
177 | norm_layer(planes * block.expansion),
178 | )
179 |
180 | layers = []
181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
182 | self.base_width, previous_dilation, norm_layer))
183 | self.inplanes = planes * block.expansion
184 | for _ in range(1, blocks):
185 | layers.append(block(self.inplanes, planes, groups=self.groups,
186 | base_width=self.base_width, dilation=self.dilation,
187 | norm_layer=norm_layer))
188 |
189 | return nn.Sequential(*layers)
190 |
191 | def forward(self, x):
192 | x = self.conv1(x)
193 | x = self.bn1(x)
194 | x = self.relu(x)
195 | #x = self.maxpool(x)
196 |
197 | x = self.layer1(x)
198 | x = self.layer2(x)
199 | x = self.layer3(x)
200 | feat = self.layer4(x)
201 |
202 | x = self.avgpool(feat)
203 | x = x.reshape(x.size(0), -1)
204 | x = self.fc(x)
205 |
206 | return x
207 |
208 |
209 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):
210 | model = ResNet(block, layers, **kwargs)
211 |
212 | if pretrained:
213 | print("pretrain resnet arch:")
214 | print(arch)
215 | model.load_state_dict(model_zoo.load_url(model_urls[arch]))
216 | # script_dir = os.path.dirname(__file__)
217 | # state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device)
218 | # model.load_state_dict(state_dict)
219 | return model
220 |
221 |
222 | def resnet18(pretrained=False, progress=True, device='cpu', **kwargs):
223 | """Constructs a ResNet-18 model.
224 |
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | progress (bool): If True, displays a progress bar of the download to stderr
228 | """
229 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device,
230 | **kwargs)
231 |
232 |
233 | def resnet34(pretrained=False, progress=True, device='cpu', **kwargs):
234 | """Constructs a ResNet-34 model.
235 |
236 | Args:
237 | pretrained (bool): If True, returns a model pre-trained on ImageNet
238 | progress (bool): If True, displays a progress bar of the download to stderr
239 | """
240 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device,
241 | **kwargs)
242 |
243 |
244 | def resnet50(pretrained=False, progress=True, device='cpu', **kwargs):
245 | """Constructs a ResNet-50 model.
246 |
247 | Args:
248 | pretrained (bool): If True, returns a model pre-trained on ImageNet
249 | progress (bool): If True, displays a progress bar of the download to stderr
250 | """
251 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device,
252 | **kwargs)
253 |
254 |
255 | def resnet101(pretrained=False, progress=True, device='cpu', **kwargs):
256 | """Constructs a ResNet-101 model.
257 |
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | progress (bool): If True, displays a progress bar of the download to stderr
261 | """
262 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device,
263 | **kwargs)
264 |
265 |
266 | def resnet152(pretrained=False, progress=True, device='cpu', **kwargs):
267 | """Constructs a ResNet-152 model.
268 |
269 | Args:
270 | pretrained (bool): If True, returns a model pre-trained on ImageNet
271 | progress (bool): If True, displays a progress bar of the download to stderr
272 | """
273 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device,
274 | **kwargs)
275 |
276 |
277 | def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs):
278 | """Constructs a ResNeXt-50 32x4d model.
279 |
280 | Args:
281 | pretrained (bool): If True, returns a model pre-trained on ImageNet
282 | progress (bool): If True, displays a progress bar of the download to stderr
283 | """
284 | kwargs['groups'] = 32
285 | kwargs['width_per_group'] = 4
286 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
287 | pretrained, progress, device, **kwargs)
288 |
289 |
290 | def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs):
291 | """Constructs a ResNeXt-101 32x8d model.
292 |
293 | Args:
294 | pretrained (bool): If True, returns a model pre-trained on ImageNet
295 | progress (bool): If True, displays a progress bar of the download to stderr
296 | """
297 | kwargs['groups'] = 32
298 | kwargs['width_per_group'] = 8
299 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
300 | pretrained, progress, device, **kwargs)
301 |
302 |
303 | if __name__ == "__main__":
304 | model = resnet18()
305 | print(sum(p.numel() for p in model.parameters() if p.requires_grad)) # 11173962
306 |
--------------------------------------------------------------------------------
/models/byol.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torchvision import transforms
7 | from math import pi, cos
8 | from torchvision.models import resnet50
9 | from collections import OrderedDict
10 | HPS = dict(
11 | max_steps=int(1000. * 1281167 / 4096), # 1000 epochs * 1281167 samples / batch size = 100 epochs * N of step/epoch
12 | # = total_epochs * len(dataloader)
13 | mlp_hidden_size=4096,
14 | projection_size=256,
15 | base_target_ema=4e-3,
16 | optimizer_config=dict(
17 | optimizer_name='lars',
18 | beta=0.9,
19 | trust_coef=1e-3,
20 | weight_decay=1.5e-6,
21 | exclude_bias_from_adaption=True),
22 | learning_rate_schedule=dict(
23 | base_learning_rate=0.2,
24 | warmup_steps=int(10.0 * 1281167 / 4096), # 10 epochs * N of steps/epoch = 10 epochs * len(dataloader)
25 | anneal_schedule='cosine'),
26 | batchnorm_kwargs=dict(
27 | decay_rate=0.9,
28 | eps=1e-5),
29 | seed=1337,
30 | )
31 |
32 |
33 |
34 | def D(p, z, version='simplified'): # negative cosine similarity
35 | if version == 'original':
36 | z = z.detach() # stop gradient
37 | p = F.normalize(p, dim=1) # l2-normalize
38 | z = F.normalize(z, dim=1) # l2-normalize
39 | return -(p*z).sum(dim=1).mean()
40 |
41 | elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__
42 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
43 | else:
44 | raise Exception
45 |
46 |
47 | class MLP(nn.Module):
48 | def __init__(self, in_dim):
49 | super().__init__()
50 |
51 | self.layer1 = nn.Sequential(
52 | nn.Linear(in_dim, HPS['mlp_hidden_size']),
53 | nn.BatchNorm1d(HPS['mlp_hidden_size'], eps=HPS['batchnorm_kwargs']['eps'], momentum=1-HPS['batchnorm_kwargs']['decay_rate']),
54 | nn.ReLU(inplace=True)
55 | )
56 | self.layer2 = nn.Linear(HPS['mlp_hidden_size'], HPS['projection_size'])
57 |
58 | def forward(self, x):
59 | x = self.layer1(x)
60 | x = self.layer2(x)
61 | return x
62 |
63 | class BYOL(nn.Module):
64 | def __init__(self, backbone):
65 | super().__init__()
66 |
67 | self.backbone = backbone
68 | self.projector = MLP(backbone.output_dim)
69 | self.online_encoder = nn.Sequential(
70 | self.backbone,
71 | self.projector
72 | )
73 |
74 | self.target_encoder = copy.deepcopy(self.online_encoder)
75 | self.online_predictor = MLP(HPS['projection_size'])
76 | #raise NotImplementedError('Please put update_moving_average to training')
77 |
78 |
79 | self.bilinear = 32
80 | self.conv16 = nn.Conv2d(2048, self.bilinear, kernel_size=1, stride=1, padding=0, bias=False)
81 | self.bn16 = nn.BatchNorm2d(self.bilinear)
82 | self.relu = nn.ReLU(inplace=True)
83 | self.avgpool = nn.AvgPool2d(7, stride=1)
84 | self.avgpool14 = nn.AvgPool2d(14, stride=1)
85 |
86 |
87 |
88 | def target_ema(self, k, K, base_ema=HPS['base_target_ema']):
89 | # tau_base = 0.996
90 | # base_ema = 1 - tau_base = 0.996
91 | return 1 - base_ema * (cos(pi*k/K)+1)/2
92 | # return 1 - (1-self.tau_base) * (cos(pi*k/K)+1)/2
93 |
94 | @torch.no_grad()
95 | def update_moving_average(self): #, global_step, max_steps
96 | #tau = self.target_ema(global_step, max_steps)
97 | tau = 0.996
98 | for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
99 | target.data = tau * target.data + (1 - tau) * online.data
100 |
101 | def forward(self, x1, x2):
102 | f_o, h_o = self.online_encoder, self.online_predictor
103 | f_t = self.target_encoder
104 |
105 |
106 | x = self.online_encoder[0].conv1(x1)
107 | x = self.online_encoder[0].bn1(x)
108 | x = self.online_encoder[0].relu(x)
109 | x = self.online_encoder[0].layer1(x)
110 | x = self.online_encoder[0].layer2(x)
111 | x = self.online_encoder[0].layer3(x)
112 | feat_map1 = self.online_encoder[0].layer4(x)
113 | x = self.online_encoder[0].avgpool(feat_map1)
114 | x = x.reshape(x.size(0), -1)
115 | z1_o = self.online_encoder[1](x)
116 |
117 | x = self.online_encoder[0].conv1(x2)
118 | x = self.online_encoder[0].bn1(x)
119 | x = self.online_encoder[0].relu(x)
120 | x = self.online_encoder[0].layer1(x)
121 | x = self.online_encoder[0].layer2(x)
122 | x = self.online_encoder[0].layer3(x)
123 | feat_map2 = self.online_encoder[0].layer4(x)
124 | x = self.online_encoder[0].avgpool(feat_map2)
125 | x = x.reshape(x.size(0), -1)
126 | z2_o = self.online_encoder[1](x)
127 |
128 |
129 |
130 | p1_o = h_o(z1_o)
131 | p2_o = h_o(z2_o)
132 |
133 | with torch.no_grad():
134 | self.update_moving_average()
135 | z1_t = f_t(x1)
136 | z2_t = f_t(x2)
137 |
138 | L = D(p1_o, z2_t) / 2 + D(p2_o, z1_t) / 2
139 |
140 |
141 | grad_wrt_act1 = torch.autograd.grad(outputs=L, inputs=feat_map1,
142 | grad_outputs=torch.ones_like(L), retain_graph=True,
143 | allow_unused=True)[0]
144 |
145 | gradcam = torch.relu((feat_map1 * grad_wrt_act1).sum(dim=1))
146 |
147 | featcov16 = self.conv16(feat_map1)
148 | featcov16 = self.bn16(featcov16)
149 | featcov16 = self.relu(featcov16)
150 | img, _ = torch.max(featcov16, axis=1)
151 | img = img - torch.min(img)
152 | att_max = img / (1e-7 + torch.max(img)) #batch*7*7
153 |
154 | return {'loss': L, 'attmap': att_max, 'gradcam': gradcam}
155 |
156 |
157 |
158 | def inference(self, img):
159 | x = self.online_encoder[0].conv1(img)
160 | x = self.online_encoder[0].bn1(x)
161 | x = self.online_encoder[0].relu(x)
162 | x = self.online_encoder[0].layer1(x)
163 | x = self.online_encoder[0].layer2(x)
164 | x = self.online_encoder[0].layer3(x)
165 | feat_map1 = self.online_encoder[0].layer4(x)
166 | x = self.online_encoder[0].avgpool(feat_map1)
167 | x = x.reshape(x.size(0), -1)
168 |
169 | featcov16 = self.conv16(feat_map1)
170 | featcov16 = self.bn16(featcov16)
171 | featcov16 = self.relu(featcov16)
172 |
173 |
174 |
175 | img, _ = torch.max(featcov16, axis=1)
176 | img = img - torch.min(img)
177 | att_max = img / (1e-7 + torch.max(img))
178 |
179 | img = att_max[:, None, :, :]
180 | img = img.repeat(1, 2048, 1, 1)
181 | PFM = feat_map1.cuda() * img.cuda()
182 | aa = self.avgpool14(PFM)
183 | bp_out_feat = aa.view(aa.size(0), -1)
184 | bp_out_feat = nn.functional.normalize(bp_out_feat, dim=1)
185 |
186 |
187 | return x, bp_out_feat
188 |
189 |
190 |
191 | if __name__ == "__main__":
192 | pass
193 |
--------------------------------------------------------------------------------
/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .lars import LARS
2 | from .lars_simclr import LARS_simclr
3 | from .larc import LARC
4 | import torch
5 | from .lr_scheduler import LR_Scheduler
6 |
7 |
8 | def get_optimizer(name, model, lr, momentum, weight_decay):
9 |
10 | predictor_prefix = ('module.predictor', 'predictor')
11 | parameters = [{
12 | 'name': 'base',
13 | 'params': [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)],
14 | 'lr': lr
15 | },{
16 | 'name': 'predictor',
17 | 'params': [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)],
18 | 'lr': lr
19 | }]
20 | if name == 'lars':
21 | optimizer = LARS(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
22 | elif name == 'sgd':
23 | optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
24 | elif name == 'lars_simclr': # Careful
25 | optimizer = LARS_simclr(model.named_modules(), lr=lr, momentum=momentum, weight_decay=weight_decay)
26 | elif name == 'larc':
27 | optimizer = LARC(
28 | torch.optim.SGD(
29 | parameters,
30 | lr=lr,
31 | momentum=momentum,
32 | weight_decay=weight_decay
33 | ),
34 | trust_coefficient=0.001,
35 | clip=False
36 | )
37 | else:
38 | raise NotImplementedError
39 | return optimizer
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/optimizers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/optimizers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/optimizers/__pycache__/larc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/optimizers/__pycache__/larc.cpython-38.pyc
--------------------------------------------------------------------------------
/optimizers/__pycache__/lars.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/optimizers/__pycache__/lars.cpython-38.pyc
--------------------------------------------------------------------------------
/optimizers/__pycache__/lr_scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/optimizers/__pycache__/lr_scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/optimizers/larc.py:
--------------------------------------------------------------------------------
1 | """SwAV use larc instead of lars optimizer"""
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn.parameter import Parameter
6 | from torch.optim.optimizer import Optimizer
7 |
8 | def main(): # Example
9 | import torchvision
10 | model = torchvision.models.resnet18(pretrained=False)
11 | # optim = torch.optim.Adam(model.parameters(), lr=0.0001)
12 | optim = torch.optim.SGD(model.parameters(),lr=0.2, momentum=0.9, weight_decay=1.5e-6)
13 | optim = LARC(optim)
14 |
15 | class LARC(Optimizer):
16 | """
17 | :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC,
18 | in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
19 | local learning rate for each individual parameter. The algorithm is designed to improve
20 | convergence of large batch training.
21 |
22 | See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
23 | In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
24 | of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
25 | ```
26 | model = ...
27 | optim = torch.optim.Adam(model.parameters(), lr=...)
28 | optim = LARC(optim)
29 | ```
30 | It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
31 | ```
32 | model = ...
33 | optim = torch.optim.Adam(model.parameters(), lr=...)
34 | optim = LARC(optim)
35 | optim = apex.fp16_utils.FP16_Optimizer(optim)
36 | ```
37 | Args:
38 | optimizer: Pytorch optimizer to wrap and modify learning rate for.
39 | trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
40 | clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
41 | eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
42 | """
43 |
44 | def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
45 | self.optim = optimizer
46 | self.trust_coefficient = trust_coefficient
47 | self.eps = eps
48 | self.clip = clip
49 |
50 | def __getstate__(self):
51 | return self.optim.__getstate__()
52 |
53 | def __setstate__(self, state):
54 | self.optim.__setstate__(state)
55 |
56 | @property
57 | def state(self):
58 | return self.optim.state
59 |
60 | def __repr__(self):
61 | return self.optim.__repr__()
62 |
63 | @property
64 | def param_groups(self):
65 | return self.optim.param_groups
66 |
67 | @param_groups.setter
68 | def param_groups(self, value):
69 | self.optim.param_groups = value
70 |
71 | def state_dict(self):
72 | return self.optim.state_dict()
73 |
74 | def load_state_dict(self, state_dict):
75 | self.optim.load_state_dict(state_dict)
76 |
77 | def zero_grad(self):
78 | self.optim.zero_grad()
79 |
80 | def add_param_group(self, param_group):
81 | self.optim.add_param_group( param_group)
82 |
83 | def step(self):
84 | with torch.no_grad():
85 | weight_decays = []
86 | for group in self.optim.param_groups:
87 | # absorb weight decay control from optimizer
88 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
89 | weight_decays.append(weight_decay)
90 | group['weight_decay'] = 0
91 | for p in group['params']:
92 | if p.grad is None:
93 | continue
94 | param_norm = torch.norm(p.data)
95 | grad_norm = torch.norm(p.grad.data)
96 |
97 | if param_norm != 0 and grad_norm != 0:
98 | # calculate adaptive lr + weight decay
99 | adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)
100 |
101 | # clip learning rate for LARC
102 | if self.clip:
103 | # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
104 | adaptive_lr = min(adaptive_lr/group['lr'], 1)
105 |
106 | p.grad.data += weight_decay * p.data
107 | p.grad.data *= adaptive_lr
108 |
109 | self.optim.step()
110 | # return weight decay control to optimizer
111 | for i, group in enumerate(self.optim.param_groups):
112 | group['weight_decay'] = weight_decays[i]
113 |
114 |
115 |
116 | if __name__ == "__main__":
117 | main()
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/optimizers/lars.py:
--------------------------------------------------------------------------------
1 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 | class LARS(Optimizer):
6 | r"""Implements layer-wise adaptive rate scaling for SGD.
7 |
8 | Args:
9 | params (iterable): iterable of parameters to optimize or dicts defining
10 | parameter groups
11 | lr (float): base learning rate (\gamma_0)
12 | momentum (float, optional): momentum factor (default: 0) ("m")
13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
14 | ("\beta")
15 | eta (float, optional): LARS coefficient
16 | max_epoch: maximum training epoch to determine polynomial LR decay.
17 |
18 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
19 | Large Batch Training of Convolutional Networks:
20 | https://arxiv.org/abs/1708.03888
21 |
22 | Example:
23 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3)
24 | >>> optimizer.zero_grad()
25 | >>> loss_fn(model(input), target).backward()
26 | >>> optimizer.step()
27 | """
28 | def __init__(self, params, lr=required, momentum=.9,
29 | weight_decay=.0005, eta=0.001, max_epoch=200):
30 | if lr is not required and lr < 0.0:
31 | raise ValueError("Invalid learning rate: {}".format(lr))
32 | if momentum < 0.0:
33 | raise ValueError("Invalid momentum value: {}".format(momentum))
34 | if weight_decay < 0.0:
35 | raise ValueError("Invalid weight_decay value: {}"
36 | .format(weight_decay))
37 | if eta < 0.0:
38 | raise ValueError("Invalid LARS coefficient value: {}".format(eta))
39 |
40 | self.epoch = 0
41 | defaults = dict(lr=lr, momentum=momentum,
42 | weight_decay=weight_decay,
43 | eta=eta, max_epoch=max_epoch)
44 | super(LARS, self).__init__(params, defaults)
45 |
46 | def step(self, epoch=None, closure=None):
47 | """Performs a single optimization step.
48 |
49 | Arguments:
50 | closure (callable, optional): A closure that reevaluates the model
51 | and returns the loss.
52 | epoch: current epoch to calculate polynomial LR decay schedule.
53 | if None, uses self.epoch and increments it.
54 | """
55 | loss = None
56 | if closure is not None:
57 | loss = closure()
58 |
59 | if epoch is None:
60 | epoch = self.epoch
61 | self.epoch += 1
62 |
63 | for group in self.param_groups:
64 | weight_decay = group['weight_decay']
65 | momentum = group['momentum']
66 | eta = group['eta']
67 | lr = group['lr']
68 | max_epoch = group['max_epoch']
69 |
70 | for p in group['params']:
71 | if p.grad is None:
72 | continue
73 |
74 | param_state = self.state[p]
75 | d_p = p.grad.data
76 |
77 | weight_norm = torch.norm(p.data)
78 | grad_norm = torch.norm(d_p)
79 |
80 | # Global LR computed on polynomial decay schedule
81 | decay = (1 - float(epoch) / max_epoch) ** 2
82 | global_lr = lr * decay
83 |
84 | # Compute local learning rate for this layer
85 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm)
86 |
87 | # Update the momentum term
88 | actual_lr = local_lr * global_lr
89 |
90 | if 'momentum_buffer' not in param_state:
91 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
92 | else:
93 | buf = param_state['momentum_buffer']
94 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr)
95 | p.data.add_(-buf)
96 |
97 | return loss
--------------------------------------------------------------------------------
/optimizers/lars_simclr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch.optim.optimizer import Optimizer
4 | import torch.nn as nn
5 |
6 | class LARS_simclr(Optimizer):
7 | def __init__(self,
8 | named_modules,
9 | lr,
10 | momentum=0.9, # beta? YES
11 | trust_coef=1e-3,
12 | weight_decay=1.5e-6,
13 | exclude_bias_from_adaption=True):
14 | '''byol: As in SimCLR and official implementation of LARS, we exclude bias # and batchnorm weight from the Lars adaptation and weightdecay'''
15 | defaults = dict(momentum=momentum,
16 | lr=lr,
17 | weight_decay=weight_decay,
18 | trust_coef=trust_coef)
19 | parameters = self.exclude_from_model(named_modules, exclude_bias_from_adaption)
20 | super(LARS_simclr, self).__init__(parameters, defaults)
21 |
22 | @torch.no_grad()
23 | def step(self):
24 | for group in self.param_groups: # only 1 group in most cases
25 | weight_decay = group['weight_decay']
26 | momentum = group['momentum']
27 | lr = group['lr']
28 |
29 | trust_coef = group['trust_coef']
30 | # print(group['name'])
31 | # eps = group['eps']
32 | for p in group['params']:
33 | # breakpoint()
34 | if p.grad is None:
35 | continue
36 | global_lr = lr
37 | velocity = self.state[p].get('velocity', 0)
38 | # if name in self.exclude_from_layer_adaptation:
39 | if self._use_weight_decay(group):
40 | p.grad.data += weight_decay * p.data
41 |
42 | trust_ratio = 1.0
43 | if self._do_layer_adaptation(group):
44 | w_norm = torch.norm(p.data, p=2)
45 | g_norm = torch.norm(p.grad.data, p=2)
46 | trust_ratio = trust_coef * w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1.0
47 | scaled_lr = global_lr * trust_ratio # trust_ratio is the local_lr
48 | next_v = momentum * velocity + scaled_lr * p.grad.data
49 | update = next_v
50 | p.data = p.data - update
51 |
52 |
53 | def _use_weight_decay(self, group):
54 | return False if group['name'] == 'exclude' else True
55 | def _do_layer_adaptation(self, group):
56 | return False if group['name'] == 'exclude' else True
57 |
58 | def exclude_from_model(self, named_modules, exclude_bias_from_adaption=True):
59 | base = []
60 | exclude = []
61 | for name, module in named_modules:
62 | if type(module) in [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]:
63 | # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
64 | for name2, param in module.named_parameters():
65 | exclude.append(param)
66 | else:
67 | for name2, param in module.named_parameters():
68 | if name2 == 'bias':
69 | exclude.append(param)
70 | elif name2 == 'weight':
71 | base.append(param)
72 | else:
73 | pass # non leaf modules
74 | return [{
75 | 'name': 'base',
76 | 'params': base
77 | },{
78 | 'name': 'exclude',
79 | 'params': exclude
80 | }] if exclude_bias_from_adaption == True else [{
81 | 'name': 'base',
82 | 'params': base+exclude
83 | }]
84 |
85 | if __name__ == "__main__":
86 |
87 | resnet = torchvision.models.resnet18(pretrained=False)
88 | model = resnet
89 |
90 | optimizer = LARS_simclr(model.named_modules(), lr=0.1)
91 | # print()
92 | # out = optimizer.exclude_from_model(model.named_modules(),exclude_bias_from_adaption=False)
93 | # print(len(out[0]['params']))
94 | # exit()
95 |
96 | criterion = torch.nn.CrossEntropyLoss()
97 | for i in range(100):
98 | model.zero_grad()
99 | pred = model(torch.randn((2,3,32,32)))
100 | loss = pred.mean()
101 | loss.backward()
102 | optimizer.step()
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------
/optimizers/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | # from torch.optim.lr_scheduler import _LRScheduler
4 |
5 |
6 | class LR_Scheduler(object):
7 | def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False):
8 | self.base_lr = base_lr
9 | self.constant_predictor_lr = constant_predictor_lr
10 | warmup_iter = iter_per_epoch * warmup_epochs
11 | warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
12 | decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
13 | cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter))
14 |
15 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
16 | self.optimizer = optimizer
17 | self.iter = 0
18 | self.current_lr = 0
19 | def step(self):
20 | for param_group in self.optimizer.param_groups:
21 |
22 | if self.constant_predictor_lr and param_group['name'] == 'predictor':
23 | param_group['lr'] = self.base_lr
24 | else:
25 | lr = param_group['lr'] = self.lr_schedule[self.iter]
26 |
27 | self.iter += 1
28 | self.current_lr = lr
29 | return lr
30 | def get_lr(self):
31 | return self.current_lr
32 |
33 | if __name__ == "__main__":
34 | import torchvision
35 | model = torchvision.models.resnet50()
36 | optimizer = torch.optim.SGD(model.parameters(), lr=999)
37 | epochs = 100
38 | n_iter = 1000
39 | scheduler = LR_Scheduler(optimizer, 10, 1, epochs, 3, 0, n_iter)
40 | import matplotlib.pyplot as plt
41 | lrs = []
42 | for epoch in range(epochs):
43 | for it in range(n_iter):
44 | lr = scheduler.step()
45 | lrs.append(lr)
46 | plt.plot(lrs)
47 | plt.show()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | tqdm
4 | Pillow
5 | numpy
6 | matplotlib
7 | yaml==5.3.1
8 | tensorboardx==2.1
9 |
10 |
--------------------------------------------------------------------------------
/resnet_output.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | #from torchvision._internally_replaced_utils import load_state_dict_from_url
5 | from torch.hub import load_state_dict_from_url
6 | from typing import Type, Any, Callable, Union, List, Optional
7 |
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
11 | 'wide_resnet50_2', 'wide_resnet101_2']
12 |
13 |
14 | model_urls = {
15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
24 | }
25 |
26 |
27 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
28 | """3x3 convolution with padding"""
29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=dilation, groups=groups, bias=False, dilation=dilation)
31 |
32 |
33 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
34 | """1x1 convolution"""
35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
36 |
37 |
38 | class BasicBlock(nn.Module):
39 | expansion: int = 1
40 |
41 | def __init__(
42 | self,
43 | inplanes: int,
44 | planes: int,
45 | stride: int = 1,
46 | downsample: Optional[nn.Module] = None,
47 | groups: int = 1,
48 | base_width: int = 64,
49 | dilation: int = 1,
50 | norm_layer: Optional[Callable[..., nn.Module]] = None
51 | ) -> None:
52 | super(BasicBlock, self).__init__()
53 | if norm_layer is None:
54 | norm_layer = nn.BatchNorm2d
55 | if groups != 1 or base_width != 64:
56 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
57 | if dilation > 1:
58 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
59 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
60 | self.conv1 = conv3x3(inplanes, planes, stride)
61 | self.bn1 = norm_layer(planes)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.conv2 = conv3x3(planes, planes)
64 | self.bn2 = norm_layer(planes)
65 | self.downsample = downsample
66 | self.stride = stride
67 |
68 | def forward(self, x: Tensor) -> Tensor:
69 | identity = x
70 |
71 | out = self.conv1(x)
72 | out = self.bn1(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv2(out)
76 | out = self.bn2(out)
77 |
78 | if self.downsample is not None:
79 | identity = self.downsample(x)
80 |
81 | out += identity
82 | out = self.relu(out)
83 |
84 | return out
85 |
86 |
87 | class Bottleneck(nn.Module):
88 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
89 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
90 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
91 | # This variant is also known as ResNet V1.5 and improves accuracy according to
92 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
93 |
94 | expansion: int = 4
95 |
96 | def __init__(
97 | self,
98 | inplanes: int,
99 | planes: int,
100 | stride: int = 1,
101 | downsample: Optional[nn.Module] = None,
102 | groups: int = 1,
103 | base_width: int = 64,
104 | dilation: int = 1,
105 | norm_layer: Optional[Callable[..., nn.Module]] = None
106 | ) -> None:
107 | super(Bottleneck, self).__init__()
108 | if norm_layer is None:
109 | norm_layer = nn.BatchNorm2d
110 | width = int(planes * (base_width / 64.)) * groups
111 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
112 | self.conv1 = conv1x1(inplanes, width)
113 | self.bn1 = norm_layer(width)
114 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
115 | self.bn2 = norm_layer(width)
116 | self.conv3 = conv1x1(width, planes * self.expansion)
117 | self.bn3 = norm_layer(planes * self.expansion)
118 | self.relu = nn.ReLU(inplace=True)
119 | self.downsample = downsample
120 | self.stride = stride
121 |
122 | def forward(self, x: Tensor) -> Tensor:
123 | identity = x
124 |
125 | out = self.conv1(x)
126 | out = self.bn1(out)
127 | out = self.relu(out)
128 |
129 | out = self.conv2(out)
130 | out = self.bn2(out)
131 | out = self.relu(out)
132 |
133 | out = self.conv3(out)
134 | out = self.bn3(out)
135 |
136 | if self.downsample is not None:
137 | identity = self.downsample(x)
138 |
139 | out += identity
140 | out = self.relu(out)
141 |
142 | return out
143 |
144 |
145 | class ResNet(nn.Module):
146 |
147 | def __init__(
148 | self,
149 | block: Type[Union[BasicBlock, Bottleneck]],
150 | layers: List[int],
151 | num_classes: int = 1000,
152 | zero_init_residual: bool = False,
153 | groups: int = 1,
154 | width_per_group: int = 64,
155 | replace_stride_with_dilation: Optional[List[bool]] = None,
156 | norm_layer: Optional[Callable[..., nn.Module]] = None
157 | ) -> None:
158 | super(ResNet, self).__init__()
159 | if norm_layer is None:
160 | norm_layer = nn.BatchNorm2d
161 | self._norm_layer = norm_layer
162 |
163 | self.inplanes = 64
164 | self.dilation = 1
165 | if replace_stride_with_dilation is None:
166 | # each element in the tuple indicates if we should replace
167 | # the 2x2 stride with a dilated convolution instead
168 | replace_stride_with_dilation = [False, False, False]
169 | if len(replace_stride_with_dilation) != 3:
170 | raise ValueError("replace_stride_with_dilation should be None "
171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
172 | self.groups = groups
173 | self.base_width = width_per_group
174 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
175 | bias=False)
176 | self.bn1 = norm_layer(self.inplanes)
177 | self.relu = nn.ReLU(inplace=True)
178 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
179 | self.layer1 = self._make_layer(block, 64, layers[0])
180 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
181 | dilate=replace_stride_with_dilation[0])
182 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
183 | dilate=replace_stride_with_dilation[1])
184 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
185 | dilate=replace_stride_with_dilation[2])
186 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
187 | self.fc = nn.Linear(512 * block.expansion, num_classes)
188 |
189 | for m in self.modules():
190 | if isinstance(m, nn.Conv2d):
191 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
192 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
193 | nn.init.constant_(m.weight, 1)
194 | nn.init.constant_(m.bias, 0)
195 |
196 | # Zero-initialize the last BN in each residual branch,
197 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
198 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
199 | if zero_init_residual:
200 | for m in self.modules():
201 | if isinstance(m, Bottleneck):
202 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
203 | elif isinstance(m, BasicBlock):
204 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
205 |
206 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
207 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
208 | norm_layer = self._norm_layer
209 | downsample = None
210 | previous_dilation = self.dilation
211 | if dilate:
212 | self.dilation *= stride
213 | stride = 1
214 | if stride != 1 or self.inplanes != planes * block.expansion:
215 | downsample = nn.Sequential(
216 | conv1x1(self.inplanes, planes * block.expansion, stride),
217 | norm_layer(planes * block.expansion),
218 | )
219 |
220 | layers = []
221 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
222 | self.base_width, previous_dilation, norm_layer))
223 | self.inplanes = planes * block.expansion
224 | for _ in range(1, blocks):
225 | layers.append(block(self.inplanes, planes, groups=self.groups,
226 | base_width=self.base_width, dilation=self.dilation,
227 | norm_layer=norm_layer))
228 |
229 | return nn.Sequential(*layers)
230 |
231 | def _forward_impl(self, x: Tensor) -> Tensor:
232 | # See note [TorchScript super()]
233 | x = self.conv1(x)
234 | x = self.bn1(x)
235 | x = self.relu(x)
236 | x = self.maxpool(x)
237 |
238 | x = self.layer1(x)
239 | x = self.layer2(x)
240 | x = self.layer3(x)
241 | featmap = self.layer4(x)
242 |
243 | x = self.avgpool(featmap)
244 | x = torch.flatten(x, 1)
245 | y = self.fc(x)
246 |
247 | return y, x, featmap
248 |
249 | def forward(self, x: Tensor) -> Tensor:
250 | return self._forward_impl(x)
251 |
252 |
253 | def _resnet(
254 | arch: str,
255 | block: Type[Union[BasicBlock, Bottleneck]],
256 | layers: List[int],
257 | pretrained: bool,
258 | progress: bool,
259 | **kwargs: Any
260 | ) -> ResNet:
261 | model = ResNet(block, layers, **kwargs)
262 | if pretrained:
263 | state_dict = load_state_dict_from_url(model_urls[arch],
264 | progress=progress)
265 | model.load_state_dict(state_dict)
266 | return model
267 |
268 |
269 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
270 | r"""ResNet-18 model from
271 | `"Deep Residual Learning for Image Recognition" `_.
272 |
273 | Args:
274 | pretrained (bool): If True, returns a model pre-trained on ImageNet
275 | progress (bool): If True, displays a progress bar of the download to stderr
276 | """
277 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
278 | **kwargs)
279 |
280 |
281 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
282 | r"""ResNet-34 model from
283 | `"Deep Residual Learning for Image Recognition" `_.
284 |
285 | Args:
286 | pretrained (bool): If True, returns a model pre-trained on ImageNet
287 | progress (bool): If True, displays a progress bar of the download to stderr
288 | """
289 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
290 | **kwargs)
291 |
292 |
293 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
294 | r"""ResNet-50 model from
295 | `"Deep Residual Learning for Image Recognition" `_.
296 |
297 | Args:
298 | pretrained (bool): If True, returns a model pre-trained on ImageNet
299 | progress (bool): If True, displays a progress bar of the download to stderr
300 | """
301 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
302 | **kwargs)
303 |
304 |
305 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
306 | r"""ResNet-101 model from
307 | `"Deep Residual Learning for Image Recognition" `_.
308 |
309 | Args:
310 | pretrained (bool): If True, returns a model pre-trained on ImageNet
311 | progress (bool): If True, displays a progress bar of the download to stderr
312 | """
313 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
314 | **kwargs)
315 |
316 |
317 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
318 | r"""ResNet-152 model from
319 | `"Deep Residual Learning for Image Recognition" `_.
320 |
321 | Args:
322 | pretrained (bool): If True, returns a model pre-trained on ImageNet
323 | progress (bool): If True, displays a progress bar of the download to stderr
324 | """
325 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
326 | **kwargs)
327 |
328 |
329 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
330 | r"""ResNeXt-50 32x4d model from
331 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
332 |
333 | Args:
334 | pretrained (bool): If True, returns a model pre-trained on ImageNet
335 | progress (bool): If True, displays a progress bar of the download to stderr
336 | """
337 | kwargs['groups'] = 32
338 | kwargs['width_per_group'] = 4
339 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
340 | pretrained, progress, **kwargs)
341 |
342 |
343 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
344 | r"""ResNeXt-101 32x8d model from
345 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
346 |
347 | Args:
348 | pretrained (bool): If True, returns a model pre-trained on ImageNet
349 | progress (bool): If True, displays a progress bar of the download to stderr
350 | """
351 | kwargs['groups'] = 32
352 | kwargs['width_per_group'] = 8
353 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
354 | pretrained, progress, **kwargs)
355 |
356 |
357 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
358 | r"""Wide ResNet-50-2 model from
359 | `"Wide Residual Networks" `_.
360 |
361 | The model is the same as ResNet except for the bottleneck number of channels
362 | which is twice larger in every block. The number of channels in outer 1x1
363 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
364 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
365 |
366 | Args:
367 | pretrained (bool): If True, returns a model pre-trained on ImageNet
368 | progress (bool): If True, displays a progress bar of the download to stderr
369 | """
370 | kwargs['width_per_group'] = 64 * 2
371 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
372 | pretrained, progress, **kwargs)
373 |
374 |
375 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
376 | r"""Wide ResNet-101-2 model from
377 | `"Wide Residual Networks" `_.
378 |
379 | The model is the same as ResNet except for the bottleneck number of channels
380 | which is twice larger in every block. The number of channels in outer 1x1
381 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
382 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
383 |
384 | Args:
385 | pretrained (bool): If True, returns a model pre-trained on ImageNet
386 | progress (bool): If True, displays a progress bar of the download to stderr
387 | """
388 | kwargs['width_per_group'] = 64 * 2
389 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
390 | pretrained, progress, **kwargs)
391 |
--------------------------------------------------------------------------------
/run_all.sh:
--------------------------------------------------------------------------------
1 | #export PATH=/home/ubuntu/miniconda3/envs/dl/bin:$PATH
2 | #
3 |
4 | python main.py --data_dir ./CUB200 --log_dir ./logs/ -c configs/byol_cub200.yaml --ckpt_dir ./.cache/ --hide_progress
5 | python main.py --data_dir ./StanfordCars --log_dir ./logs/ -c configs/byol_stanfordcars.yaml --ckpt_dir ./.cache/ --hide_progress
6 | python main.py --data_dir ./Aircraft --log_dir ./logs/ -c configs/byol_aircrafts.yaml --ckpt_dir ./.cache/ --hide_progress
7 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .average_meter import AverageMeter
2 | from .accuracy import accuracy
3 | from .knn_monitor import knn_monitor
4 | from .logger import Logger
5 | from .file_exist_fn import file_exist_check
--------------------------------------------------------------------------------
/tools/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/accuracy.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/accuracy.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/average_meter.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/average_meter.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/file_exist_fn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/file_exist_fn.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/knn_monitor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/knn_monitor.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/plotter.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GANPerf/LCR/fe591d4ee9b6dccba8710f9f73ef41b9a129f013/tools/__pycache__/plotter.cpython-38.pyc
--------------------------------------------------------------------------------
/tools/accuracy.py:
--------------------------------------------------------------------------------
1 | def accuracy(output, target, topk=(1,)):
2 | """Computes the accuracy over the k top predictions for the specified values of k"""
3 | with torch.no_grad():
4 | maxk = max(topk)
5 | batch_size = target.size(0)
6 |
7 | _, pred = output.topk(maxk, 1, True, True)
8 | pred = pred.t()
9 | correct = pred.eq(target.view(1, -1).expand_as(pred))
10 |
11 | res = []
12 | for k in topk:
13 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
14 | res.append(correct_k.mul_(100.0 / batch_size))
15 | return res
16 |
--------------------------------------------------------------------------------
/tools/average_meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter():
2 | """Computes and stores the average and current value"""
3 | def __init__(self, name, fmt=':f'):
4 | self.name = name
5 | self.fmt = fmt
6 | self.log = []
7 | self.val = 0
8 | self.avg = 0
9 | self.sum = 0
10 | self.count = 0
11 |
12 | def reset(self):
13 | self.log.append(self.avg)
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
25 | def __str__(self):
26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
27 | return fmtstr.format(**self.__dict__)
28 |
29 | if __name__ == "__main__":
30 | meter = AverageMeter('sldk')
31 | print(meter.log)
32 |
33 |
--------------------------------------------------------------------------------
/tools/file_exist_fn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 |
5 | def file_exist_check(file_dir):
6 |
7 | if os.path.isdir(file_dir):
8 | for i in range(2, 1000):
9 | if not os.path.isdir(file_dir + f'({i})'):
10 | file_dir += f'({i})'
11 | break
12 | return file_dir
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tools/knn_monitor.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch.nn.functional as F
3 | import torch
4 | # code copied from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=RI1Y8bSImD7N
5 | # test using a knn monitor
6 | def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False):
7 | net.eval()
8 | classes = len(memory_data_loader.dataset.classes)
9 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
10 | with torch.no_grad():
11 | # generate feature bank
12 | for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress):
13 | feature = net(data.cuda(non_blocking=True))
14 | feature = F.normalize(feature, dim=1)
15 | feature_bank.append(feature)
16 | # [D, N]
17 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
18 | # [N]
19 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
20 | # loop test data to predict the label by weighted knn search
21 | test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress)
22 | for data, target in test_bar:
23 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
24 | feature = net(data)
25 | feature = F.normalize(feature, dim=1)
26 |
27 | pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t)
28 |
29 | total_num += data.size(0)
30 | total_top1 += (pred_labels[:, 0] == target).float().sum().item()
31 | test_bar.set_postfix({'Accuracy':total_top1 / total_num * 100})
32 | return total_top1 / total_num * 100
33 |
34 | # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
35 | # implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
36 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
37 | # compute cos similarity between each feature vector and feature bank ---> [B, N]
38 | sim_matrix = torch.mm(feature, feature_bank)
39 | # [B, K]
40 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
41 | # [B, K]
42 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
43 | sim_weight = (sim_weight / knn_t).exp()
44 |
45 | # counts for each class
46 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
47 | # [B*K, C]
48 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
49 | # weighted score ---> [B, C]
50 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
51 |
52 | pred_labels = pred_scores.argsort(dim=-1, descending=True)
53 | return pred_labels
54 |
--------------------------------------------------------------------------------
/tools/logger.py:
--------------------------------------------------------------------------------
1 | # import torch
2 | # try:
3 | # from torch.utils.tensorboard import SummaryWriter
4 | # except ImportError:
5 | from tensorboardX import SummaryWriter
6 |
7 | from torch import Tensor
8 | from collections import OrderedDict
9 | import os
10 | from .plotter import Plotter
11 |
12 |
13 | class Logger(object):
14 | def __init__(self, log_dir, tensorboard=True, matplotlib=True):
15 |
16 | self.reset(log_dir, tensorboard, matplotlib)
17 |
18 | def reset(self, log_dir=None, tensorboard=True, matplotlib=True):
19 |
20 | if log_dir is not None: self.log_dir=log_dir
21 | self.writer = SummaryWriter(log_dir=self.log_dir) if tensorboard else None
22 | self.plotter = Plotter() if matplotlib else None
23 | self.counter = OrderedDict()
24 |
25 | def update_scalers(self, ordered_dict):
26 |
27 | for key, value in ordered_dict.items():
28 | if isinstance(value, Tensor):
29 | ordered_dict[key] = value.item()
30 | if self.counter.get(key) is None:
31 | self.counter[key] = 1
32 | else:
33 | self.counter[key] += 1
34 |
35 | if self.writer:
36 | self.writer.add_scalar(key, value, self.counter[key])
37 |
38 |
39 | if self.plotter:
40 | self.plotter.update(ordered_dict)
41 | self.plotter.save(os.path.join(self.log_dir, 'plotter.svg'))
42 |
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/tools/plotter.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg') #https://stackoverflow.com/questions/49921721/runtimeerror-main-thread-is-not-in-main-loop-with-matplotlib-and-flask
3 | import matplotlib.pyplot as plt
4 | from collections import OrderedDict
5 | from torch import Tensor
6 |
7 | class Plotter(object):
8 | def __init__(self):
9 | self.logger = OrderedDict()
10 | def update(self, ordered_dict):
11 | for key, value in ordered_dict.items():
12 | if isinstance(value, Tensor):
13 | ordered_dict[key] = value.item()
14 | if self.logger.get(key) is None:
15 | self.logger[key] = [value]
16 | else:
17 | self.logger[key].append(value)
18 |
19 | def save(self, file, **kwargs):
20 | fig, axes = plt.subplots(nrows=len(self.logger), ncols=1, figsize=(8,2*len(self.logger)))
21 | fig.tight_layout()
22 | for ax, (key, value) in zip(axes, self.logger.items()):
23 | ax.plot(value)
24 | ax.set_title(key)
25 |
26 | plt.savefig(file, **kwargs)
27 | plt.close()
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------