├── lib ├── __init__.py ├── config │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ └── default.cpython-39.pyc │ └── default.py ├── core │ ├── __init__.py │ ├── __pycache__ │ │ ├── utils.cpython-39.pyc │ │ └── __init__.cpython-39.pyc │ └── utils.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── hrir.cpython-39.pyc │ │ ├── resnet.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── vit_pixel.cpython-39.pyc │ │ ├── model_builder.cpython-39.pyc │ │ └── shufflenetv2.cpython-39.pyc │ ├── utils.py │ ├── hrir.py │ ├── model_builder.py │ ├── shufflenetv2.py │ ├── resnet.py │ ├── vit_pixel.py │ └── vit.py ├── __pycache__ │ ├── utils.cpython-39.pyc │ └── __init__.cpython-39.pyc ├── datasets │ ├── __pycache__ │ │ ├── cub.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── cub_2scale.cpython-39.pyc │ ├── __init__.py │ ├── aircraft.py │ ├── cub.py │ ├── aircraft_2scale.py │ └── cub_2scale.py ├── losses │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── at_loss.cpython-39.pyc │ │ ├── dkd_loss.cpython-39.pyc │ │ ├── ickd_loss.cpython-39.pyc │ │ ├── kd_loss.cpython-39.pyc │ │ └── similarity_loss.cpython-39.pyc │ ├── __init__.py │ ├── icfup_loss.py │ ├── at_loss.py │ ├── similarity_loss.py │ ├── kd_loss.py │ ├── ickd_loss.py │ └── dkd_loss.py └── utils.py ├── tools_1_ts ├── __init__.py ├── train_isrd_5runs.py └── train_kd_5runs.py ├── figures ├── isrd.png └── introduction.png ├── configs ├── base │ └── cub │ │ ├── cub_resnet_single_112.yaml │ │ ├── cub_resnet_single_56.yaml │ │ ├── cub_resnet_single_224.yaml │ │ ├── cub_vit_single_112.yaml │ │ ├── cub_vit_single_56.yaml │ │ └── cub_vit_single_224.yaml └── ts │ ├── cub │ ├── cub_resnet50_resnet_isrd.yaml │ └── cub_resnet50_resnet_kd.yaml │ └── aircraft │ ├── aircraft_resnet50_resnet_isrd.yaml │ └── aircraft_resnet50_resnet_kd.yaml ├── LICENSE ├── environment.yaml ├── README.md └── tools_0_base ├── train_teacher_1run.py └── train_student_5runs.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools_1_ts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/isrd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/figures/isrd.png -------------------------------------------------------------------------------- /figures/introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/figures/introduction.png -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/core/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/core/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/hrir.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/models/__pycache__/hrir.cpython-39.pyc -------------------------------------------------------------------------------- /lib/core/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/core/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/cub.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/datasets/__pycache__/cub.cpython-39.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/models/__pycache__/resnet.cpython-39.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/config/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/default.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/config/__pycache__/default.cpython-39.pyc -------------------------------------------------------------------------------- /lib/losses/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/losses/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/losses/__pycache__/at_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/losses/__pycache__/at_loss.cpython-39.pyc -------------------------------------------------------------------------------- /lib/losses/__pycache__/dkd_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/losses/__pycache__/dkd_loss.cpython-39.pyc -------------------------------------------------------------------------------- /lib/losses/__pycache__/ickd_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/losses/__pycache__/ickd_loss.cpython-39.pyc -------------------------------------------------------------------------------- /lib/losses/__pycache__/kd_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/losses/__pycache__/kd_loss.cpython-39.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/vit_pixel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/models/__pycache__/vit_pixel.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/cub_2scale.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/datasets/__pycache__/cub_2scale.cpython-39.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/model_builder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/models/__pycache__/model_builder.cpython-39.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/shufflenetv2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/models/__pycache__/shufflenetv2.cpython-39.pyc -------------------------------------------------------------------------------- /lib/losses/__pycache__/similarity_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyguo/PixelDistillation/HEAD/lib/losses/__pycache__/similarity_loss.cpython-39.pyc -------------------------------------------------------------------------------- /lib/models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | https://github.com/pytorch/vision/blob/master/torchvision/models/utils.py 5 | """ 6 | 7 | try: 8 | from torch.hub import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | -------------------------------------------------------------------------------- /configs/base/cub/cub_resnet_single_112.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | GPU_ID: [0] 3 | NUM_WORKERS: 10 4 | BACKUP_CODES: True 5 | BACKUP_LIST: ['lib', 'tools_base', 'configs'] 6 | 7 | MODEL: 8 | TYPE: '5runs' 9 | ARCH: 'resnet18' 10 | PRETRAIN: True 11 | 12 | DATA: 13 | DATASET: cub 14 | DATADIR: data/cub 15 | NUM_CLASSES: 200 16 | RESIZE_SIZE: 128 17 | CROP_SIZE: 112 18 | 19 | TRAIN: 20 | BATCH_SIZE: 64 21 | 22 | TEST: 23 | BATCH_SIZE: 64 24 | 25 | SOLVER: 26 | START_LR: 0.01 27 | LR_STEPS: [ 30, 60, 90 ] 28 | NUM_EPOCHS: 120 29 | LR_DECAY_FACTOR: 0.1 30 | MUMENTUM: 0.9 31 | WEIGHT_DECAY: 0.0005 32 | -------------------------------------------------------------------------------- /configs/base/cub/cub_resnet_single_56.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | GPU_ID: [0] 3 | NUM_WORKERS: 10 4 | BACKUP_CODES: True 5 | BACKUP_LIST: ['lib', 'tools_0_base', 'configs'] 6 | 7 | MODEL: 8 | TYPE: '5runs' 9 | ARCH: 'resnet18' 10 | PRETRAIN: True 11 | 12 | DATA: 13 | DATASET: cub 14 | DATADIR: data/cub 15 | NUM_CLASSES: 200 16 | RESIZE_SIZE: 64 17 | CROP_SIZE: 56 18 | 19 | TRAIN: 20 | BATCH_SIZE: 64 21 | 22 | TEST: 23 | BATCH_SIZE: 64 24 | 25 | SOLVER: 26 | START_LR: 0.01 27 | LR_STEPS: [ 30, 60, 90 ] 28 | NUM_EPOCHS: 120 29 | LR_DECAY_FACTOR: 0.1 30 | MUMENTUM: 0.9 31 | WEIGHT_DECAY: 0.0005 32 | -------------------------------------------------------------------------------- /configs/base/cub/cub_resnet_single_224.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | GPU_ID: [0] 3 | NUM_WORKERS: 10 4 | BACKUP_CODES: True 5 | BACKUP_LIST: ['lib', 'tools_0_base', 'configs'] 6 | 7 | MODEL: 8 | TYPE: '1runs' 9 | ARCH: 'resnet50' 10 | PRETRAIN: True 11 | 12 | DATA: 13 | DATASET: cub 14 | DATADIR: data/cub 15 | NUM_CLASSES: 200 16 | RESIZE_SIZE: 256 17 | CROP_SIZE: 224 18 | 19 | TRAIN: 20 | BATCH_SIZE: 64 21 | 22 | TEST: 23 | BATCH_SIZE: 64 24 | 25 | SOLVER: 26 | START_LR: 0.01 27 | LR_STEPS: [ 30, 60, 90 ] 28 | NUM_EPOCHS: 120 29 | LR_DECAY_FACTOR: 0.1 30 | MUMENTUM: 0.9 31 | WEIGHT_DECAY: 0.0005 32 | -------------------------------------------------------------------------------- /configs/base/cub/cub_vit_single_112.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | SEED: 0 3 | GPU_ID: [0] 4 | NUM_WORKERS: 10 5 | BACKUP_CODES: True 6 | BACKUP_LIST: ['lib', 'tools_0_base', 'configs'] 7 | 8 | MODEL: 9 | TYPE: '5runs' 10 | ARCH: 'vit_tiny_patch16_112' 11 | PRETRAIN: True 12 | 13 | DATA: 14 | DATASET: cub 15 | DATADIR: data/cub 16 | NUM_CLASSES: 200 17 | RESIZE_SIZE: 128 18 | CROP_SIZE: 112 19 | 20 | TRAIN: 21 | BATCH_SIZE: 64 22 | 23 | TEST: 24 | BATCH_SIZE: 64 25 | 26 | SOLVER: 27 | START_LR: 0.001 28 | LR_STEPS: [30, 60] 29 | NUM_EPOCHS: 90 30 | LR_DECAY_FACTOR: 0.1 31 | MUMENTUM: 0.9 32 | WEIGHT_DECAY: 0.0005 33 | -------------------------------------------------------------------------------- /configs/base/cub/cub_vit_single_56.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | SEED: 0 3 | GPU_ID: [0] 4 | NUM_WORKERS: 10 5 | BACKUP_CODES: True 6 | BACKUP_LIST: ['lib', 'tools_0_base', 'configs'] 7 | 8 | MODEL: 9 | TYPE: '5runs' 10 | ARCH: 'vit_tiny_patch16_56' 11 | PRETRAIN: True 12 | 13 | DATA: 14 | DATASET: cub 15 | DATADIR: data/cub 16 | NUM_CLASSES: 200 17 | RESIZE_SIZE: 64 18 | CROP_SIZE: 56 19 | 20 | TRAIN: 21 | BATCH_SIZE: 64 22 | 23 | TEST: 24 | BATCH_SIZE: 64 25 | 26 | SOLVER: 27 | START_LR: 0.001 28 | LR_STEPS: [30, 60] 29 | NUM_EPOCHS: 90 30 | LR_DECAY_FACTOR: 0.1 31 | MUMENTUM: 0.9 32 | WEIGHT_DECAY: 0.0005 33 | -------------------------------------------------------------------------------- /configs/base/cub/cub_vit_single_224.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | SEED: 0 3 | GPU_ID: [0] 4 | NUM_WORKERS: 10 5 | BACKUP_CODES: True 6 | BACKUP_LIST: ['lib', 'tools_0_base', 'configs'] 7 | 8 | MODEL: 9 | TYPE: '1runs' 10 | ARCH: 'vit_base_patch16_224' 11 | PRETRAIN: True 12 | 13 | DATA: 14 | DATASET: cub 15 | DATADIR: data/cub 16 | NUM_CLASSES: 200 17 | RESIZE_SIZE: 256 18 | CROP_SIZE: 224 19 | 20 | TRAIN: 21 | BATCH_SIZE: 64 22 | 23 | TEST: 24 | BATCH_SIZE: 64 25 | 26 | SOLVER: 27 | START_LR: 0.001 28 | LR_STEPS: [30, 60] 29 | NUM_EPOCHS: 90 30 | LR_DECAY_FACTOR: 0.1 31 | MUMENTUM: 0.9 32 | WEIGHT_DECAY: 0.0005 33 | -------------------------------------------------------------------------------- /lib/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .kd_loss import KDLoss 2 | from .similarity_loss import SPLoss 3 | from .at_loss import ATLoss 4 | from .ickd_loss import ICKDLoss 5 | from .dkd_loss import DKDLoss 6 | 7 | __factory = { 8 | 'kd': KDLoss, 9 | 'sp': SPLoss, 10 | 'at': ATLoss, 11 | 'ickd': ICKDLoss, 12 | 'dkd': DKDLoss, 13 | } 14 | 15 | 16 | def names(): 17 | return sorted(__factory.keys()) 18 | 19 | 20 | def build_criterion(cfg): 21 | """ 22 | Create a dataset instance. 23 | Parameters 24 | ---------- 25 | name : str 26 | The dataset name. Can be one of __factory 27 | root : str 28 | The path to the dataset directory. 29 | """ 30 | name = cfg.MODEL.KDTYPE 31 | if name not in __factory: 32 | raise NotImplementedError('The method does not have its own loss calculation method.') 33 | return __factory[name](cfg) 34 | -------------------------------------------------------------------------------- /configs/ts/cub/cub_resnet50_resnet_isrd.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | GPU_ID: [0] 3 | NUM_WORKERS: 10 4 | BACKUP_CODES: True 5 | BACKUP_LIST: ['lib', 'tools_1_ts', 'configs'] 6 | 7 | MODEL: 8 | TYPE: '5runs' 9 | KDTYPE: 'kd' 10 | ARCH_T: 'resnet50' 11 | MODELDICT_T: 'ckpt/cub/1runs_resnet50_224_seed0_0.01/ckpt/model_best.pth' 12 | ARCH_S: 'resnet18' 13 | PRETRAIN_S: True 14 | 15 | FSR: 16 | ETA: 50.0 # 10 then LESSEN_RATIO=2 17 | POSITION: 0 18 | 19 | KD: 20 | TEMP: 4 21 | ALPHA: 0.9 22 | 23 | DATA: 24 | DATASET: cub 25 | DATADIR: data/cub 26 | NUM_CLASSES: 200 27 | RESIZE_SIZE: 256 28 | CROP_SIZE: 224 29 | LESSEN_RATIO: 4.0 30 | LESSEN_TYPE: 2 31 | 32 | 33 | TRAIN: 34 | BATCH_SIZE: 64 35 | 36 | TEST: 37 | BATCH_SIZE: 64 38 | 39 | SOLVER: 40 | START_LR: 0.01 41 | LR_STEPS: [ 30, 60, 90 ] 42 | NUM_EPOCHS: 120 43 | LR_DECAY_FACTOR: 0.1 44 | MUMENTUM: 0.9 45 | WEIGHT_DECAY: 0.0005 46 | -------------------------------------------------------------------------------- /configs/ts/aircraft/aircraft_resnet50_resnet_isrd.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | GPU_ID: [0] 3 | NUM_WORKERS: 10 4 | BACKUP_CODES: True 5 | BACKUP_LIST: ['lib', 'tools_1_ts', 'configs'] 6 | 7 | MODEL: 8 | TYPE: '5runs' 9 | KDTYPE: 'kd' 10 | ARCH_T: 'resnet50' 11 | MODELDICT_T: 'ckpt/aircraft/1runs_resnet50_224_seed0_0.01/ckpt/model_best.pth' 12 | ARCH_S: 'resnet18' 13 | PRETRAIN_S: True 14 | 15 | KD: 16 | TEMP: 4 17 | ALPHA: 0.9 18 | 19 | 20 | FSR: 21 | ETA: 20.0 22 | POSITION: 0 23 | 24 | 25 | DATA: 26 | DATASET: aircraft 27 | DATADIR: data/aircraft 28 | NUM_CLASSES: 100 29 | RESIZE_SIZE: 256 30 | CROP_SIZE: 224 31 | LESSEN_RATIO: 4.0 32 | LESSEN_TYPE: 2 33 | 34 | 35 | TRAIN: 36 | BATCH_SIZE: 64 37 | 38 | TEST: 39 | BATCH_SIZE: 64 40 | 41 | SOLVER: 42 | START_LR: 0.01 43 | LR_STEPS: [ 30, 60, 90 ] 44 | NUM_EPOCHS: 120 45 | LR_DECAY_FACTOR: 0.1 46 | MUMENTUM: 0.9 47 | WEIGHT_DECAY: 0.0005 48 | -------------------------------------------------------------------------------- /lib/losses/icfup_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | feature distillation loss in input compression stage, using upsampled feature maps 3 | """ 4 | from __future__ import print_function 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ICFUPLoss(nn.Module): 11 | def __init__(self, cfg): 12 | super(ICFUPLoss, self).__init__() 13 | self.mse_criterion = nn.MSELoss() 14 | 15 | def forward(self, g_s, g_t): 16 | return [self.icfup_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 17 | 18 | def icfup_loss(self, f_s, f_t): 19 | s_H, t_H = f_s.shape[2], f_t.shape[2] 20 | f_s = F.interpolate(f_s, size=(t_H, t_H), mode='bilinear', align_corners=True) 21 | return self.mse_criterion(f_s, f_t) 22 | 23 | 24 | 25 | if __name__ == "__main__": 26 | import torch 27 | feat1 = torch.ones(2, 512, 4, 4) 28 | feat2 = torch.ones(2, 512, 28, 28) 29 | 30 | icfup_loss = ICFUPLoss(None) 31 | 32 | loss = icfup_loss([feat1], [feat2]) 33 | 34 | -------------------------------------------------------------------------------- /lib/losses/at_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ATLoss(nn.Module): 8 | """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks 9 | via Attention Transfer 10 | code: https://github.com/szagoruyko/attention-transfer""" 11 | def __init__(self, cfg): 12 | super(ATLoss, self).__init__() 13 | self.p = 2 14 | 15 | def forward(self, g_s, g_t): 16 | return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 17 | 18 | def at_loss(self, f_s, f_t): 19 | s_H, t_H = f_s.shape[2], f_t.shape[2] 20 | if s_H > t_H: 21 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 22 | elif s_H < t_H: 23 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 24 | else: 25 | pass 26 | return (self.at(f_s) - self.at(f_t)).pow(2).mean() 27 | 28 | def at(self, f): 29 | return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) 30 | -------------------------------------------------------------------------------- /lib/losses/similarity_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class SPLoss(nn.Module): 9 | """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" 10 | def __init__(self, cfg): 11 | super(SPLoss, self).__init__() 12 | 13 | def forward(self, g_s, g_t): 14 | return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 15 | 16 | def similarity_loss(self, f_s, f_t): 17 | bsz = f_s.shape[0] 18 | f_s = f_s.view(bsz, -1) 19 | f_t = f_t.view(bsz, -1) 20 | 21 | G_s = torch.mm(f_s, torch.t(f_s)) 22 | # G_s = G_s / G_s.norm(2) 23 | G_s = torch.nn.functional.normalize(G_s) 24 | G_t = torch.mm(f_t, torch.t(f_t)) 25 | # G_t = G_t / G_t.norm(2) 26 | G_t = torch.nn.functional.normalize(G_t) 27 | 28 | G_diff = G_t - G_s 29 | loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) 30 | return loss 31 | -------------------------------------------------------------------------------- /lib/losses/kd_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DistillKL(torch.nn.Module): 7 | """Distilling the Knowledge in a Neural Network""" 8 | def __init__(self, temp): 9 | super(DistillKL, self).__init__() 10 | self.temp = temp 11 | 12 | def forward(self, y_s, y_t): 13 | p_s = F.log_softmax(y_s/self.temp, dim=1) 14 | p_t = F.softmax(y_t/self.temp, dim=1) 15 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.temp**2) / y_s.shape[0] 16 | return loss 17 | 18 | 19 | class KDLoss(nn.Module): 20 | def __init__(self, cfg): 21 | super(KDLoss, self).__init__() 22 | self.alpha = cfg.KD.ALPHA 23 | self.DistillKL = DistillKL(cfg.KD.TEMP) 24 | self.cls_criterion = torch.nn.CrossEntropyLoss() 25 | 26 | def forward(self, output_s, output_t, target): 27 | cls_loss = self.cls_criterion(output_s, target) 28 | kd_loss = self.DistillKL(output_s, output_t) 29 | loss = (1 - self.alpha) * cls_loss + self.alpha * kd_loss 30 | return loss -------------------------------------------------------------------------------- /configs/ts/cub/cub_resnet50_resnet_kd.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | GPU_ID: [0] 3 | NUM_WORKERS: 10 4 | BACKUP_CODES: True 5 | BACKUP_LIST: ['lib', 'tools_1_ts', 'configs'] 6 | 7 | MODEL: 8 | TYPE: '5runs' 9 | KDTYPE: 'kd' # 'at' 'dkd' 'ickd' 'sp' 10 | ARCH_T: 'resnet50' 11 | MODELDICT_T: 'ckpt/cub/1runs_resnet50_224_seed0_0.01/ckpt/model_best.pth' 12 | ARCH_S: 'resnet18' 13 | PRETRAIN_S: True 14 | 15 | KD: 16 | TEMP: 4 17 | ALPHA: 0.9 18 | 19 | ICKD: 20 | BETA: 2.5 21 | 22 | AT: 23 | BETA: 1000.0 24 | 25 | 26 | SP: 27 | BETA: 3000.0 28 | 29 | DKD: 30 | ALPHA: 1.0 31 | BETA: 2.0 32 | TEMP: 4 33 | WARMUP: 20 34 | 35 | DATA: 36 | DATASET: cub 37 | DATADIR: data/cub 38 | NUM_CLASSES: 200 39 | RESIZE_SIZE: 256 40 | CROP_SIZE: 224 41 | LESSEN_RATIO: 1.0 42 | LESSEN_TYPE: 2 43 | 44 | 45 | TRAIN: 46 | BATCH_SIZE: 64 47 | 48 | TEST: 49 | BATCH_SIZE: 64 50 | 51 | SOLVER: 52 | START_LR: 0.01 53 | LR_STEPS: [ 30, 60, 90 ] 54 | NUM_EPOCHS: 120 55 | LR_DECAY_FACTOR: 0.1 56 | MUMENTUM: 0.9 57 | WEIGHT_DECAY: 0.0005 58 | -------------------------------------------------------------------------------- /configs/ts/aircraft/aircraft_resnet50_resnet_kd.yaml: -------------------------------------------------------------------------------- 1 | BASIC: 2 | GPU_ID: [0] 3 | NUM_WORKERS: 10 4 | BACKUP_CODES: True 5 | BACKUP_LIST: ['lib', 'tools_1_ts', 'configs'] 6 | 7 | 8 | MODEL: 9 | TYPE: '5runs' 10 | KDTYPE: 'kd' # 'at' 'dkd' 'ickd' 'sp' 11 | ARCH_T: 'resnet50' 12 | MODELDICT_T: 'ckpt/aircraft/1runs_resnet50_224_seed0_0.01/ckpt/model_best.pth' 13 | ARCH_S: 'resnet18' 14 | PRETRAIN_S: True 15 | 16 | KD: 17 | TEMP: 4 18 | ALPHA: 0.9 19 | 20 | ICKD: 21 | BETA: 2.5 22 | 23 | AT: 24 | BETA: 1000.0 25 | 26 | 27 | SP: 28 | BETA: 3000.0 29 | 30 | DKD: 31 | ALPHA: 1.0 32 | BETA: 2.0 33 | TEMP: 4 34 | WARMUP: 20 35 | 36 | DATA: 37 | DATASET: aircraft 38 | DATADIR: data/aircraft 39 | NUM_CLASSES: 100 40 | RESIZE_SIZE: 256 41 | CROP_SIZE: 224 42 | LESSEN_RATIO: 4.0 43 | LESSEN_TYPE: 2 44 | 45 | TRAIN: 46 | BATCH_SIZE: 64 47 | 48 | TEST: 49 | BATCH_SIZE: 64 50 | 51 | SOLVER: 52 | START_LR: 0.01 53 | LR_STEPS: [ 30, 60, 90 ] 54 | NUM_EPOCHS: 120 55 | LR_DECAY_FACTOR: 0.1 56 | MUMENTUM: 0.9 57 | WEIGHT_DECAY: 0.0005 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Guangyu Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lib/losses/ickd_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Normalize(nn.Module): 9 | """normalization layer""" 10 | def __init__(self, power=2, dim = 1): 11 | super(Normalize, self).__init__() 12 | self.power = power 13 | self.dim = dim 14 | 15 | def forward(self, x): 16 | norm = x.pow(self.power).sum(self.dim, keepdim=True).pow(1. / self.power) 17 | out = x.div(norm) 18 | return out 19 | 20 | 21 | class Embed(nn.Module): 22 | def __init__(self, dim_in=256, dim_out=128): 23 | super(Embed, self).__init__() 24 | self.conv2d = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) 25 | self.l2norm = nn.BatchNorm2d(dim_out)#Normalize(2) 26 | 27 | def forward(self, x): 28 | x = self.conv2d(x) 29 | x = self.l2norm(x) 30 | return x 31 | 32 | 33 | class ICKDLoss(nn.Module): 34 | """Inter-Channel Correlation""" 35 | def __init__(self, cfg): 36 | super(ICKDLoss, self).__init__() 37 | self.embed_s = Embed(cfg.ICKD.FEATDIM_S[1], cfg.ICKD.FEATDIM_T[1]) 38 | self.embed_t = Embed(cfg.ICKD.FEATDIM_S[1], cfg.ICKD.FEATDIM_T[1]) 39 | 40 | def forward(self, g_s, g_t): 41 | loss = [self.batch_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 42 | return loss 43 | 44 | def batch_loss(self, f_s, f_t): 45 | f_s = self.embed_s(f_s) 46 | bsz, ch = f_s.shape[0], f_s.shape[1] 47 | 48 | f_s = f_s.view(bsz, ch, -1) 49 | f_t = f_t.view(bsz, ch, -1) 50 | 51 | emd_s = torch.bmm(f_s, f_s.permute(0,2,1)) 52 | emd_s = torch.nn.functional.normalize(emd_s, dim = 2) 53 | 54 | emd_t = torch.bmm(f_t, f_t.permute(0,2,1)) 55 | emd_t = torch.nn.functional.normalize(emd_t, dim = 2) 56 | 57 | G_diff = emd_s - emd_t 58 | loss = (G_diff * G_diff).view(bsz, -1).sum() / (ch*bsz) 59 | return loss 60 | 61 | 62 | if __name__ == "__main__": 63 | x1 = torch.randn(2, 64, 112, 112) 64 | x2 = torch.randn(2, 32, 64, 64) 65 | s_dim = x1.shape[1] 66 | feat_dim = x2.shape[1] 67 | kd = ICKDLoss(s_dim, feat_dim) 68 | kd_loss = kd([x1], [x2]) 69 | print(kd_loss) -------------------------------------------------------------------------------- /lib/models/hrir.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SR1x1(nn.Module): 7 | 8 | def __init__(self, cfg, feat_size): 9 | super(SR1x1, self).__init__() 10 | 11 | self.image_size = cfg.DATA.CROP_SIZE 12 | self.in_size = feat_size[2] 13 | if self.image_size % self.in_size == 0: 14 | self.scale_factor = int(self.image_size / self.in_size) 15 | else: 16 | self.scale_factor = int(self.image_size / self.in_size) +1 17 | 18 | self.outplanes = self.scale_factor ** 2 * 3 19 | self.inplanes = feat_size[1] 20 | 21 | self.conv = nn.Conv2d(self.inplanes, self.outplanes, kernel_size=1, stride=1, padding=0, bias=False) 22 | self.pixel_shuffle = nn.PixelShuffle(self.scale_factor) 23 | self.prelu = nn.PReLU(3) 24 | 25 | def forward(self, feat_s): 26 | 27 | feat = self.conv(feat_s) 28 | 29 | image_sr = self.pixel_shuffle(feat) 30 | 31 | image_sr = self.prelu(image_sr) 32 | 33 | if self.image_size % self.in_size == 0: 34 | return image_sr 35 | else: 36 | return image_sr[:, :, 0:self.image_size, 0:self.image_size] 37 | 38 | if __name__ == "__main__": 39 | import os, sys 40 | 41 | sys.path.insert(0, '../../lib') 42 | 43 | import argparse 44 | from config.default import update_config 45 | from config.default import config as cfg 46 | 47 | parser = argparse.ArgumentParser(description='knowledge distillation') 48 | parser.add_argument('--config_file', type=str, default='../../configs/kd/cub/cub_resnet50_resnet_pixel.yaml', 49 | required=False, help='Optional config file for params') 50 | parser.add_argument('opts', help='see config.py for all options', 51 | default='BASIC.GPU_ID [0]', nargs=argparse.REMAINDER) 52 | args = parser.parse_args() 53 | update_config(args) 54 | 55 | cfg.DATA.CROP_SIZE = 224 56 | cfg.DATA.LESSEN_RATIO = 4.0 57 | input_s = torch.randn(64, 3, 56, 56) 58 | feat_s = torch.randn(64, 64, 28, 28) 59 | feat_t = torch.randn(64, 64, 112, 112) 60 | weight_t = torch.randn(64, 3, 7, 7) 61 | 62 | model = FSR(cfg, list(feat_s.shape)) 63 | image, feat = model(feat_s, weight_t) 64 | print(image.shape) 65 | print(feat.shape) -------------------------------------------------------------------------------- /lib/core/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import shutil 4 | import torch 5 | from sklearn.metrics import average_precision_score 6 | 7 | 8 | def str_gpus(ids): 9 | str_ids = '' 10 | for id in ids: 11 | str_ids = str_ids + str(id) 12 | str_ids = str_ids + ',' 13 | 14 | return str_ids 15 | 16 | 17 | def map_sklearn(labels, results): 18 | map = average_precision_score(labels, results, average="micro") 19 | return map 20 | 21 | 22 | def adjust_learning_rate(optimizer, epoch, cfg): 23 | """"Sets the learning rate to the initial LR decayed by lr_factor""" 24 | lr_decay = cfg.SOLVER.LR_FACTOR**(sum(epoch > np.array(cfg.SOLVER.LR_STEPS))) 25 | lr = cfg.SOLVER.START_LR * lr_decay 26 | for param_group in optimizer.param_groups: 27 | param_group['lr'] = lr * param_group['lr_mult'] 28 | 29 | 30 | def save_checkpoint(state, save_dir, epoch, is_best): 31 | filename = os.path.join(save_dir, 'ckpt_'+str(epoch)+'.pth.tar') 32 | torch.save(state, filename) 33 | if is_best: 34 | best_name = os.path.join(save_dir, 'model_best.pth.tar') 35 | shutil.copyfile(filename, best_name) 36 | 37 | 38 | class AverageMeter(object): 39 | """Computes and stores the average and current value""" 40 | def __init__(self): 41 | self.reset() 42 | 43 | def reset(self): 44 | self.val = 0 45 | self.avg = 0 46 | self.sum = 0 47 | self.count = 0 48 | 49 | def update(self, val, n=1): 50 | self.val = val 51 | self.sum += val * n 52 | self.count += n 53 | self.avg = self.sum / self.count 54 | 55 | 56 | def accuracy(output, target, topk=(1,)): 57 | """Computes the precision@k for the specified values of k""" 58 | maxk = max(topk) 59 | batch_size = target.size(0) 60 | 61 | _, pred = output.topk(maxk, 1, True, True) 62 | pred = pred.t() 63 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 64 | 65 | res = [] 66 | for k in topk: 67 | correct_k = correct[:k].view(-1).float().sum(0) 68 | res.append(correct_k.mul_(100.0 / batch_size)) 69 | return res 70 | 71 | 72 | def list2acc(results_list): 73 | """ 74 | :param results_list: list contains 0 and 1 75 | :return: accuarcy 76 | """ 77 | accuarcy = results_list.count(1)/len(results_list) 78 | return accuarcy -------------------------------------------------------------------------------- /lib/losses/dkd_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature): 7 | gt_mask = _get_gt_mask(logits_student, target) 8 | other_mask = _get_other_mask(logits_student, target) 9 | pred_student = F.softmax(logits_student / temperature, dim=1) 10 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 11 | pred_student = cat_mask(pred_student, gt_mask, other_mask) 12 | pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) 13 | log_pred_student = torch.log(pred_student) 14 | tckd_loss = ( 15 | F.kl_div(log_pred_student, pred_teacher, size_average=False) 16 | * (temperature**2) 17 | / target.shape[0] 18 | ) 19 | pred_teacher_part2 = F.softmax( 20 | logits_teacher / temperature - 1000.0 * gt_mask, dim=1 21 | ) 22 | log_pred_student_part2 = F.log_softmax( 23 | logits_student / temperature - 1000.0 * gt_mask, dim=1 24 | ) 25 | nckd_loss = ( 26 | F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False) 27 | * (temperature**2) 28 | / target.shape[0] 29 | ) 30 | return alpha * tckd_loss + beta * nckd_loss 31 | 32 | 33 | def _get_gt_mask(logits, target): 34 | target = target.reshape(-1) 35 | mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() 36 | return mask 37 | 38 | 39 | def _get_other_mask(logits, target): 40 | target = target.reshape(-1) 41 | mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() 42 | return mask 43 | 44 | 45 | def cat_mask(t, mask1, mask2): 46 | t1 = (t * mask1).sum(dim=1, keepdims=True) 47 | t2 = (t * mask2).sum(1, keepdims=True) 48 | rt = torch.cat([t1, t2], dim=1) 49 | return rt 50 | 51 | 52 | class DKDLoss(nn.Module): 53 | """Decoupled Knowledge Distillation(CVPR 2022)""" 54 | 55 | def __init__(self, cfg): 56 | super(DKDLoss, self).__init__() 57 | self.alpha = cfg.DKD.ALPHA 58 | self.beta = cfg.DKD.BETA 59 | self.temperature = cfg.DKD.TEMP 60 | self.warmup = cfg.DKD.WARMUP 61 | 62 | def forward(self, logits_student, logits_teacher, target, epoch): 63 | loss_dkd = min(epoch / self.warmup, 1.0) * dkd_loss( 64 | logits_student, 65 | logits_teacher, 66 | target, 67 | self.alpha, 68 | self.beta, 69 | self.temperature, 70 | ) 71 | return loss_dkd -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: PixelDistill 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h5eee18b_5 10 | - ca-certificates=2024.3.11=h06a4308_0 11 | - certifi=2024.2.2=py39h06a4308_0 12 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 13 | - cudatoolkit=11.3.1=h2bc3f7f_2 14 | - ffmpeg=4.3=hf484d3e_0 15 | - freetype=2.12.1=h4a9f257_0 16 | - gmp=6.2.1=h295c915_3 17 | - gnutls=3.6.15=he1e5248_0 18 | - idna=3.4=py39h06a4308_0 19 | - intel-openmp=2023.1.0=hdb19cb5_46306 20 | - jpeg=9e=h5eee18b_1 21 | - lame=3.100=h7b6447c_0 22 | - lcms2=2.12=h3be6417_0 23 | - ld_impl_linux-64=2.38=h1181459_1 24 | - lerc=3.0=h295c915_0 25 | - libdeflate=1.17=h5eee18b_1 26 | - libffi=3.4.4=h6a678d5_0 27 | - libgcc-ng=11.2.0=h1234567_1 28 | - libgomp=11.2.0=h1234567_1 29 | - libiconv=1.16=h7f8727e_2 30 | - libidn2=2.3.4=h5eee18b_0 31 | - libpng=1.6.39=h5eee18b_0 32 | - libstdcxx-ng=11.2.0=h1234567_1 33 | - libtasn1=4.19.0=h5eee18b_0 34 | - libtiff=4.5.1=h6a678d5_0 35 | - libunistring=0.9.10=h27cfd23_0 36 | - libwebp-base=1.3.2=h5eee18b_0 37 | - lz4-c=1.9.4=h6a678d5_0 38 | - mkl=2023.1.0=h213fc3f_46344 39 | - mkl-service=2.4.0=py39h5eee18b_1 40 | - mkl_fft=1.3.8=py39h5eee18b_0 41 | - mkl_random=1.2.4=py39hdb19cb5_0 42 | - ncurses=6.4=h6a678d5_0 43 | - nettle=3.7.3=hbbd107a_1 44 | - numpy=1.26.4=py39h5f9d8c6_0 45 | - numpy-base=1.26.4=py39hb5e798b_0 46 | - openh264=2.1.1=h4ff587b_0 47 | - openjpeg=2.4.0=h3ad879b_0 48 | - openssl=3.0.13=h7f8727e_0 49 | - pillow=10.2.0=py39h5eee18b_0 50 | - pip=23.3.1=py39h06a4308_0 51 | - python=3.9.19=h955ad1f_0 52 | - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 53 | - pytorch-mutex=1.0=cuda 54 | - readline=8.2=h5eee18b_0 55 | - requests=2.31.0=py39h06a4308_1 56 | - setuptools=68.2.2=py39h06a4308_0 57 | - sqlite=3.41.2=h5eee18b_0 58 | - tbb=2021.8.0=hdb19cb5_0 59 | - tk=8.6.12=h1ccaba5_0 60 | - torchvision=0.13.1=py39_cu113 61 | - typing_extensions=4.9.0=py39h06a4308_1 62 | - tzdata=2024a=h04d1e81_0 63 | - wheel=0.41.2=py39h06a4308_0 64 | - xz=5.4.6=h5eee18b_0 65 | - zlib=1.2.13=h5eee18b_0 66 | - zstd=1.5.5=hc292b87_0 67 | - pip: 68 | - absl-py==2.1.0 69 | - clip==1.0 70 | - ftfy==6.2.0 71 | - grpcio==1.64.1 72 | - importlib-metadata==8.0.0 73 | - joblib==1.4.2 74 | - markdown==3.6 75 | - markupsafe==2.1.5 76 | - protobuf==4.25.3 77 | - ==6pyyaml.0.1 78 | - regex==2024.4.16 79 | - scikit-learn==1.5.1 80 | - scipy==1.13.0 81 | - six==1.16.0 82 | - tensorboard==2.17.0 83 | - tensorboard-data-server==0.7.2 84 | - threadpoolctl==3.5.0 85 | - timm==0.6.5 86 | - tqdm==4.66.2 87 | - urllib3==1.26.18 88 | - wcwidth==0.2.13 89 | - werkzeug==3.0.3 90 | - zipp==3.19.2 91 | prefix: /mnt/workspace/envs/Mine 92 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import sys 4 | import errno 5 | import numpy as np 6 | import torch 7 | import random 8 | import torch.backends.cudnn as cudnn 9 | 10 | 11 | def rm(path): 12 | try: 13 | shutil.rmtree(path) 14 | except OSError as e: 15 | if e.errno != errno.ENOENT: 16 | raise 17 | 18 | 19 | def mkdir(path): 20 | try: 21 | os.makedirs(path) 22 | except OSError as e: 23 | if e.errno != errno.EEXIST: 24 | raise 25 | 26 | 27 | class Logger(object): 28 | def __init__(self,filename="Default.log"): 29 | self.terminal = sys.stdout 30 | self.log = open(filename,'a') 31 | 32 | def write(self,message): 33 | self.terminal.write(message) 34 | self.log.write(message) 35 | 36 | def flush(self): 37 | pass 38 | 39 | 40 | def fix_random_seed(seed): 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | 45 | np.random.seed(seed) 46 | random.seed(seed) 47 | 48 | 49 | def fix_seed_all(cfg): 50 | # fix sedd 51 | fix_random_seed(cfg.BASIC.SEED) 52 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 53 | cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 54 | cudnn.enabled = cfg.CUDNN.ENABLE 55 | 56 | 57 | def backup_codes(root_dir, res_dir, backup_list): 58 | if os.path.exists(res_dir): 59 | shutil.rmtree(res_dir) # delete 60 | os.makedirs(res_dir) 61 | for name in backup_list: 62 | shutil.copytree(os.path.join(root_dir, name), os.path.join(res_dir, name)) 63 | print('codes backup at {}'.format(os.path.join(res_dir, name))) 64 | 65 | 66 | def prepare_env_noseed(cfg): 67 | # backup codes 68 | if cfg.BASIC.BACKUP_CODES: 69 | backup_dir = os.path.join(cfg.BASIC.SAVE_DIR, 'backup') 70 | rm(backup_dir) 71 | backup_codes(cfg.BASIC.ROOT_DIR, backup_dir, cfg.BASIC.BACKUP_LIST) 72 | 73 | # create save directory 74 | cfg.BASIC.CKPT_DIR = os.path.join(cfg.BASIC.SAVE_DIR, 'ckpt') 75 | mkdir(cfg.BASIC.CKPT_DIR) 76 | cfg.BASIC.LOG_DIR = os.path.join(cfg.BASIC.SAVE_DIR, 'log') 77 | mkdir(cfg.BASIC.LOG_DIR) 78 | cfg.BASIC.LOG_FILE = os.path.join(cfg.BASIC.SAVE_DIR, 'Log_' + cfg.BASIC.TIME + '.txt') 79 | 80 | def prepare_env(cfg): 81 | # fix random seed 82 | fix_random_seed(cfg.BASIC.SEED) 83 | # cudnn 84 | cudnn.benchmark = cfg.CUDNN.BENCHMARK # Benchmark will impove the speed 85 | cudnn.deterministic = cfg.CUDNN.DETERMINISTIC # 86 | cudnn.enabled = cfg.CUDNN.ENABLE # Enables benchmark mode in cudnn, to enable the inbuilt cudnn auto-tuner 87 | 88 | # backup codes 89 | if cfg.BASIC.BACKUP_CODES: 90 | backup_dir = os.path.join(cfg.BASIC.SAVE_DIR, 'backup') 91 | rm(backup_dir) 92 | backup_codes(cfg.BASIC.ROOT_DIR, backup_dir, cfg.BASIC.BACKUP_LIST) 93 | 94 | # create save directory 95 | cfg.BASIC.CKPT_DIR = os.path.join(cfg.BASIC.SAVE_DIR, 'ckpt') 96 | mkdir(cfg.BASIC.CKPT_DIR) 97 | cfg.BASIC.LOG_DIR = os.path.join(cfg.BASIC.SAVE_DIR, 'log') 98 | mkdir(cfg.BASIC.LOG_DIR) 99 | cfg.BASIC.LOG_FILE = os.path.join(cfg.BASIC.SAVE_DIR, 'Log_' + cfg.BASIC.TIME + '.txt') -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def creat_data_loader(cfg): 5 | print('==> Preparing data...') 6 | if cfg.DATA.DATASET == 'cub': 7 | from .cub import CUBDataset 8 | train_loader = torch.utils.data.DataLoader( 9 | CUBDataset(cfg=cfg, is_train=True), batch_size=cfg.TRAIN.BATCH_SIZE, 10 | shuffle=True, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 11 | val_loader = torch.utils.data.DataLoader( 12 | CUBDataset(cfg=cfg, is_train=False), batch_size=cfg.TEST.BATCH_SIZE, 13 | shuffle=False, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 14 | elif cfg.DATA.DATASET == 'aircraft': 15 | from .aircraft import AircraftDataset 16 | train_loader = torch.utils.data.DataLoader( 17 | AircraftDataset(cfg=cfg, is_train=True), batch_size=cfg.TRAIN.BATCH_SIZE, 18 | shuffle=True, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 19 | val_loader = torch.utils.data.DataLoader( 20 | AircraftDataset(cfg=cfg, is_train=False), batch_size=cfg.TEST.BATCH_SIZE, 21 | shuffle=False, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 22 | else: 23 | raise ValueError('Please set correct dataset.') 24 | 25 | print('done!') 26 | return train_loader, val_loader 27 | 28 | 29 | def creat_data_loader_2scale(cfg): 30 | print('==> Preparing data...') 31 | if cfg.DATA.DATASET == 'cub': 32 | from .cub_2scale import CUBDataset 33 | train_loader = torch.utils.data.DataLoader( 34 | CUBDataset(cfg=cfg, is_train=True), batch_size=cfg.TRAIN.BATCH_SIZE, 35 | shuffle=True, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 36 | val_loader = torch.utils.data.DataLoader( 37 | CUBDataset(cfg=cfg, is_train=False), batch_size=cfg.TEST.BATCH_SIZE, 38 | shuffle=False, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 39 | elif cfg.DATA.DATASET == 'aircraft': 40 | from .aircraft_2scale import AircraftDataset 41 | train_loader = torch.utils.data.DataLoader( 42 | AircraftDataset(cfg=cfg, is_train=True), batch_size=cfg.TRAIN.BATCH_SIZE, 43 | shuffle=True, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 44 | val_loader = torch.utils.data.DataLoader( 45 | AircraftDataset(cfg=cfg, is_train=False), batch_size=cfg.TEST.BATCH_SIZE, 46 | shuffle=False, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 47 | else: 48 | raise ValueError('Please set correct dataset.') 49 | 50 | print('done!') 51 | return train_loader, val_loader 52 | 53 | 54 | def creat_data_loader_2scale_sr(cfg): 55 | print('==> Preparing data...') 56 | if cfg.DATA.DATASET == 'cub': 57 | from .cub_2scale_sr import CUBDataset 58 | train_loader = torch.utils.data.DataLoader( 59 | CUBDataset(cfg=cfg, is_train=True), batch_size=cfg.TRAIN.BATCH_SIZE, 60 | shuffle=True, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 61 | val_loader = torch.utils.data.DataLoader( 62 | CUBDataset(cfg=cfg, is_train=False), batch_size=cfg.TEST.BATCH_SIZE, 63 | shuffle=False, num_workers=cfg.BASIC.NUM_WORKERS, pin_memory=True) 64 | else: 65 | raise ValueError('Please set correct dataset.') 66 | 67 | print('done!') 68 | return train_loader, val_loader 69 | 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## The official code for paper: ["Pixel Distillation: Cost-flexible Distillation across Image Sizes and Heterogeneous Networks", TPAMI 2024.](https://ieeexplore.ieee.org/document/9437331) 2 | --- 3 | 4 | ## 1. INTRODUCTION 5 | Pixel Distillation is a cost flexible distillation framework, adept at accommodating diverse image sizes and heterogeneous network architectures, thereby enabling adaptive and efficient cost reduction across varied settings. 6 | 7 | ![arch image](./figures/introduction.png) 8 | 9 | 10 | ### News 11 | The pixel distillation paradigm is also applied into **Whole Slide Images Glomerulus Detection** for reducing scanner equipment cost. Paper **[MHKD: Multi-step Hybrid Knowledge Distillation for Low-resolution Whole Slide Images Glomerulus Detection](https://ieeexplore.ieee.org/document/10786212)** is accepted by IEEE JBHI. 12 | 13 | --- 14 | 15 | ## 2. Training 16 | ```bash 17 | # List of packages 18 | python=3.9 19 | pytorch=1.12.1 20 | torchvision=0.13.1 21 | timm=0.6.5 22 | tensorboard=2.17.0 23 | pillow=10.2.0 24 | pyyaml=6.0.1 25 | scikit-learn=1.5.1 26 | ``` 27 | 28 | ### 2.1. Baseline (Teacher/Student) 29 | **Teacher networks** are trained only once with fixed seed 0. For example, to train ResNet50 teacher network on CUB dataset 30 | ```bash 31 | python tools_0_base/train_teacher_1run.py --config_file configs/base/cub/cub_resnet_single_224.yaml BASIC.SEED 0 BASIC.GPU_ID [0] 32 | ``` 33 | **Student networks** are trained 5 times. For example, to train ViT-tiny with 56 size student network on CUB dataset 34 | ```bash 35 | python tools_0_base/train_student_5runs.py --config_file configs/base/cub/cub_vit_single_56.yaml MODEL.ARCH vit_tiny_patch16_56 BASIC.GPU_ID [0] 36 | ``` 37 | 38 | Note that for ViT models, changing input resolution needs to change the definition of in `lib/models/vit.py` 39 | 40 | All available models can be found in `lib/models/model_builder.py` 41 | 42 | ### 2.2. One-stage Pixel Distillation (Teacher-Student) 43 | 44 | 45 | ![arch image](./figures/isrd.png) 46 | 47 | #### Previous KD method 48 | 49 | **KD methods** are trained 5 times. For example, using **AT** to train resnet50-teacher and resnet18-student with 56 size (LESSEN_RATIO=224/56=4.0) student network on CUB dataset 50 | ```bash 51 | python tools_1_ts/train_kd_5runs.py --config_file configs/ts/cub/cub_resnet50_resnet_kd.yaml MODEL.KDTYPE 'at' MODEL.ARCH_T 'resnet50' MODEL.MODELDICT_T 'ckpt/cub/1runs_resnet50_224_seed0_0.01/ckpt/model_best.pth' MODEL.ARCH_S 'resnet18' DATA.LESSEN_RATIO 4.0 BASIC.GPU_ID [0] 52 | ``` 53 | 54 | 55 | #### Our method 56 | 57 | **Our methods** are trained 5 times. For example, to train resnet50-teacher and resnet18-student with 56 size student network on CUB dataset: LESSEN_RATIO=224/56=4.0, ETA=50 58 | ```bash 59 | python tools_1_ts/train_isrd_5runs.py --config_file configs/ts/cub/cub_resnet50_resnet_isrd.yaml MODEL.ARCH_T 'resnet50' MODEL.MODELDICT_T 'ckpt/cub/1runs_resnet50_224_seed0_0.01/ckpt/model_best.pth' MODEL.ARCH_S 'resnet18' DATA.LESSEN_RATIO 4.0 FSR.ETA 50.0 BASIC.GPU_ID [0] 60 | ``` 61 | 62 | 63 | ### Citation 64 | ``` 65 | @ARTICLE{guo2024pixel, 66 | author={Guo, Guangyu and Zhang, Dingwen and Han, Longfei and Liu, Nian and Cheng, Ming-Ming and Han, Junwei}, 67 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 68 | title={Pixel Distillation: Cost-flexible Distillation across Image Sizes and Heterogeneous Networks}, 69 | year={2024}, 70 | volume={}, 71 | number={}, 72 | pages={1-15}, 73 | doi={10.1109/TPAMI.2024.3421277}} 74 | ``` 75 | ``` 76 | @inproceedings{ 77 | zhang2024mhkd, 78 | title={{MHKD}: Multi-step Hybrid Knowledge Distillation for Low-resolution Whole Slide Images Glomerulus Detection}, 79 | author={Xiangsen Zhang and Longfei Han and Chenchu Xu and Zhaohui Zheng and Jin Ding and Xianghui Fu and Dingwen Zhang and Junwei Han}, 80 | booktitle={IEEE-EMBS International Conference on Biomedical and Health Informatics}, 81 | year={2024}, 82 | } 83 | 84 | ``` 85 | 86 | 87 | -------------------------------------------------------------------------------- /lib/datasets/aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | 6 | 7 | def get_transforms(cfg): 8 | train_transform = transforms.Compose([ 9 | transforms.RandomResizedCrop(cfg.DATA.CROP_SIZE), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.ToTensor(), 12 | transforms.Normalize(cfg.DATA.IMAGE_MEAN, cfg.DATA.IMAGE_STD) 13 | ]) 14 | test_transform = transforms.Compose([ 15 | transforms.Resize(cfg.DATA.RESIZE_SIZE), 16 | transforms.CenterCrop(cfg.DATA.CROP_SIZE), 17 | transforms.ToTensor(), 18 | transforms.Normalize(cfg.DATA.IMAGE_MEAN, cfg.DATA.IMAGE_STD) 19 | ]) 20 | return train_transform, test_transform 21 | 22 | 23 | class AircraftDataset(Dataset): 24 | def __init__(self, cfg, is_train): 25 | 26 | self.root = os.path.join(cfg.BASIC.ROOT_DIR, cfg.DATA.DATADIR) 27 | self.cfg = cfg 28 | self.is_train = is_train 29 | self.resize_size = cfg.DATA.RESIZE_SIZE 30 | self.crop_size = cfg.DATA.CROP_SIZE 31 | 32 | self.image_folder = os.path.join(self.root, "data", "images") 33 | train_transform, test_transform = get_transforms(cfg) 34 | # "train", "val", "trainval", "test" 35 | if is_train: 36 | self._split = "trainval" 37 | self.transform = train_transform 38 | else: 39 | self._split = "test" 40 | self.transform = test_transform 41 | 42 | self._annotation_level = "variant" # "variant", "family", "manufacturer" 43 | 44 | annotation_file = os.path.join( 45 | self.root, 46 | "data", 47 | { 48 | "variant": "variants.txt", 49 | "family": "families.txt", 50 | "manufacturer": "manufacturers.txt", 51 | }[self._annotation_level], 52 | ) 53 | with open(annotation_file, "r") as f: 54 | self.classes = [line.strip() for line in f] 55 | 56 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 57 | labels_file = os.path.join(self.root, "data", f"images_{self._annotation_level}_{self._split}.txt") 58 | 59 | self._image_files = [] 60 | self._labels = [] 61 | 62 | with open(labels_file, "r") as f: 63 | for line in f: 64 | image_name, label_name = line.strip().split(" ", 1) 65 | self._image_files.append(os.path.join(self.image_folder, f"{image_name}.jpg")) 66 | self._labels.append(self.class_to_idx[label_name]) 67 | 68 | def __len__(self): 69 | return len(self._image_files) 70 | 71 | def __getitem__(self, idx): 72 | image_file, target = self._image_files[idx], self._labels[idx] 73 | image = Image.open(image_file).convert("RGB") 74 | 75 | image = self.transform(image) 76 | 77 | # if self.target_transform: 78 | # label = self.target_transform(label) 79 | 80 | return image, target, image_file 81 | 82 | 83 | if __name__ == "__main__": 84 | import os, sys 85 | 86 | sys.path.insert(0, '../../lib') 87 | 88 | import argparse 89 | from config.default import update_config 90 | from config.default import config as cfg 91 | 92 | parser = argparse.ArgumentParser(description='knowledge distillation') 93 | parser.add_argument('--config_file', type=str, default='../../configs/aircraft/aircraft_resnet_single.yaml', 94 | required=False, help='Optional config file for params') 95 | parser.add_argument('opts', help='see config.py for all options', 96 | default='BASIC.GPU_ID [7]', nargs=argparse.REMAINDER) 97 | args = parser.parse_args() 98 | update_config(args) 99 | 100 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..') 101 | 102 | import torch 103 | 104 | train_loader = torch.utils.data.DataLoader( 105 | AircraftDataset(cfg=cfg, is_train=True), 106 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 107 | val_loader = torch.utils.data.DataLoader( 108 | AircraftDataset(cfg=cfg, is_train=False), 109 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 110 | 111 | for image, label, name in train_loader: 112 | print(label) 113 | print(image.shape) -------------------------------------------------------------------------------- /lib/datasets/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | 6 | 7 | def get_transforms(cfg): 8 | train_transform = transforms.Compose([ 9 | transforms.RandomResizedCrop(cfg.DATA.CROP_SIZE), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.ToTensor(), 12 | transforms.Normalize(cfg.DATA.IMAGE_MEAN, cfg.DATA.IMAGE_STD) 13 | ]) 14 | test_transform = transforms.Compose([ 15 | transforms.Resize(cfg.DATA.RESIZE_SIZE), 16 | transforms.CenterCrop(cfg.DATA.CROP_SIZE), 17 | transforms.ToTensor(), 18 | transforms.Normalize(cfg.DATA.IMAGE_MEAN, cfg.DATA.IMAGE_STD) 19 | ]) 20 | return train_transform, test_transform 21 | 22 | 23 | class CUBDataset(Dataset): 24 | def __init__(self, cfg, is_train): 25 | 26 | self.root = os.path.join(cfg.BASIC.ROOT_DIR, cfg.DATA.DATADIR) 27 | self.cfg = cfg 28 | self.is_train = is_train 29 | self.resize_size = cfg.DATA.RESIZE_SIZE 30 | self.crop_size = cfg.DATA.CROP_SIZE 31 | 32 | self.image_list = self.remove_1st_column(open( 33 | os.path.join(self.root, 'images.txt'), 'r').readlines()) 34 | self.label_list = self.remove_1st_column(open( 35 | os.path.join(self.root, 'image_class_labels.txt'), 'r').readlines()) 36 | self.split_list = self.remove_1st_column(open( 37 | os.path.join(self.root, 'train_test_split.txt'), 'r').readlines()) 38 | self.bbox_list = self.remove_1st_column(open( 39 | os.path.join(self.root, 'bounding_boxes.txt'), 'r').readlines()) 40 | 41 | train_transform, test_transform = get_transforms(cfg) 42 | 43 | if is_train: 44 | self.index_list = self.get_index(self.split_list, '1') 45 | self.transform = train_transform 46 | else: 47 | self.index_list = self.get_index(self.split_list, '0') 48 | self.transform = test_transform 49 | 50 | def get_index(self, list, value): 51 | index = [] 52 | for i in range(len(list)): 53 | if list[i] == value: 54 | index.append(i) 55 | return index 56 | 57 | def remove_1st_column(self, input_list): 58 | output_list = [] 59 | for i in range(len(input_list)): 60 | if len(input_list[i][:-1].split(' '))==2: 61 | output_list.append(input_list[i][:-1].split(' ')[1]) 62 | else: 63 | output_list.append(input_list[i][:-1].split(' ')[1:]) 64 | return output_list 65 | 66 | def __getitem__(self, idx): 67 | name = self.image_list[self.index_list[idx]] 68 | image_path = os.path.join(self.root, 'images', name) 69 | image = Image.open(image_path).convert('RGB') 70 | label = int(self.label_list[self.index_list[idx]])-1 71 | 72 | image = self.transform(image) 73 | 74 | return image, label, name[:-4] 75 | 76 | def __len__(self): 77 | return len(self.index_list) 78 | 79 | 80 | if __name__ == "__main__": 81 | import os, sys 82 | 83 | sys.path.insert(0, '../../lib') 84 | 85 | import argparse 86 | from config.default import update_config 87 | from config.default import config as cfg 88 | 89 | parser = argparse.ArgumentParser(description='knowledge distillation') 90 | parser.add_argument('--config_file', type=str, default='../../configs/cub/cub_resnet_single.yaml', 91 | required=False, help='Optional config file for params') 92 | parser.add_argument('opts', help='see config.py for all options', 93 | default='BASIC.GPU_ID [7]', nargs=argparse.REMAINDER) 94 | args = parser.parse_args() 95 | update_config(args) 96 | 97 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..') 98 | 99 | import torch 100 | 101 | train_loader = torch.utils.data.DataLoader( 102 | CUBDataset(cfg=cfg, is_train=True), 103 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 104 | val_loader = torch.utils.data.DataLoader( 105 | CUBDataset(cfg=cfg, is_train=False), 106 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 107 | 108 | for image, label, name in val_loader: 109 | print(label) 110 | print(image.shape) 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /lib/config/default.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import datetime 3 | logger = logging.getLogger(__name__) 4 | 5 | 6 | class AttrDict(dict): 7 | """ 8 | Subclass dict and define getter-setter. 9 | This behaves as both dict and obj. 10 | """ 11 | 12 | def __getattr__(self, key): 13 | return self[key] 14 | 15 | def __setattr__(self, key, value): 16 | if key in self.__dict__: 17 | self.__dict__[key] = value 18 | else: 19 | self[key] = value 20 | 21 | 22 | __C = AttrDict() 23 | config = __C 24 | 25 | __C.BASIC = AttrDict() 26 | __C.BASIC.TIME = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M') 27 | __C.BASIC.GPU_ID = [0] 28 | __C.BASIC.NUM_WORKERS = 10 29 | __C.BASIC.DISP_FREQ = 10 # frequency to display 30 | __C.BASIC.SEED = 0 31 | __C.BASIC.SAVE_DIR = '' 32 | __C.BASIC.ROOT_DIR = '' 33 | __C.BASIC.BACKUP_CODES = True 34 | __C.BASIC.BACKUP_LIST = ['lib', 'tools'] 35 | 36 | 37 | # Model options 38 | __C.MODEL = AttrDict() 39 | __C.MODEL.TYPE = '' # single, kd, pixel(ours) 40 | __C.MODEL.ARCH = '' 41 | __C.MODEL.PRETRAIN = False 42 | __C.MODEL.KDTYPE = 'kd' # kd, at 43 | __C.MODEL.ARCH_T = '' # teacher 44 | __C.MODEL.MODELDICT_T = '' 45 | __C.MODEL.PRETRAIN_T = False 46 | __C.MODEL.ARCH_A = '' # assistant 47 | __C.MODEL.MODELDICT_A = '' 48 | __C.MODEL.PRETRAIN_A = False 49 | __C.MODEL.ARCH_S = '' # student 50 | __C.MODEL.MODELDICT_S = '' 51 | __C.MODEL.PRETRAIN_S = False 52 | __C.MODEL.PRERELU = False 53 | 54 | __C.KD = AttrDict() 55 | __C.KD.TEMP = 4 56 | __C.KD.ALPHA = 0.9 57 | 58 | __C.AT = AttrDict() 59 | __C.AT.BETA = 1000.0 60 | 61 | __C.SP = AttrDict() 62 | __C.SP.BETA = 3000.0 63 | 64 | __C.ICKD = AttrDict() 65 | __C.ICKD.BETA = 2.5 66 | __C.ICKD.FEATDIM_T = [512] 67 | __C.ICKD.FEATDIM_S = [512] 68 | 69 | __C.DKD = AttrDict() 70 | __C.DKD.ALPHA = 1.0 71 | __C.DKD.BETA = 8.0 72 | __C.DKD.TEMP = 4 73 | __C.DKD.WARMUP = 20 74 | 75 | __C.FSR = AttrDict() 76 | __C.FSR.BETA = 0.0 77 | __C.FSR.GAMMA = 0.0 78 | __C.FSR.ETA = 0.0 79 | __C.FSR.BETA1 = 0.0 80 | __C.FSR.BETA2 = 0.0 81 | __C.FSR.RESIDUAL = False 82 | __C.FSR.POSITION = 0 83 | 84 | 85 | # Data options 86 | __C.DATA = AttrDict() 87 | __C.DATA.DATASET = '' 88 | __C.DATA.DATADIR = '' 89 | __C.DATA.NUM_CLASSES = 200 90 | __C.DATA.RESIZE_SIZE = 256 91 | __C.DATA.CROP_SIZE = 224 92 | __C.DATA.LESSEN_RATIO = 1.0 93 | __C.DATA.LESSEN_TYPE = 2 # 1 Nearest 2 Bilinear 3 Bicubic 4 Antialias 94 | __C.DATA.IMAGE_MEAN = [0.485, 0.456, 0.406] 95 | __C.DATA.IMAGE_STD = [0.229, 0.224, 0.225] 96 | 97 | # solver options 98 | __C.SOLVER = AttrDict() 99 | __C.SOLVER.START_LR = 0.1 100 | __C.SOLVER.LR_STEPS = [60, 120, 160] 101 | __C.SOLVER.LR_DECAY_FACTOR = 0.1 102 | __C.SOLVER.NUM_EPOCHS = 140 103 | __C.SOLVER.WEIGHT_DECAY = 5e-4 104 | __C.SOLVER.MUMENTUM = 0.9 105 | 106 | 107 | # Training options. 108 | __C.TRAIN = AttrDict() 109 | __C.TRAIN.BATCH_SIZE = 32 110 | 111 | # Testing options. 112 | __C.TEST = AttrDict() 113 | __C.TEST.BATCH_SIZE = 32 114 | 115 | # Cudnn related setting 116 | __C.CUDNN = AttrDict() 117 | __C.CUDNN.BENCHMARK = False 118 | __C.CUDNN.DETERMINISTIC = True 119 | __C.CUDNN.ENABLE = True 120 | 121 | 122 | def merge_dicts(dict_a, dict_b): 123 | from ast import literal_eval 124 | for key, value in dict_a.items(): 125 | if key not in dict_b: 126 | raise KeyError('Invalid key in config file: {}'.format(key)) 127 | if type(value) is dict: 128 | dict_a[key] = value = AttrDict(value) 129 | if isinstance(value, str): 130 | try: 131 | value = literal_eval(value) 132 | except BaseException: 133 | pass 134 | # The types must match, too. 135 | old_type = type(dict_b[key]) 136 | if old_type is not type(value) and value is not None: 137 | raise ValueError( 138 | 'Type mismatch ({} vs. {}) for config key: {}'.format( 139 | type(dict_b[key]), type(value), key) 140 | ) 141 | # Recursively merge dicts. 142 | if isinstance(value, AttrDict): 143 | try: 144 | merge_dicts(dict_a[key], dict_b[key]) 145 | except BaseException: 146 | raise Exception('Error under config key: {}'.format(key)) 147 | else: 148 | dict_b[key] = value 149 | 150 | 151 | def cfg_from_file(filename): 152 | """Load a config file and merge it into the default options.""" 153 | import yaml 154 | with open(filename, 'r') as fopen: 155 | yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.FullLoader)) 156 | merge_dicts(yaml_config, __C) 157 | 158 | 159 | def cfg_from_list(args_list): 160 | """Set config keys via list (e.g., from command line).""" 161 | from ast import literal_eval 162 | assert len(args_list) % 2 == 0, 'Specify values or keys for args' 163 | for key, value in zip(args_list[0::2], args_list[1::2]): 164 | key_list = key.split('.') 165 | cfg = __C 166 | for subkey in key_list[:-1]: 167 | assert subkey in cfg, 'Config key {} not found'.format(subkey) 168 | cfg = cfg[subkey] 169 | subkey = key_list[-1] 170 | assert subkey in cfg, 'Config key {} not found'.format(subkey) 171 | try: 172 | # Handle the case when v is a string literal. 173 | val = literal_eval(value) 174 | except BaseException: 175 | val = value 176 | assert isinstance(val, type(cfg[subkey])) or cfg[subkey] is None, \ 177 | 'type {} does not match original type {}'.format( 178 | type(val), type(cfg[subkey])) 179 | cfg[subkey] = val 180 | 181 | 182 | def update_config(args): 183 | if args.config_file is not None: 184 | cfg_from_file(args.config_file) 185 | if args.opts is not None: 186 | cfg_from_list(args.opts) 187 | -------------------------------------------------------------------------------- /lib/datasets/aircraft_2scale.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | 6 | 7 | def get_transforms_2scale(cfg): 8 | interpolation = cfg.DATA.LESSEN_TYPE 9 | crop_size_large = cfg.DATA.CROP_SIZE 10 | crop_size_small = int(cfg.DATA.CROP_SIZE/cfg.DATA.LESSEN_RATIO) 11 | resize_size_small = int(cfg.DATA.RESIZE_SIZE / cfg.DATA.LESSEN_RATIO) 12 | 13 | train_transform_large = transforms.Compose([ 14 | transforms.RandomResizedCrop(crop_size_large), 15 | transforms.RandomHorizontalFlip(), 16 | ]) 17 | 18 | train_transform_small = transforms.Compose([ 19 | transforms.Resize((crop_size_small, crop_size_small), interpolation=interpolation) 20 | ]) 21 | 22 | train_transform_normalize = transforms.Compose([ 23 | transforms.ToTensor(), 24 | transforms.Normalize(cfg.DATA.IMAGE_MEAN, cfg.DATA.IMAGE_STD) 25 | ]) 26 | 27 | test_transform = transforms.Compose([ 28 | transforms.Resize(resize_size_small), 29 | transforms.CenterCrop(crop_size_small), 30 | transforms.ToTensor(), 31 | transforms.Normalize(cfg.DATA.IMAGE_MEAN, cfg.DATA.IMAGE_STD) 32 | ]) 33 | return train_transform_large, train_transform_small, train_transform_normalize, test_transform 34 | 35 | 36 | 37 | 38 | class AircraftDataset(Dataset): 39 | def __init__(self, cfg, is_train): 40 | 41 | self.root = os.path.join(cfg.BASIC.ROOT_DIR, cfg.DATA.DATADIR) 42 | self.cfg = cfg 43 | self.is_train = is_train 44 | self.resize_size = cfg.DATA.RESIZE_SIZE 45 | self.crop_size = cfg.DATA.CROP_SIZE 46 | 47 | self.image_folder = os.path.join(self.root, "data", "images") 48 | self.train_transform_large, self.train_transform_small, \ 49 | self.train_transform_normalize, self.test_transform = get_transforms_2scale(cfg) 50 | # "train", "val", "trainval", "test" 51 | if is_train: 52 | self._split = "trainval" 53 | else: 54 | self._split = "test" 55 | 56 | self._annotation_level = "variant" # "variant", "family", "manufacturer" 57 | 58 | annotation_file = os.path.join( 59 | self.root, 60 | "data", 61 | { 62 | "variant": "variants.txt", 63 | "family": "families.txt", 64 | "manufacturer": "manufacturers.txt", 65 | }[self._annotation_level], 66 | ) 67 | with open(annotation_file, "r") as f: 68 | self.classes = [line.strip() for line in f] 69 | 70 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 71 | labels_file = os.path.join(self.root, "data", f"images_{self._annotation_level}_{self._split}.txt") 72 | 73 | self._image_files = [] 74 | self._labels = [] 75 | 76 | with open(labels_file, "r") as f: 77 | for line in f: 78 | image_name, label_name = line.strip().split(" ", 1) 79 | self._image_files.append(os.path.join(self.image_folder, f"{image_name}.jpg")) 80 | self._labels.append(self.class_to_idx[label_name]) 81 | 82 | def __len__(self): 83 | return len(self._image_files) 84 | 85 | def __getitem__(self, idx): 86 | image_file, target = self._image_files[idx], self._labels[idx] 87 | image = Image.open(image_file).convert("RGB") 88 | 89 | if self.is_train: 90 | image_large = self.train_transform_large(image) 91 | toTensor = transforms.ToTensor() 92 | image_large_tensor = toTensor(image_large) 93 | if self.cfg.DATA.LESSEN_RATIO == 1: 94 | image_large = self.train_transform_normalize(image_large) 95 | image_small = image_large 96 | else: 97 | image_small = self.train_transform_small(image_large) 98 | image_large = self.train_transform_normalize(image_large) 99 | image_small = self.train_transform_normalize(image_small) 100 | return image_large, image_small, target, image_large_tensor 101 | else: 102 | image = self.test_transform(image) 103 | return image, target, image_file 104 | 105 | 106 | if __name__ == "__main__": 107 | import os, sys 108 | 109 | sys.path.insert(0, '../../lib') 110 | 111 | import argparse 112 | from config.default import update_config 113 | from config.default import config as cfg 114 | 115 | parser = argparse.ArgumentParser(description='knowledge distillation') 116 | parser.add_argument('--config_file', type=str, default='../../configs/aircraft/aircraft_resnet50_32x4d_resnet_kd.yaml', 117 | required=False, help='Optional config file for params') 118 | parser.add_argument('opts', help='see config.py for all options', 119 | default='BASIC.GPU_ID [7]', nargs=argparse.REMAINDER) 120 | args = parser.parse_args() 121 | update_config(args) 122 | 123 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..') 124 | 125 | import torch 126 | 127 | train_loader = torch.utils.data.DataLoader( 128 | AircraftDataset(cfg=cfg, is_train=True), 129 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 130 | val_loader = torch.utils.data.DataLoader( 131 | AircraftDataset(cfg=cfg, is_train=False), 132 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 133 | 134 | for image_l, image_s, label, name in train_loader: 135 | print(label) 136 | print(image_l.shape) 137 | print(image_s.shape) 138 | print(name) 139 | break 140 | 141 | for image, label, name in val_loader: 142 | print(label) 143 | print(image.shape) 144 | print(name) 145 | break -------------------------------------------------------------------------------- /lib/datasets/cub_2scale.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | 6 | 7 | def get_transforms_2scale(cfg): 8 | interpolation = cfg.DATA.LESSEN_TYPE 9 | crop_size_large = cfg.DATA.CROP_SIZE 10 | crop_size_small = int(cfg.DATA.CROP_SIZE/cfg.DATA.LESSEN_RATIO) 11 | resize_size_small = int(cfg.DATA.RESIZE_SIZE / cfg.DATA.LESSEN_RATIO) 12 | 13 | train_transform_large = transforms.Compose([ 14 | transforms.RandomResizedCrop(crop_size_large), 15 | transforms.RandomHorizontalFlip(), 16 | ]) 17 | 18 | train_transform_small = transforms.Compose([ 19 | transforms.Resize((crop_size_small, crop_size_small), interpolation=interpolation) 20 | ]) 21 | 22 | train_transform_normalize = transforms.Compose([ 23 | transforms.ToTensor(), 24 | transforms.Normalize(cfg.DATA.IMAGE_MEAN, cfg.DATA.IMAGE_STD) 25 | ]) 26 | 27 | test_transform = transforms.Compose([ 28 | transforms.Resize(resize_size_small), 29 | transforms.CenterCrop(crop_size_small), 30 | transforms.ToTensor(), 31 | transforms.Normalize(cfg.DATA.IMAGE_MEAN, cfg.DATA.IMAGE_STD) 32 | ]) 33 | return train_transform_large, train_transform_small, train_transform_normalize, test_transform 34 | 35 | 36 | class CUBDataset(Dataset): 37 | def __init__(self, cfg, is_train): 38 | 39 | self.root = os.path.join(cfg.BASIC.ROOT_DIR, cfg.DATA.DATADIR) 40 | self.cfg = cfg 41 | self.is_train = is_train 42 | self.resize_size = cfg.DATA.RESIZE_SIZE 43 | self.crop_size = cfg.DATA.CROP_SIZE 44 | 45 | self.image_list = self.remove_1st_column(open( 46 | os.path.join(self.root, 'images.txt'), 'r').readlines()) 47 | self.label_list = self.remove_1st_column(open( 48 | os.path.join(self.root, 'image_class_labels.txt'), 'r').readlines()) 49 | self.split_list = self.remove_1st_column(open( 50 | os.path.join(self.root, 'train_test_split.txt'), 'r').readlines()) 51 | self.bbox_list = self.remove_1st_column(open( 52 | os.path.join(self.root, 'bounding_boxes.txt'), 'r').readlines()) 53 | 54 | self.train_transform_large, self.train_transform_small, \ 55 | self.train_transform_normalize, self.test_transform = get_transforms_2scale(cfg) 56 | 57 | if is_train: 58 | self.index_list = self.get_index(self.split_list, '1') 59 | else: 60 | self.index_list = self.get_index(self.split_list, '0') 61 | 62 | def get_index(self, list, value): 63 | index = [] 64 | for i in range(len(list)): 65 | if list[i] == value: 66 | index.append(i) 67 | return index 68 | 69 | def remove_1st_column(self, input_list): 70 | output_list = [] 71 | for i in range(len(input_list)): 72 | if len(input_list[i][:-1].split(' '))==2: 73 | output_list.append(input_list[i][:-1].split(' ')[1]) 74 | else: 75 | output_list.append(input_list[i][:-1].split(' ')[1:]) 76 | return output_list 77 | 78 | def __getitem__(self, idx): 79 | name = self.image_list[self.index_list[idx]] 80 | image_path = os.path.join(self.root, 'images', name) 81 | image = Image.open(image_path).convert('RGB') 82 | target = int(self.label_list[self.index_list[idx]])-1 83 | 84 | if self.is_train: 85 | image_large = self.train_transform_large(image) 86 | toTensor = transforms.ToTensor() 87 | image_large_tensor = toTensor(image_large) 88 | if self.cfg.DATA.LESSEN_RATIO == 1: 89 | image_large = self.train_transform_normalize(image_large) 90 | image_small = image_large 91 | else: 92 | image_small = self.train_transform_small(image_large) 93 | image_large = self.train_transform_normalize(image_large) 94 | image_small = self.train_transform_normalize(image_small) 95 | return image_large, image_small, target, image_large_tensor 96 | else: 97 | image = self.test_transform(image) 98 | return image, target, name[:-4] 99 | 100 | # x = transforms.ToTensor() 101 | # a = x(image_large) 102 | 103 | def __len__(self): 104 | return len(self.index_list) 105 | 106 | 107 | if __name__ == "__main__": 108 | import os, sys 109 | 110 | sys.path.insert(0, '../../lib') 111 | 112 | import argparse 113 | from config.default import update_config 114 | from config.default import config as cfg 115 | 116 | parser = argparse.ArgumentParser(description='knowledge distillation') 117 | parser.add_argument('--config_file', type=str, default='../../configs/cub/cub_resnet50_resnet_kd.yaml', 118 | required=False, help='Optional config file for params') 119 | parser.add_argument('opts', help='see config.py for all options', 120 | default='BASIC.GPU_ID [7]', nargs=argparse.REMAINDER) 121 | args = parser.parse_args() 122 | update_config(args) 123 | 124 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..') 125 | 126 | import torch 127 | 128 | train_loader = torch.utils.data.DataLoader( 129 | CUBDataset(cfg=cfg, is_train=True), 130 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 131 | val_loader = torch.utils.data.DataLoader( 132 | CUBDataset(cfg=cfg, is_train=False), 133 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 134 | 135 | for image_l, image_s, label, name in train_loader: 136 | print(label) 137 | print(image_l.shape) 138 | print(image_s.shape) 139 | print(name) 140 | break 141 | 142 | for image, label, name in val_loader: 143 | print(label) 144 | print(image.shape) 145 | print(name) 146 | break 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /tools_0_base/train_teacher_1run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, './') 4 | import datetime 5 | import pprint 6 | import argparse 7 | from lib.core.utils import str_gpus, AverageMeter, accuracy, list2acc, save_checkpoint, map_sklearn 8 | from lib.config.default import update_config 9 | from lib.config.default import config as cfg 10 | from lib.datasets import creat_data_loader 11 | from lib.utils import mkdir, Logger, prepare_env 12 | import torch 13 | from torch.utils.tensorboard import SummaryWriter 14 | import warnings 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | def args_parser(): 19 | parser = argparse.ArgumentParser(description='knowledge distillation') 20 | parser.add_argument('--config_file', type=str, 21 | default='', 22 | required=False, help='Optional config file for params') 23 | parser.add_argument('opts', help='see config.py for all options', 24 | default=None, nargs=argparse.REMAINDER) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def creat_model(cfg): 30 | print('==> Preparing networks for baseline...') 31 | # use gpu 32 | os.environ["CUDA_VISIBLE_DEVICES"] = str_gpus(cfg.BASIC.GPU_ID) 33 | device = torch.device("cuda") 34 | assert torch.cuda.is_available(), "CUDA is not available" 35 | 36 | from lib.models.model_builder import build_model 37 | model = build_model(arch=cfg.MODEL.ARCH, num_classes=cfg.DATA.NUM_CLASSES, pretrained=cfg.MODEL.PRETRAIN) 38 | 39 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 40 | lr=cfg.SOLVER.START_LR, momentum=cfg.SOLVER.MUMENTUM, 41 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 42 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.SOLVER.LR_STEPS, 43 | gamma=cfg.SOLVER.LR_DECAY_FACTOR) 44 | 45 | model = torch.nn.DataParallel(model).to(device) 46 | 47 | # loss 48 | criterion = torch.nn.CrossEntropyLoss().to(device) 49 | print('Preparing networks done!') 50 | return device, model, optimizer, scheduler, criterion 51 | 52 | 53 | def main(): 54 | # update parameters 55 | args = args_parser() 56 | update_config(args) 57 | 58 | # create checkpoint directory 59 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..') 60 | cfg.BASIC.SAVE_DIR = os.path.join(cfg.BASIC.ROOT_DIR, 'ckpt', cfg.DATA.DATASET, '{}_{}_{}_seed{}_{}_{}'.format( 61 | cfg.MODEL.TYPE, cfg.MODEL.ARCH, cfg.DATA.CROP_SIZE, cfg.BASIC.SEED, cfg.SOLVER.START_LR, cfg.BASIC.TIME)) 62 | # prepare running environment for the whole project 63 | prepare_env(cfg) 64 | 65 | # start loging 66 | sys.stdout = Logger(cfg.BASIC.LOG_FILE) 67 | pprint.pprint(cfg) 68 | logger = SummaryWriter(cfg.BASIC.LOG_DIR) 69 | 70 | device, model, optimizer, scheduler, criterion = creat_model(cfg) 71 | train_loader, val_loader = creat_data_loader(cfg) 72 | 73 | best_acc = 0 74 | update_train_step = 0 75 | update_val_step = 0 76 | for epoch in range(1, cfg.SOLVER.NUM_EPOCHS+1): 77 | update_train_step = train_one_epoch(train_loader, model, device, criterion, optimizer, 78 | epoch, logger, cfg, update_train_step) 79 | scheduler.step() 80 | acc, update_val_step = val_one_epoch(val_loader, model, device, criterion, epoch, 81 | logger, cfg, update_val_step) 82 | 83 | # remember best accuracy and save checkpoint 84 | if acc > best_acc: 85 | best_acc = max(acc, best_acc) 86 | torch.save({ 87 | 'epoch': epoch, 88 | 'state_dict': model.module.state_dict(), 89 | 'best_acc': best_acc, 90 | }, os.path.join(cfg.BASIC.CKPT_DIR, 'model_best.pth')) 91 | print("Best epoch: {}".format(epoch)) 92 | print("Best accuracy: {}".format(best_acc)) 93 | print(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')) 94 | 95 | 96 | def train_one_epoch(train_loader, model, device, criterion, optimizer, epoch, logger, cfg, update_train_step): 97 | losses = AverageMeter() 98 | eval = AverageMeter() 99 | 100 | model.train() 101 | for i, (input, target, name) in enumerate(train_loader): 102 | # update iteration steps 103 | update_train_step += 1 104 | 105 | target = target.to(device) 106 | input = input.to(device) 107 | 108 | cls_logits = model(input) 109 | loss = criterion(cls_logits, target) 110 | optimizer.zero_grad() 111 | loss.backward() 112 | optimizer.step() 113 | 114 | eval_res = accuracy(cls_logits.data, target, topk=(1,))[0] 115 | eval_res = eval_res.item() 116 | losses.update(loss.item(), input.size(0)) 117 | eval.update(eval_res, input.size(0)) 118 | logger.add_scalar('loss_iter/train', loss.item(), update_train_step) 119 | logger.add_scalar('eval_iter/train_eval', eval_res, update_train_step) 120 | 121 | if i % cfg.BASIC.DISP_FREQ == 0 or i == len(train_loader)-1: 122 | print(('Train Epoch: [{0}][{1}/{2}],lr: {lr:.5f}\t' 123 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 124 | 'ACC@1 {eval.val:.3f} ({eval.avg:.3f})'.format( 125 | epoch, i+1, len(train_loader), loss=losses, 126 | eval=eval, lr=optimizer.param_groups[-1]['lr']))) 127 | 128 | return update_train_step 129 | 130 | 131 | def val_one_epoch(val_loader, model, device, criterion, epoch, logger, cfg, update_val_step): 132 | losses = AverageMeter() 133 | eval = AverageMeter() 134 | 135 | with torch.no_grad(): 136 | model.eval() 137 | for i, (input, target, name) in enumerate(val_loader): 138 | # update iteration steps 139 | update_val_step += 1 140 | 141 | target = target.to(device) 142 | input = input.to(device) 143 | 144 | cls_logits = model(input) 145 | loss = criterion(cls_logits, target) 146 | 147 | eval_res = accuracy(cls_logits.data, target, topk=(1,))[0] 148 | eval_res = eval_res.item() 149 | losses.update(loss.item(), input.size(0)) 150 | eval.update(eval_res, input.size(0)) 151 | logger.add_scalar('loss_iter/val', loss.item(), update_val_step) 152 | logger.add_scalar('eval_iter/val_eval', eval_res, update_val_step) 153 | 154 | if i % cfg.BASIC.DISP_FREQ == 0 or i == len(val_loader)-1: 155 | print(('VAL Epoch: [{0}][{1}/{2}]\t' 156 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 157 | 'ACC@1 {eval.val:.3f} ({eval.avg:.3f})'.format( 158 | epoch, i+1, len(val_loader), loss=losses, eval=eval))) 159 | 160 | return eval.avg, update_val_step 161 | 162 | 163 | if __name__ == "__main__": 164 | main() -------------------------------------------------------------------------------- /tools_0_base/train_student_5runs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, './') 4 | import datetime 5 | import pprint 6 | import argparse 7 | from lib.core.utils import str_gpus, AverageMeter, accuracy, list2acc, save_checkpoint, map_sklearn 8 | from lib.config.default import update_config 9 | from lib.config.default import config as cfg 10 | from lib.datasets import creat_data_loader 11 | from lib.utils import mkdir, Logger, prepare_env_noseed 12 | import torch 13 | from torch.utils.tensorboard import SummaryWriter 14 | import warnings 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | def args_parser(): 19 | parser = argparse.ArgumentParser(description='knowledge distillation') 20 | parser.add_argument('--config_file', type=str, 21 | default='', 22 | required=False, help='Optional config file for params') 23 | parser.add_argument('opts', help='see config.py for all options', 24 | default=None, nargs=argparse.REMAINDER) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def creat_model(cfg): 30 | print('==> Preparing networks for baseline...') 31 | # use gpu 32 | os.environ["CUDA_VISIBLE_DEVICES"] = str_gpus(cfg.BASIC.GPU_ID) 33 | device = torch.device("cuda") 34 | assert torch.cuda.is_available(), "CUDA is not available" 35 | 36 | from lib.models.model_builder import build_model 37 | model = build_model(arch=cfg.MODEL.ARCH, num_classes=cfg.DATA.NUM_CLASSES, pretrained=cfg.MODEL.PRETRAIN) 38 | 39 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 40 | lr=cfg.SOLVER.START_LR, momentum=cfg.SOLVER.MUMENTUM, 41 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 42 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.SOLVER.LR_STEPS, 43 | gamma=cfg.SOLVER.LR_DECAY_FACTOR) 44 | 45 | model = torch.nn.DataParallel(model).to(device) 46 | 47 | # loss 48 | criterion = torch.nn.CrossEntropyLoss().to(device) 49 | print('Preparing networks done!') 50 | return device, model, optimizer, scheduler, criterion 51 | 52 | 53 | def main(): 54 | # update parameters 55 | args = args_parser() 56 | update_config(args) 57 | 58 | # create checkpoint directory 59 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..') 60 | cfg.BASIC.SAVE_DIR = os.path.join(cfg.BASIC.ROOT_DIR, 'ckpt', cfg.DATA.DATASET, '{}_{}_{}_{}_{}'.format( 61 | cfg.MODEL.TYPE, cfg.MODEL.ARCH, cfg.DATA.CROP_SIZE, cfg.SOLVER.START_LR, cfg.BASIC.TIME)) 62 | # prepare running environment for the whole project 63 | prepare_env_noseed(cfg) 64 | 65 | # start loging 66 | sys.stdout = Logger(cfg.BASIC.LOG_FILE) 67 | pprint.pprint(cfg) 68 | 69 | best_list = [] 70 | for irun in range(1, 6): 71 | 72 | log_dir_irun = os.path.join(cfg.BASIC.LOG_DIR, str(irun)) 73 | mkdir(log_dir_irun) 74 | logger_irun = SummaryWriter(log_dir_irun) 75 | 76 | ckpt_dir_irun = os.path.join(cfg.BASIC.CKPT_DIR, str(irun)) 77 | mkdir(ckpt_dir_irun) 78 | 79 | device, model, optimizer, scheduler, criterion = creat_model(cfg) 80 | train_loader, val_loader = creat_data_loader(cfg) 81 | 82 | best_acc = 0 83 | update_train_step = 0 84 | update_val_step = 0 85 | for epoch in range(1, cfg.SOLVER.NUM_EPOCHS+1): 86 | update_train_step = train_one_epoch(train_loader, model, device, criterion, optimizer, 87 | epoch, irun, logger_irun, cfg, update_train_step) 88 | scheduler.step() 89 | acc, update_val_step = val_one_epoch(val_loader, model, device, criterion, epoch, irun, 90 | logger_irun, cfg, update_val_step) 91 | 92 | # remember best accuracy and save checkpoint 93 | if acc > best_acc: 94 | best_acc = max(acc, best_acc) 95 | torch.save({ 96 | 'epoch': epoch, 97 | 'state_dict': model.module.state_dict(), 98 | 'best_acc': best_acc, 99 | }, os.path.join(ckpt_dir_irun, 'model_best.pth')) 100 | print("Best epoch: {}".format(epoch)) 101 | print("Best accuracy: {}".format(best_acc)) 102 | print(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')) 103 | 104 | best_list.append(best_acc) 105 | print(best_list) 106 | 107 | print('Mean: {}'.format(sum(best_list)/len(best_list))) 108 | 109 | 110 | def train_one_epoch(train_loader, model, device, criterion, optimizer, epoch, irun, logger, cfg, update_train_step): 111 | losses = AverageMeter() 112 | eval = AverageMeter() 113 | 114 | model.train() 115 | for i, (input, target, name) in enumerate(train_loader): 116 | # update iteration steps 117 | update_train_step += 1 118 | 119 | target = target.to(device) 120 | input = input.to(device) 121 | 122 | cls_logits = model(input) 123 | loss = criterion(cls_logits, target) 124 | optimizer.zero_grad() 125 | loss.backward() 126 | optimizer.step() 127 | 128 | eval_res = accuracy(cls_logits.data, target, topk=(1,))[0] 129 | eval_res = eval_res.item() 130 | losses.update(loss.item(), input.size(0)) 131 | eval.update(eval_res, input.size(0)) 132 | logger.add_scalar('loss_iter/train', loss.item(), update_train_step) 133 | logger.add_scalar('eval_iter/train_eval', eval_res, update_train_step) 134 | 135 | if i % cfg.BASIC.DISP_FREQ == 0 or i == len(train_loader)-1: 136 | print(('Train Run [{0}] Epoch [{1}]: [{2}/{3}],lr: {lr:.5f}\t' 137 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 138 | 'ACC@1 {eval.val:.3f} ({eval.avg:.3f})'.format( 139 | irun, epoch, i+1, len(train_loader), loss=losses, 140 | eval=eval, lr=optimizer.param_groups[-1]['lr']))) 141 | 142 | return update_train_step 143 | 144 | 145 | def val_one_epoch(val_loader, model, device, criterion, epoch, irun, logger, cfg, update_val_step): 146 | losses = AverageMeter() 147 | eval = AverageMeter() 148 | 149 | with torch.no_grad(): 150 | model.eval() 151 | for i, (input, target, name) in enumerate(val_loader): 152 | # update iteration steps 153 | update_val_step += 1 154 | 155 | target = target.to(device) 156 | input = input.to(device) 157 | 158 | cls_logits = model(input) 159 | loss = criterion(cls_logits, target) 160 | 161 | eval_res = accuracy(cls_logits.data, target, topk=(1,))[0] 162 | eval_res = eval_res.item() 163 | losses.update(loss.item(), input.size(0)) 164 | eval.update(eval_res, input.size(0)) 165 | logger.add_scalar('loss_iter/val', loss.item(), update_val_step) 166 | logger.add_scalar('eval_iter/val_eval', eval_res, update_val_step) 167 | 168 | if i % cfg.BASIC.DISP_FREQ == 0 or i == len(val_loader)-1: 169 | print(('VAL Run [{0}] Epoch [{1}]: [{2}/{3}]\t' 170 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 171 | 'ACC@1 {eval.val:.3f} ({eval.avg:.3f})'.format( 172 | irun, epoch, i+1, len(val_loader), loss=losses, eval=eval))) 173 | 174 | return eval.avg, update_val_step 175 | 176 | 177 | if __name__ == "__main__": 178 | main() -------------------------------------------------------------------------------- /lib/models/model_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pdb 5 | 6 | 7 | # init for different models 8 | def build_model(arch, num_classes, pretrained): 9 | if num_classes!=1000: 10 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d 11 | from .shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 12 | from .vit_pixel import vit_base_patch16_224, vit_tiny_patch16_224, vit_tiny_patch16_112, vit_tiny_patch16_56, \ 13 | vit_small_patch16_112, vit_small_patch16_56 14 | 15 | model_dict = { 16 | 'resnet18': resnet18, 17 | 'resnet34': resnet34, 18 | 'resnet50': resnet50, 19 | 'resnet101': resnet101, 20 | 'resnet152': resnet152, 21 | 'resnext50_32x4d': resnext50_32x4d, 22 | 'shufflenet_v2_x0_5': shufflenet_v2_x0_5, 23 | 'shufflenet_v2_x1_0': shufflenet_v2_x1_0, 24 | "vit_tiny_patch16_224": vit_tiny_patch16_224, 25 | "vit_tiny_patch16_112": vit_tiny_patch16_112, 26 | "vit_tiny_patch16_56": vit_tiny_patch16_56, 27 | "vit_base_patch16_224": vit_base_patch16_224, 28 | "vit_small_patch16_112": vit_small_patch16_112, 29 | "vit_small_patch16_56": vit_small_patch16_56 30 | } 31 | 32 | if arch not in model_dict: 33 | raise NotImplementedError('The model {} is not implemented!'.format(arch)) 34 | elif arch in ['vit_tiny_patch16_224', 'vit_base_patch16_224']: 35 | from timm.models import create_model 36 | model = create_model( 37 | arch, 38 | pretrained=pretrained, 39 | num_classes=num_classes) 40 | return model 41 | else: 42 | return model_dict[arch](num_classes=num_classes, pretrained=pretrained) 43 | 44 | 45 | class NaiveKDModelBuilder(nn.Module): 46 | def __init__(self, model_s, model_t): 47 | super(NaiveKDModelBuilder, self).__init__() 48 | self.model_s = model_s 49 | self.model_t = model_t 50 | 51 | # freeze the teacher model 52 | for p in self.model_t.parameters(): 53 | p.requires_grad = False 54 | 55 | def forward(self, input1, input2=None, return_feat=False, preReLU=False, return_final=False): 56 | if input2 is None: 57 | input_small = input1 58 | else: 59 | input_small = input2 60 | 61 | if return_feat: 62 | logits_s, feats_s = self.model_s(input_small, return_feat=return_feat, preReLU=preReLU, return_final=return_final) 63 | else: 64 | logits_s = self.model_s(input_small) 65 | 66 | if self.training == False: 67 | return logits_s 68 | else: 69 | if return_feat: 70 | with torch.no_grad(): 71 | logits_t, feats_t = self.model_t(input1, return_feat=return_feat, preReLU=preReLU, return_final=return_final) 72 | return logits_s, feats_s, logits_t, feats_t 73 | else: 74 | with torch.no_grad(): 75 | logits_t = self.model_t(input1) 76 | return logits_s, logits_t 77 | 78 | 79 | class ISRKDModelBuilder(nn.Module): 80 | def __init__(self, model_s, model_t, model_sr, position): 81 | super(ISRKDModelBuilder, self).__init__() 82 | self.model_s = model_s 83 | self.model_t = model_t 84 | self.model_sr = model_sr 85 | self.position = position 86 | 87 | # freeze the teacher model 88 | for p in self.model_t.parameters(): 89 | p.requires_grad = False 90 | 91 | def forward(self, input1, input2=None, preReLU=False, return_final=False): 92 | if input2 is None: 93 | input_small = input1 94 | else: 95 | input_small = input2 96 | 97 | logits_s, feats_s = self.model_s(input_small, return_feat=True, 98 | preReLU=preReLU, return_final=return_final) 99 | 100 | if self.training == False: 101 | return logits_s 102 | else: 103 | image_sr = self.model_sr(feats_s[self.position]) 104 | with torch.no_grad(): 105 | logits_t = self.model_t(input1, return_feat=False, preReLU=preReLU, return_final=return_final) 106 | return logits_s, image_sr, logits_t 107 | 108 | 109 | class ISRKDModelBuilderEval(nn.Module): 110 | def __init__(self, model_s, model_t, model_sr, position): 111 | super(ISRKDModelBuilderEval, self).__init__() 112 | self.model_s = model_s 113 | self.model_t = model_t 114 | self.model_sr = model_sr 115 | self.position = position 116 | 117 | # freeze the teacher model 118 | for p in self.model_t.parameters(): 119 | p.requires_grad = False 120 | 121 | def forward(self, input1, input2=None, preReLU=False, return_final=False): 122 | if input2 is None: 123 | input_small = input1 124 | else: 125 | input_small = input2 126 | 127 | logits_s, feats_s = self.model_s(input_small, return_feat=True, 128 | preReLU=preReLU, return_final=return_final) 129 | 130 | image_sr = self.model_sr(feats_s[self.position]) 131 | if self.training == False: 132 | return logits_s, image_sr 133 | else: 134 | with torch.no_grad(): 135 | logits_t = self.model_t(input1, return_feat=False, preReLU=preReLU, return_final=return_final) 136 | return logits_s, image_sr, logits_t 137 | 138 | 139 | class ISRKD_FKDModelBuilder(nn.Module): 140 | def __init__(self, model_s, model_t, model_sr, position): 141 | super(ISRKD_FKDModelBuilder, self).__init__() 142 | self.model_s = model_s 143 | self.model_t = model_t 144 | self.model_sr = model_sr 145 | self.position = position 146 | 147 | # freeze the teacher model 148 | for p in self.model_t.parameters(): 149 | p.requires_grad = False 150 | 151 | def forward(self, input1, input2=None, preReLU=False, return_final=False): 152 | if input2 is None: 153 | input_small = input1 154 | else: 155 | input_small = input2 156 | 157 | logits_s, feats_s = self.model_s(input_small, return_feat=True, 158 | preReLU=preReLU, return_final=return_final) 159 | 160 | if self.training == False: 161 | return logits_s 162 | else: 163 | image_sr = self.model_sr(feats_s[self.position]) 164 | with torch.no_grad(): 165 | logits_t, feats_t = self.model_t(input1, return_feat=True, preReLU=preReLU, return_final=return_final) 166 | return logits_s, feats_s, image_sr, logits_t, feats_t 167 | 168 | 169 | if __name__ == "__main__": 170 | import os, sys 171 | 172 | sys.path.insert(0, '../../lib') 173 | 174 | import argparse 175 | from config.default import update_config 176 | from config.default import config as cfg 177 | 178 | parser = argparse.ArgumentParser(description='knowledge distillation') 179 | parser.add_argument('--config_file', type=str, default='../../configs/cub/cub_resnet50_resnet_kd.yaml', 180 | required=False, help='Optional config file for params') 181 | parser.add_argument('opts', help='see config.py for all options', 182 | default='BASIC.GPU_ID [7]', nargs=argparse.REMAINDER) 183 | args = parser.parse_args() 184 | update_config(args) 185 | 186 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..') 187 | 188 | import torch 189 | 190 | # fsr = FSRModelBuilder(cfg=cfg) 191 | # input = torch.ones(1, 3, 56, 56) 192 | # logits_s_lr, feats_s_lr, logits_s_hr, feat_s_hr = fsr(input) 193 | # print(feats_s_lr[-1].shape) 194 | # print(feat_s_hr.shape) 195 | 196 | -------------------------------------------------------------------------------- /tools_1_ts/train_isrd_5runs.py: -------------------------------------------------------------------------------- 1 | """ 2 | one stage 3 | image sr from feature of input conv 4 | """ 5 | import os 6 | import sys 7 | sys.path.insert(0, './') 8 | import datetime 9 | import pprint 10 | import argparse 11 | from lib.core.utils import str_gpus, AverageMeter, accuracy, list2acc, save_checkpoint, map_sklearn 12 | from lib.config.default import cfg_from_list, cfg_from_file, update_config 13 | from lib.config.default import config as cfg 14 | from lib.datasets import creat_data_loader_2scale 15 | from lib.utils import mkdir, Logger, prepare_env_noseed 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.utils.tensorboard import SummaryWriter 19 | import warnings 20 | warnings.filterwarnings('ignore') 21 | 22 | 23 | def args_parser(): 24 | parser = argparse.ArgumentParser(description='knowledge distillation') 25 | parser.add_argument('--config_file', type=str, 26 | default='', 27 | required=False, help='Optional config file for params') 28 | parser.add_argument('opts', help='see config.py for all options', 29 | default=None, nargs=argparse.REMAINDER) 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def creat_model_kd(cfg): 35 | print('==> Preparing networks for baseline...') 36 | # use gpu 37 | os.environ["CUDA_VISIBLE_DEVICES"] = str_gpus(cfg.BASIC.GPU_ID) 38 | device = torch.device("cuda") 39 | assert torch.cuda.is_available(), "CUDA is not available" 40 | 41 | from lib.models.model_builder import build_model, ISRKDModelBuilder 42 | model_t = build_model(arch=cfg.MODEL.ARCH_T, num_classes=cfg.DATA.NUM_CLASSES, pretrained=cfg.MODEL.PRETRAIN_T) 43 | model_s = build_model(arch=cfg.MODEL.ARCH_S, num_classes=cfg.DATA.NUM_CLASSES, pretrained=cfg.MODEL.PRETRAIN_S) 44 | 45 | from lib.models.hrir import SR1x1 46 | crop_size_small = int(cfg.DATA.CROP_SIZE / cfg.DATA.LESSEN_RATIO) 47 | input_s = torch.ones(1, 3, crop_size_small, crop_size_small) 48 | _, feats_s = model_s(input_s, return_feat=True) 49 | model_sr = SR1x1(cfg, list(feats_s[cfg.FSR.POSITION].shape)) 50 | 51 | model = ISRKDModelBuilder(model_s, model_t, model_sr, cfg.FSR.POSITION) 52 | 53 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 54 | lr=cfg.SOLVER.START_LR, momentum=cfg.SOLVER.MUMENTUM, 55 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 56 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.SOLVER.LR_STEPS, 57 | gamma=cfg.SOLVER.LR_DECAY_FACTOR) 58 | if cfg.DATA.DATASET != 'imagenet': 59 | state_dict_t = torch.load(os.path.join(cfg.BASIC.ROOT_DIR, cfg.MODEL.MODELDICT_T))['state_dict'] 60 | model.model_t.load_state_dict(state_dict_t) 61 | model = torch.nn.DataParallel(model).to(device) 62 | 63 | # loss 64 | cls_criterion = torch.nn.CrossEntropyLoss().to(device) 65 | from lib.losses import build_criterion 66 | from lib.losses.kd_loss import KDLoss 67 | pkd_criterion = KDLoss(cfg).to(device) 68 | isr_criterion = torch.nn.L1Loss().to(device) 69 | print('Preparing networks done!') 70 | return device, model, optimizer, scheduler, cls_criterion, pkd_criterion, isr_criterion 71 | 72 | 73 | def main(): 74 | # update parameters 75 | args = args_parser() 76 | update_config(args) 77 | 78 | # create checkpoint directory 79 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..') 80 | cfg.BASIC.SAVE_DIR = os.path.join( 81 | cfg.BASIC.ROOT_DIR, 'ckpt', cfg.DATA.DATASET, 82 | '{}_{}'.format(cfg.MODEL.ARCH_S, int(cfg.DATA.CROP_SIZE / cfg.DATA.LESSEN_RATIO)), 83 | 'isrd-p{}_{}_{}_{}_k{}_eta-{}_seed{}_{}_{}'.format( 84 | cfg.FSR.POSITION, cfg.MODEL.ARCH_T, cfg.MODEL.ARCH_S, 85 | cfg.DATA.CROP_SIZE, cfg.DATA.LESSEN_RATIO, cfg.FSR.ETA, 86 | cfg.BASIC.SEED, cfg.SOLVER.START_LR, cfg.BASIC.TIME)) 87 | 88 | # prepare running environment for the whole project 89 | prepare_env_noseed(cfg) 90 | 91 | # start loging 92 | sys.stdout = Logger(cfg.BASIC.LOG_FILE) 93 | pprint.pprint(cfg) 94 | 95 | best_list = [] 96 | for irun in range(1, 6): 97 | log_dir_irun = os.path.join(cfg.BASIC.LOG_DIR, str(irun)) 98 | mkdir(log_dir_irun) 99 | logger_irun = SummaryWriter(log_dir_irun) 100 | 101 | ckpt_dir_irun = os.path.join(cfg.BASIC.CKPT_DIR, str(irun)) 102 | mkdir(ckpt_dir_irun) 103 | 104 | device, model, optimizer, scheduler, cls_criterion, pkd_criterion, isr_criterion = creat_model_kd(cfg) 105 | train_loader, val_loader = creat_data_loader_2scale(cfg) 106 | 107 | best_acc = 0 108 | update_train_step = 0 109 | update_val_step = 0 110 | for epoch in range(1, cfg.SOLVER.NUM_EPOCHS+1): 111 | update_train_step = train_one_epoch(train_loader, model, device, cls_criterion, pkd_criterion, isr_criterion, 112 | optimizer, epoch, irun, logger_irun, cfg, update_train_step) 113 | scheduler.step() 114 | acc, update_val_step = val_one_epoch(val_loader, model, device, cls_criterion, epoch, irun, 115 | logger_irun, cfg, update_val_step) 116 | 117 | # best accuracy and save checkpoint 118 | if acc > best_acc: 119 | best_acc = max(acc, best_acc) 120 | # torch.save({ 121 | # 'epoch': epoch, 122 | # 'state_dict': model.module.model_s.state_dict(), 123 | # 'best_acc': best_acc, 124 | # }, os.path.join(ckpt_dir_irun, 'model_best.pth')) 125 | torch.save({ 126 | 'epoch': epoch, 127 | 'state_dict': model.state_dict(), 128 | 'best_acc': best_acc, 129 | }, os.path.join(ckpt_dir_irun, 'model_best.pth')) 130 | 131 | print("Best epoch: {}".format(epoch)) 132 | print("Best accuracy: {}".format(best_acc)) 133 | best_list.append(best_acc) 134 | 135 | print(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')) 136 | print(best_list) 137 | print('Mean: {}'.format(sum(best_list) / len(best_list))) 138 | 139 | 140 | def train_one_epoch(train_loader, model, device, cls_criterion, pkd_criterion, isr_criterion, optimizer, 141 | epoch, irun, logger, cfg, update_train_step): 142 | losses = AverageMeter() 143 | eval = AverageMeter() 144 | 145 | model.train() 146 | model.module.model_t.eval() 147 | for i, (input_large, input_small, target, image_large) in enumerate(train_loader): 148 | # update iteration steps 149 | update_train_step += 1 150 | 151 | target = target.to(device) 152 | input_large = input_large.to(device) 153 | input_small = input_small.to(device) 154 | 155 | cls_logits_s, image_sr, cls_logits_t= model(input_large, input_small, preReLU=True) 156 | kd_loss = pkd_criterion(cls_logits_s, cls_logits_t, target) 157 | isr_loss = isr_criterion(image_sr, input_large) 158 | loss = kd_loss + isr_loss*cfg.FSR.ETA 159 | 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | 164 | eval_res = accuracy(cls_logits_s.data, target, topk=(1,))[0] 165 | eval_res = eval_res.item() 166 | losses.update(loss.item(), input_large.size(0)) 167 | eval.update(eval_res, input_large.size(0)) 168 | logger.add_scalar('loss_iter/train', loss.item(), update_train_step) 169 | logger.add_scalar('eval_iter/train_eval', eval_res, update_train_step) 170 | 171 | if i % cfg.BASIC.DISP_FREQ == 0 or i == len(train_loader)-1: 172 | print(('Train Run [{0}] Epoch [{1}]: [{2}/{3}], lr: {lr:.5f} ' 173 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 174 | 'ACC@1 {eval.val:.3f} ({eval.avg:.3f})'.format( 175 | irun, epoch, i+1, len(train_loader), loss=losses, 176 | eval=eval, lr=optimizer.param_groups[-1]['lr']))) 177 | 178 | return update_train_step 179 | 180 | 181 | def val_one_epoch(val_loader, model, device, criterion, epoch, irun, logger, cfg, update_val_step): 182 | losses = AverageMeter() 183 | eval = AverageMeter() 184 | 185 | with torch.no_grad(): 186 | model.eval() 187 | for i, (input, target, name) in enumerate(val_loader): 188 | # update iteration steps 189 | update_val_step += 1 190 | 191 | target = target.to(device) 192 | input = input.to(device) 193 | 194 | cls_logits = model(input) 195 | loss = criterion(cls_logits, target) 196 | 197 | eval_res = accuracy(cls_logits.data, target, topk=(1,))[0] 198 | eval_res = eval_res.item() 199 | losses.update(loss.item(), input.size(0)) 200 | eval.update(eval_res, input.size(0)) 201 | logger.add_scalar('loss_iter/val', loss.item(), update_val_step) 202 | logger.add_scalar('eval_iter/val_eval', eval_res, update_val_step) 203 | 204 | if i % cfg.BASIC.DISP_FREQ == 0 or i == len(val_loader)-1: 205 | print(('VAL Run [{0}] Epoch [{1}]: [{2}/{3}] ' 206 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 207 | 'ACC@1 {eval.val:.3f} ({eval.avg:.3f})'.format( 208 | irun, epoch, i+1, len(val_loader), loss=losses, eval=eval))) 209 | 210 | return eval.avg, update_val_step 211 | 212 | 213 | if __name__ == "__main__": 214 | main() -------------------------------------------------------------------------------- /lib/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | ShuffleNetV2. https://github.com/pytorch/vision/blob/master/torchvision/models/shufflenetv2.py 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | from collections import OrderedDict 12 | 13 | __all__ = [ 14 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 15 | ] 16 | 17 | model_urls = { 18 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 19 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 20 | 'shufflenetv2_x1.5': None, 21 | 'shufflenetv2_x2.0': None, 22 | } 23 | 24 | def channel_shuffle(x, groups): 25 | batchsize, num_channels, height, width = x.data.size() 26 | channels_per_group = num_channels // groups 27 | x = x.view(batchsize, groups, 28 | channels_per_group, height, width) 29 | x = torch.transpose(x, 1, 2).contiguous() 30 | x = x.view(batchsize, -1, height, width) 31 | return x 32 | 33 | class InvertedResidual(nn.Module): 34 | def __init__(self, inp, oup, stride, pre_act=True): 35 | super(InvertedResidual, self).__init__() 36 | 37 | if not (1 <= stride <= 3): 38 | raise ValueError('illegal stride value') 39 | self.stride = stride 40 | 41 | branch_features = oup // 2 42 | assert (self.stride != 1) or (inp == branch_features << 1) 43 | 44 | if self.stride > 1: 45 | if pre_act: 46 | self.branch1 = nn.Sequential( 47 | nn.ReLU(inplace=False), 48 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 49 | nn.BatchNorm2d(inp), 50 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 51 | nn.BatchNorm2d(branch_features), 52 | ) 53 | else: 54 | self.branch1 = nn.Sequential( 55 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 56 | nn.BatchNorm2d(inp), 57 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 58 | nn.BatchNorm2d(branch_features), 59 | ) 60 | else: 61 | self.branch1 = nn.Sequential() 62 | 63 | if pre_act: 64 | self.branch2 = nn.Sequential( 65 | nn.ReLU(inplace=False), 66 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 67 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 68 | nn.BatchNorm2d(branch_features), 69 | nn.ReLU(inplace=True), 70 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 71 | nn.BatchNorm2d(branch_features), 72 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 73 | nn.BatchNorm2d(branch_features), 74 | ) 75 | else: 76 | self.branch2 = nn.Sequential( 77 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 78 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 79 | nn.BatchNorm2d(branch_features), 80 | nn.ReLU(inplace=True), 81 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 82 | nn.BatchNorm2d(branch_features), 83 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 84 | nn.BatchNorm2d(branch_features), 85 | ) 86 | 87 | @staticmethod 88 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 89 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 90 | 91 | def forward(self, x): 92 | if self.stride == 1: 93 | x1, x2 = x.chunk(2, dim=1) 94 | out = torch.cat((x1, self.branch2(x2)), dim=1) 95 | else: 96 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 97 | 98 | out = channel_shuffle(out, 2) 99 | 100 | return out 101 | 102 | 103 | class ShuffleNetV2(nn.Module): 104 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual): 105 | super(ShuffleNetV2, self).__init__() 106 | 107 | if len(stages_repeats) != 3: 108 | raise ValueError('expected stages_repeats as list of 3 positive ints') 109 | if len(stages_out_channels) != 5: 110 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 111 | self._stage_out_channels = stages_out_channels 112 | 113 | input_channels = 3 114 | output_channels = self._stage_out_channels[0] 115 | self.conv1 = nn.Sequential( 116 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 117 | nn.BatchNorm2d(output_channels), 118 | ) 119 | input_channels = output_channels 120 | 121 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 122 | 123 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 124 | for name, repeats, output_channels in zip( 125 | stage_names, stages_repeats, self._stage_out_channels[1:]): 126 | pre_act = False if name == 'stage2' else True 127 | seq = [inverted_residual(input_channels, output_channels, 2, pre_act=pre_act)] 128 | for i in range(repeats - 1): 129 | seq.append(inverted_residual(output_channels, output_channels, 1)) 130 | setattr(self, name, nn.Sequential(*seq)) 131 | input_channels = output_channels 132 | 133 | output_channels = self._stage_out_channels[-1] 134 | self.conv5 = nn.Sequential( 135 | nn.ReLU(inplace=True), 136 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 137 | nn.BatchNorm2d(output_channels), 138 | nn.ReLU(inplace=True), 139 | ) 140 | 141 | self.fc_cls = nn.Linear(output_channels, num_classes) 142 | 143 | def forward(self, x, return_feat=False, preReLU=False, return_final=False): 144 | feat_input = self.conv1(x) # 56 x 56 145 | x = F.relu(feat_input) 146 | x = self.maxpool(x) 147 | 148 | feat1 = self.stage2(x) 149 | feat2 = self.stage3(feat1) 150 | feat3 = self.stage4(feat2) 151 | feat4 = self.conv5(feat3) 152 | 153 | x = F.adaptive_avg_pool2d(feat4, (1, 1)) 154 | x = x.view(x.size(0), x.size(1)) 155 | if return_final: 156 | finall_feat = x 157 | x = self.fc_cls(x) 158 | 159 | if return_feat: 160 | if not preReLU: 161 | feat_input = F.relu(feat_input) 162 | feat1 = F.relu(feat1) 163 | feat2 = F.relu(feat2) 164 | feat3 = F.relu(feat3) 165 | feat4 = F.relu(feat4) 166 | 167 | if return_final: 168 | return x, [feat_input, feat1, feat2, feat3, feat4, finall_feat] 169 | else: 170 | return x, [feat_input, feat1, feat2, feat3, feat4] 171 | else: 172 | return x 173 | 174 | def get_feat_size(self, x): 175 | """ 176 | :param x: input 177 | :return: size of final feat 178 | """ 179 | _, feats = self.forward(x, return_feat=True) 180 | feat_size = feats[-1].shape 181 | return list(feat_size) 182 | 183 | def get_input_feat_size(self, x): 184 | """ 185 | :param x: input 186 | :return: size of input feat 187 | """ 188 | _, feats = self.forward(x, return_feat=True) 189 | feat_size = feats[0].shape 190 | return list(feat_size) 191 | 192 | 193 | 194 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 195 | model = ShuffleNetV2(*args, **kwargs) 196 | 197 | if pretrained: 198 | model_url = model_urls[arch] 199 | if model_url is None: 200 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 201 | else: 202 | state_dict = load_state_dict_from_url(model_url, progress=progress) 203 | _pretrained_dict = OrderedDict() 204 | for idx, (k, v) in enumerate(state_dict.items()): 205 | splitted_k = k.split('.') 206 | # special for 1.0x 207 | if 29 < idx < 280: 208 | splitted_k[-2] = str(int(splitted_k[-2]) + 1) 209 | _pretrained_dict['.'.join(splitted_k)] = v 210 | else: 211 | _pretrained_dict[k] = v 212 | 213 | model_dict = model.state_dict() 214 | _pretrained_dict = {k: v for k, v in _pretrained_dict.items() if k in model_dict} 215 | model_dict.update(_pretrained_dict) 216 | model.load_state_dict(model_dict) 217 | # release 218 | del _pretrained_dict 219 | del state_dict 220 | return model 221 | 222 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 223 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 224 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 225 | 226 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 227 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 228 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 229 | 230 | 231 | if __name__ == "__main__": 232 | 233 | input = torch.ones(1, 3, 56, 56) 234 | model = shufflenet_v2_x1_0(pretrained=True) 235 | x, feats = model(input, return_feat=True, preReLU=False, return_final=True) 236 | 237 | print(feats[0].shape) 238 | print(feats[1].shape) 239 | print(feats[2].shape) 240 | print(feats[3].shape) 241 | print(feats[4].shape) 242 | print(feats[5].shape) 243 | -------------------------------------------------------------------------------- /tools_1_ts/train_kd_5runs.py: -------------------------------------------------------------------------------- 1 | """ 2 | one stage 3 | student classification trained by teacher, using kd loss 4 | """ 5 | import os 6 | import sys 7 | sys.path.insert(0, './') 8 | import datetime 9 | import pprint 10 | import argparse 11 | from lib.core.utils import str_gpus, AverageMeter, accuracy 12 | from lib.config.default import update_config 13 | from lib.config.default import config as cfg 14 | from lib.datasets import creat_data_loader_2scale 15 | from lib.utils import mkdir, Logger, prepare_env_noseed 16 | import torch 17 | from torch.utils.tensorboard import SummaryWriter 18 | import warnings 19 | warnings.filterwarnings('ignore') 20 | 21 | 22 | def args_parser(): 23 | parser = argparse.ArgumentParser(description='knowledge distillation') 24 | parser.add_argument('--config_file', type=str, 25 | default='', 26 | required=False, help='Optional config file for params') 27 | parser.add_argument('opts', help='see config.py for all options', 28 | default=None, nargs=argparse.REMAINDER) 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def creat_model_kd(cfg): 34 | print('==> Preparing networks for baseline...') 35 | # use gpu 36 | os.environ["CUDA_VISIBLE_DEVICES"] = str_gpus(cfg.BASIC.GPU_ID) 37 | device = torch.device("cuda") 38 | assert torch.cuda.is_available(), "CUDA is not available" 39 | 40 | from lib.models.model_builder import build_model, NaiveKDModelBuilder 41 | model_t = build_model(arch=cfg.MODEL.ARCH_T, num_classes=cfg.DATA.NUM_CLASSES, pretrained=cfg.MODEL.PRETRAIN_T) 42 | model_s = build_model(arch=cfg.MODEL.ARCH_S, num_classes=cfg.DATA.NUM_CLASSES, pretrained=cfg.MODEL.PRETRAIN_S) 43 | 44 | if cfg.MODEL.KDTYPE == 'ickd': 45 | crop_size_large = cfg.DATA.CROP_SIZE 46 | crop_size_small = int(cfg.DATA.CROP_SIZE / cfg.DATA.LESSEN_RATIO) 47 | input_l = torch.ones(1, 3, crop_size_large, crop_size_large) 48 | input_s = torch.ones(1, 3, crop_size_small, crop_size_small) 49 | cfg.ICKD.FEATDIM_T = list(model_t.get_feat_size(input_l)) 50 | cfg.ICKD.FEATDIM_S = list(model_s.get_feat_size(input_s)) 51 | 52 | if cfg.MODEL.KDTYPE == 'kd': 53 | model = NaiveKDModelBuilder(model_s, model_t) 54 | else: 55 | model = NaiveKDModelBuilder(model_s, model_t) 56 | 57 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 58 | lr=cfg.SOLVER.START_LR, momentum=cfg.SOLVER.MUMENTUM, 59 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 60 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.SOLVER.LR_STEPS, 61 | gamma=cfg.SOLVER.LR_DECAY_FACTOR) 62 | 63 | if cfg.DATA.DATASET != 'imagenet': 64 | state_dict_t = torch.load(os.path.join(cfg.BASIC.ROOT_DIR, cfg.MODEL.MODELDICT_T))['state_dict'] 65 | model.model_t.load_state_dict(state_dict_t) 66 | model = torch.nn.DataParallel(model).to(device) 67 | 68 | # loss 69 | cls_criterion = torch.nn.CrossEntropyLoss().to(device) 70 | from lib.losses import build_criterion 71 | from lib.losses.kd_loss import KDLoss 72 | pkd_criterion = KDLoss(cfg).to(device) 73 | kd_criterion_this = build_criterion(cfg).to(device) 74 | print('Preparing networks done!') 75 | return device, model, optimizer, scheduler, cls_criterion, pkd_criterion, kd_criterion_this 76 | 77 | 78 | def main(): 79 | # update parameters 80 | args = args_parser() 81 | update_config(args) 82 | 83 | # create checkpoint directory 84 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..') 85 | cfg.BASIC.SAVE_DIR = os.path.join( 86 | cfg.BASIC.ROOT_DIR, 'ckpt', cfg.DATA.DATASET, 87 | '{}_{}'.format(cfg.MODEL.ARCH_S, int(cfg.DATA.CROP_SIZE/cfg.DATA.LESSEN_RATIO)), 88 | '{}_{}_{}_{}_k{}_seed{}_{}_{}'.format(cfg.MODEL.KDTYPE, cfg.MODEL.ARCH_T, cfg.MODEL.ARCH_S, 89 | cfg.DATA.CROP_SIZE, cfg.DATA.LESSEN_RATIO, cfg.BASIC.SEED, 90 | cfg.SOLVER.START_LR, cfg.BASIC.TIME)) 91 | 92 | # prepare running environment for the whole project 93 | prepare_env_noseed(cfg) 94 | 95 | # start loging 96 | sys.stdout = Logger(cfg.BASIC.LOG_FILE) 97 | pprint.pprint(cfg) 98 | 99 | best_list = [] 100 | for irun in range(1, 6): 101 | log_dir_irun = os.path.join(cfg.BASIC.LOG_DIR, str(irun)) 102 | mkdir(log_dir_irun) 103 | logger_irun = SummaryWriter(log_dir_irun) 104 | 105 | ckpt_dir_irun = os.path.join(cfg.BASIC.CKPT_DIR, str(irun)) 106 | mkdir(ckpt_dir_irun) 107 | 108 | device, model, optimizer, scheduler, cls_criterion, pkd_criterion, kd_criterion_this = creat_model_kd(cfg) 109 | train_loader, val_loader = creat_data_loader_2scale(cfg) 110 | 111 | best_acc = 0 112 | update_train_step = 0 113 | update_val_step = 0 114 | for epoch in range(1, cfg.SOLVER.NUM_EPOCHS+1): 115 | update_train_step = train_one_epoch(train_loader, model, device, cls_criterion, pkd_criterion, 116 | kd_criterion_this, optimizer, epoch, irun, logger_irun, cfg, 117 | update_train_step) 118 | scheduler.step() 119 | acc, update_val_step = val_one_epoch(val_loader, model, device, cls_criterion, epoch, irun, 120 | logger_irun, cfg, update_val_step) 121 | 122 | # remember best accuracy and save checkpoint 123 | if acc > best_acc: 124 | best_acc = max(acc, best_acc) 125 | torch.save({ 126 | 'epoch': epoch, 127 | 'state_dict': model.module.model_s.state_dict(), 128 | 'best_acc': best_acc, 129 | }, os.path.join(ckpt_dir_irun, 'model_best.pth')) 130 | print("Best epoch: {}".format(epoch)) 131 | print("Best accuracy: {}".format(best_acc)) 132 | print(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')) 133 | best_list.append(best_acc) 134 | 135 | print('Mean: {}'.format(sum(best_list) / len(best_list))) 136 | 137 | 138 | def train_one_epoch(train_loader, model, device, cls_criterion, pkd_criterion, kd_criterion_this, optimizer, 139 | epoch, irun, logger, cfg, update_train_step): 140 | losses = AverageMeter() 141 | eval = AverageMeter() 142 | 143 | model.train() 144 | model.module.model_t.eval() 145 | for i, (input_large, input_small, target, name) in enumerate(train_loader): 146 | # update iteration steps 147 | update_train_step += 1 148 | 149 | target = target.to(device) 150 | input_large = input_large.to(device) 151 | input_small = input_small.to(device) 152 | 153 | if cfg.MODEL.KDTYPE == 'kd' or cfg.MODEL.KDTYPE == 'dkd': 154 | if cfg.MODEL.KDTYPE == 'kd': 155 | cls_logits_s, cls_logits_t = model(input_large, input_small) 156 | loss = pkd_criterion(cls_logits_s, cls_logits_t, target) 157 | elif cfg.MODEL.KDTYPE == 'dkd': 158 | cls_logits_s, cls_logits_t = model(input_large, input_small) 159 | cls_loss = cls_criterion(cls_logits_s, target) 160 | dkd_loss = kd_criterion_this(cls_logits_s, cls_logits_t, target, epoch+1) 161 | loss = cls_loss + dkd_loss 162 | else: 163 | cls_logits_s, feats_s, cls_logits_t, feats_t = model(input_large, input_small, return_feat=True) 164 | cls_loss = cls_criterion(cls_logits_s, target) 165 | if cfg.MODEL.KDTYPE == 'sp': 166 | fkd_loss = sum(kd_criterion_this([feats_s[-1]], [feats_t[-1]])) 167 | loss = cls_loss + fkd_loss * cfg.SP.BETA 168 | elif cfg.MODEL.KDTYPE == 'at': 169 | fkd_loss = sum(kd_criterion_this(feats_s[1:], feats_t[1:])) 170 | loss = cls_loss + fkd_loss * cfg.AT.BETA 171 | elif cfg.MODEL.KDTYPE == 'ickd': 172 | fkd_loss = sum(kd_criterion_this([feats_s[-1]], [feats_t[-1]])) 173 | loss = cls_loss + fkd_loss * cfg.ICKD.BETA 174 | else: 175 | raise ValueError('Please set correct knowledge distillation method.') 176 | 177 | 178 | optimizer.zero_grad() 179 | loss.backward() 180 | optimizer.step() 181 | 182 | eval_res = accuracy(cls_logits_s.data, target, topk=(1,))[0] 183 | eval_res = eval_res.item() 184 | losses.update(loss.item(), input_large.size(0)) 185 | eval.update(eval_res, input_large.size(0)) 186 | logger.add_scalar('loss_iter/train', loss.item(), update_train_step) 187 | logger.add_scalar('eval_iter/train_eval', eval_res, update_train_step) 188 | 189 | if i % cfg.BASIC.DISP_FREQ == 0 or i == len(train_loader)-1: 190 | print(('Train Run [{0}] Epoch [{1}]: [{2}/{3}], lr: {lr:.5f} ' 191 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 192 | 'ACC@1 {eval.val:.3f} ({eval.avg:.3f})'.format( 193 | irun, epoch, i+1, len(train_loader), loss=losses, 194 | eval=eval, lr=optimizer.param_groups[-1]['lr']))) 195 | 196 | return update_train_step 197 | 198 | 199 | def val_one_epoch(val_loader, model, device, criterion, epoch, irun, logger, cfg, update_val_step): 200 | losses = AverageMeter() 201 | eval = AverageMeter() 202 | 203 | with torch.no_grad(): 204 | model.eval() 205 | for i, (input, target, name) in enumerate(val_loader): 206 | # update iteration steps 207 | update_val_step += 1 208 | 209 | target = target.to(device) 210 | input = input.to(device) 211 | 212 | cls_logits = model(input) 213 | loss = criterion(cls_logits, target) 214 | 215 | eval_res = accuracy(cls_logits.data, target, topk=(1,))[0] 216 | eval_res = eval_res.item() 217 | losses.update(loss.item(), input.size(0)) 218 | eval.update(eval_res, input.size(0)) 219 | logger.add_scalar('loss_iter/val', loss.item(), update_val_step) 220 | logger.add_scalar('eval_iter/val_eval', eval_res, update_val_step) 221 | 222 | if i % cfg.BASIC.DISP_FREQ == 0 or i == len(val_loader)-1: 223 | print(('VAL Run [{0}] Epoch [{1}]: [{2}/{3}] ' 224 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 225 | 'ACC@1 {eval.val:.3f} ({eval.avg:.3f})'.format( 226 | irun, epoch, i+1, len(val_loader), loss=losses, eval=eval))) 227 | 228 | return eval.avg, update_val_step 229 | 230 | 231 | if __name__ == "__main__": 232 | main() -------------------------------------------------------------------------------- /lib/models/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | ResNet. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | try: 11 | from torch.hub import load_state_dict_from_url 12 | except ImportError: 13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 14 | 15 | 16 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 17 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 18 | 'wide_resnet50_2', 'wide_resnet101_2'] 19 | 20 | 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 28 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 29 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 30 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 31 | } 32 | 33 | 34 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 35 | """3x3 convolution with padding""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 37 | padding=dilation, groups=groups, bias=False, dilation=dilation) 38 | 39 | 40 | def conv1x1(in_planes, out_planes, stride=1): 41 | """1x1 convolution""" 42 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | expansion = 1 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 49 | base_width=64, dilation=1, norm_layer=None): 50 | super(BasicBlock, self).__init__() 51 | if norm_layer is None: 52 | norm_layer = nn.BatchNorm2d 53 | if groups != 1 or base_width != 64: 54 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 55 | if dilation > 1: 56 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 57 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 58 | self.conv1 = conv3x3(inplanes, planes, stride) 59 | self.bn1 = norm_layer(planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.conv2 = conv3x3(planes, planes) 62 | self.bn2 = norm_layer(planes) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | x = F.relu(x) 68 | identity = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | # out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 88 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 89 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 90 | # This variant is also known as ResNet V1.5 and improves accuracy according to 91 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 92 | 93 | expansion = 4 94 | 95 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 96 | base_width=64, dilation=1, norm_layer=None): 97 | super(Bottleneck, self).__init__() 98 | if norm_layer is None: 99 | norm_layer = nn.BatchNorm2d 100 | width = int(planes * (base_width / 64.)) * groups 101 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 102 | self.conv1 = conv1x1(inplanes, width) 103 | self.bn1 = norm_layer(width) 104 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 105 | self.bn2 = norm_layer(width) 106 | self.conv3 = conv1x1(width, planes * self.expansion) 107 | self.bn3 = norm_layer(planes * self.expansion) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.downsample = downsample 110 | self.stride = stride 111 | 112 | def forward(self, x): 113 | x = F.relu(x) 114 | identity = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | identity = self.downsample(x) 129 | 130 | out += identity 131 | # out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class ResNet(nn.Module): 137 | 138 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 139 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 140 | norm_layer=None): 141 | super(ResNet, self).__init__() 142 | if norm_layer is None: 143 | norm_layer = nn.BatchNorm2d 144 | self._norm_layer = norm_layer 145 | 146 | self.inplanes = 64 147 | self.dilation = 1 148 | if replace_stride_with_dilation is None: 149 | # each element in the tuple indicates if we should replace 150 | # the 2x2 stride with a dilated convolution instead 151 | replace_stride_with_dilation = [False, False, False] 152 | if len(replace_stride_with_dilation) != 3: 153 | raise ValueError("replace_stride_with_dilation should be None " 154 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 155 | self.groups = groups 156 | self.base_width = width_per_group 157 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 158 | bias=False) 159 | self.bn1 = norm_layer(self.inplanes) 160 | self.relu = nn.ReLU(inplace=True) 161 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 162 | self.layer1 = self._make_layer(block, 64, layers[0]) 163 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 164 | dilate=replace_stride_with_dilation[0]) 165 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 166 | dilate=replace_stride_with_dilation[1]) 167 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 168 | dilate=replace_stride_with_dilation[2]) 169 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 170 | self.fc_cls = nn.Linear(512 * block.expansion, num_classes) 171 | 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv2d): 174 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 175 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 176 | nn.init.constant_(m.weight, 1) 177 | nn.init.constant_(m.bias, 0) 178 | 179 | # Zero-initialize the last BN in each residual branch, 180 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 181 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 182 | if zero_init_residual: 183 | for m in self.modules(): 184 | if isinstance(m, Bottleneck): 185 | nn.init.constant_(m.bn3.weight, 0) 186 | elif isinstance(m, BasicBlock): 187 | nn.init.constant_(m.bn2.weight, 0) 188 | 189 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 190 | norm_layer = self._norm_layer 191 | downsample = None 192 | previous_dilation = self.dilation 193 | if dilate: 194 | self.dilation *= stride 195 | stride = 1 196 | if stride != 1 or self.inplanes != planes * block.expansion: 197 | downsample = nn.Sequential( 198 | conv1x1(self.inplanes, planes * block.expansion, stride), 199 | norm_layer(planes * block.expansion), 200 | ) 201 | 202 | layers = [] 203 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 204 | self.base_width, previous_dilation, norm_layer)) 205 | self.inplanes = planes * block.expansion 206 | for _ in range(1, blocks): 207 | layers.append(block(self.inplanes, planes, groups=self.groups, 208 | base_width=self.base_width, dilation=self.dilation, 209 | norm_layer=norm_layer)) 210 | 211 | return nn.Sequential(*layers) 212 | 213 | def _forward_impl(self, x): 214 | # See note [TorchScript super()] 215 | x = self.conv1(x) 216 | x = self.bn1(x) 217 | x = self.relu(x) 218 | x = self.maxpool(x) 219 | 220 | x = self.layer1(x) 221 | x = self.layer2(x) 222 | x = self.layer3(x) 223 | x = F.relu(self.layer4(x)) 224 | 225 | x = self.avgpool(x) 226 | x = x.view(x.size(0), -1) 227 | x = self.fc_cls(x) 228 | 229 | return x 230 | 231 | def forward(self, x, return_feat=False, preReLU=False, return_final=False): 232 | # See note [TorchScript super()] 233 | feat_input = self.conv1(x) 234 | x = self.bn1(feat_input) 235 | x = self.relu(x) 236 | x = self.maxpool(x) 237 | 238 | feat1 = self.layer1(x) 239 | feat2 = self.layer2(feat1) 240 | feat3 = self.layer3(feat2) 241 | feat4 = self.layer4(feat3) 242 | 243 | x = self.avgpool(F.relu(feat4)) 244 | x = x.view(x.size(0), -1) 245 | if return_final: 246 | finall_feat = x 247 | x = self.fc_cls(x) 248 | 249 | if return_feat: 250 | if not preReLU: 251 | feat_input = F.relu(feat_input) 252 | feat1 = F.relu(feat1) 253 | feat2 = F.relu(feat2) 254 | feat3 = F.relu(feat3) 255 | feat4 = F.relu(feat4) 256 | 257 | if return_final: 258 | return x, [feat_input, feat1, feat2, feat3, feat4, finall_feat] 259 | else: 260 | return x, [feat_input, feat1, feat2, feat3, feat4] 261 | else: 262 | return x 263 | 264 | def get_feat_size(self, x): 265 | """ 266 | :param x: input 267 | :return: size of final feat 268 | """ 269 | _, feats = self.forward(x, return_feat=True) 270 | feat_size = feats[-1].shape 271 | return list(feat_size) 272 | 273 | def get_input_feat_size(self, x): 274 | """ 275 | :param x: input 276 | :return: size of input feat 277 | """ 278 | _, feats = self.forward(x, return_feat=True) 279 | feat_size = feats[0].shape 280 | return list(feat_size) 281 | 282 | def get_input_weights(self,): 283 | """ 284 | :return: weights of input conv 285 | """ 286 | return self.conv1.weight 287 | 288 | 289 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 290 | model = ResNet(block, layers, **kwargs) 291 | if pretrained: 292 | pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 293 | model_dict = model.state_dict() 294 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} 295 | model_dict.update(pretrained_dict) 296 | model.load_state_dict(model_dict) 297 | return model 298 | 299 | 300 | def resnet18(pretrained=False, progress=True, **kwargs): 301 | r"""ResNet-18 model from 302 | `"Deep Residual Learning for Image Recognition" `_ 303 | 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | progress (bool): If True, displays a progress bar of the download to stderr 307 | """ 308 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 309 | **kwargs) 310 | 311 | 312 | def resnet34(pretrained=False, progress=True, **kwargs): 313 | r"""ResNet-34 model from 314 | `"Deep Residual Learning for Image Recognition" `_ 315 | 316 | Args: 317 | pretrained (bool): If True, returns a model pre-trained on ImageNet 318 | progress (bool): If True, displays a progress bar of the download to stderr 319 | """ 320 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 321 | **kwargs) 322 | 323 | 324 | def resnet50(pretrained=False, progress=True, **kwargs): 325 | r"""ResNet-50 model from 326 | `"Deep Residual Learning for Image Recognition" `_ 327 | 328 | Args: 329 | pretrained (bool): If True, returns a model pre-trained on ImageNet 330 | progress (bool): If True, displays a progress bar of the download to stderr 331 | """ 332 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 333 | **kwargs) 334 | 335 | 336 | def resnet101(pretrained=False, progress=True, **kwargs): 337 | r"""ResNet-101 model from 338 | `"Deep Residual Learning for Image Recognition" `_ 339 | 340 | Args: 341 | pretrained (bool): If True, returns a model pre-trained on ImageNet 342 | progress (bool): If True, displays a progress bar of the download to stderr 343 | """ 344 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 345 | **kwargs) 346 | 347 | 348 | def resnet152(pretrained=False, progress=True, **kwargs): 349 | r"""ResNet-152 model from 350 | `"Deep Residual Learning for Image Recognition" `_ 351 | 352 | Args: 353 | pretrained (bool): If True, returns a model pre-trained on ImageNet 354 | progress (bool): If True, displays a progress bar of the download to stderr 355 | """ 356 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 357 | **kwargs) 358 | 359 | 360 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 361 | r"""ResNeXt-50 32x4d model from 362 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 363 | 364 | Args: 365 | pretrained (bool): If True, returns a model pre-trained on ImageNet 366 | progress (bool): If True, displays a progress bar of the download to stderr 367 | """ 368 | kwargs['groups'] = 32 369 | kwargs['width_per_group'] = 4 370 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 371 | pretrained, progress, **kwargs) 372 | 373 | 374 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 375 | r"""ResNeXt-101 32x8d model from 376 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 377 | 378 | Args: 379 | pretrained (bool): If True, returns a model pre-trained on ImageNet 380 | progress (bool): If True, displays a progress bar of the download to stderr 381 | """ 382 | kwargs['groups'] = 32 383 | kwargs['width_per_group'] = 8 384 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 385 | pretrained, progress, **kwargs) 386 | 387 | 388 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 389 | r"""Wide ResNet-50-2 model from 390 | `"Wide Residual Networks" `_ 391 | 392 | The model is the same as ResNet except for the bottleneck number of channels 393 | which is twice larger in every block. The number of channels in outer 1x1 394 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 395 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 396 | 397 | Args: 398 | pretrained (bool): If True, returns a model pre-trained on ImageNet 399 | progress (bool): If True, displays a progress bar of the download to stderr 400 | """ 401 | kwargs['width_per_group'] = 64 * 2 402 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 403 | pretrained, progress, **kwargs) 404 | 405 | 406 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 407 | r"""Wide ResNet-101-2 model from 408 | `"Wide Residual Networks" `_ 409 | 410 | The model is the same as ResNet except for the bottleneck number of channels 411 | which is twice larger in every block. The number of channels in outer 1x1 412 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 413 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 414 | 415 | Args: 416 | pretrained (bool): If True, returns a model pre-trained on ImageNet 417 | progress (bool): If True, displays a progress bar of the download to stderr 418 | """ 419 | kwargs['width_per_group'] = 64 * 2 420 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 421 | pretrained, progress, **kwargs) 422 | 423 | 424 | if __name__ == "__main__": 425 | 426 | input = torch.ones(1, 3, 56, 56) 427 | model = resnet18(pretrained=True) 428 | x, feats = model(input, return_feat=True, preReLU=False, return_final=True) 429 | 430 | print(feats[0].shape) 431 | print(feats[1].shape) 432 | print(feats[2].shape) 433 | print(feats[3].shape) 434 | print(feats[4].shape) 435 | print(feats[5].shape) 436 | 437 | # feat_size = model.get_feat_size(input) 438 | # print(feat_size) -------------------------------------------------------------------------------- /lib/models/vit_pixel.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | The official jax code is released and available at https://github.com/google-research/vision_transformer 12 | 13 | DeiT model defs and weights from https://github.com/facebookresearch/deit, 14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 15 | 16 | Acknowledgments: 17 | * The paper authors for releasing code and weights, thanks! 18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 19 | for some einops/einsum fun 20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 22 | 23 | Hacked together by / Copyright 2021 Ross Wightman 24 | """ 25 | import math 26 | import logging 27 | from functools import partial 28 | from collections import OrderedDict 29 | from copy import deepcopy 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 36 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 37 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 38 | from timm.models.registry import register_model 39 | 40 | _logger = logging.getLogger(__name__) 41 | 42 | 43 | def _cfg(url='', **kwargs): 44 | return { 45 | 'url': url, 46 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 47 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 48 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 49 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 50 | **kwargs 51 | } 52 | 53 | 54 | default_cfgs = { 55 | # patch models (weights from official Google JAX impl) 56 | 'vit_tiny_patch16_224': _cfg( 57 | url='https://storage.googleapis.com/vit_models/augreg/' 58 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 59 | 'vit_tiny_patch16_149': _cfg( 60 | url='https://storage.googleapis.com/vit_models/augreg/' 61 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 62 | input_size=(3, 149, 149)), 63 | 'vit_tiny_patch16_112': _cfg( 64 | url='https://storage.googleapis.com/vit_models/augreg/' 65 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 66 | input_size=(3, 112, 112)), 67 | 'vit_tiny_patch16_89': _cfg( 68 | url='https://storage.googleapis.com/vit_models/augreg/' 69 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 70 | input_size=(3, 89, 89)), 71 | 'vit_tiny_patch16_74': _cfg( 72 | url='https://storage.googleapis.com/vit_models/augreg/' 73 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 74 | input_size=(3, 74, 74)), 75 | 'vit_tiny_patch16_64': _cfg( 76 | url='https://storage.googleapis.com/vit_models/augreg/' 77 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 78 | input_size=(3, 64, 64)), 79 | 'vit_tiny_patch16_56': _cfg( 80 | url='https://storage.googleapis.com/vit_models/augreg/' 81 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 82 | input_size=(3, 56, 56)), 83 | 'vit_base_patch16_224': _cfg( 84 | url='https://storage.googleapis.com/vit_models/augreg/' 85 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 86 | 87 | 'vit_small_patch16_224': _cfg( 88 | url='https://storage.googleapis.com/vit_models/augreg/' 89 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 90 | 'vit_small_patch16_112': _cfg( 91 | url='https://storage.googleapis.com/vit_models/augreg/' 92 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 93 | input_size=(3, 112, 112)), 94 | 'vit_small_patch16_56': _cfg( 95 | url='https://storage.googleapis.com/vit_models/augreg/' 96 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 97 | input_size=(3, 56, 56)), 98 | } 99 | 100 | 101 | class Attention(nn.Module): 102 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 103 | super().__init__() 104 | self.num_heads = num_heads 105 | head_dim = dim // num_heads 106 | self.scale = head_dim ** -0.5 107 | 108 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(dim, dim) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | def forward(self, x): 114 | B, N, C = x.shape 115 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 116 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 117 | 118 | attn = (q @ k.transpose(-2, -1)) * self.scale 119 | attn = attn.softmax(dim=-1) 120 | attn = self.attn_drop(attn) 121 | 122 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 123 | x = self.proj(x) 124 | x = self.proj_drop(x) 125 | return x 126 | 127 | 128 | class Block(nn.Module): 129 | 130 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 131 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 132 | super().__init__() 133 | self.norm1 = norm_layer(dim) 134 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 135 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 136 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 137 | self.norm2 = norm_layer(dim) 138 | mlp_hidden_dim = int(dim * mlp_ratio) 139 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 140 | 141 | def forward(self, x): 142 | x = x + self.drop_path(self.attn(self.norm1(x))) 143 | x = x + self.drop_path(self.mlp(self.norm2(x))) 144 | return x 145 | 146 | 147 | class VisionTransformer(nn.Module): 148 | """ Vision Transformer 149 | 150 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 151 | - https://arxiv.org/abs/2010.11929 152 | 153 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 154 | - https://arxiv.org/abs/2012.12877 155 | """ 156 | 157 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 158 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 159 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 160 | act_layer=None, weight_init=''): 161 | """ 162 | Args: 163 | img_size (int, tuple): input image size 164 | patch_size (int, tuple): patch size 165 | in_chans (int): number of input channels 166 | num_classes (int): number of classes for classification head 167 | embed_dim (int): embedding dimension 168 | depth (int): depth of transformer 169 | num_heads (int): number of attention heads 170 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 171 | qkv_bias (bool): enable bias for qkv if True 172 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 173 | distilled (bool): model includes a distillation token and head as in DeiT models 174 | drop_rate (float): dropout rate 175 | attn_drop_rate (float): attention dropout rate 176 | drop_path_rate (float): stochastic depth rate 177 | embed_layer (nn.Module): patch embedding layer 178 | norm_layer: (nn.Module): normalization layer 179 | weight_init: (str): weight init scheme 180 | """ 181 | super().__init__() 182 | self.num_classes = num_classes 183 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 184 | self.num_tokens = 2 if distilled else 1 185 | self.depth = depth 186 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 187 | act_layer = act_layer or nn.GELU 188 | 189 | self.patch_embed = embed_layer( 190 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 191 | num_patches = self.patch_embed.num_patches 192 | 193 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 194 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 195 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 196 | self.pos_drop = nn.Dropout(p=drop_rate) 197 | 198 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 199 | self.blocks = nn.Sequential(*[ 200 | Block( 201 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 202 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 203 | for i in range(depth)]) 204 | self.norm = norm_layer(embed_dim) 205 | 206 | # Representation layer 207 | if representation_size and not distilled: 208 | self.num_features = representation_size 209 | self.pre_logits = nn.Sequential(OrderedDict([ 210 | ('fc', nn.Linear(embed_dim, representation_size)), 211 | ('act', nn.Tanh()) 212 | ])) 213 | else: 214 | self.pre_logits = nn.Identity() 215 | 216 | # Classifier head(s) 217 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 218 | self.head_dist = None 219 | if distilled: 220 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 221 | 222 | self.init_weights(weight_init) 223 | 224 | def init_weights(self, mode=''): 225 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 226 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 227 | trunc_normal_(self.pos_embed, std=.02) 228 | if self.dist_token is not None: 229 | trunc_normal_(self.dist_token, std=.02) 230 | if mode.startswith('jax'): 231 | # leave cls token as zeros to match jax impl 232 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 233 | else: 234 | trunc_normal_(self.cls_token, std=.02) 235 | self.apply(_init_vit_weights) 236 | 237 | def _init_weights(self, m): 238 | # this fn left here for compat with downstream users 239 | _init_vit_weights(m) 240 | 241 | @torch.jit.ignore() 242 | def load_pretrained(self, checkpoint_path, prefix=''): 243 | _load_weights(self, checkpoint_path, prefix) 244 | 245 | @torch.jit.ignore 246 | def no_weight_decay(self): 247 | return {'pos_embed', 'cls_token', 'dist_token'} 248 | 249 | def get_classifier(self): 250 | if self.dist_token is None: 251 | return self.head 252 | else: 253 | return self.head, self.head_dist 254 | 255 | def reset_classifier(self, num_classes, global_pool=''): 256 | self.num_classes = num_classes 257 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 258 | if self.num_tokens == 2: 259 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 260 | 261 | def forward_features(self, x, return_feat=False): 262 | x = self.patch_embed(x) 263 | 264 | if return_feat: 265 | feats = [] 266 | patch_tokens = x.permute(0, 2, 1).view(x.size(0), x.size(2), int(math.sqrt(x.size(1))), int(math.sqrt(x.size(1)))) 267 | feats.append(patch_tokens) 268 | 269 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 270 | x = torch.cat((cls_token, x), dim=1) 271 | x = self.pos_drop(x + self.pos_embed) 272 | 273 | if return_feat: 274 | for d in range(self.depth): 275 | x = self.blocks[d](x) 276 | patch_tokens = x[:, 1:, :].permute(0, 2, 1).view(x.size(0), x.size(2), int(math.sqrt(x.size(1))), 277 | int(math.sqrt(x.size(1)))) 278 | feats.append(patch_tokens) 279 | else: 280 | x = self.blocks(x) 281 | 282 | x = self.norm(x) 283 | if return_feat: 284 | return self.pre_logits(x[:, 0]), feats 285 | else: 286 | return self.pre_logits(x[:, 0]) 287 | 288 | def forward(self, x, return_feat=False, preReLU=False, return_final=False): 289 | 290 | if return_feat: 291 | x, feats = self.forward_features(x, return_feat=return_feat) 292 | x = self.head(x) 293 | return x, feats 294 | else: 295 | x = self.forward_features(x, return_feat=return_feat) 296 | x = self.head(x) 297 | return x 298 | 299 | def get_feat_size(self, x): 300 | """ 301 | :param x: input 302 | :return: size of final feat 303 | """ 304 | _, feats = self.forward(x, return_feat=True) 305 | feat_size = feats[-1].shape 306 | return list(feat_size) 307 | 308 | def get_input_feat_size(self, x): 309 | """ 310 | :param x: input 311 | :return: size of input feat 312 | """ 313 | _, feats = self.forward(x, return_feat=True) 314 | feat_size = feats[0].shape 315 | return list(feat_size) 316 | 317 | 318 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 319 | """ ViT weight initialization 320 | * When called without n, head_bias, jax_impl args it will behave exactly the same 321 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 322 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 323 | """ 324 | if isinstance(module, nn.Linear): 325 | if name.startswith('head'): 326 | nn.init.zeros_(module.weight) 327 | nn.init.constant_(module.bias, head_bias) 328 | elif name.startswith('pre_logits'): 329 | lecun_normal_(module.weight) 330 | nn.init.zeros_(module.bias) 331 | else: 332 | if jax_impl: 333 | nn.init.xavier_uniform_(module.weight) 334 | if module.bias is not None: 335 | if 'mlp' in name: 336 | nn.init.normal_(module.bias, std=1e-6) 337 | else: 338 | nn.init.zeros_(module.bias) 339 | else: 340 | trunc_normal_(module.weight, std=.02) 341 | if module.bias is not None: 342 | nn.init.zeros_(module.bias) 343 | elif jax_impl and isinstance(module, nn.Conv2d): 344 | # NOTE conv was left to pytorch default in my original init 345 | lecun_normal_(module.weight) 346 | if module.bias is not None: 347 | nn.init.zeros_(module.bias) 348 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 349 | nn.init.zeros_(module.bias) 350 | nn.init.ones_(module.weight) 351 | 352 | 353 | @torch.no_grad() 354 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 355 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 356 | """ 357 | import numpy as np 358 | 359 | def _n2p(w, t=True): 360 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 361 | w = w.flatten() 362 | if t: 363 | if w.ndim == 4: 364 | w = w.transpose([3, 2, 0, 1]) 365 | elif w.ndim == 3: 366 | w = w.transpose([2, 0, 1]) 367 | elif w.ndim == 2: 368 | w = w.transpose([1, 0]) 369 | return torch.from_numpy(w) 370 | 371 | w = np.load(checkpoint_path) 372 | if not prefix and 'opt/target/embedding/kernel' in w: 373 | prefix = 'opt/target/' 374 | 375 | if hasattr(model.patch_embed, 'backbone'): 376 | # hybrid 377 | backbone = model.patch_embed.backbone 378 | stem_only = not hasattr(backbone, 'stem') 379 | stem = backbone if stem_only else backbone.stem 380 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 381 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 382 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 383 | if not stem_only: 384 | for i, stage in enumerate(backbone.stages): 385 | for j, block in enumerate(stage.blocks): 386 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 387 | for r in range(3): 388 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 389 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 390 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 391 | if block.downsample is not None: 392 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 393 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 394 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 395 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 396 | else: 397 | embed_conv_w = adapt_input_conv( 398 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 399 | model.patch_embed.proj.weight.copy_(embed_conv_w) 400 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 401 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 402 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 403 | if pos_embed_w.shape != model.pos_embed.shape: 404 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 405 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 406 | model.pos_embed.copy_(pos_embed_w) 407 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 408 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 409 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 410 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 411 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 412 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 413 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 414 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 415 | for i, block in enumerate(model.blocks.children()): 416 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 417 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 418 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 419 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 420 | block.attn.qkv.weight.copy_(torch.cat([ 421 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 422 | block.attn.qkv.bias.copy_(torch.cat([ 423 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 424 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 425 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 426 | for r in range(2): 427 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 428 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 429 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 430 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 431 | 432 | 433 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 434 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 435 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 436 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 437 | ntok_new = posemb_new.shape[1] 438 | if num_tokens: 439 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 440 | ntok_new -= num_tokens 441 | else: 442 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 443 | gs_old = int(math.sqrt(len(posemb_grid))) 444 | if not len(gs_new): # backwards compatibility 445 | gs_new = [int(math.sqrt(ntok_new))] * 2 446 | assert len(gs_new) >= 2 447 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 448 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 449 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') 450 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 451 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 452 | return posemb 453 | 454 | 455 | def checkpoint_filter_fn(state_dict, model): 456 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 457 | out_dict = {} 458 | if 'model' in state_dict: 459 | # For deit models 460 | state_dict = state_dict['model'] 461 | for k, v in state_dict.items(): 462 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 463 | # For old models that I trained prior to conv based patchification 464 | O, I, H, W = model.patch_embed.proj.weight.shape 465 | v = v.reshape(O, -1, H, W) 466 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 467 | # To resize pos embedding when using model at different size from pretrained weights 468 | v = resize_pos_embed( 469 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 470 | out_dict[k] = v 471 | return out_dict 472 | 473 | 474 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): 475 | default_cfg = default_cfg or default_cfgs[variant] 476 | if kwargs.get('features_only', None): 477 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 478 | 479 | # NOTE this extra code to support handling of repr size for in21k pretrained models 480 | default_num_classes = default_cfg['num_classes'] 481 | num_classes = kwargs.get('num_classes', default_num_classes) 482 | repr_size = kwargs.pop('representation_size', None) 483 | if repr_size is not None and num_classes != default_num_classes: 484 | # Remove representation layer if fine-tuning. This may not always be the desired action, 485 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 486 | _logger.warning("Removing representation layer for fine-tuning.") 487 | repr_size = None 488 | 489 | model = build_model_with_cfg( 490 | VisionTransformer, variant, pretrained, 491 | default_cfg=default_cfg, 492 | representation_size=repr_size, 493 | pretrained_filter_fn=checkpoint_filter_fn, 494 | pretrained_custom_load='npz' in default_cfg['url'], 495 | **kwargs) 496 | return model 497 | 498 | 499 | @register_model 500 | def vit_tiny_patch16_224(pretrained=False, **kwargs): 501 | """ ViT-Tiny (Vit-Ti/16) 502 | """ 503 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 504 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 505 | return model 506 | @register_model 507 | def vit_tiny_patch16_112(pretrained=False, **kwargs): 508 | """ ViT-Tiny (Vit-Ti/16) 509 | """ 510 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 511 | model = _create_vision_transformer('vit_tiny_patch16_112', pretrained=pretrained, **model_kwargs) 512 | return model 513 | @register_model 514 | def vit_tiny_patch16_149(pretrained=False, **kwargs): 515 | """ ViT-Tiny (Vit-Ti/16) 516 | """ 517 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 518 | model = _create_vision_transformer('vit_tiny_patch16_149', pretrained=pretrained, **model_kwargs) 519 | return model 520 | @register_model 521 | def vit_tiny_patch16_89(pretrained=False, **kwargs): 522 | """ ViT-Tiny (Vit-Ti/16) 523 | """ 524 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 525 | model = _create_vision_transformer('vit_tiny_patch16_89', pretrained=pretrained, **model_kwargs) 526 | return model 527 | @register_model 528 | def vit_tiny_patch16_74(pretrained=False, **kwargs): 529 | """ ViT-Tiny (Vit-Ti/16) 530 | """ 531 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 532 | model = _create_vision_transformer('vit_tiny_patch16_74', pretrained=pretrained, **model_kwargs) 533 | return model 534 | 535 | @register_model 536 | def vit_tiny_patch16_64(pretrained=False, **kwargs): 537 | """ ViT-Tiny (Vit-Ti/16) 538 | """ 539 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 540 | model = _create_vision_transformer('vit_tiny_patch16_64', pretrained=pretrained, **model_kwargs) 541 | return model 542 | 543 | 544 | @register_model 545 | def vit_tiny_patch16_56(pretrained=False, **kwargs): 546 | """ ViT-Tiny (Vit-Ti/16) 547 | """ 548 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 549 | model = _create_vision_transformer('vit_tiny_patch16_56', pretrained=pretrained, **model_kwargs) 550 | return model 551 | 552 | 553 | @register_model 554 | def vit_base_patch16_224(pretrained=False, **kwargs): 555 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 556 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 557 | """ 558 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 559 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) 560 | return model 561 | 562 | 563 | @register_model 564 | def vit_small_patch16_224(pretrained=False, **kwargs): 565 | """ ViT-Small (ViT-S/16) 566 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 567 | """ 568 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 569 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) 570 | return model 571 | 572 | @register_model 573 | def vit_small_patch16_112(pretrained=False, **kwargs): 574 | """ ViT-Small (ViT-S/16) 575 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 576 | """ 577 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 578 | model = _create_vision_transformer('vit_small_patch16_112', pretrained=pretrained, **model_kwargs) 579 | return model 580 | 581 | @register_model 582 | def vit_small_patch16_56(pretrained=False, **kwargs): 583 | """ ViT-Small (ViT-S/16) 584 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 585 | """ 586 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 587 | model = _create_vision_transformer('vit_small_patch16_56', pretrained=pretrained, **model_kwargs) 588 | return model 589 | 590 | 591 | if __name__ == "__main__": 592 | import torchvision.transforms as transforms 593 | from PIL import Image 594 | from timm.models import create_model 595 | # import os, sys 596 | # sys.path.insert(0, '../../') 597 | # import argparse 598 | # import torch.backends.cudnn as cudnn 599 | # from lib.config.default import update_config 600 | # from lib.config.default import config as cfg 601 | # from lib.utils import fix_random_seed 602 | 603 | # parser = argparse.ArgumentParser(description='knowledge distillation') 604 | # parser.add_argument('--config_file', type=str, default='../../configs/cub/cub_vit_single_224.yaml', 605 | # required=False, help='Optional config file for params') 606 | # parser.add_argument('opts', help='see config.py for all options', 607 | # default='BASIC.GPU_ID [7]', nargs=argparse.REMAINDER) 608 | # args = parser.parse_args() 609 | # update_config(args) 610 | 611 | # # fix random seed 612 | # fix_random_seed(cfg.BASIC.SEED) 613 | # # cudnn 614 | # cudnn.benchmark = cfg.CUDNN.BENCHMARK # Benchmark will impove the speed 615 | # cudnn.deterministic = cfg.CUDNN.DETERMINISTIC # 616 | # cudnn.enabled = cfg.CUDNN.ENABLE # Enables benchmark mode in cudnn, to enable the inbuilt cudnn auto-tuner 617 | 618 | # cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..') 619 | 620 | input = torch.ones(2, 3, 112, 112) 621 | model = create_model( 622 | 'vit_tiny_patch16_112', 623 | pretrained=True, 624 | num_classes=100) 625 | 626 | output = model(input, return_feat=True) 627 | print(output.shape) 628 | print(output) -------------------------------------------------------------------------------- /lib/models/vit.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | The official jax code is released and available at https://github.com/google-research/vision_transformer 12 | 13 | DeiT model defs and weights from https://github.com/facebookresearch/deit, 14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 15 | 16 | Acknowledgments: 17 | * The paper authors for releasing code and weights, thanks! 18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 19 | for some einops/einsum fun 20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 22 | 23 | Hacked together by / Copyright 2021 Ross Wightman 24 | """ 25 | import math 26 | import logging 27 | from functools import partial 28 | from collections import OrderedDict 29 | from copy import deepcopy 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 36 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 37 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 38 | from timm.models.registry import register_model 39 | 40 | _logger = logging.getLogger(__name__) 41 | 42 | 43 | def _cfg(url='', **kwargs): 44 | return { 45 | 'url': url, 46 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 47 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 48 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 49 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 50 | **kwargs 51 | } 52 | 53 | 54 | default_cfgs = { 55 | # patch models (weights from official Google JAX impl) 56 | 'vit_tiny_patch16_224': _cfg( 57 | url='https://storage.googleapis.com/vit_models/augreg/' 58 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 59 | 'vit_tiny_patch16_112': _cfg( 60 | url='https://storage.googleapis.com/vit_models/augreg/' 61 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 62 | input_size=(3, 112, 112)), 63 | 'vit_tiny_patch16_56': _cfg( 64 | url='https://storage.googleapis.com/vit_models/augreg/' 65 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', 66 | input_size=(3, 56, 56)), 67 | 68 | 69 | 'vit_tiny_patch16_384': _cfg( 70 | url='https://storage.googleapis.com/vit_models/augreg/' 71 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 72 | input_size=(3, 384, 384), crop_pct=1.0), 73 | 'vit_small_patch32_224': _cfg( 74 | url='https://storage.googleapis.com/vit_models/augreg/' 75 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 76 | 'vit_small_patch32_384': _cfg( 77 | url='https://storage.googleapis.com/vit_models/augreg/' 78 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 79 | input_size=(3, 384, 384), crop_pct=1.0), 80 | 'vit_small_patch16_224': _cfg( 81 | url='https://storage.googleapis.com/vit_models/augreg/' 82 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 83 | 'vit_small_patch16_384': _cfg( 84 | url='https://storage.googleapis.com/vit_models/augreg/' 85 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 86 | input_size=(3, 384, 384), crop_pct=1.0), 87 | 'vit_base_patch32_224': _cfg( 88 | url='https://storage.googleapis.com/vit_models/augreg/' 89 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 90 | 'vit_base_patch32_384': _cfg( 91 | url='https://storage.googleapis.com/vit_models/augreg/' 92 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 93 | input_size=(3, 384, 384), crop_pct=1.0), 94 | 'vit_base_patch16_224': _cfg( 95 | url='https://storage.googleapis.com/vit_models/augreg/' 96 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 97 | 'vit_base_patch16_384': _cfg( 98 | url='https://storage.googleapis.com/vit_models/augreg/' 99 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 100 | input_size=(3, 384, 384), crop_pct=1.0), 101 | 'vit_large_patch32_224': _cfg( 102 | url='', # no official model weights for this combo, only for in21k 103 | ), 104 | 'vit_large_patch32_384': _cfg( 105 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 106 | input_size=(3, 384, 384), crop_pct=1.0), 107 | 'vit_large_patch16_224': _cfg( 108 | url='https://storage.googleapis.com/vit_models/augreg/' 109 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 110 | 'vit_large_patch16_384': _cfg( 111 | url='https://storage.googleapis.com/vit_models/augreg/' 112 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 113 | input_size=(3, 384, 384), crop_pct=1.0), 114 | 115 | # patch models, imagenet21k (weights from official Google JAX impl) 116 | 'vit_tiny_patch16_224_in21k': _cfg( 117 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 118 | num_classes=21843), 119 | 'vit_small_patch32_224_in21k': _cfg( 120 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 121 | num_classes=21843), 122 | 'vit_small_patch16_224_in21k': _cfg( 123 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 124 | num_classes=21843), 125 | 'vit_base_patch32_224_in21k': _cfg( 126 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', 127 | num_classes=21843), 128 | 'vit_base_patch16_224_in21k': _cfg( 129 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 130 | num_classes=21843), 131 | 'vit_large_patch32_224_in21k': _cfg( 132 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', 133 | num_classes=21843), 134 | 'vit_large_patch16_224_in21k': _cfg( 135 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', 136 | num_classes=21843), 137 | 'vit_huge_patch14_224_in21k': _cfg( 138 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', 139 | hf_hub='timm/vit_huge_patch14_224_in21k', 140 | num_classes=21843), 141 | 142 | # deit models (FB weights) 143 | 'deit_tiny_patch16_224': _cfg( 144 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', 145 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 146 | 'deit_small_patch16_224': _cfg( 147 | url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', 148 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 149 | 'deit_base_patch16_224': _cfg( 150 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', 151 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 152 | 'deit_base_patch16_384': _cfg( 153 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', 154 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), 155 | 'deit_tiny_distilled_patch16_224': _cfg( 156 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', 157 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 158 | 'deit_small_distilled_patch16_224': _cfg( 159 | url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', 160 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 161 | 'deit_base_distilled_patch16_224': _cfg( 162 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', 163 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 164 | 'deit_base_distilled_patch16_384': _cfg( 165 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', 166 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, 167 | classifier=('head', 'head_dist')), 168 | 169 | # ViT ImageNet-21K-P pretraining by MILL 170 | 'vit_base_patch16_224_miil_in21k': _cfg( 171 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', 172 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, 173 | ), 174 | 'vit_base_patch16_224_miil': _cfg( 175 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' 176 | '/vit_base_patch16_224_1k_miil_84_4.pth', 177 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', 178 | ), 179 | 180 | 181 | # size 448, use pretrained model of size 384 182 | 'vit_base_patch16_448': _cfg( 183 | url='https://storage.googleapis.com/vit_models/augreg/' 184 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 185 | input_size=(3, 448, 448), crop_pct=1.0), 186 | 'deit_base_patch16_448': _cfg( 187 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', 188 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 448, 448), crop_pct=1.0), 189 | 'deit_base_distilled_patch16_448': _cfg( 190 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', 191 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 448, 448), crop_pct=1.0, 192 | classifier=('head', 'head_dist')), 193 | } 194 | 195 | 196 | class Attention(nn.Module): 197 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 198 | super().__init__() 199 | self.num_heads = num_heads 200 | head_dim = dim // num_heads 201 | self.scale = head_dim ** -0.5 202 | 203 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 204 | self.attn_drop = nn.Dropout(attn_drop) 205 | self.proj = nn.Linear(dim, dim) 206 | self.proj_drop = nn.Dropout(proj_drop) 207 | 208 | def forward(self, x): 209 | B, N, C = x.shape 210 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 211 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 212 | 213 | attn = (q @ k.transpose(-2, -1)) * self.scale 214 | attn = attn.softmax(dim=-1) 215 | attn = self.attn_drop(attn) 216 | 217 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 218 | x = self.proj(x) 219 | x = self.proj_drop(x) 220 | return x 221 | 222 | 223 | class Block(nn.Module): 224 | 225 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 226 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 227 | super().__init__() 228 | self.norm1 = norm_layer(dim) 229 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 230 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 231 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 232 | self.norm2 = norm_layer(dim) 233 | mlp_hidden_dim = int(dim * mlp_ratio) 234 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 235 | 236 | def forward(self, x): 237 | x = x + self.drop_path(self.attn(self.norm1(x))) 238 | x = x + self.drop_path(self.mlp(self.norm2(x))) 239 | return x 240 | 241 | 242 | class VisionTransformer(nn.Module): 243 | """ Vision Transformer 244 | 245 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 246 | - https://arxiv.org/abs/2010.11929 247 | 248 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 249 | - https://arxiv.org/abs/2012.12877 250 | """ 251 | 252 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 253 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 254 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 255 | act_layer=None, weight_init=''): 256 | """ 257 | Args: 258 | img_size (int, tuple): input image size 259 | patch_size (int, tuple): patch size 260 | in_chans (int): number of input channels 261 | num_classes (int): number of classes for classification head 262 | embed_dim (int): embedding dimension 263 | depth (int): depth of transformer 264 | num_heads (int): number of attention heads 265 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 266 | qkv_bias (bool): enable bias for qkv if True 267 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 268 | distilled (bool): model includes a distillation token and head as in DeiT models 269 | drop_rate (float): dropout rate 270 | attn_drop_rate (float): attention dropout rate 271 | drop_path_rate (float): stochastic depth rate 272 | embed_layer (nn.Module): patch embedding layer 273 | norm_layer: (nn.Module): normalization layer 274 | weight_init: (str): weight init scheme 275 | """ 276 | super().__init__() 277 | self.num_classes = num_classes 278 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 279 | self.num_tokens = 2 if distilled else 1 280 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 281 | act_layer = act_layer or nn.GELU 282 | 283 | self.patch_embed = embed_layer( 284 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 285 | num_patches = self.patch_embed.num_patches 286 | 287 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 288 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 289 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 290 | self.pos_drop = nn.Dropout(p=drop_rate) 291 | 292 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 293 | self.blocks = nn.Sequential(*[ 294 | Block( 295 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 296 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 297 | for i in range(depth)]) 298 | self.norm = norm_layer(embed_dim) 299 | 300 | # Representation layer 301 | if representation_size and not distilled: 302 | self.num_features = representation_size 303 | self.pre_logits = nn.Sequential(OrderedDict([ 304 | ('fc', nn.Linear(embed_dim, representation_size)), 305 | ('act', nn.Tanh()) 306 | ])) 307 | else: 308 | self.pre_logits = nn.Identity() 309 | 310 | # Classifier head(s) 311 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 312 | self.head_dist = None 313 | if distilled: 314 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 315 | 316 | self.init_weights(weight_init) 317 | 318 | def init_weights(self, mode=''): 319 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 320 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 321 | trunc_normal_(self.pos_embed, std=.02) 322 | if self.dist_token is not None: 323 | trunc_normal_(self.dist_token, std=.02) 324 | if mode.startswith('jax'): 325 | # leave cls token as zeros to match jax impl 326 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 327 | else: 328 | trunc_normal_(self.cls_token, std=.02) 329 | self.apply(_init_vit_weights) 330 | 331 | def _init_weights(self, m): 332 | # this fn left here for compat with downstream users 333 | _init_vit_weights(m) 334 | 335 | @torch.jit.ignore() 336 | def load_pretrained(self, checkpoint_path, prefix=''): 337 | _load_weights(self, checkpoint_path, prefix) 338 | 339 | @torch.jit.ignore 340 | def no_weight_decay(self): 341 | return {'pos_embed', 'cls_token', 'dist_token'} 342 | 343 | def get_classifier(self): 344 | if self.dist_token is None: 345 | return self.head 346 | else: 347 | return self.head, self.head_dist 348 | 349 | def reset_classifier(self, num_classes, global_pool=''): 350 | self.num_classes = num_classes 351 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 352 | if self.num_tokens == 2: 353 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 354 | 355 | def forward_features(self, x): 356 | x = self.patch_embed(x) 357 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 358 | if self.dist_token is None: 359 | x = torch.cat((cls_token, x), dim=1) 360 | else: 361 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 362 | x = self.pos_drop(x + self.pos_embed) 363 | x = self.blocks(x) 364 | x = self.norm(x) 365 | if self.dist_token is None: 366 | return self.pre_logits(x[:, 0]) 367 | else: 368 | return x[:, 0], x[:, 1] 369 | 370 | def forward(self, x, return_feat=False, preReLU=False, return_final=False): 371 | x = self.forward_features(x) 372 | if self.head_dist is not None: 373 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 374 | if self.training and not torch.jit.is_scripting(): 375 | # during inference, return the average of both classifier predictions 376 | return x, x_dist 377 | else: 378 | return (x + x_dist) / 2 379 | else: 380 | x = self.head(x) 381 | return x 382 | 383 | 384 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 385 | """ ViT weight initialization 386 | * When called without n, head_bias, jax_impl args it will behave exactly the same 387 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 388 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 389 | """ 390 | if isinstance(module, nn.Linear): 391 | if name.startswith('head'): 392 | nn.init.zeros_(module.weight) 393 | nn.init.constant_(module.bias, head_bias) 394 | elif name.startswith('pre_logits'): 395 | lecun_normal_(module.weight) 396 | nn.init.zeros_(module.bias) 397 | else: 398 | if jax_impl: 399 | nn.init.xavier_uniform_(module.weight) 400 | if module.bias is not None: 401 | if 'mlp' in name: 402 | nn.init.normal_(module.bias, std=1e-6) 403 | else: 404 | nn.init.zeros_(module.bias) 405 | else: 406 | trunc_normal_(module.weight, std=.02) 407 | if module.bias is not None: 408 | nn.init.zeros_(module.bias) 409 | elif jax_impl and isinstance(module, nn.Conv2d): 410 | # NOTE conv was left to pytorch default in my original init 411 | lecun_normal_(module.weight) 412 | if module.bias is not None: 413 | nn.init.zeros_(module.bias) 414 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 415 | nn.init.zeros_(module.bias) 416 | nn.init.ones_(module.weight) 417 | 418 | 419 | @torch.no_grad() 420 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 421 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 422 | """ 423 | import numpy as np 424 | 425 | def _n2p(w, t=True): 426 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 427 | w = w.flatten() 428 | if t: 429 | if w.ndim == 4: 430 | w = w.transpose([3, 2, 0, 1]) 431 | elif w.ndim == 3: 432 | w = w.transpose([2, 0, 1]) 433 | elif w.ndim == 2: 434 | w = w.transpose([1, 0]) 435 | return torch.from_numpy(w) 436 | 437 | w = np.load(checkpoint_path) 438 | if not prefix and 'opt/target/embedding/kernel' in w: 439 | prefix = 'opt/target/' 440 | 441 | if hasattr(model.patch_embed, 'backbone'): 442 | # hybrid 443 | backbone = model.patch_embed.backbone 444 | stem_only = not hasattr(backbone, 'stem') 445 | stem = backbone if stem_only else backbone.stem 446 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 447 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 448 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 449 | if not stem_only: 450 | for i, stage in enumerate(backbone.stages): 451 | for j, block in enumerate(stage.blocks): 452 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 453 | for r in range(3): 454 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 455 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 456 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 457 | if block.downsample is not None: 458 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 459 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 460 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 461 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 462 | else: 463 | embed_conv_w = adapt_input_conv( 464 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 465 | model.patch_embed.proj.weight.copy_(embed_conv_w) 466 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 467 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 468 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 469 | if pos_embed_w.shape != model.pos_embed.shape: 470 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 471 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 472 | model.pos_embed.copy_(pos_embed_w) 473 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 474 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 475 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 476 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 477 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 478 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 479 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 480 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 481 | for i, block in enumerate(model.blocks.children()): 482 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 483 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 484 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 485 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 486 | block.attn.qkv.weight.copy_(torch.cat([ 487 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 488 | block.attn.qkv.bias.copy_(torch.cat([ 489 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 490 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 491 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 492 | for r in range(2): 493 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 494 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 495 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 496 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 497 | 498 | 499 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 500 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 501 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 502 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 503 | ntok_new = posemb_new.shape[1] 504 | if num_tokens: 505 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 506 | ntok_new -= num_tokens 507 | else: 508 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 509 | gs_old = int(math.sqrt(len(posemb_grid))) 510 | if not len(gs_new): # backwards compatibility 511 | gs_new = [int(math.sqrt(ntok_new))] * 2 512 | assert len(gs_new) >= 2 513 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 514 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 515 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') 516 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 517 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 518 | return posemb 519 | 520 | 521 | def checkpoint_filter_fn(state_dict, model): 522 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 523 | out_dict = {} 524 | if 'model' in state_dict: 525 | # For deit models 526 | state_dict = state_dict['model'] 527 | for k, v in state_dict.items(): 528 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 529 | # For old models that I trained prior to conv based patchification 530 | O, I, H, W = model.patch_embed.proj.weight.shape 531 | v = v.reshape(O, -1, H, W) 532 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 533 | # To resize pos embedding when using model at different size from pretrained weights 534 | v = resize_pos_embed( 535 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 536 | out_dict[k] = v 537 | return out_dict 538 | 539 | 540 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): 541 | default_cfg = default_cfg or default_cfgs[variant] 542 | if kwargs.get('features_only', None): 543 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 544 | 545 | # NOTE this extra code to support handling of repr size for in21k pretrained models 546 | default_num_classes = default_cfg['num_classes'] 547 | num_classes = kwargs.get('num_classes', default_num_classes) 548 | repr_size = kwargs.pop('representation_size', None) 549 | if repr_size is not None and num_classes != default_num_classes: 550 | # Remove representation layer if fine-tuning. This may not always be the desired action, 551 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 552 | _logger.warning("Removing representation layer for fine-tuning.") 553 | repr_size = None 554 | 555 | model = build_model_with_cfg( 556 | VisionTransformer, variant, pretrained, 557 | default_cfg=default_cfg, 558 | representation_size=repr_size, 559 | pretrained_filter_fn=checkpoint_filter_fn, 560 | pretrained_custom_load='npz' in default_cfg['url'], 561 | **kwargs) 562 | return model 563 | 564 | 565 | @register_model 566 | def vit_tiny_patch16_224(pretrained=False, **kwargs): 567 | """ ViT-Tiny (Vit-Ti/16) 568 | """ 569 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 570 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 571 | return model 572 | @register_model 573 | def vit_tiny_patch16_112(pretrained=False, **kwargs): 574 | """ ViT-Tiny (Vit-Ti/16) 575 | """ 576 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 577 | model = _create_vision_transformer('vit_tiny_patch16_112', pretrained=pretrained, **model_kwargs) 578 | return model 579 | @register_model 580 | def vit_tiny_patch16_56(pretrained=False, **kwargs): 581 | """ ViT-Tiny (Vit-Ti/16) 582 | """ 583 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 584 | model = _create_vision_transformer('vit_tiny_patch16_56', pretrained=pretrained, **model_kwargs) 585 | return model 586 | 587 | 588 | @register_model 589 | def vit_tiny_patch16_384(pretrained=False, **kwargs): 590 | """ ViT-Tiny (Vit-Ti/16) @ 384x384. 591 | """ 592 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 593 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) 594 | return model 595 | 596 | 597 | @register_model 598 | def vit_small_patch32_224(pretrained=False, **kwargs): 599 | """ ViT-Small (ViT-S/32) 600 | """ 601 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 602 | model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) 603 | return model 604 | 605 | 606 | @register_model 607 | def vit_small_patch32_384(pretrained=False, **kwargs): 608 | """ ViT-Small (ViT-S/32) at 384x384. 609 | """ 610 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 611 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) 612 | return model 613 | 614 | 615 | @register_model 616 | def vit_small_patch16_224(pretrained=False, **kwargs): 617 | """ ViT-Small (ViT-S/16) 618 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 619 | """ 620 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 621 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) 622 | return model 623 | 624 | 625 | @register_model 626 | def vit_small_patch16_384(pretrained=False, **kwargs): 627 | """ ViT-Small (ViT-S/16) 628 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 629 | """ 630 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 631 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) 632 | return model 633 | 634 | 635 | @register_model 636 | def vit_base_patch32_224(pretrained=False, **kwargs): 637 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 638 | """ 639 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 640 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) 641 | return model 642 | 643 | 644 | @register_model 645 | def vit_base_patch32_384(pretrained=False, **kwargs): 646 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 647 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 648 | """ 649 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 650 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) 651 | return model 652 | 653 | 654 | @register_model 655 | def vit_base_patch16_224(pretrained=False, **kwargs): 656 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 657 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 658 | """ 659 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 660 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) 661 | return model 662 | 663 | 664 | @register_model 665 | def vit_base_patch16_384(pretrained=False, **kwargs): 666 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 667 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 668 | """ 669 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 670 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) 671 | return model 672 | 673 | 674 | @register_model 675 | def vit_base_patch16_448(pretrained=False, **kwargs): 676 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 677 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 678 | """ 679 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 680 | model = _create_vision_transformer('vit_base_patch16_448', pretrained=pretrained, **model_kwargs) 681 | return model 682 | 683 | 684 | @register_model 685 | def vit_large_patch32_224(pretrained=False, **kwargs): 686 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 687 | """ 688 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 689 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) 690 | return model 691 | 692 | 693 | @register_model 694 | def vit_large_patch32_384(pretrained=False, **kwargs): 695 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 696 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 697 | """ 698 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 699 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) 700 | return model 701 | 702 | 703 | @register_model 704 | def vit_large_patch16_224(pretrained=False, **kwargs): 705 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 706 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 707 | """ 708 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 709 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) 710 | return model 711 | 712 | 713 | @register_model 714 | def vit_large_patch16_384(pretrained=False, **kwargs): 715 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 716 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 717 | """ 718 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 719 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) 720 | return model 721 | 722 | 723 | @register_model 724 | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): 725 | """ ViT-Tiny (Vit-Ti/16). 726 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 727 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 728 | """ 729 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 730 | model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 731 | return model 732 | 733 | 734 | @register_model 735 | def vit_small_patch32_224_in21k(pretrained=False, **kwargs): 736 | """ ViT-Small (ViT-S/16) 737 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 738 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 739 | """ 740 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 741 | model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 742 | return model 743 | 744 | 745 | @register_model 746 | def vit_small_patch16_224_in21k(pretrained=False, **kwargs): 747 | """ ViT-Small (ViT-S/16) 748 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 749 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 750 | """ 751 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 752 | model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 753 | return model 754 | 755 | 756 | @register_model 757 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs): 758 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 759 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 760 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 761 | """ 762 | model_kwargs = dict( 763 | patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 764 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 765 | return model 766 | 767 | 768 | @register_model 769 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs): 770 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 771 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 772 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 773 | """ 774 | model_kwargs = dict( 775 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 776 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 777 | return model 778 | 779 | 780 | @register_model 781 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs): 782 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 783 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 784 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 785 | """ 786 | model_kwargs = dict( 787 | patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) 788 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 789 | return model 790 | 791 | 792 | @register_model 793 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs): 794 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 795 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 796 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 797 | """ 798 | model_kwargs = dict( 799 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 800 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 801 | return model 802 | 803 | 804 | @register_model 805 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): 806 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 807 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 808 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 809 | """ 810 | model_kwargs = dict( 811 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) 812 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) 813 | return model 814 | 815 | 816 | @register_model 817 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 818 | """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 819 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 820 | """ 821 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 822 | model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 823 | return model 824 | 825 | 826 | @register_model 827 | def deit_small_patch16_224(pretrained=False, **kwargs): 828 | """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 829 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 830 | """ 831 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 832 | model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) 833 | return model 834 | 835 | 836 | @register_model 837 | def deit_base_patch16_224(pretrained=False, **kwargs): 838 | """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 839 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 840 | """ 841 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 842 | model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) 843 | return model 844 | 845 | 846 | @register_model 847 | def deit_base_patch16_384(pretrained=False, **kwargs): 848 | """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 849 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 850 | """ 851 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 852 | model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) 853 | return model 854 | 855 | 856 | @register_model 857 | def deit_base_patch16_448(pretrained=False, **kwargs): 858 | """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 859 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 860 | """ 861 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 862 | model = _create_vision_transformer('deit_base_patch16_448', pretrained=pretrained, **model_kwargs) 863 | return model 864 | 865 | 866 | @register_model 867 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 868 | """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 869 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 870 | """ 871 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 872 | model = _create_vision_transformer( 873 | 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 874 | return model 875 | 876 | 877 | @register_model 878 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 879 | """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 880 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 881 | """ 882 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 883 | model = _create_vision_transformer( 884 | 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 885 | return model 886 | 887 | 888 | @register_model 889 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 890 | """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 891 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 892 | """ 893 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 894 | model = _create_vision_transformer( 895 | 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 896 | return model 897 | 898 | 899 | @register_model 900 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 901 | """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 902 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 903 | """ 904 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 905 | model = _create_vision_transformer( 906 | 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) 907 | return model 908 | 909 | 910 | @register_model 911 | def deit_base_distilled_patch16_448(pretrained=False, **kwargs): 912 | """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 913 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 914 | """ 915 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 916 | model = _create_vision_transformer( 917 | 'deit_base_distilled_patch16_448', pretrained=pretrained, distilled=True, **model_kwargs) 918 | return model 919 | 920 | 921 | @register_model 922 | def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): 923 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 924 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 925 | """ 926 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 927 | model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) 928 | return model 929 | 930 | 931 | @register_model 932 | def vit_base_patch16_224_miil(pretrained=False, **kwargs): 933 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 934 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 935 | """ 936 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 937 | model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) 938 | return model 939 | 940 | 941 | 942 | 943 | if __name__ == "__main__": 944 | import torchvision.transforms as transforms 945 | from PIL import Image 946 | from timm.models import create_model 947 | import os, sys 948 | sys.path.insert(0, '../../lib') 949 | import argparse 950 | import torch.backends.cudnn as cudnn 951 | from config.default import update_config 952 | from config.default import config as cfg 953 | from utils import fix_random_seed 954 | 955 | parser = argparse.ArgumentParser(description='knowledge distillation') 956 | parser.add_argument('--config_file', type=str, default='../../configs/cub/cub_vit.yaml', 957 | required=False, help='Optional config file for params') 958 | parser.add_argument('opts', help='see config.py for all options', 959 | default='BASIC.GPU_ID [7]', nargs=argparse.REMAINDER) 960 | args = parser.parse_args() 961 | update_config(args) 962 | 963 | # fix random seed 964 | fix_random_seed(cfg.BASIC.SEED) 965 | # cudnn 966 | cudnn.benchmark = cfg.CUDNN.BENCHMARK # Benchmark will impove the speed 967 | cudnn.deterministic = cfg.CUDNN.DETERMINISTIC # 968 | cudnn.enabled = cfg.CUDNN.ENABLE # Enables benchmark mode in cudnn, to enable the inbuilt cudnn auto-tuner 969 | 970 | cfg.BASIC.ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..') 971 | 972 | input = torch.ones(2, 3, 448, 448) 973 | test_transform = transforms.Compose([ 974 | transforms.Resize((448, 448)), 975 | transforms.ToTensor(), 976 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 977 | ]) 978 | image = Image.open('../../demo/demo.JPEG').convert('RGB') 979 | image = test_transform(image) 980 | input[0] = image 981 | input[1] = image 982 | 983 | model = create_model( 984 | 'vit_base_patch16_448', 985 | pretrained=cfg.MODEL.PRETRAIN, 986 | num_classes=cfg.DATA.NUM_CLASSES) 987 | 988 | 989 | output = model(input) 990 | print(output.shape) 991 | print(output) --------------------------------------------------------------------------------