├── README.md ├── args.py ├── compression ├── pruner.py └── speedup.py ├── config.py ├── config_helpers.py ├── data_utils.py ├── general_utils.py ├── main.py ├── models └── modeling_mask2former.py ├── paths.py ├── requirements.txt ├── trainer_utils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # CVPR24 Paper 2 | 3 | Code for CVPR24 Paper - Resource-Efficient Transformer Pruning for Finetuning of Large Models 4 | 5 | Fatih Ilhan, Gong Su, Selim Furkan Tekin, Tiansheng Huang, Sihao Hu, and Ling Liu, "Resource-Efficient Transformer Pruning for Finetuning of Large Models," IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Seattle, USA, Jun. 17-21, 2024. 6 | 7 | ## Setup 8 | 9 | Python 3.10 10 | 11 | Pytorch 2.0.1 12 | 13 | Transformers 4.33 14 | 15 | https://github.com/microsoft/nni 16 | 17 | Please check requirements.txt for the list of other packages. 18 | 19 | ## Usage 20 | 21 | CIFAR, TinyImageNet, GLUE datasets are automatically downloaded. You can download Cityscapes from https://www.cityscapes-dataset.com/ and KITTI from https://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015. 22 | 23 | ### General usage for finetuning with RECAP: 24 | ```python main.py --task --data --arch --init_sparse --iter_sparse - -num_pi -num_pr ``` 25 | 26 | ### Example: Finetune ViT-base at CIFAR100 with 33% pruning and 87.5% masking in 10 iterations: 27 | ```python main.py --task img_class --data cifar100 --arch vit-base --init_sparse 0.33 --iter_sparse -0.875 -num_pi 2 -num_pr 10``` 28 | 29 | ### Example: Finetune Mask2Former at Cityscapes with 50% pruning and 50% masking in 20 iterations: 30 | ```python main.py --task img_seg --data cityscapes --arch m2f --init_sparse 0.5 --iter_sparse -0.5 -num_pi 3 -num_pr 20``` 31 | 32 | ### Example: Finetune BERT-base at CoLA with 17% pruning and 50% masking in 5 iterations: 33 | ```python main.py --task glue --data cola --arch bert-base-uncased --init_sparse 0.17 --iter_sparse -0.5 -num_pi 1 -num_pr 5``` 34 | 35 | ### Parameters 36 | 37 | All pruning/finetuning parameters are controlled from ``config.py``. 38 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | 5 | 6 | def modify_args(args): 7 | if args.device == 'gpu' and args.gpu_idx: 8 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_idx 9 | 10 | args.datetime = format(str(datetime.datetime.now())) 11 | args.mask_finetune_flag = args.iter_sparse_ratio != 0 12 | 13 | if args.task == 'glue': 14 | args.metric_name = "matthews_correlation" if args.data == "cola" else "accuracy" 15 | elif args.task == 'img_class': 16 | args.metric_name = 'accuracy' 17 | elif args.task == 'img_seg': 18 | args.metric_name = 'accuracy' 19 | else: 20 | raise NotImplementedError 21 | 22 | if 'cifar' in args.data: 23 | args.final_eval_split = 'test' 24 | else: 25 | args.final_eval_split = 'val' 26 | 27 | return args 28 | 29 | 30 | model_names = ['bert-base-uncased', 'bert-large-uncased', 'vit-base', 'vit-large', 'm2f'] 31 | 32 | arg_parser = argparse.ArgumentParser(description='Pruning main script') 33 | 34 | exp_group = arg_parser.add_argument_group('exp', 'experiment setting') 35 | exp_group.add_argument('--save_path', default='output', type=str, metavar='SAVE', 36 | help='path to the experiment logging directory') 37 | exp_group.add_argument('--evaluate_from', default=None, type=str, metavar='PATH', 38 | help='path to saved checkpoint (default: none)') 39 | exp_group.add_argument('--run_mode', default='train', type=str, choices=['train', 'evaluate'], help='Script mode') 40 | exp_group.add_argument('--seed', default=0, type=int, help='random seed') 41 | exp_group.add_argument('--gpu_idx', default=None, type=str, help='Index of available GPU') 42 | exp_group.add_argument('--device', default='cuda', type=str, choices=['cpu', 'cuda', 'mps'], help='Device type for finetuning') 43 | exp_group.add_argument('--comp_device', default='cuda', type=str, choices=['cpu', 'cuda', 'mps'], help='Device type for pruning/masking operations') 44 | 45 | # compression related 46 | comp_group = arg_parser.add_argument_group('comp', 'compression setting') 47 | comp_group.add_argument('--num_pruning_rounds', '-num_pr', default=10, type=int) 48 | comp_group.add_argument('--core_res', '-res', default=64, type=float, help='Sparsity resolution') 49 | comp_group.add_argument('--init_sparse_ratio', '-init_sparse', default=0.5, type=float, help='Pruning sparsity') 50 | comp_group.add_argument('--iter_sparse_ratio', '-iter_sparse', default=-0.75, type=float, help='Finetuning sparsity') 51 | comp_group.add_argument('--num_pruning_iters', '-num_pi', default=4, type=int, help='Gradually prune in x iters') 52 | 53 | # dataset related 54 | data_group = arg_parser.add_argument_group('data', 'dataset setting') 55 | data_group.add_argument('--task', metavar='D', default='glue', choices=['glue', 'qa', 'img_class', 'img_seg'], help='task to work on') 56 | data_group.add_argument('--data', metavar='D', default='cola', help='data to work on') 57 | data_group.add_argument('--data_root', metavar='DIR', default='data', help='path to dataset folder (default: data)') 58 | data_group.add_argument('-j', '--workers', default=1, type=int, help='number of data loading workers (default: 1)') 59 | 60 | # model arch related 61 | arch_group = arg_parser.add_argument_group('arch', 'model architecture setting') 62 | arch_group.add_argument('--arch', '-a', metavar='ARCH', default='bert-base-uncased', 63 | type=str, choices=model_names, 64 | help='model architecture: ' + 65 | ' | '.join(model_names) + 66 | ' (default: bert-base-uncased)') 67 | -------------------------------------------------------------------------------- /compression/pruner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from copy import deepcopy 3 | 4 | from compression.speedup import speedup 5 | from config_helpers import * 6 | from general_utils import LogLevel 7 | from nni.contrib.compression.pruning import TaylorPruner 8 | from nni.contrib.compression.utils import TransformersEvaluator 9 | from trainer_utils import * 10 | from utils import get_model_param_keys 11 | 12 | pruner_dispatcher = {'taylor': TaylorPruner} 13 | 14 | 15 | def update_full_model(model, args, config, trainer_state, total_num_steps): 16 | init_model = torch.load(get_path(args, 'INIT_MODEL_PATH'), map_location='cpu') 17 | init_masks = torch.load(get_path(args, 'INIT_MASKS_PATH'), map_location='cpu') 18 | opt_found_flag = False 19 | keys = get_model_param_keys(model) 20 | keys = keys[0] + keys[1] 21 | 22 | if os.path.exists(get_path(args, 'OPT_STATE_PATH')): 23 | opt_states = torch.load(get_path(args, 'OPT_STATE_PATH'), map_location='cpu') 24 | opt_found_flag = True 25 | else: 26 | opt_states = dict([(i, {'step': 0}) for i in range(len(trainer_state.opt_state))]) 27 | model = model.to('cpu') 28 | 29 | if args.mask_finetune_flag: 30 | iter_masks = torch.load(get_path(args, 'ITER_MASKS_PATH'), map_location='cpu') 31 | else: 32 | iter_masks = None 33 | 34 | init_model_state_dict = init_model.state_dict() 35 | model_state_dict = model.state_dict() 36 | 37 | for key, val in model_state_dict.items(): 38 | key_ = '.'.join(key.split('.')[:-1]) 39 | _key = key.split('.')[-1] 40 | 41 | if key not in keys: 42 | continue 43 | 44 | opt_idx = keys.index(key) 45 | 46 | if 'embeddings.mask_token' in key: 47 | continue 48 | 49 | try: 50 | init_mask = init_masks[key_][_key] 51 | except: 52 | # print(f'Could not find init mask for {key}') 53 | init_mask = None 54 | 55 | if 'relative_position_bias_table' in key: 56 | init_mask = init_mask.repeat([model_state_dict[key].shape[0], 1]) 57 | 58 | try: 59 | iter_mask = iter_masks[key_][_key] # update these values 60 | except: 61 | # print(f'Could not find iter mask for {key}') 62 | iter_mask = torch.ones_like(model_state_dict[key]).bool() 63 | 64 | if init_mask is None: # check this 65 | if init_model_state_dict[key].shape != model_state_dict[key].shape: 66 | print(key) 67 | raise RuntimeError 68 | init_model_state_dict[key] = model_state_dict[key] 69 | 70 | if opt_found_flag: 71 | opt_states[opt_idx]['exp_avg'] = trainer_state.opt_state[opt_idx]['exp_avg'].to('cpu') 72 | opt_states[opt_idx]['exp_avg_sq'] = trainer_state.opt_state[opt_idx]['exp_avg_sq'].to('cpu') 73 | 74 | else: 75 | pad_idx = init_mask.flatten().nonzero().squeeze()[iter_mask.flatten() == 1] 76 | mask_padded = torch.zeros_like(init_mask).flatten() 77 | mask_padded[pad_idx] = 1 78 | mask_padded = mask_padded.reshape(init_mask.shape) 79 | 80 | try: 81 | init_model_state_dict[key][mask_padded] = model_state_dict[key][iter_mask].flatten() 82 | except: 83 | # print(f'Could not find update {key}') 84 | pass 85 | 86 | if opt_found_flag: 87 | try: 88 | opt_states[opt_idx]['exp_avg'][mask_padded] = trainer_state.opt_state[opt_idx]['exp_avg'].to('cpu')[iter_mask].flatten() 89 | opt_states[opt_idx]['exp_avg_sq'][mask_padded] = trainer_state.opt_state[opt_idx]['exp_avg_sq'].to('cpu')[iter_mask].flatten() 90 | opt_states[opt_idx]['exp_avg'][~mask_padded] *= 0.9 91 | opt_states[opt_idx]['exp_avg_sq'][~mask_padded] *= 0.999 92 | except: 93 | print(key) 94 | 95 | opt_states[opt_idx]['step'] = int(trainer_state.opt_state[opt_idx]['step'].item() + opt_states[opt_idx]['step']) 96 | if not opt_found_flag: 97 | opt_states[opt_idx]['exp_avg'] = torch.zeros_like(init_model_state_dict[key]) 98 | opt_states[opt_idx]['exp_avg_sq'] = torch.zeros_like(init_model_state_dict[key]) 99 | 100 | init_model.load_state_dict(init_model_state_dict) 101 | torch.save(init_model, get_path(args, 'INIT_MODEL_PATH')) # save the updated model 102 | torch.save(opt_states, get_path(args, 'OPT_STATE_PATH')) # save the updated model 103 | 104 | return init_model 105 | 106 | 107 | def init_pruning(model, args, config, data_content, tag='default', method=None, beta=-1): 108 | training_params = config.get_init_training_params(args.arch, args.data) 109 | pruning_params = config.get_init_pruning_params(args.arch, args.data) 110 | pruning_params['beta'] = beta 111 | 112 | full_masks = None 113 | cur_pruning_params = deepcopy(pruning_params) 114 | num_iters = pruning_params.get('num_iters', 1) 115 | for iter_idx in range(num_iters): 116 | cur_pruning_params['attn']['sparse_ratio'] = pruning_params['attn']['sparse_ratio'] \ 117 | / num_iters / (1 - iter_idx * pruning_params['attn']['sparse_ratio'] / num_iters) 118 | cur_pruning_params['ffn']['sparse_ratio'] = pruning_params['ffn']['sparse_ratio'] \ 119 | / num_iters / (1 - iter_idx * pruning_params['ffn']['sparse_ratio'] / num_iters) 120 | 121 | config_list = get_prune_config_for_attn(args, model, cur_pruning_params['attn']) \ 122 | + get_prune_config_for_ffn(args, model, cur_pruning_params['ffn']) 123 | 124 | for c in config_list: 125 | if not cur_pruning_params['global_flag']: 126 | del c['global_group_id'] 127 | 128 | if method is None: 129 | method = 'taylor' 130 | 131 | if model.device != args.comp_device: 132 | model = model.to(args.comp_device) 133 | 134 | model, masks, _ = prune(model, args, data_content, training_params, cur_pruning_params, 135 | config_list, method, tag=tag, device=args.comp_device) 136 | 137 | if full_masks is None: 138 | full_masks = deepcopy(masks) 139 | else: 140 | for k in masks.keys(): 141 | for k_ in masks[k].keys(): 142 | if full_masks[k][k_] is None: 143 | continue 144 | pad_idx = full_masks[k][k_].flatten().nonzero().squeeze()[masks[k][k_].flatten() == 1] 145 | mask_padded = torch.zeros_like(full_masks[k][k_]).flatten() 146 | mask_padded[pad_idx] = 1 147 | mask_padded = mask_padded.reshape(full_masks[k][k_].shape) 148 | full_masks[k][k_][mask_padded == 0] = False 149 | 150 | if iter_idx == num_iters - 1: 151 | torch.save(model, get_path(args, 'COMPRESSED_MODEL_PATH')) 152 | torch.save(full_masks, get_path(args, 'INIT_MASKS_PATH')) 153 | 154 | return model 155 | 156 | 157 | def iter_pruning(model, args, config, data_content, tag='default', method=None, sparsity_ratio_mul=0): 158 | training_params = config.get_iter_training_params(args.arch, args.data) 159 | pruning_params = config.get_iter_pruning_params(args.arch, args.data) 160 | init_pruning_params = config.get_init_pruning_params(args.arch, args.data) 161 | 162 | pruning_params['beta'] = 1 163 | cur_pruning_params = deepcopy(pruning_params) 164 | 165 | if sparsity_ratio_mul == 0: 166 | if cur_pruning_params['attn']['sparse_ratio'] < 0: 167 | cur_pruning_params['attn']['sparse_ratio'] *= -1 168 | if cur_pruning_params['ffn']['sparse_ratio'] < 0: 169 | cur_pruning_params['ffn']['sparse_ratio'] *= -1 170 | else: 171 | if cur_pruning_params['attn']['sparse_ratio'] < 0: 172 | cur_pruning_params['attn']['sparse_ratio'] *= -1 173 | else: 174 | cur_pruning_params['attn']['sparse_ratio'] += \ 175 | (init_pruning_params['attn']['sparse_ratio'] - 176 | pruning_params['attn']['sparse_ratio']) * sparsity_ratio_mul 177 | if cur_pruning_params['ffn']['sparse_ratio'] < 0: 178 | cur_pruning_params['ffn']['sparse_ratio'] *= -1 179 | else: 180 | cur_pruning_params['ffn']['sparse_ratio'] += \ 181 | (init_pruning_params['ffn']['sparse_ratio'] - 182 | pruning_params['ffn']['sparse_ratio']) * sparsity_ratio_mul 183 | 184 | config_list = get_prune_config_for_attn(args, model, cur_pruning_params['attn']) \ 185 | + get_prune_config_for_ffn(args, model, cur_pruning_params['ffn']) 186 | 187 | for c in config_list: 188 | if 'dependency_group_id' in c.keys(): 189 | del c['dependency_group_id'] 190 | if not cur_pruning_params['global_flag']: 191 | del c['global_group_id'] 192 | 193 | if method is None: 194 | method = 'taylor' 195 | 196 | if model.device != args.comp_device: 197 | model = model.to(args.comp_device) 198 | 199 | model, masks, pruner = prune(model, args, data_content, training_params, cur_pruning_params, 200 | config_list, method, tag=tag, device=args.comp_device, speedup_flag=False) 201 | 202 | torch.save(masks, get_path(args, 'ITER_MASKS_PATH')) 203 | 204 | return model 205 | 206 | 207 | # @profile 208 | def prune(model, args, data_content, training_params, pruning_params, config_list, pruner_method, 209 | tag='default', device='cpu', speedup_flag=True): 210 | training_params = deepcopy(training_params) 211 | training_params['learning_rate'] = 0 212 | trainer = prepare_traced_trainer(model, args, data_content, training_params, for_train_flag=False, 213 | for_eval_flag=False, tag=tag, device=device, send_tag='train') 214 | evaluator = TransformersEvaluator(trainer) 215 | 216 | pruner_init_kwargs = {} 217 | pruner_compress_kwargs = {} 218 | if pruner_method == 'movement': 219 | pruner_init_kwargs = {'warmup_step': pruning_params['warmup_step'], 220 | 'cooldown_begin_step': pruning_params['cooldown_begin_step']} 221 | pruner_compress_kwargs = {'max_steps': pruning_params['cooldown_begin_step'], 222 | 'max_epochs': training_params.get('num_train_epochs', 3)} 223 | elif pruner_method == 'taylor': 224 | pruner_init_kwargs = {'training_steps': pruning_params['training_steps'], 225 | 'beta': pruning_params['beta'], 226 | 'global_flag': pruning_params['global_flag']} 227 | 228 | with LogLevel(logging.ERROR): 229 | pruner = pruner_dispatcher[pruner_method](model, config_list, evaluator, **pruner_init_kwargs) 230 | pruner.compress(**pruner_compress_kwargs) 231 | pruner.unwrap_model() 232 | 233 | masks = pruner.get_masks() 234 | 235 | if speedup_flag: 236 | pruned_model = speedup(args, model.to('cpu'), masks) 237 | else: 238 | pruned_model = None 239 | 240 | return pruned_model, masks, pruner 241 | -------------------------------------------------------------------------------- /compression/speedup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.pytorch_utils import prune_linear_layer 3 | 4 | from paths import get_path 5 | 6 | 7 | def speedup(args, model, masks): 8 | if 'bert' in args.arch: 9 | return speedup_bert(args, model, masks) 10 | elif 'vit' in args.arch: 11 | return speedup_vit(args, model, masks) 12 | elif 'm2f' in args.arch: 13 | return speedup_swin_m2f(args, model, masks) 14 | else: 15 | raise NotImplementedError 16 | 17 | 18 | def speedup_bert(args, model, masks): 19 | if isinstance(model, str): 20 | model = torch.load(model, map_location='cpu') 21 | 22 | if isinstance(masks, str): 23 | masks = torch.load(masks, map_location='cpu') 24 | 25 | def _prune_head_idxs(mask, num_heads): 26 | head_mask = (mask.reshape([num_heads, -1]).sum(-1) == 0.) 27 | return torch.arange(len(head_mask))[head_mask].long().tolist() 28 | 29 | # prune heads 30 | attention_modules = dict( 31 | [(name, module) for name, module in model.named_modules() if name.split('.')[-1] == 'attention']) 32 | for name, att_module in attention_modules.items(): 33 | mask = masks[name + '.self.query']['weight'].to('cpu') 34 | num_heads = att_module.self.num_attention_heads 35 | prune_head_idxs = _prune_head_idxs(mask, num_heads) 36 | att_module.prune_heads(prune_head_idxs) 37 | att_module.pruned_heads = set() 38 | 39 | # prune ffns 40 | module_names = [name for name, _ in model.named_modules()] 41 | for name in module_names: 42 | if name not in masks.keys(): 43 | continue 44 | if 'attention' not in name: 45 | module = model.get_submodule(name) 46 | if 'output' in name: 47 | module = prune_linear_layer(module, masks[name]['weight'].sum(dim=0).nonzero()[:, 0], dim=1) 48 | else: 49 | module = prune_linear_layer(module, masks[name]['weight'].sum(dim=1).nonzero()[:, 0]) 50 | setattr(model.get_submodule('.'.join(name.split('.')[:-1])), name.split('.')[-1], module) 51 | 52 | return model 53 | 54 | 55 | def speedup_vit(args, model, masks): 56 | if isinstance(model, str): 57 | model = torch.load(model, map_location='cpu') 58 | 59 | if isinstance(masks, str): 60 | masks = torch.load(masks, map_location='cpu') 61 | 62 | def _prune_head_idxs(mask, num_heads): 63 | head_mask = (mask.reshape([num_heads, -1]).sum(-1) == 0.) 64 | return torch.arange(len(head_mask))[head_mask].long().tolist() 65 | 66 | # prune heads 67 | attention_modules = dict( 68 | [(name, module) for name, module in model.named_modules() if name.split('.')[-1] == 'attention' and name.split('.')[-2] != 'attention']) 69 | for name, att_module in attention_modules.items(): 70 | mask = masks[name + '.attention.query']['weight'].to('cpu') 71 | num_heads = att_module.attention.num_attention_heads 72 | prune_head_idxs = _prune_head_idxs(mask, num_heads) 73 | att_module.prune_heads(prune_head_idxs) 74 | att_module.pruned_heads = set() 75 | 76 | # prune ffns 77 | module_names = [name for name, _ in model.named_modules()] 78 | for name in module_names: 79 | if name not in masks.keys(): 80 | continue 81 | if 'attention' not in name: 82 | module = model.get_submodule(name) 83 | if 'output' in name: 84 | module = prune_linear_layer(module, masks[name]['weight'].sum(dim=0).nonzero()[:, 0], dim=1) 85 | else: 86 | module = prune_linear_layer(module, masks[name]['weight'].sum(dim=1).nonzero()[:, 0]) 87 | setattr(model.get_submodule('.'.join(name.split('.')[:-1])), name.split('.')[-1], module) 88 | 89 | return model 90 | 91 | 92 | def speedup_swin_m2f(args, model, masks): 93 | if isinstance(model, str): 94 | model = torch.load(model, map_location='cpu') 95 | 96 | if isinstance(masks, str): 97 | masks = torch.load(masks, map_location='cpu') 98 | 99 | def _prune_head_idxs(mask, num_heads): 100 | head_mask = (mask.reshape([num_heads, -1]).sum(-1) == 0.) 101 | return torch.arange(len(head_mask))[head_mask].long().tolist() 102 | 103 | # prune heads 104 | # attention_modules = dict( 105 | # [(name, module) for name, module in model.named_modules() if name.split('.')[-1] == 'attention']) 106 | 107 | attention_modules = dict( 108 | [(name, module) for name, module in model.named_modules() if name.split('.')[-1] == 'attention' and name.split('.')[-2] != 'attention']) 109 | for name, att_module in attention_modules.items(): 110 | mask = masks[name + '.self.query']['weight'].to('cpu') 111 | num_heads = att_module.self.num_attention_heads 112 | prune_head_idxs = _prune_head_idxs(mask, num_heads) 113 | att_module.prune_heads(prune_head_idxs) 114 | rem_heads = [i for i in range(att_module.self.relative_position_bias_table.shape[-1]) if i not in prune_head_idxs] 115 | att_module.self.relative_position_bias_table = torch.nn.Parameter(att_module.self.relative_position_bias_table[:, rem_heads]) 116 | att_module.pruned_heads = set() 117 | 118 | mask = torch.zeros(num_heads, dtype=bool) 119 | mask[rem_heads] = True 120 | masks[name +'.self'] = {'relative_position_bias_table': mask} 121 | 122 | # prune ffns 123 | module_names = [name for name, _ in model.named_modules()] 124 | for name in module_names: 125 | if name not in masks.keys(): 126 | continue 127 | if 'attention' not in name: 128 | module = model.get_submodule(name) 129 | if 'output' in name: 130 | module = prune_linear_layer(module, masks[name]['weight'].sum(dim=0).nonzero()[:, 0], dim=1) 131 | else: 132 | module = prune_linear_layer(module, masks[name]['weight'].sum(dim=1).nonzero()[:, 0]) 133 | setattr(model.get_submodule('.'.join(name.split('.')[:-1])), name.split('.')[-1], module) 134 | 135 | torch.save(masks, get_path(args, 'INIT_MASKS_PATH')) 136 | 137 | return model -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | def __init__(self, args): 3 | self.core_res = args.core_res 4 | self.init_sparse_ratio = args.init_sparse_ratio 5 | self.iter_sparse_ratio = args.iter_sparse_ratio 6 | self.num_pruning_iters = args.num_pruning_iters 7 | 8 | if any([k in args.arch for k in ['bert-base-uncased', 'vit-base', 'm2f']]): 9 | self.hidden_dim = 768 10 | elif any([k in args.arch for k in ['bert-large-uncased', 'vit-large']]): 11 | self.hidden_dim = 1024 12 | 13 | global_flag = True 14 | 15 | self.training_params = { 16 | 'model_default': { 17 | 'data_default': { 18 | 'init': { 19 | 'num_train_epochs': 2, 20 | 'learning_rate': 1e-5 21 | }, 22 | 'iter': { 23 | 'num_train_epochs': 2, 24 | 'learning_rate': 1e-5 25 | } 26 | } 27 | }, 28 | 'vit-base': { 29 | 'data_default': { 30 | 'init': { 31 | 'num_train_epochs': 2, 32 | 'learning_rate': 1e-4 33 | }, 34 | 'iter': { 35 | 'num_train_epochs': 2, 36 | 'learning_rate': 1e-4 37 | } 38 | } 39 | }, 40 | 'vit-large': { 41 | 'data_default': { 42 | 'init': { 43 | 'num_train_epochs': 2, 44 | 'learning_rate': 1e-4 45 | }, 46 | 'iter': { 47 | 'num_train_epochs': 2, 48 | 'learning_rate': 1e-4 49 | } 50 | } 51 | }, 52 | 'm2f': { 53 | 'data_default': { 54 | 'init': { 55 | 'num_train_epochs': 2, 56 | 'learning_rate': 1e-4 57 | }, 58 | 'iter': { 59 | 'num_train_epochs': 2, 60 | 'learning_rate': 1e-4 61 | } 62 | }, 63 | 'cityscapes': { 64 | 'init': { 65 | 'num_train_epochs': 2, 66 | 'learning_rate': 1e-4 67 | }, 68 | 'iter': { 69 | 'num_train_epochs': 2, 70 | 'learning_rate': 1e-4 71 | } 72 | }, 73 | 'kitti': { 74 | 'init': { 75 | 'num_train_epochs': 2, 76 | 'learning_rate': 1e-4 77 | }, 78 | 'iter': { 79 | 'num_train_epochs': 2, 80 | 'learning_rate': 1e-4 81 | } 82 | }, 83 | }, 84 | 'bert-base-uncased': { 85 | 'data_default': { 86 | 'init': { 87 | 'num_train_epochs': 2, 88 | 'learning_rate': 2e-5 89 | }, 90 | 'iter': { 91 | 'num_train_epochs': 2, 92 | 'learning_rate': 2e-5 93 | } 94 | } 95 | }, 96 | 'bert-large-uncased': { 97 | 'data_default': { 98 | 'init': { 99 | 'num_train_epochs': 2, 100 | 'learning_rate': 2e-5 101 | }, 102 | 'iter': { 103 | 'num_train_epochs': 2, 104 | 'learning_rate': 2e-5 105 | } 106 | } 107 | } 108 | } 109 | self.pruning_params = { 110 | 'model_default': { 111 | 'data_default': { 112 | 'init': { 113 | 'training_steps': 10, # taylor 114 | 'global_flag': global_flag, # taylor 115 | 'num_iters': self.num_pruning_iters, # perform taylor in x iters 116 | 'attn': {'sparse_ratio': self.init_sparse_ratio, 117 | 'max_sparse_ratio': 0.85, 118 | 'granularity': [self.core_res, self.hidden_dim]}, 119 | 'ffn': {'sparse_ratio': self.init_sparse_ratio, 120 | 'max_sparse_ratio': 0.85, 121 | 'granularity': [1, self.hidden_dim]} 122 | }, 123 | 'iter': { 124 | 'training_steps': 10, # taylor 125 | 'global_flag': global_flag, # taylor 126 | 'num_iters': 1, # perform taylor in x iters 127 | 'attn': {'sparse_ratio': self.iter_sparse_ratio, 128 | 'granularity': [1, self.hidden_dim]}, 129 | 'ffn': {'sparse_ratio': self.iter_sparse_ratio, 130 | 'granularity': [1, self.hidden_dim]} 131 | } 132 | } 133 | } 134 | } 135 | 136 | def get_init_training_params(self, model_name, data_name): 137 | default_params = self.training_params.get(model_name, self.training_params['model_default'])['data_default']['init'] 138 | data_params = self.training_params.get(model_name, self.training_params['model_default']).get(data_name, {'init': {}})['init'] 139 | return default_params | data_params 140 | 141 | def get_iter_training_params(self, model_name, data_name): 142 | default_params = self.training_params.get(model_name, self.training_params['model_default'])['data_default']['iter'] 143 | data_params = self.training_params.get(model_name, self.training_params['model_default']).get(data_name, {'iter': {}})['iter'] 144 | return default_params | data_params 145 | 146 | def get_init_pruning_params(self, model_name, data_name): 147 | default_params = self.pruning_params.get(model_name, self.pruning_params['model_default'])['data_default']['init'] 148 | data_params = self.pruning_params.get(model_name, self.pruning_params['model_default']).get(data_name, {'init': {}})['init'] 149 | return default_params | data_params 150 | 151 | def get_iter_pruning_params(self, model_name, data_name): 152 | default_params = self.pruning_params.get(model_name, self.pruning_params['model_default'])['data_default']['iter'] 153 | data_params = self.pruning_params.get(model_name, self.pruning_params['model_default']).get(data_name, {'iter': {}})['iter'] 154 | return default_params | data_params 155 | -------------------------------------------------------------------------------- /config_helpers.py: -------------------------------------------------------------------------------- 1 | from transformers.models.bert.modeling_bert import BertLayer 2 | from transformers.models.vit.modeling_vit import ViTLayer 3 | from transformers.models.swin.modeling_swin import SwinLayer 4 | 5 | 6 | def get_prune_config_for_attn(args, model, prune_params_dict): 7 | sparse_ratio = prune_params_dict['sparse_ratio'] 8 | max_sparse_ratio = prune_params_dict.get('max_sparse_ratio', 1) 9 | granularity = prune_params_dict['granularity'] 10 | config_list = [] 11 | 12 | if 'bert' in str(model.__class__): 13 | attention_qkv_str = '.attention.self*' 14 | attention_output_str = '.attention.output.dense' 15 | dep_id = -1 16 | elif 'vit' in str(model.__class__): 17 | attention_qkv_str = '.attention.attention*' 18 | attention_output_str = '.attention.output.dense' 19 | dep_id = -1 20 | elif 'mask2former' in str(model.__class__): 21 | attention_qkv_str = '.attention.self*' 22 | attention_output_str = '.attention.output.dense' 23 | dep_id = -3 24 | else: 25 | raise NotImplementedError 26 | 27 | for name, module in model.named_modules(): 28 | 29 | if 'encoder' in name: 30 | inc = 0 31 | else: 32 | inc = 100 33 | 34 | if isinstance(module, SwinLayer): 35 | if 'm2f' in args.arch: 36 | if '.0.' in name: 37 | granularity_ = [32, 128] 38 | elif '.1.' in name: 39 | granularity_ = [32, 256] 40 | elif '.2.' in name: 41 | granularity_ = [32, 512] 42 | else: 43 | granularity_ = [32, 1024] 44 | else: 45 | if '.0.' in name: 46 | granularity_ = [32, 192] 47 | elif '.1.' in name: 48 | granularity_ = [32, 384] 49 | elif '.2.' in name: 50 | granularity_ = [32, 768] 51 | else: 52 | granularity_ = [32, 1536] 53 | else: 54 | granularity_ = granularity 55 | 56 | if isinstance(module, BertLayer) or isinstance(module, ViTLayer) or isinstance(module, SwinLayer): 57 | config_list.append({'op_types': ['Linear'], 58 | 'op_names_re': [f'{name}{attention_qkv_str}'], 59 | 'dependency_group_id': int(name.split('.')[dep_id]) + inc, 60 | 'sparse_ratio': sparse_ratio, 61 | 'max_sparse_ratio': max_sparse_ratio, 62 | 'granularity': granularity_, 63 | 'global_group_id': inc 64 | }) 65 | config_list.append({'op_names': [f'{name}{attention_output_str}'], 66 | 'dependency_group_id': int(name.split('.')[dep_id]) + inc, 67 | 'sparse_ratio': sparse_ratio, 68 | 'max_sparse_ratio': max_sparse_ratio, 69 | 'granularity': list(reversed(granularity_)), 70 | 'global_group_id': inc 71 | }) 72 | return config_list 73 | 74 | 75 | def get_prune_config_for_ffn(args, model, prune_params_dict): 76 | sparse_ratio = prune_params_dict['sparse_ratio'] 77 | max_sparse_ratio = prune_params_dict.get('max_sparse_ratio', 1) 78 | granularity = prune_params_dict['granularity'] 79 | config_list = [] 80 | 81 | if 'bert' in str(model.__class__): 82 | intermediate_str = '.intermediate.dense' 83 | output_str = '.output.dense' 84 | dep_id = -1 85 | elif 'vit' in str(model.__class__): 86 | intermediate_str = '.intermediate.dense' 87 | output_str = '.output.dense' 88 | dep_id = -1 89 | elif 'mask2former' in str(model.__class__): 90 | intermediate_str = '.intermediate.dense' 91 | output_str = '.output.dense' 92 | dep_id = -3 93 | else: 94 | raise NotImplementedError 95 | 96 | for name, module in model.named_modules(): 97 | 98 | if 'encoder' in name: 99 | inc = 200 100 | else: 101 | inc = 300 102 | 103 | if isinstance(module, SwinLayer): 104 | if 'm2f' in args.arch: 105 | if '.0.' in name: 106 | granularity_ = [32, 128] 107 | elif '.1.' in name: 108 | granularity_ = [32, 256] 109 | elif '.2.' in name: 110 | granularity_ = [32, 512] 111 | else: 112 | granularity_ = [32, 1024] 113 | else: 114 | if '.0.' in name: 115 | granularity_ = [1, 192] 116 | elif '.1.' in name: 117 | granularity_ = [1, 384] 118 | elif '.2.' in name: 119 | granularity_ = [1, 768] 120 | else: 121 | granularity_ = [1, 1536] 122 | else: 123 | granularity_ = granularity 124 | 125 | if isinstance(module, BertLayer) or isinstance(module, ViTLayer) or isinstance(module, SwinLayer): 126 | config_list.append({'op_names': [f'{name}{intermediate_str}'], 127 | 'dependency_group_id': int(name.split('.')[dep_id]) + inc, 128 | 'sparse_ratio': sparse_ratio, 129 | 'max_sparse_ratio': max_sparse_ratio, 130 | 'granularity': granularity_, 131 | 'global_group_id': inc 132 | }) 133 | config_list.append({'op_names': [f'{name}{output_str}'], 134 | 'dependency_group_id': int(name.split('.')[dep_id]) + inc, 135 | 'sparse_ratio': sparse_ratio, 136 | 'max_sparse_ratio': max_sparse_ratio, 137 | 'granularity': list(reversed(granularity_)), 138 | 'global_group_id': inc 139 | }) 140 | return config_list 141 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from copy import copy 5 | 6 | import cv2 7 | from datasets import load_dataset 8 | from torch.utils import data 9 | from torchvision.transforms import (CenterCrop, 10 | Compose, 11 | Normalize, 12 | RandomHorizontalFlip, 13 | RandomResizedCrop, 14 | Resize, 15 | ToTensor) 16 | 17 | skip_exec = True 18 | 19 | 20 | def prepare_datasets(model_name: str, task_name: str, data_name: str, tokenizer, cache_dir: str, eval_key: str = 'val'): 21 | if task_name == 'glue': 22 | return prepare_datasets_glue(model_name, data_name, tokenizer, cache_dir, eval_key) 23 | elif task_name == 'img_class': 24 | if 'cifar' in data_name: 25 | return prepare_datasets_cifar(model_name, data_name, tokenizer, cache_dir, eval_key) 26 | elif data_name == 'tinyimagenet': 27 | return prepare_datasets_tinyimagenet(model_name, data_name, tokenizer, cache_dir, eval_key) 28 | else: 29 | raise NotImplementedError 30 | elif task_name == 'img_seg': 31 | if 'cityscapes' in data_name: 32 | return prepare_datasets_cityscapes(model_name, data_name, tokenizer, cache_dir, eval_key) 33 | elif 'kitti' in data_name: 34 | return prepare_datasets_kitti(model_name, data_name, tokenizer, cache_dir, eval_key) 35 | else: 36 | raise NotImplementedError 37 | 38 | 39 | def prepare_datasets_glue(model_name: str, data_name: str, tokenizer, cache_dir: str, eval_key: str = 'val'): 40 | task_to_keys = { 41 | 'cola': ('sentence', None), 42 | 'mnli': ('premise', 'hypothesis'), 43 | 'mrpc': ('sentence1', 'sentence2'), 44 | 'qnli': ('question', 'sentence'), 45 | 'qqp': ('question1', 'question2'), 46 | 'rte': ('sentence1', 'sentence2'), 47 | 'sst2': ('sentence', None), 48 | 'stsb': ('sentence1', 'sentence2'), 49 | 'wnli': ('sentence1', 'sentence2'), 50 | } 51 | sentence1_key, sentence2_key = task_to_keys[data_name] 52 | 53 | # used to preprocess the raw data 54 | def preprocess_function(examples): 55 | args = ( 56 | (examples[sentence1_key],) if sentence2_key is None else ( 57 | examples[sentence1_key], examples[sentence2_key]) 58 | ) 59 | result = tokenizer(*args, padding=False, max_length=128, truncation=True) 60 | 61 | if 'label' in examples: 62 | result['labels'] = examples['label'] 63 | return result 64 | 65 | raw_datasets = load_dataset('glue', data_name, cache_dir=cache_dir) 66 | 67 | if eval_key == 'val': 68 | for key in list(raw_datasets.keys()): 69 | if 'test' in key: 70 | raw_datasets.pop(key) 71 | 72 | column_names = raw_datasets['train'].column_names 73 | processed_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=column_names) 74 | 75 | if data_name == 'mnli': 76 | if eval_key == 'test': 77 | validation_datasets = { 78 | 'test_matched': processed_datasets['validation_matched'], 79 | 'test_mismatched': processed_datasets['validation_mismatched'] 80 | } 81 | else: 82 | validation_datasets = { 83 | 'validation_matched': processed_datasets['validation_matched'], 84 | 'validation_mismatched': processed_datasets['validation_mismatched'] 85 | } 86 | else: 87 | if eval_key == 'test': 88 | validation_datasets = { 89 | 'test': processed_datasets['test'] 90 | } 91 | else: 92 | validation_datasets = { 93 | 'validation': processed_datasets['validation'] 94 | } 95 | 96 | return processed_datasets['train'], validation_datasets, None 97 | 98 | 99 | def prepare_datasets_cifar(model_name: str, data_name: str, tokenizer, cache_dir: str, eval_key: str = 'val'): 100 | train_ds, test_ds = load_dataset(data_name, cache_dir=cache_dir, split=['train', 'test']) 101 | # split up training into training + validation 102 | splits = train_ds.train_test_split(test_size=0.1) 103 | train_ds = splits['train'] 104 | val_ds = splits['test'] 105 | 106 | image_mean, image_std = tokenizer.image_mean, tokenizer.image_std 107 | size = 224 108 | 109 | normalize = Normalize(mean=image_mean, std=image_std) 110 | _train_transforms = Compose( 111 | [ 112 | RandomResizedCrop(size), 113 | RandomHorizontalFlip(), 114 | ToTensor(), 115 | normalize, 116 | ] 117 | ) 118 | 119 | _val_transforms = Compose( 120 | [ 121 | Resize(size), 122 | CenterCrop(size), 123 | ToTensor(), 124 | normalize, 125 | ] 126 | ) 127 | 128 | def train_transform(examples): 129 | examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']] 130 | return examples 131 | 132 | def val_transform(examples): 133 | examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']] 134 | return examples 135 | 136 | train_ds.set_transform(train_transform) 137 | val_ds.set_transform(val_transform) 138 | test_ds.set_transform(val_transform) 139 | 140 | return train_ds, val_ds, test_ds 141 | 142 | 143 | def prepare_datasets_tinyimagenet(model_name: str, data_name: str, tokenizer, cache_dir: str, eval_key: str = 'val'): 144 | train_ds, test_ds = load_dataset('Maysee/tiny-imagenet', cache_dir=cache_dir, split=['train', 'valid']) 145 | # split up training into training + validation 146 | splits = train_ds.train_test_split(test_size=0.1) 147 | train_ds = splits['train'] 148 | val_ds = splits['test'] 149 | 150 | image_mean, image_std = tokenizer.image_mean, tokenizer.image_std 151 | size = 224 152 | 153 | normalize = Normalize(mean=image_mean, std=image_std) 154 | _train_transforms = Compose( 155 | [ 156 | RandomResizedCrop(size), 157 | RandomHorizontalFlip(), 158 | ToTensor(), 159 | normalize, 160 | ] 161 | ) 162 | 163 | _val_transforms = Compose( 164 | [ 165 | Resize(size), 166 | CenterCrop(size), 167 | ToTensor(), 168 | normalize, 169 | ] 170 | ) 171 | 172 | def train_transform(examples): 173 | examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']] 174 | return examples 175 | 176 | def val_transform(examples): 177 | examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']] 178 | return examples 179 | 180 | train_ds.set_transform(train_transform) 181 | val_ds.set_transform(val_transform) 182 | test_ds.set_transform(val_transform) 183 | 184 | return train_ds, val_ds, test_ds 185 | 186 | 187 | def prepare_datasets_cityscapes(model_name: str, data_name: str, tokenizer, cache_dir: str, eval_key: str = 'val'): 188 | num_classes = 19 189 | 190 | train_tokenizer = copy(tokenizer) 191 | train_mul = 1 192 | train_tokenizer.size = {'height': int(512 * train_mul), 'width': int(1024 * train_mul)} 193 | eval_tokenizer = copy(tokenizer) 194 | eval_mul = 2 195 | eval_tokenizer.size = {'height': int(512 * eval_mul), 'width': int(1024 * eval_mul)} 196 | 197 | train_ds = Cityscapes( 198 | root=cache_dir, 199 | list_path='/list/cityscapes/train.lst', 200 | tokenizer=train_tokenizer, 201 | num_classes=num_classes, 202 | ignore_label=255) 203 | 204 | val_ds = Cityscapes( 205 | root=cache_dir, 206 | list_path='/list/cityscapes/val.lst', 207 | tokenizer=eval_tokenizer, 208 | num_classes=num_classes, 209 | ignore_label=255) 210 | 211 | test_ds = Cityscapes( 212 | root=cache_dir, 213 | list_path='/list/cityscapes/test.lst', 214 | tokenizer=eval_tokenizer, 215 | num_classes=num_classes, 216 | ignore_label=255) 217 | 218 | return train_ds, val_ds, test_ds 219 | 220 | 221 | def prepare_datasets_kitti(model_name: str, data_name: str, tokenizer, cache_dir: str, eval_key: str = 'val'): 222 | num_classes = 19 223 | 224 | train_tokenizer = copy(tokenizer) 225 | train_mul = 1 226 | train_tokenizer.size = {'height': int(375 * train_mul), 'width': int(1242 * train_mul)} 227 | eval_tokenizer = copy(tokenizer) 228 | eval_mul = 1 229 | eval_tokenizer.size = {'height': int(375 * eval_mul), 'width': int(1242 * eval_mul)} 230 | 231 | train_ds = Cityscapes( 232 | root=cache_dir, 233 | list_path='/list/kitti/train.lst', 234 | tokenizer=train_tokenizer, 235 | num_classes=num_classes, 236 | ignore_label=255) 237 | 238 | val_ds = Cityscapes( 239 | root=cache_dir, 240 | list_path='/list/kitti/val.lst', 241 | tokenizer=eval_tokenizer, 242 | num_classes=num_classes, 243 | ignore_label=255) 244 | 245 | return train_ds, val_ds, val_ds 246 | 247 | 248 | class Cityscapes(data.Dataset): 249 | def __init__(self, 250 | root, 251 | list_path, 252 | tokenizer, 253 | num_classes=19, 254 | ignore_label=255): 255 | 256 | super(Cityscapes, self).__init__() 257 | 258 | self.tokenizer = tokenizer 259 | self.root = root 260 | self.list_path = list_path 261 | self.num_classes = num_classes 262 | self.img_list = [line.strip().split() for line in open(root + list_path)] 263 | self.files = self.read_files() 264 | 265 | self.label_mapping = {-1: ignore_label, 0: ignore_label, 266 | 1: ignore_label, 2: ignore_label, 267 | 3: ignore_label, 4: ignore_label, 268 | 5: ignore_label, 6: ignore_label, 269 | 7: 0, 8: 1, 9: ignore_label, 270 | 10: ignore_label, 11: 2, 12: 3, 271 | 13: 4, 14: ignore_label, 15: ignore_label, 272 | 16: ignore_label, 17: 5, 18: ignore_label, 273 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 274 | 25: 12, 26: 13, 27: 14, 28: 15, 275 | 29: ignore_label, 30: ignore_label, 276 | 31: 16, 32: 17, 33: 18} 277 | 278 | self.target_mode = False 279 | self.image_mode = False 280 | 281 | def __len__(self): 282 | return len(self.files) 283 | 284 | def read_files(self): 285 | files = [] 286 | if 'test' in self.list_path: 287 | for item in self.img_list: 288 | image_path = item 289 | name = os.path.splitext(os.path.basename(image_path[0]))[0] 290 | files.append({ 291 | "img": image_path[0], 292 | "name": name, 293 | }) 294 | else: 295 | for item in self.img_list: 296 | image_path, label_path = item 297 | name = os.path.splitext(os.path.basename(label_path))[0] 298 | files.append({ 299 | "img": image_path, 300 | "label": label_path, 301 | "name": name, 302 | "weight": 1 303 | }) 304 | return files 305 | 306 | def convert_label(self, label, inverse=False): 307 | temp = label.copy() 308 | if inverse: 309 | for v, k in self.label_mapping.items(): 310 | label[temp == k] = v 311 | else: 312 | for k, v in self.label_mapping.items(): 313 | label[temp == k] = v 314 | return label 315 | 316 | def __getitem__(self, index): 317 | item = self.files[index] 318 | if 'cityscapes' in self.list_path: 319 | folder = 'cityscapes' 320 | else: 321 | folder = 'kitti' 322 | 323 | image = cv2.imread(os.path.join(self.root, folder, item["img"]), cv2.IMREAD_COLOR) 324 | 325 | if 'test' in self.list_path: 326 | return self.tokenizer(image) 327 | 328 | label = cv2.imread(os.path.join(self.root, folder, item["label"]), cv2.IMREAD_GRAYSCALE) 329 | label = self.convert_label(label) 330 | return self.tokenizer(image, label) 331 | -------------------------------------------------------------------------------- /general_utils.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import logging 3 | 4 | 5 | class LogLevel: 6 | 7 | def __init__(self, level): 8 | self.loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 9 | self.old_levels = [logger.level for logger in self.loggers] 10 | self.level = level 11 | 12 | def __enter__(self): 13 | for logger in self.loggers: 14 | logger.setLevel(self.level) 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | for i, logger in enumerate(self.loggers): 18 | logger.setLevel(self.old_levels[i]) 19 | 20 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import math 4 | import shutil 5 | import subprocess 6 | from copy import deepcopy 7 | 8 | import torch.optim 9 | from torch.utils.data import ConcatDataset 10 | from transformers import BertTokenizerFast, DataCollatorWithPadding, \ 11 | ViTImageProcessor, Mask2FormerImageProcessor 12 | 13 | import compression.pruner as compress_p 14 | from args import arg_parser, modify_args 15 | from config import * 16 | from data_utils import prepare_datasets 17 | from trainer_utils import * 18 | from utils import get_model_param_keys 19 | 20 | os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' 21 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 22 | 23 | args = arg_parser.parse_args() 24 | args = modify_args(args) 25 | torch.manual_seed(args.seed) 26 | 27 | tokenizer_dispatcher = { 28 | 'bert-base-uncased': BertTokenizerFast, 29 | 'bert-large-uncased': BertTokenizerFast, 30 | 'vit-base': ViTImageProcessor, 31 | 'vit-large': ViTImageProcessor, 32 | 'm2f': Mask2FormerImageProcessor 33 | } 34 | 35 | 36 | def finetune(model, args, data_content, training_params, model_path=None, for_eval_flag=True, tag='default'): 37 | trainer = prepare_traced_trainer(model, args, data_content, training_params, for_eval_flag=for_eval_flag, tag=tag) 38 | 39 | max_steps = math.ceil(training_params['num_train_epochs'] * len(data_content['train'])) 40 | prepare_masked_trainer(args, trainer, max_steps) 41 | 42 | if os.path.exists(get_path(args, 'OPT_STATE_PATH')): 43 | opt_states = torch.load(get_path(args, 'OPT_STATE_PATH')) 44 | init_masks = torch.load(get_path(args, 'INIT_MASKS_PATH')) 45 | keys = get_model_param_keys(trainer.model) 46 | keys = keys[0] + keys[1] 47 | opt_states_to_load = trainer.optimizer.state_dict() 48 | 49 | for i in range(len(keys)): 50 | 51 | if 'embeddings.mask_token' in keys[i]: 52 | continue 53 | 54 | key_ = '.'.join(keys[i].split('.')[:-1]) 55 | _key = keys[i].split('.')[-1] 56 | 57 | try: 58 | init_mask = init_masks[key_][_key].to('cpu') 59 | except: 60 | # print(f'Could not find init mask for {key}') 61 | init_mask = None 62 | 63 | if init_mask is not None: 64 | if _key == 'weight': 65 | if ('attention' in key_ and ('query' in key_ or 'key' in key_ or 'value' in key_)) or \ 66 | ('intermediate' in key_): 67 | init_mask = init_mask.sum(dim=1).nonzero()[:, 0] 68 | opt_states_to_load['state'][i] = { 69 | 'step': opt_states[i]['step'], 70 | 'exp_avg': opt_states[i]['exp_avg'][init_mask].bfloat16(), 71 | 'exp_avg_sq': opt_states[i]['exp_avg_sq'][init_mask].bfloat16()} 72 | elif 'output' in key_: 73 | init_mask = init_mask.sum(dim=0).nonzero()[:, 0] 74 | opt_states_to_load['state'][i] = { 75 | 'step': opt_states[i]['step'], 76 | 'exp_avg': opt_states[i]['exp_avg'][:, init_mask].bfloat16(), 77 | 'exp_avg_sq': opt_states[i]['exp_avg_sq'][:, init_mask].bfloat16()} 78 | else: 79 | raise NotImplementedError 80 | elif _key == 'relative_position_bias_table': 81 | opt_states_to_load['state'][i] = { 82 | 'step': opt_states[i]['step'], 83 | 'exp_avg': opt_states[i]['exp_avg'][:, init_mask].bfloat16(), 84 | 'exp_avg_sq': opt_states[i]['exp_avg_sq'][:, init_mask].bfloat16()} 85 | else: 86 | if ('attention' in key_ and ('query' in key_ or 'key' in key_ or 'value' in key_)) or \ 87 | ('intermediate' in key_): 88 | init_mask = init_mask.nonzero()[:, 0] 89 | opt_states_to_load['state'][i] = { 90 | 'step': opt_states[i]['step'], 91 | 'exp_avg': opt_states[i]['exp_avg'][init_mask].bfloat16(), 92 | 'exp_avg_sq': opt_states[i]['exp_avg_sq'][init_mask].bfloat16()} 93 | elif 'output' in key_: 94 | opt_states_to_load['state'][i] = { 95 | 'step': opt_states[i]['step'], 96 | 'exp_avg': opt_states[i]['exp_avg'].bfloat16(), 97 | 'exp_avg_sq': opt_states[i]['exp_avg_sq'].bfloat16()} 98 | else: 99 | raise NotImplementedError 100 | 101 | trainer.optimizer.load_state_dict(opt_states_to_load) 102 | 103 | trainer.train() 104 | 105 | trainer_state = trainer.state 106 | trainer_state.opt_state = trainer.optimizer.state_dict()['state'] 107 | 108 | print('Completed finetuning') 109 | if model_path: 110 | torch.save(model, model_path) 111 | print(f'Saved to {model_path}') 112 | 113 | del trainer 114 | 115 | return model, trainer_state 116 | 117 | 118 | def prepare_data(args, eval_key): 119 | if 'vit' in args.arch: 120 | tokenizer = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", cache_dir='cache') 121 | elif 'm2f' in args.arch: 122 | tokenizer = Mask2FormerImageProcessor.from_pretrained("facebook/mask2former-swin-base-IN21k-cityscapes-semantic", cache_dir='cache') 123 | else: 124 | tokenizer = tokenizer_dispatcher[args.arch].from_pretrained(args.arch, cache_dir='cache') 125 | train_dataset, validation_datasets, test_dataset = prepare_datasets(args.arch, args.task, args.data, tokenizer, 126 | args.data_root, eval_key) 127 | 128 | dtype = torch.float32 129 | 130 | if args.task == 'img_class': 131 | def collate_fn_cls(examples): 132 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 133 | if args.data == 'cifar100': 134 | labels = torch.tensor(np.array([example["fine_label"] for example in examples])) 135 | else: 136 | labels = torch.tensor(np.array([example["label"] for example in examples])) 137 | 138 | return {"pixel_values": pixel_values.to(dtype), "labels": labels} 139 | 140 | data_collator = collate_fn_cls 141 | elif args.task == 'img_seg': 142 | def collate_fn_seg(examples): 143 | data = [] 144 | for key in examples[0].keys(): 145 | if key == 'class_labels': 146 | key_ = 'labels' 147 | else: 148 | key_ = key 149 | 150 | if 'labels' in key: 151 | val = [torch.tensor(np.stack(e[key], 0))[0] for e in examples] 152 | else: 153 | val = np.concatenate([np.stack(e[key], 0) for e in examples]) 154 | val = torch.tensor(val).to(dtype) 155 | data.append((key_, val)) 156 | return dict(data) 157 | 158 | data_collator = collate_fn_seg 159 | else: 160 | validation_datasets = ConcatDataset([d for d in validation_datasets.values()]) 161 | data_collator = DataCollatorWithPadding(tokenizer) 162 | 163 | return {'train': train_dataset, 'val': validation_datasets, 'test': test_dataset, 164 | 'collator': data_collator, 'tokenizer': tokenizer} 165 | 166 | 167 | # @profile 168 | def execute_main(args): 169 | model_name = args.arch 170 | 171 | if os.path.exists(get_path(args, 'MAIN_FOLDER_DIR', temp=False)): 172 | shutil.rmtree(get_path(args, 'MAIN_FOLDER_DIR', temp=False)) 173 | Path(get_path(args, 'TRAINER_FOLDER_DIR')).mkdir(exist_ok=True, parents=True) 174 | Path(get_path(args, 'MODEL_FOLDER_DIR')).mkdir(exist_ok=True, parents=True) 175 | 176 | with open(get_path(args, 'ARGS_PATH'), "w") as f: 177 | json.dump(args.__dict__, f, indent=2) 178 | 179 | config = Config(args) 180 | data_content = prepare_data(args, 'val') 181 | 182 | if args.task == 'img_class': 183 | if args.data == 'cifar100': 184 | id2label = {id: label for id, label in enumerate(data_content['train'].features['fine_label'].names)} 185 | else: 186 | id2label = {id: label for id, label in enumerate(data_content['train'].features['label'].names)} 187 | 188 | label2id = {label: id for id, label in id2label.items()} 189 | model = build_model(model_name, args.task, args.data, id2label=id2label, label2id=label2id) 190 | else: 191 | model = build_model(model_name, args.task, args.data) 192 | 193 | torch.save(model, get_path(args, 'INIT_MODEL_PATH')) 194 | total_num_steps = 0 195 | 196 | print('init_prune_0 starts...') 197 | model = compress_p.init_pruning(model, args, config, data_content, tag='init_prune_0', beta=-1) 198 | if args.mask_finetune_flag: 199 | sparsity_ratio_mul = 1 200 | print('iter_prune_0 starts...') 201 | compress_p.iter_pruning(model, args, config, data_content, tag='iter_prune_0', sparsity_ratio_mul=sparsity_ratio_mul) 202 | model = torch.load(get_path(args, 'COMPRESSED_MODEL_PATH'), map_location=args.comp_device) 203 | else: 204 | model = model.to(args.comp_device) 205 | 206 | model_path = get_path(args, 'COMPRESSED_MODEL_PATH') 207 | 208 | print('finetune_0 starts') 209 | model = model.to(args.device) 210 | training_params = deepcopy(config.get_init_training_params(args.arch, args.data)) 211 | 212 | _, trainer_state = finetune(model, args, data_content, training_params, 213 | get_path(args, 'COMPRESSED_MODEL_PATH'), tag='finetune_0') 214 | total_num_steps += trainer_state.global_step 215 | 216 | Path(get_path(args, 'TRAINER_FOLDER_DIR', temp=False) + f'/runs').mkdir(exist_ok=True, parents=True) 217 | try: 218 | os.rename(get_path(args, 'TRAINER_FOLDER_DIR') + f'/runs/finetune_0', 219 | get_path(args, 'TRAINER_FOLDER_DIR', temp=False) + f'/runs/finetune') 220 | except: 221 | pass 222 | 223 | tag = 'validate_0' 224 | print(f'{tag} starts') 225 | val_output = predict(model_path, args, data_content, tag=tag) 226 | val_score = val_output.metrics[f'{tag}_{args.metric_name}'] 227 | 228 | best_val_score = val_score 229 | best_val_output = val_output 230 | subprocess.run(["cp", "-r", get_path(args, 'MODEL_FOLDER_DIR'), get_path(args, 'MAIN_FOLDER_DIR', temp=False)]) 231 | 232 | num_rounds = args.num_pruning_rounds 233 | for i in range(num_rounds): 234 | print(f'Round: {i + 1}/{num_rounds} - Starting full model update...') 235 | init_model = compress_p.update_full_model(model, args, config, trainer_state, total_num_steps) 236 | print(f'Round: {i + 1}/{num_rounds} - Starting init pruning...') 237 | beta_ = -1 238 | model = compress_p.init_pruning(init_model, args, config, data_content, 239 | tag=f'init_prune_{i + 1}', beta=beta_) 240 | del init_model 241 | 242 | if args.mask_finetune_flag: 243 | sparsity_ratio_mul = i / max(1, num_rounds - 1) 244 | print(f'Round: {i + 1}/{num_rounds} - Starting iter pruning with mul: {sparsity_ratio_mul}') 245 | compress_p.iter_pruning(model, args, config, data_content, 246 | tag=f'iter_prune_{i + 1}', 247 | sparsity_ratio_mul=sparsity_ratio_mul) # determine what to update 248 | model = torch.load(get_path(args, 'COMPRESSED_MODEL_PATH'), map_location=args.comp_device) 249 | 250 | training_params = deepcopy(config.get_iter_training_params(args.arch, args.data)) 251 | 252 | print(f'Round: {i + 1}/{num_rounds} - Starting finetuning with initial learning rate ' 253 | f'{training_params["learning_rate"]: .6f}') 254 | 255 | model = model.to(args.device) 256 | _, trainer_state = finetune(model, args, data_content, training_params, 257 | get_path(args, 'COMPRESSED_MODEL_PATH'), 258 | for_eval_flag=False, tag=f'finetune_{i + 1}') 259 | total_num_steps += trainer_state.global_step 260 | 261 | gc.collect() 262 | if args.device == 'mps': 263 | torch.mps.empty_cache() 264 | elif args.device == 'cuda': 265 | torch.cuda.empty_cache() 266 | gc.collect() 267 | 268 | print(f'Round: {i + 1}/{num_rounds} - Validating...') 269 | val_output = predict(model_path, args, data_content, tag=f'validate_{i + 1}') 270 | val_score = val_output.metrics[f'validate_{i + 1}_{args.metric_name}'] 271 | 272 | if val_score >= best_val_score: 273 | best_val_score = val_score 274 | best_val_output = val_output 275 | subprocess.run( 276 | ["cp", "-r", get_path(args, 'MODEL_FOLDER_DIR'), get_path(args, 'MAIN_FOLDER_DIR', temp=False)]) 277 | Path(get_path(args, 'TRAINER_FOLDER_DIR', temp=False) + f'/runs').mkdir(exist_ok=True, parents=True) 278 | 279 | subprocess.run(["rm", "-rf", get_path(args, 'TRAINER_FOLDER_DIR', temp=False) + f'/runs/finetune']) 280 | os.rename(get_path(args, 'TRAINER_FOLDER_DIR') + f'/runs/finetune_{i + 1}', 281 | get_path(args, 'TRAINER_FOLDER_DIR', temp=False) + f'/runs/finetune') 282 | else: 283 | subprocess.run(["rm", "-rf", get_path(args, 'TRAINER_FOLDER_DIR') + f'/runs/finetune_{i + 1}']) 284 | 285 | print('Testing the finetuned model') 286 | model_path = get_path(args, 'COMPRESSED_MODEL_PATH', temp=False) 287 | test_output = predict(model_path, args, data_content, tag=args.final_eval_split) 288 | test_metric = test_output.metrics 289 | 290 | output_metric_dict = {'val_metric': best_val_output.metrics, 291 | 'test_metric': test_metric} 292 | 293 | subprocess.run(["rm", "-rf", get_path(args, 'MODEL_FOLDER_DIR')]) 294 | 295 | return output_metric_dict 296 | 297 | 298 | if __name__ == '__main__': 299 | 300 | run_mode = args.run_mode 301 | 302 | if run_mode == 'train': 303 | output_metric_dict = execute_main(args) 304 | elif run_mode == 'evaluate': 305 | model_path = args.evaluate_from 306 | data_content = prepare_data(args, args.final_eval_split) 307 | output_metric_dict = predict(model_path, args, data_content, tag='test') 308 | else: 309 | raise NotImplementedError 310 | -------------------------------------------------------------------------------- /models/modeling_mask2former.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch Mask2Former model.""" 16 | 17 | import math 18 | import warnings 19 | from dataclasses import dataclass 20 | from typing import Dict, List, Optional, Tuple 21 | 22 | import numpy as np 23 | import torch 24 | from torch import Tensor, nn 25 | 26 | from transformers import AutoBackbone 27 | from transformers.activations import ACT2FN 28 | from transformers.file_utils import ( 29 | ModelOutput, 30 | add_start_docstrings, 31 | add_start_docstrings_to_model_forward, 32 | is_scipy_available, 33 | replace_return_docstrings, 34 | requires_backends, 35 | ) 36 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, SemanticSegmenterOutput 37 | from transformers.modeling_utils import PreTrainedModel 38 | from transformers.utils import logging 39 | from transformers.models.mask2former.configuration_mask2former import Mask2FormerConfig 40 | 41 | from utils import get_confusion_matrix, process_segmenter_output 42 | 43 | if is_scipy_available(): 44 | from scipy.optimize import linear_sum_assignment 45 | 46 | logger = logging.get_logger(__name__) 47 | 48 | 49 | _CONFIG_FOR_DOC = "Mask2FormerConfig" 50 | _CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance" 51 | _IMAGE_PROCESSOR_FOR_DOC = "Mask2FormerImageProcessor" 52 | 53 | MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ 54 | "facebook/mask2former-swin-small-coco-instance", 55 | # See all mask2former models at https://huggingface.co/models?filter=mask2former 56 | ] 57 | 58 | 59 | @dataclass 60 | class Mask2FormerPixelDecoderOutput(ModelOutput): 61 | """ 62 | Mask2Former's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns 63 | the mask features and the multiscale features. 64 | 65 | Args: 66 | multi_scale_features (`tuple(torch.FloatTensor)`): 67 | Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height, 68 | width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder. 69 | mask_features (`torch.FloatTensor`): 70 | Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder 71 | Layer. 72 | attentions (`tuple(torch.FloatTensor)`, *optional*): 73 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 74 | sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed 75 | or when `config.output_attentions=True` 76 | """ 77 | 78 | multi_scale_features: Tuple[torch.FloatTensor] = None 79 | mask_features: torch.FloatTensor = None 80 | attentions: Optional[Tuple[torch.FloatTensor]] = None 81 | 82 | 83 | @dataclass 84 | class Mask2FormerMaskedAttentionDecoderOutput(BaseModelOutputWithCrossAttentions): 85 | """ 86 | Base class for outputs of the Transformer decoder. This class adds two attributes to 87 | BaseModelOutputWithCrossAttentions for mask predictions logits and a tuple of intermediate decoder activations, 88 | i.e. the output of each decoder layer, each of them gone through a layernorm. 89 | 90 | Args: 91 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 92 | Sequence of hidden-states at the output of the last layer of the model. 93 | hidden_states (`tuple(torch.FloatTensor)`, *optional*): 94 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of 95 | shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer 96 | plus the initial embedding outputs. Returned when `output_hidden_states=True`. 97 | attentions (`tuple(torch.FloatTensor)`, *optional*): 98 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 99 | sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in 100 | the self-attention heads. Returned when `output_attentions=True`. 101 | masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`): 102 | Tuple of mask predictions from all layers of the transformer decoder. 103 | intermediate_hidden_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`): 104 | Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a 105 | layernorm. 106 | """ 107 | 108 | last_hidden_state: torch.FloatTensor = None 109 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 110 | attentions: Optional[torch.FloatTensor] = None 111 | masks_queries_logits: Tuple[torch.FloatTensor] = None 112 | intermediate_hidden_states: Tuple[torch.FloatTensor] = None 113 | 114 | 115 | @dataclass 116 | class Mask2FormerPixelLevelModuleOutput(ModelOutput): 117 | """ 118 | Mask2Former's pixel level module output. It returns the output of the encoder (optional) and all hidden states 119 | (multi-scale features) from the `decoder`. By default, the `encoder` is a Swin Backbone and the `decoder` is a 120 | Multi-Scale Deformable Attention based decoder. 121 | 122 | The `decoder_last_hidden_state` are the **per-pixel embeddings** while `decoder_hidden_states` refer to multi-scale 123 | feature maps produced using **multi-scaling strategy** defined in the paper. 124 | 125 | Args: 126 | encoder_last_hidden_state (`torch.FloatTensor`): 127 | Last hidden states (final feature map of shape `(batch_size, num_channels, height, width)`) of the last 128 | stage of the encoder. 129 | encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): 130 | Tuple of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden states (also 131 | called feature maps) of the model at the output of each stage. Returned if output_hidden_states is set to 132 | True. 133 | decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)): 134 | 1/4 scale features from the last Pixel Decoder Layer. 135 | decoder_hidden_states (`tuple(torch.FloatTensor)`): 136 | Tuple of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden states (also 137 | called feature maps) of the model at the output of each stage. 138 | """ 139 | 140 | encoder_last_hidden_state: torch.FloatTensor = None 141 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 142 | decoder_last_hidden_state: torch.FloatTensor = None 143 | decoder_hidden_states: Tuple[torch.FloatTensor] = None 144 | 145 | 146 | @dataclass 147 | class Mask2FormerModelOutput(ModelOutput): 148 | """ 149 | Class for outputs of [`Mask2FormerModel`]. This class returns all the needed hidden states to compute the logits. 150 | 151 | Args: 152 | encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): 153 | Last hidden states (final feature map) of the last stage of the encoder model (backbone). Returned when 154 | `output_hidden_states=True` is passed. 155 | encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): 156 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 157 | shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder 158 | model at the output of each stage. Returned when `output_hidden_states=True` is passed. 159 | pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): 160 | Last hidden states (final feature map) of the last stage of the pixel decoder model. 161 | pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, , *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 162 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 163 | shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel 164 | decoder model at the output of each stage. Returned when `output_hidden_states=True` is passed. 165 | transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`): 166 | Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`. 167 | transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*): 168 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 169 | shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the 170 | transformer decoder at the output of each stage. Returned when `output_hidden_states=True` is passed. 171 | transformer_decoder_intermediate_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`): 172 | Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a 173 | layernorm. 174 | masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`) 175 | Mask Predictions from each layer in the transformer decoder. 176 | attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed): 177 | Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 178 | sequence_length)`. Self attentions weights from transformer decoder. 179 | """ 180 | 181 | encoder_last_hidden_state: torch.FloatTensor = None 182 | pixel_decoder_last_hidden_state: torch.FloatTensor = None 183 | transformer_decoder_last_hidden_state: torch.FloatTensor = None 184 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 185 | pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 186 | transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 187 | transformer_decoder_intermediate_states: Tuple[torch.FloatTensor] = None 188 | masks_queries_logits: Tuple[torch.FloatTensor] = None 189 | attentions: Optional[Tuple[torch.FloatTensor]] = None 190 | 191 | 192 | @dataclass 193 | class Mask2FormerForUniversalSegmentationOutput(ModelOutput): 194 | """ 195 | Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`]. 196 | 197 | This output can be directly passed to [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or 198 | [`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or 199 | [`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see 200 | [`~Mask2FormerImageProcessor] for details regarding usage. 201 | 202 | Args: 203 | loss (`torch.Tensor`, *optional*): 204 | The computed loss, returned when labels are present. 205 | class_queries_logits (`torch.FloatTensor`): 206 | A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each 207 | query. Note the `+ 1` is needed because we incorporate the null class. 208 | masks_queries_logits (`torch.FloatTensor`): 209 | A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each 210 | query. 211 | auxiliary_logits (`List[Dict(str, torch.FloatTensor)]`, *optional*): 212 | List of class and mask predictions from each layer of the transformer decoder. 213 | encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 214 | Last hidden states (final feature map) of the last stage of the encoder model (backbone). 215 | encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 216 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 217 | shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder 218 | model at the output of each stage. 219 | pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 220 | Last hidden states (final feature map) of the last stage of the pixel decoder model. 221 | pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 222 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 223 | shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel 224 | decoder model at the output of each stage. 225 | transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`): 226 | Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`. 227 | transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 228 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of 229 | shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the 230 | transformer decoder at the output of each stage. 231 | attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 232 | Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 233 | sequence_length)`. Self and Cross Attentions weights from transformer decoder. 234 | """ 235 | 236 | loss: Optional[torch.FloatTensor] = None 237 | class_queries_logits: torch.FloatTensor = None 238 | masks_queries_logits: torch.FloatTensor = None 239 | auxiliary_logits: Optional[List[Dict[str, torch.FloatTensor]]] = None 240 | encoder_last_hidden_state: torch.FloatTensor = None 241 | pixel_decoder_last_hidden_state: torch.FloatTensor = None 242 | transformer_decoder_last_hidden_state: torch.FloatTensor = None 243 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 244 | pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 245 | transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None 246 | attentions: Optional[Tuple[torch.FloatTensor]] = None 247 | 248 | 249 | # Copied from transformers.models.detr.modeling_detr._expand_mask 250 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): 251 | """ 252 | Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. 253 | """ 254 | batch_size, source_len = mask.size() 255 | target_len = target_len if target_len is not None else source_len 256 | 257 | expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) 258 | 259 | inverted_mask = 1.0 - expanded_mask 260 | 261 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 262 | 263 | 264 | # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py 265 | def sample_point( 266 | input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs 267 | ) -> torch.Tensor: 268 | """ 269 | A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. 270 | 271 | Args: 272 | input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): 273 | A tensor that contains features map on a height * width grid 274 | point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: 275 | 2)): 276 | A tensor that contains [0, 1] * [0, 1] normalized point coordinates 277 | add_dim (`bool`): 278 | boolean value to keep track of added dimension 279 | 280 | Returns: 281 | point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, 282 | height_grid, width_grid): 283 | A tensor that contains features for points in `point_coordinates`. 284 | """ 285 | if point_coordinates.dim() == 3: 286 | add_dim = True 287 | point_coordinates = point_coordinates.unsqueeze(2) 288 | 289 | # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation 290 | point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs) 291 | if add_dim: 292 | point_features = point_features.squeeze(3) 293 | 294 | return point_features 295 | 296 | 297 | # Copied from transformers.models.maskformer.modeling_maskformer.dice_loss 298 | def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: 299 | r""" 300 | Compute the DICE loss, similar to generalized IOU for masks as follows: 301 | 302 | $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ 303 | 304 | In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow 305 | 306 | $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ 307 | 308 | Args: 309 | inputs (`torch.Tensor`): 310 | A tensor representing a mask. 311 | labels (`torch.Tensor`): 312 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 313 | (0 for the negative class and 1 for the positive class). 314 | num_masks (`int`): 315 | The number of masks present in the current batch, used for normalization. 316 | 317 | Returns: 318 | `torch.Tensor`: The computed loss. 319 | """ 320 | probs = inputs.sigmoid().flatten(1) 321 | numerator = 2 * (probs * labels).sum(-1) 322 | denominator = probs.sum(-1) + labels.sum(-1) 323 | loss = 1 - (numerator + 1) / (denominator + 1) 324 | loss = loss.sum() / num_masks 325 | return loss 326 | 327 | 328 | def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor: 329 | r""" 330 | Args: 331 | inputs (`torch.Tensor`): 332 | A float tensor of arbitrary shape. 333 | labels (`torch.Tensor`): 334 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 335 | (0 for the negative class and 1 for the positive class). 336 | 337 | Returns: 338 | loss (`torch.Tensor`): The computed loss. 339 | """ 340 | criterion = nn.BCEWithLogitsLoss(reduction="none") 341 | cross_entropy_loss = criterion(inputs, labels) 342 | 343 | loss = cross_entropy_loss.mean(1).sum() / num_masks 344 | return loss 345 | 346 | 347 | # Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss 348 | def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: 349 | """ 350 | A pair wise version of the dice loss, see `dice_loss` for usage. 351 | 352 | Args: 353 | inputs (`torch.Tensor`): 354 | A tensor representing a mask 355 | labels (`torch.Tensor`): 356 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 357 | (0 for the negative class and 1 for the positive class). 358 | 359 | Returns: 360 | `torch.Tensor`: The computed loss between each pairs. 361 | """ 362 | inputs = inputs.sigmoid().flatten(1) 363 | numerator = 2 * torch.matmul(inputs, labels.T) 364 | # using broadcasting to get a [num_queries, NUM_CLASSES] matrix 365 | denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] 366 | loss = 1 - (numerator + 1) / (denominator + 1) 367 | return loss 368 | 369 | 370 | def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 371 | r""" 372 | A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. 373 | 374 | Args: 375 | inputs (`torch.Tensor`): 376 | A tensor representing a mask. 377 | labels (`torch.Tensor`): 378 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 379 | (0 for the negative class and 1 for the positive class). 380 | 381 | Returns: 382 | loss (`torch.Tensor`): The computed loss between each pairs. 383 | """ 384 | 385 | height_and_width = inputs.shape[1] 386 | 387 | criterion = nn.BCEWithLogitsLoss(reduction="none") 388 | cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) 389 | cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) 390 | 391 | loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T) 392 | loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T) 393 | loss = loss_pos + loss_neg 394 | loss = loss / height_and_width 395 | return loss 396 | 397 | 398 | # Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py 399 | class Mask2FormerHungarianMatcher(nn.Module): 400 | """This class computes an assignment between the labels and the predictions of the network. 401 | 402 | For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more 403 | predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are 404 | un-matched (and thus treated as non-objects). 405 | """ 406 | 407 | def __init__( 408 | self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544 409 | ): 410 | """Creates the matcher 411 | 412 | Params: 413 | cost_class (`float`, *optional*, defaults to 1.0): 414 | Relative weight of the classification error in the matching cost. 415 | cost_mask (`float`, *optional*, defaults to 1.0): 416 | This is the relative weight of the focal loss of the binary mask in the matching cost. 417 | cost_dice (`float`, *optional*, defaults to 1.0): 418 | This is the relative weight of the dice loss of the binary mask in the matching cost. 419 | num_points (`int`, *optional*, defaults to 12544): 420 | No. of points to sample on which the mask loss will be calculated. The same set of K points are 421 | uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite 422 | matching. 423 | """ 424 | super().__init__() 425 | if cost_class == 0 and cost_mask == 0 and cost_dice == 0: 426 | raise ValueError("All costs cant be 0") 427 | 428 | self.num_points = num_points 429 | self.cost_class = cost_class 430 | self.cost_mask = cost_mask 431 | self.cost_dice = cost_dice 432 | 433 | @torch.no_grad() 434 | def forward( 435 | self, 436 | masks_queries_logits: torch.Tensor, 437 | class_queries_logits: torch.Tensor, 438 | mask_labels: torch.Tensor, 439 | class_labels: torch.Tensor, 440 | ) -> List[Tuple[Tensor]]: 441 | """ 442 | Params: 443 | masks_queries_logits (`torch.Tensor`): 444 | A tensor of dim `batch_size, num_queries, num_labels` with the classification logits. 445 | class_queries_logits (`torch.Tensor`): 446 | A tensor of dim `batch_size, num_queries, height, width` with the predicted masks. 447 | class_labels (`torch.Tensor`): 448 | A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the 449 | target) containing the class labels. 450 | mask_labels (`torch.Tensor`): 451 | A tensor of dim `num_target_boxes, height, width` containing the target masks. 452 | 453 | Returns: 454 | matched_indices (`List[Tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j) 455 | where: 456 | - index_i is the indices of the selected predictions (in order) 457 | - index_j is the indices of the corresponding selected labels (in order) 458 | For each batch element, it holds: 459 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes). 460 | """ 461 | indices: List[Tuple[np.array]] = [] 462 | 463 | # iterate through batch size 464 | batch_size = masks_queries_logits.shape[0] 465 | for i in range(batch_size): 466 | pred_probs = class_queries_logits[i].softmax(-1) 467 | pred_mask = masks_queries_logits[i] 468 | 469 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be ommitted. 470 | cost_class = -pred_probs[:, class_labels[i]] 471 | target_mask = mask_labels[i].to(pred_mask) 472 | target_mask = target_mask[:, None] 473 | pred_mask = pred_mask[:, None] 474 | 475 | # Sample ground truth and predicted masks 476 | point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device) 477 | 478 | target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1) 479 | target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) 480 | 481 | pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1) 482 | pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) 483 | 484 | # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels) 485 | cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) 486 | # Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels) 487 | cost_dice = pair_wise_dice_loss(pred_mask, target_mask) 488 | # final cost matrix 489 | cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice 490 | # do the assigmented using the hungarian algorithm in scipy 491 | assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) 492 | indices.append(assigned_indices) 493 | 494 | # It could be stacked in one tensor 495 | matched_indices = [ 496 | (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices 497 | ] 498 | return matched_indices 499 | 500 | 501 | # Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py 502 | class Mask2FormerLoss(nn.Module): 503 | def __init__(self, config: Mask2FormerConfig, weight_dict: Dict[str, float]): 504 | """ 505 | The Mask2Former Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we 506 | compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair 507 | of matched ground-truth / prediction (supervise class and mask) 508 | 509 | Args: 510 | config (`Mask2FormerConfig`): 511 | The configuration for Mask2Former model also containing loss calculation specific parameters. 512 | weight_dict (`Dict[str, float]`): 513 | A dictionary of weights to be applied to the different losses. 514 | """ 515 | super().__init__() 516 | requires_backends(self, ["scipy"]) 517 | self.num_labels = config.num_labels 518 | self.weight_dict = weight_dict 519 | 520 | # Weight to apply to the null class 521 | self.eos_coef = config.no_object_weight 522 | empty_weight = torch.ones(self.num_labels + 1) 523 | empty_weight[-1] = self.eos_coef 524 | self.register_buffer("empty_weight", empty_weight) 525 | 526 | # pointwise mask loss parameters 527 | self.num_points = config.train_num_points 528 | self.oversample_ratio = config.oversample_ratio 529 | self.importance_sample_ratio = config.importance_sample_ratio 530 | 531 | self.matcher = Mask2FormerHungarianMatcher( 532 | cost_class=1.0, 533 | cost_dice=config.dice_weight, 534 | cost_mask=config.mask_weight, 535 | num_points=self.num_points, 536 | ) 537 | 538 | def _max_by_axis(self, sizes: List[List[int]]) -> List[int]: 539 | maxes = sizes[0] 540 | for sublist in sizes[1:]: 541 | for index, item in enumerate(sublist): 542 | maxes[index] = max(maxes[index], item) 543 | return maxes 544 | 545 | # Adapted from nested_tensor_from_tensor_list() in original implementation 546 | def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]: 547 | # get the maximum size in the batch 548 | max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) 549 | # compute final size 550 | batch_shape = [len(tensors)] + max_size 551 | batch_size, _, height, width = batch_shape 552 | dtype = tensors[0].dtype 553 | device = tensors[0].device 554 | padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device) 555 | padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) 556 | # pad the tensors to the size of the biggest one 557 | for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): 558 | padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) 559 | padding_mask[: tensor.shape[1], : tensor.shape[2]] = False 560 | 561 | return padded_tensors, padding_masks 562 | 563 | def loss_labels( 564 | self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array] 565 | ) -> Dict[str, Tensor]: 566 | """Compute the losses related to the labels using cross entropy. 567 | 568 | Args: 569 | class_queries_logits (`torch.Tensor`): 570 | A tensor of shape `batch_size, num_queries, num_labels` 571 | class_labels (`List[torch.Tensor]`): 572 | List of class labels of shape `(labels)`. 573 | indices (`Tuple[np.array])`: 574 | The indices computed by the Hungarian matcher. 575 | 576 | Returns: 577 | `Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: 578 | - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. 579 | """ 580 | pred_logits = class_queries_logits 581 | batch_size, num_queries, _ = pred_logits.shape 582 | criterion = nn.CrossEntropyLoss(weight=self.empty_weight) 583 | idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries) 584 | target_classes_o = torch.cat( 585 | [target[j] for target, (_, j) in zip(class_labels, indices)] 586 | ) # shape of (batch_size, num_queries) 587 | target_classes = torch.full( 588 | (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device 589 | ) 590 | target_classes[idx] = target_classes_o 591 | # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) 592 | pred_logits_transposed = pred_logits.transpose(1, 2) 593 | loss_ce = criterion(pred_logits_transposed, target_classes) 594 | losses = {"loss_cross_entropy": loss_ce} 595 | return losses 596 | 597 | def loss_masks( 598 | self, 599 | masks_queries_logits: torch.Tensor, 600 | mask_labels: List[torch.Tensor], 601 | indices: Tuple[np.array], 602 | num_masks: int, 603 | ) -> Dict[str, torch.Tensor]: 604 | """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss. 605 | 606 | Args: 607 | masks_queries_logits (`torch.Tensor`): 608 | A tensor of shape `(batch_size, num_queries, height, width)`. 609 | mask_labels (`torch.Tensor`): 610 | List of mask labels of shape `(labels, height, width)`. 611 | indices (`Tuple[np.array])`: 612 | The indices computed by the Hungarian matcher. 613 | num_masks (`int)`: 614 | The number of masks, used for normalization. 615 | 616 | Returns: 617 | losses (`Dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys: 618 | - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth. 619 | masks. 620 | - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth, 621 | masks. 622 | """ 623 | src_idx = self._get_predictions_permutation_indices(indices) 624 | tgt_idx = self._get_targets_permutation_indices(indices) 625 | # shape (batch_size * num_queries, height, width) 626 | pred_masks = masks_queries_logits[src_idx] 627 | # shape (batch_size, num_queries, height, width) 628 | # pad all and stack the targets to the num_labels dimension 629 | target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) 630 | target_masks = target_masks[tgt_idx] 631 | 632 | # No need to upsample predictions as we are using normalized coordinates 633 | pred_masks = pred_masks[:, None] 634 | target_masks = target_masks[:, None] 635 | 636 | # Sample point coordinates 637 | with torch.no_grad(): 638 | point_coordinates = self.sample_points_using_uncertainty( 639 | pred_masks, 640 | lambda logits: self.calculate_uncertainty(logits), 641 | self.num_points, 642 | self.oversample_ratio, 643 | self.importance_sample_ratio, 644 | ) 645 | 646 | point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1) 647 | 648 | point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1) 649 | 650 | losses = { 651 | "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), 652 | "loss_dice": dice_loss(point_logits, point_labels, num_masks), 653 | } 654 | 655 | del pred_masks 656 | del target_masks 657 | return losses 658 | 659 | def _get_predictions_permutation_indices(self, indices): 660 | # Permute predictions following indices 661 | batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 662 | predictions_indices = torch.cat([src for (src, _) in indices]) 663 | return batch_indices, predictions_indices 664 | 665 | def _get_targets_permutation_indices(self, indices): 666 | # Permute labels following indices 667 | batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 668 | target_indices = torch.cat([tgt for (_, tgt) in indices]) 669 | return batch_indices, target_indices 670 | 671 | def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor: 672 | """ 673 | In Mask2Former paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits' 674 | for the foreground class in `classes`. 675 | 676 | Args: 677 | logits (`torch.Tensor`): 678 | A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is: 679 | the number of foreground classes. The values are logits. 680 | 681 | Returns: 682 | scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most 683 | uncertain locations having the highest uncertainty score. 684 | """ 685 | uncertainty_scores = -(torch.abs(logits)) 686 | return uncertainty_scores 687 | 688 | def sample_points_using_uncertainty( 689 | self, 690 | logits: torch.Tensor, 691 | uncertainty_function, 692 | num_points: int, 693 | oversample_ratio: int, 694 | importance_sample_ratio: float, 695 | ) -> torch.Tensor: 696 | """ 697 | This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The 698 | uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit 699 | prediction as input. 700 | 701 | Args: 702 | logits (`float`): 703 | Logit predictions for P points. 704 | uncertainty_function: 705 | A function that takes logit predictions for P points and returns their uncertainties. 706 | num_points (`int`): 707 | The number of points P to sample. 708 | oversample_ratio (`int`): 709 | Oversampling parameter. 710 | importance_sample_ratio (`float`): 711 | Ratio of points that are sampled via importance sampling. 712 | 713 | Returns: 714 | point_coordinates (`torch.Tensor`): 715 | Coordinates for P sampled points. 716 | """ 717 | 718 | num_boxes = logits.shape[0] 719 | num_points_sampled = int(num_points * oversample_ratio) 720 | 721 | # Get random point coordinates 722 | point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) 723 | # Get sampled prediction value for the point coordinates 724 | point_logits = sample_point(logits, point_coordinates, align_corners=False) 725 | # Calculate the uncertainties based on the sampled prediction values of the points 726 | point_uncertainties = uncertainty_function(point_logits) 727 | 728 | num_uncertain_points = int(importance_sample_ratio * num_points) 729 | num_random_points = num_points - num_uncertain_points 730 | 731 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 732 | shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) 733 | idx += shift[:, None] 734 | point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) 735 | 736 | if num_random_points > 0: 737 | point_coordinates = torch.cat( 738 | [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], 739 | dim=1, 740 | ) 741 | return point_coordinates 742 | 743 | def forward( 744 | self, 745 | masks_queries_logits: torch.Tensor, 746 | class_queries_logits: torch.Tensor, 747 | mask_labels: List[torch.Tensor], 748 | class_labels: List[torch.Tensor], 749 | auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None, 750 | ) -> Dict[str, torch.Tensor]: 751 | """ 752 | This performs the loss computation. 753 | 754 | Args: 755 | masks_queries_logits (`torch.Tensor`): 756 | A tensor of shape `(batch_size, num_queries, height, width)`. 757 | class_queries_logits (`torch.Tensor`): 758 | A tensor of shape `(batch_size, num_queries, num_labels)`. 759 | mask_labels (`torch.Tensor`): 760 | List of mask labels of shape `(labels, height, width)`. 761 | class_labels (`List[torch.Tensor]`): 762 | List of class labels of shape `(labels)`. 763 | auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*): 764 | if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], then it contains the logits from 765 | the inner layers of the Mask2FormerMaskedAttentionDecoder. 766 | 767 | Returns: 768 | losses (`Dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys: 769 | - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. 770 | - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth 771 | masks. 772 | - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth 773 | masks. 774 | if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], the dictionary contains additional 775 | losses for each auxiliary predictions. 776 | """ 777 | 778 | # retrieve the matching between the outputs of the last layer and the labels 779 | indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) 780 | # compute the average number of target masks for normalization purposes 781 | num_masks = self.get_num_masks(class_labels, device=class_labels[0].device) 782 | # get all the losses 783 | losses: Dict[str, Tensor] = { 784 | **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), 785 | **self.loss_labels(class_queries_logits, class_labels, indices), 786 | } 787 | # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. 788 | if auxiliary_predictions is not None: 789 | for idx, aux_outputs in enumerate(auxiliary_predictions): 790 | masks_queries_logits = aux_outputs["masks_queries_logits"] 791 | class_queries_logits = aux_outputs["class_queries_logits"] 792 | loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels) 793 | loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} 794 | losses.update(loss_dict) 795 | 796 | return losses 797 | 798 | def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: 799 | """ 800 | Computes the average number of target masks across the batch, for normalization purposes. 801 | """ 802 | num_masks = sum([len(classes) for classes in class_labels]) 803 | num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) 804 | return num_masks_pt 805 | 806 | 807 | # Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention 808 | def multi_scale_deformable_attention( 809 | value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor 810 | ) -> Tensor: 811 | batch_size, _, num_heads, hidden_dim = value.shape 812 | _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape 813 | value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) 814 | sampling_grids = 2 * sampling_locations - 1 815 | sampling_value_list = [] 816 | for level_id, (height, width) in enumerate(value_spatial_shapes): 817 | # batch_size, height*width, num_heads, hidden_dim 818 | # -> batch_size, height*width, num_heads*hidden_dim 819 | # -> batch_size, num_heads*hidden_dim, height*width 820 | # -> batch_size*num_heads, hidden_dim, height, width 821 | value_l_ = ( 822 | value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) 823 | ) 824 | # batch_size, num_queries, num_heads, num_points, 2 825 | # -> batch_size, num_heads, num_queries, num_points, 2 826 | # -> batch_size*num_heads, num_queries, num_points, 2 827 | sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) 828 | # batch_size*num_heads, hidden_dim, num_queries, num_points 829 | sampling_value_l_ = nn.functional.grid_sample( 830 | value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False 831 | ) 832 | sampling_value_list.append(sampling_value_l_) 833 | # (batch_size, num_queries, num_heads, num_levels, num_points) 834 | # -> (batch_size, num_heads, num_queries, num_levels, num_points) 835 | # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) 836 | attention_weights = attention_weights.transpose(1, 2).reshape( 837 | batch_size * num_heads, 1, num_queries, num_levels * num_points 838 | ) 839 | output = ( 840 | (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) 841 | .sum(-1) 842 | .view(batch_size, num_heads * hidden_dim, num_queries) 843 | ) 844 | return output.transpose(1, 2).contiguous() 845 | 846 | 847 | # Copied from transformers.models.maskformer.modeling_maskformer.MaskFormerSinePositionEmbedding with MaskFormer->Mask2Former 848 | class Mask2FormerSinePositionEmbedding(nn.Module): 849 | """ 850 | This is a more standard version of the position embedding, very similar to the one used by the Attention is all you 851 | need paper, generalized to work on images. 852 | """ 853 | 854 | def __init__( 855 | self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None 856 | ): 857 | super().__init__() 858 | if scale is not None and normalize is False: 859 | raise ValueError("normalize should be True if scale is passed") 860 | self.num_pos_feats = num_pos_feats 861 | self.temperature = temperature 862 | self.normalize = normalize 863 | self.scale = 2 * math.pi if scale is None else scale 864 | 865 | def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: 866 | if mask is None: 867 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 868 | not_mask = (~mask).to(x.dtype) 869 | y_embed = not_mask.cumsum(1) 870 | x_embed = not_mask.cumsum(2) 871 | if self.normalize: 872 | eps = 1e-6 873 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 874 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 875 | 876 | dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) 877 | dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) 878 | 879 | pos_x = x_embed[:, :, :, None] / dim_t 880 | pos_y = y_embed[:, :, :, None] / dim_t 881 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 882 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 883 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 884 | return pos 885 | 886 | 887 | # Modified from transformers.models.detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention 888 | class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module): 889 | """ 890 | Multiscale deformable attention as proposed in Deformable DETR. 891 | """ 892 | 893 | def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): 894 | super().__init__() 895 | if embed_dim % num_heads != 0: 896 | raise ValueError( 897 | f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}" 898 | ) 899 | dim_per_head = embed_dim // num_heads 900 | # check if dim_per_head is power of 2 901 | if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): 902 | warnings.warn( 903 | "You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the" 904 | " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" 905 | " implementation." 906 | ) 907 | 908 | self.im2col_step = 128 909 | 910 | self.d_model = embed_dim 911 | self.n_levels = n_levels 912 | self.n_heads = num_heads 913 | self.n_points = n_points 914 | 915 | self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2) 916 | self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points) 917 | self.value_proj = nn.Linear(embed_dim, embed_dim) 918 | self.output_proj = nn.Linear(embed_dim, embed_dim) 919 | 920 | def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): 921 | return tensor if position_embeddings is None else tensor + position_embeddings 922 | 923 | def forward( 924 | self, 925 | hidden_states: torch.Tensor, 926 | attention_mask: Optional[torch.Tensor] = None, 927 | encoder_hidden_states=None, 928 | encoder_attention_mask=None, 929 | position_embeddings: Optional[torch.Tensor] = None, 930 | reference_points=None, 931 | spatial_shapes=None, 932 | level_start_index=None, 933 | output_attentions: bool = False, 934 | ): 935 | # add position embeddings to the hidden states before projecting to queries and keys 936 | if position_embeddings is not None: 937 | hidden_states = self.with_pos_embed(hidden_states, position_embeddings) 938 | 939 | batch_size, num_queries, _ = hidden_states.shape 940 | batch_size, sequence_length, _ = encoder_hidden_states.shape 941 | if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: 942 | raise ValueError( 943 | "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" 944 | ) 945 | 946 | value = self.value_proj(encoder_hidden_states) 947 | if attention_mask is not None: 948 | # we invert the attention_mask 949 | value = value.masked_fill(attention_mask[..., None], float(0)) 950 | value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) 951 | sampling_offsets = self.sampling_offsets(hidden_states).view( 952 | batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 953 | ) 954 | attention_weights = self.attention_weights(hidden_states).view( 955 | batch_size, num_queries, self.n_heads, self.n_levels * self.n_points 956 | ) 957 | attention_weights = nn.functional.softmax(attention_weights, -1).view( 958 | batch_size, num_queries, self.n_heads, self.n_levels, self.n_points 959 | ) 960 | # batch_size, num_queries, n_heads, n_levels, n_points, 2 961 | if reference_points.shape[-1] == 2: 962 | offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) 963 | sampling_locations = ( 964 | reference_points[:, :, None, :, None, :] 965 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 966 | ) 967 | elif reference_points.shape[-1] == 4: 968 | sampling_locations = ( 969 | reference_points[:, :, None, :, None, :2] 970 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 971 | ) 972 | else: 973 | raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") 974 | 975 | output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) 976 | output = self.output_proj(output) 977 | 978 | return output, attention_weights 979 | 980 | 981 | class Mask2FormerPixelDecoderEncoderLayer(nn.Module): 982 | def __init__(self, config: Mask2FormerConfig): 983 | super().__init__() 984 | self.embed_dim = config.feature_size 985 | self.self_attn = Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention( 986 | embed_dim=self.embed_dim, 987 | num_heads=config.num_attention_heads, 988 | n_levels=3, 989 | n_points=4, 990 | ) 991 | 992 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 993 | self.dropout = config.dropout 994 | self.activation_fn = nn.functional.relu 995 | self.activation_dropout = config.dropout 996 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim) 997 | self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim) 998 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 999 | 1000 | def forward( 1001 | self, 1002 | hidden_states: torch.Tensor, 1003 | attention_mask: torch.Tensor, 1004 | position_embeddings: torch.Tensor = None, 1005 | reference_points=None, 1006 | spatial_shapes=None, 1007 | level_start_index=None, 1008 | output_attentions: bool = False, 1009 | ): 1010 | """ 1011 | Args: 1012 | hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 1013 | Input to the layer. 1014 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): 1015 | Attention mask. 1016 | position_embeddings (`torch.FloatTensor`, *optional*): 1017 | Position embeddings, to be added to `hidden_states`. 1018 | reference_points (`torch.FloatTensor`, *optional*): 1019 | Reference points. 1020 | spatial_shapes (`torch.LongTensor`, *optional*): 1021 | Spatial shapes of the backbone feature maps. 1022 | level_start_index (`torch.LongTensor`, *optional*): 1023 | Level start index. 1024 | output_attentions (`bool`, *optional*): 1025 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1026 | returned tensors for more detail. 1027 | """ 1028 | residual = hidden_states 1029 | 1030 | # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps. 1031 | hidden_states, attn_weights = self.self_attn( 1032 | hidden_states=hidden_states, 1033 | attention_mask=attention_mask, 1034 | encoder_hidden_states=hidden_states, 1035 | encoder_attention_mask=attention_mask, 1036 | position_embeddings=position_embeddings, 1037 | reference_points=reference_points, 1038 | spatial_shapes=spatial_shapes, 1039 | level_start_index=level_start_index, 1040 | output_attentions=output_attentions, 1041 | ) 1042 | 1043 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1044 | hidden_states = residual + hidden_states 1045 | hidden_states = self.self_attn_layer_norm(hidden_states) 1046 | 1047 | residual = hidden_states 1048 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 1049 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 1050 | 1051 | hidden_states = self.fc2(hidden_states) 1052 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1053 | 1054 | hidden_states = residual + hidden_states 1055 | hidden_states = self.final_layer_norm(hidden_states) 1056 | 1057 | if self.training: 1058 | if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): 1059 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 1060 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 1061 | 1062 | outputs = (hidden_states,) 1063 | 1064 | if output_attentions: 1065 | outputs += (attn_weights.transpose(1, 0),) 1066 | 1067 | return outputs 1068 | 1069 | 1070 | # Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetrEncoder->Mask2FormerPixelDecoderEncoderOnly 1071 | class Mask2FormerPixelDecoderEncoderOnly(nn.Module): 1072 | """ 1073 | Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a 1074 | [`Mask2FormerPixelDecoderEncoderLayer`]. The encoder updates the flattened multi-scale feature maps through 1075 | multiple deformable attention layers. 1076 | 1077 | Args: 1078 | config: Mask2FormerConfig 1079 | """ 1080 | 1081 | def __init__(self, config: Mask2FormerConfig): 1082 | super().__init__() 1083 | 1084 | self.config = config 1085 | self.dropout = config.dropout 1086 | self.layers = nn.ModuleList( 1087 | [Mask2FormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)] 1088 | ) 1089 | 1090 | @staticmethod 1091 | def get_reference_points(spatial_shapes, valid_ratios, device): 1092 | """ 1093 | Get reference points for each feature map. Used in decoder. 1094 | 1095 | Args: 1096 | spatial_shapes (`torch.LongTensor`): 1097 | Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`. 1098 | valid_ratios (`torch.FloatTensor`): 1099 | Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`. 1100 | device (`torch.device`): 1101 | Device on which to create the tensors. 1102 | Returns: 1103 | `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` 1104 | """ 1105 | reference_points_list = [] 1106 | for lvl, (height, width) in enumerate(spatial_shapes): 1107 | ref_y, ref_x = torch.meshgrid( 1108 | torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), 1109 | torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), 1110 | indexing="ij", 1111 | ) 1112 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height) 1113 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width) 1114 | ref = torch.stack((ref_x, ref_y), -1) 1115 | reference_points_list.append(ref) 1116 | 1117 | reference_points = torch.cat(reference_points_list, 1) 1118 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 1119 | 1120 | return reference_points 1121 | 1122 | def forward( 1123 | self, 1124 | inputs_embeds=None, 1125 | attention_mask=None, 1126 | position_embeddings=None, 1127 | spatial_shapes=None, 1128 | level_start_index=None, 1129 | valid_ratios=None, 1130 | output_attentions=None, 1131 | output_hidden_states=None, 1132 | return_dict=None, 1133 | ): 1134 | r""" 1135 | Args: 1136 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 1137 | Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. 1138 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1139 | Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: 1140 | - 1 for pixel features that are real (i.e. **not masked**), 1141 | - 0 for pixel features that are padding (i.e. **masked**). 1142 | [What are attention masks?](../glossary#attention-mask) 1143 | position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 1144 | Position embeddings that are added to the queries and keys in each self-attention layer. 1145 | spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): 1146 | Spatial shapes of each feature map. 1147 | level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): 1148 | Starting index of each feature map. 1149 | valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): 1150 | Ratio of valid area in each feature level. 1151 | output_attentions (`bool`, *optional*): 1152 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1153 | returned tensors for more detail. 1154 | output_hidden_states (`bool`, *optional*): 1155 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1156 | for more detail. 1157 | return_dict (`bool`, *optional*): 1158 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 1159 | """ 1160 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1161 | output_hidden_states = ( 1162 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1163 | ) 1164 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1165 | 1166 | hidden_states = inputs_embeds 1167 | reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) 1168 | 1169 | all_hidden_states = () if output_hidden_states else None 1170 | all_attentions = () if output_attentions else None 1171 | 1172 | for i, encoder_layer in enumerate(self.layers): 1173 | if output_hidden_states: 1174 | all_hidden_states += (hidden_states.transpose(1, 0),) 1175 | 1176 | layer_outputs = encoder_layer( 1177 | hidden_states, 1178 | attention_mask, 1179 | position_embeddings=position_embeddings, 1180 | reference_points=reference_points, 1181 | spatial_shapes=spatial_shapes, 1182 | level_start_index=level_start_index, 1183 | output_attentions=output_attentions, 1184 | ) 1185 | 1186 | hidden_states = layer_outputs[0] 1187 | 1188 | if output_attentions: 1189 | all_attentions = all_attentions + (layer_outputs[1],) 1190 | 1191 | if output_hidden_states: 1192 | all_hidden_states += (hidden_states.transpose(1, 0),) 1193 | 1194 | return BaseModelOutput( 1195 | last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions 1196 | ) 1197 | 1198 | 1199 | # Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrModel with DeformableDetrModel->Mask2FormerPixelDecoder 1200 | class Mask2FormerPixelDecoder(nn.Module): 1201 | def __init__(self, config: Mask2FormerConfig, feature_channels): 1202 | super().__init__() 1203 | 1204 | self.config = config 1205 | 1206 | feature_dim = config.feature_size 1207 | mask_dim = config.mask_feature_size 1208 | num_pos_features = feature_dim // 2 1209 | 1210 | self.position_embedding = Mask2FormerSinePositionEmbedding(num_pos_feats=num_pos_features, normalize=True) 1211 | self.num_feature_levels = 3 1212 | transformer_in_channels = feature_channels[-self.num_feature_levels :] 1213 | 1214 | self.transformer_feature_strides = config.feature_strides[-self.num_feature_levels :] 1215 | self.feature_channels = feature_channels 1216 | self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, feature_dim)) 1217 | 1218 | # Create input projection layers 1219 | if self.num_feature_levels > 1: 1220 | input_projections_list = [] 1221 | for in_channels in transformer_in_channels[::-1]: 1222 | input_projections_list.append( 1223 | nn.Sequential( 1224 | nn.Conv2d(in_channels, feature_dim, kernel_size=1), 1225 | nn.GroupNorm(32, feature_dim), 1226 | ) 1227 | ) 1228 | self.input_projections = nn.ModuleList(input_projections_list) 1229 | else: 1230 | self.input_projections = nn.ModuleList( 1231 | [ 1232 | nn.Sequential( 1233 | nn.Conv2d(transformer_in_channels[-1], feature_dim, kernel_size=1), 1234 | nn.GroupNorm(32, feature_dim), 1235 | ) 1236 | ] 1237 | ) 1238 | 1239 | self.encoder = Mask2FormerPixelDecoderEncoderOnly(config) 1240 | self.mask_projection = nn.Conv2d(feature_dim, mask_dim, kernel_size=1, stride=1, padding=0) 1241 | 1242 | # Extra FPN levels 1243 | stride = min(self.transformer_feature_strides) 1244 | self.common_stride = config.common_stride 1245 | self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) 1246 | 1247 | lateral_convs = [] 1248 | output_convs = [] 1249 | 1250 | for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]): 1251 | lateral_conv = nn.Sequential( 1252 | nn.Conv2d(in_channels, feature_dim, kernel_size=1, bias=False), 1253 | nn.GroupNorm(32, feature_dim), 1254 | ) 1255 | 1256 | output_conv = nn.Sequential( 1257 | nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False), 1258 | nn.GroupNorm(32, feature_dim), 1259 | nn.ReLU(), 1260 | ) 1261 | self.add_module("adapter_{}".format(idx + 1), lateral_conv) 1262 | self.add_module("layer_{}".format(idx + 1), output_conv) 1263 | 1264 | lateral_convs.append(lateral_conv) 1265 | output_convs.append(output_conv) 1266 | 1267 | # Order convolutional layers from low to high resolution 1268 | self.lateral_convolutions = lateral_convs[::-1] 1269 | self.output_convolutions = output_convs[::-1] 1270 | 1271 | def get_valid_ratio(self, mask, dtype=torch.float32): 1272 | """Get the valid ratio of all feature maps.""" 1273 | 1274 | _, height, width = mask.shape 1275 | valid_height = torch.sum(~mask[:, :, 0], 1) 1276 | valid_width = torch.sum(~mask[:, 0, :], 1) 1277 | valid_ratio_heigth = valid_height.to(dtype) / height 1278 | valid_ratio_width = valid_width.to(dtype) / width 1279 | valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) 1280 | return valid_ratio 1281 | 1282 | def forward( 1283 | self, 1284 | features, 1285 | encoder_outputs=None, 1286 | output_attentions=None, 1287 | output_hidden_states=None, 1288 | return_dict=None, 1289 | ): 1290 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1291 | output_hidden_states = ( 1292 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1293 | ) 1294 | 1295 | # Apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) 1296 | input_embeds = [] 1297 | position_embeddings = [] 1298 | for level, x in enumerate(features[::-1][: self.num_feature_levels]): 1299 | input_embeds.append(self.input_projections[level](x)) 1300 | position_embeddings.append(self.position_embedding(x)) 1301 | 1302 | masks = [ 1303 | torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds 1304 | ] 1305 | 1306 | # Prepare encoder inputs (by flattening) 1307 | spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds] 1308 | input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1) 1309 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device) 1310 | masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1) 1311 | 1312 | position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings] 1313 | level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)] 1314 | level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1) 1315 | 1316 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) 1317 | valid_ratios = torch.stack([self.get_valid_ratio(mask, dtype=input_embeds_flat.dtype) for mask in masks], 1) 1318 | 1319 | # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder 1320 | if encoder_outputs is None: 1321 | encoder_outputs = self.encoder( 1322 | inputs_embeds=input_embeds_flat, 1323 | attention_mask=masks_flat, 1324 | position_embeddings=level_pos_embed_flat, 1325 | spatial_shapes=spatial_shapes, 1326 | level_start_index=level_start_index, 1327 | valid_ratios=valid_ratios, 1328 | output_attentions=output_attentions, 1329 | output_hidden_states=output_hidden_states, 1330 | return_dict=return_dict, 1331 | ) 1332 | 1333 | last_hidden_state = encoder_outputs.last_hidden_state 1334 | batch_size = last_hidden_state.shape[0] 1335 | 1336 | split_sizes = [None] * self.num_feature_levels 1337 | for i in range(self.num_feature_levels): 1338 | if i < self.num_feature_levels - 1: 1339 | split_sizes[i] = level_start_index[i + 1] - level_start_index[i] 1340 | else: 1341 | split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i] 1342 | 1343 | encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1) 1344 | 1345 | # Compute final features 1346 | outputs = [ 1347 | x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) 1348 | for i, x in enumerate(encoder_output) 1349 | ] 1350 | 1351 | # Append extra FPN levels to outputs, ordered from low to high resolution 1352 | for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]): 1353 | lateral_conv = self.lateral_convolutions[idx] 1354 | output_conv = self.output_convolutions[idx] 1355 | current_fpn = lateral_conv(feature) 1356 | 1357 | # Following FPN implementation, we use nearest upsampling here 1358 | out = current_fpn + nn.functional.interpolate( 1359 | outputs[-1], size=current_fpn.shape[-2:], mode="bilinear", align_corners=False 1360 | ) 1361 | out = output_conv(out) 1362 | outputs.append(out) 1363 | 1364 | num_cur_levels = 0 1365 | multi_scale_features = [] 1366 | 1367 | for out in outputs: 1368 | if num_cur_levels < self.num_feature_levels: 1369 | multi_scale_features.append(out) 1370 | num_cur_levels += 1 1371 | 1372 | return Mask2FormerPixelDecoderOutput( 1373 | mask_features=self.mask_projection(outputs[-1]), 1374 | multi_scale_features=tuple(multi_scale_features), 1375 | attentions=encoder_outputs.attentions, 1376 | ) 1377 | 1378 | 1379 | class Mask2FormerPixelLevelModule(nn.Module): 1380 | def __init__(self, config: Mask2FormerConfig): 1381 | """ 1382 | Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image 1383 | Segmentation](https://arxiv.org/abs/2112.01527). It runs the input image through a backbone and a pixel 1384 | decoder, generating multi-scale feature maps and pixel embeddings. 1385 | 1386 | Args: 1387 | config ([`Mask2FormerConfig`]): 1388 | The configuration used to instantiate this model. 1389 | """ 1390 | super().__init__() 1391 | 1392 | self.encoder = AutoBackbone.from_config(config.backbone_config) 1393 | self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) 1394 | 1395 | def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: 1396 | backbone_features = self.encoder(pixel_values).feature_maps 1397 | decoder_output = self.decoder(backbone_features, output_hidden_states=output_hidden_states) 1398 | 1399 | return Mask2FormerPixelLevelModuleOutput( 1400 | encoder_last_hidden_state=backbone_features[-1], 1401 | encoder_hidden_states=tuple(backbone_features) if output_hidden_states else None, 1402 | decoder_last_hidden_state=decoder_output.mask_features, 1403 | decoder_hidden_states=decoder_output.multi_scale_features, 1404 | ) 1405 | 1406 | 1407 | # Modified from transformers.models.detr.modeling_detr.DetrAttention with Detr->Mask2Former 1408 | class Mask2FormerAttention(nn.Module): 1409 | """ 1410 | Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and 1411 | keys (as explained in the DETR paper). 1412 | """ 1413 | 1414 | def __init__( 1415 | self, 1416 | embed_dim: int, 1417 | num_heads: int, 1418 | dropout: float = 0.0, 1419 | is_decoder: bool = False, 1420 | bias: bool = True, 1421 | ): 1422 | super().__init__() 1423 | self.embed_dim = embed_dim 1424 | self.num_heads = num_heads 1425 | self.dropout = dropout 1426 | self.head_dim = embed_dim // num_heads 1427 | if self.head_dim * num_heads != self.embed_dim: 1428 | raise ValueError( 1429 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 1430 | f" {num_heads})." 1431 | ) 1432 | self.scaling = self.head_dim**-0.5 1433 | 1434 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 1435 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 1436 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 1437 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 1438 | 1439 | def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): 1440 | return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 1441 | 1442 | def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): 1443 | return tensor if position_embeddings is None else tensor + position_embeddings 1444 | 1445 | def forward( 1446 | self, 1447 | hidden_states: torch.Tensor, 1448 | attention_mask: Optional[torch.Tensor] = None, 1449 | position_embeddings: Optional[torch.Tensor] = None, 1450 | key_value_states: Optional[torch.Tensor] = None, 1451 | key_value_position_embeddings: Optional[torch.Tensor] = None, 1452 | output_attentions: bool = False, 1453 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 1454 | """Input shape: Batch x Time x Channel""" 1455 | 1456 | hidden_states = hidden_states.permute(1, 0, 2) if hidden_states is not None else None 1457 | position_embeddings = position_embeddings.permute(1, 0, 2) if position_embeddings is not None else None 1458 | key_value_states = key_value_states.permute(1, 0, 2) if key_value_states is not None else None 1459 | key_value_position_embeddings = ( 1460 | key_value_position_embeddings.permute(1, 0, 2) if key_value_position_embeddings is not None else None 1461 | ) 1462 | 1463 | # if key_value_states are provided this layer is used as a cross-attention layer 1464 | # for the decoder 1465 | is_cross_attention = key_value_states is not None 1466 | batch_size, target_len, embed_dim = hidden_states.size() 1467 | 1468 | # add position embeddings to the hidden states before projecting to queries and keys 1469 | if position_embeddings is not None: 1470 | hidden_states_original = hidden_states 1471 | hidden_states = self.with_pos_embed(hidden_states, position_embeddings) 1472 | 1473 | # add key-value position embeddings to the key value states 1474 | if key_value_position_embeddings is not None: 1475 | key_value_states_original = key_value_states 1476 | key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) 1477 | 1478 | # get query proj 1479 | query_states = self.q_proj(hidden_states) * self.scaling 1480 | # get key, value proj 1481 | if is_cross_attention: 1482 | # cross_attentions 1483 | key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) 1484 | value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) 1485 | else: 1486 | # self_attention 1487 | key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) 1488 | value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) 1489 | 1490 | proj_shape = (batch_size * self.num_heads, -1, self.head_dim) 1491 | query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) 1492 | key_states = key_states.view(*proj_shape) 1493 | value_states = value_states.view(*proj_shape) 1494 | 1495 | source_len = key_states.size(1) 1496 | 1497 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 1498 | 1499 | if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): 1500 | raise ValueError( 1501 | f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" 1502 | f" {attn_weights.size()}" 1503 | ) 1504 | 1505 | if attention_mask is not None: 1506 | if attention_mask.size() != (batch_size * self.num_heads, target_len, source_len): 1507 | raise ValueError( 1508 | f"Attention mask should be of size {(target_len, batch_size * self.num_heads, source_len)}, but is" 1509 | f" {attention_mask.size()}" 1510 | ) 1511 | attn_weights += attention_mask 1512 | 1513 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 1514 | 1515 | if output_attentions: 1516 | # this operation is a bit awkward, but it's required to 1517 | # make sure that attn_weights keeps its gradient. 1518 | # In order to do so, attn_weights have to reshaped 1519 | # twice and have to be reused in the following 1520 | attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) 1521 | attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) 1522 | else: 1523 | attn_weights_reshaped = None 1524 | 1525 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 1526 | 1527 | attn_output = torch.bmm(attn_probs, value_states) 1528 | 1529 | if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): 1530 | raise ValueError( 1531 | f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" 1532 | f" {attn_output.size()}" 1533 | ) 1534 | 1535 | attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) 1536 | attn_output = attn_output.transpose(1, 2) 1537 | attn_output = attn_output.reshape(batch_size, target_len, embed_dim) 1538 | 1539 | attn_output = self.out_proj(attn_output).permute(1, 0, 2) 1540 | 1541 | return attn_output, attn_weights_reshaped 1542 | 1543 | 1544 | class Mask2FormerMaskedAttentionDecoderLayer(nn.Module): 1545 | """ 1546 | The Mask2FormerMaskedAttentionDecoderLayer is made up of self-attention, cross (masked) attention as well as FFN 1547 | blocks. The cross attention block used as part of `Mask2FormerMaskedAttentionDecoderLayer` is actually a `masked 1548 | attention` block that restricts the attention to localized features centered around predicted segments which leads 1549 | to faster convergence and improved performance. The order of self and cross (i.e. masked) attention blocks have 1550 | also been swapped in Mask2FormerMaskedAttentionDecoder compared to a standard DetrDecoder as an optimization 1551 | improvement. 1552 | 1553 | Args: 1554 | config (`Mask2FormerConfig`): 1555 | The configuration used to initialize the Mask2FormerMaskedAttentionDecoder. 1556 | """ 1557 | 1558 | def __init__(self, config: Mask2FormerConfig): 1559 | super().__init__() 1560 | self.config = config 1561 | self.embed_dim = self.config.hidden_dim 1562 | self.pre_norm = self.config.pre_norm 1563 | self.self_attn = Mask2FormerAttention( 1564 | embed_dim=self.embed_dim, 1565 | num_heads=config.num_attention_heads, 1566 | dropout=config.dropout, 1567 | is_decoder=True, 1568 | ) 1569 | 1570 | self.dropout = self.config.dropout 1571 | self.activation_fn = ACT2FN[self.config.activation_function] 1572 | self.activation_dropout = self.config.dropout 1573 | 1574 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 1575 | self.cross_attn = nn.MultiheadAttention(self.embed_dim, self.config.num_attention_heads, self.config.dropout) 1576 | self.cross_attn_layer_norm = nn.LayerNorm(self.embed_dim) 1577 | self.fc1 = nn.Linear(self.embed_dim, self.config.dim_feedforward) 1578 | self.fc2 = nn.Linear(self.config.dim_feedforward, self.embed_dim) 1579 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 1580 | 1581 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 1582 | return tensor if pos is None else tensor + pos 1583 | 1584 | def forward_post( 1585 | self, 1586 | hidden_states: torch.Tensor, 1587 | level_index: int = None, 1588 | attention_mask: Optional[torch.Tensor] = None, 1589 | position_embeddings: Optional[torch.Tensor] = None, 1590 | query_position_embeddings: Optional[torch.Tensor] = None, 1591 | encoder_hidden_states: Optional[torch.Tensor] = None, 1592 | encoder_attention_mask: Optional[torch.Tensor] = None, 1593 | output_attentions: Optional[bool] = False, 1594 | ): 1595 | # Masked(Cross)-Attention Block 1596 | cross_attn_weights = None 1597 | self_attn_weights = None 1598 | 1599 | residual = hidden_states 1600 | 1601 | hidden_states, cross_attn_weights = self.cross_attn( 1602 | query=self.with_pos_embed(hidden_states, query_position_embeddings), 1603 | key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]), 1604 | value=encoder_hidden_states[level_index], 1605 | attn_mask=encoder_attention_mask, 1606 | key_padding_mask=None, 1607 | ) 1608 | 1609 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1610 | hidden_states = residual + hidden_states 1611 | hidden_states = self.cross_attn_layer_norm(hidden_states) 1612 | 1613 | # Self Attention Block 1614 | residual = hidden_states 1615 | 1616 | hidden_states, self_attn_weights = self.self_attn( 1617 | hidden_states=hidden_states, 1618 | position_embeddings=query_position_embeddings, 1619 | attention_mask=None, 1620 | output_attentions=True, 1621 | ) 1622 | 1623 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1624 | hidden_states = residual + hidden_states 1625 | hidden_states = self.self_attn_layer_norm(hidden_states) 1626 | 1627 | # Fully Connected 1628 | residual = hidden_states 1629 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 1630 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 1631 | hidden_states = self.fc2(hidden_states) 1632 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1633 | hidden_states = residual + hidden_states 1634 | hidden_states = self.final_layer_norm(hidden_states) 1635 | 1636 | outputs = (hidden_states,) 1637 | 1638 | if output_attentions: 1639 | outputs += (self_attn_weights, cross_attn_weights) 1640 | 1641 | return outputs 1642 | 1643 | def forward_pre( 1644 | self, 1645 | hidden_states: torch.Tensor, 1646 | level_index: int = None, 1647 | attention_mask: Optional[torch.Tensor] = None, 1648 | position_embeddings: Optional[torch.Tensor] = None, 1649 | query_position_embeddings: Optional[torch.Tensor] = None, 1650 | encoder_hidden_states: Optional[torch.Tensor] = None, 1651 | encoder_attention_mask: Optional[torch.Tensor] = None, 1652 | output_attentions: Optional[bool] = False, 1653 | ): 1654 | # Masked(Cross)-Attention Block 1655 | cross_attn_weights = None 1656 | self_attn_weights = None 1657 | 1658 | residual = hidden_states 1659 | 1660 | hidden_states = self.cross_attn_layer_norm(hidden_states) 1661 | 1662 | hidden_states, cross_attn_weights = self.cross_attn( 1663 | query=self.with_pos_embed(hidden_states, query_position_embeddings), 1664 | key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]), 1665 | value=encoder_hidden_states[level_index], 1666 | attn_mask=encoder_attention_mask, 1667 | key_padding_mask=None, 1668 | ) 1669 | 1670 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1671 | hidden_states = residual + hidden_states 1672 | 1673 | # Self Attention Block 1674 | residual = hidden_states 1675 | 1676 | hidden_states = self.self_attn_layer_norm(hidden_states) 1677 | 1678 | hidden_states, self_attn_weights = self.self_attn( 1679 | hidden_states=hidden_states, 1680 | position_embeddings=query_position_embeddings, 1681 | attention_mask=None, 1682 | output_attentions=True, 1683 | ) 1684 | 1685 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1686 | hidden_states = residual + hidden_states 1687 | 1688 | # Fully Connected 1689 | residual = hidden_states 1690 | hidden_states = self.final_layer_norm(hidden_states) 1691 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 1692 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 1693 | hidden_states = self.fc2(hidden_states) 1694 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1695 | hidden_states = residual + hidden_states 1696 | 1697 | outputs = (hidden_states,) 1698 | 1699 | if output_attentions: 1700 | outputs += (self_attn_weights, cross_attn_weights) 1701 | 1702 | return outputs 1703 | 1704 | def forward( 1705 | self, 1706 | hidden_states: torch.Tensor, 1707 | level_index: int = None, 1708 | attention_mask: Optional[torch.Tensor] = None, 1709 | position_embeddings: Optional[torch.Tensor] = None, 1710 | query_position_embeddings: Optional[torch.Tensor] = None, 1711 | encoder_hidden_states: Optional[torch.Tensor] = None, 1712 | encoder_attention_mask: Optional[torch.Tensor] = None, 1713 | output_attentions: Optional[bool] = False, 1714 | ): 1715 | """ 1716 | Args: 1717 | hidden_states (`torch.FloatTensor`): 1718 | Input to the layer of shape `(seq_len, batch, embed_dim)`. 1719 | attention_mask (`torch.FloatTensor`): 1720 | Attention mask of shape `(1, seq_len, tgt_len, src_len)`. 1721 | position_embeddings (`torch.FloatTensor`, *optional*): 1722 | Position embeddings that are added to the keys in the masked-attention layer. 1723 | query_position_embeddings (`torch.FloatTensor`, *optional*): 1724 | Position embeddings that are added to the queries and keys in the self-attention layer. 1725 | encoder_hidden_states (`torch.FloatTensor`): 1726 | Cross attention input to the layer of shape `(seq_len, batch, embed_dim)`. 1727 | encoder_attention_mask (`torch.FloatTensor`): 1728 | Encoder attention mask of size`(1, seq_len, tgt_len, src_len)`. 1729 | output_attentions (`bool`, *optional*): 1730 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1731 | returned tensors for more detail. 1732 | """ 1733 | 1734 | if self.pre_norm: 1735 | outputs = self.forward_pre( 1736 | hidden_states=hidden_states, 1737 | level_index=level_index, 1738 | position_embeddings=position_embeddings, 1739 | query_position_embeddings=query_position_embeddings, 1740 | encoder_hidden_states=encoder_hidden_states, 1741 | encoder_attention_mask=encoder_attention_mask, 1742 | output_attentions=output_attentions, 1743 | ) 1744 | else: 1745 | outputs = self.forward_post( 1746 | hidden_states=hidden_states, 1747 | level_index=level_index, 1748 | position_embeddings=position_embeddings, 1749 | query_position_embeddings=query_position_embeddings, 1750 | encoder_hidden_states=encoder_hidden_states, 1751 | encoder_attention_mask=encoder_attention_mask, 1752 | output_attentions=output_attentions, 1753 | ) 1754 | 1755 | return outputs 1756 | 1757 | 1758 | class Mask2FormerMaskedAttentionDecoder(nn.Module): 1759 | """ 1760 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a 1761 | [`Mask2FormerMaskedAttentionDecoderLayer`]. The decoder updates the query embeddings through multiple cross 1762 | (masked) and self-attention layers. The decoder uses a new **masked attention** mechanism instead of the standard 1763 | cross-attention, which extracts localized features by constraining cross-attention to within the foreground region 1764 | of the predicted mask for each query, instead of attending to the full feature map. 1765 | 1766 | Args: 1767 | config (`Mask2FormerConfig`): 1768 | Configuration used to instantiate Mask2FormerMaskedAttentionDecoder. 1769 | """ 1770 | 1771 | def __init__(self, config: Mask2FormerConfig): 1772 | super().__init__() 1773 | 1774 | self.config = config 1775 | self.mask_feature_size = config.mask_feature_size 1776 | self.dropout = config.dropout 1777 | self.layerdrop = config.dropout 1778 | self.num_feature_levels = 3 # level embedding (3 scales) 1779 | self.decoder_layers = config.decoder_layers - 1 1780 | 1781 | self.layers = nn.ModuleList( 1782 | [Mask2FormerMaskedAttentionDecoderLayer(self.config) for _ in range(self.decoder_layers)] 1783 | ) 1784 | self.layernorm = nn.LayerNorm(config.hidden_dim) 1785 | 1786 | self.mask_predictor = Mask2FormerMaskPredictor( 1787 | hidden_size=config.hidden_dim, 1788 | num_heads=config.num_attention_heads, 1789 | mask_feature_size=self.mask_feature_size, 1790 | ) 1791 | 1792 | self.gradient_checkpointing = False 1793 | 1794 | def forward( 1795 | self, 1796 | inputs_embeds: torch.Tensor = None, 1797 | multi_stage_positional_embeddings: torch.Tensor = None, 1798 | pixel_embeddings: torch.Tensor = None, 1799 | encoder_hidden_states: torch.Tensor = None, 1800 | query_position_embeddings: torch.Tensor = None, 1801 | feature_size_list: List = None, 1802 | output_attentions: Optional[bool] = None, 1803 | output_hidden_states: Optional[bool] = None, 1804 | return_dict: Optional[bool] = None, 1805 | ): 1806 | r""" 1807 | Args: 1808 | inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): 1809 | The query embeddings that are passed into the decoder. 1810 | multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): 1811 | Position embeddings that are added to the keys in each cross(masked)-attention layer. 1812 | pixel_embeddings (`torch.FloatTensor`): 1813 | Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel 1814 | Decoder. 1815 | query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): 1816 | , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. 1817 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): 1818 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the 1819 | cross(masked)-attention of the decoder. 1820 | feature_size_list (`List[torch.Size]` ): 1821 | This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. 1822 | output_attentions (`bool`, *optional*): 1823 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1824 | returned tensors for more detail. 1825 | output_hidden_states (`bool`, *optional*): 1826 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1827 | for more detail. 1828 | return_dict (`bool`, *optional*): 1829 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1830 | """ 1831 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1832 | output_hidden_states = ( 1833 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1834 | ) 1835 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1836 | 1837 | if inputs_embeds is not None: 1838 | hidden_states = inputs_embeds 1839 | 1840 | # intermediate hidden states with layernorm applied - required for predicting class logits 1841 | intermediate = () 1842 | 1843 | # decoder layers 1844 | all_hidden_states = () if output_hidden_states else None 1845 | attentions = () if output_attentions else None 1846 | 1847 | # intermediate mask predictions from transformer decoder layers 1848 | intermediate_mask_predictions = () 1849 | 1850 | intermediate_hidden_states = self.layernorm(inputs_embeds) 1851 | intermediate += (intermediate_hidden_states,) 1852 | 1853 | predicted_mask, attention_mask = self.mask_predictor( 1854 | intermediate_hidden_states, pixel_embeddings, feature_size_list[0] 1855 | ) 1856 | intermediate_mask_predictions += (predicted_mask,) 1857 | 1858 | for idx, decoder_layer in enumerate(self.layers): 1859 | if output_hidden_states: 1860 | all_hidden_states += (hidden_states,) 1861 | 1862 | dropout_probability = torch.rand([]) 1863 | 1864 | if self.training and (dropout_probability < self.layerdrop): 1865 | continue 1866 | 1867 | if self.gradient_checkpointing and self.training: 1868 | 1869 | def create_custom_forward(module): 1870 | def custom_forward(*inputs): 1871 | return module(*inputs, output_attentions) 1872 | 1873 | return custom_forward 1874 | 1875 | layer_outputs = torch.utils.checkpoint.checkpoint( 1876 | create_custom_forward(decoder_layer), 1877 | hidden_states, 1878 | attention_mask, 1879 | encoder_hidden_states, 1880 | None, 1881 | None, 1882 | ) 1883 | 1884 | else: 1885 | level_index = idx % self.num_feature_levels 1886 | 1887 | attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False 1888 | 1889 | layer_outputs = decoder_layer( 1890 | hidden_states, 1891 | level_index=level_index, 1892 | position_embeddings=multi_stage_positional_embeddings, 1893 | query_position_embeddings=query_position_embeddings, 1894 | encoder_hidden_states=encoder_hidden_states, 1895 | encoder_attention_mask=attention_mask, 1896 | output_attentions=output_attentions, 1897 | ) 1898 | 1899 | intermediate_hidden_states = self.layernorm(layer_outputs[0]) 1900 | 1901 | predicted_mask, attention_mask = self.mask_predictor( 1902 | intermediate_hidden_states, 1903 | pixel_embeddings, 1904 | feature_size_list[(idx + 1) % self.num_feature_levels], 1905 | ) 1906 | 1907 | intermediate_mask_predictions += (predicted_mask,) 1908 | 1909 | # add intermediate hidden states with layer norm applied which will be used for predicting class logits 1910 | intermediate += (intermediate_hidden_states,) 1911 | 1912 | hidden_states = layer_outputs[0] 1913 | 1914 | if output_attentions: 1915 | attentions += (layer_outputs[1],) 1916 | 1917 | # add hidden states from the last decoder layer 1918 | if output_hidden_states: 1919 | all_hidden_states += (hidden_states,) 1920 | 1921 | hidden_states = hidden_states.transpose(1, 0) 1922 | if not return_dict: 1923 | outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] 1924 | return tuple(v for v in outputs if v is not None) 1925 | 1926 | return Mask2FormerMaskedAttentionDecoderOutput( 1927 | last_hidden_state=hidden_states, 1928 | hidden_states=all_hidden_states, 1929 | attentions=attentions, 1930 | intermediate_hidden_states=intermediate, 1931 | masks_queries_logits=intermediate_mask_predictions, 1932 | ) 1933 | 1934 | 1935 | # Copied from transformers.models.maskformer.modeling_maskformer.PredictionBlock with MaskFormer->Mask2Former 1936 | class Mask2FormerPredictionBlock(nn.Module): 1937 | def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None: 1938 | super().__init__() 1939 | self.layers = [nn.Linear(in_dim, out_dim), activation] 1940 | # Maintain submodule indexing as if part of a Sequential block 1941 | for i, layer in enumerate(self.layers): 1942 | self.add_module(str(i), layer) 1943 | 1944 | def forward(self, input: Tensor) -> Tensor: 1945 | hidden_state = input 1946 | for layer in self.layers: 1947 | hidden_state = layer(hidden_state) 1948 | return hidden_state 1949 | 1950 | 1951 | class Mask2FormerMLPPredictionHead(nn.Module): 1952 | def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3): 1953 | """ 1954 | A classic Multi Layer Perceptron (MLP). 1955 | 1956 | Args: 1957 | input_dim (`int`): 1958 | The input dimensions. 1959 | hidden_dim (`int`): 1960 | The hidden dimensions. 1961 | output_dim (`int`): 1962 | The output dimensions. 1963 | num_layers (int, *optional*, defaults to 3): 1964 | The number of layers. 1965 | """ 1966 | super().__init__() 1967 | in_dims = [input_dim] + [hidden_dim] * (num_layers - 1) 1968 | out_dims = [hidden_dim] * (num_layers - 1) + [output_dim] 1969 | 1970 | self.layers = [] 1971 | for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): 1972 | activation = nn.ReLU() if i < num_layers - 1 else nn.Identity() 1973 | layer = Mask2FormerPredictionBlock(in_dim, out_dim, activation=activation) 1974 | self.layers.append(layer) 1975 | # Provide backwards compatibility from when the class inherited from nn.Sequential 1976 | # In nn.Sequential subclasses, the name given to the layer is its index in the sequence. 1977 | # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g. 1978 | # self.my_layer_name = Layer() 1979 | # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register 1980 | # explicitly 1981 | self.add_module(str(i), layer) 1982 | 1983 | def forward(self, input: Tensor) -> Tensor: 1984 | hidden_state = input 1985 | for layer in self.layers: 1986 | hidden_state = layer(hidden_state) 1987 | return hidden_state 1988 | 1989 | 1990 | class Mask2FormerMaskPredictor(nn.Module): 1991 | def __init__(self, hidden_size: int, num_heads: int, mask_feature_size: torch.Tensor): 1992 | """ 1993 | This class is used to get the predicted mask for a given Mask2FormerMaskedAttentionDecoder layer. It also 1994 | generates the binarized attention mask associated with the given predicted mask. The attention mask obtained 1995 | using predicted mask of the (l-1)th decoder layer is fed to the cross(masked)-attention block of the next 1996 | decoder layer as input. 1997 | 1998 | Args: 1999 | hidden_size (`int`): 2000 | The feature dimension of the Mask2FormerMaskedAttentionDecoder 2001 | num_heads (`int`): 2002 | The number of heads used in the Mask2FormerMaskedAttentionDecoder 2003 | mask_feature_size (`torch.Tensor`): 2004 | one of the output dimensions of the predicted masks for each query 2005 | """ 2006 | super().__init__() 2007 | self.hidden_size = hidden_size 2008 | self.num_heads = num_heads 2009 | 2010 | self.mask_embedder = Mask2FormerMLPPredictionHead(self.hidden_size, self.hidden_size, mask_feature_size) 2011 | 2012 | def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None): 2013 | mask_embeddings = self.mask_embedder(outputs.transpose(0, 1)) 2014 | 2015 | # Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly 2016 | batch_size, num_queries, num_channels = mask_embeddings.shape 2017 | _, _, height, width = pixel_embeddings.shape 2018 | outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device) 2019 | for c in range(num_channels): 2020 | outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c] 2021 | 2022 | attention_mask = nn.functional.interpolate( 2023 | outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False 2024 | ) 2025 | 2026 | attention_mask = attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1) 2027 | attention_mask = (attention_mask.flatten(0, 1) < 0.5).bool() 2028 | attention_mask = attention_mask.detach() 2029 | 2030 | return outputs_mask, attention_mask 2031 | 2032 | 2033 | class Mask2FormerTransformerModule(nn.Module): 2034 | """ 2035 | The Mask2Former's transformer module. 2036 | """ 2037 | 2038 | def __init__(self, in_features: int, config: Mask2FormerConfig): 2039 | super().__init__() 2040 | hidden_dim = config.hidden_dim 2041 | self.num_feature_levels = 3 2042 | self.position_embedder = Mask2FormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True) 2043 | self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim) 2044 | self.queries_features = nn.Embedding(config.num_queries, hidden_dim) 2045 | self.input_projections = [] 2046 | 2047 | for _ in range(self.num_feature_levels): 2048 | if in_features != hidden_dim or config.enforce_input_projection: 2049 | self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1)) 2050 | else: 2051 | self.input_projections.append(nn.Sequential()) 2052 | 2053 | self.decoder = Mask2FormerMaskedAttentionDecoder(config=config) 2054 | self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) 2055 | 2056 | def forward( 2057 | self, 2058 | multi_scale_features: List[Tensor], 2059 | mask_features: Tensor, 2060 | output_hidden_states: bool = False, 2061 | output_attentions: bool = False, 2062 | ) -> Mask2FormerMaskedAttentionDecoderOutput: 2063 | multi_stage_features = [] 2064 | multi_stage_positional_embeddings = [] 2065 | size_list = [] 2066 | 2067 | for i in range(self.num_feature_levels): 2068 | size_list.append(multi_scale_features[i].shape[-2:]) 2069 | multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) 2070 | multi_stage_features.append( 2071 | self.input_projections[i](multi_scale_features[i]).flatten(2) 2072 | + self.level_embed.weight[i][None, :, None] 2073 | ) 2074 | 2075 | # Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels) 2076 | multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) 2077 | multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) 2078 | 2079 | _, batch_size, _ = multi_stage_features[0].shape 2080 | 2081 | # [num_queries, batch_size, num_channels] 2082 | query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) 2083 | query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) 2084 | 2085 | decoder_output = self.decoder( 2086 | inputs_embeds=query_features, 2087 | multi_stage_positional_embeddings=multi_stage_positional_embeddings, 2088 | pixel_embeddings=mask_features, 2089 | encoder_hidden_states=multi_stage_features, 2090 | query_position_embeddings=query_embeddings, 2091 | feature_size_list=size_list, 2092 | output_hidden_states=output_hidden_states, 2093 | output_attentions=output_attentions, 2094 | return_dict=True, 2095 | ) 2096 | 2097 | return decoder_output 2098 | 2099 | 2100 | MASK2FORMER_START_DOCSTRING = r""" 2101 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use 2102 | it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and 2103 | behavior. 2104 | 2105 | Parameters: 2106 | config ([`Mask2FormerConfig`]): Model configuration class with all the parameters of the model. 2107 | Initializing with a config file does not load the weights associated with the model, only the 2108 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 2109 | """ 2110 | 2111 | MASK2FORMER_INPUTS_DOCSTRING = r""" 2112 | Args: 2113 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 2114 | Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See 2115 | [`AutoImageProcessor.preprocess`] for details. 2116 | pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): 2117 | Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: 2118 | 2119 | - 1 for pixels that are real (i.e. **not masked**), 2120 | - 0 for pixels that are padding (i.e. **masked**). 2121 | 2122 | [What are attention masks?](../glossary#attention-mask) 2123 | output_hidden_states (`bool`, *optional*): 2124 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 2125 | more detail. 2126 | output_attentions (`bool`, *optional*): 2127 | Whether or not to return the attentions tensors of Detr's decoder attention layers. 2128 | return_dict (`bool`, *optional*): 2129 | Whether or not to return a [`~Mask2FormerModelOutput`] instead of a plain tuple. 2130 | """ 2131 | 2132 | 2133 | class Mask2FormerPreTrainedModel(PreTrainedModel): 2134 | config_class = Mask2FormerConfig 2135 | base_model_prefix = "model" 2136 | main_input_name = "pixel_values" 2137 | 2138 | def _init_weights(self, module: nn.Module): 2139 | xavier_std = self.config.init_xavier_std 2140 | std = self.config.init_std 2141 | 2142 | if isinstance(module, Mask2FormerTransformerModule): 2143 | if module.input_projections is not None: 2144 | for input_projection in module.input_projections: 2145 | if not isinstance(input_projection, nn.Sequential): 2146 | nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std) 2147 | nn.init.constant_(input_projection.bias, 0) 2148 | 2149 | elif isinstance(module, Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention): 2150 | nn.init.constant_(module.sampling_offsets.weight.data, 0.0) 2151 | thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads) 2152 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 2153 | grid_init = ( 2154 | (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) 2155 | .view(module.n_heads, 1, 1, 2) 2156 | .repeat(1, module.n_levels, module.n_points, 1) 2157 | ) 2158 | for i in range(module.n_points): 2159 | grid_init[:, :, i, :] *= i + 1 2160 | with torch.no_grad(): 2161 | module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 2162 | 2163 | nn.init.constant_(module.attention_weights.weight.data, 0.0) 2164 | nn.init.constant_(module.attention_weights.bias.data, 0.0) 2165 | nn.init.xavier_uniform_(module.value_proj.weight.data) 2166 | nn.init.constant_(module.value_proj.bias.data, 0.0) 2167 | nn.init.xavier_uniform_(module.output_proj.weight.data) 2168 | nn.init.constant_(module.output_proj.bias.data, 0.0) 2169 | 2170 | elif isinstance(module, Mask2FormerMaskedAttentionDecoderLayer): 2171 | for p in module.parameters(): 2172 | if p.dim() > 1: 2173 | nn.init.xavier_uniform_(p, gain=xavier_std) 2174 | 2175 | elif isinstance(module, Mask2FormerPixelLevelModule): 2176 | for submodule in module.modules(): 2177 | if isinstance(submodule, (nn.Conv2d, nn.Linear)): 2178 | submodule.weight.data.normal_(mean=0.0, std=std) 2179 | if submodule.bias is not None: 2180 | submodule.bias.data.zero_() 2181 | 2182 | elif isinstance(module, Mask2FormerPixelDecoder): 2183 | for p in module.parameters(): 2184 | if p.dim() > 1: 2185 | nn.init.xavier_uniform_(p) 2186 | nn.init.normal_(module.level_embed, std=0) 2187 | 2188 | elif isinstance(module, Mask2FormerPixelDecoderEncoderOnly): 2189 | for p in module.parameters(): 2190 | if p.dim() > 1: 2191 | nn.init.xavier_uniform_(p) 2192 | 2193 | elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): 2194 | module.weight.data.normal_(mean=0.0, std=std) 2195 | if module.bias is not None: 2196 | module.bias.data.zero_() 2197 | 2198 | elif isinstance(module, nn.Embedding): 2199 | module.weight.data.normal_(mean=0.0, std=std) 2200 | if module.padding_idx is not None: 2201 | module.weight.data[module.padding_idx].zero_() 2202 | 2203 | if hasattr(module, "reference_points"): 2204 | nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) 2205 | nn.init.constant_(module.reference_points.bias.data, 0.0) 2206 | 2207 | 2208 | @add_start_docstrings( 2209 | "The bare Mask2Former Model outputting raw hidden-states without any specific head on top.", 2210 | MASK2FORMER_START_DOCSTRING, 2211 | ) 2212 | class Mask2FormerModel(Mask2FormerPreTrainedModel): 2213 | main_input_name = "pixel_values" 2214 | 2215 | def __init__(self, config: Mask2FormerConfig): 2216 | super().__init__(config) 2217 | self.pixel_level_module = Mask2FormerPixelLevelModule(config) 2218 | self.transformer_module = Mask2FormerTransformerModule(in_features=config.feature_size, config=config) 2219 | 2220 | self.post_init() 2221 | 2222 | @add_start_docstrings_to_model_forward(MASK2FORMER_INPUTS_DOCSTRING) 2223 | @replace_return_docstrings(output_type=Mask2FormerModelOutput, config_class=_CONFIG_FOR_DOC) 2224 | def forward( 2225 | self, 2226 | pixel_values: Tensor, 2227 | pixel_mask: Optional[Tensor] = None, 2228 | output_hidden_states: Optional[bool] = None, 2229 | output_attentions: Optional[bool] = None, 2230 | return_dict: Optional[bool] = None, 2231 | ) -> Mask2FormerModelOutput: 2232 | r""" 2233 | Returns: 2234 | `Mask2FormerModelOutput` 2235 | 2236 | Examples: 2237 | ```python 2238 | >>> import torch 2239 | >>> from PIL import Image 2240 | >>> import requests 2241 | >>> from transformers import AutoImageProcessor, Mask2FormerModel 2242 | 2243 | >>> # load image 2244 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 2245 | >>> image = Image.open(requests.get(url, stream=True).raw) 2246 | 2247 | >>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset 2248 | >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") 2249 | >>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance") 2250 | >>> inputs = image_processor(image, return_tensors="pt") 2251 | 2252 | >>> # forward pass 2253 | >>> with torch.no_grad(): 2254 | ... outputs = model(**inputs) 2255 | 2256 | >>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size) 2257 | >>> print(outputs.transformer_decoder_last_hidden_state.shape) 2258 | torch.Size([1, 100, 256]) 2259 | ``` 2260 | """ 2261 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 2262 | output_hidden_states = ( 2263 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 2264 | ) 2265 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 2266 | 2267 | batch_size, _, height, width = pixel_values.shape 2268 | 2269 | if pixel_mask is None: 2270 | pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) 2271 | 2272 | pixel_level_module_output = self.pixel_level_module( 2273 | pixel_values=pixel_values, output_hidden_states=output_hidden_states 2274 | ) 2275 | 2276 | transformer_module_output = self.transformer_module( 2277 | multi_scale_features=pixel_level_module_output.decoder_hidden_states, 2278 | mask_features=pixel_level_module_output.decoder_last_hidden_state, 2279 | output_hidden_states=True, 2280 | output_attentions=output_attentions, 2281 | ) 2282 | 2283 | encoder_hidden_states = None 2284 | pixel_decoder_hidden_states = None 2285 | transformer_decoder_hidden_states = None 2286 | transformer_decoder_intermediate_states = None 2287 | 2288 | if output_hidden_states: 2289 | encoder_hidden_states = pixel_level_module_output.encoder_hidden_states 2290 | pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states 2291 | transformer_decoder_hidden_states = transformer_module_output.hidden_states 2292 | transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states 2293 | 2294 | output = Mask2FormerModelOutput( 2295 | encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state, 2296 | pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state, 2297 | transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state, 2298 | encoder_hidden_states=encoder_hidden_states, 2299 | pixel_decoder_hidden_states=pixel_decoder_hidden_states, 2300 | transformer_decoder_hidden_states=transformer_decoder_hidden_states, 2301 | transformer_decoder_intermediate_states=transformer_decoder_intermediate_states, 2302 | attentions=transformer_module_output.attentions, 2303 | masks_queries_logits=transformer_module_output.masks_queries_logits, 2304 | ) 2305 | 2306 | if not return_dict: 2307 | output = tuple(v for v in output.values() if v is not None) 2308 | 2309 | return output 2310 | 2311 | 2312 | @add_start_docstrings( 2313 | "The Mask2Former Model with heads on top for instance/semantic/panoptic segmentation.", 2314 | MASK2FORMER_START_DOCSTRING, 2315 | ) 2316 | class Mask2FormerForUniversalSegmentation(Mask2FormerPreTrainedModel): 2317 | main_input_name = "pixel_values" 2318 | 2319 | def __init__(self, config: Mask2FormerConfig): 2320 | super().__init__(config) 2321 | self.model = Mask2FormerModel(config) 2322 | 2323 | self.weight_dict: Dict[str, float] = { 2324 | "loss_cross_entropy": config.class_weight, 2325 | "loss_mask": config.mask_weight, 2326 | "loss_dice": config.dice_weight, 2327 | } 2328 | 2329 | self.class_predictor = nn.Linear(config.hidden_dim, config.num_labels + 1) 2330 | 2331 | self.criterion = Mask2FormerLoss(config=config, weight_dict=self.weight_dict) 2332 | self.post_init() 2333 | 2334 | def get_loss_dict( 2335 | self, 2336 | masks_queries_logits: Tensor, 2337 | class_queries_logits: Tensor, 2338 | mask_labels: Tensor, 2339 | class_labels: Tensor, 2340 | auxiliary_predictions: Dict[str, Tensor], 2341 | ) -> Dict[str, Tensor]: 2342 | loss_dict: Dict[str, Tensor] = self.criterion( 2343 | masks_queries_logits=masks_queries_logits, 2344 | class_queries_logits=class_queries_logits, 2345 | mask_labels=mask_labels, 2346 | class_labels=class_labels, 2347 | auxiliary_predictions=auxiliary_predictions, 2348 | ) 2349 | 2350 | # weight each loss by `self.weight_dict[]` including auxiliary losses 2351 | for key, weight in self.weight_dict.items(): 2352 | for loss_key, loss in loss_dict.items(): 2353 | if key in loss_key: 2354 | loss *= weight 2355 | 2356 | return loss_dict 2357 | 2358 | def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: 2359 | return sum(loss_dict.values()) 2360 | 2361 | def get_auxiliary_logits(self, classes: torch.Tensor, output_masks: torch.Tensor): 2362 | auxiliary_logits = [] 2363 | 2364 | for aux_binary_masks, aux_classes in zip(output_masks[:-1], classes[:-1]): 2365 | auxiliary_logits.append({"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes}) 2366 | 2367 | return auxiliary_logits 2368 | 2369 | @add_start_docstrings_to_model_forward(MASK2FORMER_INPUTS_DOCSTRING) 2370 | @replace_return_docstrings(output_type=Mask2FormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC) 2371 | def forward( 2372 | self, 2373 | pixel_values: Tensor, 2374 | mask_labels: Optional[List[Tensor]] = None, 2375 | labels: Optional[List[Tensor]] = None, 2376 | pixel_mask: Optional[Tensor] = None, 2377 | output_hidden_states: Optional[bool] = None, 2378 | output_auxiliary_logits: Optional[bool] = None, 2379 | output_attentions: Optional[bool] = None, 2380 | return_dict: Optional[bool] = None, 2381 | ) -> SemanticSegmenterOutput: 2382 | r""" 2383 | mask_labels (`List[torch.Tensor]`, *optional*): 2384 | List of mask labels of shape `(num_labels, height, width)` to be fed to a model 2385 | class_labels (`List[torch.LongTensor]`, *optional*): 2386 | list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the 2387 | labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. 2388 | 2389 | Returns: 2390 | `Mask2FormerUniversalSegmentationOutput` 2391 | """ 2392 | 2393 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 2394 | output_hidden_states = ( 2395 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 2396 | ) 2397 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 2398 | 2399 | outputs = self.model( 2400 | pixel_values=pixel_values, 2401 | pixel_mask=pixel_mask, 2402 | output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, 2403 | output_attentions=output_attentions, 2404 | return_dict=True, 2405 | ) 2406 | 2407 | loss, loss_dict, auxiliary_logits = None, None, None 2408 | class_queries_logits = () 2409 | 2410 | for decoder_output in outputs.transformer_decoder_intermediate_states: 2411 | class_prediction = self.class_predictor(decoder_output.transpose(0, 1)) 2412 | class_queries_logits += (class_prediction,) 2413 | 2414 | masks_queries_logits = outputs.masks_queries_logits 2415 | 2416 | auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits) 2417 | 2418 | if mask_labels is None: 2419 | mask_labels = torch.ones((pixel_values.shape[0], pixel_values.shape[2], pixel_values.shape[3]), device=pixel_values.device) 2420 | 2421 | if labels is not None: 2422 | loss_dict = self.get_loss_dict( 2423 | masks_queries_logits=masks_queries_logits[-1], 2424 | class_queries_logits=class_queries_logits[-1], 2425 | mask_labels=mask_labels, 2426 | class_labels=labels, 2427 | auxiliary_predictions=auxiliary_logits, 2428 | ) 2429 | loss = self.get_loss(loss_dict)[0] 2430 | 2431 | encoder_hidden_states = None 2432 | pixel_decoder_hidden_states = None 2433 | transformer_decoder_hidden_states = None 2434 | 2435 | if output_hidden_states: 2436 | encoder_hidden_states = outputs.encoder_hidden_states 2437 | pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states 2438 | transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states 2439 | 2440 | output_auxiliary_logits = ( 2441 | self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits 2442 | ) 2443 | if not output_auxiliary_logits: 2444 | auxiliary_logits = None 2445 | 2446 | output = Mask2FormerForUniversalSegmentationOutput( 2447 | loss=loss, 2448 | class_queries_logits=class_queries_logits[-1].float(), # 1 2449 | masks_queries_logits=masks_queries_logits[-1], # 2 2450 | auxiliary_logits=auxiliary_logits, 2451 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, # 3 2452 | pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state.float(), 2453 | transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state, 2454 | encoder_hidden_states=encoder_hidden_states, 2455 | pixel_decoder_hidden_states=pixel_decoder_hidden_states, 2456 | transformer_decoder_hidden_states=transformer_decoder_hidden_states, 2457 | attentions=outputs.attentions, 2458 | ) 2459 | 2460 | logits = torch.stack(process_segmenter_output(output, target_sizes=[pixel_values[i].shape[-2:] for i in range(len(pixel_values))]), dim=0) 2461 | index_labels = torch.stack([torch.argmax(mask_labels[i], dim=0) - (mask_labels[i].sum(dim=0) == 0).float() for i in range(len(labels))], dim=0) 2462 | index_labels_ = torch.ones_like(index_labels.long()) * 255 2463 | for i in range(len(index_labels_)): 2464 | for j, l in enumerate(labels[i]): 2465 | index_labels_[i][index_labels[i] == j] = l 2466 | 2467 | confusion_matrix = get_confusion_matrix(index_labels_.detach().cpu().numpy(), logits.detach().cpu().numpy(), pixel_values.shape[-2:], 19, 255) 2468 | confusion_matrix = torch.tensor(confusion_matrix).unsqueeze(0).int() 2469 | 2470 | return SemanticSegmenterOutput( 2471 | loss=loss, 2472 | logits=confusion_matrix, 2473 | hidden_states=None, 2474 | attentions=None, 2475 | ) 2476 | -------------------------------------------------------------------------------- /paths.py: -------------------------------------------------------------------------------- 1 | def get_path(args, key: str, add_date=False, temp=True): 2 | 3 | method_tag = get_method_tag(args) 4 | 5 | if add_date: 6 | if method_tag: 7 | method_tag += '/' 8 | method_tag += args.datetime[:19] 9 | 10 | method_tag = method_tag.replace(' ', '_') 11 | 12 | folder_prefix = f'{args.save_path}/{args.task}/{args.data}/{args.arch}/{method_tag}' 13 | 14 | if temp: 15 | folder_prefix += '/temp' 16 | 17 | trainer_prefix = f'{folder_prefix}/trainer' 18 | model_prefix = f'{folder_prefix}/models' 19 | 20 | path_dict = {'MAIN_FOLDER_DIR': folder_prefix, 21 | 'TRAINER_FOLDER_DIR': trainer_prefix, 22 | 'MODEL_FOLDER_DIR': model_prefix, 23 | 'INIT_MODEL_PATH': f'{model_prefix}/init_model.pth', 24 | 'INIT_MASKS_PATH': f'{model_prefix}/init_masks.pth', 25 | 'ITER_MASKS_PATH': f'{model_prefix}/iter_masks.pth', 26 | 'OPT_STATE_PATH': f'{model_prefix}/opt_state.pth', 27 | 'COMPRESSED_MODEL_PATH': f'{model_prefix}/compressed_model.pth', 28 | 'ARGS_PATH': f'{folder_prefix}/args.json' 29 | } 30 | 31 | return path_dict[key] 32 | 33 | 34 | def get_method_tag(args): 35 | method_tag = [] 36 | method_tag.append(str(args.init_sparse_ratio)) 37 | method_tag.append(str(args.iter_sparse_ratio)) 38 | 39 | method_tag = '_'.join(method_tag) 40 | if not method_tag: 41 | method_tag = '_' 42 | 43 | return method_tag 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.20.3 2 | datasets==2.12.0 3 | dill==0.3.6 4 | huggingface-hub==0.15.1 5 | keras==2.13.1rc0 6 | pandas==1.5.3 7 | Pillow 8 | tensorboard 9 | tensorflow 10 | torch==2.0.1 11 | torchaudio==2.0.2 12 | torchvision==0.15.2 13 | tqdm==4.65.0 14 | transformers==4.33.1 -------------------------------------------------------------------------------- /trainer_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from datasets import load_metric 10 | from sklearn.metrics import accuracy_score 11 | from torch.optim import * 12 | from transformers import BertForSequenceClassification 13 | from transformers import EvalPrediction 14 | from transformers import ViTForImageClassification 15 | from transformers.trainer import Trainer 16 | from transformers.trainer_pt_utils import get_parameter_names 17 | from transformers.training_args import TrainingArguments 18 | 19 | import nni 20 | from models.modeling_mask2former import Mask2FormerForUniversalSegmentation 21 | from paths import get_path 22 | from utils import get_model_param_keys 23 | 24 | model_dispatcher = { 25 | 'bert-base-uncased': BertForSequenceClassification, 26 | 'bert-large-uncased': BertForSequenceClassification, 27 | 'vit-base': ViTForImageClassification, 28 | 'vit-large': ViTForImageClassification, 29 | 'm2f': Mask2FormerForUniversalSegmentation 30 | } 31 | 32 | 33 | def build_model(pretrained_model_name_or_path: str, task_name: str, data_name: str, **kwargs): 34 | 35 | if data_name == 'cifar100': 36 | num_labels = 100 37 | elif data_name == 'tinyimagenet': 38 | num_labels = 200 39 | elif data_name == 'cityscapes' or data_name == 'kitti': 40 | num_labels = 19 41 | else: 42 | num_labels = 2 43 | 44 | if task_name == 'img_class': 45 | if 'vit' in pretrained_model_name_or_path: 46 | if pretrained_model_name_or_path == 'vit-base': 47 | model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', 48 | id2label=kwargs['id2label'], 49 | label2id=kwargs['label2id'], cache_dir='cache') 50 | elif pretrained_model_name_or_path == 'vit-large': 51 | model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-224', 52 | id2label=kwargs['id2label'], 53 | label2id=kwargs['label2id'], 54 | ignore_mismatched_sizes=True, cache_dir='cache') 55 | else: 56 | raise NotImplementedError 57 | else: 58 | raise NotImplementedError 59 | elif task_name == 'img_seg': 60 | if 'm2f' in pretrained_model_name_or_path: 61 | model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-IN21k-cityscapes-semantic", cache_dir='cache') 62 | else: 63 | raise NotImplementedError 64 | else: 65 | model = model_dispatcher[pretrained_model_name_or_path].from_pretrained(pretrained_model_name_or_path, num_labels=num_labels, cache_dir='cache') 66 | return model 67 | 68 | 69 | def prepare_traced_trainer(model, args, data_content, training_params={}, for_train_flag=True, for_eval_flag=True, 70 | tag='default', device=None, send_tag='train'): 71 | 72 | if 'img' in args.task: 73 | save_strategy = 'no' if 'prune' in tag else 'epoch' 74 | evaluation_strategy = 'no' if 'prune' in tag else 'epoch' 75 | else: 76 | save_strategy = 'no' if 'prune' in tag else 'epoch' 77 | evaluation_strategy = 'no' if 'prune' in tag else 'epoch' 78 | 79 | def compute_metrics(p: EvalPrediction): 80 | if args.task == 'glue': 81 | metric = load_metric('glue', args.data) 82 | 83 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 84 | preds = np.argmax(preds, axis=1) 85 | result = metric.compute(predictions=preds, references=p.label_ids) 86 | 87 | elif args.task == 'img_class': 88 | predictions, labels = p.predictions, p.label_ids 89 | predictions = np.argmax(predictions, axis=1) 90 | result = dict(accuracy=accuracy_score(predictions, labels)) 91 | 92 | elif args.task == 'img_seg': 93 | predictions, labels = p.predictions, p.label_ids 94 | predictions = predictions.sum(0) 95 | pos = predictions.sum(1) 96 | res = predictions.sum(0) 97 | tp = np.diag(predictions) 98 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 99 | score = IoU_array[pos + res - tp != 0].mean() 100 | result = dict(accuracy=score) 101 | else: 102 | raise NotImplementedError 103 | 104 | return result 105 | 106 | if tag == 'default': 107 | logging_dir = None 108 | else: 109 | logging_dir = get_path(args, 'TRAINER_FOLDER_DIR') + '/runs/' + tag 110 | 111 | if device is None: 112 | device = args.device 113 | 114 | if device == 'cpu': 115 | no_cuda = True 116 | else: 117 | no_cuda = False 118 | 119 | if for_train_flag and for_eval_flag and args.task == 'img_seg': 120 | for_eval_flag = False 121 | 122 | num_steps = min(int(training_params.get('num_train_epochs', 3) * len(data_content['train']) / training_params.get('batch_size', 8)) + 5, 10000) 123 | 124 | training_args = TrainingArguments(output_dir=get_path(args, 'TRAINER_FOLDER_DIR') + f'/runs/{tag}', 125 | do_train=for_train_flag, 126 | do_eval=for_eval_flag, 127 | evaluation_strategy=evaluation_strategy, 128 | save_strategy=save_strategy, 129 | logging_strategy='epoch', 130 | logging_dir=logging_dir, 131 | logging_steps=500, 132 | per_device_train_batch_size=training_params.get('batch_size', 32), 133 | per_device_eval_batch_size=32, 134 | max_steps=num_steps, 135 | weight_decay=training_params.get('weight_decay', 1e-2), 136 | lr_scheduler_type='linear', 137 | dataloader_num_workers=1, 138 | learning_rate=training_params.get('learning_rate', 1e-4), 139 | save_total_limit=1, 140 | metric_for_best_model=args.metric_name, 141 | load_best_model_at_end=True, 142 | greater_is_better=True, 143 | disable_tqdm=True, 144 | optim='adamw_torch', 145 | seed=1024, 146 | use_mps_device=device == 'mps', 147 | no_cuda=no_cuda, 148 | remove_unused_columns=False) 149 | 150 | trainer = nni.trace(Trainer)(model=model, 151 | args=training_args, 152 | data_collator=data_content['collator'], 153 | train_dataset=data_content[send_tag], 154 | eval_dataset=data_content['val'], 155 | tokenizer=data_content['tokenizer'], 156 | compute_metrics=compute_metrics) 157 | 158 | return trainer 159 | 160 | 161 | def predict(model_path, args, data_content, tag='default'): 162 | if not Path(model_path).exists(): 163 | print(f'Model does not exist at {model_path}, exiting...') 164 | return {} 165 | 166 | if args.task == 'img_class' and tag == 'test': 167 | send_tag = 'test' 168 | else: 169 | send_tag = 'val' 170 | 171 | model = torch.load(model_path) 172 | trainer = prepare_traced_trainer(model.to(args.device), args, data_content, {}, for_train_flag=False, tag=tag) 173 | 174 | output = trainer.predict(data_content[send_tag], metric_key_prefix=tag) 175 | 176 | print(f'Metric: {output.metrics}') 177 | return output 178 | 179 | 180 | def prepare_masked_trainer(args, trainer, max_steps, decay_zero=True): 181 | trainer.create_optimizer_and_scheduler(num_training_steps=max_steps) 182 | 183 | if os.path.exists(get_path(args, 'ITER_MASKS_PATH')): 184 | masks = torch.load(get_path(args, 'ITER_MASKS_PATH')) 185 | else: 186 | masks = 1 187 | 188 | keys = get_model_param_keys(trainer.model) 189 | 190 | decay_parameters = get_parameter_names(trainer.model, [nn.LayerNorm]) 191 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 192 | 193 | decay_val = 0 if decay_zero else trainer.args.weight_decay 194 | 195 | optimizer_grouped_parameters = [ 196 | { 197 | "params": [ 198 | p for n, p in trainer.model.named_parameters() if (n in decay_parameters and p.requires_grad) 199 | ], 200 | "weight_decay": decay_val, 201 | }, 202 | { 203 | "params": [ 204 | p for n, p in trainer.model.named_parameters() if (n not in decay_parameters and p.requires_grad) 205 | ], 206 | "weight_decay": 0, 207 | }, 208 | ] 209 | _, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(trainer.args) 210 | trainer.optimizer = CustomAdamW(keys, masks, optimizer_grouped_parameters, **optimizer_kwargs) 211 | 212 | 213 | class CustomAdamW(AdamW): 214 | def __init__(self, keys, masks, args, **kwargs): 215 | super().__init__(args, **kwargs) 216 | self.keys = keys 217 | self.masks = masks 218 | 219 | def step(self, closure=None): 220 | c = -1 221 | for i in range(len(self.param_groups)): 222 | for j, param in enumerate(self.param_groups[i]['params']): 223 | c += 1 224 | key = self.keys[i][j] 225 | 226 | key_ = '.'.join(key.split('.')[:-1]) 227 | _key = key.split('.')[-1] 228 | 229 | try: 230 | if isinstance(self.masks, dict): 231 | mask = self.masks[key_][_key] 232 | else: 233 | continue 234 | except: 235 | continue 236 | 237 | if param.grad is None: 238 | continue 239 | 240 | if mask.shape != param.grad.shape: 241 | print(key) 242 | raise RuntimeError 243 | 244 | param.grad *= mask.to(param.device) 245 | 246 | super(CustomAdamW, self).step(closure) 247 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from transformers.trainer_pt_utils import get_parameter_names 5 | 6 | 7 | def get_confusion_matrix(label, pred, size, num_class, ignore=255): 8 | """ 9 | Calcute the confusion matrix by given label and pred 10 | """ 11 | if pred.ndim == 4: 12 | output = pred.transpose(0, 2, 3, 1) 13 | seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8) 14 | else: 15 | seg_pred = pred 16 | seg_gt = np.asarray(label[:, :size[-2], :size[-1]], dtype=int) 17 | 18 | ignore_index = seg_gt != ignore 19 | seg_gt = seg_gt[ignore_index] 20 | seg_pred = seg_pred[ignore_index] 21 | 22 | index = (seg_gt * num_class + seg_pred).astype('int32') 23 | label_count = np.bincount(index) 24 | confusion_matrix = np.zeros((num_class, num_class)) 25 | 26 | for i_label in range(num_class): 27 | for i_pred in range(num_class): 28 | cur_index = i_label * num_class + i_pred 29 | if cur_index < len(label_count): 30 | confusion_matrix[i_label, 31 | i_pred] = label_count[cur_index] 32 | return confusion_matrix 33 | 34 | 35 | def get_model_param_keys(model): 36 | decay_parameters = get_parameter_names(model, [nn.LayerNorm]) 37 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 38 | 39 | keys = [[n for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)], 40 | [n for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)]] 41 | 42 | return keys 43 | 44 | 45 | def process_segmenter_output(outputs, target_sizes): 46 | class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] 47 | masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] 48 | 49 | # Scale back to preprocessed image size - (384, 384) for all models 50 | # masks_queries_logits = torch.nn.functional.interpolate( 51 | # masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False 52 | # ) 53 | 54 | # Remove the null class `[..., :-1]` 55 | masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] 56 | masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] 57 | 58 | # Semantic segmentation logits of shape (batch_size, num_classes, height, width) 59 | segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs).float() 60 | batch_size = class_queries_logits.shape[0] 61 | 62 | # Resize logits and compute semantic segmentation maps 63 | if target_sizes is not None: 64 | if batch_size != len(target_sizes): 65 | raise ValueError( 66 | "Make sure that you pass in as many target sizes as the batch dimension of the logits" 67 | ) 68 | 69 | semantic_segmentation = [] 70 | for idx in range(batch_size): 71 | resized_logits = torch.nn.functional.interpolate( 72 | segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False 73 | ) 74 | semantic_map = resized_logits[0].argmax(dim=0) 75 | semantic_segmentation.append(semantic_map) 76 | else: 77 | semantic_segmentation = segmentation.argmax(dim=1) 78 | semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] 79 | 80 | return semantic_segmentation 81 | --------------------------------------------------------------------------------