├── README.md ├── classification ├── Sampling.py ├── main_arb.py ├── main_cifar10.py ├── main_cifar100.py ├── main_imagenet.py ├── test_arb.py ├── test_imagenet.py └── vit │ ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── configs-checkpoint.py │ ├── model-checkpoint.py │ ├── transformer-checkpoint.py │ └── utils-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── configs.cpython-36.pyc │ ├── configs.cpython-37.pyc │ ├── configs.cpython-38.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── transformer.cpython-36.pyc │ ├── transformer.cpython-37.pyc │ ├── transformer.cpython-38.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ └── utils.cpython-38.pyc │ ├── configs.py │ ├── model.py │ ├── transformer.py │ └── utils.py ├── figs ├── README.md └── network.PNG └── segmentation └── README.md /README.md: -------------------------------------------------------------------------------- 1 | # 📖TransCL: Transformer Makes Strong and Flexible Compressive Learning (TPAMI 2022) 2 | >[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://ieeexplore.ieee.org/document/9841016)
3 | >[Chong Mou](https://scholar.google.com.hk/citations?user=SYQoDk0AAAAJ&hl=en), [Jian Zhang](https://jianzhang.tech/)
4 | 5 | 6 |

7 | 8 |

9 | 10 | ## 🔧 Dependencies and Installation 11 | 12 | - Python >= 3.6 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) 13 | - [PyTorch >= 1.4](https://pytorch.org/) 14 | - At least two V100 GPUs are required. 15 | 16 | ### Installation 17 | 18 | 1. Clone repo 19 | 20 | ```bash 21 | git clone https://github.com/MC-E/TransCL.git 22 | cd TransCL-main 23 | ``` 24 | 25 | 2. Install dependent packages 26 | 27 | ```bash 28 | pip install tensorboardX 29 | conda install pytorch=1.6.0 torchvision cudatoolkit=10.1 -c pytorch -y 30 | pip install mmcv-full==1.2.2 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html 31 | ``` 32 | 33 | ## Training for Classification 34 | Please run the commands in the folder of `classification`.
35 | 1. Prepare the training data of [ImageNet1K](https://image-net.org/download-images.php) 36 | 2. Download the pre-trained checkpoints of vision transformer from [link-ImageNet](https://disk.pku.edu.cn:443/link/9B8CAF903895E2BEDBA1E58641A3C4E3) and [link-CIFAR](https://disk.pku.edu.cn:443/link/1437785D0ECDE48C86AEC4EF13E61939). 37 | 3. The training support ViT-base with patch size being 16 (`-a B_16_imagenet1k`) and 32 (`-a B_32_imagenet1k`). 38 | #### Training on ImageNet with a fixed CS ratio 39 | ```bash 40 | python main_imagenet.py -a 'B_32_imagenet1k' -b 128 --image_size 384 --gpu 0 --lr 1e-3 --log_dir logs/transcl_384_imagenet_p32_01 --cs=1 --mm=1 --save_path=transcl_384_imagenet_p32_01 --devices=4 --rat 0.1 --data /group/30042/public_datasets/imagenet1k 41 | ``` 42 | 43 | #### Training on ImageNet with arbitrary CS ratios 44 | ```bash 45 | python main_arb.py -a 'B_32_imagenet1k' -b 128 --image_size 384 --gpu 0 --lr 1e-3 --log_dir logs/transcl_384_imagenet_p32_01 --cs=1 --mm=1 --save_path=transcl_384_imagenet_p32_01 --devices=4 --rat 0.1 --data /group/30042/public_datasets/imagenet1k 46 | ``` 47 | 48 | #### Training on Cifar10 with a fixed CS ratio 49 | ```bash 50 | python main_cifar10.py -a 'B_32_imagenet1k' -b 128 --image_size 384 --gpu 0 --lr 1e-3 --log_dir logs/transcl_384_cifar10_p32_01 --cs=1 --mm=1 --save_path=transcl_384_cifar10_p32_01 --devices=4 --rat 0.1 51 | ``` 52 | 53 | #### Training on Cifar100 with a fixed CS ratio 54 | ```bash 55 | python main_cifar100.py -a 'B_32_imagenet1k' -b 128 --image_size 384 --gpu 0 --lr 1e-3 --log_dir logs/transcl_384_cifar100_p32_01 --cs=1 --mm=1 --save_path=transcl_384_cifar100_p32_01 --devices=4 --rat 0.1 56 | ``` 57 | 58 | ## Training for Segmentation 59 | Coming soon 60 | 61 | ## Testing for Classification 62 | You can download the pre-trained checkpoints from our model zoo. 63 | ### Testing on ImageNet with a fixed CS ratio 64 | ```bash 65 | 66 | python test_imagenet.py -a 'B_32_imagenet1k' -b 128 --image_size 384 67 | ``` 68 | 69 | ### Testing on ImageNet with arbitrary CS ratios 70 | ```bash 71 | 72 | python test_arb.py -a 'B_32_imagenet1k' -b 128 --image_size 384 73 | ``` 74 | 75 | ## Testing for Segmentation 76 | Coming soon 77 | 78 | ## :european_castle: Model Zoo 79 | ### Classification 80 | 81 | | Mode | Download link | 82 | | :------------------- | :--------------------------------------------: | 83 | | Pre-trained ViT | [ImageNet](https://disk.pku.edu.cn:443/link/9B8CAF903895E2BEDBA1E58641A3C4E3), [CIFAR](https://disk.pku.edu.cn:443/link/1437785D0ECDE48C86AEC4EF13E61939) | 84 | | ImageNet classification (patch size=16, ratio={0.1, 0.05, 0.025, 0.01}) | [URL](https://disk.pku.edu.cn:443/link/750ECBBCE56BC0A60A81CA2C9B09DEE1) | 85 | | ImageNet classification (patch size=32, ratio={0.1, 0.05, 0.025, 0.01}) | [URL](https://disk.pku.edu.cn:443/link/2713FF6650438ACDAB39411A74FF9334) | 86 | | Distilled ImageNet classification (patch size=16, ratio={0.1, 0.05, 0.025, 0.01}) | URL | 87 | | Cifar10 classification | URL | 88 | | Cifar100 classification | URL | 89 | | Arbitrary ratio classification (patch size=32) | [URL](https://disk.pku.edu.cn:443/link/8A1FC9F83F364F3CEAA0F6C048DAB362) | 90 | | Binary sampling classification | URL | 91 | | Shuffled classification | URL | 92 | 93 | ### Segmentation 94 | 95 | | Mode | Download link | 96 | | :------------------- | :--------------------------------------------: | 97 | | Pre-trained ViT-large | URL | 98 | | Segmentation on ADE20K with fixed ratio (patch size=16, ratio={0.1, 0.05, 0.025, 0.01}) | URL | 99 | | Segmentation on Cityscapes with fixed ratio (patch size=16, ratio={0.1, 0.05, 0.025, 0.01}) | URL | 100 | | Segmentation on Pascal Context with fixed ratio (patch size=16, ratio={0.1, 0.05, 0.025, 0.01}) | URL | 101 | | Segmentation with arbitrary ratios | URL | 102 | | Segmentation with binary sampling | URL | 103 | 104 | ## BibTeX 105 | 106 | @article{mou2022transcl, 107 | title={TransCL: Transformer makes strong and flexible compressive learning}, 108 | author={Mou, Chong and Zhang, Jian}, 109 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 110 | year={2022}, 111 | publisher={IEEE} 112 | } 113 | -------------------------------------------------------------------------------- /classification/Sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import cv2 5 | import numpy as np 6 | import random 7 | import torch.nn.functional as F 8 | 9 | 10 | class MySign(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, input): 13 | ctx.save_for_backward(input) 14 | output = input.new(input.size()) 15 | output[input >= 0] = 1 16 | output[input < 0] = -1 17 | return output 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | input_, = ctx.saved_tensors 22 | # print(input_.shape) 23 | grad_input = grad_output.clone() 24 | grad_input[input_ < -1] = 0 25 | grad_input[input_ > 1] = 0 26 | return grad_input 27 | 28 | 29 | MyBinarize = MySign.apply 30 | 31 | 32 | class CS_Sampling_bin(torch.nn.Module): 33 | def __init__(self, n_channels=3, cs_ratio=0.25, blocksize=32): 34 | super(CS_Sampling_bin, self).__init__() 35 | 36 | print('CS ratio: ', cs_ratio) 37 | 38 | n_output = int(blocksize ** 2) 39 | n_input = int(cs_ratio * n_output) 40 | 41 | self.PhiR = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 42 | self.PhiG = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 43 | self.PhiB = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 44 | 45 | self.Phi_scaleR = nn.Parameter(torch.Tensor([0.01])) 46 | self.Phi_scaleG = nn.Parameter(torch.Tensor([0.01])) 47 | self.Phi_scaleB = nn.Parameter(torch.Tensor([0.01])) 48 | 49 | self.n_channels = n_channels 50 | self.n_input = n_input 51 | self.n_output = n_output 52 | self.blocksize = blocksize 53 | 54 | def forward(self, x): 55 | Phi_R = MyBinarize(self.PhiR) * self.Phi_scaleR 56 | Phi_G = MyBinarize(self.PhiG) * self.Phi_scaleG 57 | Phi_B = MyBinarize(self.PhiB) * self.Phi_scaleB 58 | 59 | PhiWeight_R = Phi_R.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 60 | PhiWeight_G = Phi_G.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 61 | PhiWeight_B = Phi_B.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 62 | 63 | Phix_R = F.conv2d(x[:, 0:1, :, :], PhiWeight_R, padding=0, stride=self.blocksize, bias=None) # Get measurements 64 | Phix_G = F.conv2d(x[:, 1:2, :, :], PhiWeight_G, padding=0, stride=self.blocksize, bias=None) # Get measurements 65 | Phix_B = F.conv2d(x[:, 2:3, :, :], PhiWeight_B, padding=0, stride=self.blocksize, bias=None) # Get measurements 66 | 67 | # Initialization-subnet 68 | PhiTWeight_R = Phi_R.t().contiguous().view(self.n_output, self.n_input, 1, 1) 69 | PhiTb_R = F.conv2d(Phix_R, PhiTWeight_R, padding=0, bias=None) 70 | PhiTb_R = torch.nn.PixelShuffle(self.blocksize)(PhiTb_R) 71 | x_R = PhiTb_R # Conduct initialization 72 | 73 | PhiTWeight_G = Phi_G.t().contiguous().view(self.n_output, self.n_input, 1, 1) 74 | PhiTb_G = F.conv2d(Phix_G, PhiTWeight_G, padding=0, bias=None) 75 | PhiTb_G = torch.nn.PixelShuffle(self.blocksize)(PhiTb_G) 76 | x_G = PhiTb_G 77 | 78 | PhiTWeight_B = Phi_B.t().contiguous().view(self.n_output, self.n_input, 1, 1) 79 | PhiTb_B = F.conv2d(Phix_B, PhiTWeight_B, padding=0, bias=None) 80 | PhiTb_B = torch.nn.PixelShuffle(self.blocksize)(PhiTb_B) 81 | x_B = PhiTb_B 82 | 83 | x = torch.cat([x_R, x_G, x_B], dim=1) 84 | 85 | return x 86 | 87 | 88 | class CS_Sampling(torch.nn.Module): 89 | def __init__(self, n_channels=3, cs_ratio=0.25, blocksize=32, im_size=384): 90 | super(CS_Sampling, self).__init__() 91 | print('bcs') 92 | 93 | n_output = int(blocksize ** 2) 94 | n_input = int(cs_ratio * n_output) 95 | 96 | self.PhiR = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 97 | self.PhiG = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 98 | self.PhiB = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 99 | 100 | self.n_channels = n_channels 101 | self.n_input = n_input 102 | self.n_output = n_output 103 | self.blocksize = blocksize 104 | 105 | self.im_size = im_size 106 | 107 | def forward(self, x): 108 | Phi_R = self.PhiR 109 | Phi_G = self.PhiG 110 | Phi_B = self.PhiB 111 | 112 | PhiWeight_R = Phi_R.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 113 | PhiWeight_G = Phi_G.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 114 | PhiWeight_B = Phi_B.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 115 | 116 | Phix_R = F.conv2d(x[:, 0:1, :, :], PhiWeight_R, padding=0, stride=self.blocksize, bias=None) # Get measurements 117 | Phix_G = F.conv2d(x[:, 1:2, :, :], PhiWeight_G, padding=0, stride=self.blocksize, bias=None) # Get measurements 118 | Phix_B = F.conv2d(x[:, 2:3, :, :], PhiWeight_B, padding=0, stride=self.blocksize, bias=None) # Get measurements 119 | 120 | # Initialization-subnet 121 | PhiTWeight_R = Phi_R.t().contiguous().view(self.n_output, self.n_input, 1, 1) 122 | PhiTb_R = F.conv2d(Phix_R, PhiTWeight_R, padding=0, bias=None) 123 | PhiTb_R = torch.nn.PixelShuffle(self.blocksize)(PhiTb_R) 124 | x_R = PhiTb_R # Conduct initialization 125 | 126 | PhiTWeight_G = Phi_G.t().contiguous().view(self.n_output, self.n_input, 1, 1) 127 | PhiTb_G = F.conv2d(Phix_G, PhiTWeight_G, padding=0, bias=None) 128 | PhiTb_G = torch.nn.PixelShuffle(self.blocksize)(PhiTb_G) 129 | x_G = PhiTb_G 130 | 131 | PhiTWeight_B = Phi_B.t().contiguous().view(self.n_output, self.n_input, 1, 1) 132 | PhiTb_B = F.conv2d(Phix_B, PhiTWeight_B, padding=0, bias=None) 133 | PhiTb_B = torch.nn.PixelShuffle(self.blocksize)(PhiTb_B) 134 | x_B = PhiTb_B 135 | 136 | x = torch.cat([x_R, x_G, x_B], dim=1) 137 | x = F.interpolate(x, size=(self.im_size, self.im_size), mode='bilinear') 138 | 139 | return x 140 | 141 | 142 | class CS_Sampling_rm(torch.nn.Module): 143 | def __init__(self, n_channels=3, cs_ratio=0.25, blocksize=32, image_size=384, rate_rm=0.1, random=True): 144 | super(CS_Sampling_rm, self).__init__() 145 | 146 | n_output = int(blocksize ** 2) 147 | n_input = int(cs_ratio * n_output) 148 | 149 | self.PhiR = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 150 | self.PhiG = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 151 | self.PhiB = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 152 | 153 | self.n_channels = n_channels 154 | self.n_input = n_input 155 | self.n_output = n_output 156 | self.blocksize = blocksize 157 | 158 | self.random = random 159 | self.num_blocks = (image_size // blocksize) ** 2 160 | self.rate_rm = rate_rm 161 | self.image_size = image_size 162 | self.blocksize = blocksize 163 | 164 | num_x = image_size // blocksize 165 | self.pos = [[i, j] for j in range(num_x) for i in range(num_x)] 166 | 167 | def generate_mask(self, size, psize, num_rm): 168 | mask = torch.ones(size=(size, size)) 169 | random.shuffle(self.pos) 170 | pos_rm = self.pos[:num_rm] 171 | for pos_rm_i in pos_rm: 172 | mask[pos_rm_i[0] * psize:(pos_rm_i[0] + 1) * psize, pos_rm_i[1] * psize:(pos_rm_i[1] + 1) * psize] = 0. 173 | return mask.unsqueeze(0).unsqueeze(0) 174 | 175 | def updata_rat(self, rat): 176 | self.rate_rm = rat 177 | 178 | def forward(self, x): 179 | # print(self.rate_rm) 180 | img = x 181 | if self.random == True: 182 | rate_rm = torch.rand(1) * 0.5 183 | else: 184 | rate_rm = self.rate_rm 185 | 186 | num_rm = int(rate_rm * self.num_blocks) 187 | # print(rate_rm,num_rm,self.num_blocks) 188 | masks = [] 189 | for b in range(x.shape[0]): 190 | masks.append(self.generate_mask(size=self.image_size, psize=self.blocksize, num_rm=num_rm)) 191 | masks = torch.cat(masks, dim=0).cuda() 192 | x = x * masks 193 | img = ((((img + 1.) / 2. * masks)[0].permute(1, 2, 0).cpu().data.numpy()) * 255.).astype(np.uint8) 194 | cv2.imwrite('mask.png', img) 195 | exit(0) 196 | 197 | Phi_R = self.PhiR 198 | Phi_G = self.PhiG 199 | Phi_B = self.PhiB 200 | 201 | PhiWeight_R = Phi_R.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 202 | PhiWeight_G = Phi_G.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 203 | PhiWeight_B = Phi_B.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 204 | 205 | Phix_R = F.conv2d(x[:, 0:1, :, :], PhiWeight_R, padding=0, stride=self.blocksize, bias=None) # Get measurements 206 | Phix_G = F.conv2d(x[:, 1:2, :, :], PhiWeight_G, padding=0, stride=self.blocksize, bias=None) # Get measurements 207 | Phix_B = F.conv2d(x[:, 2:3, :, :], PhiWeight_B, padding=0, stride=self.blocksize, bias=None) # Get measurements 208 | 209 | # Initialization-subnet 210 | PhiTWeight_R = Phi_R.t().contiguous().view(self.n_output, self.n_input, 1, 1) 211 | PhiTb_R = F.conv2d(Phix_R, PhiTWeight_R, padding=0, bias=None) 212 | PhiTb_R = torch.nn.PixelShuffle(self.blocksize)(PhiTb_R) 213 | x_R = PhiTb_R # Conduct initialization 214 | 215 | PhiTWeight_G = Phi_G.t().contiguous().view(self.n_output, self.n_input, 1, 1) 216 | PhiTb_G = F.conv2d(Phix_G, PhiTWeight_G, padding=0, bias=None) 217 | PhiTb_G = torch.nn.PixelShuffle(self.blocksize)(PhiTb_G) 218 | x_G = PhiTb_G 219 | 220 | PhiTWeight_B = Phi_B.t().contiguous().view(self.n_output, self.n_input, 1, 1) 221 | PhiTb_B = F.conv2d(Phix_B, PhiTWeight_B, padding=0, bias=None) 222 | PhiTb_B = torch.nn.PixelShuffle(self.blocksize)(PhiTb_B) 223 | x_B = PhiTb_B 224 | 225 | x = torch.cat([x_R, x_G, x_B], dim=1) 226 | 227 | return x 228 | 229 | 230 | class CS_Sampling_shuffle(torch.nn.Module): 231 | def __init__(self, n_channels=3, cs_ratio=0.25, blocksize=32, image_size=384, rate_rm=0.1, random=True): 232 | super(CS_Sampling_shuffle, self).__init__() 233 | 234 | n_output = int(blocksize ** 2) 235 | n_input = int(cs_ratio * n_output) 236 | 237 | self.PhiR = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 238 | self.PhiG = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 239 | self.PhiB = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, n_output))) 240 | 241 | self.n_channels = n_channels 242 | self.n_input = n_input 243 | self.n_output = n_output 244 | self.blocksize = blocksize 245 | 246 | self.random = random 247 | self.num_blocks = (image_size // blocksize) ** 2 248 | self.rate_rm = rate_rm 249 | self.image_size = image_size 250 | self.blocksize = blocksize 251 | 252 | num_x = image_size // blocksize 253 | self.pos = [[i, j] for j in range(num_x) for i in range(num_x)] 254 | 255 | def shuffle(self, x): 256 | h, w = x.shape[-2], x.shape[-1] 257 | blocks = F.unfold(x, kernel_size=32, stride=32).permute(0, 2, 1) 258 | l = blocks.shape[1] 259 | idxes = list(range(l)) 260 | random.shuffle(idxes) 261 | blocks = blocks[:, idxes, :] 262 | blocks = blocks.permute(0, 2, 1) 263 | return F.fold(blocks, output_size=(h, w), kernel_size=32, stride=32).contiguous() 264 | 265 | def forward(self, x): 266 | x = self.shuffle(x) 267 | Phi_R = self.PhiR 268 | Phi_G = self.PhiG 269 | Phi_B = self.PhiB 270 | 271 | PhiWeight_R = Phi_R.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 272 | PhiWeight_G = Phi_G.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 273 | PhiWeight_B = Phi_B.contiguous().view(int(self.n_input), 1, self.blocksize, self.blocksize) 274 | 275 | Phix_R = F.conv2d(x[:, 0:1, :, :], PhiWeight_R, padding=0, stride=self.blocksize, bias=None) # Get measurements 276 | Phix_G = F.conv2d(x[:, 1:2, :, :], PhiWeight_G, padding=0, stride=self.blocksize, bias=None) # Get measurements 277 | Phix_B = F.conv2d(x[:, 2:3, :, :], PhiWeight_B, padding=0, stride=self.blocksize, bias=None) # Get measurements 278 | 279 | # Initialization-subnet 280 | PhiTWeight_R = Phi_R.t().contiguous().view(self.n_output, self.n_input, 1, 1) 281 | PhiTb_R = F.conv2d(Phix_R, PhiTWeight_R, padding=0, bias=None) 282 | PhiTb_R = torch.nn.PixelShuffle(self.blocksize)(PhiTb_R) 283 | x_R = PhiTb_R # Conduct initialization 284 | 285 | PhiTWeight_G = Phi_G.t().contiguous().view(self.n_output, self.n_input, 1, 1) 286 | PhiTb_G = F.conv2d(Phix_G, PhiTWeight_G, padding=0, bias=None) 287 | PhiTb_G = torch.nn.PixelShuffle(self.blocksize)(PhiTb_G) 288 | x_G = PhiTb_G 289 | 290 | PhiTWeight_B = Phi_B.t().contiguous().view(self.n_output, self.n_input, 1, 1) 291 | PhiTb_B = F.conv2d(Phix_B, PhiTWeight_B, padding=0, bias=None) 292 | PhiTb_B = torch.nn.PixelShuffle(self.blocksize)(PhiTb_B) 293 | x_B = PhiTb_B 294 | 295 | x = torch.cat([x_R, x_G, x_B], dim=1) 296 | 297 | return x 298 | 299 | 300 | class CS_Sampling_arb(torch.nn.Module): 301 | def __init__(self, n_channels=3, cs_ratio=0.25, blocksize=32): 302 | super(CS_Sampling_arb, self).__init__() 303 | 304 | n_output = int(blocksize ** 2) 305 | 306 | self.PhiR = nn.Parameter(init.xavier_normal_(torch.Tensor(blocksize * blocksize, blocksize * blocksize))) 307 | self.PhiG = nn.Parameter(init.xavier_normal_(torch.Tensor(blocksize * blocksize, blocksize * blocksize))) 308 | self.PhiB = nn.Parameter(init.xavier_normal_(torch.Tensor(blocksize * blocksize, blocksize * blocksize))) 309 | 310 | self.n_channels = n_channels 311 | self.n_output = n_output 312 | self.blocksize = blocksize 313 | 314 | def forward(self, x, num_rows=None): 315 | if num_rows is None: 316 | num_rows = np.random.randint(1, 1024) 317 | 318 | Phi_R = self.PhiR[:num_rows, :] 319 | Phi_G = self.PhiG[:num_rows, :] 320 | Phi_B = self.PhiB[:num_rows, :] 321 | 322 | PhiWeight_R = Phi_R.contiguous().view(num_rows, 1, self.blocksize, self.blocksize) 323 | PhiWeight_G = Phi_G.contiguous().view(num_rows, 1, self.blocksize, self.blocksize) 324 | PhiWeight_B = Phi_B.contiguous().view(num_rows, 1, self.blocksize, self.blocksize) 325 | 326 | Phix_R = F.conv2d(x[:, 0:1, :, :], PhiWeight_R, padding=0, stride=self.blocksize, bias=None) # Get measurements 327 | Phix_G = F.conv2d(x[:, 1:2, :, :], PhiWeight_G, padding=0, stride=self.blocksize, bias=None) # Get measurements 328 | Phix_B = F.conv2d(x[:, 2:3, :, :], PhiWeight_B, padding=0, stride=self.blocksize, bias=None) # Get measurements 329 | 330 | # Initialization-subnet 331 | PhiTWeight_R = Phi_R.t().contiguous().view(self.n_output, num_rows, 1, 1) 332 | PhiTb_R = F.conv2d(Phix_R, PhiTWeight_R, padding=0, bias=None) 333 | PhiTb_R = torch.nn.PixelShuffle(self.blocksize)(PhiTb_R) 334 | x_R = PhiTb_R # Conduct initialization 335 | 336 | PhiTWeight_G = Phi_G.t().contiguous().view(self.n_output, num_rows, 1, 1) 337 | PhiTb_G = F.conv2d(Phix_G, PhiTWeight_G, padding=0, bias=None) 338 | PhiTb_G = torch.nn.PixelShuffle(self.blocksize)(PhiTb_G) 339 | x_G = PhiTb_G 340 | 341 | PhiTWeight_B = Phi_B.t().contiguous().view(self.n_output, num_rows, 1, 1) 342 | PhiTb_B = F.conv2d(Phix_B, PhiTWeight_B, padding=0, bias=None) 343 | PhiTb_B = torch.nn.PixelShuffle(self.blocksize)(PhiTb_B) 344 | x_B = PhiTb_B 345 | 346 | x = torch.cat([x_R, x_G, x_B], dim=1) 347 | 348 | return x 349 | 350 | if __name__ == '__main__': 351 | cs_sampling = CS_Sampling(n_channels=3, cs_ratio=0.25, blocksize=16, im_size=384) 352 | input_img = torch.randn(2, 3, 32, 32) 353 | output_img = cs_sampling(input_img) 354 | print(output_img.shape) -------------------------------------------------------------------------------- /classification/main_arb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import PIL 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from Sampling import CS_Sampling_arb as CS_Sampling 22 | from tensorboardX import SummaryWriter 23 | 24 | from vit import ViT, load_pretrained_weights 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('--data', default='/gdata/ImageNet2012', 28 | help='path to dataset') 29 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 30 | help='model architecture (default: resnet18)') 31 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 34 | help='number of total epochs to run') 35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 36 | help='manual epoch number (useful on restarts)') 37 | parser.add_argument('-b', '--batch-size', default=256, type=int, 38 | metavar='N', 39 | help='mini-batch size (default: 256), this is the total ' 40 | 'batch size of all GPUs on the current node when ' 41 | 'using Data Parallel or Distributed Data Parallel') 42 | # parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 43 | # metavar='LR', help='initial learning rate', dest='lr') 44 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 45 | metavar='LR', help='initial learning rate', dest='lr') 46 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 47 | help='momentum') 48 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 49 | metavar='W', help='weight decay (default: 1e-4)', 50 | dest='weight_decay') 51 | parser.add_argument('-p', '--print-freq', default=10, type=int, 52 | metavar='N', help='print frequency (default: 10)') 53 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 54 | help='path to latest checkpoint (default: none)') 55 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 56 | help='evaluate model on validation set') 57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 58 | help='use pre-trained model') 59 | parser.add_argument('--world-size', default=-1, type=int, 60 | help='number of nodes for distributed training') 61 | parser.add_argument('--rank', default=-1, type=int, 62 | help='node rank for distributed training') 63 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 64 | help='url used to set up distributed training') 65 | parser.add_argument('--dist-backend', default='nccl', type=str, 66 | help='distributed backend') 67 | parser.add_argument('--seed', default=None, type=int, 68 | help='seed for initializing training. ') 69 | parser.add_argument('--gpu', default=None, type=int, 70 | help='GPU id to use.') 71 | parser.add_argument('--image_size', default=224, type=int, 72 | help='image size') 73 | parser.add_argument('--vit', default=True, help='use ViT model') 74 | parser.add_argument('--multiprocessing-distributed', action='store_true', 75 | help='Use multi-processing distributed training to launch ' 76 | 'N processes per node, which has N GPUs. This is the ' 77 | 'fastest way to use PyTorch for either single node or ' 78 | 'multi node data parallel training') 79 | parser.add_argument('--cs',type=int,default=0) 80 | parser.add_argument('--log_dir',default='logs') 81 | parser.add_argument('--mm',type=int,default=0) 82 | parser.add_argument('--save_path',type=str,default='ckp') 83 | parser.add_argument('--rat',type=float,default=0.1) 84 | parser.add_argument('--devices',type=int,default=4) 85 | parser.add_argument('--psize',type=int,default=32) 86 | best_acc1 = 0 87 | 88 | 89 | def main(): 90 | args = parser.parse_args() 91 | 92 | if args.seed is not None: 93 | random.seed(args.seed) 94 | torch.manual_seed(args.seed) 95 | cudnn.deterministic = True 96 | warnings.warn('You have chosen to seed training. ' 97 | 'This will turn on the CUDNN deterministic setting, ' 98 | 'which can slow down your training considerably! ' 99 | 'You may see unexpected behavior when restarting ' 100 | 'from checkpoints.') 101 | 102 | if args.gpu is not None: 103 | warnings.warn('You have chosen a specific GPU. This will completely ' 104 | 'disable data parallelism.') 105 | 106 | if args.dist_url == "env://" and args.world_size == -1: 107 | args.world_size = int(os.environ["WORLD_SIZE"]) 108 | 109 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 110 | 111 | ngpus_per_node = torch.cuda.device_count() 112 | if args.multiprocessing_distributed: 113 | # Since we have ngpus_per_node processes per node, the total world_size 114 | # needs to be adjusted accordingly 115 | args.world_size = ngpus_per_node * args.world_size 116 | # Use torch.multiprocessing.spawn to launch distributed processes: the 117 | # main_worker process function 118 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 119 | else: 120 | # Simply call main_worker function 121 | main_worker(args.gpu, ngpus_per_node, args) 122 | 123 | 124 | def main_worker(gpu, ngpus_per_node, args): 125 | # print(args.cs,args.lr) 126 | # exit(0) 127 | global best_acc1 128 | args.gpu = gpu 129 | patch_size = args.image_size//16 130 | writer_loss = SummaryWriter(args.log_dir) 131 | writer_acc = SummaryWriter(args.log_dir) 132 | if os.path.isdir(args.save_path)==False: 133 | os.mkdir(args.save_path) 134 | # print('patch_size:', patch_size) 135 | 136 | if args.gpu is not None: 137 | print("Use GPU: {} for training".format(args.gpu)) 138 | 139 | if args.distributed: 140 | if args.dist_url == "env://" and args.rank == -1: 141 | args.rank = int(os.environ["RANK"]) 142 | if args.multiprocessing_distributed: 143 | # For multiprocessing distributed training, rank needs to be the 144 | # global rank among all the processes 145 | args.rank = args.rank * ngpus_per_node + gpu 146 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 147 | world_size=args.world_size, rank=args.rank) 148 | 149 | # NEW 150 | if args.vit: 151 | model = ViT(args.arch, pretrained=args.pretrained,image_size=(args.image_size,args.image_size)) 152 | 153 | else: 154 | model = models.__dict__[args.arch](pretrained=args.pretrained) 155 | model.load_state_dict(torch.load('pretrain/B_32_imagenet1k.pth')) 156 | if args.cs==1: 157 | print('Now Use CS.') 158 | print('CS ratio=',args.rat) 159 | cs_sampling = CS_Sampling(n_channels=3, cs_ratio=args.rat, blocksize=args.psize).cuda() 160 | cs_sampling = torch.nn.DataParallel(cs_sampling,range(args.devices)) 161 | else: 162 | cs_sampling = None 163 | print("=> using model '{}' (pretrained={})".format(args.arch, args.pretrained)) 164 | 165 | if args.distributed: 166 | # For multiprocessing distributed, DistributedDataParallel constructor 167 | # should always set the single device scope, otherwise, 168 | # DistributedDataParallel will use all available devices. 169 | if args.gpu is not None: 170 | torch.cuda.set_device(args.gpu) 171 | model.cuda(args.gpu) 172 | # When using a single GPU per process and per 173 | # DistributedDataParallel, we need to divide the batch size 174 | # ourselves based on the total number of GPUs we have 175 | args.batch_size = int(args.batch_size / ngpus_per_node) 176 | args.workers = int(args.workers / ngpus_per_node) 177 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.devices])#[args.gpu]) 178 | else: 179 | model.cuda() 180 | # DistributedDataParallel will divide and allocate batch_size to all 181 | # available GPUs if device_ids are not set 182 | model = torch.nn.parallel.DistributedDataParallel(model) 183 | elif args.gpu is not None: 184 | print('No distribut') 185 | #torch.cuda.set_device(args.gpu) 186 | model.cuda() 187 | model = torch.nn.DataParallel(model,range(args.devices)) 188 | # model = model.cuda(args.gpu) 189 | else: 190 | # DataParallel will divide and allocate batch_size to all available GPUs 191 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 192 | model.features = torch.nn.DataParallel(model.features) 193 | model.cuda() 194 | else: 195 | model = torch.nn.DataParallel(model).cuda() 196 | 197 | # define loss function (criterion) and optimizer 198 | criterion = nn.CrossEntropyLoss().cuda()#(args.gpu) 199 | if args.cs==1: 200 | if args.mm==1: 201 | print('MM!') 202 | optimizer = torch.optim.SGD( 203 | [{'params': model.parameters()}, 204 | {'params': cs_sampling.parameters()}] 205 | ,args.lr, 206 | momentum=args.momentum, 207 | weight_decay=args.weight_decay) 208 | else: 209 | optimizer = torch.optim.SGD(cs_sampling.parameters(), args.lr, 210 | momentum=args.momentum, 211 | weight_decay=args.weight_decay) 212 | else: 213 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 214 | momentum=args.momentum, 215 | weight_decay=args.weight_decay) 216 | 217 | # optionally resume from a checkpoint 218 | if args.resume: 219 | if os.path.isfile(args.resume): 220 | print("=> loading checkpoint '{}'".format(args.resume)) 221 | checkpoint = torch.load(args.resume) 222 | args.start_epoch = checkpoint['epoch'] 223 | best_acc1 = checkpoint['best_acc1'] 224 | if args.gpu is not None: 225 | # best_acc1 may be from a checkpoint from a different GPU 226 | best_acc1 = best_acc1.cuda()#to(args.gpu) 227 | model.load_state_dict(checkpoint['state_dict']) 228 | optimizer.load_state_dict(checkpoint['optimizer']) 229 | print("=> loaded checkpoint '{}' (epoch {})" 230 | .format(args.resume, checkpoint['epoch'])) 231 | else: 232 | print("=> no checkpoint found at '{}'".format(args.resume)) 233 | 234 | cudnn.benchmark = True 235 | 236 | # Data loading code 237 | traindir = os.path.join(args.data, 'train') 238 | valdir = os.path.join(args.data, 'val') 239 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 240 | normalize = transforms.Normalize(0.5, 0.5) 241 | 242 | train_dataset = datasets.ImageFolder( 243 | traindir, 244 | transforms.Compose([ 245 | transforms.RandomResizedCrop(args.image_size), 246 | transforms.RandomHorizontalFlip(), 247 | transforms.ToTensor(), 248 | normalize, 249 | ])) 250 | 251 | if args.distributed: 252 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 253 | else: 254 | train_sampler = None 255 | 256 | train_loader = torch.utils.data.DataLoader( 257 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 258 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 259 | 260 | val_transforms = transforms.Compose([ 261 | transforms.Resize(args.image_size, interpolation=PIL.Image.BICUBIC), 262 | transforms.CenterCrop(args.image_size), 263 | transforms.ToTensor(), 264 | normalize, 265 | ]) 266 | print('Using image size', args.image_size) 267 | 268 | val_loader = torch.utils.data.DataLoader( 269 | datasets.ImageFolder(valdir, val_transforms), 270 | batch_size=args.batch_size, shuffle=False, 271 | num_workers=args.workers, pin_memory=True) 272 | 273 | if args.evaluate: 274 | res = validate(val_loader, model, criterion, args, cs_sampling) 275 | with open('res.txt', 'w') as f: 276 | print(res, file=f) 277 | return 278 | 279 | for epoch in range(args.start_epoch, args.epochs): 280 | if args.distributed: 281 | train_sampler.set_epoch(epoch) 282 | adjust_learning_rate(optimizer, epoch, args) 283 | 284 | # train for one epoch 285 | train(train_loader, model, criterion, optimizer, epoch, args,writer_loss,writer_acc,cs_sampling) 286 | 287 | # evaluate on validation set 288 | acc1 = validate(val_loader, model, criterion, args, cs_sampling) 289 | 290 | # remember best acc@1 and save checkpoint 291 | is_best = acc1 > best_acc1 292 | best_acc1 = max(acc1, best_acc1) 293 | 294 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 295 | and args.rank % ngpus_per_node == 0): 296 | save_checkpoint({ 297 | 'epoch': epoch + 1, 298 | 'arch': args.arch, 299 | 'state_dict': model.state_dict(), 300 | 'best_acc1': best_acc1, 301 | 'optimizer' : optimizer.state_dict(), 302 | }, is_best,args.save_path+'/ckp_'+str(epoch)) 303 | if args.cs==1: 304 | save_checkpoint({ 305 | 'epoch': epoch + 1, 306 | 'arch': args.arch, 307 | 'state_dict': cs_sampling.state_dict(), 308 | 'best_acc1': best_acc1, 309 | 'optimizer' : optimizer.state_dict(), 310 | }, is_best,args.save_path+'/ckp_CS_'+str(epoch)) 311 | 312 | 313 | def train(train_loader, model, criterion, optimizer, epoch, args,writer_loss,writer_acc,cs_sampling): 314 | print('Start training.') 315 | batch_time = AverageMeter('Time', ':6.3f') 316 | data_time = AverageMeter('Data', ':6.3f') 317 | losses = AverageMeter('Loss', ':.4e') 318 | top1 = AverageMeter('Acc@1', ':6.2f') 319 | top5 = AverageMeter('Acc@5', ':6.2f') 320 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 321 | top5, prefix="Epoch: [{}]".format(epoch)) 322 | 323 | # switch to train mode 324 | model.train() 325 | 326 | end = time.time() 327 | step = 0 328 | l_ = len(train_loader) 329 | for i, (images, target) in enumerate(train_loader): 330 | # print(images.shape) 331 | # measure data loading time 332 | data_time.update(time.time() - end) 333 | 334 | if args.gpu is not None: 335 | images = images.cuda()#(args.gpu, non_blocking=True) 336 | if cs_sampling: 337 | # print(cs_sampling) 338 | images = cs_sampling(images) 339 | # exit(0) 340 | target = target.cuda()#(args.gpu, non_blocking=True) 341 | 342 | # compute output 343 | output = model(images) 344 | loss = criterion(output, target) 345 | 346 | # measure accuracy and record loss 347 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 348 | losses.update(loss.item(), images.size(0)) 349 | top1.update(acc1[0], images.size(0)) 350 | top5.update(acc5[0], images.size(0)) 351 | # print('test:',acc1[0],acc5[0],losses) 352 | # exit(0) 353 | 354 | # compute gradient and do SGD step 355 | optimizer.zero_grad() 356 | loss.backward() 357 | optimizer.step() 358 | 359 | # measure elapsed time 360 | batch_time.update(time.time() - end) 361 | end = time.time() 362 | 363 | if i % args.print_freq == 0: 364 | writer_loss.add_scalar('Loss',loss.item(),step+(1000*128//args.batch_size)*epoch) 365 | writer_loss.add_scalar('ACC',acc1[0],step+(1000*128//args.batch_size)*epoch) 366 | writer_loss.add_scalar('ACC_avg',top1.avg,step+(1000*128//args.batch_size)*epoch) 367 | progress.print(i) 368 | step+=1 369 | 370 | 371 | def validate(val_loader, model, criterion, args,cs_sampling): 372 | batch_time = AverageMeter('Time', ':6.3f') 373 | losses = AverageMeter('Loss', ':.4e') 374 | top1 = AverageMeter('Acc@1', ':6.2f') 375 | top5 = AverageMeter('Acc@5', ':6.2f') 376 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 377 | prefix='Test: ') 378 | 379 | # switch to evaluate mode 380 | model.eval() 381 | 382 | with torch.no_grad(): 383 | end = time.time() 384 | for i, (images, target) in enumerate(val_loader): 385 | if args.gpu is not None: 386 | images = images.cuda()#(args.gpu, non_blocking=True) 387 | target = target.cuda()#(args.gpu,non_blocking=True) 388 | 389 | # compute output 390 | if cs_sampling: 391 | images = cs_sampling(images) 392 | output = model(images) 393 | loss = criterion(output, target) 394 | 395 | # measure accuracy and record loss 396 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 397 | losses.update(loss.item(), images.size(0)) 398 | top1.update(acc1[0], images.size(0)) 399 | top5.update(acc5[0], images.size(0)) 400 | 401 | # measure elapsed time 402 | batch_time.update(time.time() - end) 403 | end = time.time() 404 | 405 | if i % args.print_freq == 0: 406 | progress.print(i) 407 | 408 | # TODO: this should also be done with the ProgressMeter 409 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 410 | .format(top1=top1, top5=top5)) 411 | 412 | return top1.avg 413 | 414 | 415 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 416 | torch.save(state, filename) 417 | if is_best: 418 | shutil.copyfile(filename, 'model_best.pth.tar') 419 | 420 | 421 | class AverageMeter(object): 422 | """Computes and stores the average and current value""" 423 | def __init__(self, name, fmt=':f'): 424 | self.name = name 425 | self.fmt = fmt 426 | self.reset() 427 | 428 | def reset(self): 429 | self.val = 0 430 | self.avg = 0 431 | self.sum = 0 432 | self.count = 0 433 | 434 | def update(self, val, n=1): 435 | self.val = val 436 | self.sum += val * n 437 | self.count += n 438 | self.avg = self.sum / self.count 439 | 440 | def __str__(self): 441 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 442 | return fmtstr.format(**self.__dict__) 443 | 444 | 445 | class ProgressMeter(object): 446 | def __init__(self, num_batches, *meters, prefix=""): 447 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 448 | self.meters = meters 449 | self.prefix = prefix 450 | 451 | def print(self, batch): 452 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 453 | entries += [str(meter) for meter in self.meters] 454 | print('\t'.join(entries)) 455 | 456 | def _get_batch_fmtstr(self, num_batches): 457 | num_digits = len(str(num_batches // 1)) 458 | fmt = '{:' + str(num_digits) + 'd}' 459 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 460 | 461 | 462 | def adjust_learning_rate(optimizer, epoch, args): 463 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 464 | lr = args.lr * (0.1 ** (epoch // 30)) 465 | for param_group in optimizer.param_groups: 466 | param_group['lr'] = lr 467 | 468 | 469 | def accuracy(output, target, topk=(1,)): 470 | """Computes the accuracy over the k top predictions for the specified values of k""" 471 | with torch.no_grad(): 472 | maxk = max(topk) 473 | batch_size = target.size(0) 474 | 475 | _, pred = output.topk(maxk, 1, True, True) 476 | pred = pred.t() 477 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 478 | 479 | res = [] 480 | for k in topk: 481 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 482 | res.append(correct_k.mul_(100.0 / batch_size)) 483 | return res 484 | 485 | 486 | if __name__ == '__main__': 487 | main() 488 | -------------------------------------------------------------------------------- /classification/main_cifar10.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import PIL 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from Sampling import CS_Sampling 22 | from tensorboardX import SummaryWriter 23 | import torchvision 24 | from vit import ViT, load_pretrained_weights 25 | # from vit_my import ViT, load_pretrained_weights 26 | 27 | # from pytorch_pretrained_vit import ViT, load_pretrained_weights 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 30 | parser.add_argument('--data', default='/gdata/ImageNet2012', 31 | help='path to dataset') 32 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 33 | help='model architecture (default: resnet18)') 34 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--epochs', default=1000, type=int, metavar='N', 37 | help='number of total epochs to run') 38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 39 | help='manual epoch number (useful on restarts)') 40 | parser.add_argument('-b', '--batch-size', default=256, type=int, 41 | metavar='N', 42 | help='mini-batch size (default: 256), this is the total ' 43 | 'batch size of all GPUs on the current node when ' 44 | 'using Data Parallel or Distributed Data Parallel') 45 | # parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 46 | # metavar='LR', help='initial learning rate', dest='lr') 47 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 48 | metavar='LR', help='initial learning rate', dest='lr') 49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 50 | help='momentum') 51 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 52 | metavar='W', help='weight decay (default: 1e-4)', 53 | dest='weight_decay') 54 | parser.add_argument('-p', '--print-freq', default=40, type=int, 55 | metavar='N', help='print frequency (default: 10)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='path to latest checkpoint (default: none)') 58 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 59 | help='evaluate model on validation set') 60 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 61 | help='use pre-trained model') 62 | parser.add_argument('--world-size', default=-1, type=int, 63 | help='number of nodes for distributed training') 64 | parser.add_argument('--rank', default=-1, type=int, 65 | help='node rank for distributed training') 66 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 67 | help='url used to set up distributed training') 68 | parser.add_argument('--dist-backend', default='nccl', type=str, 69 | help='distributed backend') 70 | parser.add_argument('--seed', default=None, type=int, 71 | help='seed for initializing training. ') 72 | parser.add_argument('--gpu', default=None, type=int, 73 | help='GPU id to use.') 74 | parser.add_argument('--image_size', default=224, type=int, 75 | help='image size') 76 | parser.add_argument('--vit', default=True, help='use ViT model') 77 | parser.add_argument('--multiprocessing-distributed', action='store_true', 78 | help='Use multi-processing distributed training to launch ' 79 | 'N processes per node, which has N GPUs. This is the ' 80 | 'fastest way to use PyTorch for either single node or ' 81 | 'multi node data parallel training') 82 | parser.add_argument('--cs', type=int, default=0) 83 | parser.add_argument('--log_dir', default='logs') 84 | parser.add_argument('--mm', type=int, default=0) 85 | parser.add_argument('--save_path', type=str, default='ckp') 86 | parser.add_argument('--rat', type=float, default=0.1) 87 | parser.add_argument('--devices', type=int, default=4) 88 | parser.add_argument('--psize', type=int, default=32) 89 | parser.add_argument('--weights_path', type=str, default='/group/30042/chongmou/ft_local/TransCL/TransCL/classification/ft_local/pretrain_cifar/ckp_cf10_pre/ckp_89') 90 | parser.add_argument('--cs_mode', type=str, default='mcl') 91 | best_acc1 = 0 92 | 93 | 94 | def main(): 95 | args = parser.parse_args() 96 | 97 | if args.seed is not None: 98 | random.seed(args.seed) 99 | torch.manual_seed(args.seed) 100 | cudnn.deterministic = True 101 | warnings.warn('You have chosen to seed training. ' 102 | 'This will turn on the CUDNN deterministic setting, ' 103 | 'which can slow down your training considerably! ' 104 | 'You may see unexpected behavior when restarting ' 105 | 'from checkpoints.') 106 | 107 | if args.gpu is not None: 108 | warnings.warn('You have chosen a specific GPU. This will completely ' 109 | 'disable data parallelism.') 110 | 111 | if args.dist_url == "env://" and args.world_size == -1: 112 | args.world_size = int(os.environ["WORLD_SIZE"]) 113 | 114 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 115 | 116 | ngpus_per_node = torch.cuda.device_count() 117 | if args.multiprocessing_distributed: 118 | # Since we have ngpus_per_node processes per node, the total world_size 119 | # needs to be adjusted accordingly 120 | args.world_size = ngpus_per_node * args.world_size 121 | # Use torch.multiprocessing.spawn to launch distributed processes: the 122 | # main_worker process function 123 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 124 | else: 125 | # Simply call main_worker function 126 | main_worker(args.gpu, ngpus_per_node, args) 127 | 128 | 129 | def main_worker(gpu, ngpus_per_node, args): 130 | # print(args.cs,args.lr) 131 | # exit(0) 132 | global best_acc1 133 | args.gpu = gpu 134 | patch_size = args.image_size // 16 135 | writer_loss = SummaryWriter(args.log_dir) 136 | writer_acc = SummaryWriter(args.log_dir) 137 | if os.path.isdir(args.save_path) == False: 138 | os.mkdir(args.save_path) 139 | # print('patch_size:', patch_size) 140 | 141 | if args.gpu is not None: 142 | print("Use GPU: {} for training".format(args.gpu)) 143 | 144 | if args.distributed: 145 | if args.dist_url == "env://" and args.rank == -1: 146 | args.rank = int(os.environ["RANK"]) 147 | if args.multiprocessing_distributed: 148 | # For multiprocessing distributed training, rank needs to be the 149 | # global rank among all the processes 150 | args.rank = args.rank * ngpus_per_node + gpu 151 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 152 | world_size=args.world_size, rank=args.rank) 153 | 154 | # NEW 155 | if args.vit: 156 | model = ViT(args.arch, pretrained=True, image_size=(args.image_size, args.image_size), num_classes=10, 157 | weights_path=args.weights_path) 158 | # model = ViT(args.arch, pretrained=False, image_size=(args.image_size, args.image_size), num_classes=10) 159 | # st = torch.load(args.weights_path) 160 | # model.load_state_dict(st['state_dict']) 161 | 162 | else: 163 | model = models.__dict__[args.arch](pretrained=args.pretrained) 164 | 165 | if args.cs == 1: 166 | print('Now Use CS.') 167 | print('CS ratio=', args.rat) 168 | cs_sampling = CS_Sampling(n_channels=3, cs_ratio=args.rat, blocksize=32, im_size=args.image_size).cuda() 169 | cs_sampling = torch.nn.DataParallel(cs_sampling, range(args.devices)) 170 | else: 171 | cs_sampling = None 172 | print("=> using model '{}' (pretrained={})".format(args.arch, args.pretrained)) 173 | 174 | if args.distributed: 175 | # For multiprocessing distributed, DistributedDataParallel constructor 176 | # should always set the single device scope, otherwise, 177 | # DistributedDataParallel will use all available devices. 178 | if args.gpu is not None: 179 | torch.cuda.set_device(args.gpu) 180 | model.cuda(args.gpu) 181 | # When using a single GPU per process and per 182 | # DistributedDataParallel, we need to divide the batch size 183 | # ourselves based on the total number of GPUs we have 184 | args.batch_size = int(args.batch_size / ngpus_per_node) 185 | args.workers = int(args.workers / ngpus_per_node) 186 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.devices]) # [args.gpu]) 187 | else: 188 | model.cuda() 189 | # DistributedDataParallel will divide and allocate batch_size to all 190 | # available GPUs if device_ids are not set 191 | model = torch.nn.parallel.DistributedDataParallel(model) 192 | elif args.gpu is not None: 193 | print('No distribut') 194 | # torch.cuda.set_device(args.gpu) 195 | model.cuda() 196 | model = torch.nn.DataParallel(model, range(args.devices)) 197 | # model = model.cuda(args.gpu) 198 | else: 199 | # DataParallel will divide and allocate batch_size to all available GPUs 200 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 201 | model.features = torch.nn.DataParallel(model.features) 202 | model.cuda() 203 | else: 204 | model = torch.nn.DataParallel(model).cuda() 205 | 206 | # define loss function (criterion) and optimizer 207 | criterion = nn.CrossEntropyLoss().cuda() # (args.gpu) 208 | if args.cs == 1: 209 | if args.mm == 1: 210 | print('MM!') 211 | # optimizer = torch.optim.Adam( 212 | # [{'params': model.parameters()}, 213 | # {'params': cs_sampling.parameters()}] 214 | # , args.lr, 215 | # weight_decay=args.weight_decay, 216 | # betas=(0.9, 0.999) 217 | # ) 218 | optimizer = torch.optim.SGD( 219 | [{'params': model.parameters()}, 220 | {'params': cs_sampling.parameters()}] 221 | , args.lr, 222 | momentum=args.momentum, 223 | weight_decay=args.weight_decay) 224 | print('scheduler') 225 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5, 226 | last_epoch=-1) # MultiStepLR(optimizer, milestones=[80,120], gamma=0.1) 227 | else: 228 | optimizer = torch.optim.SGD(cs_sampling.parameters(), args.lr, 229 | momentum=args.momentum, 230 | weight_decay=args.weight_decay) 231 | else: 232 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 233 | momentum=args.momentum, 234 | weight_decay=args.weight_decay) 235 | 236 | # optionally resume from a checkpoint 237 | if args.resume: 238 | if os.path.isfile(args.resume): 239 | print("=> loading checkpoint '{}'".format(args.resume)) 240 | checkpoint = torch.load(args.resume) 241 | args.start_epoch = checkpoint['epoch'] 242 | best_acc1 = checkpoint['best_acc1'] 243 | if args.gpu is not None: 244 | # best_acc1 may be from a checkpoint from a different GPU 245 | best_acc1 = best_acc1.cuda() # to(args.gpu) 246 | model.load_state_dict(checkpoint['state_dict']) 247 | optimizer.load_state_dict(checkpoint['optimizer']) 248 | print("=> loaded checkpoint '{}' (epoch {})" 249 | .format(args.resume, checkpoint['epoch'])) 250 | else: 251 | print("=> no checkpoint found at '{}'".format(args.resume)) 252 | 253 | cudnn.benchmark = True 254 | 255 | # Data loading code 256 | # traindir = os.path.join(args.data, 'train') 257 | # valdir = os.path.join(args.data, 'val') 258 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 259 | normalize = transforms.Normalize(0.5, 0.5) 260 | 261 | transform = transforms.Compose([ 262 | transforms.RandomResizedCrop(32), 263 | transforms.RandomHorizontalFlip(), 264 | transforms.ToTensor(), 265 | normalize, 266 | ]) 267 | 268 | val_transforms = transforms.Compose([ 269 | transforms.Resize(32, interpolation=PIL.Image.BICUBIC), 270 | transforms.CenterCrop(32), 271 | transforms.ToTensor(), 272 | normalize, 273 | ]) 274 | 275 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 276 | download=True, transform=transform) 277 | 278 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 279 | download=True, transform=val_transforms) 280 | 281 | # train_dataset = datasets.ImageFolder( 282 | # traindir, 283 | # transforms.Compose([ 284 | # transforms.RandomResizedCrop(args.image_size), 285 | # transforms.RandomHorizontalFlip(), 286 | # transforms.ToTensor(), 287 | # normalize, 288 | # ])) 289 | 290 | if args.distributed: 291 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) 292 | else: 293 | train_sampler = None 294 | 295 | train_loader = torch.utils.data.DataLoader( 296 | trainset, batch_size=args.batch_size, shuffle=(train_sampler is None), 297 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 298 | 299 | # val_transforms = transforms.Compose([ 300 | # transforms.Resize(args.image_size, interpolation=PIL.Image.BICUBIC), 301 | # transforms.CenterCrop(args.image_size), 302 | # transforms.ToTensor(), 303 | # normalize, 304 | # ]) 305 | print('Using image size', args.image_size) 306 | 307 | val_loader = torch.utils.data.DataLoader( 308 | testset, 309 | batch_size=args.batch_size, shuffle=False, 310 | num_workers=args.workers, pin_memory=True) 311 | 312 | if args.evaluate: 313 | res = validate(val_loader, model, criterion, args, cs_sampling) 314 | with open('res.txt', 'w') as f: 315 | print(res, file=f) 316 | return 317 | 318 | for epoch in range(args.start_epoch, args.epochs): 319 | scheduler.step() 320 | print(epoch, 'Learning rate: ', optimizer.param_groups[0]['lr']) 321 | if args.distributed: 322 | train_sampler.set_epoch(epoch) 323 | # adjust_learning_rate(optimizer, epoch, args) 324 | 325 | # train for one epoch 326 | train(train_loader, model, criterion, optimizer, epoch, args, writer_loss, writer_acc, cs_sampling) 327 | 328 | # evaluate on validation set 329 | acc1 = validate(val_loader, model, criterion, args, cs_sampling) 330 | 331 | # remember best acc@1 and save checkpoint 332 | is_best = acc1 > best_acc1 333 | best_acc1 = max(acc1, best_acc1) 334 | 335 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 336 | and args.rank % ngpus_per_node == 0): 337 | save_checkpoint({ 338 | 'epoch': epoch + 1, 339 | 'arch': args.arch, 340 | 'state_dict': model.state_dict(), 341 | 'best_acc1': best_acc1, 342 | 'optimizer': optimizer.state_dict(), 343 | }, is_best, args.save_path + '/ckp_' + str(epoch)) 344 | if args.cs == 1: 345 | save_checkpoint({ 346 | 'epoch': epoch + 1, 347 | 'arch': args.arch, 348 | 'state_dict': cs_sampling.state_dict(), 349 | 'best_acc1': best_acc1, 350 | 'optimizer': optimizer.state_dict(), 351 | }, is_best, args.save_path + '/ckp_CS_' + str(epoch)) 352 | 353 | 354 | def train(train_loader, model, criterion, optimizer, epoch, args, writer_loss, writer_acc, cs_sampling): 355 | print('Start training.') 356 | batch_time = AverageMeter('Time', ':6.3f') 357 | data_time = AverageMeter('Data', ':6.3f') 358 | losses = AverageMeter('Loss', ':.4e') 359 | top1 = AverageMeter('Acc@1', ':6.2f') 360 | top5 = AverageMeter('Acc@5', ':6.2f') 361 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 362 | top5, prefix="Epoch: [{}]".format(epoch)) 363 | 364 | # switch to train mode 365 | model.train() 366 | 367 | end = time.time() 368 | step = 0 369 | l_ = len(train_loader) 370 | for i, (images, target) in enumerate(train_loader): 371 | # print(images.shape) 372 | # measure data loading time 373 | data_time.update(time.time() - end) 374 | 375 | if args.gpu is not None: 376 | images = images.cuda() # (args.gpu, non_blocking=True) 377 | if cs_sampling: 378 | # print(cs_sampling) 379 | images = cs_sampling(images) 380 | # exit(0) 381 | target = target.cuda() # (args.gpu, non_blocking=True) 382 | 383 | # compute output 384 | output = model(images) 385 | loss = criterion(output, target) 386 | 387 | # measure accuracy and record loss 388 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 389 | losses.update(loss.item(), images.size(0)) 390 | top1.update(acc1[0], images.size(0)) 391 | top5.update(acc5[0], images.size(0)) 392 | # print('test:',acc1[0],acc5[0],losses) 393 | # exit(0) 394 | 395 | # compute gradient and do SGD step 396 | optimizer.zero_grad() 397 | loss.backward() 398 | optimizer.step() 399 | 400 | # measure elapsed time 401 | batch_time.update(time.time() - end) 402 | end = time.time() 403 | 404 | if i % args.print_freq == 0: 405 | writer_loss.add_scalar('Loss', loss.item(), step + (1000 * 128 // args.batch_size) * epoch) 406 | writer_loss.add_scalar('ACC', acc1[0], step + (1000 * 128 // args.batch_size) * epoch) 407 | writer_loss.add_scalar('ACC_avg', top1.avg, step + (1000 * 128 // args.batch_size) * epoch) 408 | progress.print(i) 409 | step += 1 410 | 411 | 412 | def validate(val_loader, model, criterion, args, cs_sampling): 413 | batch_time = AverageMeter('Time', ':6.3f') 414 | losses = AverageMeter('Loss', ':.4e') 415 | top1 = AverageMeter('Acc@1', ':6.2f') 416 | top5 = AverageMeter('Acc@5', ':6.2f') 417 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 418 | prefix='Test: ') 419 | 420 | # switch to evaluate mode 421 | model.eval() 422 | 423 | with torch.no_grad(): 424 | end = time.time() 425 | for i, (images, target) in enumerate(val_loader): 426 | if args.gpu is not None: 427 | images = images.cuda() # (args.gpu, non_blocking=True) 428 | target = target.cuda() # (args.gpu,non_blocking=True) 429 | 430 | # compute output 431 | if cs_sampling: 432 | images = cs_sampling(images) 433 | output = model(images) 434 | loss = criterion(output, target) 435 | 436 | # measure accuracy and record loss 437 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 438 | losses.update(loss.item(), images.size(0)) 439 | top1.update(acc1[0], images.size(0)) 440 | top5.update(acc5[0], images.size(0)) 441 | 442 | # measure elapsed time 443 | batch_time.update(time.time() - end) 444 | end = time.time() 445 | 446 | if i % args.print_freq == 0: 447 | progress.print(i) 448 | 449 | # TODO: this should also be done with the ProgressMeter 450 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 451 | .format(top1=top1, top5=top5)) 452 | 453 | return top1.avg 454 | 455 | 456 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 457 | torch.save(state, filename) 458 | if is_best: 459 | shutil.copyfile(filename, 'model_best.pth.tar') 460 | 461 | 462 | class AverageMeter(object): 463 | """Computes and stores the average and current value""" 464 | 465 | def __init__(self, name, fmt=':f'): 466 | self.name = name 467 | self.fmt = fmt 468 | self.reset() 469 | 470 | def reset(self): 471 | self.val = 0 472 | self.avg = 0 473 | self.sum = 0 474 | self.count = 0 475 | 476 | def update(self, val, n=1): 477 | self.val = val 478 | self.sum += val * n 479 | self.count += n 480 | self.avg = self.sum / self.count 481 | 482 | def __str__(self): 483 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 484 | return fmtstr.format(**self.__dict__) 485 | 486 | 487 | class ProgressMeter(object): 488 | def __init__(self, num_batches, *meters, prefix=""): 489 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 490 | self.meters = meters 491 | self.prefix = prefix 492 | 493 | def print(self, batch): 494 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 495 | entries += [str(meter) for meter in self.meters] 496 | print('\t'.join(entries)) 497 | 498 | def _get_batch_fmtstr(self, num_batches): 499 | num_digits = len(str(num_batches // 1)) 500 | fmt = '{:' + str(num_digits) + 'd}' 501 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 502 | 503 | 504 | def adjust_learning_rate(optimizer, epoch, args): 505 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 506 | lr = args.lr * (0.1 ** (epoch // 30)) 507 | for param_group in optimizer.param_groups: 508 | param_group['lr'] = lr 509 | 510 | 511 | def accuracy(output, target, topk=(1,)): 512 | """Computes the accuracy over the k top predictions for the specified values of k""" 513 | with torch.no_grad(): 514 | maxk = max(topk) 515 | batch_size = target.size(0) 516 | 517 | _, pred = output.topk(maxk, 1, True, True) 518 | pred = pred.t() 519 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 520 | 521 | res = [] 522 | for k in topk: 523 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 524 | res.append(correct_k.mul_(100.0 / batch_size)) 525 | return res 526 | 527 | 528 | if __name__ == '__main__': 529 | main() -------------------------------------------------------------------------------- /classification/main_cifar100.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import PIL 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from Sampling import CS_Sampling 22 | from tensorboardX import SummaryWriter 23 | import torchvision 24 | from vit import ViT, load_pretrained_weights 25 | # from pytorch_pretrained_vit import ViT, load_pretrained_weights 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 28 | parser.add_argument('--data', default='/gdata/ImageNet2012', 29 | help='path to dataset') 30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 31 | help='model architecture (default: resnet18)') 32 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=1000, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('-b', '--batch-size', default=256, type=int, 39 | metavar='N', 40 | help='mini-batch size (default: 256), this is the total ' 41 | 'batch size of all GPUs on the current node when ' 42 | 'using Data Parallel or Distributed Data Parallel') 43 | # parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 44 | # metavar='LR', help='initial learning rate', dest='lr') 45 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 46 | metavar='LR', help='initial learning rate', dest='lr') 47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 48 | help='momentum') 49 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 50 | metavar='W', help='weight decay (default: 1e-4)', 51 | dest='weight_decay') 52 | parser.add_argument('-p', '--print-freq', default=40, type=int, 53 | metavar='N', help='print frequency (default: 10)') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 59 | help='use pre-trained model') 60 | parser.add_argument('--world-size', default=-1, type=int, 61 | help='number of nodes for distributed training') 62 | parser.add_argument('--rank', default=-1, type=int, 63 | help='node rank for distributed training') 64 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 65 | help='url used to set up distributed training') 66 | parser.add_argument('--dist-backend', default='nccl', type=str, 67 | help='distributed backend') 68 | parser.add_argument('--seed', default=None, type=int, 69 | help='seed for initializing training. ') 70 | parser.add_argument('--gpu', default=None, type=int, 71 | help='GPU id to use.') 72 | parser.add_argument('--image_size', default=224, type=int, 73 | help='image size') 74 | parser.add_argument('--vit', default=True, help='use ViT model') 75 | parser.add_argument('--multiprocessing-distributed', action='store_true', 76 | help='Use multi-processing distributed training to launch ' 77 | 'N processes per node, which has N GPUs. This is the ' 78 | 'fastest way to use PyTorch for either single node or ' 79 | 'multi node data parallel training') 80 | parser.add_argument('--cs', type=int, default=0) 81 | parser.add_argument('--log_dir', default='logs') 82 | parser.add_argument('--mm', type=int, default=0) 83 | parser.add_argument('--save_path', type=str, default='ckp') 84 | parser.add_argument('--rat', type=float, default=0.1) 85 | parser.add_argument('--devices', type=int, default=4) 86 | parser.add_argument('--psize', type=int, default=32) 87 | parser.add_argument('--weights_path',type=str,default='ckp_vit_cifar100_p32/ckp_89') 88 | best_acc1 = 0 89 | 90 | 91 | def main(): 92 | args = parser.parse_args() 93 | 94 | if args.seed is not None: 95 | random.seed(args.seed) 96 | torch.manual_seed(args.seed) 97 | cudnn.deterministic = True 98 | warnings.warn('You have chosen to seed training. ' 99 | 'This will turn on the CUDNN deterministic setting, ' 100 | 'which can slow down your training considerably! ' 101 | 'You may see unexpected behavior when restarting ' 102 | 'from checkpoints.') 103 | 104 | if args.gpu is not None: 105 | warnings.warn('You have chosen a specific GPU. This will completely ' 106 | 'disable data parallelism.') 107 | 108 | if args.dist_url == "env://" and args.world_size == -1: 109 | args.world_size = int(os.environ["WORLD_SIZE"]) 110 | 111 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 112 | 113 | ngpus_per_node = torch.cuda.device_count() 114 | if args.multiprocessing_distributed: 115 | # Since we have ngpus_per_node processes per node, the total world_size 116 | # needs to be adjusted accordingly 117 | args.world_size = ngpus_per_node * args.world_size 118 | # Use torch.multiprocessing.spawn to launch distributed processes: the 119 | # main_worker process function 120 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 121 | else: 122 | # Simply call main_worker function 123 | main_worker(args.gpu, ngpus_per_node, args) 124 | 125 | 126 | def main_worker(gpu, ngpus_per_node, args): 127 | # print(args.cs,args.lr) 128 | # exit(0) 129 | global best_acc1 130 | args.gpu = gpu 131 | patch_size = args.image_size // 16 132 | writer_loss = SummaryWriter(args.log_dir) 133 | writer_acc = SummaryWriter(args.log_dir) 134 | if os.path.isdir(args.save_path) == False: 135 | os.mkdir(args.save_path) 136 | # print('patch_size:', patch_size) 137 | 138 | if args.gpu is not None: 139 | print("Use GPU: {} for training".format(args.gpu)) 140 | 141 | if args.distributed: 142 | if args.dist_url == "env://" and args.rank == -1: 143 | args.rank = int(os.environ["RANK"]) 144 | if args.multiprocessing_distributed: 145 | # For multiprocessing distributed training, rank needs to be the 146 | # global rank among all the processes 147 | args.rank = args.rank * ngpus_per_node + gpu 148 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 149 | world_size=args.world_size, rank=args.rank) 150 | 151 | # NEW 152 | if args.vit: 153 | model = ViT(args.arch, pretrained=True, image_size=(args.image_size, args.image_size),num_classes=100,weights_path=args.weights_path) 154 | else: 155 | model = models.__dict__[args.arch](pretrained=args.pretrained) 156 | 157 | if args.cs == 1: 158 | print('Now Use CS.') 159 | print('CS ratio=', args.rat) 160 | cs_sampling = CS_Sampling(n_channels=3, cs_ratio=args.rat, blocksize=32,im_size=args.image_size).cuda()#args.psize).cuda() 161 | cs_sampling = torch.nn.DataParallel(cs_sampling, range(args.devices)) 162 | else: 163 | cs_sampling = None 164 | print("=> using model '{}' (pretrained={})".format(args.arch, args.pretrained)) 165 | 166 | if args.distributed: 167 | # For multiprocessing distributed, DistributedDataParallel constructor 168 | # should always set the single device scope, otherwise, 169 | # DistributedDataParallel will use all available devices. 170 | if args.gpu is not None: 171 | torch.cuda.set_device(args.gpu) 172 | model.cuda(args.gpu) 173 | # When using a single GPU per process and per 174 | # DistributedDataParallel, we need to divide the batch size 175 | # ourselves based on the total number of GPUs we have 176 | args.batch_size = int(args.batch_size / ngpus_per_node) 177 | args.workers = int(args.workers / ngpus_per_node) 178 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.devices]) # [args.gpu]) 179 | else: 180 | model.cuda() 181 | # DistributedDataParallel will divide and allocate batch_size to all 182 | # available GPUs if device_ids are not set 183 | model = torch.nn.parallel.DistributedDataParallel(model) 184 | elif args.gpu is not None: 185 | print('No distribut') 186 | # torch.cuda.set_device(args.gpu) 187 | model.cuda() 188 | model = torch.nn.DataParallel(model, range(args.devices)) 189 | # model = model.cuda(args.gpu) 190 | else: 191 | # DataParallel will divide and allocate batch_size to all available GPUs 192 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 193 | model.features = torch.nn.DataParallel(model.features) 194 | model.cuda() 195 | else: 196 | model = torch.nn.DataParallel(model).cuda() 197 | 198 | # define loss function (criterion) and optimizer 199 | criterion = nn.CrossEntropyLoss().cuda() # (args.gpu) 200 | if args.cs == 1: 201 | if args.mm == 1: 202 | print('MM!') 203 | optimizer = torch.optim.SGD( 204 | [{'params': model.parameters()}, 205 | {'params': cs_sampling.parameters()}] 206 | , args.lr, 207 | momentum=args.momentum, 208 | weight_decay=args.weight_decay) 209 | print('scheduler') 210 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5, last_epoch=-1) 211 | else: 212 | optimizer = torch.optim.SGD(cs_sampling.parameters(), args.lr, 213 | momentum=args.momentum, 214 | weight_decay=args.weight_decay) 215 | else: 216 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 217 | momentum=args.momentum, 218 | weight_decay=args.weight_decay) 219 | 220 | # optionally resume from a checkpoint 221 | if args.resume: 222 | if os.path.isfile(args.resume): 223 | print("=> loading checkpoint '{}'".format(args.resume)) 224 | checkpoint = torch.load(args.resume) 225 | args.start_epoch = checkpoint['epoch'] 226 | best_acc1 = checkpoint['best_acc1'] 227 | if args.gpu is not None: 228 | # best_acc1 may be from a checkpoint from a different GPU 229 | best_acc1 = best_acc1.cuda() # to(args.gpu) 230 | model.load_state_dict(checkpoint['state_dict']) 231 | optimizer.load_state_dict(checkpoint['optimizer']) 232 | print("=> loaded checkpoint '{}' (epoch {})" 233 | .format(args.resume, checkpoint['epoch'])) 234 | else: 235 | print("=> no checkpoint found at '{}'".format(args.resume)) 236 | 237 | cudnn.benchmark = True 238 | 239 | # Data loading code 240 | # traindir = os.path.join(args.data, 'train') 241 | # valdir = os.path.join(args.data, 'val') 242 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 243 | normalize = transforms.Normalize(0.5, 0.5) 244 | 245 | transform=transforms.Compose([ 246 | transforms.RandomResizedCrop(32), 247 | transforms.RandomHorizontalFlip(), 248 | transforms.ToTensor(), 249 | normalize, 250 | ]) 251 | 252 | val_transforms = transforms.Compose([ 253 | transforms.Resize(32, interpolation=PIL.Image.BICUBIC), 254 | transforms.CenterCrop(32), 255 | transforms.ToTensor(), 256 | normalize, 257 | ]) 258 | 259 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, 260 | download=True, transform=transform) 261 | 262 | 263 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, 264 | download=True, transform=val_transforms) 265 | 266 | # train_dataset = datasets.ImageFolder( 267 | # traindir, 268 | # transforms.Compose([ 269 | # transforms.RandomResizedCrop(args.image_size), 270 | # transforms.RandomHorizontalFlip(), 271 | # transforms.ToTensor(), 272 | # normalize, 273 | # ])) 274 | 275 | if args.distributed: 276 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) 277 | else: 278 | train_sampler = None 279 | 280 | train_loader = torch.utils.data.DataLoader( 281 | trainset, batch_size=args.batch_size, shuffle=(train_sampler is None), 282 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 283 | 284 | # val_transforms = transforms.Compose([ 285 | # transforms.Resize(args.image_size, interpolation=PIL.Image.BICUBIC), 286 | # transforms.CenterCrop(args.image_size), 287 | # transforms.ToTensor(), 288 | # normalize, 289 | # ]) 290 | print('Using image size', args.image_size) 291 | 292 | val_loader = torch.utils.data.DataLoader( 293 | testset, 294 | batch_size=args.batch_size, shuffle=False, 295 | num_workers=args.workers, pin_memory=True) 296 | 297 | if args.evaluate: 298 | res = validate(val_loader, model, criterion, args, cs_sampling) 299 | with open('res.txt', 'w') as f: 300 | print(res, file=f) 301 | return 302 | 303 | for epoch in range(args.start_epoch, args.epochs): 304 | scheduler.step() 305 | print(epoch,'Learning rate: ', optimizer.param_groups[0]['lr']) 306 | if args.distributed: 307 | train_sampler.set_epoch(epoch) 308 | # adjust_learning_rate(optimizer, epoch, args) 309 | 310 | # train for one epoch 311 | train(train_loader, model, criterion, optimizer, epoch, args, writer_loss, writer_acc, cs_sampling) 312 | 313 | # evaluate on validation set 314 | acc1 = validate(val_loader, model, criterion, args, cs_sampling) 315 | 316 | # remember best acc@1 and save checkpoint 317 | is_best = acc1 > best_acc1 318 | best_acc1 = max(acc1, best_acc1) 319 | 320 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 321 | and args.rank % ngpus_per_node == 0): 322 | save_checkpoint({ 323 | 'epoch': epoch + 1, 324 | 'arch': args.arch, 325 | 'state_dict': model.state_dict(), 326 | 'best_acc1': best_acc1, 327 | 'optimizer': optimizer.state_dict(), 328 | }, is_best, args.save_path + '/ckp_' + str(epoch)) 329 | if args.cs == 1: 330 | save_checkpoint({ 331 | 'epoch': epoch + 1, 332 | 'arch': args.arch, 333 | 'state_dict': cs_sampling.state_dict(), 334 | 'best_acc1': best_acc1, 335 | 'optimizer': optimizer.state_dict(), 336 | }, is_best, args.save_path + '/ckp_CS_' + str(epoch)) 337 | 338 | 339 | def train(train_loader, model, criterion, optimizer, epoch, args, writer_loss, writer_acc, cs_sampling): 340 | print('Start training.') 341 | batch_time = AverageMeter('Time', ':6.3f') 342 | data_time = AverageMeter('Data', ':6.3f') 343 | losses = AverageMeter('Loss', ':.4e') 344 | top1 = AverageMeter('Acc@1', ':6.2f') 345 | top5 = AverageMeter('Acc@5', ':6.2f') 346 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 347 | top5, prefix="Epoch: [{}]".format(epoch)) 348 | 349 | # switch to train mode 350 | model.train() 351 | 352 | end = time.time() 353 | step = 0 354 | l_ = len(train_loader) 355 | for i, (images, target) in enumerate(train_loader): 356 | # print(images.shape) 357 | # measure data loading time 358 | data_time.update(time.time() - end) 359 | 360 | if args.gpu is not None: 361 | images = images.cuda() # (args.gpu, non_blocking=True) 362 | if cs_sampling: 363 | # print(cs_sampling) 364 | images = cs_sampling(images) 365 | # exit(0) 366 | target = target.cuda() # (args.gpu, non_blocking=True) 367 | 368 | # compute output 369 | output = model(images) 370 | loss = criterion(output, target) 371 | 372 | # measure accuracy and record loss 373 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 374 | losses.update(loss.item(), images.size(0)) 375 | top1.update(acc1[0], images.size(0)) 376 | top5.update(acc5[0], images.size(0)) 377 | # print('test:',acc1[0],acc5[0],losses) 378 | # exit(0) 379 | 380 | # compute gradient and do SGD step 381 | optimizer.zero_grad() 382 | loss.backward() 383 | optimizer.step() 384 | 385 | # measure elapsed time 386 | batch_time.update(time.time() - end) 387 | end = time.time() 388 | 389 | if i % args.print_freq == 0: 390 | writer_loss.add_scalar('Loss', loss.item(), step + (1000 * 128 // args.batch_size) * epoch) 391 | writer_loss.add_scalar('ACC', acc1[0], step + (1000 * 128 // args.batch_size) * epoch) 392 | writer_loss.add_scalar('ACC_avg', top1.avg, step + (1000 * 128 // args.batch_size) * epoch) 393 | progress.print(i) 394 | step += 1 395 | 396 | 397 | def validate(val_loader, model, criterion, args, cs_sampling): 398 | batch_time = AverageMeter('Time', ':6.3f') 399 | losses = AverageMeter('Loss', ':.4e') 400 | top1 = AverageMeter('Acc@1', ':6.2f') 401 | top5 = AverageMeter('Acc@5', ':6.2f') 402 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 403 | prefix='Test: ') 404 | 405 | # switch to evaluate mode 406 | model.eval() 407 | 408 | with torch.no_grad(): 409 | end = time.time() 410 | for i, (images, target) in enumerate(val_loader): 411 | if args.gpu is not None: 412 | images = images.cuda() # (args.gpu, non_blocking=True) 413 | target = target.cuda() # (args.gpu,non_blocking=True) 414 | 415 | # compute output 416 | if cs_sampling: 417 | images = cs_sampling(images) 418 | output = model(images) 419 | loss = criterion(output, target) 420 | 421 | # measure accuracy and record loss 422 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 423 | losses.update(loss.item(), images.size(0)) 424 | top1.update(acc1[0], images.size(0)) 425 | top5.update(acc5[0], images.size(0)) 426 | 427 | # measure elapsed time 428 | batch_time.update(time.time() - end) 429 | end = time.time() 430 | 431 | if i % args.print_freq == 0: 432 | progress.print(i) 433 | 434 | # TODO: this should also be done with the ProgressMeter 435 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 436 | .format(top1=top1, top5=top5)) 437 | 438 | return top1.avg 439 | 440 | 441 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 442 | torch.save(state, filename) 443 | if is_best: 444 | shutil.copyfile(filename, 'model_best.pth.tar') 445 | 446 | 447 | class AverageMeter(object): 448 | """Computes and stores the average and current value""" 449 | 450 | def __init__(self, name, fmt=':f'): 451 | self.name = name 452 | self.fmt = fmt 453 | self.reset() 454 | 455 | def reset(self): 456 | self.val = 0 457 | self.avg = 0 458 | self.sum = 0 459 | self.count = 0 460 | 461 | def update(self, val, n=1): 462 | self.val = val 463 | self.sum += val * n 464 | self.count += n 465 | self.avg = self.sum / self.count 466 | 467 | def __str__(self): 468 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 469 | return fmtstr.format(**self.__dict__) 470 | 471 | 472 | class ProgressMeter(object): 473 | def __init__(self, num_batches, *meters, prefix=""): 474 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 475 | self.meters = meters 476 | self.prefix = prefix 477 | 478 | def print(self, batch): 479 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 480 | entries += [str(meter) for meter in self.meters] 481 | print('\t'.join(entries)) 482 | 483 | def _get_batch_fmtstr(self, num_batches): 484 | num_digits = len(str(num_batches // 1)) 485 | fmt = '{:' + str(num_digits) + 'd}' 486 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 487 | 488 | 489 | def adjust_learning_rate(optimizer, epoch, args): 490 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 491 | lr = args.lr * (0.1 ** (epoch // 30)) 492 | for param_group in optimizer.param_groups: 493 | param_group['lr'] = lr 494 | 495 | 496 | def accuracy(output, target, topk=(1,)): 497 | """Computes the accuracy over the k top predictions for the specified values of k""" 498 | with torch.no_grad(): 499 | maxk = max(topk) 500 | batch_size = target.size(0) 501 | 502 | _, pred = output.topk(maxk, 1, True, True) 503 | pred = pred.t() 504 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 505 | 506 | res = [] 507 | for k in topk: 508 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 509 | res.append(correct_k.mul_(100.0 / batch_size)) 510 | return res 511 | 512 | 513 | if __name__ == '__main__': 514 | main() 515 | -------------------------------------------------------------------------------- /classification/main_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import PIL 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from Sampling import CS_Sampling 22 | from tensorboardX import SummaryWriter 23 | import torchvision 24 | from vit import ViT, load_pretrained_weights 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('--data', default='/gdata/ImageNet2012', 28 | help='path to dataset') 29 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 30 | help='model architecture (default: resnet18)') 31 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | parser.add_argument('--epochs', default=1000, type=int, metavar='N', 34 | help='number of total epochs to run') 35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 36 | help='manual epoch number (useful on restarts)') 37 | parser.add_argument('-b', '--batch-size', default=256, type=int, 38 | metavar='N', 39 | help='mini-batch size (default: 256), this is the total ' 40 | 'batch size of all GPUs on the current node when ' 41 | 'using Data Parallel or Distributed Data Parallel') 42 | # parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 43 | # metavar='LR', help='initial learning rate', dest='lr') 44 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 45 | metavar='LR', help='initial learning rate', dest='lr') 46 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 47 | help='momentum') 48 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 49 | metavar='W', help='weight decay (default: 1e-4)', 50 | dest='weight_decay') 51 | parser.add_argument('-p', '--print-freq', default=40, type=int, 52 | metavar='N', help='print frequency (default: 10)') 53 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 54 | help='path to latest checkpoint (default: none)') 55 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 56 | help='evaluate model on validation set') 57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 58 | help='use pre-trained model') 59 | parser.add_argument('--world-size', default=-1, type=int, 60 | help='number of nodes for distributed training') 61 | parser.add_argument('--rank', default=-1, type=int, 62 | help='node rank for distributed training') 63 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 64 | help='url used to set up distributed training') 65 | parser.add_argument('--dist-backend', default='nccl', type=str, 66 | help='distributed backend') 67 | parser.add_argument('--seed', default=None, type=int, 68 | help='seed for initializing training. ') 69 | parser.add_argument('--gpu', default=None, type=int, 70 | help='GPU id to use.') 71 | parser.add_argument('--image_size', default=224, type=int, 72 | help='image size') 73 | parser.add_argument('--vit', default=True, help='use ViT model') 74 | parser.add_argument('--multiprocessing-distributed', action='store_true', 75 | help='Use multi-processing distributed training to launch ' 76 | 'N processes per node, which has N GPUs. This is the ' 77 | 'fastest way to use PyTorch for either single node or ' 78 | 'multi node data parallel training') 79 | parser.add_argument('--cs', type=int, default=0) 80 | parser.add_argument('--log_dir', default='logs') 81 | parser.add_argument('--mm', type=int, default=0) 82 | parser.add_argument('--save_path', type=str, default='ckp') 83 | parser.add_argument('--rat', type=float, default=0.1) 84 | parser.add_argument('--devices', type=int, default=4) 85 | parser.add_argument('--psize', type=int, default=32) 86 | parser.add_argument('--weights_path', type=str, default='pretrain/B_32_imagenet1k.pth') 87 | parser.add_argument('--cs_mode', type=str, default='mcl') 88 | best_acc1 = 0 89 | 90 | 91 | def main(): 92 | args = parser.parse_args() 93 | 94 | if args.seed is not None: 95 | random.seed(args.seed) 96 | torch.manual_seed(args.seed) 97 | cudnn.deterministic = True 98 | warnings.warn('You have chosen to seed training. ' 99 | 'This will turn on the CUDNN deterministic setting, ' 100 | 'which can slow down your training considerably! ' 101 | 'You may see unexpected behavior when restarting ' 102 | 'from checkpoints.') 103 | 104 | if args.gpu is not None: 105 | warnings.warn('You have chosen a specific GPU. This will completely ' 106 | 'disable data parallelism.') 107 | 108 | if args.dist_url == "env://" and args.world_size == -1: 109 | args.world_size = int(os.environ["WORLD_SIZE"]) 110 | 111 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 112 | 113 | ngpus_per_node = torch.cuda.device_count() 114 | if args.multiprocessing_distributed: 115 | # Since we have ngpus_per_node processes per node, the total world_size 116 | # needs to be adjusted accordingly 117 | args.world_size = ngpus_per_node * args.world_size 118 | # Use torch.multiprocessing.spawn to launch distributed processes: the 119 | # main_worker process function 120 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 121 | else: 122 | # Simply call main_worker function 123 | main_worker(args.gpu, ngpus_per_node, args) 124 | 125 | 126 | def main_worker(gpu, ngpus_per_node, args): 127 | # print(args.cs,args.lr) 128 | # exit(0) 129 | global best_acc1 130 | args.gpu = gpu 131 | patch_size = args.image_size // 16 132 | writer_loss = SummaryWriter(args.log_dir) 133 | writer_acc = SummaryWriter(args.log_dir) 134 | if os.path.isdir(args.save_path) == False: 135 | os.mkdir(args.save_path) 136 | # print('patch_size:', patch_size) 137 | 138 | if args.gpu is not None: 139 | print("Use GPU: {} for training".format(args.gpu)) 140 | 141 | if args.distributed: 142 | if args.dist_url == "env://" and args.rank == -1: 143 | args.rank = int(os.environ["RANK"]) 144 | if args.multiprocessing_distributed: 145 | # For multiprocessing distributed training, rank needs to be the 146 | # global rank among all the processes 147 | args.rank = args.rank * ngpus_per_node + gpu 148 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 149 | world_size=args.world_size, rank=args.rank) 150 | 151 | # NEW 152 | if args.vit: 153 | model = ViT(args.arch, pretrained=True, image_size=(args.image_size, args.image_size), num_classes=1000, 154 | weights_path=args.weights_path) 155 | else: 156 | model = models.__dict__[args.arch](pretrained=args.pretrained) 157 | 158 | if args.cs == 1: 159 | print('CS ratio=', args.rat) 160 | cs_sampling = CS_Sampling(n_channels=3, cs_ratio=args.rat, blocksize=32, im_size=args.image_size).cuda() 161 | cs_sampling = torch.nn.DataParallel(cs_sampling, range(args.devices)) 162 | else: 163 | cs_sampling = None 164 | print("=> using model '{}' (pretrained={})".format(args.arch, args.pretrained)) 165 | 166 | if args.distributed: 167 | # For multiprocessing distributed, DistributedDataParallel constructor 168 | # should always set the single device scope, otherwise, 169 | # DistributedDataParallel will use all available devices. 170 | if args.gpu is not None: 171 | torch.cuda.set_device(args.gpu) 172 | model.cuda(args.gpu) 173 | # When using a single GPU per process and per 174 | # DistributedDataParallel, we need to divide the batch size 175 | # ourselves based on the total number of GPUs we have 176 | args.batch_size = int(args.batch_size / ngpus_per_node) 177 | args.workers = int(args.workers / ngpus_per_node) 178 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.devices]) # [args.gpu]) 179 | else: 180 | model.cuda() 181 | # DistributedDataParallel will divide and allocate batch_size to all 182 | # available GPUs if device_ids are not set 183 | model = torch.nn.parallel.DistributedDataParallel(model) 184 | elif args.gpu is not None: 185 | print('No distribut') 186 | # torch.cuda.set_device(args.gpu) 187 | model.cuda() 188 | model = torch.nn.DataParallel(model, range(args.devices)) 189 | # model = model.cuda(args.gpu) 190 | else: 191 | # DataParallel will divide and allocate batch_size to all available GPUs 192 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 193 | model.features = torch.nn.DataParallel(model.features) 194 | model.cuda() 195 | else: 196 | model = torch.nn.DataParallel(model).cuda() 197 | 198 | # define loss function (criterion) and optimizer 199 | criterion = nn.CrossEntropyLoss().cuda() # (args.gpu) 200 | if args.cs == 1: 201 | if args.mm == 1: 202 | print('MM!') 203 | optimizer = torch.optim.SGD( 204 | [{'params': model.parameters()}, 205 | {'params': cs_sampling.parameters()}] 206 | , args.lr, 207 | momentum=args.momentum, 208 | weight_decay=args.weight_decay) 209 | print('scheduler') 210 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5, 211 | last_epoch=-1) # MultiStepLR(optimizer, milestones=[80,120], gamma=0.1) 212 | else: 213 | optimizer = torch.optim.SGD(cs_sampling.parameters(), args.lr, 214 | momentum=args.momentum, 215 | weight_decay=args.weight_decay) 216 | else: 217 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 218 | momentum=args.momentum, 219 | weight_decay=args.weight_decay) 220 | 221 | # optionally resume from a checkpoint 222 | if args.resume: 223 | if os.path.isfile(args.resume): 224 | print("=> loading checkpoint '{}'".format(args.resume)) 225 | checkpoint = torch.load(args.resume) 226 | args.start_epoch = checkpoint['epoch'] 227 | best_acc1 = checkpoint['best_acc1'] 228 | if args.gpu is not None: 229 | # best_acc1 may be from a checkpoint from a different GPU 230 | best_acc1 = best_acc1.cuda() # to(args.gpu) 231 | model.load_state_dict(checkpoint['state_dict']) 232 | optimizer.load_state_dict(checkpoint['optimizer']) 233 | print("=> loaded checkpoint '{}' (epoch {})" 234 | .format(args.resume, checkpoint['epoch'])) 235 | else: 236 | print("=> no checkpoint found at '{}'".format(args.resume)) 237 | 238 | cudnn.benchmark = True 239 | 240 | # Data loading code 241 | traindir = os.path.join(args.data, 'train') 242 | valdir = os.path.join(args.data, 'val') 243 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 244 | normalize = transforms.Normalize(0.5, 0.5) 245 | 246 | train_dataset = datasets.ImageFolder( 247 | traindir, 248 | transforms.Compose([ 249 | transforms.RandomResizedCrop(args.image_size), 250 | transforms.RandomHorizontalFlip(), 251 | transforms.ToTensor(), 252 | normalize, 253 | ])) 254 | 255 | if args.distributed: 256 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 257 | else: 258 | train_sampler = None 259 | 260 | train_loader = torch.utils.data.DataLoader( 261 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 262 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 263 | 264 | val_transforms = transforms.Compose([ 265 | transforms.Resize(args.image_size, interpolation=PIL.Image.BICUBIC), 266 | transforms.CenterCrop(args.image_size), 267 | transforms.ToTensor(), 268 | normalize, 269 | ]) 270 | print('Using image size', args.image_size) 271 | 272 | val_loader = torch.utils.data.DataLoader( 273 | datasets.ImageFolder(valdir, val_transforms), 274 | batch_size=args.batch_size, shuffle=False, 275 | num_workers=args.workers, pin_memory=True) 276 | 277 | if args.evaluate: 278 | res = validate(val_loader, model, criterion, args, cs_sampling) 279 | with open('res.txt', 'w') as f: 280 | print(res, file=f) 281 | return 282 | 283 | for epoch in range(args.start_epoch, args.epochs): 284 | scheduler.step() 285 | print(epoch, 'Learning rate: ', optimizer.param_groups[0]['lr']) 286 | if args.distributed: 287 | train_sampler.set_epoch(epoch) 288 | # adjust_learning_rate(optimizer, epoch, args) 289 | 290 | # train for one epoch 291 | train(train_loader, model, criterion, optimizer, epoch, args, writer_loss, writer_acc, cs_sampling) 292 | 293 | # evaluate on validation set 294 | acc1 = validate(val_loader, model, criterion, args, cs_sampling) 295 | 296 | # remember best acc@1 and save checkpoint 297 | is_best = acc1 > best_acc1 298 | best_acc1 = max(acc1, best_acc1) 299 | 300 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 301 | and args.rank % ngpus_per_node == 0): 302 | save_checkpoint({ 303 | 'epoch': epoch + 1, 304 | 'arch': args.arch, 305 | 'state_dict': model.state_dict(), 306 | 'best_acc1': best_acc1, 307 | 'optimizer': optimizer.state_dict(), 308 | }, is_best, args.save_path + '/ckp_' + str(epoch)) 309 | if args.cs == 1: 310 | save_checkpoint({ 311 | 'epoch': epoch + 1, 312 | 'arch': args.arch, 313 | 'state_dict': cs_sampling.state_dict(), 314 | 'best_acc1': best_acc1, 315 | 'optimizer': optimizer.state_dict(), 316 | }, is_best, args.save_path + '/ckp_CS_' + str(epoch)) 317 | 318 | 319 | def train(train_loader, model, criterion, optimizer, epoch, args, writer_loss, writer_acc, cs_sampling): 320 | print('Start training.') 321 | batch_time = AverageMeter('Time', ':6.3f') 322 | data_time = AverageMeter('Data', ':6.3f') 323 | losses = AverageMeter('Loss', ':.4e') 324 | top1 = AverageMeter('Acc@1', ':6.2f') 325 | top5 = AverageMeter('Acc@5', ':6.2f') 326 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 327 | top5, prefix="Epoch: [{}]".format(epoch)) 328 | 329 | # switch to train mode 330 | model.train() 331 | 332 | end = time.time() 333 | step = 0 334 | l_ = len(train_loader) 335 | for i, (images, target) in enumerate(train_loader): 336 | # print(images.shape) 337 | # measure data loading time 338 | data_time.update(time.time() - end) 339 | 340 | if args.gpu is not None: 341 | images = images.cuda() # (args.gpu, non_blocking=True) 342 | if cs_sampling: 343 | # print(cs_sampling) 344 | images = cs_sampling(images) 345 | # exit(0) 346 | target = target.cuda() # (args.gpu, non_blocking=True) 347 | 348 | # compute output 349 | output = model(images) 350 | loss = criterion(output, target) 351 | 352 | # measure accuracy and record loss 353 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 354 | losses.update(loss.item(), images.size(0)) 355 | top1.update(acc1[0], images.size(0)) 356 | top5.update(acc5[0], images.size(0)) 357 | # print('test:',acc1[0],acc5[0],losses) 358 | # exit(0) 359 | 360 | # compute gradient and do SGD step 361 | optimizer.zero_grad() 362 | loss.backward() 363 | optimizer.step() 364 | 365 | # measure elapsed time 366 | batch_time.update(time.time() - end) 367 | end = time.time() 368 | 369 | if i % args.print_freq == 0: 370 | writer_loss.add_scalar('Loss', loss.item(), step + (1000 * 128 // args.batch_size) * epoch) 371 | writer_loss.add_scalar('ACC', acc1[0], step + (1000 * 128 // args.batch_size) * epoch) 372 | writer_loss.add_scalar('ACC_avg', top1.avg, step + (1000 * 128 // args.batch_size) * epoch) 373 | progress.print(i) 374 | step += 1 375 | 376 | 377 | def validate(val_loader, model, criterion, args, cs_sampling): 378 | batch_time = AverageMeter('Time', ':6.3f') 379 | losses = AverageMeter('Loss', ':.4e') 380 | top1 = AverageMeter('Acc@1', ':6.2f') 381 | top5 = AverageMeter('Acc@5', ':6.2f') 382 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 383 | prefix='Test: ') 384 | 385 | # switch to evaluate mode 386 | model.eval() 387 | 388 | with torch.no_grad(): 389 | end = time.time() 390 | for i, (images, target) in enumerate(val_loader): 391 | if args.gpu is not None: 392 | images = images.cuda() # (args.gpu, non_blocking=True) 393 | target = target.cuda() # (args.gpu,non_blocking=True) 394 | 395 | # compute output 396 | if cs_sampling: 397 | images = cs_sampling(images) 398 | output = model(images) 399 | loss = criterion(output, target) 400 | 401 | # measure accuracy and record loss 402 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 403 | losses.update(loss.item(), images.size(0)) 404 | top1.update(acc1[0], images.size(0)) 405 | top5.update(acc5[0], images.size(0)) 406 | 407 | # measure elapsed time 408 | batch_time.update(time.time() - end) 409 | end = time.time() 410 | 411 | if i % args.print_freq == 0: 412 | progress.print(i) 413 | 414 | # TODO: this should also be done with the ProgressMeter 415 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 416 | .format(top1=top1, top5=top5)) 417 | 418 | return top1.avg 419 | 420 | 421 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 422 | torch.save(state, filename) 423 | if is_best: 424 | shutil.copyfile(filename, 'model_best.pth.tar') 425 | 426 | 427 | class AverageMeter(object): 428 | """Computes and stores the average and current value""" 429 | 430 | def __init__(self, name, fmt=':f'): 431 | self.name = name 432 | self.fmt = fmt 433 | self.reset() 434 | 435 | def reset(self): 436 | self.val = 0 437 | self.avg = 0 438 | self.sum = 0 439 | self.count = 0 440 | 441 | def update(self, val, n=1): 442 | self.val = val 443 | self.sum += val * n 444 | self.count += n 445 | self.avg = self.sum / self.count 446 | 447 | def __str__(self): 448 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 449 | return fmtstr.format(**self.__dict__) 450 | 451 | 452 | class ProgressMeter(object): 453 | def __init__(self, num_batches, *meters, prefix=""): 454 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 455 | self.meters = meters 456 | self.prefix = prefix 457 | 458 | def print(self, batch): 459 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 460 | entries += [str(meter) for meter in self.meters] 461 | print('\t'.join(entries)) 462 | 463 | def _get_batch_fmtstr(self, num_batches): 464 | num_digits = len(str(num_batches // 1)) 465 | fmt = '{:' + str(num_digits) + 'd}' 466 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 467 | 468 | 469 | def adjust_learning_rate(optimizer, epoch, args): 470 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 471 | lr = args.lr * (0.1 ** (epoch // 30)) 472 | for param_group in optimizer.param_groups: 473 | param_group['lr'] = lr 474 | 475 | 476 | def accuracy(output, target, topk=(1,)): 477 | """Computes the accuracy over the k top predictions for the specified values of k""" 478 | with torch.no_grad(): 479 | maxk = max(topk) 480 | batch_size = target.size(0) 481 | 482 | _, pred = output.topk(maxk, 1, True, True) 483 | pred = pred.t() 484 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 485 | 486 | res = [] 487 | for k in topk: 488 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 489 | res.append(correct_k.mul_(100.0 / batch_size)) 490 | return res 491 | 492 | 493 | if __name__ == '__main__': 494 | main() -------------------------------------------------------------------------------- /classification/test_arb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import PIL 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from Sampling import CS_Sampling_arb as CS_Sampling 22 | 23 | from vit import ViT, load_pretrained_weights 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 26 | parser.add_argument('--data', default='/group/30042/public_datasets/imagenet1k', 27 | help='path to dataset') 28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 29 | help='model architecture (default: resnet18)') 30 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 31 | help='number of data loading workers (default: 4)') 32 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 33 | help='number of total epochs to run') 34 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 35 | help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=256, type=int, 37 | metavar='N', 38 | help='mini-batch size (default: 256), this is the total ' 39 | 'batch size of all GPUs on the current node when ' 40 | 'using Data Parallel or Distributed Data Parallel') 41 | # parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 42 | # metavar='LR', help='initial learning rate', dest='lr') 43 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 44 | metavar='LR', help='initial learning rate', dest='lr') 45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 46 | help='momentum') 47 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 48 | metavar='W', help='weight decay (default: 1e-4)', 49 | dest='weight_decay') 50 | parser.add_argument('-p', '--print-freq', default=10, type=int, 51 | metavar='N', help='print frequency (default: 10)') 52 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 53 | help='path to latest checkpoint (default: none)') 54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 55 | help='evaluate model on validation set') 56 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 57 | help='use pre-trained model') 58 | parser.add_argument('--world-size', default=-1, type=int, 59 | help='number of nodes for distributed training') 60 | parser.add_argument('--rank', default=-1, type=int, 61 | help='node rank for distributed training') 62 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 63 | help='url used to set up distributed training') 64 | parser.add_argument('--dist-backend', default='nccl', type=str, 65 | help='distributed backend') 66 | parser.add_argument('--seed', default=None, type=int, 67 | help='seed for initializing training. ') 68 | parser.add_argument('--gpu', default=None, type=int, 69 | help='GPU id to use.') 70 | parser.add_argument('--image_size', default=224, type=int, 71 | help='image size') 72 | parser.add_argument('--vit', default=True, help='use ViT model') 73 | parser.add_argument('--multiprocessing-distributed', action='store_true', 74 | help='Use multi-processing distributed training to launch ' 75 | 'N processes per node, which has N GPUs. This is the ' 76 | 'fastest way to use PyTorch for either single node or ' 77 | 'multi node data parallel training') 78 | parser.add_argument('--rat',type=float,default=0.1) 79 | parser.add_argument('--devices',type=int,default=4) 80 | parser.add_argument('--ckp_vit',type=str,default='/group/30042/chongmou/ft_local/TransCL/TransCL/ft_local/arb/ckp_85') 81 | parser.add_argument('--ckp_cs',type=str,default='/group/30042/chongmou/ft_local/TransCL/TransCL/ft_local/arb/ckp_CS_85') 82 | parser.add_argument('--psize',type=int,default=32) 83 | best_acc1 = 0 84 | 85 | 86 | def main(): 87 | args = parser.parse_args() 88 | 89 | if args.seed is not None: 90 | random.seed(args.seed) 91 | torch.manual_seed(args.seed) # 神经网络默认随机初始化,设定seed之后神经网络每次初始化是相同的 92 | cudnn.deterministic = True 93 | warnings.warn('You have chosen to seed training. ' 94 | 'This will turn on the CUDNN deterministic setting, ' 95 | 'which can slow down your training considerably! ' 96 | 'You may see unexpected behavior when restarting ' 97 | 'from checkpoints.') 98 | 99 | if args.gpu is not None: 100 | warnings.warn('You have chosen a specific GPU. This will completely ' 101 | 'disable data parallelism.') 102 | 103 | if args.dist_url == "env://" and args.world_size == -1: 104 | args.world_size = int(os.environ["WORLD_SIZE"]) 105 | 106 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 107 | 108 | ngpus_per_node = torch.cuda.device_count() 109 | if args.multiprocessing_distributed: 110 | # Since we have ngpus_per_node processes per node, the total world_size 111 | # needs to be adjusted accordingly 112 | args.world_size = ngpus_per_node * args.world_size 113 | # Use torch.multiprocessing.spawn to launch distributed processes: the 114 | # main_worker process function 115 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 116 | else: 117 | # Simply call main_worker function 118 | main_worker(args.gpu, ngpus_per_node, args) 119 | 120 | 121 | def main_worker(gpu, ngpus_per_node, args): 122 | # print(args.cs,args.lr) 123 | # exit(0) 124 | global best_acc1 125 | args.gpu = gpu 126 | patch_size = args.image_size//16 127 | # print('patch_size:', patch_size) 128 | 129 | if args.gpu is not None: 130 | print("Use GPU: {} for testing".format(args.gpu)) 131 | 132 | if args.distributed: 133 | if args.dist_url == "env://" and args.rank == -1: 134 | args.rank = int(os.environ["RANK"]) 135 | if args.multiprocessing_distributed: 136 | # For multiprocessing distributed training, rank needs to be the 137 | # global rank among all the processes 138 | args.rank = args.rank * ngpus_per_node + gpu 139 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 140 | world_size=args.world_size, rank=args.rank) 141 | 142 | # NEW 143 | if args.vit: 144 | # print(args.image_size) 145 | model = ViT(args.arch, pretrained=False,image_size=(args.image_size,args.image_size)) 146 | 147 | # # NOTE: This is for debugging 148 | # model = ViT('B_16_imagenet1k', pretrained=False) 149 | # load_pretrained_weights(model, weights_path='/home/luke/projects/experiments/ViT-PyTorch/jax_to_pytorch/weights/B_16_imagenet1k.pth') 150 | 151 | else: 152 | model = models.__dict__[args.arch](pretrained=args.pretrained) 153 | 154 | list_rate_rm=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9] 155 | 156 | print('Create CS model.') 157 | print('CS ratio=', args.rat) 158 | cs_sampling = CS_Sampling(n_channels=3, cs_ratio=args.rat, blocksize=args.psize).cuda() 159 | cs_sampling = torch.nn.DataParallel(cs_sampling,range(args.devices)) 160 | cs_sampling.load_state_dict(torch.load(args.ckp_cs)['state_dict']) 161 | print('Load pretrained cs model from: ', args.ckp_cs) 162 | 163 | model.cuda() 164 | model = torch.nn.DataParallel(model,range(args.devices)) 165 | model.load_state_dict(torch.load(args.ckp_vit)['state_dict']) 166 | print('Load pretrained vit model from: ', args.ckp_vit) 167 | 168 | criterion = nn.CrossEntropyLoss().cuda() 169 | 170 | cudnn.benchmark = True 171 | 172 | # Data loading code 173 | valdir = os.path.join(args.data, 'val') 174 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 175 | normalize = transforms.Normalize(0.5, 0.5) 176 | 177 | print('Load test set.') 178 | val_transforms = transforms.Compose([ 179 | transforms.Resize(args.image_size, interpolation=PIL.Image.BICUBIC), 180 | transforms.CenterCrop(args.image_size), 181 | transforms.ToTensor(), 182 | normalize, 183 | ]) 184 | print('Using image size', args.image_size) 185 | 186 | val_loader = torch.utils.data.DataLoader( 187 | datasets.ImageFolder(valdir, val_transforms), 188 | batch_size=args.batch_size, shuffle=False, 189 | num_workers=args.workers, pin_memory=True) 190 | 191 | print('Start test.') 192 | for num_rows in range(1,1024): 193 | # print('rate_rm: ', rate_rm) 194 | # cs_sampling.rate_rm=rate_rm 195 | # cs_sampling.module.updata_rat(rate_rm) 196 | # print(cs_sampling.module.rate_rm) 197 | acc1 = validate(val_loader, model, criterion, args, cs_sampling,num_rows) 198 | 199 | 200 | def validate(val_loader, model, criterion, args,cs_sampling,num_rows): 201 | batch_time = AverageMeter('Time', ':6.3f') 202 | losses = AverageMeter('Loss', ':.4e') 203 | top1 = AverageMeter('Acc@1', ':6.2f') 204 | top5 = AverageMeter('Acc@5', ':6.2f') 205 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 206 | prefix='Test: ') 207 | 208 | # switch to evaluate mode 209 | model.eval() 210 | file = open('test_arb.txt','a') 211 | with torch.no_grad(): 212 | end = time.time() 213 | for i, (images, target) in enumerate(val_loader): 214 | # file.write(' * Acc@1'+'\n') 215 | if args.gpu is not None: 216 | images = images.cuda()#(args.gpu, non_blocking=True) 217 | target = target.cuda()#(args.gpu,non_blocking=True) 218 | 219 | # compute output 220 | # if cs_sampling: 221 | images = cs_sampling(images, num_rows) 222 | output = model(images) 223 | loss = criterion(output, target) 224 | 225 | # measure accuracy and record loss 226 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 227 | losses.update(loss.item(), images.size(0)) 228 | top1.update(acc1[0], images.size(0)) 229 | top5.update(acc5[0], images.size(0)) 230 | 231 | # measure elapsed time 232 | batch_time.update(time.time() - end) 233 | end = time.time() 234 | 235 | if i % args.print_freq == 0: 236 | progress.print(i) 237 | 238 | # TODO: this should also be done with the ProgressMeter 239 | print(str(num_rows)+': * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 240 | .format(top1=top1, top5=top5)) 241 | file.write(str(num_rows)+': * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 242 | .format(top1=top1, top5=top5)+'\n') 243 | file.close() 244 | 245 | return top1.avg 246 | 247 | 248 | class AverageMeter(object): 249 | """Computes and stores the average and current value""" 250 | def __init__(self, name, fmt=':f'): 251 | self.name = name 252 | self.fmt = fmt 253 | self.reset() 254 | 255 | def reset(self): 256 | self.val = 0 257 | self.avg = 0 258 | self.sum = 0 259 | self.count = 0 260 | 261 | def update(self, val, n=1): 262 | self.val = val 263 | self.sum += val * n 264 | self.count += n 265 | self.avg = self.sum / self.count 266 | 267 | def __str__(self): 268 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 269 | return fmtstr.format(**self.__dict__) 270 | 271 | 272 | class ProgressMeter(object): 273 | def __init__(self, num_batches, *meters, prefix=""): 274 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 275 | self.meters = meters 276 | self.prefix = prefix 277 | 278 | def print(self, batch): 279 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 280 | entries += [str(meter) for meter in self.meters] 281 | print('\t'.join(entries)) 282 | 283 | def _get_batch_fmtstr(self, num_batches): 284 | num_digits = len(str(num_batches // 1)) 285 | fmt = '{:' + str(num_digits) + 'd}' 286 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 287 | 288 | 289 | def adjust_learning_rate(optimizer, epoch, args): 290 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 291 | lr = args.lr * (0.1 ** (epoch // 30)) 292 | for param_group in optimizer.param_groups: 293 | param_group['lr'] = lr 294 | 295 | 296 | def accuracy(output, target, topk=(1,)): 297 | """Computes the accuracy over the k top predictions for the specified values of k""" 298 | with torch.no_grad(): 299 | maxk = max(topk) 300 | batch_size = target.size(0) 301 | 302 | _, pred = output.topk(maxk, 1, True, True) 303 | pred = pred.t() 304 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 305 | 306 | res = [] 307 | for k in topk: 308 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 309 | res.append(correct_k.mul_(100.0 / batch_size)) 310 | return res 311 | 312 | 313 | if __name__ == '__main__': 314 | main() 315 | -------------------------------------------------------------------------------- /classification/test_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import PIL 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from Sampling import CS_Sampling 22 | from vit import ViT, load_pretrained_weights 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 25 | parser.add_argument('--data', default='/group/30042/public_datasets/imagenet1k', 26 | help='path to dataset') 27 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 28 | help='model architecture (default: resnet18)') 29 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 32 | help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch-size', default=256, type=int, 36 | metavar='N', 37 | help='mini-batch size (default: 256), this is the total ' 38 | 'batch size of all GPUs on the current node when ' 39 | 'using Data Parallel or Distributed Data Parallel') 40 | # parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 41 | # metavar='LR', help='initial learning rate', dest='lr') 42 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 43 | metavar='LR', help='initial learning rate', dest='lr') 44 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 45 | help='momentum') 46 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 47 | metavar='W', help='weight decay (default: 1e-4)', 48 | dest='weight_decay') 49 | parser.add_argument('-p', '--print-freq', default=10, type=int, 50 | metavar='N', help='print frequency (default: 10)') 51 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 52 | help='path to latest checkpoint (default: none)') 53 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 54 | help='evaluate model on validation set') 55 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 56 | help='use pre-trained model') 57 | parser.add_argument('--world-size', default=-1, type=int, 58 | help='number of nodes for distributed training') 59 | parser.add_argument('--rank', default=-1, type=int, 60 | help='node rank for distributed training') 61 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 62 | help='url used to set up distributed training') 63 | parser.add_argument('--dist-backend', default='nccl', type=str, 64 | help='distributed backend') 65 | parser.add_argument('--seed', default=None, type=int, 66 | help='seed for initializing training. ') 67 | parser.add_argument('--gpu', default=None, type=int, 68 | help='GPU id to use.') 69 | parser.add_argument('--image_size', default=224, type=int, 70 | help='image size') 71 | parser.add_argument('--vit', default=True, help='use ViT model') 72 | parser.add_argument('--multiprocessing-distributed', action='store_true', 73 | help='Use multi-processing distributed training to launch ' 74 | 'N processes per node, which has N GPUs. This is the ' 75 | 'fastest way to use PyTorch for either single node or ' 76 | 'multi node data parallel training') 77 | parser.add_argument('--cs',type=int,default=0) 78 | parser.add_argument('--mm',type=int,default=0) 79 | parser.add_argument('--save_path',type=str,default='ckp') 80 | parser.add_argument('--rat',type=float,default=0.1) 81 | parser.add_argument('--devices',type=int,default=4) 82 | parser.add_argument('--ckp_vit',type=str,default='ckp/ckp_89') 83 | parser.add_argument('--ckp_cs',type=str,default='ckp/ckp_CS_89') 84 | best_acc1 = 0 85 | 86 | 87 | def main(): 88 | args = parser.parse_args() 89 | 90 | if args.seed is not None: 91 | random.seed(args.seed) 92 | torch.manual_seed(args.seed) 93 | cudnn.deterministic = True 94 | warnings.warn('You have chosen to seed training. ' 95 | 'This will turn on the CUDNN deterministic setting, ' 96 | 'which can slow down your training considerably! ' 97 | 'You may see unexpected behavior when restarting ' 98 | 'from checkpoints.') 99 | 100 | if args.gpu is not None: 101 | warnings.warn('You have chosen a specific GPU. This will completely ' 102 | 'disable data parallelism.') 103 | 104 | if args.dist_url == "env://" and args.world_size == -1: 105 | args.world_size = int(os.environ["WORLD_SIZE"]) 106 | 107 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 108 | 109 | ngpus_per_node = torch.cuda.device_count() 110 | if args.multiprocessing_distributed: 111 | # Since we have ngpus_per_node processes per node, the total world_size 112 | # needs to be adjusted accordingly 113 | args.world_size = ngpus_per_node * args.world_size 114 | # Use torch.multiprocessing.spawn to launch distributed processes: the 115 | # main_worker process function 116 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 117 | else: 118 | # Simply call main_worker function 119 | main_worker(args.gpu, ngpus_per_node, args) 120 | 121 | 122 | def main_worker(gpu, ngpus_per_node, args): 123 | global best_acc1 124 | args.gpu = gpu 125 | if os.path.isdir(args.save_path)==False: 126 | os.mkdir(args.save_path) 127 | 128 | if args.gpu is not None: 129 | print("Use GPU: {} for training".format(args.gpu)) 130 | 131 | if args.distributed: 132 | if args.dist_url == "env://" and args.rank == -1: 133 | args.rank = int(os.environ["RANK"]) 134 | if args.multiprocessing_distributed: 135 | # For multiprocessing distributed training, rank needs to be the 136 | # global rank among all the processes 137 | args.rank = args.rank * ngpus_per_node + gpu 138 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 139 | world_size=args.world_size, rank=args.rank) 140 | 141 | # NEW 142 | if args.vit: 143 | # print(args.image_size) 144 | model = ViT(args.arch, pretrained=False,image_size=(args.image_size,args.image_size)) 145 | 146 | # # NOTE: This is for debugging 147 | # model = ViT('B_16_imagenet1k', pretrained=False) 148 | # load_pretrained_weights(model, weights_path='/home/luke/projects/experiments/ViT-PyTorch/jax_to_pytorch/weights/B_16_imagenet1k.pth') 149 | 150 | else: 151 | model = models.__dict__[args.arch](pretrained=args.pretrained) 152 | 153 | print('Create CS model.') 154 | print('CS ratio=',args.rat) 155 | cs_sampling = CS_Sampling(n_channels=3, cs_ratio=args.rat, blocksize=32).cuda() 156 | cs_sampling = torch.nn.DataParallel(cs_sampling,range(args.devices)) 157 | cs_sampling.load_state_dict(torch.load(args.ckp_cs)['state_dict']) 158 | print('Load pretrained cs model from: ', args.ckp_cs) 159 | 160 | model.cuda() 161 | model = torch.nn.DataParallel(model,range(args.devices)) 162 | model.load_state_dict(torch.load(args.ckp_vit)['state_dict']) 163 | print('Load pretrained vit model from: ', args.ckp_vit) 164 | 165 | criterion = nn.CrossEntropyLoss().cuda() 166 | 167 | cudnn.benchmark = True 168 | 169 | # Data loading code 170 | valdir = os.path.join(args.data, 'val') 171 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 172 | normalize = transforms.Normalize(0.5, 0.5) 173 | 174 | print('Load test set.') 175 | val_transforms = transforms.Compose([ 176 | transforms.Resize(args.image_size, interpolation=PIL.Image.BICUBIC), 177 | transforms.CenterCrop(args.image_size), 178 | transforms.ToTensor(), 179 | normalize, 180 | ]) 181 | print('Using image size', args.image_size) 182 | 183 | val_loader = torch.utils.data.DataLoader( 184 | datasets.ImageFolder(valdir, val_transforms), 185 | batch_size=args.batch_size, shuffle=False, 186 | num_workers=args.workers, pin_memory=True) 187 | 188 | print('Start test.') 189 | acc1 = validate(val_loader, model, criterion, args, cs_sampling) 190 | 191 | 192 | def validate(val_loader, model, criterion, args,cs_sampling): 193 | batch_time = AverageMeter('Time', ':6.3f') 194 | losses = AverageMeter('Loss', ':.4e') 195 | top1 = AverageMeter('Acc@1', ':6.2f') 196 | top5 = AverageMeter('Acc@5', ':6.2f') 197 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 198 | prefix='Test: ') 199 | 200 | # switch to evaluate mode 201 | model.eval() 202 | 203 | with torch.no_grad(): 204 | end = time.time() 205 | for i, (images, target) in enumerate(val_loader): 206 | if args.gpu is not None: 207 | images = images.cuda()#(args.gpu, non_blocking=True) 208 | target = target.cuda()#(args.gpu,non_blocking=True) 209 | 210 | # compute output 211 | if cs_sampling: 212 | images = cs_sampling(images) 213 | output = model(images) 214 | loss = criterion(output, target) 215 | 216 | # measure accuracy and record loss 217 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 218 | losses.update(loss.item(), images.size(0)) 219 | top1.update(acc1[0], images.size(0)) 220 | top5.update(acc5[0], images.size(0)) 221 | 222 | # measure elapsed time 223 | batch_time.update(time.time() - end) 224 | end = time.time() 225 | 226 | if i % args.print_freq == 0: 227 | progress.print(i) 228 | 229 | # TODO: this should also be done with the ProgressMeter 230 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 231 | .format(top1=top1, top5=top5)) 232 | 233 | return top1.avg 234 | 235 | 236 | class AverageMeter(object): 237 | """Computes and stores the average and current value""" 238 | def __init__(self, name, fmt=':f'): 239 | self.name = name 240 | self.fmt = fmt 241 | self.reset() 242 | 243 | def reset(self): 244 | self.val = 0 245 | self.avg = 0 246 | self.sum = 0 247 | self.count = 0 248 | 249 | def update(self, val, n=1): 250 | self.val = val 251 | self.sum += val * n 252 | self.count += n 253 | self.avg = self.sum / self.count 254 | 255 | def __str__(self): 256 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 257 | return fmtstr.format(**self.__dict__) 258 | 259 | 260 | class ProgressMeter(object): 261 | def __init__(self, num_batches, *meters, prefix=""): 262 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 263 | self.meters = meters 264 | self.prefix = prefix 265 | 266 | def print(self, batch): 267 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 268 | entries += [str(meter) for meter in self.meters] 269 | print('\t'.join(entries)) 270 | 271 | def _get_batch_fmtstr(self, num_batches): 272 | num_digits = len(str(num_batches // 1)) 273 | fmt = '{:' + str(num_digits) + 'd}' 274 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 275 | 276 | 277 | def adjust_learning_rate(optimizer, epoch, args): 278 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 279 | lr = args.lr * (0.1 ** (epoch // 30)) 280 | for param_group in optimizer.param_groups: 281 | param_group['lr'] = lr 282 | 283 | 284 | def accuracy(output, target, topk=(1,)): 285 | """Computes the accuracy over the k top predictions for the specified values of k""" 286 | with torch.no_grad(): 287 | maxk = max(topk) 288 | batch_size = target.size(0) 289 | 290 | _, pred = output.topk(maxk, 1, True, True) 291 | pred = pred.t() 292 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 293 | 294 | res = [] 295 | for k in topk: 296 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 297 | res.append(correct_k.mul_(100.0 / batch_size)) 298 | return res 299 | 300 | 301 | if __name__ == '__main__': 302 | main() 303 | -------------------------------------------------------------------------------- /classification/vit/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.6" 2 | 3 | from .model import ViT 4 | from .configs import * 5 | from .utils import load_pretrained_weights -------------------------------------------------------------------------------- /classification/vit/.ipynb_checkpoints/configs-checkpoint.py: -------------------------------------------------------------------------------- 1 | """configs.py - ViT model configurations, based on: 2 | https://github.com/google-research/vision_transformer/blob/master/vit_jax/configs.py 3 | """ 4 | 5 | def get_base_config(): 6 | """Base ViT config ViT""" 7 | return dict( 8 | dim=768, 9 | ff_dim=3072, 10 | num_heads=12, 11 | num_layers=12, 12 | attention_dropout_rate=0.0, 13 | dropout_rate=0.1, 14 | representation_size=768, 15 | classifier='token' 16 | ) 17 | 18 | def get_b16_config(): 19 | """Returns the ViT-B/16 configuration.""" 20 | config = get_base_config() 21 | config.update(dict(patches=(16, 16))) 22 | return config 23 | 24 | def get_b32_config(): 25 | """Returns the ViT-B/32 configuration.""" 26 | config = get_b16_config() 27 | config.update(dict(patches=(32, 32))) 28 | return config 29 | 30 | def get_l16_config(): 31 | """Returns the ViT-L/16 configuration.""" 32 | config = get_base_config() 33 | config.update(dict( 34 | patches=(16, 16), 35 | dim=1024, 36 | ff_dim=4096, 37 | num_heads=16, 38 | num_layers=24, 39 | attention_dropout_rate=0.0, 40 | dropout_rate=0.1, 41 | representation_size=1024 42 | )) 43 | return config 44 | 45 | def get_l32_config(): 46 | """Returns the ViT-L/32 configuration.""" 47 | config = get_l16_config() 48 | config.update(dict(patches=(32, 32))) 49 | return config 50 | 51 | def drop_head_variant(config): 52 | config.update(dict(representation_size=None)) 53 | return config 54 | 55 | 56 | PRETRAINED_MODELS = { 57 | 'B_16': { 58 | 'config': get_b16_config(), 59 | 'num_classes': 21843, 60 | 'image_size': (224, 224), 61 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16.pth" 62 | }, 63 | 'B_32': { 64 | 'config': get_b32_config(), 65 | 'num_classes': 21843, 66 | 'image_size': (224, 224), 67 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32.pth" 68 | }, 69 | 'L_16': { 70 | 'config': get_l16_config(), 71 | 'num_classes': 21843, 72 | 'image_size': (224, 224), 73 | 'url': None 74 | }, 75 | 'L_32': { 76 | 'config': get_l32_config(), 77 | 'num_classes': 21843, 78 | 'image_size': (224, 224), 79 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32.pth" 80 | }, 81 | 'B_16_imagenet1k': { 82 | 'config': drop_head_variant(get_b16_config()), 83 | 'num_classes': 1000, 84 | 'image_size': (384, 384), 85 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16_imagenet1k.pth" 86 | }, 87 | 'B_32_imagenet1k': { 88 | 'config': drop_head_variant(get_b32_config()), 89 | 'num_classes': 1000, 90 | 'image_size': (384, 384), 91 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32_imagenet1k.pth" 92 | }, 93 | 'L_16_imagenet1k': { 94 | 'config': drop_head_variant(get_l16_config()), 95 | 'num_classes': 1000, 96 | 'image_size': (384, 384), 97 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_16_imagenet1k.pth" 98 | }, 99 | 'L_32_imagenet1k': { 100 | 'config': drop_head_variant(get_l32_config()), 101 | 'num_classes': 1000, 102 | 'image_size': (384, 384), 103 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32_imagenet1k.pth" 104 | }, 105 | } 106 | -------------------------------------------------------------------------------- /classification/vit/.ipynb_checkpoints/model-checkpoint.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for ViT. 2 | They are built to mirror those in the official Jax implementation. 3 | """ 4 | 5 | from typing import Optional 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from .transformer import Transformer 11 | from .utils import load_pretrained_weights, as_tuple 12 | from .configs import PRETRAINED_MODELS 13 | import numpy as np 14 | import os 15 | 16 | class PositionalEmbedding1D(nn.Module): 17 | """Adds (optionally learned) positional embeddings to the inputs.""" 18 | 19 | def __init__(self, seq_len, dim): 20 | super().__init__() 21 | self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim)) 22 | 23 | def forward(self, x): 24 | """Input has shape `(batch_size, seq_len, emb_dim)`""" 25 | return x + self.pos_embedding 26 | 27 | 28 | class ViT(nn.Module): 29 | """ 30 | Args: 31 | name (str): Model name, e.g. 'B_16' 32 | pretrained (bool): Load pretrained weights 33 | in_channels (int): Number of channels in input data 34 | num_classes (int): Number of classes, default 1000 35 | 36 | References: 37 | [1] https://openreview.net/forum?id=YicbFdNTTy 38 | """ 39 | 40 | def __init__( 41 | self, 42 | name: Optional[str] = None, 43 | pretrained: bool = False, 44 | patches: int = 16, 45 | dim: int = 768, 46 | ff_dim: int = 3072, 47 | num_heads: int = 12, 48 | num_layers: int = 12, 49 | attention_dropout_rate: float = 0.0, 50 | dropout_rate: float = 0.1, 51 | representation_size: Optional[int] = None, 52 | load_repr_layer: bool = False, 53 | classifier: str = 'token', 54 | positional_embedding: str = '1d', 55 | in_channels: int = 3, 56 | image_size: Optional[int] = None, 57 | num_classes: Optional[int] = None, 58 | weights_path = None 59 | ): 60 | super().__init__() 61 | 62 | # Configuration 63 | if name is None: 64 | check_msg = 'must specify name of pretrained model' 65 | assert not pretrained, check_msg 66 | assert not resize_positional_embedding, check_msg 67 | if num_classes is None: 68 | num_classes = 1000 69 | if image_size is None: 70 | image_size = 384 71 | else: # load pretrained model 72 | assert name in PRETRAINED_MODELS.keys(), \ 73 | 'name should be in: ' + ', '.join(PRETRAINED_MODELS.keys()) 74 | config = PRETRAINED_MODELS[name]['config'] 75 | patches = config['patches'] 76 | dim = config['dim'] 77 | ff_dim = config['ff_dim'] 78 | num_heads = config['num_heads'] 79 | num_layers = config['num_layers'] 80 | attention_dropout_rate = config['attention_dropout_rate'] 81 | dropout_rate = config['dropout_rate'] 82 | representation_size = config['representation_size'] 83 | classifier = config['classifier'] 84 | if image_size is None: 85 | image_size = PRETRAINED_MODELS[name]['image_size'] 86 | if num_classes is None: 87 | num_classes = PRETRAINED_MODELS[name]['num_classes'] 88 | self.image_size = image_size 89 | 90 | # Image and patch sizes 91 | h, w = as_tuple(image_size) # image sizes 92 | fh, fw = as_tuple(patches) # patch sizes 93 | gh, gw = h // fh, w // fw # number of patches 94 | seq_len = gh * gw 95 | # print('seq_len:', seq_len,h,w,fh,fw) 96 | # exit(0) 97 | 98 | # Patch embedding 99 | self.patch_embedding = nn.Conv2d(in_channels, dim, kernel_size=(fh, fw), stride=(fh, fw)) 100 | 101 | # Class token 102 | if classifier == 'token': 103 | self.class_token = nn.Parameter(torch.zeros(1, 1, dim)) 104 | seq_len += 1 105 | 106 | # Positional embedding 107 | if positional_embedding.lower() == '1d': 108 | self.positional_embedding = PositionalEmbedding1D(seq_len, dim) 109 | else: 110 | raise NotImplementedError() 111 | 112 | # Transformer 113 | self.transformer = Transformer(num_layers=num_layers, dim=dim, num_heads=num_heads, 114 | ff_dim=ff_dim, dropout=dropout_rate) 115 | 116 | # Representation layer 117 | if representation_size and load_repr_layer: 118 | self.pre_logits = nn.Linear(dim, representation_size) 119 | pre_logits_size = representation_size 120 | else: 121 | pre_logits_size = dim 122 | 123 | # Classifier head 124 | self.norm = nn.LayerNorm(pre_logits_size, eps=1e-6) 125 | self.fc = nn.Linear(pre_logits_size, num_classes) 126 | 127 | # Load pretrained model 128 | if pretrained: 129 | pretrained_num_channels = 3 130 | pretrained_num_classes = 1000 #PRETRAINED_MODELS[name]['num_classes'] #edit for new 131 | pretrained_image_size = PRETRAINED_MODELS[name]['image_size'] 132 | # print(image_size,pretrained_image_size) 133 | # exit(0) 134 | load_pretrained_weights( 135 | self, name, 136 | weights_path= weights_path,#'pretrain/B_32_imagenet1k.pth', 137 | load_first_conv=(in_channels == pretrained_num_channels), 138 | load_fc=(num_classes == pretrained_num_classes), 139 | load_repr_layer=load_repr_layer, 140 | resize_positional_embedding=(image_size != pretrained_image_size), 141 | ) 142 | print('load_first_conv:',in_channels == pretrained_num_channels,'load_fc:',num_classes == pretrained_num_classes) 143 | # self.idx=[0]*1000 144 | # self.dict_={} 145 | 146 | # # Modify model as specified. NOTE: We do not do this earlier because 147 | # # it's easier to load only part of a pretrained model in this manner. 148 | # if in_channels != 3: 149 | # self.embedding = nn.Conv2d(in_channels, patches, kernel_size=patches, stride=patches) 150 | # if num_classes is not None and num_classes != num_classes_init: 151 | # self.fc = nn.Linear(dim, num_classes) 152 | def forward(self, x): 153 | """Breaks image into patches, applies transformer, applies MLP head. 154 | 155 | Args: 156 | x (tensor): `b,c,fh,fw` 157 | """ 158 | b, c, fh, fw = x.shape 159 | x = self.patch_embedding(x) # b,d,gh,gw 160 | x = x.flatten(2).transpose(1, 2) # b,gh*gw,d 161 | if hasattr(self, 'class_token'): 162 | x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) # b,gh*gw+1,d 163 | if hasattr(self, 'positional_embedding'): 164 | x = self.positional_embedding(x) # b,gh*gw+1,d 165 | x = self.transformer(x) # b,gh*gw+1,d 166 | if hasattr(self, 'pre_logits'): 167 | x = self.pre_logits(x) 168 | x = torch.tanh(x) 169 | if hasattr(self, 'fc'): 170 | x = self.norm(x)[:, 0] # b,d 171 | x = self.fc(x) # b,num_classes 172 | return x 173 | # def forward(self, x):#,y): 174 | # """Breaks image into patches, applies transformer, applies MLP head. 175 | 176 | # Args: 177 | # x (tensor): `b,c,fh,fw` 178 | # """ 179 | # # blocks=F.unfold(x,kernel_size=32,stride=32).permute(0, 2, 1).contiguous() 180 | # # np.save('feature_TCL_org.npy',blocks[0,:,:].cpu().data.numpy()) 181 | # # exit(0) 182 | # b, c, fh, fw = x.shape 183 | # x = self.patch_embedding(x) # b,d,gh,gw 184 | # x = x.flatten(2).transpose(1, 2) # b,gh*gw,d 185 | # if hasattr(self, 'class_token'): 186 | # x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) # b,gh*gw+1,d 187 | # if hasattr(self, 'positional_embedding'): 188 | # x = self.positional_embedding(x) # b,gh*gw+1,d 189 | # x = self.transformer(x) # b,gh*gw+1,d 190 | # # np.save('feature_vit_b16_2.npy',x[2,:,:].cpu().data.numpy()) 191 | 192 | # # for i in range(len(y)): 193 | # # class_=str(y[i]) 194 | # # if os.path.isdir(os.path.join('feature_rp/TCL_2',class_))==False: 195 | # # os.mkdir(os.path.join('feature_rp/TCL_2',class_)) 196 | # # np.save(os.path.join('feature_rp/TCL_2',class_,'feature_TCL_%d.npy'%self.idx[y[i]]),x[i,:,:].cpu().data.numpy()) 197 | # # np.save(os.path.join('feature_rp/TCL_2',class_,'feature_TCL_%d_org.npy'%self.idx[y[i]]),blocks[i,:,:].cpu().data.numpy()) 198 | # # self.idx[y[i]]=self.idx[y[i]]+1 199 | 200 | # # for i in range(len(y)): 201 | # # class_=str(y[i]) 202 | # # if os.path.isdir(os.path.join('feature_rp/TCL',class_))==False: 203 | # # os.mkdir(os.path.join('feature_rp/TCL',class_)) 204 | # # np.save(os.path.join('feature_rp/TCL',class_,'feature_TCL_%d.npy'%self.idx[y[i]]),x[i,:,:].cpu().data.numpy()) 205 | # # self.idx[y[i]]=self.idx[y[i]]+1 206 | 207 | # # print(x[:,0].shape) 208 | # # exit(0) 209 | # if hasattr(self, 'pre_logits'): 210 | # x = self.pre_logits(x) 211 | # x = torch.tanh(x) 212 | # if hasattr(self, 'fc'): 213 | # x = self.norm(x)[:, 0] # b,d 214 | # x = self.fc(x) # b,num_classes 215 | # return x 216 | 217 | -------------------------------------------------------------------------------- /classification/vit/.ipynb_checkpoints/transformer-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/lukemelas/simple-bert 3 | """ 4 | 5 | import numpy as np 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | 10 | 11 | def split_last(x, shape): 12 | "split the last dimension to given shape" 13 | shape = list(shape) 14 | assert shape.count(-1) <= 1 15 | if -1 in shape: 16 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 17 | return x.view(*x.size()[:-1], *shape) 18 | 19 | 20 | def merge_last(x, n_dims): 21 | "merge the last n_dims to a dimension" 22 | s = x.size() 23 | assert n_dims > 1 and n_dims < len(s) 24 | return x.view(*s[:-n_dims], -1) 25 | 26 | 27 | class MultiHeadedSelfAttention(nn.Module): 28 | """Multi-Headed Dot Product Attention""" 29 | def __init__(self, dim, num_heads, dropout): 30 | super().__init__() 31 | self.proj_q = nn.Linear(dim, dim) 32 | self.proj_k = nn.Linear(dim, dim) 33 | self.proj_v = nn.Linear(dim, dim) 34 | self.drop = nn.Dropout(dropout) 35 | self.n_heads = num_heads 36 | self.scores = None # for visualization 37 | 38 | def forward(self, x, mask): 39 | """ 40 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 41 | mask : (B(batch_size) x S(seq_len)) 42 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 43 | """ 44 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 45 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 46 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) 47 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 48 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 49 | if mask is not None: 50 | mask = mask[:, None, None, :].float() 51 | scores -= 10000.0 * (1.0 - mask) 52 | scores = self.drop(F.softmax(scores, dim=-1)) 53 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 54 | h = (scores @ v).transpose(1, 2).contiguous() 55 | # -merge-> (B, S, D) 56 | h = merge_last(h, 2) 57 | self.scores = scores 58 | return h 59 | 60 | 61 | class PositionWiseFeedForward(nn.Module): 62 | """FeedForward Neural Networks for each position""" 63 | def __init__(self, dim, ff_dim): 64 | super().__init__() 65 | self.fc1 = nn.Linear(dim, ff_dim) 66 | self.fc2 = nn.Linear(ff_dim, dim) 67 | 68 | def forward(self, x): 69 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 70 | return self.fc2(F.gelu(self.fc1(x))) 71 | 72 | 73 | class Block(nn.Module): 74 | """Transformer Block""" 75 | def __init__(self, dim, num_heads, ff_dim, dropout): 76 | super().__init__() 77 | self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout) 78 | self.proj = nn.Linear(dim, dim) 79 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 80 | self.pwff = PositionWiseFeedForward(dim, ff_dim) 81 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 82 | self.drop = nn.Dropout(dropout) 83 | 84 | def forward(self, x, mask): 85 | h = self.drop(self.proj(self.attn(self.norm1(x), mask))) 86 | x = x + h 87 | h = self.drop(self.pwff(self.norm2(x))) 88 | x = x + h 89 | return x 90 | 91 | 92 | class Transformer(nn.Module): 93 | """Transformer with Self-Attentive Blocks""" 94 | def __init__(self, num_layers, dim, num_heads, ff_dim, dropout): 95 | super().__init__() 96 | self.blocks = nn.ModuleList([ 97 | Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]) 98 | 99 | def forward(self, x, mask=None): 100 | for block in self.blocks: 101 | x = block(x, mask) 102 | return x 103 | -------------------------------------------------------------------------------- /classification/vit/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | """utils.py - Helper functions 2 | """ 3 | 4 | import torch 5 | from torch.utils import model_zoo 6 | import numpy as np 7 | from .configs import PRETRAINED_MODELS 8 | 9 | 10 | def load_pretrained_weights( 11 | model, 12 | model_name=None, 13 | weights_path=None, 14 | load_first_conv=True, 15 | load_fc=True, 16 | load_repr_layer=False, 17 | resize_positional_embedding=False, 18 | verbose=True 19 | ): 20 | """Loads pretrained weights from weights path or download using url. 21 | 22 | Args: 23 | model (Module): Full model (a nn.Module) 24 | model_name (str): Model name (e.g. B_16) 25 | weights_path (None or str): 26 | str: path to pretrained weights file on the local disk. 27 | None: use pretrained weights downloaded from the Internet. 28 | load_first_conv (bool): Whether to load patch embedding. 29 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 30 | resize_positional_embedding=False, 31 | verbose (bool): Whether to print on completion 32 | """ 33 | # assert bool(model_name) ^ bool(weights_path), 'Expected exactly one of model_name or weights_path' 34 | 35 | # Load or download weights 36 | if weights_path is None: 37 | url = PRETRAINED_MODELS[model_name]['url'] 38 | if url: 39 | state_dict = model_zoo.load_url(url) 40 | else: 41 | raise ValueError(f'Pretrained model for {model_name} has not yet been released') 42 | else: 43 | print('Have weight: ', weights_path) 44 | state_dict = torch.load(weights_path) 45 | 46 | # Modifications to load partial state dict 47 | expected_missing_keys = [] 48 | if not load_first_conv and 'patch_embedding.weight' in state_dict: 49 | expected_missing_keys += ['patch_embedding.weight', 'patch_embedding.bias'] 50 | if not load_fc and 'fc.weight' in state_dict: 51 | expected_missing_keys += ['fc.weight', 'fc.bias'] 52 | if not load_repr_layer and 'pre_logits.weight' in state_dict: 53 | expected_missing_keys += ['pre_logits.weight', 'pre_logits.bias'] 54 | for key in expected_missing_keys: 55 | state_dict.pop(key) 56 | 57 | # Change size of positional embeddings 58 | if resize_positional_embedding: 59 | posemb = state_dict['state_dict']['module.positional_embedding.pos_embedding'].cpu() #state_dict['positional_embedding.pos_embedding'] # edit for new 60 | posemb_new = model.state_dict()['positional_embedding.pos_embedding'] 61 | # print(posemb,posemb_new) 62 | state_dict['state_dict']['module.positional_embedding.pos_embedding'] = \ 63 | resize_positional_embedding_(posemb=posemb, posemb_new=posemb_new, 64 | has_class_token=hasattr(model, 'class_token'))# edit for new 65 | # state_dict['positional_embedding.pos_embedding'] = \ 66 | # resize_positional_embedding_(posemb=posemb, posemb_new=posemb_new, 67 | # has_class_token=hasattr(model, 'class_token')) 68 | if verbose: 69 | print('Resized positional embeddings from {} to {}'.format( 70 | posemb.shape, posemb_new.shape)) 71 | # Load state dict 72 | # state_dict = state_dict['state_dict'] # edit for new 73 | # ret = model.load_state_dict({k.replace('module.',''):state_dict[k] for k in state_dict}, strict=False) # edit for new 74 | ret = model.load_state_dict(state_dict, strict=False) 75 | 76 | 77 | # print(state_dict) 78 | # for k in state_dict: 79 | # print(k) 80 | # exit(0) 81 | # 82 | # ret = model.load_state_dict(state_dict, strict=False) 83 | # print(state_dict.keys()) 84 | # exit(0) 85 | assert set(ret.missing_keys) == set(expected_missing_keys), \ 86 | 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 87 | assert not ret.unexpected_keys, \ 88 | 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) 89 | 90 | if verbose: 91 | print('Loaded pretrained weights.') 92 | 93 | 94 | def as_tuple(x): 95 | return x if isinstance(x, tuple) else (x, x) 96 | 97 | 98 | # def resize_positional_embedding_(posemb, posemb_new, has_class_token=True): 99 | # """Rescale the grid of position embeddings in a sensible manner""" 100 | # from scipy.ndimage import zoom 101 | 102 | # # Deal with class token 103 | # ntok_new = posemb_new.shape[1] 104 | # if has_class_token: # this means classifier == 'token' 105 | # posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 106 | # ntok_new -= 1 107 | # else: 108 | # posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 109 | 110 | # # Get old and new grid sizes 111 | # gs_old = int(np.sqrt(len(posemb_grid))) 112 | # gs_new = int(np.sqrt(ntok_new)) 113 | # posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 114 | 115 | # # Rescale grid 116 | # zoom = (gs_new / gs_old, gs_new / gs_old, 1) 117 | # posemb_grid = zoom(posemb_grid, zoom, order=1) 118 | # posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 119 | 120 | # # Deal with class token and return 121 | # posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 122 | # return posemb 123 | 124 | def resize_positional_embedding_(posemb, posemb_new, has_class_token=True): 125 | """Rescale the grid of position embeddings in a sensible manner""" 126 | from scipy.ndimage import zoom 127 | 128 | # Deal with class token 129 | ntok_new = posemb_new.shape[1] 130 | if has_class_token: # this means classifier == 'token' 131 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 132 | ntok_new -= 1 133 | else: 134 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 135 | 136 | # Get old and new grid sizes 137 | gs_old = int(np.sqrt(len(posemb_grid))) 138 | gs_new = int(np.sqrt(ntok_new)) 139 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 140 | 141 | # Rescale grid 142 | zoom_factor = (gs_new / gs_old, gs_new / gs_old, 1) 143 | posemb_grid = zoom(posemb_grid, zoom_factor, order=1) 144 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 145 | posemb_grid = torch.from_numpy(posemb_grid) 146 | 147 | # Deal with class token and return 148 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 149 | return posemb -------------------------------------------------------------------------------- /classification/vit/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.6" 2 | 3 | from .model import ViT 4 | from .configs import * 5 | from .utils import load_pretrained_weights -------------------------------------------------------------------------------- /classification/vit/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/configs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/configs.cpython-36.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/configs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/configs.cpython-37.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/configs.cpython-38.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /classification/vit/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/classification/vit/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /classification/vit/configs.py: -------------------------------------------------------------------------------- 1 | """configs.py - ViT model configurations, based on: 2 | https://github.com/google-research/vision_transformer/blob/master/vit_jax/configs.py 3 | """ 4 | 5 | def get_base_config(): 6 | """Base ViT config ViT""" 7 | return dict( 8 | dim=768, 9 | ff_dim=3072, 10 | num_heads=12, 11 | num_layers=12, 12 | attention_dropout_rate=0.0, 13 | dropout_rate=0.1, 14 | representation_size=768, 15 | classifier='token' 16 | ) 17 | 18 | def get_b16_config(): 19 | """Returns the ViT-B/16 configuration.""" 20 | config = get_base_config() 21 | config.update(dict(patches=(16, 16))) 22 | return config 23 | 24 | def get_b32_config(): 25 | """Returns the ViT-B/32 configuration.""" 26 | config = get_b16_config() 27 | config.update(dict(patches=(32, 32))) 28 | return config 29 | 30 | def get_l16_config(): 31 | """Returns the ViT-L/16 configuration.""" 32 | config = get_base_config() 33 | config.update(dict( 34 | patches=(16, 16), 35 | dim=1024, 36 | ff_dim=4096, 37 | num_heads=16, 38 | num_layers=24, 39 | attention_dropout_rate=0.0, 40 | dropout_rate=0.1, 41 | representation_size=1024 42 | )) 43 | return config 44 | 45 | def get_l32_config(): 46 | """Returns the ViT-L/32 configuration.""" 47 | config = get_l16_config() 48 | config.update(dict(patches=(32, 32))) 49 | return config 50 | 51 | def drop_head_variant(config): 52 | config.update(dict(representation_size=None)) 53 | return config 54 | 55 | 56 | PRETRAINED_MODELS = { 57 | 'B_16': { 58 | 'config': get_b16_config(), 59 | 'num_classes': 21843, 60 | 'image_size': (224, 224), 61 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16.pth" 62 | }, 63 | 'B_32': { 64 | 'config': get_b32_config(), 65 | 'num_classes': 21843, 66 | 'image_size': (224, 224), 67 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32.pth" 68 | }, 69 | 'L_16': { 70 | 'config': get_l16_config(), 71 | 'num_classes': 21843, 72 | 'image_size': (224, 224), 73 | 'url': None 74 | }, 75 | 'L_32': { 76 | 'config': get_l32_config(), 77 | 'num_classes': 21843, 78 | 'image_size': (224, 224), 79 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32.pth" 80 | }, 81 | 'B_16_imagenet1k': { 82 | 'config': drop_head_variant(get_b16_config()), 83 | 'num_classes': 1000, 84 | 'image_size': (384, 384), 85 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16_imagenet1k.pth" 86 | }, 87 | 'B_32_imagenet1k': { 88 | 'config': drop_head_variant(get_b32_config()), 89 | 'num_classes': 1000, 90 | 'image_size': (384, 384), 91 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32_imagenet1k.pth" 92 | }, 93 | 'L_16_imagenet1k': { 94 | 'config': drop_head_variant(get_l16_config()), 95 | 'num_classes': 1000, 96 | 'image_size': (384, 384), 97 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_16_imagenet1k.pth" 98 | }, 99 | 'L_32_imagenet1k': { 100 | 'config': drop_head_variant(get_l32_config()), 101 | 'num_classes': 1000, 102 | 'image_size': (384, 384), 103 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32_imagenet1k.pth" 104 | }, 105 | } 106 | -------------------------------------------------------------------------------- /classification/vit/model.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for ViT. 2 | They are built to mirror those in the official Jax implementation. 3 | """ 4 | 5 | from typing import Optional 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from .transformer import Transformer 11 | from .utils import load_pretrained_weights, as_tuple 12 | from .configs import PRETRAINED_MODELS 13 | import numpy as np 14 | import os 15 | 16 | class PositionalEmbedding1D(nn.Module): 17 | """Adds (optionally learned) positional embeddings to the inputs.""" 18 | 19 | def __init__(self, seq_len, dim): 20 | super().__init__() 21 | self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim)) 22 | 23 | def forward(self, x): 24 | """Input has shape `(batch_size, seq_len, emb_dim)`""" 25 | return x + self.pos_embedding 26 | 27 | 28 | class ViT(nn.Module): 29 | """ 30 | Args: 31 | name (str): Model name, e.g. 'B_16' 32 | pretrained (bool): Load pretrained weights 33 | in_channels (int): Number of channels in input data 34 | num_classes (int): Number of classes, default 1000 35 | 36 | References: 37 | [1] https://openreview.net/forum?id=YicbFdNTTy 38 | """ 39 | 40 | def __init__( 41 | self, 42 | name: Optional[str] = None, 43 | pretrained: bool = False, 44 | patches: int = 16, 45 | dim: int = 768, 46 | ff_dim: int = 3072, 47 | num_heads: int = 12, 48 | num_layers: int = 12, 49 | attention_dropout_rate: float = 0.0, 50 | dropout_rate: float = 0.1, 51 | representation_size: Optional[int] = None, 52 | load_repr_layer: bool = False, 53 | classifier: str = 'token', 54 | positional_embedding: str = '1d', 55 | in_channels: int = 3, 56 | image_size: Optional[int] = None, 57 | num_classes: Optional[int] = None, 58 | weights_path = None 59 | ): 60 | super().__init__() 61 | 62 | # Configuration 63 | if name is None: 64 | check_msg = 'must specify name of pretrained model' 65 | assert not pretrained, check_msg 66 | assert not resize_positional_embedding, check_msg 67 | if num_classes is None: 68 | num_classes = 1000 69 | if image_size is None: 70 | image_size = 384 71 | else: # load pretrained model 72 | assert name in PRETRAINED_MODELS.keys(), \ 73 | 'name should be in: ' + ', '.join(PRETRAINED_MODELS.keys()) 74 | config = PRETRAINED_MODELS[name]['config'] 75 | patches = config['patches'] 76 | dim = config['dim'] 77 | ff_dim = config['ff_dim'] 78 | num_heads = config['num_heads'] 79 | num_layers = config['num_layers'] 80 | attention_dropout_rate = config['attention_dropout_rate'] 81 | dropout_rate = config['dropout_rate'] 82 | representation_size = config['representation_size'] 83 | classifier = config['classifier'] 84 | if image_size is None: 85 | image_size = PRETRAINED_MODELS[name]['image_size'] 86 | if num_classes is None: 87 | num_classes = PRETRAINED_MODELS[name]['num_classes'] 88 | self.image_size = image_size 89 | 90 | # Image and patch sizes 91 | h, w = as_tuple(image_size) # image sizes 92 | fh, fw = as_tuple(patches) # patch sizes 93 | gh, gw = h // fh, w // fw # number of patches 94 | seq_len = gh * gw 95 | # print('seq_len:', seq_len,h,w,fh,fw) 96 | # exit(0) 97 | 98 | # Patch embedding 99 | self.patch_embedding = nn.Conv2d(in_channels, dim, kernel_size=(fh, fw), stride=(fh, fw)) 100 | 101 | # Class token 102 | if classifier == 'token': 103 | self.class_token = nn.Parameter(torch.zeros(1, 1, dim)) 104 | seq_len += 1 105 | 106 | # Positional embedding 107 | if positional_embedding.lower() == '1d': 108 | self.positional_embedding = PositionalEmbedding1D(seq_len, dim) 109 | else: 110 | raise NotImplementedError() 111 | 112 | # Transformer 113 | self.transformer = Transformer(num_layers=num_layers, dim=dim, num_heads=num_heads, 114 | ff_dim=ff_dim, dropout=dropout_rate) 115 | 116 | # Representation layer 117 | if representation_size and load_repr_layer: 118 | self.pre_logits = nn.Linear(dim, representation_size) 119 | pre_logits_size = representation_size 120 | else: 121 | pre_logits_size = dim 122 | 123 | # Classifier head 124 | self.norm = nn.LayerNorm(pre_logits_size, eps=1e-6) 125 | self.fc = nn.Linear(pre_logits_size, num_classes) 126 | 127 | # Load pretrained model 128 | if pretrained: 129 | pretrained_num_channels = 3 130 | pretrained_num_classes = 1000 #PRETRAINED_MODELS[name]['num_classes'] #edit for new 131 | pretrained_image_size = PRETRAINED_MODELS[name]['image_size'] 132 | # print(image_size,pretrained_image_size) 133 | # exit(0) 134 | load_pretrained_weights( 135 | self, name, 136 | weights_path= weights_path,#'pretrain/B_32_imagenet1k.pth', 137 | load_first_conv=(in_channels == pretrained_num_channels), 138 | load_fc=(num_classes == pretrained_num_classes), 139 | load_repr_layer=load_repr_layer, 140 | resize_positional_embedding=(image_size != pretrained_image_size), 141 | ) 142 | print('load_first_conv:',in_channels == pretrained_num_channels,'load_fc:',num_classes == pretrained_num_classes) 143 | # self.idx=[0]*1000 144 | # self.dict_={} 145 | 146 | # # Modify model as specified. NOTE: We do not do this earlier because 147 | # # it's easier to load only part of a pretrained model in this manner. 148 | # if in_channels != 3: 149 | # self.embedding = nn.Conv2d(in_channels, patches, kernel_size=patches, stride=patches) 150 | # if num_classes is not None and num_classes != num_classes_init: 151 | # self.fc = nn.Linear(dim, num_classes) 152 | def forward(self, x): 153 | """Breaks image into patches, applies transformer, applies MLP head. 154 | 155 | Args: 156 | x (tensor): `b,c,fh,fw` 157 | """ 158 | b, c, fh, fw = x.shape 159 | x = self.patch_embedding(x) # b,d,gh,gw 160 | x = x.flatten(2).transpose(1, 2) # b,gh*gw,d 161 | if hasattr(self, 'class_token'): 162 | x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) # b,gh*gw+1,d 163 | if hasattr(self, 'positional_embedding'): 164 | x = self.positional_embedding(x) # b,gh*gw+1,d 165 | x = self.transformer(x) # b,gh*gw+1,d 166 | if hasattr(self, 'pre_logits'): 167 | x = self.pre_logits(x) 168 | x = torch.tanh(x) 169 | if hasattr(self, 'fc'): 170 | x = self.norm(x)[:, 0] # b,d 171 | x = self.fc(x) # b,num_classes 172 | return x 173 | 174 | -------------------------------------------------------------------------------- /classification/vit/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/lukemelas/simple-bert 3 | """ 4 | 5 | import numpy as np 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | 10 | 11 | def split_last(x, shape): 12 | "split the last dimension to given shape" 13 | shape = list(shape) 14 | assert shape.count(-1) <= 1 15 | if -1 in shape: 16 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 17 | return x.view(*x.size()[:-1], *shape) 18 | 19 | 20 | def merge_last(x, n_dims): 21 | "merge the last n_dims to a dimension" 22 | s = x.size() 23 | assert n_dims > 1 and n_dims < len(s) 24 | return x.view(*s[:-n_dims], -1) 25 | 26 | 27 | class MultiHeadedSelfAttention(nn.Module): 28 | """Multi-Headed Dot Product Attention""" 29 | def __init__(self, dim, num_heads, dropout): 30 | super().__init__() 31 | self.proj_q = nn.Linear(dim, dim) 32 | self.proj_k = nn.Linear(dim, dim) 33 | self.proj_v = nn.Linear(dim, dim) 34 | self.drop = nn.Dropout(dropout) 35 | self.n_heads = num_heads 36 | self.scores = None # for visualization 37 | 38 | def forward(self, x, mask): 39 | """ 40 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 41 | mask : (B(batch_size) x S(seq_len)) 42 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 43 | """ 44 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 45 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 46 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) 47 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 48 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 49 | if mask is not None: 50 | mask = mask[:, None, None, :].float() 51 | scores -= 10000.0 * (1.0 - mask) 52 | scores = self.drop(F.softmax(scores, dim=-1)) 53 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 54 | h = (scores @ v).transpose(1, 2).contiguous() 55 | # -merge-> (B, S, D) 56 | h = merge_last(h, 2) 57 | self.scores = scores 58 | return h 59 | 60 | 61 | class PositionWiseFeedForward(nn.Module): 62 | """FeedForward Neural Networks for each position""" 63 | def __init__(self, dim, ff_dim): 64 | super().__init__() 65 | self.fc1 = nn.Linear(dim, ff_dim) 66 | self.fc2 = nn.Linear(ff_dim, dim) 67 | 68 | def forward(self, x): 69 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 70 | return self.fc2(F.gelu(self.fc1(x))) 71 | 72 | 73 | class Block(nn.Module): 74 | """Transformer Block""" 75 | def __init__(self, dim, num_heads, ff_dim, dropout): 76 | super().__init__() 77 | self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout) 78 | self.proj = nn.Linear(dim, dim) 79 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 80 | self.pwff = PositionWiseFeedForward(dim, ff_dim) 81 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 82 | self.drop = nn.Dropout(dropout) 83 | 84 | def forward(self, x, mask): 85 | h = self.drop(self.proj(self.attn(self.norm1(x), mask))) 86 | x = x + h 87 | h = self.drop(self.pwff(self.norm2(x))) 88 | x = x + h 89 | return x 90 | 91 | 92 | class Transformer(nn.Module): 93 | """Transformer with Self-Attentive Blocks""" 94 | def __init__(self, num_layers, dim, num_heads, ff_dim, dropout): 95 | super().__init__() 96 | self.blocks = nn.ModuleList([ 97 | Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]) 98 | 99 | def forward(self, x, mask=None): 100 | for block in self.blocks: 101 | x = block(x, mask) 102 | return x 103 | -------------------------------------------------------------------------------- /classification/vit/utils.py: -------------------------------------------------------------------------------- 1 | """utils.py - Helper functions 2 | """ 3 | 4 | import torch 5 | from torch.utils import model_zoo 6 | import numpy as np 7 | from .configs import PRETRAINED_MODELS 8 | 9 | 10 | def load_pretrained_weights( 11 | model, 12 | model_name=None, 13 | weights_path=None, 14 | load_first_conv=True, 15 | load_fc=True, 16 | load_repr_layer=False, 17 | resize_positional_embedding=False, 18 | verbose=True 19 | ): 20 | """Loads pretrained weights from weights path or download using url. 21 | 22 | Args: 23 | model (Module): Full model (a nn.Module) 24 | model_name (str): Model name (e.g. B_16) 25 | weights_path (None or str): 26 | str: path to pretrained weights file on the local disk. 27 | None: use pretrained weights downloaded from the Internet. 28 | load_first_conv (bool): Whether to load patch embedding. 29 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 30 | resize_positional_embedding=False, 31 | verbose (bool): Whether to print on completion 32 | """ 33 | # assert bool(model_name) ^ bool(weights_path), 'Expected exactly one of model_name or weights_path' 34 | 35 | # Load or download weights 36 | if weights_path is None: 37 | url = PRETRAINED_MODELS[model_name]['url'] 38 | if url: 39 | state_dict = model_zoo.load_url(url) 40 | else: 41 | raise ValueError(f'Pretrained model for {model_name} has not yet been released') 42 | else: 43 | print('Have weight: ', weights_path) 44 | state_dict = torch.load(weights_path) 45 | 46 | # Modifications to load partial state dict 47 | expected_missing_keys = [] 48 | if not load_first_conv and 'patch_embedding.weight' in state_dict: 49 | expected_missing_keys += ['patch_embedding.weight', 'patch_embedding.bias'] 50 | if not load_fc and 'fc.weight' in state_dict: 51 | expected_missing_keys += ['fc.weight', 'fc.bias'] 52 | if not load_repr_layer and 'pre_logits.weight' in state_dict: 53 | expected_missing_keys += ['pre_logits.weight', 'pre_logits.bias'] 54 | for key in expected_missing_keys: 55 | state_dict.pop(key) 56 | 57 | # Change size of positional embeddings 58 | if resize_positional_embedding: 59 | posemb = state_dict['state_dict']['module.positional_embedding.pos_embedding'].cpu() #state_dict['positional_embedding.pos_embedding'] # edit for new 60 | posemb_new = model.state_dict()['positional_embedding.pos_embedding'] 61 | # print(posemb,posemb_new) 62 | state_dict['state_dict']['module.positional_embedding.pos_embedding'] = \ 63 | resize_positional_embedding_(posemb=posemb, posemb_new=posemb_new, 64 | has_class_token=hasattr(model, 'class_token'))# edit for new 65 | # state_dict['positional_embedding.pos_embedding'] = \ 66 | # resize_positional_embedding_(posemb=posemb, posemb_new=posemb_new, 67 | # has_class_token=hasattr(model, 'class_token')) 68 | if verbose: 69 | print('Resized positional embeddings from {} to {}'.format( 70 | posemb.shape, posemb_new.shape)) 71 | # Load state dict 72 | # state_dict = state_dict['state_dict'] # edit for new 73 | if 'state_dict' in state_dict: 74 | ret = model.load_state_dict({k.replace('module.',''):state_dict['state_dict'][k] for k in state_dict['state_dict']}, strict=False) # edit for new 75 | else: 76 | ret = model.load_state_dict(state_dict, strict=False) 77 | 78 | 79 | # print(state_dict) 80 | # for k in state_dict: 81 | # print(k) 82 | # exit(0) 83 | # 84 | # ret = model.load_state_dict(state_dict, strict=False) 85 | # print(state_dict.keys()) 86 | # exit(0) 87 | assert set(ret.missing_keys) == set(expected_missing_keys), \ 88 | 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 89 | assert not ret.unexpected_keys, \ 90 | 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) 91 | 92 | if verbose: 93 | print('Loaded pretrained weights.') 94 | 95 | 96 | def as_tuple(x): 97 | return x if isinstance(x, tuple) else (x, x) 98 | 99 | 100 | # def resize_positional_embedding_(posemb, posemb_new, has_class_token=True): 101 | # """Rescale the grid of position embeddings in a sensible manner""" 102 | # from scipy.ndimage import zoom 103 | 104 | # # Deal with class token 105 | # ntok_new = posemb_new.shape[1] 106 | # if has_class_token: # this means classifier == 'token' 107 | # posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 108 | # ntok_new -= 1 109 | # else: 110 | # posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 111 | 112 | # # Get old and new grid sizes 113 | # gs_old = int(np.sqrt(len(posemb_grid))) 114 | # gs_new = int(np.sqrt(ntok_new)) 115 | # posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 116 | 117 | # # Rescale grid 118 | # zoom = (gs_new / gs_old, gs_new / gs_old, 1) 119 | # posemb_grid = zoom(posemb_grid, zoom, order=1) 120 | # posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 121 | 122 | # # Deal with class token and return 123 | # posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 124 | # return posemb 125 | 126 | def resize_positional_embedding_(posemb, posemb_new, has_class_token=True): 127 | """Rescale the grid of position embeddings in a sensible manner""" 128 | from scipy.ndimage import zoom 129 | 130 | # Deal with class token 131 | ntok_new = posemb_new.shape[1] 132 | if has_class_token: # this means classifier == 'token' 133 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 134 | ntok_new -= 1 135 | else: 136 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 137 | 138 | # Get old and new grid sizes 139 | gs_old = int(np.sqrt(len(posemb_grid))) 140 | gs_new = int(np.sqrt(ntok_new)) 141 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 142 | 143 | # Rescale grid 144 | zoom_factor = (gs_new / gs_old, gs_new / gs_old, 1) 145 | posemb_grid = zoom(posemb_grid, zoom_factor, order=1) 146 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 147 | posemb_grid = torch.from_numpy(posemb_grid) 148 | 149 | # Deal with class token and return 150 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 151 | return posemb -------------------------------------------------------------------------------- /figs/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figs/network.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/TransCL/7f34330b94d61cc2abc408d13f68853d4da4c6bc/figs/network.PNG -------------------------------------------------------------------------------- /segmentation/README.md: -------------------------------------------------------------------------------- 1 | # Coming soon. 2 | --------------------------------------------------------------------------------