├── doctor ├── __init__.py ├── create_mask.py ├── labelme_to_mask.py ├── get_all_activation.py ├── get_linear.py ├── get_class_activation.py ├── test.py ├── get_all_attention.py ├── sample_sift.py ├── get_attention.py ├── train.py ├── get_all_channel.py ├── get_grads.py ├── test_mask.py ├── train_1k.py ├── train_lrp.py ├── train_att.py ├── train_grad.py ├── constraint.py └── grad_calculate.py ├── utils ├── __init__.py ├── file_util.py ├── train_util.py ├── image_util.py ├── dist.py ├── fig_util.py └── cam.py ├── framewrok.png ├── metrics ├── __init__.py ├── accuracy.py └── ildv │ └── __init__.py ├── get_linear.sh ├── get_all_activation.sh ├── get_attention.sh ├── get_class_activation.sh ├── loaders ├── __init__.py ├── image_enhance.py ├── image_dataset.py └── image_loader.py ├── get_grads.sh ├── get_all_attention.sh ├── test_mask.sh ├── test.sh ├── train_1k.sh ├── get_all_channel.sh ├── grad_calculate.sh ├── train.sh ├── sample_sift.sh ├── criterions └── cross_entropy.py ├── train_lrp.sh ├── train_plus.sh ├── train_att.sh ├── train_grad.sh ├── evaluator └── default.py ├── README.md ├── feature_sift.sh ├── engines ├── test.py └── train.py ├── models ├── __init__.py ├── vit.py └── tnt.py └── .gitignore /doctor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /framewrok.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaconghu/Transformer-Doctor/HEAD/framewrok.png -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from metrics.accuracy import accuracy 2 | from metrics.accuracy import ClassAccuracy 3 | import torch 4 | 5 | class Metric: 6 | def __init__(self): 7 | super().__init__() 8 | self.device = torch.device('cpu') 9 | 10 | def __call__(self, *args, **kwargs): 11 | return self 12 | 13 | def update(self, outputs, labels): 14 | pass 15 | 16 | def compute(self): 17 | pass 18 | 19 | def to(self, device): 20 | return self -------------------------------------------------------------------------------- /get_linear.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=[Your custom-defined path] 3 | export CUDA_VISIBLE_DEVICES=1 4 | export result_path='[Your custom-defined path]' 5 | #export exp_name='vgg16d_cifar10_mu' 6 | #export model_name='vgg16d' 7 | export exp_name='vit_high_test' 8 | export model_name='vit_linear' 9 | export data_name='cifar10' 10 | export num_classes=10 11 | export model_path='[Your custom-defined path]' 12 | export data_dir='[Your custom-defined path]' 13 | export save_dir=${result_path}'/'${exp_name} 14 | python doctor/get_linear.py \ 15 | --model_name ${model_name} \ 16 | --data_name ${data_name} \ 17 | --num_classes ${num_classes} \ 18 | --model_path ${model_path} \ 19 | --data_dir ${data_dir} \ 20 | --save_dir ${save_dir} 21 | -------------------------------------------------------------------------------- /utils/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def walk_file(path): 6 | count = 0 7 | for root, dirs, files in os.walk(path): 8 | print(root) 9 | 10 | for f in files: 11 | count += 1 12 | # print(os.path.join(root, f)) 13 | 14 | for d in dirs: 15 | print(os.path.join(root, d)) 16 | print(count) 17 | 18 | 19 | def count_files(path): 20 | for root, dirs, files in os.walk(path): 21 | print(root, len(files)) 22 | 23 | 24 | def copy_file(src, dst): 25 | path, name = os.path.split(dst) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | shutil.copyfile(src, dst) 29 | 30 | 31 | if __name__ == '__main__': 32 | count_files('') 33 | -------------------------------------------------------------------------------- /get_all_activation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=[Your custom-defined path] 3 | export CUDA_VISIBLE_DEVICES=0 4 | export result_path='[Your custom-defined path]' 5 | export exp_name='ideal' 6 | export model_name='tvit' 7 | export data_name='imagenet' 8 | export num_classes=1000 9 | export model_path=${result_path}'/train_models/pth/'${data_name}'/tvit_'${data_name}'_base.pth' 10 | export data_dir=${result_path}'/'${exp_name}'/low_images/'${model_name}'/'${data_name} 11 | export save_dir=${result_path}'/tmp/garbage' 12 | python doctor/get_all_activation.py \ 13 | --model_name ${model_name} \ 14 | --data_name ${data_name} \ 15 | --num_classes ${num_classes} \ 16 | --model_path ${model_path} \ 17 | --data_dir ${data_dir} \ 18 | --save_dir ${save_dir} 19 | -------------------------------------------------------------------------------- /get_attention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=[Your custom-defined path] 3 | export CUDA_VISIBLE_DEVICES=1 4 | export result_path='[Your custom-defined path]' 5 | export exp_name='ideal' 6 | export model_name='tvit' 7 | export data_name='imagenet' 8 | export num_classes=1000 9 | export model_path=${result_path}'/train_models/pth/'${data_name}'/tvit_'${data_name}'_all.pth' 10 | # export data_dir=${result_path}'/'${exp_name}'/low_images/tvit/'${data_name} 11 | export data_dir='[Your custom-defined path]' 12 | export save_dir=${result_path}'/tmp/attentions/tvit/npys' 13 | python doctor/get_attention.py \ 14 | --model_name ${model_name} \ 15 | --data_name ${data_name} \ 16 | --num_classes ${num_classes} \ 17 | --model_path ${model_path} \ 18 | --data_dir ${data_dir} \ 19 | --save_dir ${save_dir} -------------------------------------------------------------------------------- /get_class_activation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=[Your custom-defined path] 3 | export CUDA_VISIBLE_DEVICES=0 4 | export result_path='[Your custom-defined path]' 5 | export exp_name='ideal' 6 | export model_name='tvit' 7 | export data_name='imagenet' 8 | export num_classes=1000 9 | export model_path=${result_path}'/train_models/pth/'${data_name}'/'${model_name}'_'${data_name}'_base.pth' 10 | export data_dir=${result_path}'/'${exp_name}'/low_images/'${model_name}'/'${data_name} 11 | export save_dir=${result_path}'/tmp/class_activations/'${model_name}'/low/npys' 12 | python doctor/get_class_activation.py \ 13 | --model_name ${model_name} \ 14 | --data_name ${data_name} \ 15 | --num_classes ${num_classes} \ 16 | --model_path ${model_path} \ 17 | --data_dir ${data_dir} \ 18 | --save_dir ${save_dir} 19 | -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from loaders.image_loader import load_images 2 | from loaders.image_loader import load_images_masks 3 | 4 | 5 | def load_data(data_dir, data_name, data_type, batch_size): 6 | print('-' * 50) 7 | print('DATA PATH:', data_dir) 8 | print('DATA NAME:', data_name) 9 | print('DATA TYPE:', data_type) 10 | print('-' * 50) 11 | 12 | return load_images(data_dir, data_name, data_type, batch_size) 13 | 14 | 15 | def load_data_mask(data_dir, data_name, mask_dir, data_type, batch_size): 16 | print('-' * 50) 17 | print('DATA PATH:', data_dir) 18 | print('DATA NAME:', data_name) 19 | print('MASK PATH:', mask_dir) 20 | print('DATA TYPE:', data_type) 21 | 22 | print('-' * 50) 23 | 24 | return load_images_masks(data_dir, data_name, mask_dir, data_type, batch_size) -------------------------------------------------------------------------------- /get_grads.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export result_path='[Your custom-defined path]' 3 | export exp_name='ideal' 4 | export model_name='tvit' 5 | export data_name='imagenet50' 6 | export in_channels=3 7 | export num_classes=50 8 | export model_path='[Your custom-defined path]' 9 | export data_dir='[Your custom-defined path]' 10 | # export grad_path=${result_path}${exp_name}'/visualize/'${data_name}'/grad' 11 | export grad_path='[Your custom-defined path]' 12 | export theta=0.4 13 | export device_index='0' 14 | python doctor/get_grads.py \ 15 | --model_name ${model_name} \ 16 | --data_name ${data_name} \ 17 | --in_channels ${in_channels} \ 18 | --num_classes ${num_classes} \ 19 | --model_path ${model_path} \ 20 | --data_path ${data_dir} \ 21 | --grad_path ${grad_path} \ 22 | --theta ${theta} \ 23 | --device_index ${device_index} 24 | -------------------------------------------------------------------------------- /get_all_attention.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=[Your custom-defined path] 3 | export CUDA_VISIBLE_DEVICES=1 4 | #export result_path='[Your custom-defined path]' 5 | #export exp_name='ideal' 6 | export model_name='tvit' 7 | export data_name='imagenet' 8 | export num_classes=1000 9 | export model_path=[Your custom-defined path] 10 | #export data_dir=[Your custom-defined path] 11 | #export save_dir=[Your custom-defined path] 12 | #export data_dir=[Your custom-defined path] 13 | #export save_dir=[Your custom-defined path] 14 | export data_dir=[Your custom-defined path] 15 | export save_dir=[Your custom-defined path] 16 | python doctor/get_all_attention.py \ 17 | --model_name ${model_name} \ 18 | --data_name ${data_name} \ 19 | --num_classes ${num_classes} \ 20 | --model_path ${model_path} \ 21 | --data_dir ${data_dir} \ 22 | --save_dir ${save_dir} 23 | -------------------------------------------------------------------------------- /test_mask.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="[Your custom-defined path]" 2 | export CUDA_VISIBLE_DEVICES=1 3 | export result_dir='[Your custom-defined path]/output' 4 | export exp_name='hjc_test' 5 | export data_name='imagenet10' 6 | export num_classes=10 7 | #export data_dir='[Your custom-defined path]/dataset/'${data_name}'/test' 8 | export data_dir='[Your custom-defined path]/ideal/low_test/tvit/imagenet10' 9 | export save_dir=${result_dir}'/'${exp_name} 10 | export model_names='tvit' 11 | #export model_names='deit' 12 | export model_paths='[Your custom-defined path]/output/train_models/pth/'${data_name}'/'${model_names}'_'${data_name}'_base.pth' 13 | 14 | python core/test_mask.py \ 15 | --model_name ${model_names} \ 16 | --data_name ${data_name} \ 17 | --num_classes ${num_classes} \ 18 | --model_path ${model_paths} \ 19 | --data_dir ${data_dir} \ 20 | --save_dir ${save_dir} 21 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="[Your custom-defined path]" 2 | export CUDA_VISIBLE_DEVICES=1 3 | export result_dir='[Your custom-defined path]/output' 4 | export exp_name='hjc_test' 5 | export data_name='imagenet10' 6 | export num_classes=10 7 | # export data_dir='[Your custom-defined path]/output/ideal/high_images/vit/'${data_name} 8 | export data_dir='[Your custom-defined path]/dataset/'${data_name}'/test' 9 | #export data_dir=${result_dir}'/'${exp_name}'/low_images/'${model_name}'/'${data_name} 10 | export save_dir=${result_dir}'/'${exp_name} 11 | #export model_names='tvit' 12 | export model_names='deit' 13 | export model_paths='[Your custom-defined path]/output/train_models/pth/'${data_name}'/'${model_names}'_'${data_name}'_base.pth' 14 | 15 | python doctor/test.py \ 16 | --model_name ${model_names} \ 17 | --data_name ${data_name} \ 18 | --num_classes ${num_classes} \ 19 | --model_path ${model_paths} \ 20 | --data_dir ${data_dir} \ 21 | --save_dir ${save_dir} 22 | -------------------------------------------------------------------------------- /train_1k.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="[Your custom-defined path]" 2 | export CUDA_VISIBLE_DEVICES=0 3 | export result_dir='[Your custom-defined path]/output' 4 | export exp_name='train_models' 5 | export data_name='imagenet10' 6 | export num_classes=10 7 | export num_epochs=300 8 | export model_name='tvit' 9 | export train_data_dir='[Your custom-defined path]/dataset/'${data_name}'/train' 10 | export test_data_dir='[Your custom-defined path]/dataset/'${data_name}'/test' 11 | export log_dir=${result_dir}'/'${exp_name}'/tensorboard/'${data_name}'/'${model_name}'_'${data_name}'_baseline_test' 12 | export model_path=${result_dir}'/'${exp_name}'/pth/'${data_name} 13 | 14 | python doctor/train_1k.py \ 15 | --model_name ${model_name} \ 16 | --data_name ${data_name} \ 17 | --num_classes ${num_classes} \ 18 | --num_epochs ${num_epochs} \ 19 | --model_dir ${model_path} \ 20 | --train_data_dir ${train_data_dir} \ 21 | --test_data_dir ${test_data_dir} \ 22 | --log_dir ${log_dir} 23 | -------------------------------------------------------------------------------- /get_all_channel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=[Your custom-defined path] 3 | export CUDA_VISIBLE_DEVICES=1 4 | #export result_path='[Your custom-defined path]' 5 | #export exp_name='tmp' 6 | export model_name='tvit' 7 | export data_name='imagenet' 8 | export in_channels=3 9 | export num_classes=1000 10 | export model_path=[Your custom-defined path] 11 | #export data_dir=[Your custom-defined path] 12 | #export save_dir=[Your custom-defined path] 13 | export data_dir=[Your custom-defined path] 14 | export save_dir=[Your custom-defined path] 15 | #export data_dir=[Your custom-defined path] 16 | #export save_dir=[Your custom-defined path] 17 | export theta=0.15 18 | export device_index='0' 19 | python doctor/get_all_channel.py \ 20 | --model_name ${model_name} \ 21 | --data_name ${data_name} \ 22 | --in_channels ${in_channels} \ 23 | --num_classes ${num_classes} \ 24 | --model_path ${model_path} \ 25 | --data_path ${data_dir} \ 26 | --grad_path ${save_dir} \ 27 | --theta ${theta} \ 28 | --device_index ${device_index} 29 | -------------------------------------------------------------------------------- /grad_calculate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export result_path='[Your custom-defined path]' 3 | export device_index='2' 4 | export exp_name='ideal' 5 | export model_name='tvit' 6 | export data_name='cifar10' 7 | export in_channels=3 8 | export num_classes=10 9 | export batch_size=512 10 | export model_path=${result_path}'/train_models/pth/'${data_name}'/'${model_name}'_'${data_name}'_base.pth' 11 | # export model_path=${result_path}'/baseline/'${model_name}'/checkpoint.pth' 12 | export data_dir=${result_path}'/'${exp_name}'/high_images/'${model_name}'/'${data_name} 13 | export grad_path=${result_path}${exp_name}'/grads/'${model_name}'/'${data_name} 14 | export theta=0.15 15 | python doctor/grad_calculate.py \ 16 | --model_name ${model_name} \ 17 | --data_name ${data_name} \ 18 | --in_channels ${in_channels} \ 19 | --num_classes ${num_classes} \ 20 | --model_path ${model_path} \ 21 | --data_path ${data_dir} \ 22 | --grad_path ${grad_path} \ 23 | --theta ${theta} \ 24 | --batch_size ${batch_size} \ 25 | --device_index ${device_index} 26 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="[Your custom-defined path]" 2 | export CUDA_VISIBLE_DEVICES=2 3 | export result_dir='[Your custom-defined path]/output' 4 | export exp_name='train_models' 5 | export data_name='cifar10' 6 | export num_classes=10 7 | export num_epochs=300 8 | export batch_size=512 9 | export model_name='tvit' 10 | export train_data_dir='[Your custom-defined path]/dataset/'${data_name}'/train' 11 | export test_data_dir='[Your custom-defined path]/dataset/'${data_name}'/test' 12 | export result_name=${model_names}'_'${data_name}'_base' 13 | export log_dir=${result_dir}'/'${exp_name}'/tensorboard/'${data_name}'/'${result_name} 14 | export model_path=${result_dir}'/'${exp_name}'/pth/'${data_name} 15 | 16 | python doctor/train.py \ 17 | --model_name ${model_name} \ 18 | --data_name ${data_name} \ 19 | --num_classes ${num_classes} \ 20 | --num_epochs ${num_epochs} \ 21 | --batch_size ${batch_size} \ 22 | --model_dir ${model_path} \ 23 | --train_data_dir ${train_data_dir} \ 24 | --test_data_dir ${test_data_dir} \ 25 | --result_name ${result_name} \ 26 | --log_dir ${log_dir} 27 | -------------------------------------------------------------------------------- /sample_sift.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/vipa-nfs/homes/hjc/projects/TD/codes 3 | export CUDA_VISIBLE_DEVICES=1 4 | #export result_dir='/mnt/nfs/hjc/project/td/output' 5 | export exp_name='ideal' 6 | export model_name='tvit' 7 | export data_name='imagenet10' 8 | export num_classes=10 9 | export batch_size=512 10 | export model_path='/mnt/nfs/hjc/project/td/output/train_models/pth/'${data_name}'/'${model_name}'_'${data_name}'_base.pth' 11 | export data_dir='/mnt/nfs/hjc/project/td/dataset/'${data_name}'/train' 12 | #export save_dir='/vipa-nfs/homes/hjc/projects/TD/outputs/'${exp_name}'/high_train/'${model_name}'/'${data_name} 13 | export save_dir='/vipa-nfs/homes/hjc/projects/TD/outputs/'${exp_name}'/low_train/'${model_name}'/'${data_name} 14 | export num_samples=10 15 | export is_high_confidence=0 16 | python doctor/sample_sift.py \ 17 | --model_name ${model_name} \ 18 | --data_name ${data_name} \ 19 | --num_classes ${num_classes} \ 20 | --model_path ${model_path} \ 21 | --data_dir ${data_dir} \ 22 | --save_dir ${save_dir} \ 23 | --num_samples ${num_samples} \ 24 | --batch_size ${batch_size} \ 25 | --is_high_confidence ${is_high_confidence} 26 | -------------------------------------------------------------------------------- /utils/train_util.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | def __init__(self, name, fmt=':f'): 3 | self.name = name 4 | self.fmt = fmt 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | 19 | def __str__(self): 20 | fmtstr = '{name}[VAL:{val' + self.fmt + '} AVG:{avg' + self.fmt + '}]' 21 | return fmtstr.format(**self.__dict__) 22 | 23 | 24 | class ProgressMeter(object): 25 | def __init__(self, total, step, prefix, meters): 26 | self._fmtstr = self._get_fmtstr(total) 27 | self.meters = meters 28 | self.prefix = prefix 29 | 30 | self.step = step 31 | 32 | def display(self, running): 33 | if running % self.step == 0: 34 | entries = [self.prefix + self._fmtstr.format(running)] # [prefix xx.xx/xx.xx] 35 | entries += [str(meter) for meter in self.meters] 36 | print(' '.join(entries)) 37 | 38 | def _get_fmtstr(self, total): 39 | num_digits = len(str(total // 1)) 40 | fmt = '{:' + str(num_digits) + 'd}' 41 | return '[' + fmt + '/' + fmt.format(total) + ']' # [prefix xx.xx/xx.xx] 42 | -------------------------------------------------------------------------------- /criterions/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) QIU Tian. All rights reserved. 2 | 3 | from typing import List, Dict 4 | 5 | import torch.nn.functional as F 6 | 7 | from utils.misc import accuracy 8 | 9 | from torch import nn 10 | 11 | 12 | class BaseCriterion(nn.Module): 13 | def __init__(self, losses: List[str], weight_dict: Dict[str, float]): 14 | super().__init__() 15 | self.losses = losses 16 | self.weight_dict = weight_dict 17 | 18 | def forward(self, outputs, targets, **kwargs): 19 | losses = {} 20 | for loss in self.losses: 21 | losses.update(getattr(self, f'loss_{loss}')(outputs, targets, **kwargs)) 22 | 23 | 24 | class CrossEntropy(BaseCriterion): 25 | def __init__(self, losses: List[str], weight_dict: Dict[str, float]): 26 | super().__init__(losses, weight_dict) 27 | 28 | def loss_labels(self, outputs, targets, **kwargs): 29 | if isinstance(outputs, dict): 30 | assert 'logits' in outputs.keys(), \ 31 | f"When using 'loss_labels(self, outputs, targets, **kwargs)' in '{self.__class__.__name__}', " \ 32 | f"if 'outputs' is a dict, 'logits' MUST be the key." 33 | outputs = outputs["logits"] 34 | 35 | loss_ce = F.cross_entropy(outputs, targets, reduction='mean') 36 | losses = {'loss_ce': loss_ce, 'class_error': 100 - accuracy(outputs, targets)[0]} 37 | 38 | return losses 39 | -------------------------------------------------------------------------------- /train_lrp.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="[Your custom-defined path]" 2 | export CUDA_VISIBLE_DEVICES=1 3 | export result_dir='[Your custom-defined path]/output' 4 | export exp_name='train_models' 5 | export data_name='imagenet10' 6 | export num_classes=10 7 | export num_epochs=100 8 | export model_names='evit' 9 | export train_data_dir='[Your custom-defined path]/dataset/'${data_name}'/train' 10 | export test_data_dir='[Your custom-defined path]/dataset/'${data_name}'/test' 11 | export origin_path=${result_dir}'/'${exp_name}'/pth/'${data_name}'/'${model_names}'_'${data_name}'_base.pth' 12 | export log_dir=${result_dir}'/'${exp_name}'/tensorboard/'${data_name}'/'${model_names}'_'${data_name}'_lrp' 13 | export model_paths=${result_dir}'/'${exp_name}'/pth/'${data_name} 14 | export grad_path=${result_dir}'/ideal/grads/'${model_names}'/'${data_name}'/layer_2.npy' 15 | export lrp_path=${result_dir}'/ideal/lrps/'${model_names}'/'${data_name}'/result.npy' 16 | export mask_path=${result_dir}'/ideal/atts/'${data_name} 17 | 18 | python doctor/train_lrp.py \ 19 | --model_name ${model_names} \ 20 | --data_name ${data_name} \ 21 | --num_classes ${num_classes} \ 22 | --num_epochs ${num_epochs} \ 23 | --origin_dir ${origin_path}\ 24 | --model_dir ${model_paths} \ 25 | --train_data_dir ${train_data_dir} \ 26 | --test_data_dir ${test_data_dir} \ 27 | --log_dir ${log_dir} \ 28 | --grad_dir ${grad_path}\ 29 | --lrp_dir ${lrp_path}\ 30 | --mask_dir ${mask_path} 31 | -------------------------------------------------------------------------------- /train_plus.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="[Your custom-defined path]" 2 | export CUDA_VISIBLE_DEVICES=0 3 | export result_dir='[Your custom-defined path]/output' 4 | export exp_name='train_models' 5 | export data_name='imagenet50' 6 | export num_classes=50 7 | export num_epochs=100 8 | export model_names='vit' 9 | export origin_path=${result_dir}'/'${exp_name}'/pth/'${data_name}'/'${model_names}'_'${data_name}'_base.pth' 10 | export train_data_dir='[Your custom-defined path]/dataset/'${data_name}'/train' 11 | export test_data_dir='[Your custom-defined path]/dataset/'${data_name}'/test' 12 | export log_dir=${result_dir}'/'${exp_name}'/tensorboard/'${data_name}'/'${model_names}'_'${data_name}'_attlrp' 13 | export pth_path=${result_dir}'/'${exp_name}'/pth/'${data_name} 14 | export grad_path=${result_dir}'/ideal/grads/'${model_names}'/'${data_name}'/layer_2.npy' 15 | export lrp_path=${result_dir}'/ideal/lrps/'${model_names}'/'${data_name}'/result.npy' 16 | export mask_path=${result_dir}'/ideal/atts/'${data_name} 17 | 18 | python doctor/train_plus.py \ 19 | --model_name ${model_names} \ 20 | --data_name ${data_name} \ 21 | --num_classes ${num_classes} \ 22 | --num_epochs ${num_epochs} \ 23 | --origin_dir ${origin_path}\ 24 | --model_dir ${pth_path} \ 25 | --train_data_dir ${train_data_dir} \ 26 | --test_data_dir ${test_data_dir} \ 27 | --log_dir ${log_dir}\ 28 | --grad_dir ${grad_path}\ 29 | --lrp_dir ${lrp_path}\ 30 | --mask_dir ${mask_path} 31 | -------------------------------------------------------------------------------- /doctor/create_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def save_cv(img, path): 8 | print(path) 9 | img_dir, _ = os.path.split(path) 10 | if not os.path.exists(img_dir): 11 | os.makedirs(img_dir) 12 | cv2.imwrite(path, img) 13 | 14 | 15 | def to_mask(img_shape): 16 | """ 17 | 边界点生成mask 18 | :param img_shape: [h,w] 19 | :param polygons: labelme JSON中的边界点格式 [[x1,y1],[x2,y2],[x3,y3],...[xn,yn]] 20 | :return: mask 0-1 21 | """ 22 | mask = np.ones(img_shape, dtype=np.uint8) 23 | return mask 24 | 25 | 26 | def main(): 27 | mask = to_mask([256, 256]) * 255 28 | mask_path = os.path.join('/nfs/ch/project/td/output/ideal/default', 'default.png') 29 | save_cv(mask, mask_path) 30 | # images_dir = '/nfs3-p1/hjc/datasets/imagenet/ch_select_test/' 31 | # for root, _, files in os.walk(images_dir): 32 | # for file in files: 33 | # if os.path.splitext(file)[1] == '.JPEG': 34 | # img_path = os.path.join(root, file) 35 | # img = cv2.imread(img_path) 36 | 37 | # json_name = os.path.splitext(file)[0] 38 | 39 | # mask = to_mask([img.shape[0], img.shape[1]]) * 255 40 | # # print(mask.shape) 41 | # mask_path = os.path.join('/nfs/ch/project/td/output/ideal/default', 'default.png') 42 | # save_cv(mask, mask_path) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() -------------------------------------------------------------------------------- /doctor/labelme_to_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def save_cv(img, path): 8 | print(path) 9 | img_dir, _ = os.path.split(path) 10 | if not os.path.exists(img_dir): 11 | os.makedirs(img_dir) 12 | cv2.imwrite(path, img) 13 | 14 | 15 | def parse_json(json_path): 16 | data = json.load(open(json_path)) 17 | shapes = data['shapes'][0] 18 | points = shapes['points'] 19 | return points 20 | 21 | 22 | def polygons_to_mask(img_shape, polygons): 23 | mask = np.zeros(img_shape, dtype=np.uint8) 24 | polygons = np.asarray([polygons], np.int32) 25 | cv2.fillPoly(mask, polygons, 1) 26 | return mask 27 | 28 | 29 | def main(): 30 | images_dir = 'xxx' 31 | for root, _, files in os.walk(images_dir): 32 | for file in files: 33 | if os.path.splitext(file)[1] == '.json': 34 | img_path = os.path.join(root, file.replace('json', 'JPEG')) 35 | img = cv2.imread(img_path) 36 | 37 | json_name = os.path.splitext(file)[0] 38 | json_path = os.path.join(root, file) 39 | 40 | mask = polygons_to_mask([img.shape[0], img.shape[1]], parse_json(json_path)) * 255 41 | print(json_name) 42 | print(mask.shape) 43 | mask_path = os.path.join('xxx', json_name + '.png') 44 | print(mask) 45 | save_cv(mask, mask_path) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() -------------------------------------------------------------------------------- /train_att.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="[Your custom-defined path]" 2 | export CUDA_VISIBLE_DEVICES=3 3 | export result_dir='[Your custom-defined path]/output' 4 | export exp_name='train_models' 5 | export data_name='imagenet10' 6 | export num_classes=10 7 | export num_epochs=100 8 | export batch_size=512 9 | export model_names='eva' 10 | export train_data_dir='[Your custom-defined path]/dataset/'${data_name}'/train' 11 | # export train_data_dir=${result_dir}'/ideal/high_images/tvit/'${data_name} 12 | export test_data_dir='[Your custom-defined path]/dataset/'${data_name}'/test' 13 | export origin_path=${result_dir}'/'${exp_name}'/pth/'${data_name}'/'${model_names}'_'${data_name}'_base.pth' 14 | export result_name=${model_names}'_'${data_name}'_att' 15 | export log_dir=${result_dir}'/'${exp_name}'/tensorboard/'${data_name}'/'${result_name} 16 | export model_path=${result_dir}'/'${exp_name}'/pth/'${data_name} 17 | export grad_path=${result_dir}'/ideal/grads/'${model_names}'/'${data_name}'/layer_3.npy' 18 | export lrp_path=${result_dir}'/ideal/lrps/'${model_names}'/'${data_name}'/result.npy' 19 | export mask_path=${result_dir}'/ideal/atts/'${data_name} 20 | 21 | python doctor/train_att.py \ 22 | --model_name ${model_names} \ 23 | --data_name ${data_name} \ 24 | --num_classes ${num_classes} \ 25 | --num_epochs ${num_epochs} \ 26 | --batch_size ${batch_size} \ 27 | --result_name ${result_name} \ 28 | --origin_dir ${origin_path}\ 29 | --model_dir ${model_path} \ 30 | --train_data_dir ${train_data_dir} \ 31 | --test_data_dir ${test_data_dir} \ 32 | --log_dir ${log_dir} \ 33 | --grad_dir ${grad_path}\ 34 | --lrp_dir ${lrp_path}\ 35 | --mask_dir ${mask_path} 36 | -------------------------------------------------------------------------------- /metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(outputs, labels, topk=(1,)): 5 | with torch.no_grad(): 6 | maxk = max(topk) 7 | batch_size = labels.size(0) 8 | 9 | _, pred = outputs.topk(maxk, 1, True, True) # [batch_size, topk] 10 | pred = pred.t() # [topk, batch_size] 11 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) # [topk, batch_size] 12 | 13 | res = [] 14 | for k in topk: 15 | correct_k = correct[:k].float().sum() 16 | res.append(correct_k.mul_(100.0 / batch_size)) 17 | return res 18 | 19 | 20 | class ClassAccuracy: 21 | def __init__(self): 22 | self.sum = {} 23 | self.count = {} 24 | 25 | def update(self, outputs, labels): 26 | _, pred = outputs.max(dim=1) 27 | correct = pred.eq(labels) 28 | 29 | for b, label in enumerate(labels): 30 | label = label.item() 31 | if label not in self.sum.keys(): 32 | self.sum[label] = 0 33 | self.count[label] = 0 34 | self.sum[label] += correct[b].item() 35 | self.count[label] += 1 36 | 37 | def __call__(self): 38 | self.sum = dict(sorted(self.sum.items())) 39 | self.count = dict(sorted(self.count.items())) 40 | return [s / c * 100 for s, c in zip(self.sum.values(), self.count.values())] 41 | 42 | def __getitem__(self, item): 43 | return self.__call__()[item] 44 | 45 | def list(self): 46 | return self.__call__() 47 | 48 | def __str__(self): 49 | fmtstr = '{}:{:6.2f}' 50 | result = '\n'.join([fmtstr.format(l, a) for l, a in enumerate(self.__call__())]) 51 | return result 52 | -------------------------------------------------------------------------------- /train_grad.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="[Your custom-defined path]" 2 | export CUDA_VISIBLE_DEVICES=2 3 | export result_dir='[Your custom-defined path]/output' 4 | export exp_name='train_models' 5 | export data_name='cifar10' 6 | export num_classes=10 7 | export num_epochs=100 8 | export batch_size=512 9 | export model_names='tvit' 10 | export train_data_dir='[Your custom-defined path]/dataset/'${data_name}'/train' 11 | export test_data_dir='[Your custom-defined path]/dataset/'${data_name}'/test' 12 | export origin_path=${result_dir}'/'${exp_name}'/pth/'${data_name}'/'${model_names}'_'${data_name}'_base.pth' 13 | # export origin_path=${result_dir}'/baseline/'${model_names}'/checkpoint.pth' 14 | export i=0 15 | export layer=$((i*4+3)) 16 | export result_name=${model_names}'_'${data_name}'_grad' 17 | export log_dir=${result_dir}'/'${exp_name}'/tensorboard/'${data_name}'/'${result_name} 18 | export model_path=${result_dir}'/'${exp_name}'/pth/'${data_name} 19 | export grad_path=${result_dir}'/ideal/grads/'${model_names}'/'${data_name}'/layer_3.npy' 20 | export lrp_path=${result_dir}'/ideal/lrps/'${model_names}'/'${data_name}'/result.npy' 21 | export mask_path=${result_dir}'/ideal/atts/'${data_name} 22 | 23 | python doctor/train_grad.py \ 24 | --model_name ${model_names} \ 25 | --data_name ${data_name} \ 26 | --num_classes ${num_classes} \ 27 | --num_epochs ${num_epochs} \ 28 | --batch_size ${batch_size} \ 29 | --origin_dir ${origin_path}\ 30 | --result_name ${result_name} \ 31 | --model_dir ${model_path} \ 32 | --train_data_dir ${train_data_dir} \ 33 | --test_data_dir ${test_data_dir} \ 34 | --log_dir ${log_dir} \ 35 | --grad_dir ${grad_path}\ 36 | --lrp_dir ${lrp_path}\ 37 | --mask_dir ${mask_path} 38 | -------------------------------------------------------------------------------- /utils/image_util.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('AGG') 4 | import os 5 | import matplotlib.pyplot as plt 6 | import torchvision 7 | import numpy as np 8 | import cv2 9 | 10 | import seaborn as sns 11 | 12 | 13 | def show_torch_images(images, mode=None): 14 | img = torchvision.utils.make_grid(images, pad_value=1) 15 | np_img = img.numpy() 16 | plt.imshow(np.transpose(np_img, (1, 2, 0)), cmap=mode) 17 | plt.show() 18 | 19 | 20 | def show_cv(image, end=0, name=None): 21 | cv2.imshow('test', image) 22 | if end: 23 | cv2.waitKey(0) 24 | 25 | 26 | def save_cv(img, path): 27 | print(path) 28 | img_dir, _ = os.path.split(path) 29 | if not os.path.exists(img_dir): 30 | os.makedirs(img_dir) 31 | cv2.imwrite(path, img) 32 | 33 | 34 | # def deprocess_image(img): 35 | # """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ 36 | # img = img - np.mean(img) 37 | # img = img / (np.std(img) + 1e-5) 38 | # img = img * 0.1 39 | # img = img + 0.5 40 | # img = np.clip(img, 0, 1) 41 | # return np.uint8(img * 255) 42 | 43 | def deprocess_image(img, std, mean): 44 | import torch 45 | t_mean = torch.FloatTensor(mean).view(3, 1, 1).expand(3, 32, 32).numpy() 46 | t_std = torch.FloatTensor(std).view(3, 1, 1).expand(3, 32, 32).numpy() 47 | img = img * t_std + t_mean 48 | img = np.clip(img, 0, 1) 49 | return np.uint8(img * 255) 50 | 51 | 52 | def view_grads(grads, fig_w, fig_h, fig_path): 53 | f, ax = plt.subplots(figsize=(fig_w, fig_h), ncols=1) 54 | ax.set_xlabel('convolutional kernel') 55 | ax.set_ylabel('category') 56 | sns.heatmap(grads, annot=False, ax=ax) 57 | plt.savefig(fig_path, bbox_inches='tight') 58 | # plt.show() 59 | plt.clf() 60 | -------------------------------------------------------------------------------- /evaluator/default.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import warnings 3 | from typing import List 4 | 5 | from sklearn import metrics as sklearn_metrics 6 | 7 | from utils.misc import all_gather 8 | 9 | warnings.filterwarnings('ignore') 10 | 11 | 12 | class DefaultEvaluator: 13 | def __init__(self, metrics: List[str]): 14 | self.metrics = metrics 15 | self.outputs = [] 16 | self.targets = [] 17 | self.eval = {metric: None for metric in metrics} 18 | 19 | def update(self, outputs, targets): 20 | if isinstance(outputs, dict): 21 | assert 'logits' in outputs.keys(), \ 22 | f"When using 'update(self, outputs, targets)' in '{self.__class__.__name__}', " \ 23 | f"if 'outputs' is a dict, 'logits' MUST be the key." 24 | outputs = outputs['logits'] 25 | outputs = outputs.max(1)[1].tolist() 26 | targets = targets.tolist() 27 | self.outputs += outputs 28 | self.targets += targets 29 | 30 | def synchronize_between_processes(self): 31 | self.outputs = list(itertools.chain(*all_gather(self.outputs))) 32 | self.targets = list(itertools.chain(*all_gather(self.targets))) 33 | 34 | @staticmethod 35 | def metric_acc(outputs, targets, **kwargs): 36 | return sklearn_metrics.accuracy_score(targets, outputs) 37 | 38 | @staticmethod 39 | def metric_recall(outputs, targets, **kwargs): 40 | return sklearn_metrics.recall_score(targets, outputs, average='macro') 41 | 42 | @staticmethod 43 | def metric_precision(outputs, targets, **kwargs): 44 | return sklearn_metrics.precision_score(targets, outputs, average='macro') 45 | 46 | @staticmethod 47 | def metric_f1(outputs, targets, **kwargs): 48 | return sklearn_metrics.f1_score(targets, outputs, average='macro') 49 | 50 | def summarize(self): 51 | print('Classification Metrics:') 52 | for metric in self.metrics: 53 | value = getattr(self, f'metric_{metric}')(self.outputs, self.targets) 54 | self.eval[metric] = value 55 | print(f'{metric}: {value:.3f}', end=' ') 56 | print('\n') 57 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def setup_for_distributed(is_main): 10 | """disables printing when not in main process""" 11 | import builtins as __builtin__ 12 | 13 | builtin_print = __builtin__.print 14 | 15 | def print(*args, **kwargs): 16 | force = kwargs.pop("force", False) 17 | if is_main or force: 18 | builtin_print(*args, **kwargs) 19 | 20 | __builtin__.print = print 21 | 22 | 23 | def is_dist_avail_and_initialized(): 24 | if not dist.is_available(): 25 | return False 26 | if not dist.is_initialized(): 27 | return False 28 | return True 29 | 30 | 31 | def get_world_size(): 32 | if not is_dist_avail_and_initialized(): 33 | return 1 34 | return dist.get_world_size() 35 | 36 | 37 | def get_rank(): 38 | if not is_dist_avail_and_initialized(): 39 | return 0 40 | return dist.get_rank() 41 | 42 | 43 | def is_main_process(): 44 | return get_rank() == 0 45 | 46 | 47 | def init_distributed_mode(args): 48 | if args.no_dist: # qt + 49 | args.distributed = False 50 | return 51 | 52 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 53 | args.rank = int(os.environ["RANK"]) 54 | args.world_size = int(os.environ['WORLD_SIZE']) 55 | args.gpu = int(os.environ['LOCAL_RANK']) 56 | elif 'SLURM_PROCID' in os.environ: 57 | args.rank = int(os.environ['SLURM_PROCID']) 58 | args.gpu = args.rank % torch.cuda.device_count() 59 | else: 60 | print('Not using distributed mode') 61 | args.distributed = False 62 | return 63 | 64 | args.distributed = True 65 | 66 | torch.cuda.set_device(args.gpu) 67 | print('Distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) 68 | 69 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 70 | world_size=args.world_size, rank=args.rank) 71 | dist.barrier() 72 | setup_for_distributed(is_main=args.rank == 0) 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer Doctor: Diagnosing and Treating Vision Transformers 2 | 3 | ![Transformer Doctor](framewrok.png) 4 | 5 | This repository contains the implementation for Transformer Doctor, a toolkit designed to diagnose and treat Transformer models for improved performance. Below are the details on setting up the environment and utilizing the provided scripts for training and refinement steps. For more information, 6 | please visit https://transformer-doctor.github.io/. 7 | 8 | ## Environment Setup 9 | 10 | - **GPU**: NVIDIA A6000 11 | - **Python Environment**: 12 | - Python 3.9 13 | 14 | ## Key Steps 15 | 16 | ### 1. Training Baseline Model 17 | 18 | To train a baseline model, execute the following command: 19 | 20 | ```bash 21 | bash train.sh 22 | ``` 23 | ### 2. Sifting High Confidence Samples 24 | To obtain high-confidence samples based on the trained model, run: 25 | ```bash 26 | bash sample_sift.sh 27 | ``` 28 | Key Parameter: 29 | 30 | - is_high_confidence: Flag to filter high (1) or low (0) confidence samples 31 | 32 | ### 3. Static Integration Rules for Intra-Token Information 33 | Compute static integration rules for intra-token information using: 34 | ```bash 35 | bash grad_calculate.sh 36 | ``` 37 | Key Parameters: 38 | 39 | - theta: Threshold value 40 | - data_dir: Path to directory containing high-confidence samples 41 | 42 | ### 4. Treatment of Conjunction Errors in Dynamic Integration of Inter-Token Information 43 | To treat conjunction errors in dynamic integration of inter-token information, execute: 44 | ```bash 45 | bash train_att.sh 46 | ``` 47 | Key Parameters: 48 | 49 | - origin_path: Path to the original model 50 | - mask_path: Path to foreground annotation 51 | 52 | ### 5. Treatment of Conjunction Errors in Static Integration of Intra-Token Information 53 | For treating conjunction errors in static integration of intra-token information, use: 54 | ```bash 55 | bash train_grad.sh 56 | ``` 57 | Key Parameters: 58 | 59 | - origin_path: Path to the original model 60 | - grad_path: Path to static aggregation rules 61 | 62 | ## Further Steps 63 | Following these key steps will aid in diagnosing and treating Transformer models using the Transformer Doctor toolkit. Stay tuned for more detailed code explanations and enhancements. 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /feature_sift.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export PYTHONPATH=/nfs3-p1/ch/project/td/codes 3 | export CUDA_VISIBLE_DEVICES=0 4 | # export CUDA_LAUNCH_BLOCKING=1 5 | # export TORCH_USE_CUDA_DSA=1 6 | export result_path='/nfs3-p1/ch/project/td/output' 7 | #export exp_name='vgg16_cifar10_pure' 8 | #export model_name='vgg16' 9 | export exp_name='vitb16' 10 | export model_name='vitb16py' 11 | export data_name='imagenet' 12 | export num_classes=1 13 | export model_path=${result_path}'/'${exp_name}'/models/model.safetensors' 14 | export data_dir='/nfs3-p1/ch/project/td/output/vitb16/images/chtrain_vitb16' 15 | #export data_dir='/nfs3-p1/hjc/datasets/cifar10/test' 16 | #export data_dir=${result_path}'/'${exp_name}'/adv_images/PGD/test' 17 | export save_dir=${result_path}'/'${exp_name}'/features' 18 | export num_samples=50 19 | python doctor/feature_sift.py \ 20 | --model_name ${model_name} \ 21 | --data_name ${data_name} \ 22 | --num_classes ${num_classes} \ 23 | --model_path ${model_path} \ 24 | --data_dir ${data_dir} \ 25 | --save_dir ${save_dir} \ 26 | --num_samples ${num_samples} 27 | 28 | #echo "STRAT !!!" 29 | #export PYTHONPATH=/nfs3-p1/hjc/projects/RO/codes 30 | #export CUDA_VISIBLE_DEVICES=0 31 | #export result_path='/nfs3-p1/hjc/projects/RO/outputs' 32 | #export model_name='vgg16' 33 | #export data_name='cifar10' 34 | ##export data_dir='/nfs3-p1/hjc/datasets/cifar10/train' 35 | #export data_dir=${result_path}'/'${exp_name}'/adv_images/PGD/train' 36 | #export num_classes=10 37 | #export num_samples=5000 38 | #exp_names=( 39 | # #'vgg16d_cifar10_$04041006b' 40 | # #'vgg16d_cifar10_$04041006c' 41 | # #'vgg16d_cifar10_$04041006f' 42 | # #'vgg16d_cifar10_$04041006g' 43 | # #'vgg16d_cifar10_$04042324a' 44 | # #'vgg16d_cifar10_$04042324b' 45 | # 'vgg16d_cifar10_04162018c7' 46 | # #'vgg16d_cifar10_04162018c8' 47 | #) 48 | #for exp_name in ${exp_names[*]}; do 49 | # export model_path=${result_path}'/'${exp_name}'/models/model_ori.pth' 50 | # export save_dir=${result_path}'/'${exp_name}'/features' 51 | # python doctor/feature_sift.py \ 52 | # --model_name ${model_name} \ 53 | # --data_name ${data_name} \ 54 | # --num_classes ${num_classes} \ 55 | # --model_path ${model_path} \ 56 | # --data_dir ${data_dir} \ 57 | # --save_dir ${save_dir} \ 58 | # --num_samples ${num_samples} & 59 | #done 60 | #wait 61 | #echo "ACCOMPLISH !!!" 62 | -------------------------------------------------------------------------------- /loaders/image_enhance.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter, ImageStat 2 | 3 | 4 | class _Enhance: 5 | def enhance(self, factor): 6 | """ 7 | Returns an enhanced image. 8 | 9 | :param factor: A floating point value controlling the enhancement. 10 | Factor 1.0 always returns a copy of the original image, 11 | lower factors mean less color (brightness, contrast, 12 | etc), and higher values more. There are no restrictions 13 | on this value. 14 | :rtype: :py:class:`~PIL.Image.Image` 15 | """ 16 | trans = [Image.blend(self.degenerate, self.images[0], factor), self.images[1]] 17 | return trans 18 | 19 | 20 | class MyColor(_Enhance): 21 | """Adjust image color balance. 22 | 23 | This class can be used to adjust the colour balance of an image, in 24 | a manner similar to the controls on a colour TV set. An enhancement 25 | factor of 0.0 gives a black and white image. A factor of 1.0 gives 26 | the original image. 27 | """ 28 | 29 | def __init__(self, images): 30 | self.images = images 31 | self.intermediate_mode = "L" 32 | if "A" in images[0].getbands(): 33 | self.intermediate_mode = "LA" 34 | 35 | self.degenerate = images[0].convert(self.intermediate_mode).convert(images[0].mode) 36 | 37 | 38 | class MyContrast(_Enhance): 39 | """Adjust image contrast. 40 | 41 | This class can be used to control the contrast of an image, similar 42 | to the contrast control on a TV set. An enhancement factor of 0.0 43 | gives a solid grey image. A factor of 1.0 gives the original image. 44 | """ 45 | 46 | def __init__(self, images): 47 | self.images = images 48 | mean = int(ImageStat.Stat(images[0].convert("L")).mean[0] + 0.5) 49 | self.degenerate = Image.new("L", images[0].size, mean).convert(images[0].mode) 50 | 51 | if "A" in images[0].getbands(): 52 | self.degenerate.putalpha(images[0].getchannel("A")) 53 | 54 | 55 | class MyBrightness(_Enhance): 56 | """Adjust image brightness. 57 | 58 | This class can be used to control the brightness of an image. An 59 | enhancement factor of 0.0 gives a black image. A factor of 1.0 gives the 60 | original image. 61 | """ 62 | 63 | def __init__(self, images): 64 | self.images = images 65 | self.degenerate = Image.new(images[0].mode, images[0].size, 0) 66 | 67 | if "A" in images[0].getbands(): 68 | self.degenerate.putalpha(images[0].getchannel("A")) 69 | 70 | 71 | class MySharpness(_Enhance): 72 | """Adjust image sharpness. 73 | 74 | This class can be used to adjust the sharpness of an image. An 75 | enhancement factor of 0.0 gives a blurred image, a factor of 1.0 gives the 76 | original image, and a factor of 2.0 gives a sharpened image. 77 | """ 78 | 79 | def __init__(self, images): 80 | self.images = images 81 | self.degenerate = images[0].filter(ImageFilter.SMOOTH) 82 | 83 | if "A" in images[0].getbands(): 84 | self.degenerate.putalpha(images[0].getchannel("A")) 85 | -------------------------------------------------------------------------------- /engines/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | 5 | import torch 6 | from torch import nn 7 | import collections 8 | 9 | import loaders 10 | import models 11 | import metrics 12 | from utils.train_util import AverageMeter, ProgressMeter 13 | from tqdm import tqdm 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description='') 18 | parser.add_argument('--model_name', default='', type=str, help='model name') 19 | parser.add_argument('--data_name', default='', type=str, help='data name') 20 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 21 | parser.add_argument('--model_path', default='', type=str, help='model path') 22 | parser.add_argument('--data_dir', default='', type=str, help='data directory') 23 | args = parser.parse_args() 24 | 25 | # ---------------------------------------- 26 | # basic configuration 27 | # ---------------------------------------- 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | 30 | print('-' * 50) 31 | print('TEST ON:', device) 32 | print('MODEL PATH:', args.model_path) 33 | print('DATA PATH:', args.data_dir) 34 | print('-' * 50) 35 | 36 | # ---------------------------------------- 37 | # trainer configuration 38 | # ---------------------------------------- 39 | state = torch.load(args.model_path) 40 | if isinstance(state, collections.OrderedDict): 41 | model = models.load_model(args.model_name, num_classes=args.num_classes) 42 | model.load_state_dict(state) 43 | else: 44 | model = state 45 | model.to(device) 46 | 47 | test_loader = loaders.load_data(args.data_dir, args.data_name, data_type='test') 48 | 49 | criterion = nn.CrossEntropyLoss() 50 | 51 | # ---------------------------------------- 52 | # each epoch 53 | # ---------------------------------------- 54 | since = time.time() 55 | 56 | loss, acc1, acc5, class_acc = test(test_loader, model, criterion, device) 57 | 58 | print('-' * 50) 59 | print('COMPLETE !!!') 60 | print(class_acc) 61 | print('AVG:', acc1.avg) 62 | print('TIME CONSUMED', time.time() - since) 63 | 64 | 65 | def test(test_loader, model, criterion, device): 66 | loss_meter = AverageMeter('Loss', ':.4e') 67 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 68 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 69 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test', 70 | meters=[loss_meter, acc1_meter, acc5_meter]) 71 | class_acc = metrics.ClassAccuracy() 72 | model.eval() 73 | 74 | for i, samples in tqdm(enumerate(test_loader)): 75 | inputs, labels, _ = samples 76 | inputs = inputs.to(device) 77 | labels = labels.to(device) 78 | 79 | with torch.set_grad_enabled(False): 80 | outputs = model(inputs) 81 | loss = criterion(outputs, labels) 82 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 1)) 83 | class_acc.update(outputs, labels) 84 | 85 | loss_meter.update(loss.item(), inputs.size(0)) 86 | acc1_meter.update(acc1.item(), inputs.size(0)) 87 | acc5_meter.update(acc5.item(), inputs.size(0)) 88 | 89 | progress.display(i) 90 | 91 | return loss_meter, acc1_meter, acc5_meter, class_acc 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /utils/fig_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import torchvision 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | import numpy as np 11 | import pandas as pd 12 | 13 | matplotlib.use('AGG') 14 | 15 | 16 | def heatmap(vals, fig_path, fig_w=None, fig_h=None, annot=False): 17 | if fig_w is None: 18 | fig_w = vals.shape[1] 19 | if fig_h is None: 20 | fig_h = vals.shape[0] 21 | 22 | f, ax = plt.subplots(figsize=(fig_w, fig_h), ncols=1) 23 | sns.heatmap(vals, ax=ax, annot=annot) 24 | plt.savefig(fig_path, bbox_inches='tight') 25 | plt.clf() 26 | 27 | 28 | def draw_group(data, fig_path): 29 | # features = features[:, :, 0:50] 30 | 31 | t = data.shape[0] 32 | s = data.shape[1] 33 | d = data.shape[2] 34 | 35 | x = np.tile(np.arange(0, d), t * s) 36 | y = data.flatten() 37 | z = np.repeat(np.arange(0, t), s * d) 38 | 39 | plt.figure(figsize=(15, 10)) 40 | sns.scatterplot(x=x, y=y, hue=z, s=20, alpha=0.5, style=z, palette=sns.color_palette("hls", t)) 41 | # sns.lineplot(x="channel", y="value", data=data_plot, hue=z) 42 | 43 | plt.savefig(fig_path, bbox_inches='tight') 44 | plt.clf() 45 | 46 | 47 | def draw_single(data, fig_path): 48 | np.set_printoptions(threshold=np.inf) 49 | n = data.shape[0] # 一个坐标几种数据 50 | # s = data.shape[1] 51 | d = data.shape[1] # 坐标长度 52 | 53 | x = np.tile(np.arange(0, d), n) # [0-d,0-d,...,0-d] * 外n 54 | # print(x) 55 | y = data.flatten() 56 | z = np.repeat(np.arange(0, n), d) # [0-0,1-1,...,n-n] * 内d 57 | # print(z) 58 | 59 | plt.figure(figsize=(15, 10)) 60 | # sns.lineplot(x=x, y=y, hue=z, style=z, palette=sns.color_palette("hls", n)) 61 | sns.scatterplot(x=x, y=y, hue=z, style=z, palette=sns.color_palette("hls", n)) 62 | 63 | plt.savefig(fig_path, bbox_inches='tight') 64 | plt.clf() 65 | 66 | 67 | def draw_box(data, fig_path): 68 | C, N, D = data.shape 69 | 70 | x = np.tile(np.arange(0, D), N * C) 71 | z = np.repeat(np.arange(0, C), N * D) 72 | y = data.flatten() 73 | 74 | plt.figure(figsize=(15, 10)) 75 | sns.boxplot(x=x, y=y, hue=z, palette=sns.color_palette("hls", C)) 76 | # sns.scatterplot(x=x, y=y, hue=z, s=20, alpha=0.5, style=z, palette=sns.color_palette("hls", C)) 77 | # plt.show() 78 | plt.savefig(fig_path, bbox_inches='tight') 79 | plt.clf() 80 | 81 | 82 | # def feature_distribute(features, fig_path): 83 | # plt.figure(figsize=(15, 10)) 84 | # 85 | # s = features.shape[0] 86 | # d = features.shape[1] 87 | # 88 | # x = np.tile(np.arange(0, d), s) 89 | # y = features.flatten() 90 | # 91 | # sns.scatterplot(x=x, y=y, s=20, alpha=0.5) 92 | # # sns.lineplot(x=x, y=y) 93 | # 94 | # plt.savefig(fig_path, bbox_inches='tight') 95 | # plt.clf() 96 | 97 | 98 | def imshow(img, title, fig_path): 99 | img = torchvision.utils.make_grid(img.cpu().data, normalize=True, nrow=10) 100 | npimg = img.numpy() 101 | # fig = plt.figure(figsize=(5, 15)) 102 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 103 | # plt.title(title) 104 | # plt.show() 105 | 106 | plt.title(title) 107 | plt.savefig(fig_path, bbox_inches='tight') 108 | plt.clf() 109 | 110 | 111 | def save_img_by_cv2(img, path): 112 | img_dir, _ = os.path.split(path) 113 | if not os.path.exists(img_dir): 114 | os.makedirs(img_dir) 115 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 116 | cv2.imwrite(path, img) 117 | -------------------------------------------------------------------------------- /doctor/get_all_activation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch 10 | import torch.nn as nn 11 | 12 | import models 13 | import loaders 14 | import pickle 15 | 16 | 17 | class HookModule: 18 | def __init__(self, module): 19 | self.module = module 20 | self.inputs = None 21 | self.outputs = None 22 | module.register_forward_hook(self._hook) 23 | 24 | def _hook(self, module, inputs, outputs): 25 | self.inputs = inputs[0] 26 | self.outputs = outputs 27 | 28 | 29 | class ActivationGet: 30 | def __init__(self, modules): 31 | self.modules = [HookModule(module) for module in modules] 32 | 33 | def __call__(self, names, outputs, labels): 34 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 35 | 36 | for layer in range(12): 37 | activtations = self.modules[int(4 * layer + 3)].outputs 38 | activtations = activtations.cpu().detach().numpy() 39 | folder_path = '/mnt/nfs/hjc/project/td/output/tmp/activations/npys/{}'.format(layer) 40 | count = len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]) 41 | save_path = os.path.join(folder_path, 'img_{}'.format(count)) 42 | np.save(save_path, activtations) 43 | print(save_path) 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser(description='') 48 | parser.add_argument('--model_name', default='', type=str, help='model name') 49 | parser.add_argument('--data_name', default='', type=str, help='data name') 50 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 51 | parser.add_argument('--model_path', default='', type=str, help='model path') 52 | parser.add_argument('--data_dir', default='', type=str, help='data path') 53 | parser.add_argument('--save_dir', default='', type=str, help='locating path') 54 | args = parser.parse_args() 55 | 56 | # ---------------------------------------- 57 | # basic configuration 58 | # ---------------------------------------- 59 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 60 | 61 | if not os.path.exists(args.save_dir): 62 | os.makedirs(args.save_dir) 63 | 64 | print('-' * 50) 65 | print('TRAIN ON:', device) 66 | print('DATA DIR:', args.data_dir) 67 | print('SAVE DIR:', args.save_dir) 68 | print('-' * 50) 69 | 70 | # ---------------------------------------- 71 | # model/data configuration 72 | # ---------------------------------------- 73 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 74 | model.load_state_dict(torch.load(args.model_path)) 75 | model.to(device) 76 | model.eval() 77 | 78 | mask_dir = '/mnt/nfs/hjc/project/td/output/ideal/atts/imagenet' 79 | data_loader = loaders.load_data_mask(args.data_dir, args.data_name, mask_dir, data_type='test') 80 | 81 | modules = models.load_modules(model=model) 82 | 83 | attention_get = ActivationGet(modules) 84 | 85 | # ---------------------------------------- 86 | # forward 87 | # ---------------------------------------- 88 | for i, samples in enumerate(tqdm(data_loader)): 89 | inputs, labels, names, masks = samples 90 | inputs = inputs.to(device) 91 | labels = labels.to(device) 92 | outputs = model(inputs) 93 | attention_get(names, outputs, labels) 94 | 95 | 96 | if __name__ == '__main__': 97 | np.set_printoptions(threshold=np.inf) 98 | main() -------------------------------------------------------------------------------- /loaders/image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import PIL.Image as Image 5 | from torch.utils.data import Dataset 6 | 7 | def _img_loader(path, mode='RGB'): 8 | assert mode in ['RGB', 'L'] 9 | 10 | # default_path = '/nfs/ch/project/td/output/ideal/default/default.png' 11 | default_path = '/mnt/nfs/hjc/project/td/output/ideal/default/default.png' 12 | if not os.path.exists(path): 13 | path = default_path 14 | 15 | with open(path, 'rb') as f: 16 | img = Image.open(f) 17 | return img.convert(mode) 18 | 19 | 20 | def _find_classes(root): 21 | class_names = [d.name for d in os.scandir(root) if d.is_dir()] 22 | class_names.sort() 23 | classes_indices = {class_names[i]: i for i in range(len(class_names))} 24 | # print(classes_indices) 25 | return class_names, classes_indices # 'class_name':index 26 | 27 | 28 | def _make_dataset(image_dir): 29 | samples = [] # image_path, class_idx 30 | 31 | class_names, class_indices = _find_classes(image_dir) 32 | 33 | for class_name in sorted(class_names): 34 | class_idx = class_indices[class_name] 35 | target_dir = os.path.join(image_dir, class_name) 36 | 37 | if not os.path.isdir(target_dir): 38 | continue 39 | 40 | for root, _, files in sorted(os.walk(target_dir)): 41 | for file in sorted(files): 42 | image_path = os.path.join(root, file) 43 | item = image_path, class_idx 44 | samples.append(item) 45 | return samples 46 | 47 | 48 | def _make_dataset_mask(image_dir, mask_dir): 49 | samples = [] # image_path, mask_path, class_idx 50 | 51 | class_names, class_indices = _find_classes(image_dir) 52 | 53 | for class_name in sorted(class_names): 54 | class_idx = class_indices[class_name] 55 | target_dir = os.path.join(image_dir, class_name) 56 | 57 | if not os.path.isdir(target_dir): 58 | continue 59 | 60 | for root, _, files in sorted(os.walk(target_dir)): 61 | for file in sorted(files): 62 | image_path = os.path.join(root, file) 63 | mask_path = os.path.join(mask_dir, file.replace('JPEG', 'png')) 64 | item = image_path, mask_path, class_idx 65 | samples.append(item) 66 | return samples 67 | 68 | 69 | class ImageDataset(Dataset): 70 | def __init__(self, image_dir, transform=None): 71 | self.image_dir = image_dir 72 | self.transform = transform 73 | self.samples = _make_dataset(self.image_dir) 74 | self.targets = [s[1] for s in self.samples] 75 | 76 | def __getitem__(self, index): 77 | image_path, target = self.samples[index] 78 | image = _img_loader(image_path, mode='RGB') 79 | name = os.path.split(image_path)[1] 80 | 81 | if self.transform is not None: 82 | image = self.transform(image) 83 | 84 | return image, target, name 85 | 86 | def __len__(self): 87 | return len(self.samples) 88 | 89 | 90 | class ImageMaskDataset(Dataset): 91 | def __init__(self, image_dir, mask_dir, transform=None): 92 | self.image_dir = image_dir 93 | self.mask_dir = mask_dir 94 | self.transform = transform 95 | self.samples = _make_dataset_mask(self.image_dir, self.mask_dir) 96 | self.targets = [s[2] for s in self.samples] 97 | 98 | def __getitem__(self, index): 99 | image_path, mask_path, target = self.samples[index] 100 | image = _img_loader(path=image_path, mode='RGB') 101 | mask = _img_loader(path=mask_path, mode='L') 102 | name = os.path.split(image_path)[1] 103 | 104 | images = [image, mask] 105 | if self.transform is not None: 106 | images = self.transform(images) 107 | 108 | return images[0], target, name, images[1] 109 | 110 | def __len__(self): 111 | return len(self.samples) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import vit 3 | from models.cait import cait_xxs24_imagenet, cait_xxs24_cifar 4 | from models.deit import deit_tiny_imagenet, deit_tiny_cifar 5 | from models.pvt import pvt_tiny_imagenet, pvt_tiny_cifar 6 | from models.tnt import tnt_s_imagenet, tnt_s_cifar 7 | from models.tvit import vit_tiny_imagenet, vit_tiny_cifar 8 | from models.beit import beit_base_patch16_224, beit_base_patch16_cifar 9 | from models.eva import eva02_tiny_patch14_224, eva02_tiny_patch14_cifar 10 | 11 | import random 12 | import numpy as np 13 | # import timm 14 | from models.tvit import MyIdentity 15 | 16 | 17 | def setup_seed(seed): 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | 24 | 25 | def load_model(model_name, data_name, in_channels=3, num_classes=10): 26 | print('-' * 50) 27 | print('MODEL NAME:', model_name) 28 | print('NUM CLASSES:', num_classes) 29 | print('-' * 50) 30 | 31 | model_kwargs = dict(drop_path_rate=0.1) 32 | model_kwargs['num_classes'] = num_classes 33 | print(model_kwargs) 34 | 35 | model = None 36 | if model_name == 'tvit' and ('cifar' in data_name): 37 | model = vit_tiny_cifar(**model_kwargs) 38 | elif model_name == 'tvit' and ('imagenet' in data_name): 39 | model = vit_tiny_imagenet(**model_kwargs) 40 | elif model_name == 'cait' and ('cifar' in data_name): 41 | model = cait_xxs24_cifar(**model_kwargs) 42 | elif model_name == 'cait' and ('imagenet' in data_name): 43 | model = cait_xxs24_imagenet(**model_kwargs) 44 | elif model_name == 'pvt' and ('cifar' in data_name): 45 | model = pvt_tiny_cifar(**model_kwargs) 46 | elif model_name == 'pvt' and ('imagenet' in data_name): 47 | model = pvt_tiny_imagenet(**model_kwargs) 48 | elif model_name == 'deit' and ('cifar' in data_name): 49 | model = deit_tiny_cifar(**model_kwargs) 50 | elif model_name == 'deit' and ('imagenet' in data_name): 51 | model = deit_tiny_imagenet(**model_kwargs) 52 | elif model_name == 'tnt' and ('cifar' in data_name): 53 | model = tnt_s_cifar(**model_kwargs) 54 | elif model_name == 'tnt' and ('imagenet' in data_name): 55 | model = tnt_s_imagenet(**model_kwargs) 56 | elif model_name == 'beit' and ('imagenet' in data_name): 57 | model = beit_base_patch16_224(**model_kwargs) 58 | elif model_name == 'beit' and ('cifar' in data_name): 59 | model = beit_base_patch16_cifar(**model_kwargs) 60 | elif model_name == 'eva' and ('imagenet' in data_name): 61 | model = eva02_tiny_patch14_224(**model_kwargs) 62 | elif model_name == 'eva' and ('cifar' in data_name): 63 | model = eva02_tiny_patch14_cifar(**model_kwargs) 64 | 65 | return model 66 | 67 | 68 | def load_modules(model, model_layers=None): 69 | assert model_layers is None or type(model_layers) is list 70 | 71 | modules = [] 72 | for module in model.modules(): 73 | if isinstance(module, torch.nn.Linear): 74 | modules.append(module) 75 | elif isinstance(module, MyIdentity): 76 | modules.append(module) 77 | 78 | modules.reverse() # reverse order 79 | if model_layers is None: 80 | model_modules = modules 81 | # model_modules = modules[:10] 82 | else: 83 | model_modules = [] 84 | for layer in model_layers: 85 | model_modules.append(modules[layer]) 86 | 87 | print('-' * 50) 88 | print('Model Layers:', model_layers) 89 | print('Model Modules:', model_modules) 90 | print('Model Modules Length:', len(model_modules)) 91 | print('-' * 50) 92 | 93 | return model_modules 94 | 95 | 96 | if __name__ == '__main__': 97 | from torchsummary import summary 98 | 99 | model = load_model('vgg16') 100 | print(model) 101 | # summary(model, (3, 224, 224)) 102 | 103 | modules = load_modules(model) 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Example user template template 2 | ### Example user template 3 | 4 | # IntelliJ project files 5 | .idea 6 | *.iml 7 | out 8 | gen 9 | ### Python template 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | 171 | -------------------------------------------------------------------------------- /doctor/get_linear.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch 10 | import torch.nn as nn 11 | 12 | import models 13 | import loaders 14 | import pickle 15 | 16 | 17 | class LinearProcess: 18 | def __init__(self): 19 | self.linears = [[] for _ in range(10)] 20 | self.linear_names = [[] for _ in range(10)] 21 | 22 | def __call__(self, linears, labels, names): 23 | for i, linear in enumerate(linears): 24 | label = labels[i] 25 | name = names[i] 26 | self.linears[label].append(linear) 27 | self.linear_names[label].append(name) 28 | 29 | def save(self): 30 | data_path = '/nfs3-p1/ch/project/td/output/vit_high_cifar10/train_images' 31 | class_names = sorted([d.name for d in os.scandir(data_path) if d.is_dir()]) 32 | for i in range(10): 33 | linears = self.linears[i] 34 | result = np.empty([0,0]) 35 | for j, linear in enumerate(linears): 36 | linear = linear.cpu().numpy() 37 | for k in range(64): 38 | if (k == 0): 39 | result = linear[k + 1][:50] 40 | else: 41 | result = np.vstack([result, linear[k + 1][:50]]) 42 | class_name = class_names[i] 43 | sample_name = self.linear_names[i][j] 44 | original_path = '/nfs3-p1/ch/project/td/output/vit_high_cifar10/linears' 45 | save_path = os.path.join(original_path, class_name) 46 | file_path = os.path.join(save_path, sample_name) 47 | print(file_path) 48 | print(result.shape) 49 | np.save(file_path, result) 50 | break 51 | 52 | # 把result可视化 53 | 54 | 55 | def main(): 56 | parser = argparse.ArgumentParser(description='') 57 | parser.add_argument('--model_name', default='', type=str, help='model name') 58 | parser.add_argument('--data_name', default='', type=str, help='data name') 59 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 60 | parser.add_argument('--model_path', default='', type=str, help='model path') 61 | parser.add_argument('--data_dir', default='', type=str, help='data path') 62 | parser.add_argument('--save_dir', default='', type=str, help='locating path') 63 | args = parser.parse_args() 64 | 65 | # ---------------------------------------- 66 | # basic configuration 67 | # ---------------------------------------- 68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 69 | 70 | if not os.path.exists(args.save_dir): 71 | os.makedirs(args.save_dir) 72 | 73 | print('-' * 50) 74 | print('TRAIN ON:', device) 75 | print('DATA DIR:', args.data_dir) 76 | print('SAVE DIR:', args.save_dir) 77 | print('-' * 50) 78 | 79 | # ---------------------------------------- 80 | # model/data configuration 81 | # ---------------------------------------- 82 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 83 | model.load_state_dict(torch.load(args.model_path)) 84 | model.to(device) 85 | model.eval() 86 | 87 | data_loader = loaders.load_data(args.data_dir, args.data_name, data_type='train') 88 | 89 | my_linear = LinearProcess() 90 | 91 | 92 | # ---------------------------------------- 93 | # forward 94 | # ---------------------------------------- 95 | for i, samples in enumerate(tqdm(data_loader)): 96 | inputs, labels, names = samples 97 | inputs = inputs.to(device) 98 | labels = labels.to(device) 99 | with torch.no_grad(): 100 | outputs, linears = model(inputs) 101 | my_linear(linears, labels, names) 102 | my_linear.save() 103 | 104 | if __name__ == '__main__': 105 | np.set_printoptions(threshold=np.inf) 106 | main() -------------------------------------------------------------------------------- /doctor/get_class_activation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch 10 | import torch.nn as nn 11 | 12 | import models 13 | import loaders 14 | 15 | 16 | class HookModule: 17 | def __init__(self, module): 18 | self.module = module 19 | self.inputs = None 20 | self.outputs = None 21 | module.register_forward_hook(self._hook) 22 | 23 | def _hook(self, module, inputs, outputs): 24 | self.inputs = inputs[0] 25 | self.outputs = outputs 26 | 27 | 28 | class ComponentLocating: 29 | def __init__(self, modules, num_classes): 30 | self.modules = [HookModule(module) for module in modules] 31 | self.values = [[[] for _ in range(num_classes)] for _ in range(len(modules))] 32 | self.num_classes = num_classes 33 | 34 | def __call__(self, outputs, labels): 35 | for layer in range(12): 36 | values = self.modules[int(4 * layer + 3)].outputs 37 | values = torch.relu(values) 38 | values = values.detach().cpu().numpy() 39 | for b in range(len(labels)): 40 | self.values[layer][labels[b]].append(values[b]) 41 | print('layer: ', layer, ' shape: ', np.shape(self.values[layer][0])) 42 | 43 | 44 | def sift(self, result_path): 45 | for layer in range(12): 46 | for label in range(100): 47 | values = self.values[layer][label] # (num_images, channels) 48 | save_path = os.path.join(result_path, 'layer_{}_label_{}.npy'.format(layer, label)) 49 | np.save(save_path, values) 50 | print(save_path) 51 | 52 | 53 | 54 | 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser(description='') 58 | parser.add_argument('--model_name', default='', type=str, help='model name') 59 | parser.add_argument('--data_name', default='', type=str, help='data name') 60 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 61 | parser.add_argument('--model_path', default='', type=str, help='model path') 62 | parser.add_argument('--data_dir', default='', type=str, help='data path') 63 | parser.add_argument('--save_dir', default='', type=str, help='locating path') 64 | args = parser.parse_args() 65 | 66 | # ---------------------------------------- 67 | # basic configuration 68 | # ---------------------------------------- 69 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 70 | 71 | if not os.path.exists(args.save_dir): 72 | os.makedirs(args.save_dir) 73 | 74 | print('-' * 50) 75 | print('TRAIN ON:', device) 76 | print('DATA DIR:', args.data_dir) 77 | print('SAVE DIR:', args.save_dir) 78 | print('-' * 50) 79 | 80 | # ---------------------------------------- 81 | # model/data configuration 82 | # ---------------------------------------- 83 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 84 | model.load_state_dict(torch.load(args.model_path)) 85 | model.to(device) 86 | model.eval() 87 | 88 | mask_dir = '/mnt/nfs/hjc/project/td/output/ideal/atts/imagenet' 89 | data_loader = loaders.load_data_mask(args.data_dir, args.data_name, mask_dir, data_type='test') 90 | 91 | modules = models.load_modules(model=model) 92 | 93 | component_locating = ComponentLocating(modules=modules, num_classes=args.num_classes) 94 | 95 | # ---------------------------------------- 96 | # forward 97 | # ---------------------------------------- 98 | for i, samples in enumerate(tqdm(data_loader)): 99 | inputs, labels, names, masks = samples 100 | inputs = inputs.to(device) 101 | labels = labels.to(device) 102 | outputs = model(inputs) 103 | component_locating(outputs, labels) 104 | 105 | component_locating.sift(result_path=args.save_dir) 106 | 107 | 108 | if __name__ == '__main__': 109 | np.set_printoptions(threshold=np.inf) 110 | main() 111 | -------------------------------------------------------------------------------- /doctor/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | import re 4 | import time 5 | from collections import OrderedDict 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | from torch import nn 12 | 13 | import loaders 14 | import models 15 | from metrics import ildv 16 | 17 | from utils.misc import reduce_dict, update, MetricLogger, SmoothedValue 18 | from evaluator.default import DefaultEvaluator 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--model_name', default='', type=str, help='model name') 24 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 25 | parser.add_argument('--model_path', default='', type=str, help='model path') 26 | parser.add_argument('--data_name', default='', type=str, help='data name') 27 | parser.add_argument('--data_dir', default='', type=str, help='data directory') 28 | parser.add_argument('--save_dir', default='', type=str, help='save directory') 29 | args = parser.parse_args() 30 | 31 | # ---------------------------------------- 32 | # basic configuration 33 | # ---------------------------------------- 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | 36 | if not os.path.exists(args.save_dir): 37 | os.makedirs(args.save_dir) 38 | 39 | print('-' * 50) 40 | print('DEVICE:', device) 41 | print('MODEL PATH:', args.model_path) 42 | print('DATA PATH:', args.data_dir) 43 | print('-' * 50) 44 | 45 | # ---------------------------------------- 46 | # test configuration 47 | # ---------------------------------------- 48 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 49 | # model.load_state_dict(torch.load(args.model_path)['model'],strict=False) 50 | print(model.load_state_dict(torch.load(args.model_path), strict=False)) 51 | 52 | model.to(device) 53 | 54 | test_loader = loaders.load_data(data_dir=args.data_dir, data_name=args.data_name, data_type='test', batch_size=512) 55 | 56 | # 计算方法 57 | # evaluates = { 58 | # 'A': [ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device), 59 | # ildv.MulticlassF1Score(average='macro', num_classes=args.num_classes).to(device)], 60 | # 'B': [ildv.MulticlassFalseNegativeRate(average='macro', num_classes=args.num_classes).to(device), 61 | # ildv.MulticlassFalseDiscoveryRate(average='macro', num_classes=args.num_classes).to(device)], 62 | # 'C': [ildv.MulticlassBalancedAccuracy(args.num_classes).to(device)] 63 | # } 64 | evaluates = DefaultEvaluator(metrics=['acc', 'recall', 'precision', 'f1']) 65 | 66 | # ---------------------------------------- 67 | # each epoch 68 | # ---------------------------------------- 69 | since = time.time() 70 | 71 | scores = test(test_loader, model, evaluates, device) 72 | save_path = os.path.join(args.save_dir, '{}_{}.npy'.format(args.model_name, args.data_name)) 73 | np.save(save_path, scores) 74 | 75 | print('-' * 50) 76 | print('TIME CONSUMED', time.time() - since) 77 | 78 | 79 | # def test(test_loader, model, evaluates, device): 80 | # model.eval() 81 | 82 | # for i, samples in enumerate(tqdm(test_loader)): 83 | # # print(samples) 84 | # tmp = samples 85 | # inputs, labels, names = tmp 86 | # inputs = inputs.to(device) 87 | # labels = labels.to(device) 88 | 89 | # with torch.no_grad(): 90 | # outputs = model(inputs) 91 | 92 | # for value in evaluates.values(): 93 | # for evaluate in value: 94 | # evaluate.update(outputs, labels) 95 | 96 | # # calculate result 97 | # scores = {} 98 | # for key, value in zip(evaluates.keys(), evaluates.values()): 99 | # scores[key] = [] 100 | # for evaluate in value: 101 | # score = evaluate.compute().cpu().item() 102 | # scores[key].append(score) 103 | 104 | 105 | # print(scores) 106 | 107 | # return scores 108 | 109 | def test(test_loader, model, evaluates, device): 110 | model.eval() 111 | 112 | for i, samples in enumerate(tqdm(test_loader)): 113 | # print(samples) 114 | tmp = samples 115 | inputs, labels, names = tmp 116 | inputs = inputs.to(device) 117 | labels = labels.to(device) 118 | 119 | with torch.cuda.amp.autocast(enabled=False): 120 | outputs = model(inputs) 121 | 122 | evaluates.update(outputs, labels) 123 | print('ACC:', evaluates.metric_acc(evaluates.outputs, evaluates.targets)) 124 | 125 | return evaluates 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /doctor/get_all_attention.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch 10 | import torch.nn as nn 11 | 12 | import models 13 | import loaders 14 | import pickle 15 | 16 | 17 | class HookModule: 18 | def __init__(self, module): 19 | self.module = module 20 | self.inputs = None 21 | self.outputs = None 22 | module.register_forward_hook(self._hook) 23 | 24 | def _hook(self, module, inputs, outputs): 25 | self.inputs = inputs[0] 26 | self.outputs = outputs 27 | 28 | 29 | class AttentionGet: 30 | def __init__(self, modules): 31 | self.modules = [HookModule(module) for module in modules] 32 | 33 | def __call__(self, save_dir, outputs, labels, names): 34 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 35 | 36 | # attentions = self.modules[6].outputs 37 | # grads = torch.autograd.grad(outputs=-nll_loss, inputs=attentions, retain_graph=True, create_graph=True)[0] 38 | # grads (100, 16, 65, 65) 39 | # for i, attention in enumerate(attentions): 40 | # grad = grads[i] 41 | # grad = torch.relu(grad) 42 | # attention = torch.relu(attention) 43 | # attention = attention * grad 44 | # attention = torch.mean(attention, dim=0) 45 | # attention = attention.cpu().detach().numpy() 46 | # folder_path = '/mnt/nfs/hjc/project/td/output/tmp/att_compare/npys/before_new' 47 | # count = len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]) 48 | # save_path = os.path.join(folder_path, '{}'.format(count)) 49 | # np.save(save_path, attention) 50 | # print(save_path) 51 | 52 | for layer in range(12): 53 | attentions = self.modules[5 * layer + 6].outputs 54 | # grads = torch.autograd.grad(outputs=-nll_loss, inputs=attentions, retain_graph=True, create_graph=True)[0] 55 | # grads (100, 16, 65, 65) 56 | print(attentions.shape) # [100, 3, 197, 197] 57 | for i, attention in enumerate(attentions): 58 | # grad = grads[i] 59 | # grad = torch.relu(grad) 60 | # attention = torch.relu(attention) 61 | # attention = attention * grad 62 | # attention = torch.mean(attention, dim=0) 63 | attention = attention.cpu().detach().numpy() 64 | 65 | label = labels[i] 66 | fig_name = os.path.splitext(names[i])[0] 67 | # count = len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]) 68 | save_path = os.path.join(save_dir, '{}_{}_{}.npy'.format(label, fig_name, layer)) 69 | np.save(save_path, attention) 70 | print(save_path, attention.shape) 71 | 72 | 73 | def main(): 74 | parser = argparse.ArgumentParser(description='') 75 | parser.add_argument('--model_name', default='', type=str, help='model name') 76 | parser.add_argument('--data_name', default='', type=str, help='data name') 77 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 78 | parser.add_argument('--model_path', default='', type=str, help='model path') 79 | parser.add_argument('--data_dir', default='', type=str, help='data path') 80 | parser.add_argument('--save_dir', default='', type=str, help='locating path') 81 | args = parser.parse_args() 82 | 83 | # ---------------------------------------- 84 | # basic configuration 85 | # ---------------------------------------- 86 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 87 | 88 | if not os.path.exists(args.save_dir): 89 | os.makedirs(args.save_dir) 90 | 91 | print('-' * 50) 92 | print('TRAIN ON:', device) 93 | print('MODEL DIR:', args.model_path) 94 | print('DATA DIR:', args.data_dir) 95 | print('SAVE DIR:', args.save_dir) 96 | print('-' * 50) 97 | 98 | # ---------------------------------------- 99 | # model/data configuration 100 | # ---------------------------------------- 101 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 102 | model.load_state_dict(torch.load(args.model_path)) 103 | model.to(device) 104 | model.eval() 105 | 106 | # mask_dir = '/mnt/nfs/hjc/project/td/output/ideal/atts/imagenet' 107 | data_loader = loaders.load_data(args.data_dir, args.data_name, data_type='test', batch_size=512) 108 | 109 | modules = models.load_modules(model=model) 110 | 111 | class_names = sorted([d.name for d in os.scandir(args.data_dir) if d.is_dir()]) 112 | print('class_names:', class_names) 113 | 114 | attention_get = AttentionGet(modules) 115 | 116 | # ---------------------------------------- 117 | # forward 118 | # ---------------------------------------- 119 | for i, samples in enumerate(tqdm(data_loader)): 120 | inputs, labels, names = samples 121 | inputs = inputs.to(device) 122 | labels = labels.to(device) 123 | outputs = model(inputs) 124 | attention_get(args.save_dir, outputs, labels, names) 125 | 126 | 127 | if __name__ == '__main__': 128 | np.set_printoptions(threshold=np.inf) 129 | main() 130 | -------------------------------------------------------------------------------- /doctor/sample_sift.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from tqdm import tqdm 5 | # import timm 6 | 7 | import loaders 8 | import models 9 | from utils import file_util 10 | 11 | 12 | class SampleSift: 13 | def __init__(self, num_classes, num_samples, is_high_confidence=True): 14 | self.names = [[None for j in range(num_samples)] for i in range(num_classes)] 15 | self.scores = torch.zeros((num_classes, num_samples)) 16 | self.nums = torch.zeros(num_classes, dtype=torch.long) 17 | self.is_high_confidence = is_high_confidence 18 | 19 | def __call__(self, outputs, labels, names): 20 | softmax = torch.nn.Softmax(dim=1)(outputs.detach()) 21 | scores, predicts = torch.max(softmax, dim=1) 22 | # print(scores) 23 | 24 | for i, label in enumerate(labels): 25 | if self.is_high_confidence == 1: 26 | print('-->high') 27 | if self.nums[label] == self.scores.shape[1]: 28 | score_min, index = torch.min(self.scores[label], dim=0) 29 | if scores[i] > score_min: 30 | self.scores[label][index] = scores[i] 31 | self.names[label.item()][index.item()] = names[i] 32 | else: 33 | self.scores[label][self.nums[label]] = scores[i] 34 | self.names[label.item()][self.nums[label].item()] = names[i] 35 | self.nums[label] += 1 36 | else: # sift low confidence 37 | print('-->low') 38 | if self.nums[label] == self.scores.shape[1]: 39 | score_max, index = torch.max(self.scores[label], dim=0) 40 | if scores[i] < score_max: 41 | self.scores[label][index] = scores[i] 42 | self.names[label.item()][index.item()] = names[i] 43 | else: 44 | self.scores[label][self.nums[label]] = scores[i] 45 | self.names[label.item()][self.nums[label].item()] = names[i] 46 | self.nums[label] += 1 47 | 48 | def save_image(self, input_path, output_path): 49 | 50 | class_names = sorted([d.name for d in os.scandir(input_path) if d.is_dir()]) 51 | print(self.names) 52 | 53 | for label, image_list in enumerate(self.names): 54 | for image in tqdm(image_list): 55 | class_name = class_names[label] 56 | 57 | src_path = os.path.join(input_path, class_name, str(image)) 58 | dst_path = os.path.join(output_path, class_name, str(image)) 59 | file_util.copy_file(src_path, dst_path) 60 | 61 | 62 | def main(): 63 | parser = argparse.ArgumentParser(description='') 64 | parser.add_argument('--model_name', default='', type=str, help='model name') 65 | parser.add_argument('--data_name', default='', type=str, help='data name') 66 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 67 | parser.add_argument('--batch_size', default=512, type=int, help='batch size') 68 | parser.add_argument('--model_path', default='', type=str, help='model path') 69 | parser.add_argument('--data_dir', default='', type=str, help='data dir') 70 | parser.add_argument('--save_dir', default='', type=str, help='sift dir') 71 | parser.add_argument('--num_samples', default=10, type=int, help='num samples') 72 | parser.add_argument('--is_high_confidence', default=1, type=int, help='is high confidence') 73 | args = parser.parse_args() 74 | 75 | # ---------------------------------------- 76 | # basic configuration 77 | # ---------------------------------------- 78 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 79 | 80 | if not os.path.exists(args.save_dir): 81 | os.makedirs(args.save_dir) 82 | 83 | print('-' * 50) 84 | print('TRAIN ON:', device) 85 | print('MODEL PATH:', args.model_path) 86 | print('DATA PATH:', args.data_dir) 87 | print('RESULT PATH:', args.save_dir) 88 | print('-' * 50) 89 | 90 | # ---------------------------------------- 91 | # model/data configuration 92 | # ---------------------------------------- 93 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 94 | model.load_state_dict(torch.load(args.model_path)) 95 | model.to(device) 96 | model.eval() 97 | 98 | data_loader = loaders.load_data(data_dir=args.data_dir, data_name=args.data_name, data_type='train', 99 | batch_size=args.batch_size) 100 | print('DATA LOAD DONE') 101 | 102 | sample_sift = SampleSift(num_classes=args.num_classes, num_samples=args.num_samples, 103 | is_high_confidence=args.is_high_confidence) 104 | 105 | print('SAMPLE SIFT DONE') 106 | 107 | # ---------------------------------------- 108 | # forward 109 | # ---------------------------------------- 110 | for samples in tqdm(data_loader): 111 | inputs, labels, names = samples 112 | inputs = inputs.to(device) 113 | labels = labels.to(device) 114 | with torch.no_grad(): 115 | outputs = model(inputs) 116 | sample_sift(outputs=outputs, labels=labels, names=names) 117 | 118 | sample_sift.save_image(args.data_dir, args.save_dir) 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | copy from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 3 | ''' 4 | import torch 5 | from torch import nn 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | import numpy as np 10 | 11 | 12 | # helpers 13 | 14 | def pair(t): 15 | return t if isinstance(t, tuple) else (t, t) 16 | 17 | 18 | # classes 19 | class MyIdentity(nn.Identity): 20 | def __init__(self): 21 | super().__init__() 22 | 23 | 24 | class PreNorm(nn.Module): 25 | def __init__(self, dim, fn): 26 | super().__init__() 27 | self.norm = nn.LayerNorm(dim) 28 | self.fn = fn 29 | 30 | def forward(self, x, **kwargs): 31 | return self.fn(self.norm(x), **kwargs) 32 | 33 | 34 | class FeedForward(nn.Module): 35 | def __init__(self, dim, hidden_dim, dropout=0.): 36 | super().__init__() 37 | self.net = nn.Sequential( 38 | nn.Linear(dim, hidden_dim), 39 | nn.GELU(), 40 | nn.Dropout(dropout), 41 | nn.Linear(hidden_dim, dim), 42 | nn.Dropout(dropout) 43 | ) 44 | 45 | def forward(self, x): 46 | return self.net(x) # b, n+1, d 47 | 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 51 | super().__init__() 52 | self.my_identity = MyIdentity() 53 | inner_dim = dim_head * heads 54 | project_out = not (heads == 1 and dim_head == dim) 55 | 56 | self.heads = heads 57 | self.scale = dim_head ** -0.5 58 | 59 | self.attend = nn.Softmax(dim=-1) 60 | self.dropout = nn.Dropout(dropout) 61 | 62 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 63 | self.identify = nn.Identity() 64 | 65 | self.to_out = nn.Sequential( 66 | nn.Linear(inner_dim, dim), 67 | nn.Dropout(dropout) 68 | ) if project_out else nn.Identity() 69 | 70 | def forward(self, x): 71 | qkv = self.to_qkv(x).chunk(3, dim=-1) # b, n+1, d -> b, n+1, (h * d_h) * 3 72 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) # b, h, n+1, d_h 73 | 74 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # b, h, n+1, n+1 75 | 76 | # dots = self.identify(dots) 77 | attn = self.attend(dots) # b, h, n+1, n+1 78 | 79 | attn = self.my_identity(attn) 80 | 81 | attn = self.dropout(attn) 82 | 83 | out = torch.matmul(attn, v) # b, h, n+1, d_h 84 | # out = self.identify(out) 85 | out = rearrange(out, 'b h n d -> b n (h d)') # b, n+1, (h * d_h) 86 | return self.to_out(out) # b, n+1, d 87 | 88 | 89 | class Transformer(nn.Module): 90 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): 91 | super().__init__() 92 | self.layers = nn.ModuleList([]) 93 | for _ in range(depth): 94 | self.layers.append(nn.ModuleList([ 95 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 96 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) 97 | ])) 98 | 99 | def forward(self, x): 100 | for attn, ff in self.layers: 101 | x = attn(x) + x # residual 102 | x = ff(x) + x 103 | return x # b, n+1, d 104 | 105 | 106 | class ViT(nn.Module): 107 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, 108 | dim_head=64, dropout=0., emb_dropout=0.): 109 | super().__init__() 110 | image_height, image_width = pair(image_size) 111 | patch_height, patch_width = pair(patch_size) 112 | 113 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 114 | 115 | num_patches = (image_height // patch_height) * (image_width // patch_width) # n 116 | patch_dim = channels * patch_height * patch_width 117 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 118 | 119 | self.to_patch_embedding = nn.Sequential( 120 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), 121 | nn.LayerNorm(patch_dim), 122 | nn.Linear(patch_dim, dim), 123 | nn.LayerNorm(dim), 124 | ) 125 | 126 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # 1, n+1, d 127 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 128 | self.dropout = nn.Dropout(emb_dropout) 129 | 130 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 131 | 132 | self.pool = pool 133 | # self.to_latent = nn.Identity() 134 | 135 | self.mlp_head = nn.Sequential( 136 | nn.LayerNorm(dim), 137 | nn.Linear(dim, num_classes) 138 | ) 139 | 140 | def forward(self, img): 141 | x = self.to_patch_embedding(img) # b, n, d 142 | b, n, _ = x.shape 143 | 144 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) # b, 1, d 145 | x = torch.cat((cls_tokens, x), dim=1) # b, n+1, d 146 | x += self.pos_embedding[:, :(n + 1)] # b, n+1, d 147 | x = self.dropout(x) # b, n+1, d 148 | 149 | x = self.transformer(x) # b, n+1, d 150 | 151 | x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] # b, 1, d 152 | 153 | # x = self.to_latent(x) 154 | return self.mlp_head(x) # n, c -------------------------------------------------------------------------------- /engines/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import time 5 | from tqdm import tqdm 6 | 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | import loaders 13 | import models 14 | import metrics 15 | from utils.train_util import AverageMeter, ProgressMeter 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description='') 20 | parser.add_argument('--model_name', default='', type=str, help='model name') 21 | parser.add_argument('--data_name', default='', type=str, help='data name') 22 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 23 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 24 | parser.add_argument('--model_dir', default='', type=str, help='model dir') 25 | parser.add_argument('--data_dir', default='', type=str, help='data dir') 26 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 27 | args = parser.parse_args() 28 | 29 | # ---------------------------------------- 30 | # basic configuration 31 | # ---------------------------------------- 32 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | 34 | # train_dir = os.path.join(args.data_dir, 'train') 35 | # test_dir = os.path.join(args.data_dir, 'test') 36 | train_dir = '/datasets/ILSVRC2012/train' 37 | test_dir = '/nfs3/hjc/datasets/imagenet1k/val' 38 | 39 | if not os.path.exists(args.model_dir): 40 | os.makedirs(args.model_dir) 41 | if os.path.exists(args.log_dir): 42 | shutil.rmtree(args.log_dir) 43 | 44 | print('-' * 50) 45 | print('TRAIN ON:', device) 46 | print('DATA DIR:', args.data_dir) 47 | print('MODEL DIR:', args.model_dir) 48 | print('LOG DIR:', args.log_dir) 49 | print('-' * 50) 50 | 51 | # ---------------------------------------- 52 | # trainer configuration 53 | # ---------------------------------------- 54 | model = models.load_model(args.model_name, num_classes=args.num_classes) 55 | model.to(device) 56 | 57 | train_loader = loaders.load_data(train_dir, args.data_name, data_type='train') 58 | test_loader = loaders.load_data(test_dir, args.data_name, data_type='test') 59 | 60 | criterion = nn.CrossEntropyLoss() 61 | optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 62 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 63 | 64 | writer = SummaryWriter(args.log_dir) 65 | 66 | # ---------------------------------------- 67 | # each epoch 68 | # ---------------------------------------- 69 | since = time.time() 70 | 71 | best_acc = None 72 | best_epoch = None 73 | 74 | for epoch in tqdm(range(args.num_epochs)): 75 | loss, acc1, acc5 = train(train_loader, model, criterion, optimizer, device) 76 | writer.add_scalar(tag='training loss', scalar_value=loss.avg, global_step=epoch) 77 | writer.add_scalar(tag='training acc1', scalar_value=acc1.avg, global_step=epoch) 78 | loss, acc1, acc5 = test(test_loader, model, criterion, device) 79 | writer.add_scalar(tag='test loss', scalar_value=loss.avg, global_step=epoch) 80 | writer.add_scalar(tag='test acc1', scalar_value=acc1.avg, global_step=epoch) 81 | 82 | # ---------------------------------------- 83 | # save best model 84 | # ---------------------------------------- 85 | if best_acc is None or best_acc < acc1.avg: 86 | best_acc = acc1.avg 87 | best_epoch = epoch 88 | torch.save(model.state_dict(), os.path.join(args.model_dir, 'model_ori.pth')) 89 | 90 | scheduler.step() 91 | 92 | print('COMPLETE !!!') 93 | print('BEST ACC', best_acc) 94 | print('BEST EPOCH', best_epoch) 95 | print('TIME CONSUMED', time.time() - since) 96 | 97 | 98 | def train(train_loader, model, criterion, optimizer, device): 99 | loss_meter = AverageMeter('Loss', ':.4e') 100 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 101 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 102 | progress = ProgressMeter(total=len(train_loader), step=20, prefix='Training', 103 | meters=[loss_meter, acc1_meter, acc5_meter]) 104 | 105 | model.train() 106 | 107 | for i, samples in enumerate(train_loader): 108 | inputs, labels, _ = samples 109 | inputs = inputs.to(device) 110 | labels = labels.to(device) 111 | 112 | outputs = model(inputs) 113 | loss = criterion(outputs, labels) 114 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 115 | 116 | loss_meter.update(loss.item(), inputs.size(0)) 117 | acc1_meter.update(acc1.item(), inputs.size(0)) 118 | acc5_meter.update(acc5.item(), inputs.size(0)) 119 | 120 | optimizer.zero_grad() # 1 121 | loss.backward() # 2 122 | optimizer.step() # 3 123 | 124 | progress.display(i) 125 | 126 | return loss_meter, acc1_meter, acc5_meter 127 | 128 | 129 | def test(test_loader, model, criterion, device): 130 | loss_meter = AverageMeter('Loss', ':.4e') 131 | acc1_meter = AverageMeter('Acc@1', ':6.2f') 132 | acc5_meter = AverageMeter('Acc@5', ':6.2f') 133 | progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test', 134 | meters=[loss_meter, acc1_meter, acc5_meter]) 135 | model.eval() 136 | 137 | for i, samples in enumerate(test_loader): 138 | inputs, labels, _ = samples 139 | inputs = inputs.to(device) 140 | labels = labels.to(device) 141 | 142 | with torch.set_grad_enabled(False): 143 | outputs = model(inputs) 144 | loss = criterion(outputs, labels) 145 | acc1, acc5 = metrics.accuracy(outputs, labels, topk=(1, 5)) 146 | 147 | loss_meter.update(loss.item(), inputs.size(0)) 148 | acc1_meter.update(acc1.item(), inputs.size(0)) 149 | acc5_meter.update(acc5.item(), inputs.size(0)) 150 | 151 | progress.display(i) 152 | 153 | return loss_meter, acc1_meter, acc5_meter 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /doctor/get_attention.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch 10 | import torch.nn as nn 11 | 12 | import models 13 | import loaders 14 | import pickle 15 | 16 | 17 | class HookModule: 18 | def __init__(self, module): 19 | self.module = module 20 | self.inputs = None 21 | self.outputs = None 22 | module.register_forward_hook(self._hook) 23 | 24 | def _hook(self, module, inputs, outputs): 25 | self.inputs = inputs[0] 26 | self.outputs = outputs 27 | 28 | 29 | class AttentionGet: 30 | def __init__(self, modules): 31 | self.modules = [HookModule(module) for module in modules] 32 | 33 | def __call__(self, names, outputs, labels): 34 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 35 | attentions = self.modules[6].outputs #(100, 16, 65, 65) 36 | 37 | # wall = 0.015 38 | # heads_num = 3 39 | # for i, attention in enumerate(attentions): 40 | # attention_value = attention[:, 0, :] #(16, 65, 65) -> (16, 65) 41 | # attention_value = attention_value[:, 1:65] #(16, 65) -> (16, 64) 42 | # # attention_value = torch.where(attention_value < wall, 0, attention_value) 43 | # attention_value = torch.sum(attention_value, dim=-1) #(16, 64) -> (16) 44 | # sorted_value, indices = torch.sort(attention_value) 45 | # print(indices[0], ' ', indices[1], ' ', indices[2]) 46 | # tmp_attention_heads = [] 47 | # for j in range(heads_num): 48 | # tmp = attention[indices[j]] 49 | # # tmp = torch.where(tmp < wall, 0, tmp) 50 | # tmp_attention_heads.append(tmp) 51 | # attention_max = torch.stack(tmp_attention_heads, dim=0) 52 | # attention_max = attention_max[:, 0, :] 53 | # attention_max = torch.mean(attention_max, dim=0) 54 | # attention_max = attention_max[1:65] 55 | # attention_max = attention_max.reshape((8, 8)) 56 | # attention_max = attention_max.cpu().detach().numpy() 57 | # save_path = os.path.join('/nfs1/ch/project/td/output/tmp/high/npys/min_npys', 'img_{}'.format(i)) 58 | # np.save(save_path, attention_max) 59 | 60 | grads = torch.autograd.grad(outputs=-nll_loss, inputs=attentions, retain_graph=True, create_graph=True)[0] 61 | # grads (100, 16, 65, 65) 62 | for i, attention in enumerate(attentions): 63 | # if i == 20: 64 | # break 65 | grad = grads[i] 66 | grad = torch.relu(grad) #(16, 65, 65) 67 | grad = grad[:, 0, :] #(16, 65, 65) -> (16, 65) 68 | grad = grad[:, 1:197] #(16, 64) 69 | 70 | # grad_sum = torch.sum(grad, dim=-1) #(16, 64) -> (16) 71 | # sorted_grad, indices = torch.sort(grad_sum, descending=True) #降序,大的在前 72 | #(16, 65, 65) 73 | attention = attention[:, 0, :] #(16, 65) 74 | attention = attention[:, 1:197] #(16, 64) 75 | attention = attention * grad #(16, 64) 76 | attention = torch.mean(attention, dim=0) 77 | attention = attention.reshape((14, 14)) #(8, 8) 78 | attention = attention.cpu().detach().numpy() 79 | save_path = os.path.join('/nfs/ch/project/td/output/tmp/low/all/low_npys', 'img_{}'.format(i)) 80 | np.save(save_path, attention) 81 | print(save_path) 82 | # for j, head_attention in enumerate(attention): 83 | # save_path = os.path.join('/nfs/ch/project/td/output/tmp/high/npys/img_{}'.format(i), 'rank_{}_head_{}.npy'.format(indices[j], j)) 84 | # print(save_path) 85 | # head_attention = head_attention[1:65] 86 | # head_attention = head_attention.reshape((8, 8)) 87 | # head_attention = head_attention.cpu().detach().numpy() 88 | # np.save(save_path, head_attention) 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser(description='') 93 | parser.add_argument('--model_name', default='', type=str, help='model name') 94 | parser.add_argument('--data_name', default='', type=str, help='data name') 95 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 96 | parser.add_argument('--model_path', default='', type=str, help='model path') 97 | parser.add_argument('--data_dir', default='', type=str, help='data path') 98 | parser.add_argument('--save_dir', default='', type=str, help='locating path') 99 | args = parser.parse_args() 100 | 101 | # ---------------------------------------- 102 | # basic configuration 103 | # ---------------------------------------- 104 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 105 | 106 | if not os.path.exists(args.save_dir): 107 | os.makedirs(args.save_dir) 108 | 109 | print('-' * 50) 110 | print('TRAIN ON:', device) 111 | print('DATA DIR:', args.data_dir) 112 | print('SAVE DIR:', args.save_dir) 113 | print('-' * 50) 114 | 115 | # ---------------------------------------- 116 | # model/data configuration 117 | # ---------------------------------------- 118 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 119 | model.load_state_dict(torch.load(args.model_path)) 120 | model.to(device) 121 | model.eval() 122 | 123 | mask_dir = '/mnt/nfs/hjc/project/td/output/ideal/atts/imagenet' 124 | data_loader = loaders.load_data_mask(args.data_dir, args.data_name, mask_dir, data_type='test') 125 | 126 | modules = models.load_modules(model=model) 127 | 128 | attention_get = AttentionGet(modules) 129 | 130 | # ---------------------------------------- 131 | # forward 132 | # ---------------------------------------- 133 | for i, samples in enumerate(tqdm(data_loader)): 134 | inputs, labels, names, masks = samples 135 | inputs = inputs.to(device) 136 | labels = labels.to(device) 137 | outputs = model(inputs) 138 | attention_get(names, outputs, labels) 139 | 140 | 141 | if __name__ == '__main__': 142 | np.set_printoptions(threshold=np.inf) 143 | main() -------------------------------------------------------------------------------- /doctor/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.utils.tensorboard import SummaryWriter 11 | import shutil 12 | import timm.scheduler as timm_scheduler 13 | 14 | import loaders 15 | import models 16 | from metrics import ildv 17 | import math 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser(description='') 21 | parser.add_argument('--model_name', default='', type=str, help='model name') 22 | parser.add_argument('--data_name', default='', type=str, help='data name') 23 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 24 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 25 | parser.add_argument('--batch_size', default=512, type=int, help='batch size') 26 | parser.add_argument('--model_dir', default='', type=str, help='model dir') 27 | parser.add_argument('--result_name', default='', type=str, help='result name') 28 | parser.add_argument('--train_data_dir', default='', type=str, help='train_data_dir') 29 | parser.add_argument('--test_data_dir', default='', type=str, help='test_data_dir') 30 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 31 | args = parser.parse_args() 32 | 33 | # ---------------------------------------- 34 | # basic configuration 35 | # ---------------------------------------- 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | 38 | train_dir = args.train_data_dir 39 | test_dir = args.test_data_dir 40 | 41 | if not os.path.exists(args.model_dir): 42 | os.makedirs(args.model_dir) 43 | if os.path.exists(args.log_dir): 44 | shutil.rmtree(args.log_dir) 45 | 46 | print('-' * 50) 47 | print('TRAIN ON:', device) 48 | print('TRAIN DATA DIR:', args.train_data_dir) 49 | print('TEST DATA DIR:', args.test_data_dir) 50 | print('MODEL DIR:', args.model_dir) 51 | print('LOG DIR:', args.log_dir) 52 | print('RESULT_NAME:', args.result_name + '.pth') 53 | print('-' * 50) 54 | 55 | # ---------------------------------------- 56 | # trainer configuration 57 | # ---------------------------------------- 58 | model = models.load_model(args.model_name, args.data_name, num_classes=args.num_classes) 59 | # model.load_state_dict(torch.load('/mnt/nfs/hjc/project/td/output/train_models/pth/imagenet/deit_imagenet_baseline.pth'), strict=False) 60 | model.to(device) 61 | print('MODEL LOAD DONE') 62 | train_loader = loaders.load_data(train_dir, args.data_name, data_type='train', batch_size=args.batch_size) 63 | print('TRAIN LOAD DONE') 64 | test_loader = loaders.load_data(test_dir, args.data_name, data_type='test', batch_size=args.batch_size) 65 | print('TEST LOAD DONE') 66 | modules = models.load_modules(model=model) 67 | 68 | criterion = nn.CrossEntropyLoss() 69 | 70 | batch_size = args.batch_size 71 | max_learn_rate = 0.0005 * (batch_size / 512) 72 | optimizer = optim.AdamW(params=model.parameters(), lr=max_learn_rate, weight_decay=0.05, eps=1e-8) 73 | num_steps = int(args.num_epochs * len(train_loader)) 74 | warmup_steps = 0 75 | warmup_lr = 1e-06 76 | scheduler = timm_scheduler.CosineLRScheduler( 77 | optimizer, 78 | t_initial=(num_steps - warmup_steps), 79 | lr_min=1e-05, 80 | warmup_lr_init=warmup_lr, 81 | warmup_t=warmup_steps, 82 | cycle_limit=1, 83 | t_in_epochs=False, 84 | ) 85 | 86 | # optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-2) 87 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 88 | 89 | writer = SummaryWriter(args.log_dir) 90 | 91 | # ---------------------------------------- 92 | # each epoch 93 | # ---------------------------------------- 94 | since = time.time() 95 | 96 | best_acc = None 97 | best_epoch = None 98 | 99 | for epoch in tqdm(range(args.num_epochs)): 100 | acc1, loss = train(train_loader, model, criterion, optimizer, device, args) 101 | writer.add_scalar(tag='training acc1', scalar_value=acc1, global_step=epoch) 102 | writer.add_scalar(tag='training loss', scalar_value=loss, global_step=epoch) 103 | acc1, loss = test(test_loader, model, criterion, device, args) 104 | writer.add_scalar(tag='test acc1', scalar_value=acc1, global_step=epoch) 105 | writer.add_scalar(tag='test loss', scalar_value=loss, global_step=epoch) 106 | 107 | # ---------------------------------------- 108 | # save best model 109 | # ---------------------------------------- 110 | if best_acc is None or best_acc < acc1: 111 | best_acc = acc1 112 | best_epoch = epoch 113 | torch.save(model.state_dict(), os.path.join(args.model_dir, args.result_name + '.pth')) 114 | print(best_acc) 115 | 116 | scheduler.step(epoch) 117 | 118 | print('COMPLETE !!!') 119 | print('BEST ACC', best_acc) 120 | print('BEST EPOCH', best_epoch) 121 | print('TIME CONSUMED', time.time() - since) 122 | 123 | 124 | def train(train_loader, model, criterion, optimizer, device, args): 125 | 126 | model.train() 127 | 128 | acc1 = ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device) 129 | 130 | print('LEN OF TRAIN_LOADER') 131 | print(len(train_loader)) 132 | for samples in tqdm(enumerate(train_loader)): 133 | _, tmp = samples 134 | inputs, labels, names = tmp 135 | inputs = inputs.to(device) 136 | labels = labels.to(device) 137 | outputs = model(inputs) 138 | loss = criterion(outputs, labels) 139 | acc1.update(outputs, labels) 140 | optimizer.zero_grad() # 1 141 | loss.backward() # 2 142 | optimizer.step() # 3 143 | 144 | return acc1.compute().item(), loss.item() 145 | 146 | 147 | def test(test_loader, model, criterion, device, args): 148 | 149 | model.eval() 150 | 151 | acc1 = ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device) 152 | 153 | for samples in enumerate(test_loader): 154 | _, tmp = samples 155 | inputs, labels, names = tmp 156 | inputs = inputs.to(device) 157 | labels = labels.to(device) 158 | 159 | with torch.set_grad_enabled(False): 160 | outputs = model(inputs) 161 | loss = criterion(outputs, labels) 162 | 163 | acc1.update(outputs, labels) 164 | 165 | return acc1.compute().item(), loss.item() 166 | 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /doctor/get_all_channel.py: -------------------------------------------------------------------------------- 1 | """ 2 | activation: 3 | each layer 4 | 5 | gradient: 6 | output to each layer 7 | """ 8 | import sys 9 | 10 | sys.path.append('.') 11 | 12 | import argparse 13 | import torch 14 | import numpy as np 15 | import os 16 | from tqdm import tqdm 17 | 18 | import models 19 | import loaders 20 | 21 | 22 | class HookModule: 23 | def __init__(self, module): 24 | self.module = module 25 | self.inputs = None 26 | self.outputs = None 27 | module.register_forward_hook(self._hook) 28 | 29 | def grads(self, outputs, inputs=None, retain_graph=True, create_graph=True): 30 | if inputs is None: 31 | inputs = self.outputs # default the output dim 32 | 33 | return torch.autograd.grad(outputs=outputs, 34 | inputs=inputs, 35 | retain_graph=retain_graph, 36 | create_graph=create_graph)[0] 37 | 38 | def _hook(self, module, inputs, outputs): 39 | self.inputs = inputs[0] 40 | self.outputs = outputs 41 | 42 | 43 | def _normalization(data, axis=None, bot=False): 44 | assert axis in [None, 0, 1] 45 | _max = np.max(data, axis=axis) 46 | if bot: 47 | _min = np.zeros(_max.shape) 48 | else: 49 | _min = np.min(data, axis=axis) 50 | _range = _max - _min 51 | if axis == 1: 52 | _norm = ((data.T - _min) / (_range + 1e-5)).T 53 | else: 54 | _norm = (data - _min) / (_range + 1e-5) 55 | return _norm 56 | 57 | 58 | class getChannels: 59 | def __init__(self, modules, num_classes): 60 | self.modules = [HookModule(module) for module in modules] 61 | # [num_modules, num_labels, num_images, channels] 62 | self.grads = [[[] for _ in range(num_classes)] for _ in range(len(modules))] 63 | self.activations = [[[] for _ in range(num_classes)] for _ in range(len(modules))] 64 | 65 | def __call__(self, outputs, labels): 66 | # nll_loss = torch.nn.NLLLoss()(outputs, labels) 67 | for layer, module in enumerate(self.modules): 68 | # grads = module.grads(-nll_loss, module.outputs) 69 | # grads = torch.relu(grads) 70 | # grads = grads.detach().cpu().numpy() 71 | activations = module.outputs 72 | activations = torch.relu(activations) 73 | activations = activations.detach().cpu().numpy() 74 | for b in range(len(labels)): 75 | # self.grads[layer][labels[b]].append(grads[b]) 76 | self.activations[layer][labels[b]].append(activations[b]) 77 | 78 | def sift(self, result_path, threshold): 79 | for layer in range(12): 80 | # grads = self.grads[3 + layer * 5] 81 | activations = self.activations[3 + layer * 5] 82 | # grads = np.asarray(grads) 83 | activations = np.asarray(activations) 84 | # if len(grads.shape) == 4: 85 | # grads = np.squeeze(grads[:, :, 0, :]) # [num_labels, num_images, channels] 86 | if len(activations.shape) == 4: 87 | activations = np.squeeze(activations[:, :, 0, :]) # [num_labels, num_images, channels] 88 | 89 | values = activations 90 | 91 | for label, value in enumerate(values): 92 | value = _normalization(value, axis=1) # [num_images, channels] 93 | 94 | value_path = os.path.join(result_path, '{}_{}.npy'.format(label, layer)) 95 | np.save(value_path, value) 96 | print(value_path) 97 | 98 | # mask = np.zeros(value.shape) # [num_images, channels] 99 | # mask[np.where(value > threshold)] = 1 # [num_images, channels] 100 | # mask = np.sum(mask, axis=0) # [channels] 101 | # mask = np.where(mask > 2, 1, 0) # [channels] 102 | 103 | # mask_path = os.path.join(result_path, '{}.npy'.format(label)) 104 | # np.save(mask_path, mask) 105 | # print(mask_path) 106 | 107 | 108 | def main(): 109 | parser = argparse.ArgumentParser(description='') 110 | parser.add_argument('--model_name', default='', type=str, help='model name') 111 | parser.add_argument('--data_name', default='', type=str, help='data name') 112 | parser.add_argument('--in_channels', default='', type=int, help='in channels') 113 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 114 | parser.add_argument('--model_path', default='', type=str, help='model path') 115 | parser.add_argument('--data_path', default='', type=str, help='data path') 116 | parser.add_argument('--grad_path', default='', type=str, help='grad path') 117 | parser.add_argument('--theta', default='', type=float, help='theta') 118 | parser.add_argument('--device_index', default='0', type=str, help='device index') 119 | args = parser.parse_args() 120 | 121 | # ---------------------------------------- 122 | # basic configuration 123 | # ---------------------------------------- 124 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 125 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 126 | 127 | if not os.path.exists(args.grad_path): 128 | os.makedirs(args.grad_path) 129 | 130 | print('-' * 50) 131 | print('TRAIN ON:', device) 132 | print('DATA PATH:', args.data_path) 133 | print('RESULT PATH:', args.grad_path) 134 | print('-' * 50) 135 | 136 | # ---------------------------------------- 137 | # model/data configuration 138 | # ---------------------------------------- 139 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 140 | model.load_state_dict(torch.load(args.model_path)) 141 | model.to(device) 142 | model.eval() 143 | 144 | data_loader = loaders.load_data(data_dir=args.data_path, data_name=args.data_name, data_type='test', batch_size=512) 145 | 146 | modules = models.load_modules(model=model) 147 | print(modules) 148 | 149 | get_channels = getChannels(modules=modules, num_classes=10) 150 | 151 | # ---------------------------------------- 152 | # forward 153 | # ---------------------------------------- 154 | for i, samples in enumerate(tqdm(data_loader)): 155 | inputs, labels, names = samples 156 | inputs = inputs.to(device) 157 | labels = labels.to(device) 158 | outputs = model(inputs) 159 | 160 | get_channels(outputs, labels) 161 | 162 | get_channels.sift(result_path=args.grad_path, threshold=args.theta) 163 | 164 | 165 | if __name__ == '__main__': 166 | np.set_printoptions(threshold=np.inf) 167 | main() -------------------------------------------------------------------------------- /doctor/get_grads.py: -------------------------------------------------------------------------------- 1 | """ 2 | activation: 3 | each layer 4 | 5 | gradient: 6 | output to each layer 7 | """ 8 | import sys 9 | 10 | sys.path.append('.') 11 | 12 | import argparse 13 | import torch 14 | import numpy as np 15 | import os 16 | from tqdm import tqdm 17 | 18 | import models 19 | import loaders 20 | 21 | 22 | class HookModule: 23 | def __init__(self, module): 24 | self.inputs = None 25 | self.outputs = None 26 | module.register_forward_hook(self._hook) 27 | 28 | def grads(self, outputs, inputs=None, retain_graph=True, create_graph=True): 29 | if inputs is None: 30 | inputs = self.outputs # default the output dim 31 | 32 | return torch.autograd.grad(outputs=outputs, 33 | inputs=inputs, 34 | retain_graph=retain_graph, 35 | create_graph=create_graph)[0] 36 | 37 | def _hook(self, module, inputs, outputs): 38 | self.inputs = inputs[0] 39 | self.outputs = outputs 40 | 41 | 42 | def _normalization(data, axis=None, bot=False): 43 | assert axis in [None, 0, 1] 44 | _max = np.max(data, axis=axis) 45 | if bot: 46 | _min = np.zeros(_max.shape) 47 | else: 48 | _min = np.min(data, axis=axis) 49 | _range = _max - _min 50 | if axis == 1: 51 | _norm = ((data.T - _min) / (_range + 1e-5)).T 52 | else: 53 | _norm = (data - _min) / (_range + 1e-5) 54 | return _norm 55 | 56 | 57 | class GradCalculate: 58 | def __init__(self, modules, num_classes): 59 | self.modules = [HookModule(module) for module in modules] 60 | self.values = [[[] for _ in range(50)] for _ in range(12)] 61 | self.num_classes = num_classes 62 | # [num_classes, num_images, channels] 63 | 64 | def __call__(self, outputs, labels): 65 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 66 | for block in range(12): 67 | layer = int((11 - block) * 5 + 2) 68 | module = self.modules[layer] 69 | 70 | values = module.grads(-nll_loss, module.outputs) 71 | values = torch.relu(values) 72 | 73 | values = values.detach().cpu().numpy() 74 | 75 | for b in range(len(labels)): 76 | self.values[block][labels[b]].append(values[b]) 77 | 78 | def sift(self, result_path, threshold): 79 | # layer = self.layer 80 | # for label in range(self.num_classes): 81 | # values = self.values[label] # (num_images, tokens, channels) 82 | # values = np.asarray(values) 83 | # # values = np.sum(values, axis=1) # (num_images, channels) 84 | # if len(values.shape) > 2: 85 | # values = np.squeeze(values[:, 0, :]) # [num_classes, num_images, channels] 86 | 87 | # values = _normalization(values, axis=1) 88 | 89 | # mask = np.zeros(values.shape) 90 | # mask[np.where(values > threshold)] = 1 91 | 92 | # result_path = os.path.join(result_path, block) 93 | # mask_path = os.path.join(result_path, 'block_{}_layer_{}_class_{}.npy'.format(block, layer, label)) 94 | # np.save(mask_path, mask) 95 | # print(mask_path) 96 | 97 | for block in range(12): 98 | layer = int((11 - block) * 5 + 2) 99 | for label in range(10): 100 | values = self.values[block][label] # (num_images, tokens, channels) 101 | values = np.asarray(values) 102 | # values = np.sum(values, axis=1) # (num_images, channels) 103 | if len(values.shape) > 2: 104 | values = np.squeeze(values[:, 0, :]) # [num_classes, num_images, channels] 105 | 106 | values = _normalization(values, axis=1) 107 | 108 | mask = np.zeros(values.shape) 109 | mask[np.where(values > threshold)] = 1 110 | 111 | result_path_block = os.path.join(result_path, str(block)) 112 | mask_path = os.path.join(result_path_block, 'block_{}_layer_{}_class_{}.npy'.format(block, layer, label)) 113 | np.save(mask_path, mask) 114 | print(mask_path) 115 | 116 | 117 | 118 | def main(): 119 | parser = argparse.ArgumentParser(description='') 120 | parser.add_argument('--model_name', default='', type=str, help='model name') 121 | parser.add_argument('--data_name', default='', type=str, help='data name') 122 | parser.add_argument('--in_channels', default='', type=int, help='in channels') 123 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 124 | parser.add_argument('--model_path', default='', type=str, help='model path') 125 | parser.add_argument('--data_path', default='', type=str, help='data path') 126 | parser.add_argument('--grad_path', default='', type=str, help='grad path') 127 | parser.add_argument('--theta', default='', type=float, help='theta') 128 | parser.add_argument('--device_index', default='0', type=str, help='device index') 129 | args = parser.parse_args() 130 | 131 | # ---------------------------------------- 132 | # basic configuration 133 | # ---------------------------------------- 134 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 135 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 136 | 137 | if not os.path.exists(args.grad_path): 138 | os.makedirs(args.grad_path) 139 | 140 | print('-' * 50) 141 | print('TRAIN ON:', device) 142 | print('DATA PATH:', args.data_path) 143 | print('RESULT PATH:', args.grad_path) 144 | print('-' * 50) 145 | 146 | # ---------------------------------------- 147 | # model/data configuration 148 | # ---------------------------------------- 149 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 150 | model.load_state_dict(torch.load(args.model_path)) 151 | model.to(device) 152 | model.eval() 153 | 154 | data_loader = loaders.load_data(data_dir=args.data_path, data_name=args.data_name, data_type='test') 155 | 156 | modules = models.load_modules(model=model) 157 | 158 | grad_calculate = GradCalculate(modules=modules, num_classes=args.num_classes) 159 | 160 | # ---------------------------------------- 161 | # forward 162 | # ---------------------------------------- 163 | for i, samples in enumerate(tqdm(data_loader)): 164 | inputs, labels, _ = samples 165 | inputs = inputs.to(device) 166 | labels = labels.to(device) 167 | outputs = model(inputs) 168 | 169 | grad_calculate(outputs, labels) 170 | 171 | grad_calculate.sift(result_path=args.grad_path, threshold=args.theta) 172 | 173 | 174 | if __name__ == '__main__': 175 | np.set_printoptions(threshold=np.inf) 176 | main() -------------------------------------------------------------------------------- /doctor/test_mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | import re 4 | import time 5 | from collections import OrderedDict 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | from torch import nn 12 | 13 | import loaders 14 | import models 15 | from metrics import ildv 16 | 17 | from utils.misc import reduce_dict, update, MetricLogger, SmoothedValue 18 | from evaluator.default import DefaultEvaluator 19 | from models.tvit_mask import vit_tiny_imagenet 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--model_name', default='', type=str, help='model name') 25 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 26 | parser.add_argument('--model_path', default='', type=str, help='model path') 27 | parser.add_argument('--data_name', default='', type=str, help='data name') 28 | parser.add_argument('--data_dir', default='', type=str, help='data directory') 29 | parser.add_argument('--save_dir', default='', type=str, help='save directory') 30 | args = parser.parse_args() 31 | 32 | # ---------------------------------------- 33 | # basic configuration 34 | # ---------------------------------------- 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | 37 | if not os.path.exists(args.save_dir): 38 | os.makedirs(args.save_dir) 39 | 40 | print('-' * 50) 41 | print('DEVICE:', device) 42 | print('MODEL PATH:', args.model_path) 43 | print('DATA PATH:', args.data_dir) 44 | print('-' * 50) 45 | 46 | # ---------------------------------------- 47 | # test configuration 48 | # ---------------------------------------- 49 | model_kwargs = dict(drop_path_rate=0.1) 50 | model_kwargs['num_classes'] = 10 51 | print(model_kwargs) 52 | model = vit_tiny_imagenet(**model_kwargs) 53 | # model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 54 | # model.load_state_dict(torch.load(args.model_path)['model'],strict=False) 55 | print(model.load_state_dict(torch.load(args.model_path), strict=False)) 56 | 57 | model.to(device) 58 | 59 | test_loader = loaders.load_data_mask(data_dir=args.data_dir, 60 | data_name=args.data_name, 61 | data_type='test', 62 | # mask_dir='/mnt/nfs/hjc/project/td/output/ideal/atts/imagenet10', 63 | mask_dir='/vipa-nfs/homes/hjc/projects/TD/outputs/ideal/low_images/tvit/imagenet10_mask', 64 | batch_size=512) 65 | 66 | # 计算方法 67 | # evaluates = { 68 | # 'A': [ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device), 69 | # ildv.MulticlassF1Score(average='macro', num_classes=args.num_classes).to(device)], 70 | # 'B': [ildv.MulticlassFalseNegativeRate(average='macro', num_classes=args.num_classes).to(device), 71 | # ildv.MulticlassFalseDiscoveryRate(average='macro', num_classes=args.num_classes).to(device)], 72 | # 'C': [ildv.MulticlassBalancedAccuracy(args.num_classes).to(device)] 73 | # } 74 | evaluates = DefaultEvaluator(metrics=['acc', 'recall', 'precision', 'f1']) 75 | 76 | # ---------------------------------------- 77 | # each epoch 78 | # ---------------------------------------- 79 | since = time.time() 80 | 81 | scores = test(test_loader, model, evaluates, device) 82 | save_path = os.path.join(args.save_dir, '{}_{}.npy'.format(args.model_name, args.data_name)) 83 | np.save(save_path, scores) 84 | 85 | print('-' * 50) 86 | print('TIME CONSUMED', time.time() - since) 87 | 88 | 89 | # def test(test_loader, model, evaluates, device): 90 | # model.eval() 91 | 92 | # for i, samples in enumerate(tqdm(test_loader)): 93 | # # print(samples) 94 | # tmp = samples 95 | # inputs, labels, names = tmp 96 | # inputs = inputs.to(device) 97 | # labels = labels.to(device) 98 | 99 | # with torch.no_grad(): 100 | # outputs = model(inputs) 101 | 102 | # for value in evaluates.values(): 103 | # for evaluate in value: 104 | # evaluate.update(outputs, labels) 105 | 106 | # # calculate result 107 | # scores = {} 108 | # for key, value in zip(evaluates.keys(), evaluates.values()): 109 | # scores[key] = [] 110 | # for evaluate in value: 111 | # score = evaluate.compute().cpu().item() 112 | # scores[key].append(score) 113 | 114 | 115 | # print(scores) 116 | 117 | # return scores 118 | 119 | def test(test_loader, model, evaluates, device): 120 | model.eval() 121 | 122 | for i, samples in enumerate(tqdm(test_loader)): 123 | # print(len(samples)) 124 | inputs, labels, names, masks = samples 125 | inputs = inputs.to(device) 126 | labels = labels.to(device) 127 | masks = masks.to(device) 128 | # print(inputs.shape) 129 | # print(labels.shape) 130 | # print('====================================') 131 | print(masks.shape) 132 | masks = torch.max_pool2d(masks, kernel_size=16, stride=16, padding=0) 133 | # masks = torch.nn.AvgPool2d(kernel_size=16, stride=16, padding=0)(masks) 134 | # masks = torch.where(masks > 0.5, 1, 0) 135 | masks = torch.flatten(masks, start_dim=1, end_dim=3) 136 | ones = torch.ones((masks.shape[0], 1)).to(masks.device) 137 | masks = torch.cat((ones, masks), dim=1) 138 | masks = masks.unsqueeze(1).unsqueeze(2) 139 | # masks = 1 - masks 140 | print(masks.shape) 141 | print('masks', torch.min(masks), torch.max(masks), torch.mean(masks)) 142 | # print('====================================') 143 | grads = [] 144 | layers = [(11 - block) * 4 + 3 for block in range(12)] 145 | print(layers) 146 | for layer in layers: 147 | print('==>layer', layer) 148 | # np_path = '/mnt/nfs/hjc/project/td/output/ideal/grads/tvit/imagenet10/ac_grad_layer_{}.npy'.format(layer) 149 | np_path = '/mnt/nfs/hjc/project/td/output/ideal/grads/tvit/imagenet10/layer_{}.npy'.format(layer) 150 | grad = np.load(np_path) 151 | grads.append(grad) 152 | grads = torch.asarray(grads).to(inputs.dtype).cuda() 153 | grads = torch.transpose(grads, 0, 1) 154 | grads = torch.index_select(grads, dim=0, index=labels) 155 | print(grads.shape) 156 | print('grads', torch.min(grads), torch.max(grads), torch.mean(grads)) 157 | 158 | # print('====================================') 159 | 160 | with torch.cuda.amp.autocast(enabled=False): 161 | outputs = model(inputs, None, grads) 162 | 163 | evaluates.update(outputs, labels) 164 | print('ACC:', evaluates.metric_acc(evaluates.outputs, evaluates.targets)) 165 | 166 | return evaluates 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /doctor/train_1k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.utils.tensorboard import SummaryWriter 11 | import shutil 12 | import timm.scheduler as timm_scheduler 13 | 14 | import loaders 15 | import models 16 | from metrics import ildv 17 | from utils.misc import reduce_dict, update, MetricLogger, SmoothedValue 18 | from evaluator.default import DefaultEvaluator 19 | from criterions import cross_entropy 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description='') 23 | parser.add_argument('--model_name', default='', type=str, help='model name') 24 | parser.add_argument('--data_name', default='', type=str, help='data name') 25 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 26 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 27 | parser.add_argument('--model_dir', default='', type=str, help='model dir') 28 | parser.add_argument('--train_data_dir', default='', type=str, help='train_data_dir') 29 | parser.add_argument('--test_data_dir', default='', type=str, help='test_data_dir') 30 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 31 | args = parser.parse_args() 32 | 33 | # ---------------------------------------- 34 | # basic configuration 35 | # ---------------------------------------- 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | 38 | train_dir = args.train_data_dir 39 | test_dir = args.test_data_dir 40 | 41 | if not os.path.exists(args.model_dir): 42 | os.makedirs(args.model_dir) 43 | if os.path.exists(args.log_dir): 44 | shutil.rmtree(args.log_dir) 45 | 46 | print('-' * 50) 47 | print('TRAIN ON:', device) 48 | print('TRAIN DATA DIR:', args.train_data_dir) 49 | print('TEST DATA DIR:', args.test_data_dir) 50 | print('MODEL DIR:', args.model_dir) 51 | print('LOG DIR:', args.log_dir) 52 | print('-' * 50) 53 | 54 | # ---------------------------------------- 55 | # trainer configuration 56 | # ---------------------------------------- 57 | model = models.load_model(args.model_name, args.data_name, num_classes=args.num_classes) 58 | model.to(device) 59 | print('MODEL LOAD DONE') 60 | train_loader = loaders.load_data(train_dir, args.data_name, data_type='train') 61 | print('TRAIN LOAD DONE') 62 | test_loader = loaders.load_data(test_dir, args.data_name, data_type='test') 63 | print('TEST LOAD DONE') 64 | 65 | # criterion = nn.CrossEntropyLoss() 66 | criterion = cross_entropy.CrossEntropy(losses=['labels'], weight_dict={'loss_ce': 1}) 67 | print(criterion) 68 | 69 | batch_size = 256 70 | learn_rate = 0.0005 * (batch_size / 512) 71 | optimizer = optim.AdamW(params=model.parameters(), lr=learn_rate, weight_decay=0.05, eps=1e-8) 72 | num_steps = int(args.num_epochs * len(train_loader)) 73 | warmup_epochs = 5 74 | warmup_steps = 0 75 | warmup_lr = 1e-06 76 | scheduler = timm_scheduler.CosineLRScheduler( 77 | optimizer, 78 | t_initial=(num_steps - warmup_steps), 79 | lr_min=1e-05, 80 | warmup_lr_init=warmup_lr, 81 | warmup_t=warmup_steps, 82 | cycle_limit=1, 83 | t_in_epochs=False, 84 | ) 85 | 86 | # optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-2) 87 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 88 | 89 | writer = SummaryWriter(args.log_dir) 90 | 91 | # ---------------------------------------- 92 | # each epoch 93 | # ---------------------------------------- 94 | since = time.time() 95 | 96 | best_acc = None 97 | best_epoch = None 98 | 99 | for epoch in tqdm(range(args.num_epochs)): 100 | acc1, loss = train(train_loader, model, criterion, optimizer, device, args) 101 | writer.add_scalar(tag='training acc1', scalar_value=acc1, global_step=epoch) 102 | writer.add_scalar(tag='training loss', scalar_value=loss, global_step=epoch) 103 | acc1, loss = test(test_loader, model, criterion, device, args) 104 | writer.add_scalar(tag='test acc1', scalar_value=acc1, global_step=epoch) 105 | writer.add_scalar(tag='test loss', scalar_value=loss, global_step=epoch) 106 | 107 | # ---------------------------------------- 108 | # save best model 109 | # ---------------------------------------- 110 | if best_acc is None or best_acc < acc1: 111 | best_acc = acc1 112 | best_epoch = epoch 113 | torch.save(model.state_dict(), os.path.join(args.model_dir, 'tvit_imagenet10_baseline_test.pth')) 114 | print(best_acc) 115 | 116 | scheduler.step(epoch) 117 | 118 | print('COMPLETE !!!') 119 | print('BEST ACC', best_acc) 120 | print('BEST EPOCH', best_epoch) 121 | print('TIME CONSUMED', time.time() - since) 122 | 123 | 124 | def train(train_loader, model, criterion, optimizer, device, args): 125 | 126 | model.train() 127 | criterion.train() 128 | 129 | evaluator = DefaultEvaluator(metrics=['acc', 'recall', 'precision', 'f1']) 130 | 131 | print('LEN OF TRAIN_LOADER') 132 | print(len(train_loader)) 133 | for samples in tqdm(enumerate(train_loader)): 134 | _, tmp = samples 135 | inputs, labels, names = tmp 136 | inputs = inputs.to(device) 137 | labels = labels.to(device) 138 | outputs = model(inputs) 139 | 140 | # loss = criterion(outputs, labels) 141 | loss_dict = criterion(outputs, labels) 142 | print(loss_dict) 143 | weight_dict = criterion.weight_dict 144 | loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 145 | 146 | evaluator.update(outputs, labels) 147 | 148 | optimizer.zero_grad() # 1 149 | # if scaler: 150 | # scaler.scale(loss).backward() 151 | # if max_norm > 0: 152 | # scaler.unscale_(optimizer) 153 | # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 154 | # scaler.step(optimizer) 155 | # scaler.update() 156 | loss.backward() # 2 157 | optimizer.step() # 3 158 | 159 | acc = evaluator.metric_acc(evaluator.outputs, evaluator.targets) 160 | 161 | return acc, loss.item() 162 | 163 | 164 | def test(test_loader, model, criterion, device, args): 165 | 166 | model.eval() 167 | 168 | evaluator = DefaultEvaluator(metrics=['acc', 'recall', 'precision', 'f1']) 169 | 170 | for samples in enumerate(test_loader): 171 | _, tmp = samples 172 | inputs, labels, names = tmp 173 | inputs = inputs.to(device) 174 | labels = labels.to(device) 175 | 176 | with torch.no_grad(): 177 | outputs = model(inputs) 178 | loss = criterion(outputs, labels) 179 | 180 | evaluator.update(outputs, labels) 181 | 182 | acc = evaluator.metric_acc(evaluator.outputs, evaluator.targets) 183 | 184 | return acc, loss.item() 185 | 186 | 187 | if __name__ == '__main__': 188 | main() 189 | -------------------------------------------------------------------------------- /loaders/image_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torch.utils.data.distributed import DistributedSampler 3 | from torchvision import transforms 4 | from torchvision import datasets 5 | # from torchvision.datasets import ImageFolder 6 | from loaders.image_dataset import ImageDataset 7 | from loaders.image_dataset import ImageMaskDataset 8 | from loaders.image_masks_transforms import Normalize, ToTensor, Resize, RandomHorizontalFlip, Compose 9 | from loaders.data_enhance import create_transform, create_transform_mask 10 | 11 | def _build_timm_aug_kwargs(image_size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 12 | train_aug_kwargs = dict(input_size=image_size, is_training=True, use_prefetcher=False, no_aug=False, 13 | scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), hflip=0.5, vflip=0., color_jitter=0.4, 14 | auto_augment='rand-m9-mstd0.5-inc1', interpolation='random', mean=mean, std=std, 15 | re_prob=0.25, re_mode='pixel', re_count=1, re_num_splits=0, separate=False) 16 | 17 | eval_aug_kwargs = dict(input_size=image_size, is_training=False, use_prefetcher=False, no_aug=False, crop_pct=0.875, 18 | interpolation='bilinear', mean=mean, std=std) 19 | 20 | return { 21 | 'train_aug_kwargs': train_aug_kwargs, 22 | 'eval_aug_kwargs': eval_aug_kwargs 23 | } 24 | 25 | cifar10_train_transform = transforms.Compose([ 26 | transforms.RandomCrop(32, padding=4), 27 | transforms.Resize((32, 32)), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor(), 30 | # transforms.Normalize((0.4913, 0.4821, 0.4465), 31 | # (0.2470, 0.2434, 0.2615)), 32 | transforms.Normalize((0.4914, 0.4822, 0.4465), 33 | (0.2023, 0.1994, 0.2010)), 34 | ]) 35 | 36 | cifar10_test_transform = transforms.Compose([ 37 | transforms.Resize((32, 32)), 38 | transforms.ToTensor(), 39 | # transforms.Normalize((0.4913, 0.4821, 0.4465), 40 | # (0.2470, 0.2434, 0.2615)), 41 | transforms.Normalize((0.4914, 0.4822, 0.4465), 42 | (0.2023, 0.1994, 0.2010)), 43 | ]) 44 | 45 | cifar100_train_transform = transforms.Compose([ 46 | transforms.RandomCrop(32, padding=4), 47 | transforms.Resize((32, 32)), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.4913, 0.4821, 0.4465), 51 | (0.2470, 0.2434, 0.2615)), 52 | ]) 53 | 54 | cifar100_test_transform = transforms.Compose([ 55 | transforms.Resize((32, 32)), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.4913, 0.4821, 0.4465), 58 | (0.2470, 0.2434, 0.2615)), 59 | ]) 60 | 61 | imagenet_train_transform = transforms.Compose([ 62 | transforms.Resize((224, 224)), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 66 | ]) 67 | 68 | imagenet_test_transform = transforms.Compose([ 69 | transforms.Resize((224, 224)), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 72 | ]) 73 | 74 | imagenet_mask_train_transform = Compose([ 75 | Resize((224, 224)), 76 | RandomHorizontalFlip(), 77 | ToTensor(), 78 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 79 | ]) 80 | 81 | imagenet_mask_test_transform = Compose([ 82 | Resize((224, 224)), 83 | ToTensor(), 84 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 85 | ]) 86 | 87 | mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 88 | aug_kwargs = _build_timm_aug_kwargs(224, mean, std) 89 | imagenet_train_transform_enhance = create_transform(**aug_kwargs['train_aug_kwargs']) 90 | imagenet_test_transform_enhance = create_transform(**aug_kwargs['eval_aug_kwargs']) 91 | imagenet_mask_train_transform_enhance = create_transform_mask(**aug_kwargs['train_aug_kwargs']) 92 | imagenet_mask_test_transform_enhance = create_transform(**aug_kwargs['eval_aug_kwargs']) 93 | 94 | 95 | def _get_set(data_path, transform): 96 | return ImageDataset(image_dir=data_path, 97 | transform=transform) 98 | 99 | def _get_set_mask(data_path, transform, mask_path): 100 | return ImageMaskDataset(image_dir=data_path, 101 | mask_dir=mask_path, 102 | transform=transform) 103 | 104 | 105 | def load_images(data_dir, data_name, data_type=None, batch_size=512): 106 | assert data_name in ['cifar10', 'cifar100', 'imagenet', 'imagenet50', 'imagenet10', 'imagenet300', 'imagenet100'] 107 | assert data_type is None or data_type in ['train', 'test'] 108 | 109 | data_transform = None 110 | if data_name == 'cifar10' and data_type == 'train': 111 | data_transform = cifar10_train_transform 112 | elif data_name == 'cifar10' and data_type == 'test': 113 | data_transform = cifar10_test_transform 114 | elif data_name == 'cifar100' and data_type == 'train': 115 | data_transform = cifar100_train_transform 116 | elif data_name == 'cifar100' and data_type == 'test': 117 | data_transform = cifar100_test_transform 118 | elif ('imagenet' in data_name) and data_type == 'train': 119 | data_transform = imagenet_train_transform 120 | elif ('imagenet' in data_name) and data_type == 'test': 121 | data_transform = imagenet_test_transform 122 | assert data_transform is not None 123 | 124 | print(data_transform) 125 | 126 | data_set = _get_set(data_dir, transform=data_transform) 127 | # data_set = datasets.ImageFolder(root=data_dir, 128 | # transform=imagenet_test_transform) 129 | data_loader = DataLoader(dataset=data_set, 130 | batch_size=batch_size, 131 | num_workers=8, 132 | shuffle=True) 133 | 134 | return data_loader 135 | 136 | 137 | def load_images_masks(data_dir, data_name, mask_dir, data_type=None, batch_size=512): 138 | assert data_name in ['cifar10', 'cifar100', 'imagenet', 'imagenet50', 'imagenet10', 'imagenet300', 'imagenet100'] 139 | assert data_type is None or data_type in ['train', 'test'] 140 | 141 | data_transform = None 142 | if data_name == 'cifar10' and data_type == 'train': 143 | data_transform = cifar10_train_transform 144 | elif data_name == 'cifar10' and data_type == 'test': 145 | data_transform = cifar10_test_transform 146 | elif ('imagenet' in data_name) and data_type == 'train': 147 | data_transform = imagenet_mask_train_transform 148 | elif ('imagenet' in data_name) and data_type == 'test': 149 | data_transform = imagenet_mask_test_transform 150 | assert data_transform is not None 151 | 152 | data_set = _get_set_mask(data_dir, transform=data_transform, mask_path=mask_dir) 153 | # data_set = datasets.ImageFolder(root=data_dir, 154 | # transform=imagenet_test_transform) 155 | data_loader = DataLoader(dataset=data_set, 156 | batch_size=batch_size, 157 | num_workers=8, 158 | shuffle=True) 159 | 160 | return data_loader 161 | -------------------------------------------------------------------------------- /metrics/ildv/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metrics import Metric 3 | 4 | from torcheval.metrics import MulticlassAccuracy, MulticlassF1Score 5 | from torchmetrics import Accuracy, Recall, Specificity 6 | from torchmetrics.classification.stat_scores import MulticlassStatScores, Tensor 7 | 8 | 9 | def my_safe_divide(num: Tensor, denom: Tensor) -> Tensor: 10 | denom[denom == 0.0] = 1 11 | num = num if num.is_floating_point() else num.float() 12 | denom = denom if denom.is_floating_point() else denom.float() 13 | return num / denom 14 | 15 | 16 | class MulticlassFalsePositiveRate(Metric): 17 | def __init__(self, num_classes): 18 | super().__init__() 19 | self.false_positives = 1 - Specificity(task='multiclass', average='macro', num_classes=num_classes) 20 | 21 | def update(self, outputs, targets): 22 | self.false_positives.update(outputs, targets) 23 | 24 | def compute(self): 25 | FPR = self.false_positives.compute() 26 | return FPR 27 | 28 | def to(self, device): 29 | self.false_positives.to(device) 30 | return self 31 | 32 | 33 | class MulticlassFNR(MulticlassStatScores): 34 | def compute(self) -> Tensor: 35 | """Computes accuracy based on inputs passed in to ``update`` previously.""" 36 | tp, fp, tn, fn = self._final_state() 37 | return my_safe_divide(fn , fn + tp) 38 | 39 | 40 | class MulticlassFalseNegativeRate(Metric): 41 | def __init__(self, average, num_classes): 42 | super().__init__() 43 | self.avg = average 44 | # self.false_negatives = 1 - Recall(task='multiclass', average='micro', num_classes=num_classes) 45 | self.false_negatives = MulticlassFNR(num_classes=num_classes, average=average) 46 | 47 | def update(self, outputs, targets): 48 | self.false_negatives.update(outputs, targets) 49 | 50 | def compute(self): 51 | FNR = self.false_negatives.compute() 52 | if self.avg == 'macro': 53 | tmp = 0.0 54 | for i in FNR: 55 | tmp += i 56 | tmp /= 1000 57 | return tmp 58 | else: 59 | return FNR 60 | 61 | def to(self, device): 62 | self.false_negatives.to(device) 63 | return self 64 | 65 | 66 | class MulticlassFOR(MulticlassStatScores): 67 | def compute(self) -> Tensor: 68 | """Computes accuracy based on inputs passed in to ``update`` previously.""" 69 | tp, fp, tn, fn = self._final_state() 70 | return my_safe_divide(fn, fn + tn) 71 | 72 | 73 | class MulticlassFalseOmissionRate(Metric): 74 | def __init__(self, num_classes): 75 | super().__init__() 76 | self.false_omission = MulticlassFOR(num_classes=num_classes, average='macro') 77 | 78 | def update(self, outputs, targets): 79 | self.false_omission.update(outputs, targets) 80 | 81 | def compute(self): 82 | FOR = self.false_omission.compute() 83 | return FOR 84 | 85 | def to(self, device): 86 | self.false_omission.to(device) 87 | return self 88 | 89 | 90 | class MulticlassFDR(MulticlassStatScores): 91 | def compute(self) -> Tensor: 92 | """Computes accuracy based on inputs passed in to ``update`` previously.""" 93 | tp, fp, tn, fn = self._final_state() 94 | return my_safe_divide(fp, fp + tp) 95 | 96 | 97 | class MulticlassFalseDiscoveryRate(Metric): 98 | def __init__(self, average, num_classes): 99 | super().__init__() 100 | self.avg = average 101 | self.false_discovery = MulticlassFDR(num_classes=num_classes, average=average) 102 | 103 | def update(self, outputs, targets): 104 | self.false_discovery.update(outputs, targets) 105 | 106 | def compute(self): 107 | FDR = self.false_discovery.compute() 108 | if self.avg == 'macro': 109 | tmp = 0.0 110 | for i in FDR: 111 | tmp += i 112 | tmp /= 1000 113 | return tmp 114 | else: 115 | return FDR 116 | 117 | def to(self, device): 118 | self.false_discovery.to(device) 119 | return self 120 | 121 | 122 | # class MulticlassFalsePositiveRate(Metric): 123 | # def __init__(self): 124 | # super().__init__() 125 | # self.false_positives = torch.tensor(0) 126 | # self.true_negatives = torch.tensor(0) 127 | 128 | # def update(self, outputs, targets): 129 | # targets = torch.reshape(targets, outputs.shape) 130 | # false_positives = torch.logical_and(outputs == 1, targets == 0).sum() 131 | # true_negatives = torch.logical_and(outputs == 0, targets == 0).sum() 132 | # self.false_positives += false_positives 133 | # self.true_negatives += true_negatives 134 | 135 | # def compute(self): 136 | # return self.false_positives / (self.false_positives + self.true_negatives) 137 | 138 | 139 | # class MulticlassFalseNegativeRate(Metric): 140 | # def __init__(self): 141 | # super().__init__() 142 | # self.false_negatives = torch.tensor(0) 143 | # self.true_positives = torch.tensor(0) 144 | 145 | # def update(self, outputs, targets): 146 | # targets = torch.reshape(targets, outputs.shape) 147 | # false_negatives = torch.logical_and(outputs == 0, targets == 1).sum() 148 | # true_positives = torch.logical_and(outputs == 1, targets == 1).sum() 149 | # self.false_negatives += false_negatives 150 | # self.true_positives += true_positives 151 | 152 | # def compute(self): 153 | # return self.false_negatives / (self.false_negatives + self.true_positives) 154 | 155 | 156 | class MulticlassBalancedAccuracy(Metric): 157 | def __init__(self, num_classes): 158 | super().__init__() 159 | self.sensitivity = Recall(task='multiclass', average='macro', num_classes=num_classes) 160 | self.specificity = Specificity(task='multiclass', average='macro', num_classes=num_classes) 161 | 162 | def update(self, outputs, targets): 163 | self.sensitivity.update(outputs, targets) 164 | self.specificity.update(outputs, targets) 165 | 166 | def compute(self): 167 | ba = (self.sensitivity.compute() + self.specificity.compute()) / 2 168 | return ba 169 | 170 | def to(self, device): 171 | self.sensitivity.to(device) 172 | self.specificity.to(device) 173 | return self 174 | 175 | 176 | class MulticlassOptimizedPrecision(Metric): 177 | def __init__(self, num_classes): 178 | super().__init__() 179 | self.accuracy = Accuracy(task='multiclass', average='macro', num_classes=num_classes) 180 | self.sensitivity = Recall(task='multiclass', average='macro', num_classes=num_classes) 181 | self.specificity = Specificity(task='multiclass', average='macro', num_classes=num_classes) 182 | 183 | def update(self, outputs, targets): 184 | self.accuracy.update(outputs, targets) 185 | self.sensitivity.update(outputs, targets) 186 | self.specificity.update(outputs, targets) 187 | 188 | def compute(self): 189 | sensitivity = self.sensitivity.compute() 190 | specificity = self.specificity.compute() 191 | op = self.accuracy.compute() - torch.abs(sensitivity - specificity) / (sensitivity + specificity) 192 | return op 193 | 194 | def to(self, device): 195 | self.accuracy.to(device) 196 | self.sensitivity.to(device) 197 | self.specificity.to(device) 198 | return self 199 | -------------------------------------------------------------------------------- /doctor/train_lrp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from tqdm import tqdm 5 | import numpy as np 6 | import pickle 7 | 8 | import torch 9 | from torch import nn 10 | from torch import optim 11 | from torch.utils.tensorboard import SummaryWriter 12 | from constraint import Constraint 13 | import shutil 14 | from metrics import ildv 15 | 16 | import timm.scheduler as timm_scheduler 17 | 18 | import loaders 19 | import models 20 | from PIL import Image 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description='') 25 | parser.add_argument('--model_name', default='', type=str, help='model name') 26 | parser.add_argument('--data_name', default='', type=str, help='data name') 27 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 28 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 29 | parser.add_argument('--model_dir', default='', type=str, help='model dir') 30 | parser.add_argument('--origin_dir', default='', type=str, help='origin dir') 31 | parser.add_argument('--train_data_dir', default='', type=str, help='train_data_dir') 32 | parser.add_argument('--test_data_dir', default='', type=str, help='test_data_dir') 33 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 34 | parser.add_argument('--grad_dir', default='', type=str, help='grad dir') 35 | parser.add_argument('--lrp_dir', default='', type=str, help='lrp dir') 36 | parser.add_argument('--mask_dir', default='', type=str, help='mask dir') 37 | args = parser.parse_args() 38 | 39 | # ---------------------------------------- 40 | # basic configuration 41 | # ---------------------------------------- 42 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 43 | 44 | print('DEVICE: ', device) 45 | 46 | train_dir = args.train_data_dir 47 | test_dir = args.test_data_dir 48 | 49 | 50 | if not os.path.exists(args.model_dir): 51 | os.makedirs(args.model_dir) 52 | if os.path.exists(args.log_dir): 53 | shutil.rmtree(args.log_dir) 54 | 55 | print('-' * 50) 56 | print('TRAIN ON:', device) 57 | print('TRAIN DATA DIR:', args.train_data_dir) 58 | print('TEST DATA DIR:', args.test_data_dir) 59 | print('MODEL DIR:', args.model_dir) 60 | print('LOG DIR:', args.log_dir) 61 | print('ORIGIN DIR:', args.origin_dir) 62 | print('-' * 50) 63 | 64 | # ---------------------------------------- 65 | # trainer configuration 66 | # ---------------------------------------- 67 | model = models.load_model(args.model_name, args.data_name, num_classes=args.num_classes) 68 | model.load_state_dict(torch.load(args.origin_dir), strict=False) 69 | model.to(device) 70 | print('MODEL LOAD DONE') 71 | train_mask_dir = args.mask_dir 72 | lrp_dir = args.lrp_dir 73 | grad_dir = args.grad_dir 74 | train_loader = loaders.load_data(train_dir, args.data_name, data_type='train') 75 | print('TRAIN LOAD DONE') 76 | test_loader = loaders.load_data(test_dir, args.data_name, data_type='test') 77 | print('TEST LOAD DONE') 78 | modules = models.load_modules(model=model) 79 | print('MODULE LODE DONE') 80 | 81 | criterion = nn.CrossEntropyLoss() 82 | 83 | # optimizer = optim.AdamW(params=model.parameters(), lr=0.0005, weight_decay=0.05, eps=1e-8) 84 | # num_steps = int(args.num_epochs * len(train_loader)) 85 | # warmup_steps = 0 86 | # warmup_lr = 1e-06 87 | # scheduler = timm_scheduler.CosineLRScheduler( 88 | # optimizer, 89 | # t_initial=(num_steps - warmup_steps), 90 | # lr_min=1e-05, 91 | # warmup_lr_init=warmup_lr, 92 | # warmup_t=warmup_steps, 93 | # cycle_limit=1, 94 | # t_in_epochs=False, 95 | # ) 96 | 97 | optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 98 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 99 | 100 | grad_ratio = 1 101 | att_ratio = 1 102 | lrp_ratio = 0.1 103 | constraint = Constraint(model_name = args.model_name, modules=modules, grad_path=grad_dir, lrp_path=lrp_dir, alpha=grad_ratio, beta=att_ratio, gamma=lrp_ratio) 104 | 105 | writer = SummaryWriter(args.log_dir) 106 | 107 | # ---------------------------------------- 108 | # each epoch 109 | # ---------------------------------------- 110 | since = time.time() 111 | 112 | best_acc = None 113 | best_epoch = None 114 | 115 | for epoch in tqdm(range(args.num_epochs)): 116 | acc1, loss, lrp_loss, cr_loss = train(train_loader, model, criterion, optimizer, device, args, constraint) 117 | writer.add_scalar(tag='training acc1', scalar_value=acc1, global_step=epoch) 118 | writer.add_scalar(tag='training loss', scalar_value=loss, global_step=epoch) 119 | writer.add_scalar(tag='training lrp_loss', scalar_value=lrp_loss, global_step=epoch) 120 | writer.add_scalar(tag='training cr_loss', scalar_value=cr_loss, global_step=epoch) 121 | acc1, loss, lrp_loss, cr_loss = test(test_loader, model, criterion, device, args, constraint) 122 | writer.add_scalar(tag='test acc1', scalar_value=acc1, global_step=epoch) 123 | writer.add_scalar(tag='test loss', scalar_value=loss, global_step=epoch) 124 | writer.add_scalar(tag='test lrp_loss', scalar_value=lrp_loss, global_step=epoch) 125 | writer.add_scalar(tag='test cr_loss', scalar_value=cr_loss, global_step=epoch) 126 | 127 | # ---------------------------------------- 128 | # save best model 129 | # ---------------------------------------- 130 | if best_acc is None or best_acc < acc1: 131 | best_acc = acc1 132 | best_epoch = epoch 133 | torch.save(model.state_dict(), os.path.join(args.model_dir, 'evit_imagenet10_lrp.pth')) 134 | print(best_acc) 135 | 136 | scheduler.step(epoch) 137 | 138 | print('COMPLETE !!!') 139 | print('BEST ACC', best_acc) 140 | print('BEST EPOCH', best_epoch) 141 | print('TIME CONSUMED', time.time() - since) 142 | 143 | 144 | def train(train_loader, model, criterion, optimizer, device, args, constraint): 145 | 146 | model.train() 147 | 148 | acc1 = ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device) 149 | 150 | print('LEN OF TRAIN_LOADER') 151 | print(len(train_loader)) 152 | 153 | for samples in tqdm(enumerate(train_loader)): 154 | _, tmp = samples 155 | inputs, labels, names = tmp 156 | inputs = inputs.to(device) 157 | labels = labels.to(device) 158 | outputs = model(inputs) 159 | cr_loss = criterion(outputs, labels) 160 | lrp_loss = constraint.loss_lrp(labels) 161 | # print('lrp_loss: ', lrp_loss) 162 | # print('cr_loss: ', cr_loss) 163 | loss = cr_loss + lrp_loss 164 | 165 | acc1.update(outputs, labels) 166 | optimizer.zero_grad() # 1 167 | loss.backward() # 2 168 | optimizer.step() # 3 169 | 170 | return acc1.compute().item(), loss.item(), lrp_loss.item(), cr_loss.item() 171 | 172 | 173 | def test(test_loader, model, criterion, device, args, constraint): 174 | 175 | model.eval() 176 | 177 | acc1 = ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device) 178 | 179 | for samples in enumerate(test_loader): 180 | _, tmp = samples 181 | inputs, labels, names = tmp 182 | inputs = inputs.to(device) 183 | labels = labels.to(device) 184 | 185 | outputs = model(inputs) 186 | cr_loss = criterion(outputs, labels) 187 | lrp_loss = constraint.loss_lrp(labels) 188 | loss = cr_loss + lrp_loss 189 | 190 | acc1.update(outputs, labels) 191 | 192 | return acc1.compute().item(), loss.item(), lrp_loss.item(), cr_loss.item() 193 | 194 | 195 | if __name__ == '__main__': 196 | main() -------------------------------------------------------------------------------- /utils/cam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | class ActivationsAndGradients: 6 | """ Class for extracting activations and 7 | registering gradients from targeted intermediate layers """ 8 | 9 | def __init__(self, model, target_layers, reshape_transform): 10 | self.model = model 11 | self.gradients = [] 12 | self.activations = [] 13 | self.reshape_transform = reshape_transform 14 | self.handles = [] 15 | for target_layer in target_layers: 16 | self.handles.append( 17 | target_layer.register_forward_hook( 18 | self.save_activation)) 19 | # Backward compatibility with older pytorch versions: 20 | if hasattr(target_layer, 'register_full_backward_hook'): 21 | self.handles.append( 22 | target_layer.register_full_backward_hook( 23 | self.save_gradient)) 24 | else: 25 | self.handles.append( 26 | target_layer.register_backward_hook( 27 | self.save_gradient)) 28 | 29 | def save_activation(self, module, input, output): 30 | activation = output 31 | if self.reshape_transform is not None: 32 | activation = self.reshape_transform(activation) 33 | self.activations.append(activation.cpu().detach()) 34 | 35 | def save_gradient(self, module, grad_input, grad_output): 36 | # Gradients are computed in reverse order 37 | grad = grad_output[0] 38 | if self.reshape_transform is not None: 39 | grad = self.reshape_transform(grad) 40 | self.gradients = [grad.cpu().detach()] + self.gradients 41 | 42 | def __call__(self, x): 43 | self.gradients = [] 44 | self.activations = [] 45 | return self.model(x) 46 | 47 | def release(self): 48 | for handle in self.handles: 49 | handle.remove() 50 | 51 | 52 | class GradCAM: 53 | def __init__(self, 54 | model, 55 | target_layers, 56 | reshape_transform=None, 57 | use_cuda=False): 58 | self.model = model.eval() 59 | self.target_layers = target_layers 60 | self.reshape_transform = reshape_transform 61 | self.cuda = use_cuda 62 | if self.cuda: 63 | self.model = model.cuda() 64 | self.activations_and_grads = ActivationsAndGradients( 65 | self.model, target_layers, reshape_transform) 66 | print('self.activations_and_grads: ', self.activations_and_grads) 67 | 68 | """ Get a vector of weights for every channel in the target layer. 69 | Methods that return weights channels, 70 | will typically need to only implement this function. """ 71 | 72 | @staticmethod 73 | def get_cam_weights(grads): 74 | return np.mean(grads, axis=(2, 3), keepdims=True) 75 | 76 | @staticmethod 77 | def get_loss(output, target_category): 78 | loss = 0 79 | for i in range(len(target_category)): 80 | loss = loss + output[i, target_category[i]] 81 | return loss 82 | 83 | def get_cam_image(self, activations, grads): 84 | weights = self.get_cam_weights(grads) 85 | weighted_activations = weights * activations 86 | cam = weighted_activations.sum(axis=1) 87 | 88 | return cam 89 | 90 | @staticmethod 91 | def get_target_width_height(input_tensor): 92 | width, height = input_tensor.size(-1), input_tensor.size(-2) 93 | return width, height 94 | 95 | def compute_cam_per_layer(self, input_tensor): 96 | activations_list = [a.cpu().data.numpy() 97 | for a in self.activations_and_grads.activations] 98 | grads_list = [g.cpu().data.numpy() 99 | for g in self.activations_and_grads.gradients] 100 | target_size = self.get_target_width_height(input_tensor) 101 | 102 | cam_per_target_layer = [] 103 | # Loop over the saliency image from every layer 104 | 105 | for layer_activations, layer_grads in zip(activations_list, grads_list): 106 | cam = self.get_cam_image(layer_activations, layer_grads) 107 | cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image 108 | scaled = self.scale_cam_image(cam, target_size) 109 | cam_per_target_layer.append(scaled[:, None, :]) 110 | 111 | return cam_per_target_layer 112 | 113 | def aggregate_multi_layers(self, cam_per_target_layer): 114 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 115 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) 116 | result = np.mean(cam_per_target_layer, axis=1) 117 | return self.scale_cam_image(result) 118 | 119 | @staticmethod 120 | def scale_cam_image(cam, target_size=None): 121 | result = [] 122 | for img in cam: 123 | img = img - np.min(img) 124 | img = img / (1e-7 + np.max(img)) 125 | if target_size is not None: 126 | img = cv2.resize(img, target_size) 127 | result.append(img) 128 | result = np.float32(result) 129 | 130 | return result 131 | 132 | def __call__(self, input_tensor, target_category=None): 133 | 134 | if self.cuda: 135 | input_tensor = input_tensor.cuda() 136 | 137 | # 正向传播得到网络输出logits(未经过softmax) 138 | output = self.activations_and_grads(input_tensor) 139 | if isinstance(target_category, int): 140 | target_category = [target_category] * input_tensor.size(0) 141 | 142 | if target_category is None: 143 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1) 144 | print(f"category id: {target_category}") 145 | else: 146 | assert (len(target_category) == input_tensor.size(0)) 147 | 148 | self.model.zero_grad() 149 | loss = self.get_loss(output, target_category) 150 | loss.backward(retain_graph=True) 151 | 152 | # In most of the saliency attribution papers, the saliency is 153 | # computed with a single target layer. 154 | # Commonly it is the last convolutional layer. 155 | # Here we support passing a list with multiple target layers. 156 | # It will compute the saliency image for every image, 157 | # and then aggregate them (with a default mean aggregation). 158 | # This gives you more flexibility in case you just want to 159 | # use all conv layers for example, all Batchnorm layers, 160 | # or something else. 161 | cam_per_layer = self.compute_cam_per_layer(input_tensor) 162 | return self.aggregate_multi_layers(cam_per_layer) 163 | 164 | def __del__(self): 165 | self.activations_and_grads.release() 166 | 167 | def __enter__(self): 168 | return self 169 | 170 | def __exit__(self, exc_type, exc_value, exc_tb): 171 | self.activations_and_grads.release() 172 | if isinstance(exc_value, IndexError): 173 | # Handle IndexError here... 174 | print( 175 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") 176 | return True 177 | 178 | 179 | def show_cam_on_image(img: np.ndarray, 180 | mask: np.ndarray, 181 | use_rgb: bool = False, 182 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 183 | """ This function overlays the cam mask on the image as an heatmap. 184 | By default the heatmap is in BGR format. 185 | 186 | :param img: The base image in RGB or BGR format. 187 | :param mask: The cam mask. 188 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 189 | :param colormap: The OpenCV colormap to be used. 190 | :returns: The default image with the cam overlay. 191 | """ 192 | 193 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 194 | if use_rgb: 195 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 196 | heatmap = np.float32(heatmap) / 255 197 | 198 | if np.max(img) > 1: 199 | raise Exception( 200 | "The input image should np.float32 in the range [0, 1]") 201 | -------------------------------------------------------------------------------- /doctor/train_att.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.utils.tensorboard import SummaryWriter 11 | import shutil 12 | from metrics import ildv 13 | 14 | import timm.scheduler as timm_scheduler 15 | 16 | import loaders 17 | import models 18 | from constraint import Constraint 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description='') 23 | parser.add_argument('--model_name', default='', type=str, help='model name') 24 | parser.add_argument('--data_name', default='', type=str, help='data name') 25 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 26 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 27 | parser.add_argument('--batch_size', default=512, type=int, help='batch size') 28 | parser.add_argument('--model_dir', default='', type=str, help='model dir') 29 | parser.add_argument('--result_name', default='', type=str, help='result name') 30 | parser.add_argument('--origin_dir', default='', type=str, help='origin dir') 31 | parser.add_argument('--train_data_dir', default='', type=str, help='train_data_dir') 32 | parser.add_argument('--test_data_dir', default='', type=str, help='test_data_dir') 33 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 34 | parser.add_argument('--grad_dir', default='', type=str, help='grad dir') 35 | parser.add_argument('--lrp_dir', default='', type=str, help='lrp dir') 36 | parser.add_argument('--mask_dir', default='', type=str, help='mask dir') 37 | args = parser.parse_args() 38 | 39 | # ---------------------------------------- 40 | # basic configuration 41 | # ---------------------------------------- 42 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 43 | 44 | print('DEVICE: ', device) 45 | 46 | train_dir = args.train_data_dir 47 | test_dir = args.test_data_dir 48 | 49 | 50 | if not os.path.exists(args.model_dir): 51 | os.makedirs(args.model_dir) 52 | if os.path.exists(args.log_dir): 53 | shutil.rmtree(args.log_dir) 54 | 55 | print('-' * 50) 56 | print('TRAIN ON:', device) 57 | print('TRAIN DATA DIR:', args.train_data_dir) 58 | print('TEST DATA DIR:', args.test_data_dir) 59 | print('MODEL DIR:', args.model_dir) 60 | print('LOG DIR:', args.log_dir) 61 | print('ORIGIN DIR:', args.origin_dir) 62 | print('RESULT NAME:', args.result_name + '.pth') 63 | print('-' * 50) 64 | 65 | # ---------------------------------------- 66 | # trainer configuration 67 | # ---------------------------------------- 68 | model = models.load_model(args.model_name, args.data_name, num_classes=args.num_classes) 69 | model.load_state_dict(torch.load(args.origin_dir), strict=False) 70 | # print(model.load_state_dict(torch.load(args.origin_dir, map_location=torch.device('cpu'))['model'], strict=False)) 71 | model.to(device) 72 | print('MODEL LOAD DONE') 73 | lrp_dir = args.lrp_dir 74 | grad_dir = args.grad_dir 75 | train_mask_dir = args.mask_dir 76 | # data_dir, data_name, mask_dir, data_type 77 | train_loader = loaders.load_data_mask(train_dir, args.data_name, train_mask_dir, data_type='train', batch_size=args.batch_size) 78 | print('TRAIN LOAD DONE') 79 | test_loader = loaders.load_data(test_dir, args.data_name, data_type='test', batch_size=args.batch_size) 80 | print('TEST LOAD DONE') 81 | modules = models.load_modules(model=model) 82 | print('MODULE LODE DONE') 83 | 84 | criterion = nn.CrossEntropyLoss() 85 | 86 | batch_size = args.batch_size 87 | learn_rate = 0.0005 * (batch_size / 512) 88 | optimizer = optim.AdamW(params=model.parameters(), lr=learn_rate, weight_decay=0.05, eps=1e-8) 89 | num_steps = int(args.num_epochs * len(train_loader)) 90 | warmup_epochs = 5 91 | warmup_steps = 0 92 | warmup_lr = 1e-06 93 | scheduler = timm_scheduler.CosineLRScheduler( 94 | optimizer, 95 | t_initial=(num_steps - warmup_steps), 96 | lr_min=1e-05, 97 | warmup_lr_init=warmup_lr, 98 | warmup_t=warmup_steps, 99 | cycle_limit=1, 100 | t_in_epochs=False, 101 | ) 102 | 103 | # optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 104 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 105 | 106 | grad_ratio = 1 107 | att_ratio = 10 108 | lrp_ratio = 1 109 | constraint = Constraint(model_name = args.model_name, modules=modules, grad_path=grad_dir, lrp_path=lrp_dir, alpha=grad_ratio, beta=att_ratio, gamma=lrp_ratio) 110 | 111 | writer = SummaryWriter(args.log_dir) 112 | 113 | # ---------------------------------------- 114 | # each epoch 115 | # ---------------------------------------- 116 | since = time.time() 117 | 118 | best_acc = None 119 | best_epoch = None 120 | 121 | acc1_first, loss_first = test(test_loader, model, criterion, device, args, constraint) 122 | print(acc1_first) 123 | 124 | for epoch in tqdm(range(args.num_epochs)): 125 | acc1, loss, att_loss, cr_loss = train(train_loader, model, criterion, optimizer, device, args, constraint) 126 | writer.add_scalar(tag='training acc1', scalar_value=acc1, global_step=epoch) 127 | writer.add_scalar(tag='training loss', scalar_value=loss, global_step=epoch) 128 | writer.add_scalar(tag='training att_loss', scalar_value=att_loss, global_step=epoch) 129 | writer.add_scalar(tag='training cr_loss', scalar_value=cr_loss, global_step=epoch) 130 | acc1, loss = test(test_loader, model, criterion, device, args, constraint) 131 | writer.add_scalar(tag='test acc1', scalar_value=acc1, global_step=epoch) 132 | writer.add_scalar(tag='test loss', scalar_value=loss, global_step=epoch) 133 | 134 | # ---------------------------------------- 135 | # save best model 136 | # ---------------------------------------- 137 | if best_acc is None or best_acc < acc1: 138 | best_acc = acc1 139 | best_epoch = epoch 140 | torch.save(model.state_dict(), os.path.join(args.model_dir, args.result_name + '.pth')) 141 | 142 | print(best_acc) 143 | 144 | scheduler.step(epoch) 145 | 146 | print('COMPLETE !!!') 147 | print('BEST ACC', best_acc) 148 | print('BEST EPOCH', best_epoch) 149 | print('TIME CONSUMED', time.time() - since) 150 | 151 | 152 | def train(train_loader, model, criterion, optimizer, device, args, constraint): 153 | 154 | model.train() 155 | 156 | acc1 = ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device) 157 | 158 | print('LEN OF TRAIN_LOADER') 159 | print(len(train_loader)) 160 | 161 | for samples in tqdm(enumerate(train_loader)): 162 | _, tmp = samples 163 | inputs, labels, names, masks = tmp 164 | inputs = inputs.to(device) 165 | labels = labels.to(device) 166 | patch_size = 16 167 | if (args.model_name == 'pvt'): 168 | patch_size = 32 169 | elif (args.model_name == 'eva'): 170 | patch_size = 14 171 | 172 | outputs = model(inputs) 173 | cr_loss = criterion(outputs, labels) 174 | att_loss = constraint.loss_attention(inputs, outputs, labels, masks, patch_size, device) 175 | # print('cr_loss: ', cr_loss) 176 | # print('att_loss: ', att_loss) 177 | loss = cr_loss + att_loss 178 | 179 | acc1.update(outputs, labels) 180 | optimizer.zero_grad() # 1 181 | loss.backward() # 2 182 | optimizer.step() # 3 183 | 184 | return acc1.compute().item(), loss.item(), att_loss.item(), cr_loss.item() 185 | 186 | 187 | def test(test_loader, model, criterion, device, args, constraint): 188 | 189 | model.eval() 190 | 191 | acc1 = ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device) 192 | 193 | for samples in enumerate(test_loader): 194 | _, tmp = samples 195 | inputs, labels, names = tmp 196 | inputs = inputs.to(device) 197 | labels = labels.to(device) 198 | 199 | with torch.set_grad_enabled(False): 200 | outputs = model(inputs) 201 | cr_loss = criterion(outputs, labels) 202 | loss = cr_loss 203 | 204 | acc1.update(outputs, labels) 205 | 206 | return acc1.compute().item(), loss.item() 207 | 208 | 209 | if __name__ == '__main__': 210 | main() 211 | -------------------------------------------------------------------------------- /doctor/train_grad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from tqdm import tqdm 5 | import numpy as np 6 | import pickle 7 | 8 | import torch 9 | from torch import nn 10 | from torch import optim 11 | from torch.utils.tensorboard import SummaryWriter 12 | from constraint import Constraint 13 | import shutil 14 | from metrics import ildv 15 | 16 | import timm.scheduler as timm_scheduler 17 | 18 | import loaders 19 | import models 20 | from PIL import Image 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description='') 25 | parser.add_argument('--model_name', default='', type=str, help='model name') 26 | parser.add_argument('--data_name', default='', type=str, help='data name') 27 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 28 | parser.add_argument('--num_epochs', default=200, type=int, help='num epochs') 29 | parser.add_argument('--batch_size', default=512, type=int, help='batch size') 30 | parser.add_argument('--model_dir', default='', type=str, help='model dir') 31 | parser.add_argument('--origin_dir', default='', type=str, help='origin dir') 32 | parser.add_argument('--result_name', default='', type=str, help='result name') 33 | parser.add_argument('--train_data_dir', default='', type=str, help='train_data_dir') 34 | parser.add_argument('--test_data_dir', default='', type=str, help='test_data_dir') 35 | parser.add_argument('--log_dir', default='', type=str, help='log dir') 36 | parser.add_argument('--grad_dir', default='', type=str, help='grad dir') 37 | parser.add_argument('--lrp_dir', default='', type=str, help='lrp dir') 38 | parser.add_argument('--mask_dir', default='', type=str, help='mask dir') 39 | args = parser.parse_args() 40 | 41 | # ---------------------------------------- 42 | # basic configuration 43 | # ---------------------------------------- 44 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 45 | 46 | print('DEVICE: ', device) 47 | 48 | train_dir = args.train_data_dir 49 | test_dir = args.test_data_dir 50 | 51 | 52 | if not os.path.exists(args.model_dir): 53 | os.makedirs(args.model_dir) 54 | if os.path.exists(args.log_dir): 55 | shutil.rmtree(args.log_dir) 56 | 57 | print('-' * 50) 58 | print('TRAIN ON:', device) 59 | print('TRAIN DATA DIR:', args.train_data_dir) 60 | print('TEST DATA DIR:', args.test_data_dir) 61 | print('MODEL DIR:', args.model_dir) 62 | print('LOG DIR:', args.log_dir) 63 | print('ORIGIN DIR:', args.origin_dir) 64 | print('RESULT_NAME:', args.result_name + '.pth') 65 | print('-' * 50) 66 | 67 | # ---------------------------------------- 68 | # trainer configuration 69 | # ---------------------------------------- 70 | model = models.load_model(args.model_name, args.data_name, num_classes=args.num_classes) 71 | print(model.load_state_dict(torch.load(args.origin_dir), strict=False)) 72 | # print(model.load_state_dict(torch.load(args.origin_dir, map_location=torch.device('cpu'))['model'], strict=False)) 73 | 74 | model.to(device) 75 | print('MODEL LOAD DONE') 76 | train_mask_dir = args.mask_dir 77 | lrp_dir = args.lrp_dir 78 | grad_dir = args.grad_dir 79 | train_loader = loaders.load_data(train_dir, args.data_name, data_type='train', batch_size=args.batch_size) 80 | print('TRAIN LOAD DONE') 81 | test_loader = loaders.load_data(test_dir, args.data_name, data_type='test', batch_size=args.batch_size) 82 | print('TEST LOAD DONE') 83 | modules = models.load_modules(model=model) 84 | print('MODULE LODE DONE') 85 | 86 | criterion = nn.CrossEntropyLoss() 87 | 88 | batch_size = args.batch_size 89 | learn_rate = 0.0005 * (batch_size / 512) 90 | optimizer = optim.AdamW(params=model.parameters(), lr=learn_rate, weight_decay=0.05, eps=1e-8) 91 | num_steps = int(args.num_epochs * len(train_loader)) 92 | warmup_epochs = 5 93 | warmup_steps = 0 94 | warmup_lr = 1e-06 95 | scheduler = timm_scheduler.CosineLRScheduler( 96 | optimizer, 97 | t_initial=(num_steps - warmup_steps), 98 | lr_min=1e-05, 99 | warmup_lr_init=warmup_lr, 100 | warmup_t=warmup_steps, 101 | cycle_limit=1, 102 | t_in_epochs=False, 103 | ) 104 | 105 | # optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-2) 106 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.num_epochs) 107 | 108 | grad_ratio = 10 109 | att_ratio = 1 110 | lrp_ratio = 1 111 | constraint = Constraint(model_name = args.model_name, modules=modules, grad_path=grad_dir, lrp_path=lrp_dir, alpha=grad_ratio, beta=att_ratio, gamma=lrp_ratio) 112 | 113 | writer = SummaryWriter(args.log_dir) 114 | 115 | # ---------------------------------------- 116 | # each epoch 117 | # ---------------------------------------- 118 | since = time.time() 119 | 120 | best_acc = None 121 | best_epoch = None 122 | 123 | acc1_first, loss_first, grad_loss_first, cr_loss_first = test(test_loader, model, criterion, device, args, constraint) 124 | print(acc1_first) 125 | 126 | for epoch in tqdm(range(args.num_epochs)): 127 | acc1, loss, grad_loss, cr_loss = train(train_loader, model, criterion, optimizer, device, args, constraint) 128 | writer.add_scalar(tag='training acc1', scalar_value=acc1, global_step=epoch) 129 | writer.add_scalar(tag='training loss', scalar_value=loss, global_step=epoch) 130 | writer.add_scalar(tag='training grad_loss', scalar_value=grad_loss, global_step=epoch) 131 | writer.add_scalar(tag='training cr_loss', scalar_value=cr_loss, global_step=epoch) 132 | acc1, loss, grad_loss, cr_loss = test(test_loader, model, criterion, device, args, constraint) 133 | writer.add_scalar(tag='test acc1', scalar_value=acc1, global_step=epoch) 134 | writer.add_scalar(tag='test loss', scalar_value=loss, global_step=epoch) 135 | writer.add_scalar(tag='test grad_loss', scalar_value=grad_loss, global_step=epoch) 136 | writer.add_scalar(tag='test cr_loss', scalar_value=cr_loss, global_step=epoch) 137 | 138 | # ---------------------------------------- 139 | # save best model 140 | # ---------------------------------------- 141 | if best_acc is None or best_acc < acc1: 142 | best_acc = acc1 143 | best_epoch = epoch 144 | torch.save(model.state_dict(), os.path.join(args.model_dir, args.result_name + '.pth')) 145 | 146 | print(best_acc) 147 | 148 | scheduler.step(epoch) 149 | 150 | print('COMPLETE !!!') 151 | print('BEST ACC', best_acc) 152 | print('BEST EPOCH', best_epoch) 153 | print('TIME CONSUMED', time.time() - since) 154 | 155 | 156 | def train(train_loader, model, criterion, optimizer, device, args, constraint): 157 | 158 | model.train() 159 | 160 | acc1 = ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device) 161 | 162 | print('LEN OF TRAIN_LOADER') 163 | print(len(train_loader)) 164 | 165 | for samples in tqdm(enumerate(train_loader)): 166 | _, tmp = samples 167 | inputs, labels, names = tmp 168 | inputs = inputs.to(device) 169 | labels = labels.to(device) 170 | outputs = model(inputs) 171 | cr_loss = criterion(outputs, labels) 172 | grad_loss = constraint.loss_grad(outputs, labels) 173 | # print('grad_loss: ', grad_loss) 174 | # print('cr_loss: ', cr_loss) 175 | loss = cr_loss + grad_loss 176 | 177 | acc1.update(outputs, labels) 178 | optimizer.zero_grad() # 1 179 | loss.backward() # 2 180 | optimizer.step() # 3 181 | 182 | return acc1.compute().item(), loss.item(), grad_loss.item(), cr_loss.item() 183 | 184 | 185 | def test(test_loader, model, criterion, device, args, constraint): 186 | 187 | model.eval() 188 | 189 | acc1 = ildv.MulticlassAccuracy(average='macro', num_classes=args.num_classes).to(device) 190 | 191 | for samples in enumerate(test_loader): 192 | _, tmp = samples 193 | inputs, labels, names = tmp 194 | inputs = inputs.to(device) 195 | labels = labels.to(device) 196 | 197 | outputs = model(inputs) 198 | cr_loss = criterion(outputs, labels) 199 | grad_loss = constraint.loss_grad(outputs, labels) 200 | loss = cr_loss + grad_loss 201 | 202 | acc1.update(outputs, labels) 203 | 204 | return acc1.compute().item(), loss.item(), grad_loss.item(), cr_loss.item() 205 | 206 | 207 | if __name__ == '__main__': 208 | main() 209 | -------------------------------------------------------------------------------- /doctor/constraint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torchvision import transforms 5 | from doctor.grad_calculate import HookModule 6 | import os 7 | 8 | def partial_linear(linear: nn.Linear, inp: torch.Tensor): 9 | weight = linear.weight.to(inp.device) # (o, i) 10 | # bias = linear.bias.to(inp.device) # (o) 11 | 12 | out = torch.einsum('bi,oi->boi', inp, weight) # (b, o, i) 13 | 14 | # out = torch.sum(out, dim=-1) 15 | # out = out + bias 16 | 17 | return out 18 | 19 | class Constraint: 20 | 21 | def __init__(self, model_name, modules, grad_path, lrp_path, alpha, beta, gamma): 22 | # [channel loss, choose linear layers] 23 | if (model_name == 'beit' or model_name == 'pvt' or model_name == 'eva'): 24 | self.grad_module = HookModule(modules[2]) 25 | else: 26 | self.grad_module = HookModule(modules[3]) 27 | # [attention loss, choose MyIdentity to get attention map] 28 | if (model_name == 'cait'): 29 | self.attention_module = HookModule(modules[7]) 30 | elif (model_name == 'eva' or model_name == 'beit'): 31 | self.attention_module = HookModule(modules[5]) 32 | else: 33 | self.attention_module = HookModule(modules[6]) 34 | # [lrp loss] 35 | self.lrp_mlp_module = HookModule(modules[0]) 36 | self.lrp_linear_module = HookModule(modules[1]) 37 | # [channel mask, obtained from high-credit images] 38 | self.channels_masks = None 39 | if os.path.exists(grad_path): 40 | self.channels_masks = torch.from_numpy(np.load(grad_path)).cuda() 41 | print("CHANNEL MASK LOAD SUCCEED") 42 | else: 43 | print("CHANNEL MASK LOAD FAIL !!!") 44 | # self.lrps = torch.from_numpy(np.load(lrp_path)).cuda() 45 | self.grad_ratio = alpha 46 | self.attention_ratio = beta 47 | self.lrp_ratio = gamma 48 | 49 | 50 | def loss_grad(self, outputs, labels): 51 | # [low response channel loss] 52 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 53 | grads = self.grad_module.grads(outputs=-nll_loss) 54 | grads = torch.relu(grads) 55 | activations = self.grad_module.outputs 56 | activations = torch.relu(activations) 57 | # channel_grads = torch.sum(grads, dim=1) 58 | channel_grads = grads[:, 0, :] # grad 59 | channel_ac = activations[:, 0, :] # activation 60 | channel_ac_grads = (grads * activations)[:, 0, :] # activation*grad 61 | # [choose one to calculate channel loss: channel_grads, channel_ac, channel_ac_grads] 62 | loss = torch.sum(torch.sum(channel_grads * torch.index_select(self.channels_masks, 0, labels), dim=-1), dim=-1) 63 | loss = loss / len(labels) 64 | return loss * self.grad_ratio 65 | 66 | 67 | def loss_attention(self, inputs, outputs, labels, masks, patch_size, device): 68 | masks = masks.squeeze(1) 69 | masks = torch.where(masks > 0, 1, 0) 70 | count = 0 71 | for mask in masks: 72 | if (torch.all(mask == 1) == False): 73 | count = count + 1 74 | 75 | # [Through MyIdentity Layer to get attention data] 76 | sample_attentions = self.attention_module.outputs 77 | 78 | image_size = inputs.shape[2] 79 | xy_num = image_size // patch_size 80 | tokens = xy_num * xy_num 81 | 82 | # attention_actual = sample_attentions[:, :, 0, :] # (b, h, t+1, t+1) -> (b, h, t+1) 83 | # attention_actual = attention_actual[:, :, 1:] # (b, h, t+1) -> (b, h, t) 84 | 85 | # ==================== 86 | # [attention mean] 87 | # ==================== 88 | # attention_mean = torch.mean(attention_actual, dim=1) 89 | # ==================== 90 | # [attentions * grads] 91 | # ==================== 92 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 93 | grads = torch.autograd.grad(outputs=-nll_loss, inputs=sample_attentions, retain_graph=True, create_graph=True)[0] #(b, h, t+1, t+1) 94 | grads = torch.relu(grads) 95 | # grads = grads[:, :, 0, :] # (b, h, t+1, t+1) -> (b, h, t+1) 96 | # grads = grads[:, :, 1:] # (b, h, t+1) -> (b, h, t) 97 | attention_weight = sample_attentions * grads 98 | attention_weight = torch.sum(attention_weight, dim=2) 99 | attention_weight = torch.mean(attention_weight, dim=1) 100 | # ==================== 101 | # [attention max] 102 | # ==================== 103 | # attention_actual = torch.where(attention_actual < 0.015, 0, attention_actual) 104 | # attention_value = torch.sum(attention_actual, dim=-1) # (b, h, t) -> (b, h) 105 | # sorted_value, indices = torch.sort(attention_value, dim=-1, descending=True) 106 | # indices = indices[:, 0:3] 107 | # attention_max_0 = attention_actual[torch.arange(attention_actual.size(0)), indices[:, 0], :] # (b, h, t) -> (b, t) 108 | # attention_max_1 = attention_actual[torch.arange(attention_actual.size(0)), indices[:, 1], :] # (b, h, t) -> (b, t) 109 | # attention_max_2 = attention_actual[torch.arange(attention_actual.size(0)), indices[:, 2], :] # (b, h, t) -> (b, t) 110 | # attention_max = torch.stack([attention_max_0, attention_max_1, attention_max_2], dim=0) # (b, t) -> (n, b, t) 111 | # attention_max = torch.mean(attention_max, dim=0) # (n, b, t) -> (b, t) 112 | 113 | # masks 114 | # (b, t, p, p) 115 | sub_masks = masks.reshape(inputs.shape[0], xy_num, patch_size, xy_num, patch_size).swapaxes(2, 3).reshape(inputs.shape[0], tokens, patch_size, patch_size) 116 | # (b, t) 117 | masks_inside = torch.sum(torch.sum(sub_masks, dim=-1), dim=-1) 118 | # (b, t) 119 | attention_ideal = torch.where(masks_inside > 0, 0, 1).to(device) 120 | if (torch.all(attention_ideal == 1)): 121 | attention_ideal = torch.zeros(attention_ideal.shape) 122 | 123 | # real_imgs = [] 124 | # for i, attention_save_tmp in enumerate(attention_ideal): 125 | # if (torch.all(attention_save_tmp == 0) == False): 126 | # real_imgs.append(i) 127 | # attention_save = attention_save_tmp.reshape((7, 7)) 128 | # attention_save = attention_save.cpu().detach().numpy() 129 | # folder_path = '/nfs1/ch/project/td/output/tmp/high/masks/npys' 130 | # my_count = len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]) 131 | # save_path = folder_path + '/' + str(my_count) + '.npy' 132 | # np.save(save_path, attention_save) 133 | # print(real_imgs) 134 | # print('DONE') 135 | 136 | loss = torch.sum(torch.sum(attention_ideal * attention_weight, dim=-1), dim=-1) 137 | if (count != 0): 138 | loss = loss / count 139 | return loss * self.attention_ratio 140 | 141 | def loss_lrp(self, labels): 142 | mlp = self.lrp_mlp_module.module 143 | mlp_inputs = self.lrp_mlp_module.inputs 144 | mlp_head = partial_linear(mlp, mlp_inputs) # mlp_head: (b, o, i) 145 | mlp_head = torch.relu(mlp_head) 146 | linear2 = self.lrp_linear_module.module 147 | linear_inputs = self.lrp_linear_module.inputs 148 | linear_layer = partial_linear(linear2, linear_inputs) # linear_layer: (128, 1024, 2048) 149 | batch_size = len(labels) 150 | 151 | # only correct label to calculate mlp_inputs lrp 152 | mlp_lrp_inputs = mlp_head[torch.arange(mlp_head.size(0)), labels, :] # (b, o, i) ---(labels(b))--> (b, i) 153 | 154 | # lrps: (labels_num, i) max-min normalization 155 | max_lrp = torch.max(self.lrps, dim=-1).values 156 | min_lrp = torch.min(self.lrps, dim=-1).values 157 | tmp = max_lrp - min_lrp 158 | my_lrp = (self.lrps - min_lrp.unsqueeze(-1)) / tmp.unsqueeze(-1) # my_lrp (labels_num, i) 159 | 160 | ideal_lrps = my_lrp[labels] # (labels_num, i) -> (b, i) 161 | # sift the lrp data < 0.2 to 0 162 | ideal_lrps = torch.where(ideal_lrps < 0.20, 0, ideal_lrps) 163 | # reverse 1 and 0 164 | ideal_lrps = torch.where(ideal_lrps == 0, 1, 0) 165 | loss_tmp = torch.sum(mlp_lrp_inputs * ideal_lrps, dim=-1) 166 | loss_tmp = loss_tmp / batch_size 167 | loss = torch.sum(loss_tmp, dim=-1) 168 | return loss * self.lrp_ratio 169 | 170 | 171 | def calculate_loss_channel(channels, grads, labels): 172 | grads = torch.relu(grads) 173 | channel_grads = torch.sum(grads, dim=1) # [batch_size, channels] 174 | loss = torch.sum(torch.sum(channel_grads * torch.index_select(channels, 0, labels), dim=-1), dim=-1) 175 | loss = loss / len(labels) 176 | return loss 177 | 178 | # second linear 179 | # linear_sum = torch.sum(linear_layer, dim=-1) #(128, 1024) 180 | # linear_sum = linear_sum.unsqueeze(2) #(128, 1024,) 181 | # linear_sum = linear_sum.repeat(1, 1, linear_layer.shape[2]) #(128, 1024, 2048) 182 | # linear_ratio = linear_layer / linear_sum #(128, 1024, 2048) 183 | # mlp_inputs_repeat = mlp_lrp_inputs.unsqueeze(2).repeat(1, 1, linear_layer.shape[2]) #(128, 1024, 2048) 184 | # linear_lrp_inputs = torch.sum(linear_ratio * mlp_inputs_repeat, dim=1) #(128, 2048) 185 | # second linear 186 | -------------------------------------------------------------------------------- /doctor/grad_calculate.py: -------------------------------------------------------------------------------- 1 | """ 2 | activation: 3 | each layer 4 | 5 | gradient: 6 | output to each layer 7 | """ 8 | import sys 9 | 10 | sys.path.append('.') 11 | 12 | import argparse 13 | import torch 14 | import numpy as np 15 | import os 16 | from tqdm import tqdm 17 | 18 | import models 19 | import loaders 20 | 21 | 22 | class HookModule: 23 | def __init__(self, module): 24 | self.module = module 25 | self.inputs = None 26 | self.outputs = None 27 | module.register_forward_hook(self._hook) 28 | 29 | def grads(self, outputs, inputs=None, retain_graph=True, create_graph=True): 30 | if inputs is None: 31 | inputs = self.outputs # default the output dim 32 | 33 | return torch.autograd.grad(outputs=outputs, 34 | inputs=inputs, 35 | retain_graph=retain_graph, 36 | create_graph=create_graph)[0] 37 | 38 | def _hook(self, module, inputs, outputs): 39 | self.inputs = inputs[0] 40 | self.outputs = outputs 41 | 42 | 43 | def _normalization(data, axis=None, bot=False): 44 | assert axis in [None, 0, 1] 45 | _max = np.max(data, axis=axis) 46 | if bot: 47 | _min = np.zeros(_max.shape) 48 | else: 49 | _min = np.min(data, axis=axis) 50 | _range = _max - _min 51 | if axis == 1: 52 | _norm = ((data.T - _min) / (_range + 1e-5)).T 53 | else: 54 | _norm = (data - _min) / (_range + 1e-5) 55 | return _norm 56 | 57 | 58 | class GradCalculate: 59 | def __init__(self, modules, num_classes): 60 | self.modules = [HookModule(module) for module in modules] 61 | # [num_modules, num_labels, num_images, channels] 62 | self.grads = [[[] for _ in range(num_classes)] for _ in range(len(modules))] 63 | self.activations = [[[] for _ in range(num_classes)] for _ in range(len(modules))] 64 | 65 | def __call__(self, outputs, labels): 66 | nll_loss = torch.nn.NLLLoss()(outputs, labels) 67 | for layer, module in enumerate(self.modules): 68 | if module.outputs == None: 69 | continue 70 | grads = module.grads(-nll_loss, module.outputs) 71 | grads = torch.relu(grads) 72 | grads = grads.detach().cpu().numpy() 73 | activations = module.outputs 74 | activations = torch.relu(activations) 75 | activations = activations.detach().cpu().numpy() 76 | for b in range(len(labels)): 77 | self.grads[layer][labels[b]].append(grads[b]) 78 | self.activations[layer][labels[b]].append(activations[b]) 79 | 80 | def sift_ac_grad(self, result_path, threshold): 81 | for layer, _ in enumerate(tqdm(self.activations)): 82 | if (layer == 2 or layer == 3): 83 | print('----', layer, '----activation * grad----') 84 | grads = self.grads[layer] 85 | activations = self.activations[layer] 86 | grads = np.asarray(grads) 87 | activations = np.asarray(activations) 88 | if len(grads.shape) == 4: 89 | grads = np.squeeze(grads[:, :, 0, :]) # [num_labels, num_images, channels] 90 | if len(activations.shape) == 4: 91 | activations = np.squeeze(activations[:, :, 0, :]) # [num_labels, num_images, channels] 92 | 93 | values = grads * activations 94 | 95 | masks_array = [] 96 | for value in values: 97 | value = _normalization(value, axis=1) # [num_images, channels] 98 | mask = np.zeros(value.shape) # [num_images, channels] 99 | mask[np.where(value > threshold)] = 1 # [num_images, channels] 100 | mask = np.sum(mask, axis=0) # [channels] 101 | mask = np.where(mask > 3, 1, 0) # [channels] 102 | print(np.sum(mask)) 103 | masks_array.append(mask) 104 | masks = np.stack(masks_array, axis=0) 105 | print('masks: ', masks.shape) 106 | masks_path = os.path.join(result_path, 'att_ac_grad_layer_{}.npy'.format(layer)) 107 | np.save(masks_path, masks) 108 | 109 | def sift_ac(self, result_path, threshold): 110 | for layer, activations in enumerate(tqdm(self.activations)): 111 | if (layer == 2 or layer == 3): 112 | print('----', layer, '----activation----') 113 | activations = np.asarray(activations) 114 | print('activations_shape: ', np.shape(activations)) 115 | if len(activations.shape) == 4: 116 | activations = np.sum(activations, axis=2) # [num_classes, num_images, channels] 117 | 118 | masks_array = [] 119 | for activation in activations: 120 | activation = _normalization(activation, axis=1) # [num_images, channels] 121 | mask = np.zeros(activation.shape) # [num_images, channels] 122 | mask[np.where(activation > threshold)] = 1 # [num_images, channels] 123 | mask = np.sum(mask, axis=0) # [channels] 124 | mask = np.where(mask > 3, 1, 0) # [channels] 125 | print(np.sum(mask)) 126 | masks_array.append(mask) 127 | masks = np.stack(masks_array, axis=0) 128 | print('masks: ', masks.shape) 129 | masks_path = os.path.join(result_path, 'ac_layer_{}.npy'.format(layer)) 130 | np.save(masks_path, masks) 131 | 132 | def sift_grad(self, result_path, threshold): 133 | for layer, grads in enumerate(tqdm(self.grads)): 134 | # if (layer == 2 or layer == 3): 135 | print('----', layer, '----grad----') 136 | grads = np.asarray(grads) 137 | if grads.shape[1] == 0: 138 | continue 139 | print('grads_shape: ', np.shape(grads)) 140 | if len(grads.shape) == 4: 141 | grads = np.sum(grads, axis=2) # [num_classes, num_images, channels] 142 | 143 | # grads = np.sum(grads, axis=1) # [num_classes, channels] 144 | # grads = _normalization(grads, axis=1) 145 | # masks = np.zeros(grads.shape) 146 | # masks[np.where(grads > threshold)] = 1 147 | 148 | masks_array = [] 149 | for grad in grads: 150 | grad = _normalization(grad, axis=1) # [num_images, channels] 151 | mask = np.zeros(grad.shape) # [num_images, channels] 152 | mask[np.where(grad > threshold)] = 1 # [num_images, channels] 153 | mask = np.sum(mask, axis=0) # [channels] 154 | mask = np.where(mask > 3, 1, 0) # [channels] 155 | print(np.sum(mask)) 156 | masks_array.append(mask) 157 | masks = np.stack(masks_array, axis=0) 158 | print('masks: ', masks.shape) 159 | masks_path = os.path.join(result_path, 'layer_{}.npy'.format(layer)) 160 | np.save(masks_path, masks) 161 | 162 | 163 | def main(): 164 | parser = argparse.ArgumentParser(description='') 165 | parser.add_argument('--model_name', default='', type=str, help='model name') 166 | parser.add_argument('--data_name', default='', type=str, help='data name') 167 | parser.add_argument('--in_channels', default='', type=int, help='in channels') 168 | parser.add_argument('--num_classes', default='', type=int, help='num classes') 169 | parser.add_argument('--batch_size', default=512, type=int, help='batch size') 170 | parser.add_argument('--model_path', default='', type=str, help='model path') 171 | parser.add_argument('--data_path', default='', type=str, help='data path') 172 | parser.add_argument('--grad_path', default='', type=str, help='grad path') 173 | parser.add_argument('--theta', default='', type=float, help='theta') 174 | parser.add_argument('--device_index', default='0', type=str, help='device index') 175 | args = parser.parse_args() 176 | 177 | # ---------------------------------------- 178 | # basic configuration 179 | # ---------------------------------------- 180 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 181 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 182 | 183 | if not os.path.exists(args.grad_path): 184 | os.makedirs(args.grad_path) 185 | 186 | print('-' * 50) 187 | print('TRAIN ON:', device) 188 | print('DATA PATH:', args.data_path) 189 | print('RESULT PATH:', args.grad_path) 190 | print('-' * 50) 191 | 192 | # ---------------------------------------- 193 | # model/data configuration 194 | # ---------------------------------------- 195 | model = models.load_model(model_name=args.model_name, data_name=args.data_name, num_classes=args.num_classes) 196 | # print(model.load_state_dict(torch.load(args.model_path, map_location=torch.device('cpu'))['model'], strict=False)) 197 | model.load_state_dict(torch.load(args.model_path)) 198 | model.to(device) 199 | model.eval() 200 | print(model) 201 | data_loader = loaders.load_data(data_dir=args.data_path, data_name=args.data_name, data_type='train', batch_size=args.batch_size) 202 | 203 | modules = models.load_modules(model=model) 204 | 205 | grad_calculate = GradCalculate(modules=modules, num_classes=args.num_classes) 206 | 207 | # ---------------------------------------- 208 | # forward 209 | # ---------------------------------------- 210 | for i, samples in enumerate(tqdm(data_loader)): 211 | inputs, labels, names = samples 212 | inputs = inputs.to(device) 213 | labels = labels.to(device) 214 | outputs = model(inputs) 215 | 216 | grad_calculate(outputs, labels) 217 | 218 | # grad_calculate.sift_ac_grad(result_path=args.grad_path, threshold=args.theta) 219 | # grad_calculate.sift_ac(result_path=args.grad_path, threshold=args.theta) 220 | grad_calculate.sift_grad(result_path=args.grad_path, threshold=args.theta) 221 | 222 | 223 | if __name__ == '__main__': 224 | np.set_printoptions(threshold=np.inf) 225 | main() 226 | -------------------------------------------------------------------------------- /models/tnt.py: -------------------------------------------------------------------------------- 1 | # -------------------------------- 2 | # Modified from timm by QIU Tian 3 | # -------------------------------- 4 | 5 | """ Transformer in Transformer (TNT) in PyTorch 6 | 7 | A PyTorch implement of TNT as described in 8 | 'Transformer in Transformer' - https://arxiv.org/abs/2103.00112 9 | 10 | The official mindspore code is released and available at 11 | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT 12 | """ 13 | 14 | __all__ = [ 15 | 'TNT', 16 | 'tnt_s_patch16_224', 17 | 'tnt_b_patch16_224' 18 | ] 19 | 20 | import math 21 | 22 | import torch 23 | import torch.nn as nn 24 | from timm.layers import _assert, to_2tuple, DropPath, Mlp, trunc_normal_ 25 | from models.tvit import MyIdentity 26 | 27 | 28 | class Attention(nn.Module): 29 | """ Multi-Head Attention 30 | """ 31 | 32 | def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 33 | super().__init__() 34 | self.my_identity = MyIdentity() 35 | self.hidden_dim = hidden_dim 36 | self.num_heads = num_heads 37 | head_dim = hidden_dim // num_heads 38 | self.head_dim = head_dim 39 | self.scale = head_dim ** -0.5 40 | 41 | self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias) 42 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 43 | self.attn_drop = nn.Dropout(attn_drop, inplace=True) 44 | self.proj = nn.Linear(dim, dim) 45 | self.proj_drop = nn.Dropout(proj_drop, inplace=True) 46 | 47 | def forward(self, x): 48 | B, N, C = x.shape 49 | qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 50 | q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) 51 | v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) 52 | 53 | attn = (q @ k.transpose(-2, -1)) * self.scale 54 | attn = attn.softmax(dim=-1) 55 | 56 | attn = self.my_identity(attn) 57 | 58 | attn = self.attn_drop(attn) 59 | 60 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 61 | x = self.proj(x) 62 | x = self.proj_drop(x) 63 | return x 64 | 65 | 66 | class Block(nn.Module): 67 | """ TNT Block 68 | """ 69 | 70 | def __init__( 71 | self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., 72 | qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 73 | super().__init__() 74 | # Inner transformer 75 | self.norm_in = norm_layer(in_dim) 76 | self.attn_in = Attention( 77 | in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, 78 | attn_drop=attn_drop, proj_drop=drop) 79 | 80 | self.norm_mlp_in = norm_layer(in_dim) 81 | self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), 82 | out_features=in_dim, act_layer=act_layer, drop=drop) 83 | 84 | self.norm1_proj = norm_layer(in_dim) 85 | self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) 86 | # Outer transformer 87 | self.norm_out = norm_layer(dim) 88 | self.attn_out = Attention( 89 | dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, 90 | attn_drop=attn_drop, proj_drop=drop) 91 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 92 | 93 | self.norm_mlp = norm_layer(dim) 94 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), 95 | out_features=dim, act_layer=act_layer, drop=drop) 96 | 97 | def forward(self, pixel_embed, patch_embed): 98 | # inner 99 | pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed))) 100 | pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) 101 | # outer 102 | B, N, C = patch_embed.size() 103 | patch_embed = torch.cat( 104 | [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))], 105 | dim=1) 106 | patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) 107 | patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) 108 | return pixel_embed, patch_embed 109 | 110 | 111 | class PixelEmbed(nn.Module): 112 | """ Image to Pixel Embedding 113 | """ 114 | 115 | def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): 116 | super().__init__() 117 | img_size = to_2tuple(img_size) 118 | patch_size = to_2tuple(patch_size) 119 | # grid_size property necessary for resizing positional embedding 120 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 121 | num_patches = (self.grid_size[0]) * (self.grid_size[1]) 122 | self.img_size = img_size 123 | self.num_patches = num_patches 124 | self.in_dim = in_dim 125 | new_patch_size = [math.ceil(ps / stride) for ps in patch_size] 126 | self.new_patch_size = new_patch_size 127 | 128 | self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) 129 | self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) 130 | 131 | def forward(self, x, pixel_pos): 132 | B, C, H, W = x.shape 133 | _assert(H == self.img_size[0], 134 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") 135 | _assert(W == self.img_size[1], 136 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") 137 | x = self.proj(x) 138 | x = self.unfold(x) 139 | x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) 140 | x = x + pixel_pos 141 | x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) 142 | return x 143 | 144 | 145 | class TNT(nn.Module): 146 | """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 147 | """ 148 | 149 | def __init__( 150 | self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', 151 | embed_dim=768, in_dim=48, depth=12, num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, 152 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): 153 | super().__init__() 154 | assert global_pool in ('', 'token', 'avg') 155 | self.num_classes = num_classes 156 | self.global_pool = global_pool 157 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 158 | 159 | self.pixel_embed = PixelEmbed( 160 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) 161 | num_patches = self.pixel_embed.num_patches 162 | self.num_patches = num_patches 163 | new_patch_size = self.pixel_embed.new_patch_size 164 | num_pixel = new_patch_size[0] * new_patch_size[1] 165 | 166 | self.norm1_proj = norm_layer(num_pixel * in_dim) 167 | self.proj = nn.Linear(num_pixel * in_dim, embed_dim) 168 | self.norm2_proj = norm_layer(embed_dim) 169 | 170 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 171 | self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 172 | self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1])) 173 | self.pos_drop = nn.Dropout(p=drop_rate) 174 | 175 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 176 | blocks = [] 177 | for i in range(depth): 178 | blocks.append(Block( 179 | dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, 180 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, 181 | drop_path=dpr[i], norm_layer=norm_layer)) 182 | self.blocks = nn.ModuleList(blocks) 183 | self.norm = norm_layer(embed_dim) 184 | 185 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 186 | 187 | trunc_normal_(self.cls_token, std=.02) 188 | trunc_normal_(self.patch_pos, std=.02) 189 | trunc_normal_(self.pixel_pos, std=.02) 190 | self.apply(self._init_weights) 191 | 192 | def _init_weights(self, m): 193 | if isinstance(m, nn.Linear): 194 | trunc_normal_(m.weight, std=.02) 195 | if isinstance(m, nn.Linear) and m.bias is not None: 196 | nn.init.constant_(m.bias, 0) 197 | elif isinstance(m, nn.LayerNorm): 198 | nn.init.constant_(m.bias, 0) 199 | nn.init.constant_(m.weight, 1.0) 200 | 201 | def forward_features(self, x): 202 | B = x.shape[0] 203 | pixel_embed = self.pixel_embed(x, self.pixel_pos) 204 | 205 | patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) 206 | patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) 207 | patch_embed = patch_embed + self.patch_pos 208 | patch_embed = self.pos_drop(patch_embed) 209 | 210 | for blk in self.blocks: 211 | pixel_embed, patch_embed = blk(pixel_embed, patch_embed) 212 | 213 | patch_embed = self.norm(patch_embed) 214 | return patch_embed 215 | 216 | def forward_head(self, x, pre_logits: bool = False): 217 | if self.global_pool: 218 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 219 | return x if pre_logits else self.head(x) 220 | 221 | def forward(self, x): 222 | x = self.forward_features(x) 223 | x = self.forward_head(x) 224 | return x 225 | 226 | 227 | def tnt_s_imagenet(**kwargs): 228 | model = TNT(img_size=224, 229 | in_chans=3, 230 | patch_size=16, 231 | embed_dim=384, 232 | in_dim=24, 233 | depth=12, 234 | num_heads=6, 235 | in_num_head=4, 236 | qkv_bias=False, 237 | **kwargs) 238 | return model 239 | 240 | def tnt_s_cifar(**kwargs): 241 | model = TNT(img_size=32, 242 | in_chans=3, 243 | patch_size=4, 244 | embed_dim=384, 245 | in_dim=24, 246 | depth=12, 247 | num_heads=6, 248 | in_num_head=4, 249 | qkv_bias=False, 250 | **kwargs) 251 | return model --------------------------------------------------------------------------------